Commit 7a7617ba authored by Jakob Knollmüller's avatar Jakob Knollmüller
Browse files

refactor likelihood building

parent 1b8baa3c
Pipeline #91891 passed with stages
in 14 minutes and 22 seconds
...@@ -27,7 +27,7 @@ import src as vlbi ...@@ -27,7 +27,7 @@ import src as vlbi
from config import comm, nranks, rank, master from config import comm, nranks, rank, master
from config import doms, dt, eps, min_timestamps_per_bin, npixt, nthreads from config import doms, dt, eps, min_timestamps_per_bin, npixt, nthreads
from config import sky_movie_mf as sky from config import sky_movie_mf as sky
from config import startt from config import startt, dt, npix
def stat_plotting(pos, KL): def stat_plotting(pos, KL):
if master: if master:
...@@ -46,8 +46,8 @@ def stat_plotting(pos, KL): ...@@ -46,8 +46,8 @@ def stat_plotting(pos, KL):
sc_mean.add(0.5*(samp['hi']+samp['lo']).val) sc_mean.add(0.5*(samp['hi']+samp['lo']).val)
sc_spectral.add(2 * (samp['lo']-samp['hi']).val / (samp['lo']+samp['hi']).val ) sc_spectral.add(2 * (samp['lo']-samp['hi']).val / (samp['lo']+samp['hi']).val )
def optimization_heuristic(ii, lh_full, lh_amp, lh_ph, ind_full, ind_amp, ind_ph): def optimization_heuristic(ii, likelihoods):
cut = 2 if dt == 24 else 6 lh_full, lh_full_amp, lh_full_ph, lh_cut, lh_cut_amp, lh_cut_ph = likelihoods
N_samples = 10 * (1 + ii // 8) N_samples = 10 * (1 + ii // 8)
N_iterations = 4 * (4 + ii // 4) if ii<50 else 20 N_iterations = 4 * (4 + ii // 4) if ii<50 else 20
...@@ -63,115 +63,86 @@ def optimization_heuristic(ii, lh_full, lh_amp, lh_ph, ind_full, ind_amp, ind_ph ...@@ -63,115 +63,86 @@ def optimization_heuristic(ii, lh_full, lh_amp, lh_ph, ind_full, ind_amp, ind_ph
else: else:
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
if ii < 30: if ii < 30:
lh = []
active_inds = []
if ii % 2 == 0 or ii < 10: if ii % 2 == 0 or ii < 10:
for key in ind_full.keys(): lh = lh_cut
if int(key[0])<cut and key[1] == '_': #FIXME: dirty fix to select
# the data of the first two days
active_inds.append(key)
lh.append(ind_full[key])
elif ii % 4 == 1: elif ii % 4 == 1:
for key in ind_amp.keys(): lh = lh_cut_amp
if int(key[0])<cut and key[1] == '_': #FIXME
active_inds.append(key)
lh.append(ind_amp[key])
else: else:
for key in ind_ph.keys(): lh = lh_cut_ph
if int(key[0])<cut and key[1] == '_': #FIXME
active_inds.append(key)
lh.append(ind_ph[key])
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
lh = reduce(add, lh) @ conv
else: else:
if ii % 2 == 0 or ii > 50: if ii % 2 == 0 or ii > 50:
lh = lh_full lh = lh_full
elif ii % 4 == 1: elif ii % 4 == 1:
lh = lh_amp lh = lh_full_amp
else: else:
lh = lh_ph lh = lh_full_ph
return minimizer, N_samples, N_iterations, lh return minimizer, N_samples, N_iterations, lh
def setup(): def build_likelihood(rawdd, startt, npix, dt, mode):
if len(sys.argv) != 3: lh = []
raise RuntimeError
_, pre_data, fname_input = sys.argv
pre_output = pre_data
ndata = {}
lh_full = []
lh_amp = []
lh_ph = []
ind_full = {}
ind_amp = {}
ind_ph = {}
active_inds = [] active_inds = []
for freq in vlbi.data.FREQS: for freq in vlbi.data.FREQS:
rawd = vlbi.combined_data(f'data/{pre_data}', [freq], min_timestamps_per_bin) if master else None rawd = rawdd[freq]
if nranks > 1:
rawd = comm.bcast(rawd)
ndata[freq] = []
args = {'tmin': startt, 'tmax': npixt*dt, 'delta_t': dt} args = {'tmin': startt, 'tmax': npixt*dt, 'delta_t': dt}
for ii, dd in enumerate(vlbi.time_binning(rawd, **args)): for ii, dd in enumerate(vlbi.time_binning(rawd, **args)):
if len(dd) == 0: if len(dd) == 0:
ndata[freq] += [0, ]
continue continue
ind = str(ii) + "_" + freq ind = str(ii) + "_" + freq
active_inds.append(ind) active_inds.append(ind)
vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd)
vis2closampl, evalsampl = vlbi.Visibilities2ClosureAmplitudes(dd)
nfft = vlbi.RadioResponse(doms, dd['uv'], eps, nthreads) nfft = vlbi.RadioResponse(doms, dd['uv'], eps, nthreads)
vis = ift.makeField(nfft.target, dd['vis']) vis = ift.makeField(nfft.target, dd['vis'])
llh = []
if mode in ['amp', 'full']:
vis2closampl, evalsampl = vlbi.Visibilities2ClosureAmplitudes(dd)
llh.append(ift.GaussianEnergy(mean=vis2closampl(vis)) @ vis2closampl)
if mode in ['ph', 'full']:
vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd)
llh.append(ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph)
llh_op = reduce(add, llh) @ nfft.ducktape(ind)
ift.extra.check_jacobian_consistency(llh_op, ift.from_random(llh_op.domain),
tol=1e-5, ntries=10)
lh.append(llh_op)
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
lh= reduce(add, lh) @ conv
return lh
ndata[freq] += [vis2closph.target.size + vis2closampl.target.size,] def setup():
if len(sys.argv) != 3:
lhph = ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph raise RuntimeError
lhamp = ift.GaussianEnergy(mean=vis2closampl(vis)) @ vis2closampl _, pre_data, fname_input = sys.argv
llh_full = reduce(add, [lhamp, lhph]) @ nfft.ducktape(ind)
lh_full.append(llh_full)
llh_amp = reduce(add, [lhamp]) @ nfft.ducktape(ind)
lh_amp.append(llh_amp)
llh_ph = reduce(add, [lhph]) @ nfft.ducktape(ind)
lh_ph.append(llh_ph)
ind_ph[ind] = llh_ph rawd = {}
ind_amp[ind] = llh_amp for freq in vlbi.data.FREQS:
ind_full[ind] = llh_full rawd[freq] = vlbi.combined_data(f'data/{pre_data}', [freq], min_timestamps_per_bin) if master else None
if nranks > 1:
rawd = comm.bcast(rawd)
foo = reduce(add, [lhph, lhamp]) @ nfft.ducktape(ind) pre_output = pre_data
ift.extra.check_jacobian_consistency(foo, ift.from_random(foo.domain),
tol=1e-5, ntries=10)
lh_full = build_likelihood(rawd, startt, npix, dt,'full')
lh_full_ph = build_likelihood(rawd, startt, npix, dt,'ph')
lh_full_amp = build_likelihood(rawd, startt, npix, dt,'amp')
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds) lh_cut = build_likelihood(rawd, startt, npix//2, dt,'full')
lh_full = reduce(add, lh_full) @ conv lh_cut_ph = build_likelihood(rawd, startt, npix//2, dt,'ph')
lh_amp = reduce(add, lh_amp) @ conv lh_cut_amp = build_likelihood(rawd, startt, npix//2, dt,'amp')
lh_ph = reduce(add, lh_ph) @ conv
pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None
if nranks > 1: if nranks > 1:
pos = comm.bcast(pos) pos = comm.bcast(pos)
ic = ift.AbsDeltaEnergyController(0.5, iteration_limit=200, name=f'Sampling(task {rank})', convergence_level=3) ic = ift.AbsDeltaEnergyController(0.5, iteration_limit=200, name=f'Sampling(task {rank})', convergence_level=3)
if master: if master:
t0 = time() t0 = time()
(lh_full @ sky)(pos) (lh_full @ sky)(pos)
print(f'Likelihood call: {1000*(time()-t0):.0f} ms') print(f'Likelihood call: {1000*(time()-t0):.0f} ms')
return pos, lh_full, lh_amp, lh_ph, sky, ic, pre_output, ind_full, ind_amp, ind_ph return pos, sky, ic, pre_output, (lh_full, lh_full_amp, lh_full_ph, lh_cut, lh_cut_amp, lh_cut_ph)
# Encapsulate everything in functions, to avoid as many (unintended) global variables as possible # Encapsulate everything in functions, to avoid as many (unintended) global variables as possible
...@@ -180,12 +151,11 @@ def main(): ...@@ -180,12 +151,11 @@ def main():
with open("time_averaging.txt", 'w') as f: with open("time_averaging.txt", 'w') as f:
# delete the file such that new lines can be appended # delete the file such that new lines can be appended
f.write("min max avg med\n") f.write("min max avg med\n")
pos, lh_full, lh_amp, lh_ph, sky, ic, pre_output, ind_full, ind_amp, ind_ph = setup() pos, sky, ic, pre_output, likelihoods= setup()
for ii in range(60): for ii in range(60):
gc.collect() gc.collect()
minimizer, N_samples, N_iterations, lh = optimization_heuristic( minimizer, N_samples, N_iterations, lh = optimization_heuristic(ii, likelihoods)
ii, lh_full, lh_amp, lh_ph,ind_full, ind_amp, ind_ph)
if master: if master:
print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}') print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment