noise_energy.py 3.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Philipp Arras's avatar
Philipp Arras committed
19
20
21
from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy
22

23
24
# TODO Take only residual_sample_list as argument

25
26

class NoiseEnergy(Energy):
Philipp Arras's avatar
Philipp Arras committed
27
    def __init__(self, position, d, xi, D, t, ht, Instrument,
Martin Reinecke's avatar
Martin Reinecke committed
28
                 nonlinearity, alpha, q, Distributor, samples=3,
Philipp Arras's avatar
Philipp Arras committed
29
                 xi_sample_list=None, inverter=None):
Martin Reinecke's avatar
Martin Reinecke committed
30
        super(NoiseEnergy, self).__init__(position=position)
Philipp Arras's avatar
Philipp Arras committed
31
        self.xi = xi
32
33
        self.D = D
        self.d = d
Philipp Arras's avatar
Philipp Arras committed
34
        self.N = DiagonalOperator(diagonal=exp(self.position))
35
36
        self.t = t
        self.samples = samples
37
        self.ht = ht
38
39
40
41
42
        self.Instrument = Instrument
        self.nonlinearity = nonlinearity

        self.alpha = alpha
        self.q = q
Martin Reinecke's avatar
Martin Reinecke committed
43
44
        self.Distributor = Distributor
        self.power = self.Distributor(exp(0.5 * self.t))
Philipp Arras's avatar
Philipp Arras committed
45
        if xi_sample_list is None:
Martin Reinecke's avatar
Martin Reinecke committed
46
            if samples is None or samples == 0:
Philipp Arras's avatar
Philipp Arras committed
47
                xi_sample_list = [xi]
48
            else:
Martin Reinecke's avatar
Martin Reinecke committed
49
                xi_sample_list = [D.draw_sample() + xi
Philipp Arras's avatar
Philipp Arras committed
50
                                  for _ in range(samples)]
Philipp Arras's avatar
Philipp Arras committed
51
        self.xi_sample_list = xi_sample_list
52
        self.inverter = inverter
Philipp Arras's avatar
Philipp Arras committed
53

Martin Reinecke's avatar
Martin Reinecke committed
54
        A = Distributor(exp(.5*self.t))
Philipp Arras's avatar
Philipp Arras committed
55
56

        self._gradient = None
Philipp Arras's avatar
Philipp Arras committed
57
        for sample in self.xi_sample_list:
58
            map_s = self.ht(A * sample)
Philipp Arras's avatar
Philipp Arras committed
59

Philipp Arras's avatar
Philipp Arras committed
60
            residual = self.d - \
Philipp Arras's avatar
Philipp Arras committed
61
                self.Instrument(self.nonlinearity(map_s))
Philipp Arras's avatar
Philipp Arras committed
62
            lh = .5 * residual.vdot(self.N.inverse_times(residual))
63
            grad = -.5 * self.N.inverse_times(residual.conjugate()*residual)
Philipp Arras's avatar
Philipp Arras committed
64
65
66
67
68
69
70
71

            if self._gradient is None:
                self._value = lh
                self._gradient = grad.copy()
            else:
                self._value += lh
                self._gradient += grad

Philipp Arras's avatar
Philipp Arras committed
72
        self._value *= 1. / len(self.xi_sample_list)
73
        self._value += .5 * self.position.sum()
Philipp Arras's avatar
Philipp Arras committed
74
75
        self._value += (self.alpha - 1.).vdot(self.position) + \
            self.q.vdot(exp(-self.position))
Philipp Arras's avatar
Philipp Arras committed
76

Philipp Arras's avatar
Philipp Arras committed
77
        self._gradient *= 1. / len(self.xi_sample_list)
78
        self._gradient += (self.alpha-0.5) - self.q*(exp(-self.position))
79
80

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
81
        return self.__class__(
Philipp Arras's avatar
Cleanup  
Philipp Arras committed
82
            position, self.d, self.xi, self.D, self.t, self.ht,
Martin Reinecke's avatar
Martin Reinecke committed
83
            self.Instrument, self.nonlinearity, self.alpha, self.q,
Martin Reinecke's avatar
Martin Reinecke committed
84
            self.Distributor, xi_sample_list=self.xi_sample_list,
Martin Reinecke's avatar
Martin Reinecke committed
85
            samples=self.samples, inverter=self.inverter)
86
87

    @property
Martin Reinecke's avatar
Martin Reinecke committed
88
89
    def value(self):
        return self._value
90
91

    @property
Martin Reinecke's avatar
Martin Reinecke committed
92
93
    def gradient(self):
        return self._gradient