Commit 83b0eccc authored by Martin Reinecke's avatar Martin Reinecke

move Operator

parent 73a85004
......@@ -16,6 +16,7 @@ from .domains.log_rg_space import LogRGSpace
from .domain_tuple import DomainTuple
from .field import Field
from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.dof_distributor import DOFDistributor
......@@ -92,7 +93,6 @@ from .energies.kl import SampledKullbachLeiblerDivergence
from .energies.hamiltonian import Hamiltonian
from .energies.energy_adapter import EnergyAdapter
from .operator import Operator
from .linearization import Linearization
# We deliberately don't set __all__ here, because we don't want people to do a
......
......@@ -19,7 +19,7 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operator import Operator
from ..operators.operator import Operator
from ..library.gaussian_energy import GaussianEnergy
from ..operators.sampling_enabler import SamplingEnabler
......
......@@ -19,7 +19,7 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operator import Operator
from ..operators.operator import Operator
from ..utilities import my_sum
......
......@@ -27,7 +27,7 @@ from ..field import Field
from ..multi.multi_field import MultiField
from ..multi.multi_domain import MultiDomain
from ..sugar import makeOp, sqrt
from ..operator import Operator
from ..operators.operator import Operator
def _ceps_kernel(dof_space, k, a, k0):
......
......@@ -19,7 +19,7 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operator import Operator
from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp
......
......@@ -25,7 +25,7 @@ from ..multi.multi_domain import MultiDomain
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_transform_operator import HarmonicTransformOperator
from ..operators.power_distributor import PowerDistributor
from ..operator import Operator
from ..operators.operator import Operator
class CorrelatedField(Operator):
......
......@@ -19,8 +19,9 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..operator import Operator
from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator
from ..domain_tuple import DomainTuple
class GaussianEnergy(Operator):
......@@ -28,6 +29,7 @@ class GaussianEnergy(Operator):
super(GaussianEnergy, self).__init__()
self._mean = mean
self._icov = None if covariance is None else covariance.inverse
self._target = DomainTuple.scalar_domain()
def __call__(self, x):
residual = x if self._mean is None else x-self._mean
......
......@@ -21,7 +21,7 @@ from __future__ import absolute_import, division, print_function
from numpy import inf, isnan
from ..compat import *
from ..operator import Operator
from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp
......
from __future__ import absolute_import, division, print_function
from .compat import *
from .utilities import NiftyMetaBase
class Operator(NiftyMetaBase()):
"""Transforms values living on one domain into values living on another
domain, and can also provide the Jacobian.
"""
def chain(self, x):
if not callable(x):
raise TypeError("callable needed")
ops1 = self._ops if isinstance(self, OpChain) else (self,)
ops2 = x._ops if isinstance(x, OpChain) else (x,)
return OpChain(ops1+ops2)
def __call__(self, x):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise NotImplementedError
class OpChain(Operator):
def __init__(self, ops):
self._ops = tuple(ops)
def __call__(self, x):
for op in reversed(self._ops):
x = op(x)
return x
......@@ -23,7 +23,7 @@ import abc
import numpy as np
from ..compat import *
from ..operator import Operator
from .operator import Operator
class LinearOperator(Operator):
......@@ -86,21 +86,6 @@ class LinearOperator(Operator):
def __init__(self):
pass
@abc.abstractproperty
def domain(self):
# FIXME Adopt documentation to MultiDomains
"""DomainTuple : the operator's input domain
The domain on which the Operator's input Field lives."""
raise NotImplementedError
@abc.abstractproperty
def target(self):
"""DomainTuple : the operator's output domain
The domain on which the Operator's output Field lives."""
raise NotImplementedError
def _flip_modes(self, trafo):
from .operator_adapter import OperatorAdapter
return self if trafo == 0 else OperatorAdapter(self, trafo)
......
from __future__ import absolute_import, division, print_function
import abc
from ..compat import *
from ..utilities import NiftyMetaBase
class Operator(NiftyMetaBase()):
"""Transforms values living on one domain into values living on another
domain, and can also provide the Jacobian.
"""
def domain(self):
"""DomainTuple or MultiDomain : the operator's input domain
The domain on which the Operator's input Field lives."""
return self._domain
def target(self):
"""DomainTuple or MultiDomain : the operator's output domain
The domain on which the Operator's output Field lives."""
return self._target
def __matmul__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return OpChain.make((self, x))
ops1 = self._ops if isinstance(self, OpChain) else (self,)
ops2 = x._ops if isinstance(x, OpChain) else (x,)
return OpChain(ops1+ops2)
def chain(self, x):
res = self.__matmul__(x)
if res == NotImplemented:
raise TypeError("operator expected")
return res
def __call__(self, x):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise NotImplementedError
class _CombinedOperator(Operator):
def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
self._ops = tuple(ops)
@classmethod
def unpack(cls, ops, res):
for op in ops:
if isinstance(op, cls):
res = cls.unpack(op, res)
else:
res = res + [op]
return res
@classmethod
def make(cls, ops):
res = cls.unpack(ops, [])
if len(res) == 1:
return res[0]
return cls(res, _callingfrommake=True)
class _OpChain(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpChain, self).__init__(ops, _callingfrommake)
self._domain = self._ops[-1].domain
self._target = self._ops[0].target
def __call__(self, x):
for op in reversed(self._ops):
x = op(x)
return x
class _OpProd(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpProd, self).__init__(ops, _callingfrommake)
self._domain = self._ops[0].domain
self._target = self._ops[0].target
def __call__(self, x):
return my_prod(map(lambda op: op(x) for op in self._ops))
class _OpSum(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpSum, self).__init__(ops, _callingfrommake)
self._domain = domain_union([op.domain for op in self._ops])
self._target = domain_union([op.target for op in self._ops])
def __call__(self, x):
raise NotImplementedError
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment