Commit 3ad02747 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'update_multi_fields' into 'NIFTy_5'

Update multi fields

See merge request ift/nifty-dev!99
parents 66934b83 31960d21
......@@ -24,6 +24,7 @@ from . import utilities
from .compat import *
from .field import Field
from .multi_domain import MultiDomain
from .domain_tuple import DomainTuple
class MultiField(object):
......@@ -52,6 +53,10 @@ class MultiField(object):
@staticmethod
def from_dict(dict, domain=None):
if domain is None:
for dd in dict.values():
if not isinstance(dd.domain, DomainTuple):
raise TypeError('Values of dictionary need to be Fields '
'defined on DomainTuples.')
domain = MultiDomain.make({key: v._domain
for key, v in dict.items()})
res = tuple(dict[key] if key in dict else Field(dom, 0)
......@@ -61,6 +66,11 @@ class MultiField(object):
def to_dict(self):
return {key: val for key, val in zip(self._domain.keys(), self._val)}
def update(self, other):
foo = self.to_dict()
foo.update(other.to_dict())
return MultiField.from_dict(foo)
def __getitem__(self, key):
return self._val[self._domain.idx[key]]
......
......@@ -56,3 +56,19 @@ class Test_Functionality(unittest.TestCase):
f1 = op2(ift.full(dom, 1))
for val in f1.values():
assert_equal((val == 40).all(), True)
def test_update(self):
dom = ift.RGSpace(10)
f1 = ift.from_random('normal', domain=dom)
f2 = ift.from_random('normal', domain=dom)
f_new = ift.MultiField.from_dict({'dom1': f1, 'dom2': f2})
f3 = ift.from_random('normal', domain=dom)
f4 = ift.from_random('normal', domain=dom)
f5 = ift.from_random('normal', domain=dom)
f_old = ift.MultiField.from_dict({'dom1': f3, 'dom2': f4, 'dom3': f5})
updated = f_old.update(f_new).to_dict()
updated_true = f_old.to_dict()
updated_true.update(f_new.to_dict())
for key, val in updated.items():
assert_equal((val == updated_true[key]).all(), True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment