Commit 6a5ec66d authored by Philipp Arras's avatar Philipp Arras

Add domain checks for operators

parent b55598c9
......@@ -132,6 +132,7 @@ class AmplitudeModel(Operator):
self._ceps = makeOp(sqrt(cepstrum))
def apply(self, x):
self._check_input(x)
smooth_spec = self._smooth_op(x[self._keys[0]])
phi = x[self._keys[1]] + self._norm_phi_mean
linear_spec = self._slope(phi)
......
......@@ -22,6 +22,7 @@ import numpy as np
from scipy.stats import invgamma, norm
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
......@@ -30,11 +31,12 @@ from ..sugar import makeOp
class InverseGammaModel(Operator):
def __init__(self, domain, alpha, q):
self._domain = self._target = domain
self._domain = self._target = DomainTuple.make(domain)
self._alpha = alpha
self._q = q
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
val = x.val.local_data if lin else x.local_data
# MR FIXME?!
......
......@@ -39,6 +39,7 @@ class SquaredNormOperator(EnergyOperator):
self._domain = domain
def apply(self, x):
self._check_input(x)
if isinstance(x, Linearization):
val = Field.scalar(x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac)
......@@ -55,6 +56,7 @@ class QuadraticFormOperator(EnergyOperator):
self._domain = op.domain
def apply(self, x):
self._check_input(x)
if isinstance(x, Linearization):
t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac)
......@@ -83,12 +85,13 @@ class GaussianEnergy(EnergyOperator):
def _checkEquivalence(self, newdom):
if self._domain is None:
self._domain = newdom
self._domain = DomainTuple.make(newdom)
else:
if self._domain != newdom:
if self._domain != DomainTuple.make(newdom):
raise ValueError("domain mismatch")
def apply(self, x):
self._check_input(x)
residual = x if self._mean is None else x-self._mean
res = self._op(residual).real
if not isinstance(x, Linearization) or not x.want_metric:
......@@ -103,6 +106,7 @@ class PoissonianEnergy(EnergyOperator):
self._domain = d.domain
def apply(self, x):
self._check_input(x)
x = self._op(x)
res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization):
......@@ -119,6 +123,7 @@ class InverseGammaLikelihood(EnergyOperator):
self._domain = d.domain
def apply(self, x):
self._check_input(x)
x = self._op(x)
res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
if not isinstance(x, Linearization):
......@@ -136,6 +141,7 @@ class BernoulliEnergy(EnergyOperator):
self._domain = d.domain
def apply(self, x):
self._check_input(x)
x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization):
......@@ -155,6 +161,7 @@ class Hamiltonian(EnergyOperator):
self._domain = lh.domain
def apply(self, x):
self._check_input(x)
if (self._ic_samp is None or not isinstance(x, Linearization) or
not x.want_metric):
return self._lh(x)+self._prior(x)
......@@ -177,5 +184,6 @@ class SampledKullbachLeiblerDivergence(EnergyOperator):
self._res_samples = tuple(res_samples)
def apply(self, x):
self._check_input(x)
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap) * (1./len(self._res_samples))
......@@ -59,6 +59,16 @@ class Operator(NiftyMetaBase()):
def apply(self, x):
raise NotImplementedError
def _check_input(self, x):
from ..linearization import Linearization
print('checkinput')
d = x.target if isinstance(x, Linearization) else x.domain
print(d)
print(self._domain)
print()
if self._domain != d:
raise ValueError("The operator's and field's domains don't match.")
def __call__(self, x):
if isinstance(x, Operator):
return _OpChain.make((self, x))
......@@ -84,6 +94,7 @@ class _FunctionApplier(Operator):
self._funcname = funcname
def apply(self, x):
self._check_input(x)
return getattr(x, self._funcname)()
......@@ -120,6 +131,7 @@ class _OpChain(_CombinedOperator):
raise ValueError("domain mismatch")
def apply(self, x):
self._check_input(x)
for op in reversed(self._ops):
x = op(x)
return x
......@@ -138,6 +150,7 @@ class _OpProd(Operator):
def apply(self, x):
from ..linearization import Linearization
from ..sugar import makeOp
self._check_input(x)
lin = isinstance(x, Linearization)
v = x._val if lin else x
v1 = v.extract(self._op1.domain)
......@@ -162,6 +175,7 @@ class _OpSum(Operator):
def apply(self, x):
from ..linearization import Linearization
self._check_input(x)
lin = isinstance(x, Linearization)
v = x._val if lin else x
v1 = v.extract(self._op1.domain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment