Commit 21024ed0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'check_inputs_for_operators' into 'NIFTy_5'

Check inputs for operators

See merge request ift/nifty-dev!104
parents b55598c9 5b54f5d3
...@@ -65,7 +65,7 @@ if __name__ == '__main__': ...@@ -65,7 +65,7 @@ if __name__ == '__main__':
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
position = ift.from_random('normal', harmonic_space) position = ift.from_random('normal', harmonic_space)
likelihood = ift.BernoulliEnergy(p, data) likelihood = ift.BernoulliEnergy(data)(p)
ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100, ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100,
tol_rel_deltaE=1e-8) tol_rel_deltaE=1e-8)
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
......
...@@ -86,9 +86,9 @@ if __name__ == '__main__': ...@@ -86,9 +86,9 @@ if __name__ == '__main__':
data = ift.Field.from_global_data(d_space, data) data = ift.Field.from_global_data(d_space, data)
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
likelihood = ift.PoissonianEnergy(lamb, data)
ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100, ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100,
tol_rel_deltaE=1e-8) tol_rel_deltaE=1e-8)
likelihood = ift.PoissonianEnergy(data)(lamb)
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# Minimize the Hamiltonian # Minimize the Hamiltonian
......
...@@ -85,7 +85,7 @@ if __name__ == '__main__': ...@@ -85,7 +85,7 @@ if __name__ == '__main__':
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(MOCK_POSITION), title='Ground Truth') plot.add(signal(MOCK_POSITION), title='Ground Truth')
plot.add(R.adjoint_times(data), title='Data') plot.add(R.adjoint_times(data), title='Data')
plot.add([A(MOCK_POSITION)], title='Power Spectrum') plot.add([A(MOCK_POSITION.extract(A.domain))], title='Power Spectrum')
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png")
# number of samples used to estimate the KL # number of samples used to estimate the KL
...@@ -97,18 +97,24 @@ if __name__ == '__main__': ...@@ -97,18 +97,24 @@ if __name__ == '__main__':
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(KL.position), title="reconstruction") plot.add(signal(KL.position), title="reconstruction")
plot.add([A(KL.position), A(MOCK_POSITION)], title="power") plot.add(
[
A(KL.position.extract(A.domain)),
A(MOCK_POSITION.extract(A.domain))
],
title="power")
plot.output(ny=1, ysize=6, xsize=16, name="loop.png") plot.output(ny=1, ysize=6, xsize=16, name="loop.png")
plot = ift.Plot() plot = ift.Plot()
sc = ift.StatCalculator() sc = ift.StatCalculator()
for sample in KL.samples: for sample in KL.samples:
sc.add(signal(sample+KL.position)) sc.add(signal(sample + KL.position))
plot.add(sc.mean, title="Posterior Mean") plot.add(sc.mean, title="Posterior Mean")
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers = [A(s+KL.position) for s in KL.samples] powers = [A((s + KL.position).extract(A.domain)) for s in KL.samples]
plot.add( plot.add(
[A(KL.position), A(MOCK_POSITION)]+powers, [A(KL.position.extract(A.domain)),
A(MOCK_POSITION.extract(A.domain))] + powers,
title="Sampled Posterior Power Spectrum") title="Sampled Posterior Power Spectrum")
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png")
...@@ -66,4 +66,4 @@ def make_adjust_variances(a, xi, position, samples=[], scaling=None, ...@@ -66,4 +66,4 @@ def make_adjust_variances(a, xi, position, samples=[], scaling=None,
if scaling is not None: if scaling is not None:
x = ScalingOperator(scaling, x.target)(x) x = ScalingOperator(scaling, x.target)(x)
return Hamiltonian(InverseGammaLikelihood(x, d_eval), ic_samp=ic_samp) return Hamiltonian(InverseGammaLikelihood(d_eval)(x), ic_samp=ic_samp)
...@@ -132,6 +132,7 @@ class AmplitudeModel(Operator): ...@@ -132,6 +132,7 @@ class AmplitudeModel(Operator):
self._ceps = makeOp(sqrt(cepstrum)) self._ceps = makeOp(sqrt(cepstrum))
def apply(self, x): def apply(self, x):
self._check_input(x)
smooth_spec = self._smooth_op(x[self._keys[0]]) smooth_spec = self._smooth_op(x[self._keys[0]])
phi = x[self._keys[1]] + self._norm_phi_mean phi = x[self._keys[1]] + self._norm_phi_mean
linear_spec = self._slope(phi) linear_spec = self._slope(phi)
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from scipy.stats import invgamma, norm from scipy.stats import invgamma, norm
from ..compat import * from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field from ..field import Field
from ..linearization import Linearization from ..linearization import Linearization
from ..operators.operator import Operator from ..operators.operator import Operator
...@@ -30,11 +31,12 @@ from ..sugar import makeOp ...@@ -30,11 +31,12 @@ from ..sugar import makeOp
class InverseGammaModel(Operator): class InverseGammaModel(Operator):
def __init__(self, domain, alpha, q): def __init__(self, domain, alpha, q):
self._domain = self._target = domain self._domain = self._target = DomainTuple.make(domain)
self._alpha = alpha self._alpha = alpha
self._q = q self._q = q
def apply(self, x): def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization) lin = isinstance(x, Linearization)
val = x.val.local_data if lin else x.local_data val = x.val.local_data if lin else x.local_data
# MR FIXME?! # MR FIXME?!
......
...@@ -28,6 +28,7 @@ from .operator import Operator ...@@ -28,6 +28,7 @@ from .operator import Operator
from .sampling_enabler import SamplingEnabler from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator from .simple_linear_operators import VdotOperator
from ..sugar import makeDomain
class EnergyOperator(Operator): class EnergyOperator(Operator):
...@@ -39,6 +40,7 @@ class SquaredNormOperator(EnergyOperator): ...@@ -39,6 +40,7 @@ class SquaredNormOperator(EnergyOperator):
self._domain = domain self._domain = domain
def apply(self, x): def apply(self, x):
self._check_input(x)
if isinstance(x, Linearization): if isinstance(x, Linearization):
val = Field.scalar(x.val.vdot(x.val)) val = Field.scalar(x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac) jac = VdotOperator(2*x.val)(x.jac)
...@@ -55,6 +57,7 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -55,6 +57,7 @@ class QuadraticFormOperator(EnergyOperator):
self._domain = op.domain self._domain = op.domain
def apply(self, x): def apply(self, x):
self._check_input(x)
if isinstance(x, Linearization): if isinstance(x, Linearization):
t1 = self._op(x.val) t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac) jac = VdotOperator(t1)(x.jac)
...@@ -82,6 +85,7 @@ class GaussianEnergy(EnergyOperator): ...@@ -82,6 +85,7 @@ class GaussianEnergy(EnergyOperator):
self._icov = None if covariance is None else covariance.inverse self._icov = None if covariance is None else covariance.inverse
def _checkEquivalence(self, newdom): def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom)
if self._domain is None: if self._domain is None:
self._domain = newdom self._domain = newdom
else: else:
...@@ -89,6 +93,7 @@ class GaussianEnergy(EnergyOperator): ...@@ -89,6 +93,7 @@ class GaussianEnergy(EnergyOperator):
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
def apply(self, x): def apply(self, x):
self._check_input(x)
residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean
res = self._op(residual).real res = self._op(residual).real
if not isinstance(x, Linearization) or not x.want_metric: if not isinstance(x, Linearization) or not x.want_metric:
...@@ -98,12 +103,12 @@ class GaussianEnergy(EnergyOperator): ...@@ -98,12 +103,12 @@ class GaussianEnergy(EnergyOperator):
class PoissonianEnergy(EnergyOperator): class PoissonianEnergy(EnergyOperator):
def __init__(self, op, d): def __init__(self, d):
self._op, self._d = op, d self._d = d
self._domain = d.domain self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x):
x = self._op(x) self._check_input(x)
res = x.sum() - x.log().vdot(self._d) res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
...@@ -114,12 +119,12 @@ class PoissonianEnergy(EnergyOperator): ...@@ -114,12 +119,12 @@ class PoissonianEnergy(EnergyOperator):
class InverseGammaLikelihood(EnergyOperator): class InverseGammaLikelihood(EnergyOperator):
def __init__(self, op, d): def __init__(self, d):
self._op, self._d = op, d self._d = d
self._domain = d.domain self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x):
x = self._op(x) self._check_input(x)
res = 0.5*(x.log().sum() + (1./x).vdot(self._d)) res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
...@@ -130,13 +135,12 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -130,13 +135,12 @@ class InverseGammaLikelihood(EnergyOperator):
class BernoulliEnergy(EnergyOperator): class BernoulliEnergy(EnergyOperator):
def __init__(self, p, d): def __init__(self, d):
self._p = p
self._d = d self._d = d
self._domain = d.domain self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x):
x = self._p(x) self._check_input(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(v) return Field.scalar(v)
...@@ -155,6 +159,7 @@ class Hamiltonian(EnergyOperator): ...@@ -155,6 +159,7 @@ class Hamiltonian(EnergyOperator):
self._domain = lh.domain self._domain = lh.domain
def apply(self, x): def apply(self, x):
self._check_input(x)
if (self._ic_samp is None or not isinstance(x, Linearization) or if (self._ic_samp is None or not isinstance(x, Linearization) or
not x.want_metric): not x.want_metric):
return self._lh(x)+self._prior(x) return self._lh(x)+self._prior(x)
...@@ -177,5 +182,6 @@ class SampledKullbachLeiblerDivergence(EnergyOperator): ...@@ -177,5 +182,6 @@ class SampledKullbachLeiblerDivergence(EnergyOperator):
self._res_samples = tuple(res_samples) self._res_samples = tuple(res_samples)
def apply(self, x): def apply(self, x):
self._check_input(x)
mymap = map(lambda v: self._h(x+v), self._res_samples) mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap) * (1./len(self._res_samples)) return utilities.my_sum(mymap) * (1./len(self._res_samples))
...@@ -59,6 +59,12 @@ class Operator(NiftyMetaBase()): ...@@ -59,6 +59,12 @@ class Operator(NiftyMetaBase()):
def apply(self, x): def apply(self, x):
raise NotImplementedError raise NotImplementedError
def _check_input(self, x):
from ..linearization import Linearization
d = x.target if isinstance(x, Linearization) else x.domain
if self._domain != d:
raise ValueError("The operator's and field's domains don't match.")
def __call__(self, x): def __call__(self, x):
if isinstance(x, Operator): if isinstance(x, Operator):
return _OpChain.make((self, x)) return _OpChain.make((self, x))
...@@ -84,6 +90,7 @@ class _FunctionApplier(Operator): ...@@ -84,6 +90,7 @@ class _FunctionApplier(Operator):
self._funcname = funcname self._funcname = funcname
def apply(self, x): def apply(self, x):
self._check_input(x)
return getattr(x, self._funcname)() return getattr(x, self._funcname)()
...@@ -120,6 +127,7 @@ class _OpChain(_CombinedOperator): ...@@ -120,6 +127,7 @@ class _OpChain(_CombinedOperator):
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
def apply(self, x): def apply(self, x):
self._check_input(x)
for op in reversed(self._ops): for op in reversed(self._ops):
x = op(x) x = op(x)
return x return x
...@@ -138,6 +146,7 @@ class _OpProd(Operator): ...@@ -138,6 +146,7 @@ class _OpProd(Operator):
def apply(self, x): def apply(self, x):
from ..linearization import Linearization from ..linearization import Linearization
from ..sugar import makeOp from ..sugar import makeOp
self._check_input(x)
lin = isinstance(x, Linearization) lin = isinstance(x, Linearization)
v = x._val if lin else x v = x._val if lin else x
v1 = v.extract(self._op1.domain) v1 = v.extract(self._op1.domain)
...@@ -162,6 +171,7 @@ class _OpSum(Operator): ...@@ -162,6 +171,7 @@ class _OpSum(Operator):
def apply(self, x): def apply(self, x):
from ..linearization import Linearization from ..linearization import Linearization
self._check_input(x)
lin = isinstance(x, Linearization) lin = isinstance(x, Linearization)
v = x._val if lin else x v = x._val if lin else x
v1 = v.extract(self._op1.domain) v1 = v.extract(self._op1.domain)
......
...@@ -64,9 +64,10 @@ class Energy_Tests(unittest.TestCase): ...@@ -64,9 +64,10 @@ class Energy_Tests(unittest.TestCase):
], [4, 78, 23])) ], [4, 78, 23]))
def testInverseGammaLikelihood(self, space, seed): def testInverseGammaLikelihood(self, space, seed):
model = self.make_model(space_key='s1', space=space, seed=seed)['s1'] model = self.make_model(space_key='s1', space=space, seed=seed)['s1']
model = model.exp()
d = np.random.normal(10, size=space.shape)**2 d = np.random.normal(10, size=space.shape)**2
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.InverseGammaLikelihood(ift.exp, d) energy = ift.InverseGammaLikelihood(d)
ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7) ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7)
@expand(product( @expand(product(
...@@ -78,9 +79,10 @@ class Energy_Tests(unittest.TestCase): ...@@ -78,9 +79,10 @@ class Energy_Tests(unittest.TestCase):
def testPoissonian(self, space, seed): def testPoissonian(self, space, seed):
model = self.make_model( model = self.make_model(
space_key='s1', space=space, seed=seed)['s1'] space_key='s1', space=space, seed=seed)['s1']
model = model.exp()
d = np.random.poisson(120, size=space.shape) d = np.random.poisson(120, size=space.shape)
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.PoissonianEnergy(ift.exp, d) energy = ift.PoissonianEnergy(d)
ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7) ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7)
@expand(product( @expand(product(
...@@ -113,5 +115,5 @@ class Energy_Tests(unittest.TestCase): ...@@ -113,5 +115,5 @@ class Energy_Tests(unittest.TestCase):
model = model.positive_tanh() model = model.positive_tanh()
d = np.random.binomial(1, 0.1, size=space.shape) d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(ift.positive_tanh, d) energy = ift.BernoulliEnergy(d)
ift.extra.check_value_gradient_consistency(energy, model, tol=2e-7) ift.extra.check_value_gradient_consistency(energy, model, tol=1e-6)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment