Commit 7a8ca8cd authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'relaxed_multidomain' into 'NIFTy_5'

Relaxed multidomain

See merge request ift/NIFTy!272
parents 5abb0ed8 29b199ce
Pipeline #31638 failed with stages
in 4 minutes and 8 seconds
......@@ -128,13 +128,24 @@ class DomainTuple(object):
def __eq__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self is x:
return True
return self._dom == x._dom
return self is x
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
return self.__eq__(x)
def subsetOf(self, x):
return self.__eq__(x)
def unitedWith(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self != x:
raise ValueError("domain mismatch")
return self
def __str__(self):
res = "DomainTuple, len: " + str(len(self))
for i in self:
......
......@@ -747,6 +747,7 @@ for op in ["__add__", "__radd__", "__iadd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
global COUNTER
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
......
......@@ -16,7 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..operators.model_gradient_operator import ModelGradientOperator
from .model import Model
......@@ -26,9 +25,7 @@ class Constant(Model):
self._constant = constant
self._value = self._constant
self._gradient = ModelGradientOperator({}, position.domain,
self.value.domain)
self._gradient = 0.
def at(self, position):
return self.__class__(position, self._constant)
......@@ -71,3 +71,46 @@ class MultiDomain(frozendict):
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
return obj
def __eq__(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
return self is x
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
commonKeys = set(self.keys()) & set(x.keys())
for key in commonKeys:
if self[key] != x[key]:
return False
return True
def subsetOf(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if len(x) == 0:
return True
for key in self.keys():
if key not in x:
return False
if self[key] != x[key]:
return False
return True
def unitedWith(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if self == x:
return self
if not self.compatibleTo(x):
raise ValueError("domain mismatch")
res = {}
for key, val in self.items():
res[key] = val
for key, val in x.items():
res[key] = val
return MultiDomain.make(res)
......@@ -199,9 +199,37 @@ for op in ["__add__", "__radd__", "__iadd__",
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
if self._domain == other._domain:
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
else:
if not self._domain.compatibleTo(other.domain):
raise ValueError("domain mismatch")
fullkeys = set(self._domain.keys()) | set(other._domain.keys())
result_val = {}
if op in ["__iadd__", "__add__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None:
result_val[key] = f2
elif f2 is None:
result_val[key] = f1
else:
result_val[key] = getattr(f1, op)(f2)
elif op in ["__mul__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None or f2 is None:
continue
else:
result_val[key] = getattr(f1, op)(f2)
else:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else other[key]*0
f2 = other[key] if key in other._domain.keys() else self[key]*0
result_val[key] = getattr(f1, op)(f2)
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
......
......@@ -8,7 +8,6 @@ from .harmonic_transform_operator import HarmonicTransformOperator
from .inversion_enabler import InversionEnabler
from .laplace_operator import LaplaceOperator
from .linear_operator import LinearOperator
from .model_gradient_operator import ModelGradientOperator
from .power_distributor import PowerDistributor
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
......@@ -21,4 +20,4 @@ __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"FFTSmoothingOperator", "GeometryRemover",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler", "SandwichOperator", "SamplingEnabler",
"DOFDistributor", "ModelGradientOperator"]
"DOFDistributor", "SelectionOperator"]
......@@ -280,5 +280,5 @@ class LinearOperator(NiftyMetaBase()):
def _check_input(self, x, mode):
self._check_mode(mode)
if x.domain != self._dom(mode):
if not self._dom(mode).subsetOf(x.domain):
raise ValueError("The operator's and field's domains don't match.")
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..multi.multi_domain import MultiDomain
from ..multi.multi_field import MultiField
from ..sugar import full
from .linear_operator import LinearOperator
class ModelGradientOperator(LinearOperator):
def __init__(self, gradients, domain, target):
super(ModelGradientOperator, self).__init__()
self._gradients = gradients
gradients_domain = MultiField(self._gradients).domain
self._domain = MultiDomain.make(domain)
# Check compatibility
if not (set(gradients_domain.items()) <= set(self.domain.items())):
raise ValueError
self._target = target
for grad in gradients.values():
if self._target != grad.target:
raise TypeError(
'All gradients have to have the same target domain')
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
@property
def gradients(self):
return self._gradients
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = None
for key, op in self._gradients.items():
if res is None:
res = op(x[key])
else:
res += op(x[key])
# Needed if gradients == {}
if res is None:
res = full(self.target, 0.)
if not res.domain == self.target:
raise TypeError
else:
grad_keys = self._gradients.keys()
res = {}
for dd in self.domain:
if dd in grad_keys:
res[dd] = self._gradients[dd].adjoint_times(x)
else:
res[dd] = full(self.domain[dd], 0.)
res = MultiField(res)
if not res.domain == self.domain:
raise TypeError
return res
......@@ -17,7 +17,6 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..operators import LinearOperator
from ..sugar import full
class SelectionOperator(LinearOperator):
......@@ -42,15 +41,10 @@ class SelectionOperator(LinearOperator):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
# FIXME Is the copying necessary?
self._check_input(x, mode)
if mode == self.TIMES:
return x[self._key].copy()
else:
result = {}
for key, val in self.domain.items():
if key != self._key:
result[key] = full(val, 0.)
else:
result[key] = x.copy()
from ..multi import MultiField
return MultiField(result)
return MultiField({self._key: x.copy()})
......@@ -23,12 +23,14 @@ import numpy as np
class SumOperator(LinearOperator):
"""Class representing sums of operators."""
def __init__(self, ops, neg, _callingfrommake=False):
def __init__(self, ops, neg, dom, tgt, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SumOperator, self).__init__()
self._ops = ops
self._neg = neg
self._domain = dom
self._target = tgt
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
......@@ -38,9 +40,12 @@ class SumOperator(LinearOperator):
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
dom = ops[0].domain
tgt = ops[0].target
for op in ops[1:]:
if op.domain != ops[0].domain or op.target != ops[0].target:
raise ValueError("domain mismatch")
dom = dom.unitedWith(op.domain)
tgt = tgt.unitedWith(op.target)
# Step 2: unpack SumOperators
opsnew = []
negnew = []
......@@ -124,7 +129,7 @@ class SumOperator(LinearOperator):
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg
return ops, neg, dom, tgt
@staticmethod
def make(ops, neg):
......@@ -134,18 +139,18 @@ class SumOperator(LinearOperator):
raise ValueError("ops is empty")
if len(ops) != len(neg):
raise ValueError("length mismatch between ops and neg")
ops, neg = SumOperator.simplify(ops, neg)
ops, neg, dom, tgt = SumOperator.simplify(ops, neg)
if len(ops) == 1 and not neg[0]:
return ops[0]
return SumOperator(ops, neg, _callingfrommake=True)
return SumOperator(ops, neg, dom, tgt, _callingfrommake=True)
@property
def domain(self):
return self._ops[0].domain
return self._domain
@property
def target(self):
return self._ops[0].target
return self._target
@property
def adjoint(self):
......
......@@ -76,7 +76,7 @@ class Test_Minimizers(unittest.TestCase):
except ImportError:
raise SkipTest
np.random.seed(42)
space = ift.UnstructuredDomain((2,))
space = ift.DomainTuple.make(ift.UnstructuredDomain((2,)))
starting_point = ift.Field.from_random('normal', domain=space)*10
class RBEnergy(ift.Energy):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment