Commit d3ef2730 authored by Jakob Knollmueller's avatar Jakob Knollmueller

finally all errors gone? back to KL, samples now reasonable, performs excellent

parent 8bdaa000
This diff is collapsed.
......@@ -16,9 +16,10 @@
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, log, exp, DiagonalOperator, create_power_operator
from nifty4 import Energy, Field, log, exp, DiagonalOperator,\
create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler
from nifty4.library import WienerFilterCurvature
from nifty4.library.nonlinearities import PositiveTanh
from nifty4.library.nonlinearities import PositiveTanh, Tanh
class StarbladeEnergy(Energy):
......@@ -64,12 +65,16 @@ class StarbladeEnergy(Energy):
self.update_power = parameters['update_power']
self.newton_iterations = parameters['newton_iterations']
pos_tanh = PositiveTanh()
self.S = self.FFT * self.correlation * self.FFT.adjoint
tanh = Tanh()
self.S = SandwichOperator.make(self.FFT.adjoint, self.correlation)
# self.S = self.FFT * self.correlation * self.FFT.adjoint
self.a = pos_tanh(self.position)
self.a_p = pos_tanh.derivative(self.position)
self.a_pp = -tanh(position)*tanh.derivative(self.position)
self.u = log(self.d * self.a)
self.u_p = self.a_p/self.a
self.u_a = -log(self.a)
self.u_ap = - self.a_p/self.a
one_m_a = 1 - self.a
self.s = log(self.d * one_m_a)
self.s_p = - self.a_p / one_m_a
......@@ -80,24 +85,40 @@ class StarbladeEnergy(Energy):
@property
def value(self):
point = 0
diffuse = 0
det = 0
diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
det = self.s.integrate()
det = - self.s.sum()
det += - self.u_a.sum()
det += -log(self.a_p).sum()
det += 0.5 / self.var_x * self.position.vdot(self.position)
return diffuse + point + det
return diffuse + point + det
@property
def gradient(self):
point = 0
diffuse = 0
det = 0
diffuse = self.S.inverse(self.s) * self.s_p
point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
det = self.position / self.var_x
det += self.s_p
return diffuse + point + det
det += - self.s_p
det += - self.u_ap
det += -1./self.a_p * self.a_pp
return +diffuse + point +det
@property
def curvature(self):
point = self.q * exp(-self.u) * self.u_p ** 2
R = self.FFT.inverse * self.s_p
N = self.correlation
S = DiagonalOperator(1/(point + 1/self.var_x))
return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
# R = self.FFT.inverse * self.s_p
# N = self.correlation
N_inv = DiagonalOperator(point + 1/self.var_x )#+ 2*self.a_p))
R = ScalingOperator(1., point.domain)
S_p = DiagonalOperator(self.s_p)
my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint * S_p, self.correlation.inverse)
curv = InversionEnabler(N_inv + my_S_inv, self.inverter)
return curv
# return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
......@@ -16,7 +16,7 @@
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler
from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler, full
from starblade_energy import StarbladeEnergy
class StarbladeKL(Energy):
......@@ -68,7 +68,7 @@ class StarbladeKL(Energy):
@property
def gradient(self):
gradient = Field.zeros(self.position.domain)
gradient = full(self.position.domain,0.)
for energy in self.energy_list:
gradient += energy.gradient
gradient /= len(self.energy_list)
......@@ -76,7 +76,7 @@ class StarbladeKL(Energy):
@property
def curvature(self):
curvature = DiagonalOperator(Field.zeros(self.position.domain))
curvature = DiagonalOperator(full(self.position.domain, 0.))
for energy in self.energy_list:
curvature += energy.curvature
curvature *= Field(self.position.domain,val=1./len(self.energy_list))
......
......@@ -43,16 +43,17 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
If it is not specified, the algorithm will try to infer it via critical filtering.
"""
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
s_space = ift.RGSpace(data.shape)#, distances=len(data.shape) * [1])
h_space = s_space.get_default_codomain()
data = ift.Field(s_space,val=data)
FFT = ift.FFTOperator(h_space, target=s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = True)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = ift.PowerSpace(h_space, binbounds=binbounds)
if manual_power_spectrum is None:
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
initial_spectrum = ift.power_analyze(FFT.adjoint(ift.log(data)),
binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**4
initial_spectrum /= 100*(p_space.k_lengths+1.)**4
# initial_spectrum = ift.Field(p_space,val=1e-3)
update_power = True
......@@ -89,19 +90,24 @@ def starblade_iteration(starblade, samples=3):
tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller)
# if len(sample_list)>0:
# energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
# else:
# minimizer = ift.VL_BFGS(controller=controller)
energy = starblade
sample_list = []
for i in range(samples):
sample = energy.curvature.inverse.draw_sample()
sample_list.append(sample)
if len(sample_list)>0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
else:
energy = starblade
energy, convergence = minimizer(energy)
energy = StarbladeEnergy(energy.position, parameters=energy.parameters)
sample_list = []
for i in range(samples):
sample = energy.curvature.inverse.draw_sample()
sample_list.append(sample)
if len(sample_list) == 0:
sample_list.append(energy.position)
# energy = StarbladeKL(energy.position, samples=sample_list, parameters=energy.parameters)
new_position = energy.position
new_parameters = energy.parameters
......@@ -171,11 +177,11 @@ def update_power(energy):
if isinstance(energy, StarbladeKL):
power = 0.
for en in energy.energy_list:
power = ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s),
power += ift.power_analyze(energy.parameters['FFT'].inverse(en.s),
binbounds=en.parameters['power_spectrum'].domain[0].binbounds)
# power /= len(energy.energy_list)
power /= len(energy.energy_list)
else:
power = ift.power_analyze(energy.FFT.inverse_times(energy.s),
power = ift.power_analyze(energy.FFT.inverse(energy.s),
binbounds=energy.parameters['power_spectrum'].domain[0].binbounds)
return power
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment