Commit d05cb5ce by Martin Reinecke

### introduce domain and target properties

parent e7df2052
 ... @@ -28,9 +28,17 @@ class Hamiltonian(Operator): ... @@ -28,9 +28,17 @@ class Hamiltonian(Operator): def __init__(self, lh, ic_samp=None): def __init__(self, lh, ic_samp=None): super(Hamiltonian, self).__init__() super(Hamiltonian, self).__init__() self._lh = lh self._lh = lh self._prior = GaussianEnergy() self._prior = GaussianEnergy(domain=lh.domain) self._ic_samp = ic_samp self._ic_samp = ic_samp @property def domain(self): return self._lh.domain @property def target(self): return DomainTuple.scalar_domain() def __call__(self, x): def __call__(self, x): res = self._lh(x) + self._prior(x) res = self._lh(x) + self._prior(x) if self._ic_samp is None: if self._ic_samp is None: ... ...
 ... @@ -34,6 +34,14 @@ class SampledKullbachLeiblerDivergence(Operator): ... @@ -34,6 +34,14 @@ class SampledKullbachLeiblerDivergence(Operator): self._h = h self._h = h self._res_samples = tuple(res_samples) self._res_samples = tuple(res_samples) @property def domain(self): return self._h.domain @property def target(self): return DomainTuple.scalar_domain() def __call__(self, x): def __call__(self, x): return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * (1./len(self._res_samples))) (1./len(self._res_samples)))
 ... @@ -133,6 +133,14 @@ class AmplitudeModel(Operator): ... @@ -133,6 +133,14 @@ class AmplitudeModel(Operator): self._smooth_op = sym * qht * ceps self._smooth_op = sym * qht * ceps self._keys = tuple(keys) self._keys = tuple(keys) @property def domain(self): return self._domain @property def target(self): return self._target def __call__(self, x): def __call__(self, x): smooth_spec = self._smooth_op(x[self._keys[0]]) smooth_spec = self._smooth_op(x[self._keys[0]]) phi = x[self._keys[1]] + self._norm_phi_mean phi = x[self._keys[1]] + self._norm_phi_mean ... ...
 ... @@ -30,6 +30,14 @@ class BernoulliEnergy(Operator): ... @@ -30,6 +30,14 @@ class BernoulliEnergy(Operator): self._p = p self._p = p self._d = d self._d = d @property def domain(self): return self._p.domain @property def target(self): return DomainTuple.scalar_domain() def __call__(self, x): def __call__(self, x): x = self._p(x) x = self._p(x) v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() ... ...
 ... @@ -25,11 +25,34 @@ from ..domain_tuple import DomainTuple ... @@ -25,11 +25,34 @@ from ..domain_tuple import DomainTuple class GaussianEnergy(Operator): class GaussianEnergy(Operator): def __init__(self, mean=None, covariance=None): def __init__(self, mean=None, covariance=None, domain=None): super(GaussianEnergy, self).__init__() super(GaussianEnergy, self).__init__() self._domain = None if mean is not None: self._checkEquivalence(mean.domain) if covariance is not None: self._checkEquivalence(covariance.domain) if domain is not None: self._checkEquivalence(domain) if self._domain is None: raise ValueError("no domain given") self._mean = mean self._mean = mean self._icov = None if covariance is None else covariance.inverse self._icov = None if covariance is None else covariance.inverse self._target = DomainTuple.scalar_domain() def _checkEquivalence(self, newdom): if self._domain is None: self._domain = newdom else: if self._domain is not newdom: raise ValueError("domain mismatch") @property def domain(self): return self._domain @property def target(self): return DomainTuple.scalar_domain() def __call__(self, x): def __call__(self, x): residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean ... ...
 ... @@ -32,6 +32,14 @@ class PoissonianEnergy(Operator): ... @@ -32,6 +32,14 @@ class PoissonianEnergy(Operator): self._op = op self._op = op self._d = d self._d = d @property def domain(self): return self._op.domain @property def target(self): return DomainTuple.scalar_domain() def __call__(self, x): def __call__(self, x): x = self._op(x) x = self._op(x) res = (x - self._d*x.log()).sum() res = (x - self._d*x.log()).sum() ... ...
 ... @@ -13,6 +13,14 @@ class FieldAdapter(LinearOperator): ... @@ -13,6 +13,14 @@ class FieldAdapter(LinearOperator): self._name = name_dom self._name = name_dom self._target = dom[name_dom] self._target = dom[name_dom] @property def domain(self): return self._domain @property def target(self): return self._target @property @property def capability(self): def capability(self): return self._all_ops return self._all_ops ... ...
 ... @@ -10,19 +10,19 @@ class Operator(NiftyMetaBase()): ... @@ -10,19 +10,19 @@ class Operator(NiftyMetaBase()): domain, and can also provide the Jacobian. domain, and can also provide the Jacobian. """ """ @property @abc.abstractproperty def domain(self): def domain(self): """DomainTuple or MultiDomain : the operator's input domain """DomainTuple or MultiDomain : the operator's input domain The domain on which the Operator's input Field lives.""" The domain on which the Operator's input Field lives.""" return self._domain raise NotImplementedError @property @abc.abstractproperty def target(self): def target(self): """DomainTuple or MultiDomain : the operator's output domain """DomainTuple or MultiDomain : the operator's output domain The domain on which the Operator's output Field lives.""" The domain on which the Operator's output Field lives.""" return self._target raise NotImplementedError def __matmul__(self, x): def __matmul__(self, x): if not isinstance(x, Operator): if not isinstance(x, Operator): ... @@ -59,7 +59,7 @@ class Operator(NiftyMetaBase()): ... @@ -59,7 +59,7 @@ class Operator(NiftyMetaBase()): for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: def func(f): def func(f): def func2(self): def func2(self): fa = _FunctionApplier(self._target, f) fa = _FunctionApplier(self.target, f) return _OpChain.make((fa, self)) return _OpChain.make((fa, self)) return func2 return func2 setattr(Operator, f, func(f)) setattr(Operator, f, func(f)) ... @@ -68,9 +68,17 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: ... @@ -68,9 +68,17 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: class _FunctionApplier(Operator): class _FunctionApplier(Operator): def __init__(self, domain, funcname): def __init__(self, domain, funcname): from ..sugar import makeDomain from ..sugar import makeDomain self._domain = self._target = makeDomain(domain) self._domain = makeDomain(domain) self._funcname = funcname self._funcname = funcname @property def domain(self): return self._domain @property def target(self): return self._domain def __call__(self, x): def __call__(self, x): return getattr(x, self._funcname)() return getattr(x, self._funcname)() ... @@ -101,8 +109,14 @@ class _CombinedOperator(Operator): ... @@ -101,8 +109,14 @@ class _CombinedOperator(Operator): class _OpChain(_CombinedOperator): class _OpChain(_CombinedOperator): def __init__(self, ops, _callingfrommake=False): def __init__(self, ops, _callingfrommake=False): super(_OpChain, self).__init__(ops, _callingfrommake) super(_OpChain, self).__init__(ops, _callingfrommake) self._domain = self._ops[-1].domain self._target = self._ops[0].target @property def domain(self): return self._ops[-1].domain @property def target(self): return self._ops[0].target def __call__(self, x): def __call__(self, x): for op in reversed(self._ops): for op in reversed(self._ops): ... @@ -113,8 +127,14 @@ class _OpChain(_CombinedOperator): ... @@ -113,8 +127,14 @@ class _OpChain(_CombinedOperator): class _OpProd(_CombinedOperator): class _OpProd(_CombinedOperator): def __init__(self, ops, _callingfrommake=False): def __init__(self, ops, _callingfrommake=False): super(_OpProd, self).__init__(ops, _callingfrommake) super(_OpProd, self).__init__(ops, _callingfrommake) self._domain = self._ops[0].domain self._target = self._ops[0].target @property def domain(self): return self._ops[0].domain @property def target(self): return self._ops[0].target def __call__(self, x): def __call__(self, x): from ..utilities import my_product from ..utilities import my_product ... @@ -127,5 +147,13 @@ class _OpSum(_CombinedOperator): ... @@ -127,5 +147,13 @@ class _OpSum(_CombinedOperator): self._domain = domain_union([op.domain for op in self._ops]) self._domain = domain_union([op.domain for op in self._ops]) self._target = domain_union([op.target for op in self._ops]) self._target = domain_union([op.target for op in self._ops]) @property def domain(self): return self._domain @property def target(self): return self._target def __call__(self, x): def __call__(self, x): raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!