Commit d08d0cbd authored by Martin Reinecke's avatar Martin Reinecke

tests, cosmetics, missing stuff

parent 071fc76e
......@@ -35,7 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
"clipped_exp", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -296,7 +297,8 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
......@@ -304,10 +306,14 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f))
def clipped_exp(a):
def clipped_exp(x):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis))
def hardplus(x):
return data_object(x.shape, np.clip(x.data, 1e-20, None), x.distaxis)
def from_object(object, dtype, copy, set_locked):
if dtype is None:
dtype = object.dtype
......
......@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "hardplus", "sin", "cos", "tan", "sinh",
"cosh","absolute", "sign", "sinc"]
"cosh", "absolute", "sign", "sinc"]
ntask = 1
rank = 0
......@@ -159,4 +159,4 @@ def clipped_exp(arr):
def hardplus(arr):
return np.clip(arr, 1e-20, None)
\ No newline at end of file
return np.clip(arr, 1e-20, None)
......@@ -33,7 +33,7 @@ from ..operators.linear_operator import LinearOperator
def _gaussian_error_function(x):
return 0.5*erfc(x*np.sqrt(2.))
return 0.5/erfc(x*np.sqrt(2.))
def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf):
......
......@@ -189,7 +189,7 @@ class Linearization(object):
def hardplus(self):
tmp = self._val.hardplus()
tmp2 = makeOp(1.-(tmp==1e-20))
tmp2 = makeOp(1.-(tmp == 1e-20))
return self.new(tmp, tmp2(self._jac))
def sin(self):
......
......@@ -178,7 +178,6 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "_OpChain:\n" + indent(subs)
......@@ -211,7 +210,6 @@ class _OpProd(Operator):
makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, op(x.jac))
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpProd:\n"+indent(subs)
......
......@@ -292,3 +292,14 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.)
@expand(product([float(5), 5.],
[ift.RGSpace((8,), harmonic=True), ()],
["exp", "log", "sin", "cos", "tan", "sinh", "cosh", "sinc",
"absolute", "sign"]))
def test_funcs(self, num, dom, func):
num = 5
f = ift.Field.full(dom, num)
res = getattr(f, func)()
res2 = getattr(np, func)(num)
assert_allclose(res.local_data, res2)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment