Skip to content
Snippets Groups Projects
Commit 2e5933f5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 82fae1d6
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -81,7 +81,6 @@ The current version of Nifty4 can be obtained by cloning the repository and
switching to the NIFTy_4 branch:
git clone https://gitlab.mpcdf.mpg.de/ift/NIFTy.git
git checkout NIFTy_4
### Installation
......
......@@ -20,7 +20,6 @@ from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
from ..utilities import general_axpy
class ConjugateGradient(Minimizer):
......@@ -68,15 +67,12 @@ class ConjugateGradient(Minimizer):
return energy, status
r = energy.gradient
if preconditioner is not None:
d = preconditioner(r)
else:
d = r.copy()
d = r.copy() if preconditioner is None else preconditioner(r)
previous_gamma = (r.vdot(d)).real
if previous_gamma == 0:
return energy, controller.CONVERGED
tpos = Field(d.domain, dtype=d.dtype) # temporary buffer
while True:
q = energy.curvature(d)
ddotq = d.vdot(q).real
......@@ -89,15 +85,12 @@ class ConjugateGradient(Minimizer):
dobj.mprint("Error: ConjugateGradient: alpha<0.")
return energy, controller.ERROR
general_axpy(-alpha, q, r, out=r)
q *= -alpha
r += q
general_axpy(-alpha, d, energy.position, out=tpos)
energy = energy.at_with_grad(tpos, r)
energy = energy.at_with_grad(energy.position - alpha*d, r)
if preconditioner is not None:
s = preconditioner(r)
else:
s = r
s = r if preconditioner is None else preconditioner(r)
gamma = r.vdot(s).real
if gamma < 0:
......@@ -111,6 +104,7 @@ class ConjugateGradient(Minimizer):
if status != controller.CONTINUE:
return energy, status
general_axpy(max(0, gamma/previous_gamma), d, s, out=d)
d *= max(0, gamma/previous_gamma)
d += s
previous_gamma = gamma
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
class ScipyMinimizer(Minimizer):
"""Scipy-based minimizer
Parameters
----------
controller : IterationController
Object that decides when to terminate the minimization.
"""
def __init__(self, controller, method="trust-ncg"):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._controller = controller
self._method = method
def __call__(self, energy):
class _MinimizationDone:
pass
class _MinHelper(object):
def __init__(self, controller, energy):
self._controller = controller
self._energy = energy
self._domain = energy.position.domain
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
self._energy = self._energy.at(pos)
status = self._controller.check(self._energy)
if status != self._controller.CONTINUE:
raise _MinimizationDone
def fun(self, x):
self._update(x)
return self._energy.value
def jac(self, x):
self._update(x)
return self._energy.gradient.val.reshape(-1)
def hessp(self, x, p):
self._update(x)
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.reshape(-1)
import scipy.optimize as opt
status = self._controller.start(energy)
if status != self._controller.CONTINUE:
return energy, status
hlp = _MinHelper(self._controller, energy)
options = {'disp': False,
'xtol': 1e-15,
'eps': 1.4901161193847656e-08,
'return_all': False,
'maxiter': None}
options = {'disp': False,
'ftol': 1e-15,
'gtol': 1e-15,
'eps': 1.4901161193847656e-08}
try:
opt.minimize(hlp.fun, energy.position.val.reshape(-1),
method=self._method, jac=hlp.jac,
hessp=hlp.hessp,
options=options)
except _MinimizationDone:
energy = hlp._energy
status = self._controller.check(energy)
return energy, status
return hlp._energy, self._controller.ERROR
......@@ -232,25 +232,3 @@ def my_fftn_r2c(a, axes=None):
return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError("Incompatible domains")
if out is x:
if a != 1.:
out *= a
out += y
elif out is y:
if a != 1.:
out += a*x
else:
out += x
else:
out.copy_content_from(y)
if a != 1.:
out += a*x
else:
out += x
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment