Commit 3d1fadb6 authored by Martin Reinecke's avatar Martin Reinecke

first try

parent 5abb0ed8
Pipeline #31617 failed with stages
in 3 minutes and 19 seconds
......@@ -128,13 +128,24 @@ class DomainTuple(object):
def __eq__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self is x:
return True
return self._dom == x._dom
return self is x
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
return self.__eq__(x)
def subsetOf(self, x):
return self.__eq__(x)
def unitedWith(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self != x:
raise ValueError("domain mismatch")
return self
def __str__(self):
res = "DomainTuple, len: " + str(len(self))
for i in self:
......
......@@ -71,3 +71,44 @@ class MultiDomain(frozendict):
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
return obj
def __eq__(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
return self is x
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
commonKeys = set(self.keys()) & set(x.keys())
for key in commonKeys:
if self[key] != x[key]:
return False
return True
def subsetOf(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
for key in self.keys():
if not key in x:
return False
if self[key] != x[key]:
return False
return True
def unitedWith(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if self == x:
return self
if not self.compatibleTo(x):
raise ValueError("domain mismatch")
res = {}
for key, val in self.items():
res[key] = val
for key, val in x.items():
res[key] = val
return MultiDomain.make(res)
......@@ -199,9 +199,18 @@ for op in ["__add__", "__radd__", "__iadd__",
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
if self._domain == other._domain:
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
else:
if not self._domain.compatibleTo(other.domain):
raise ValueError("domain mismatch")
fullkeys = set(self._domain.keys()) | set(other._domain.keys())
result_val = {}
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else other[key]*0
f2 = other[key] if key in other._domain.keys() else self[key]*0
result_val[key] = getattr(f1, op)(f2)
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
......
......@@ -280,5 +280,5 @@ class LinearOperator(NiftyMetaBase()):
def _check_input(self, x, mode):
self._check_mode(mode)
if x.domain != self._dom(mode):
if not self._dom(mode).subsetOf(x.domain):
raise ValueError("The operator's and field's domains don't match.")
......@@ -23,12 +23,14 @@ import numpy as np
class SumOperator(LinearOperator):
"""Class representing sums of operators."""
def __init__(self, ops, neg, _callingfrommake=False):
def __init__(self, ops, neg, dom, tgt, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SumOperator, self).__init__()
self._ops = ops
self._neg = neg
self._domain = dom
self._target = tgt
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
......@@ -38,9 +40,12 @@ class SumOperator(LinearOperator):
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
dom = ops[0].domain
tgt = ops[0].target
for op in ops[1:]:
if op.domain != ops[0].domain or op.target != ops[0].target:
raise ValueError("domain mismatch")
dom = dom.unitedWith(op.domain)
tgt = tgt.unitedWith(op.target)
# Step 2: unpack SumOperators
opsnew = []
negnew = []
......@@ -124,7 +129,7 @@ class SumOperator(LinearOperator):
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg
return ops, neg, dom, tgt
@staticmethod
def make(ops, neg):
......@@ -134,18 +139,18 @@ class SumOperator(LinearOperator):
raise ValueError("ops is empty")
if len(ops) != len(neg):
raise ValueError("length mismatch between ops and neg")
ops, neg = SumOperator.simplify(ops, neg)
ops, neg, dom, tgt = SumOperator.simplify(ops, neg)
if len(ops) == 1 and not neg[0]:
return ops[0]
return SumOperator(ops, neg, _callingfrommake=True)
return SumOperator(ops, neg, dom, tgt, _callingfrommake=True)
@property
def domain(self):
return self._ops[0].domain
return self._domain
@property
def target(self):
return self._ops[0].target
return self._target
@property
def adjoint(self):
......
......@@ -76,7 +76,7 @@ class Test_Minimizers(unittest.TestCase):
except ImportError:
raise SkipTest
np.random.seed(42)
space = ift.UnstructuredDomain((2,))
space = ift.DomainTuple.make(ift.UnstructuredDomain((2,)))
starting_point = ift.Field.from_random('normal', domain=space)*10
class RBEnergy(ift.Energy):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment