Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
092bf7fd
Commit
092bf7fd
authored
Jun 20, 2020
by
Philipp Arras
Browse files
Implement proper constant support 6/n
parent
9dea1d88
Changes
5
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
092bf7fd
...
@@ -344,7 +344,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
...
@@ -344,7 +344,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
assert
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
assert
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
rndinp
=
from_random
(
oplin
.
jac
.
target
)
rndinp
=
from_random
(
oplin
.
jac
.
target
)
assert_equal
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
))
assert_allclose
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
),
1e-13
,
1e-13
)
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
assert_equal
(
foo
,
0
*
foo
)
assert_equal
(
foo
,
0
*
foo
)
...
@@ -352,7 +353,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
...
@@ -352,7 +353,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin
.
metric
.
draw_sample
()
oplin
.
metric
.
draw_sample
()
assert
op0
.
domain
is
varloc
.
domain
assert
op0
.
domain
is
varloc
.
domain
_jac_vs_finite_differences
(
op0
,
varloc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
_jac_vs_finite_differences
(
op0
,
varloc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
def
_jac_vs_finite_differences
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
):
def
_jac_vs_finite_differences
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
):
...
...
src/operators/chain_operator.py
View file @
092bf7fd
...
@@ -138,13 +138,12 @@ class ChainOperator(LinearOperator):
...
@@ -138,13 +138,12 @@ class ChainOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"ChainOperator:
\n
"
+
utilities
.
indent
(
subs
)
return
"ChainOperator:
\n
"
+
utilities
.
indent
(
subs
)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
# from ..multi_domain import MultiDomain
from
..multi_domain
import
MultiDomain
# if not isinstance(self._domain, MultiDomain):
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
# return None, self
return
None
,
self
newop
=
None
# newop = None
for
op
in
reversed
(
self
.
_ops
):
# for op in reversed(self._ops):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
# newop = t_op if newop is None else op(newop)
return
c_inp
,
newop
# return c_inp, newop
src/operators/energy_operators.py
View file @
092bf7fd
...
@@ -175,26 +175,25 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -175,26 +175,25 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
**
(
-
2
)})
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
**
(
-
2
)})
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
# def _simplify_for_constant_input_nontrivial(self, c_inp):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
# from .simplify_for_const import ConstantEnergyOperator
from
.simplify_for_const
import
ConstantEnergyOperator
# assert len(c_inp.keys()) == 1
assert
len
(
c_inp
.
keys
())
==
1
# key = c_inp.keys()[0]
key
=
c_inp
.
keys
()[
0
]
# assert key in self._domain.keys()
assert
key
in
self
.
_domain
.
keys
()
# cst = c_inp[key]
cst
=
c_inp
[
key
]
# if key == self._kr:
if
key
==
self
.
_kr
:
# res = _SpecialGammaEnergy(cst).ducktape(self._ki)
res
=
_SpecialGammaEnergy
(
cst
).
ducktape
(
self
.
_ki
)
# else:
else
:
# dt = self._dt[self._kr]
dt
=
self
.
_dt
[
self
.
_kr
]
# res = GaussianEnergy(inverse_covariance=makeOp(cst),
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
# sampling_dtype=dt).ducktape(self._kr)
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
# trlog = cst.log().sum().val_rw()
trlog
=
cst
.
log
().
sum
().
val_rw
()
# if not _iscomplex(dt):
if
not
_iscomplex
(
dt
):
# trlog /= 2
trlog
/=
2
# res = res + ConstantEnergyOperator(res.domain, -trlog)
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
# res = res + ConstantEnergyOperator(self._domain, 0.)
res
=
res
+
ConstantEnergyOperator
(
0.
)
# assert res.domain is self.domain
assert
res
.
target
is
self
.
target
# assert res.target is self.target
return
None
,
res
# return None, res
class
_SpecialGammaEnergy
(
EnergyOperator
):
class
_SpecialGammaEnergy
(
EnergyOperator
):
...
...
src/operators/operator.py
View file @
092bf7fd
...
@@ -371,16 +371,15 @@ class _OpChain(_CombinedOperator):
...
@@ -371,16 +371,15 @@ class _OpChain(_CombinedOperator):
x
=
op
(
x
)
x
=
op
(
x
)
return
x
return
x
# def _simplify_for_constant_input_nontrivial(self, c_inp):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
# from ..multi_domain import MultiDomain
from
..multi_domain
import
MultiDomain
# if not isinstance(self._domain, MultiDomain):
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
# return None, self
return
None
,
self
newop
=
None
# newop = None
for
op
in
reversed
(
self
.
_ops
):
# for op in reversed(self._ops):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
# newop = t_op if newop is None else op(newop)
return
c_inp
,
newop
# return c_inp, newop
def
__repr__
(
self
):
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
...
@@ -413,20 +412,19 @@ class _OpProd(Operator):
...
@@ -413,20 +412,19 @@ class _OpProd(Operator):
jac
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
jac
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
return
lin1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
jac
)
return
lin1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
jac
)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
# from ..multi_domain import MultiDomain
from
..multi_domain
import
MultiDomain
# from .simplify_for_const import ConstCollector
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
# f1, o1 = self._op1.simplify_for_constant_input(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
# c_inp.extract_part(self._op1.domain))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
# f2, o2 = self._op2.simplify_for_constant_input(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
# c_inp.extract_part(self._op2.domain))
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
# if not isinstance(self._target, MultiDomain):
return
None
,
_OpProd
(
o1
,
o2
)
# return None, _OpProd(o1, o2)
cc
=
ConstCollector
()
# cc = ConstCollector()
cc
.
mult
(
f1
,
o1
.
target
)
# cc.mult(f1, o1.target)
cc
.
mult
(
f2
,
o2
.
target
)
# cc.mult(f2, o2.target)
return
cc
.
constfield
,
_OpProd
(
o1
,
o2
)
# return cc.constfield, _OpProd(o1, o2)
def
__repr__
(
self
):
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
...
@@ -459,20 +457,19 @@ class _OpSum(Operator):
...
@@ -459,20 +457,19 @@ class _OpSum(Operator):
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
return
res
return
res
# def _simplify_for_constant_input_nontrivial(self, c_inp):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
# from ..multi_domain import MultiDomain
from
..multi_domain
import
MultiDomain
# from .simplify_for_const import ConstCollector
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
# f1, o1 = self._op1.simplify_for_constant_input(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
# c_inp.extract_part(self._op1.domain))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
# f2, o2 = self._op2.simplify_for_constant_input(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
# c_inp.extract_part(self._op2.domain))
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
# if not isinstance(self._target, MultiDomain):
return
None
,
_OpSum
(
o1
,
o2
)
# return None, _OpSum(o1, o2)
cc
=
ConstCollector
()
# cc = ConstCollector()
cc
.
add
(
f1
,
o1
.
target
)
# cc.add(f1, o1.target)
cc
.
add
(
f2
,
o2
.
target
)
# cc.add(f2, o2.target)
return
cc
.
constfield
,
_OpSum
(
o1
,
o2
)
# return cc.constfield, _OpSum(o1, o2)
def
__repr__
(
self
):
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
...
...
src/operators/sum_operator.py
View file @
092bf7fd
...
@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
...
@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"SumOperator:
\n
"
+
indent
(
subs
)
return
"SumOperator:
\n
"
+
indent
(
subs
)
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
#
f = []
f
=
[]
#
o = []
o
=
[]
#
for op in self._ops:
for
op
in
self
.
_ops
:
#
tf, to = op.simplify_for_constant_input(
tf
,
to
=
op
.
simplify_for_constant_input
(
#
c_inp.extract_part(op.domain))
c_inp
.
extract_part
(
op
.
domain
))
#
f.append(tf)
f
.
append
(
tf
)
#
o.append(to)
o
.
append
(
to
)
#
from ..multi_domain import MultiDomain
from
..multi_domain
import
MultiDomain
#
if not isinstance(self._target, MultiDomain):
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
#
fullop = None
fullop
=
None
#
for to, n in zip(o, self._neg):
for
to
,
n
in
zip
(
o
,
self
.
_neg
):
#
op = to if not n else -to
op
=
to
if
not
n
else
-
to
#
fullop = op if fullop is None else fullop + op
fullop
=
op
if
fullop
is
None
else
fullop
+
op
#
return None, fullop
return
None
,
fullop
#
from .simplify_for_const import ConstCollector
from
.simplify_for_const
import
ConstCollector
#
cc = ConstCollector()
cc
=
ConstCollector
()
#
fullop = None
fullop
=
None
#
for tf, to, n in zip(f, o, self._neg):
for
tf
,
to
,
n
in
zip
(
f
,
o
,
self
.
_neg
):
#
cc.add(tf, to.target)
cc
.
add
(
tf
,
to
.
target
)
#
op = to if not n else -to
op
=
to
if
not
n
else
-
to
#
fullop = op if fullop is None else fullop + op
fullop
=
op
if
fullop
is
None
else
fullop
+
op
#
return cc.constfield, fullop
return
cc
.
constfield
,
fullop
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment