From d07a537eda98c052d849087e93083e703d7230c8 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sat, 7 Jul 2018 16:08:50 +0200 Subject: [PATCH] more work --- nifty5/library/correlated_fields.py | 4 +-- nifty5/library/point_sources.py | 4 ++- nifty5/multi/block_diagonal_operator.py | 40 ++++++++++++------------- nifty5/multi/multi_domain.py | 10 ++++--- nifty5/multi/multi_field.py | 15 ++++++---- nifty5/operators/selection_operator.py | 4 +-- nifty5/sugar.py | 4 ++- 7 files changed, 44 insertions(+), 37 deletions(-) diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index a6f643925..5f92291fe 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -22,7 +22,7 @@ def make_correlated_field(s_space, amplitude_model): position['xi'] = Field.from_random('normal', h_space) position['tau'] = amplitude_model.position['tau'] position['phi'] = amplitude_model.position['phi'] - position = MultiField(position) + position = MultiField.from_dict(position) xi = Variable(position)['xi'] A = power_distributor(amplitude_model) @@ -70,7 +70,7 @@ def make_mf_correlated_field(s_space_spatial, s_space_energy, a = a_spatial*a_energy A = pd(a) - position = MultiField({'xi': Field.from_random('normal', h_space)}) + position = MultiField.from_dict({'xi': Field.from_random('normal', h_space)}) xi = Variable(position)['xi'] correlated_field_h = A*xi correlated_field = ht(correlated_field_h) diff --git a/nifty5/library/point_sources.py b/nifty5/library/point_sources.py index 67e2b162e..4ba9572cb 100644 --- a/nifty5/library/point_sources.py +++ b/nifty5/library/point_sources.py @@ -22,6 +22,7 @@ class PointSources(Model): @memo def value(self): points = self.position['points'].local_data + # MR FIXME?! points = np.clip(points, None, 8.2) points = Field.from_local_data(self.position['points'].domain, points) return self.IG(points, self._alpha, self._q) @@ -40,7 +41,8 @@ class PointSources(Model): outer = 1/outer_inv grad = Field.from_local_data(self.position['points'].domain, inner*outer) - grad = makeOp(MultiField({'points': grad})) + grad = makeOp(MultiField.from_dict({"points": grad}, + self.position._domain)) return SelectionOperator(grad.target, 'points')*grad @staticmethod diff --git a/nifty5/multi/block_diagonal_operator.py b/nifty5/multi/block_diagonal_operator.py index 99877d430..0546f35a1 100644 --- a/nifty5/multi/block_diagonal_operator.py +++ b/nifty5/multi/block_diagonal_operator.py @@ -5,7 +5,7 @@ from .multi_field import MultiField class BlockDiagonalOperator(EndomorphicOperator): - def __init__(self, operators): + def __init__(self, domain, operators): """ Parameters ---------- @@ -14,12 +14,12 @@ class BlockDiagonalOperator(EndomorphicOperator): LinearOperators as items """ super(BlockDiagonalOperator, self).__init__() - self._operators = operators - self._domain = MultiDomain.make( - {key: op.domain for key, op in self._operators.items()}) + self._domain = domain + self._ops = tuple(operators[key] for key in self.domain.keys()) self._cap = self._all_ops - for op in self._operators.values(): - self._cap &= op.capability + for op in self._ops: + if op is not None: + self._cap &= op.capability @property def domain(self): @@ -31,27 +31,27 @@ class BlockDiagonalOperator(EndomorphicOperator): def apply(self, x, mode): self._check_input(x, mode) - val = tuple(self._operators[key].apply(x._val[i], mode=mode) - for i, key in enumerate(x.keys())) + val = tuple(op.apply(v, mode=mode) if op is not None else None + for op, v in zip(self._ops, x.values())) return MultiField(self._domain, val) - def draw_sample(self, from_inverse=False, dtype=np.float64): - dtype = MultiField.build_dtype(dtype, self._domain) - val = tuple(self._operators[key].draw_sample(from_inverse, dtype[key]) - for key in self._domain._keys) - return MultiField(self._domain, val) +# def draw_sample(self, from_inverse=False, dtype=np.float64): +# dtype = MultiField.build_dtype(dtype, self._domain) +# val = tuple(op.draw_sample(from_inverse, dtype) +# for op in self._op) +# return MultiField(self._domain, val) def _combine_chain(self, op): - res = {} - for key in self._operators.keys(): - res[key] = self._operators[key]*op._operators[key] + if self._domain is not op._domain: + raise ValueError("domain mismatch") + res = {key : v1*v2 for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)} return BlockDiagonalOperator(res) def _combine_sum(self, op, selfneg, opneg): from ..operators.sum_operator import SumOperator + if self._domain is not op._domain: + raise ValueError("domain mismatch") res = {} - for key in self._operators.keys(): - res[key] = SumOperator.make([self._operators[key], - op._operators[key]], - [selfneg, opneg]) + for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops): + res[key] = SumOperator.make([v1, v2], [selfneg, opneg]) return BlockDiagonalOperator(res) diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py index 6497c780e..6687f165d 100644 --- a/nifty5/multi/multi_domain.py +++ b/nifty5/multi/multi_domain.py @@ -6,8 +6,6 @@ from ..utilities import frozendict class MultiDomain(object): _domainCache = {} - _subsetCache = set() - _compatCache = set() def __init__(self, dict, _callingfrommake=False): if not _callingfrommake: @@ -15,7 +13,7 @@ class MultiDomain(object): 'To create a MultiDomain call `MultiDomain.make()`.') self._keys = tuple(sorted(dict.keys())) self._domains = tuple(dict[key] for key in self._keys) - self._dict = frozendict({key: i for i, key in enumerate(self._keys)}) + self._idx = frozendict({key: i for i, key in enumerate(self._keys)}) @staticmethod def make(inp): @@ -42,11 +40,15 @@ class MultiDomain(object): def domains(self): return self._domains + @property + def idx(self): + return self._idx + def items(self): return zip(self._keys, self._domains) def __getitem__(self, key): - return self._domains[self._dict[key]] + return self._domains[self._idx[key]] def __len__(self): return len(self._keys) diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index d43e7f58c..bb6004c5f 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -55,21 +55,24 @@ class MultiField(object): self._val = val @staticmethod - def from_dict(dict): - domain = MultiDomain.make({key: v._domain for key, v in dict.items()}) - return MultiField(domain, tuple(dict[key] for key in domain._keys)) + def from_dict(dict, domain=None): + if domain is None: + domain = MultiDomain.make({key: v._domain + for key, v in dict.items()}) + return MultiField(domain, tuple(dict[key] if key in dict else None + for key in domain.keys())) def to_dict(self): - return {key: val for key, val in zip(self._domain._keys, self._val)} + return {key: val for key, val in zip(self._domain.keys(), self._val)} def __getitem__(self, key): - return self._val[self._domain._dict[key]] + return self._val[self._domain.idx[key]] def keys(self): return self._domain.keys() def items(self): - return zip(self._domain._keys, self._val) + return zip(self._domain.keys(), self._val) def values(self): return self._val diff --git a/nifty5/operators/selection_operator.py b/nifty5/operators/selection_operator.py index daf31b92a..71fc3e860 100644 --- a/nifty5/operators/selection_operator.py +++ b/nifty5/operators/selection_operator.py @@ -53,6 +53,4 @@ class SelectionOperator(LinearOperator): return x[self._key] else: from ..multi.multi_field import MultiField - rval = [None]*len(self._domain) - rval[self._domain._dict[self._key]] = x - return MultiField(self._domain, tuple(rval)) + return MultiField.from_dict({self._key: x}, self._domain) diff --git a/nifty5/sugar.py b/nifty5/sugar.py index 0dd149ac5..277cb3b62 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -228,10 +228,12 @@ def makeDomain(domain): def makeOp(input): + if input is None: + return None if isinstance(input, Field): return DiagonalOperator(input) if isinstance(input, MultiField): - return BlockDiagonalOperator({key: makeOp(val) + return BlockDiagonalOperator(input.domain, {key: makeOp(val) for key, val in input.items()}) raise NotImplementedError -- GitLab