Commit d05cb5ce authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce domain and target properties

parent e7df2052
......@@ -28,9 +28,17 @@ class Hamiltonian(Operator):
def __init__(self, lh, ic_samp=None):
super(Hamiltonian, self).__init__()
self._lh = lh
self._prior = GaussianEnergy()
self._prior = GaussianEnergy(domain=lh.domain)
self._ic_samp = ic_samp
@property
def domain(self):
return self._lh.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
res = self._lh(x) + self._prior(x)
if self._ic_samp is None:
......
......@@ -34,6 +34,14 @@ class SampledKullbachLeiblerDivergence(Operator):
self._h = h
self._res_samples = tuple(res_samples)
@property
def domain(self):
return self._h.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
(1./len(self._res_samples)))
......@@ -133,6 +133,14 @@ class AmplitudeModel(Operator):
self._smooth_op = sym * qht * ceps
self._keys = tuple(keys)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def __call__(self, x):
smooth_spec = self._smooth_op(x[self._keys[0]])
phi = x[self._keys[1]] + self._norm_phi_mean
......
......@@ -30,6 +30,14 @@ class BernoulliEnergy(Operator):
self._p = p
self._d = d
@property
def domain(self):
return self._p.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
x = self._p(x)
v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum()
......
......@@ -25,11 +25,34 @@ from ..domain_tuple import DomainTuple
class GaussianEnergy(Operator):
def __init__(self, mean=None, covariance=None):
def __init__(self, mean=None, covariance=None, domain=None):
super(GaussianEnergy, self).__init__()
self._domain = None
if mean is not None:
self._checkEquivalence(mean.domain)
if covariance is not None:
self._checkEquivalence(covariance.domain)
if domain is not None:
self._checkEquivalence(domain)
if self._domain is None:
raise ValueError("no domain given")
self._mean = mean
self._icov = None if covariance is None else covariance.inverse
self._target = DomainTuple.scalar_domain()
def _checkEquivalence(self, newdom):
if self._domain is None:
self._domain = newdom
else:
if self._domain is not newdom:
raise ValueError("domain mismatch")
@property
def domain(self):
return self._domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
residual = x if self._mean is None else x-self._mean
......
......@@ -32,6 +32,14 @@ class PoissonianEnergy(Operator):
self._op = op
self._d = d
@property
def domain(self):
return self._op.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
x = self._op(x)
res = (x - self._d*x.log()).sum()
......
......@@ -13,6 +13,14 @@ class FieldAdapter(LinearOperator):
self._name = name_dom
self._target = dom[name_dom]
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self._all_ops
......
......@@ -10,19 +10,19 @@ class Operator(NiftyMetaBase()):
domain, and can also provide the Jacobian.
"""
@property
@abc.abstractproperty
def domain(self):
"""DomainTuple or MultiDomain : the operator's input domain
The domain on which the Operator's input Field lives."""
return self._domain
raise NotImplementedError
@property
@abc.abstractproperty
def target(self):
"""DomainTuple or MultiDomain : the operator's output domain
The domain on which the Operator's output Field lives."""
return self._target
raise NotImplementedError
def __matmul__(self, x):
if not isinstance(x, Operator):
......@@ -59,7 +59,7 @@ class Operator(NiftyMetaBase()):
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
def func(f):
def func2(self):
fa = _FunctionApplier(self._target, f)
fa = _FunctionApplier(self.target, f)
return _OpChain.make((fa, self))
return func2
setattr(Operator, f, func(f))
......@@ -68,9 +68,17 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._domain = makeDomain(domain)
self._funcname = funcname
@property
def domain(self):
return self._domain
@property
def target(self):
return self._domain
def __call__(self, x):
return getattr(x, self._funcname)()
......@@ -101,8 +109,14 @@ class _CombinedOperator(Operator):
class _OpChain(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpChain, self).__init__(ops, _callingfrommake)
self._domain = self._ops[-1].domain
self._target = self._ops[0].target
@property
def domain(self):
return self._ops[-1].domain
@property
def target(self):
return self._ops[0].target
def __call__(self, x):
for op in reversed(self._ops):
......@@ -113,8 +127,14 @@ class _OpChain(_CombinedOperator):
class _OpProd(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpProd, self).__init__(ops, _callingfrommake)
self._domain = self._ops[0].domain
self._target = self._ops[0].target
@property
def domain(self):
return self._ops[0].domain
@property
def target(self):
return self._ops[0].target
def __call__(self, x):
from ..utilities import my_product
......@@ -127,5 +147,13 @@ class _OpSum(_CombinedOperator):
self._domain = domain_union([op.domain for op in self._ops])
self._target = domain_union([op.target for op in self._ops])
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def __call__(self, x):
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment