Skip to content
Snippets Groups Projects
Commit 1b98ae46 authored by Philipp Frank's avatar Philipp Frank
Browse files

small fixes

parent 7a7617ba
No related branches found
No related tags found
3 merge requests!4Revision,!3small fixes,!2refactor likelihood building
Pipeline #91901 passed
...@@ -25,9 +25,9 @@ from time import time ...@@ -25,9 +25,9 @@ from time import time
import nifty6 as ift import nifty6 as ift
import src as vlbi import src as vlbi
from config import comm, nranks, rank, master from config import comm, nranks, rank, master
from config import doms, dt, eps, min_timestamps_per_bin, npixt, nthreads from config import doms, eps, min_timestamps_per_bin, nthreads
from config import sky_movie_mf as sky from config import sky_movie_mf as sky
from config import startt, dt, npix from config import startt, dt, npixt
def stat_plotting(pos, KL): def stat_plotting(pos, KL):
if master: if master:
...@@ -79,8 +79,7 @@ def optimization_heuristic(ii, likelihoods): ...@@ -79,8 +79,7 @@ def optimization_heuristic(ii, likelihoods):
lh = lh_full_ph lh = lh_full_ph
return minimizer, N_samples, N_iterations, lh return minimizer, N_samples, N_iterations, lh
def build_likelihood(rawdd, startt, npixt, dt, mode):
def build_likelihood(rawdd, startt, npix, dt, mode):
lh = [] lh = []
active_inds = [] active_inds = []
for freq in vlbi.data.FREQS: for freq in vlbi.data.FREQS:
...@@ -103,13 +102,13 @@ def build_likelihood(rawdd, startt, npix, dt, mode): ...@@ -103,13 +102,13 @@ def build_likelihood(rawdd, startt, npix, dt, mode):
vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd) vis2closph, evalsph, _ = vlbi.Visibilities2ClosurePhases(dd)
llh.append(ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph) llh.append(ift.GaussianEnergy(mean=vis2closph(vis)) @ vis2closph)
llh_op = reduce(add, llh) @ nfft.ducktape(ind) llh_op = reduce(add, llh) @ nfft.ducktape(ind)
ift.extra.check_jacobian_consistency(llh_op, ift.from_random(llh_op.domain), if mode == 'full' and freq == vlbi.data.FREQS[0]:
tol=1e-5, ntries=10) ift.extra.check_jacobian_consistency(llh_op, ift.from_random(llh_op.domain),
tol=1e-5, ntries=10)
lh.append(llh_op) lh.append(llh_op)
conv = vlbi.DomainTuple2MultiField(sky.target, active_inds) conv = vlbi.DomainTuple2MultiField(sky.target, active_inds)
lh= reduce(add, lh) @ conv return reduce(add, lh) @ conv
return lh
def setup(): def setup():
if len(sys.argv) != 3: if len(sys.argv) != 3:
...@@ -124,13 +123,13 @@ def setup(): ...@@ -124,13 +123,13 @@ def setup():
pre_output = pre_data pre_output = pre_data
lh_full = build_likelihood(rawd, startt, npix, dt,'full') lh_full = build_likelihood(rawd, startt, npixt, dt, 'full')
lh_full_ph = build_likelihood(rawd, startt, npix, dt,'ph') lh_full_ph = build_likelihood(rawd, startt, npixt, dt, 'ph')
lh_full_amp = build_likelihood(rawd, startt, npix, dt,'amp') lh_full_amp = build_likelihood(rawd, startt, npixt, dt, 'amp')
lh_cut = build_likelihood(rawd, startt, npix//2, dt,'full') lh_cut = build_likelihood(rawd, startt, npixt//2, dt, 'full')
lh_cut_ph = build_likelihood(rawd, startt, npix//2, dt,'ph') lh_cut_ph = build_likelihood(rawd, startt, npixt//2, dt, 'ph')
lh_cut_amp = build_likelihood(rawd, startt, npix//2, dt,'amp') lh_cut_amp = build_likelihood(rawd, startt, npixt//2, dt, 'amp')
pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None pos = vlbi.load_hdf5(fname_input, sky.domain) if master else None
if nranks > 1: if nranks > 1:
...@@ -151,7 +150,7 @@ def main(): ...@@ -151,7 +150,7 @@ def main():
with open("time_averaging.txt", 'w') as f: with open("time_averaging.txt", 'w') as f:
# delete the file such that new lines can be appended # delete the file such that new lines can be appended
f.write("min max avg med\n") f.write("min max avg med\n")
pos, sky, ic, pre_output, likelihoods= setup() pos, sky, ic, pre_output, likelihoods = setup()
for ii in range(60): for ii in range(60):
gc.collect() gc.collect()
...@@ -159,7 +158,7 @@ def main(): ...@@ -159,7 +158,7 @@ def main():
if master: if master:
print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}') print(f'Iter: {ii}, N_samples: {N_samples}, N_iter: {N_iterations}')
ll = lh @ sky ll = lh @ sky
H = ift.StandardHamiltonian(ll, ic) H = ift.StandardHamiltonian(ll, ic)
KL = ift.MetricGaussianKL(pos, H, N_samples, comm=comm, mirror_samples=True) KL = ift.MetricGaussianKL(pos, H, N_samples, comm=comm, mirror_samples=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment