Commit ab127321 authored by Martin Reinecke's avatar Martin Reinecke

cosmetics

parent ff8dcc0b
......@@ -272,9 +272,9 @@ def test_err():
with assert_raises(TypeError):
ift.full(s1, [2, 3])
with assert_raises(TypeError):
ift.Field(s2, [0,1])
ift.Field(s2, [0, 1])
with assert_raises(TypeError):
f1.outer([0,1])
f1.outer([0, 1])
with assert_raises(ValueError):
f1.extract(s2)
with assert_raises(TypeError):
......@@ -340,12 +340,14 @@ def test_funcs(num, dom, func):
res2 = getattr(np, func)(num)
assert_allclose(res.local_data, res2)
@pmp('rtype', ['normal', 'pm1', 'uniform'])
@pmp('dtype', [np.float64, np.complex128])
def test_from_random(rtype, dtype):
sp = ift.RGSpace(3)
f = ift.Field.from_random(rtype, sp, dtype=dtype)
def test_field_of_objects():
arr = np.array(['x', 'y', 'z'])
sp = ift.RGSpace(3)
......
......@@ -72,6 +72,7 @@ def test_quadratic_minimization(minimizer, space):
rtol=1e-3,
atol=1e-3)
@pmp('space', spaces)
def test_WF_curvature(space):
np.random.seed(42)
......@@ -85,7 +86,8 @@ def test_WF_curvature(space):
n = ift.Field.from_random('uniform', domain=space) + 0.5
N = ift.DiagonalOperator(n)
all_diag = 1./s + r**2/n
curv = ift.WienerFilterCurvature(R,N,S, iteration_controller=IC, iteration_controller_sampling=IC)
curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC,
iteration_controller_sampling=IC)
m = curv.inverse(required_result)
assert_allclose(
m.local_data,
......@@ -100,7 +102,9 @@ def test_WF_curvature(space):
n = ift.from_random('uniform', R.domain) + 0.5
N = ift.DiagonalOperator(n)
all_diag = 1./s + R(1/n)
curv = ift.WienerFilterCurvature(R.adjoint,N,S, iteration_controller=IC, iteration_controller_sampling=IC)
curv = ift.WienerFilterCurvature(R.adjoint, N, S,
iteration_controller=IC,
iteration_controller_sampling=IC)
m = curv.inverse(required_result)
assert_allclose(
m.local_data,
......@@ -111,8 +115,6 @@ def test_WF_curvature(space):
curv.draw_sample(from_inverse=True)
@pmp('minimizer', minimizers + newton_minimizers)
def test_rosenbrock(minimizer):
try:
......
......@@ -34,13 +34,13 @@ def test_func():
assert_allclose(
ift.log(ift.exp((f1)))["d1"].local_data, f1["d1"].local_data)
def test_multifield_field_consistency():
f1 = ift.full(dom, 27)
f2 = ift.from_global_data(dom, f1.to_global_data())
assert_equal(f1.sum(), f2.sum())
assert_equal(-f1, (-f2)['d1'])
assert_equal(f1.__abs__(), (f2.__abs__())['d1'])
assert_equal(abs(f1), (abs(f2))['d1'])
def test_dataconv():
......@@ -48,7 +48,7 @@ def test_dataconv():
f2 = ift.from_global_data(dom, f1.to_global_data())
for key, val in f1.items():
assert_equal(val.local_data, f2[key].local_data)
if not "d1" in f2:
if "d1" not in f2:
raise KeyError()
assert_equal({"d1": f1}, f2.to_dict())
f3 = ift.full(dom, 27+1.j)
......
......@@ -81,10 +81,10 @@ def testBinary(type1, type2, space, seed):
pos = ift.from_random("normal", dom)
model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
model = select_s1 **2
model = select_s1**2
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = select_s1.clip(-1,1)
model = select_s1.clip(-1, 1)
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
if isinstance(space, ift.RGSpace):
......@@ -127,9 +127,9 @@ def testPointModel(space, seed):
@pmp('target', [
ift.RGSpace(64, distances=.789,harmonic=True),
ift.RGSpace([32, 32], distances=.789,harmonic=True),
ift.RGSpace([32, 32, 8], distances=.789,harmonic=True)
ift.RGSpace(64, distances=.789, harmonic=True),
ift.RGSpace([32, 32], distances=.789, harmonic=True),
ift.RGSpace([32, 32, 8], distances=.789, harmonic=True)
])
@pmp('causal', [True, False])
@pmp('minimum_phase', [True, False])
......
......@@ -25,13 +25,14 @@ import nifty5 as ift
def test_get_signal_variance():
space = ift.RGSpace(3)
hspace = space.get_default_codomain()
spec1 = lambda x : np.ones_like(x)
spec1 = lambda x: np.ones_like(x)
assert_equal(ift.get_signal_variance(spec1, hspace), 3.)
space = ift.RGSpace(3, distances=1.)
hspace = space.get_default_codomain()
def spec2(k):
t = np.zeros_like(k)
t[k==0] = 1.
t[k == 0] = 1.
return t
assert_equal(ift.get_signal_variance(spec2, hspace), 1/9.)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment