From aa7eed54c0bfa5c9d5ec5342ef4c02e9c6c143c3 Mon Sep 17 00:00:00 2001
From: Philipp Arras <c@philipp-arras.de>
Date: Wed, 27 Apr 2022 15:58:57 +0200
Subject: [PATCH] Try to implement spectral index fit for plotting

---
 resolve/ubik_tools/plot_sky_hdf5.py | 52 ++++++++++++++++++++++++++---
 1 file changed, 48 insertions(+), 4 deletions(-)

diff --git a/resolve/ubik_tools/plot_sky_hdf5.py b/resolve/ubik_tools/plot_sky_hdf5.py
index 712140a6..d0556d75 100644
--- a/resolve/ubik_tools/plot_sky_hdf5.py
+++ b/resolve/ubik_tools/plot_sky_hdf5.py
@@ -36,12 +36,13 @@ def cmdline_visualize_sky_hdf5():
     parser.add_argument('file_name')
     parser.add_argument('what', help="Can be 'mean', 'stddev', 'stddev/mean', 'mean/stddev', "
                                       "'sample[i]'")
-    parser.add_argument('stokes', help="Can be 'I', 'Q', 'U', 'V', 'polarizatedfraction' "
+    parser.add_argument('stokes', help="Can be 'I', 'Q', 'U', 'V', 'polarizedfraction' "
                                        "or nothing if output file ends with '.fits'. Then all "
                                        "Stokes parameters are written to the fits file.", nargs="?")
     parser.add_argument('--norm', help="Can be 'log', 'symmetric', 'linear'", default='linear')
     parser.add_argument('--vmin', type=float)
     parser.add_argument('--vmax', type=float)
+    parser.add_argument('--prepare', help="Can be 'freq-diff', 'freq-frac', 'spectral-index', None", default=None)
     parser.add_argument('-o')
     parser.add_argument('--dpi', type=int, default=300)
     args = parser.parse_args()
@@ -50,11 +51,12 @@ def cmdline_visualize_sky_hdf5():
                               what=args.what,
                               stokes=args.stokes,
                               norm=args.norm, vmin=args.vmin, vmax=args.vmax,
-                              dpi=args.dpi)
+                              dpi=args.dpi,
+                              prepare=args.prepare)
 
 
 def visualize_sky_hdf5(hdf5_file, output_file, what, stokes, norm="linear", vmin=None, vmax=None,
-                       dpi=300):
+                       dpi=300, prepare=None):
     """
 
     Parameters
@@ -67,7 +69,7 @@ def visualize_sky_hdf5(hdf5_file, output_file, what, stokes, norm="linear", vmin
         Can be 'mean', 'stddev', 'stddev/mean', 'mean/stddev', 'sample[i]'.
 
     stokes : str
-        Can be 'I', 'Q', 'U', 'V', 'polarizatedfraction' or nothing if output
+        Can be 'I', 'Q', 'U', 'V', 'polarizedfraction' or nothing if output
         file ends with '.fits'. Then all Stokes parameters are written to the
         fits file.
 
@@ -79,6 +81,9 @@ def visualize_sky_hdf5(hdf5_file, output_file, what, stokes, norm="linear", vmin
     vmax : 
 
     dpi : int
+
+    prepare : str
+        Can be 'freq-diff', 'freq-frac', 'spectral-index' or None, default is None.
     """
     try:
         import h5py
@@ -106,6 +111,45 @@ def visualize_sky_hdf5(hdf5_file, output_file, what, stokes, norm="linear", vmin
         else:
             raise RuntimeError()
 
+        if prepare == "freq-diff":
+            newarr = np.empty_like(arr)
+            newarr[:, :, :-1] = arr[:, :, :-1] - arr[:, :, 1:]
+            newarr[:, :, -1] = np.nan
+            arr = newarr
+        elif prepare == "freq-frac":
+            newarr = np.empty_like(arr)
+            newarr[:, :, :-1] = arr[:, :, :-1] / arr[:, :, 1:]
+            newarr[:, :, -1] = np.nan
+            arr = newarr
+        elif prepare == "spectral-index":
+            dom = eval(f.attrs["nifty domain"])
+            fdom = dom[2]
+
+            bc_freq = ift.ContractionOperator(dom, 2).adjoint
+            inp = bc_freq @ ift.Operator.identity_operator(bc_freq.domain)
+            freq = np.array(fdom.coordinates)
+            x = ift.makeOp(ift.ContractionOperator(dom, (0, 1, 3)).adjoint(ift.makeField(fdom, np.log(freq/np.mean(freq)))))
+            model = (inp.ducktape("c0") + x @ inp.ducktape("c1")).exp()
+            data = ift.makeField(dom, np.array(arr))
+            invcov = ift.ScalingOperator(dom, (np.max(data.val)/10)**2).inverse
+
+            mini = ift.NewtonCG(ift.GradientNormController(iteration_limit=100, name="newton"))
+            e = ift.GaussianEnergy(data=data, inverse_covariance=invcov) @ model
+            e = ift.EnergyAdapter(ift.from_random(e.domain), e, want_metric=True)
+            e, _ = mini(e)
+            out = e.position
+
+            p = ift.Plot()
+            for kk, vv in out.items():
+                p.add(vv, title=kk)
+            p.output()
+            exit()
+        elif prepare is None:
+            pass
+        else:
+            raise ValueError(f"prepare='{prepare}' not understood.")
+
+
         fits = output_file is not None and os.path.splitext(output_file)[1] == ".fits"
 
         if "nifty domain" not in f.attrs and arr.shape[0] == 1:
-- 
GitLab