Commit d3f0bf6c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 9b072104
Pipeline #24381 passed with stage
in 6 minutes and 8 seconds
...@@ -93,17 +93,17 @@ ...@@ -93,17 +93,17 @@
### Implement Propagator ### Implement Propagator
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def PropagatorOperator(R, N, Sh): def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000, IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
D = (R.adjoint*N.inverse*R + Sh.inverse).inverse # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# MR FIXME: we can/should provide a preconditioner here as well! # helper methods.
return ift.InversionEnabler(D, inverter) return ift.library.WienerFilterCurvature(R,N,Sh,inverter)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
...@@ -156,11 +156,12 @@ ...@@ -156,11 +156,12 @@
n = ift.Field.from_random(domain=s_space, random_type='normal', n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0) std=noise_amplitude, mean=0)
d = noiseless_data + n d = noiseless_data + n
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(R=R, N=N, Sh=Sh) curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Run Wiener Filter ### Run Wiener Filter
...@@ -281,11 +282,12 @@ ...@@ -281,11 +282,12 @@
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
D = PropagatorOperator(R=R, N=N, Sh=Sh) curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
m = D(j) m = D(j)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
...@@ -295,16 +297,10 @@ ...@@ -295,16 +297,10 @@
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
sc = ift.probing.utils.StatCalculator() sc = ift.probing.utils.StatCalculator()
IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)
for i in range(200): for i in range(200):
print i print i
sc.add(HT(curv.generate_posterior_sample())) sc.add(HT(curv.generate_posterior_sample()))
m_var = sc.var m_var = sc.var
...@@ -384,12 +380,10 @@ ...@@ -384,12 +380,10 @@
p_space = ift.PowerSpace(h_space) p_space = ift.PowerSpace(h_space)
# Operators # Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec) Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(sigma2,s_space) N = ift.ScalingOperator(sigma2,s_space)
R = ift.FFTSmoothingOperator(s_space, sigma=.01)
#D = PropagatorOperator(R=R, N=N, Sh=Sh)
# Fields and data # Fields and data
sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True) sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)
n = ift.Field.from_random(domain=s_space, random_type='normal', n = ift.Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0) std=np.sqrt(sigma2), mean=0)
...@@ -402,11 +396,12 @@ ...@@ -402,11 +396,12 @@
mask = ift.Field(s_space, val=1) mask = ift.Field(s_space, val=1)
mask.val[l:h,l:h] = 0 mask.val[l:h,l:h] = 0
R = ift.DiagonalOperator(mask)*HT R = ift.DiagonalOperator(mask)*HT
n.val[l:h, l:h] = 0 n.val[l:h, l:h] = 0
D = PropagatorOperator(R=R, N=N, Sh=Sh) curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
d = R(sh) + n d = R(sh) + n
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter # Run Wiener filter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment