diff --git a/nifty4/library/noise_energy.py b/nifty4/library/noise_energy.py index 88bfa41be4f10d9f7b3a40e8d761de26225993bf..6bf8cd1b435a63442ac272a63d02a19956680033 100644 --- a/nifty4/library/noise_energy.py +++ b/nifty4/library/noise_energy.py @@ -17,51 +17,26 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from .. import Field, exp -from ..operators.diagonal_operator import DiagonalOperator from ..minimization.energy import Energy - -# TODO Take only residual_sample_list as argument +from ..operators.diagonal_operator import DiagonalOperator class NoiseEnergy(Energy): - def __init__(self, position, d, xi, D, t, ht, Instrument, - nonlinearity, alpha, q, Distributor, samples=3, - xi_sample_list=None, inverter=None): + def __init__(self, position, alpha, q, res_sample_list): super(NoiseEnergy, self).__init__(position=position) - self.xi = xi - self.D = D - self.d = d - self.N = DiagonalOperator(diagonal=exp(self.position)) - self.t = t - self.samples = samples - self.ht = ht - self.Instrument = Instrument - self.nonlinearity = nonlinearity + self.N = DiagonalOperator(diagonal=exp(self.position)) self.alpha = alpha self.q = q - self.Distributor = Distributor - self.power = self.Distributor(exp(0.5 * self.t)) - if xi_sample_list is None: - if samples is None or samples == 0: - xi_sample_list = [xi] - else: - xi_sample_list = [D.draw_sample() + xi - for _ in range(samples)] - self.xi_sample_list = xi_sample_list - self.inverter = inverter - - A = Distributor(exp(.5*self.t)) + alpha_field = Field(self.position.domain, val=alpha) + q_field = Field(self.position.domain, val=q) + self.res_sample_list = res_sample_list self._gradient = None - for sample in self.xi_sample_list: - map_s = self.ht(A * sample) - - residual = self.d - \ - self.Instrument(self.nonlinearity(map_s)) - lh = .5 * residual.vdot(self.N.inverse_times(residual)) - grad = -.5 * self.N.inverse_times(residual.conjugate()*residual) + for s in self.res_sample_list: + lh = .5 * s.vdot(self.N.inverse_times(s)) + grad = -.5 * self.N.inverse_times(s.conjugate()*s) if self._gradient is None: self._value = lh self._gradient = grad.copy() @@ -69,20 +44,16 @@ class NoiseEnergy(Energy): self._value += lh self._gradient += grad - self._value *= 1. / len(self.xi_sample_list) + self._value /= len(self.res_sample_list) self._value += .5 * self.position.sum() - self._value += (self.alpha - 1.).vdot(self.position) + \ - self.q.vdot(exp(-self.position)) + self._value += (alpha_field - 1.).vdot(self.position) + \ + q_field.vdot(exp(-self.position)) - self._gradient *= 1. / len(self.xi_sample_list) - self._gradient += (self.alpha-0.5) - self.q*(exp(-self.position)) + self._gradient /= len(self.res_sample_list) + self._gradient += (alpha_field-0.5) - q_field*(exp(-self.position)) def at(self, position): - return self.__class__( - position, self.d, self.xi, self.D, self.t, self.ht, - self.Instrument, self.nonlinearity, self.alpha, self.q, - self.Distributor, xi_sample_list=self.xi_sample_list, - samples=self.samples, inverter=self.inverter) + return self.__class__(position, self.alpha, self.q, self.res_sample_list) @property def value(self):