diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py index 7277808793bf8b304e0b132ce753b14ed329d0e0..09e2aa08a63e50eab8a4dcd00aeca6400633d1df 100644 --- a/nifty5/library/amplitude_model.py +++ b/nifty5/library/amplitude_model.py @@ -132,6 +132,7 @@ class AmplitudeModel(Operator): self._ceps = makeOp(sqrt(cepstrum)) def apply(self, x): + self._check_input(x) smooth_spec = self._smooth_op(x[self._keys[0]]) phi = x[self._keys[1]] + self._norm_phi_mean linear_spec = self._slope(phi) diff --git a/nifty5/library/inverse_gamma_model.py b/nifty5/library/inverse_gamma_model.py index 831546453062e734b0e708dd33245e5fd13304e2..5d95c3d46c38cbc1296cbfe699bc6c20b5cabcae 100644 --- a/nifty5/library/inverse_gamma_model.py +++ b/nifty5/library/inverse_gamma_model.py @@ -22,6 +22,7 @@ import numpy as np from scipy.stats import invgamma, norm from ..compat import * +from ..domain_tuple import DomainTuple from ..field import Field from ..linearization import Linearization from ..operators.operator import Operator @@ -30,11 +31,12 @@ from ..sugar import makeOp class InverseGammaModel(Operator): def __init__(self, domain, alpha, q): - self._domain = self._target = domain + self._domain = self._target = DomainTuple.make(domain) self._alpha = alpha self._q = q def apply(self, x): + self._check_input(x) lin = isinstance(x, Linearization) val = x.val.local_data if lin else x.local_data # MR FIXME?! diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py index dbde8b1db46cbd7da0f3ffff6a8427c5ebcc1c11..55ccdc95ed5cff2b7df29daa2131c7d82c7cb827 100644 --- a/nifty5/operators/energy_operators.py +++ b/nifty5/operators/energy_operators.py @@ -39,6 +39,7 @@ class SquaredNormOperator(EnergyOperator): self._domain = domain def apply(self, x): + self._check_input(x) if isinstance(x, Linearization): val = Field.scalar(x.val.vdot(x.val)) jac = VdotOperator(2*x.val)(x.jac) @@ -55,6 +56,7 @@ class QuadraticFormOperator(EnergyOperator): self._domain = op.domain def apply(self, x): + self._check_input(x) if isinstance(x, Linearization): t1 = self._op(x.val) jac = VdotOperator(t1)(x.jac) @@ -83,12 +85,13 @@ class GaussianEnergy(EnergyOperator): def _checkEquivalence(self, newdom): if self._domain is None: - self._domain = newdom + self._domain = DomainTuple.make(newdom) else: - if self._domain != newdom: + if self._domain != DomainTuple.make(newdom): raise ValueError("domain mismatch") def apply(self, x): + self._check_input(x) residual = x if self._mean is None else x-self._mean res = self._op(residual).real if not isinstance(x, Linearization) or not x.want_metric: @@ -103,6 +106,7 @@ class PoissonianEnergy(EnergyOperator): self._domain = d.domain def apply(self, x): + self._check_input(x) x = self._op(x) res = x.sum() - x.log().vdot(self._d) if not isinstance(x, Linearization): @@ -119,6 +123,7 @@ class InverseGammaLikelihood(EnergyOperator): self._domain = d.domain def apply(self, x): + self._check_input(x) x = self._op(x) res = 0.5*(x.log().sum() + (1./x).vdot(self._d)) if not isinstance(x, Linearization): @@ -136,6 +141,7 @@ class BernoulliEnergy(EnergyOperator): self._domain = d.domain def apply(self, x): + self._check_input(x) x = self._p(x) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) if not isinstance(x, Linearization): @@ -155,6 +161,7 @@ class Hamiltonian(EnergyOperator): self._domain = lh.domain def apply(self, x): + self._check_input(x) if (self._ic_samp is None or not isinstance(x, Linearization) or not x.want_metric): return self._lh(x)+self._prior(x) @@ -177,5 +184,6 @@ class SampledKullbachLeiblerDivergence(EnergyOperator): self._res_samples = tuple(res_samples) def apply(self, x): + self._check_input(x) mymap = map(lambda v: self._h(x+v), self._res_samples) return utilities.my_sum(mymap) * (1./len(self._res_samples)) diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 13b7e6fb7bfb1e27f12e3f292744b239f868b8b5..4fc611a6adedb487fea853e8340129d3e57a59c2 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -59,6 +59,16 @@ class Operator(NiftyMetaBase()): def apply(self, x): raise NotImplementedError + def _check_input(self, x): + from ..linearization import Linearization + print('checkinput') + d = x.target if isinstance(x, Linearization) else x.domain + print(d) + print(self._domain) + print() + if self._domain != d: + raise ValueError("The operator's and field's domains don't match.") + def __call__(self, x): if isinstance(x, Operator): return _OpChain.make((self, x)) @@ -84,6 +94,7 @@ class _FunctionApplier(Operator): self._funcname = funcname def apply(self, x): + self._check_input(x) return getattr(x, self._funcname)() @@ -120,6 +131,7 @@ class _OpChain(_CombinedOperator): raise ValueError("domain mismatch") def apply(self, x): + self._check_input(x) for op in reversed(self._ops): x = op(x) return x @@ -138,6 +150,7 @@ class _OpProd(Operator): def apply(self, x): from ..linearization import Linearization from ..sugar import makeOp + self._check_input(x) lin = isinstance(x, Linearization) v = x._val if lin else x v1 = v.extract(self._op1.domain) @@ -162,6 +175,7 @@ class _OpSum(Operator): def apply(self, x): from ..linearization import Linearization + self._check_input(x) lin = isinstance(x, Linearization) v = x._val if lin else x v1 = v.extract(self._op1.domain)