Commit 713a86b0 authored by Jakob Knollmueller's avatar Jakob Knollmueller

now working with default NIFTy_4

parent 34357ff5
...@@ -43,11 +43,11 @@ def generate_mock_data(): ...@@ -43,11 +43,11 @@ def generate_mock_data():
if __name__ == '__main__': if __name__ == '__main__':
np.random.seed(42) np.random.seed(42)
data = generate_mock_data() data = generate_mock_data()
myStarblade = sb.build_starblade(data=data, alpha=1.5, newton_steps=200, cg_steps=5, q=1e-3) myStarblade = sb.build_starblade(data=data, alpha=1.5, q=1e-3)
for i in range(3): # not fully converged after 3 steps. for i in range(5): # not fully converged after 5 steps.
myStarblade = sb.starblade_iteration(myStarblade, samples=5, cg_steps=10, myStarblade = sb.starblade_iteration(myStarblade, samples=5, cg_steps=100,
newton_steps=100, sampling_steps=1000) newton_steps=50)
### PLOTS ### ### PLOTS ###
vmin = data.min() vmin = data.min()
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, log, exp, DiagonalOperator,\ from nifty4 import Energy, Field, log, exp, DiagonalOperator,\
create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler, SamplingEnabler create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler#, SamplingEnabler
from nifty4.library.nonlinearities import PositiveTanh, Tanh from nifty4.library.nonlinearities import PositiveTanh, Tanh
...@@ -59,7 +59,6 @@ class StarbladeEnergy(Energy): ...@@ -59,7 +59,6 @@ class StarbladeEnergy(Energy):
self.parameters = parameters self.parameters = parameters
self.inverter = parameters['inverter'] self.inverter = parameters['inverter']
self.sampling_inverter = parameters['sampling_inverter']
self.d = parameters['data'] self.d = parameters['data']
self.FFT = parameters['FFT'] self.FFT = parameters['FFT']
self.power_spectrum = parameters['power_spectrum'] self.power_spectrum = parameters['power_spectrum']
...@@ -116,5 +115,5 @@ class StarbladeEnergy(Energy): ...@@ -116,5 +115,5 @@ class StarbladeEnergy(Energy):
R = ScalingOperator(1., point.domain) R = ScalingOperator(1., point.domain)
S_p = DiagonalOperator(self.s_p) S_p = DiagonalOperator(self.s_p)
my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint*S_p, self.correlation.inverse) my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint*S_p, self.correlation.inverse)
curv = InversionEnabler(SamplingEnabler(my_S_inv+N_inv, O_x, self.sampling_inverter), self.inverter) curv = InversionEnabler(my_S_inv+N_inv, self.inverter,O_x)
return curv return curv
...@@ -46,8 +46,6 @@ class StarbladeKL(Energy): ...@@ -46,8 +46,6 @@ class StarbladeKL(Energy):
the minimization strategy to use for operator inversion the minimization strategy to use for operator inversion
newton_iterations : newton_iterations :
Number of Newton optimization steps. Number of Newton optimization steps.
sampling_inverter :
Inverter which is used to generate samples.
""" """
def __init__(self, position, samples, parameters): def __init__(self, position, samples, parameters):
......
...@@ -23,8 +23,8 @@ from .starblade_energy import StarbladeEnergy ...@@ -23,8 +23,8 @@ from .starblade_energy import StarbladeEnergy
from .starblade_kl import StarbladeKL from .starblade_kl import StarbladeKL
def build_starblade(data, alpha=1.5, q=1e-10, cg_steps=10, newton_steps=100, def build_starblade(data, alpha=1.5, q=1e-10, cg_steps=100, newton_steps=100,
manual_power_spectrum=None, sampling_steps=100): manual_power_spectrum=None):
""" Setting up the StarbladeEnergy for the given data and parameters """ Setting up the StarbladeEnergy for the given data and parameters
Parameters Parameters
---------- ----------
...@@ -64,19 +64,16 @@ def build_starblade(data, alpha=1.5, q=1e-10, cg_steps=10, newton_steps=100, ...@@ -64,19 +64,16 @@ def build_starblade(data, alpha=1.5, q=1e-10, cg_steps=10, newton_steps=100,
ICI = ift.GradientNormController(iteration_limit=cg_steps, ICI = ift.GradientNormController(iteration_limit=cg_steps,
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI) inverter = ift.ConjugateGradient(controller=ICI)
IC_samples = ift.GradientNormController(iteration_limit=sampling_steps,
tol_abs_gradnorm=1e-5)
sampling_inverter = ift.ConjugateGradient(controller=IC_samples)
parameters = dict(data=data, power_spectrum=initial_spectrum, parameters = dict(data=data, power_spectrum=initial_spectrum,
alpha=alpha, q=q, alpha=alpha, q=q,
inverter=inverter, FFT=FFT, inverter=inverter, FFT=FFT,
newton_iterations=newton_steps, sampling_inverter=sampling_inverter, update_power=update_power) newton_iterations=newton_steps, update_power=update_power)
Starblade = StarbladeEnergy(position=initial_x, parameters=parameters) Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
return Starblade return Starblade
def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sampling_steps=1000): def starblade_iteration(starblade, samples=5, cg_steps=100, newton_steps=100):
""" Performing one Newton minimization step """ Performing one Newton minimization step
Parameters Parameters
---------- ----------
...@@ -86,11 +83,9 @@ def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sam ...@@ -86,11 +83,9 @@ def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sam
Number of samples drawn in order to estimate the KL. If zero the MAP is calculated (default: 5). Number of samples drawn in order to estimate the KL. If zero the MAP is calculated (default: 5).
cg_steps : int cg_steps : int
Maximum number of conjugate gradient iterations for Maximum number of conjugate gradient iterations for
numerical operator inversion for each Newton step (default: 10). numerical operator inversion (default: 100).
newton_steps : int newton_steps : int
Number of consecutive Newton steps within one algorithmic step.(default: 100) Number of consecutive Newton steps within one algorithmic step.(default: 100)
sampling_steps : int
Number of conjugate gradient steps for each sample (default: 1000).
""" """
controller = ift.GradientNormController(name="Newton", controller = ift.GradientNormController(name="Newton",
tol_abs_gradnorm=1e-8, tol_abs_gradnorm=1e-8,
...@@ -98,17 +93,13 @@ def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sam ...@@ -98,17 +93,13 @@ def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sam
minimizer = ift.RelaxedNewton(controller=controller) minimizer = ift.RelaxedNewton(controller=controller)
ICI = ift.GradientNormController(iteration_limit=cg_steps, tol_abs_gradnorm=1e-5) ICI = ift.GradientNormController(iteration_limit=cg_steps, tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI) inverter = ift.ConjugateGradient(controller=ICI)
IC_samples = ift.GradientNormController(iteration_limit=sampling_steps,
tol_abs_gradnorm=1e-5)
sampling_inverter = ift.ConjugateGradient(controller=IC_samples)
para = starblade.parameters para = starblade.parameters
para['inverter'] = inverter para['inverter'] = inverter
para['sampling_inverter'] = sampling_inverter
energy = StarbladeEnergy(starblade.position, parameters=para) energy = StarbladeEnergy(starblade.position, parameters=para)
sample_list = [] sample_list = []
for i in range(samples): for i in range(samples):
sample = energy.curvature.inverse.draw_sample() sample = energy.curvature.draw_sample(from_inverse=True)
sample_list.append(sample) sample_list.append(sample)
if len(sample_list) > 0: if len(sample_list) > 0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=energy.parameters) energy = StarbladeKL(starblade.position, samples=sample_list, parameters=energy.parameters)
...@@ -118,7 +109,7 @@ def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sam ...@@ -118,7 +109,7 @@ def starblade_iteration(starblade, samples=5, cg_steps=10, newton_steps=100, sam
energy = StarbladeEnergy(energy.position, parameters=energy.parameters) energy = StarbladeEnergy(energy.position, parameters=energy.parameters)
sample_list = [] sample_list = []
for i in range(samples): for i in range(samples):
sample = energy.curvature.inverse.draw_sample() sample = energy.curvature.draw_sample(from_inverse=True)
sample_list.append(sample) sample_list.append(sample)
if len(sample_list) == 0: if len(sample_list) == 0:
sample_list.append(energy.position) sample_list.append(energy.position)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment