diff --git a/nifty5/__init__.py b/nifty5/__init__.py index da3a6cf244fc709ca23c9653580cd42d9ad8b258..e4c782f32ea3203ce24451a4f35c33e7718e2f48 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -16,6 +16,7 @@ from .domains.log_rg_space import LogRGSpace from .domain_tuple import DomainTuple from .field import Field +from .operators.operator import Operator from .operators.central_zero_padder import CentralZeroPadder from .operators.diagonal_operator import DiagonalOperator from .operators.dof_distributor import DOFDistributor @@ -92,7 +93,6 @@ from .energies.kl import SampledKullbachLeiblerDivergence from .energies.hamiltonian import Hamiltonian from .energies.energy_adapter import EnergyAdapter -from .operator import Operator from .linearization import Linearization # We deliberately don't set __all__ here, because we don't want people to do a diff --git a/nifty5/energies/hamiltonian.py b/nifty5/energies/hamiltonian.py index 659cefb66ac7a0d6e9bda169bb7496245b9e07a7..1fc3867cdef409082512111ca92fe6dc1e3763bf 100644 --- a/nifty5/energies/hamiltonian.py +++ b/nifty5/energies/hamiltonian.py @@ -19,7 +19,7 @@ from __future__ import absolute_import, division, print_function from ..compat import * -from ..operator import Operator +from ..operators.operator import Operator from ..library.gaussian_energy import GaussianEnergy from ..operators.sampling_enabler import SamplingEnabler diff --git a/nifty5/energies/kl.py b/nifty5/energies/kl.py index 99fb600ed192e8698419a4e60baca529cc544104..d1fbf1d2dfcde1d27d6fc3cbc35fa58bbfc4b0ac 100644 --- a/nifty5/energies/kl.py +++ b/nifty5/energies/kl.py @@ -19,7 +19,7 @@ from __future__ import absolute_import, division, print_function from ..compat import * -from ..operator import Operator +from ..operators.operator import Operator from ..utilities import my_sum diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py index ef53fe22101a5474b0a269da5300caeedd32b471..846b3f721cb9b02d68567b4579a69f608ca3ec55 100644 --- a/nifty5/library/amplitude_model.py +++ b/nifty5/library/amplitude_model.py @@ -27,7 +27,7 @@ from ..field import Field from ..multi.multi_field import MultiField from ..multi.multi_domain import MultiDomain from ..sugar import makeOp, sqrt -from ..operator import Operator +from ..operators.operator import Operator def _ceps_kernel(dof_space, k, a, k0): diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py index cbf7739026cd8d98f431fc454bf1e86feeee52cd..39a6de3b503605c0160af335085fff83fdebefba 100644 --- a/nifty5/library/bernoulli_energy.py +++ b/nifty5/library/bernoulli_energy.py @@ -19,7 +19,7 @@ from __future__ import absolute_import, division, print_function from ..compat import * -from ..operator import Operator +from ..operators.operator import Operator from ..operators.sandwich_operator import SandwichOperator from ..sugar import makeOp diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index e7ad3fe58f056d88de25f1b058c05582cbe2dbd9..c89fb04fe319c1596e85b375d0058f3996b40eff 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -25,7 +25,7 @@ from ..multi.multi_domain import MultiDomain from ..operators.domain_distributor import DomainDistributor from ..operators.harmonic_transform_operator import HarmonicTransformOperator from ..operators.power_distributor import PowerDistributor -from ..operator import Operator +from ..operators.operator import Operator class CorrelatedField(Operator): diff --git a/nifty5/library/gaussian_energy.py b/nifty5/library/gaussian_energy.py index 68a4ee0b31b8a4572a4f126615936a925f9d0a6e..56a64ea8047ef76970ae152d0d8a117a3f34fbff 100644 --- a/nifty5/library/gaussian_energy.py +++ b/nifty5/library/gaussian_energy.py @@ -19,8 +19,9 @@ from __future__ import absolute_import, division, print_function from ..compat import * -from ..operator import Operator +from ..operators.operator import Operator from ..operators.sandwich_operator import SandwichOperator +from ..domain_tuple import DomainTuple class GaussianEnergy(Operator): @@ -28,6 +29,7 @@ class GaussianEnergy(Operator): super(GaussianEnergy, self).__init__() self._mean = mean self._icov = None if covariance is None else covariance.inverse + self._target = DomainTuple.scalar_domain() def __call__(self, x): residual = x if self._mean is None else x-self._mean diff --git a/nifty5/library/poissonian_energy.py b/nifty5/library/poissonian_energy.py index 5d8578014d10ea9d553e6a87978a2de812527825..1358ea653ffc23be8d1549ea6f6eac816b95983c 100644 --- a/nifty5/library/poissonian_energy.py +++ b/nifty5/library/poissonian_energy.py @@ -21,7 +21,7 @@ from __future__ import absolute_import, division, print_function from numpy import inf, isnan from ..compat import * -from ..operator import Operator +from ..operators.operator import Operator from ..operators.sandwich_operator import SandwichOperator from ..sugar import makeOp diff --git a/nifty5/operator.py b/nifty5/operator.py deleted file mode 100644 index 1fe232cfd2d3c8847f15ece5b7833a6ac2e27879..0000000000000000000000000000000000000000 --- a/nifty5/operator.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import absolute_import, division, print_function - -from .compat import * -from .utilities import NiftyMetaBase - - -class Operator(NiftyMetaBase()): - """Transforms values living on one domain into values living on another - domain, and can also provide the Jacobian. - """ - - def chain(self, x): - if not callable(x): - raise TypeError("callable needed") - ops1 = self._ops if isinstance(self, OpChain) else (self,) - ops2 = x._ops if isinstance(x, OpChain) else (x,) - return OpChain(ops1+ops2) - - def __call__(self, x): - """Returns transformed x - - Parameters - ---------- - x : Linearization - input - - Returns - ------- - Linearization - output - """ - raise NotImplementedError - - -class OpChain(Operator): - def __init__(self, ops): - self._ops = tuple(ops) - - def __call__(self, x): - for op in reversed(self._ops): - x = op(x) - return x diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index eac78bd8323a7c89a11e26a91990c26a2bf86287..46301f4f5a22e8f33a0000aeb206d3c27fa21a53 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -23,7 +23,7 @@ import abc import numpy as np from ..compat import * -from ..operator import Operator +from .operator import Operator class LinearOperator(Operator): @@ -86,21 +86,6 @@ class LinearOperator(Operator): def __init__(self): pass - @abc.abstractproperty - def domain(self): - # FIXME Adopt documentation to MultiDomains - """DomainTuple : the operator's input domain - - The domain on which the Operator's input Field lives.""" - raise NotImplementedError - - @abc.abstractproperty - def target(self): - """DomainTuple : the operator's output domain - - The domain on which the Operator's output Field lives.""" - raise NotImplementedError - def _flip_modes(self, trafo): from .operator_adapter import OperatorAdapter return self if trafo == 0 else OperatorAdapter(self, trafo) diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..99f54d4947e4ee23dea9fd356a521540cd255ab6 --- /dev/null +++ b/nifty5/operators/operator.py @@ -0,0 +1,108 @@ +from __future__ import absolute_import, division, print_function +import abc + +from ..compat import * +from ..utilities import NiftyMetaBase + + +class Operator(NiftyMetaBase()): + """Transforms values living on one domain into values living on another + domain, and can also provide the Jacobian. + """ + + def domain(self): + """DomainTuple or MultiDomain : the operator's input domain + + The domain on which the Operator's input Field lives.""" + return self._domain + + def target(self): + """DomainTuple or MultiDomain : the operator's output domain + + The domain on which the Operator's output Field lives.""" + return self._target + + def __matmul__(self, x): + if not isinstance(x, Operator): + return NotImplemented + return OpChain.make((self, x)) + ops1 = self._ops if isinstance(self, OpChain) else (self,) + ops2 = x._ops if isinstance(x, OpChain) else (x,) + return OpChain(ops1+ops2) + + def chain(self, x): + res = self.__matmul__(x) + if res == NotImplemented: + raise TypeError("operator expected") + return res + + def __call__(self, x): + """Returns transformed x + + Parameters + ---------- + x : Linearization + input + + Returns + ------- + Linearization + output + """ + raise NotImplementedError + + +class _CombinedOperator(Operator): + def __init__(self, ops, _callingfrommake=False): + if not _callingfrommake: + raise NotImplementedError + self._ops = tuple(ops) + + @classmethod + def unpack(cls, ops, res): + for op in ops: + if isinstance(op, cls): + res = cls.unpack(op, res) + else: + res = res + [op] + return res + + @classmethod + def make(cls, ops): + res = cls.unpack(ops, []) + if len(res) == 1: + return res[0] + return cls(res, _callingfrommake=True) + + +class _OpChain(_CombinedOperator): + def __init__(self, ops, _callingfrommake=False): + super(_OpChain, self).__init__(ops, _callingfrommake) + self._domain = self._ops[-1].domain + self._target = self._ops[0].target + + def __call__(self, x): + for op in reversed(self._ops): + x = op(x) + return x + + +class _OpProd(_CombinedOperator): + def __init__(self, ops, _callingfrommake=False): + super(_OpProd, self).__init__(ops, _callingfrommake) + self._domain = self._ops[0].domain + self._target = self._ops[0].target + + def __call__(self, x): + return my_prod(map(lambda op: op(x) for op in self._ops)) + + +class _OpSum(_CombinedOperator): + def __init__(self, ops, _callingfrommake=False): + super(_OpSum, self).__init__(ops, _callingfrommake) + self._domain = domain_union([op.domain for op in self._ops]) + self._target = domain_union([op.target for op in self._ops]) + + + def __call__(self, x): + raise NotImplementedError