Commit 426d824c authored by Martin Reinecke's avatar Martin Reinecke

add Rosenbrock test

parent 7e59ee98
Pipeline #25395 passed with stages
in 41 minutes and 54 seconds
......@@ -29,6 +29,9 @@ minimizers = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS,
ift.ConjugateGradient, ift.NonlinearCG,
ift.NewtonCG, ift.L_BFGS_B]
minimizers2 = [ift.RelaxedNewton, ift.VL_BFGS, ift.NonlinearCG,
ift.NewtonCG, ift.L_BFGS_B]
class Test_Minimizers(unittest.TestCase):
......@@ -57,5 +60,67 @@ class Test_Minimizers(unittest.TestCase):
1./covariance_diagonal.to_global_data(),
rtol=1e-3, atol=1e-3)
@expand(product(minimizers2))
def test_rosenbrock(self, minimizer_class):
try:
from scipy.optimize import rosen, rosen_der, rosen_hess_prod
except ImportError:
raise SkipTest
np.random.seed(42)
space = ift.UnstructuredDomain((2,))
starting_point = ift.Field.from_random('normal', domain=space)*10
class RBEnergy(ift.Energy):
def __init__(self, position):
super(RBEnergy, self).__init__(position)
@property
def value(self):
return rosen(self._position.to_global_data().copy())
@property
def gradient(self):
inp = self._position.to_global_data().copy()
out = ift.Field.from_global_data(space, rosen_der(inp))
return out
@property
def curvature(self):
class RBCurv(ift.EndomorphicOperator):
def __init__(self, loc):
self._loc = loc.to_global_data().copy()
@property
def domain(self):
return space
@property
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
inp = x.to_global_data().copy()
out = ift.Field.from_global_data(
space, rosen_hess_prod(self._loc.copy(), inp))
return out
# MR FIXME: add Rosenbrock test
t1 = ift.GradientNormController(tol_abs_gradnorm=1e-5,
iteration_limit=1000)
t2 = ift.ConjugateGradient(controller=t1)
return ift.InversionEnabler(RBCurv(self._position),
inverter=t2)
IC = ift.GradientNormController(tol_abs_gradnorm=1e-5,
iteration_limit=10000)
try:
minimizer = minimizer_class(controller=IC)
energy = RBEnergy(position=starting_point)
(energy, convergence) = minimizer(energy)
except NotImplementedError:
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.to_global_data(), 1.,
rtol=1e-3, atol=1e-3)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment