Commit 12219de3 authored by Philipp Arras's avatar Philipp Arras

Clean up noise energy

parent e110e7d5
......@@ -17,51 +17,26 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy
# TODO Take only residual_sample_list as argument
from ..operators.diagonal_operator import DiagonalOperator
class NoiseEnergy(Energy):
def __init__(self, position, d, xi, D, t, ht, Instrument,
nonlinearity, alpha, q, Distributor, samples=3,
xi_sample_list=None, inverter=None):
def __init__(self, position, alpha, q, res_sample_list):
super(NoiseEnergy, self).__init__(position=position)
self.xi = xi
self.D = D
self.d = d
self.N = DiagonalOperator(diagonal=exp(self.position))
self.t = t
self.samples = samples
self.ht = ht
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.N = DiagonalOperator(diagonal=exp(self.position))
self.alpha = alpha
self.q = q
self.Distributor = Distributor
self.power = self.Distributor(exp(0.5 * self.t))
if xi_sample_list is None:
if samples is None or samples == 0:
xi_sample_list = [xi]
else:
xi_sample_list = [D.draw_sample() + xi
for _ in range(samples)]
self.xi_sample_list = xi_sample_list
self.inverter = inverter
A = Distributor(exp(.5*self.t))
alpha_field = Field(self.position.domain, val=alpha)
q_field = Field(self.position.domain, val=q)
self.res_sample_list = res_sample_list
self._gradient = None
for sample in self.xi_sample_list:
map_s = self.ht(A * sample)
residual = self.d - \
self.Instrument(self.nonlinearity(map_s))
lh = .5 * residual.vdot(self.N.inverse_times(residual))
grad = -.5 * self.N.inverse_times(residual.conjugate()*residual)
for s in self.res_sample_list:
lh = .5 * s.vdot(self.N.inverse_times(s))
grad = -.5 * self.N.inverse_times(s.conjugate()*s)
if self._gradient is None:
self._value = lh
self._gradient = grad.copy()
......@@ -69,20 +44,16 @@ class NoiseEnergy(Energy):
self._value += lh
self._gradient += grad
self._value *= 1. / len(self.xi_sample_list)
self._value /= len(self.res_sample_list)
self._value += .5 * self.position.sum()
self._value += (self.alpha - 1.).vdot(self.position) + \
self.q.vdot(exp(-self.position))
self._value += (alpha_field - 1.).vdot(self.position) + \
q_field.vdot(exp(-self.position))
self._gradient *= 1. / len(self.xi_sample_list)
self._gradient += (self.alpha-0.5) - self.q*(exp(-self.position))
self._gradient /= len(self.res_sample_list)
self._gradient += (alpha_field-0.5) - q_field*(exp(-self.position))
def at(self, position):
return self.__class__(
position, self.d, self.xi, self.D, self.t, self.ht,
self.Instrument, self.nonlinearity, self.alpha, self.q,
self.Distributor, xi_sample_list=self.xi_sample_list,
samples=self.samples, inverter=self.inverter)
return self.__class__(position, self.alpha, self.q, self.res_sample_list)
@property
def value(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment