From a535e0e87cc35709cd7277e6780d21ff31cf2aa2 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Sat, 7 Jul 2018 11:57:01 +0200 Subject: [PATCH] cleanups and fixes --- demos/getting_started_3.py | 2 +- nifty5/domain_tuple.py | 14 ------- nifty5/field.py | 2 +- nifty5/library/amplitude_model.py | 2 +- nifty5/multi/block_diagonal_operator.py | 9 +++-- nifty5/multi/multi_domain.py | 49 ------------------------- nifty5/multi/multi_field.py | 9 +++-- nifty5/operators/linear_operator.py | 2 +- nifty5/operators/selection_operator.py | 10 ++--- nifty5/operators/sum_operator.py | 4 +- 10 files changed, 22 insertions(+), 81 deletions(-) diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index 6035aa95..b70fc6a6 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -24,7 +24,7 @@ if __name__ == '__main__': power_distributor = ift.PowerDistributor(harmonic_space, power_space) position = {} position['xi'] = ift.Field.from_random('normal', harmonic_space) - position = ift.MultiField(position) + position = ift.MultiField.from_dict(position) xi = ift.Variable(position)['xi'] Amp = power_distributor(A) diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py index fd63f599..e12cab39 100644 --- a/nifty5/domain_tuple.py +++ b/nifty5/domain_tuple.py @@ -142,20 +142,6 @@ class DomainTuple(object): def __ne__(self, x): return not self.__eq__(x) - def compatibleTo(self, x): - return self.__eq__(x) - - def subsetOf(self, x): - return self.__eq__(x) - - def unitedWith(self, x): - if self is x: - return self - x = DomainTuple.make(x) - if self is not x: - raise ValueError("domain mismatch") - return self - def __str__(self): res = "DomainTuple, len: " + str(len(self)) for i in self: diff --git a/nifty5/field.py b/nifty5/field.py index bd7102d0..87514351 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -109,7 +109,7 @@ class Field(object): @staticmethod def from_local_data(domain, arr): return Field(DomainTuple.make(domain), - dobj.from_local_data(domain.shape, arr)) + dobj.from_local_data(domain.shape, arr)) def to_global_data(self): """Returns an array containing the full data of the field. diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py index 6878e536..c3f3e4f9 100644 --- a/nifty5/library/amplitude_model.py +++ b/nifty5/library/amplitude_model.py @@ -58,7 +58,7 @@ def make_amplitude_model(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv, fields = {keys[0]: Field.from_random('normal', dof_space), keys[1]: Field.from_random('normal', param_space)} - position = MultiField(fields) + position = MultiField.from_dict(fields) dof_space = position[keys[0]].domain[0] kern = lambda k: _ceps_kernel(dof_space, k, ceps_a, ceps_k) diff --git a/nifty5/multi/block_diagonal_operator.py b/nifty5/multi/block_diagonal_operator.py index 4a41ac6f..99877d43 100644 --- a/nifty5/multi/block_diagonal_operator.py +++ b/nifty5/multi/block_diagonal_operator.py @@ -31,12 +31,15 @@ class BlockDiagonalOperator(EndomorphicOperator): def apply(self, x, mode): self._check_input(x, mode) - return MultiField(x.domain, tuple(self._operators[key].apply(x._val[i], mode=mode) for i, key in enumerate(x.keys()))) + val = tuple(self._operators[key].apply(x._val[i], mode=mode) + for i, key in enumerate(x.keys())) + return MultiField(self._domain, val) def draw_sample(self, from_inverse=False, dtype=np.float64): dtype = MultiField.build_dtype(dtype, self._domain) - return MultiField.from_dict({key: op.draw_sample(from_inverse, dtype[key]) - for key, op in self._operators.items()}) + val = tuple(self._operators[key].draw_sample(from_inverse, dtype[key]) + for key in self._domain._keys) + return MultiField(self._domain, val) def _combine_chain(self, op): res = {} diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py index 85c1b597..6497c780 100644 --- a/nifty5/multi/multi_domain.py +++ b/nifty5/multi/multi_domain.py @@ -61,52 +61,3 @@ class MultiDomain(object): def __ne__(self, x): return not self.__eq__(x) - - def compatibleTo(self, x): - if self is x: - return True - x = MultiDomain.make(x) - if self is x: - return True - if (self, x) in MultiDomain._compatCache: - return True - commonKeys = set(self.keys()) & set(x.keys()) - for key in commonKeys: - if self[key] is not x[key]: - return False - MultiDomain._compatCache.add((self, x)) - MultiDomain._compatCache.add((x, self)) - return True - - def subsetOf(self, x): - if self is x: - return True - x = MultiDomain.make(x) - if self is x: - return True - if len(x) == 0: - return True - if (self, x) in MultiDomain._subsetCache: - return True - for key in self.keys(): - if key not in x: - return False - if self[key] is not x[key]: - return False - MultiDomain._subsetCache.add((self, x)) - return True - - def unitedWith(self, x): - if self is x: - return self - x = MultiDomain.make(x) - if self is x: - return self - if not self.compatibleTo(x): - raise ValueError("domain mismatch") - res = {} - for key, val in self.items(): - res[key] = val - for key, val in x.items(): - res[key] = val - return MultiDomain.make(res) diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index 65ee8287..d43e7f58 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -103,7 +103,7 @@ class MultiField(object): # dtype = MultiField.build_dtype(dtype, domain) return MultiField( domain, tuple(Field.from_random(random_type, dom, dtype, **kwargs) - for dom in domain._domains)) + for dom in domain._domains)) def _check_domain(self, other): if other._domain is not self._domain: @@ -131,13 +131,14 @@ class MultiField(object): for dom in domain._domains)) def to_global_data(self): - return {key: val.to_global_data() for key, val in zip(self._domain.keys(), self._val)} + return {key: val.to_global_data() + for key, val in zip(self._domain.keys(), self._val)} @staticmethod def from_global_data(domain, arr, sum_up=False): return MultiField(domain, tuple(Field.from_global_data(domain[key], - arr[key], sum_up) - for key in domain.keys())) + arr[key], sum_up) + for key in domain.keys())) def norm(self): """ Computes the L2-norm of the field values. diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index d6f74fc0..00660f17 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -282,5 +282,5 @@ class LinearOperator(NiftyMetaBase()): def _check_input(self, x, mode): self._check_mode(mode) - if not self._dom(mode).subsetOf(x.domain): + if self._dom(mode) is not x.domain: raise ValueError("The operator's and field's domains don't match.") diff --git a/nifty5/operators/selection_operator.py b/nifty5/operators/selection_operator.py index 73ecdf1b..daf31b92 100644 --- a/nifty5/operators/selection_operator.py +++ b/nifty5/operators/selection_operator.py @@ -17,6 +17,7 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from .linear_operator import LinearOperator +from ..multi.multi_domain import MultiDomain class SelectionOperator(LinearOperator): @@ -31,10 +32,7 @@ class SelectionOperator(LinearOperator): String identifier of the wanted subdomain """ def __init__(self, domain, key): - from ..multi.multi_domain import MultiDomain - if not isinstance(domain, MultiDomain): - raise TypeError("Domain must be a MultiDomain") - self._domain = domain + self._domain = MultiDomain.make(domain) self._key = key @property @@ -55,4 +53,6 @@ class SelectionOperator(LinearOperator): return x[self._key] else: from ..multi.multi_field import MultiField - return MultiField.from_dict({self._key: x}) + rval = [None]*len(self._domain) + rval[self._domain._dict[self._key]] = x + return MultiField(self._domain, tuple(rval)) diff --git a/nifty5/operators/sum_operator.py b/nifty5/operators/sum_operator.py index c8292e30..c5946c4d 100644 --- a/nifty5/operators/sum_operator.py +++ b/nifty5/operators/sum_operator.py @@ -46,8 +46,8 @@ class SumOperator(LinearOperator): dom = ops[0].domain tgt = ops[0].target for op in ops[1:]: - dom = dom.unitedWith(op.domain) - tgt = tgt.unitedWith(op.target) + if dom is not op.domain or tgt is not op.target: + raise ValueError("Domain mismatch") # Step 2: unpack SumOperators opsnew = [] -- GitLab