Commit 69a50640 authored by Martin Reinecke's avatar Martin Reinecke

intermediate commit

parent 75f50648
......@@ -22,12 +22,12 @@ from .model import Model
def _joint_position(model1, model2):
a = model1.position._val
b = model2.position._val
a = model1.position.to_dict()
b = model2.position.to_dict()
# Note: In python >3.5 one could do {**a, **b}
ab = dict(a)
ab = a
ab.update(b)
return MultiField(ab)
return MultiField.from_dict(ab)
class ScalarMul(Model):
......
......@@ -31,12 +31,11 @@ class BlockDiagonalOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField({key: op.apply(x[key], mode=mode)
for key, op in self._operators.items()})
return MultiField(x.domain, tuple(self._operators[key].apply(x._val[i], mode=mode) for i, key in enumerate(x.keys())))
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField({key: op.draw_sample(from_inverse, dtype[key])
return MultiField.from_dict({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
def _combine_chain(self, op):
......
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..utilities import frozendict
class MultiDomain(frozendict):
class MultiDomain(object):
_domainCache = {}
_subsetCache = set()
_compatCache = set()
def __init__(self, domain, _callingfrommake=False):
def __init__(self, dict, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError(
'To create a MultiDomain call `MultiDomain.make()`.')
super(MultiDomain, self).__init__(domain)
self._keys = tuple(sorted(dict.keys()))
self._domains = tuple(dict[key] for key in self._keys)
self._dict = frozendict({key: i for i, key in enumerate(self._keys)})
@staticmethod
def make(domain):
if isinstance(domain, MultiDomain):
return domain
if not isinstance(domain, dict):
def make(inp):
if isinstance(inp, MultiDomain):
return inp
if not isinstance(inp, dict):
raise TypeError("dict expected")
tmp = {}
for key, value in domain.items():
for key, value in inp.items():
if not isinstance(key, str):
raise TypeError("keys must be strings")
tmp[key] = DomainTuple.make(value)
domain = frozendict(tmp)
obj = MultiDomain._domainCache.get(domain)
tmp = frozendict(tmp)
obj = MultiDomain._domainCache.get(tmp)
if obj is not None:
return obj
obj = MultiDomain(domain, _callingfrommake=True)
MultiDomain._domainCache[domain] = obj
obj = MultiDomain(tmp, _callingfrommake=True)
MultiDomain._domainCache[tmp] = obj
return obj
def keys(self):
return self._keys
def domains(self):
return self._domains
def items(self):
return zip(self._keys, self._domains)
def __getitem__(self, key):
return self._domains[self._dict[key]]
def __len__(self):
return len(self._keys)
def __hash__(self):
return self._keys.__hash__() ^ self._domains.__hash__()
def __eq__(self, x):
if self is x:
return True
x = MultiDomain.make(x)
return self is x
return self is MultiDomain.make(x)
def __ne__(self, x):
return not self.__eq__(x)
def __hash__(self):
return super(MultiDomain, self).__hash__()
def compatibleTo(self, x):
if self is x:
return True
......
......@@ -22,86 +22,122 @@ from .multi_domain import MultiDomain
from ..utilities import frozendict
# ways of creating MultiFields:
# - (Field)
# - (Field, name)
# - (dict {string, Field})
# - MultiDomain, dict(string, Field)
# new methods
# .field(name)
# .domain(name)
class MultiField(object):
def __init__(self, val):
def __init__(self, domain, val):
"""
Parameters
----------
val : dict
domain: MultiDomain
val: tuple of Fields
"""
self._val = frozendict(val)
self._domain = MultiDomain.make(
{key: val.domain for key, val in self._val.items()})
if not isinstance(domain, MultiDomain):
raise TypeError("domain must be of type MultiDomain")
if not isinstance(val, tuple):
raise TypeError("val must be a tuple")
if len(val) != len(domain):
raise ValueError("length mismatch")
for i, v in enumerate(val):
if isinstance(v, Field):
if v._domain is not domain._domains[i]:
raise ValueError("domain mismatch")
elif v is not None:
raise TypeError("bad entry in val")
self._domain = domain
self._val = val
@staticmethod
def from_dict(dict):
domain = MultiDomain.make({key: v._domain for key, v in dict.items()})
return MultiField(domain, tuple(dict[key] for key in domain._keys))
def to_dict(self):
return {key: val for key, val in zip(self._domain._keys, self._val)}
def __getitem__(self, key):
return self._val[key]
return self._val[self._domain._dict[key]]
def keys(self):
return self._val.keys()
return self._domain.keys()
def items(self):
return self._val.items()
return zip(self._domain._keys, self._val)
def values(self):
return self._val.values()
return self._val
@property
def domain(self):
return self._domain
@property
def dtype(self):
return {key: val.dtype for key, val in self._val.items()}
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
def _transform(self, op):
return MultiField(
self._domain,
tuple(op(v) if v is not None else None for v in self._val))
@property
def real(self):
"""MultiField : The real part of the multi field"""
return MultiField({key: field.real for key, field in self.items()})
return self._transform(lambda x: x.real)
@property
def imag(self):
"""MultiField : The imaginary part of the multi field"""
return MultiField({key: field.imag for key, field in self.items()})
return self._transform(lambda x: x.imag)
@staticmethod
def from_random(random_type, domain, dtype=np.float64, **kwargs):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.from_random(random_type, domain[key],
dtype[key], **kwargs)
for key in sorted(domain.keys())})
domain = MultiDomain.make(domain)
# dtype = MultiField.build_dtype(dtype, domain)
return MultiField(
domain, tuple(Field.from_random(random_type, dom, dtype, **kwargs)
for dom in domain._domains))
def _check_domain(self, other):
if other._domain != self._domain:
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
result = 0.
self._check_domain(x)
for key, sub_field in self.items():
result += sub_field.vdot(x[key])
for v1, v2 in zip(self._val, x._val):
if v1 is not None and v2 is not None:
result += v1.vdot(v2)
return result
@staticmethod
def build_dtype(dtype, domain):
if isinstance(dtype, dict):
return dtype
if dtype is None:
dtype = np.float64
return {key: dtype for key in domain.keys()}
# @staticmethod
# def build_dtype(dtype, domain):
# if isinstance(dtype, dict):
# return dtype
# if dtype is None:
# dtype = np.float64
# return {key: dtype for key in domain.keys()}
@staticmethod
def full(domain, val):
return MultiField({key: Field.full(dom, val)
for key, dom in domain.items()})
return MultiField(domain, tuple(Field.full(dom, val)
for dom in domain._domains))
def to_global_data(self):
return {key: val.to_global_data() for key, val in self._val.items()}
return {key: val.to_global_data() for key, val in zip(self._domain.keys(), self._val)}
@staticmethod
def from_global_data(domain, arr, sum_up=False):
return MultiField({key: Field.from_global_data(domain[key],
val, sum_up)
for key, val in arr.items()})
return MultiField(domain, tuple(Field.from_global_data(domain[key],
arr[key], sum_up)
for key in domain.keys()))
def norm(self):
""" Computes the L2-norm of the field values.
......@@ -124,24 +160,23 @@ class MultiField(object):
return abs(self.vdot(x=self))
def __neg__(self):
return MultiField({key: -val for key, val in self.items()})
return self._transform(lambda x: -x)
def __abs__(self):
return MultiField({key: abs(val) for key, val in self.items()})
return self._transform(lambda x: abs(x))
def conjugate(self):
return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()})
return self._transform(lambda x: x.conjugate())
def all(self):
for v in self.values():
if not v.all():
for v in self._val:
if v is None or not v.all():
return False
return True
def any(self):
for v in self.values():
if v.any():
for v in self._val:
if v is not None and v.any():
return True
return False
......@@ -152,10 +187,10 @@ class MultiField(object):
return True
if not isinstance(other, MultiField):
return False
if self._domain != other._domain:
if self._domain is not other._domain:
return False
for key, val in self._val.items():
if not val.isEquivalentTo(other[key]):
for v1, v2 in zip(self._val, other._val):
if not v1.isEquivalentTo(v2):
return False
return True
......@@ -186,39 +221,24 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
res = []
if isinstance(other, MultiField):
if self._domain == other._domain:
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
else:
if not self._domain.compatibleTo(other.domain):
raise ValueError("domain mismatch")
s1 = set(self._domain.keys())
s2 = set(other._domain.keys())
common_keys = s1 & s2
only_self_keys = s1 - s2
only_other_keys = s2 - s1
result_val = {}
for key in common_keys:
result_val[key] = getattr(self[key], op)(other[key])
if op in ("__add__", "__radd__"):
for key in only_self_keys:
result_val[key] = self[key]
for key in only_other_keys:
result_val[key] = other[key]
elif op in ("__mul__", "__rmul__"):
pass
if self._domain is not other._domain:
raise ValueError("domain mismatch")
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
if v2 is None:
res.append(getattr(v1, op)(v1*0))
else:
res.append(getattr(v1, op)(v2))
else:
for key in only_self_keys:
result_val[key] = getattr(
self[key], op)(self[key]*0.)
for key in only_other_keys:
result_val[key] = getattr(
other[key]*0., op)(other[key])
if v2 is None:
res.append(None)
else:
res.append(getattr(v2*0, op)(v2))
return MultiField(self._domain, tuple(res))
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
return MultiField(result_val)
return self._transform(lambda x: getattr(x, op)(other))
return func2
setattr(MultiField, op, func(op))
......
......@@ -55,4 +55,4 @@ class SelectionOperator(LinearOperator):
return x[self._key]
else:
from ..multi.multi_field import MultiField
return MultiField({self._key: x})
return MultiField.from_dict({self._key: x})
......@@ -77,7 +77,7 @@ class Energy_Tests(unittest.TestCase):
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
Dist = ift.PowerDistributor(target=hspace, power_space=pspace)
xi0 = ift.Field.from_random(domain=hspace, random_type='normal')
xi0_var = ift.Variable(ift.MultiField({'xi': xi0}))['xi']
xi0_var = ift.Variable(ift.MultiField.from_dict({'xi': xi0}))['xi']
def pspec(k): return 1 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
......
......@@ -34,6 +34,6 @@ class Model_Tests(unittest.TestCase):
S = ift.ScalingOperator(1., space)
s1 = S.draw_sample()
s2 = S.draw_sample()
s1_var = ift.Variable(ift.MultiField({'s1': s1}))['s1']
s2_var = ift.Variable(ift.MultiField({'s2': s2}))['s2']
s1_var = ift.Variable(ift.MultiField.from_dict({'s1': s1}))['s1']
s2_var = ift.Variable(ift.MultiField.from_dict({'s2': s2}))['s2']
ift.extra.check_value_gradient_consistency(s1_var*s2_var)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment