Commit 1200b36a authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'various_pa' into 'NIFTy_7'

Support complex data in `VariableCovarianceGaussianEnergy` and use simplify for constant input for KL and `EnergyAdapter`

See merge request !509
parents 6c25c285 8f95edad
Pipeline #75766 passed with stages
in 13 minutes and 34 seconds
......@@ -11,12 +11,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..linearization import Linearization
from ..minimization.energy import Energy
from ..sugar import makeDomain
class EnergyAdapter(Energy):
......@@ -40,12 +41,23 @@ class EnergyAdapter(Energy):
additional resources. Default: False.
"""
def __init__(self, position, op, constants=[], want_metric=False):
def __init__(self, position, op, constants=[], want_metric=False,
_op4eval=None):
super(EnergyAdapter, self).__init__(position)
self._op = op
self._constants = constants
self._op4eval = _op4eval
if self._op4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in position.domain.items()
if kk in constants}
dom = makeDomain(dom)
cstpos = position.extract(dom)
_, self._op4eval = op.simplify_for_constant_input(cstpos)
else:
self._op4eval = op
self._want_metric = want_metric
lin = Linearization.make_partial_var(position, constants, want_metric)
lin = Linearization.make_var(position, want_metric)
tmp = self._op(lin)
self._val = tmp.val.val[()]
self._grad = tmp.gradient
......@@ -53,7 +65,7 @@ class EnergyAdapter(Energy):
def at(self, position):
return EnergyAdapter(position, self._op, self._constants,
self._want_metric)
self._want_metric, self._op4eval)
@property
def value(self):
......
......@@ -17,14 +17,14 @@
import numpy as np
from .. import random, utilities
from ..field import Field
from .. import random
from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import full, makeOp
from ..sugar import makeDomain, makeOp
from .energy import Energy
......@@ -125,13 +125,19 @@ class MetricGaussianKL(Energy):
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError(
'Point estimates for whole domain. Use EnergyAdapter instead.')
self._hamiltonian = hamiltonian
if len(constants) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
if comm is not None:
......@@ -149,8 +155,13 @@ class MetricGaussianKL(Energy):
self._n_eff_samples *= 2
if _local_samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, self._point_estimates, True)).metric
if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items()
if kk in point_estimates}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = hamiltonian(Linearization.make_var(mean, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_local_samples = []
......@@ -163,7 +174,7 @@ class MetricGaussianKL(Energy):
if len(_local_samples) != self._hi-self._lo:
raise ValueError("# of samples mismatch")
self._local_samples = _local_samples
self._lin = Linearization.make_partial_var(mean, self._constants)
self._lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s)
......@@ -176,14 +187,14 @@ class MetricGaussianKL(Energy):
v.append(tv)
g.append(tg)
self._val = self._sumup(v)[()]/self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
if self._mitigate_nans and np.isnan(self._val):
self._val = np.inf
self._grad = self._sumup(g)/self._n_eff_samples
def at(self, position):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
position, self._hamiltonian, self._n_samples,
mirror_samples=self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
@property
......@@ -287,6 +298,10 @@ class MetricGaussianKL(Energy):
def _metric_sample(self, from_inverse=False):
if from_inverse:
raise NotImplementedError()
s = ('This draws from the Hamiltonian used for evaluation and does '
' not take point_estimates into accout. Make sure that this '
'is your intended use.')
logger.warning(s)
lin = self._lin.with_want_metric()
samp = []
sseq = random.spawn_sseq(self._n_samples)
......
......@@ -251,8 +251,10 @@ class MultiField(Operator):
def extract_part(self, subset):
if subset is self._domain:
return self
return MultiField.from_dict({key: self[key] for key in subset.keys()
if key in self})
dct = {key: self[key] for key in subset.keys() if key in self}
if len(dct) == 0:
return None
return MultiField.from_dict(dct)
def unite(self, other):
"""Merges two MultiFields on potentially different MultiDomains.
......
......@@ -150,21 +150,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
"""
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
self._r = str(residual_key)
self._icov = str(inverse_covariance_key)
self._kr = str(residual_key)
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
self._sampling_dtype = sampling_dtype
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._dt = sampling_dtype
_check_sampling_dtype(self._domain, sampling_dtype)
def apply(self, x):
self._check_input(x)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
r, i = x[self._kr], x[self._ki]
res = 0.5*(r.vdot(r*i.real).real - i.ptw("log").sum())
if not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
met = makeOp(MultiField.from_dict(mf))
return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype))
met = MultiField.from_dict({self._kr: i.val, self._ki: .5*i.val**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
class GaussianEnergy(EnergyOperator):
......@@ -223,9 +223,11 @@ class GaussianEnergy(EnergyOperator):
if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1)
self._trivial_invcov = True
else:
self._op = QuadraticFormOperator(inverse_covariance)
self._met = inverse_covariance
self._trivial_invcov = False
if sampling_dtype is not None:
self._met = SamplingDtypeSetter(self._met, sampling_dtype)
......@@ -245,6 +247,10 @@ class GaussianEnergy(EnergyOperator):
return res.add_metric(self._met)
return res
def __repr__(self):
dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
return f'GaussianEnergy {dom}'
class PoissonianEnergy(EnergyOperator):
"""Computes likelihood Hamiltonians of expected count field constrained by
......
......@@ -11,13 +11,15 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .. import pointwise
from ..logger import logger
from ..multi_domain import MultiDomain
from ..utilities import NiftyMeta, indent
......@@ -269,15 +271,35 @@ class Operator(metaclass=NiftyMeta):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
from .energy_operators import EnergyOperator
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
raise ValueError
if c_inp.domain is self.domain:
if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp))
else:
op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
if not isinstance(c_inp.domain, MultiDomain):
raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
from .simplify_for_const import SlowPartialConstantOperator
s = ('SlowPartialConstantOperator used. You might want to consider'
' implementing `_simplify_for_constant_input_nontrivial()` for'
' this operator:')
logger.warning(s)
logger.warning(self.__repr__())
return None, self @ SlowPartialConstantOperator(self.domain, c_inp.keys())
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
......@@ -291,65 +313,6 @@ for f in pointwise.ptw_dict.keys():
setattr(Operator, f, func(f))
class _ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class _ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from .simple_linear_operators import NullOperator
self._check_input(x)
if x.jac is not None:
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname, *args, **kwargs):
from ..sugar import makeDomain
......@@ -444,16 +407,16 @@ class _OpProd(Operator):
return lin1.new(lin1._val*lin2._val, jac)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = _ConstCollector()
cc = ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
......@@ -490,16 +453,16 @@ class _OpSum(Operator):
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = _ConstCollector()
cc = ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
......
......@@ -349,6 +349,11 @@ class NullOperator(LinearOperator):
self._check_input(x, mode)
return self._nullfield(self._tgt(mode))
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- NullOperator <- {dom}'
class PartialExtractor(LinearOperator):
def __init__(self, domain, target):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..multi_domain import MultiDomain
from .block_diagonal_operator import BlockDiagonalOperator
from .energy_operators import EnergyOperator
from .operator import Operator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import NullOperator
class ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from .simple_linear_operators import NullOperator
self._check_input(x)
if x.jac is not None:
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- ConstantOperator <- {dom}'
class SlowPartialConstantOperator(Operator):
def __init__(self, domain, constant_keys):
from ..sugar import makeDomain
if not isinstance(domain, MultiDomain):
raise TypeError
if set(constant_keys) > set(domain.keys()) or len(constant_keys) == 0:
raise ValueError
self._keys = set(constant_keys) & set(domain.keys())
self._domain = self._target = makeDomain(domain)
def apply(self, x):
self._check_input(x)
if x.jac is None:
return x
jac = {kk: ScalingOperator(dd, 0 if kk in self._keys else 1)
for kk, dd in self._domain.items()}
return x.prepend_jac(BlockDiagonalOperator(x.jac.domain, jac))
def __repr__(self):
return f'SlowPartialConstantOperator ({self._keys})'
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
if self.target is not output.domain:
raise TypeError
self._output = output
def apply(self, x):
self._check_input(x)
if x.jac is not None:
val = self._output
jac = NullOperator(self._domain, self._target)
met = NullOperator(self._domain, self._domain) if x.want_metric else None
return x.new(val, jac, met)
return self._output
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
......@@ -224,8 +224,8 @@ class SumOperator(LinearOperator):
fullop = op if fullop is None else fullop + op
return None, fullop
from .operator import _ConstCollector
cc = _ConstCollector()
from .simplify_for_const import ConstCollector
cc = ConstCollector()
fullop = None
for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target)
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_, assert_allclose
from numpy.testing import assert_, assert_allclose, assert_raises
import nifty7 as ift
......@@ -44,13 +44,17 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
mean0 = ift.from_random(h.domain, 'normal')
nsamps = 2
kl = ift.MetricGaussianKL(mean0,
h,
nsamps,
constants=constants,
point_estimates=point_estimates,
mirror_samples=mirror_samples,
napprox=0)
args = {'constants': constants,
'point_estimates': point_estimates,
'mirror_samples': mirror_samples,
'n_samples': nsamps,
'mean': mean0,
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args)
return
kl = ift.MetricGaussianKL(**args)
assert_(len(ic.history) > 0)
assert_(len(ic.history) == len(ic.history.time_stamps))
assert_(len(ic.history) == len(ic.history.energy_values))
......@@ -64,7 +68,8 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
h,
nsamps,
mirror_samples=mirror_samples,
napprox=0,
constants=constants,
point_estimates=point_estimates,
_local_samples=locsamp)
# Test number of samples
......
......@@ -18,7 +18,7 @@
import numpy as np
import pytest
from mpi4py import MPI
from numpy.testing import assert_, assert_equal
from numpy.testing import assert_, assert_equal, assert_raises
import nifty7 as ift
......@@ -58,6 +58,10 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'n_samples': 2,
'mean': mean0,
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args, comm=comm)
return
if mode == 0:
kl0 = ift.MetricGaussianKL(**args, comm=comm)
locsamp = kl0._local_samples
......
......@@ -22,18 +22,18 @@ from ..common import setup_function, teardown_function
def test_simplification():
from nifty7.operators.operator import _ConstantOperator
from nifty7.operators.simplify_for_const import ConstantOperator
f1 = ift.Field.full(ift.RGSpace(10), 2.)
op = ift.FFTOperator(f1.domain)