From 8122ea7a6cf4b46b9abc43ec9b79ff0cf08237b3 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Tue, 28 Nov 2017 21:11:42 +0100
Subject: [PATCH] cleanup

---
 demos/wiener_filter_via_curvature.py        |  62 +++++++---
 nifty/library/nonlinear_signal_curvature.py |  28 -----
 nifty/library/nonlinear_signal_energy.py    |  33 ++---
 nifty/library/nonlinearities.py             |   1 -
 nifty/library/response_operators.py         |  56 ---------
 nifty/plotting/plot.py                      | 130 +++++++++++++++++---
 setup.py                                    |   6 +-
 7 files changed, 175 insertions(+), 141 deletions(-)
 delete mode 100644 nifty/library/nonlinear_signal_curvature.py

diff --git a/demos/wiener_filter_via_curvature.py b/demos/wiener_filter_via_curvature.py
index 7d4f41593..fad1d3f0c 100644
--- a/demos/wiener_filter_via_curvature.py
+++ b/demos/wiener_filter_via_curvature.py
@@ -1,9 +1,17 @@
+use_nifty2go = True
+
 import numpy as np
-import nifty2go as ift
+if use_nifty2go:
+    import nifty2go as ift
+else:
+    import nifty as ift
 import numericalunits as nu
 
 if __name__ == "__main__":
     # In MPI mode, the random seed for numericalunits must be set by hand
+    if not use_nifty2go:
+        ift.nifty_configuration['default_distribution_strategy'] = 'fftw'
+        ift.nifty_configuration['harmonic_rg_base'] = 'real'
     nu.reset_units(43)
     dimensionality = 2
     np.random.seed(43)
@@ -32,11 +40,14 @@ if __name__ == "__main__":
     # Total side-length of the domain
     L = 2.*nu.m
     # Grid resolution (pixels per axis)
-    N_pixels = 512
+    N_pixels = 4096
     shape = [N_pixels]*dimensionality
 
     signal_space = ift.RGSpace(shape, distances=L/N_pixels)
-    harmonic_space = signal_space.get_default_codomain()
+    if use_nifty2go:
+        harmonic_space = signal_space.get_default_codomain()
+    else:
+        harmonic_space = ift.FFTOperator.get_default_codomain(signal_space)
     fft = ift.FFTOperator(harmonic_space, target=signal_space)
     power_space = ift.PowerSpace(harmonic_space)
 
@@ -45,8 +56,12 @@ if __name__ == "__main__":
                                   power_spectrum=power_spectrum)
     np.random.seed(43)
 
-    mock_power = ift.PS_field(power_space, power_spectrum)
-    mock_harmonic = ift.power_synthesize(mock_power, real_signal=True)
+    if use_nifty2go:
+        mock_power = ift.PS_field(power_space, power_spectrum)
+        mock_harmonic = ift.power_synthesize(mock_power, real_signal=True)
+    else:
+        mock_power = ift.Field(power_space, val=power_spectrum)
+        mock_harmonic = mock_power.power_synthesize(real_signal=True)
     mock_harmonic = mock_harmonic.real
     mock_signal = fft(mock_harmonic)
 
@@ -54,11 +69,19 @@ if __name__ == "__main__":
     R = ift.ResponseOperator(signal_space, sigma=(response_sigma,),
                              exposure=(exposure,))
     data_domain = R.target[0]
-    R_harmonic = ift.ComposedOperator([fft, R])
+    if use_nifty2go:
+        R_harmonic = ift.ComposedOperator([fft, R])
+    else:
+        R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0])
+
+    if use_nifty2go:
+        N = ift.DiagonalOperator(
+            ift.Field.full(data_domain,
+                           mock_signal.var()/signal_to_noise).weight(1))
+    else:
+        ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
+        N = ift.DiagonalOperator(data_domain, ndiag)
 
-    N = ift.DiagonalOperator(
-        ift.Field.full(data_domain,
-                       mock_signal.var()/signal_to_noise).weight(1))
     noise = ift.Field.from_random(
         domain=data_domain, random_type='normal',
         std=mock_signal.std()/np.sqrt(signal_to_noise), mean=0)
@@ -67,12 +90,23 @@ if __name__ == "__main__":
     # Wiener filter
 
     j = R_harmonic.adjoint_times(N.inverse_times(data))
-    ctrl = ift.GradientNormController(
-        verbose=True, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality)))
-    wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N,
-                                                         R=R_harmonic)
+    if use_nifty2go:
+        ctrl = ift.GradientNormController(
+            verbose=True, iteration_limit=10, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality)))
+    else:
+        def print_stats(a_energy, iteration):  # returns current energy
+            x = a_energy.value
+            print(x, iteration)
+        ctrl = ift.GradientNormController(
+            callback=print_stats, iteration_limit=10, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality)))
+
     inverter = ift.ConjugateGradient(controller=ctrl)
-    wiener_curvature = ift.InversionEnabler(wiener_curvature, inverter)
+    if use_nifty2go:
+        wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N,
+                                                             R=R_harmonic)
+        wiener_curvature = ift.InversionEnabler(wiener_curvature, inverter)
+    else:
+        wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
 
     m = wiener_curvature.inverse_times(j)
     m_s = fft(m)
diff --git a/nifty/library/nonlinear_signal_curvature.py b/nifty/library/nonlinear_signal_curvature.py
deleted file mode 100644
index 29b3698be..000000000
--- a/nifty/library/nonlinear_signal_curvature.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from ..operators.endomorphic_operator import EndomorphicOperator
-
-
-class NonlinearSignalCurvature(EndomorphicOperator):
-    def __init__(self, R, N, S, inverter=None):
-        self.R = R
-        self.N = N
-        self.S = S
-        # if preconditioner is None:
-        #     preconditioner = self.S.times
-        self._domain = self.S.domain
-        super(NonlinearSignalCurvature, self).__init__(inverter=inverter)
-
-    @property
-    def domain(self):
-        return self._domain
-
-    @property
-    def self_adjoint(self):
-        return True
-
-    @property
-    def unitary(self):
-        return False
-
-    # ---Added properties and methods---
-    def _times(self, x, spaces):
-        return self.R.adjoint_times(self.N.inverse_times(self.R(x))) + self.S.inverse_times(x)
diff --git a/nifty/library/nonlinear_signal_energy.py b/nifty/library/nonlinear_signal_energy.py
index a151abda6..9e6aed8e9 100644
--- a/nifty/library/nonlinear_signal_energy.py
+++ b/nifty/library/nonlinear_signal_energy.py
@@ -1,4 +1,4 @@
-from .nonlinear_signal_curvature import NonlinearSignalCurvature
+from .wiener_filter_curvature import WienerFilterCurvature
 from .. import Field, exp
 from ..utilities import memo
 from ..sugar import generate_posterior_sample
@@ -8,33 +8,17 @@ from .response_operators import LinearizedSignalResponse
 
 
 class NonlinearWienerFilterEnergy(Energy):
-    """The Energy for the Gaussian lognormal case.
-
-    It describes the situation of linear measurement  of a
-    lognormal signal with Gaussian noise and Gaussain signal prior.
-
-    Parameters
-    ----------
-    d : Field,
-        the data.
-    R : Operator,
-        The nonlinear response operator, describtion of the measurement process.
-    N : EndomorphicOperator,
-        The noise covariance in data space.
-    S : EndomorphicOperator,
-        The prior signal covariance in harmonic space.
-    """
-
-    def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, inverter=None):
+    def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S,
+                 inverter=None):
         super(NonlinearWienerFilterEnergy, self).__init__(position=position)
-        # print "init", position.norm()
         self.d = d
         self.Instrument = Instrument
         self.nonlinearity = nonlinearity
         self.FFT = FFT
         self.power = power
-        self.LinearizedResponse = LinearizedSignalResponse(Instrument, nonlinearity,
-                                                           FFT, power, self.position)
+        self.LinearizedResponse = \
+            LinearizedSignalResponse(Instrument, nonlinearity, FFT, power,
+                                     self.position)
 
         position_map = FFT.adjoint_times(self.power * self.position)
         # position_map = (Field(FFT.domain,val=position_map.val.real+0j))
