diff --git a/demos/critical_filtering.py b/demos/critical_filtering.py index 28d9f0e3e5ce73c0cd414b92928bdb81a34cc1e8..19fcb4a8876fb94d70f9e96d83fe3bcc67aa92f1 100644 --- a/demos/critical_filtering.py +++ b/demos/critical_filtering.py @@ -96,7 +96,7 @@ if __name__ == "__main__": IC2 = ift.GradientNormController(verbose=True, iteration_limit=100, tol_abs_gradnorm=0.1) minimizer2 = ift.VL_BFGS(IC2, max_history_length=20) - IC3 = ift.GradientNormController(verbose=True, iteration_limit=100, + IC3 = ift.GradientNormController(verbose=True, iteration_limit=1000, tol_abs_gradnorm=0.1) minimizer3 = ift.SteepestDescent(IC3) diff --git a/nifty/library/critical_filter/critical_power_curvature.py b/nifty/library/critical_filter/critical_power_curvature.py index 0f355b1fa9bd378f1ad0eadab68bb8295dc14777..ae8bb6a9bff22349f9675cdfcd516f8a1791d519 100644 --- a/nifty/library/critical_filter/critical_power_curvature.py +++ b/nifty/library/critical_filter/critical_power_curvature.py @@ -21,7 +21,7 @@ class CriticalPowerCurvature(EndomorphicOperator): # ---Overwritten properties and methods--- def __init__(self, theta, T): - self.theta = DiagonalOperator(theta.weight(1)) + self.theta = DiagonalOperator(theta) self.T = T super(CriticalPowerCurvature, self).__init__() diff --git a/nifty/library/critical_filter/critical_power_energy.py b/nifty/library/critical_filter/critical_power_energy.py index 8d3bebc75f73f5fd255d194dee3d69f8188bb3b5..0c7af23a579b6b0c5834339b2148f9fa3f2ee8a2 100644 --- a/nifty/library/critical_filter/critical_power_energy.py +++ b/nifty/library/critical_filter/critical_power_energy.py @@ -1,11 +1,11 @@ from ...energies.energy import Energy from ...operators.smoothness_operator import SmoothnessOperator +from ...operators.power_projection_operator import PowerProjectionOperator from ...operators.inversion_enabler import InversionEnabler from . import CriticalPowerCurvature from ...memoization import memo - -from ...sugar import generate_posterior_sample, power_analyze from ... import Field, exp +from ...sugar import generate_posterior_sample class CriticalPowerEnergy(Energy): @@ -66,7 +66,8 @@ class CriticalPowerEnergy(Energy): self.T = SmoothnessOperator(domain=self.position.domain[0], strength=smoothness_prior, logarithmic=logarithmic) - self.rho = self.position.domain[0].rho + self.P = PowerProjectionOperator(domain=self.m.domain, + power_space=self.position.domain[0]) self._w = w self._inverter = inverter @@ -80,26 +81,27 @@ class CriticalPowerEnergy(Energy): inverter=self._inverter) @property + @memo def value(self): - energy = self._theta.sum() - energy += self.position.weight(-1).vdot(self._rho_prime) + energy = Field.ones_like(self.position).vdot(self._theta) + energy += self.position.vdot(self.alpha-0.5) energy += 0.5 * self.position.vdot(self._Tt) return energy.real @property + @memo def gradient(self): - gradient = -self._theta.weight(-1) - gradient += self._rho_prime.weight(-1) + gradient = -self._theta + gradient += self.alpha-0.5 gradient += self._Tt - gradient = gradient.real - return gradient + return gradient.real @property + @memo def curvature(self): - curvature = InversionEnabler(CriticalPowerCurvature( - theta=self._theta.weight(-1), - T=self.T), inverter=self._inverter) - return curvature + curv = CriticalPowerCurvature(theta=self._theta, T=self.T) + return InversionEnabler(curv, inverter=self._inverter, + preconditioner=curv.preconditioner) # ---Added properties and methods--- @@ -121,15 +123,11 @@ class CriticalPowerEnergy(Energy): # self.logger.info("Drawing sample %i" % i) posterior_sample = generate_posterior_sample( self.m, self.D) - projected_sample = power_analyze( - posterior_sample, - binbounds=self.position.domain[0].binbounds) - w += (projected_sample) * self.rho + w += self.P(abs(posterior_sample) ** 2) + w /= float(self.samples) else: - w = self.m.power_analyze( - binbounds=self.position.domain[0].binbounds) - w *= self.rho + w = self.P(abs(self.m)**2) self._w = w return self._w @@ -138,11 +136,6 @@ class CriticalPowerEnergy(Energy): def _theta(self): return exp(-self.position) * (self.q + self.w / 2.) - @property - @memo - def _rho_prime(self): - return self.alpha - 1. + self.rho / 2. - @property @memo def _Tt(self):