Skip to content
Snippets Groups Projects
Commit bcc66910 authored by Julian Rüstig's avatar Julian Rüstig
Browse files

multi: beam as relative field with no slicing

parent bc1c193a
Branches
No related tags found
1 merge request!50Draft: Mosaic imaging
import nifty8 as ift
import numpy as np
class SkyBeamer(ift.LinearOperator):
"""Maps from the total sky domain to individual sky domains and applies the
primary beam pattern.
Parameters
----------
domain : RGSpace
Two-dimensional RG-Space which serves as domain. The distances are
in pseudo-radian.
beam_directions : dict(key: beam)
Dictionary with beam patterns (in same RGSpace as domain)
"""
def __init__(self, domain, beam_directions):
self._bd = dict(beam_directions)
self._domain = ift.makeDomain(domain)
t, b = {}, {}
for kk, vv in self._bd.items():
print("\r" + kk, end="")
t[kk], b[kk] = self._domain, vv['beam']
assert t[kk] == b[kk].domain
print()
self._beams = b
self._target = ift.makeDomain(t)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
if mode == self.TIMES:
res = {}
for kk in self._target.keys():
res[kk] = x * self._beams[kk].val
else:
res = np.zeros(self._domain.shape)
for kk, xx in x.items():
res += xx * self._beams[kk].val
return ift.makeField(self._tgt(mode), res)
......@@ -4,7 +4,7 @@ from astropy import units as u
fov = 25 * u.arcsec
npix = 512
npix = 128
distance = fov/npix
dvol = distance**2
......
from numpy.typing import ArrayLike
from functools import reduce
from resolve.library.primary_beams import alma_beam_func
import resolve as rve
import nifty8 as ift
from resolve.sky_model import sky_model_diffuse
import configparser
import numpy as np
from beamer import SkyBeamer
from os.path import join
from sys import exit
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from astropy import units as u
from astropy.coordinates import SkyCoord
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
master = comm.Get_rank() == 0
except ImportError:
comm = None
master = True
data_filenames = ['small/inter_data/m51c.ALMA_0.5arcsec.ms_fld00.npz',
'small/inter_data/m51c.ALMA_0.5arcsec.ms_fld01.npz',
'small/inter_data/m51c.ALMA_0.5arcsec.ms_fld02.npz']
def coords(shape: int, distance: float) -> ArrayLike:
'''Returns coordinates such that the edge of the array is
shape/2*distance'''
halfside = shape/2 * distance
return np.linspace(-halfside+distance/2, halfside-distance/2, shape)
all_obs = []
for file in data_filenames:
obs = rve.Observation.load(file)
obs = obs.to_double_precision()
obs = obs.average_stokesi() # FIXME: Needs to be adjusted for polarization
all_obs.append(obs)
cfg = configparser.ConfigParser()
cfg.read("multi_obs.cfg")
sky, sky_diffuse_operators = rve.sky_model_diffuse(cfg['sky'])
center_ra = cfg['sky']['image center ra']
center_dec = cfg['sky']['image center dec']
center_frame = cfg['sky']['image center frame']
npix = cfg['sky']['space npix x']
sdom = sky.target[3]
x_direction = coords(sdom.shape[0], sdom.distances[0])
y_direction = coords(sdom.shape[1], sdom.distances[1])
sky_coordinates = np.array(np.meshgrid(
x_direction, y_direction, indexing='xy'))
output_directory = f"output/combined_corrected_{npix}"
sky_center = SkyCoord(center_ra, center_dec, frame=center_frame)
beam_directions = {}
for fldid, oo in enumerate(all_obs):
# Calculate phase center
o_phase_center = SkyCoord(oo.direction.phase_center[0]*u.rad,
oo.direction.phase_center[1]*u.rad,
frame=center_frame)
r = sky_center.separation(o_phase_center)
phi = sky_center.position_angle(o_phase_center)
dy = r.to(u.rad).value * np.cos(phi.to(u.rad).value)
dx = r.to(u.rad).value * np.sin(phi.to(u.rad).value)
x = np.sqrt((sky_coordinates[0] - dx)**2 + (sky_coordinates[1] - dy)**2)
beam = alma_beam_func(D=12.0, d=0.75, freq=oo.freq.mean(), x=x)
beam = ift.makeField(sdom, beam)
# beam_operator = ift.DiagonalOperator(beam)
beam_direction = f'fld{fldid}'
beam_directions[beam_direction] = dict(
dx=dx,
dy=dy,
beam=beam
)
# Used for the dtypes
tmp_sky = sky(ift.from_random(sky.domain))
SKY_BEAMER = SkyBeamer(sky.target[3], beam_directions=beam_directions)
REDUCER = ift.JaxLinearOperator(
sky.target,
SKY_BEAMER.domain,
lambda x: x[0, 0, 0], # FIXME: How to do this on all polarizations??
domain_dtype=tmp_sky.dtype
)
def build_response(field_key, dx, dy, obs, sky_dtype):
RADIO_RESPONSE = rve.SingleResponse(
SKY_BEAMER.target[field_key],
obs.uvw,
obs.freq,
do_wgridding=False,
epsilon=1e-3,
# center of the dirty image relative to the phase_center
# (in projected radians)
center_x=dx,
center_y=dy,
)
FIELD_EXTRACTOR = ift.JaxLinearOperator(
SKY_BEAMER.target,
RADIO_RESPONSE.domain,
lambda x: x[field_key],
domain_dtype={k: sky_dtype for k, v in SKY_BEAMER.target.items()}
)
# FIXME: This is a hack for making stokes I work, see above
UPCAST_TO_STOKES_I = ift.JaxLinearOperator(
RADIO_RESPONSE.target,
obs.vis.domain,
lambda x: x[None],
domain_dtype=obs.vis.dtype
)
# FIXME: This response operator makes unnecessary multiple beam pattern
# multiplications. As the SKY_BEAMER calculates the beam pattern for all fields.
# Hence, we throw away all other fields.
return UPCAST_TO_STOKES_I @ RADIO_RESPONSE @ FIELD_EXTRACTOR @ SKY_BEAMER @ REDUCER
def build_energy(response, obs):
return rve.DiagonalGaussianLikelihood(
data=obs.vis,
inverse_covariance=obs.weight,
mask=obs.mask
) @ response
responses = []
energies = []
for kk, vv, obs in zip(beam_directions.keys(), beam_directions.values(), all_obs):
R = build_response(kk, vv['dx'], vv['dy'], obs, tmp_sky.dtype)
responses.append(R)
energies.append(build_energy(R, obs))
for ii in range(1):
rnd = ift.from_random(sky.domain)
f = sky(rnd)
tot_response_adjoint = np.zeros_like(f.val)
for rr, oo in zip(responses, all_obs):
tot_response_adjoint += rr.adjoint(oo.vis).val
plt.imshow(tot_response_adjoint[0, 0, 0], origin='lower')
plt.show()
# ff = (SKY_BEAMER @ REDUCER)(f).val
# fig, axes = plt.subplots(2, 2)
# ((a00, a01), (a10, a11)) = axes
# i00 = a00.imshow(f.val[0, 0, 0], origin='lower',
# vmin=f.val.min(), vmax=f.val.max())
# i01 = a01.imshow(tot_response_adjoint[0, 0, 0], origin='lower')
# i10 = a10.imshow(f.val[0, 0, 0, xi:xe, yi:ye], origin='lower',
# vmin=f.val.min(), vmax=f.val.max())
# i11 = a11.imshow(ff['fld0'], origin='lower',
# vmin=f.val.min(), vmax=f.val.max())
# plt.colorbar(i00, ax=a00)
# plt.colorbar(i01, ax=a01)
# plt.colorbar(i10, ax=a10)
# plt.colorbar(i11, ax=a11)
# a00.scatter([xi, xi, xe, xe], [yi, ye, yi, ye], c='orange')
# plt.show()
# lh = rve.ImagingLikelihood(obs, sky, 1e-7, False, nthreads=1)
lh = reduce(lambda x, y: x+y, energies)
lh = lh @ sky
def callback(samples, i):
sky_mean = samples.average(sky)
plt.imshow(sky_mean.val[0, 0, 0, :, :].T, origin="lower", norm=LogNorm())
plt.colorbar()
if master:
plt.savefig(f"{output_directory}/resovle_iteration_{i}.png")
plt.close()
ic_sampling_early = ift.AbsDeltaEnergyController(
name="Sampling (linear)", deltaE=0.05, iteration_limit=100
)
ic_sampling_late = ift.AbsDeltaEnergyController(
name="Sampling (linear)", deltaE=0.05, iteration_limit=500
)
ic_newton_early = ift.AbsDeltaEnergyController(
name="Newton", deltaE=0.5, convergence_level=2, iteration_limit=10
)
ic_newton_late = ift.AbsDeltaEnergyController(
name="Newton", deltaE=0.5, convergence_level=2, iteration_limit=30
)
minimizer_early = ift.NewtonCG(ic_newton_early)
minimizer_late = ift.NewtonCG(ic_newton_late)
n_iterations = 7
def ic_sampling(i): return ic_sampling_early if i < 15 else ic_sampling_late
def minimizer(i): return minimizer_early if i < 15 else minimizer_late
def n_samples(i): return 2 if i < 7 else 4
samples = ift.optimize_kl(
lh,
n_iterations,
n_samples,
minimizer,
ic_sampling,
None,
output_directory=output_directory,
comm=comm,
inspect_callback=callback,
export_operator_outputs=dict(
logdiffuse_stokesI=sky_diffuse_operators['logdiffuse stokesI']),
resume=True
)
sky_mean = samples.average(sky)
rve.ubik_tools.field2fits(sky_mean, join(
output_directory, f'sky_reso_{npix}.fits'))
......@@ -99,7 +99,11 @@ def build_response(field_key, obs, sky_dtype):
obs.uvw,
obs.freq,
do_wgridding=False,
epsilon=1e-3
epsilon=1e-3,
# center_x=center of the dirty image relative to the phase_center (in projected radians)
# center_x=, # FIXME: How to implement the offset shifts correctly.
# center_y=
)
FIELD_EXTRACTOR = ift.JaxLinearOperator(
......
[sky]
freq mode = single
polarization=I
space npix x = 512
space npix y = 512
space npix x = 128
space npix y = 128
space fov x = 25as
space fov y = 25as
image center ra = 13h37m00.75s
......
# https://casaguides.nrao.edu/index.php?title=Simalma_CASA_6.5.4
# Set simobserve to default parameters
default("simobserve")
# Our project name will be m51c, and all simulation products will be placed in a subdirectory m51c/
......
......@@ -271,7 +271,8 @@ def alma_beam_func(D, d, freq, x, use_cache=False):
if not use_cache:
return _compute_alma_beam(D, d, freq, x)
iden = "_".join([str(ll) for ll in [D, d, freq]] + [str(ll) for ll in x.shape])
iden = "_".join([str(ll) for ll in [D, d, freq]] + [str(ll)
for ll in x.shape])
fname = f".beamcache{iden}.npy"
try:
return np.load(fname)
......@@ -282,6 +283,15 @@ def alma_beam_func(D, d, freq, x, use_cache=False):
def _compute_alma_beam(D, d, freq, x):
"""Compute the theoretical primary beam pattern.
Parameters
----------
D = Dish diameter (in m)
d = blockage diameter (in m)
freq = frequency of the observation (in 1/s, Hz)
x = sin(theta) = angle from pointing on sky (theta in rad)
"""
import scipy.special as sc
a = freq / SPEEDOFLIGHT
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment