Commit a35319be authored by Martin Reinecke's avatar Martin Reinecke
Browse files

performance tweak

parent b94097df
Pipeline #22730 passed with stage
in 4 minutes and 51 seconds
from .. import Field, exp from .. import Field, exp
from ..operators.diagonal_operator import DiagonalOperator from ..operators.diagonal_operator import DiagonalOperator
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..utilities import memo
class NoiseEnergy(Energy): class NoiseEnergy(Energy):
def __init__(self, position, d, m, D, t, FFT, Instrument, nonlinearity, def __init__(self, position, d, m, D, t, FFT, Instrument, nonlinearity,
alpha, q, Projection, samples=3, sample_list=None, alpha, q, Projection, samples=3, sample_list=None,
inverter=None): inverter=None):
super(NoiseEnergy, self).__init__(position=position.copy()) super(NoiseEnergy, self).__init__(position=position)
self.m = m self.m = m
self.D = D self.D = D
self.d = d self.d = d
...@@ -25,17 +24,14 @@ class NoiseEnergy(Energy): ...@@ -25,17 +24,14 @@ class NoiseEnergy(Energy):
self.power = self.Projection.adjoint_times(exp(0.5 * self.t)) self.power = self.Projection.adjoint_times(exp(0.5 * self.t))
self.one = Field(self.position.domain, val=1.) self.one = Field(self.position.domain, val=1.)
if sample_list is None: if sample_list is None:
sample_list = [] if samples is None or samples == 0:
if samples is None: sample_list = [m]
sample_list.append(self.m)
else: else:
for i in range(samples): sample_list = [D.generate_posterior_sample(m)
sample = D.generate_posterior_sample(m) for _ in range(samples)]
sample = FFT(Field(FFT.domain, val=(
FFT.adjoint_times(sample).val)))
sample_list.append(sample)
self.sample_list = sample_list self.sample_list = sample_list
self.inverter = inverter self.inverter = inverter
self._value, self._gradient = self._value_and_grad()
def at(self, position): def at(self, position):
return self.__class__( return self.__class__(
...@@ -44,43 +40,35 @@ class NoiseEnergy(Energy): ...@@ -44,43 +40,35 @@ class NoiseEnergy(Energy):
self.Projection, sample_list=self.sample_list, self.Projection, sample_list=self.sample_list,
samples=self.samples, inverter=self.inverter) samples=self.samples, inverter=self.inverter)
@property def _value_and_grad(self):
@memo likelihood_gradient = None
def value(self):
likelihood = 0.
for sample in self.sample_list: for sample in self.sample_list:
likelihood += self._likelihood(sample) residual = self.d - \
return ((likelihood / float(len(self.sample_list))) + self.Instrument(self.nonlinearity(
0.5 * self.one.vdot(self.position) + self.FFT.adjoint_times(self.power*sample)))
(self.alpha - self.one).vdot(self.position) + lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
self.q.vdot(exp(-self.position))) grad = -0.5 * self.N.inverse_times(residual.conjugate() * residual)
if likelihood_gradient is None:
likelihood = lh
likelihood_gradient = grad.copy()
else:
likelihood += lh
likelihood_gradient += grad
def _likelihood(self, m): likelihood = ((likelihood / float(len(self.sample_list))) +
residual = self.d - \ 0.5 * self.position.integrate() +
self.Instrument(self.nonlinearity( (self.alpha - 1.).vdot(self.position) +
self.FFT.adjoint_times(self.power * m))) self.q.vdot(exp(-self.position)))
energy = 0.5 * residual.vdot(self.N.inverse_times(residual)) likelihood_gradient = (
return energy.real likelihood_gradient / float(len(self.sample_list)) +
(self.alpha-0.5) -
self.q * (exp(-self.position)))
return likelihood, likelihood_gradient
@property @property
@memo def value(self):
def gradient(self): return self._value
likelihood_gradient = Field(self.position.domain, val=0.)
for sample in self.sample_list:
likelihood_gradient += self._likelihood_gradient(sample)
return (likelihood_gradient / float(len(self.sample_list)) +
0.5 * self.one + (self.alpha - self.one) -
self.q * (exp(-self.position)))
def _likelihood_gradient(self, m):
residual = self.d - \
self.Instrument(self.nonlinearity(
self.FFT.adjoint_times(self.power * m)))
gradient = - 0.5 * \
self.N.inverse_times(residual.conjugate() * residual)
return gradient
@property @property
@memo def gradient(self):
def curvature(self): return self._gradient
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment