From 1c5f501440055fa62a5fb8b3601cf14306be025f Mon Sep 17 00:00:00 2001
From: pfrank <philipp@mpa-garching.mpg.de>
Date: Thu, 21 Jan 2021 14:41:41 +0100
Subject: [PATCH] new version

---
 config.py         | 205 ++++++++++++++++++++++----
 movie_start.py    |  98 +++++++++++++
 reconstruction.py | 241 +++++++++++++++++++++---------
 src/__init__.py   |   5 +-
 src/closure.py    | 296 ++++++++++++++++++-------------------
 src/constants.py  |  11 +-
 src/data.py       | 108 ++++++++++----
 src/response.py   |  11 +-
 src/sugar.py      | 234 +++++++++++++++++++++++++-----
 test.py           | 363 +++++++++++++++++++++++++++++++---------------
 10 files changed, 1136 insertions(+), 436 deletions(-)
 create mode 100644 movie_start.py

diff --git a/config.py b/config.py
index fb45872..dbb5f8e 100644
--- a/config.py
+++ b/config.py
@@ -12,44 +12,188 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 # Copyright(C) 2019-2020 Max-Planck-Society
-# Author: Philipp Arras, Philipp Frank, Philipp Haim, Reimar Leike,
-# Jakob Knollmueller
+
+import numpy as np
 
 import nifty6 as ift
 import src as vlbi
 
+# MPI information
+comm, nranks, rank, master = ift.utilities.get_MPI_params()
 
-ift.fft.set_nthreads(1)
 
-npix = 128
-dt = 6  # hours
-fov = 256*vlbi.MUAS2RAD  # RAD
+np.seterr(all='raise', under='warn')
+if master:
+    print('Numpy mode:', np.geterr())
 
-doms = ift.RGSpace(2*(npix,), 2*(fov/npix,))
-npixt = 7*24//dt
+oversampling_factor = 2
+nthreads = 1
+ift.fft.set_nthreads(nthreads)
+eps = 1e-7
+radius_fitting_init = {'m87': [5, 0, 1, 15],  # cx, cy, step, dist
+                       'crescent': [10, -25, 1, 15],
+                       'ehtcrescent': [0, 0, 1, 15]}
+radius_fitting_dtheta, radius_fitting_sample_r = 1, 0.5*vlbi.MUAS2RAD
+
+# vlbi.set_SNR_cutoff(0.2)
+
+npix = 256
+dt = 6  # hours
+# npix = 128
+# dt = 24  # hours
 startt = 0
+npixt = 7*24//dt
+fov = 256*vlbi.MUAS2RAD
+min_timestamps_per_bin = 12  # how many different time stamps should be at least in a time bin
+npix = int(npix)
+doms = ift.RGSpace(2*(npix,), 2*(fov/npix,))
+
+zm = 0, 0.2, 0.1, ''
+spaceparams = {
+    'target_subdomain': doms,
+    'fluctuations_mean':   1.5,
+    'fluctuations_stddev': 1.,
+    'loglogavgslope_mean': -1.5,
+    'loglogavgslope_stddev': 0.5,
+    'flexibility_mean':   0.01,
+    'flexibility_stddev': 0.001,
+    'asperity_mean':   0.0001,
+    'asperity_stddev': 0.00001,
+    'prefix': 'space'
+}
+timeparams = {
+    'fluctuations_mean':   0.2,
+    'fluctuations_stddev': 1.,
+    'loglogavgslope_mean': -4,
+    'loglogavgslope_stddev': 0.5,
+    'flexibility_mean':   0.01,
+    'flexibility_stddev': 0.001,
+    'asperity_mean':   0.001,
+    'asperity_stddev': 0.0001,
+    'prefix': 'time'
+}
+
+zm_single = 0, 1, 0.3, ''
+spaceparams_single = {
+    'target_subdomain': doms,
+    'fluctuations_mean':   0.75,
+    'fluctuations_stddev': 0.3,
+    'flexibility_mean':   0.001,
+    'flexibility_stddev': 0.0001,
+    'asperity_mean':   0.0001,
+    'asperity_stddev': 0.00001,
+    'loglogavgslope_mean': -1.5,
+    'loglogavgslope_stddev': 0.5,
+    'prefix': 'space'
+}
+
+cfm_single = ift.CorrelatedFieldMaker.make(*zm_single)
+cfm_single.add_fluctuations(**spaceparams_single)
+
+logsky = cfm_single.finalize(prior_info=0)
+
+
+sky_single = vlbi.normalize(doms) @ logsky.exp()
+def smoothed_sky_model(sigma):
+    cfm_single = ift.CorrelatedFieldMaker.make(*zm_single)
+    cfm_single.add_fluctuations(**spaceparams_single)
+
+    logsky = cfm_single.finalize(prior_info=0)
+
+    SMO = ift.HarmonicSmoothingOperator(logsky.target,sigma*vlbi.MUAS2RAD)
+
+    sky_single = vlbi.normalize(doms) @ SMO @ (SMO @ logsky).exp()
+    return sky_single
+
+def excitation_smoother(sigma,sigma_old, xi):
+    codomain = xi.domain[0]
+    kernel = codomain.get_k_length_array()
+    old_sig = codomain.get_fft_smoothing_kernel_function(sigma_old*vlbi.MUAS2RAD)
+    new_sig = codomain.get_fft_smoothing_kernel_function(sigma*vlbi.MUAS2RAD)
+    weights = old_sig(kernel).clip(1e-200,1e300)/new_sig(kernel).clip(1e-200,1e300)
+
+    return weights * xi
+
+def smoothed_movie_model(sigma_x,sigma_t):
+
+    domt = ift.RGSpace(npixt, dt)
+    dom = ift.makeDomain((domt, doms))
+    domt_zeropadded = ift.RGSpace(int(oversampling_factor * npixt), dt)
+
+    cfm = ift.CorrelatedFieldMaker.make(*zm)
+    cfm.add_fluctuations(domt_zeropadded, **timeparams)
+    cfm.add_fluctuations(**spaceparams)
+    logsky = cfm.finalize(prior_info=0)
+
+    FA_lo = ift.FieldAdapter(logsky.target, 'lo')
+    FA_hi = ift.FieldAdapter(logsky.target, 'hi')
+    xi_hi = ift.FieldAdapter(logsky.domain['xi'], 'xi_hi')
+    id_hi = ift.FieldAdapter(logsky.domain['xi'], 'xi').adjoint @ xi_hi
+    xi_lo = ift.FieldAdapter(logsky.domain['xi'], 'xi_lo')
+    id_lo = ift.FieldAdapter(logsky.domain['xi'], 'xi').adjoint @ xi_lo
+
+    expected_difference = 1e-2
+    logsky_1 = (FA_hi.adjoint @ logsky).partial_insert(id_hi)
+    logsky_2 = (FA_lo.adjoint
+                @ logsky).partial_insert(id_lo).scale(expected_difference)
+
+    ls_lo =  (FA_hi + FA_lo)
+    ls_hi =  (FA_hi - FA_lo)
+    t_SMO = ift.HarmonicSmoothingOperator(ls_lo.target,sigma_t,space=0)
+    x_SMO = ift.HarmonicSmoothingOperator(ls_lo.target,sigma_x*vlbi.MUAS2RAD,space=1)
+
+    SMO = t_SMO @ x_SMO
+    logsky_lo_smo = SMO @ ls_lo
+    logsky_hi_smo = SMO @ ls_hi
+
+    sky_lo_smo =  SMO @ logsky_lo_smo.exp()
+    sky_hi_smo =  SMO @ logsky_hi_smo.exp()
+
+    sky_lo = FA_lo.adjoint @ sky_lo_smo
+    sky_hi = FA_hi.adjoint @ sky_hi_smo
+
+    sky_mf = (sky_hi + sky_lo) @ (logsky_1 + logsky_2)
+    smooth_movie = vlbi.normalize(logsky_mf.target) @ sky_mf
+    return smooth_movie
+
+def movie_excitation_smoother(sigma_x, sigma_t, sigma_x_old, sigma_t_old, xi_hi,xi_lo):
+
+    def get_kernel(sigma_x, sigma_t,domain):
+
+        def get_single_kernel(domain, sigma):
+            k = domain.get_k_length_array()
+            kernel = domain.get_fft_smoothing_kernel_function(sigma)
+            sig = kernel(k)
+            return sig
+
+        sig_t = get_single_kernel(domain[0],sigma_t).val
+        sig_x = get_single_kernel(domain[1],sigma_x*vlbi.MUAS2RAD).val
+        sig = np.outer(sig_t, sig_x)
+        sig = sig.reshape(sig_t.shape + sig_x.shape )
+        sig = ift.makeField(domain, sig)
+        return sig
+    sig_old = get_kernel(sigma_x_old, sigma_t_old, xi_hi.domain)
+    sig_new = get_kernel(sigma_x, sigma_t, xi_hi.domain)
+
+    weights = sig_old.clip(1e-200,1e300)/sig_new.clip(1e-200,1e300)
+
+    return xi_hi * weights, xi_lo * weights
+
+
+
+
+
+
+
+pspec_single = cfm_single.amplitude
+
+domt = ift.RGSpace(npixt, dt)
+dom = ift.makeDomain((domt, doms))
+domt_zeropadded = ift.RGSpace(int(oversampling_factor*npixt), dt)
 
-cfm = ift.CorrelatedFieldMaker.make(0.2, 1e-1, '')
-cfm.add_fluctuations(ift.RGSpace(int(2*npixt), dt),
-                     fluctuations_mean=.2,
-                     fluctuations_stddev=1.,
-                     flexibility_mean=0.1,
-                     flexibility_stddev=0.001,
-                     asperity_mean=0.01,
-                     asperity_stddev=0.001,
-                     loglogavgslope_mean=-4,
-                     loglogavgslope_stddev=0.5,
-                     prefix='time')
-cfm.add_fluctuations(doms,
-                     fluctuations_mean=cfm.moment_slice_to_average(1.5),
-                     fluctuations_stddev=1.,
-                     flexibility_mean=0.3,
-                     flexibility_stddev=0.001,
-                     asperity_mean=0.01,
-                     asperity_stddev=0.001,
-                     loglogavgslope_mean=-1.5,
-                     loglogavgslope_stddev=0.5,
-                     prefix='space')
+cfm = ift.CorrelatedFieldMaker.make(*zm)
+cfm.add_fluctuations(domt_zeropadded, **timeparams)
+cfm.add_fluctuations(**spaceparams)
 logsky = cfm.finalize(prior_info=0)
 
 FA_lo = ift.FieldAdapter(logsky.target, 'lo')
@@ -59,8 +203,9 @@ id_hi = ift.FieldAdapter(logsky.domain['xi'], 'xi').adjoint @ xi_hi
 xi_lo = ift.FieldAdapter(logsky.domain['xi'], 'xi_lo')
 id_lo = ift.FieldAdapter(logsky.domain['xi'], 'xi').adjoint @ xi_lo
 
+expected_difference = 1e-2
 logsky_1 = (FA_hi.adjoint @ logsky).partial_insert(id_hi)
-logsky_2 = (FA_lo.adjoint @ logsky).partial_insert(id_lo).scale(0.01)
+logsky_2 = (FA_lo.adjoint @ logsky).partial_insert(id_lo).scale(expected_difference)
 logsky_lo = FA_lo.adjoint @ (FA_hi + FA_lo)
 logsky_hi = FA_hi.adjoint @ (FA_hi - FA_lo)
 logsky_mf = (logsky_hi + logsky_lo) @ (logsky_1 + logsky_2)
diff --git a/movie_start.py b/movie_start.py
new file mode 100644
index 0000000..d3e662c
--- /dev/null
+++ b/movie_start.py
@@ -0,0 +1,98 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2020 Max-Planck-Society
+# Author: Philipp Arras
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+import nifty6 as ift
+import src as vlbi
+from config import comm, rank, master
+from config import oversampling_factor as ofac
+from config import sky_movie_mf
+from config import sky_movie_mf as full_sky
+
+
+
+def main():
+
+    # Prior plots
+    if master:
+        with ift.random.Context(31):
+            p = ift.Plot()
+            n = 5
+            for _ in range(5):  # FIXME: should this be 'range(n)'?
+                pos = ift.from_random(sky_movie_mf.domain, 'normal')
+                ss = sky_movie_mf(pos)['hi']
+                mi, ma = 0, np.max(ss.val)
+                for ii in range(0, ss.shape[0], 4):
+                    extr = ift.DomainTupleFieldInserter(ss.domain, 0, (ii,)).adjoint
+                    p.add(extr(ss), zmin=mi, zmax=ma)
+            p.output(name=f'prior_samples.png', ny=n, xsize=28, ysize=9)
+
+            print('Start inf check')
+            for ii in range(20):
+                pp = ift.from_random(sky_movie_mf.domain, 'normal') #* 2.5
+                sky_movie_mf(pp)
+
+    sky = full_sky
+    fld = vlbi.gaussian_profile(sky.target['hi'][1], 30 *vlbi.MUAS2RAD)
+    fld = ift.ContractionOperator(sky.target['hi'], 0).adjoint(fld).val
+    multi = ift.makeField(sky.target, {'lo': fld, 'hi': fld})
+    output = vlbi.normalize(multi.domain)(multi)
+
+    cov = 1e-3*max([vv.max() for vv in output.val.values()])**2
+    dtype = list(set([ff.dtype for ff in output.values()]))
+    if len(dtype) != 1:
+        raise ValueError('Only MultiFields with one dtype supported.')
+    dtype = dtype[0]
+    invcov = ift.ScalingOperator(output.domain, cov).inverse
+    d = output + invcov.draw_sample_with_dtype(dtype, from_inverse=True)
+    lh = ift.GaussianEnergy(d, invcov) @ sky
+    H = ift.StandardHamiltonian(
+        lh, ic_samp=ift.AbsDeltaEnergyController(0.5, iteration_limit=100, convergence_level=2, name=f'CG(task {rank})'))
+    pos = 0.1*ift.from_random(sky.domain, 'normal')
+    cst = ('spaceasperity', 'spaceflexibility', 'spacefluctuations', 'spaceloglogavgslope', 'spacespectrum', 'timeasperity', 'timeflexibility', 'timefluctuations', 'timeloglogavgslope', 'timespectrum')
+    minimizer = ift.NewtonCG(ift.GradientNormController(iteration_limit=10, name='findpos' if master else None))
+    n = 2
+    for ii in range(n):
+        if master:
+            ift.logger.info(f'Start iteration {ii+1}/{n}')
+        kl = ift.MetricGaussianKL(pos, H, 2, comm=comm, mirror_samples=True, constants=cst, point_estimates=cst)
+        kl, _ = minimizer(kl)
+        pos = kl.position
+
+        if master:
+            movie = vlbi.freq_avg(vlbi.strip_oversampling(sky(pos), ofac)).val
+            ma = np.max(movie)
+            fig, axs = plt.subplots(5, 6, figsize=(12, 12))
+            axs = axs.ravel()
+            for ii in range(movie.shape[0]):
+                axx = axs[ii]
+                axx.imshow(movie[ii].T, vmin=0, vmax=ma)
+                axx.set_title(f'Frame {ii}')
+            fig.savefig(f'movie_start.png')
+            plt.close()
+
+    if master:
+        vlbi.save_hdf5('initial.h5', pos)
+        with open("initial_random_state.txt", "wb") as f:
+            f.write(ift.random.getState())
+
+
+if __name__ == '__main__':
+    # Change seed for the whole reconstruction here
+    # ift.random.push_sseq_from_seed(42)
+    main()
diff --git a/reconstruction.py b/reconstruction.py
index 50268d1..733831b 100644
--- a/reconstruction.py
+++ b/reconstruction.py
@@ -12,89 +12,198 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 # Copyright(C) 2019-2020 Max-Planck-Society
-# Author: Philipp Arras, Philipp Frank, Philipp Haim, Reimar Leike,
-# Jakob Knollmueller
+# Author: Philipp Arras, Philipp Haim, Reimar Leike, Jakob Knollmueller
 
-import os
-os.environ["MKL_NUM_THREADS"] = "1"
-os.environ["OMP_NUM_THREADS"] = "1"
 import sys
+import gc
+from mpi4py import MPI
 from functools import reduce
 from operator import add
-
-import numpy as np
+from time import time
 
 import nifty6 as ift
 import src as vlbi
-from config import doms, dt, npixt
+from config import comm, nranks, rank, master
+from config import doms, dt, eps, min_timestamps_per_bin, npixt, nthreads
 from config import sky_movie_mf as sky
 from config import startt
 
-if __name__ == '__main__':
-    np.seterr(all='raise')
-    np.random.seed(42)
+def stat_plotting(pos, KL):
+    if master:
+        sc = ift.StatCalculator()
+        sc_mean = ift.StatCalculator()
+        sc_spectral = ift.StatCalculator()
+        skysamples = {'hi': [], 'lo': []}
+
+    # CAUTION: the loop and the if-clause cannot be interchanged!
+    for ss in KL.samples:
+        if master:
+            samp = sky(pos+ss)
+            skysamples['hi'].append(samp['hi'])
+            skysamples['lo'].append(samp['lo'])
+            sc.add(samp)
+            sc_mean.add(0.5*(samp['hi']+samp['lo']).val)
+            sc_spectral.add(2 * (samp['lo']-samp['hi']).val / (samp['lo']+samp['hi']).val )
+
+def optimization_heuristic(ii, lh_full, lh_amp, lh_ph, ind_full, ind_amp, ind_ph):
+    cut = 2 if dt == 24 else 6
+    N_samples = 10 * (1 + ii // 8)
+    N_iterations = 4 * (4 + ii // 4) if ii<50 else 20
+
+    eps = 0.1
+    clvl = 3 if ii < 20 else 2
+
+    ic_newton = ift.AbsDeltaEnergyController(eps, iteration_limit=N_iterations,
+                                             name=f'it_{ii}' if master else None,
+                                             convergence_level=clvl)
+
+    if ii < 50:
+        minimizer = ift.VL_BFGS(ic_newton)
+    else:
+        minimizer = ift.NewtonCG(ic_newton)
+
+
+
+    if ii < 30:
+        lh = []
+        active_inds = []
+        if ii % 2 == 0 or ii < 10:
+            for key in ind_full.keys():
+                if int(key[0])<cut and key[1] == '_': #FIXME: dirty fix to select
+                                                    # the data of the first two days
+                    active_inds.append(key)
+                    lh.append(ind_full[key])
+        elif ii % 4 == 1:
+            for key in ind_amp.keys():
+                if int(key[0])<cut and key[1] == '_': #FIXME
+                    active_inds.append(key)
+                    lh.append(ind_amp[key])
+        else:
+            for key in ind_ph.keys():
+                if int(key[0])<cut and key[1] == '_': #FIXME
+                    active_inds.append(key)
+                    lh.append(ind_ph[key])
 
-    pre = sys.argv[1]
-    if pre not in ['crescent', 'disk', 'blobs']:
+        conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
+        lh = reduce(add, lh) @ conv
+    else:
+        if ii % 2 == 0 or ii > 50:
+            lh = lh_full
+        elif ii % 4 == 1:
+            lh = lh_amp
+        else:
+            lh = lh_ph
+
+    return minimizer, N_samples, N_iterations, lh
+
+
+def setup():
+    if len(sys.argv) != 3:
         raise RuntimeError
-    path = f'data/{pre}'
+    _, pre_data, fname_input = sys.argv
+
+    pre_output = pre_data
+
+    ndata = {}
+    lh_full = []
+    lh_amp = []
+    lh_ph = []
+    ind_full = {}
+    ind_amp = {}
+    ind_ph = {}
+    active_inds = []
 
-    # Build data model
-    lh, active_inds = [], []
     for freq in vlbi.data.FREQS:
-        rawd = vlbi.combined_data(path, [freq], identify_short_baselines=['APAA', 'SMJC'])
-        for ii, dd in enumerate(vlbi.time_binning(rawd, dt, startt, npixt*dt)):
+        rawd = vlbi.combined_data(f'data/{pre_data}', [freq], min_timestamps_per_bin) if master else None
+        if nranks > 1:
+            rawd = comm.bcast(rawd)
+        ndata[freq] = []
+        args = {'tmin': startt, 'tmax': npixt*dt, 'delta_t': dt}
+
+        for ii, dd in enumerate(vlbi.time_binning(rawd, **args)):
             if len(dd) == 0:
+                ndata[freq] += [0, ]
                 continue
             ind = str(ii) + "_" + freq
             active_inds.append(ind)
-            v2ph, icovph = vlbi.Visibilities2ClosurePhases(dd)
-            v2ampl, icovampl = vlbi.Visibilities2ClosureAmplitudes(dd)
-            R = vlbi.RadioResponse(doms, dd['uv'], 1e-7).ducktape(ind)
-            vis = ift.makeField(R.target, dd['vis'])
-            ll_ph = ift.GaussianEnergy(v2ph(vis), icovph) @ v2ph
-            ll_ampl = ift.GaussianEnergy(v2ampl(vis), icovampl) @ v2ampl
-            lh.append((ll_ph + ll_ampl) @ R)
-    lh = reduce(add, lh) @ vlbi.DomainTuple2MultiField(sky.target, active_inds)
-    pb = vlbi.gaussian_profile(sky.target['hi'][1], 50*vlbi.MUAS2RAD)
-    pb = ift.ContractionOperator(sky.target['hi'], 0).adjoint(pb)
-    pb = ift.makeOp(ift.MultiField.from_dict({'hi': pb, 'lo': pb}))
-
-    # Plot prior samples
-    for ii in range(4):
-        pp = ift.from_random('normal', sky.domain)
-        vlbi.save_state(sky, pp, pre, f'prior{ii}', [])
-
-    # Minimization
-    pos = 0.1*ift.from_random('normal', sky.domain)
-    vlbi.save_state(sky, pos, pre, 'initial', [])
-    for ii in range(30):
-        if ii < 10:
-            N_samples = 1
-            N_iterations = 15
-        elif ii < 20:
-            N_samples = 2
-            N_iterations = 15
-        elif ii < 25:
-            N_samples = 3
-            N_iterations = 15
-        elif ii < 28:
-            N_samples = 10
-            N_iterations = 20
-        else:
-            N_samples = 20
-            N_iterations = 30
-        print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}')
-        cstkeys = set(pos.domain.keys()) - set(['xi_lo', 'xi_hi'])
-        cst = cstkeys if ii < 20 else []
-        H = ift.StandardHamiltonian(lh @ pb @ sky if ii < 15 else lh @ sky,
-                                    ift.GradientNormController(iteration_limit=40))
-        KL = ift.MetricGaussianKL(pos, H, N_samples, mirror_samples=True,
-                                  lh_sampling_dtype=np.complex128,
-                                  point_estimates=cst, constants=cst)
-        dct = {'deltaE': 0.1, 'iteration_limit': N_iterations,
-               'name': f'it_{ii}', 'convergence_level': 2}
-        minimizer = ift.NewtonCG(ift.AbsDeltaEnergyController(**dct))
+
+            vis2closph, evalsph, _  = vlbi.Visibilities2ClosurePhases(dd)
+            vis2closampl, evalsampl = vlbi.Visibilities2ClosureAmplitudes(dd)
+            nfft = vlbi.RadioResponse(doms, dd['uv'], eps, nthreads)
+            vis = ift.makeField(nfft.target, dd['vis'])
+
+
+            ndata[freq] += [vis2closph.target.size + vis2closampl.target.size,]
+
+            lhph = ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph
+            lhamp = ift.GaussianEnergy(mean=vis2closampl(vis)) @ vis2closampl
+
+            llh_full = reduce(add, [lhamp, lhph]) @ nfft.ducktape(ind)
+            lh_full.append(llh_full)
+            llh_amp = reduce(add, [lhamp]) @ nfft.ducktape(ind)
+            lh_amp.append(llh_amp)
+            llh_ph = reduce(add, [lhph]) @ nfft.ducktape(ind)
+            lh_ph.append(llh_ph)
+
+            ind_ph[ind] = llh_ph
+            ind_amp[ind] = llh_amp
+            ind_full[ind] = llh_full
+
+            foo = reduce(add, [lhph, lhamp]) @ nfft.ducktape(ind)
+            ift.extra.check_jacobian_consistency(foo, ift.from_random(foo.domain),
+                                                 tol=1e-5, ntries=10)
+
+
+    conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
+    lh_full = reduce(add, lh_full) @ conv
+    lh_amp = reduce(add, lh_amp) @ conv
+    lh_ph = reduce(add, lh_ph) @ conv
+
+    pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None
+    if nranks > 1:
+        pos = comm.bcast(pos)
+
+
+    ic = ift.AbsDeltaEnergyController(0.5, iteration_limit=200, name=f'Sampling(task {rank})', convergence_level=3)
+
+    if master:
+        t0 = time()
+        (lh_full @ sky)(pos)
+        print(f'Likelihood call: {1000*(time()-t0):.0f} ms')
+    return pos, lh_full, lh_amp, lh_ph, sky, ic, pre_output, ind_full, ind_amp, ind_ph
+
+
+# Encapsulate everything in functions, to avoid as many (unintended) global variables as possible
+def main():
+    # These following lines are to verify the scan averaging
+    with open("time_averaging.txt", 'w') as f:
+        # delete the file such that new lines can be appended
+        f.write("min max avg med\n")
+    pos, lh_full, lh_amp, lh_ph, sky, ic, pre_output, ind_full, ind_amp, ind_ph = setup()
+
+    for ii in range(60):
+        gc.collect()
+        minimizer, N_samples, N_iterations, lh = optimization_heuristic(
+                ii, lh_full, lh_amp, lh_ph,ind_full, ind_amp, ind_ph)
+
+        if master:
+            print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}')
+        
+        ll = lh @ sky
+        H = ift.StandardHamiltonian(ll, ic)
+        KL = ift.MetricGaussianKL(pos, H, N_samples, comm=comm, mirror_samples=True)
+        del ll, H
+        gc.collect()
+
         KL, _ = minimizer(KL)
         pos = KL.position
-        vlbi.save_state(sky, pos, pre, ii, KL.samples)
+
+        vlbi.save_state(sky, KL.position, f'{pre_output}_', ii, samples=KL.samples, master=master)
+        del KL
+        gc.collect()
+
+
+if __name__ == '__main__':
+    with open("initial_random_state.txt", "rb") as f:
+        ift.random.setState(f.read())
+    main()
diff --git a/src/__init__.py b/src/__init__.py
index 44d1339..10c7ac0 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -1,5 +1,8 @@
-from .closure import Visibilities2ClosureAmplitudes, Visibilities2ClosurePhases
+from .closure import Visibilities2ClosureAmplitudes, Visibilities2ClosurePhases, set_SNR_cutoff
 from .constants import *
 from .data import combined_data, read_data, read_data_raw, time_binning
+#from .fitting import radius_fitting
+#from .plotting import *
+#from .validation_models import *
 from .response import RadioResponse
 from .sugar import *
diff --git a/src/closure.py b/src/closure.py
index 6f8a7d8..23e15a4 100644
--- a/src/closure.py
+++ b/src/closure.py
@@ -12,138 +12,164 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 # Copyright(C) 2019-2020 Max-Planck-Society
-# Author: Philipp Arras, Philipp Frank, Philipp Haim, Reimar Leike,
-# Jakob Knollmueller
 
 from itertools import combinations
 
 import numpy as np
-from scipy.linalg import eigh
 from scipy.sparse import coo_matrix
 from scipy.sparse.linalg import aslinearoperator
 
 import nifty6 as ift
 
-from .sugar import baselines, binom, n_baselines
+from .sugar import baselines, binom2, binom3, binom4
+from .constants import MAXSHORTBASELINE
 
+_SNR_cutoff = 0
 
-def Visibilities2ClosurePhases(d):
-    rows = []
-    cols0, cols1, cols2 = [], [], []
-    icov_rows, icov_cols, icov_data = [], [], []
-    offset_closure, offset_vis = 0, 0
-    for tt in np.unique(d['time']):
-        ind = tt == d['time']
-        aa1 = d['ant1'][ind]
-        aa2 = d['ant2'][ind]
-        ww = (d['sigma'][ind]/abs(d['vis'][ind]))**2
-        aa = set(aa1) | set(aa2)
-        if len(aa) < 3:
-            offset_vis += len(aa1)
-            continue
-        psi = closure_phase_design_matrix(len(aa))
-        ww, missing_inds = insert_missing_baselines_into_weights(aa1, aa2, ww)
-        clos_desgn_mat = psi @ np.diag(ww)
-        goodclosures = np.sum(clos_desgn_mat != 0, axis=1) == 3
-        clos_desgn_mat = clos_desgn_mat[goodclosures]
-        if clos_desgn_mat.shape[0] == 0:
-            offset_vis += len(aa1)
-            continue
-        clos_desgn_mat = nonredundant_closure_set(clos_desgn_mat)
-        ww[ww == 0] = np.nan
-        clos_desgn_mat = np.round(clos_desgn_mat @ np.diag(1/ww)).astype(
-            np.int)
-        for ii in missing_inds[::-1]:
-            clos_desgn_mat = np.delete(clos_desgn_mat, ii, axis=1)
-        ww = (d['sigma'][ind]/abs(d['vis'][ind]))**2
-        assert clos_desgn_mat.shape[1] == len(aa1)
-        for i in range(clos_desgn_mat.shape[0]):
-            row = clos_desgn_mat[i]
-            rows += [offset_closure + i]
-            cols0 += [np.where(row == -1)[0][0] + offset_vis]
-            cols1 += [np.where(row == 1)[0][0] + offset_vis]
-            cols2 += [np.where(row == 1)[0][1] + offset_vis]
-        tmp_dat, tmp_rows, tmp_cols = sparse_cov(clos_desgn_mat, ww)
-        icov_data += tmp_dat
-        icov_rows += [xx + offset_closure for xx in tmp_rows]
-        icov_cols += [xx + offset_closure for xx in tmp_cols]
-        offset_closure += clos_desgn_mat.shape[0]
-        offset_vis += clos_desgn_mat.shape[1]
-    assert len(d['vis']) == offset_vis
-    shp = (offset_closure, offset_vis)
-    term0 = SparseMatrixOperator((np.ones(len(rows)), (rows, cols0)), shp)
-    term1 = SparseMatrixOperator((np.ones(len(rows)), (rows, cols1)), shp)
-    term2 = SparseMatrixOperator((np.ones(len(rows)), (rows, cols2)), shp)
-    invcov = SparseMatrixOperator((icov_data, (icov_rows, icov_cols)),
-                                  (shp[0], shp[0]))
-    invcov = ift.SandwichOperator.make(invcov.adjoint)
-    return (term0.conjugate()*term1*term2) @ ToUnitCircle(term0.domain), invcov
-
-
-def Visibilities2ClosureAmplitudes(d):
-    rows, cols, sign = [], [], []
-    icov_rows, icov_cols, icov_data = [], [], []
+
+def set_SNR_cutoff(x):
+    global _SNR_cutoff
+    _SNR_cutoff = x
+
+
+def Visibilities2ClosureMat(d, amplitudes):
+    rows, cols, values = [], [], []
     offset_closure, offset_vis = 0, 0
+    if not amplitudes:
+        rowsp, colsp, valuesp = [], [], []
+        offset_closurep, offset_visp = 0, 0
+    evals = []
+    relweights = (d['sigma']/abs(d['vis']))**2
+    timestamps = []
     for tt in np.unique(d['time']):
         ind = tt == d['time']
-        aa1 = d['ant1'][ind]
-        aa2 = d['ant2'][ind]
-        ww = (d['sigma'][ind]/abs(d['vis'][ind]))**2
-        aa = set(aa1) | set(aa2)
-        if len(aa) < 4:
-            offset_vis += len(aa1)
+        aa1, aa2 = d['ant1'][ind], d['ant2'][ind]
+        nstations = len(set(aa1) | set(aa2))
+        tm, missing_inds = insert_missing_baselines_into_weights(aa1, aa2, np.ones(sum(ind)))
+        tm = np.diag(tm)
+        if nstations < (4 if amplitudes else 3):
+            offset_vis += sum(ind)
             continue
-        psi = closure_amplitude_design_matrix(len(aa))
-        ww, missing_inds = insert_missing_baselines_into_weights(aa1, aa2, ww)
-        zero_baselines = np.zeros(len(ww), dtype=int)
-        jj = 0
-        for ii in range(len(zero_baselines)):
-            if ii in missing_inds:
+        if amplitudes:
+            psi = closure_amplitude_design_matrix(nstations)
+            nontrivial = remove_short_diagonals(psi, aa1, aa2, d['uv'][ind])
+            psi = psi @ tm
+            goodclosures = np.logical_and(np.sum(psi != 0, axis=1) == 4, nontrivial)
+            psi = psi[goodclosures]
+            if psi.shape[0] == 0:
+                offset_vis += sum(ind)
                 continue
-            zero_baselines[ii] = np.linalg.norm(d['uv'][jj]) < 1e7
-            jj += 1
-
-        nontrivial = (np.abs(psi) @ zero_baselines) < 2
-        clos_desgn_mat = psi @ np.diag(ww)
-        goodclosures = np.sum(clos_desgn_mat != 0, axis=1) == 4
-        goodclosures = np.logical_and(goodclosures, nontrivial)
-        clos_desgn_mat = clos_desgn_mat[goodclosures]
-        if clos_desgn_mat.shape[0] == 0:
-            offset_vis += len(aa1)
-            continue
-        clos_desgn_mat = nonredundant_closure_set(clos_desgn_mat)
-        ww[ww == 0] = np.nan
-        clos_desgn_mat = np.round(clos_desgn_mat @ np.diag(1/ww)).astype(
-            np.int)
-        for ii in missing_inds[::-1]:
-            clos_desgn_mat = np.delete(clos_desgn_mat, ii, axis=1)
-        ww = (d['sigma'][ind]/abs(d['vis'][ind]))**2
-        for i in range(clos_desgn_mat.shape[0]):
-            row = clos_desgn_mat[i]
-            for c in np.where(row != 0)[0]:
+            mdecomp = psi = np.delete(psi, missing_inds, axis=1)
+        else:
+            psi = closure_phase_design_matrix(nstations) @ tm
+            #Unused as removed from data directly:
+            #psi = remove_short_baselines_from_matrix(psi, aa1, aa2, d['uv'][ind])
+            psi = psi[np.sum(psi != 0, axis=1) == 3]
+            if psi.shape[0] == 0:
+                offset_vis += sum(ind)
+                continue
+            psi = to_nonredundant(np.delete(psi, missing_inds, axis=1))
+            mdecomp = np.diag(1.j*np.exp(1.j*psi @ np.log(d['vis'][ind]).imag)) @ psi
+
+        U, ivsq = get_decomp(mdecomp, relweights[ind])
+        evals += list(1./ivsq.diagonal())
+        if amplitudes:
+            psi = ivsq @ U @ psi
+        else:
+            proj = ivsq @ U
+
+        for i in range(psi.shape[0]):
+            row = psi[i]
+            for c in range(psi.shape[1]):
                 rows += [i + offset_closure]
                 cols += [c + offset_vis]
-                sign += [row[c]]
-        tmp_dat, tmp_rows, tmp_cols = sparse_cov(clos_desgn_mat, ww)
-        icov_data += tmp_dat
-        icov_rows += [xx + offset_closure for xx in tmp_rows]
-        icov_cols += [xx + offset_closure for xx in tmp_cols]
-        offset_closure += clos_desgn_mat.shape[0]
-        offset_vis += clos_desgn_mat.shape[1]
-    smo = SparseMatrixOperator((sign, (rows, cols)),
-                               (offset_closure, offset_vis))
-    invcov = SparseMatrixOperator((icov_data, (icov_rows, icov_cols)),
-                                  (offset_closure,)*2)
-    invcov = ift.SandwichOperator.make(invcov.adjoint)
-    inp = ift.ScalingOperator(ift.UnstructuredDomain(d['vis'].shape), 1.)
-    return smo @ ((inp*inp.conjugate()).real).log().scale(0.5), invcov
+                values += [row[c]]
+        offset_closure += psi.shape[0]
+        offset_vis += psi.shape[1]
+        if not amplitudes:
+            for i in range(proj.shape[0]):
+                row = proj[i]
+                for c in range(proj.shape[1]):
+                    rowsp += [i + offset_closurep]
+                    colsp += [c + offset_visp]
+                    valuesp += [row[c]]
+            offset_closurep += proj.shape[0]
+            offset_visp += proj.shape[1]
+            timestamps += [tt,]*proj.shape[0]
+        else:
+            timestamps += [tt,]*psi.shape[0]
+
+    smo = SparseMatrixOperator((values, (rows, cols)), (offset_closure, offset_vis))
+    evals = ift.makeField(smo.target, np.array(evals))
+    if amplitudes:
+        timestamps = ift.makeField(smo.target, np.array(timestamps))
+        return smo, evals, timestamps
+    clos2eig = SparseMatrixOperator((valuesp, (rowsp, colsp)), (offset_closurep, offset_visp))
+    timestamps = ift.makeField(clos2eig.target, np.array(timestamps))
+    return smo, clos2eig, evals, timestamps
+
+
+def Visibilities2ClosureAmplitudes(d, want_timestamps = False):
+    vis2closeig, evals, times = Visibilities2ClosureMat(d, True)
+    inp = ift.ScalingOperator(vis2closeig.domain, 1.)
+    if want_timestamps:
+        return vis2closeig @ inp.log().real, evals, times
+    return vis2closeig @ inp.log().real, evals
+
+
+def Visibilities2ClosurePhases(d, want_timestamps = False):
+    smo, clos2eig, evals, times = Visibilities2ClosureMat(d, False)
+    inp = ift.ScalingOperator(smo.domain, 1.)
+    op = (smo @ inp.log().imag)
+    ima = ift.Imaginizer(op.target).adjoint
+    vis2clos = ima(op).exp()
+    if want_timestamps:
+        return clos2eig @ vis2clos, evals, vis2clos, times
+    return clos2eig @ vis2clos, evals, vis2clos
+
+
+def get_decomp(psi, diag, triv_cutoff=1e-9):
+    m = psi@np.diag(diag)@psi.conj().T
+    ev, eh = np.linalg.eigh(m)
+    inds = ev > triv_cutoff
+    U = eh.conj().T[inds]
+    ivsq = np.diag(1./np.sqrt(ev[inds]))
+    assert np.allclose(np.eye(U.shape[0]), ivsq @ U @ m @ U.conj().T @ ivsq)
+    assert ivsq.dtype == float
+    return U, ivsq
+
+
+def remove_short_diagonals(psi, aa1, aa2, uv):
+    assert aa1.shape == aa2.shape
+    assert (aa1.size, 2) == uv.shape
+    assert psi.shape[1] >= aa1.size
+
+    # gets an array with all possible baselines where there is a 0 if the
+    # baselines is missing and ii+1 if it is the ii-th in the data
+    missing_baseline_index, _ = insert_missing_baselines_into_weights(aa1, aa2, np.arange(len(aa1))+1)
+    nontrivial = np.ones(psi.shape[0], dtype=bool)
+    for ii in range(len(nontrivial)):
+        bls = np.where(psi[ii] != 0)[0]
+        bls = missing_baseline_index[bls]
+        if np.any(0 == bls):
+            continue
+        bls = (bls-1).astype(np.int64)
+        b0 = bls[0]
+        for b in bls[1:]:
+            if aa1[b] == aa1[b0] or aa2[b] == aa2[b0]:
+                if np.linalg.norm(uv[b0]-uv[b]) < MAXSHORTBASELINE:
+                    nontrivial[ii] = False
+            if aa2[b] == aa1[b0] or aa1[b] == aa2[b0]:
+                if np.linalg.norm(uv[b0]+uv[b]) < MAXSHORTBASELINE:
+                    nontrivial[ii] = False
+    return nontrivial
 
 
 def visibility_design_matrix(n):
     if n < 2:
         raise ValueError
-    lst = range(n*(n - 1)//2)
-    x = np.zeros((n_baselines(n), n), dtype=int)
+    lst = range(binom2(n))
+    x = np.zeros((binom2(n), n), dtype=int)
     x[(lst, [ii for ii in range(n) for _ in range(ii + 1, n)])] = 1
     x[(lst, [jj for ii in range(n) for jj in range(ii + 1, n)])] = -1
     return x
@@ -152,7 +178,7 @@ def visibility_design_matrix(n):
 def closure_phase_design_matrix(n):
     if n < 3:
         raise ValueError
-    x = np.zeros((binom(n, 3), n_baselines(n)), dtype=int)
+    x = np.zeros((binom3(n), binom2(n)), dtype=int)
     n = n - 1
     vdm = visibility_design_matrix(n)
     nb = vdm.shape[0]
@@ -171,7 +197,7 @@ def closure_amplitude_design_matrix(n):
     block[([0, 0, 1, 1, 2, 2], [0, 5, 0, 5, 2, 3])] = 1
     block[([0, 0, 1, 1, 2, 2], [1, 4, 2, 3, 1, 4])] = -1
     bl = baselines(range(n))
-    x = np.zeros((m*binom(n, 4), n_baselines(n)))
+    x = np.zeros((m*binom4(n), binom2(n)))
     for ii, ants in enumerate(combinations(range(n), 4)):
         inds = [bl.index(bb) for bb in baselines(ants)]
         x[ii*m:(ii + 1)*m, inds] = block
@@ -179,16 +205,16 @@ def closure_amplitude_design_matrix(n):
 
 
 def insert_missing_baselines_into_weights(ants1, ants2, weights):
+    weights = np.copy(weights)
     missing_inds = []
     aa = set(ants1) | set(ants2)
-    if n_baselines(len(aa)) != len(weights):
+    if binom2(len(aa)) != len(weights):
         baselines = list(zip(ants1, ants2))
         ants = np.sort(list(aa))
         counter, missing_inds = 0, []
         for ii, xx in enumerate(ants):
             for yy in ants[ii + 1:]:
                 if (xx, yy) not in baselines:
-                    print('Missing baseline')
                     missing_inds.append(counter)
                 counter += 1
         for ii in missing_inds:
@@ -196,39 +222,21 @@ def insert_missing_baselines_into_weights(ants1, ants2, weights):
     return weights, missing_inds
 
 
-def nonredundant_closure_set(weighted_design_matrix):
-    closureweights = np.sum(np.abs(weighted_design_matrix), axis=1)
-    inds = np.argsort(closureweights)
-    sorted_design_matrix = weighted_design_matrix[inds]
-    rank = np.linalg.matrix_rank(weighted_design_matrix)
-    result = sorted_design_matrix[0:1]
+def to_nonredundant(mat):
+    rnk = np.linalg.matrix_rank(mat)
+    result = mat[0:1]
     current_rank = 1
-    for ii in range(1, weighted_design_matrix.shape[0]):
-        if current_rank >= rank:
+    for i in range(1, mat.shape[0]):
+        if current_rank >= rnk:
             break
-        tmp = np.append(result, sorted_design_matrix[ii:ii + 1], axis=0)
-        newrank = np.linalg.matrix_rank(tmp)
-        if newrank > current_rank:
+        tmp = np.append(result, mat[i:i+1], axis=0)
+        new_rank = np.linalg.matrix_rank(tmp)
+        if new_rank > current_rank:
             result = tmp
-            current_rank = newrank
-    if rank > current_rank:
-        raise RuntimeError('Not full basis found')
+            current_rank = new_rank
     return result
 
 
-def sparse_cov(design_matrix, sigma_sq):
-    sigma = np.diag(sigma_sq)
-    cov = design_matrix @ sigma @ design_matrix.T
-    eig_val, eig_vec = eigh(cov)
-    L = np.diag(eig_val**(-.5)) @ eig_vec.T
-    n = L.shape[0]
-    assert n == L.shape[1]
-    d = [L[ii, jj] for ii in range(n) for jj in range(n)]
-    c = n*list(range(n))
-    r = [x for x in range(n) for _ in range(n)]
-    return d, r, c
-
-
 class SparseMatrixOperator(ift.LinearOperator):
     def __init__(self, arg1, shape):
         assert len(shape) == 2
@@ -241,11 +249,3 @@ class SparseMatrixOperator(ift.LinearOperator):
         self._check_input(x, mode)
         f = self._smat.matvec if mode == self.TIMES else self._smat.rmatvec
         return ift.makeField(self._tgt(mode), f(x.val))
-
-
-class ToUnitCircle(ift.Operator):
-    def __init__(self, domain):
-        self._domain = self._target = ift.DomainTuple.make(domain)
-
-    def apply(self, x):
-        return x*((x*x.conjugate()).real.sqrt().one_over())
diff --git a/src/constants.py b/src/constants.py
index fd7093d..37b64b3 100644
--- a/src/constants.py
+++ b/src/constants.py
@@ -11,13 +11,20 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
-# Copyright(C) 2019-2020 Max-Planck-Society
+# Copyright(C) 2019 Max-Planck-Society
 
 import numpy as np
 
+DEG2RAD = np.pi/180.
 SPEEDOFLIGHT = 299792458.
-MUAS2RAD = 1e-6/3600*np.pi/180
+
+ARCMIN2RAD = 1/60*DEG2RAD
+AS2RAD = 1/3600*DEG2RAD
+MUAS2RAD = 1/1e6/3600*DEG2RAD
+
 MAXSHORTBASELINE = 2000000
 KEYS = ['time', 'ant1', 'ant2', 'uv', 'vis', 'sigma']
 FREQS = ['lo', 'hi']
 DAYS = ['095', '096', '100', '101']
+_inches_per_pt = 1/72.27
+TEXTWIDTH = 455.24417*_inches_per_pt
diff --git a/src/data.py b/src/data.py
index d2bd62e..af36e8b 100644
--- a/src/data.py
+++ b/src/data.py
@@ -11,9 +11,7 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
-# Copyright(C) 2019-2020 Max-Planck-Society
-# Author: Philipp Arras, Philipp Frank, Philipp Haim, Reimar Leike,
-# Jakob Knollmueller
+# Copyright(C) 2019 Max-Planck-Society
 
 import csv
 from collections import defaultdict
@@ -22,7 +20,11 @@ import numpy as np
 from scipy.sparse import coo_matrix
 from scipy.sparse.linalg import aslinearoperator
 
-from .constants import DAYS, FREQS, KEYS, MAXSHORTBASELINE
+from .constants import MAXSHORTBASELINE, KEYS, FREQS, DAYS
+
+
+VIS_SNR_THRESHOLD = 1
+REL_SYS_NOISE = 0.01
 
 
 def _short_baseline_indices(uv):
@@ -60,33 +62,65 @@ def read_data_raw(fname):
     }
 
 
-def read_data(prefix, day, freq, identify_short_baselines):
+def read_data(prefix, day, freq, ts_per_bin):
     assert freq in FREQS
     assert day in DAYS
     fname = f'{prefix}_{day}_{freq}.csv'
     d = read_data_raw(fname)
-    ant1, ant2 = d['ant1'], d['ant2']
-    all_antennas = list(set(ant1) | set(ant2))
-    if len(identify_short_baselines) > 0:
-        dct = {aa: ii for ii, aa in enumerate(all_antennas)}
-        if 'APAA' in identify_short_baselines:
-            dct['AP'] = dct['AA']
-        if 'SMJC' in identify_short_baselines:
-            dct['SM'] = dct['JC']
-        ant1 = [dct[ele] for ele in ant1]
-        ant2 = [dct[ele] for ele in ant2]
+
+    # add_systematic_noise_budget(d)
+    # remove_low_snr(d)
+
+    ant1 = d['ant1']
+    ant2 = d['ant2']
     ddct = defaultdict(lambda: len(ddct))
     d['ant1'] = np.array([ddct[ele] for ele in ant1])
     d['ant2'] = np.array([ddct[ele] for ele in ant2])
 
     # Combine over scan
