Commit 5b237675 authored by Martin Reinecke's avatar Martin Reinecke

streamlining

parent 6200a09c
......@@ -619,6 +619,16 @@ class Field(object):
def positive_tanh(self):
return 0.5*(1.+self.tanh())
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
return Field(self._domain, f(other._val))
if np.isscalar(other):
return Field(self._domain, f(other))
return NotImplemented
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......@@ -630,16 +640,7 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self._val, op)(other._val)
return Field(self._domain, tval)
if np.isscalar(other):
tval = getattr(self._val, op)(other)
return Field(self._domain, tval)
return NotImplemented
return self._binary_op(other, op)
return func2
setattr(Field, op, func(op))
......@@ -655,7 +656,6 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
return Field(self._domain, getattr(dobj, f)(self.val))
return func2
setattr(Field, f, func(f))
......@@ -200,27 +200,30 @@ class MultiField(object):
return False
def extract(self, subset):
if isinstance(subset, MultiDomain):
if subset is self._domain:
return self
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
else:
return MultiField.from_dict({key: self[key] for key in subset})
if subset is self._domain:
return self
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
def unite(self, other):
if self._domain is other._domain:
return self + other
return self.combine((self, other))
@staticmethod
def combine(fields):
res = {}
for f in fields:
for key, val in f.items():
res[key] = res[key]+val if key in res else val
res = self.to_dict()
for key, val in other.items():
res[key] = res[key]+val if key in res else val
return MultiField.from_dict(res)
def _binary_op(self, other, op):
f = getattr(Field, op)
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(f(v1, v2)
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(f(v1, other) for v1 in self._val)
return MultiField(self._domain, val)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......@@ -232,14 +235,7 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(getattr(v1, op)(v2)
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(getattr(v1, op)(other) for v1 in self._val)
return MultiField(self._domain, val)
return self._binary_op(other, op)
return func2
setattr(MultiField, op, func(op))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment