Commit 8d13e434 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce (and use) new general consistency check for energies

parent 38a1494f
Pipeline #26575 canceled with stage
in 1 minute and 25 seconds
from .operator_tests import consistency_check from .operator_tests import consistency_check
from .energy_tests import check_value_gradient_consistency
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
__all__ = ["check_value_gradient_consistency"]
def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
if not np.isfinite(E.value):
raise ValueError
for _ in range(ntries):
dir = Field.from_random("normal", E.position.domain)
# find a step length that leads to a "reasonable" energy
for i in range(50):
try:
E2 = E.at(E.position+dir)
if np.isfinite(E2.value) and abs(E2.value) < 1e20:
break
except FloatingPointError:
pass
dir *= 0.5
else:
raise ValueError("could not find a reasonable initial step")
Enext = E2
dirder = E.gradient.vdot(dir)
for i in range(50):
Ediff = E2.value - E.value
eps = 1e-10*max(abs(E.value), abs(E2.value))
if abs(Ediff-dirder) < max([tol*abs(Ediff), tol*abs(dirder), eps]):
break
dir *= 0.5
dirder *= 0.5
E2 = E2.at(E.position+dir)
else:
print i, Ediff, dirder, eps, E.value
raise ValueError("gradient and value seem inconsistent")
E = Enext
...@@ -63,14 +63,9 @@ class Energy_Tests(unittest.TestCase): ...@@ -63,14 +63,9 @@ class Energy_Tests(unittest.TestCase):
inverter = ift.ConjugateGradient(IC) inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS) S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy0 = ift.library.WienerFilterEnergy( energy = ift.library.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, inverter=inverter) position=s0, d=d, R=R, N=N, S=S, inverter=inverter)
energy1 = energy0.at(s1) ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789), @expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)], ift.RGSpace([32, 32], distances=.789)],
...@@ -103,15 +98,10 @@ class Energy_Tests(unittest.TestCase): ...@@ -103,15 +98,10 @@ class Energy_Tests(unittest.TestCase):
xi1 = xi0 + eps * direction xi1 = xi0 + eps * direction
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS) S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy0 = ift.library.NonlinearWienerFilterEnergy( energy = ift.library.NonlinearWienerFilterEnergy(
position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A, position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A,
N=N, S=S) N=N, S=S)
energy1 = energy0.at(xi1) ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol)
class Curvature_Tests(unittest.TestCase): class Curvature_Tests(unittest.TestCase):
......
...@@ -87,10 +87,5 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -87,10 +87,5 @@ class Noise_Energy_Tests(unittest.TestCase):
res_sample_list = [d - R(f(ht(C.inverse_draw_sample() + xi))) res_sample_list = [d - R(f(ht(C.inverse_draw_sample() + xi)))
for _ in range(10)] for _ in range(10)]
energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list) energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
energy1 = energy0.at(eta1) ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol)
...@@ -76,7 +76,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -76,7 +76,7 @@ class Energy_Tests(unittest.TestCase):
ht=ht, ht=ht,
inverter=inverter).curvature inverter=inverter).curvature
energy0 = ift.library.NonlinearPowerEnergy( energy = ift.library.NonlinearPowerEnergy(
position=tau0, position=tau0,
d=d, d=d,
xi=xi, xi=xi,
...@@ -87,9 +87,4 @@ class Energy_Tests(unittest.TestCase): ...@@ -87,9 +87,4 @@ class Energy_Tests(unittest.TestCase):
ht=ht, ht=ht,
N=N, N=N,
samples=10) samples=10)
energy1 = energy0.at(tau1) ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-4
assert_allclose(a, b, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment