Skip to content
Snippets Groups Projects
Commit 432752bd authored by Philipp Arras's avatar Philipp Arras
Browse files

Add noise inference

parent 28d4befd
No related branches found
No related tags found
No related merge requests found
Pipeline #60399 passed
......@@ -52,15 +52,30 @@ signal_response = R @ signal
N = ift.ScalingOperator(0.1, data_space)
data, ground_truth = generate_gaussian_data(signal_response, N)
# Corrupt data
if True:
arr = data.to_global_data_rw()
arr[1001] *= 1000
arr[48] *= 1000
data = ift.from_global_data(data.domain, arr)
plot_prior_samples_2d(5, signal, R, signal, A, 'gauss', N=N)
likelihood = ift.GaussianEnergy(
mean=data, inverse_covariance=N.inverse)(signal_response)
residual = ift.Adder(data, neg=True) @ signal_response
# Enable noise inference
if True:
scale_q = 0.5*N.inverse
N_amp = scale_q @ ift.InverseGammaOperator(data_space, 0.5, 1) @ scale_q
sqrtnop = N_amp.ducktape('eta')**-0.5
residual = sqrtnop*residual
likelihood = ift.GaussianEnergy(domain=data_space) @ residual
# Solve inference problem
ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.GradInfNormController(
name='Newton', tol=1e-6, iteration_limit=30)
ic_newton = ift.GradInfNormController(name='Newton',
tol=1e-6,
iteration_limit=30)
minimizer = ift.NewtonCG(ic_newton)
H = ift.StandardHamiltonian(likelihood, ic_sampling)
......@@ -68,12 +83,12 @@ initial_mean = ift.MultiField.full(H.domain, 0.)
mean = initial_mean
# Draw five samples and minimize KL, iterate 10 times
for _ in range(10):
for _ in range(3):
KL = ift.MetricGaussianKL(mean, H, 5)
KL, convergence = minimizer(KL)
mean = KL.position
# Draw posterior samples and plot
N_posterior_samples = 30
N_posterior_samples = 10
KL = ift.MetricGaussianKL(mean, H, N_posterior_samples)
plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, 'criticalfilter')
......@@ -162,7 +162,7 @@ def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name):
sc = ift.StatCalculator()
sky_samples, pspec_samples = [], []
for sample in KL.samples:
tmp = signal(sample + KL.position)
tmp = signal.force(sample + KL.position)
sc.add(tmp)
sky_samples.append(tmp)
pspec_samples.append(A.force(sample)**2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment