diff --git a/resolve/sky_model.py b/resolve/sky_model.py index 60468998568c6624d9db69b23b24e651f716ccd0..4395edfbfbf5442ed983e72df96dfb0a877bc88b 100644 --- a/resolve/sky_model.py +++ b/resolve/sky_model.py @@ -121,7 +121,10 @@ def sky_model_points(cfg, observations=[], nthreads=1): alpha = cfg.getfloat("point sources alpha") q = cfg.getfloat("point sources q") - inserter = PointInserter(sky.target, ppos) + freq = _get_frequencies(cfg, observations) + fdom = IRGSpace(freq) + sky_dom = default_sky_domain(pdom=pdom, fdom=fdom, sdom=sdom) + inserter = PointInserter(sky_dom, ppos) udom = inserter.domain[-1] p_i0 = ift.InverseGammaOperator(udom, alpha=alpha, q=q/sdom.scalar_dvol) @@ -138,7 +141,6 @@ def sky_model_points(cfg, observations=[], nthreads=1): if p_asp is not None: p_asp = ift.LognormalTransform(*p_asp, "points asperity", 0) - freq = _get_frequencies(cfg, observations) log_fdom = IRGSpace(np.sort(np.log(freq))) nfreq = len(freq) npoints = udom.size @@ -259,7 +261,10 @@ def _integrated_wiener_process(i0, alpha, irg_space, flexibility, asperity, freq vasp[1] = 0 vasp = ift.DiagonalOperator(ift.makeField(dom, vasp), domain=broadcast.target, spaces=0) sig_asp = broadcast_full @ vasp @ broadcast @ asperity - shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None, None], intop.domain.shape)) + if len(i0.target.shape) == 1: + shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None], intop.domain.shape)) + else: + shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None, None], intop.domain.shape)) increments = freq_xi * sig_flex * (ift.Adder(shift) @ sig_asp).ptw("sqrt") return IntWProcessInitialConditions(i0, alpha, intop @ increments)