@@ -68,7 +52,6 @@ class NonlinearWienerFilterEnergy(Energy):
     @property
     @memo
     def curvature(self):
-        curvature = NonlinearSignalCurvature(R=self.LinearizedResponse,
-                                             N=self.N,
-                                             S=self.S, inverter=self.inverter)
+        curvature = WienerFilterCurvature(R=self.LinearizedResponse,
+                                          N=self.N, S=self.S)
         return InversionEnabler(curvature, inverter=self.inverter)
diff --git a/nifty/library/nonlinearities.py b/nifty/library/nonlinearities.py
index 3a1c78983..bd0e1cbc9 100644
--- a/nifty/library/nonlinearities.py
+++ b/nifty/library/nonlinearities.py
@@ -1,5 +1,4 @@
 from numpy import logical_and, where
-
 from ... import Field, exp, tanh
 
 
diff --git a/nifty/library/response_operators.py b/nifty/library/response_operators.py
index afcbebc05..dbf77fc1e 100644
--- a/nifty/library/response_operators.py
+++ b/nifty/library/response_operators.py
@@ -2,33 +2,6 @@ from .. import exp
 from ..operators.linear_operator import LinearOperator
 
 
-class AdjointFFTResponse(LinearOperator):
-    def __init__(self, FFT, R, default_spaces=None):
-        super(AdjointFFTResponse, self).__init__(default_spaces)
-        self._domain = FFT.target
-        self._target = R.target
-        self.R = R
-        self.FFT = FFT
-
-    def _times(self, x, spaces=None):
-        return self.R(self.FFT.adjoint_times(x))
-
-    def _adjoint_times(self, x, spaces=None):
-        return self.FFT(self.R.adjoint_times(x))
-
-    @property
-    def domain(self):
-        return self._domain
-
-    @property
-    def target(self):
-        return self._target
-
-    @property
-    def unitary(self):
-        return False
-
-
 class LinearizedSignalResponse(LinearOperator):
     def __init__(self, Instrument, nonlinearity, FFT, power, m, default_spaces=None):
         super(LinearizedSignalResponse, self).__init__(default_spaces)
@@ -94,32 +67,3 @@ class LinearizedPowerResponse(LinearOperator):
     @property
     def unitary(self):
         return False
-
-
-class SignalResponse(LinearOperator):
-    def __init__(self, t, FFT, R, default_spaces=None):
-        super(SignalResponse, self).__init__(default_spaces)
-        self._domain = FFT.target
-        self._target = R.target
-        self.power = exp(t).power_synthesize(
-            mean=1, std=0, real_signal=False)
-        self.R = R
-        self.FFT = FFT
-
-    def _times(self, x, spaces=None):
-        return self.R(self.FFT.adjoint_times(self.power * x))
-
-    def _adjoint_times(self, x, spaces=None):
-        return self.power * self.FFT(self.R.adjoint_times(x))
-
-    @property
-    def domain(self):
-        return self._domain
-
-    @property
-    def target(self):
-        return self._target
-
-    @property
-    def unitary(self):
-        return False
diff --git a/nifty/plotting/plot.py b/nifty/plotting/plot.py
index d0a9921eb..3c3a16b0b 100644
--- a/nifty/plotting/plot.py
+++ b/nifty/plotting/plot.py
@@ -42,7 +42,7 @@ def _find_closest(A, target):
     return idx
 
 
-def _makeplot(name):
+def _mpl_makeplot(name):
     import matplotlib.pyplot as plt
     if dobj.rank != 0:
         return
@@ -70,7 +70,7 @@ def _makeplot(name):
         raise ValueError("file format not understood")
 
 
-def _limit_xy(**kwargs):
+def _mpl_limit_xy(**kwargs):
     import matplotlib.pyplot as plt
     x1, x2, y1, y2 = plt.axis()
     x1 = _get_kw("xmin", x1, **kwargs)
@@ -145,12 +145,11 @@ def _register_cmaps():
 
 
 def _get_kw(kwname, kwdefault=None, **kwargs):
-    if kwargs.get(kwname) is not None:
-        return kwargs.get(kwname)
-    return kwdefault
+    res = kwargs.get(kwname)
+    return kwdefault if res is None else res
 
 
-def plot(f, **kwargs):
+def _mpl_plot(f, **kwargs):
     import matplotlib.pyplot as plt
     _register_cmaps()
     if not isinstance(f, Field):
@@ -176,8 +175,8 @@ def plot(f, **kwargs):
             xcoord = np.arange(npoints, dtype=np.float64)*dist
             ycoord = dobj.to_global_data(f.val)
             plt.plot(xcoord, ycoord)
-            _limit_xy(**kwargs)
-            _makeplot(kwargs.get("name"))
+            _mpl_limit_xy(**kwargs)
+            _mpl_makeplot(kwargs.get("name"))
             return
         elif len(dom.shape) == 2:
             nx = dom.shape[0]
@@ -195,8 +194,8 @@ def plot(f, **kwargs):
             # cax = divider.append_axes("right", size="5%", pad=0.05)
             # plt.colorbar(im,cax=cax)
             plt.colorbar(im)
-            _limit_xy(**kwargs)
-            _makeplot(kwargs.get("name"))
+            _mpl_limit_xy(**kwargs)
+            _mpl_makeplot(kwargs.get("name"))
             return
     elif isinstance(dom, PowerSpace):
         xcoord = dom.k_lengths
@@ -205,8 +204,8 @@ def plot(f, **kwargs):
         plt.yscale('log')
         plt.title('power')
         plt.plot(xcoord, ycoord)
-        _limit_xy(**kwargs)
-        _makeplot(kwargs.get("name"))
+        _mpl_limit_xy(**kwargs)
+        _mpl_makeplot(kwargs.get("name"))
         return
     elif isinstance(dom, HPSpace):
         import pyHealpix
@@ -222,7 +221,7 @@ def plot(f, **kwargs):
         plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
                    cmap=cmap, origin="lower")
         plt.colorbar(orientation="horizontal")
-        _makeplot(kwargs.get("name"))
+        _mpl_makeplot(kwargs.get("name"))
         return
     elif isinstance(dom, GLSpace):
         import pyHealpix
@@ -239,7 +238,110 @@ def plot(f, **kwargs):
         plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
                    cmap=cmap, origin="lower")
         plt.colorbar(orientation="horizontal")
-        _makeplot(kwargs.get("name"))
+        _mpl_makeplot(kwargs.get("name"))
         return
 
     raise ValueError("Field type not(yet) supported")
