Commit 8d13e434 authored by Martin Reinecke's avatar Martin Reinecke

introduce (and use) new general consistency check for energies

parent 38a1494f
Pipeline #26575 canceled with stage
in 1 minute and 25 seconds
from .operator_tests import consistency_check
from .energy_tests import check_value_gradient_consistency
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
__all__ = ["check_value_gradient_consistency"]
def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
if not np.isfinite(E.value):
raise ValueError
for _ in range(ntries):
dir = Field.from_random("normal", E.position.domain)
# find a step length that leads to a "reasonable" energy
for i in range(50):
try:
E2 = E.at(E.position+dir)
if np.isfinite(E2.value) and abs(E2.value) < 1e20:
break
except FloatingPointError:
pass
dir *= 0.5
else:
raise ValueError("could not find a reasonable initial step")
Enext = E2
dirder = E.gradient.vdot(dir)
for i in range(50):
Ediff = E2.value - E.value
eps = 1e-10*max(abs(E.value), abs(E2.value))
if abs(Ediff-dirder) < max([tol*abs(Ediff), tol*abs(dirder), eps]):
break
dir *= 0.5
dirder *= 0.5
E2 = E2.at(E.position+dir)
else:
print i, Ediff, dirder, eps, E.value
raise ValueError("gradient and value seem inconsistent")
E = Enext
......@@ -63,14 +63,9 @@ class Energy_Tests(unittest.TestCase):
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy0 = ift.library.WienerFilterEnergy(
energy = ift.library.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, inverter=inverter)
energy1 = energy0.at(s1)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
......@@ -103,15 +98,10 @@ class Energy_Tests(unittest.TestCase):
xi1 = xi0 + eps * direction
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy0 = ift.library.NonlinearWienerFilterEnergy(
energy = ift.library.NonlinearWienerFilterEnergy(
position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A,
N=N, S=S)
energy1 = energy0.at(xi1)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
class Curvature_Tests(unittest.TestCase):
......
......@@ -87,10 +87,5 @@ class Noise_Energy_Tests(unittest.TestCase):
res_sample_list = [d - R(f(ht(C.inverse_draw_sample() + xi)))
for _ in range(10)]
energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
energy1 = energy0.at(eta1)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol)
energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
......@@ -76,7 +76,7 @@ class Energy_Tests(unittest.TestCase):
ht=ht,
inverter=inverter).curvature
energy0 = ift.library.NonlinearPowerEnergy(
energy = ift.library.NonlinearPowerEnergy(
position=tau0,
d=d,
xi=xi,
......@@ -87,9 +87,4 @@ class Energy_Tests(unittest.TestCase):
ht=ht,
N=N,
samples=10)
energy1 = energy0.at(tau1)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-4
assert_allclose(a, b, rtol=tol, atol=tol)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment