diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py index 9085762cbfbd0804c28107459fc3ba73148b32d1..25947573cf91535d639035a2a688b34c75362579 100644 --- a/nifty4/data_objects/distributed_do.py +++ b/nifty4/data_objects/distributed_do.py @@ -19,6 +19,7 @@ import numpy as np from .random import Random from mpi4py import MPI +import sys _comm = MPI.COMM_WORLD ntask = _comm.Get_size() @@ -185,75 +186,6 @@ class data_object(object): else: return data_object(self._shape, tval, self._distaxis) - def __add__(self, other): - return self._binary_helper(other, op='__add__') - - def __radd__(self, other): - return self._binary_helper(other, op='__radd__') - - def __iadd__(self, other): - return self._binary_helper(other, op='__iadd__') - - def __sub__(self, other): - return self._binary_helper(other, op='__sub__') - - def __rsub__(self, other): - return self._binary_helper(other, op='__rsub__') - - def __isub__(self, other): - return self._binary_helper(other, op='__isub__') - - def __mul__(self, other): - return self._binary_helper(other, op='__mul__') - - def __rmul__(self, other): - return self._binary_helper(other, op='__rmul__') - - def __imul__(self, other): - return self._binary_helper(other, op='__imul__') - - def __div__(self, other): - return self._binary_helper(other, op='__div__') - - def __rdiv__(self, other): - return self._binary_helper(other, op='__rdiv__') - - def __idiv__(self, other): - return self._binary_helper(other, op='__idiv__') - - def __truediv__(self, other): - return self._binary_helper(other, op='__truediv__') - - def __rtruediv__(self, other): - return self._binary_helper(other, op='__rtruediv__') - - def __pow__(self, other): - return self._binary_helper(other, op='__pow__') - - def __rpow__(self, other): - return self._binary_helper(other, op='__rpow__') - - def __ipow__(self, other): - return self._binary_helper(other, op='__ipow__') - - def __lt__(self, other): - return self._binary_helper(other, op='__lt__') - - def __le__(self, other): - return self._binary_helper(other, op='__le__') - - def __ne__(self, other): - return self._binary_helper(other, op='__ne__') - - def __eq__(self, other): - return self._binary_helper(other, op='__eq__') - - def __ge__(self, other): - return self._binary_helper(other, op='__ge__') - - def __gt__(self, other): - return self._binary_helper(other, op='__gt__') - def __neg__(self): return data_object(self._shape, -self._data, self._distaxis) @@ -269,6 +201,20 @@ class data_object(object): def fill(self, value): self._data.fill(value) +for op in ["__add__", "__radd__", "__iadd__", + "__sub__", "__rsub__", "__isub__", + "__mul__", "__rmul__", "__imul__", + "__div__", "__rdiv__", "__idiv__", + "__truediv__", "__rtruediv__", "__itruediv__", + "__floordiv__", "__rfloordiv__", "__ifloordiv__", + "__pow__", "__rpow__", "__ipow__", + "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: + def func(op): + def func2(self, other): + return self._binary_helper(other, op=op) + return func2 + setattr(data_object, op, func(op)) + def full(shape, fill_value, dtype=None, distaxis=0): return data_object(shape, np.full(local_shape(shape, distaxis), @@ -302,6 +248,7 @@ def vdot(a, b): def _math_helper(x, function, out): + function = getattr(np, function) if out is not None: function(x._data, out=out._data) return out @@ -309,24 +256,14 @@ def _math_helper(x, function, out): return data_object(x.shape, function(x._data), x._distaxis) -def abs(a, out=None): - return _math_helper(a, np.abs, out) - - -def exp(a, out=None): - return _math_helper(a, np.exp, out) - - -def log(a, out=None): - return _math_helper(a, np.log, out) - - -def tanh(a, out=None): - return _math_helper(a, np.tanh, out) - +_current_module = sys.modules[__name__] -def sqrt(a, out=None): - return _math_helper(a, np.sqrt, out) +for f in ["sqrt", "exp", "log", "tanh", "conjugate", "abs"]: + def func(f): + def func2(x, out=None): + return _math_helper(x, f, out) + return func2 + setattr(_current_module, f, func(f)) def from_object(object, dtype, copy, set_locked): diff --git a/nifty4/field.py b/nifty4/field.py index 2f6d30f478ee730c07b3a7252845a20ecd7a32bb..cdb66d4736380867d34577ec64a86d58e42dbf4f 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -23,6 +23,7 @@ from . import utilities from .domain_tuple import DomainTuple from functools import reduce from . import dobj +import sys __all__ = ["Field", "sqrt", "exp", "log", "conjugate"] @@ -759,75 +760,6 @@ class Field(object): return NotImplemented - def __add__(self, other): - return self._binary_helper(other, op='__add__') - - def __radd__(self, other): - return self._binary_helper(other, op='__radd__') - - def __iadd__(self, other): - return self._binary_helper(other, op='__iadd__') - - def __sub__(self, other): - return self._binary_helper(other, op='__sub__') - - def __rsub__(self, other): - return self._binary_helper(other, op='__rsub__') - - def __isub__(self, other): - return self._binary_helper(other, op='__isub__') - - def __mul__(self, other): - return self._binary_helper(other, op='__mul__') - - def __rmul__(self, other): - return self._binary_helper(other, op='__rmul__') - - def __imul__(self, other): - return self._binary_helper(other, op='__imul__') - - def __div__(self, other): - return self._binary_helper(other, op='__div__') - - def __truediv__(self, other): - return self._binary_helper(other, op='__truediv__') - - def __rdiv__(self, other): - return self._binary_helper(other, op='__rdiv__') - - def __rtruediv__(self, other): - return self._binary_helper(other, op='__rtruediv__') - - def __idiv__(self, other): - return self._binary_helper(other, op='__idiv__') - - def __pow__(self, other): - return self._binary_helper(other, op='__pow__') - - def __rpow__(self, other): - return self._binary_helper(other, op='__rpow__') - - def __ipow__(self, other): - return self._binary_helper(other, op='__ipow__') - - def __lt__(self, other): - return self._binary_helper(other, op='__lt__') - - def __le__(self, other): - return self._binary_helper(other, op='__le__') - - def __ne__(self, other): - return self._binary_helper(other, op='__ne__') - - def __eq__(self, other): - return self._binary_helper(other, op='__eq__') - - def __ge__(self, other): - return self._binary_helper(other, op='__ge__') - - def __gt__(self, other): - return self._binary_helper(other, op='__gt__') - def __repr__(self): return "<nifty4.Field>" @@ -836,10 +768,25 @@ class Field(object): self._domain.__str__() + \ "\n- val = " + repr(self.val) +for op in ["__add__", "__radd__", "__iadd__", + "__sub__", "__rsub__", "__isub__", + "__mul__", "__rmul__", "__imul__", + "__div__", "__rdiv__", "__idiv__", + "__truediv__", "__rtruediv__", "__itruediv__", + "__floordiv__", "__rfloordiv__", "__ifloordiv__", + "__pow__", "__rpow__", "__ipow__", + "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: + def func(op): + def func2(self, other): + return self._binary_helper(other, op=op) + return func2 + setattr(Field, op, func(op)) + # Arithmetic functions working on Fields def _math_helper(x, function, out): + function = getattr(dobj, function) if not isinstance(x, Field): raise TypeError("This function only accepts Field objects.") if out is not None: @@ -850,22 +797,11 @@ def _math_helper(x, function, out): else: return Field(domain=x._domain, val=function(x.val)) +_current_module = sys.modules[__name__] -def sqrt(x, out=None): - return _math_helper(x, dobj.sqrt, out) - - -def exp(x, out=None): - return _math_helper(x, dobj.exp, out) - - -def log(x, out=None): - return _math_helper(x, dobj.log, out) - - -def tanh(x, out=None): - return _math_helper(x, dobj.tanh, out) - - -def conjugate(x, out=None): - return _math_helper(x, dobj.conjugate, out) +for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: + def func(f): + def func2(x, out=None): + return _math_helper(x, f, out) + return func2 + setattr(_current_module, f, func(f))