Commit 55458cb3 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'various_fixes' into 'NIFTy_5'

Various fixes

See merge request ift/nifty-dev!173
parents c14efdc6 ee63db3a
...@@ -26,33 +26,33 @@ from ..sugar import makeOp ...@@ -26,33 +26,33 @@ from ..sugar import makeOp
class InverseGammaOperator(Operator): class InverseGammaOperator(Operator):
"""Operator which transforms a Gaussian into an inverse gamma distribution.
The pdf of the inverse gamma distribution is defined as follows:
.. math ::
\\frac{\\beta^\\alpha}{\\Gamma(\\alpha)}x^{-\\alpha -1}\\exp \\left(-\\frac{\\beta }{x}\\right)
That means that for large x the pdf falls off like x^(-alpha -1).
The mean of the pdf is at q / (alpha - 1) if alpha > 1.
The mode is q / (alpha + 1).
This transformation is implemented as a linear interpolation which maps a
Gaussian onto a inverse gamma distribution.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain on which the field shall be defined. This is at the same
time the domain and the target of the operator.
alpha : float
The alpha-parameter of the inverse-gamma distribution.
q : float
The q-parameter of the inverse-gamma distribution.
delta : float
distance between sampling points for linear interpolation.
"""
def __init__(self, domain, alpha, q, delta=0.001): def __init__(self, domain, alpha, q, delta=0.001):
"""Operator which transforms a Gaussian into an inverse gamma distribution.
The pdf of the inverse gamma distribution is defined as follows:
.. math::
\frac {\beta ^{\alpha }}{\Gamma (\alpha )}}x^{-\alpha -1}\exp \left(-{\frac {\beta }{x}}\right)
That means that for large x the pdf falls off like x^(-alpha -1).
The mean of the pdf is at q / (alpha - 1) if alpha > 1.
The mode is q / (alpha + 1).
This transformation is implemented as a linear interpolation which
maps a Gaussian onto a inverse gamma distribution.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain on which the field shall be defined. This is at the same
time the domain and the target of the operator.
alpha : float
The alpha-parameter of the inverse-gamma distribution.
q : float
The q-parameter of the inverse-gamma distribution.
delta : float
distance between sampling points for linear interpolation.
"""
self._domain = self._target = DomainTuple.make(domain) self._domain = self._target = DomainTuple.make(domain)
self._alpha, self._q, self._delta = float(alpha), float(q), float(delta) self._alpha, self._q, self._delta = float(alpha), float(q), float(delta)
self._xmin, self._xmax = -8.2, 8.2 self._xmin, self._xmax = -8.2, 8.2
......
...@@ -47,9 +47,9 @@ def _make_coords(domain, absolute=False): ...@@ -47,9 +47,9 @@ def _make_coords(domain, absolute=False):
return k_array return k_array
class LightConeDerivative(LinearOperator): class _LightConeDerivative(LinearOperator):
def __init__(self, domain, target, derivatives): def __init__(self, domain, target, derivatives):
super(LightConeDerivative, self).__init__() super(_LightConeDerivative, self).__init__()
self._domain = domain self._domain = domain
self._target = target self._target = target
self._derivatives = derivatives self._derivatives = derivatives
...@@ -67,7 +67,7 @@ class LightConeDerivative(LinearOperator): ...@@ -67,7 +67,7 @@ class LightConeDerivative(LinearOperator):
return Field.from_global_data(self._tgt(mode), res) return Field.from_global_data(self._tgt(mode), res)
def cone_arrays(c, domain, sigx, want_gradient): def _cone_arrays(c, domain, sigx, want_gradient):
x = _make_coords(domain) x = _make_coords(domain)
a = np.zeros(domain.shape, dtype=np.complex) a = np.zeros(domain.shape, dtype=np.complex)
if want_gradient: if want_gradient:
...@@ -96,6 +96,9 @@ def cone_arrays(c, domain, sigx, want_gradient): ...@@ -96,6 +96,9 @@ def cone_arrays(c, domain, sigx, want_gradient):
class LightConeOperator(Operator): class LightConeOperator(Operator):
'''
FIXME
'''
def __init__(self, domain, target, sigx): def __init__(self, domain, target, sigx):
self._domain = domain self._domain = domain
self._target = target self._target = target
...@@ -104,9 +107,9 @@ class LightConeOperator(Operator): ...@@ -104,9 +107,9 @@ class LightConeOperator(Operator):
def apply(self, x): def apply(self, x):
islin = isinstance(x, Linearization) islin = isinstance(x, Linearization)
val = x.val.to_global_data() if islin else x.to_global_data() val = x.val.to_global_data() if islin else x.to_global_data()
a, derivs = cone_arrays(val, self.target, self._sigx, islin) a, derivs = _cone_arrays(val, self.target, self._sigx, islin)
res = Field.from_global_data(self.target, a) res = Field.from_global_data(self.target, a)
if not islin: if not islin:
return res return res
jac = LightConeDerivative(x.jac.target, self.target, derivs)(x.jac) jac = _LightConeDerivative(x.jac.target, self.target, derivs)(x.jac)
return Linearization(res, jac, want_metric=x.want_metric) return Linearization(res, jac, want_metric=x.want_metric)
...@@ -112,7 +112,7 @@ def CepstrumOperator(target, a, k0): ...@@ -112,7 +112,7 @@ def CepstrumOperator(target, a, k0):
return sym @ qht @ makeOp(cepstrum.sqrt()) return sym @ qht @ makeOp(cepstrum.sqrt())
def SLAmplitude(target, n_pix, a, k0, sm, sv, im, iv, keys=['tau', 'phi']): def SLAmplitude(*, target, n_pix, a, k0, sm, sv, im, iv, keys=['tau', 'phi']):
'''Operator for parametrizing smooth amplitudes (square roots of power '''Operator for parametrizing smooth amplitudes (square roots of power
spectra). spectra).
......
...@@ -24,6 +24,8 @@ from .domains.gl_space import GLSpace ...@@ -24,6 +24,8 @@ from .domains.gl_space import GLSpace
from .domains.hp_space import HPSpace from .domains.hp_space import HPSpace
from .domains.power_space import PowerSpace from .domains.power_space import PowerSpace
from .domains.rg_space import RGSpace from .domains.rg_space import RGSpace
from .domains.log_rg_space import LogRGSpace
from .domain_tuple import DomainTuple
from .field import Field from .field import Field
# relevant properties: # relevant properties:
...@@ -152,26 +154,20 @@ def _register_cmaps(): ...@@ -152,26 +154,20 @@ def _register_cmaps():
plt.register_cmap(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap)) plt.register_cmap(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap))
def _plot(f, ax, **kwargs): def _plot1D(f, ax, **kwargs):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
_register_cmaps()
if isinstance(f, Field):
f = [f]
if not isinstance(f, list):
raise TypeError("incorrect data type")
for i, fld in enumerate(f): for i, fld in enumerate(f):
if not isinstance(fld, Field): if not isinstance(fld, Field):
raise TypeError("incorrect data type") raise TypeError("incorrect data type")
if i == 0: if i == 0:
dom = fld.domain dom = fld.domain
if len(dom) != 1: if (len(dom) != 1):
raise ValueError("input field must have exactly one domain") raise ValueError("input field must have exactly one domain")
else: else:
if fld.domain != dom: if fld.domain != dom:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
if not (isinstance(dom[0], PowerSpace) or dom = dom[0]
(isinstance(dom[0], RGSpace) and len(dom[0].shape) == 1)):
raise ValueError("PowerSpace or 1D RGSpace required")
label = kwargs.pop("label", None) label = kwargs.pop("label", None)
if not isinstance(label, list): if not isinstance(label, list):
...@@ -185,44 +181,38 @@ def _plot(f, ax, **kwargs): ...@@ -185,44 +181,38 @@ def _plot(f, ax, **kwargs):
if not isinstance(alpha, list): if not isinstance(alpha, list):
alpha = [alpha] * len(f) alpha = [alpha] * len(f)
foo = kwargs.pop("norm", None)
norm = {} if foo is None else {'norm': foo}
dom = dom[0]
ax.set_title(kwargs.pop("title", "")) ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", "")) ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", "")) ax.set_ylabel(kwargs.pop("ylabel", ""))
cmap = kwargs.pop("colormap", plt.rcParams['image.cmap'])
if isinstance(dom, RGSpace): if isinstance(dom, RGSpace):
if len(dom.shape) == 1: plt.yscale(kwargs.pop("yscale", "linear"))
npoints = dom.shape[0] npoints = dom.shape[0]
dist = dom.distances[0] dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist xcoord = np.arange(npoints, dtype=np.float64)*dist
for i, fld in enumerate(f): for i, fld in enumerate(f):
ycoord = fld.to_global_data() ycoord = fld.to_global_data()
plt.plot(xcoord, ycoord, label=label[i], plt.plot(xcoord, ycoord, label=label[i],
linewidth=linewidth[i], alpha=alpha[i]) linewidth=linewidth[i], alpha=alpha[i])
_limit_xy(**kwargs) _limit_xy(**kwargs)
if label != ([None]*len(f)): if label != ([None]*len(f)):
plt.legend() plt.legend()
return return
elif len(dom.shape) == 2: elif isinstance(dom, LogRGSpace):
nx, ny = dom.shape plt.yscale(kwargs.pop("yscale", "log"))
dx, dy = dom.distances npoints = dom.shape[0]
im = ax.imshow( xcoord = dom.t_0 + np.arange(npoints-1)*dom.bindistances[0]
f[0].to_global_data().T, extent=[0, nx*dx, 0, ny*dy], for i, fld in enumerate(f):
vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), ycoord = fld.to_global_data()[1:]
cmap=cmap, origin="lower", **norm) plt.plot(xcoord, ycoord, label=label[i],
# from mpl_toolkits.axes_grid1 import make_axes_locatable linewidth=linewidth[i], alpha=alpha[i])
# divider = make_axes_locatable(ax) _limit_xy(**kwargs)
# cax = divider.append_axes("right", size="5%", pad=0.05) if label != ([None]*len(f)):
# plt.colorbar(im,cax=cax) plt.legend()
plt.colorbar(im) return
_limit_xy(**kwargs)
return
elif isinstance(dom, PowerSpace): elif isinstance(dom, PowerSpace):
plt.xscale('log') plt.xscale(kwargs.pop("xscale", "log"))
plt.yscale('log') plt.yscale(kwargs.pop("yscale", "log"))
xcoord = dom.k_lengths xcoord = dom.k_lengths
for i, fld in enumerate(f): for i, fld in enumerate(f):
ycoord = fld.to_global_data() ycoord = fld.to_global_data()
...@@ -232,6 +222,38 @@ def _plot(f, ax, **kwargs): ...@@ -232,6 +222,38 @@ def _plot(f, ax, **kwargs):
if label != ([None]*len(f)): if label != ([None]*len(f)):
plt.legend() plt.legend()
return return
raise ValueError("Field type not(yet) supported")
def _plot2D(f, ax, **kwargs):
import matplotlib.pyplot as plt
dom = f.domain
if len(dom) > 1:
raise ValueError("DomainTuple must have exactly one entry.")
label = kwargs.pop("label", None)
foo = kwargs.pop("norm", None)
norm = {} if foo is None else {'norm': foo}
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
dom = dom[0]
cmap = kwargs.pop("colormap", plt.rcParams['image.cmap'])
if isinstance(dom, RGSpace):
nx, ny = dom.shape
dx, dy = dom.distances
im = ax.imshow(
f.to_global_data().T, extent=[0, nx*dx, 0, ny*dy],
vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower", **norm)
plt.colorbar(im)
_limit_xy(**kwargs)
return
elif isinstance(dom, (HPSpace, GLSpace)): elif isinstance(dom, (HPSpace, GLSpace)):
import pyHealpix import pyHealpix
xsize = 800 xsize = 800
...@@ -240,21 +262,44 @@ def _plot(f, ax, **kwargs): ...@@ -240,21 +262,44 @@ def _plot(f, ax, **kwargs):
ptg = np.empty((phi.size, 2), dtype=np.float64) ptg = np.empty((phi.size, 2), dtype=np.float64)
ptg[:, 0] = theta ptg[:, 0] = theta
ptg[:, 1] = phi ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(f[0].size//12)), "RING") base = pyHealpix.Healpix_Base(int(np.sqrt(dom.size//12)), "RING")
res[mask] = f[0].to_global_data()[base.ang2pix(ptg)] res[mask] = f.to_global_data()[base.ang2pix(ptg)]
else: else:
ra = np.linspace(0, 2*np.pi, dom.nlon+1) ra = np.linspace(0, 2*np.pi, dom.nlon+1)
dec = pyHealpix.GL_thetas(dom.nlat) dec = pyHealpix.GL_thetas(dom.nlat)
ilat = _find_closest(dec, theta) ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi) ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom.nlon, 0, ilon) ilon = np.where(ilon == dom.nlon, 0, ilon)
res[mask] = f[0].to_global_data()[ilat*dom.nlon + ilon] res[mask] = f.to_global_data()[ilat*dom.nlon + ilon]
plt.axis('off') plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"), plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower") cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal") plt.colorbar(orientation="horizontal")
return return
raise ValueError("Field type not(yet) supported")
def _plot(f, ax, **kwargs):
_register_cmaps()
if isinstance(f, Field):
f = [f]
f = list(f)
if len(f) == 0:
raise ValueError("need something to plot")
if not isinstance(f[0], Field):
raise TypeError("incorrect data type")
dom1 = f[0].domain
if (len(dom1)==1 and
(isinstance(dom1[0],PowerSpace) or
(isinstance(dom1[0], (RGSpace, LogRGSpace)) and
len(dom1[0].shape) == 1))):
_plot1D(f, ax, **kwargs)
return
else:
if len(f) != 1:
raise ValueError("need exactly one Field for 2D plot")
_plot2D(f[0], ax, **kwargs)
return
raise ValueError("Field type not(yet) supported") raise ValueError("Field type not(yet) supported")
......
...@@ -89,10 +89,10 @@ def testBinary(type1, type2, space, seed): ...@@ -89,10 +89,10 @@ def testBinary(type1, type2, space, seed):
def testModelLibrary(space, seed): def testModelLibrary(space, seed):
# Tests amplitude model and coorelated field model # Tests amplitude model and coorelated field model
Npixdof, ceps_a, ceps_k, sm, sv, im, iv = 4, 0.5, 2., 3., 1.5, 1.75, 1.3
np.random.seed(seed) np.random.seed(seed)
domain = ift.PowerSpace(space.get_default_codomain()) domain = ift.PowerSpace(space.get_default_codomain())
model = ift.SLAmplitude(domain, Npixdof, ceps_a, ceps_k, sm, sv, im, iv) model = ift.SLAmplitude(target=domain, n_pix=4, a=.5, k0=2, sm=3, sv=1.5,
im=1.75, iv=1.3)
assert_(isinstance(model, ift.Operator)) assert_(isinstance(model, ift.Operator))
S = ift.ScalingOperator(1., model.domain) S = ift.ScalingOperator(1., model.domain)
pos = S.draw_sample() pos = S.draw_sample()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment