Commit d07a537e authored by Martin Reinecke's avatar Martin Reinecke

more work

parent a535e0e8
......@@ -22,7 +22,7 @@ def make_correlated_field(s_space, amplitude_model):
position['xi'] = Field.from_random('normal', h_space)
position['tau'] = amplitude_model.position['tau']
position['phi'] = amplitude_model.position['phi']
position = MultiField(position)
position = MultiField.from_dict(position)
xi = Variable(position)['xi']
A = power_distributor(amplitude_model)
......@@ -70,7 +70,7 @@ def make_mf_correlated_field(s_space_spatial, s_space_energy,
a = a_spatial*a_energy
A = pd(a)
position = MultiField({'xi': Field.from_random('normal', h_space)})
position = MultiField.from_dict({'xi': Field.from_random('normal', h_space)})
xi = Variable(position)['xi']
correlated_field_h = A*xi
correlated_field = ht(correlated_field_h)
......
......@@ -22,6 +22,7 @@ class PointSources(Model):
@memo
def value(self):
points = self.position['points'].local_data
# MR FIXME?!
points = np.clip(points, None, 8.2)
points = Field.from_local_data(self.position['points'].domain, points)
return self.IG(points, self._alpha, self._q)
......@@ -40,7 +41,8 @@ class PointSources(Model):
outer = 1/outer_inv
grad = Field.from_local_data(self.position['points'].domain,
inner*outer)
grad = makeOp(MultiField({'points': grad}))
grad = makeOp(MultiField.from_dict({"points": grad},
self.position._domain))
return SelectionOperator(grad.target, 'points')*grad
@staticmethod
......
......@@ -5,7 +5,7 @@ from .multi_field import MultiField
class BlockDiagonalOperator(EndomorphicOperator):
def __init__(self, operators):
def __init__(self, domain, operators):
"""
Parameters
----------
......@@ -14,12 +14,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain.make(
{key: op.domain for key, op in self._operators.items()})
self._domain = domain
self._ops = tuple(operators[key] for key in self.domain.keys())
self._cap = self._all_ops
for op in self._operators.values():
self._cap &= op.capability
for op in self._ops:
if op is not None:
self._cap &= op.capability
@property
def domain(self):
......@@ -31,27 +31,27 @@ class BlockDiagonalOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
val = tuple(self._operators[key].apply(x._val[i], mode=mode)
for i, key in enumerate(x.keys()))
val = tuple(op.apply(v, mode=mode) if op is not None else None
for op, v in zip(self._ops, x.values()))
return MultiField(self._domain, val)
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
val = tuple(self._operators[key].draw_sample(from_inverse, dtype[key])
for key in self._domain._keys)
return MultiField(self._domain, val)
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# dtype = MultiField.build_dtype(dtype, self._domain)
# val = tuple(op.draw_sample(from_inverse, dtype)
# for op in self._op)
# return MultiField(self._domain, val)
def _combine_chain(self, op):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = {key : v1*v2 for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
op._operators[key]],
[selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops):
res[key] = SumOperator.make([v1, v2], [selfneg, opneg])
return BlockDiagonalOperator(res)
......@@ -6,8 +6,6 @@ from ..utilities import frozendict
class MultiDomain(object):
_domainCache = {}
_subsetCache = set()
_compatCache = set()
def __init__(self, dict, _callingfrommake=False):
if not _callingfrommake:
......@@ -15,7 +13,7 @@ class MultiDomain(object):
'To create a MultiDomain call `MultiDomain.make()`.')
self._keys = tuple(sorted(dict.keys()))
self._domains = tuple(dict[key] for key in self._keys)
self._dict = frozendict({key: i for i, key in enumerate(self._keys)})
self._idx = frozendict({key: i for i, key in enumerate(self._keys)})
@staticmethod
def make(inp):
......@@ -42,11 +40,15 @@ class MultiDomain(object):
def domains(self):
return self._domains
@property
def idx(self):
return self._idx
def items(self):
return zip(self._keys, self._domains)
def __getitem__(self, key):
return self._domains[self._dict[key]]
return self._domains[self._idx[key]]
def __len__(self):
return len(self._keys)
......
......@@ -55,21 +55,24 @@ class MultiField(object):
self._val = val
@staticmethod
def from_dict(dict):
domain = MultiDomain.make({key: v._domain for key, v in dict.items()})
return MultiField(domain, tuple(dict[key] for key in domain._keys))
def from_dict(dict, domain=None):
if domain is None:
domain = MultiDomain.make({key: v._domain
for key, v in dict.items()})
return MultiField(domain, tuple(dict[key] if key in dict else None
for key in domain.keys()))
def to_dict(self):
return {key: val for key, val in zip(self._domain._keys, self._val)}
return {key: val for key, val in zip(self._domain.keys(), self._val)}
def __getitem__(self, key):
return self._val[self._domain._dict[key]]
return self._val[self._domain.idx[key]]
def keys(self):
return self._domain.keys()
def items(self):
return zip(self._domain._keys, self._val)
return zip(self._domain.keys(), self._val)
def values(self):
return self._val
......
......@@ -53,6 +53,4 @@ class SelectionOperator(LinearOperator):
return x[self._key]
else:
from ..multi.multi_field import MultiField
rval = [None]*len(self._domain)
rval[self._domain._dict[self._key]] = x
return MultiField(self._domain, tuple(rval))
return MultiField.from_dict({self._key: x}, self._domain)
......@@ -228,10 +228,12 @@ def makeDomain(domain):
def makeOp(input):
if input is None:
return None
if isinstance(input, Field):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator({key: makeOp(val)
return BlockDiagonalOperator(input.domain, {key: makeOp(val)
for key, val in input.items()})
raise NotImplementedError
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment