Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
N
NIFTy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
ift
NIFTy
Commits
db221234
Commit
db221234
authored
1 month ago
by
Jakob Roth
Browse files
Options
Downloads
Patches
Plain Diff
optimize_kl: map_over_devices -> devices
parent
e5d865bd
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!993
Multi gpu
Pipeline
#241421
failed
1 month ago
Stage: build_docker
Stage: test
Stage: demo_runs
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/re/optimize_kl.py
+15
-15
15 additions, 15 deletions
src/re/optimize_kl.py
with
15 additions
and
15 deletions
src/re/optimize_kl.py
+
15
−
15
View file @
db221234
...
@@ -269,7 +269,7 @@ class OptimizeVI:
...
@@ -269,7 +269,7 @@ class OptimizeVI:
residual_map
=
"
lmap
"
,
residual_map
=
"
lmap
"
,
kl_reduce
=
_reduce
,
kl_reduce
=
_reduce
,
mirror_samples
=
True
,
mirror_samples
=
True
,
map_over_
devices
:
Optional
[
list
]
=
None
,
devices
:
Optional
[
list
]
=
None
,
kl_device_map
=
"
shard_map
"
,
kl_device_map
=
"
shard_map
"
,
residual_device_map
=
"
pmap
"
,
residual_device_map
=
"
pmap
"
,
_kl_value_and_grad
:
Optional
[
Callable
]
=
None
,
_kl_value_and_grad
:
Optional
[
Callable
]
=
None
,
...
@@ -300,7 +300,7 @@ class OptimizeVI:
...
@@ -300,7 +300,7 @@ class OptimizeVI:
Reduce function used for the KL minimization.
Reduce function used for the KL minimization.
mirror_samples: bool
mirror_samples: bool
Whether to mirror the samples or not.
Whether to mirror the samples or not.
map_over_
devices : list of devices or None
devices : list of devices or None
Devices over which the samples are mapped. If `None` only the
Devices over which the samples are mapped. If `None` only the
default device is used. To use all detected devices pass
default device is used. To use all detected devices pass
jax.devices(). Generally the samples needs to be evenly
jax.devices(). Generally the samples needs to be evenly
...
@@ -308,17 +308,17 @@ class OptimizeVI:
...
@@ -308,17 +308,17 @@ class OptimizeVI:
the `kl_device_map` and `residual_device_map` arguments.
the `kl_device_map` and `residual_device_map` arguments.
kl_device_map : str
kl_device_map : str
Map function used for mapping KL minimization over the devices
Map function used for mapping KL minimization over the devices
listed in `
map_over_
devices`. `kl_device_map` can be
'
shard_map
'
,
listed in `devices`. `kl_device_map` can be
'
shard_map
'
,
'
pmap
'
, or
'
pmap
'
, or
'
jit
'
. If set to
'
pmap
'
, 2*n_samples need to be equal to
'
jit
'
. If set to
'
pmap
'
, 2*n_samples need to be equal to
the number
the number
of devices. For all other maps the samples needs to be
of devices. For all other maps the samples needs to be
equally
equally
distributable over the devices.
distributable over the devices.
residual_device_map : str
residual_device_map : str
Map function used for mapping sampling over the devices listed in
Map function used for mapping sampling over the devices listed in
`
map_over_
devices`. `residual_device_map` can be
'
shard_map
'
,
`devices`. `residual_device_map` can be
'
shard_map
'
,
'
pmap
'
, or
'
pmap
'
, or
'
jit
'
. If set to
'
pmap
'
, 2*n_samples need to be equal to
'
jit
'
. If set to
'
pmap
'
, 2*n_samples need to be equal to
the number
the number
of devices. If only linear samples are drawn ,
'
pmap
'
also
of devices. If only linear samples are drawn ,
'
pmap
'
also
works if
works if
n_samples equals the number of devices. For the other maps
n_samples equals the number of devices. For the other maps
it is
it is
sufficient if 2*n_samples equals the number of devices, or
sufficient if 2*n_samples equals the number of devices, or
n_samples can be evenly divided by the number of devices.
n_samples can be evenly divided by the number of devices.
Notes
Notes
...
@@ -340,8 +340,8 @@ class OptimizeVI:
...
@@ -340,8 +340,8 @@ class OptimizeVI:
residual_map
=
get_map
(
residual_map
)
residual_map
=
get_map
(
residual_map
)
self
.
mesh
=
None
self
.
mesh
=
None
self
.
pspec
=
None
self
.
pspec
=
None
if
(
not
map_over_
devices
is
None
)
and
len
(
map_over_
devices
)
>
1
:
if
(
not
devices
is
None
)
and
len
(
devices
)
>
1
:
self
.
mesh
=
Mesh
(
map_over_
devices
,
(
"
x
"
,))
self
.
mesh
=
Mesh
(
devices
,
(
"
x
"
,))
self
.
pspec
=
Pspec
(
"
x
"
)
self
.
pspec
=
Pspec
(
"
x
"
)
self
.
residual_device_map
=
residual_device_map
self
.
residual_device_map
=
residual_device_map
...
@@ -795,7 +795,7 @@ def optimize_kl(
...
@@ -795,7 +795,7 @@ def optimize_kl(
resume
:
Union
[
str
,
bool
]
=
False
,
resume
:
Union
[
str
,
bool
]
=
False
,
callback
:
Optional
[
Callable
[[
Samples
,
OptimizeVIState
],
None
]]
=
None
,
callback
:
Optional
[
Callable
[[
Samples
,
OptimizeVIState
],
None
]]
=
None
,
odir
:
Optional
[
str
]
=
None
,
odir
:
Optional
[
str
]
=
None
,
map_over_
devices
:
Optional
[
list
]
=
None
,
devices
:
Optional
[
list
]
=
None
,
kl_device_map
=
"
shard_map
"
,
kl_device_map
=
"
shard_map
"
,
residual_device_map
=
"
pmap
"
,
residual_device_map
=
"
pmap
"
,
_optimize_vi
=
None
,
_optimize_vi
=
None
,
...
@@ -835,7 +835,7 @@ def optimize_kl(
...
@@ -835,7 +835,7 @@ def optimize_kl(
residual_map
=
residual_map
,
residual_map
=
residual_map
,
kl_reduce
=
kl_reduce
,
kl_reduce
=
kl_reduce
,
mirror_samples
=
mirror_samples
,
mirror_samples
=
mirror_samples
,
map_over_devices
=
map_over_
devices
,
devices
=
devices
,
kl_device_map
=
kl_device_map
,
kl_device_map
=
kl_device_map
,
residual_device_map
=
residual_device_map
,
residual_device_map
=
residual_device_map
,
)
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment