From 8122ea7a6cf4b46b9abc43ec9b79ff0cf08237b3 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Tue, 28 Nov 2017 21:11:42 +0100 Subject: [PATCH] cleanup --- demos/wiener_filter_via_curvature.py | 62 +++++++--- nifty/library/nonlinear_signal_curvature.py | 28 ----- nifty/library/nonlinear_signal_energy.py | 33 ++--- nifty/library/nonlinearities.py | 1 - nifty/library/response_operators.py | 56 --------- nifty/plotting/plot.py | 130 +++++++++++++++++--- setup.py | 6 +- 7 files changed, 175 insertions(+), 141 deletions(-) delete mode 100644 nifty/library/nonlinear_signal_curvature.py diff --git a/demos/wiener_filter_via_curvature.py b/demos/wiener_filter_via_curvature.py index 7d4f41593..fad1d3f0c 100644 --- a/demos/wiener_filter_via_curvature.py +++ b/demos/wiener_filter_via_curvature.py @@ -1,9 +1,17 @@ +use_nifty2go = True + import numpy as np -import nifty2go as ift +if use_nifty2go: + import nifty2go as ift +else: + import nifty as ift import numericalunits as nu if __name__ == "__main__": # In MPI mode, the random seed for numericalunits must be set by hand + if not use_nifty2go: + ift.nifty_configuration['default_distribution_strategy'] = 'fftw' + ift.nifty_configuration['harmonic_rg_base'] = 'real' nu.reset_units(43) dimensionality = 2 np.random.seed(43) @@ -32,11 +40,14 @@ if __name__ == "__main__": # Total side-length of the domain L = 2.*nu.m # Grid resolution (pixels per axis) - N_pixels = 512 + N_pixels = 4096 shape = [N_pixels]*dimensionality signal_space = ift.RGSpace(shape, distances=L/N_pixels) - harmonic_space = signal_space.get_default_codomain() + if use_nifty2go: + harmonic_space = signal_space.get_default_codomain() + else: + harmonic_space = ift.FFTOperator.get_default_codomain(signal_space) fft = ift.FFTOperator(harmonic_space, target=signal_space) power_space = ift.PowerSpace(harmonic_space) @@ -45,8 +56,12 @@ if __name__ == "__main__": power_spectrum=power_spectrum) np.random.seed(43) - mock_power = ift.PS_field(power_space, power_spectrum) - mock_harmonic = ift.power_synthesize(mock_power, real_signal=True) + if use_nifty2go: + mock_power = ift.PS_field(power_space, power_spectrum) + mock_harmonic = ift.power_synthesize(mock_power, real_signal=True) + else: + mock_power = ift.Field(power_space, val=power_spectrum) + mock_harmonic = mock_power.power_synthesize(real_signal=True) mock_harmonic = mock_harmonic.real mock_signal = fft(mock_harmonic) @@ -54,11 +69,19 @@ if __name__ == "__main__": R = ift.ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(exposure,)) data_domain = R.target[0] - R_harmonic = ift.ComposedOperator([fft, R]) + if use_nifty2go: + R_harmonic = ift.ComposedOperator([fft, R]) + else: + R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0]) + + if use_nifty2go: + N = ift.DiagonalOperator( + ift.Field.full(data_domain, + mock_signal.var()/signal_to_noise).weight(1)) + else: + ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1) + N = ift.DiagonalOperator(data_domain, ndiag) - N = ift.DiagonalOperator( - ift.Field.full(data_domain, - mock_signal.var()/signal_to_noise).weight(1)) noise = ift.Field.from_random( domain=data_domain, random_type='normal', std=mock_signal.std()/np.sqrt(signal_to_noise), mean=0) @@ -67,12 +90,23 @@ if __name__ == "__main__": # Wiener filter j = R_harmonic.adjoint_times(N.inverse_times(data)) - ctrl = ift.GradientNormController( - verbose=True, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality))) - wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, - R=R_harmonic) + if use_nifty2go: + ctrl = ift.GradientNormController( + verbose=True, iteration_limit=10, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality))) + else: + def print_stats(a_energy, iteration): # returns current energy + x = a_energy.value + print(x, iteration) + ctrl = ift.GradientNormController( + callback=print_stats, iteration_limit=10, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality))) + inverter = ift.ConjugateGradient(controller=ctrl) - wiener_curvature = ift.InversionEnabler(wiener_curvature, inverter) + if use_nifty2go: + wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, + R=R_harmonic) + wiener_curvature = ift.InversionEnabler(wiener_curvature, inverter) + else: + wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter) m = wiener_curvature.inverse_times(j) m_s = fft(m) diff --git a/nifty/library/nonlinear_signal_curvature.py b/nifty/library/nonlinear_signal_curvature.py deleted file mode 100644 index 29b3698be..000000000 --- a/nifty/library/nonlinear_signal_curvature.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..operators.endomorphic_operator import EndomorphicOperator - - -class NonlinearSignalCurvature(EndomorphicOperator): - def __init__(self, R, N, S, inverter=None): - self.R = R - self.N = N - self.S = S - # if preconditioner is None: - # preconditioner = self.S.times - self._domain = self.S.domain - super(NonlinearSignalCurvature, self).__init__(inverter=inverter) - - @property - def domain(self): - return self._domain - - @property - def self_adjoint(self): - return True - - @property - def unitary(self): - return False - - # ---Added properties and methods--- - def _times(self, x, spaces): - return self.R.adjoint_times(self.N.inverse_times(self.R(x))) + self.S.inverse_times(x) diff --git a/nifty/library/nonlinear_signal_energy.py b/nifty/library/nonlinear_signal_energy.py index a151abda6..9e6aed8e9 100644 --- a/nifty/library/nonlinear_signal_energy.py +++ b/nifty/library/nonlinear_signal_energy.py @@ -1,4 +1,4 @@ -from .nonlinear_signal_curvature import NonlinearSignalCurvature +from .wiener_filter_curvature import WienerFilterCurvature from .. import Field, exp from ..utilities import memo from ..sugar import generate_posterior_sample @@ -8,33 +8,17 @@ from .response_operators import LinearizedSignalResponse class NonlinearWienerFilterEnergy(Energy): - """The Energy for the Gaussian lognormal case. - - It describes the situation of linear measurement of a - lognormal signal with Gaussian noise and Gaussain signal prior. - - Parameters - ---------- - d : Field, - the data. - R : Operator, - The nonlinear response operator, describtion of the measurement process. - N : EndomorphicOperator, - The noise covariance in data space. - S : EndomorphicOperator, - The prior signal covariance in harmonic space. - """ - - def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, inverter=None): + def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, + inverter=None): super(NonlinearWienerFilterEnergy, self).__init__(position=position) - # print "init", position.norm() self.d = d self.Instrument = Instrument self.nonlinearity = nonlinearity self.FFT = FFT self.power = power - self.LinearizedResponse = LinearizedSignalResponse(Instrument, nonlinearity, - FFT, power, self.position) + self.LinearizedResponse = \ + LinearizedSignalResponse(Instrument, nonlinearity, FFT, power, + self.position) position_map = FFT.adjoint_times(self.power * self.position) # position_map = (Field(FFT.domain,val=position_map.val.real+0j)) @@ -68,7 +52,6 @@ class NonlinearWienerFilterEnergy(Energy): @property @memo def curvature(self): - curvature = NonlinearSignalCurvature(R=self.LinearizedResponse, - N=self.N, - S=self.S, inverter=self.inverter) + curvature = WienerFilterCurvature(R=self.LinearizedResponse, + N=self.N, S=self.S) return InversionEnabler(curvature, inverter=self.inverter) diff --git a/nifty/library/nonlinearities.py b/nifty/library/nonlinearities.py index 3a1c78983..bd0e1cbc9 100644 --- a/nifty/library/nonlinearities.py +++ b/nifty/library/nonlinearities.py @@ -1,5 +1,4 @@ from numpy import logical_and, where - from ... import Field, exp, tanh diff --git a/nifty/library/response_operators.py b/nifty/library/response_operators.py index afcbebc05..dbf77fc1e 100644 --- a/nifty/library/response_operators.py +++ b/nifty/library/response_operators.py @@ -2,33 +2,6 @@ from .. import exp from ..operators.linear_operator import LinearOperator -class AdjointFFTResponse(LinearOperator): - def __init__(self, FFT, R, default_spaces=None): - super(AdjointFFTResponse, self).__init__(default_spaces) - self._domain = FFT.target - self._target = R.target - self.R = R - self.FFT = FFT - - def _times(self, x, spaces=None): - return self.R(self.FFT.adjoint_times(x)) - - def _adjoint_times(self, x, spaces=None): - return self.FFT(self.R.adjoint_times(x)) - - @property - def domain(self): - return self._domain - - @property - def target(self): - return self._target - - @property - def unitary(self): - return False - - class LinearizedSignalResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, power, m, default_spaces=None): super(LinearizedSignalResponse, self).__init__(default_spaces) @@ -94,32 +67,3 @@ class LinearizedPowerResponse(LinearOperator): @property def unitary(self): return False - - -class SignalResponse(LinearOperator): - def __init__(self, t, FFT, R, default_spaces=None): - super(SignalResponse, self).__init__(default_spaces) - self._domain = FFT.target - self._target = R.target - self.power = exp(t).power_synthesize( - mean=1, std=0, real_signal=False) - self.R = R - self.FFT = FFT - - def _times(self, x, spaces=None): - return self.R(self.FFT.adjoint_times(self.power * x)) - - def _adjoint_times(self, x, spaces=None): - return self.power * self.FFT(self.R.adjoint_times(x)) - - @property - def domain(self): - return self._domain - - @property - def target(self): - return self._target - - @property - def unitary(self): - return False diff --git a/nifty/plotting/plot.py b/nifty/plotting/plot.py index d0a9921eb..3c3a16b0b 100644 --- a/nifty/plotting/plot.py +++ b/nifty/plotting/plot.py @@ -42,7 +42,7 @@ def _find_closest(A, target): return idx -def _makeplot(name): +def _mpl_makeplot(name): import matplotlib.pyplot as plt if dobj.rank != 0: return @@ -70,7 +70,7 @@ def _makeplot(name): raise ValueError("file format not understood") -def _limit_xy(**kwargs): +def _mpl_limit_xy(**kwargs): import matplotlib.pyplot as plt x1, x2, y1, y2 = plt.axis() x1 = _get_kw("xmin", x1, **kwargs) @@ -145,12 +145,11 @@ def _register_cmaps(): def _get_kw(kwname, kwdefault=None, **kwargs): - if kwargs.get(kwname) is not None: - return kwargs.get(kwname) - return kwdefault + res = kwargs.get(kwname) + return kwdefault if res is None else res -def plot(f, **kwargs): +def _mpl_plot(f, **kwargs): import matplotlib.pyplot as plt _register_cmaps() if not isinstance(f, Field): @@ -176,8 +175,8 @@ def plot(f, **kwargs): xcoord = np.arange(npoints, dtype=np.float64)*dist ycoord = dobj.to_global_data(f.val) plt.plot(xcoord, ycoord) - _limit_xy(**kwargs) - _makeplot(kwargs.get("name")) + _mpl_limit_xy(**kwargs) + _mpl_makeplot(kwargs.get("name")) return elif len(dom.shape) == 2: nx = dom.shape[0] @@ -195,8 +194,8 @@ def plot(f, **kwargs): # cax = divider.append_axes("right", size="5%", pad=0.05) # plt.colorbar(im,cax=cax) plt.colorbar(im) - _limit_xy(**kwargs) - _makeplot(kwargs.get("name")) + _mpl_limit_xy(**kwargs) + _mpl_makeplot(kwargs.get("name")) return elif isinstance(dom, PowerSpace): xcoord = dom.k_lengths @@ -205,8 +204,8 @@ def plot(f, **kwargs): plt.yscale('log') plt.title('power') plt.plot(xcoord, ycoord) - _limit_xy(**kwargs) - _makeplot(kwargs.get("name")) + _mpl_limit_xy(**kwargs) + _mpl_makeplot(kwargs.get("name")) return elif isinstance(dom, HPSpace): import pyHealpix @@ -222,7 +221,7 @@ def plot(f, **kwargs): plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") plt.colorbar(orientation="horizontal") - _makeplot(kwargs.get("name")) + _mpl_makeplot(kwargs.get("name")) return elif isinstance(dom, GLSpace): import pyHealpix @@ -239,7 +238,110 @@ def plot(f, **kwargs): plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") plt.colorbar(orientation="horizontal") - _makeplot(kwargs.get("name")) + _mpl_makeplot(kwargs.get("name")) return raise ValueError("Field type not(yet) supported") + + +def _plotly_plot(f, **kwargs): + if not isinstance(f, Field): + raise TypeError("incorrect data type") + if len(f.domain) != 1: + raise ValueError("input field must have exactly one domain") + + dom = f.domain[0] + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + xsize = _get_kw("xsize", 6, **kwargs) + ysize = _get_kw("ysize", 6, **kwargs) + fig.set_size_inches(xsize, ysize) + ax.set_title(_get_kw("title", "", **kwargs)) + ax.set_xlabel(_get_kw("xlabel", "", **kwargs)) + ax.set_ylabel(_get_kw("ylabel", "", **kwargs)) + cmap = _get_kw("colormap", plt.rcParams['image.cmap'], **kwargs) + if isinstance(dom, RGSpace): + if len(dom.shape) == 1: + npoints = dom.shape[0] + dist = dom.distances[0] + xcoord = np.arange(npoints, dtype=np.float64)*dist + ycoord = dobj.to_global_data(f.val) + plt.plot(xcoord, ycoord) + _mpl_limit_xy(**kwargs) + _mpl_makeplot(kwargs.get("name")) + return + elif len(dom.shape) == 2: + nx = dom.shape[0] + ny = dom.shape[1] + dx = dom.distances[0] + dy = dom.distances[1] + xc = np.arange(nx, dtype=np.float64)*dx + yc = np.arange(ny, dtype=np.float64)*dy + im = ax.imshow(dobj.to_global_data(f.val), + extent=[xc[0], xc[-1], yc[0], yc[-1]], + vmin=kwargs.get("zmin"), + vmax=kwargs.get("zmax"), cmap=cmap, origin="lower") + # from mpl_toolkits.axes_grid1 import make_axes_locatable + # divider = make_axes_locatable(ax) + # cax = divider.append_axes("right", size="5%", pad=0.05) + # plt.colorbar(im,cax=cax) + plt.colorbar(im) + _mpl_limit_xy(**kwargs) + _mpl_makeplot(kwargs.get("name")) + return + elif isinstance(dom, PowerSpace): + xcoord = dom.k_lengths + ycoord = dobj.to_global_data(f.val) + plt.xscale('log') + plt.yscale('log') + plt.title('power') + plt.plot(xcoord, ycoord) + _mpl_limit_xy(**kwargs) + _mpl_makeplot(kwargs.get("name")) + return + elif isinstance(dom, HPSpace): + import pyHealpix + xsize = 800 + res, mask, theta, phi = _mollweide_helper(xsize) + + ptg = np.empty((phi.size, 2), dtype=np.float64) + ptg[:, 0] = theta + ptg[:, 1] = phi + base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING") + res[mask] = dobj.to_global_data(f.val)[base.ang2pix(ptg)] + plt.axis('off') + plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), + cmap=cmap, origin="lower") + plt.colorbar(orientation="horizontal") + _mpl_makeplot(kwargs.get("name")) + return + elif isinstance(dom, GLSpace): + import pyHealpix + xsize = 800 + res, mask, theta, phi = _mollweide_helper(xsize) + ra = np.linspace(0, 2*np.pi, dom.nlon+1) + dec = pyHealpix.GL_thetas(dom.nlat) + ilat = _find_closest(dec, theta) + ilon = _find_closest(ra, phi) + ilon = np.where(ilon == dom.nlon, 0, ilon) + res[mask] = dobj.to_global_data(f.val)[ilat*dom.nlon + ilon] + + plt.axis('off') + plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), + cmap=cmap, origin="lower") + plt.colorbar(orientation="horizontal") + _mpl_makeplot(kwargs.get("name")) + return + + raise ValueError("Field type not(yet) supported") + + +def plot(f, **kwargs): + extension = os.path.splitext(kwargs.get("name"))[1] + if extension in [".html"]: + _plotly_plot(f, **kwargs) + elif extension in [".pdf", ".png"]: + _mpl_plot(f, **kwargs) + else: + raise ValueError("unknown file name extension: " + extension) diff --git a/setup.py b/setup.py index 4ddc06ca0..5fa9d4fc4 100644 --- a/setup.py +++ b/setup.py @@ -30,10 +30,10 @@ setup(name="nifty2go", packages=["nifty2go"] + ["nifty2go."+p for p in find_packages("nifty")], zip_safe=False, dependency_links=[ - 'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git@setuptools_test#egg=pyHealpix-0.0.1'], + 'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git#egg=pyHealpix-0.0.1'], license="GPLv3", - setup_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'], - install_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'], + setup_requires=['future', 'numpy'], + install_requires=['future', 'numpy'], classifiers=[ "Development Status :: 4 - Beta", "Topic :: Utilities", -- GitLab