Commit 2e5933f5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 82fae1d6
Pipeline #24621 passed with stage
in 14 minutes and 36 seconds
...@@ -81,7 +81,6 @@ The current version of Nifty4 can be obtained by cloning the repository and ...@@ -81,7 +81,6 @@ The current version of Nifty4 can be obtained by cloning the repository and
switching to the NIFTy_4 branch: switching to the NIFTy_4 branch:
git clone https://gitlab.mpcdf.mpg.de/ift/NIFTy.git git clone https://gitlab.mpcdf.mpg.de/ift/NIFTy.git
git checkout NIFTy_4
### Installation ### Installation
......
...@@ -20,7 +20,6 @@ from __future__ import division ...@@ -20,7 +20,6 @@ from __future__ import division
from .minimizer import Minimizer from .minimizer import Minimizer
from ..field import Field from ..field import Field
from .. import dobj from .. import dobj
from ..utilities import general_axpy
class ConjugateGradient(Minimizer): class ConjugateGradient(Minimizer):
...@@ -68,15 +67,12 @@ class ConjugateGradient(Minimizer): ...@@ -68,15 +67,12 @@ class ConjugateGradient(Minimizer):
return energy, status return energy, status
r = energy.gradient r = energy.gradient
if preconditioner is not None: d = r.copy() if preconditioner is None else preconditioner(r)
d = preconditioner(r)
else:
d = r.copy()
previous_gamma = (r.vdot(d)).real previous_gamma = (r.vdot(d)).real
if previous_gamma == 0: if previous_gamma == 0:
return energy, controller.CONVERGED return energy, controller.CONVERGED
tpos = Field(d.domain, dtype=d.dtype) # temporary buffer
while True: while True:
q = energy.curvature(d) q = energy.curvature(d)
ddotq = d.vdot(q).real ddotq = d.vdot(q).real
...@@ -89,15 +85,12 @@ class ConjugateGradient(Minimizer): ...@@ -89,15 +85,12 @@ class ConjugateGradient(Minimizer):
dobj.mprint("Error: ConjugateGradient: alpha<0.") dobj.mprint("Error: ConjugateGradient: alpha<0.")
return energy, controller.ERROR return energy, controller.ERROR
general_axpy(-alpha, q, r, out=r) q *= -alpha
r += q
general_axpy(-alpha, d, energy.position, out=tpos) energy = energy.at_with_grad(energy.position - alpha*d, r)
energy = energy.at_with_grad(tpos, r)
if preconditioner is not None: s = r if preconditioner is None else preconditioner(r)
s = preconditioner(r)
else:
s = r
gamma = r.vdot(s).real gamma = r.vdot(s).real
if gamma < 0: if gamma < 0:
...@@ -111,6 +104,7 @@ class ConjugateGradient(Minimizer): ...@@ -111,6 +104,7 @@ class ConjugateGradient(Minimizer):
if status != controller.CONTINUE: if status != controller.CONTINUE:
return energy, status return energy, status
general_axpy(max(0, gamma/previous_gamma), d, s, out=d) d *= max(0, gamma/previous_gamma)
d += s
previous_gamma = gamma previous_gamma = gamma
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
class ScipyMinimizer(Minimizer):
"""Scipy-based minimizer
Parameters
----------
controller : IterationController
Object that decides when to terminate the minimization.
"""
def __init__(self, controller, method="trust-ncg"):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._controller = controller
self._method = method
def __call__(self, energy):
class _MinimizationDone:
pass
class _MinHelper(object):
def __init__(self, controller, energy):
self._controller = controller
self._energy = energy
self._domain = energy.position.domain
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
self._energy = self._energy.at(pos)
status = self._controller.check(self._energy)
if status != self._controller.CONTINUE:
raise _MinimizationDone
def fun(self, x):
self._update(x)
return self._energy.value
def jac(self, x):
self._update(x)
return self._energy.gradient.val.reshape(-1)
def hessp(self, x, p):
self._update(x)
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.reshape(-1)
import scipy.optimize as opt
status = self._controller.start(energy)
if status != self._controller.CONTINUE:
return energy, status
hlp = _MinHelper(self._controller, energy)
options = {'disp': False,
'xtol': 1e-15,
'eps': 1.4901161193847656e-08,
'return_all': False,
'maxiter': None}
options = {'disp': False,
'ftol': 1e-15,
'gtol': 1e-15,
'eps': 1.4901161193847656e-08}
try:
opt.minimize(hlp.fun, energy.position.val.reshape(-1),
method=self._method, jac=hlp.jac,
hessp=hlp.hessp,
options=options)
except _MinimizationDone:
energy = hlp._energy
status = self._controller.check(energy)
return energy, status
return hlp._energy, self._controller.ERROR
...@@ -232,25 +232,3 @@ def my_fftn_r2c(a, axes=None): ...@@ -232,25 +232,3 @@ def my_fftn_r2c(a, axes=None):
return res return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes) return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError("Incompatible domains")
if out is x:
if a != 1.:
out *= a
out += y
elif out is y:
if a != 1.:
out += a*x
else:
out += x
else:
out.copy_content_from(y)
if a != 1.:
out += a*x
else:
out += x
return out
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment