From c960121b31b1187bb11d3beb79e64af3a2461c5a Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Tue, 24 Jul 2018 21:21:59 +0200 Subject: [PATCH] more --- nifty5/multi/multi_domain.py | 12 +++++ nifty5/multi/multi_field.py | 1 + nifty5/operator.py | 20 ++++--- nifty5/operators/relaxed_sum_operator.py | 66 ++++++++++++++++++++++++ nifty5/sugar.py | 3 +- 5 files changed, 93 insertions(+), 9 deletions(-) create mode 100644 nifty5/operators/relaxed_sum_operator.py diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py index 715e51570..7876184af 100644 --- a/nifty5/multi/multi_domain.py +++ b/nifty5/multi/multi_domain.py @@ -105,3 +105,15 @@ class MultiDomain(object): for key, dom in zip(self._keys, self._domains): res += key+": "+str(dom)+"\n" return res + + @staticmethod + def union(inp): + res = {} + for dom in inp: + for key, subdom in zip(dom._keys, dom._domains): + if key in res: + if res[key] is not subdom: + raise ValueError("domain mismatch") + else: + res[key] = subdom + return MultiDomain.make(res) diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index c79e2767a..3c4457264 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -121,6 +121,7 @@ class MultiField(object): @staticmethod def full(domain, val): + domain = MultiDomain.make(domain) return MultiField(domain, tuple(Field.full(dom, val) for dom in domain._domains)) diff --git a/nifty5/operator.py b/nifty5/operator.py index ecccefbf0..a6abdf5ee 100644 --- a/nifty5/operator.py +++ b/nifty5/operator.py @@ -38,7 +38,10 @@ class Linearization(object): def __add__(self, other): if isinstance(other, Linearization): - return Linearization(self._val+other._val, self._jac+other._jac) + from .operators.relaxed_sum_operator import RelaxedSumOperator + return Linearization( + MultiField.combine((self._val, other._val)), + RelaxedSumOperator((self._jac, other._jac))) if isinstance(other, (int, float, complex, Field, MultiField)): return Linearization(self._val+other, self._jac) @@ -52,10 +55,10 @@ class Linearization(object): return (-self).__add__(other) def __mul__(self, other): - from .operators.diagonal_operator import DiagonalOperator + from .sugar import makeOp if isinstance(other, Linearization): - d1 = DiagonalOperator(self._val) - d2 = DiagonalOperator(other._val) + d1 = makeOp(self._val) + d2 = makeOp(other._val) return Linearization(self._val*other._val, self._jac*d2 + d1*other._jac) if isinstance(other, (int, float, complex)): @@ -63,15 +66,16 @@ class Linearization(object): # return ... return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): - d2 = DiagonalOperator(other) + d2 = makeOp(other) return Linearization(self._val*other, self._jac*d2) raise TypeError def __rmul__(self, other): + from .sugar import makeOp if isinstance(other, (int, float, complex)): return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): - d1 = DiagonalOperator(other) + d1 = makeOp(other) return Linearization(self._val*other, d1*self._jac) @staticmethod @@ -80,8 +84,8 @@ class Linearization(object): return Linearization(field, ScalingOperator(1., field.domain)) @staticmethod def make_const(field): - from .operators.scaling_operator import ScalingOperator - return Linearization(field, ScalingOperator(0., {})) + from .operators.null_operator import NullOperator + return Linearization(field, NullOperator({}, field.domain)) class Operator(NiftyMetaBase()): """Transforms values living on one domain into values living on another diff --git a/nifty5/operators/relaxed_sum_operator.py b/nifty5/operators/relaxed_sum_operator.py new file mode 100644 index 000000000..f7f50a6ec --- /dev/null +++ b/nifty5/operators/relaxed_sum_operator.py @@ -0,0 +1,66 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +from ..compat import * +from ..utilities import my_sum +from .linear_operator import LinearOperator +from ..multi.multi_domain import MultiDomain + + +class RelaxedSumOperator(LinearOperator): + """Class representing sums of operators with compatible MultiDomains.""" + + def __init__(self, ops): + super(RelaxedSumOperator, self).__init__() + self._ops = ops + self._domain = MultiDomain.union([op.domain for op in ops]) + self._target = MultiDomain.union([op.target for op in ops]) + self._capability = self.TIMES | self.ADJOINT_TIMES + for op in ops: + self._capability &= op.capability + + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + + @property + def adjoint(self): + return RelaxedSumOperator([op.adjoint for op in self._ops]) + + @property + def capability(self): + return self._capability + + def apply(self, x, mode): + self._check_mode(mode) + res = None + for op in self._ops: + tmp = x.extract(op._dom(mode), mode) + if res is None: + res = tmp + else: + res = MultiField.combine([res, tmp]) + return res diff --git a/nifty5/sugar.py b/nifty5/sugar.py index 942d11a4e..cc2a2963c 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -251,7 +251,8 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: def func(f): def func2(x): if isinstance(x, MultiField): - return MultiField({key: func2(val) for key, val in x.items()}) + return MultiField(x.domain, + tuple(func2(val) for val in x.values())) elif isinstance(x, Field): fu = getattr(dobj, f) return Field(domain=x._domain, val=fu(x.val)) -- GitLab