noise_energy.py 3.53 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Philipp Arras's avatar
Philipp Arras committed
19 20 21
from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy
22 23 24


class NoiseEnergy(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
25 26 27
    def __init__(self, position, d, m, D, t, FFT, Instrument, nonlinearity,
                 alpha, q, Projection, samples=3, sample_list=None,
                 inverter=None):
Martin Reinecke's avatar
Martin Reinecke committed
28
        super(NoiseEnergy, self).__init__(position=position)
29 30 31
        self.m = m
        self.D = D
        self.d = d
Martin Reinecke's avatar
Martin Reinecke committed
32
        self.N = DiagonalOperator(diagonal=exp(self.position))
33 34 35 36 37 38 39 40 41 42 43 44
        self.t = t
        self.samples = samples
        self.FFT = FFT
        self.Instrument = Instrument
        self.nonlinearity = nonlinearity

        self.alpha = alpha
        self.q = q
        self.Projection = Projection
        self.power = self.Projection.adjoint_times(exp(0.5 * self.t))
        self.one = Field(self.position.domain, val=1.)
        if sample_list is None:
Martin Reinecke's avatar
Martin Reinecke committed
45 46
            if samples is None or samples == 0:
                sample_list = [m]
47
            else:
Martin Reinecke's avatar
Martin Reinecke committed
48
                sample_list = [D.generate_posterior_sample() + m
Martin Reinecke's avatar
Martin Reinecke committed
49
                               for _ in range(samples)]
50 51
        self.sample_list = sample_list
        self.inverter = inverter
Martin Reinecke's avatar
Martin Reinecke committed
52
        self._value, self._gradient = self._value_and_grad()
53 54

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
55 56 57 58 59
        return self.__class__(
            position, self.d, self.m, self.D, self.t, self.FFT,
            self.Instrument, self.nonlinearity, self.alpha, self.q,
            self.Projection, sample_list=self.sample_list,
            samples=self.samples, inverter=self.inverter)
60

Martin Reinecke's avatar
Martin Reinecke committed
61 62
    def _value_and_grad(self):
        likelihood_gradient = None
63
        for sample in self.sample_list:
Martin Reinecke's avatar
Martin Reinecke committed
64 65 66 67 68 69 70 71 72 73 74
            residual = self.d - \
                self.Instrument(self.nonlinearity(
                    self.FFT.adjoint_times(self.power*sample)))
            lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
            grad = -0.5 * self.N.inverse_times(residual.conjugate() * residual)
            if likelihood_gradient is None:
                likelihood = lh
                likelihood_gradient = grad.copy()
            else:
                likelihood += lh
                likelihood_gradient += grad
75

Martin Reinecke's avatar
Martin Reinecke committed
76 77 78 79 80 81 82 83 84
        likelihood = ((likelihood / float(len(self.sample_list))) +
                      0.5 * self.position.integrate() +
                      (self.alpha - 1.).vdot(self.position) +
                      self.q.vdot(exp(-self.position)))
        likelihood_gradient = (
            likelihood_gradient / float(len(self.sample_list)) +
            (self.alpha-0.5) -
            self.q * (exp(-self.position)))
        return likelihood, likelihood_gradient
85 86

    @property
Martin Reinecke's avatar
Martin Reinecke committed
87 88
    def value(self):
        return self._value
89 90

    @property
Martin Reinecke's avatar
Martin Reinecke committed
91 92
    def gradient(self):
        return self._gradient