Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
01779e03
Commit
01779e03
authored
Apr 08, 2020
by
Martin Reinecke
Browse files
more
parent
f24e26e9
Changes
14
Show whitespace changes
Inline
Side-by-side
nifty6/extra.py
View file @
01779e03
...
...
@@ -32,7 +32,8 @@ __all__ = ["consistency_check", "check_jacobian_consistency",
def
assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
val
,
f2
.
val
,
atol
=
atol
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
f1
.
val
,
f2
.
val
,
atol
=
atol
,
rtol
=
rtol
)
else
:
for
key
,
val
in
f1
.
items
():
assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
...
...
@@ -103,10 +104,10 @@ def _actual_domain_check_nonlinear(op, loc):
reslin
=
op
(
lin
)
assert_
(
lin
.
domain
is
op
.
domain
)
assert_
(
lin
.
target
is
op
.
domain
)
assert_
(
lin
.
val
.
domain
is
lin
.
domain
)
assert_
(
lin
.
fld
.
domain
is
lin
.
domain
)
assert_
(
reslin
.
domain
is
op
.
domain
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
val
.
domain
is
reslin
.
target
)
assert_
(
reslin
.
fld
.
domain
is
reslin
.
target
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
jac
.
domain
is
reslin
.
domain
)
assert_
(
reslin
.
jac
.
target
is
reslin
.
target
)
...
...
@@ -150,7 +151,7 @@ def _performance_check(op, pos, raise_on_fail):
cond
.
append
(
cop
.
count
!=
2
)
lin
.
jac
(
pos
)
cond
.
append
(
cop
.
count
!=
3
)
lin
.
jac
.
adjoint
(
lin
.
val
)
lin
.
jac
.
adjoint
(
lin
.
fld
)
cond
.
append
(
cop
.
count
!=
4
)
if
lin
.
metric
is
not
None
:
lin
.
metric
(
pos
)
...
...
@@ -217,20 +218,20 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def
_get_acceptable_location
(
op
,
loc
,
lin
):
if
not
np
.
isfinite
(
lin
.
val
.
s_sum
()):
if
not
np
.
isfinite
(
lin
.
fld
.
s_sum
()):
raise
ValueError
(
'Initial value must be finite'
)
dir
=
from_random
(
"normal"
,
loc
.
domain
)
dirder
=
lin
.
jac
(
dir
)
if
dirder
.
norm
()
==
0
:
dir
=
dir
*
(
lin
.
val
.
norm
()
*
1e-5
)
dir
=
dir
*
(
lin
.
fld
.
norm
()
*
1e-5
)
else
:
dir
=
dir
*
(
lin
.
val
.
norm
()
*
1e-5
/
dirder
.
norm
())
dir
=
dir
*
(
lin
.
fld
.
norm
()
*
1e-5
/
dirder
.
norm
())
# Find a step length that leads to a "reasonable" location
for
i
in
range
(
50
):
try
:
loc2
=
loc
+
dir
lin2
=
op
(
Linearization
.
make_var
(
loc2
,
lin
.
want_metric
))
if
np
.
isfinite
(
lin2
.
val
.
s_sum
())
and
abs
(
lin2
.
val
.
s_sum
())
<
1e20
:
if
np
.
isfinite
(
lin2
.
fld
.
s_sum
())
and
abs
(
lin2
.
fld
.
s_sum
())
<
1e20
:
break
except
FloatingPointError
:
pass
...
...
@@ -244,7 +245,7 @@ def _linearization_value_consistency(op, loc):
for
wm
in
[
False
,
True
]:
lin
=
Linearization
.
make_var
(
loc
,
wm
)
fld0
=
op
(
loc
)
fld1
=
op
(
lin
).
val
fld1
=
op
(
lin
).
fld
assert_allclose
(
fld0
,
fld1
,
0
,
1e-7
)
...
...
@@ -283,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
locmid
=
loc
+
0.5
*
dir
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
dirder
=
linmid
.
jac
(
dir
)
numgrad
=
(
lin2
.
val
-
lin
.
val
)
numgrad
=
(
lin2
.
fld
-
lin
.
fld
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
hist
.
append
((
numgrad
-
dirder
).
norm
())
# print(len(hist),hist[-1])
...
...
nifty6/field.py
View file @
01779e03
...
...
@@ -147,6 +147,10 @@ class Field(Operator):
arr
=
generator_function
(
dtype
=
dtype
,
shape
=
domain
.
shape
,
**
kwargs
)
return
Field
(
domain
,
arr
)
@
property
def
fld
(
self
):
return
self
@
property
def
val
(
self
):
"""numpy.ndarray : the array storing the field's entries.
...
...
@@ -172,6 +176,11 @@ class Field(Operator):
"""DomainTuple : the field's domain"""
return
self
.
_domain
@
property
def
target
(
self
):
"""DomainTuple : the field's domain"""
return
self
.
_domain
@
property
def
shape
(
self
):
"""tuple of int : the concatenated shapes of all sub-domains"""
...
...
nifty6/library/light_cone_operator.py
View file @
01779e03
...
...
@@ -132,7 +132,7 @@ class LightConeOperator(Operator):
def
apply
(
self
,
x
):
lin
=
x
.
jac
is
not
None
a
,
derivs
=
_cone_arrays
(
x
.
val
.
val
if
lin
else
x
.
val
,
self
.
target
,
self
.
_sigx
,
lin
)
a
,
derivs
=
_cone_arrays
(
x
.
val
,
self
.
target
,
self
.
_sigx
,
lin
)
res
=
Field
(
self
.
target
,
a
)
if
not
lin
:
return
res
...
...
nifty6/library/special_distributions.py
View file @
01779e03
...
...
@@ -79,11 +79,10 @@ class _InterpolationOperator(Operator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
lin
=
x
.
jac
is
not
None
xval
=
x
.
val
.
val
if
lin
else
x
.
val
res
=
self
.
_interpolator
(
xval
)
res
=
self
.
_interpolator
(
x
.
val
)
res
=
Field
(
self
.
_domain
,
res
)
if
lin
:
res
=
x
.
new
(
res
,
makeOp
(
Field
(
self
.
_domain
,
self
.
_deriv
(
xval
))))
res
=
x
.
new
(
res
,
makeOp
(
Field
(
self
.
_domain
,
self
.
_deriv
(
x
.
val
))))
if
self
.
_inv_table_func
is
not
None
:
res
=
self
.
_inv_table_func
(
res
)
return
res
...
...
@@ -148,11 +147,10 @@ class UniformOperator(Operator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
lin
=
x
.
jac
is
not
None
xval
=
x
.
val
.
val
if
lin
else
x
.
val
res
=
Field
(
self
.
_target
,
self
.
_scale
*
norm
.
_cdf
(
xval
)
+
self
.
_loc
)
res
=
Field
(
self
.
_target
,
self
.
_scale
*
norm
.
_cdf
(
x
.
val
)
+
self
.
_loc
)
if
not
lin
:
return
res
jac
=
makeOp
(
Field
(
self
.
_domain
,
norm
.
_pdf
(
xval
)
*
self
.
_scale
))
jac
=
makeOp
(
Field
(
self
.
_domain
,
norm
.
_pdf
(
x
.
val
)
*
self
.
_scale
))
return
x
.
new
(
res
,
jac
)
def
inverse
(
self
,
field
):
...
...
nifty6/linearization.py
View file @
01779e03
...
...
@@ -29,7 +29,7 @@ class Linearization(Operator):
Parameters
----------
val
: Field or MultiField
fld
: Field or MultiField
The value of the operator application.
jac : LinearOperator
The Jacobian.
...
...
@@ -39,38 +39,38 @@ class Linearization(Operator):
If True, the metric will be computed for other Linearizations derived
from this one. Default: False.
"""
def
__init__
(
self
,
val
,
jac
,
metric
=
None
,
want_metric
=
False
):
self
.
_
val
=
val
def
__init__
(
self
,
fld
,
jac
,
metric
=
None
,
want_metric
=
False
):
self
.
_
fld
=
fld
self
.
_jac
=
jac
if
self
.
_
val
.
domain
!=
self
.
_jac
.
target
:
if
self
.
_
fld
.
domain
!=
self
.
_jac
.
target
:
raise
ValueError
(
"domain mismatch"
)
self
.
_want_metric
=
want_metric
self
.
_metric
=
metric
def
new
(
self
,
val
,
jac
,
metric
=
None
):
def
new
(
self
,
fld
,
jac
,
metric
=
None
):
"""Create a new Linearization, taking the `want_metric` property from
this one.
Parameters
----------
val
: Field or MultiField
fld
: Field or MultiField
the value of the operator application
jac : LinearOperator
the Jacobian
metric : LinearOperator or None
The metric. Default: None.
"""
return
Linearization
(
val
,
jac
,
metric
,
self
.
_want_metric
)
return
Linearization
(
fld
,
jac
,
metric
,
self
.
_want_metric
)
def
trivial_jac
(
self
):
return
self
.
make_var
(
self
.
_
val
,
self
.
_want_metric
)
return
self
.
make_var
(
self
.
_
fld
,
self
.
_want_metric
)
def
prepend_jac
(
self
,
jac
):
metric
=
None
if
self
.
_metric
is
not
None
:
from
.operators.sandwich_operator
import
SandwichOperator
metric
=
None
if
self
.
_metric
is
None
else
SandwichOperator
.
make
(
jac
,
self
.
_metric
)
return
self
.
new
(
self
.
_
val
,
self
.
_jac
@
jac
,
metric
)
return
self
.
new
(
self
.
_
fld
,
self
.
_jac
@
jac
,
metric
)
@
property
def
domain
(
self
):
...
...
@@ -82,10 +82,19 @@ class Linearization(Operator):
"""DomainTuple or MultiDomain : the Jacobian's target (i.e. the value's domain)"""
return
self
.
_jac
.
target
@
property
def
fld
(
self
):
"""Field or MultiField : the pure field-like part of this object"""
return
self
.
_fld
@
property
def
val
(
self
):
"""Field or MultiField : the value"""
return
self
.
_val
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return
self
.
_fld
.
val
def
val_rw
(
self
):
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return
self
.
_fld
.
val_rw
()
@
property
def
jac
(
self
):
...
...
@@ -119,30 +128,30 @@ class Linearization(Operator):
return
self
.
_metric
def
__getitem__
(
self
,
name
):
return
self
.
new
(
self
.
_
val
[
name
],
self
.
_jac
.
ducktape_left
(
name
))
return
self
.
new
(
self
.
_
fld
[
name
],
self
.
_jac
.
ducktape_left
(
name
))
def
__neg__
(
self
):
return
self
.
new
(
-
self
.
_
val
,
-
self
.
_jac
,
return
self
.
new
(
-
self
.
_
fld
,
-
self
.
_jac
,
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
def
conjugate
(
self
):
return
self
.
new
(
self
.
_
val
.
conjugate
(),
self
.
_jac
.
conjugate
(),
self
.
_
fld
.
conjugate
(),
self
.
_jac
.
conjugate
(),
None
if
self
.
_metric
is
None
else
self
.
_metric
.
conjugate
())
@
property
def
real
(
self
):
return
self
.
new
(
self
.
_
val
.
real
,
self
.
_jac
.
real
)
return
self
.
new
(
self
.
_
fld
.
real
,
self
.
_jac
.
real
)
def
_myadd
(
self
,
other
,
neg
):
if
np
.
isscalar
(
other
)
or
other
.
jac
is
None
:
return
self
.
new
(
self
.
_val
-
other
if
neg
else
self
.
_val
+
other
,
return
self
.
new
(
self
.
fld
-
other
if
neg
else
self
.
fld
+
other
,
self
.
_jac
,
self
.
_metric
)
met
=
None
if
self
.
_metric
is
not
None
and
other
.
_metric
is
not
None
:
met
=
self
.
_metric
.
_myadd
(
other
.
_metric
,
neg
)
return
self
.
new
(
self
.
val
.
flexible_addsub
(
other
.
val
,
neg
),
self
.
fld
.
flexible_addsub
(
other
.
fld
,
neg
),
self
.
jac
.
_myadd
(
other
.
jac
,
neg
),
met
)
def
__add__
(
self
,
other
):
...
...
@@ -175,18 +184,18 @@ class Linearization(Operator):
if
other
==
1
:
return
self
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
scale
(
other
)
return
self
.
new
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
return
self
.
new
(
self
.
fld
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
from
.sugar
import
makeOp
if
other
.
jac
is
None
:
if
self
.
target
!=
other
.
domain
:
raise
ValueError
(
"domain mismatch"
)
return
self
.
new
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
return
self
.
new
(
self
.
fld
*
other
,
makeOp
(
other
)(
self
.
_jac
))
if
self
.
target
!=
other
.
target
:
raise
ValueError
(
"domain mismatch"
)
return
self
.
new
(
self
.
val
*
other
.
val
,
(
makeOp
(
other
.
val
)(
self
.
jac
)).
_myadd
(
makeOp
(
self
.
val
)(
other
.
jac
),
False
))
self
.
fld
*
other
.
fld
,
(
makeOp
(
other
.
fld
)(
self
.
jac
)).
_myadd
(
makeOp
(
self
.
fld
)(
other
.
jac
),
False
))
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
...
...
@@ -208,12 +217,12 @@ class Linearization(Operator):
return
self
.
__mul__
(
other
)
from
.operators.outer_product_operator
import
OuterProduct
if
other
.
jac
is
None
:
return
self
.
new
(
OuterProduct
(
self
.
_
val
,
other
.
domain
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_
val
),
other
.
domain
))
return
self
.
new
(
OuterProduct
(
self
.
_
fld
,
other
.
domain
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_
fld
),
other
.
domain
))
return
self
.
new
(
OuterProduct
(
self
.
_
val
,
other
.
target
)(
other
.
_
val
),
OuterProduct
(
self
.
_jac
(
self
.
_
val
),
other
.
target
).
_myadd
(
OuterProduct
(
self
.
_
val
,
other
.
target
)(
other
.
_jac
),
False
))
OuterProduct
(
self
.
_
fld
,
other
.
target
)(
other
.
_
fld
),
OuterProduct
(
self
.
_jac
(
self
.
_
fld
),
other
.
target
).
_myadd
(
OuterProduct
(
self
.
_
fld
,
other
.
target
)(
other
.
_jac
),
False
))
def
vdot
(
self
,
other
):
"""Computes the inner product of this Linearization with a Field or
...
...
@@ -229,14 +238,18 @@ class Linearization(Operator):
the inner product of self and other
"""
from
.operators.simple_linear_operators
import
VdotOperator
if
other
is
self
:
return
self
.
new
(
self
.
_fld
.
vdot
(
self
.
_fld
),
VdotOperator
(
2
*
self
.
_fld
)(
self
.
_jac
))
if
other
.
jac
is
None
:
return
self
.
new
(
self
.
_
val
.
vdot
(
other
),
self
.
_
fld
.
vdot
(
other
),
VdotOperator
(
other
)(
self
.
_jac
))
return
self
.
new
(
self
.
_
val
.
vdot
(
other
.
_
val
),
VdotOperator
(
self
.
_
val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_
val
)(
self
.
_jac
))
self
.
_
fld
.
vdot
(
other
.
_
fld
),
VdotOperator
(
self
.
_
fld
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_
fld
)(
self
.
_jac
))
def
sum
(
self
,
spaces
=
None
):
"""Computes the (partial) sum over self
...
...
@@ -254,7 +267,7 @@ class Linearization(Operator):
"""
from
.operators.contraction_operator
import
ContractionOperator
return
self
.
new
(
self
.
_
val
.
sum
(
spaces
),
self
.
_
fld
.
sum
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
def
integrate
(
self
,
spaces
=
None
):
...
...
@@ -273,12 +286,12 @@ class Linearization(Operator):
"""
from
.operators.contraction_operator
import
ContractionOperator
return
self
.
new
(
self
.
_
val
.
integrate
(
spaces
),
self
.
_
fld
.
integrate
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
,
1
)(
self
.
_jac
))
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
from
.pointwise
import
ptw_dict
t1
,
t2
=
self
.
_
val
.
ptw_with_deriv
(
op
,
*
args
,
**
kwargs
)
t1
,
t2
=
self
.
_
fld
.
ptw_with_deriv
(
op
,
*
args
,
**
kwargs
)
return
self
.
new
(
t1
,
makeOp
(
t2
)(
self
.
_jac
))
def
clip
(
self
,
a_min
=
None
,
a_max
=
None
):
...
...
@@ -291,10 +304,10 @@ class Linearization(Operator):
return
self
.
ptw
(
"clip"
,
a_min
,
a_max
)
def
add_metric
(
self
,
metric
):
return
self
.
new
(
self
.
_
val
,
self
.
_jac
,
metric
)
return
self
.
new
(
self
.
_
fld
,
self
.
_jac
,
metric
)
def
with_want_metric
(
self
):
return
Linearization
(
self
.
_
val
,
self
.
_jac
,
self
.
_metric
,
True
)
return
Linearization
(
self
.
_
fld
,
self
.
_jac
,
self
.
_metric
,
True
)
@
staticmethod
def
make_var
(
field
,
want_metric
=
False
):
...
...
nifty6/minimization/energy_adapter.py
View file @
01779e03
...
...
@@ -47,7 +47,7 @@ class EnergyAdapter(Energy):
self
.
_want_metric
=
want_metric
lin
=
Linearization
.
make_partial_var
(
position
,
constants
,
want_metric
)
tmp
=
self
.
_op
(
lin
)
self
.
_val
=
tmp
.
val
.
val
[()]
self
.
_val
=
tmp
.
val
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
_metric
...
...
nifty6/minimization/metric_gaussian_kl.py
View file @
01779e03
...
...
@@ -198,10 +198,10 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
self
.
_lin
-
s
)
if
v
is
None
:
v
=
tmp
.
val
.
val
_rw
()
v
=
tmp
.
val_rw
()
g
=
tmp
.
gradient
else
:
v
+=
tmp
.
val
.
val
v
+=
tmp
.
val
g
=
g
+
tmp
.
gradient
self
.
_val
=
_np_allreduce_sum
(
self
.
_comm
,
v
)[()]
/
self
.
_n_eff_samples
self
.
_grad
=
_allreduce_sum_field
(
self
.
_comm
,
g
)
/
self
.
_n_eff_samples
...
...
nifty6/multi_field.py
View file @
01779e03
...
...
@@ -83,6 +83,10 @@ class MultiField(Operator):
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_domain
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
...
...
@@ -136,6 +140,10 @@ class MultiField(Operator):
return
MultiField
(
domain
,
tuple
(
Field
(
dom
,
val
)
for
dom
in
domain
.
_domains
))
@
property
def
fld
(
self
):
return
self
@
property
def
val
(
self
):
return
{
key
:
val
.
val
...
...
nifty6/operators/energy_operators.py
View file @
01779e03
...
...
@@ -58,10 +58,10 @@ class Squared2NormOperator(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
x
.
fld
.
vdot
(
x
.
fld
)
if
x
.
jac
is
None
:
return
x
.
vdot
(
x
)
res
=
x
.
val
.
vdot
(
x
.
val
)
return
x
.
new
(
res
,
VdotOperator
(
2
*
x
.
val
))
return
res
return
x
.
new
(
res
,
VdotOperator
(
2
*
x
.
fld
))
class
QuadraticFormOperator
(
EnergyOperator
):
...
...
@@ -86,10 +86,10 @@ class QuadraticFormOperator(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res
=
0.5
*
x
.
fld
.
vdot
(
self
.
_op
(
x
.
fld
))
if
x
.
jac
is
None
:
return
0.5
*
x
.
vdot
(
self
.
_op
(
x
))
res
=
0.5
*
x
.
val
.
vdot
(
self
.
_op
(
x
.
val
))
return
x
.
new
(
res
,
VdotOperator
(
self
.
_op
(
x
.
val
)))
return
res
return
x
.
new
(
res
,
VdotOperator
(
self
.
_op
(
x
.
fld
)))
class
VariableCovarianceGaussianEnergy
(
EnergyOperator
):
...
...
@@ -128,7 +128,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res
=
0.5
*
(
x
[
self
.
_r
].
vdot
(
x
[
self
.
_r
]
*
x
[
self
.
_icov
]).
real
-
x
[
self
.
_icov
].
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
return
res
mf
=
{
self
.
_r
:
x
.
val
[
self
.
_icov
],
self
.
_icov
:
.
5
*
x
.
val
[
self
.
_icov
]
**
(
-
2
)}
mf
=
{
self
.
_r
:
x
.
fld
[
self
.
_icov
],
self
.
_icov
:
.
5
*
x
.
fld
[
self
.
_icov
]
**
(
-
2
)}
return
res
.
add_metric
(
makeOp
(
MultiField
.
from_dict
(
mf
)))
...
...
@@ -230,7 +230,7 @@ class PoissonianEnergy(EnergyOperator):
res
=
x
.
sum
()
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
makeOp
(
1.
/
x
.
val
))
return
res
.
add_metric
(
makeOp
(
x
.
fld
.
ptw
(
"reciprocal"
)
))
class
InverseGammaLikelihood
(
EnergyOperator
):
...
...
@@ -270,7 +270,7 @@ class InverseGammaLikelihood(EnergyOperator):
res
=
x
.
ptw
(
"log"
).
vdot
(
self
.
_alphap1
)
+
x
.
ptw
(
"reciprocal"
).
vdot
(
self
.
_beta
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
makeOp
(
self
.
_alphap1
/
(
x
.
val
**
2
)))
return
res
.
add_metric
(
makeOp
(
self
.
_alphap1
/
(
x
.
fld
**
2
)))
class
StudentTEnergy
(
EnergyOperator
):
...
...
@@ -333,7 +333,7 @@ class BernoulliEnergy(EnergyOperator):
res
=
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
+
(
1.
-
x
).
ptw
(
"log"
).
vdot
(
self
.
_d
-
1.
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
makeOp
(
1.
/
(
x
.
val
*
(
1.
-
x
.
val
))))
return
res
.
add_metric
(
makeOp
(
1.
/
(
x
.
fld
*
(
1.
-
x
.
fld
))))
class
StandardHamiltonian
(
EnergyOperator
):
...
...
nifty6/operators/linear_operator.py
View file @
01779e03
...
...
@@ -172,7 +172,7 @@ class LinearOperator(Operator):
"""Same as :meth:`times`"""
from
..linearization
import
Linearization
if
x
.
jac
is
not
None
:
return
x
.
new
(
self
(
x
.
_val
),
self
).
prepend_jac
(
x
.
jac
)
return
x
.
new
(
self
(
x
.
fld
),
self
).
prepend_jac
(
x
.
jac
)
if
x
.
val
is
not
None
:
return
self
.
apply
(
x
,
self
.
TIMES
)
return
self
@
x
...
...
nifty6/operators/operator.py
View file @
01779e03
...
...
@@ -45,11 +45,23 @@ class Operator(metaclass=NiftyMeta):
"""
return
self
.
_target
@
property
def
fld
(
self
):
"""The field associated with this object
For "pure" operators this is `None`. For Field-like objects this
is a `Field` or a `MultiField` matching the object's `target`.
Returns
-------
None or Field or MultiField : the field object
"""
return
None
@
property
def
val
(
self
):
"""The numerical value associated with this object
For "pure" operators this is `None`. For Field-like objects this
is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s mat
h
cing the
is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s matc
h
ing the
object's `target`.
Returns
...
...
@@ -421,16 +433,16 @@ class _OpProd(Operator):
from
..sugar
import
makeOp
self
.
_check_input
(
x
)
lin
=
x
.
jac
is
not
None
wm
=
x
.
want_metric
if
lin
else
False
x
=
x
.
val
if
lin
else
x
wm
=
x
.
want_metric
x
=
x
.
fld
if
lin
else
x
v1
=
x
.
extract
(
self
.
_op1
.
domain
)
v2
=
x
.
extract
(
self
.
_op2
.
domain
)
if
not
lin
:
return
self
.
_op1
(
v1
)
*
self
.
_op2
(
v2
)
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
,
wm
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
,
wm
))
jac
=
(
makeOp
(
lin1
.
_
val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_
val
)(
lin1
.
_jac
),
False
)
return
lin1
.
new
(
lin1
.
_
val
*
lin2
.
_
val
,
jac
)
jac
=
(
makeOp
(
lin1
.
_
fld
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_
fld
)(
lin1
.
_jac
),
False
)
return
lin1
.
new
(
lin1
.
_
fld
*
lin2
.
_
fld
,
jac
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
...
...
@@ -467,13 +479,13 @@ class _OpSum(Operator):
v1
=
x
.
extract
(
self
.
_op1
.
domain
)
v2
=
x
.
extract
(
self
.
_op2
.
domain
)
return
self
.
_op1
(
v1
).
unite
(
self
.
_op2
(
v2
))
v1
=
x
.
val
.
extract
(
self
.
_op1
.
domain
)
v2
=
x
.
val
.
extract
(
self
.
_op2
.
domain
)
v1
=
x
.
fld
.
extract
(
self
.
_op1
.
domain
)
v2
=
x
.
fld
.
extract
(
self
.
_op2
.
domain
)
wm
=
x
.
want_metric
lin1
=
self
.
_op1
(
Linearization
.
make_var
(
v1
,
wm
))
lin2
=
self
.
_op2
(
Linearization
.
make_var
(
v2
,
wm
))
op
=
lin1
.
_jac
.
_myadd
(
lin2
.
_jac
,
False
)
res
=
lin1
.
new
(
lin1
.
_
val
.
unite
(
lin2
.
_
val
),
op
)
res
=
lin1
.
new
(
lin1
.
_
fld
.
unite
(
lin2
.
_
fld
),
op
)
if
lin1
.
_metric
is
not
None
and
lin2
.
_metric
is
not
None
:
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
return
res
...
...
test/test_linearization.py
View file @
01779e03
...
...
@@ -63,8 +63,8 @@ def test_actual_gradients(f):
eps
=
1e-8
var0
=
ift
.
Linearization
.
make_var
(
fld
)
var1
=
ift
.
Linearization
.
make_var
(
fld
+
eps
)
f0
=
var0
.
ptw
(
f
).
val
.
val
f1
=
var1
.
ptw
(
f
).
val
.
val
f0
=
var0
.
ptw
(
f
).
val
f1
=
var1
.
ptw
(
f
).
val
df0
=
(
f1
-
f0
)
/
eps
df1
=
_lin2grad
(
var0
.
ptw
(
f
))
assert_allclose
(
df0
,
df1
,
rtol
=
100
*
eps
)
test/test_operators/test_jacobian.py
View file @
01779e03
...
...
@@ -43,7 +43,7 @@ def testBasics(space, seed):
s
=
S
.
draw_sample
()
var
=
ift
.
Linearization
.
make_var
(
s
)
model
=
ift
.
ScalingOperator
(
var
.
target
,
6.
)
ift
.
extra
.
check_jacobian_consistency
(
model
,
var
.
val
)
ift
.
extra
.
check_jacobian_consistency
(
model
,
var
.
fld
)
@
pmp
(
'type1'
,
[
'Variable'
,
'Constant'
])
...
...
test/test_operators/test_simplification.py
View file @
01779e03
...
...
@@ -46,10 +46,10 @@ def test_simplification():
o2
.
ducktape
(
"b"
).
ducktape_left
(
"b"
))
_
,
op2
=
op
.
simplify_for_constant_input
(
f2
)
assert_equal
(
isinstance
(
op2
.
_op1
,
_ConstantOperator
),
True
)
assert_allclose
(
op
(
f1
)[
"a"
]
.
val
,
op2
(
f1
)[
"a"
]
.
val
)
assert_allclose
(
op
(
f1
)[
"b"
]
.
val
,
op2
(
f1
)[
"b"
]
.
val
)
assert_allclose
(
op
(
f1
)
.
val
[
"a"
],
op2
(
f1
)
.
val
[
"a"
])
assert_allclose
(
op
(
f1
)
.
val
[
"b"
],
op2
(
f1
)
.
val
[
"b"
])
lin
=
ift
.
Linearization
.
make_var
(
ift
.
MultiField
.
full
(
op2
.
domain
,
2.
),
True
)
assert_allclose
(
op
(
lin
).
val
[
"a"
]
.
val
,
op2
(
lin
).
val
[
"a"
]
.
val
)
assert_allclose
(
op
(
lin
).
val
[
"b"
]
.
val
,
op2
(
lin
).
val
[
"b"
]
.
val
)
assert_allclose
(
op
(
lin
).
val
[
"a"
],
op2
(
lin
).
val
[
"a"
])
assert_allclose
(
op
(
lin
).
val
[
"b"
],
op2
(
lin
).
val
[
"b"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment