Commit c960121b authored by Martin Reinecke's avatar Martin Reinecke

more

parent 60bf8aa9
......@@ -105,3 +105,15 @@ class MultiDomain(object):
for key, dom in zip(self._keys, self._domains):
res += key+": "+str(dom)+"\n"
return res
@staticmethod
def union(inp):
res = {}
for dom in inp:
for key, subdom in zip(dom._keys, dom._domains):
if key in res:
if res[key] is not subdom:
raise ValueError("domain mismatch")
else:
res[key] = subdom
return MultiDomain.make(res)
......@@ -121,6 +121,7 @@ class MultiField(object):
@staticmethod
def full(domain, val):
domain = MultiDomain.make(domain)
return MultiField(domain, tuple(Field.full(dom, val)
for dom in domain._domains))
......
......@@ -38,7 +38,10 @@ class Linearization(object):
def __add__(self, other):
if isinstance(other, Linearization):
return Linearization(self._val+other._val, self._jac+other._jac)
from .operators.relaxed_sum_operator import RelaxedSumOperator
return Linearization(
MultiField.combine((self._val, other._val)),
RelaxedSumOperator((self._jac, other._jac)))
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac)
......@@ -52,10 +55,10 @@ class Linearization(object):
return (-self).__add__(other)
def __mul__(self, other):
from .operators.diagonal_operator import DiagonalOperator
from .sugar import makeOp
if isinstance(other, Linearization):
d1 = DiagonalOperator(self._val)
d2 = DiagonalOperator(other._val)
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(self._val*other._val,
self._jac*d2 + d1*other._jac)
if isinstance(other, (int, float, complex)):
......@@ -63,15 +66,16 @@ class Linearization(object):
# return ...
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d2 = DiagonalOperator(other)
d2 = makeOp(other)
return Linearization(self._val*other, self._jac*d2)
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d1 = DiagonalOperator(other)
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
@staticmethod
......@@ -80,8 +84,8 @@ class Linearization(object):
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(0., {}))
from .operators.null_operator import NullOperator
return Linearization(field, NullOperator({}, field.domain))
class Operator(NiftyMetaBase()):
"""Transforms values living on one domain into values living on another
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..utilities import my_sum
from .linear_operator import LinearOperator
from ..multi.multi_domain import MultiDomain
class RelaxedSumOperator(LinearOperator):
"""Class representing sums of operators with compatible MultiDomains."""
def __init__(self, ops):
super(RelaxedSumOperator, self).__init__()
self._ops = ops
self._domain = MultiDomain.union([op.domain for op in ops])
self._target = MultiDomain.union([op.target for op in ops])
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def adjoint(self):
return RelaxedSumOperator([op.adjoint for op in self._ops])
@property
def capability(self):
return self._capability
def apply(self, x, mode):
self._check_mode(mode)
res = None
for op in self._ops:
tmp = x.extract(op._dom(mode), mode)
if res is None:
res = tmp
else:
res = MultiField.combine([res, tmp])
return res
......@@ -251,7 +251,8 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x):
if isinstance(x, MultiField):
return MultiField({key: func2(val) for key, val in x.items()})
return MultiField(x.domain,
tuple(func2(val) for val in x.values()))
elif isinstance(x, Field):
fu = getattr(dobj, f)
return Field(domain=x._domain, val=fu(x.val))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment