noise_energy.py 3.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Philipp Arras's avatar
Philipp Arras committed
19
20
21
from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy
22

23
24
# TODO Take only residual_sample_list as argument

25
26

class NoiseEnergy(Energy):
Philipp Arras's avatar
Philipp Arras committed
27
    def __init__(self, position, d, xi, D, t, ht, Instrument,
Philipp Arras's avatar
Philipp Arras committed
28
                 nonlinearity, alpha, q, Projection, munit=1., sunit=1.,
Philipp Arras's avatar
Philipp Arras committed
29
                 dunit=1., samples=3, xi_sample_list=None, inverter=None):
Martin Reinecke's avatar
Martin Reinecke committed
30
        super(NoiseEnergy, self).__init__(position=position)
Philipp Arras's avatar
Philipp Arras committed
31
        self.xi = xi
32
33
        self.D = D
        self.d = d
Philipp Arras's avatar
Philipp Arras committed
34
        self.N = DiagonalOperator(diagonal=dunit**2 * exp(self.position))
35
36
        self.t = t
        self.samples = samples
37
        self.ht = ht
38
39
        self.Instrument = Instrument
        self.nonlinearity = nonlinearity
Philipp Arras's avatar
Philipp Arras committed
40
41
42
        self.munit = munit
        self.sunit = sunit
        self.dunit = dunit
43
44
45
46

        self.alpha = alpha
        self.q = q
        self.Projection = Projection
Philipp Arras's avatar
Philipp Arras committed
47
        self.power = self.Projection.adjoint_times(munit * exp(0.5 * self.t))
Philipp Arras's avatar
Philipp Arras committed
48
        if xi_sample_list is None:
Martin Reinecke's avatar
Martin Reinecke committed
49
            if samples is None or samples == 0:
Philipp Arras's avatar
Philipp Arras committed
50
                xi_sample_list = [xi]
51
            else:
Philipp Arras's avatar
Philipp Arras committed
52
                xi_sample_list = [D.generate_posterior_sample() + xi
Philipp Arras's avatar
Philipp Arras committed
53
                                  for _ in range(samples)]
Philipp Arras's avatar
Philipp Arras committed
54
        self.xi_sample_list = xi_sample_list
55
        self.inverter = inverter
Philipp Arras's avatar
Philipp Arras committed
56

Philipp Arras's avatar
Philipp Arras committed
57
        A = Projection.adjoint_times(munit * exp(.5 * self.t))  # unit: munit
Philipp Arras's avatar
Philipp Arras committed
58
59

        self._gradient = None
Philipp Arras's avatar
Philipp Arras committed
60
        for sample in self.xi_sample_list:
61
            map_s = self.ht(A * sample)
Philipp Arras's avatar
Philipp Arras committed
62

Philipp Arras's avatar
Philipp Arras committed
63
64
            residual = self.d - \
                self.Instrument(sunit * self.nonlinearity(map_s))
Philipp Arras's avatar
Philipp Arras committed
65
66
67
68
69
70
71
72
73
74
            lh = .5 * residual.vdot(self.N.inverse_times(residual))
            grad = -.5 * self.N.inverse_times(residual.conjugate() * residual)

            if self._gradient is None:
                self._value = lh
                self._gradient = grad.copy()
            else:
                self._value += lh
                self._gradient += grad

Philipp Arras's avatar
Philipp Arras committed
75
        self._value *= 1. / len(self.xi_sample_list)
76
        self._value += .5 * self.position.sum()
Philipp Arras's avatar
Philipp Arras committed
77
78
        self._value += (self.alpha - 1.).vdot(self.position) + \
            self.q.vdot(exp(-self.position))
Philipp Arras's avatar
Philipp Arras committed
79

Philipp Arras's avatar
Philipp Arras committed
80
        self._gradient *= 1. / len(self.xi_sample_list)
Philipp Arras's avatar
Philipp Arras committed
81
        self._gradient += (self.alpha - 0.5) - self.q * (exp(-self.position))
82
83

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
84
        return self.__class__(
Philipp Arras's avatar
Cleanup    
Philipp Arras committed
85
            position, self.d, self.xi, self.D, self.t, self.ht,
Martin Reinecke's avatar
Martin Reinecke committed
86
            self.Instrument, self.nonlinearity, self.alpha, self.q,
Philipp Arras's avatar
Philipp Arras committed
87
            self.Projection, munit=self.munit, sunit=self.sunit,
Philipp Arras's avatar
Philipp Arras committed
88
            dunit=self.dunit, xi_sample_list=self.xi_sample_list,
Martin Reinecke's avatar
Martin Reinecke committed
89
            samples=self.samples, inverter=self.inverter)
90
91

    @property
Martin Reinecke's avatar
Martin Reinecke committed
92
93
    def value(self):
        return self._value
94
95

    @property
Martin Reinecke's avatar
Martin Reinecke committed
96
97
    def gradient(self):
        return self._gradient