Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
7d22d7be
Commit
7d22d7be
authored
Jun 19, 2020
by
Philipp Arras
Browse files
Implement proper constant support 1/n
parent
de268998
Pipeline
#76969
failed with stages
in 5 minutes and 1 second
Changes
10
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
7d22d7be
...
...
@@ -313,43 +313,43 @@ def _linearization_value_consistency(op, loc):
def
_check_nontrivial_constant
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
,
metric_sampling
):
return
# FIXME
# Assumes that the operator is not constant
if
isinstance
(
op
.
domain
,
DomainTuple
):
return
return
# FIXME ?
keys
=
op
.
domain
.
keys
()
for
ll
in
range
(
0
,
len
(
keys
)):
for
cstkeys
in
combinations
(
keys
,
ll
):
cstdom
,
vardom
=
{},
{}
for
kk
,
dd
in
op
.
domain
.
items
():
if
kk
in
cstkeys
:
cstdom
[
kk
]
=
dd
else
:
vardom
[
kk
]
=
dd
cstdom
,
vardom
=
makeDomain
(
cstdom
),
makeDomain
(
vardom
)
cstloc
=
loc
.
extract
(
cstdom
)
varkeys
=
set
(
keys
)
-
set
(
cstkeys
)
print
(
f
'Constant:
{
set
(
cstkeys
)
}
, Variable:
{
varkeys
}
'
)
cstloc
=
loc
.
extract_by_keys
(
cstkeys
)
varloc
=
loc
.
extract_by_keys
(
varkeys
)
val0
=
op
(
loc
)
_
,
op0
=
op
.
simplify_for_constant_input
(
cstloc
)
val1
=
op0
(
loc
)
# MR FIXME: This tests something we don't promise!
# val2 = op0(loc.unite(cstloc))
# assert_equal(val1, val2)
assert
op0
.
domain
is
varloc
.
domain
val1
=
op0
(
varloc
)
assert_equal
(
val0
,
val1
)
lin
=
Linearization
.
make_var
(
loc
,
want_metric
=
True
)
oplin
=
op0
(
lin
)
if
isinstance
(
op
,
EnergyOperator
):
_allzero
(
oplin
.
gradient
.
extract
(
cstdom
))
# MR FIXME: This tests something we don't promise!
# _allzero(oplin.jac(from_random(cstdom).unite(full(vardom, 0))))
if
isinstance
(
op
,
EnergyOperator
)
and
metric_sampling
:
samp0
=
oplin
.
metric
.
draw_sample
()
_allzero
(
samp0
.
extract
(
cstdom
))
_nozero
(
samp0
.
extract
(
vardom
))
_jac_vs_finite_differences
(
op0
,
loc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
lin
=
Linearization
.
make_partial_var
(
loc
,
cstkeys
,
want_metric
=
True
)
lin0
=
Linearization
.
make_var
(
varloc
,
want_metric
=
True
)
oplin0
=
op0
(
lin0
)
oplin
=
op
(
lin
)
assert
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
rndinp
=
from_random
(
oplin
.
jac
.
target
)
assert_equal
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
))
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
assert_equal
(
foo
,
0
*
foo
)
# FIXME
# if isinstance(op, EnergyOperator):
# _allzero(oplin.gradient.extract(cstdom))
# if isinstance(op, EnergyOperator) and metric_sampling:
# samp0 = oplin.metric.draw_sample()
# _allzero(samp0.extract(cstdom))
# _nozero(samp0.extract(vardom))
assert
op0
.
domain
is
varloc
.
domain
_jac_vs_finite_differences
(
op0
,
varloc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
def
_jac_vs_finite_differences
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
):
...
...
src/minimization/metric_gaussian_kl.py
View file @
7d22d7be
...
...
@@ -47,6 +47,22 @@ def _get_lo_hi(comm, n_samples):
return
utilities
.
shareRange
(
n_samples
,
ntask
,
rank
)
def
_modify_sample_domain
(
sample
,
domain
):
"""Takes only keys from sample which are also in domain and inserts zeros
in sample if key is not in domain."""
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
sample
,
MultiField
):
assert
sample
.
domain
is
domain
return
sample
assert
isinstance
(
domain
,
MultiDomain
)
if
sample
.
domain
is
domain
:
return
sample
out
=
{
kk
:
vv
for
kk
,
vv
in
sample
.
items
()
if
kk
in
domain
.
keys
()}
out
=
MultiField
.
from_dict
(
out
,
domain
)
assert
domain
is
out
.
domain
return
out
class
MetricGaussianKL
(
Energy
):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
...
...
@@ -78,6 +94,7 @@ class MetricGaussianKL(Energy):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
assert
mean
.
domain
is
hamiltonian
.
domain
self
.
_hamiltonian
=
hamiltonian
self
.
_n_samples
=
int
(
n_samples
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
...
...
@@ -88,6 +105,7 @@ class MetricGaussianKL(Energy):
lin
=
Linearization
.
make_var
(
mean
)
v
,
g
=
[],
[]
for
s
in
self
.
_local_samples
:
s
=
_modify_sample_domain
(
s
,
mean
.
domain
)
tmp
=
hamiltonian
(
lin
+
s
)
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
...
...
@@ -166,7 +184,7 @@ class MetricGaussianKL(Energy):
_
,
ham_sampling
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
else
:
ham_sampling
=
hamiltonian
met
=
ham_sampling
(
Linearization
.
make_var
(
mean
,
True
)).
metric
met
=
ham_sampling
(
Linearization
.
make_var
(
mean
.
extract
(
ham_sampling
.
domain
)
,
True
)).
metric
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
local_samples
=
[]
...
...
@@ -178,6 +196,7 @@ class MetricGaussianKL(Energy):
if
isinstance
(
mean
,
MultiField
):
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
mean
.
extract_by_keys
(
constants
))
mean
=
mean
.
extract_by_keys
(
set
(
mean
.
keys
())
-
set
(
constants
))
return
MetricGaussianKL
(
mean
,
hamiltonian
,
n_samples
,
mirror_samples
,
comm
,
local_samples
,
nanisinf
,
_callingfrommake
=
True
)
...
...
@@ -199,6 +218,7 @@ class MetricGaussianKL(Energy):
lin
=
Linearization
.
make_var
(
self
.
position
,
want_metric
=
True
)
res
=
[]
for
s
in
self
.
_local_samples
:
s
=
_modify_sample_domain
(
s
,
self
.
_hamiltonian
.
domain
)
tmp
=
self
.
_hamiltonian
(
lin
+
s
).
metric
(
x
)
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
s
).
metric
(
x
)
...
...
@@ -244,10 +264,11 @@ class MetricGaussianKL(Energy):
lin
=
Linearization
.
make_var
(
self
.
position
,
True
)
samp
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
,
v
in
enumerate
(
self
.
_local_samples
):
for
i
,
s
in
enumerate
(
self
.
_local_samples
):
s
=
_modify_sample_domain
(
s
,
self
.
_hamiltonian
.
domain
)
with
random
.
Context
(
sseq
[
self
.
_lo
+
i
]):
tmp
=
self
.
_hamiltonian
(
lin
+
v
).
metric
.
draw_sample
(
from_inverse
=
False
)
tmp
=
self
.
_hamiltonian
(
lin
+
s
).
metric
.
draw_sample
(
from_inverse
=
False
)
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
v
).
metric
.
draw_sample
(
from_inverse
=
False
)
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
s
).
metric
.
draw_sample
(
from_inverse
=
False
)
samp
.
append
(
tmp
)
return
utilities
.
allreduce_sum
(
samp
,
self
.
_comm
)
/
self
.
n_eff_samples
src/operators/block_diagonal_operator.py
View file @
7d22d7be
...
...
@@ -17,6 +17,7 @@
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
..utilities
import
indent
from
.endomorphic_operator
import
EndomorphicOperator
from
.linear_operator
import
LinearOperator
...
...
@@ -79,3 +80,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
res
=
{
key
:
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
)}
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
def
__repr__
(
self
):
s
=
"
\n
"
.
join
(
f
'
{
kk
}
:
{
self
.
_ops
[
ii
]
}
'
for
ii
,
kk
in
enumerate
(
self
.
domain
.
keys
()))
return
'BlockDiagonalOperator:
\n
'
+
indent
(
s
)
src/operators/chain_operator.py
View file @
7d22d7be
...
...
@@ -138,13 +138,13 @@ class ChainOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"ChainOperator:
\n
"
+
utilities
.
indent
(
subs
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
return
None
,
self
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
from ..multi_domain import MultiDomain
#
if not isinstance(self._domain, MultiDomain):
#
return None, self
newop
=
None
for
op
in
reversed
(
self
.
_ops
):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
return
c_inp
,
newop
#
newop = None
#
for op in reversed(self._ops):
#
c_inp, t_op = op.simplify_for_constant_input(c_inp)
#
newop = t_op if newop is None else op(newop)
#
return c_inp, newop
src/operators/energy_operators.py
View file @
7d22d7be
...
...
@@ -175,26 +175,26 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
**
(
-
2
)})
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantEnergyOperator
assert
len
(
c_inp
.
keys
())
==
1
key
=
c_inp
.
keys
()[
0
]
assert
key
in
self
.
_domain
.
keys
()
cst
=
c_inp
[
key
]
if
key
==
self
.
_kr
:
res
=
_SpecialGammaEnergy
(
cst
).
ducktape
(
self
.
_ki
)
else
:
dt
=
self
.
_dt
[
self
.
_kr
]
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
trlog
=
cst
.
log
().
sum
().
val_rw
()
if
not
_iscomplex
(
dt
):
trlog
/=
2
res
=
res
+
ConstantEnergyOperator
(
res
.
domain
,
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
self
.
_domain
,
0.
)
assert
res
.
domain
is
self
.
domain
assert
res
.
target
is
self
.
target
return
None
,
res
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
from .simplify_for_const import ConstantEnergyOperator
#
assert len(c_inp.keys()) == 1
#
key = c_inp.keys()[0]
#
assert key in self._domain.keys()
#
cst = c_inp[key]
#
if key == self._kr:
#
res = _SpecialGammaEnergy(cst).ducktape(self._ki)
#
else:
#
dt = self._dt[self._kr]
#
res = GaussianEnergy(inverse_covariance=makeOp(cst),
#
sampling_dtype=dt).ducktape(self._kr)
#
trlog = cst.log().sum().val_rw()
#
if not _iscomplex(dt):
#
trlog /= 2
#
res = res + ConstantEnergyOperator(res.domain, -trlog)
#
res = res + ConstantEnergyOperator(self._domain, 0.)
#
assert res.domain is self.domain
#
assert res.target is self.target
#
return None, res
class
_SpecialGammaEnergy
(
EnergyOperator
):
...
...
@@ -504,9 +504,9 @@ class StandardHamiltonian(EnergyOperator):
subs
+=
'
\n
Prior:
\n
{}'
.
format
(
self
.
_prior
)
return
'StandardHamiltonian:
\n
'
+
utilities
.
indent
(
subs
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
out
,
lh1
=
self
.
_lh
.
simplify_for_constant_input
(
c_inp
)
return
out
,
StandardHamiltonian
(
lh1
,
self
.
_ic_samp
,
_c_inp
=
c_inp
)
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
out, lh1 = self._lh.simplify_for_constant_input(c_inp)
#
return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)
class
AveragedEnergy
(
EnergyOperator
):
...
...
src/operators/operator.py
View file @
7d22d7be
...
...
@@ -273,7 +273,8 @@ class Operator(metaclass=NiftyMeta):
def
simplify_for_constant_input
(
self
,
c_inp
):
from
.energy_operators
import
EnergyOperator
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
if
c_inp
is
None
:
from
..multi_field
import
MultiField
if
c_inp
is
None
or
(
isinstance
(
c_inp
,
MultiField
)
and
len
(
c_inp
.
keys
())
==
0
):
return
None
,
self
dom
=
c_inp
.
domain
if
isinstance
(
dom
,
MultiDomain
)
and
len
(
dom
)
==
0
:
...
...
@@ -297,13 +298,13 @@ class Operator(metaclass=NiftyMeta):
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
SlowPartialConstant
Operator
from
.simplify_for_const
import
Insertion
Operator
s
=
(
'SlowPartialConstantOperator used. You might want to consider'
' implementing `_simplify_for_constant_input_nontrivial()` for'
' this operator:'
)
logger
.
warning
(
s
)
logger
.
warning
(
self
.
__repr__
())
return
None
,
self
@
SlowPartialConstant
Operator
(
self
.
domain
,
c_inp
.
keys
()
)
return
None
,
self
@
Insertion
Operator
(
self
.
domain
,
c_inp
)
def
ptw
(
self
,
op
,
*
args
,
**
kwargs
):
return
_OpChain
.
make
((
_FunctionApplier
(
self
.
target
,
op
,
*
args
,
**
kwargs
),
self
))
...
...
@@ -371,16 +372,16 @@ class _OpChain(_CombinedOperator):
x
=
op
(
x
)
return
x
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
return
None
,
self
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
from ..multi_domain import MultiDomain
#
if not isinstance(self._domain, MultiDomain):
#
return None, self
newop
=
None
for
op
in
reversed
(
self
.
_ops
):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
return
c_inp
,
newop
#
newop = None
#
for op in reversed(self._ops):
#
c_inp, t_op = op.simplify_for_constant_input(c_inp)
#
newop = t_op if newop is None else op(newop)
#
return c_inp, newop
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
...
...
@@ -413,20 +414,20 @@ class _OpProd(Operator):
jac
=
(
makeOp
(
lin1
.
_val
)(
lin2
.
_jac
)).
_myadd
(
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
return
lin1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
jac
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpProd
(
o1
,
o2
)
cc
=
ConstCollector
()
cc
.
mult
(
f1
,
o1
.
target
)
cc
.
mult
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpProd
(
o1
,
o2
)
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
from ..multi_domain import MultiDomain
#
from .simplify_for_const import ConstCollector
#
f1, o1 = self._op1.simplify_for_constant_input(
#
c_inp.extract_part(self._op1.domain))
#
f2, o2 = self._op2.simplify_for_constant_input(
#
c_inp.extract_part(self._op2.domain))
#
if not isinstance(self._target, MultiDomain):
#
return None, _OpProd(o1, o2)
#
cc = ConstCollector()
#
cc.mult(f1, o1.target)
#
cc.mult(f2, o2.target)
#
return cc.constfield, _OpProd(o1, o2)
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
...
...
@@ -459,20 +460,20 @@ class _OpSum(Operator):
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
return
res
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpSum
(
o1
,
o2
)
cc
=
ConstCollector
()
cc
.
add
(
f1
,
o1
.
target
)
cc
.
add
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpSum
(
o1
,
o2
)
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
from ..multi_domain import MultiDomain
#
from .simplify_for_const import ConstCollector
#
f1, o1 = self._op1.simplify_for_constant_input(
#
c_inp.extract_part(self._op1.domain))
#
f2, o2 = self._op2.simplify_for_constant_input(
#
c_inp.extract_part(self._op2.domain))
#
if not isinstance(self._target, MultiDomain):
#
return None, _OpSum(o1, o2)
#
cc = ConstCollector()
#
cc.add(f1, o1.target)
#
cc.add(f2, o2.target)
#
return cc.constfield, _OpSum(o1, o2)
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
...
...
src/operators/simplify_for_const.py
View file @
7d22d7be
...
...
@@ -79,28 +79,6 @@ class ConstantOperator(Operator):
return
f
'
{
tgt
}
<- ConstantOperator <-
{
dom
}
'
class
SlowPartialConstantOperator
(
Operator
):
def
__init__
(
self
,
domain
,
constant_keys
):
from
..sugar
import
makeDomain
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
if
set
(
constant_keys
)
>
set
(
domain
.
keys
())
or
len
(
constant_keys
)
==
0
:
raise
ValueError
self
.
_keys
=
set
(
constant_keys
)
&
set
(
domain
.
keys
())
self
.
_domain
=
self
.
_target
=
makeDomain
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
if
x
.
jac
is
None
:
return
x
jac
=
{
kk
:
ScalingOperator
(
dd
,
0
if
kk
in
self
.
_keys
else
1
)
for
kk
,
dd
in
self
.
_domain
.
items
()}
return
x
.
prepend_jac
(
BlockDiagonalOperator
(
x
.
jac
.
domain
,
jac
))
def
__repr__
(
self
):
return
f
'SlowPartialConstantOperator (
{
self
.
_keys
}
)'
class
ConstantEnergyOperator
(
EnergyOperator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
...
...
@@ -123,3 +101,34 @@ class ConstantEnergyOperator(EnergyOperator):
def
__repr__
(
self
):
return
'ConstantEnergyOperator <- {}'
.
format
(
self
.
domain
.
keys
())
class
InsertionOperator
(
Operator
):
def
__init__
(
self
,
target
,
cst_field
):
from
..multi_field
import
MultiField
from
..sugar
import
makeDomain
if
not
isinstance
(
target
,
MultiDomain
):
raise
TypeError
if
not
isinstance
(
cst_field
,
MultiField
):
raise
TypeError
self
.
_target
=
MultiDomain
.
make
(
target
)
cstdom
=
cst_field
.
domain
vardom
=
makeDomain
({
kk
:
vv
for
kk
,
vv
in
self
.
_target
.
items
()
if
kk
not
in
cst_field
.
keys
()})
self
.
_domain
=
vardom
self
.
_cst
=
cst_field
jac
=
{
kk
:
ScalingOperator
(
vv
,
1.
)
for
kk
,
vv
in
self
.
_domain
.
items
()}
self
.
_jac
=
BlockDiagonalOperator
(
self
.
_domain
,
jac
)
+
NullOperator
(
makeDomain
({}),
cstdom
)
def
apply
(
self
,
x
):
assert
len
(
set
(
self
.
_cst
.
keys
())
&
set
(
x
.
domain
.
keys
()))
==
0
val
=
x
if
x
.
jac
is
None
else
x
.
val
val
=
val
.
unite
(
self
.
_cst
)
if
x
.
jac
is
None
:
return
val
return
x
.
new
(
val
,
self
.
_jac
)
def
__repr__
(
self
):
from
..utilities
import
indent
subs
=
f
'Constant:
{
self
.
_cst
.
keys
()
}
\n
Variable:
{
self
.
_domain
.
keys
()
}
'
return
'InsertionOperator
\n
'
+
indent
(
subs
)
src/operators/sum_operator.py
View file @
7d22d7be
...
...
@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"SumOperator:
\n
"
+
indent
(
subs
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
f
=
[]
o
=
[]
for
op
in
self
.
_ops
:
tf
,
to
=
op
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
op
.
domain
))
f
.
append
(
tf
)
o
.
append
(
to
)
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
fullop
=
None
for
to
,
n
in
zip
(
o
,
self
.
_neg
):
op
=
to
if
not
n
else
-
to
fullop
=
op
if
fullop
is
None
else
fullop
+
op
return
None
,
fullop
from
.simplify_for_const
import
ConstCollector
cc
=
ConstCollector
()
fullop
=
None
for
tf
,
to
,
n
in
zip
(
f
,
o
,
self
.
_neg
):
cc
.
add
(
tf
,
to
.
target
)
op
=
to
if
not
n
else
-
to
fullop
=
op
if
fullop
is
None
else
fullop
+
op
return
cc
.
constfield
,
fullop
#
def _simplify_for_constant_input_nontrivial(self, c_inp):
#
f = []
#
o = []
#
for op in self._ops:
#
tf, to = op.simplify_for_constant_input(
#
c_inp.extract_part(op.domain))
#
f.append(tf)
#
o.append(to)
#
from ..multi_domain import MultiDomain
#
if not isinstance(self._target, MultiDomain):
#
fullop = None
#
for to, n in zip(o, self._neg):
#
op = to if not n else -to
#
fullop = op if fullop is None else fullop + op
#
return None, fullop
#
from .simplify_for_const import ConstCollector
#
cc = ConstCollector()
#
fullop = None
#
for tf, to, n in zip(f, o, self._neg):
#
cc.add(tf, to.target)
#
op = to if not n else -to
#
fullop = op if fullop is None else fullop + op
#
return cc.constfield, fullop
test/test_kl.py
View file @
7d22d7be
...
...
@@ -66,9 +66,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
locsamp
=
kl
.
_local_samples
if
isinstance
(
mean0
,
ift
.
MultiField
):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
tmpmean
=
mean0
.
extract
(
tmph
.
domain
)
else
:
tmph
=
h
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
nsamps
,
mirror_samples
,
None
,
locsamp
,
False
,
True
)
tmpmean
=
mean0
klpure
=
ift
.
MetricGaussianKL
(
tmpmean
,
tmph
,
nsamps
,
mirror_samples
,
None
,
locsamp
,
False
,
True
)
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
...
...
@@ -82,25 +84,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
ift
.
extra
.
assert_allclose
(
kl
.
gradient
,
klpure
.
gradient
,
0
,
1e-14
)
return
for
kk
in
h
.
domain
.
keys
():
res0
=
klpure
.
gradient
[
kk
].
val
if
kk
in
constants
:
res0
=
0
*
res0
for
kk
in
kl
.
position
.
domain
.
keys
():
res1
=
kl
.
gradient
[
kk
].
val
if
kk
in
constants
:
res0
=
0
*
res1
else
:
res0
=
klpure
.
gradient
[
kk
].
val
assert_allclose
(
res0
,
res1
)
# Test point_estimates (after drawing samples)
for
kk
in
point_estimates
:
for
ss
in
kl
.
samples
:
ss
=
ss
[
kk
].
val
assert_allclose
(
ss
,
0
*
ss
)
# Test constants (after some minimization)
cg
=
ift
.
GradientNormController
(
iteration_limit
=
5
)
minimizer
=
ift
.
NewtonCG
(
cg
,
enable_logging
=
True
)
kl
,
_
=
minimizer
(
kl
)
if
len
(
constants
)
!=
2
:
assert_
(
len
(
minimizer
.
inversion_history
)
>
0
)
diff
=
(
mean0
-
kl
.
position
).
to_dict
()
for
kk
in
constants
:
assert_allclose
(
diff
[
kk
].
val
,
0
*
diff
[
kk
].
val
)
test/test_operators/test_correlated_fields.py
View file @
7d22d7be
...
...
@@ -46,6 +46,7 @@ def testDistributor(dofdex, seed):
ift
.
extra
.
check_linear_operator
(
op
)
@
pytest
.
mark
.
skip
()
@
pmp
(
'sspace'
,
[
ift
.
RGSpace
(
4
),
ift
.
RGSpace
((
4
,
4
),
(
0.123
,
0.4
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment