Commit 6079d9bf authored by Martin Reinecke's avatar Martin Reinecke
Browse files

less is more

parent 04c80477
Pipeline #26556 failed with stages
in 4 minutes and 45 seconds
......@@ -19,39 +19,40 @@
from ..field import Field, exp
from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
import numpy as np
class NoiseEnergy(Energy):
def __init__(self, position, alpha, q, res_sample_list):
super(NoiseEnergy, self).__init__(position)
self.N = DiagonalOperator(diagonal=exp(self.position))
self.N = DiagonalOperator(exp(self.position))
self.alpha = alpha
self.q = q
alpha_field = Field(self.position.domain, val=alpha)
q_field = Field(self.position.domain, val=q)
self.res_sample_list = res_sample_list
self._gradient = None
for s in self.res_sample_list:
for s in res_sample_list:
lh = .5 * s.vdot(self.N.inverse_times(s))
grad = -.5 * self.N.inverse_times(s.conjugate()*s)
if self._gradient is None:
self._value = lh
self._gradient = grad.copy()
self._gradient = grad
else:
self._value += lh
self._gradient += grad
expmpos = exp(-position)
self._value *= 1./len(self.res_sample_list)
self._value += .5 * self.position.sum()
self._value += (alpha_field-1.).vdot(self.position) + \
q_field.vdot(expmpos)
possum = position.sum()
s1 = (alpha-1.)*possum if np.isscalar(alpha) \
else (alpha-1.).vdot(position)
s2 = q*expmpos.sum() if np.isscalar(q) else q.vdot(expmpos)
self._value += .5*possum + s1 + s2
self._gradient *= 1./len(self.res_sample_list)
self._gradient += (alpha_field-0.5) - q_field*expmpos
self._gradient *= 1./len(res_sample_list)
self._gradient += (alpha-0.5) - q*expmpos
self._gradient.lock()
def at(self, position):
......
......@@ -18,7 +18,6 @@
from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler
import numpy as np
def WienerFilterCurvature(R, N, S, inverter):
......
......@@ -16,11 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..minimization.energy import Energy
from ..minimization.quadratic_energy import QuadraticEnergy
from .wiener_filter_curvature import WienerFilterCurvature
class WienerFilterEnergy(Energy):
def WienerFilterEnergy(position, d, R, N, S, inverter):
"""The Energy for the Wiener filter.
It covers the case of linear measurement with
......@@ -42,33 +42,6 @@ class WienerFilterEnergy(Energy):
inverter : Minimizer
the minimization strategy to use for operator inversion
"""
def __init__(self, position, d, R, N, S, inverter, _j=None):
super(WienerFilterEnergy, self).__init__(position=position)
self.R = R
self.N = N
self.S = S
self._curvature = WienerFilterCurvature(R, N, S, inverter)
self._inverter = inverter
if _j is None:
_j = R.adjoint_times(N.inverse_times(d))
self._j = _j
Dx = self._curvature(self.position)
self._value = 0.5*position.vdot(Dx) - self._j.vdot(position)
self._gradient = (Dx - self._j).lock()
def at(self, position):
return self.__class__(position=position, d=None, R=self.R, N=self.N,
S=self.S, inverter=self._inverter, _j=self._j)
@property
def value(self):
return self._value
@property
def gradient(self):
return self._gradient
@property
def curvature(self):
return self._curvature
op = WienerFilterCurvature(R, N, S, inverter)
vec = R.adjoint_times(N.inverse_times(d))
return QuadraticEnergy(position, op, vec)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment