Commit d08d0cbd authored by Martin Reinecke's avatar Martin Reinecke

tests, cosmetics, missing stuff

parent 071fc76e
...@@ -35,7 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -35,7 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"] "clipped_exp", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
...@@ -296,7 +297,8 @@ def _math_helper(x, function, out): ...@@ -296,7 +297,8 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__] _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]:
def func(f): def func(f):
def func2(x, out=None): def func2(x, out=None):
return _math_helper(x, f, out) return _math_helper(x, f, out)
...@@ -304,10 +306,14 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: ...@@ -304,10 +306,14 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f)) setattr(_current_module, f, func(f))
def clipped_exp(a): def clipped_exp(x):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis)) return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis))
def hardplus(x):
return data_object(x.shape, np.clip(x.data, 1e-20, None), x.distaxis)
def from_object(object, dtype, copy, set_locked): def from_object(object, dtype, copy, set_locked):
if dtype is None: if dtype is None:
dtype = object.dtype dtype = object.dtype
......
...@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "hardplus", "sin", "cos", "tan", "sinh", "clipped_exp", "hardplus", "sin", "cos", "tan", "sinh",
"cosh","absolute", "sign", "sinc"] "cosh", "absolute", "sign", "sinc"]
ntask = 1 ntask = 1
rank = 0 rank = 0
...@@ -159,4 +159,4 @@ def clipped_exp(arr): ...@@ -159,4 +159,4 @@ def clipped_exp(arr):
def hardplus(arr): def hardplus(arr):
return np.clip(arr, 1e-20, None) return np.clip(arr, 1e-20, None)
\ No newline at end of file
...@@ -33,7 +33,7 @@ from ..operators.linear_operator import LinearOperator ...@@ -33,7 +33,7 @@ from ..operators.linear_operator import LinearOperator
def _gaussian_error_function(x): def _gaussian_error_function(x):
return 0.5*erfc(x*np.sqrt(2.)) return 0.5/erfc(x*np.sqrt(2.))
def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf): def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf):
......
...@@ -189,7 +189,7 @@ class Linearization(object): ...@@ -189,7 +189,7 @@ class Linearization(object):
def hardplus(self): def hardplus(self):
tmp = self._val.hardplus() tmp = self._val.hardplus()
tmp2 = makeOp(1.-(tmp==1e-20)) tmp2 = makeOp(1.-(tmp == 1e-20))
return self.new(tmp, tmp2(self._jac)) return self.new(tmp, tmp2(self._jac))
def sin(self): def sin(self):
......
...@@ -178,7 +178,6 @@ class _OpChain(_CombinedOperator): ...@@ -178,7 +178,6 @@ class _OpChain(_CombinedOperator):
x = op(x) x = op(x)
return x return x
def __repr__(self): def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops) subs = "\n".join(sub.__repr__() for sub in self._ops)
return "_OpChain:\n" + indent(subs) return "_OpChain:\n" + indent(subs)
...@@ -211,7 +210,6 @@ class _OpProd(Operator): ...@@ -211,7 +210,6 @@ class _OpProd(Operator):
makeOp(lin2._val)(lin1._jac), False) makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, op(x.jac)) return lin1.new(lin1._val*lin2._val, op(x.jac))
def __repr__(self): def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2)) subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
return "_OpProd:\n"+indent(subs) return "_OpProd:\n"+indent(subs)
......
...@@ -292,3 +292,14 @@ class Test_Functionality(unittest.TestCase): ...@@ -292,3 +292,14 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f.local_data.shape, ()) assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1) assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.) assert_equal(f.vdot(f), 9.)
@expand(product([float(5), 5.],
[ift.RGSpace((8,), harmonic=True), ()],
["exp", "log", "sin", "cos", "tan", "sinh", "cosh", "sinc",
"absolute", "sign"]))
def test_funcs(self, num, dom, func):
num = 5
f = ift.Field.full(dom, num)
res = getattr(f, func)()
res2 = getattr(np, func)(num)
assert_allclose(res.local_data, res2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment