From f784cb7ad22d9ce0103b80c86b5ec3f657c92fa1 Mon Sep 17 00:00:00 2001
From: Simon Ding <simon.ding@iap.fr>
Date: Wed, 11 Oct 2023 17:04:18 +0200
Subject: [PATCH] added iwp point sources

---
 resolve/sky_model.py | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/resolve/sky_model.py b/resolve/sky_model.py
index 60468998..4395edfb 100644
--- a/resolve/sky_model.py
+++ b/resolve/sky_model.py
@@ -121,7 +121,10 @@ def sky_model_points(cfg, observations=[], nthreads=1):
         alpha = cfg.getfloat("point sources alpha")
         q = cfg.getfloat("point sources q")
 
-        inserter = PointInserter(sky.target, ppos)
+        freq = _get_frequencies(cfg, observations)
+        fdom = IRGSpace(freq)
+        sky_dom = default_sky_domain(pdom=pdom, fdom=fdom, sdom=sdom)
+        inserter = PointInserter(sky_dom, ppos)
         udom = inserter.domain[-1]
 
         p_i0 = ift.InverseGammaOperator(udom, alpha=alpha, q=q/sdom.scalar_dvol)
@@ -138,7 +141,6 @@ def sky_model_points(cfg, observations=[], nthreads=1):
         if p_asp is not None:
             p_asp = ift.LognormalTransform(*p_asp, "points asperity", 0)
 
-        freq = _get_frequencies(cfg, observations)
         log_fdom = IRGSpace(np.sort(np.log(freq)))
         nfreq = len(freq)
         npoints = udom.size
@@ -259,7 +261,10 @@ def _integrated_wiener_process(i0, alpha, irg_space, flexibility, asperity, freq
         vasp[1] = 0
         vasp = ift.DiagonalOperator(ift.makeField(dom, vasp), domain=broadcast.target, spaces=0)
         sig_asp = broadcast_full @ vasp @ broadcast @ asperity
-        shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None, None], intop.domain.shape))
+        if len(i0.target.shape) == 1:
+            shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None], intop.domain.shape))
+        else:
+            shift = ift.makeField(intop.domain, np.broadcast_to(shift[..., None, None], intop.domain.shape))
         increments = freq_xi * sig_flex * (ift.Adder(shift) @ sig_asp).ptw("sqrt")
 
     return IntWProcessInitialConditions(i0, alpha, intop @ increments)
-- 
GitLab