Commit d05cb5ce authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce domain and target properties

parent e7df2052
...@@ -28,9 +28,17 @@ class Hamiltonian(Operator): ...@@ -28,9 +28,17 @@ class Hamiltonian(Operator):
def __init__(self, lh, ic_samp=None): def __init__(self, lh, ic_samp=None):
super(Hamiltonian, self).__init__() super(Hamiltonian, self).__init__()
self._lh = lh self._lh = lh
self._prior = GaussianEnergy() self._prior = GaussianEnergy(domain=lh.domain)
self._ic_samp = ic_samp self._ic_samp = ic_samp
@property
def domain(self):
return self._lh.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x): def __call__(self, x):
res = self._lh(x) + self._prior(x) res = self._lh(x) + self._prior(x)
if self._ic_samp is None: if self._ic_samp is None:
......
...@@ -34,6 +34,14 @@ class SampledKullbachLeiblerDivergence(Operator): ...@@ -34,6 +34,14 @@ class SampledKullbachLeiblerDivergence(Operator):
self._h = h self._h = h
self._res_samples = tuple(res_samples) self._res_samples = tuple(res_samples)
@property
def domain(self):
return self._h.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x): def __call__(self, x):
return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
(1./len(self._res_samples))) (1./len(self._res_samples)))
...@@ -133,6 +133,14 @@ class AmplitudeModel(Operator): ...@@ -133,6 +133,14 @@ class AmplitudeModel(Operator):
self._smooth_op = sym * qht * ceps self._smooth_op = sym * qht * ceps
self._keys = tuple(keys) self._keys = tuple(keys)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def __call__(self, x): def __call__(self, x):
smooth_spec = self._smooth_op(x[self._keys[0]]) smooth_spec = self._smooth_op(x[self._keys[0]])
phi = x[self._keys[1]] + self._norm_phi_mean phi = x[self._keys[1]] + self._norm_phi_mean
......
...@@ -30,6 +30,14 @@ class BernoulliEnergy(Operator): ...@@ -30,6 +30,14 @@ class BernoulliEnergy(Operator):
self._p = p self._p = p
self._d = d self._d = d
@property
def domain(self):
return self._p.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x): def __call__(self, x):
x = self._p(x) x = self._p(x)
v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum()
......
...@@ -25,11 +25,34 @@ from ..domain_tuple import DomainTuple ...@@ -25,11 +25,34 @@ from ..domain_tuple import DomainTuple
class GaussianEnergy(Operator): class GaussianEnergy(Operator):
def __init__(self, mean=None, covariance=None): def __init__(self, mean=None, covariance=None, domain=None):
super(GaussianEnergy, self).__init__() super(GaussianEnergy, self).__init__()
self._domain = None
if mean is not None:
self._checkEquivalence(mean.domain)
if covariance is not None:
self._checkEquivalence(covariance.domain)
if domain is not None:
self._checkEquivalence(domain)
if self._domain is None:
raise ValueError("no domain given")
self._mean = mean self._mean = mean
self._icov = None if covariance is None else covariance.inverse self._icov = None if covariance is None else covariance.inverse
self._target = DomainTuple.scalar_domain()
def _checkEquivalence(self, newdom):
if self._domain is None:
self._domain = newdom
else:
if self._domain is not newdom:
raise ValueError("domain mismatch")
@property
def domain(self):
return self._domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x): def __call__(self, x):
residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean
......
...@@ -32,6 +32,14 @@ class PoissonianEnergy(Operator): ...@@ -32,6 +32,14 @@ class PoissonianEnergy(Operator):
self._op = op self._op = op
self._d = d self._d = d
@property
def domain(self):
return self._op.domain
@property
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x): def __call__(self, x):
x = self._op(x) x = self._op(x)
res = (x - self._d*x.log()).sum() res = (x - self._d*x.log()).sum()
......
...@@ -13,6 +13,14 @@ class FieldAdapter(LinearOperator): ...@@ -13,6 +13,14 @@ class FieldAdapter(LinearOperator):
self._name = name_dom self._name = name_dom
self._target = dom[name_dom] self._target = dom[name_dom]
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property @property
def capability(self): def capability(self):
return self._all_ops return self._all_ops
......
...@@ -10,19 +10,19 @@ class Operator(NiftyMetaBase()): ...@@ -10,19 +10,19 @@ class Operator(NiftyMetaBase()):
domain, and can also provide the Jacobian. domain, and can also provide the Jacobian.
""" """
@property @abc.abstractproperty
def domain(self): def domain(self):
"""DomainTuple or MultiDomain : the operator's input domain """DomainTuple or MultiDomain : the operator's input domain
The domain on which the Operator's input Field lives.""" The domain on which the Operator's input Field lives."""
return self._domain raise NotImplementedError
@property @abc.abstractproperty
def target(self): def target(self):
"""DomainTuple or MultiDomain : the operator's output domain """DomainTuple or MultiDomain : the operator's output domain
The domain on which the Operator's output Field lives.""" The domain on which the Operator's output Field lives."""
return self._target raise NotImplementedError
def __matmul__(self, x): def __matmul__(self, x):
if not isinstance(x, Operator): if not isinstance(x, Operator):
...@@ -59,7 +59,7 @@ class Operator(NiftyMetaBase()): ...@@ -59,7 +59,7 @@ class Operator(NiftyMetaBase()):
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
def func(f): def func(f):
def func2(self): def func2(self):
fa = _FunctionApplier(self._target, f) fa = _FunctionApplier(self.target, f)
return _OpChain.make((fa, self)) return _OpChain.make((fa, self))
return func2 return func2
setattr(Operator, f, func(f)) setattr(Operator, f, func(f))
...@@ -68,9 +68,17 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: ...@@ -68,9 +68,17 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
class _FunctionApplier(Operator): class _FunctionApplier(Operator):
def __init__(self, domain, funcname): def __init__(self, domain, funcname):
from ..sugar import makeDomain from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain) self._domain = makeDomain(domain)
self._funcname = funcname self._funcname = funcname
@property
def domain(self):
return self._domain
@property
def target(self):
return self._domain
def __call__(self, x): def __call__(self, x):
return getattr(x, self._funcname)() return getattr(x, self._funcname)()
...@@ -101,8 +109,14 @@ class _CombinedOperator(Operator): ...@@ -101,8 +109,14 @@ class _CombinedOperator(Operator):
class _OpChain(_CombinedOperator): class _OpChain(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False): def __init__(self, ops, _callingfrommake=False):
super(_OpChain, self).__init__(ops, _callingfrommake) super(_OpChain, self).__init__(ops, _callingfrommake)
self._domain = self._ops[-1].domain
self._target = self._ops[0].target @property
def domain(self):
return self._ops[-1].domain
@property
def target(self):
return self._ops[0].target
def __call__(self, x): def __call__(self, x):
for op in reversed(self._ops): for op in reversed(self._ops):
...@@ -113,8 +127,14 @@ class _OpChain(_CombinedOperator): ...@@ -113,8 +127,14 @@ class _OpChain(_CombinedOperator):
class _OpProd(_CombinedOperator): class _OpProd(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False): def __init__(self, ops, _callingfrommake=False):
super(_OpProd, self).__init__(ops, _callingfrommake) super(_OpProd, self).__init__(ops, _callingfrommake)
self._domain = self._ops[0].domain
self._target = self._ops[0].target @property
def domain(self):
return self._ops[0].domain
@property
def target(self):
return self._ops[0].target
def __call__(self, x): def __call__(self, x):
from ..utilities import my_product from ..utilities import my_product
...@@ -127,5 +147,13 @@ class _OpSum(_CombinedOperator): ...@@ -127,5 +147,13 @@ class _OpSum(_CombinedOperator):
self._domain = domain_union([op.domain for op in self._ops]) self._domain = domain_union([op.domain for op in self._ops])
self._target = domain_union([op.target for op in self._ops]) self._target = domain_union([op.target for op in self._ops])
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def __call__(self, x): def __call__(self, x):
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment