diff --git a/nifty5/energies/hamiltonian.py b/nifty5/energies/hamiltonian.py index 1fc3867cdef409082512111ca92fe6dc1e3763bf..d5cb473828c09d45621472285cf7ab73c9396d8c 100644 --- a/nifty5/energies/hamiltonian.py +++ b/nifty5/energies/hamiltonian.py @@ -28,9 +28,17 @@ class Hamiltonian(Operator): def __init__(self, lh, ic_samp=None): super(Hamiltonian, self).__init__() self._lh = lh - self._prior = GaussianEnergy() + self._prior = GaussianEnergy(domain=lh.domain) self._ic_samp = ic_samp + @property + def domain(self): + return self._lh.domain + + @property + def target(self): + return DomainTuple.scalar_domain() + def __call__(self, x): res = self._lh(x) + self._prior(x) if self._ic_samp is None: diff --git a/nifty5/energies/kl.py b/nifty5/energies/kl.py index d1fbf1d2dfcde1d27d6fc3cbc35fa58bbfc4b0ac..6c9bb07fdca5f663139695530aff32a63f851d8b 100644 --- a/nifty5/energies/kl.py +++ b/nifty5/energies/kl.py @@ -34,6 +34,14 @@ class SampledKullbachLeiblerDivergence(Operator): self._h = h self._res_samples = tuple(res_samples) + @property + def domain(self): + return self._h.domain + + @property + def target(self): + return DomainTuple.scalar_domain() + def __call__(self, x): return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * (1./len(self._res_samples))) diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py index 3464bdce28824885f89409977db623ff0a6adbfb..54477381ac2b1dc142c93823069f523d2f40bce6 100644 --- a/nifty5/library/amplitude_model.py +++ b/nifty5/library/amplitude_model.py @@ -133,6 +133,14 @@ class AmplitudeModel(Operator): self._smooth_op = sym * qht * ceps self._keys = tuple(keys) + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + def __call__(self, x): smooth_spec = self._smooth_op(x[self._keys[0]]) phi = x[self._keys[1]] + self._norm_phi_mean diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py index 39a6de3b503605c0160af335085fff83fdebefba..78041592df677efb24d0c488a8f417ec2e4cc139 100644 --- a/nifty5/library/bernoulli_energy.py +++ b/nifty5/library/bernoulli_energy.py @@ -30,6 +30,14 @@ class BernoulliEnergy(Operator): self._p = p self._d = d + @property + def domain(self): + return self._p.domain + + @property + def target(self): + return DomainTuple.scalar_domain() + def __call__(self, x): x = self._p(x) v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() diff --git a/nifty5/library/gaussian_energy.py b/nifty5/library/gaussian_energy.py index 56a64ea8047ef76970ae152d0d8a117a3f34fbff..fbb9cf72457d1af997964ec1909ee35d5addabbd 100644 --- a/nifty5/library/gaussian_energy.py +++ b/nifty5/library/gaussian_energy.py @@ -25,11 +25,34 @@ from ..domain_tuple import DomainTuple class GaussianEnergy(Operator): - def __init__(self, mean=None, covariance=None): + def __init__(self, mean=None, covariance=None, domain=None): super(GaussianEnergy, self).__init__() + self._domain = None + if mean is not None: + self._checkEquivalence(mean.domain) + if covariance is not None: + self._checkEquivalence(covariance.domain) + if domain is not None: + self._checkEquivalence(domain) + if self._domain is None: + raise ValueError("no domain given") self._mean = mean self._icov = None if covariance is None else covariance.inverse - self._target = DomainTuple.scalar_domain() + + def _checkEquivalence(self, newdom): + if self._domain is None: + self._domain = newdom + else: + if self._domain is not newdom: + raise ValueError("domain mismatch") + + @property + def domain(self): + return self._domain + + @property + def target(self): + return DomainTuple.scalar_domain() def __call__(self, x): residual = x if self._mean is None else x-self._mean diff --git a/nifty5/library/poissonian_energy.py b/nifty5/library/poissonian_energy.py index 1358ea653ffc23be8d1549ea6f6eac816b95983c..14410dc5a4414ce455a8dc6b2bcb262fb4a1cd12 100644 --- a/nifty5/library/poissonian_energy.py +++ b/nifty5/library/poissonian_energy.py @@ -32,6 +32,14 @@ class PoissonianEnergy(Operator): self._op = op self._d = d + @property + def domain(self): + return self._op.domain + + @property + def target(self): + return DomainTuple.scalar_domain() + def __call__(self, x): x = self._op(x) res = (x - self._d*x.log()).sum() diff --git a/nifty5/operators/field_adapter.py b/nifty5/operators/field_adapter.py index 24abb0a0b03e2c7f7f7700d1d9f612b1a6a095b5..c9038ffbf3cf1c1a7d1ddcdf8057f033fc88856c 100644 --- a/nifty5/operators/field_adapter.py +++ b/nifty5/operators/field_adapter.py @@ -13,6 +13,14 @@ class FieldAdapter(LinearOperator): self._name = name_dom self._target = dom[name_dom] + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + @property def capability(self): return self._all_ops diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index c7ea21ff35605552b398149b1eeac8bf397ac52b..edb0d8a06cbc660376d53b2df471f82b1ceb0d38 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -10,19 +10,19 @@ class Operator(NiftyMetaBase()): domain, and can also provide the Jacobian. """ - @property + @abc.abstractproperty def domain(self): """DomainTuple or MultiDomain : the operator's input domain The domain on which the Operator's input Field lives.""" - return self._domain + raise NotImplementedError - @property + @abc.abstractproperty def target(self): """DomainTuple or MultiDomain : the operator's output domain The domain on which the Operator's output Field lives.""" - return self._target + raise NotImplementedError def __matmul__(self, x): if not isinstance(x, Operator): @@ -59,7 +59,7 @@ class Operator(NiftyMetaBase()): for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: def func(f): def func2(self): - fa = _FunctionApplier(self._target, f) + fa = _FunctionApplier(self.target, f) return _OpChain.make((fa, self)) return func2 setattr(Operator, f, func(f)) @@ -68,9 +68,17 @@ for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: class _FunctionApplier(Operator): def __init__(self, domain, funcname): from ..sugar import makeDomain - self._domain = self._target = makeDomain(domain) + self._domain = makeDomain(domain) self._funcname = funcname + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._domain + def __call__(self, x): return getattr(x, self._funcname)() @@ -101,8 +109,14 @@ class _CombinedOperator(Operator): class _OpChain(_CombinedOperator): def __init__(self, ops, _callingfrommake=False): super(_OpChain, self).__init__(ops, _callingfrommake) - self._domain = self._ops[-1].domain - self._target = self._ops[0].target + + @property + def domain(self): + return self._ops[-1].domain + + @property + def target(self): + return self._ops[0].target def __call__(self, x): for op in reversed(self._ops): @@ -113,8 +127,14 @@ class _OpChain(_CombinedOperator): class _OpProd(_CombinedOperator): def __init__(self, ops, _callingfrommake=False): super(_OpProd, self).__init__(ops, _callingfrommake) - self._domain = self._ops[0].domain - self._target = self._ops[0].target + + @property + def domain(self): + return self._ops[0].domain + + @property + def target(self): + return self._ops[0].target def __call__(self, x): from ..utilities import my_product @@ -127,5 +147,13 @@ class _OpSum(_CombinedOperator): self._domain = domain_union([op.domain for op in self._ops]) self._target = domain_union([op.target for op in self._ops]) + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + def __call__(self, x): raise NotImplementedError