-    for ii in range(1, len(d['time'])):
-        if d['time'][ii] - d['time'][ii - 1] < 2/60:
-            d['time'][ii] = d['time'][ii - 1]
+    if ts_per_bin > 1:
+        tbins = []
+        times = d['time']
+        nval = len(d['time'])
+        i0 = 0
+
+        def fair_share(n, nshare, ishare):
+            return n//nshare + (ishare < (n % nshare))
+
+        while i0 < nval:
+            i = i0+1
+            nscan = 1  # number of different time stamps in the scan
+            while i < nval and times[i]-times[i-1] < 20./3600:  # as long as there are less than 20s between time stamps, we assume we are in the same scan
+                if times[i] != times[i-1]:
+                    nscan += 1
+                i += 1
+            nbin = max(1, nscan//ts_per_bin)  # how many bins to use for this scan
+            for j in range(nbin):
+                n = fair_share(nscan, nbin, j)
+                i = i0+1
+                icnt = 0
+                oldtime = times[i0]
+                while i < nval and icnt < n:
+                    if times[i] != oldtime:
+                        icnt += 1
+                        oldtime = times[i]
+                    if icnt < n:
+                        if icnt == n-1:
+                            tbins += [(times[i0],
+                                times[i],
+                                times[i]-times[i0])]
+                        times[i] = times[i0]  # give all values in this bin the time stamp of the first value
+                        i += 1
+                i0 = i
+        tbsize = np.array([t[2] for t in tbins])
+        print("min, max, avg, med time bins:")
+        print("{} {} {} {}".format(np.amin(tbsize), np.amax(tbsize), np.mean(tbsize), np.median(tbsize)))
+    else:
+        print('No temporal averaging.')
     # End combine over scan
 
     data_ordering(d)
-    remove_zero_baselines(d)
+    remove_short_baselines(d)
 
     identifier = []
     for i in range(len(d['time'])):
@@ -116,13 +150,24 @@ def combine_baselines(d, to_combine):
     N = d['sigma']**2
     cols = []
     rows = []
+    assumed_N = []
+    actual_std = []
     for i, l in enumerate(to_combine):
         cols += [i]*len(l)
         rows += l
+        assumed_N += [np.mean(N[l])]
+        actual_std += [np.std(d['vis'][l])]
     R = aslinearoperator(
         coo_matrix((np.ones_like(rows), (rows, cols)),
                    (to_combine[-1][-1] + 1, len(to_combine))))
     D = 1/R.rmatvec(1/N)
+    assumed_N = np.array(assumed_N)
+    actual_std = np.array(actual_std)
+    quot = actual_std/np.sqrt(assumed_N)
+    with open("time_averaging.txt", 'a') as f:
+        f.write("{} {} {} {}\n".format(np.amin(quot), np.amax(quot), np.mean(quot), np.median(quot)))
+    print("min max avg med")
+    print("{} {} {} {}".format(np.amin(quot), np.amax(quot), np.mean(quot), np.median(quot)))
     m = D*(R.rmatvec(d['vis']/N))
 
     res = {}
@@ -139,18 +184,30 @@ def combine_baselines(d, to_combine):
     return res
 
 
-def remove_zero_baselines(d):
-    """Removes all entries in the data dictionary which have the same antenna
-    labels for ant1 and ant2"""
-    ind = d['ant1'] != d['ant2']
+def remove_short_baselines(d):
+    uvlen = np.linalg.norm(d['uv'], axis=1)
+    ind = np.logical_and(uvlen >= MAXSHORTBASELINE, d['ant1'] != d['ant2'])
+    for kk in d:
+        if kk == 'freq':
+            continue
+        d[kk] = d[kk][ind]
+
+
+def remove_low_snr(d):
+    ind = np.abs(d['vis'])/d['sigma'] > VIS_SNR_THRESHOLD
     for kk in d:
         if kk == 'freq':
             continue
         d[kk] = d[kk][ind]
 
 
+def add_systematic_noise_budget(d):
+    d['sigma'] = np.sqrt((abs(d['vis'])*REL_SYS_NOISE)**2 + d['sigma']**2)
+
+
 def data_ordering(d):
     """Swap ant1, ant2 such that ant1 < ant2, sort after time, ant1, ant2"""
+    # FIXME Check: Every baseline only once per timestamp?
     foo = d['ant1'] > d['ant2']
     tmp = d['ant1'][foo]
     d['ant1'][foo] = d['ant2'][foo]
@@ -162,14 +219,13 @@ def data_ordering(d):
         d[kk] = d[kk][foo]
 
 
-def combined_data(path, freqs, identify_short_baselines=[]):
+def combined_data(prefix, freqs, ts_per_bin):
     d = {kk: [] for kk in KEYS}
     time_offsets = [0, 1, 5, 6]
     for ii, day in enumerate(DAYS):
         time_offset = time_offsets[ii]*24.
         for freq in freqs:
-            dd = read_data(path, day, freq,
-                           identify_short_baselines=identify_short_baselines)
+            dd = read_data(prefix, day, freq, ts_per_bin)
             dd['time'] += time_offset
             if freq == 'lo':
                 dd['time'] += 1e-6
diff --git a/src/response.py b/src/response.py
index 5645968..4eb5686 100644
--- a/src/response.py
+++ b/src/response.py
@@ -11,8 +11,8 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
-# Copyright(C) 2019-2020 Max-Planck-Society
-# Author: Philipp Arras, Martin Reinecke
+# Copyright(C) 2019 Max-Planck-Society
+# Author: Philipp Arras
 
 import nifty_gridder as ng
 import numpy as np
@@ -23,10 +23,11 @@ from .constants import SPEEDOFLIGHT
 
 
 class RadioResponse(ift.LinearOperator):
-    def __init__(self, domain, uv, epsilon):
+    def __init__(self, domain, uv, epsilon, j=1):
         ndata = uv.shape[0]
         self._uvw = np.zeros((ndata, 3), dtype=uv.dtype)
         self._uvw[:, 0:2] = uv
+        self._j = int(j)
         self._eps = float(epsilon)
         self._domain = ift.makeDomain(domain)
         self._target = ift.makeDomain(ift.UnstructuredDomain(ndata))
@@ -39,8 +40,8 @@ class RadioResponse(ift.LinearOperator):
         nx, ny = self._domain[0].shape
         f = np.array([SPEEDOFLIGHT])
         if mode == self.TIMES:
-            res = ng.dirty2ms(self._uvw, f, x, None, dx, dy, self._eps,
-                              False, nthreads=1)
+            res = ng.dirty2ms(self._uvw, f, x, None, dx, dy, self._eps, False,
+                              nthreads=1)
             res = res[:, 0]
         else:
             x = x[:, None]
diff --git a/src/sugar.py b/src/sugar.py
index 919e144..0d59ed3 100644
--- a/src/sugar.py
+++ b/src/sugar.py
@@ -12,65 +12,147 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
 # Copyright(C) 2019-2020 Max-Planck-Society
-# Author: Philipp Arras, Philipp Frank, Philipp Haim, Reimar Leike,
-# Jakob Knollmueller
 
+import h5py as h5
 import numpy as np
-import scipy.special as ss
 
 import nifty6 as ift
 
 
 def normalize(domain):
     domain = ift.makeDomain(domain)
-    one = ift.full(domain, 1.)
-    vdot = ift.VdotOperator(one)
     if isinstance(domain, ift.DomainTuple):
-        assert len(domain) == 1
-        fac = domain[0].scalar_dvol
+        fac = domain.scalar_weight()
+        vdot = ift.ContractionOperator(domain, spaces=None)
     else:
+        assert domain['hi'] is domain['lo']
         fac = domain['hi'].scalar_weight()
-    return ift.ScalingOperator(domain, 1/fac)*(vdot.adjoint @ vdot).one_over()
-
-
-def save_state(sky, pos, pre, name, samples):
-    p = ift.Plot()
-    dom = sky.target['hi']
-    dt = dom[0].distances[0]
-    sc = ift.StatCalculator()
-    if len(samples) == 0:
-        samples = 2*(0*pos,)
-    if len(samples) == 1:
-        samples.append(samples[0])
-    for sam in samples:
-        sk = sky(pos+sam)
-        sk = 0.5*(sk['hi']+sk['lo'])
-        sc.add(sk)
-    vlim = [0., np.max(sc.mean.val)]
-    rel = sc.var.sqrt()/sc.mean
-    varlim = [0., np.max(rel.val)]
-    for ii in range(7):
-        fi = ift.DomainTupleFieldInserter(dom, 0, (int(ii*24/dt),)).adjoint
-        p.add(fi(sc.mean), zmin=vlim[0], zmax=vlim[1], title=f'Day {ii}, Posterior Mean')
-    for ii in range(7):
-        fi = ift.DomainTupleFieldInserter(dom, 0, (int(ii*24/dt),)).adjoint
-        p.add(fi(rel), zmin=varlim[0], zmax=varlim[1], title=f'Day {ii}, Rel. Stddev')
-    for ii in range(7):
-        fi = ift.DomainTupleFieldInserter(dom, 0, (int(ii*24/dt),)).adjoint
-        p.add(fi(sk), zmin=vlim[0], zmax=vlim[1], title=f'Day {ii}, Latent Mean')
-    p.output(nx=7, ny=3, name=f'{pre}{name}.png', xsize=25, ysize=10)
+        vdot_hi = ift.ContractionOperator(domain['hi'], spaces=None)
+        vdot_lo = ift.ContractionOperator(domain['lo'], spaces=None)
+        vdot = vdot_hi.ducktape('hi') + vdot_lo.ducktape('lo')
+    return ift.ScalingOperator(domain, 1)*(vdot.adjoint @ vdot.ptw('reciprocal').scale(1/fac))
+
+
+def save_hdf5(path, field):
+    print(f'Save {path}')
+    if isinstance(field, ift.Field):
+        field = ift.MultiField.from_dict({'single_field': field})
+    fh = h5.File(path, 'w')
+    for key, f in field.items():
+        fh.create_dataset(key, data=f.val, dtype=f.dtype)
+    fh.close()
+
+
+def save_hdf5_iterable(path, iterable, master=True):
+    if master:
+        print(f'Save {path}')
+        fh = h5.File(path, 'w')
+    # CAUTION: the loop and the if-clause cannot be interchanged!
+    for kk, field in enumerate(iterable):
+        if master:
+            grp = fh.create_group(str(kk))
+            if isinstance(field, ift.Field):
+                field = ift.MultiField.from_dict({'single_field': field})
+            for key, f in field.items():
+                vv = f.val if isinstance(f, ift.Field) else f
+                grp.create_dataset(key, data=vv, dtype=f.dtype)
+    if master:
+        fh.close()
+
+
+def save_state(sky, pos, pre, current_iter, samples=None, master=True):
+    if master:
+        save_random_state(pre)  # FIXME
+        save_hdf5(f'{pre}position.h5', pos)
+    if samples is not None:
+        save_hdf5_iterable(f'{pre}samples.h5', samples, master)
+    if master:
+        np.savetxt(f'{pre}current_iteration.txt', np.array([int(current_iter)]),
+                   fmt='%i')
+
+
+# FIXME: not MPI-compatible yet
+def load_state(domain, pre):
+    load_random_state(pre)  # FIXME
+    pos = load_hdf5(f'{pre}position.h5', domain)
+    samples = load_hdf5_lst(f'{pre}samples.h5', domain)
+    current_iter = np.loadtxt(f'{pre}current_iteration.txt')[()]
+    return pos, samples, current_iter
+
+
+def save_random_state(pre):
+    with open(f'{pre}random.pickle', 'wb') as f:
+        f.write(ift.random.getState())
+
+
+def load_random_state(pre):
+    with open(f'{pre}random.pickle', 'rb') as f:
+        ift.random.setState(f.read())
+
+
+def load_hdf5(path, domain=None):
+    print(f'Load {path} ', end='', flush=True)
+    fh = h5.File(path, 'r')
+    if len(fh) == 1 and 'single_field' in fh.keys():
+        res = np.array(fh['single_field'])
+    else:
+        res = {kk: np.array(vv) for kk, vv in fh.items()}
+    fh.close()
+    if domain is None:
+        return res
+    print('ok')
+    return ift.makeField(domain, res)
+
+
+def load_hdf5_lst(path, domain=None, index=None):
+    ss = 'all' if index is None else index
+    print(f'Load {path} {ss} ', end='', flush=True)
+    fh = h5.File(path, 'r')
+    res = []
+    for kk in fh:
+        if index is not None and int(kk) != index:
+            continue
+        obj = fh[kk]
+        if len(obj) == 1 and 'single_field' in obj.keys():
+            res.append(np.array(obj['single_field']))
+        else:
+            res.append({kk: np.array(vv) for kk, vv in obj.items()})
+    fh.close()
+    print('ok')
+    if domain is not None:
+        res = [ift.makeField(domain, rr) for rr in res]
+    if index is None:
+        return res
+    assert len(res) == 1
+    return res[0]
+
+
+def len_hdf5_lst(path):
+    print(f'Open {path} ', end='', flush=True)
+    with h5.File(path, 'r') as f:
+        n = len(f)
+    print('ok')
+    return n
 
 
 def baselines(alst):
     return [(aa, bb) for ii, aa in enumerate(alst) for bb in alst[ii + 1:]]
 
 
-def n_baselines(n_antennas):
-    return binom(n_antennas, 2)
+def binom2(n):
+    return (n*(n-1))//2
+
+
+def binom3(n):
+    return (n*(n-1)*(n-2))//6
+
 
+def binom4(n):
+    return (n*(n-1)*(n-2)*(n-3))//24
 
-def binom(n, m):
-    return int(round(ss.binom(n, m)))
+
+def sigmoid(x):
+    return .5*(1 + np.tanh(x))
 
 
 def gaussian_profile(dom, rad):
@@ -83,6 +165,17 @@ def gaussian_profile(dom, rad):
     profile = 1/(2*np.pi*rad**2)*np.exp(-0.5*(xx**2 + yy**2)/rad**2)
     return ift.makeField(dom, profile)
 
+#FIXME unused
+def lognormal_to_normal(mean, sig):
+    tmp = np.log((sig/mean)**2 + 1)
+    return np.log(mean)-0.5*tmp, np.sqrt(tmp)
+
+#FIXME unused
+def normal_to_lognormal(mu, sig):
+    tmp = np.exp(sig**2)
+    logmu = np.exp(mu) * np.sqrt(tmp)
+    return logmu, logmu*np.sqrt(tmp-1)
+
 
 class DomainTuple2MultiField(ift.LinearOperator):
     def __init__(self, domain, active_inds):
@@ -111,3 +204,64 @@ class DomainTuple2MultiField(ift.LinearOperator):
             res[ii[-2:]][int(ii[:-3])] = x[ii]
         res = {k: ift.Field.from_raw(self.dom0, v) for k, v in res.items()}
         return ift.MultiField.from_dict(res)
+
+#FIXME unused
+class DomainTuple2MultiField_sf(ift.LinearOperator):
+    def __init__(self, domain, active_inds):
+        self._domain = ift.DomainTuple.make(domain)
+        tgt = {str(ii): self._domain[1:] for ii in set(active_inds)}
+        self._target = ift.makeDomain(tgt)
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+
+    def apply(self, x, mode):
+        self._check_input(x, mode)
+        x = x.val
+        if mode == self.TIMES:
+            return ift.MultiField.from_dict({
+                ii: ift.makeField(self._domain[1:], x[int(ii)])
+                for ii in self._target.keys()
+            })
+        res = np.zeros(self._domain.shape,
+                       dtype=x[list(self._target.keys())[0]].dtype)
+        for ii in self._target.keys():
+            res[int(ii)] = x[ii]
+        return ift.Field.from_raw(self._domain, res)
+
+
+def freq_avg(fld):
+    return 0.5*(fld['hi'] + fld['lo'])
+
+
+def strip_oversampling(res, ofac):
+    res = res.to_dict()
+    for kk, vv in res.items():
+        n = int(np.round(vv.shape[0]/ofac))
+        domt = ift.RGSpace(n, vv.domain[0].distances[0])
+        dom = domt, vv.domain[1]
+        res[kk] = ift.makeField(dom, vv.val[:n])
+    return ift.MultiField.from_dict(res)
+
+#FIXME unused
+def load_skysample(pre, ii, sky, ofac, strip=True):
+    fname = f'{pre}samples.h5'
+    dom = sky.domain
+    ss = load_hdf5_lst(fname, domain=dom, index=ii)
+    pos = load_hdf5(f'{pre}position.h5', dom)
+    res = sky(pos+ss)
+    if strip:
+        res = strip_oversampling(res, ofac)
+    return res
+
+#FIXME unused
+def load_skysamples(pre, sky, ofac=1):
+    fname = f'{pre}samples.h5'
+    dom = sky.domain
+    samples = load_hdf5_lst(fname, domain=dom)
+    pos = load_hdf5(f'{pre}position.h5', dom)
+    res = []
+    for samp in samples:
+        rr = sky(pos+samp)
+        if ofac != 1:
+            rr = strip_oversampling(rr, ofac)
+        res.append(rr)
+    return res
diff --git a/test.py b/test.py
index 2384b62..6e9b314 100644
--- a/test.py
+++ b/test.py
@@ -13,6 +13,7 @@
 #
 # Copyright(C) 2019-2020 Max-Planck-Society
 
+import os
 from collections import defaultdict
 from itertools import product
 
@@ -20,24 +21,39 @@ import numpy as np
 import pytest
 
 import nifty6 as ift
-import src as eht
+import src as vlbi
+from config import sky_movie_mf, sky_single
+
+
+def setup_function():
+    import nifty6 as ift
+    ift.random.push_sseq_from_seed(42)
+
+
+def teardown_function():
+    import nifty6 as ift
+    ift.random.pop_sseq()
+
 
 pmp = pytest.mark.parametrize
-cb = [[], ['APAA', 'SMJC']]
-ds = [{
-    'uv': np.random.randn(11, 2)*1e7,
-    'vis': np.random.randn(11) + 1j*np.random.rand(11),
-    'sigma': np.random.randn(11)**2,
+bools = [False, True]
+avg_ints = [8]  # FIXME Test without averaging are superslow
+d3 = {
+    'uv': ift.random.current_rng().random((11, 2))*1e7,
+    'vis': ift.random.current_rng().random(11) + 1j*ift.random.current_rng().random(11),
+    'sigma': ift.random.current_rng().random(11)**2,
     'freq': 1e11,
     'time': np.array(5*[7.24027777] + 6*[7.4236114]),
     'ant1': np.array([0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2]),
     'ant2': np.array([1, 2, 3, 2, 3, 1, 2, 3, 2, 3, 3])
-}]
-ds += [eht.read_data('data/disk', dd, ff, o2) for dd, ff, o2 in product(eht.DAYS, eht.FREQS, cb)]
-ds += [eht.combined_data('data/disk', eht.FREQS, identify_short_baselines=o2) for o2 in cb]
+}
+ds = [vlbi.read_data('data/m87', dd, ff, aa) for dd, ff, aa in product(vlbi.DAYS, vlbi.FREQS, avg_ints)]
+ds.extend([vlbi.combined_data('data/m87', vlbi.FREQS, aa) for aa in avg_ints])
+ds.append(d3)
 
 seeds = [189, 4938]
 dtypes = [np.complex128, np.float64]
+dom = ift.RGSpace(10)
 
 
 def AntennaBasedCalibration(tspace, time, ant1, ant2, amplkey, phkey):
@@ -72,47 +88,60 @@ def CalibrationDistributor(domain, ant, time):
 
 @pmp('n', range(3, 40))
 def test_visibility_closure_matrices(n):
-    np.random.seed(42)
-    Phi = eht.closure.visibility_design_matrix(n)
-    Psi = eht.closure.closure_phase_design_matrix(n)
+    Phi = vlbi.closure.visibility_design_matrix(n)
+    Psi = vlbi.closure.closure_phase_design_matrix(n)
     np.testing.assert_allclose(Psi @ Phi, 0)
 
 
 @pmp('seed', seeds)
 @pmp('shape', [(3, 4), (10, 20)])
 def test_normalize(seed, shape):
-    np.random.seed(seed)
-    sp = ift.RGSpace(shape, distances=np.exp(np.random.randn(len(shape))))
-    s = ift.from_random('normal', sp).exp()
-    s_normalized = eht.normalize(s.domain)(s)
-    np.testing.assert_allclose(s_normalized.integrate(), 1.)
+    with ift.random.Context(seed):
+        sp = ift.RGSpace(shape, distances=np.exp(ift.random.current_rng().random(len(shape))))
+        s = ift.from_random(sp, 'normal').exp()
+        s_normalized = vlbi.normalize(s.domain)(s)
+        np.testing.assert_allclose(s_normalized.s_integrate(), 1.)
 
 
-@pmp('seed', seeds)
-@pmp('d', ds)
-@pmp('mode', ['ph', 'ampl'])
-def test_closure_consistency(seed, d, mode):
-    np.random.seed(seed)
-    f = eht.Visibilities2ClosureAmplitudes
-    if mode == 'ph':
-        f = eht.Visibilities2ClosurePhases
-    op = f(d)[0]
-    pos = ift.from_random('normal', op.domain, dtype=np.complex128)
-    ift.extra.check_jacobian_consistency(op, pos)
+@pmp('sky', [sky_single, sky_movie_mf])
+def test_normalize2(sky):
+    fld = sky(ift.from_random(sky.domain, 'normal'))
+    if isinstance(fld, ift.Field):
+        np.testing.assert_allclose(fld.s_integrate(), 1)
+    else:
+        val = 0
+        for ff in fld.values():
+            val += ff.s_integrate()
+        np.testing.assert_allclose(val, 1)
 
 
 @pmp('seed', seeds)
-@pmp('shape', [(3, 4), (10, 20)])
-def test_sparse_cov(seed, shape):
-    np.random.seed(seed)
-    matrix = np.random.randn(*shape)
-    sigma_sq = np.exp(np.random.randn(shape[1]))
-    d, r, c = eht.closure.sparse_cov(matrix, sigma_sq)
-    mat = np.zeros((shape[0],)*2)
-    mat[(r, c)] = d
-    res0 = mat.T @ mat @ matrix @ np.diag(sigma_sq) @ matrix.T
-    res1 = np.eye(shape[0])
-    np.testing.assert_allclose(res0, res1, rtol=1e-7, atol=1e-7)
+@pmp('d', ds)
+@pmp('mode', ['ph', 'ampl'])
+def test_closure_gradient_consistency(seed, d, mode):
+    with ift.random.Context(seed):
+        f = vlbi.Visibilities2ClosureAmplitudes
+        if mode == 'ph':
+            f = vlbi.Visibilities2ClosurePhases
+        op = f(d)[0]
+        pos = ift.from_random(op.domain, 'normal', dtype=np.complex128)
+        ift.extra.check_jacobian_consistency(op, pos, tol=1e-7, ntries=10)
+
+
+def test_saving_hdf5():
+    sp1 = ift.UnstructuredDomain(2)
+    sp2 = ift.UnstructuredDomain(2)
+    tmp_path = 'mf_test_asehefuyiq2rgui.h5'
+    mf = {'a': ift.full(sp1, 1.3), 'b': ift.full(sp2, 1.5)}
+    mf = ift.MultiField.from_dict(mf)
+    vlbi.save_hdf5(tmp_path, mf)
+    mf2 = vlbi.load_hdf5(tmp_path, mf.domain)
+    assert ((mf2 - mf)**2).s_sum() == 0
+    f = ift.full(sp1, 1.3)
+    vlbi.save_hdf5(tmp_path, f)
+    f2 = vlbi.load_hdf5(tmp_path, f.domain)
+    assert ((f2 - f)**2).s_sum() == 0
+    os.system(f'rm {tmp_path}')
 
 
 @pmp('seed', seeds)
@@ -120,119 +149,97 @@ def test_sparse_cov(seed, shape):
 @pmp('corrupt_phase', [False, True])
 @pmp('corrupt_ampl', [False, True])
 def test_closure_property(seed, dd, corrupt_phase, corrupt_ampl):
-    np.random.seed(seed)
-    CP, _ = eht.Visibilities2ClosurePhases(dd)
-    CA, _ = eht.Visibilities2ClosureAmplitudes(dd)
-    vis = ift.makeField(CP.domain, dd['vis'].copy())
-    resph0 = CP(vis).val
-    resam0 = CA(vis).val
-    corruptedvis = dd['vis'].copy()
-    nants = len(set(dd['ant1']) | set(dd['ant2']))
-    for tt in np.unique(dd['time']):
-        corruption = np.ones(nants)
-        if corrupt_ampl:
-            corruption *= np.abs(np.random.normal(size=nants))
-        if corrupt_phase:
-            corruption = corruption*np.exp(1j*np.random.normal(size=nants))
-        ind = tt == dd['time']
-        aa1 = dd['ant1'][ind]
-        aa2 = dd['ant2'][ind]
-        corruptedvis[ind] = corruptedvis[ind]*corruption[aa1]
-        corruptedvis[ind] = corruptedvis[ind]*corruption[aa2].conjugate()
-    vis = ift.makeField(CP.domain, corruptedvis)
-    np.testing.assert_allclose(resam0, CA(vis).val)
-    np.testing.assert_allclose(resph0, CP(vis).val)
-
-
-@pmp('seed', seeds)
-def test_abs_operator(seed):
-    np.random.seed(seed)
-    dom = ift.UnstructuredDomain(5)
-    op = eht.closure.ToUnitCircle(dom)
-    pos = ift.from_random('normal', op.domain)
-    op(pos)
-    # currently unfortunately the only way to test complex differentiability
-    build_complex_op = 1.j*ift.FieldAdapter(
-        op.domain, 'Im') + ift.FieldAdapter(op.domain, 'Re')
-    op = op @ build_complex_op
-    pos = ift.from_random('normal', op.domain)
-    ift.extra.check_jacobian_consistency(op, pos)
-    res = op(pos).val
-    np.testing.assert_allclose(np.abs(res), 1)
-    np.iscomplexobj(res)
-    np.testing.assert_(np.sum(np.abs(res.imag)) > 0)
+    with ift.random.Context(seed):
+        CP = vlbi.Visibilities2ClosurePhases(dd)[0]
+        CA = vlbi.Visibilities2ClosureAmplitudes(dd)[0]
+        vis = ift.makeField(CP.domain, dd['vis'].copy())
+        resph0 = CP(vis).val
+        resam0 = CA(vis).val
+        corruptedvis = dd['vis'].copy()
+        nants = len(set(dd['ant1']) | set(dd['ant2']))
+        for tt in np.unique(dd['time']):
+            corruption = np.ones(nants)
+            if corrupt_ampl:
+                corruption *= np.abs(ift.random.current_rng().random(nants))
+            if corrupt_phase:
+                corruption = corruption*np.exp(1j*ift.random.current_rng().random(nants))
+            ind = tt == dd['time']
+            aa1 = dd['ant1'][ind]
+            aa2 = dd['ant2'][ind]
+            corruptedvis[ind] = corruptedvis[ind]*corruption[aa1]
+            corruptedvis[ind] = corruptedvis[ind]*corruption[aa2].conjugate()
+        vis = ift.makeField(CP.domain, corruptedvis)
+        np.testing.assert_allclose(resam0, CA(vis).val)
+        np.testing.assert_allclose(resph0, CP(vis).val)
 
 
 @pmp('npix', [64, 98, 256])
 @pmp('epsilon', [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8])
 @pmp('dd', ds)
 def test_nfft(npix, epsilon, dd):
-    np.random.seed(42)
-    fov = 200*eht.MUAS2RAD
+    fov = 200*vlbi.MUAS2RAD
     dom = ift.RGSpace(2*(npix,), 2*(fov/npix,))
-    nfft = eht.RadioResponse(dom, dd['uv'], epsilon)
+    nfft = vlbi.RadioResponse(dom, dd['uv'], epsilon)
     ift.extra.consistency_check(nfft,
                                 domain_dtype=np.float64,
                                 target_dtype=np.complex128,
                                 only_r_linear=True)
 
 
-def _vis2closure(d, ind):
-    ind = np.sort(ind)
-    i, j, k = ind
-    vis_ind = tuple(
-        np.where(np.logical_and(d['ant1'] == a, d['ant2'] == b) == 1)[0]
-        for a, b in ((i, j), (j, k), (i, k)))
-    phase = d['vis']/abs(d['vis'])
-    return phase[vis_ind[0]]*phase[vis_ind[1]]*phase[vis_ind[2]].conjugate()
-
-
 def test_closure():
-    np.random.seed(42)
-    ant1 = []
-    ant2 = []
+    def vis2closure(d, ind):
+        ind = np.sort(ind)
+        i, j, k = ind
+        vis_ind = tuple(
+            np.where(np.logical_and(d['ant1'] == a, d['ant2'] == b) == 1)[0]
+            for a, b in ((i, j), (j, k), (i, k)))
+        phase = d['vis']/abs(d['vis'])
+        return phase[vis_ind[0]]*phase[vis_ind[1]]*phase[vis_ind[2]].conjugate()
+
+    # Antenna setup
+    #  0 -- 3 -- 4
+    #  | \/ |
+    #  | /\ |
+    #  1 -- 2
+    ant1, ant2 = [], []
     for ii in range(4):
         for jj in range(ii + 1, 4):
             ant1.extend([ii])
             ant2.extend([jj])
-    ant1.extend([3])
-    ant2.extend([4])
+    ant1.append(3)
+    ant2.append(4)
     d = {}
-    d['time'] = np.array([
-        0,
-    ]*len(ant1))
+    d['uv'] = ift.random.current_rng().random((len(ant1), 2))*1e7
+    d['time'] = np.array([0]*len(ant1))
     d['ant1'] = np.array(ant1)
     d['ant2'] = np.array(ant2)
-    d['vis'] = np.random.random(len(ant1))*np.exp(
-        np.random.random(len(ant1))*2*np.pi*1j)
-    ww = 1 + np.arange(len(ant1))
-    d['sigma'] = abs(d['vis'])/np.sqrt(ww)
+    randn = ift.random.current_rng().random
+    d['vis'] = randn(len(ant1))*np.exp(randn(len(ant1))*2*np.pi*1j)
+    d['sigma'] = abs(d['vis'])/np.sqrt(1 + np.arange(len(ant1)))
 
     closure = np.empty(3, dtype=np.complex128)
-    closure[0] = _vis2closure(d, [1, 2, 3])[0]
-    closure[1] = _vis2closure(d, [0, 2, 3])[0]
-    closure[2] = _vis2closure(d, [0, 1, 3])[0]
+    closure[0] = vis2closure(d, [0, 1, 2])[0]
+    closure[1] = vis2closure(d, [0, 1, 3])[0]
+    closure[2] = vis2closure(d, [0, 2, 3])[0]
 
-    closure_op = eht.Visibilities2ClosurePhases(d)[0]
+    _, _, closure_op = vlbi.Visibilities2ClosurePhases(d)
     vis = ift.makeField(ift.UnstructuredDomain(len(ant1)), d['vis'])
     closure_from_op = closure_op(vis).val
-    np.testing.assert_allclose(closure_from_op - closure, 0, atol=1e-10)
+    np.testing.assert_allclose(closure_from_op, closure)
 
 
 @pmp('d', ds)
 @pmp('dt', [1/60, 1/6])
 def test_calibration_dist(d, dt):
-    np.random.seed(42)
-    cph, _ = eht.Visibilities2ClosurePhases(d)
-    ca, _ = eht.Visibilities2ClosureAmplitudes(d)
+    cph = vlbi.Visibilities2ClosurePhases(d)[0]
+    ca = vlbi.Visibilities2ClosureAmplitudes(d)[0]
     time = np.array(d['time']) - min(d['time'])
     npix = max(time)/dt + 1
     tspace = ift.RGSpace(npix, distances=dt)
-    cal_op = AntennaBasedCalibration(tspace, time, d['ant1'], d['ant2'], 'a',
-                                     'p')
+    cal_op = AntennaBasedCalibration(tspace, time, d['ant1'], d['ant2'], 'a', 'p')
     g0 = cal_op(ift.full(cal_op.domain, 0)).val
     np.testing.assert_allclose(g0, np.ones_like(g0))
-    g = cal_op(ift.from_random('normal', cal_op.domain))
+    g = cal_op(ift.from_random(cal_op.domain, 'normal'))
     vis = ift.makeField(g.domain, d['vis'])
     cal_vis = g*vis
     np.testing.assert_(np.all(vis.val != cal_vis.val))
@@ -240,12 +247,132 @@ def test_calibration_dist(d, dt):
     np.testing.assert_allclose(ca(vis).val, ca(cal_vis).val)
 
 
+@pmp('d', ds)
+def test_skyscaling_invariance(d):
+    cph = vlbi.Visibilities2ClosurePhases(d)[0]
+    ca = vlbi.Visibilities2ClosureAmplitudes(d)[0]
+    vis = ift.from_random(cph.domain, 'normal', dtype=np.complex128)
+    ift.extra.assert_allclose(cph(vis), cph(0.78*vis), 0, 1e-11)
+    ift.extra.assert_allclose(ca(vis), ca(0.878*vis), 0, 1e-11)
+
+
 @pmp('ddtype', dtypes)
 @pmp('tdtype', dtypes)
 def test_DomainTupleMultiField(ddtype, tdtype):
-    np.random.seed(42)
     dom = {'lo': ift.RGSpace(5, 3), 'hi': ift.RGSpace(5, 3)}
     dom = ift.MultiDomain.make(dom)
     active_inds = ["0_lo", "4_hi"]
-    foo = eht.DomainTuple2MultiField(dom, active_inds)
+    foo = vlbi.DomainTuple2MultiField(dom, active_inds)
+    ift.extra.consistency_check(foo, domain_dtype=ddtype, target_dtype=tdtype)
+    dom = ift.RGSpace(5, 3)
+    active_inds = [0, 4]
+    foo = vlbi.DomainTuple2MultiField_sf(dom, active_inds)
     ift.extra.consistency_check(foo, domain_dtype=ddtype, target_dtype=tdtype)
+
+
+@pmp('seed', seeds)
+def test_random_states(seed):
+    vlbi.save_random_state('pref')
+    arr0 = ift.random.current_rng().random(5)
+    ift.random.push_sseq_from_seed(seed)
+    arr1 = ift.random.current_rng().random(5)
+    vlbi.load_random_state('pref')
+    arr2 = ift.random.current_rng().random(5)
+    np.testing.assert_equal(arr0, arr2)
+    with pytest.raises(AssertionError):
+        np.testing.assert_equal(arr0, arr1)
+
+
+@pmp('op', [ift.GeometryRemover(dom), ift.GeometryRemover(dom).ducktape('aa')])
+@pmp('pre', ['', 'foo'])
+@pmp('iteration', [0, 32])
+def test_position_states(op, pre, iteration):
+    pos0 = ift.from_random(op.domain, 'normal')
+    samples0 = [ift.from_random(op.domain, 'normal') for _ in range(2)]
+    vlbi.save_state(op, pos0, pre, iteration, samples0)
+    pos2 = ift.from_random(op.domain, 'normal')
+    pos1, samples1, current_iter = vlbi.load_state(op.domain, pre)
+    assert iteration == current_iter
+    ift.extra.assert_allclose(pos2, ift.from_random(op.domain, 'normal'), 0, 1e-7)
+    ift.extra.assert_allclose(pos0, pos1, 0, 1e-7)
+    for ss0, ss1 in zip(samples0, samples1):
+        ift.extra.assert_allclose(ss0, ss1, 0, 1e-7)
+
+
+@pmp('mean', [0, 1.2])
+@pmp('sig', [1, 1.3])
+def test_normal_lognormal(mean, sig):
+    mean1, sig1 = vlbi.lognormal_to_normal(*vlbi.normal_to_lognormal(mean, sig))
+    np.testing.assert_allclose(mean, mean1)
+    np.testing.assert_allclose(sig, sig1)
+
+
+def _std_check(arr):
+    tol = np.std(arr) if arr.size > 3 else 0.5
+    np.testing.assert_allclose(np.mean(arr), 1, atol=tol)
+
+
+def _cov_check(mat):
+    mat = np.copy(mat)
+    for i in range(mat.shape[0]):
+        mat[i] = np.roll(mat[i], -i)
+    mm = np.zeros(mat.shape[1])
+    mm[0] = 1.
+    tol = np.std(mat) if mat.shape[0] > 3 else 0.5
+    np.testing.assert_allclose(np.mean(mat, axis=0), mm, atol=tol)
+
+
+def _make_vis_N(dom, d):
+    vis = ift.makeField(dom, d['vis'])
+    # Assume that the sigma in the data describes the noise variation for real
+    # and imaginary part individually.
+    N = ift.makeOp(ift.makeField(dom, d['sigma']**2))
+    return vis, N
+
+
+@pmp('d', ds)
+def test_closure_noise(d):
+    nsamples = 100
+
+    # Phases
+    vis2clos = vlbi.Visibilities2ClosurePhases(d)[0]
+    vis, N = _make_vis_N(vis2clos.domain, d)
+    sc = []
+    for _ in range(nsamples):
+        n = N.draw_sample_with_dtype(np.complex128)
+        clos = vis2clos(vis + n)
+        sc.append(clos.val)
+    sc = np.array(sc)
+    scm = np.mean(sc, axis=0)
+    # np.std adds var of imag and real part
+    # however the Gauss is not really 2D
+    val = np.std(sc, axis=0) 
+    mat = 0.
+    for s in sc:
+        tm = s-scm
+        mat += np.outer(tm, tm.conjugate())
+    # the line before adds var of imag and real part
+    # however the Gauss is not really 2D
+    mat /= nsamples 
+
+    _std_check(val)
+    _cov_check(mat)
+
+    # Amplitudes
+    vis2clos = vlbi.Visibilities2ClosureAmplitudes(d)[0]
+    vis, N = _make_vis_N(vis2clos.domain, d)
+    sc = []
+    for _ in range(nsamples):
+        n = N.draw_sample_with_dtype(np.complex128)
+        clos = vis2clos(vis + n)
+        sc.append(clos.val)
+    sc = np.array(sc)
+    scm = np.mean(sc, axis=0)
+    val = np.std(sc, axis=0)
+    mat = 0.
+    for s in sc:
+        tm = s-scm
+        mat += np.outer(tm, tm)
+    mat /= nsamples
+    _std_check(val)
+    _cov_check(mat)
-- 
GitLab