Skip to content
Snippets Groups Projects
Commit 9f96cfc8 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add command line interface

parent 6ca346b0
No related branches found
No related tags found
No related merge requests found
......@@ -15,31 +15,49 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("-j", type=int, default=1)
parser.add_argument("--use-cached", action="store_true")
parser.add_argument("ms")
parser.add_argument("--use-wgridding", action="store_true")
parser.add_argument(
"--data-column",
default="DATA",
help="Only active if a measurement set is read.",
)
parser.add_argument("--point", action="append", nargs=2)
parser.add_argument("ms", type=str)
parser.add_argument("xfov", type=str)
parser.add_argument("yfov", type=str)
parser.add_argument("xpix", type=int)
parser.add_argument("ypix", type=int)
parser.add_argument("diffusefluxlevel", type=float)
args = parser.parse_args()
rve.set_nthreads(args.j)
rve.set_wgridding(False)
if splitext(args.ms)[1] == ".npz":
obs = rve.Observation.load(args.ms)
else:
obs = rve.ms2observations(args.ms, "DATA", False, 0, "stokesiavg")[0]
obs = rve.ms2observations(args.ms, args.data_column, False, 0, "stokesiavg")[0]
rve.set_nthreads(args.j)
rve.set_wgridding(args.use_wgridding)
fov = np.array([rve.str2rad(args.xfov), rve.str2rad(args.yfov)])
npix = np.array([args.xpix, args.ypix])
rve.set_epsilon(1 / 10 / obs.max_snr())
fov = np.array([3, 1.5]) * rve.ARCMIN2RAD
npix = np.array([4096, 2048])
npix = np.array([4096, 2048]) / 4 # FIXME QUICK
ppos = []
for point in args.point:
ppos.append([rve.str2rad(point[0]), rve.str2rad(point[1])])
dom = ift.RGSpace(npix, fov / npix)
logsky = ift.SimpleCorrelatedField(
dom, 21, (1, 0.1), (5, 1), (1.2, 0.4), (0.2, 0.2), (-2, 0.5)
dom, args.diffusefluxlevel, (1, 0.1), (5, 1), (1.2, 0.4), (0.2, 0.2), (-2, 0.5)
)
diffuse = logsky.exp()
inserter = rve.PointInserter(dom, np.array([[0, 0], [0.7, -0.34]]) * rve.AS2RAD)
if len(ppos) > 0:
inserter = rve.PointInserter(dom, ppos)
points = ift.InverseGammaOperator(
inserter.domain, alpha=0.5, q=0.2 / dom.scalar_dvol
).ducktape("points")
points = inserter @ points
sky = points + diffuse
sky = diffuse + points
else:
sky = diffuse
npix = 2500
effuv = np.linalg.norm(obs.effective_uv().T, axis=1)
assert obs.nfreq == obs.npol == 1
......@@ -63,7 +81,9 @@ def main():
lh = rve.ImagingLikelihood(obs, points)
ham = ift.StandardHamiltonian(lh)
state = rve.MinimizationState(0.1 * ift.from_random(ham.domain), [])
mini = ift.NewtonCG(ift.GradientNormController(name="newton", iteration_limit=4))
mini = ift.NewtonCG(
ift.GradientNormController(name="newton", iteration_limit=4)
)
if args.use_cached and isfile("stage0"):
state = rve.MinimizationState.load("stage0")
else:
......@@ -78,9 +98,12 @@ def main():
)
ham = ift.StandardHamiltonian(lh)
state = rve.MinimizationState(
ift.MultiField.union([0.1 * ift.from_random(diffuse.domain), state.mean]), []
ift.MultiField.union([0.1 * ift.from_random(diffuse.domain), state.mean]),
[],
)
mini = ift.NewtonCG(
ift.GradientNormController(name="newton", iteration_limit=20)
)
mini = ift.NewtonCG(ift.GradientNormController(name="newton", iteration_limit=20))
if args.use_cached and isfile("stage1"):
state = rve.MinimizationState.load("stage1")
else:
......@@ -97,7 +120,8 @@ def main():
ham = ift.StandardHamiltonian(lh, ic)
cst = sky.domain.keys()
state = rve.MinimizationState(
ift.MultiField.union([0.1 * ift.from_random(weightop.domain), state.mean]), []
ift.MultiField.union([0.1 * ift.from_random(weightop.domain), state.mean]),
[],
)
mini = ift.VL_BFGS(ift.GradientNormController(name="bfgs", iteration_limit=20))
if args.use_cached and isfile("stage2"):
......@@ -134,9 +158,13 @@ def main():
ham = ift.StandardHamiltonian(lh, ic)
for ii in range(30):
if ii < 5:
mini = ift.VL_BFGS(ift.GradientNormController(name="newton", iteration_limit=15))
mini = ift.VL_BFGS(
ift.GradientNormController(name="newton", iteration_limit=15)
)
else:
mini = ift.NewtonCG(ift.GradientNormController(name="newton", iteration_limit=15))
mini = ift.NewtonCG(
ift.GradientNormController(name="newton", iteration_limit=15)
)
fname = f"stage4_{ii}"
if args.use_cached and isfile(fname):
state = rve.MinimizationState.load(fname)
......
......
......@@ -8,3 +8,33 @@ ARCMIN2RAD = np.pi / 60 / 180
AS2RAD = ARCMIN2RAD / 60
DEG2RAD = np.pi / 180
SPEEDOFLIGHT = 299792458
def str2rad(s):
"""Convert string of number and unit to radian.
Support the following units: muas mas as amin deg rad.
Parameters
----------
s : str
TODO
"""
c = {
"muas": AS2RAD * 1e-6,
"mas": AS2RAD * 1e-3,
"as": AS2RAD,
"amin": ARCMIN2RAD,
"deg": DEG2RAD,
"rad": 1,
}
keys = list(c.keys())
keys.sort(key=len)
for kk in reversed(keys):
nn = -len(kk)
unit = s[nn:]
print(unit, kk)
if unit == kk:
return float(s[:nn])*c[kk]
raise RuntimeError("Unit not understood")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment