Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2e5933f5
Commit
2e5933f5
authored
Feb 08, 2018
by
Martin Reinecke
Browse files
cleanup
parent
82fae1d6
Pipeline
#24621
passed with stage
in 14 minutes and 36 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
README.md
View file @
2e5933f5
...
@@ -81,7 +81,6 @@ The current version of Nifty4 can be obtained by cloning the repository and
...
@@ -81,7 +81,6 @@ The current version of Nifty4 can be obtained by cloning the repository and
switching to the NIFTy_4 branch:
switching to the NIFTy_4 branch:
git clone https://gitlab.mpcdf.mpg.de/ift/NIFTy.git
git clone https://gitlab.mpcdf.mpg.de/ift/NIFTy.git
git checkout NIFTy_4
### Installation
### Installation
...
...
nifty4/minimization/conjugate_gradient.py
View file @
2e5933f5
...
@@ -20,7 +20,6 @@ from __future__ import division
...
@@ -20,7 +20,6 @@ from __future__ import division
from
.minimizer
import
Minimizer
from
.minimizer
import
Minimizer
from
..field
import
Field
from
..field
import
Field
from
..
import
dobj
from
..
import
dobj
from
..utilities
import
general_axpy
class
ConjugateGradient
(
Minimizer
):
class
ConjugateGradient
(
Minimizer
):
...
@@ -68,15 +67,12 @@ class ConjugateGradient(Minimizer):
...
@@ -68,15 +67,12 @@ class ConjugateGradient(Minimizer):
return
energy
,
status
return
energy
,
status
r
=
energy
.
gradient
r
=
energy
.
gradient
if
preconditioner
is
not
None
:
d
=
r
.
copy
()
if
preconditioner
is
None
else
preconditioner
(
r
)
d
=
preconditioner
(
r
)
else
:
d
=
r
.
copy
()
previous_gamma
=
(
r
.
vdot
(
d
)).
real
previous_gamma
=
(
r
.
vdot
(
d
)).
real
if
previous_gamma
==
0
:
if
previous_gamma
==
0
:
return
energy
,
controller
.
CONVERGED
return
energy
,
controller
.
CONVERGED
tpos
=
Field
(
d
.
domain
,
dtype
=
d
.
dtype
)
# temporary buffer
while
True
:
while
True
:
q
=
energy
.
curvature
(
d
)
q
=
energy
.
curvature
(
d
)
ddotq
=
d
.
vdot
(
q
).
real
ddotq
=
d
.
vdot
(
q
).
real
...
@@ -89,15 +85,12 @@ class ConjugateGradient(Minimizer):
...
@@ -89,15 +85,12 @@ class ConjugateGradient(Minimizer):
dobj
.
mprint
(
"Error: ConjugateGradient: alpha<0."
)
dobj
.
mprint
(
"Error: ConjugateGradient: alpha<0."
)
return
energy
,
controller
.
ERROR
return
energy
,
controller
.
ERROR
general_axpy
(
-
alpha
,
q
,
r
,
out
=
r
)
q
*=
-
alpha
r
+=
q
general_axpy
(
-
alpha
,
d
,
energy
.
position
,
out
=
tpos
)
energy
=
energy
.
at_with_grad
(
energy
.
position
-
alpha
*
d
,
r
)
energy
=
energy
.
at_with_grad
(
tpos
,
r
)
if
preconditioner
is
not
None
:
s
=
r
if
preconditioner
is
None
else
preconditioner
(
r
)
s
=
preconditioner
(
r
)
else
:
s
=
r
gamma
=
r
.
vdot
(
s
).
real
gamma
=
r
.
vdot
(
s
).
real
if
gamma
<
0
:
if
gamma
<
0
:
...
@@ -111,6 +104,7 @@ class ConjugateGradient(Minimizer):
...
@@ -111,6 +104,7 @@ class ConjugateGradient(Minimizer):
if
status
!=
controller
.
CONTINUE
:
if
status
!=
controller
.
CONTINUE
:
return
energy
,
status
return
energy
,
status
general_axpy
(
max
(
0
,
gamma
/
previous_gamma
),
d
,
s
,
out
=
d
)
d
*=
max
(
0
,
gamma
/
previous_gamma
)
d
+=
s
previous_gamma
=
gamma
previous_gamma
=
gamma
nifty4/minimization/scipy_minimizer.py.bak
deleted
100644 → 0
View file @
82fae1d6
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from .minimizer import Minimizer
from ..field import Field
from .. import dobj
class ScipyMinimizer(Minimizer):
"""Scipy-based minimizer
Parameters
----------
controller : IterationController
Object that decides when to terminate the minimization.
"""
def __init__(self, controller, method="trust-ncg"):
super(ScipyMinimizer, self).__init__()
if not dobj.is_numpy():
raise NotImplementedError
self._controller = controller
self._method = method
def __call__(self, energy):
class _MinimizationDone:
pass
class _MinHelper(object):
def __init__(self, controller, energy):
self._controller = controller
self._energy = energy
self._domain = energy.position.domain
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
self._energy = self._energy.at(pos)
status = self._controller.check(self._energy)
if status != self._controller.CONTINUE:
raise _MinimizationDone
def fun(self, x):
self._update(x)
return self._energy.value
def jac(self, x):
self._update(x)
return self._energy.gradient.val.reshape(-1)
def hessp(self, x, p):
self._update(x)
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.reshape(-1)
import scipy.optimize as opt
status = self._controller.start(energy)
if status != self._controller.CONTINUE:
return energy, status
hlp = _MinHelper(self._controller, energy)
options = {'disp': False,
'xtol': 1e-15,
'eps': 1.4901161193847656e-08,
'return_all': False,
'maxiter': None}
options = {'disp': False,
'ftol': 1e-15,
'gtol': 1e-15,
'eps': 1.4901161193847656e-08}
try:
opt.minimize(hlp.fun, energy.position.val.reshape(-1),
method=self._method, jac=hlp.jac,
hessp=hlp.hessp,
options=options)
except _MinimizationDone:
energy = hlp._energy
status = self._controller.check(energy)
return energy, status
return hlp._energy, self._controller.ERROR
nifty4/utilities.py
View file @
2e5933f5
...
@@ -232,25 +232,3 @@ def my_fftn_r2c(a, axes=None):
...
@@ -232,25 +232,3 @@ def my_fftn_r2c(a, axes=None):
return
res
return
res
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
def
general_axpy
(
a
,
x
,
y
,
out
):
if
x
.
domain
!=
y
.
domain
or
x
.
domain
!=
out
.
domain
:
raise
ValueError
(
"Incompatible domains"
)
if
out
is
x
:
if
a
!=
1.
:
out
*=
a
out
+=
y
elif
out
is
y
:
if
a
!=
1.
:
out
+=
a
*
x
else
:
out
+=
x
else
:
out
.
copy_content_from
(
y
)
if
a
!=
1.
:
out
+=
a
*
x
else
:
out
+=
x
return
out
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment