stochastic_minimizer.py 3.24 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
from .minimizer import Minimizer
19
from .energy import Energy
20

Philipp Arras's avatar
Docs    
Philipp Arras committed
21

Philipp Arras's avatar
Philipp Arras committed
22
class ADVIOptimizer(Minimizer):
Philipp Arras's avatar
Philipp Arras committed
23
24
    """Provide an implementation of an adaptive step-size sequence optimizer,
    following https://arxiv.org/abs/1603.00788.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
25
26
27
28
29
30

    Parameters
    ----------
    steps: int
        The number of concecutive steps during one call of the optimizer.
    eta: positive float
Philipp Arras's avatar
Philipp Arras committed
31
32
        The scale of the step-size sequence. It might have to be adapted to the
        application to increase performance. Default: 1.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
33
    alpha: float between 0 and 1
Philipp Arras's avatar
Philipp Arras committed
34
        The fraction of how much the current gradient impacts the momentum.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
35
36
37
38
    tau: positive float
        This quantity prevents division by zero.
    epsilon: positive float
        A small value guarantees Robbins and Monro conditions.
39
40
41
    resample: bool
        Whether the loss function is resampled for the next iteration. 
        Stochastic losses require resampleing, deterministic ones not.
Philipp Arras's avatar
Philipp Arras committed
42
    """
Jakob Knollmüller's avatar
Jakob Knollmüller committed
43

44
    def __init__(self, controller, eta=1, alpha=0.1, tau=1, epsilon=1e-16, resample=True):
45
46
        self.alpha = alpha
        self.eta = eta
Philipp Arras's avatar
Philipp Arras committed
47
        self.tau = tau
48
49
        self.epsilon = epsilon
        self.counter = 1
50
51
        self._controller = controller
        # self.steps = steps
52
        self.s = None
53
        self.resample = resample
54
55

    def _step(self, position, gradient):
Philipp Arras's avatar
Philipp Arras committed
56
        self.s = self.alpha * gradient ** 2 + (1 - self.alpha) * self.s
Philipp Arras's avatar
Docs    
Philipp Arras committed
57
58
        self.rho = self.eta * self.counter ** (-0.5 + self.epsilon) \
                / (self.tau + (self.s).sqrt())
59
60
61
62
63
        new_position = position - self.rho * gradient
        self.counter += 1
        return new_position

    def __call__(self, E):
64
        from ..utilities import myassert
65
66
67
68
69
70

        controller = self._controller
        status = controller.start(energy)
        if status != controller.CONTINUE:
            return energy, status

71
        if self.s is None:
Philipp Arras's avatar
Philipp Arras committed
72
            self.s = E.gradient ** 2
73
74
75
76
77
        while True:
            # check if position is at a flat point
            if energy.gradient_norm == 0:
                return energy, controller.CONVERGED

78
            x = self._step(E.position, E.gradient)
79
80
            if self.resample:
                E = E.resample_at(x)
81
            myassert(isinstance(E, Energy))
Philipp Arras's avatar
Fixup    
Philipp Arras committed
82
            myassert(x.domain is E.position.domain)
83
84
85
86
87
88

            energy = new_energy
            status = self._controller.check(energy)
            if status != controller.CONTINUE:
                return energy, status

89
90
91
92
        return E, convergence

    def reset(self):
        self.counter = 1
Philipp Arras's avatar
Philipp Arras committed
93
        self.s = None