Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
997cdaef
Commit
997cdaef
authored
Aug 21, 2018
by
Martin Reinecke
Browse files
make metric calculation optional
parent
f30e88f9
Changes
8
Hide whitespace changes
Inline
Side-by-side
nifty5/extra/energy_and_model_tests.py
View file @
997cdaef
...
...
@@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin):
for
i
in
range
(
50
):
try
:
loc2
=
loc
+
dir
lin2
=
op
(
Linearization
.
make_var
(
loc2
))
lin2
=
op
(
Linearization
.
make_var
(
loc2
,
lin
.
want_metric
))
if
np
.
isfinite
(
lin2
.
val
.
sum
())
and
abs
(
lin2
.
val
.
sum
())
<
1e20
:
break
except
FloatingPointError
:
...
...
@@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin):
def
_check_consistency
(
op
,
loc
,
tol
,
ntries
,
do_metric
):
for
_
in
range
(
ntries
):
lin
=
op
(
Linearization
.
make_var
(
loc
))
lin
=
op
(
Linearization
.
make_var
(
loc
,
do_metric
))
loc2
,
lin2
=
_get_acceptable_location
(
op
,
loc
,
lin
)
dir
=
loc2
-
loc
locnext
=
loc2
dirnorm
=
dir
.
norm
()
for
i
in
range
(
50
):
locmid
=
loc
+
0.5
*
dir
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
linmid
=
op
(
Linearization
.
make_var
(
locmid
,
do_metric
))
dirder
=
linmid
.
jac
(
dir
)
numgrad
=
(
lin2
.
val
-
lin
.
val
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
...
...
nifty5/library/inverse_gamma_model.py
View file @
997cdaef
...
...
@@ -53,7 +53,7 @@ class InverseGammaModel(Operator):
outer
=
1
/
outer_inv
jac
=
makeOp
(
Field
.
from_local_data
(
self
.
_domain
,
inner
*
outer
))
jac
=
jac
(
x
.
jac
)
return
Linearization
(
points
,
jac
)
return
x
.
new
(
points
,
jac
)
@
staticmethod
def
IG
(
field
,
alpha
,
q
):
...
...
nifty5/linearization.py
View file @
997cdaef
...
...
@@ -9,13 +9,17 @@ from .sugar import makeOp
class
Linearization
(
object
):
def
__init__
(
self
,
val
,
jac
,
metric
=
None
):
def
__init__
(
self
,
val
,
jac
,
metric
=
None
,
want_metric
=
False
):
self
.
_val
=
val
self
.
_jac
=
jac
if
self
.
_val
.
domain
!=
self
.
_jac
.
target
:
raise
ValueError
(
"domain mismatch"
)
self
.
_want_metric
=
want_metric
self
.
_metric
=
metric
def
new
(
self
,
val
,
jac
,
metric
=
None
):
return
Linearization
(
val
,
jac
,
metric
,
self
.
_want_metric
)
@
property
def
domain
(
self
):
return
self
.
_jac
.
domain
...
...
@@ -37,6 +41,10 @@ class Linearization(object):
"""Only available if target is a scalar"""
return
self
.
_jac
.
adjoint_times
(
Field
.
scalar
(
1.
))
@
property
def
want_metric
(
self
):
return
self
.
_want_metric
@
property
def
metric
(
self
):
"""Only available if target is a scalar"""
...
...
@@ -44,35 +52,34 @@ class Linearization(object):
def
__getitem__
(
self
,
name
):
from
.operators.simple_linear_operators
import
FieldAdapter
return
Linearization
(
self
.
_val
[
name
],
FieldAdapter
(
self
.
domain
,
name
))
return
self
.
new
(
self
.
_val
[
name
],
FieldAdapter
(
self
.
domain
,
name
))
def
__neg__
(
self
):
return
Linearization
(
-
self
.
_val
,
-
self
.
_jac
,
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
return
self
.
new
(
-
self
.
_val
,
-
self
.
_jac
,
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
def
conjugate
(
self
):
return
Linearization
(
return
self
.
new
(
self
.
_val
.
conjugate
(),
self
.
_jac
.
conjugate
(),
None
if
self
.
_metric
is
None
else
self
.
_metric
.
conjugate
())
@
property
def
real
(
self
):
return
Linearization
(
self
.
_val
.
real
,
self
.
_jac
.
real
)
return
self
.
new
(
self
.
_val
.
real
,
self
.
_jac
.
real
)
def
_myadd
(
self
,
other
,
neg
):
if
isinstance
(
other
,
Linearization
):
met
=
None
if
self
.
_metric
is
not
None
and
other
.
_metric
is
not
None
:
met
=
self
.
_metric
.
_myadd
(
other
.
_metric
,
neg
)
return
Linearization
(
return
self
.
new
(
self
.
_val
.
flexible_addsub
(
other
.
_val
,
neg
),
self
.
_jac
.
_myadd
(
other
.
_jac
,
neg
),
met
)
if
isinstance
(
other
,
(
int
,
float
,
complex
,
Field
,
MultiField
)):
if
neg
:
return
Linearization
(
self
.
_val
-
other
,
self
.
_jac
,
self
.
_metric
)
return
self
.
new
(
self
.
_val
-
other
,
self
.
_jac
,
self
.
_metric
)
else
:
return
Linearization
(
self
.
_val
+
other
,
self
.
_jac
,
self
.
_metric
)
return
self
.
new
(
self
.
_val
+
other
,
self
.
_jac
,
self
.
_metric
)
def
__add__
(
self
,
other
):
return
self
.
_myadd
(
other
,
False
)
...
...
@@ -91,7 +98,7 @@ class Linearization(object):
if
isinstance
(
other
,
Linearization
):
if
self
.
target
!=
other
.
target
:
raise
ValueError
(
"domain mismatch"
)
return
Linearization
(
return
self
.
new
(
self
.
_val
*
other
.
_val
,
(
makeOp
(
other
.
_val
)(
self
.
_jac
)).
_myadd
(
makeOp
(
self
.
_val
)(
other
.
_jac
),
False
))
...
...
@@ -99,11 +106,11 @@ class Linearization(object):
if
other
==
1
:
return
self
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
scale
(
other
)
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
return
self
.
new
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
if
self
.
target
!=
other
.
domain
:
raise
ValueError
(
"domain mismatch"
)
return
Linearization
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
return
self
.
new
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
...
...
@@ -111,46 +118,48 @@ class Linearization(object):
def
vdot
(
self
,
other
):
from
.operators.simple_linear_operators
import
VdotOperator
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
Linearization
(
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
vdot
(
other
)),
VdotOperator
(
other
)(
self
.
_jac
))
return
Linearization
(
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
vdot
(
other
.
_val
)),
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
def
sum
(
self
):
from
.operators.simple_linear_operators
import
SumReductionOperator
return
Linearization
(
return
self
.
new
(
Field
.
scalar
(
self
.
_val
.
sum
()),
SumReductionOperator
(
self
.
_jac
.
target
)(
self
.
_jac
))
def
exp
(
self
):
tmp
=
self
.
_val
.
exp
()
return
Linearization
(
tmp
,
makeOp
(
tmp
)(
self
.
_jac
))
return
self
.
new
(
tmp
,
makeOp
(
tmp
)(
self
.
_jac
))
def
log
(
self
):
tmp
=
self
.
_val
.
log
()
return
Linearization
(
tmp
,
makeOp
(
1.
/
self
.
_val
)(
self
.
_jac
))
return
self
.
new
(
tmp
,
makeOp
(
1.
/
self
.
_val
)(
self
.
_jac
))
def
tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
return
Linearization
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)(
self
.
_jac
))
return
self
.
new
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)(
self
.
_jac
))
def
positive_tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
tmp2
=
0.5
*
(
1.
+
tmp
)
return
Linearization
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))(
self
.
_jac
))
return
self
.
new
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))(
self
.
_jac
))
def
add_metric
(
self
,
metric
):
return
Linearization
(
self
.
_val
,
self
.
_jac
,
metric
)
return
self
.
new
(
self
.
_val
,
self
.
_jac
,
metric
)
@
staticmethod
def
make_var
(
field
):
def
make_var
(
field
,
want_metric
=
False
):
from
.operators.scaling_operator
import
ScalingOperator
return
Linearization
(
field
,
ScalingOperator
(
1.
,
field
.
domain
))
return
Linearization
(
field
,
ScalingOperator
(
1.
,
field
.
domain
),
want_metric
=
want_metric
)
@
staticmethod
def
make_const
(
field
):
def
make_const
(
field
,
want_metric
=
False
):
from
.operators.simple_linear_operators
import
NullOperator
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
))
return
Linearization
(
field
,
NullOperator
(
field
.
domain
,
field
.
domain
),
want_metric
=
want_metric
)
nifty5/minimization/energy_adapter.py
View file @
997cdaef
...
...
@@ -8,17 +8,18 @@ from ..operators.scaling_operator import ScalingOperator
class
EnergyAdapter
(
Energy
):
def
__init__
(
self
,
position
,
op
,
constants
=
[]):
def
__init__
(
self
,
position
,
op
,
constants
=
[]
,
want_metric
=
False
):
super
(
EnergyAdapter
,
self
).
__init__
(
position
)
self
.
_op
=
op
self
.
_constants
=
constants
if
len
(
self
.
_constants
)
==
0
:
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
))
tmp
=
self
.
_op
(
Linearization
.
make_var
(
self
.
_position
,
want_metric
))
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
self
.
_constants
else
1.
,
dom
)
for
key
,
dom
in
self
.
_position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
self
.
_position
.
domain
,
tuple
(
ops
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
))
tmp
=
self
.
_op
(
Linearization
(
self
.
_position
,
bdop
,
want_metric
=
want_metric
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
_metric
...
...
nifty5/minimization/kl_energy.py
View file @
997cdaef
...
...
@@ -9,22 +9,24 @@ from .. import utilities
class
KL_Energy
(
Energy
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
):
def
__init__
(
self
,
position
,
h
,
nsamp
,
constants
=
[],
_samples
=
None
,
want_metric
=
False
):
super
(
KL_Energy
,
self
).
__init__
(
position
)
self
.
_h
=
h
self
.
_constants
=
constants
self
.
_want_metric
=
want_metric
if
_samples
is
None
:
met
=
h
(
Linearization
.
make_var
(
position
)).
metric
met
=
h
(
Linearization
.
make_var
(
position
,
True
)).
metric
_samples
=
tuple
(
met
.
draw_sample
(
from_inverse
=
True
)
for
_
in
range
(
nsamp
))
self
.
_samples
=
_samples
if
len
(
constants
)
==
0
:
tmp
=
Linearization
.
make_var
(
position
)
tmp
=
Linearization
.
make_var
(
position
,
want_metric
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
position
.
domain
.
items
()]
bdop
=
BlockDiagonalOperator
(
position
.
domain
,
tuple
(
ops
))
tmp
=
Linearization
(
position
,
bdop
)
tmp
=
Linearization
(
position
,
bdop
,
want_metric
=
want_metric
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
tmp
+
v
),
self
.
_samples
)
tmp
=
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_samples
))
self
.
_val
=
tmp
.
val
.
local_data
[()]
...
...
@@ -32,7 +34,8 @@ class KL_Energy(Energy):
self
.
_metric
=
tmp
.
metric
def
at
(
self
,
position
):
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
)
return
KL_Energy
(
position
,
self
.
_h
,
0
,
self
.
_constants
,
self
.
_samples
,
self
.
_want_metric
)
@
property
def
value
(
self
):
...
...
nifty5/operators/energy_operators.py
View file @
997cdaef
...
...
@@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator):
if
isinstance
(
x
,
Linearization
):
val
=
Field
.
scalar
(
x
.
val
.
vdot
(
x
.
val
))
jac
=
VdotOperator
(
2
*
x
.
val
)(
x
.
jac
)
return
Linearization
(
val
,
jac
)
return
x
.
new
(
val
,
jac
)
return
Field
.
scalar
(
x
.
vdot
(
x
))
...
...
@@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator):
t1
=
self
.
_op
(
x
.
val
)
jac
=
VdotOperator
(
t1
)(
x
.
jac
)
val
=
Field
.
scalar
(
0.5
*
x
.
val
.
vdot
(
t1
))
return
Linearization
(
val
,
jac
)
return
x
.
new
(
val
,
jac
)
return
Field
.
scalar
(
0.5
*
x
.
vdot
(
self
.
_op
(
x
)))
...
...
@@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator):
def
apply
(
self
,
x
):
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
res
=
self
.
_op
(
residual
).
real
if
not
isinstance
(
x
,
Linearization
):
if
not
isinstance
(
x
,
Linearization
)
or
not
x
.
want_metric
:
return
res
metric
=
SandwichOperator
.
make
(
x
.
jac
,
self
.
_icov
)
return
res
.
add_metric
(
metric
)
...
...
@@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator):
res
=
x
.
sum
()
-
x
.
log
().
vdot
(
self
.
_d
)
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
res
)
if
not
x
.
want_metric
:
return
res
metric
=
SandwichOperator
.
make
(
x
.
jac
,
makeOp
(
1.
/
x
.
val
))
return
res
.
add_metric
(
metric
)
...
...
@@ -122,6 +124,8 @@ class BernoulliEnergy(EnergyOperator):
v
=
x
.
log
().
vdot
(
-
self
.
_d
)
-
(
1.
-
x
).
log
().
vdot
(
1.
-
self
.
_d
)
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
v
)
if
not
x
.
want_metric
:
return
v
met
=
makeOp
(
1.
/
(
x
.
val
*
(
1.
-
x
.
val
)))
met
=
SandwichOperator
.
make
(
x
.
jac
,
met
)
return
v
.
add_metric
(
met
)
...
...
@@ -135,11 +139,11 @@ class Hamiltonian(EnergyOperator):
self
.
_domain
=
lh
.
domain
def
apply
(
self
,
x
):
if
self
.
_ic_samp
is
None
or
not
isinstance
(
x
,
Linearization
):
if
(
self
.
_ic_samp
is
None
or
not
isinstance
(
x
,
Linearization
)
or
not
x
.
want_metric
):
return
self
.
_lh
(
x
)
+
self
.
_prior
(
x
)
else
:
lhx
=
self
.
_lh
(
x
)
prx
=
self
.
_prior
(
x
)
lhx
,
prx
=
self
.
_lh
(
x
),
self
.
_prior
(
x
)
mtr
=
SamplingEnabler
(
lhx
.
metric
,
prx
.
metric
.
inverse
,
self
.
_ic_samp
,
prx
.
metric
.
inverse
)
return
(
lhx
+
prx
).
add_metric
(
mtr
)
...
...
nifty5/operators/linear_operator.py
View file @
997cdaef
...
...
@@ -175,7 +175,7 @@ class LinearOperator(Operator):
return
self
.
apply
(
x
,
self
.
TIMES
)
from
..linearization
import
Linearization
if
isinstance
(
x
,
Linearization
):
return
Linearization
(
self
(
x
.
_val
),
self
(
x
.
_jac
))
return
x
.
new
(
self
(
x
.
_val
),
self
(
x
.
_jac
))
return
self
.
__matmul__
(
x
)
def
times
(
self
,
x
):
...
...
nifty5/operators/operator.py
View file @
997cdaef
...
...
@@ -144,11 +144,12 @@ class _OpProd(Operator):
v2
=
v
.
extract
(
self
.
_op2
.
domain
)
if
not
lin
:
return
self
.
_op1
(
v1
)
*
self
.
_op2
(
v2
)
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
))
wm
=
x
.
want_metric
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
,
wm
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
,
wm
))
op
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
return
L
in
earization
(
lin1
.
_val
*
lin2
.
_val
,
op
(
x
.
jac
))
return
l
in
1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
op
(
x
.
jac
))
class
_OpSum
(
Operator
):
...
...
@@ -168,10 +169,11 @@ class _OpSum(Operator):
res
=
None
if
not
lin
:
return
self
.
_op1
(
v1
).
unite
(
self
.
_op2
(
v2
))
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
))
wm
=
x
.
want_metric
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
,
wm
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
,
wm
))
op
=
lin1
.
_jac
.
_myadd
(
lin2
.
_jac
,
False
)
res
=
L
in
earization
(
lin1
.
_val
+
lin2
.
_val
,
op
(
x
.
jac
))
res
=
l
in
1
.
new
(
lin1
.
_val
+
lin2
.
_val
,
op
(
x
.
jac
))
if
lin1
.
_metric
is
not
None
and
lin2
.
_metric
is
not
None
:
res
=
res
.
add_metric
(
lin1
.
_metric
+
lin2
.
_metric
)
return
res
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment