Commit 640ed6ea authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce method

parent 9161a4c6
......@@ -690,6 +690,11 @@ class Field(object):
max = max.val if isinstance(max, Field) else max
return Field(self._domain, np.clip(self._val, min, max))
def where(self, iftrue, iffalse):
iftrue = itrue.val if isinstance(iftrue, Field) else iftrue
iffalse = iffalse.val if isinstance(iffalse, Field) else iffalse
return Field(self._domain, np.where(self._val, iftrue, iffalse))
def one_over(self):
return 1/self
......@@ -711,7 +716,8 @@ for op in ["__add__", "__radd__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__",
"__and__", "__or__", "__xor__"]:
def func(op):
def func2(self, other):
return self._binary_op(other, op)
......
......@@ -300,13 +300,13 @@ class Linearization(Operator):
tmp = self._val.clip(min, max)
if (min is None) and (max is None):
return self
elif max is None:
tmp2 = makeOp(1. - (tmp == min))
elif min is None:
tmp2 = makeOp(1. - (tmp == max))
else:
tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
return self.new(tmp, tmp2(self._jac))
from .sugar import full
mask = full(tmp._domain, 1.)
if max is not None:
mask = (tmp == max).where(0., mask)
if min is not None:
mask = (tmp == min).where(0., mask)
return self.new(tmp, makeOp(mask)(self._jac))
def sqrt(self):
tmp = self._val.sqrt()
......@@ -330,10 +330,7 @@ class Linearization(Operator):
def sinc(self):
tmp = self._val.sinc()
tmp2 = ((np.pi*self._val).cos()-tmp)/self._val
ind = self._val.val == 0
loc = tmp2.val_rw()
loc[ind] = 0
tmp2 = makeField(tmp.domain, loc)
tmp2 = (self._val == 0.).where(0., tmp2)
return self.new(tmp, makeOp(tmp2)(self._jac))
def log(self):
......@@ -375,16 +372,11 @@ class Linearization(Operator):
return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
if utilities.iscomplextype(self._val.dtype):
raise TypeError("Argument must not be complex")
# FIXME
# if utilities.iscomplextype(self._val.dtype):
# raise TypeError("Argument must not be complex")
tmp = self._val.absolute()
tmp2 = self._val.sign()
ind = self._val.val == 0
loc = tmp2.val_rw().astype(float)
loc[ind] = np.nan
tmp2 = Field(tmp.domain, loc)
tmp2 = (self._val == 0).where(np.nan, self._val.sign())
return self.new(tmp, makeOp(tmp2)(self._jac))
def one_over(self):
......
......@@ -235,6 +235,14 @@ class MultiField(Operator):
self._domain,
tuple(self._val[i].clip(lmin[i], lmax[i]) for i in range(ncomp)))
def where(self, iftrue, iffalse):
ncomp = len(self._val)
iftrue = iftrue._val if isinstance(iftrue, MultiField) else (iftrue,)*ncomp
iffalse = iffalse._val if isinstance(iffalse, MultiField) else (iffalse,)*ncomp
return MultiField(
self._domain,
tuple(self._val[i].where(iftrue[i], iffalse[i]) for i in range(ncomp)))
def s_all(self):
for v in self._val:
if not v.s_all():
......@@ -360,7 +368,8 @@ for op in ["__add__", "__radd__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__",
"__and__", "__or__", "__xor__"]:
def func(op):
def func2(self, other):
return self._binary_op(other, op)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment