Skip to content
Snippets Groups Projects
Commit a06601c5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

adjust to nightly branch

parent df3b073b
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -96,7 +96,7 @@ if __name__ == "__main__":
IC2 = ift.GradientNormController(verbose=True, iteration_limit=100,
tol_abs_gradnorm=0.1)
minimizer2 = ift.VL_BFGS(IC2, max_history_length=20)
IC3 = ift.GradientNormController(verbose=True, iteration_limit=100,
IC3 = ift.GradientNormController(verbose=True, iteration_limit=1000,
tol_abs_gradnorm=0.1)
minimizer3 = ift.SteepestDescent(IC3)
......
......@@ -21,7 +21,7 @@ class CriticalPowerCurvature(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, theta, T):
self.theta = DiagonalOperator(theta.weight(1))
self.theta = DiagonalOperator(theta)
self.T = T
super(CriticalPowerCurvature, self).__init__()
......
from ...energies.energy import Energy
from ...operators.smoothness_operator import SmoothnessOperator
from ...operators.power_projection_operator import PowerProjectionOperator
from ...operators.inversion_enabler import InversionEnabler
from . import CriticalPowerCurvature
from ...memoization import memo
from ...sugar import generate_posterior_sample, power_analyze
from ... import Field, exp
from ...sugar import generate_posterior_sample
class CriticalPowerEnergy(Energy):
......@@ -66,7 +66,8 @@ class CriticalPowerEnergy(Energy):
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior,
logarithmic=logarithmic)
self.rho = self.position.domain[0].rho
self.P = PowerProjectionOperator(domain=self.m.domain,
power_space=self.position.domain[0])
self._w = w
self._inverter = inverter
......@@ -80,26 +81,27 @@ class CriticalPowerEnergy(Energy):
inverter=self._inverter)
@property
@memo
def value(self):
energy = self._theta.sum()
energy += self.position.weight(-1).vdot(self._rho_prime)
energy = Field.ones_like(self.position).vdot(self._theta)
energy += self.position.vdot(self.alpha-0.5)
energy += 0.5 * self.position.vdot(self._Tt)
return energy.real
@property
@memo
def gradient(self):
gradient = -self._theta.weight(-1)
gradient += self._rho_prime.weight(-1)
gradient = -self._theta
gradient += self.alpha-0.5
gradient += self._Tt
gradient = gradient.real
return gradient
return gradient.real
@property
@memo
def curvature(self):
curvature = InversionEnabler(CriticalPowerCurvature(
theta=self._theta.weight(-1),
T=self.T), inverter=self._inverter)
return curvature
curv = CriticalPowerCurvature(theta=self._theta, T=self.T)
return InversionEnabler(curv, inverter=self._inverter,
preconditioner=curv.preconditioner)
# ---Added properties and methods---
......@@ -121,15 +123,11 @@ class CriticalPowerEnergy(Energy):
# self.logger.info("Drawing sample %i" % i)
posterior_sample = generate_posterior_sample(
self.m, self.D)
projected_sample = power_analyze(
posterior_sample,
binbounds=self.position.domain[0].binbounds)
w += (projected_sample) * self.rho
w += self.P(abs(posterior_sample) ** 2)
w /= float(self.samples)
else:
w = self.m.power_analyze(
binbounds=self.position.domain[0].binbounds)
w *= self.rho
w = self.P(abs(self.m)**2)
self._w = w
return self._w
......@@ -138,11 +136,6 @@ class CriticalPowerEnergy(Energy):
def _theta(self):
return exp(-self.position) * (self.q + self.w / 2.)
@property
@memo
def _rho_prime(self):
return self.alpha - 1. + self.rho / 2.
@property
@memo
def _Tt(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment