Commit 19ab4fa3 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add absolute and sinc to MultiField

parent 2f9f16af
......@@ -699,6 +699,20 @@ class Field(object):
return Field(self._domain, f(other))
return NotImplemented
def _sinc_withjac(self):
res = self.sinc()
jac = (((np.pi*self).cos()-res)/self).val_rw()
jac[self.val == 0] = 0
return res, Field(self._domain, jac)
def _absolute_withjac(self):
if utilities.iscomplextype(self.dtype):
raise TypeError("Argument must not be complex")
res = self.absolute()
jac = self.sign().val_rw().astype(float)
jac[self.val == 0] = np.nan
return res, Field(self._domain, jac)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -318,13 +318,12 @@ class Linearization(object):
return self.new(tmp, makeOp(tmp2)(self._jac))
def sinc(self):
tmp = self._val.sinc()
tmp2 = ((np.pi*self._val).cos()-tmp)/self._val
ind = self._val.val == 0
loc = tmp2.val_rw()
loc[ind] = 0
tmp2 = Field(tmp.domain, loc)
return self.new(tmp, makeOp(tmp2)(self._jac))
val, jac = self._val._sinc_withjac()
return self.new(val, makeOp(jac)(self._jac))
def absolute(self):
val, jac = self._val._absolute_withjac()
return self.new(val, makeOp(jac)(self._jac))
def log(self):
tmp = self._val.log()
......@@ -364,19 +363,6 @@ class Linearization(object):
tmp2 = 0.5*(1.+tmp)
return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
if utilities.iscomplextype(self._val.dtype):
raise TypeError("Argument must not be complex")
tmp = self._val.absolute()
tmp2 = self._val.sign()
ind = self._val.val == 0
loc = tmp2.val_rw().astype(float)
loc[ind] = np.nan
tmp2 = Field(tmp.domain, loc)
return self.new(tmp, makeOp(tmp2)(self._jac))
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
......
......@@ -324,6 +324,21 @@ class MultiField(object):
val = tuple(f(v1, other) for v1 in self._val)
return MultiField(self._domain, val)
def _sinc_withjac(self):
return self._nontrivial_jac_helper('_sinc_withjac')
def _absolute_withjac(self):
return self._nontrivial_jac_helper('_absolute_withjac')
def _nontrivial_jac_helper(self, funcname):
val = self.to_dict()
jac = {}
for kk, vv in val.items():
val[kk], jac[kk] = getattr(vv, funcname)()
val = MultiField.from_dict(val, self._domain)
jac = MultiField.from_dict(jac, self._domain)
return val, jac
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment