Skip to content
Snippets Groups Projects
Commit e0f0f20a authored by Philipp Arras's avatar Philipp Arras
Browse files

Improvements demo 3

parent 09456c35
Branches
Tags
No related merge requests found
...@@ -16,130 +16,133 @@ ...@@ -16,130 +16,133 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
############################################################ ############################################################
# Non-linear tomography # Non-linear tomography
# data is line of sight (LOS) field # The data is integrated lines of sight
# random lines (set mode=0), radial lines (mode=1) # Random lines (set mode=0), radial lines (mode=1)
############################################################# #############################################################
mode = 0
import nifty5 as ift
import numpy as np import numpy as np
import nifty5 as ift
def get_random_LOS(n_los): def random_los(n_los):
# Setting up LOS
starts = list(np.random.uniform(0, 1, (n_los, 2)).T) starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
if mode == 0: ends = list(0.5 + 0*np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(np.random.uniform(0, 1, (n_los, 2)).T) return starts, ends
else:
ends = list(0.5+0*np.random.uniform(0, 1, (n_los, 2)).T)
def radial_los(n_los):
starts = list(np.random.uniform(0, 1, (n_los, 2)).T)
ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
return starts, ends return starts, ends
if __name__ == '__main__': if __name__ == '__main__':
# FIXME description of the tutorial
np.random.seed(420) np.random.seed(420)
np.seterr(all='raise')
# Choose between random line-of-sight response (mode=1) and radial lines
# of sight (mode=2)
mode = 1
position_space = ift.RGSpace([128, 128]) position_space = ift.RGSpace([128, 128])
# Setting up an amplitude model for the field # Set up an amplitude model for the field
A = ift.AmplitudeModel(position_space, 64, 3, 0.4, -5., 0.5, 0.4, 0.3) # The parameters mean:
# made choices: # 64 spectral bins
# 64 spectral bins
# #
# Spectral smoothness (affects Gaussian process part) # Spectral smoothness (affects Gaussian process part)
# 3 = relatively high variance of spectral curbvature # 3 = relatively high variance of spectral curbvature
# 0.4 = quefrency mode below which cepstrum flattens # 0.4 = quefrency mode below which cepstrum flattens
# #
# power law part of spectrum: # Power-law part of spectrum:
# -5= preferred power-law slope # -5 = preferred power-law slope
# 0.5 = low variance of power-law slope # 0.5 = low variance of power-law slope
# # 0.4 = y-intercept mean
# Gaussian process part of log-spectrum # 0.3 = relatively high y-intercept variance
# 0.4 = y-intercept mean of additional power A = ift.AmplitudeModel(position_space, 64, 3, 0.4, -5., 0.5, 0.4, 0.3)
# 0.3 = y-intercept variance of additional power
# Build the model for a correlated signal
# Building the model for a correlated signal
harmonic_space = position_space.get_default_codomain() harmonic_space = position_space.get_default_codomain()
ht = ift.HarmonicTransformOperator(harmonic_space, position_space) ht = ift.HarmonicTransformOperator(harmonic_space, position_space)
power_space = A.target[0] power_space = A.target[0]
power_distributor = ift.PowerDistributor(harmonic_space, power_space) power_distributor = ift.PowerDistributor(harmonic_space, power_space)
vol = harmonic_space.scalar_dvol vol = ift.ScalingOperator(harmonic_space.scalar_dvol**(-0.5),
vol = ift.ScalingOperator(vol**(-0.5), harmonic_space) harmonic_space)
correlated_field = ht( correlated_field = ht(
vol(power_distributor(A))*ift.ducktape(harmonic_space, None, 'xi')) vol(power_distributor(A))*ift.ducktape(harmonic_space, None, 'xi'))
# alternatively to the block above one can do: # Alternatively, one can use:
#correlated_field = ift.CorrelatedField(position_space, A) # correlated_field = ift.CorrelatedField(position_space, A)
# apply some nonlinearity # Apply a nonlinearity
signal = ift.positive_tanh(correlated_field) signal = ift.positive_tanh(correlated_field)
# Building the Line of Sight response # Build the line-of-sight response and define signal response
LOS_starts, LOS_ends = get_random_LOS(100) LOS_starts, LOS_ends = random_los(100) if mode == 1 else radial_los(100)
R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends) R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends)
# build signal response model and model likelihood
signal_response = R(signal) signal_response = R(signal)
# specify noise
# Specify noise
data_space = R.target data_space = R.target
noise = .001 noise = .001
N = ift.ScalingOperator(noise, data_space) N = ift.ScalingOperator(noise, data_space)
# generate mock signal and data # Generate mock signal and data
MOCK_POSITION = ift.from_random('normal', signal_response.domain) mock_position = ift.from_random('normal', signal_response.domain)
data = signal_response(MOCK_POSITION) + N.draw_sample() data = signal_response(mock_position) + N.draw_sample()
# set up model likelihood # Minimization parameters
likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response)
# set up minimization and inversion schemes
ic_sampling = ift.GradientNormController(iteration_limit=100) ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.GradInfNormController( ic_newton = ift.GradInfNormController(
name='Newton', tol=1e-7, iteration_limit=35) name='Newton', tol=1e-7, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# build model Hamiltonian # Set up model likelihood and information Hamiltonian
likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response)
H = ift.Hamiltonian(likelihood, ic_sampling) H = ift.Hamiltonian(likelihood, ic_sampling)
INITIAL_POSITION = ift.MultiField.full(H.domain, 0.) initial_position = ift.MultiField.full(H.domain, 0.)
position = INITIAL_POSITION position = initial_position
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(MOCK_POSITION), title='Ground Truth') plot.add(signal(mock_position), title='Ground Truth')
plot.add(R.adjoint_times(data), title='Data') plot.add(R.adjoint_times(data), title='Data')
plot.add([A.force(MOCK_POSITION)], title='Power Spectrum') plot.add([A.force(mock_position)], title='Power Spectrum')
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png")
# number of samples used to estimate the KL # number of samples used to estimate the KL
N_samples = 20 N_samples = 20
# five intermediate steps in minimization to illustrate progress # Draw new samples to approximate the KL five times
for i in range(5): for i in range(5):
# set up KL
KL = ift.KL_Energy(position, H, N_samples) KL = ift.KL_Energy(position, H, N_samples)
# minimize KL until iteration limit reached # Minimize KL
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
# read out position # Update position
position = KL.position position = KL.position
# plot momentariy field
# Plot current reconstruction
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(KL.position), title="reconstruction") plot.add(signal(KL.position), title="reconstruction")
plot.add([A.force(KL.position), A.force(MOCK_POSITION)], title="power") plot.add([A.force(KL.position), A.force(mock_position)], title="power")
plot.output(ny=1, ysize=6, xsize=16, name="loop-{:02}.png".format(i)) plot.output(ny=1, ysize=6, xsize=16, name="loop-{:02}.png".format(i))
# final plot # Draw posterior samples
KL = ift.KL_Energy(position, H, N_samples) KL = ift.KL_Energy(position, H, N_samples)
plot = ift.Plot()
# do statistics
sc = ift.StatCalculator() sc = ift.StatCalculator()
for sample in KL.samples: for sample in KL.samples:
sc.add(signal(sample + KL.position)) sc.add(signal(sample + KL.position))
# Plotting
plot = ift.Plot()
plot.add(sc.mean, title="Posterior Mean") plot.add(sc.mean, title="Posterior Mean")
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers = [A.force(s + KL.position) for s in KL.samples] powers = [A.force(s + KL.position) for s in KL.samples]
plot.add( plot.add(
powers + [A.force(KL.position), A.force(MOCK_POSITION)], powers + [A.force(KL.position),
A.force(mock_position)],
title="Sampled Posterior Power Spectrum", title="Sampled Posterior Power Spectrum",
linewidth=[1.]*len(powers) + [3., 3.]) linewidth=[1.]*len(powers) + [3., 3.])
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment