Commit 5b237675 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

streamlining

parent 6200a09c
...@@ -619,6 +619,16 @@ class Field(object): ...@@ -619,6 +619,16 @@ class Field(object):
def positive_tanh(self): def positive_tanh(self):
return 0.5*(1.+self.tanh()) return 0.5*(1.+self.tanh())
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
return Field(self._domain, f(other._val))
if np.isscalar(other):
return Field(self._domain, f(other))
return NotImplemented
for op in ["__add__", "__radd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__sub__", "__rsub__",
...@@ -630,16 +640,7 @@ for op in ["__add__", "__radd__", ...@@ -630,16 +640,7 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op): def func(op):
def func2(self, other): def func2(self, other):
# if other is a field, make sure that the domains match return self._binary_op(other, op)
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self._val, op)(other._val)
return Field(self._domain, tval)
if np.isscalar(other):
tval = getattr(self._val, op)(other)
return Field(self._domain, tval)
return NotImplemented
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
...@@ -655,7 +656,6 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -655,7 +656,6 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
for f in ["sqrt", "exp", "log", "tanh"]: for f in ["sqrt", "exp", "log", "tanh"]:
def func(f): def func(f):
def func2(self): def func2(self):
fu = getattr(dobj, f) return Field(self._domain, getattr(dobj, f)(self.val))
return Field(domain=self._domain, val=fu(self.val))
return func2 return func2
setattr(Field, f, func(f)) setattr(Field, f, func(f))
...@@ -200,27 +200,30 @@ class MultiField(object): ...@@ -200,27 +200,30 @@ class MultiField(object):
return False return False
def extract(self, subset): def extract(self, subset):
if isinstance(subset, MultiDomain): if subset is self._domain:
if subset is self._domain: return self
return self return MultiField(subset,
return MultiField(subset, tuple(self[key] for key in subset.keys()))
tuple(self[key] for key in subset.keys()))
else:
return MultiField.from_dict({key: self[key] for key in subset})
def unite(self, other): def unite(self, other):
if self._domain is other._domain: if self._domain is other._domain:
return self + other return self + other
return self.combine((self, other)) res = self.to_dict()
for key, val in other.items():
@staticmethod res[key] = res[key]+val if key in res else val
def combine(fields):
res = {}
for f in fields:
for key, val in f.items():
res[key] = res[key]+val if key in res else val
return MultiField.from_dict(res) return MultiField.from_dict(res)
def _binary_op(self, other, op):
f = getattr(Field, op)
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(f(v1, v2)
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(f(v1, other) for v1 in self._val)
return MultiField(self._domain, val)
for op in ["__add__", "__radd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__sub__", "__rsub__",
...@@ -232,14 +235,7 @@ for op in ["__add__", "__radd__", ...@@ -232,14 +235,7 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op): def func(op):
def func2(self, other): def func2(self, other):
if isinstance(other, MultiField): return self._binary_op(other, op)
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(getattr(v1, op)(v2)
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(getattr(v1, op)(other) for v1 in self._val)
return MultiField(self._domain, val)
return func2 return func2
setattr(MultiField, op, func(op)) setattr(MultiField, op, func(op))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment