Commit 74c750ca authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 95ec8d23
Pipeline #22051 passed with stage
in 4 minutes and 43 seconds
......@@ -82,3 +82,6 @@ Significant differences between NIFTy nightly and nifty2go
14) A new approach is used for FFTs along axes that are distributed among
MPI tasks. As a consequence, nifty2go works well with the standard version
of pyfftw and does not need the MPI-enabled fork.
15) Arithmetic functions working on Fields have been moved from
basic_arithmetics.py to field.py.
......@@ -73,7 +73,7 @@ class CriticalPowerEnergy(Energy):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
logarithmic=self.logarithmic,
w=self.w, samples=self.samples,
samples=self.samples, w=self.w,
inverter=self._inverter)
@property
......
......@@ -20,10 +20,10 @@ class WienerFilterCurvature(EndomorphicOperator):
"""
def __init__(self, R, N, S):
super(WienerFilterCurvature, self).__init__()
self.R = R
self.N = N
self.S = S
super(WienerFilterCurvature, self).__init__()
@property
def preconditioner(self):
......
from ..minimization.energy import Energy
from ..utilities import memo
from ..operators.inversion_enabler import InversionEnabler
from .wiener_filter_curvature import WienerFilterCurvature
......@@ -26,41 +25,31 @@ class WienerFilterEnergy(Energy):
def __init__(self, position, d, R, N, S, inverter, _j=None):
super(WienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
self._curvature = InversionEnabler(WienerFilterCurvature(R, N, S),
inverter=inverter)
self._inverter = inverter
self._jpre = _j
if _j is None:
_j = self.R.adjoint_times(self.N.inverse_times(d))
self._j = _j
Dx = self._curvature(self.position)
self._value = 0.5*self.position.vdot(Dx) - self._j.vdot(self.position)
self._gradient = Dx - self._j
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, inverter=self._inverter, _j=self._jpre)
return self.__class__(position=position, d=None, R=self.R, N=self.N,
S=self.S, inverter=self._inverter, _j=self._j)
@property
@memo
def value(self):
return 0.5*self.position.vdot(self._Dx) - self._j.vdot(self.position)
return self._value
@property
@memo
def gradient(self):
return self._Dx - self._j
return self._gradient
@property
@memo
def curvature(self):
return InversionEnabler(WienerFilterCurvature(R=self.R, N=self.N,
S=self.S),
inverter=self._inverter)
@property
@memo
def _Dx(self):
return self.curvature(self.position)
@property
def _j(self):
if self._jpre is None:
self._jpre = self.R.adjoint_times(self.N.inverse_times(self.d))
return self._jpre
return self._curvature
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment