Commit a0e5a346 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent bbdb9944
Pipeline #28368 passed with stages
in 2 minutes and 36 seconds
......@@ -746,20 +746,6 @@ class Field(object):
raise ValueError("domains are incompatible.")
self.local_data[()] = other.local_data[()]
def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
def __repr__(self):
return "<nifty4.Field>"
......@@ -778,30 +764,38 @@ for op in ["__add__", "__radd__", "__iadd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
return self._binary_helper(other, op=op)
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
return func2
setattr(Field, op, func(op))
# Arithmetic functions working on Fields
def _math_helper(x, function, out):
function = getattr(dobj, function)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
function(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=function(x.val))
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
......@@ -50,17 +50,17 @@ class MultiField(object):
@staticmethod
def zeros(domain, dtype=None):
return MultiField({key: Field.zeros(dom, dtype=dtype)
for key, dom in domain.items()})
for key, dom in domain.items()})
@staticmethod
def ones(domain, dtype=None):
return MultiField({key: Field.ones(dom, dtype=dtype)
for key, dom in domain.items()})
for key, dom in domain.items()})
@staticmethod
def empty(domain, dtype=None):
return MultiField({key: Field.empty(dom, dtype=dtype)
for key, dom in domain.items()})
for key, dom in domain.items()})
def norm(self):
""" Computes the L2-norm of the field values.
......@@ -70,23 +70,14 @@ class MultiField(object):
norm : float
The L2-norm of the field values.
"""
return np.sqrt(np.abs(self.vdot(x=self)))
def _binary_helper(self, other, op):
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field,op)(other[key])
for key, sub_field in self.items()}
else:
result_val = {key: getattr(val,op)(other) for key, val in self.items()}
return MultiField(result_val)
def __neg__(self):
return MultiField({key: -val for key, val in self.items()})
def conjugate(self):
return MultiField({key: sub_field.conjugate() for key, sub_field in self.items()})
return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()})
for op in ["__add__", "__radd__", "__iadd__",
......@@ -99,6 +90,13 @@ for op in ["__add__", "__radd__", "__iadd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
return self._binary_helper(other, op=op)
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
return MultiField(result_val)
return func2
setattr(MultiField, op, func(op))
from ..operators.linear_operator import LinearOperator
class MultiLinearOperator(LinearOperator):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment