diff --git a/resolve/__init__.py b/resolve/__init__.py index fadc803f26d1ea7207a44529460770d600f58865..d69bc4b0d0fee51d513924dd75a0ccf391164df4 100644 --- a/resolve/__init__.py +++ b/resolve/__init__.py @@ -1,9 +1,12 @@ +from .antenna_positions import AntennaPositions from .calibration import calibration_distribution from .constants import * -from .fits import field2fits +from .direction import * +from .fits import field2fits, fits2field from .global_config import * from .likelihood import * from .minimization import Minimization, MinimizationState, simple_minimize +from .mosaicing import * from .mpi import onlymaster from .mpi_operators import * from .ms_import import ms2observations, ms_n_spectral_windows @@ -13,11 +16,11 @@ from .multi_frequency.operators import ( MfWeightingInterpolation, WienerIntegrations, ) -from .observation import Observation, tmin_tmax, unique_antennas, unique_times +from .observation import * from .plotter import MfPlotter, Plotter from .points import PointInserter -from .polarization import polarization_matrix_exponential -from .primary_beam import vla_beam +from .polarization import Polarization, polarization_matrix_exponential +from .primary_beam import * from .response import MfResponse, ResponseDistributor, StokesIResponse, SingleResponse from .simple_operators import * from .util import ( diff --git a/resolve/direction.py b/resolve/direction.py index cc1480548d84d21bd06a9eb8586424959f7ae078..5d2acdad4fd64b8558a448caeddfa4f9a1e8de08 100644 --- a/resolve/direction.py +++ b/resolve/direction.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Copyright(C) 2020 Max-Planck-Society +# Copyright(C) 2020-2021 Max-Planck-Society # Author: Philipp Arras from .util import compare_attributes, my_asserteq @@ -42,3 +42,40 @@ class Direction: if not isinstance(other, Direction): return False return compare_attributes(self, other, ("_pc", "_e")) + + +class Directions: + def __init__(self, phase_centers, equinox): + assert phase_centers.ndim == 2 + assert phase_centers.shape[1] == 2 + self._pc = phase_centers + self._e = float(equinox) + + @property + def phase_centers(self): + return self._pc + + @property + def equinox(self): + return self._e + + def __repr__(self): + return f"Directions({self._pc}, equinox={self._e})" + + def to_list(self): + return [self._pc, self._e] + + def __len__(self): + return self._pc.shape[0] + + @staticmethod + def from_list(lst): + return Directions(lst[0], lst[1]) + + def __eq__(self, other): + if not isinstance(other, Direction): + return False + return compare_attributes(self, other, ("_pc", "_e")) + + def __getitem__(self, slc): + return Directions(self._pc[slc], self._e) diff --git a/resolve/fits.py b/resolve/fits.py index 4f2ecc6d3abc686c0adfb23da8def5932bf3377f..9370015924285f4f31d1f776bf22a9586591c077 100644 --- a/resolve/fits.py +++ b/resolve/fits.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Copyright(C) 2019-2020 Max-Planck-Society +# Copyright(C) 2019-2021 Max-Planck-Society # Author: Philipp Arras import time @@ -7,6 +7,9 @@ from os.path import splitext import numpy as np +import nifty7 as ift + +from .constants import DEG2RAD from .mpi import onlymaster @@ -41,35 +44,68 @@ def field2fits(field, file_name, overwrite, direction=None): base, ext = splitext(file_name) hdulist.writeto(base + ext, overwrite=overwrite) - # @staticmethod - # def make_from_file(file_name): - # with pyfits.open(file_name) as hdu_list: - # lst = hdu_list[0] - # pcx = lst.header['CRVAL1']/180*np.pi - # pcy = lst.header['CRVAL2']/180*np.pi - # equ = lst.header['EQUINOX'] - # return FitsWriter([pcx, pcy], equ) - # @staticmethod - # def fits2field(file_name, ignore_units=False, from_wsclean=False): - # with pyfits.open(file_name) as hdu_list: - # image_data = np.squeeze(hdu_list[0].data).astype(np.float64) - # head = hdu_list[0].header - # dstx = abs(head['CDELT1']*np.pi/180) - # dsty = abs(head['CDELT2']*np.pi/180) - # if not ignore_units: - # if head['BUNIT'] == 'JY/BEAM': - # fac = np.pi/4/np.log(2) - # scale = fac*head['BMAJ']*head['BMIN']*(np.pi/180)**2 - # elif head['BUNIT'] == 'JY/PIXEL': - # scale = dstx*dsty - # else: - # scale = 1 - # image_data /= scale - # if from_wsclean: - # image_data = image_data[::-1].T[:, :-1] - # image_data = np.pad(image_data, ((0, 0), (1, 0)), mode='constant') - # else: - # image_data = image_data.T[:, ::-1] - # dom = ift.RGSpace(image_data.shape, (dstx, dsty)) - # return ift.makeField(dom, image_data) +def fits2field(file_name, ignore_units=False, from_wsclean=False): + import astropy.io.fits as pyfits + + with pyfits.open(file_name) as hdu_list: + image_data = hdu_list[0].data.astype(np.float64) + assert image_data.shape[0] == 1 # Only one Stokes component + image_data = image_data[0] + head = hdu_list[0].header + assert head["CUNIT1"].strip() == "deg" + assert head["CUNIT2"].strip() == "deg" + assert head["CUNIT3"].strip() == "Hz" + refs = [] + refs.append([float(head["CRVAL3"]), int(head["CRPIX3"]), head["CDELT3"]]) + refs.append( + [ + float(head["CRVAL2"]) * DEG2RAD, + int(head["CRPIX2"]), + head["CDELT2"] * DEG2RAD, + ] + ) + refs.append( + [ + float(head["CRVAL1"]) * DEG2RAD, + int(head["CRPIX1"]), + head["CDELT1"] * DEG2RAD, + ] + ) + + if not ignore_units: + if head["BUNIT"].upper() == "JY/BEAM": + fac = np.pi / 4 / np.log(2) + scale = fac * head["BMAJ"] * head["BMIN"] * (np.pi / 180) ** 2 + elif head["BUNIT"].upper() == "JY/PIXEL": + scale = abs(refs[0][2] * refs[1][2]) + else: + scale = 1 + image_data /= scale + + # Convert CASA conventions to resolve conventions + inds = 0, 2, 1 + image_data = np.transpose(image_data, inds) + refs = [refs[ii] for ii in inds] + refs[1][2] *= -1 + + for ii, (_, mypx, mydst) in enumerate(refs): + if mydst == 0.0: + raise RuntimeError + if mydst > 0: + continue + image_data = np.flip(image_data, ii) + refs[ii][2] *= -1 + # FIXME Assume pixel counting start at 0. Maybe also 1? + refs[ii][1] = image_data.shape[ii] - mypx + + refval = tuple(refs[ii][0] for ii in range(3)) + refpx = tuple(refs[ii][1] for ii in range(3)) + dsts = tuple(refs[ii][2] for ii in range(3)) + dom = ( + ift.RGSpace(image_data.shape[0], dsts[0]), + ift.RGSpace(image_data.shape[1:], dsts[1:]), + ) + refval = tuple(refval[ii] - refpx[ii] * dsts[ii] for ii in range(3)) + del refpx + return ift.makeField(dom, image_data), refval diff --git a/resolve/likelihood.py b/resolve/likelihood.py index 44c3ad5e1b823a372e88545177bd7fc62f9a0f02..49d1b4b14ec106f3ee1ecec1dd369a21e45c9d34 100644 --- a/resolve/likelihood.py +++ b/resolve/likelihood.py @@ -2,19 +2,21 @@ # Copyright(C) 2020 Max-Planck-Society # Author: Philipp Arras +from functools import reduce +from operator import add + import numpy as np import nifty7 as ift from .observation import Observation from .response import FullPolResponse, MfResponse, StokesIResponse -from .util import my_assert_isinstance, my_asserteq, my_assert +from .util import my_assert, my_assert_isinstance, my_asserteq def _get_mask(observation): # Only needed for variable covariance gaussian energy my_assert_isinstance(observation, Observation) - vis = observation.vis flags = observation.flags if not np.any(flags): @@ -23,6 +25,21 @@ def _get_mask(observation): return mask, mask(vis), mask(observation.weight) +def get_mask_multi_field(weight): + assert isinstance(weight, ift.MultiField) + op = [] + for kk, ww in weight.items(): + flags = ww.val == 0.0 + if np.any(flags): + myop = ift.MaskOperator(ift.makeField(ww.domain, flags)) + else: + myop = ift.ScalingOperator(ww.domain, 1.0) + op.append(myop.ducktape(kk).ducktape_left(kk)) + op = reduce(add, op) + assert op.domain == weight.domain + return op + + def _Likelihood(operator, normalized_residual_operator): my_assert_isinstance(operator, ift.Operator) my_asserteq(operator.target, ift.DomainTuple.scalar_domain()) @@ -91,7 +108,7 @@ def ImagingLikelihood( Parameters ---------- - observation : Observation + observation : Observation or dict(Observation) Observation object from which observation.vis and potentially observation.weight is used for computing the likelihood. @@ -109,11 +126,12 @@ def ImagingLikelihood( calibration_operator : Operator Optional. Target needs to be the same as observation.vis. """ - my_assert_isinstance(observation, Observation) my_assert_isinstance(sky_operator, ift.Operator) sdom = sky_operator.target - if isinstance(sdom, ift.MultiDomain): + mosaicing = isinstance(observation, dict) + + if isinstance(sdom, ift.MultiDomain) and not mosaicing: if len(sdom["I"].shape) == 3: raise NotImplementedError( "Polarization and multi-frequency at the same time not supported yet." @@ -121,14 +139,20 @@ def ImagingLikelihood( else: R = FullPolResponse(observation, sky_operator.target) else: - if len(sdom.shape) == 3: + if not mosaicing and len(sdom.shape) == 3: R = MfResponse(observation, sdom[0], sdom[1]) else: R = StokesIResponse(observation, sdom) model_data = R @ sky_operator - if inverse_covariance_operator is None: - return _build_gauss_lh_nres(model_data, observation.vis, observation.weight) + if mosaicing: + vis = ift.MultiField.from_dict({kk: o.vis for kk, o in observation.items()}) + weight = ift.MultiField.from_dict( + {kk: o.weight for kk, o in observation.items()} + ) + else: + vis, weight = observation.vis, observation.weight + return _build_gauss_lh_nres(model_data, vis, weight) return _varcov(observation, model_data, inverse_covariance_operator) diff --git a/resolve/mosaicing/__init__.py b/resolve/mosaicing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a346d91e3cfe4b1d84394ca966f6ddc700c8f647 --- /dev/null +++ b/resolve/mosaicing/__init__.py @@ -0,0 +1,2 @@ +from .sky_slicer import * +from .single_dish_response import * diff --git a/resolve/mosaicing/single_dish_response.py b/resolve/mosaicing/single_dish_response.py new file mode 100644 index 0000000000000000000000000000000000000000..1d584882c42ffa2010a0fc8b694394c371b1d871 --- /dev/null +++ b/resolve/mosaicing/single_dish_response.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# Copyright(C) 2021 Max-Planck-Society +# Author: Philipp Arras + +import numpy as np + +import nifty7 as ift + +from ..observation import SingleDishObservation + + +def SingleDishResponse( + observation, domain, beam_function, global_phase_center, additive_term=None +): + assert isinstance(observation, SingleDishObservation) + domain = ift.makeDomain(domain) + assert len(domain) == 1 + codomain = domain[0].get_default_codomain() + kernel = codomain.get_conv_kernel_from_func(beam_function) + HT = ift.HartleyOperator(domain, codomain) + conv = HT.inverse @ ift.makeOp(kernel) @ HT.scale(domain.total_volume()) + # FIXME Move into tests + fld = ift.from_random(conv.domain) + ift.extra.assert_allclose(conv(fld).integrate(), fld.integrate()) + + pc = observation.pointings.phase_centers.T - np.array(global_phase_center)[:, None] + pc = pc + (np.array(domain.shape) * np.array(domain[0].distances) / 2)[:, None] + # Convention: pointing convention (see also BeamDirection) + pc[0] *= -1 + interp = ift.LinearInterpolator(domain, pc) + bc = ift.ContractionOperator(observation.vis.domain, (0, 2)).adjoint + # NOTE The volume factor above `domain.total_volume()` and the volume factor + # below `domain[0].scalar_dvol` cancel each other. They are left in the + # code such that the convolution leaves the integral invariant. + + convsky = conv.scale(domain[0].scalar_dvol).ducktape("sky") + if additive_term is not None: + convsky = convsky + additive_term + return bc @ interp @ convsky diff --git a/resolve/mosaicing/sky_slicer.py b/resolve/mosaicing/sky_slicer.py new file mode 100644 index 0000000000000000000000000000000000000000..fb951280f683503d0ee01be9e8cf4ee9451169b5 --- /dev/null +++ b/resolve/mosaicing/sky_slicer.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# Copyright(C) 2020 Max-Planck-Society +# Author: Philipp Arras + +import numpy as np + +import nifty7 as ift + + +class SkySlicer(ift.LinearOperator): + """Maps from the total sky domain to individual sky domains and applies the + primary beam pattern. + + Parameters + ---------- + domain : RGSpace + Two-dimensional RG-Space which serves as domain. The distances are + in pseudo-radian. + + beam_directions : dict(key: BeamDirection) + Dictionary of BeamDirection that contains the information of the + directions of the different observations and the beam pattern. + """ + + def __init__(self, domain, beam_directions): + self._bd = dict(beam_directions) + self._domain = ift.makeDomain(domain) + t, b, s = {}, {}, {} + for kk, vv in self._bd.items(): + print("\r" + kk, end="") + t[kk], s[kk], b[kk] = vv.slice_target(self._domain) + print() + self._beams = b + self._slices = s + self._target = ift.makeDomain(t) + self._capability = self.TIMES | self.ADJOINT_TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + x = x.val + if mode == self.TIMES: + res = {} + for kk in self._target.keys(): + res[kk] = x[self._slices[kk]] * self._beams[kk].val + else: + res = np.zeros(self._domain.shape) + for kk, xx in x.items(): + res[self._slices[kk]] += xx * self._beams[kk].val + return ift.makeField(self._tgt(mode), res) + + +class BeamDirection: + """Represent direction information of one pointing. + + Parameters + ---------- + dx : float + Pointing offset in pseudo radian. + dy : float + Pointing offset in pseudo radian. + beam_func : function + Function that takes the two directions on the sky in pseudo radian and + returns beam pattern. + cutoff : float + Relative area of beam pattern that is cut off. 0 corresponds to no + cutoff at all. 1 corresponds to cut off everything. + """ + + def __init__(self, dx, dy, beam_func, cutoff): + self._dx, self._dy = float(dx), float(dy) + self._f = beam_func + self._cutoff = float(cutoff) + assert 0. <= cutoff < 1 + + def slice_target(self, domain): + """ + Parameters + ---------- + domain : RGSpace + Total sky domain. + """ + dom = ift.makeDomain(domain)[0] + dst = np.array(dom.distances) + assert abs(self._dx) < dst[0] * dom.shape[0] + assert abs(self._dy) < dst[1] * dom.shape[1] + npix = np.array(domain.shape) + + xs = np.linspace(0, max(npix * dst), num=4*max(npix)) + # Technically not 100% correct since we integrate circles here. + # The actual cutoff is lower than the one that is induced by the + # following computation. + ys = self._f(xs)*xs + cond = np.cumsum(ys)/np.sum(ys) + np.testing.assert_allclose(cond[-1], 1.) + ind = np.searchsorted(cond, 1.-self._cutoff) + fov = 2*xs[ind] + patch_npix = fov/dst + patch_npix = (np.ceil(patch_npix/2)*2).astype(int) + + # Convention: pointing convention (see also SingleDishResponse) + shift = np.array([-self._dx, self._dy]) + patch_center_unrounded = np.array(npix) / 2 + shift / dst + ipix_patch_center = np.round(patch_center_unrounded).astype(int) + # Or maybe use ceil? + assert np.all(patch_npix % 2 == 0) + mi = (ipix_patch_center - patch_npix / 2).astype(int) + ma = mi + patch_npix + + assert np.all(mi >= 0) + assert np.all(ma < npix) + slc = slice(mi[0], ma[0]), slice(mi[1], ma[1]) + tgt = ift.RGSpace(patch_npix, dst) + assert tgt.shape[0] % 2 == 0 + assert tgt.shape[1] % 2 == 0 + test = np.empty(dom.shape) + assert test[slc].shape == tgt.shape + + # Create coordinate field + mgrid = np.mgrid[: patch_npix[0], : patch_npix[1]].astype(float) + mgrid -= patch_npix[..., None, None] / 2 + # FIXME Add subpixel offset + # subpixel_offset = ipix_patch_center - patch_center_unrounded + # assert np.all(subpixel_offset < 1.0) + # assert np.all(subpixel_offset >= 0.0) + mgrid *= dst[..., None, None] + beam = ift.makeField(tgt, self._f(np.linalg.norm(mgrid, axis=0))) + return tgt, slc, beam diff --git a/resolve/observation.py b/resolve/observation.py index 69735a556c9ccc6ab33efe716f6bb4192b86abe4..3d2a0a28c098d1fed1493cec155a2d40900cc316 100644 --- a/resolve/observation.py +++ b/resolve/observation.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Copyright(C) 2019-2020 Max-Planck-Society +# Copyright(C) 2019-2021 Max-Planck-Society # Author: Philipp Arras import numpy as np @@ -8,13 +8,138 @@ import nifty7 as ift from .antenna_positions import AntennaPositions from .constants import SPEEDOFLIGHT -from .direction import Direction +from .direction import Direction, Directions from .mpi import onlymaster from .polarization import Polarization from .util import compare_attributes, my_assert, my_assert_isinstance, my_asserteq -class Observation: +class _Observation: + @property + def vis(self): + dom = [ift.UnstructuredDomain(ss) for ss in self._vis.shape] + return ift.makeField(dom, self._vis) + + @property + def weight(self): + dom = [ift.UnstructuredDomain(ss) for ss in self._weight.shape] + return ift.makeField(dom, self._weight) + + @property + def freq(self): + return self._freq + + @property + def polarization(self): + return self._polarization + + @property + def direction(self): + return self._direction + + @property + def npol(self): + return self._vis.shape[0] + + @property + def nrow(self): + return self._vis.shape[1] + + @property + def nfreq(self): + return self._vis.shape[2] + + def apply_flags(self, arr): + return arr[self._weight != 0.0] + + @property + def flags(self): + return self._weight == 0.0 + + @property + def mask(self): + return self._weight > 0.0 + + def max_snr(self): + return np.max(np.abs(self.apply_flags(self._vis * np.sqrt(self._weight)))) + + def fraction_useful(self): + return self.apply_flags(self._weight).size / self._weight.size + + +class SingleDishObservation(_Observation): + def __init__(self, pointings, data, weight, polarization, freq): + my_assert_isinstance(pointings, Directions) + my_assert_isinstance(polarization, Polarization) + my_assert(data.dtype in [np.float32, np.float64]) + nrows = len(pointings) + my_asserteq(weight.shape, data.shape) + my_asserteq(data.shape, (len(polarization), nrows, len(freq))) + my_asserteq(nrows, data.shape[1]) + + data.flags.writeable = False + weight.flags.writeable = False + + my_assert(np.all(weight >= 0.0)) + my_assert(np.all(np.isfinite(data))) + my_assert(np.all(np.isfinite(weight))) + + self._pointings = pointings + self._vis = data + self._weight = weight + self._polarization = polarization + self._freq = freq + + @onlymaster + def save(self, file_name, compress): + p = self._pointings.to_list() + dct = dict( + vis=self._vis, + weight=self._weight, + freq=self._freq, + polarization=self._polarization.to_list(), + pointings0=p[0], + pointings1=p[1], + ) + f = np.savez_compressed if compress else np.savez + f(file_name, **dct) + + @staticmethod + def load(file_name): + dct = dict(np.load(file_name)) + pol = Polarization.from_list(dct["polarization"]) + pointings = Directions.from_list([dct["pointings0"], dct["pointings1"]]) + return SingleDishObservation( + pointings, dct["vis"], dct["weight"], pol, dct["freq"] + ) + + def __eq__(self, other): + if not isinstance(other, Observation): + return False + if ( + self._vis.dtype != other._vis.dtype + or self._weight.dtype != other._weight.dtype + ): + return False + return compare_attributes( + self, other, ("_polarization", "_freq", "_pointings", "_vis", "_weight") + ) + + def __getitem__(self, slc): + return SingleDishObservation( + self._pointings[slc], + self._vis[:, slc], + self._weight[:, slc], + self._polarization, + self._freq, + ) + + @property + def pointings(self): + return self._pointings + + +class Observation(_Observation): """Observation data This class contains all the data and information about an observation. @@ -48,6 +173,8 @@ class Observation: my_asserteq(vis.shape, (len(polarization), nrows, len(freq))) my_asserteq(nrows, vis.shape[1]) my_assert(np.all(weight >= 0.0)) + my_assert(np.all(np.isfinite(vis))) + my_assert(np.all(np.isfinite(weight))) vis.flags.writeable = False weight.flags.writeable = False @@ -59,23 +186,6 @@ class Observation: self._freq = freq self._direction = direction - def apply_flags(self, arr): - return arr[self._weight != 0.0] - - @property - def flags(self): - return self._weight == 0.0 - - @property - def mask(self): - return self._weight > 0.0 - - def max_snr(self): - return np.max(np.abs(self.apply_flags(self._vis * np.sqrt(self._weight)))) - - def fraction_useful(self): - return self.apply_flags(self._weight).size / self._weight.size - @onlymaster def save(self, file_name, compress): dct = dict( @@ -170,40 +280,6 @@ class Observation: uvlen = np.linalg.norm(self.uvw, axis=1) return np.outer(uvlen, self._freq / SPEEDOFLIGHT) - @property - def vis(self): - dom = [ift.UnstructuredDomain(ss) for ss in self._vis.shape] - return ift.makeField(dom, self._vis) - - @property - def weight(self): - dom = [ift.UnstructuredDomain(ss) for ss in self._weight.shape] - return ift.makeField(dom, self._weight) - - @property - def freq(self): - return self._freq - - @property - def polarization(self): - return self._polarization - - @property - def direction(self): - return self._direction - - @property - def npol(self): - return self._vis.shape[0] - - @property - def nrow(self): - return self._vis.shape[1] - - @property - def nfreq(self): - return self._vis.shape[2] - def tmin_tmax(*args): """ diff --git a/resolve/primary_beam.py b/resolve/primary_beam.py index 71f8ebaa7979974be7d24d9265883e85faf26281..e13e7738351aeac465b846d559259432762e4d55 100644 --- a/resolve/primary_beam.py +++ b/resolve/primary_beam.py @@ -3,10 +3,11 @@ # Author: Philipp Arras import numpy as np +import scipy.special as sc import nifty7 as ift -from .constants import ARCMIN2RAD +from .constants import ARCMIN2RAD, SPEEDOFLIGHT from .util import my_assert @@ -176,3 +177,31 @@ def vla_beam(domain, freq): beam = rweight * upper + (1 - rweight) * lower beam[beam < 0] = 0 return ift.makeOp(ift.makeField(dom, beam)) + + +def alma_beam_func(D, d, freq, x, use_cache=False): + assert isinstance(x, np.ndarray) + assert x.ndim < 3 + assert np.max(np.abs(x)) < np.pi / np.sqrt(2) + + if not use_cache: + return _compute_alma_beam(D, d, freq, x) + + iden = "_".join([str(ll) for ll in [D, d, freq]] + [str(ll) for ll in x.shape]) + fname = f".beamcache{iden}.npy" + try: + return np.load(fname) + except FileNotFoundError: + arr = _compute_alma_beam(D, d, freq, x) + np.save(fname, arr) + + +def _compute_alma_beam(D, d, freq, x): + a = freq / SPEEDOFLIGHT + b = d / D + x = np.pi * a * D * x + mask = x == 0.0 + x[mask] = 1 + sol = 2 / (x * (1 - b ** 2)) * (sc.jn(1, x) - b * sc.jn(1, x * b)) + sol[mask] = 1 + return sol * sol diff --git a/resolve/response.py b/resolve/response.py index 89dbe37499bea61dc3ca1c2c314fa6f4d36da032..43fade8b86d0cd009aaaac3b7769482d93619f2b 100644 --- a/resolve/response.py +++ b/resolve/response.py @@ -2,11 +2,15 @@ # Copyright(C) 2019-2020 Max-Planck-Society # Author: Philipp Arras +from functools import reduce +from operator import add + import numpy as np from ducc0.wgridder.experimental import dirty2vis, vis2dirty import nifty7 as ift +from .constants import SPEEDOFLIGHT from .global_config import epsilon, nthreads, wgridding from .multi_frequency.irg_space import IRGSpace from .observation import Observation @@ -14,6 +18,14 @@ from .util import my_assert, my_assert_isinstance, my_asserteq def StokesIResponse(observation, domain): + if isinstance(observation, dict): + # TODO Possibly add subpixel offset here + d = ift.MultiDomain.make(domain) + res = ( + StokesIResponse(o, d[kk]).ducktape(kk).ducktape_left(kk) + for kk, o in observation.items() + ) + return reduce(add, res) my_assert_isinstance(observation, Observation) domain = ift.DomainTuple.make(domain) my_asserteq(len(domain), 1) @@ -227,6 +239,7 @@ class SingleResponse(ift.LinearOperator): self._target_dtype = np.complex64 if single_precision else np.complex128 self._domain_dtype = np.float32 if single_precision else np.float64 self._verbt, self._verbadj = verbose, verbose + self._ofac = None def apply(self, x, mode): self._check_input(x, mode) @@ -239,6 +252,9 @@ class SingleResponse(ift.LinearOperator): args1 = {"dirty": x} if self._verbt: args1["verbosity"] = True + print( + f"\nINFO: Oversampling factors in response: {self.oversampling_factors()}\n" + ) self._verbt = False f = dirty2vis # FIXME Use vis_out keyword of wgridder @@ -252,6 +268,9 @@ class SingleResponse(ift.LinearOperator): } if self._verbadj: args1["verbosity"] = True + print( + f"\nINFO: Oversampling factors in response: {self.oversampling_factors()}\n" + ) self._verbadj = False f = vis2dirty res = ift.makeField(self._tgt(mode), f(**self._args, **args1) * self._vol) @@ -259,3 +278,16 @@ class SingleResponse(ift.LinearOperator): res.dtype, self._target_dtype if mode == self.TIMES else self._domain_dtype ) return res + + def oversampling_factors(self): + if self._ofac is not None: + return self._ofac + maxuv = ( + np.max(np.abs(self._args["uvw"][:, 0:2]), axis=0) + * np.max(self._args["freq"]) + / SPEEDOFLIGHT + ) + hspace = self._domain[0].get_default_codomain() + hvol = np.array(hspace.shape) * np.array(hspace.distances) / 2 + self._ofac = hvol / maxuv + return self._ofac diff --git a/resolve/simple_operators.py b/resolve/simple_operators.py index 7a900e9f0ac283ff5c3af6a21e9b771a35ff8f41..ad834f4c7b5b8b1fbc07ef09416bc88e2bcfe86c 100644 --- a/resolve/simple_operators.py +++ b/resolve/simple_operators.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: GPL-3.0-or-later -# Copyright(C) 2019-2020 Max-Planck-Society +# Copyright(C) 2019-2021 Max-Planck-Society # Author: Philipp Arras +from functools import reduce +from operator import add + import numpy as np import nifty7 as ift -from .util import my_assert_isinstance, my_asserteq +from .util import my_assert, my_assert_isinstance, my_asserteq class AddEmptyDimension(ift.LinearOperator): @@ -64,3 +67,49 @@ class AddEmptyDimensionAtEnd(ift.LinearOperator): else: x = x.val[..., 0] return ift.makeField(self._tgt(mode), x) + + +class KeyPrefixer(ift.LinearOperator): + def __init__(self, domain, prefix): + self._domain = ift.MultiDomain.make(domain) + self._target = ift.MultiDomain.make( + {prefix + kk: vv for kk, vv in self._domain.items()} + ) + self._capability = self.TIMES | self.ADJOINT_TIMES + self._prefix = prefix + + def apply(self, x, mode): + self._check_input(x, mode) + if mode == self.TIMES: + res = {self._prefix + kk: vv for kk, vv in x.items()} + else: + res = {kk[len(self._prefix) :]: vv for kk, vv in x.items()} + return ift.MultiField.from_dict(res) + + def __repr__(self): + return f"{self.domain.keys()} -> {self.target.keys()}" + + +def MultiDomainVariableCovarianceGaussianEnergy(data, signal_response, invcov): + from .likelihood import get_mask_multi_field + + my_asserteq(data.domain, signal_response.target, invcov.target) + my_assert_isinstance(data.domain, ift.MultiDomain) + my_assert_isinstance(signal_response.domain, ift.MultiDomain) + my_assert(ift.is_operator(invcov)) + my_assert(ift.is_operator(signal_response)) + res = [] + invcovfld = invcov(ift.full(invcov.domain, 1.0)) + mask = get_mask_multi_field(invcovfld) + data = mask(data) + signal_response = mask @ signal_response + invcov = mask @ invcov + for kk in data.keys(): + res.append( + ift.VariableCovarianceGaussianEnergy( + data.domain[kk], "resi" + kk, "icov" + kk, data[kk].dtype + ) + ) + resi = KeyPrefixer(data.domain, "resi") @ ift.Adder(data, True) @ signal_response + invcov = KeyPrefixer(data.domain, "icov") @ invcov + return reduce(add, res) @ (resi + invcov) diff --git a/test/test_general.py b/test/test_general.py index 3511ce1280ae05ca0eea04589555748f2971d67c..61a17224523f1d86c0e3b9fc87f651cbe994806f 100644 --- a/test/test_general.py +++ b/test/test_general.py @@ -2,8 +2,9 @@ # Copyright(C) 2020 Max-Planck-Society # Author: Philipp Arras -import numpy as np from os.path import join + +import numpy as np import pytest import nifty7 as ift @@ -310,3 +311,11 @@ def test_intop(): dom = ift.RGSpace((12, 12)) op = rve.WienerIntegrations(freqdomain, dom) ift.extra.check_linear_operator(op) + + +def test_prefixer(): + op = rve.KeyPrefixer( + ift.MultiDomain.make({"a": ift.UnstructuredDomain(10), "b": ift.RGSpace(190)}), + "invcov_inp", + ).adjoint + ift.extra.check_linear_operator(op)