noise_energy.py 3.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Philipp Arras's avatar
Philipp Arras committed
19
20
21
from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy
22

23
24
# TODO Take only residual_sample_list as argument

25
26

class NoiseEnergy(Energy):
27
    def __init__(self, position, d, m, D, t, ht, Instrument,
Philipp Arras's avatar
Philipp Arras committed
28
29
                 nonlinearity, alpha, q, Projection, munit=1., sunit=1.,
                 dunit=1., samples=3, sample_list=None, inverter=None):
Martin Reinecke's avatar
Martin Reinecke committed
30
        super(NoiseEnergy, self).__init__(position=position)
31
32
33
        self.m = m
        self.D = D
        self.d = d
Philipp Arras's avatar
Philipp Arras committed
34
        self.N = DiagonalOperator(diagonal=dunit**2 * exp(self.position))
35
36
        self.t = t
        self.samples = samples
37
        self.ht = ht
38
39
        self.Instrument = Instrument
        self.nonlinearity = nonlinearity
Philipp Arras's avatar
Philipp Arras committed
40
41
42
        self.munit = munit
        self.sunit = sunit
        self.dunit = dunit
43
44
45
46

        self.alpha = alpha
        self.q = q
        self.Projection = Projection
Philipp Arras's avatar
Philipp Arras committed
47
        self.power = self.Projection.adjoint_times(munit * exp(0.5 * self.t))
48
49
        self.one = Field(self.position.domain, val=1.)
        if sample_list is None:
Martin Reinecke's avatar
Martin Reinecke committed
50
51
            if samples is None or samples == 0:
                sample_list = [m]
52
            else:
Martin Reinecke's avatar
Martin Reinecke committed
53
                sample_list = [D.generate_posterior_sample() + m
Martin Reinecke's avatar
Martin Reinecke committed
54
                               for _ in range(samples)]
55
56
        self.sample_list = sample_list
        self.inverter = inverter
Philipp Arras's avatar
Philipp Arras committed
57

Philipp Arras's avatar
Philipp Arras committed
58
        A = Projection.adjoint_times(munit * exp(.5 * self.t))  # unit: munit
59
        map_s = self.ht(A * m)
Philipp Arras's avatar
Philipp Arras committed
60
61
62

        self._gradient = None
        for sample in self.sample_list:
63
            map_s = self.ht(A * sample)
Philipp Arras's avatar
Philipp Arras committed
64

Philipp Arras's avatar
Philipp Arras committed
65
66
            residual = self.d - \
                self.Instrument(sunit * self.nonlinearity(map_s))
Philipp Arras's avatar
Philipp Arras committed
67
68
69
70
71
72
73
74
75
76
            lh = .5 * residual.vdot(self.N.inverse_times(residual))
            grad = -.5 * self.N.inverse_times(residual.conjugate() * residual)

            if self._gradient is None:
                self._value = lh
                self._gradient = grad.copy()
            else:
                self._value += lh
                self._gradient += grad

Philipp Arras's avatar
Philipp Arras committed
77
        self._value *= 1. / len(self.sample_list)
78
        self._value += .5 * self.position.sum()
Philipp Arras's avatar
Philipp Arras committed
79
80
        self._value += (self.alpha - 1.).vdot(self.position) + \
            self.q.vdot(exp(-self.position))
Philipp Arras's avatar
Philipp Arras committed
81

Philipp Arras's avatar
Philipp Arras committed
82
83
        self._gradient *= 1. / len(self.sample_list)
        self._gradient += (self.alpha - 0.5) - self.q * (exp(-self.position))
84
85

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
86
        return self.__class__(
87
            position, self.d, self.m, self.D, self.t, self.ht,
Martin Reinecke's avatar
Martin Reinecke committed
88
            self.Instrument, self.nonlinearity, self.alpha, self.q,
Philipp Arras's avatar
Philipp Arras committed
89
90
            self.Projection, munit=self.munit, sunit=self.sunit,
            dunit=self.dunit, sample_list=self.sample_list,
Martin Reinecke's avatar
Martin Reinecke committed
91
            samples=self.samples, inverter=self.inverter)
92
93

    @property
Martin Reinecke's avatar
Martin Reinecke committed
94
95
    def value(self):
        return self._value
96
97

    @property
Martin Reinecke's avatar
Martin Reinecke committed
98
99
    def gradient(self):
        return self._gradient