Commit 103043d1 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'noise_energy' into 'NIFTy_4'

Noise energy

See merge request ift/NIFTy!226
parents 61fe9b08 99ecec99
Pipeline #25683 passed with stages
in 5 minutes and 51 seconds
...@@ -16,52 +16,27 @@ ...@@ -16,52 +16,27 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, exp from ..field import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
# TODO Take only residual_sample_list as argument
class NoiseEnergy(Energy): class NoiseEnergy(Energy):
def __init__(self, position, d, xi, D, t, ht, Instrument, def __init__(self, position, alpha, q, res_sample_list):
nonlinearity, alpha, q, Distributor, samples=3, super(NoiseEnergy, self).__init__(position)
xi_sample_list=None, inverter=None):
super(NoiseEnergy, self).__init__(position=position)
self.xi = xi
self.D = D
self.d = d
self.N = DiagonalOperator(diagonal=exp(self.position))
self.t = t
self.samples = samples
self.ht = ht
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.N = DiagonalOperator(diagonal=exp(self.position))
self.alpha = alpha self.alpha = alpha
self.q = q self.q = q
self.Distributor = Distributor alpha_field = Field(self.position.domain, val=alpha)
self.power = self.Distributor(exp(0.5 * self.t)) q_field = Field(self.position.domain, val=q)
if xi_sample_list is None: self.res_sample_list = res_sample_list
if samples is None or samples == 0:
xi_sample_list = [xi]
else:
xi_sample_list = [D.draw_sample() + xi
for _ in range(samples)]
self.xi_sample_list = xi_sample_list
self.inverter = inverter
A = Distributor(exp(.5*self.t))
self._gradient = None self._gradient = None
for sample in self.xi_sample_list:
map_s = self.ht(A * sample)
residual = self.d - \
self.Instrument(self.nonlinearity(map_s))
lh = .5 * residual.vdot(self.N.inverse_times(residual))
grad = -.5 * self.N.inverse_times(residual.conjugate()*residual)
for s in self.res_sample_list:
lh = .5 * s.vdot(self.N.inverse_times(s))
grad = -.5 * self.N.inverse_times(s.conjugate()*s)
if self._gradient is None: if self._gradient is None:
self._value = lh self._value = lh
self._gradient = grad.copy() self._gradient = grad.copy()
...@@ -69,20 +44,19 @@ class NoiseEnergy(Energy): ...@@ -69,20 +44,19 @@ class NoiseEnergy(Energy):
self._value += lh self._value += lh
self._gradient += grad self._gradient += grad
self._value *= 1. / len(self.xi_sample_list) expmpos = exp(-position)
self._value *= 1./len(self.res_sample_list)
self._value += .5 * self.position.sum() self._value += .5 * self.position.sum()
self._value += (self.alpha - 1.).vdot(self.position) + \ self._value += (alpha_field-1.).vdot(self.position) + \
self.q.vdot(exp(-self.position)) q_field.vdot(expmpos)
self._gradient *= 1. / len(self.xi_sample_list) self._gradient *= 1./len(self.res_sample_list)
self._gradient += (self.alpha-0.5) - self.q*(exp(-self.position)) self._gradient += (alpha_field-0.5) - q_field*expmpos
self._gradient.lock()
def at(self, position): def at(self, position):
return self.__class__( return self.__class__(position, self.alpha, self.q,
position, self.d, self.xi, self.D, self.t, self.ht, self.res_sample_list)
self.Instrument, self.nonlinearity, self.alpha, self.q,
self.Distributor, xi_sample_list=self.xi_sample_list,
samples=self.samples, inverter=self.inverter)
@property @property
def value(self): def value(self):
......
...@@ -73,7 +73,7 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -73,7 +73,7 @@ class Noise_Energy_Tests(unittest.TestCase):
inverter = ift.ConjugateGradient(IC) inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS) S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
D = ift.library.NonlinearWienerFilterEnergy( C = ift.library.NonlinearWienerFilterEnergy(
position=xi, position=xi,
d=d, d=d,
Instrument=R, Instrument=R,
...@@ -84,10 +84,10 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -84,10 +84,10 @@ class Noise_Energy_Tests(unittest.TestCase):
S=S, S=S,
inverter=inverter).curvature inverter=inverter).curvature
energy0 = ift.library.NoiseEnergy( res_sample_list = [d - R(f(ht(C.draw_sample() + xi)))
position=eta0, d=d, xi=xi, D=D, t=tau, Instrument=R, for _ in range(10)]
alpha=alpha, q=q, Distributor=Dist, nonlinearity=f,
ht=ht, samples=10) energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
energy1 = energy0.at(eta1) energy1 = energy0.at(eta1)
a = (energy1.value - energy0.value) / eps a = (energy1.value - energy0.value) / eps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment