Skip to content
Snippets Groups Projects
Commit a06601c5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

adjust to nightly branch

parent df3b073b
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -96,7 +96,7 @@ if __name__ == "__main__": ...@@ -96,7 +96,7 @@ if __name__ == "__main__":
IC2 = ift.GradientNormController(verbose=True, iteration_limit=100, IC2 = ift.GradientNormController(verbose=True, iteration_limit=100,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
minimizer2 = ift.VL_BFGS(IC2, max_history_length=20) minimizer2 = ift.VL_BFGS(IC2, max_history_length=20)
IC3 = ift.GradientNormController(verbose=True, iteration_limit=100, IC3 = ift.GradientNormController(verbose=True, iteration_limit=1000,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
minimizer3 = ift.SteepestDescent(IC3) minimizer3 = ift.SteepestDescent(IC3)
......
...@@ -21,7 +21,7 @@ class CriticalPowerCurvature(EndomorphicOperator): ...@@ -21,7 +21,7 @@ class CriticalPowerCurvature(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, theta, T): def __init__(self, theta, T):
self.theta = DiagonalOperator(theta.weight(1)) self.theta = DiagonalOperator(theta)
self.T = T self.T = T
super(CriticalPowerCurvature, self).__init__() super(CriticalPowerCurvature, self).__init__()
......
from ...energies.energy import Energy from ...energies.energy import Energy
from ...operators.smoothness_operator import SmoothnessOperator from ...operators.smoothness_operator import SmoothnessOperator
from ...operators.power_projection_operator import PowerProjectionOperator
from ...operators.inversion_enabler import InversionEnabler from ...operators.inversion_enabler import InversionEnabler
from . import CriticalPowerCurvature from . import CriticalPowerCurvature
from ...memoization import memo from ...memoization import memo
from ...sugar import generate_posterior_sample, power_analyze
from ... import Field, exp from ... import Field, exp
from ...sugar import generate_posterior_sample
class CriticalPowerEnergy(Energy): class CriticalPowerEnergy(Energy):
...@@ -66,7 +66,8 @@ class CriticalPowerEnergy(Energy): ...@@ -66,7 +66,8 @@ class CriticalPowerEnergy(Energy):
self.T = SmoothnessOperator(domain=self.position.domain[0], self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior, strength=smoothness_prior,
logarithmic=logarithmic) logarithmic=logarithmic)
self.rho = self.position.domain[0].rho self.P = PowerProjectionOperator(domain=self.m.domain,
power_space=self.position.domain[0])
self._w = w self._w = w
self._inverter = inverter self._inverter = inverter
...@@ -80,26 +81,27 @@ class CriticalPowerEnergy(Energy): ...@@ -80,26 +81,27 @@ class CriticalPowerEnergy(Energy):
inverter=self._inverter) inverter=self._inverter)
@property @property
@memo
def value(self): def value(self):
energy = self._theta.sum() energy = Field.ones_like(self.position).vdot(self._theta)
energy += self.position.weight(-1).vdot(self._rho_prime) energy += self.position.vdot(self.alpha-0.5)
energy += 0.5 * self.position.vdot(self._Tt) energy += 0.5 * self.position.vdot(self._Tt)
return energy.real return energy.real
@property @property
@memo
def gradient(self): def gradient(self):
gradient = -self._theta.weight(-1) gradient = -self._theta
gradient += self._rho_prime.weight(-1) gradient += self.alpha-0.5
gradient += self._Tt gradient += self._Tt
gradient = gradient.real return gradient.real
return gradient
@property @property
@memo
def curvature(self): def curvature(self):
curvature = InversionEnabler(CriticalPowerCurvature( curv = CriticalPowerCurvature(theta=self._theta, T=self.T)
theta=self._theta.weight(-1), return InversionEnabler(curv, inverter=self._inverter,
T=self.T), inverter=self._inverter) preconditioner=curv.preconditioner)
return curvature
# ---Added properties and methods--- # ---Added properties and methods---
...@@ -121,15 +123,11 @@ class CriticalPowerEnergy(Energy): ...@@ -121,15 +123,11 @@ class CriticalPowerEnergy(Energy):
# self.logger.info("Drawing sample %i" % i) # self.logger.info("Drawing sample %i" % i)
posterior_sample = generate_posterior_sample( posterior_sample = generate_posterior_sample(
self.m, self.D) self.m, self.D)
projected_sample = power_analyze( w += self.P(abs(posterior_sample) ** 2)
posterior_sample,
binbounds=self.position.domain[0].binbounds)
w += (projected_sample) * self.rho
w /= float(self.samples) w /= float(self.samples)
else: else:
w = self.m.power_analyze( w = self.P(abs(self.m)**2)
binbounds=self.position.domain[0].binbounds)
w *= self.rho
self._w = w self._w = w
return self._w return self._w
...@@ -138,11 +136,6 @@ class CriticalPowerEnergy(Energy): ...@@ -138,11 +136,6 @@ class CriticalPowerEnergy(Energy):
def _theta(self): def _theta(self):
return exp(-self.position) * (self.q + self.w / 2.) return exp(-self.position) * (self.q + self.w / 2.)
@property
@memo
def _rho_prime(self):
return self.alpha - 1. + self.rho / 2.
@property @property
@memo @memo
def _Tt(self): def _Tt(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment