diff --git a/resolve/irg_space.py b/resolve/irg_space.py index 455e4ca8ed9455ec1106821e79f91e99011766ad..23ced88d122b4db62c8cb97e8de8eb8ef65b00dd 100644 --- a/resolve/irg_space.py +++ b/resolve/irg_space.py @@ -28,13 +28,44 @@ class IRGSpace(ift.StructuredDomain): _needed_for_hash = ["_coordinates"] - def __init__(self, coordinates): + def __init__(self, coordinates, binbounds=None): bb = np.array(coordinates) if bb.ndim != 1: raise TypeError if np.any(np.diff(bb) <= 0.0): - raise ValueError("Coordinates must be sorted and strictly ascending") + raise ValueError( + "Coordinates must be sorted and strictly ascending") self._coordinates = tuple(bb) + self._binbounds = binbounds + + def binbounds(self): + if self._binbounds is not None: + return self._binbounds + if len(self._coordinates) == 1: + return np.array([-np.inf, np.inf]) + + c = np.array(self._coordinates) + bounds = np.empty(self.size + 1) + bounds[1:-1] = c[:-1] + 0.5*np.diff(c) + bounds[0] = c[0] - 0.5*(c[1] - c[0]) + bounds[-1] = c[-1] + 0.5*(c[-1] - c[-2]) + return bounds + + @classmethod + def from_binbounds(cls, binbounds: list[float]): + '''Builds a IRGSpace such that the passed binbounds can be returned.''' + + binbounds = np.array(binbounds) + + if len(binbounds) < 2: + raise ValueError("binbounds must have at least 2 elements") + + if np.any(np.diff(binbounds) <= 0): + raise ValueError("binbounds must be strictly ascending") + + coordinates = np.array([(binbounds[ii+1]-binbounds[ii])/2+binbounds[ii] + for ii in range(len(binbounds)-1)]) + return cls(coordinates, binbounds) def __repr__(self): return f"IRGSpace(coordinates={self._coordinates})" @@ -61,16 +92,6 @@ class IRGSpace(ift.StructuredDomain): """Assume that the coordinates are the center of symmetric pixels.""" return np.diff(self.binbounds()) - def binbounds(self): - if len(self._coordinates) == 1: - return np.array([-np.inf, np.inf]) - c = np.array(self._coordinates) - bounds = np.empty(self.size + 1) - bounds[1:-1] = c[:-1] + 0.5*np.diff(c) - bounds[0] = c[0] - 0.5*(c[1] - c[0]) - bounds[-1] = c[-1] + 0.5*(c[-1] - c[-2]) - return bounds - @property def distances(self): return np.diff(self._coordinates) diff --git a/resolve/sky_model.py b/resolve/sky_model.py index 60468998568c6624d9db69b23b24e651f716ccd0..e3236d8191fe0964c91220b8db72e3bcdd8fdd6c 100644 --- a/resolve/sky_model.py +++ b/resolve/sky_model.py @@ -48,7 +48,8 @@ def sky_model_diffuse(cfg, observations=[], nthreads=1): op, aa = _multi_freq_logsky_cfm(cfg, sdom, pol_lbl) elif cfg["freq mode"] == "iwp": freq = _get_frequencies(cfg, observations) - op, aa = _multi_freq_logsky_integrated_wiener_process(cfg, sdom, pol_lbl, freq) + op, aa = _multi_freq_logsky_integrated_wiener_process( + cfg, sdom, pol_lbl, freq) else: raise RuntimeError logsky[lbl] = op @@ -60,7 +61,8 @@ def sky_model_diffuse(cfg, observations=[], nthreads=1): tgt = default_sky_domain(pdom=pdom, fdom=fdom, sdom=sdom) logsky = reduce(add, (oo.ducktape_left(lbl) for lbl, oo in logsky.items())) - mexp = polarization_matrix_exponential_mf2f(logsky.target, nthreads=nthreads) + mexp = polarization_matrix_exponential_mf2f( + logsky.target, nthreads=nthreads) sky = mexp @ logsky sky = sky.ducktape_left(tgt) @@ -84,29 +86,37 @@ def sky_model_points(cfg, observations=[], nthreads=1): alpha = cfg.getfloat("point sources alpha") q = cfg.getfloat("point sources q") - inserter = PointInserter(default_sky_domain(pdom=pdom, sdom=sdom), ppos) + inserter = PointInserter( + default_sky_domain(pdom=pdom, sdom=sdom), ppos) if pdom.labels_eq("I"): - points = ift.InverseGammaOperator(inserter.domain, alpha=alpha, q=q/sdom.scalar_dvol) + points = ift.InverseGammaOperator( + inserter.domain, alpha=alpha, q=q/sdom.scalar_dvol) points = points.ducktape("points") elif pdom.labels_eq(["I", "Q", "U"]) or pdom.labels_eq(["I", "Q", "U", "V"]): points_domain = inserter.domain[-1] npoints = points_domain.size - i = ift.InverseGammaOperator(points_domain, alpha=alpha, q=q/sdom.scalar_dvol).log().ducktape("points I") - q = ift.NormalTransform(cfg["point sources stokesq log mean"], cfg["point sources stokesq log stddev"], "points Q", npoints) - u = ift.NormalTransform(cfg["point sources stokesu log mean"], cfg["point sources stokesu log stddev"], "points U", npoints) + i = ift.InverseGammaOperator( + points_domain, alpha=alpha, q=q/sdom.scalar_dvol).log().ducktape("points I") + q = ift.NormalTransform(cfg["point sources stokesq log mean"], + cfg["point sources stokesq log stddev"], "points Q", npoints) + u = ift.NormalTransform(cfg["point sources stokesu log mean"], + cfg["point sources stokesu log stddev"], "points U", npoints) i = i.ducktape_left("I") q = q.ducktape_left("Q") u = u.ducktape_left("U") polsum = i + q + u if pdom.labels_eq(["I", "Q", "U", "V"]): - v = ift.NormalTransform(cfg["point sources stokesv log mean"], cfg["point sources stokesv log stddev"], "points V", npoints) + v = ift.NormalTransform(cfg["point sources stokesv log mean"], + cfg["point sources stokesv log stddev"], "points V", npoints) v = v.ducktape_left("V") polsum = polsum + v - points = polarization_matrix_exponential_mf2f(polsum.target, nthreads=nthreads) @ polsum + points = polarization_matrix_exponential_mf2f( + polsum.target, nthreads=nthreads) @ polsum points = points.ducktape_left(inserter.domain) else: - raise NotImplementedError(f"single_frequency_sky does not support point sources on {pdom.labels} (yet?)") + raise NotImplementedError( + f"single_frequency_sky does not support point sources on {pdom.labels} (yet?)") additional["point_list"] = points sky = inserter @ points @@ -124,7 +134,8 @@ def sky_model_points(cfg, observations=[], nthreads=1): inserter = PointInserter(sky.target, ppos) udom = inserter.domain[-1] - p_i0 = ift.InverseGammaOperator(udom, alpha=alpha, q=q/sdom.scalar_dvol) + p_i0 = ift.InverseGammaOperator( + udom, alpha=alpha, q=q/sdom.scalar_dvol) p_i0 = p_i0.ducktape("points") p_alpha = _parse_or_none(cfg, "point sources alpha") @@ -142,9 +153,11 @@ def sky_model_points(cfg, observations=[], nthreads=1): log_fdom = IRGSpace(np.sort(np.log(freq))) nfreq = len(freq) npoints = udom.size - p_xi = ift.ScalingOperator(ift.UnstructuredDomain(2*npoints*(nfreq - 1)), 1.).ducktape("points xi") + p_xi = ift.ScalingOperator(ift.UnstructuredDomain( + 2*npoints*(nfreq - 1)), 1.).ducktape("points xi") - points = _integrated_wiener_process(p_i0, p_alpha, log_fdom, p_flex, p_asp, p_xi) + points = _integrated_wiener_process( + p_i0, p_alpha, log_fdom, p_flex, p_asp, p_xi) points = points.ducktape_left(inserter.domain) additional["point_list"] = points @@ -179,7 +192,8 @@ def _multi_freq_logsky_cfm(cfg, sdom, pol_label): fdom = IRGSpace(freq0 + np.arange(fnpix)*df) fdom_rg = ift.RGSpace(fnpix, df) - cfm = cfm_from_cfg(cfg, {"freq": fdom_rg, "space": sdom}, f"stokes{pol_label} diffuse") + cfm = cfm_from_cfg( + cfg, {"freq": fdom_rg, "space": sdom}, f"stokes{pol_label} diffuse") op = cfm.finalize(0) fampl, sampl = list(cfm.get_normalized_amplitudes()) @@ -219,7 +233,8 @@ def _multi_freq_logsky_integrated_wiener_process(cfg, sdom, pol_label, freq): flexibility = _parse_or_none(cfg, prefix + " wp flexibility") if flexibility is None: raise RuntimeError("freq flexibility cannot be None") - flexibility = ift.LognormalTransform(*flexibility, prefix + " wp flexibility", 0) + flexibility = ift.LognormalTransform( + *flexibility, prefix + " wp flexibility", 0) asperity = _parse_or_none(cfg, prefix + " wp asperity") asperity = ift.LognormalTransform(*asperity, prefix + " wp asperity", 0) @@ -251,16 +266,20 @@ def _integrated_wiener_process(i0, alpha, irg_space, flexibility, asperity, freq shift = np.ones(dom.shape) shift[0] = vol * vol / 12.0 if asperity is None: - shift = ift.DiagonalOperator(ift.makeField(dom, shift).sqrt(), intop.domain, 0) + shift = ift.DiagonalOperator(ift.makeField( + dom, shift).sqrt(), intop.domain, 0) increments = shift @ (freq_xi * sig_flex) else: vasp = np.empty(dom.shape) vasp[0] = 1 vasp[1] = 0 - vasp = ift.DiagonalOperator(ift.makeField(dom, vasp), domain=broadcast.target, spaces=0) + vasp = ift.DiagonalOperator(ift.makeField( + dom, vasp), domain=broadcast.target, spaces=0) sig_asp = broadcast_full @ vasp @ broadcast @ asperity - shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None, None], intop.domain.shape)) - increments = freq_xi * sig_flex * (ift.Adder(shift) @ sig_asp).ptw("sqrt") + shift = ift.makeField(intop.domain, np.broadcast_to( + shift[..., None, None], intop.domain.shape)) + increments = freq_xi * sig_flex * \ + (ift.Adder(shift) @ sig_asp).ptw("sqrt") return IntWProcessInitialConditions(i0, alpha, intop @ increments) @@ -268,7 +287,8 @@ def _integrated_wiener_process(i0, alpha, irg_space, flexibility, asperity, freq def cfm_from_cfg(cfg, domain_dct, prefix, total_N=0, dofdex=None, override={}, domain_prefix=None): assert len(prefix) > 0 product_spectrum = len(domain_dct) > 1 - cfm = ift.CorrelatedFieldMaker(prefix if domain_prefix is None else domain_prefix, total_N=total_N) + cfm = ift.CorrelatedFieldMaker( + prefix if domain_prefix is None else domain_prefix, total_N=total_N) for key_prefix, dom in domain_dct.items(): ll = _append_to_nonempty_string(key_prefix, " ") kwargs = {kk: _parse_or_none(cfg, f"{prefix} {ll}{kk}", override)