Commit 347a5b77 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more scipy work

parent e4e5039f
apt-get install -y build-essential git autoconf libtool pkg-config libfftw3-dev openmpi-bin libopenmpi-dev \
python python-pip python-dev python-nose python-numpy python-matplotlib python-future python-mpi4py \
python3 python3-pip python3-dev python3-nose python3-numpy python3-matplotlib python3-future python3-mpi4py
python python-pip python-dev python-nose python-numpy python-matplotlib python-future python-mpi4py python-scipy \
python3 python3-pip python3-dev python3-nose python3-numpy python3-matplotlib python3-future python3-mpi4py python3-scipy
......@@ -44,6 +44,7 @@ from .minimization.descent_minimizer import DescentMinimizer
from .minimization.steepest_descent import SteepestDescent
from .minimization.vl_bfgs import VL_BFGS
from .minimization.relaxed_newton import RelaxedNewton
from .minimization.scipy_minimizer import NewtonCG, L_BFGS_B
from import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.line_energy import LineEnergy
......@@ -27,6 +27,10 @@ rank = _comm.Get_rank()
master = (rank == 0)
def is_numpy():
return False
def mprint(*args):
if master:
......@@ -259,6 +263,7 @@ class data_object(object):
def fill(self, value):
def full(shape, fill_value, dtype=None, distaxis=0):
return data_object(shape, np.full(local_shape(shape, distaxis),
fill_value, dtype), distaxis)
......@@ -346,7 +351,7 @@ def local_data(arr):
def ibegin_from_shape(glob_shape, distaxis=0):
res = [0] * len(glob_shape)
if distaxis<0:
if distaxis < 0:
return res
res[distaxis] = _shareRange(glob_shape[distaxis], ntask, rank)[0]
return tuple(res)
......@@ -30,6 +30,10 @@ rank = 0
master = True
def is_numpy():
return True
def mprint(*args):
......@@ -33,4 +33,4 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "mprint"]
"redistribute", "default_distaxis", "mprint", "is_numpy"]
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2017 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
class ScipyMinimizer(Minimizer):
"""Scipy-based minimizer
controller : IterationController
Object that decides when to terminate the minimization.
method : str
The selected Scipy minimization method.
options : dictionary
A set of custom options for the selected minimizer.
def __init__(self, controller, method, options, need_hessp):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._controller = controller
self._method = method
self._options = options
self._need_hessp = need_hessp
def __call__(self, energy):
class _MinimizationDone:
class _MinHelper(object):
def __init__(self, controller, energy):
self._controller = controller
self._energy = energy
self._domain = energy.position.domain
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
self._energy =
status = self._controller.check(self._energy)
if status != self._controller.CONTINUE:
raise _MinimizationDone
def fun(self, x):
return self._energy.value
def jac(self, x):
return self._energy.gradient.val.reshape(-1)
def hessp(self, x, p):
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.reshape(-1)
import scipy.optimize as opt
hlp = _MinHelper(self._controller, energy)
energy = None
status = self._controller.start(hlp._energy)
if status != self._controller.CONTINUE:
return hlp._energy, status
if self._need_hessp:
opt.minimize(, hlp._energy.position.val.reshape(-1),
method=self._method, jac=hlp.jac,
opt.minimize(, hlp._energy.position.val.reshape(-1),
method=self._method, jac=hlp.jac,
except _MinimizationDone:
status = self._controller.check(hlp._energy)
return hlp._energy, self._controller.check(hlp._energy)
return hlp._energy, self._controller.ERROR
def NewtonCG(controller):
return ScipyMinimizer(controller, "Newton-CG",
{"xtol": 1e-20, "maxiter": None}, True)
def L_BFGS_B(controller, maxcor=10):
return ScipyMinimizer(controller, "L-BFGS-B",
{"ftol": 1e-20, "gtol": 1e-20, "maxcor": maxcor},
......@@ -18,14 +18,16 @@
import unittest
import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal
import nifty4 as ift
from itertools import product
from test.common import expand
from nose.plugins.skip import SkipTest
spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)]
minimizers = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS,
ift.ConjugateGradient, ift.NonlinearCG]
ift.ConjugateGradient, ift.NonlinearCG,
ift.NewtonCG, ift.L_BFGS_B]
class Test_Minimizers(unittest.TestCase):
......@@ -39,13 +41,20 @@ class Test_Minimizers(unittest.TestCase):
covariance = ift.DiagonalOperator(covariance_diagonal)
required_result = ift.Field.ones(space, dtype=np.float64)
IC = ift.GradientNormController(tol_abs_gradnorm=1e-5)
IC = ift.GradientNormController(verbose=True,tol_abs_gradnorm=1e-5, iteration_limit=1000)
minimizer = minimizer_class(controller=IC)
energy = ift.QuadraticEnergy(A=covariance, b=required_result,
(energy, convergence) = minimizer(energy)
assert convergence == IC.CONVERGED
except NotImplementedError:
raise SkipTest
assert_equal(convergence, IC.CONVERGED)
rtol=1e-3, atol=1e-3)
#MR FIXME: add Rosenbrock test
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment