Skip to content
Snippets Groups Projects
Commit 5d255035 authored by Philipp Arras's avatar Philipp Arras
Browse files

Last minute changes

parent aab304b5
Branches master nifty5_to_nifty7
No related tags found
1 merge request!2Draft: Nifty5 to nifty7
Pipeline #108272 passed
......@@ -15,8 +15,6 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty7 as ift
from helpers import generate_wf_data, plot_WF
......@@ -27,38 +25,32 @@ position_space = ift.RGSpace(256)
prior_spectrum = lambda k: 1/(10. + k**2.5)
data, ground_truth = generate_wf_data(position_space, prior_spectrum)
R = ift.GeometryRemover(position_space)
data_space = R.target
data = ift.makeField(data_space, data)
ground_truth = ift.makeField(position_space, ground_truth)
data_space = ift.UnstructuredDomain(data.shape)
data = ift.makeField(data_space, data)
plot_WF('1_data', ground_truth, data)
N = ift.ScalingOperator(data_space, 0.1)
h_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space, target=position_space)
harmonic_space = position_space.get_default_codomain()
HT = ift.HartleyOperator(harmonic_space, target=position_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=prior_spectrum)
R = ift.GeometryRemover(position_space) @ HT
S_h = ift.create_power_operator(harmonic_space, prior_spectrum)
S = HT @ S_h @ HT.adjoint
D_inv = S.inverse + R.adjoint @ N.inverse @ R
j = (R.adjoint @ N.inverse)(data)
# Fields and data
N = ift.ScalingOperator(data_space, 0.1)
j = R.adjoint(N.inverse(data))
IC = ift.GradientNormController(iteration_limit=100, tol_abs_gradnorm=1e-7)
D = ift.InversionEnabler(D_inv.inverse, IC, approximation=S)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods.
curv = ift.WienerFilterCurvature(R, N, Sh, iteration_controller=IC,
iteration_controller_sampling=IC)
D = curv.inverse
m = D(j)
plot_WF('1_result', ground_truth, data, HT(m))
plot_WF('1_result', ground_truth, data, m)
N = ift.SamplingDtypeSetter(N, np.float64)
# The following S adds information necessary for sampling to the above
# defined S.
S = ift.SandwichOperator.make(HT.adjoint,
ift.SamplingDtypeSetter(S_h, np.float64))
Dinv = ift.WienerFilterCurvature(R, N, S, IC, IC, None, None)
N_samples = 10
samples = [Dinv.draw_sample(from_inverse=True) + m for i in range(N_samples)]
plot_WF('1_result_with_uncertainty', ground_truth, data, m=m, samples=samples)
samples = [HT(D.draw_sample() + m) for i in range(N_samples)]
plot_WF('1_result_with_uncertainty', ground_truth, data, m=HT(m), samples=samples)
......@@ -42,7 +42,7 @@ def generate_bernoulli_data(signal_response):
def generate_wf_data(domain, spectrum):
harmonic_space = domain.get_default_codomain()
HT = ift.HartleyOperator(harmonic_space, target=domain)
HT = ift.HarmonicTransformOperator(harmonic_space, target=domain)
N = ift.ScalingOperator(domain, 0.1)
S_k = ift.create_power_operator(harmonic_space, spectrum)
s = HT(S_k.draw_sample_with_dtype(np.float64)).val
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment