gradient_norm_controller.py 2.84 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
16
17
18
19
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

from .iteration_controller import IterationController
Martin Reinecke's avatar
Martin Reinecke committed
20
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
21

Martin Reinecke's avatar
Martin Reinecke committed
22

23
class GradientNormController(IterationController):
Martin Reinecke's avatar
Martin Reinecke committed
24
    def __init__(self, tol_abs_gradnorm=None, tol_rel_gradnorm=None,
25
                 convergence_level=1, iteration_limit=None, name=None):
26
        super(GradientNormController, self).__init__()
27
        self._tol_abs_gradnorm = tol_abs_gradnorm
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
28
        self._tol_rel_gradnorm = tol_rel_gradnorm
Martin Reinecke's avatar
Martin Reinecke committed
29
30
        self._convergence_level = convergence_level
        self._iteration_limit = iteration_limit
31
        self._name = name
Martin Reinecke's avatar
Martin Reinecke committed
32

33
    def start(self, energy):
Martin Reinecke's avatar
Martin Reinecke committed
34
35
        self._itcount = -1
        self._ccount = 0
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
36
        if self._tol_rel_gradnorm is not None:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
37
38
            self._tol_rel_gradnorm_now = self._tol_rel_gradnorm \
                                       * energy.gradient_norm
39
        return self.check(energy)
Martin Reinecke's avatar
Martin Reinecke committed
40

41
    def check(self, energy):
Martin Reinecke's avatar
Martin Reinecke committed
42
        self._itcount += 1
43

44
        inclvl = False
45
46
        if self._tol_abs_gradnorm is not None:
            if energy.gradient_norm <= self._tol_abs_gradnorm:
47
                inclvl = True
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
48
        if self._tol_rel_gradnorm is not None:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
49
            if energy.gradient_norm <= self._tol_rel_gradnorm_now:
50
51
52
53
54
                inclvl = True
        if inclvl:
            self._ccount += 1
        else:
            self._ccount = max(0, self._ccount-1)
55
56

        # report
57
58
        if self._name is not None:
            msg = self._name+":"
59
            msg += " Iteration #" + str(self._itcount)
Philipp Arras's avatar
Philipp Arras committed
60
            msg += " energy={:.6E}".format(energy.value)
Philipp Arras's avatar
Philipp Arras committed
61
            msg += " gradnorm={:.2E}".format(energy.gradient_norm)
Martin Reinecke's avatar
Martin Reinecke committed
62
            msg += " clvl=" + str(self._ccount)
Martin Reinecke's avatar
Martin Reinecke committed
63
            dobj.mprint(msg)
Martin Reinecke's avatar
Martin Reinecke committed
64
            # self.logger.info(msg)
65
66
67
68

        # Are we done?
        if self._iteration_limit is not None:
            if self._itcount >= self._iteration_limit:
69
70
                dobj.mprint(
                    "Warning:Iteration limit reached. Assuming convergence")
71
                return self.CONVERGED
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
72
73
        if self._ccount >= self._convergence_level:
            return self.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
74
75

        return self.CONTINUE