diff --git a/demo/filaments_correlated_infer_a.py b/demo/filaments_correlated_infer_a.py
index 35efe720869cc0689ad3e3c2ced3c9426d9c66f5..bbfa859f43825713a221b81a90dd1faf2bfc8d42 100644
--- a/demo/filaments_correlated_infer_a.py
+++ b/demo/filaments_correlated_infer_a.py
@@ -51,15 +51,15 @@ def get_filament_prior(domain):
     cfmaker_phi0.set_amplitude_total_offset(0., (1.0, 0.1))
     Correlated_field_phi0 = cfmaker_phi0.finalize()
     # linear field phi0
-    #Phi0 = Correlated_field_phi0
+    Phi0 = Correlated_field_phi0
     # minus lognormal field phi0
-    Phi0_ = Correlated_field_phi0
-    Phi0 = -1 * ift.exp(Phi0_)
+    #Phi0_ = Correlated_field_phi0
+    #Phi0 = -1 * ift.exp(Phi0_)
 
     ### 2.Calculate initial wave function operator Psi_0
 
     hbar = 5 * 10 ** -3
-    a = 0.05 # time scale
+    #a = 0.05 # time scale
     Half_operator_ = ift.ScalingOperator(C0.target, 0.5)
     Hbar_operator = ift.ScalingOperator(Phi0.target, -1j / hbar)
     Complexifier = ift.Realizer(Phi0.target).adjoint
@@ -79,22 +79,20 @@ def get_filament_prior(domain):
     # infer time a
     # A_operator in harmonic space
     A_operator_scalar = ift.LognormalTransform(0.05, 0.025, 'time_a', 0)
+    # expander(ContractionOperator.adjoint)
     ContractionOp_adj = ift.ContractionOperator(harmonic_space, None).adjoint
     A_operator = ContractionOp_adj(A_operator_scalar)
 
-    Hbar_half_operator = ift.ScalingOperator(A_operator.target, -1j * hbar * 0.5)
-
-    Complexifier_a = ift.Realizer(A_operator.target).adjoint
-    Hbar_half_operator_ = Hbar_half_operator @ Complexifier_a
-    Hbar_half_operator__ = ift.exp(Hbar_half_operator_)
-    Hbar_half_operator_exp = Hbar_half_operator__(A_operator)
 
     # length of k vector for each pixel
     k_values = harmonic_space.get_k_length_array()
-    k_values_squared_exp = ift.exp(k_values ** 2)
-    K_values_squared_exp = ift.makeOp(k_values_squared_exp)
+    K_values_squared = ift.makeOp(k_values ** 2)
+    Hbar_half_operator_ = ift.ScalingOperator(K_values_squared.target, -1j * hbar * 0.5)
+    Hbar_half_operator__ = Hbar_half_operator_(K_values_squared)
+    Complexifier_k = ift.Realizer(K_values_squared.target).adjoint
+    Hbar_half_operator = Hbar_half_operator__ @ Complexifier_k
 
-    Propagator_h = K_values_squared_exp(Hbar_half_operator_exp)
+    Propagator_h = ift.exp(Hbar_half_operator(A_operator))
 
     Psi_1h = Propagator_h * Psi_0h
     Psi_1 = ifft(Psi_1h)
@@ -146,7 +144,7 @@ def main():
 
     xfov = yfov = "250as"
     #npix = 4000
-    npix = 3000
+    npix = 1000
     #npix = 30
 
     fov = np.array([rve.str2rad(xfov), rve.str2rad(yfov)])
@@ -188,12 +186,12 @@ def main():
     # ) ** (-2)
 
 
-    mini = ift.NewtonCG(ift.GradientNormController(name="newton", iteration_limit=15))
+    mini = ift.NewtonCG(ift.GradientNormController(name="newton", iteration_limit=5))
     # Fit point source only
     state = rve.MinimizationState(0.1 * ift.from_random(sky.domain), [])
     lh = rve.ImagingLikelihood(obs, sky)
     ham = ift.StandardHamiltonian(
-        lh, ift.AbsDeltaEnergyController(0.5, iteration_limit=500)
+        lh, ift.AbsDeltaEnergyController(0.5, iteration_limit=100)
     )
     cst = filaments.domain.keys()
     state = rve.simple_minimize(
@@ -203,7 +201,7 @@ def main():
 
     # Fit diffuse + points
     for ii in range(20):
-        state = rve.simple_minimize(ham, state.mean, 2, mini)
+        state = rve.simple_minimize(ham, state.mean, 0, mini)
         if ii >= 19:
             state.save(f"filaments{ii}")