From 5d2550354257c66b771f10b660db5140c003f07a Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Tue, 24 Aug 2021 13:47:54 +0200
Subject: [PATCH] Last minute changes

---
 1_wiener_filter_solution.py | 44 +++++++++++++++----------------------
 helpers/generate_data.py    |  2 +-
 2 files changed, 19 insertions(+), 27 deletions(-)

diff --git a/1_wiener_filter_solution.py b/1_wiener_filter_solution.py
index ea2c535..21e32b3 100644
--- a/1_wiener_filter_solution.py
+++ b/1_wiener_filter_solution.py
@@ -15,8 +15,6 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
-import numpy as np
-
 import nifty7 as ift
 from helpers import generate_wf_data, plot_WF
 
@@ -27,38 +25,32 @@ position_space = ift.RGSpace(256)
 prior_spectrum = lambda k: 1/(10. + k**2.5)
 data, ground_truth = generate_wf_data(position_space, prior_spectrum)
 
-R = ift.GeometryRemover(position_space)
-data_space = R.target
-data = ift.makeField(data_space, data)
-
 ground_truth = ift.makeField(position_space, ground_truth)
+data_space = ift.UnstructuredDomain(data.shape)
+data = ift.makeField(data_space, data)
 plot_WF('1_data', ground_truth, data)
 
-N = ift.ScalingOperator(data_space, 0.1)
+h_space = position_space.get_default_codomain()
+HT = ift.HarmonicTransformOperator(h_space, target=position_space)
 
-harmonic_space = position_space.get_default_codomain()
-HT = ift.HartleyOperator(harmonic_space, target=position_space)
+# Operators
+Sh = ift.create_power_operator(h_space, power_spectrum=prior_spectrum)
+R = ift.GeometryRemover(position_space) @ HT
 
-S_h = ift.create_power_operator(harmonic_space, prior_spectrum)
-S = HT @ S_h @ HT.adjoint
-
-D_inv = S.inverse + R.adjoint @ N.inverse @ R
-j = (R.adjoint @ N.inverse)(data)
+# Fields and data
+N = ift.ScalingOperator(data_space, 0.1)
+j = R.adjoint(N.inverse(data))
 
 IC = ift.GradientNormController(iteration_limit=100, tol_abs_gradnorm=1e-7)
-D = ift.InversionEnabler(D_inv.inverse, IC, approximation=S)
+# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
+# helper methods.
+curv = ift.WienerFilterCurvature(R, N, Sh, iteration_controller=IC,
+                                 iteration_controller_sampling=IC)
+D = curv.inverse
 
 m = D(j)
+plot_WF('1_result', ground_truth, data, HT(m))
 
-plot_WF('1_result', ground_truth, data, m)
-
-N = ift.SamplingDtypeSetter(N, np.float64)
-# The following S adds information necessary for sampling to the above
-# defined S.
-S = ift.SandwichOperator.make(HT.adjoint,
-                              ift.SamplingDtypeSetter(S_h, np.float64))
-Dinv = ift.WienerFilterCurvature(R, N, S, IC, IC, None, None)
 N_samples = 10
-samples = [Dinv.draw_sample(from_inverse=True) + m for i in range(N_samples)]
-
-plot_WF('1_result_with_uncertainty', ground_truth, data, m=m, samples=samples)
+samples = [HT(D.draw_sample() + m) for i in range(N_samples)]
+plot_WF('1_result_with_uncertainty', ground_truth, data, m=HT(m), samples=samples)
diff --git a/helpers/generate_data.py b/helpers/generate_data.py
index aa95586..8849de9 100644
--- a/helpers/generate_data.py
+++ b/helpers/generate_data.py
@@ -42,7 +42,7 @@ def generate_bernoulli_data(signal_response):
 
 def generate_wf_data(domain, spectrum):
     harmonic_space = domain.get_default_codomain()
-    HT = ift.HartleyOperator(harmonic_space, target=domain)
+    HT = ift.HarmonicTransformOperator(harmonic_space, target=domain)
     N = ift.ScalingOperator(domain, 0.1)
     S_k = ift.create_power_operator(harmonic_space, spectrum)
     s = HT(S_k.draw_sample_with_dtype(np.float64)).val
-- 
GitLab