Commit c233eba4 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

Merge branch 'refactor_likelihood_pf' into 'refactor_likelihood'

small fixes

See merge request !3
parents 7a7617ba 1b98ae46
Pipeline #91934 passed with stages
in 17 minutes and 49 seconds
...@@ -25,9 +25,9 @@ from time import time ...@@ -25,9 +25,9 @@ from time import time
import nifty6 as ift import nifty6 as ift
import src as vlbi import src as vlbi
from config import comm, nranks, rank, master from config import comm, nranks, rank, master
from config import doms, dt, eps, min_timestamps_per_bin, npixt, nthreads from config import doms, eps, min_timestamps_per_bin, nthreads
from config import sky_movie_mf as sky from config import sky_movie_mf as sky
from config import startt, dt, npix from config import startt, dt, npixt
def stat_plotting(pos, KL): def stat_plotting(pos, KL):
if master: if master:
...@@ -79,8 +79,7 @@ def optimization_heuristic(ii, likelihoods): ...@@ -79,8 +79,7 @@ def optimization_heuristic(ii, likelihoods):
lh = lh_full_ph lh = lh_full_ph
return minimizer, N_samples, N_iterations, lh return minimizer, N_samples, N_iterations, lh
def build_likelihood(rawdd, startt, npixt, dt, mode):
def build_likelihood(rawdd, startt, npix, dt, mode):
lh = [] lh = []
active_inds = [] active_inds = []
for freq in vlbi.data.FREQS: for freq in vlbi.data.FREQS:
...@@ -103,13 +102,13 @@ def build_likelihood(rawdd, startt, npix, dt, mode): ...@@ -103,13 +102,13 @@ def build_likelihood(rawdd, startt, npix, dt, mode):
vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd) vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd)
llh.append(ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph) llh.append(ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph)
llh_op = reduce(add, llh) @ nfft.ducktape(ind) llh_op = reduce(add, llh) @ nfft.ducktape(ind)
ift.extra.check_jacobian_consistency(llh_op, ift.from_random(llh_op.domain), if mode == 'full' and freq == vlbi.data.FREQS[0]:
tol=1e-5, ntries=10) ift.extra.check_jacobian_consistency(llh_op, ift.from_random(llh_op.domain),
tol=1e-5, ntries=10)
lh.append(llh_op) lh.append(llh_op)
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds) conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
lh= reduce(add, lh) @ conv return reduce(add, lh) @ conv
return lh
def setup(): def setup():
if len(sys.argv) != 3: if len(sys.argv) != 3:
...@@ -124,13 +123,13 @@ def setup(): ...@@ -124,13 +123,13 @@ def setup():
pre_output = pre_data pre_output = pre_data
lh_full = build_likelihood(rawd, startt, npix, dt,'full') lh_full = build_likelihood(rawd, startt, npixt, dt, 'full')
lh_full_ph = build_likelihood(rawd, startt, npix, dt,'ph') lh_full_ph = build_likelihood(rawd, startt, npixt, dt, 'ph')
lh_full_amp = build_likelihood(rawd, startt, npix, dt,'amp') lh_full_amp = build_likelihood(rawd, startt, npixt, dt, 'amp')
lh_cut = build_likelihood(rawd, startt, npix//2, dt,'full') lh_cut = build_likelihood(rawd, startt, npixt//2, dt, 'full')
lh_cut_ph = build_likelihood(rawd, startt, npix//2, dt,'ph') lh_cut_ph = build_likelihood(rawd, startt, npixt//2, dt, 'ph')
lh_cut_amp = build_likelihood(rawd, startt, npix//2, dt,'amp') lh_cut_amp = build_likelihood(rawd, startt, npixt//2, dt, 'amp')
pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None
if nranks > 1: if nranks > 1:
...@@ -151,7 +150,7 @@ def main(): ...@@ -151,7 +150,7 @@ def main():
with open("time_averaging.txt", 'w') as f: with open("time_averaging.txt", 'w') as f:
# delete the file such that new lines can be appended # delete the file such that new lines can be appended
f.write("min max avg med\n") f.write("min max avg med\n")
pos, sky, ic, pre_output, likelihoods= setup() pos, sky, ic, pre_output, likelihoods = setup()
for ii in range(60): for ii in range(60):
gc.collect() gc.collect()
...@@ -159,7 +158,7 @@ def main(): ...@@ -159,7 +158,7 @@ def main():
if master: if master:
print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}') print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}')
ll = lh @ sky ll = lh @ sky
H = ift.StandardHamiltonian(ll, ic) H = ift.StandardHamiltonian(ll, ic)
KL = ift.MetricGaussianKL(pos, H, N_samples, comm=comm, mirror_samples=True) KL = ift.MetricGaussianKL(pos, H, N_samples, comm=comm, mirror_samples=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment