Commit 2e5933f5 authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent 82fae1d6
Pipeline #24621 passed with stage
in 14 minutes and 36 seconds
......@@ -81,7 +81,6 @@ The current version of Nifty4 can be obtained by cloning the repository and
switching to the NIFTy_4 branch:
git clone https://gitlab.mpcdf.mpg.de/ift/NIFTy.git
git checkout NIFTy_4
### Installation
......
......@@ -20,7 +20,6 @@ from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
from ..utilities import general_axpy
class ConjugateGradient(Minimizer):
......@@ -68,15 +67,12 @@ class ConjugateGradient(Minimizer):
return energy, status
r = energy.gradient
if preconditioner is not None:
d = preconditioner(r)
else:
d = r.copy()
d = r.copy() if preconditioner is None else preconditioner(r)
previous_gamma = (r.vdot(d)).real
if previous_gamma == 0:
return energy, controller.CONVERGED
tpos = Field(d.domain, dtype=d.dtype) # temporary buffer
while True:
q = energy.curvature(d)
ddotq = d.vdot(q).real
......@@ -89,15 +85,12 @@ class ConjugateGradient(Minimizer):
dobj.mprint("Error: ConjugateGradient: alpha<0.")
return energy, controller.ERROR
general_axpy(-alpha, q, r, out=r)
q *= -alpha
r += q
general_axpy(-alpha, d, energy.position, out=tpos)
energy = energy.at_with_grad(tpos, r)
energy = energy.at_with_grad(energy.position - alpha*d, r)
if preconditioner is not None:
s = preconditioner(r)
else:
s = r
s = r if preconditioner is None else preconditioner(r)
gamma = r.vdot(s).real
if gamma < 0:
......@@ -111,6 +104,7 @@ class ConjugateGradient(Minimizer):
if status != controller.CONTINUE:
return energy, status
general_axpy(max(0, gamma/previous_gamma), d, s, out=d)
d *= max(0, gamma/previous_gamma)
d += s
previous_gamma = gamma
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
class ScipyMinimizer(Minimizer):
"""Scipy-based minimizer
Parameters
----------
controller : IterationController
Object that decides when to terminate the minimization.
"""
def __init__(self, controller, method="trust-ncg"):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._controller = controller
self._method = method
def __call__(self, energy):
class _MinimizationDone:
pass
class _MinHelper(object):
def __init__(self, controller, energy):
self._controller = controller
self._energy = energy
self._domain = energy.position.domain
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
self._energy = self._energy.at(pos)
status = self._controller.check(self._energy)
if status != self._controller.CONTINUE:
raise _MinimizationDone
def fun(self, x):
self._update(x)
return self._energy.value
def jac(self, x):
self._update(x)
return self._energy.gradient.val.reshape(-1)
def hessp(self, x, p):
self._update(x)
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.reshape(-1)
import scipy.optimize as opt
status = self._controller.start(energy)
if status != self._controller.CONTINUE:
return energy, status
hlp = _MinHelper(self._controller, energy)
options = {'disp': False,
'xtol': 1e-15,
'eps': 1.4901161193847656e-08,
'return_all': False,
'maxiter': None}
options = {'disp': False,
'ftol': 1e-15,
'gtol': 1e-15,
'eps': 1.4901161193847656e-08}
try:
opt.minimize(hlp.fun, energy.position.val.reshape(-1),
method=self._method, jac=hlp.jac,
hessp=hlp.hessp,
options=options)
except _MinimizationDone:
energy = hlp._energy
status = self._controller.check(energy)
return energy, status
return hlp._energy, self._controller.ERROR
......@@ -232,25 +232,3 @@ def my_fftn_r2c(a, axes=None):
return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError("Incompatible domains")
if out is x:
if a != 1.:
out *= a
out += y
elif out is y:
if a != 1.:
out += a*x
else:
out += x
else:
out.copy_content_from(y)
if a != 1.:
out += a*x
else:
out += x
return out
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment