Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
01779e03
Commit
01779e03
authored
Apr 08, 2020
by
Martin Reinecke
Browse files
more
parent
f24e26e9
Changes
14
Hide whitespace changes
Inline
Side-by-side
nifty6/extra.py
View file @
01779e03
...
@@ -32,9 +32,10 @@ __all__ = ["consistency_check", "check_jacobian_consistency",
...
@@ -32,9 +32,10 @@ __all__ = ["consistency_check", "check_jacobian_consistency",
def
assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
def
assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
val
,
f2
.
val
,
atol
=
atol
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
f1
.
val
,
f2
.
val
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
else
:
assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
,
def
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
,
...
@@ -103,10 +104,10 @@ def _actual_domain_check_nonlinear(op, loc):
...
@@ -103,10 +104,10 @@ def _actual_domain_check_nonlinear(op, loc):
reslin
=
op
(
lin
)
reslin
=
op
(
lin
)
assert_
(
lin
.
domain
is
op
.
domain
)
assert_
(
lin
.
domain
is
op
.
domain
)
assert_
(
lin
.
target
is
op
.
domain
)
assert_
(
lin
.
target
is
op
.
domain
)
assert_
(
lin
.
val
.
domain
is
lin
.
domain
)
assert_
(
lin
.
fld
.
domain
is
lin
.
domain
)
assert_
(
reslin
.
domain
is
op
.
domain
)
assert_
(
reslin
.
domain
is
op
.
domain
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
val
.
domain
is
reslin
.
target
)
assert_
(
reslin
.
fld
.
domain
is
reslin
.
target
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
jac
.
domain
is
reslin
.
domain
)
assert_
(
reslin
.
jac
.
domain
is
reslin
.
domain
)
assert_
(
reslin
.
jac
.
target
is
reslin
.
target
)
assert_
(
reslin
.
jac
.
target
is
reslin
.
target
)
...
@@ -150,7 +151,7 @@ def _performance_check(op, pos, raise_on_fail):
...
@@ -150,7 +151,7 @@ def _performance_check(op, pos, raise_on_fail):
cond
.
append
(
cop
.
count
!=
2
)
cond
.
append
(
cop
.
count
!=
2
)
lin
.
jac
(
pos
)
lin
.
jac
(
pos
)
cond
.
append
(
cop
.
count
!=
3
)
cond
.
append
(
cop
.
count
!=
3
)
lin
.
jac
.
adjoint
(
lin
.
val
)
lin
.
jac
.
adjoint
(
lin
.
fld
)
cond
.
append
(
cop
.
count
!=
4
)
cond
.
append
(
cop
.
count
!=
4
)
if
lin
.
metric
is
not
None
:
if
lin
.
metric
is
not
None
:
lin
.
metric
(
pos
)
lin
.
metric
(
pos
)
...
@@ -217,20 +218,20 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
...
@@ -217,20 +218,20 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def
_get_acceptable_location
(
op
,
loc
,
lin
):
def
_get_acceptable_location
(
op
,
loc
,
lin
):
if
not
np
.
isfinite
(
lin
.
val
.
s_sum
()):
if
not
np
.
isfinite
(
lin
.
fld
.
s_sum
()):
raise
ValueError
(
'Initial value must be finite'
)
raise
ValueError
(
'Initial value must be finite'
)
dir
=
from_random
(
"normal"
,
loc
.
domain
)
dir
=
from_random
(
"normal"
,
loc
.
domain
)
dirder
=
lin
.
jac
(
dir
)
dirder
=
lin
.
jac
(
dir
)
if
dirder
.
norm
()
==
0
:
if
dirder
.
norm
()
==
0
:
dir
=
dir
*
(
lin
.
val
.
norm
()
*
1e-5
)
dir
=
dir
*
(
lin
.
fld
.
norm
()
*
1e-5
)
else
:
else
:
dir
=
dir
*
(
lin
.
val
.
norm
()
*
1e-5
/
dirder
.
norm
())
dir
=
dir
*
(
lin
.
fld
.
norm
()
*
1e-5
/
dirder
.
norm
())
# Find a step length that leads to a "reasonable" location
# Find a step length that leads to a "reasonable" location
for
i
in
range
(
50
):
for
i
in
range
(
50
):
try
:
try
:
loc2
=
loc
+
dir
loc2
=
loc
+
dir
lin2
=
op
(
Linearization
.
make_var
(
loc2
,
lin
.
want_metric
))
lin2
=
op
(
Linearization
.
make_var
(
loc2
,
lin
.
want_metric
))
if
np
.
isfinite
(
lin2
.
val
.
s_sum
())
and
abs
(
lin2
.
val
.
s_sum
())
<
1e20
:
if
np
.
isfinite
(
lin2
.
fld
.
s_sum
())
and
abs
(
lin2
.
fld
.
s_sum
())
<
1e20
:
break
break
except
FloatingPointError
:
except
FloatingPointError
:
pass
pass
...
@@ -244,7 +245,7 @@ def _linearization_value_consistency(op, loc):
...
@@ -244,7 +245,7 @@ def _linearization_value_consistency(op, loc):
for
wm
in
[
False
,
True
]:
for
wm
in
[
False
,
True
]:
lin
=
Linearization
.
make_var
(
loc
,
wm
)
lin
=
Linearization
.
make_var
(
loc
,
wm
)
fld0
=
op
(
loc
)
fld0
=
op
(
loc
)
fld1
=
op
(
lin
).
val
fld1
=
op
(
lin
).
fld
assert_allclose
(
fld0
,
fld1
,
0
,
1e-7
)
assert_allclose
(
fld0
,
fld1
,
0
,
1e-7
)
...
@@ -283,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
...
@@ -283,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
locmid
=
loc
+
0.5
*
dir
locmid
=
loc
+
0.5
*
dir
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
linmid
=
op
(
Linearization
.
make_var
(
locmid
))
dirder
=
linmid
.
jac
(
dir
)
dirder
=
linmid
.
jac
(
dir
)
numgrad
=
(
lin2
.
val
-
lin
.
val
)
numgrad
=
(
lin2
.
fld
-
lin
.
fld
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
xtol
=
tol
*
dirder
.
norm
()
/
np
.
sqrt
(
dirder
.
size
)
hist
.
append
((
numgrad
-
dirder
).
norm
())
hist
.
append
((
numgrad
-
dirder
).
norm
())
# print(len(hist),hist[-1])
# print(len(hist),hist[-1])
...
...
nifty6/field.py
View file @
01779e03
...
@@ -147,6 +147,10 @@ class Field(Operator):
...
@@ -147,6 +147,10 @@ class Field(Operator):
arr
=
generator_function
(
dtype
=
dtype
,
shape
=
domain
.
shape
,
**
kwargs
)
arr
=
generator_function
(
dtype
=
dtype
,
shape
=
domain
.
shape
,
**
kwargs
)
return
Field
(
domain
,
arr
)
return
Field
(
domain
,
arr
)
@
property
def
fld
(
self
):
return
self
@
property
@
property
def
val
(
self
):
def
val
(
self
):
"""numpy.ndarray : the array storing the field's entries.
"""numpy.ndarray : the array storing the field's entries.
...
@@ -172,6 +176,11 @@ class Field(Operator):
...
@@ -172,6 +176,11 @@ class Field(Operator):
"""DomainTuple : the field's domain"""
"""DomainTuple : the field's domain"""
return
self
.
_domain
return
self
.
_domain
@
property
def
target
(
self
):
"""DomainTuple : the field's domain"""
return
self
.
_domain
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
"""tuple of int : the concatenated shapes of all sub-domains"""
"""tuple of int : the concatenated shapes of all sub-domains"""
...
...
nifty6/library/light_cone_operator.py
View file @
01779e03
...
@@ -132,7 +132,7 @@ class LightConeOperator(Operator):
...
@@ -132,7 +132,7 @@ class LightConeOperator(Operator):
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
lin
=
x
.
jac
is
not
None
lin
=
x
.
jac
is
not
None
a
,
derivs
=
_cone_arrays
(
x
.
val
.
val
if
lin
else
x
.
val
,
self
.
target
,
self
.
_sigx
,
lin
)
a
,
derivs
=
_cone_arrays
(
x
.
val
,
self
.
target
,
self
.
_sigx
,
lin
)
res
=
Field
(
self
.
target
,
a
)
res
=
Field
(
self
.
target
,
a
)
if
not
lin
:
if
not
lin
:
return
res
return
res
...
...
nifty6/library/special_distributions.py
View file @
01779e03
...
@@ -79,11 +79,10 @@ class _InterpolationOperator(Operator):
...
@@ -79,11 +79,10 @@ class _InterpolationOperator(Operator):
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
lin
=
x
.
jac
is
not
None
lin
=
x
.
jac
is
not
None
xval
=
x
.
val
.
val
if
lin
else
x
.
val
res
=
self
.
_interpolator
(
x
.
val
)
res
=
self
.
_interpolator
(
xval
)
res
=
Field
(
self
.
_domain
,
res
)
res
=
Field
(
self
.
_domain
,
res
)
if
lin
:
if
lin
:
res
=
x
.
new
(
res
,
makeOp
(
Field
(
self
.
_domain
,
self
.
_deriv
(
xval
))))
res
=
x
.
new
(
res
,
makeOp
(
Field
(
self
.
_domain
,
self
.
_deriv
(
x
.
val
))))
if
self
.
_inv_table_func
is
not
None
:
if
self
.
_inv_table_func
is
not
None
:
res
=
self
.
_inv_table_func
(
res
)
res
=
self
.
_inv_table_func
(
res
)
return
res
return
res
...
@@ -148,11 +147,10 @@ class UniformOperator(Operator):
...
@@ -148,11 +147,10 @@ class UniformOperator(Operator):
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
lin
=
x
.
jac
is
not
None
lin
=
x
.
jac
is
not
None
xval
=
x
.
val
.
val
if
lin
else
x
.
val
res
=
Field
(
self
.
_target
,
self
.
_scale
*
norm
.
_cdf
(
x
.
val
)
+
self
.
_loc
)
res
=
Field
(
self
.
_target
,
self
.
_scale
*
norm
.
_cdf
(
xval
)
+
self
.
_loc
)
if
not
lin
:
if
not
lin
:
return
res
return
res
jac
=
makeOp
(
Field
(
self
.
_domain
,
norm
.
_pdf
(
xval
)
*
self
.
_scale
))
jac
=
makeOp
(
Field
(
self
.
_domain
,
norm
.
_pdf
(
x
.
val
)
*
self
.
_scale
))
return
x
.
new
(
res
,
jac
)
return
x
.
new
(
res
,
jac
)
def
inverse
(
self
,
field
):
def
inverse
(
self
,
field
):
...
...
nifty6/linearization.py
View file @
01779e03
...
@@ -29,7 +29,7 @@ class Linearization(Operator):
...
@@ -29,7 +29,7 @@ class Linearization(Operator):
Parameters
Parameters
----------
----------
val
: Field or MultiField
fld
: Field or MultiField
The value of the operator application.
The value of the operator application.
jac : LinearOperator
jac : LinearOperator
The Jacobian.
The Jacobian.
...
@@ -39,38 +39,38 @@ class Linearization(Operator):
...
@@ -39,38 +39,38 @@ class Linearization(Operator):
If True, the metric will be computed for other Linearizations derived
If True, the metric will be computed for other Linearizations derived
from this one. Default: False.
from this one. Default: False.
"""
"""
def
__init__
(
self
,
val
,
jac
,
metric
=
None
,
want_metric
=
False
):
def
__init__
(
self
,
fld
,
jac
,
metric
=
None
,
want_metric
=
False
):
self
.
_
val
=
val
self
.
_
fld
=
fld
self
.
_jac
=
jac
self
.
_jac
=
jac
if
self
.
_
val
.
domain
!=
self
.
_jac
.
target
:
if
self
.
_
fld
.
domain
!=
self
.
_jac
.
target
:
raise
ValueError
(
"domain mismatch"
)
raise
ValueError
(
"domain mismatch"
)
self
.
_want_metric
=
want_metric
self
.
_want_metric
=
want_metric
self
.
_metric
=
metric
self
.
_metric
=
metric
def
new
(
self
,
val
,
jac
,
metric
=
None
):
def
new
(
self
,
fld
,
jac
,
metric
=
None
):
"""Create a new Linearization, taking the `want_metric` property from
"""Create a new Linearization, taking the `want_metric` property from
this one.
this one.
Parameters
Parameters
----------
----------
val
: Field or MultiField
fld
: Field or MultiField
the value of the operator application
the value of the operator application
jac : LinearOperator
jac : LinearOperator
the Jacobian
the Jacobian
metric : LinearOperator or None
metric : LinearOperator or None
The metric. Default: None.
The metric. Default: None.
"""
"""
return
Linearization
(
val
,
jac
,
metric
,
self
.
_want_metric
)
return
Linearization
(
fld
,
jac
,
metric
,
self
.
_want_metric
)
def
trivial_jac
(
self
):
def
trivial_jac
(
self
):
return
self
.
make_var
(
self
.
_
val
,
self
.
_want_metric
)
return
self
.
make_var
(
self
.
_
fld
,
self
.
_want_metric
)
def
prepend_jac
(
self
,
jac
):
def
prepend_jac
(
self
,
jac
):
metric
=
None
metric
=
None
if
self
.
_metric
is
not
None
:
if
self
.
_metric
is
not
None
:
from
.operators.sandwich_operator
import
SandwichOperator
from
.operators.sandwich_operator
import
SandwichOperator
metric
=
None
if
self
.
_metric
is
None
else
SandwichOperator
.
make
(
jac
,
self
.
_metric
)
metric
=
None
if
self
.
_metric
is
None
else
SandwichOperator
.
make
(
jac
,
self
.
_metric
)
return
self
.
new
(
self
.
_
val
,
self
.
_jac
@
jac
,
metric
)
return
self
.
new
(
self
.
_
fld
,
self
.
_jac
@
jac
,
metric
)
@
property
@
property
def
domain
(
self
):
def
domain
(
self
):
...
@@ -82,10 +82,19 @@ class Linearization(Operator):
...
@@ -82,10 +82,19 @@ class Linearization(Operator):
"""DomainTuple or MultiDomain : the Jacobian's target (i.e. the value's domain)"""
"""DomainTuple or MultiDomain : the Jacobian's target (i.e. the value's domain)"""
return
self
.
_jac
.
target
return
self
.
_jac
.
target
@
property
def
fld
(
self
):
"""Field or MultiField : the pure field-like part of this object"""
return
self
.
_fld
@
property
@
property
def
val
(
self
):
def
val
(
self
):
"""Field or MultiField : the value"""
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return
self
.
_val
return
self
.
_fld
.
val
def
val_rw
(
self
):
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return
self
.
_fld
.
val_rw
()
@
property
@
property
def
jac
(
self
):
def
jac
(
self
):
...
@@ -119,30 +128,30 @@ class Linearization(Operator):
...
@@ -119,30 +128,30 @@ class Linearization(Operator):
return
self
.
_metric
return
self
.
_metric
def
__getitem__
(
self
,
name
):
def
__getitem__
(
self
,
name
):
return
self
.
new
(
self
.
_
val
[
name
],
self
.
_jac
.
ducktape_left
(
name
))
return
self
.
new
(
self
.
_
fld
[
name
],
self
.
_jac
.
ducktape_left
(
name
))
def
__neg__
(
self
):
def
__neg__
(
self
):
return
self
.
new
(
-
self
.
_
val
,
-
self
.
_jac
,
return
self
.
new
(
-
self
.
_
fld
,
-
self
.
_jac
,
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
None
if
self
.
_metric
is
None
else
-
self
.
_metric
)
def
conjugate
(
self
):
def
conjugate
(
self
):
return
self
.
new
(
return
self
.
new
(
self
.
_
val
.
conjugate
(),
self
.
_jac
.
conjugate
(),
self
.
_
fld
.
conjugate
(),
self
.
_jac
.
conjugate
(),
None
if
self
.
_metric
is
None
else
self
.
_metric
.
conjugate
())
None
if
self
.
_metric
is
None
else
self
.
_metric
.
conjugate
())
@
property
@
property
def
real
(
self
):
def
real
(
self
):
return
self
.
new
(
self
.
_
val
.
real
,
self
.
_jac
.
real
)
return
self
.
new
(
self
.
_
fld
.
real
,
self
.
_jac
.
real
)
def
_myadd
(
self
,
other
,
neg
):
def
_myadd
(
self
,
other
,
neg
):
if
np
.
isscalar
(
other
)
or
other
.
jac
is
None
:
if
np
.
isscalar
(
other
)
or
other
.
jac
is
None
:
return
self
.
new
(
self
.
_val
-
other
if
neg
else
self
.
_val
+
other
,
return
self
.
new
(
self
.
fld
-
other
if
neg
else
self
.
fld
+
other
,
self
.
_jac
,
self
.
_metric
)
self
.
_jac
,
self
.
_metric
)
met
=
None
met
=
None
if
self
.
_metric
is
not
None
and
other
.
_metric
is
not
None
:
if
self
.
_metric
is
not
None
and
other
.
_metric
is
not
None
:
met
=
self
.
_metric
.
_myadd
(
other
.
_metric
,
neg
)
met
=
self
.
_metric
.
_myadd
(
other
.
_metric
,
neg
)
return
self
.
new
(
return
self
.
new
(
self
.
val
.
flexible_addsub
(
other
.
val
,
neg
),
self
.
fld
.
flexible_addsub
(
other
.
fld
,
neg
),
self
.
jac
.
_myadd
(
other
.
jac
,
neg
),
met
)
self
.
jac
.
_myadd
(
other
.
jac
,
neg
),
met
)
def
__add__
(
self
,
other
):
def
__add__
(
self
,
other
):
...
@@ -175,18 +184,18 @@ class Linearization(Operator):
...
@@ -175,18 +184,18 @@ class Linearization(Operator):
if
other
==
1
:
if
other
==
1
:
return
self
return
self
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
scale
(
other
)
met
=
None
if
self
.
_metric
is
None
else
self
.
_metric
.
scale
(
other
)
return
self
.
new
(
self
.
_val
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
return
self
.
new
(
self
.
fld
*
other
,
self
.
_jac
.
scale
(
other
),
met
)
from
.sugar
import
makeOp
from
.sugar
import
makeOp
if
other
.
jac
is
None
:
if
other
.
jac
is
None
:
if
self
.
target
!=
other
.
domain
:
if
self
.
target
!=
other
.
domain
:
raise
ValueError
(
"domain mismatch"
)
raise
ValueError
(
"domain mismatch"
)
return
self
.
new
(
self
.
_val
*
other
,
makeOp
(
other
)(
self
.
_jac
))
return
self
.
new
(
self
.
fld
*
other
,
makeOp
(
other
)(
self
.
_jac
))
if
self
.
target
!=
other
.
target
:
if
self
.
target
!=
other
.
target
:
raise
ValueError
(
"domain mismatch"
)
raise
ValueError
(
"domain mismatch"
)
return
self
.
new
(
return
self
.
new
(
self
.
val
*
other
.
val
,
self
.
fld
*
other
.
fld
,
(
makeOp
(
other
.
val
)(
self
.
jac
)).
_myadd
(
(
makeOp
(
other
.
fld
)(
self
.
jac
)).
_myadd
(
makeOp
(
self
.
val
)(
other
.
jac
),
False
))
makeOp
(
self
.
fld
)(
other
.
jac
),
False
))
def
__rmul__
(
self
,
other
):
def
__rmul__
(
self
,
other
):
return
self
.
__mul__
(
other
)
return
self
.
__mul__
(
other
)
...
@@ -208,12 +217,12 @@ class Linearization(Operator):
...
@@ -208,12 +217,12 @@ class Linearization(Operator):
return
self
.
__mul__
(
other
)
return
self
.
__mul__
(
other
)
from
.operators.outer_product_operator
import
OuterProduct
from
.operators.outer_product_operator
import
OuterProduct
if
other
.
jac
is
None
:
if
other
.
jac
is
None
:
return
self
.
new
(
OuterProduct
(
self
.
_
val
,
other
.
domain
)(
other
),
return
self
.
new
(
OuterProduct
(
self
.
_
fld
,
other
.
domain
)(
other
),
OuterProduct
(
self
.
_jac
(
self
.
_
val
),
other
.
domain
))
OuterProduct
(
self
.
_jac
(
self
.
_
fld
),
other
.
domain
))
return
self
.
new
(
return
self
.
new
(
OuterProduct
(
self
.
_
val
,
other
.
target
)(
other
.
_
val
),
OuterProduct
(
self
.
_
fld
,
other
.
target
)(
other
.
_
fld
),
OuterProduct
(
self
.
_jac
(
self
.
_
val
),
other
.
target
).
_myadd
(
OuterProduct
(
self
.
_jac
(
self
.
_
fld
),
other
.
target
).
_myadd
(
OuterProduct
(
self
.
_
val
,
other
.
target
)(
other
.
_jac
),
False
))
OuterProduct
(
self
.
_
fld
,
other
.
target
)(
other
.
_jac
),
False
))
def
vdot
(
self
,
other
):
def
vdot
(
self
,
other
):
"""Computes the inner product of this Linearization with a Field or
"""Computes the inner product of this Linearization with a Field or
...
@@ -229,14 +238,18 @@ class Linearization(Operator):
...
@@ -229,14 +238,18 @@ class Linearization(Operator):
the inner product of self and other
the inner product of self and other
"""
"""
from
.operators.simple_linear_operators
import
VdotOperator
from
.operators.simple_linear_operators
import
VdotOperator
if
other
is
self
:
return
self
.
new
(
self
.
_fld
.
vdot
(
self
.
_fld
),
VdotOperator
(
2
*
self
.
_fld
)(
self
.
_jac
))
if
other
.
jac
is
None
:
if
other
.
jac
is
None
:
return
self
.
new
(
return
self
.
new
(
self
.
_
val
.
vdot
(
other
),
self
.
_
fld
.
vdot
(
other
),
VdotOperator
(
other
)(
self
.
_jac
))
VdotOperator
(
other
)(
self
.
_jac
))
return
self
.
new
(
return
self
.
new
(
self
.
_
val
.
vdot
(
other
.
_
val
),
self
.
_
fld
.
vdot
(
other
.
_
fld
),
VdotOperator
(
self
.
_
val
)(
other
.
_jac
)
+
VdotOperator
(
self
.
_
fld
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_
val
)(
self
.
_jac
))
VdotOperator
(
other
.
_
fld
)(
self
.
_jac
))
def
sum
(
self
,
spaces
=
None
):
def
sum
(
self
,
spaces
=
None
):
"""Computes the (partial) sum over self
"""Computes the (partial) sum over self
...
@@ -254,7 +267,7 @@ class Linearization(Operator):
...
@@ -254,7 +267,7 @@ class Linearization(Operator):
"""
"""
from
.operators.contraction_operator
import
ContractionOperator
from
.operators.contraction_operator
import
ContractionOperator
return
self
.
new
(
return
self
.
new
(
self
.
_
val
.
sum
(
spaces
),
self
.
_
fld
.
sum
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
ContractionOperator
(
self
.
_jac
.
target
,
spaces
)(
self
.
_jac
))
def
integrate
(
self
,
spaces
=
None
):
def
integrate
(
self
,
spaces
=
None
):
...
@@ -273,12 +286,12 @@ class Linearization(Operator):
...
@@ -273,12 +286,12 @@ class Linearization(Operator):
"""
"""
from
.operators.contraction_operator
import
ContractionOperator
from
.operators.contraction_operator
import
ContractionOperator
return
self
.
new
(
return
self
.
new
(
self
.
_
val
.
integrate
(
spaces
),
self
.
_
fld
.
integrate
(
spaces
),
ContractionOperator
(
self
.
_jac
.
target
,
spaces
,
1
)(
self
.
_jac
))
ContractionOperator
(
self
.
_jac
.
target
,
spaces
,
1
)(
self
.
_jac
))
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
from
.pointwise
import
ptw_dict
from
.pointwise
import
ptw_dict
t1
,
t2
=
self
.
_
val
.
ptw_with_deriv
(
op
,
*
args
,
**
kwargs
)
t1
,
t2
=
self
.
_
fld
.
ptw_with_deriv
(
op
,
*
args
,
**
kwargs
)
return
self
.
new
(
t1
,
makeOp
(
t2
)(
self
.
_jac
))
return
self
.
new
(
t1
,
makeOp
(
t2
)(
self
.
_jac
))
def
clip
(
self
,
a_min
=
None
,
a_max
=
None
):
def
clip
(
self
,
a_min
=
None
,
a_max
=
None
):
...
@@ -291,10 +304,10 @@ class Linearization(Operator):
...
@@ -291,10 +304,10 @@ class Linearization(Operator):
return
self
.
ptw
(
"clip"
,
a_min
,
a_max
)
return
self
.
ptw
(
"clip"
,
a_min
,
a_max
)
def
add_metric
(
self
,
metric
):
def
add_metric
(
self
,
metric
):
return
self
.
new
(
self
.
_
val
,
self
.
_jac
,
metric
)
return
self
.
new
(
self
.
_
fld
,
self
.
_jac
,
metric
)
def
with_want_metric
(
self
):
def
with_want_metric
(
self
):
return
Linearization
(
self
.
_
val
,
self
.
_jac
,
self
.
_metric
,
True
)
return
Linearization
(
self
.
_
fld
,
self
.
_jac
,
self
.
_metric
,
True
)
@
staticmethod
@
staticmethod
def
make_var
(
field
,
want_metric
=
False
):
def
make_var
(
field
,
want_metric
=
False
):
...
...
nifty6/minimization/energy_adapter.py
View file @
01779e03
...
@@ -47,7 +47,7 @@ class EnergyAdapter(Energy):
...
@@ -47,7 +47,7 @@ class EnergyAdapter(Energy):
self
.
_want_metric
=
want_metric
self
.
_want_metric
=
want_metric
lin
=
Linearization
.
make_partial_var
(
position
,
constants
,
want_metric
)
lin
=
Linearization
.
make_partial_var
(
position
,
constants
,
want_metric
)
tmp
=
self
.
_op
(
lin
)
tmp
=
self
.
_op
(
lin
)
self
.
_val
=
tmp
.
val
.
val
[()]
self
.
_val
=
tmp
.
val
[()]
self
.
_grad
=
tmp
.
gradient
self
.
_grad
=
tmp
.
gradient
self
.
_metric
=
tmp
.
_metric
self
.
_metric
=
tmp
.
_metric
...
...
nifty6/minimization/metric_gaussian_kl.py
View file @
01779e03
...
@@ -198,10 +198,10 @@ class MetricGaussianKL(Energy):
...
@@ -198,10 +198,10 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
self
.
_lin
-
s
)
tmp
=
tmp
+
self
.
_hamiltonian
(
self
.
_lin
-
s
)
if
v
is
None
:
if
v
is
None
:
v
=
tmp
.
val
.
val
_rw
()
v
=
tmp
.
val_rw
()
g
=
tmp
.
gradient
g
=
tmp
.
gradient
else
:
else
:
v
+=
tmp
.
val
.
val
v
+=
tmp
.
val
g
=
g
+
tmp
.
gradient
g
=
g
+
tmp
.
gradient
self
.
_val
=
_np_allreduce_sum
(
self
.
_comm
,
v
)[()]
/
self
.
_n_eff_samples
self
.
_val
=
_np_allreduce_sum
(
self
.
_comm
,
v
)[()]
/
self
.
_n_eff_samples
self
.
_grad
=
_allreduce_sum_field
(
self
.
_comm
,
g
)
/
self
.
_n_eff_samples
self
.
_grad
=
_allreduce_sum_field
(
self
.
_comm
,
g
)
/
self
.
_n_eff_samples
...
...
nifty6/multi_field.py
View file @
01779e03
...
@@ -83,6 +83,10 @@ class MultiField(Operator):
...
@@ -83,6 +83,10 @@ class MultiField(Operator):
def
domain
(
self
):
def
domain
(
self
):
return
self
.
_domain
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_domain
# @property
# @property
# def dtype(self):
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
# return {key: val.dtype for key, val in self._val.items()}
...
@@ -136,6 +140,10 @@ class MultiField(Operator):
...
@@ -136,6 +140,10 @@ class MultiField(Operator):
return
MultiField
(
domain
,
tuple
(
Field
(
dom
,
val
)
return
MultiField
(
domain
,
tuple
(
Field
(
dom
,
val
)
for
dom
in
domain
.
_domains
))
for
dom
in
domain
.
_domains
))
@
property
def
fld
(
self
):
return
self
@
property
@
property
def
val
(
self
):
def
val
(
self
):
return
{
key
:
val
.
val
return
{
key
:
val
.
val
...
...
nifty6/operators/energy_operators.py
View file @
01779e03
...
@@ -58,10 +58,10 @@ class Squared2NormOperator(EnergyOperator):
...
@@ -58,10 +58,10 @@ class Squared2NormOperator(EnergyOperator):
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
res
=
x
.
fld
.
vdot
(
x
.
fld
)
if
x
.
jac
is
None
:
if
x
.
jac
is
None
:
return
x
.
vdot
(
x
)
return
res
res
=
x
.
val
.
vdot
(
x
.
val
)
return
x
.
new
(
res
,
VdotOperator
(
2
*
x
.
fld
))
return
x
.
new
(
res
,
VdotOperator
(
2
*
x
.
val
))