+
+
+def _plotly_plot(f, **kwargs):
+    if not isinstance(f, Field):
+        raise TypeError("incorrect data type")
+    if len(f.domain) != 1:
+        raise ValueError("input field must have exactly one domain")
+
+    dom = f.domain[0]
+    fig = plt.figure()
+    ax = fig.add_subplot(1, 1, 1)
+
+    xsize = _get_kw("xsize", 6, **kwargs)
+    ysize = _get_kw("ysize", 6, **kwargs)
+    fig.set_size_inches(xsize, ysize)
+    ax.set_title(_get_kw("title", "", **kwargs))
+    ax.set_xlabel(_get_kw("xlabel", "", **kwargs))
+    ax.set_ylabel(_get_kw("ylabel", "", **kwargs))
+    cmap = _get_kw("colormap", plt.rcParams['image.cmap'], **kwargs)
+    if isinstance(dom, RGSpace):
+        if len(dom.shape) == 1:
+            npoints = dom.shape[0]
+            dist = dom.distances[0]
+            xcoord = np.arange(npoints, dtype=np.float64)*dist
+            ycoord = dobj.to_global_data(f.val)
+            plt.plot(xcoord, ycoord)
+            _mpl_limit_xy(**kwargs)
+            _mpl_makeplot(kwargs.get("name"))
+            return
+        elif len(dom.shape) == 2:
+            nx = dom.shape[0]
+            ny = dom.shape[1]
+            dx = dom.distances[0]
+            dy = dom.distances[1]
+            xc = np.arange(nx, dtype=np.float64)*dx
+            yc = np.arange(ny, dtype=np.float64)*dy
+            im = ax.imshow(dobj.to_global_data(f.val),
+                           extent=[xc[0], xc[-1], yc[0], yc[-1]],
+                           vmin=kwargs.get("zmin"),
+                           vmax=kwargs.get("zmax"), cmap=cmap, origin="lower")
+            # from mpl_toolkits.axes_grid1 import make_axes_locatable
+            # divider = make_axes_locatable(ax)
+            # cax = divider.append_axes("right", size="5%", pad=0.05)
+            # plt.colorbar(im,cax=cax)
+            plt.colorbar(im)
+            _mpl_limit_xy(**kwargs)
+            _mpl_makeplot(kwargs.get("name"))
+            return
+    elif isinstance(dom, PowerSpace):
+        xcoord = dom.k_lengths
+        ycoord = dobj.to_global_data(f.val)
+        plt.xscale('log')
+        plt.yscale('log')
+        plt.title('power')
+        plt.plot(xcoord, ycoord)
+        _mpl_limit_xy(**kwargs)
+        _mpl_makeplot(kwargs.get("name"))
+        return
+    elif isinstance(dom, HPSpace):
+        import pyHealpix
+        xsize = 800
+        res, mask, theta, phi = _mollweide_helper(xsize)
+
+        ptg = np.empty((phi.size, 2), dtype=np.float64)
+        ptg[:, 0] = theta
+        ptg[:, 1] = phi
+        base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING")
+        res[mask] = dobj.to_global_data(f.val)[base.ang2pix(ptg)]
+        plt.axis('off')
+        plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
+                   cmap=cmap, origin="lower")
+        plt.colorbar(orientation="horizontal")
+        _mpl_makeplot(kwargs.get("name"))
+        return
+    elif isinstance(dom, GLSpace):
+        import pyHealpix
+        xsize = 800
+        res, mask, theta, phi = _mollweide_helper(xsize)
+        ra = np.linspace(0, 2*np.pi, dom.nlon+1)
+        dec = pyHealpix.GL_thetas(dom.nlat)
+        ilat = _find_closest(dec, theta)
+        ilon = _find_closest(ra, phi)
+        ilon = np.where(ilon == dom.nlon, 0, ilon)
+        res[mask] = dobj.to_global_data(f.val)[ilat*dom.nlon + ilon]
+
+        plt.axis('off')
+        plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
+                   cmap=cmap, origin="lower")
+        plt.colorbar(orientation="horizontal")
+        _mpl_makeplot(kwargs.get("name"))
+        return
+
+    raise ValueError("Field type not(yet) supported")
+
+
+def plot(f, **kwargs):
+    extension = os.path.splitext(kwargs.get("name"))[1]
+    if extension in [".html"]:
+        _plotly_plot(f, **kwargs)
+    elif extension in [".pdf", ".png"]:
+        _mpl_plot(f, **kwargs)
+    else:
+        raise ValueError("unknown file name extension: " + extension)
diff --git a/setup.py b/setup.py
index 4ddc06ca0..5fa9d4fc4 100644
--- a/setup.py
+++ b/setup.py
@@ -30,10 +30,10 @@ setup(name="nifty2go",
       packages=["nifty2go"] + ["nifty2go."+p for p in find_packages("nifty")],
       zip_safe=False,
       dependency_links=[
-               'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git@setuptools_test#egg=pyHealpix-0.0.1'],
+               'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git#egg=pyHealpix-0.0.1'],
       license="GPLv3",
-      setup_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'],
-      install_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'],
+      setup_requires=['future', 'numpy'],
+      install_requires=['future', 'numpy'],
       classifiers=[
         "Development Status :: 4 - Beta",
         "Topic :: Utilities",
-- 
GitLab