Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
N
NIFTy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
ift
NIFTy
Commits
cd349c61
Commit
cd349c61
authored
2 years ago
by
Gordian Edenhofer
Browse files
Options
Downloads
Patches
Plain Diff
Prune geomap demo
parent
26b503d0
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!832
Better Lanczos interface
Pipeline
#152735
passed
2 years ago
Stage: build_docker
Stage: test
Stage: demo_runs
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
demos/re/geomap.py
+0
-313
0 additions, 313 deletions
demos/re/geomap.py
with
0 additions
and
313 deletions
demos/re/geomap.py
deleted
100644 → 0
+
0
−
313
View file @
26b503d0
#!/usr/bin/env python3
# Copyright(C) 2013-2021 Max-Planck-Society
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
from
functools
import
partial
import
sys
from
jax
import
numpy
as
jnp
from
jax
import
random
from
jax
import
jit
import
jax
from
jax
import
random
import
matplotlib.pyplot
as
plt
import
nifty8.re
as
jft
jax
.
config
.
update
(
"
jax_enable_x64
"
,
True
)
# %%
def
lanczos_logdet
(
mat
,
v
,
order
:
int
,
):
"""
Computes a stochastic estimate of the log-determinate of the Lanczos
decomposed matrix. This is not the same as applying the stochastic Lanczos
quadrature algorithm as it estimates the log-determinate for the
decomposition only.
"""
mat
=
mat
.
__matmul__
if
not
hasattr
(
mat
,
"
__call__
"
)
else
mat
tridiag
,
vecs
=
jft
.
lanczos
.
lanczos_tridiag
(
mat
,
v
,
order
=
order
)
eig_vals
=
jnp
.
linalg
.
eigvalsh
(
tridiag
)
return
jnp
.
log
(
eig_vals
).
sum
(),
vecs
def
_metric_sample
(
hamiltonian
:
jft
.
StandardHamiltonian
,
primals
,
key
,
):
if
not
isinstance
(
hamiltonian
,
jft
.
StandardHamiltonian
):
te
=
f
"
`hamiltonian` of invalid type; got
'
{
type
(
hamiltonian
)
}
'"
raise
TypeError
(
te
)
subkey_nll
,
subkey_prr
=
random
.
split
(
key
,
2
)
nll_smpl
=
jft
.
kl
.
sample_likelihood
(
hamiltonian
.
likelihood
,
primals
,
key
=
subkey_nll
)
prr_inv_metric_smpl
=
jft
.
random_like
(
key
=
subkey_prr
,
primals
=
primals
)
# One may transform any metric sample to a sample of the inverse
# metric by simply applying the inverse metric to it
prr_smpl
=
prr_inv_metric_smpl
met_smpl
=
nll_smpl
+
prr_smpl
return
met_smpl
,
prr_smpl
def
geomap
(
hamiltonian
:
jft
.
StandardHamiltonian
,
order
:
int
,
key
,
sample_orthonormally
=
True
):
from
jax
import
flatten_util
def
geomap_energy
(
pos
,
return_aux
=
False
):
p
,
unflatten
=
flatten_util
.
ravel_pytree
(
pos
)
def
mat
(
x
):
# Hack to stomp arbitrary objects into a 1D array
o
,
_
=
flatten_util
.
ravel_pytree
(
hamiltonian
.
metric
(
pos
,
unflatten
(
x
))
)
return
o
probe
,
smpl
=
_metric_sample
(
hamiltonian
,
pos
,
key
)
probe
=
flatten_util
.
ravel_pytree
(
probe
)[
0
]
smpl
=
flatten_util
.
ravel_pytree
(
smpl
)[
0
]
logdet
,
vecs
=
lanczos_logdet
(
mat
,
probe
,
order
)
if
not
sample_orthonormally
:
energy
=
hamiltonian
(
pos
)
smpl_orig
,
smpl
=
None
,
None
else
:
#smpl = random.normal(smpl_key, p.shape)
smpl_orig
=
unflatten
(
smpl
.
copy
())
# TODO: Pull into new lanczos method which computes orthoganlized smpls
# for vecs
ortho_smpl
=
vecs
@
smpl
# One could add an additional `jnp.linalg.inv(vecs @ vecs.T)` in
# between the vecs to ensure proper projection
# ortho_smpl = jnp.linalg.inv(vecs @ vecs.T) @ ortho_smpl
ortho_smpl
=
vecs
.
T
@
ortho_smpl
smpl
-=
ortho_smpl
smpl
=
unflatten
(
smpl
)
# GeoMAP requires the sample to be mirrored as to perform MAP along
# the subspace in the (near) linear regime. With samples, the
# solution is not only much less noisy in this regime but is
# actually the true posterior.
energy
=
0.5
*
(
hamiltonian
(
pos
+
smpl
)
+
hamiltonian
(
pos
-
smpl
))
energy
+=
0.5
*
logdet
if
return_aux
:
return
energy
,
(
smpl_orig
,
smpl
)
return
energy
return
geomap_energy
# %%
def
hartley
(
p
,
axes
=
None
):
from
jax.numpy
import
fft
tmp
=
fft
.
fftn
(
p
,
axes
)
return
tmp
.
real
+
tmp
.
imag
seed
=
42
key
=
random
.
PRNGKey
(
seed
)
dims
=
(
1024
,
)
absdelta
=
1e-4
*
jnp
.
prod
(
jnp
.
array
(
dims
))
cf
=
{
"
loglogavgslope
"
:
2.
}
loglogslope
=
cf
[
"
loglogavgslope
"
]
power_spectrum
=
lambda
k
:
1.
/
(
k
**
loglogslope
+
1.
)
modes
=
jnp
.
arange
((
dims
[
0
]
/
2
)
+
1.
,
dtype
=
float
)
harmonic_power
=
power_spectrum
(
modes
)
# Every mode appears exactly two times, first ascending then descending
# Save a little on the computational side by mirroring the ascending part
harmonic_power
=
jnp
.
concatenate
((
harmonic_power
,
harmonic_power
[
-
2
:
0
:
-
1
]))
# Specify the model
correlated_field
=
jft
.
Model
(
lambda
x
:
hartley
(
harmonic_power
*
x
),
domain
=
jft
.
ShapeWithDtype
(
dims
)
)
signal_response
=
lambda
x
:
correlated_field
(
x
)
noise_cov
=
lambda
x
:
0.1
**
2
*
x
noise_cov_inv
=
lambda
x
:
0.1
**-
2
*
x
# Create synthetic data
key
,
subkey
=
random
.
split
(
key
)
pos_truth
=
jft
.
random_like
(
subkey
,
correlated_field
.
domain
)
signal_response_truth
=
signal_response
(
pos_truth
)
key
,
subkey
=
random
.
split
(
key
)
noise_truth
=
jnp
.
sqrt
(
noise_cov
(
jnp
.
ones
(
dims
))
)
*
random
.
normal
(
shape
=
dims
,
key
=
key
)
data
=
signal_response_truth
+
noise_truth
nll
=
jft
.
Gaussian
(
data
,
noise_cov_inv
)
@
signal_response
ham
=
jft
.
StandardHamiltonian
(
likelihood
=
nll
).
jit
()
plt
.
plot
(
jnp
.
array
([
signal_response_truth
,
data
]).
T
,
label
=
(
"
truth
"
,
"
data
"
))
plt
.
legend
()
plt
.
show
()
# %%
key
,
subkey
,
subkey_geomap
=
random
.
split
(
key
,
3
)
pos_init
=
jft
.
random_like
(
subkey
,
correlated_field
.
domain
)
pos
=
1e-2
*
pos_init
.
copy
()
# %%
print
(
"
!!! HAM
"
,
ham
(
pos
))
print
(
"
!!! metric
"
,
ham
.
metric
(
pos
,
pos
)
@
pos
)
# This is 50 times slower in compile time than ham.metric
geomap_order
=
40
geomap_energy
=
geomap
(
ham
,
geomap_order
,
subkey_geomap
,
sample_orthonormally
=
True
)
geomap_energy
=
jax
.
jit
(
geomap_energy
,
static_argnames
=
(
"
return_aux
"
,
))
print
(
"
!!! geomap_energy
"
,
geomap_energy
(
pos
))
# %%
pos
=
1e-2
*
pos_init
.
copy
()
opt_state_geomap
=
jft
.
minimize
(
geomap_energy
,
pos
,
method
=
"
newton-cg
"
,
options
=
{
"
name
"
:
"
N
"
,
"
maxiter
"
:
30
,
"
cg_kwargs
"
:
{
"
name
"
:
None
},
}
)
# %%
_
,
(
prr_smpl
,
ortho_smpl
)
=
geomap_energy
(
opt_state_geomap
.
x
,
return_aux
=
True
)
plt
.
plot
(
prr_smpl
,
label
=
"
prior sample
"
,
alpha
=
0.7
)
plt
.
plot
(
ortho_smpl
,
label
=
"
ortho sample
"
,
alpha
=
0.7
)
plt
.
plot
(
jnp
.
abs
(
prr_smpl
-
ortho_smpl
),
label
=
"
abs diff
"
,
alpha
=
0.3
)
plt
.
legend
()
plt
.
show
()
# %%
smpls_by_order
=
[]
for
i
in
range
(
1
,
geomap_order
):
_
,
(
_
,
s
)
=
geomap
(
ham
,
i
,
subkey_geomap
,
sample_orthonormally
=
True
)(
opt_state_geomap
.
x
,
return_aux
=
True
)
smpls_by_order
+=
[
s
]
smpls_by_order
=
jnp
.
array
(
smpls_by_order
)
# %%
fig
,
axs
=
plt
.
subplots
(
2
,
1
,
sharex
=
True
)
d
=
jnp
.
diff
(
smpls_by_order
,
axis
=
0
)
axs
.
flat
[
0
].
plot
(
smpls_by_order
.
T
,
label
=
jnp
.
arange
(
1
,
geomap_order
),
alpha
=
0.3
,
marker
=
"
.
"
)
axs
.
flat
[
0
].
axhline
(
0.
,
color
=
"
red
"
)
axs
.
flat
[
0
].
legend
()
axs
.
flat
[
1
].
plot
(
d
.
T
,
label
=
jnp
.
arange
(
1
,
geomap_order
-
1
),
alpha
=
0.3
,
marker
=
"
.
"
)
axs
.
flat
[
1
].
axhline
(
0.
,
color
=
"
red
"
)
axs
.
flat
[
1
].
legend
()
plt
.
show
()
# %%
plt
.
plot
(
jnp
.
array
(
[
signal_response_truth
,
data
,
signal_response
(
opt_state_geomap
.
x
),
signal_response
(
opt_state_geomap
.
x
+
ortho_smpl
),
]
).
T
,
label
=
(
"
truth
"
,
"
data
"
,
"
rec
"
,
"
rec + smpl
"
)
)
plt
.
legend
()
plt
.
show
()
# %%
n_samples
=
1
n_newton_iterations
=
10
n_mgvi_iterations
=
6
ham_vg
=
jit
(
jft
.
mean_value_and_grad
(
ham
))
ham_metric
=
jit
(
jft
.
mean_metric
(
ham
.
metric
))
MetricKL
=
jit
(
partial
(
jft
.
MetricKL
,
ham
),
static_argnames
=
(
"
n_samples
"
,
"
mirror_samples
"
,
"
linear_sampling_name
"
)
)
# %%
pos
=
1e-2
*
pos_init
.
copy
()
# Minimize the potential
for
i
in
range
(
n_mgvi_iterations
):
print
(
f
"
MGVI Iteration
{
i
}
"
,
file
=
sys
.
stderr
)
print
(
"
Sampling...
"
,
file
=
sys
.
stderr
)
key
,
subkey
=
random
.
split
(
key
,
2
)
samples
=
MetricKL
(
pos
,
n_samples
=
n_samples
,
key
=
subkey
,
mirror_samples
=
False
,
linear_sampling_kwargs
=
{
"
absdelta
"
:
absdelta
/
10.
,
"
maxiter
"
:
geomap_order
},
# linear_sampling_name="S",
)
print
(
"
Minimizing...
"
,
file
=
sys
.
stderr
)
opt_state_mgvi
=
jft
.
minimize
(
None
,
pos
,
method
=
"
newton-cg
"
,
options
=
{
"
fun_and_grad
"
:
partial
(
ham_vg
,
primals_samples
=
samples
),
"
hessp
"
:
partial
(
ham_metric
,
primals_samples
=
samples
),
"
absdelta
"
:
absdelta
,
"
maxiter
"
:
n_newton_iterations
}
)
pos
=
opt_state_mgvi
.
x
msg
=
f
"
Post MGVI Iteration
{
i
}
: Energy
{
samples
.
at
(
pos
).
mean
(
ham
)
:
2.4
e
}
"
print
(
msg
,
file
=
sys
.
stderr
)
# %%
plt
.
plot
(
jnp
.
array
(
[
signal_response_truth
,
data
,
signal_response
(
opt_state_geomap
.
x
),
signal_response
(
opt_state_mgvi
.
x
),
*
samples
.
at
(
opt_state_mgvi
.
x
).
apply
(
signal_response
),
]
).
T
,
label
=
(
"
truth
"
,
"
data
"
,
"
rec geomap
"
,
"
rec mgvi
"
,
)
+
(
"
smpls
"
,
)
*
len
(
samples
)
)
plt
.
legend
()
plt
.
show
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment