There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit 426d824c authored by Martin Reinecke's avatar Martin Reinecke

add Rosenbrock test

parent 7e59ee98
Pipeline #25395 passed with stages
in 41 minutes and 54 seconds
......@@ -29,6 +29,9 @@ minimizers = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS,
ift.ConjugateGradient, ift.NonlinearCG,
ift.NewtonCG, ift.L_BFGS_B]
minimizers2 = [ift.RelaxedNewton, ift.VL_BFGS, ift.NonlinearCG,
ift.NewtonCG, ift.L_BFGS_B]
class Test_Minimizers(unittest.TestCase):
......@@ -57,5 +60,67 @@ class Test_Minimizers(unittest.TestCase):
rtol=1e-3, atol=1e-3)
def test_rosenbrock(self, minimizer_class):
from scipy.optimize import rosen, rosen_der, rosen_hess_prod
except ImportError:
raise SkipTest
space = ift.UnstructuredDomain((2,))
starting_point = ift.Field.from_random('normal', domain=space)*10
class RBEnergy(ift.Energy):
def __init__(self, position):
super(RBEnergy, self).__init__(position)
def value(self):
return rosen(self._position.to_global_data().copy())
def gradient(self):
inp = self._position.to_global_data().copy()
out = ift.Field.from_global_data(space, rosen_der(inp))
return out
def curvature(self):
class RBCurv(ift.EndomorphicOperator):
def __init__(self, loc):
self._loc = loc.to_global_data().copy()
def domain(self):
return space
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
inp = x.to_global_data().copy()
out = ift.Field.from_global_data(
space, rosen_hess_prod(self._loc.copy(), inp))
return out
# MR FIXME: add Rosenbrock test
t1 = ift.GradientNormController(tol_abs_gradnorm=1e-5,
t2 = ift.ConjugateGradient(controller=t1)
return ift.InversionEnabler(RBCurv(self._position),
IC = ift.GradientNormController(tol_abs_gradnorm=1e-5,
minimizer = minimizer_class(controller=IC)
energy = RBEnergy(position=starting_point)
(energy, convergence) = minimizer(energy)
except NotImplementedError:
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.to_global_data(), 1.,
rtol=1e-3, atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment