There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit 02c18676 authored by Martin Reinecke's avatar Martin Reinecke

fix tests for pointwise functions that take an extra argument

parent 87b680d2
Pipeline #91928 passed with stages
in 11 minutes and 9 seconds
......@@ -55,7 +55,8 @@ def test_special_gradients():
@pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', 'expm1', 'softplus', 'exponentiate'
'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', 'expm1', 'softplus',
('power', 2.), ('exponentiate', 1.1)
])
@pmp('cplxpos', [True, False])
@pmp('cplxdir', [True, False])
......@@ -77,8 +78,13 @@ def test_actual_gradients(f, cplxpos, cplxdir, holomorphic):
eps *= (1+0.78j)
var0 = ift.Linearization.make_var(fld)
var1 = ift.Linearization.make_var(fld + eps)
f0 = var0.ptw(f).val.val
f1 = var1.ptw(f).val.val
if isinstance(f, tuple):
f0 = var0.ptw(*f).val.val
f1 = var1.ptw(*f).val.val
df1 = _lin2grad(var0.ptw(*f))
else:
f0 = var0.ptw(f).val.val
f1 = var1.ptw(f).val.val
df1 = _lin2grad(var0.ptw(f))
df0 = (f1 - f0)/eps
df1 = _lin2grad(var0.ptw(f))
assert_allclose(df0, df1, rtol=100*np.abs(eps))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment