-
Martin Reinecke authoredMartin Reinecke authored
operator.py 16.15 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..utilities import NiftyMeta, indent
from .. import pointwise
class Operator(metaclass=NiftyMeta):
"""Transforms values defined on one domain into values defined on another
domain, and can also provide the Jacobian.
"""
@property
def domain(self):
"""The domain on which the Operator's input Field is defined.
Returns
-------
domain : DomainTuple or MultiDomain
"""
return self._domain
@property
def target(self):
"""The domain on which the Operator's output Field is defined.
Returns
-------
target : DomainTuple or MultiDomain
"""
return self._target
@property
def val(self):
"""The numerical value associated with this object
For "pure" operators this is `None`. For Field-like objects this
is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s mathcing the
object's `target`.
Returns
-------
None or numpy.ndarray or dictionary of np.ndarrays : the numerical value
"""
return None
@property
def jac(self):
"""The Jacobian associated with this object
For "pure" operators this is `None`. For Field-like objects this
can be `None` (in which case the object is a constant), or it can be a
`LinearOperator` with `domain` and `target` matching the object's.
Returns
-------
None or LinearOperator : the Jacobian
Notes
-----
if `value` is None, this must be `None` as well!
"""
return None
@property
def want_metric(self):
"""Whether a metric should be computed for the full expression.
This is `False` whenever `jac` is `None`. In other cases it signals
that operators processing this object should compute the metric.
Returns
-------
bool : whether the metric should be computed
"""
return False
@property
def metric(self):
"""The metric associated with the object.
This is `None`, except when all the following conditions hold:
- `want_metric` is `True`
- `target` is the scalar domain
- the operator chain contained an operator which could compute the
metric
Returns
-------
None or LinearOperator : the metric
"""
return None
@staticmethod
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
s = "The operator's and field's domains don't match."
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
if not isinstance(dom_op, (DomainTuple, MultiDomain,)):
s += " Your operator's domain is neither a `DomainTuple`" \
" nor a `MultiDomain`."
raise ValueError(s)
def scale(self, factor):
if factor == 1:
return self
from .scaling_operator import ScalingOperator
return ScalingOperator(self.target, factor)(self)
def conjugate(self):
from .simple_linear_operators import ConjugationOperator
return ConjugationOperator(self.target)(self)
def sum(self, spaces=None):
from .contraction_operator import ContractionOperator
return ContractionOperator(self.target, spaces)(self)
def vdot(self, other):
from ..sugar import makeOp
if not isinstance(other, Operator):
raise TypeError
if other.jac is None:
res = self.conjugate()*other
else:
res = makeOp(other) @ self.conjugate()
return res.sum()
@property
def real(self):
from .simple_linear_operators import Realizer
return Realizer(self.target)(self)
def __neg__(self):
return self.scale(-1)
def __matmul__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return _OpChain.make((self, x))
def __rmatmul__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return _OpChain.make((x, self))
def partial_insert(self, x):
from ..multi_domain import MultiDomain
if not isinstance(x, Operator):
raise TypeError
if not isinstance(self.domain, MultiDomain):
raise TypeError
if not isinstance(x.target, MultiDomain):
raise TypeError
bigdom = MultiDomain.union([self.domain, x.target])
k1, k2 = set(self.domain.keys()), set(x.target.keys())
le, ri = k2 - k1, k1 - k2
leop, riop = self, x
if len(ri) > 0:
riop = riop + self.identity_operator(
MultiDomain.make({kk: bigdom[kk]
for kk in ri}))
if len(le) > 0:
leop = leop + self.identity_operator(
MultiDomain.make({kk: bigdom[kk]
for kk in le}))
return leop @ riop
@staticmethod
def identity_operator(dom):
from .block_diagonal_operator import BlockDiagonalOperator
from .scaling_operator import ScalingOperator
idops = {kk: ScalingOperator(dd, 1.) for kk, dd in dom.items()}
return BlockDiagonalOperator(dom, idops)
def __mul__(self, x):
if isinstance(x, Operator):
return _OpProd(self, x)
if np.isscalar(x):
return self.scale(x)
return NotImplemented
def __rmul__(self, x):
return self.__mul__(x)
def __add__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return _OpSum(self, x)
def __sub__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return _OpSum(self, -x)
def __pow__(self, power):
if not (np.isscalar(power) or power.jac is None):
return NotImplemented
return self.ptw("power", power)
def apply(self, x):
"""Applies the operator to a Field or MultiField.
Parameters
----------
x : Field or MultiField
Input on which the operator shall act. Needs to be defined on
:attr:`domain`.
"""
raise NotImplementedError
def force(self, x):
"""Extract subset of domain of x according to `self.domain` and apply
operator."""
return self.apply(x.extract(self.domain))
def _check_input(self, x):
from .scaling_operator import ScalingOperator
if not (isinstance(x, Operator) and x.val is not None):
raise TypeError
if x.jac is not None:
if not isinstance(x.jac, ScalingOperator):
raise ValueError
if x.jac._factor != 1:
raise ValueError
self._check_domain_equality(self._domain, x.domain)
def __call__(self, x):
if not isinstance(x, Operator):
raise TypeError
if x.jac is not None:
return self.apply(x.trivial_jac()).prepend_jac(x.jac)
elif x.val is not None:
return self.apply(x)
return self @ x
def ducktape(self, name):
from .simple_linear_operators import ducktape
return self @ ducktape(self, None, name)
def ducktape_left(self, name):
from .simple_linear_operators import ducktape
return ducktape(None, self, name) @ self
def __repr__(self):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
for f in pointwise.ptw_dict.keys():
def func(f):
def func2(self, *args, **kwargs):
return self.ptw(f, *args, **kwargs)
return func2
setattr(Operator, f, func(f))
class _ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
else:
self._nc |= set(fulldom) - set(const)
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
def add(self, const, fulldom):
if const is None:
self._nc |= set(fulldom.keys())
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
@property
def constfield(self):
return self._const
class _ConstantOperator(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._output = output
def apply(self, x):
from .simple_linear_operators import NullOperator
self._check_input(x)
if x.jac is not None:
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname, *args, **kwargs):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._funcname = funcname
self._args = args
self._kwargs = kwargs
def apply(self, x):
self._check_input(x)
return x.ptw(self._funcname, *self._args, **self._kwargs)
class _CombinedOperator(Operator):
def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
self._ops = tuple(ops)
@classmethod
def unpack(cls, ops, res):
for op in ops:
if isinstance(op, cls):
res = cls.unpack(op._ops, res)
else:
res = res + [op]
return res
@classmethod
def make(cls, ops):
res = cls.unpack(ops, [])
if len(res) == 1:
return res[0]
return cls(res, _callingfrommake=True)
class _OpChain(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
super(_OpChain, self).__init__(ops, _callingfrommake)
self._domain = self._ops[-1].domain
self._target = self._ops[0].target
for i in range(1, len(self._ops)):
if self._ops[i-1].domain != self._ops[i].target:
raise ValueError("domain mismatch")
def apply(self, x):
self._check_input(x)
for op in reversed(self._ops):
x = op(x)
return x
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "_OpChain:\n" + indent(subs)
class _OpProd(Operator):
def __init__(self, op1, op2):
from ..sugar import domain_union
self._domain = domain_union((op1.domain, op2.domain))
self._target = op1.target
if op1.target != op2.target:
raise ValueError("target mismatch")
self._op1 = op1
self._op2 = op2
def apply(self, x):
from ..linearization import Linearization
from ..sugar import makeOp
self._check_input(x)
lin = x.jac is not None
wm = x.want_metric if lin else False
x = x.val if lin else x
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
if not lin:
return self._op1(v1) * self._op2(v2)
lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, jac)
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = _ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpProd:\n"+indent(subs)
class _OpSum(Operator):
def __init__(self, op1, op2):
from ..sugar import domain_union
self._domain = domain_union((op1.domain, op2.domain))
self._target = domain_union((op1.target, op2.target))
self._op1 = op1
self._op2 = op2
def apply(self, x):
from ..linearization import Linearization
self._check_input(x)
if x.jac is None:
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
return self._op1(v1).unite(self._op2(v2))
v1 = x.val.extract(self._op1.domain)
v2 = x.val.extract(self._op2.domain)
wm = x.want_metric
lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
op = lin1._jac._myadd(lin2._jac, False)
res = lin1.new(lin1._val.unite(lin2._val), op)
if lin1._metric is not None and lin2._metric is not None:
res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = _ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpSum:\n"+indent(subs)