Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
092bf7fd
Commit
092bf7fd
authored
Jun 20, 2020
by
Philipp Arras
Browse files
Implement proper constant support 6/n
parent
9dea1d88
Changes
5
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
092bf7fd
...
...
@@ -344,7 +344,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
assert
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
rndinp
=
from_random
(
oplin
.
jac
.
target
)
assert_equal
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
))
assert_allclose
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
),
1e-13
,
1e-13
)
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
assert_equal
(
foo
,
0
*
foo
)
...
...
@@ -352,7 +353,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin
.
metric
.
draw_sample
()
assert
op0
.
domain
is
varloc
.
domain
_jac_vs_finite_differences
(
op0
,
varloc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
_jac_vs_finite_differences
(
op0
,
varloc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
def
_jac_vs_finite_differences
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
):
...
...
src/operators/chain_operator.py
View file @
092bf7fd
...
...
@@ -138,13 +138,12 @@ class ChainOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"ChainOperator:
\n
"
+
utilities
.
indent
(
subs
)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain):
# return None, self
# newop = None
# for op in reversed(self._ops):
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
# newop = t_op if newop is None else op(newop)
# return c_inp, newop
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
return
None
,
self
newop
=
None
for
op
in
reversed
(
self
.
_ops
):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
return
c_inp
,
newop
src/operators/energy_operators.py
View file @
092bf7fd
...
...
@@ -175,26 +175,25 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
**
(
-
2
)})
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from .simplify_for_const import ConstantEnergyOperator
# assert len(c_inp.keys()) == 1
# key = c_inp.keys()[0]
# assert key in self._domain.keys()
# cst = c_inp[key]
# if key == self._kr:
# res = _SpecialGammaEnergy(cst).ducktape(self._ki)
# else:
# dt = self._dt[self._kr]
# res = GaussianEnergy(inverse_covariance=makeOp(cst),
# sampling_dtype=dt).ducktape(self._kr)
# trlog = cst.log().sum().val_rw()
# if not _iscomplex(dt):
# trlog /= 2
# res = res + ConstantEnergyOperator(res.domain, -trlog)
# res = res + ConstantEnergyOperator(self._domain, 0.)
# assert res.domain is self.domain
# assert res.target is self.target
# return None, res
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantEnergyOperator
assert
len
(
c_inp
.
keys
())
==
1
key
=
c_inp
.
keys
()[
0
]
assert
key
in
self
.
_domain
.
keys
()
cst
=
c_inp
[
key
]
if
key
==
self
.
_kr
:
res
=
_SpecialGammaEnergy
(
cst
).
ducktape
(
self
.
_ki
)
else
:
dt
=
self
.
_dt
[
self
.
_kr
]
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
trlog
=
cst
.
log
().
sum
().
val_rw
()
if
not
_iscomplex
(
dt
):
trlog
/=
2
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
0.
)
assert
res
.
target
is
self
.
target
return
None
,
res
class
_SpecialGammaEnergy
(
EnergyOperator
):
...
...
src/operators/operator.py
View file @
092bf7fd
...
...
@@ -371,16 +371,15 @@ class _OpChain(_CombinedOperator):
x
=
op
(
x
)
return
x
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain):
# return None, self
# newop = None
# for op in reversed(self._ops):
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
# newop = t_op if newop is None else op(newop)
# return c_inp, newop
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
return
None
,
self
newop
=
None
for
op
in
reversed
(
self
.
_ops
):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
return
c_inp
,
newop
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
...
...
@@ -413,20 +412,19 @@ class _OpProd(Operator):
jac
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
return
lin1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
jac
)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector
# f1, o1 = self._op1.simplify_for_constant_input(
# c_inp.extract_part(self._op1.domain))
# f2, o2 = self._op2.simplify_for_constant_input(
# c_inp.extract_part(self._op2.domain))
# if not isinstance(self._target, MultiDomain):
# return None, _OpProd(o1, o2)
# cc = ConstCollector()
# cc.mult(f1, o1.target)
# cc.mult(f2, o2.target)
# return cc.constfield, _OpProd(o1, o2)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpProd
(
o1
,
o2
)
cc
=
ConstCollector
()
cc
.
mult
(
f1
,
o1
.
target
)
cc
.
mult
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpProd
(
o1
,
o2
)
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
...
...
@@ -459,20 +457,19 @@ class _OpSum(Operator):
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
return
res
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector
# f1, o1 = self._op1.simplify_for_constant_input(
# c_inp.extract_part(self._op1.domain))
# f2, o2 = self._op2.simplify_for_constant_input(
# c_inp.extract_part(self._op2.domain))
# if not isinstance(self._target, MultiDomain):
# return None, _OpSum(o1, o2)
# cc = ConstCollector()
# cc.add(f1, o1.target)
# cc.add(f2, o2.target)
# return cc.constfield, _OpSum(o1, o2)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpSum
(
o1
,
o2
)
cc
=
ConstCollector
()
cc
.
add
(
f1
,
o1
.
target
)
cc
.
add
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpSum
(
o1
,
o2
)
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
...
...
src/operators/sum_operator.py
View file @
092bf7fd
...
...
@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"SumOperator:
\n
"
+
indent
(
subs
)
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
f = []
#
o = []
#
for op in self._ops:
#
tf, to = op.simplify_for_constant_input(
#
c_inp.extract_part(op.domain))
#
f.append(tf)
#
o.append(to)
#
from ..multi_domain import MultiDomain
#
if not isinstance(self._target, MultiDomain):
#
fullop = None
#
for to, n in zip(o, self._neg):
#
op = to if not n else -to
#
fullop = op if fullop is None else fullop + op
#
return None, fullop
#
from .simplify_for_const import ConstCollector
#
cc = ConstCollector()
#
fullop = None
#
for tf, to, n in zip(f, o, self._neg):
#
cc.add(tf, to.target)
#
op = to if not n else -to
#
fullop = op if fullop is None else fullop + op
#
return cc.constfield, fullop
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
f
=
[]
o
=
[]
for
op
in
self
.
_ops
:
tf
,
to
=
op
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
op
.
domain
))
f
.
append
(
tf
)
o
.
append
(
to
)
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
fullop
=
None
for
to
,
n
in
zip
(
o
,
self
.
_neg
):
op
=
to
if
not
n
else
-
to
fullop
=
op
if
fullop
is
None
else
fullop
+
op
return
None
,
fullop
from
.simplify_for_const
import
ConstCollector
cc
=
ConstCollector
()
fullop
=
None
for
tf
,
to
,
n
in
zip
(
f
,
o
,
self
.
_neg
):
cc
.
add
(
tf
,
to
.
target
)
op
=
to
if
not
n
else
-
to
fullop
=
op
if
fullop
is
None
else
fullop
+
op
return
cc
.
constfield
,
fullop
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment