Commit 8122ea7a authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent 1fc0eeed
Pipeline #22292 failed with stage
in 4 minutes
use_nifty2go = True
import numpy as np
import nifty2go as ift
if use_nifty2go:
import nifty2go as ift
else:
import nifty as ift
import numericalunits as nu
if __name__ == "__main__":
# In MPI mode, the random seed for numericalunits must be set by hand
if not use_nifty2go:
ift.nifty_configuration['default_distribution_strategy'] = 'fftw'
ift.nifty_configuration['harmonic_rg_base'] = 'real'
nu.reset_units(43)
dimensionality = 2
np.random.seed(43)
......@@ -32,11 +40,14 @@ if __name__ == "__main__":
# Total side-length of the domain
L = 2.*nu.m
# Grid resolution (pixels per axis)
N_pixels = 512
N_pixels = 4096
shape = [N_pixels]*dimensionality
signal_space = ift.RGSpace(shape, distances=L/N_pixels)
if use_nifty2go:
harmonic_space = signal_space.get_default_codomain()
else:
harmonic_space = ift.FFTOperator.get_default_codomain(signal_space)
fft = ift.FFTOperator(harmonic_space, target=signal_space)
power_space = ift.PowerSpace(harmonic_space)
......@@ -45,8 +56,12 @@ if __name__ == "__main__":
power_spectrum=power_spectrum)
np.random.seed(43)
if use_nifty2go:
mock_power = ift.PS_field(power_space, power_spectrum)
mock_harmonic = ift.power_synthesize(mock_power, real_signal=True)
else:
mock_power = ift.Field(power_space, val=power_spectrum)
mock_harmonic = mock_power.power_synthesize(real_signal=True)
mock_harmonic = mock_harmonic.real
mock_signal = fft(mock_harmonic)
......@@ -54,11 +69,19 @@ if __name__ == "__main__":
R = ift.ResponseOperator(signal_space, sigma=(response_sigma,),
exposure=(exposure,))
data_domain = R.target[0]
if use_nifty2go:
R_harmonic = ift.ComposedOperator([fft, R])
else:
R_harmonic = ift.ComposedOperator([fft, R], default_spaces=[0, 0])
if use_nifty2go:
N = ift.DiagonalOperator(
ift.Field.full(data_domain,
mock_signal.var()/signal_to_noise).weight(1))
else:
ndiag = ift.Field(data_domain, mock_signal.var()/signal_to_noise).weight(1)
N = ift.DiagonalOperator(data_domain, ndiag)
noise = ift.Field.from_random(
domain=data_domain, random_type='normal',
std=mock_signal.std()/np.sqrt(signal_to_noise), mean=0)
......@@ -67,12 +90,23 @@ if __name__ == "__main__":
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
if use_nifty2go:
ctrl = ift.GradientNormController(
verbose=True, iteration_limit=10, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality)))
else:
def print_stats(a_energy, iteration): # returns current energy
x = a_energy.value
print(x, iteration)
ctrl = ift.GradientNormController(
verbose=True, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality)))
callback=print_stats, iteration_limit=10, tol_abs_gradnorm=1e-4/nu.K/(nu.m**(0.5*dimensionality)))
inverter = ift.ConjugateGradient(controller=ctrl)
if use_nifty2go:
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N,
R=R_harmonic)
inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.InversionEnabler(wiener_curvature, inverter)
else:
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R_harmonic, inverter=inverter)
m = wiener_curvature.inverse_times(j)
m_s = fft(m)
......
from ..operators.endomorphic_operator import EndomorphicOperator
class NonlinearSignalCurvature(EndomorphicOperator):
def __init__(self, R, N, S, inverter=None):
self.R = R
self.N = N
self.S = S
# if preconditioner is None:
# preconditioner = self.S.times
self._domain = self.S.domain
super(NonlinearSignalCurvature, self).__init__(inverter=inverter)
@property
def domain(self):
return self._domain
@property
def self_adjoint(self):
return True
@property
def unitary(self):
return False
# ---Added properties and methods---
def _times(self, x, spaces):
return self.R.adjoint_times(self.N.inverse_times(self.R(x))) + self.S.inverse_times(x)
from .nonlinear_signal_curvature import NonlinearSignalCurvature
from .wiener_filter_curvature import WienerFilterCurvature
from .. import Field, exp
from ..utilities import memo
from ..sugar import generate_posterior_sample
......@@ -8,33 +8,17 @@ from .response_operators import LinearizedSignalResponse
class NonlinearWienerFilterEnergy(Energy):
"""The Energy for the Gaussian lognormal case.
It describes the situation of linear measurement of a
lognormal signal with Gaussian noise and Gaussain signal prior.
Parameters
----------
d : Field,
the data.
R : Operator,
The nonlinear response operator, describtion of the measurement process.
N : EndomorphicOperator,
The noise covariance in data space.
S : EndomorphicOperator,
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, inverter=None):
def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S,
inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
# print "init", position.norm()
self.d = d
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.FFT = FFT
self.power = power
self.LinearizedResponse = LinearizedSignalResponse(Instrument, nonlinearity,
FFT, power, self.position)
self.LinearizedResponse = \
LinearizedSignalResponse(Instrument, nonlinearity, FFT, power,
self.position)
position_map = FFT.adjoint_times(self.power * self.position)
# position_map = (Field(FFT.domain,val=position_map.val.real+0j))
......@@ -68,7 +52,6 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def curvature(self):
curvature = NonlinearSignalCurvature(R=self.LinearizedResponse,
N=self.N,
S=self.S, inverter=self.inverter)
curvature = WienerFilterCurvature(R=self.LinearizedResponse,
N=self.N, S=self.S)
return InversionEnabler(curvature, inverter=self.inverter)
from numpy import logical_and, where
from ... import Field, exp, tanh
......
......@@ -2,33 +2,6 @@ from .. import exp
from ..operators.linear_operator import LinearOperator
class AdjointFFTResponse(LinearOperator):
def __init__(self, FFT, R, default_spaces=None):
super(AdjointFFTResponse, self).__init__(default_spaces)
self._domain = FFT.target
self._target = R.target
self.R = R
self.FFT = FFT
def _times(self, x, spaces=None):
return self.R(self.FFT.adjoint_times(x))
def _adjoint_times(self, x, spaces=None):
return self.FFT(self.R.adjoint_times(x))
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def unitary(self):
return False
class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m, default_spaces=None):
super(LinearizedSignalResponse, self).__init__(default_spaces)
......@@ -94,32 +67,3 @@ class LinearizedPowerResponse(LinearOperator):
@property
def unitary(self):
return False
class SignalResponse(LinearOperator):
def __init__(self, t, FFT, R, default_spaces=None):
super(SignalResponse, self).__init__(default_spaces)
self._domain = FFT.target
self._target = R.target
self.power = exp(t).power_synthesize(
mean=1, std=0, real_signal=False)
self.R = R
self.FFT = FFT
def _times(self, x, spaces=None):
return self.R(self.FFT.adjoint_times(self.power * x))
def _adjoint_times(self, x, spaces=None):
return self.power * self.FFT(self.R.adjoint_times(x))
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def unitary(self):
return False
......@@ -42,7 +42,7 @@ def _find_closest(A, target):
return idx
def _makeplot(name):
def _mpl_makeplot(name):
import matplotlib.pyplot as plt
if dobj.rank != 0:
return
......@@ -70,7 +70,7 @@ def _makeplot(name):
raise ValueError("file format not understood")
def _limit_xy(**kwargs):
def _mpl_limit_xy(**kwargs):
import matplotlib.pyplot as plt
x1, x2, y1, y2 = plt.axis()
x1 = _get_kw("xmin", x1, **kwargs)
......@@ -145,12 +145,11 @@ def _register_cmaps():
def _get_kw(kwname, kwdefault=None, **kwargs):
if kwargs.get(kwname) is not None:
return kwargs.get(kwname)
return kwdefault
res = kwargs.get(kwname)
return kwdefault if res is None else res
def plot(f, **kwargs):
def _mpl_plot(f, **kwargs):
import matplotlib.pyplot as plt
_register_cmaps()
if not isinstance(f, Field):
......@@ -176,8 +175,8 @@ def plot(f, **kwargs):
xcoord = np.arange(npoints, dtype=np.float64)*dist
ycoord = dobj.to_global_data(f.val)
plt.plot(xcoord, ycoord)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
_mpl_limit_xy(**kwargs)
_mpl_makeplot(kwargs.get("name"))
return
elif len(dom.shape) == 2:
nx = dom.shape[0]
......@@ -195,8 +194,8 @@ def plot(f, **kwargs):
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im,cax=cax)
plt.colorbar(im)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
_mpl_limit_xy(**kwargs)
_mpl_makeplot(kwargs.get("name"))
return
elif isinstance(dom, PowerSpace):
xcoord = dom.k_lengths
......@@ -205,8 +204,8 @@ def plot(f, **kwargs):
plt.yscale('log')
plt.title('power')
plt.plot(xcoord, ycoord)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
_mpl_limit_xy(**kwargs)
_mpl_makeplot(kwargs.get("name"))
return
elif isinstance(dom, HPSpace):
import pyHealpix
......@@ -222,7 +221,7 @@ def plot(f, **kwargs):
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
_makeplot(kwargs.get("name"))
_mpl_makeplot(kwargs.get("name"))
return
elif isinstance(dom, GLSpace):
import pyHealpix
......@@ -239,7 +238,110 @@ def plot(f, **kwargs):
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
_makeplot(kwargs.get("name"))
_mpl_makeplot(kwargs.get("name"))
return
raise ValueError("Field type not(yet) supported")
def _plotly_plot(f, **kwargs):
if not isinstance(f, Field):
raise TypeError("incorrect data type")
if len(f.domain) != 1:
raise ValueError("input field must have exactly one domain")
dom = f.domain[0]
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
xsize = _get_kw("xsize", 6, **kwargs)
ysize = _get_kw("ysize", 6, **kwargs)
fig.set_size_inches(xsize, ysize)
ax.set_title(_get_kw("title", "", **kwargs))
ax.set_xlabel(_get_kw("xlabel", "", **kwargs))
ax.set_ylabel(_get_kw("ylabel", "", **kwargs))
cmap = _get_kw("colormap", plt.rcParams['image.cmap'], **kwargs)
if isinstance(dom, RGSpace):
if len(dom.shape) == 1:
npoints = dom.shape[0]
dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist
ycoord = dobj.to_global_data(f.val)
plt.plot(xcoord, ycoord)
_mpl_limit_xy(**kwargs)
_mpl_makeplot(kwargs.get("name"))
return
elif len(dom.shape) == 2:
nx = dom.shape[0]
ny = dom.shape[1]
dx = dom.distances[0]
dy = dom.distances[1]
xc = np.arange(nx, dtype=np.float64)*dx
yc = np.arange(ny, dtype=np.float64)*dy
im = ax.imshow(dobj.to_global_data(f.val),
extent=[xc[0], xc[-1], yc[0], yc[-1]],
vmin=kwargs.get("zmin"),
vmax=kwargs.get("zmax"), cmap=cmap, origin="lower")
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im,cax=cax)
plt.colorbar(im)
_mpl_limit_xy(**kwargs)
_mpl_makeplot(kwargs.get("name"))
return
elif isinstance(dom, PowerSpace):
xcoord = dom.k_lengths
ycoord = dobj.to_global_data(f.val)
plt.xscale('log')
plt.yscale('log')
plt.title('power')
plt.plot(xcoord, ycoord)
_mpl_limit_xy(**kwargs)
_mpl_makeplot(kwargs.get("name"))
return
elif isinstance(dom, HPSpace):
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
ptg = np.empty((phi.size, 2), dtype=np.float64)
ptg[:, 0] = theta
ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(f.val.size//12)), "RING")
res[mask] = dobj.to_global_data(f.val)[base.ang2pix(ptg)]
plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
_mpl_makeplot(kwargs.get("name"))
return
elif isinstance(dom, GLSpace):
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
ra = np.linspace(0, 2*np.pi, dom.nlon+1)
dec = pyHealpix.GL_thetas(dom.nlat)
ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom.nlon, 0, ilon)
res[mask] = dobj.to_global_data(f.val)[ilat*dom.nlon + ilon]
plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
_mpl_makeplot(kwargs.get("name"))
return
raise ValueError("Field type not(yet) supported")
def plot(f, **kwargs):
extension = os.path.splitext(kwargs.get("name"))[1]
if extension in [".html"]:
_plotly_plot(f, **kwargs)
elif extension in [".pdf", ".png"]:
_mpl_plot(f, **kwargs)
else:
raise ValueError("unknown file name extension: " + extension)
......@@ -30,10 +30,10 @@ setup(name="nifty2go",
packages=["nifty2go"] + ["nifty2go."+p for p in find_packages("nifty")],
zip_safe=False,
dependency_links=[
'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git@setuptools_test#egg=pyHealpix-0.0.1'],
'git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git#egg=pyHealpix-0.0.1'],
license="GPLv3",
setup_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'],
install_requires=['future', 'pyHealpix>=0.0.1', 'numpy', 'pyfftw>=0.10.4'],
setup_requires=['future', 'numpy'],
install_requires=['future', 'numpy'],
classifiers=[
"Development Status :: 4 - Beta",
"Topic :: Utilities",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment