Commit 55ec681a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more streamlining

parent fc902894
Pipeline #72511 passed with stages
in 19 minutes and 54 seconds
......@@ -677,9 +677,6 @@ class Field(Operator):
def flexible_addsub(self, other, neg):
return self-other if neg else self+other
def clip(self, a_min=None, a_max=None):
return self.ptw("clip", a_min, a_max)
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
......@@ -691,20 +688,24 @@ class Field(Operator):
return Field(self._domain, f(other))
return NotImplemented
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
def _prep_args(self, args, kwargs):
for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
def ptw_with_deriv(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
argstmp, kwargstmp = self._prep_args(args, kwargs)
tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp)
return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
......
......@@ -281,15 +281,6 @@ class Linearization(Operator):
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
return self.new(t1, makeOp(t2)(self._jac))
def clip(self, a_min=None, a_max=None):
if a_min is None and a_max is None:
return self
if not (a_min is None or np.isscalar(a_min) or a_min.jac is None):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or a_max.jac is None):
return NotImplemented
return self.ptw("clip", a_min, a_max)
def add_metric(self, metric):
return self.new(self._val, self._jac, metric)
......
......@@ -306,25 +306,27 @@ class MultiField(Operator):
res[key] = -val if neg else val
return MultiField.from_dict(res)
def _prep_args(self, args, kwargs, i):
for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs):
# _check_args(args, kwargs)
tmp = []
for i in range(len(self._val)):
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
argstmp, kwargstmp = self._prep_args(args, kwargs, i)
tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp))
return MultiField(self.domain, tuple(tmp))
def ptw_with_deriv(self, op, *args, **kwargs):
# _check_args(args, kwargs)
tmp = []
for i in range(len(self._val)):
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
argstmp, kwargstmp = self._prep_args(args, kwargs, i)
tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp))
return (MultiField(self.domain, tuple(v[0] for v in tmp)),
MultiField(self.domain, tuple(v[1] for v in tmp)))
......
......@@ -210,15 +210,6 @@ class Operator(metaclass=NiftyMeta):
return NotImplemented
return self.ptw("power", power)
def clip(self, a_min=None, a_max=None):
if a_min is None and a_max is None:
return self
if not (a_min is None or np.isscalar(a_min) or a_min.jac is None):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or a_max.jac is None):
return NotImplemented
return self.ptw("clip", a_min, a_max)
def apply(self, x):
"""Applies the operator to a Field or MultiField.
......
......@@ -67,7 +67,7 @@ def _power_helper(v, expo):
return (np.power(v, expo), expo*np.power(v, expo-1))
def _clip_helper(v, a_min=None, a_max=None):
def _clip_helper(v, a_min, a_max):
if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex")
tmp = np.clip(v, a_min, a_max)
......
......@@ -193,8 +193,8 @@ def test_empty_domain():
def test_trivialities():
s1 = ift.RGSpace((10,))
f1 = ift.Field.full(s1, 27)
assert_equal(f1.clip(a_min=29).val, 29.)
assert_equal(f1.clip(a_max=25).val, 25.)
assert_equal(f1.clip(a_min=29, a_max=50).val, 29.)
assert_equal(f1.clip(a_min=0, a_max=25).val, 25.)
assert_equal(f1.val, f1.real.val)
assert_equal(f1.val, (+f1).val)
f1 = ift.Field.full(s1, 27. + 3j)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment