Commit a535e0e8 authored by Martin Reinecke's avatar Martin Reinecke

cleanups and fixes

parent 69a50640
......@@ -24,7 +24,7 @@ if __name__ == '__main__':
power_distributor = ift.PowerDistributor(harmonic_space, power_space)
position = {}
position['xi'] = ift.Field.from_random('normal', harmonic_space)
position = ift.MultiField(position)
position = ift.MultiField.from_dict(position)
xi = ift.Variable(position)['xi']
Amp = power_distributor(A)
......
......@@ -142,20 +142,6 @@ class DomainTuple(object):
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
return self.__eq__(x)
def subsetOf(self, x):
return self.__eq__(x)
def unitedWith(self, x):
if self is x:
return self
x = DomainTuple.make(x)
if self is not x:
raise ValueError("domain mismatch")
return self
def __str__(self):
res = "DomainTuple, len: " + str(len(self))
for i in self:
......
......@@ -109,7 +109,7 @@ class Field(object):
@staticmethod
def from_local_data(domain, arr):
return Field(DomainTuple.make(domain),
dobj.from_local_data(domain.shape, arr))
dobj.from_local_data(domain.shape, arr))
def to_global_data(self):
"""Returns an array containing the full data of the field.
......
......@@ -58,7 +58,7 @@ def make_amplitude_model(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
fields = {keys[0]: Field.from_random('normal', dof_space),
keys[1]: Field.from_random('normal', param_space)}
position = MultiField(fields)
position = MultiField.from_dict(fields)
dof_space = position[keys[0]].domain[0]
kern = lambda k: _ceps_kernel(dof_space, k, ceps_a, ceps_k)
......
......@@ -31,12 +31,15 @@ class BlockDiagonalOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField(x.domain, tuple(self._operators[key].apply(x._val[i], mode=mode) for i, key in enumerate(x.keys())))
val = tuple(self._operators[key].apply(x._val[i], mode=mode)
for i, key in enumerate(x.keys()))
return MultiField(self._domain, val)
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField.from_dict({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
val = tuple(self._operators[key].draw_sample(from_inverse, dtype[key])
for key in self._domain._keys)
return MultiField(self._domain, val)
def _combine_chain(self, op):
res = {}
......
......@@ -61,52 +61,3 @@ class MultiDomain(object):
def __ne__(self, x):
return not self.__eq__(x)
def compatibleTo(self, x):
if self is x:
return True
x = MultiDomain.make(x)
if self is x:
return True
if (self, x) in MultiDomain._compatCache:
return True
commonKeys = set(self.keys()) & set(x.keys())
for key in commonKeys:
if self[key] is not x[key]:
return False
MultiDomain._compatCache.add((self, x))
MultiDomain._compatCache.add((x, self))
return True
def subsetOf(self, x):
if self is x:
return True
x = MultiDomain.make(x)
if self is x:
return True
if len(x) == 0:
return True
if (self, x) in MultiDomain._subsetCache:
return True
for key in self.keys():
if key not in x:
return False
if self[key] is not x[key]:
return False
MultiDomain._subsetCache.add((self, x))
return True
def unitedWith(self, x):
if self is x:
return self
x = MultiDomain.make(x)
if self is x:
return self
if not self.compatibleTo(x):
raise ValueError("domain mismatch")
res = {}
for key, val in self.items():
res[key] = val
for key, val in x.items():
res[key] = val
return MultiDomain.make(res)
......@@ -103,7 +103,7 @@ class MultiField(object):
# dtype = MultiField.build_dtype(dtype, domain)
return MultiField(
domain, tuple(Field.from_random(random_type, dom, dtype, **kwargs)
for dom in domain._domains))
for dom in domain._domains))
def _check_domain(self, other):
if other._domain is not self._domain:
......@@ -131,13 +131,14 @@ class MultiField(object):
for dom in domain._domains))
def to_global_data(self):
return {key: val.to_global_data() for key, val in zip(self._domain.keys(), self._val)}
return {key: val.to_global_data()
for key, val in zip(self._domain.keys(), self._val)}
@staticmethod
def from_global_data(domain, arr, sum_up=False):
return MultiField(domain, tuple(Field.from_global_data(domain[key],
arr[key], sum_up)
for key in domain.keys()))
arr[key], sum_up)
for key in domain.keys()))
def norm(self):
""" Computes the L2-norm of the field values.
......
......@@ -282,5 +282,5 @@ class LinearOperator(NiftyMetaBase()):
def _check_input(self, x, mode):
self._check_mode(mode)
if not self._dom(mode).subsetOf(x.domain):
if self._dom(mode) is not x.domain:
raise ValueError("The operator's and field's domains don't match.")
......@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
from ..multi.multi_domain import MultiDomain
class SelectionOperator(LinearOperator):
......@@ -31,10 +32,7 @@ class SelectionOperator(LinearOperator):
String identifier of the wanted subdomain
"""
def __init__(self, domain, key):
from ..multi.multi_domain import MultiDomain
if not isinstance(domain, MultiDomain):
raise TypeError("Domain must be a MultiDomain")
self._domain = domain
self._domain = MultiDomain.make(domain)
self._key = key
@property
......@@ -55,4 +53,6 @@ class SelectionOperator(LinearOperator):
return x[self._key]
else:
from ..multi.multi_field import MultiField
return MultiField.from_dict({self._key: x})
rval = [None]*len(self._domain)
rval[self._domain._dict[self._key]] = x
return MultiField(self._domain, tuple(rval))
......@@ -46,8 +46,8 @@ class SumOperator(LinearOperator):
dom = ops[0].domain
tgt = ops[0].target
for op in ops[1:]:
dom = dom.unitedWith(op.domain)
tgt = tgt.unitedWith(op.target)
if dom is not op.domain or tgt is not op.target:
raise ValueError("Domain mismatch")
# Step 2: unpack SumOperators
opsnew = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment