Commit 12219de3 authored by Philipp Arras's avatar Philipp Arras

Clean up noise energy

parent e110e7d5
...@@ -17,51 +17,26 @@ ...@@ -17,51 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, exp from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
# TODO Take only residual_sample_list as argument
class NoiseEnergy(Energy): class NoiseEnergy(Energy):
def __init__(self, position, d, xi, D, t, ht, Instrument, def __init__(self, position, alpha, q, res_sample_list):
nonlinearity, alpha, q, Distributor, samples=3,
xi_sample_list=None, inverter=None):
super(NoiseEnergy, self).__init__(position=position) super(NoiseEnergy, self).__init__(position=position)
self.xi = xi
self.D = D
self.d = d
self.N = DiagonalOperator(diagonal=exp(self.position))
self.t = t
self.samples = samples
self.ht = ht
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.N = DiagonalOperator(diagonal=exp(self.position))
self.alpha = alpha self.alpha = alpha
self.q = q self.q = q
self.Distributor = Distributor alpha_field = Field(self.position.domain, val=alpha)
self.power = self.Distributor(exp(0.5 * self.t)) q_field = Field(self.position.domain, val=q)
if xi_sample_list is None: self.res_sample_list = res_sample_list
if samples is None or samples == 0:
xi_sample_list = [xi]
else:
xi_sample_list = [D.draw_sample() + xi
for _ in range(samples)]
self.xi_sample_list = xi_sample_list
self.inverter = inverter
A = Distributor(exp(.5*self.t))
self._gradient = None self._gradient = None
for sample in self.xi_sample_list:
map_s = self.ht(A * sample)
residual = self.d - \
self.Instrument(self.nonlinearity(map_s))
lh = .5 * residual.vdot(self.N.inverse_times(residual))
grad = -.5 * self.N.inverse_times(residual.conjugate()*residual)
for s in self.res_sample_list:
lh = .5 * s.vdot(self.N.inverse_times(s))
grad = -.5 * self.N.inverse_times(s.conjugate()*s)
if self._gradient is None: if self._gradient is None:
self._value = lh self._value = lh
self._gradient = grad.copy() self._gradient = grad.copy()
...@@ -69,20 +44,16 @@ class NoiseEnergy(Energy): ...@@ -69,20 +44,16 @@ class NoiseEnergy(Energy):
self._value += lh self._value += lh
self._gradient += grad self._gradient += grad
self._value *= 1. / len(self.xi_sample_list) self._value /= len(self.res_sample_list)
self._value += .5 * self.position.sum() self._value += .5 * self.position.sum()
self._value += (self.alpha - 1.).vdot(self.position) + \ self._value += (alpha_field - 1.).vdot(self.position) + \
self.q.vdot(exp(-self.position)) q_field.vdot(exp(-self.position))
self._gradient *= 1. / len(self.xi_sample_list) self._gradient /= len(self.res_sample_list)
self._gradient += (self.alpha-0.5) - self.q*(exp(-self.position)) self._gradient += (alpha_field-0.5) - q_field*(exp(-self.position))
def at(self, position): def at(self, position):
return self.__class__( return self.__class__(position, self.alpha, self.q, self.res_sample_list)
position, self.d, self.xi, self.D, self.t, self.ht,
self.Instrument, self.nonlinearity, self.alpha, self.q,
self.Distributor, xi_sample_list=self.xi_sample_list,
samples=self.samples, inverter=self.inverter)
@property @property
def value(self): def value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment