Commit 2db0d555 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into even_more_operator_work

parents 1cd93ef6 7c9500fa
......@@ -88,6 +88,7 @@ if __name__ == '__main__':
reconstruction = sky.at(H.position).value
ift.plot(reconstruction, title='reconstruction', name='reconstruction.png')
ift.plot(GR.adjoint_times(data), title='data', name='data.png')
ift.plot(sky.at(mock_position).value, title='truth', name='truth.png')
ift.plot(reconstruction, title='reconstruction')
ift.plot(GR.adjoint_times(data), title='data')
ift.plot(sky.at(mock_position).value, title='truth')
ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="bernoulli.png")
......@@ -46,7 +46,7 @@ if __name__ == '__main__':
# FIXME description of the tutorial
# Choose problem geometry and masking
mode = 0
mode = 1
if mode == 0:
# One dimensional regular grid
position_space = ift.RGSpace([1024])
......@@ -106,11 +106,14 @@ if __name__ == '__main__':
if rg and len(position_space.shape) == 1:
ift.plot([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)],
label=['Mock signal', 'Data', 'Reconstruction'],
alpha=[1, .3, 1],
name='getting_started_1.png')
alpha=[1, .3, 1])
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)))
ift.plot_finish(nx=2, ny=1, xsize=10, ysize=4,
title="getting_started_1")
else:
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal', name='mock_signal.png')
ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)),
title='Data', name='data.png')
ift.plot(HT(m), title='Reconstruction', name='reconstruction.png')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), name='residuals.png')
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal')
ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)), title='Data')
ift.plot(HT(m), title='Reconstruction')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)))
ift.plot_finish(nx=2, ny=2, xsize=10, ysize=10,
title="getting_started_1")
......@@ -83,13 +83,14 @@ if __name__ == '__main__':
INITIAL_POSITION = ift.from_random('normal', H.position.domain)
position = INITIAL_POSITION
ift.plot(signal.at(MOCK_POSITION).value, name='truth.png')
ift.plot(R.adjoint_times(data), name='data.png')
ift.plot([A.at(MOCK_POSITION).value], name='power.png')
ift.plot(signal.at(MOCK_POSITION).value, title='ground truth')
ift.plot(R.adjoint_times(data), title='data')
ift.plot([A.at(MOCK_POSITION).value], title='power')
ift.plot_finish(nx=3, xsize=16, ysize=5, title="setup", name="setup.png")
# number of samples used to estimate the KL
N_samples = 20
for i in range(5):
for i in range(2):
H = H.at(position)
samples = [H.metric.draw_sample(from_inverse=True)
for _ in range(N_samples)]
......@@ -99,17 +100,19 @@ if __name__ == '__main__':
KL, convergence = minimizer(KL)
position = KL.position
ift.plot(signal.at(position).value, name='reconstruction.png')
ift.plot(signal.at(position).value, title="reconstruction")
ift.plot([A.at(position).value, A.at(MOCK_POSITION).value],
name='power.png')
title="power")
ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png")
sc = ift.StatCalculator()
for sample in samples:
sc.add(signal.at(sample+position).value)
ift.plot(sc.mean, name='avrg.png')
ift.plot(ift.sqrt(sc.var), name='std.png')
ift.plot(sc.mean, title="mean")
ift.plot(ift.sqrt(sc.var), title="std deviation")
powers = [A.at(s+position).value for s in samples]
ift.plot([A.at(position).value, A.at(MOCK_POSITION).value]+powers,
name='power.png')
title="power")
ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", name="results.png")
import nifty5 as ift
import numpy as np
def plot_test():
rg_space1 = ift.makeDomain(ift.RGSpace((100,)))
rg_space2 = ift.makeDomain(ift.RGSpace((80, 80)))
hp_space = ift.makeDomain(ift.HPSpace(64))
gl_space = ift.makeDomain(ift.GLSpace(128))
fft = ift.FFTOperator(rg_space2)
field_rg1_1 = ift.Field.from_global_data(rg_space1, np.random.randn(100))
field_rg1_2 = ift.Field.from_global_data(rg_space1, np.random.randn(100))
field_rg2 = ift.Field.from_global_data(
rg_space2, np.random.randn(80 ** 2).reshape((80, 80)))
field_hp = ift.Field.from_global_data(hp_space, np.random.randn(12*64**2))
field_gl = ift.Field.from_global_data(gl_space, np.random.randn(32640))
field_ps = ift.power_analyze(fft.times(field_rg2))
## Start various plotting tests
ift.plot(field_rg1_1, title='Single plot')
ift.plot_finish()
ift.plot(field_rg2, title='2d rg')
ift.plot([field_rg1_1, field_rg1_2], title='list 1d rg', label=['1', '2'])
ift.plot(field_rg1_2, title='1d rg, xmin, ymin', xmin=0.5, ymin=0.,
xlabel='xmin=0.5', ylabel='ymin=0')
ift.plot_finish(title='Three plots')
ift.plot(field_hp, title='HP planck-color', colormap='Planck-like')
ift.plot(field_rg1_2, title='1d rg')
ift.plot(field_ps)
ift.plot(field_gl, title='GL')
ift.plot(field_rg2, title='2d rg')
ift.plot_finish(nx=2, ny=3, title='Five plots')
if __name__ == '__main__':
plot_test()
......@@ -75,7 +75,7 @@ from .minimization.line_energy import LineEnergy
from .minimization.energy_sum import EnergySum
from .sugar import *
from .plotting.plot import plot
from .plotting.plot import plot, plot_finish
from .library.amplitude_model import make_amplitude_model
from .library.gaussian_energy import GaussianEnergy
......
......@@ -27,6 +27,8 @@ class DomainTuple(object):
This class holds a tuple of :class:`Domain` objects, which together form
the space on which a :class:`Field` is defined.
This corresponds to a tensor product of the corresponding vector
fields.
Notes
-----
......
......@@ -74,11 +74,9 @@ class LogRGSpace(StructuredDomain):
% (self.shape, self.harmonic))
def get_default_codomain(self):
if self._harmonic:
raise ValueError("only supported for nonharmonic space")
codomain_bindistances = 1. / (self.bindistances * self.shape)
return LogRGSpace(self.shape, codomain_bindistances,
np.zeros(len(self.shape)), True)
self._t_0, True)
def get_k_length_array(self):
ib = dobj.ibegin_from_shape(self._shape)
......
......@@ -65,9 +65,9 @@ def make_amplitude_model(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
p_space = PowerSpace(h_space)
exp_transform = ExpTransform(p_space, Npixdof)
logk_space = exp_transform.domain[0]
dof_space = logk_space.get_default_codomain()
qht = QHTOperator(target=logk_space)
dof_space = qht.domain[0]
param_space = UnstructuredDomain(2)
qht = QHTOperator(dof_space, logk_space)
sym = SymmetrizingOperator(logk_space)
phi_mean = np.array([sm, im])
......
......@@ -62,7 +62,7 @@ def _comp_traverse(start, end, shp, dist, lo, mid, hi, erf):
# hack: move away from potential grid crossings
dmin += 1e-7
dmax -= 1e-7
if dmin > dmax: # no intersection
if dmin >= dmax: # no intersection
out[i] = (np.full(0, 0), np.full(0, 0.))
continue
# determine coordinates of first cell crossing
......
......@@ -24,6 +24,12 @@ from ..utilities import frozendict
class MultiDomain(object):
"""A tuple of domains corresponding to a direct sum.
This class is the domain of the direct sum of fields living
over (possibly different) domains. To make an instance
of this class, call `MultiDomain.make(inp)`.
"""
_domainCache = {}
def __init__(self, dict, _callingfrommake=False):
......@@ -36,6 +42,17 @@ class MultiDomain(object):
@staticmethod
def make(inp):
"""Creates a MultiDomain object from a dictionary of names and domains
Parameters
----------
inp : MultiDomain or dict{name: DomainTuple}
The already built MultiDomain or a dictionary of DomainTuples
Returns
------
A MultiDomain with the input Domains as domains
"""
if isinstance(inp, MultiDomain):
return inp
if not isinstance(inp, dict):
......
......@@ -36,34 +36,28 @@ class QHTOperator(LinearOperator):
Parameters
----------
domain : domain, tuple of domains or DomainTuple
The full input domain
target : domain, tuple of domains or DomainTuple
The full output domain
space : int
The index of the domain on which the operator acts.
domain[space] must be a harmonic LogRGSpace.
target : LogRGSpace
The target codomain of domain[space]
Must be a nonharmonic LogRGSpace.
target[space] must be a nonharmonic LogRGSpace.
"""
def __init__(self, domain, target, space=0):
self._domain = DomainTuple.make(domain)
self._space = infer_space(self._domain, space)
def __init__(self, target, space=0):
self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space)
from ..domains.log_rg_space import LogRGSpace
if not isinstance(self._domain[self._space], LogRGSpace):
raise ValueError("Domain[space] has to be a LogRGSpace!")
if not isinstance(target, LogRGSpace):
raise ValueError("Target has to be a LogRGSpace!")
if not isinstance(self._target[self._space], LogRGSpace):
raise ValueError("target[space] has to be a LogRGSpace!")
if not self._domain[self._space].harmonic:
if self._target[self._space].harmonic:
raise TypeError(
"QHTOperator only works on a harmonic space")
if target.harmonic:
raise TypeError("Target is not a codomain of domain")
"target[space] must be a nonharmonic space")
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
self._domain = [dom for dom in self._target]
self._domain[self._space] = \
self._target[self._space].get_default_codomain()
self._domain = DomainTuple.make(self._domain)
@property
def domain(self):
......
......@@ -23,7 +23,12 @@ import os
import numpy as np
from ..compat import *
from .. import Field, GLSpace, HPSpace, PowerSpace, RGSpace, dobj
from ..field import Field
from ..domains.gl_space import GLSpace
from ..domains.hp_space import HPSpace
from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
from .. import dobj
# relevant properties:
# - x/y size
......@@ -81,16 +86,6 @@ def _makeplot(name):
elif extension == ".png":
plt.savefig(name)
plt.close()
# elif extension==".html":
# import mpld3
# mpld3.save_html(plt.gcf(),fileobj=name,no_extras=True)
# import plotly.offline as py
# import plotly.tools as tls
# plotly_fig = tls.mpl_to_plotly(plt.gcf())
# py.plot(plotly_fig,filename=name)
# py.plot_mpl(plt.gcf(),filename=name)
# import bokeh
# bokeh.mpl.to_bokeh(plt.gcf())
else:
raise ValueError("file format not understood")
......@@ -169,7 +164,7 @@ def _register_cmaps():
plt.register_cmap(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap))
def plot(f, **kwargs):
def _plot(f, ax, **kwargs):
import matplotlib.pyplot as plt
_register_cmaps()
if isinstance(f, Field):
......@@ -209,12 +204,6 @@ def plot(f, **kwargs):
alpha = [alpha]
dom = dom[0]
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
xsize = kwargs.pop("xsize", 6)
ysize = kwargs.pop("ysize", 6)
fig.set_size_inches(xsize, ysize)
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
......@@ -231,7 +220,6 @@ def plot(f, **kwargs):
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
_makeplot(kwargs.get("name"))
return
elif len(dom.shape) == 2:
f = f[0]
......@@ -251,7 +239,6 @@ def plot(f, **kwargs):
# plt.colorbar(im,cax=cax)
plt.colorbar(im)
_limit_xy(**kwargs)
_makeplot(kwargs.get("name"))
return
elif isinstance(dom, PowerSpace):
plt.xscale('log')
......@@ -265,7 +252,6 @@ def plot(f, **kwargs):
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
_makeplot(kwargs.get("name"))
return
elif isinstance(dom, HPSpace):
f = f[0]
......@@ -282,7 +268,6 @@ def plot(f, **kwargs):
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
_makeplot(kwargs.get("name"))
return
elif isinstance(dom, GLSpace):
f = f[0]
......@@ -300,7 +285,85 @@ def plot(f, **kwargs):
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
_makeplot(kwargs.get("name"))
return
raise ValueError("Field type not(yet) supported")
_plots = []
_kwargs = []
def plot(f, **kwargs):
"""Add a figure to the current list of plots.
Notes
-----
After doing one or more calls `plot()`, one also needs to call
`plot_finish()` to output the result.
Parameters
----------
f: Field, or list of Field objects
If `f` is a single Field, it must live over a single `RGSpace`,
`PowerSpace`, `HPSpace`, `GLSPace`.
If it is a list, all list members must be Fields living over the same
one-dimensional `RGSpace` or `PowerSpace`.
title: string
title of the plot
xlabel: string
label for the x axis
ylabel: string
label for the y axis
[xyz]min, [xyz]max: float
limits for the values to plot
colormap: string
color map to use for the plot (if it is a 2D plot)
linewidth: float or list of floats
line width
label: string of list of strings
annotation string
alpha: float or list of floats
transparency value
"""
_plots.append(f)
_kwargs.append(kwargs)
def plot_finish(**kwargs):
"""Plot the accumulated list of figures.
Parameters
----------
title: string
title of the full plot
nx, ny: integer (default: square root of the numer of plots, rounded up)
number of subplots to use in x- and y-direction
xsize, ysize: float (default: 6)
dimensions of the full plot in inches
name: string (default: "")
if left empty, the plot will be shown on the screen,
otherwise it will be written to a file with the given name.
Supported extensions: .png and .pdf
"""
global _plots, _kwargs
import matplotlib.pyplot as plt
nplot = len(_plots)
fig = plt.figure()
if "title" in kwargs:
plt.suptitle(kwargs.pop("title"))
nx = kwargs.pop("nx", int(np.ceil(np.sqrt(nplot))))
ny = kwargs.pop("ny", int(np.ceil(np.sqrt(nplot))))
if nx*ny < nplot:
raise ValueError(
'Figure dimensions not sufficient for number of plots')
xsize = kwargs.pop("xsize", 6)
ysize = kwargs.pop("ysize", 6)
fig.set_size_inches(xsize, ysize)
for i in range(nplot):
ax = fig.add_subplot(ny, nx, i+1)
_plot(_plots[i], ax, **_kwargs[i])
fig.tight_layout()
_makeplot(kwargs.pop("name", None))
_plots = []
_kwargs = []
......@@ -144,8 +144,6 @@ class Consistency_Tests(unittest.TestCase):
ift.LogRGSpace(17, [3.], [.7])), 1)],
[np.float64]))
def testQHTOperator(self, args, dtype):
dom = ift.DomainTuple.make(args[0])
tgt = list(dom)
tgt[args[1]] = tgt[args[1]].get_default_codomain()
op = ift.QHTOperator(tgt, dom[args[1]], args[1])
tgt = ift.DomainTuple.make(args[0])
op = ift.QHTOperator(tgt, args[1])
ift.extra.consistency_check(op, dtype, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment