Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
3d526bfa
Commit
3d526bfa
authored
Aug 26, 2018
by
Martin Reinecke
Browse files
Merge branch 'NIFTy_5' into adjust_variances_but_right
parents
926503ff
b9f09974
Changes
20
Hide whitespace changes
Inline
Side-by-side
demos/bernoulli_demo.py
View file @
3d526bfa
...
...
@@ -73,7 +73,7 @@ if __name__ == '__main__':
# Minimize the Hamiltonian
H
=
ift
.
Hamiltonian
(
likelihood
,
ic_sampling
)
H
=
ift
.
EnergyAdapter
(
position
,
H
)
H
=
ift
.
EnergyAdapter
(
position
,
H
,
want_metric
=
True
)
# minimizer = ift.L_BFGS(ic_newton)
H
,
convergence
=
minimizer
(
H
)
...
...
demos/getting_started_2.py
View file @
3d526bfa
...
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
# Minimize the Hamiltonian
H
=
ift
.
Hamiltonian
(
likelihood
)
H
=
ift
.
EnergyAdapter
(
position
,
H
)
H
=
ift
.
EnergyAdapter
(
position
,
H
,
want_metric
=
True
)
H
,
convergence
=
minimizer
(
H
)
# Plot results
...
...
demos/getting_started_3.py
View file @
3d526bfa
...
...
@@ -91,30 +91,24 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples
=
20
for
i
in
range
(
2
):
metric
=
H
(
ift
.
Linearization
.
make_var
(
position
)).
metric
samples
=
[
metric
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
N_samples
)
]
KL
=
ift
.
SampledKullbachLeiblerDivergence
(
H
,
samples
)
KL
=
ift
.
EnergyAdapter
(
position
,
KL
)
KL
=
ift
.
KL_Energy
(
position
,
H
,
N_samples
,
want_metric
=
True
)
KL
,
convergence
=
minimizer
(
KL
)
position
=
KL
.
position
plot
=
ift
.
Plot
()
plot
.
add
(
signal
(
position
),
title
=
"reconstruction"
)
plot
.
add
([
A
(
position
),
A
(
MOCK_POSITION
)],
title
=
"power"
)
plot
.
add
(
signal
(
KL
.
position
),
title
=
"reconstruction"
)
plot
.
add
([
A
(
KL
.
position
),
A
(
MOCK_POSITION
)],
title
=
"power"
)
plot
.
output
(
ny
=
1
,
ysize
=
6
,
xsize
=
16
,
name
=
"loop.png"
)
plot
=
ift
.
Plot
()
sc
=
ift
.
StatCalculator
()
for
sample
in
samples
:
sc
.
add
(
signal
(
sample
+
position
))
for
sample
in
KL
.
samples
:
sc
.
add
(
signal
(
sample
+
KL
.
position
))
plot
.
add
(
sc
.
mean
,
title
=
"Posterior Mean"
)
plot
.
add
(
ift
.
sqrt
(
sc
.
var
),
title
=
"Posterior Standard Deviation"
)
powers
=
[
A
(
s
+
position
)
for
s
in
samples
]
powers
=
[
A
(
s
+
KL
.
position
)
for
s
in
KL
.
samples
]
plot
.
add
(
[
A
(
position
),
A
(
MOCK_POSITION
)]
+
powers
,
[
A
(
KL
.
position
),
A
(
MOCK_POSITION
)]
+
powers
,
title
=
"Sampled Posterior Power Spectrum"
)
plot
.
output
(
ny
=
1
,
nx
=
3
,
xsize
=
24
,
ysize
=
6
,
name
=
"results.png"
)
demos/polynomial_fit.py
View file @
3d526bfa
...
...
@@ -86,15 +86,16 @@ N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC
=
ift
.
GradientNormController
(
tol_abs_gradnorm
=
1e-8
)
likelihood
=
ift
.
GaussianEnergy
(
d
,
N
)(
R
)
H
=
ift
.
Hamiltonian
(
likelihood
,
IC
)
H
=
ift
.
EnergyAdapter
(
params
,
H
,
IC
)
H
am
=
ift
.
Hamiltonian
(
likelihood
,
IC
)
H
=
ift
.
EnergyAdapter
(
params
,
H
am
,
want_metric
=
True
)
# Minimize
minimizer
=
ift
.
NewtonCG
(
IC
)
H
,
_
=
minimizer
(
H
)
# Draw posterior samples
samples
=
[
H
.
metric
.
draw_sample
(
from_inverse
=
True
)
+
H
.
position
metric
=
Ham
(
ift
.
Linearization
.
make_var
(
H
.
position
,
want_metric
=
True
)).
metric
samples
=
[
metric
.
draw_sample
(
from_inverse
=
True
)
+
H
.
position
for
_
in
range
(
N_samples
)]
# Plotting
...
...
nifty5/__init__.py
View file @
3d526bfa
...
...
@@ -67,6 +67,7 @@ from .minimization.energy import Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.line_energy
import
LineEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.kl_energy
import
KL_Energy
from
.sugar
import
*
from
.plotting.plot
import
Plot
...
...
nifty5/extra/energy_and_model_tests.py
View file @
3d526bfa
...
...
@@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin):
for
i
in
range
(
50
):
try
:
loc2
=
loc
+
dir
lin2
=
op
(
Linearization
.
make_var
(
loc2
))
lin2
=
op
(
Linearization
.
make_var
(
loc2
,
lin
.
want_metric
))
if
np
.
isfinite
(
lin2
.
val
.
sum
())
and
abs
(
lin2
.
val
.
sum
())
<
1e20
:
break
except
FloatingPointError
:
...
...
@@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin):
def
_check_consistency
(
op
,
loc
,
tol
,
ntries
,
do_metric
):
for
_
in
range
(
ntries
):
lin
=
op
(
Linearization
.
make_var
(
loc
))
lin
=
op
(
Linearization
.
make_var
(
loc
,
do_metric
))
loc2
,
lin2
=
_get_acceptable_location
(
op
,
loc
,
lin
)
dir
=
loc2
-
loc
locnext
=
loc2
dirnorm
=
dir
.
norm
()
for
i
in
range
(
50
):
locmid
=
loc
+
0.5
*
dir
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
linmid
=
op
(
Linearization
.
make_var
(
locmid
,
do_metric
))
dirder
=
linmid
.
jac
(
dir
)
numgrad
=
(
lin2
.
val
-
lin
.
val
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
...
...
nifty5/library/inverse_gamma_model.py
View file @
3d526bfa
...
...
@@ -53,7 +53,7 @@ class InverseGammaModel(Operator):
outer
=
1
/
outer_inv
jac
=
makeOp
(
Field
.
from_local_data
(
self
.
_domain
,
inner
*
outer
))
jac
=
jac
(
x
.
jac
)
return
Linearization
(
points
,
jac
)
return
x
.
new
(
points
,
jac
)
@
staticmethod
def
IG
(
field
,
alpha
,
q
):
...
...
nifty5/linearization.py
View file @
3d526bfa
...
...
@@ -9,13 +9,17 @@ from .sugar import makeOp
class
Linearization
(
object
):
def
__init__
(
self
,
val
,
jac
,
metric
=
None
):
def
__init__
(
self
,
val
,
jac
,
metric
=
None
,
want_metric
=
False
):
self
.
_val
=
val
self
.
_jac
=
jac
if
self
.
_val
.
domain
!=
self
.
_jac
.
target
:
raise
ValueError
(
"domain mismatch"
)
self
.
_want_metric
=
want_metric
self
.
_metric
=
metric
def
new
(
self
,
val
,
jac
,
metric
=
None
):
return
Linearization
(
val
,
jac
,
metric
,
self
.
_want_metric
)
@
property
def
domain
(
self
):
return
self
.
_jac
.
domain
...
...
@@ -37,6 +41,10 @@ class Linearization(object):
"""Only available if target is a scalar"""
return
self
.
_jac
.
adjoint_times
(
Field
.
scalar
(
1.
))
@
property
def
want_metric
(
self
):
return
self
.
_want_metric
@
property
def
metric
(
self
):
"""Only available if target is a scalar"""
...
...
@@ -44,35 +52,34 @@ class Linearization(object):
def
__getitem__
(
self
,
name
):
from
.operators.simple_linear_operators
import
FieldAdapter
return
Linearization
(
self
.
_val
[
name
],
FieldAdapter
(
self
.
domain
,
name
))
return
self
.
new
(
self
.
_val
[
name
],
FieldAdapter
(
self
.
domain
,
name
))
def
__neg__
(
self
):
return
Linearization
(
-
self
.
_val
,
-
self
.
_jac
,
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
return
self
.
new
(
-
self
.
_val
,
-
self
.
_jac
,
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
def
conjugate
(
self
):
return
Linearization
(
return
self
.
new
(
self
.
_val
.
conjugate
(),
self
.
_jac
.
conjugate
(),
None
if
self
.
_metric
is
None
else
self
.
_metric
.
conjugate
())
@
property
def
real
(
self
):
return
Linearization
(
self
.
_val
.
real
,
self
.
_jac
.
real
)
return
self
.
new
(
self
.
_val
.
real
,
self
.
_jac
.
real
)
def
_myadd
(
self
,
other
,
neg
):
if
isinstance
(
other
,
Linearization
):
met
=
None
if
self
.
_metric
is
not
None
and
other
.
_metric
is
not
None
:
met
=
self
.
_metric
.
_myadd
(
other
.
_metric
,
neg
)
return
Linearization
(
return
self
.
new
(
self
.
_val
.
flexible_addsub
(
other
.
_val
,
neg
),
self
.
_jac
.
_myadd
(
other
.
_jac
,
neg
),
met
)
if
isinstance
(
other
,
(
int
,
float
,
complex
,
Field
,
MultiField
)):
if
neg
:
return
Linearization
(
self
.
_val
-
other
,
self
.
_jac
,
self
.
_metric
)
return
self
.
new
(
self
.
_val
-
other
,
self
.
_jac
,
self
.
_metric
)
else
:
return
Linearization
(
self
.
_val
+
other
,
self
.
_jac
,
self
.
_metric
)
return
self
.
new
(
self
.
_val
+
other
,
self
.
_jac
,
self
.
_metric
)
def
__add__
(
self
,
other
):
return
self
.
_myadd
(
other
,
False
)
...
...
@@ -98,7 +105,7 @@ class Linearization(object):
if
isinstance
(
other
,
Linearization
):
if
self
.
target
!=
other
.
target
:
raise
ValueError
(
"domain mismatch"
)
return
Linearization
(
return
self
.
new
(
self
.
_val
*
other
.
_val
,
(
makeOp
(
other
.
_val
)(
self
.
_jac
)).
_myadd
(
makeOp
(
self
.
_val
)(
other
.
_jac
),
False
))
...
...
@@ -106,11 +113,11 @@ class Linearization(object):
if
other
==
1
:
return
self
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
scale
(
other
)
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
return
self
.
new
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
if
self
.
target
!=
other
.
domain
:
raise
ValueError
(
"domain mismatch"
)
return
Linearization
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
return
self
.
new
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
...
...
@@ -118,46 +125,48 @@ class Linearization(object):
def
vdot
(
self
,
other
):
from
.operators.simple_linear_operators
import
VdotOperator
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
Linearization
(
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
vdot
(
other
)),
VdotOperator
(
other
)(
self
.
_jac
))
return
Linearization
(
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
vdot
(
other
.
_val
)),
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
def
sum
(
self
):
from
.operators.simple_linear_operators
import
SumReductionOperator
return
Linearization
(
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
SumReductionOperator
(
self
.
_jac
.
target
)(
self
.
_jac
))
def
exp
(
self
):
tmp
=
self
.
_val
.
exp
()
return
Linearization
(
tmp
,
makeOp
(
tmp
)(
self
.
_jac
))
return
self
.
new
(
tmp
,
makeOp
(
tmp
)(
self
.
_jac
))
def
log
(
self
):
tmp
=
self
.
_val
.
log
()
return
Linearization
(
tmp
,
makeOp
(
1.
/
self
.
_val
)(
self
.
_jac
))
return
self
.
new
(
tmp
,
makeOp
(
1.
/
self
.
_val
)(
self
.
_jac
))
def
tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
return
Linearization
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)(
self
.
_jac
))
return
self
.
new
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)(
self
.
_jac
))
def
positive_tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
tmp2
=
0.5
*
(
1.
+
tmp
)
return
Linearization
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))(
self
.
_jac
))
return
self
.
new
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))(
self
.
_jac
))
def
add_metric
(
self
,
metric
):
return
Linearization
(
self
.
_val
,
self
.
_jac
,
metric
)
return
self
.
new
(
self
.
_val
,
self
.
_jac
,
metric
)
@
staticmethod
def
make_var
(
field
):
def
make_var
(
field
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
return
Linearization
(
field
,
ScalingOperator
(
1.
,
field
.
domain
))
return
Linearization
(
field
,
ScalingOperator
(
1.
,
field
.
domain
),
want_metric
=
want_metric
)
@
staticmethod
def
make_const
(
field
):
def
make_const
(
field
,
want_metric
=
False
):
from
.operators.simple_linear_operators
import
NullOperator
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
))
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
want_metric
=
want_metric
)
nifty5/minimization/conjugate_gradient.py
View file @
3d526bfa
...
...
@@ -75,7 +75,7 @@ class ConjugateGradient(Minimizer):
return
energy
,
controller
.
CONVERGED
while
True
:
q
=
energy
.
metric
(
d
)
q
=
energy
.
apply_
metric
(
d
)
ddotq
=
d
.
vdot
(
q
).
real
if
ddotq
==
0.
:
logger
.
error
(
"Error: ConjugateGradient: ddotq==0."
)
...
...
nifty5/minimization/descent_minimizers.py
View file @
3d526bfa
...
...
@@ -180,7 +180,7 @@ class NewtonCG(DescentMinimizer):
while
True
:
if
abs
(
ri
).
sum
()
<=
termcond
:
return
xsupi
Ap
=
energy
.
metric
(
psupi
)
Ap
=
energy
.
apply_
metric
(
psupi
)
# check curvature
curv
=
psupi
.
vdot
(
Ap
)
if
0
<=
curv
<=
3
*
float64eps
:
...
...
nifty5/minimization/energy.py
View file @
3d526bfa
...
...
@@ -109,6 +109,20 @@ class Energy(NiftyMetaBase()):
"""
raise
NotImplementedError
def
apply_metric
(
self
,
x
):
"""
Parameters
----------
x: Field/MultiField
Argument for the metric operator
Returns
-------
Field/MultiField:
Output of the metric operator
"""
raise
NotImplementedError
def
longest_step
(
self
,
dir
):
"""Returns the longest allowed step size along `dir`
...
...
nifty5/minimization/energy_adapter.py
View file @
3d526bfa
...
...
@@ -8,58 +8,38 @@ from ..operators.scaling_operator import ScalingOperator
class
EnergyAdapter
(
Energy
):
def
__init__
(
self
,
position
,
op
,
controller
=
None
,
preconditioner
=
None
,
constants
=
[]):
def
__init__
(
self
,
position
,
op
,
constants
=
[],
want_metric
=
False
):
super
(
EnergyAdapter
,
self
).
__init__
(
position
)
self
.
_op
=
op
self
.
_val
=
self
.
_grad
=
self
.
_metric
=
None
self
.
_controller
=
controller
self
.
_preconditioner
=
preconditioner
self
.
_constants
=
constants
def
at
(
self
,
position
):
return
EnergyAdapter
(
position
,
self
.
_op
,
self
.
_controller
,
self
.
_preconditioner
,
self
.
_constants
)
def
_fill_all
(
self
):
self
.
_want_metric
=
want_metric
if
len
(
self
.
_constants
)
==
0
:
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
))
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
,
want_metric
))
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
self
.
_constants
else
1.
,
dom
)
for
key
,
dom
in
self
.
_position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
self
.
_position
.
domain
,
tuple
(
ops
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
,
want_metric
=
want_metric
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
if
self
.
_controller
is
not
None
:
from
..operators.linear_operator
import
LinearOperator
from
..operators.inversion_enabler
import
InversionEnabler
self
.
_metric
=
tmp
.
_metric
if
self
.
_preconditioner
is
None
:
precond
=
None
elif
isinstance
(
self
.
_preconditioner
,
LinearOperator
):
precond
=
self
.
_preconditioner
elif
isinstance
(
self
.
_preconditioner
,
Energy
):
precond
=
self
.
_preconditioner
.
at
(
self
.
_position
).
metric
self
.
_metric
=
InversionEnabler
(
tmp
.
_metric
,
self
.
_controller
,
precond
)
else
:
self
.
_metric
=
tmp
.
_metric
def
at
(
self
,
position
):
return
EnergyAdapter
(
position
,
self
.
_op
,
self
.
_constants
,
self
.
_want_metric
)
@
property
def
value
(
self
):
if
self
.
_val
is
None
:
self
.
_val
=
self
.
_op
(
self
.
_position
).
local_data
[()]
return
self
.
_val
@
property
def
gradient
(
self
):
if
self
.
_grad
is
None
:
self
.
_fill_all
()
return
self
.
_grad
@
property
def
metric
(
self
):
if
self
.
_metric
is
None
:
self
.
_fill_all
()
return
self
.
_metric
def
apply_metric
(
self
,
x
):
return
self
.
_metric
(
x
)
nifty5/minimization/kl_energy.py
0 → 100644
View file @
3d526bfa
from
__future__
import
absolute_import
,
division
,
print_function
from
..compat
import
*
from
.energy
import
Energy
from
..linearization
import
Linearization
from
..operators.scaling_operator
import
ScalingOperator
from
..operators.block_diagonal_operator
import
BlockDiagonalOperator
from
..
import
utilities
class
KL_Energy
(
Energy
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
,
want_metric
=
False
):
super
(
KL_Energy
,
self
).
__init__
(
position
)
self
.
_h
=
h
self
.
_constants
=
constants
self
.
_want_metric
=
want_metric
if
_samples
is
None
:
met
=
h
(
Linearization
.
make_var
(
position
,
True
)).
metric
_samples
=
tuple
(
met
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
nsamp
))
self
.
_samples
=
_samples
if
len
(
constants
)
==
0
:
tmp
=
Linearization
.
make_var
(
position
,
want_metric
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
position
.
domain
,
tuple
(
ops
))
tmp
=
Linearization
(
position
,
bdop
,
want_metric
=
want_metric
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
tmp
+
v
),
self
.
_samples
)
tmp
=
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_samples
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
metric
def
at
(
self
,
position
):
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
,
self
.
_want_metric
)
@
property
def
value
(
self
):
return
self
.
_val
@
property
def
gradient
(
self
):
return
self
.
_grad
def
apply_metric
(
self
,
x
):
return
self
.
_metric
(
x
)
@
property
def
metric
(
self
):
return
self
.
_metric
@
property
def
samples
(
self
):
return
self
.
_samples
nifty5/minimization/quadratic_energy.py
View file @
3d526bfa
...
...
@@ -77,3 +77,6 @@ class QuadraticEnergy(Energy):
@
property
def
metric
(
self
):
return
self
.
_A
def
apply_metric
(
self
,
x
):
return
self
.
_A
(
x
)
nifty5/minimization/scipy_minimizer.py
View file @
3d526bfa
...
...
@@ -93,7 +93,7 @@ class _MinHelper(object):
def
hessp
(
self
,
x
,
p
):
self
.
_update
(
x
)
res
=
self
.
_energy
.
metric
(
_toField
(
p
,
self
.
_energy
.
position
))
res
=
self
.
_energy
.
apply_
metric
(
_toField
(
p
,
self
.
_energy
.
position
))
return
_toArray_rw
(
res
)
...
...
nifty5/operators/energy_operators.py
View file @
3d526bfa
...
...
@@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator):
if
isinstance
(
x
,
Linearization
):
val
=
Field
.
scalar
(
x
.
val
.
vdot
(
x
.
val
))
jac
=
VdotOperator
(
2
*
x
.
val
)(
x
.
jac
)
return
Linearization
(
val
,
jac
)
return
x
.
new
(
val
,
jac
)
return
Field
.
scalar
(
x
.
vdot
(
x
))
...
...
@@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator):
t1
=
self
.
_op
(
x
.
val
)
jac
=
VdotOperator
(
t1
)(
x
.
jac
)
val
=
Field
.
scalar
(
0.5
*
x
.
val
.
vdot
(
t1
))
return
Linearization
(
val
,
jac
)
return
x
.
new
(
val
,
jac
)
return
Field
.
scalar
(
0.5
*
x
.
vdot
(
self
.
_op
(
x
)))
...
...
@@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator):
def
apply
(
self
,
x
):
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
res
=
self
.
_op
(
residual
).
real
if
not
isinstance
(
x
,
Linearization
):
if
not
isinstance
(
x
,
Linearization
)
or
not
x
.
want_metric
:
return
res
metric
=
SandwichOperator
.
make
(
x
.
jac
,
self
.
_icov
)
return
res
.
add_metric
(
metric
)
...
...
@@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator):
res
=
x
.
sum
()
-
x
.
log
().
vdot
(
self
.
_d
)
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
res
)
if
not
x
.
want_metric
:
return
res
metric
=
SandwichOperator
.
make
(
x
.
jac
,
makeOp
(
1.
/
x
.
val
))
return
res
.
add_metric
(
metric
)
...
...
@@ -136,6 +138,8 @@ class BernoulliEnergy(EnergyOperator):
v
=
x
.
log
().
vdot
(
-
self
.
_d
)
-
(
1.
-
x
).
log
().
vdot
(
1.
-
self
.
_d
)
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
v
)
if
not
x
.
want_metric
:
return
v
met
=
makeOp
(
1.
/
(
x
.
val
*
(
1.
-
x
.
val
)))
met
=
SandwichOperator
.
make
(
x
.
jac
,
met
)
return
v
.
add_metric
(
met
)
...
...
@@ -149,11 +153,11 @@ class Hamiltonian(EnergyOperator):
self
.
_domain
=
lh
.
domain
def
apply
(
self
,
x
):
if
self
.
_ic_samp
is
None
or
not
isinstance
(
x
,
Linearization
):
if
(
self
.
_ic_samp
is
None
or
not
isinstance
(
x
,
Linearization
)
or
not
x
.
want_metric
):
return
self
.
_lh
(
x
)
+
self
.
_prior
(
x
)
else
:
lhx
=
self
.
_lh
(
x
)
prx
=
self
.
_prior
(
x
)
lhx
,
prx
=
self
.
_lh
(
x
),
self
.
_prior
(
x
)
mtr
=
SamplingEnabler
(
lhx
.
metric
,
prx
.
metric
.
inverse
,
self
.
_ic_samp
,
prx
.
metric
.
inverse
)
return
(
lhx
+
prx
).
add_metric
(
mtr
)
...
...
nifty5/operators/linear_operator.py
View file @
3d526bfa
...
...
@@ -175,7 +175,7 @@ class LinearOperator(Operator):
return
self
.
apply
(
x
,
self
.
TIMES
)
from
..linearization
import
Linearization
if
isinstance
(
x
,
Linearization
):
return
Linearization
(
self
(
x
.
_val
),
self
(
x
.
_jac
))
return
x
.
new
(
self
(
x
.
_val
),
self
(
x
.
_jac
))
return
self
.
__matmul__
(
x
)
def
times
(
self
,
x
):
...
...
nifty5/operators/operator.py
View file @
3d526bfa
...
...
@@ -144,11 +144,12 @@ class _OpProd(Operator):
v2
=
v
.
extract
(
self
.
_op2
.
domain
)
if
not
lin
:
return
self
.
_op1
(
v1
)
*
self
.
_op2
(
v2
)
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
))
wm
=
x
.
want_metric
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
,
wm
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
,
wm
))
op
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
return
L
in
earization
(
lin1
.
_val
*
lin2
.
_val
,
op
(
x
.
jac
))
return
l
in
1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
op
(
x
.
jac
))
class
_OpSum
(
Operator
):
...
...
@@ -168,10 +169,11 @@ class _OpSum(Operator):
res
=
None
if
not
lin
:
return
self
.
_op1
(
v1
).
unite
(
self
.
_op2
(
v2
))
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
))
wm
=
x
.
want_metric
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
,
wm
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
,
wm
))
op
=
lin1
.
_jac
.
_myadd
(
lin2
.
_jac
,
False
)
res
=
L
in
earization
(
lin1
.
_val
+
lin2
.
_val
,
op
(
x
.
jac
))
res
=
l
in
1
.
new
(
lin1
.
_val
+
lin2
.
_val
,
op
(
x
.
jac
))
if
lin1
.
_metric
is
not
None
and
lin2
.
_metric
is
not
None
:
res
=
res
.
add_metric
(
lin1
.
_metric
+
lin2
.
_metric
)
return
res
nifty5/plotting/plot.py
View file @
3d526bfa
...
...
@@ -267,7 +267,6 @@ class Plot(object):
self
.
_plots
=
[]
self
.
_kwargs
=
[]
def
add
(
self
,
f
,
**
kwargs
):
"""Add a figure to the current list of plots.
...
...
@@ -303,7 +302,6 @@ class Plot(object):
self
.
_plots
.
append
(
f
)
self
.
_kwargs
.
append
(
kwargs
)
def
output
(
self
,
**
kwargs
):
"""Plot the accumulated list of figures.
...
...