diff --git a/nifty5/field.py b/nifty5/field.py index e22382d8081bc2b8f597b5ee796f6e0e57c74afd..5ba5fba3ae79c920fd503ea4b819429e0f184580 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -626,6 +626,11 @@ class Field(object): raise ValueError("domain mismatch") return self + def extract_part(self, dom): + if dom != self._domain: + raise ValueError("domain mismatch") + return self + def unite(self, other): return self+other diff --git a/nifty5/multi_field.py b/nifty5/multi_field.py index a465e1b403043d6e2c052056c64b433e99ec71a1..fe541469708b9e5c89409030b8ee19e394360d09 100644 --- a/nifty5/multi_field.py +++ b/nifty5/multi_field.py @@ -217,6 +217,12 @@ class MultiField(object): return MultiField(subset, tuple(self[key] for key in subset.keys())) + def extract_part(self, subset): + if subset is self._domain: + return self + return MultiField.from_dict({key: self[key] for key in subset.keys() + if key in self}) + def unite(self, other): """Merges two MultiFields on potentially different MultiDomains. diff --git a/nifty5/operators/chain_operator.py b/nifty5/operators/chain_operator.py index 154fc7098174c17a2d1ceb0d65b44f339c7e0f9d..af46f4fefe5d1914745fd8164bbdb6549b05ed2a 100644 --- a/nifty5/operators/chain_operator.py +++ b/nifty5/operators/chain_operator.py @@ -138,6 +138,17 @@ class ChainOperator(LinearOperator): subs = "\n".join(sub.__repr__() for sub in self._ops) return "ChainOperator:\n" + utilities.indent(subs) + def _simplify_for_constant_input_nontrivial(self, c_inp): + from ..multi_domain import MultiDomain + if not isinstance(self._domain, MultiDomain): + return None, self + + newop = None + for op in reversed(self._ops): + c_inp, t_op = op.simplify_for_constant_input(c_inp) + newop = t_op if newop is None else op(newop) + return c_inp, newop + # def draw_sample(self, from_inverse=False, dtype=np.float64): # from ..sugar import from_random # if len(self._ops) == 1: diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 073579043d9104b53fb97baee3fa84275e5fa0b7..a674ccd5f03ac66d268cdd38d2ca46e1519b9908 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -146,6 +146,17 @@ class Operator(metaclass=NiftyMeta): def __repr__(self): return self.__class__.__name__ + def simplify_for_constant_input(self, c_inp): + if c_inp is None: + return None, self + if c_inp.domain == self.domain: + op = _ConstantOperator(self.domain, self(c_inp)) + return op(c_inp), op + return self._simplify_for_constant_input_nontrivial(c_inp) + + def _simplify_for_constant_input_nontrivial(self, c_inp): + return None, self + for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan', 'sinh', 'cosh', 'absolute', 'sinc', 'one_over']: @@ -157,6 +168,72 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan', setattr(Operator, f, func(f)) +class _ConstCollector(object): + def __init__(self): + self._const = None + self._nc = set() + + def mult(self, const, fulldom): + if const is None: + self._nc |= set(fulldom) + else: + self._nc |= set(fulldom) - set(const) + if self._const is None: + from ..multi_field import MultiField + self._const = MultiField.from_dict( + {key: const[key] for key in const if key not in self._nc}) + else: + from ..multi_field import MultiField + self._const = MultiField.from_dict( + {key: self._const[key]*const[key] + for key in const if key not in self._nc}) + + def add(self, const, fulldom): + if const is None: + self._nc |= set(fulldom.keys()) + else: + from ..multi_field import MultiField + self._nc |= set(fulldom.keys()) - set(const.keys()) + if self._const is None: + self._const = MultiField.from_dict( + {key: const[key] + for key in const.keys() if key not in self._nc}) + else: + self._const = self._const.unite(const) + self._const = MultiField.from_dict( + {key: self._const[key] + for key in self._const if key not in self._nc}) + + @property + def constfield(self): + return self._const + + +class _ConstantOperator(Operator): + def __init__(self, dom, output): + from ..sugar import makeDomain + self._domain = makeDomain(dom) + self._target = output.domain + self._output = output + + def apply(self, x): + from ..linearization import Linearization + from .simple_linear_operators import NullOperator + from ..domain_tuple import DomainTuple + self._check_input(x) + if not isinstance(x, Linearization): + return self._output + if x.want_metric and self._target is DomainTuple.scalar_domain(): + met = NullOperator(self._domain, self._domain) + else: + met = None + return x.new(self._output, NullOperator(self._domain, self._target), + met) + + def __repr__(self): + return 'ConstantOperator <- {}'.format(self.domain.keys()) + + class _FunctionApplier(Operator): def __init__(self, domain, funcname): from ..sugar import makeDomain @@ -229,6 +306,17 @@ class _OpChain(_CombinedOperator): x = op(x) return x + def _simplify_for_constant_input_nontrivial(self, c_inp): + from ..multi_domain import MultiDomain + if not isinstance(self._domain, MultiDomain): + return None, self + + newop = None + for op in reversed(self._ops): + c_inp, t_op = op.simplify_for_constant_input(c_inp) + newop = t_op if newop is None else op(newop) + return c_inp, newop + def __repr__(self): subs = "\n".join(sub.__repr__() for sub in self._ops) return "_OpChain:\n" + indent(subs) @@ -261,6 +349,21 @@ class _OpProd(Operator): makeOp(lin2._val)(lin1._jac), False) return lin1.new(lin1._val*lin2._val, op(x.jac)) + def _simplify_for_constant_input_nontrivial(self, c_inp): + f1, o1 = self._op1.simplify_for_constant_input( + c_inp.extract_part(self._op1.domain)) + f2, o2 = self._op2.simplify_for_constant_input( + c_inp.extract_part(self._op2.domain)) + + from ..multi_domain import MultiDomain + if not isinstance(self._target, MultiDomain): + return None, _OpProd(o1, o2) + + cc = _ConstCollector() + cc.mult(f1, o1.target) + cc.mult(f2, o2.target) + return cc.constfield, _OpProd(o1, o2) + def __repr__(self): subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2)) return "_OpProd:\n"+indent(subs) @@ -293,6 +396,21 @@ class _OpSum(Operator): res = res.add_metric(lin1._metric + lin2._metric) return res + def _simplify_for_constant_input_nontrivial(self, c_inp): + f1, o1 = self._op1.simplify_for_constant_input( + c_inp.extract_part(self._op1.domain)) + f2, o2 = self._op2.simplify_for_constant_input( + c_inp.extract_part(self._op2.domain)) + + from ..multi_domain import MultiDomain + if not isinstance(self._target, MultiDomain): + return None, _OpSum(o1, o2) + + cc = _ConstCollector() + cc.add(f1, o1.target) + cc.add(f2, o2.target) + return cc.constfield, _OpSum(o1, o2) + def __repr__(self): subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2)) return "_OpSum:\n"+indent(subs) diff --git a/nifty5/operators/scaling_operator.py b/nifty5/operators/scaling_operator.py index 1e14b62ed13de6c5e3cddfa0b1f8b6f6c3366ee3..a4b20ab07044e05c2e2eaaf24711e5688bd69401 100644 --- a/nifty5/operators/scaling_operator.py +++ b/nifty5/operators/scaling_operator.py @@ -35,14 +35,6 @@ class ScalingOperator(EndomorphicOperator): ----- :class:`Operator` supports the multiplication with a scalar. So one does not need instantiate :class:`ScalingOperator` explicitly in most cases. - - Formally, this operator always supports all operation modes (times, - adjoint_times, inverse_times and inverse_adjoint_times), even if `factor` - is 0 or infinity. It is the user's responsibility to apply the operator - only in appropriate ways (e.g. call inverse_times only if `factor` is - nonzero). - - This shortcoming will hopefully be fixed in the future. """ def __init__(self, factor, domain): @@ -52,7 +44,10 @@ class ScalingOperator(EndomorphicOperator): raise TypeError("Scalar required") self._factor = factor self._domain = makeDomain(domain) - self._capability = self._all_ops + if self._factor == 0.: + self._capability = self.TIMES | self.ADJOINT_TIMES + else: + self._capability = self._all_ops def apply(self, x, mode): self._check_input(x, mode) diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 3c1b54f4be5a62f7e7a51483a672b87746312d4f..096ebb2a24fb03d0a99bff835d79a9dd63a50b00 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -315,3 +315,23 @@ class NullOperator(LinearOperator): def apply(self, x, mode): self._check_input(x, mode) return self._nullfield(self._tgt(mode)) + + +class _PartialExtractor(LinearOperator): + def __init__(self, domain, target): + if not isinstance(domain, MultiDomain): + raise TypeError("MultiDomain expected") + if not isinstance(target, MultiDomain): + raise TypeError("MultiDomain expected") + self._domain = domain + self._target = target + for key in self._target.keys(): + if not (self._domain[key] is not self._target[key]): + raise ValueError("domain mismatch") + self._capability = self.TIMES | self.ADJOINT_TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + if mode == self.TIMES: + return x.extract(self._target) + return MultiField.from_dict({key: x[key] for key in x.domain.keys()}) diff --git a/nifty5/operators/sum_operator.py b/nifty5/operators/sum_operator.py index 9cfe1328c6eaa4b59401a606fa6279473d881a8c..d9a3c00c51a8cf3fdcd77a313e3d5f93c661e13d 100644 --- a/nifty5/operators/sum_operator.py +++ b/nifty5/operators/sum_operator.py @@ -23,6 +23,7 @@ from ..sugar import domain_union from ..utilities import indent from .block_diagonal_operator import BlockDiagonalOperator from .linear_operator import LinearOperator +from .simple_linear_operators import NullOperator class SumOperator(LinearOperator): @@ -59,6 +60,9 @@ class SumOperator(LinearOperator): negnew += [not n for n in op._neg] else: negnew += list(op._neg) +# FIXME: this needs some more work to keep the domain and target unchanged! +# elif isinstance(op, NullOperator): +# pass else: opsnew.append(op) negnew.append(ng) @@ -193,6 +197,9 @@ class SumOperator(LinearOperator): "cannot draw from inverse of this operator") res = None for op in self._ops: + from .simple_linear_operators import NullOperator + if isinstance(op, NullOperator): + continue tmp = op.draw_sample(from_inverse, dtype) res = tmp if res is None else res.unite(tmp) return res @@ -200,3 +207,29 @@ class SumOperator(LinearOperator): def __repr__(self): subs = "\n".join(sub.__repr__() for sub in self._ops) return "SumOperator:\n"+indent(subs) + + def _simplify_for_constant_input_nontrivial(self, c_inp): + f = [] + o = [] + for op in self._ops: + tf, to = op.simplify_for_constant_input( + c_inp.extract_part(op.domain)) + f.append(tf) + o.append(to) + + from ..multi_domain import MultiDomain + if not isinstance(self._target, MultiDomain): + fullop = None + for to, n in zip(o, self._neg): + op = to if not n else -to + fullop = op if fullop is None else fullop + op + return None, fullop + + from .operator import _ConstCollector + cc = _ConstCollector() + fullop = None + for tf, to, n in zip(f, o, self._neg): + cc.add(tf, to.target) + op = to if not n else -to + fullop = op if fullop is None else fullop + op + return cc.constfield, fullop diff --git a/test/test_operators/test_simplification.py b/test/test_operators/test_simplification.py new file mode 100644 index 0000000000000000000000000000000000000000..bce790f27c226383479802e13239395b31f17573 --- /dev/null +++ b/test/test_operators/test_simplification.py @@ -0,0 +1,55 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2019 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import pytest +from numpy.testing import assert_allclose, assert_equal + +import nifty5 as ift + + +def test_simplification(): + from nifty5.operators.operator import _ConstantOperator + f1 = ift.Field.full(ift.RGSpace(10),2.) + op = ift.FFTOperator(f1.domain) + _, op2 = op.simplify_for_constant_input(f1) + assert_equal(isinstance(op2, _ConstantOperator), True) + assert_allclose(op(f1).local_data, op2(f1).local_data) + + dom = {"a": ift.RGSpace(10)} + f1 = ift.full(dom,2.) + op = ift.FFTOperator(f1.domain["a"]).ducktape("a") + _, op2 = op.simplify_for_constant_input(f1) + assert_equal(isinstance(op2, _ConstantOperator), True) + assert_allclose(op(f1).local_data, op2(f1).local_data) + + dom = {"a": ift.RGSpace(10), "b": ift.RGSpace(5)} + f1 = ift.full(dom,2.) + pdom = {"a": ift.RGSpace(10)} + f2 = ift.full(pdom,2.) + o1 = ift.FFTOperator(f1.domain["a"]) + o2 = ift.FFTOperator(f1.domain["b"]) + op = (o1.ducktape("a").ducktape_left("a") + + o2.ducktape("b").ducktape_left("b")) + _, op2 = op.simplify_for_constant_input(f2) + assert_equal(isinstance(op2._op1, _ConstantOperator), True) + assert_allclose(op(f1)["a"].local_data, op2(f1)["a"].local_data) + assert_allclose(op(f1)["b"].local_data, op2(f1)["b"].local_data) + lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True) + assert_allclose(op(lin).val["a"].local_data, + op2(lin).val["a"].local_data) + assert_allclose(op(lin).val["b"].local_data, + op2(lin).val["b"].local_data)