Commit 721d0d1b authored by Reimar H Leike's avatar Reimar H Leike
Browse files

Merge branch 'refactor_likelihood' into 'revision'

refactor likelihood building

See merge request !2
parents 1b8baa3c 28c89338
Pipeline #91950 passed with stages
in 15 minutes and 40 seconds
......@@ -25,9 +25,9 @@ from time import time
import nifty6 as ift
import src as vlbi
from config import comm, nranks, rank, master
from config import doms, dt, eps, min_timestamps_per_bin, npixt, nthreads
from config import doms, eps, min_timestamps_per_bin, nthreads
from config import sky_movie_mf as sky
from config import startt
from config import startt, dt, npixt
def stat_plotting(pos, KL):
if master:
......@@ -46,8 +46,8 @@ def stat_plotting(pos, KL):
sc_mean.add(0.5*(samp['hi']+samp['lo']).val)
sc_spectral.add(2 * (samp['lo']-samp['hi']).val / (samp['lo']+samp['hi']).val )
def optimization_heuristic(ii, lh_full, lh_amp, lh_ph, ind_full, ind_amp, ind_ph):
cut = 2 if dt == 24 else 6
def optimization_heuristic(ii, likelihoods):
lh_full, lh_full_amp, lh_full_ph, lh_cut, lh_cut_amp, lh_cut_ph = likelihoods
N_samples = 10 * (1 + ii // 8)
N_iterations = 4 * (4 + ii // 4) if ii<50 else 20
......@@ -63,133 +63,98 @@ def optimization_heuristic(ii, lh_full, lh_amp, lh_ph, ind_full, ind_amp, ind_ph
else:
minimizer = ift.NewtonCG(ic_newton)
if ii < 30:
lh = []
active_inds = []
if ii % 2 == 0 or ii < 10:
for key in ind_full.keys():
if int(key[0])<cut and key[1] == '_': #FIXME: dirty fix to select
# the data of the first two days
active_inds.append(key)
lh.append(ind_full[key])
lh = lh_cut
elif ii % 4 == 1:
for key in ind_amp.keys():
if int(key[0])<cut and key[1] == '_': #FIXME
active_inds.append(key)
lh.append(ind_amp[key])
lh = lh_cut_amp
else:
for key in ind_ph.keys():
if int(key[0])<cut and key[1] == '_': #FIXME
active_inds.append(key)
lh.append(ind_ph[key])
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
lh = reduce(add, lh) @ conv
lh = lh_cut_ph
else:
if ii % 2 == 0 or ii > 50:
lh = lh_full
elif ii % 4 == 1:
lh = lh_amp
lh = lh_full_amp
else:
lh = lh_ph
lh = lh_full_ph
return minimizer, N_samples, N_iterations, lh
def setup():
if len(sys.argv) != 3:
raise RuntimeError
_, pre_data, fname_input = sys.argv
pre_output = pre_data
ndata = {}
lh_full = []
lh_amp = []
lh_ph = []
ind_full = {}
ind_amp = {}
ind_ph = {}
def build_likelihood(rawdd, startt, npixt, dt, mode):
lh = []
active_inds = []
for freq in vlbi.data.FREQS:
rawd = vlbi.combined_data(f'data/{pre_data}', [freq], min_timestamps_per_bin) if master else None
if nranks > 1:
rawd = comm.bcast(rawd)
ndata[freq] = []
rawd = rawdd[freq]
args = {'tmin': startt, 'tmax': npixt*dt, 'delta_t': dt}
for ii, dd in enumerate(vlbi.time_binning(rawd, **args)):
if len(dd) == 0:
ndata[freq] += [0, ]
continue
ind = str(ii) + "_" + freq
active_inds.append(ind)
vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd)
vis2closampl, evalsampl = vlbi.Visibilities2ClosureAmplitudes(dd)
nfft = vlbi.RadioResponse(doms, dd['uv'], eps, nthreads)
vis = ift.makeField(nfft.target, dd['vis'])
llh = []
if mode in ['amp', 'full']:
vis2closampl, evalsampl = vlbi.Visibilities2ClosureAmplitudes(dd)
llh.append(ift.GaussianEnergy(mean=vis2closampl(vis)) @ vis2closampl)
if mode in ['ph', 'full']:
vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd)
llh.append(ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph)
llh_op = reduce(add, llh) @ nfft.ducktape(ind)
if mode == 'full' and freq == vlbi.data.FREQS[0]:
ift.extra.check_jacobian_consistency(llh_op, ift.from_random(llh_op.domain),
tol=1e-5, ntries=10)
lh.append(llh_op)
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
return reduce(add, lh) @ conv
ndata[freq] += [vis2closph.target.size + vis2closampl.target.size,]
lhph = ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph
lhamp = ift.GaussianEnergy(mean=vis2closampl(vis)) @ vis2closampl
llh_full = reduce(add, [lhamp, lhph]) @ nfft.ducktape(ind)
lh_full.append(llh_full)
llh_amp = reduce(add, [lhamp]) @ nfft.ducktape(ind)
lh_amp.append(llh_amp)
llh_ph = reduce(add, [lhph]) @ nfft.ducktape(ind)
lh_ph.append(llh_ph)
def setup():
if len(sys.argv) != 3:
raise RuntimeError
_, pre_data, fname_input = sys.argv
ind_ph[ind] = llh_ph
ind_amp[ind] = llh_amp
ind_full[ind] = llh_full
rawd = {}
for freq in vlbi.data.FREQS:
rawd[freq] = vlbi.combined_data(f'data/{pre_data}', [freq], min_timestamps_per_bin) if master else None
if nranks > 1:
rawd = comm.bcast(rawd)
foo = reduce(add, [lhph, lhamp]) @ nfft.ducktape(ind)
ift.extra.check_jacobian_consistency(foo, ift.from_random(foo.domain),
tol=1e-5, ntries=10)
pre_output = pre_data
lh_full = build_likelihood(rawd, startt, npixt, dt, 'full')
lh_full_ph = build_likelihood(rawd, startt, npixt, dt, 'ph')
lh_full_amp = build_likelihood(rawd, startt, npixt, dt, 'amp')
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
lh_full = reduce(add, lh_full) @ conv
lh_amp = reduce(add, lh_amp) @ conv
lh_ph = reduce(add, lh_ph) @ conv
lh_cut = build_likelihood(rawd, startt, npixt//2, dt, 'full')
lh_cut_ph = build_likelihood(rawd, startt, npixt//2, dt, 'ph')
lh_cut_amp = build_likelihood(rawd, startt, npixt//2, dt, 'amp')
pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None
if nranks > 1:
pos = comm.bcast(pos)
ic = ift.AbsDeltaEnergyController(0.5, iteration_limit=200, name=f'Sampling(task {rank})', convergence_level=3)
if master:
t0 = time()
(lh_full @ sky)(pos)
print(f'Likelihood call: {1000*(time()-t0):.0f} ms')
return pos, lh_full, lh_amp, lh_ph, sky, ic, pre_output, ind_full, ind_amp, ind_ph
return pos, sky, ic, pre_output, (lh_full, lh_full_amp, lh_full_ph, lh_cut, lh_cut_amp, lh_cut_ph)
# Encapsulate everything in functions, to avoid as many (unintended) global variables as possible
def main():
# These following lines are to verify the scan averaging
with open("time_averaging.txt", 'w') as f:
# delete the file such that new lines can be appended
f.write("min max avg med\n")
pos, lh_full, lh_amp, lh_ph, sky, ic, pre_output, ind_full, ind_amp, ind_ph = setup()
pos, sky, ic, pre_output, likelihoods = setup()
for ii in range(60):
gc.collect()
minimizer, N_samples, N_iterations, lh = optimization_heuristic(
ii, lh_full, lh_amp, lh_ph,ind_full, ind_amp, ind_ph)
minimizer, N_samples, N_iterations, lh = optimization_heuristic(ii, likelihoods)
if master:
print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}')
ll = lh @ sky
H = ift.StandardHamiltonian(ll, ic)
KL = ift.MetricGaussianKL(pos, H, N_samples, comm=comm, mirror_samples=True)
......
......@@ -115,8 +115,6 @@ def read_data(prefix, day, freq, ts_per_bin):
i += 1
i0 = i
tbsize = np.array([t[2] for t in tbins])
print("min, max, avg, med time bins:")
print("{} {} {} {}".format(np.amin(tbsize), np.amax(tbsize), np.mean(tbsize), np.median(tbsize)))
else:
print('No temporal averaging.')
# End combine over scan
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment