diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py index 1b7ca92269a52d5e7db9f9d3f8b1b21989504e6c..67c2f317ec7e825f6b0a7bdfbf3697742249196a 100644 --- a/test/test_energies/test_map.py +++ b/test/test_energies/test_map.py @@ -67,7 +67,7 @@ class Energy_Tests(unittest.TestCase): ift.RGSpace([32, 32], distances=.789)], [ift.Tanh, ift.Exponential, ift.Linear], [4, 78, 23])) - def testNonlinearMap(self, space, nonlinearity, seed): + def testGaussianEnergy(self, space, nonlinearity, seed): np.random.seed(seed) f = nonlinearity() dim = len(space.shape) @@ -85,14 +85,11 @@ class Energy_Tests(unittest.TestCase): n = ift.Field.from_random(domain=space, random_type='normal') s = ht(ift.makeOp(A)(xi0_var)) R = ift.ScalingOperator(10., space) - sqrtN = ift.ScalingOperator(1., space) + N = ift.ScalingOperator(1., space) d_model = R(ift.LocalModel(s, nonlinearity())) d = d_model.value + n - IC = ift.GradientNormController(iteration_limit=100, - tol_abs_gradnorm=1e-5) - energy = ift.NonlinearWienerFilterEnergy( - d, d_model, sqrtN, IC) + energy = ift.GaussianEnergy(d_model, d, N) if isinstance(nonlinearity, ift.Linear): ift.extra.check_value_gradient_curvature_consistency( energy, ntries=10)