Commit 9c7ce5fc authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'plotting' into 'master'

Plotting

See merge request !94
parents 88033e38 0327c36a
Pipeline #12346 failed with stages
in 6 minutes and 19 seconds
from nifty import *
from mpi4py import MPI
import plotly.offline as py
import plotly.graph_objs as go
comm = MPI.COMM_WORLD
rank = comm.rank
def plot_maps(x, name):
trace = [None]*len(x)
keys = x.keys()
field = x[keys[0]]
domain = field.domain[0]
shape = len(domain.shape)
max_n = domain.shape[0]*domain.distances[0]
step = domain.distances[0]
x_axis = np.arange(0, max_n, step)
if shape == 1:
for ii in xrange(len(x)):
trace[ii] = go.Scatter(x= x_axis, y=x[keys[ii]].val.get_full_data(), name=keys[ii])
fig = go.Figure(data=trace)
py.plot(fig, filename=name)
elif shape == 2:
for ii in xrange(len(x)):
py.plot([go.Heatmap(z=x[keys[ii]].val.get_full_data().real)], filename=keys[ii])
else:
raise TypeError("Only 1D and 2D field plots are supported")
def plot_power(x, name):
layout = go.Layout(
xaxis=dict(
type='log',
autorange=True
),
yaxis=dict(
type='log',
autorange=True
)
)
trace = [None]*len(x)
keys = x.keys()
field = x[keys[0]]
domain = field.domain[0]
x_axis = domain.kindex
for ii in xrange(len(x)):
trace[ii] = go.Scatter(x= x_axis, y=x[keys[ii]].val.get_full_data(), name=keys[ii])
fig = go.Figure(data=trace, layout=layout)
py.plot(fig, filename=name)
np.random.seed(42)
if __name__ == "__main__":
distribution_strategy = 'not'
# setting spaces
npix = np.array([500]) # number of pixels
total_volume = 1. # total length
# setting signal parameters
lambda_s = .05 # signal correlation length
sigma_s = 10. # signal variance
#setting response operator parameters
length_convolution = .025
exposure = 1.
# calculating parameters
k_0 = 4. / (2 * np.pi * lambda_s)
a_s = sigma_s ** 2. * lambda_s * total_volume
# creation of spaces
# x1 = RGSpace([npix,npix], distances=total_volume / npix,
# zerocenter=False)
# k1 = RGRGTransformation.get_codomain(x1)
x1 = HPSpace(32)
k1 = HPLMTransformation.get_codomain(x1)
p1 = PowerSpace(harmonic_partner=k1, logarithmic=False)
# creating Power Operator with given spectrum
spec = (lambda k: a_s / (1 + (k / k_0) ** 2) ** 2)
p_field = Field(p1, val=spec)
S_op = create_power_operator(k1, spec)
# creating FFT-Operator and Response-Operator with Gaussian convolution
Fft_op = FFTOperator(domain=x1, target=k1,
domain_dtype=np.float64,
target_dtype=np.complex128)
R_op = ResponseOperator(x1, sigma=[length_convolution],
exposure=[exposure])
# drawing a random field
sk = p_field.power_synthesize(real_signal=True, mean=0.)
s = Fft_op.adjoint_times(sk)
signal_to_noise = 1
N_op = DiagonalOperator(R_op.target, diagonal=s.var()/signal_to_noise, bare=True)
n = Field.from_random(domain=R_op.target,
random_type='normal',
std=s.std()/np.sqrt(signal_to_noise),
mean=0.)
d = R_op(s) + n
# Wiener filter
j = Fft_op.times(R_op.adjoint_times(N_op.inverse_times(d)))
D = HarmonicPropagatorOperator(S=S_op, N=N_op, R=R_op)
mk = D(j)
m = Fft_op.adjoint_times(mk)
# z={}
# z["signal"] = s
# z["reconstructed_map"] = m
# z["data"] = d
# z["lambda"] = R_op(s)
# z["j"] = j
#
# plot_maps(z, "Wiener_filter.html")
......@@ -50,8 +50,8 @@ from spaces import *
from operators import *
from plotting import *
from probing import *
from sugar import *
import plotting
from healpix_plotter import HealpixPlotter
from power_plotter import PowerPlotter
......@@ -3,39 +3,41 @@
import abc
import os
import numpy as np
import plotly
from plotly import tools
import plotly.offline as ply
import plotly.offline as plotly_offline
from keepers import Loggable
from nifty.spaces.space import Space
from nifty.field_types.field_type import FieldType
from nifty.field import Field
import nifty.nifty_utilities as utilities
from nifty.plotting.figures import Figure2D,\
Figure3D,\
MultiFigure
plotly.offline.init_notebook_mode()
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
class Plotter(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', stack_subplots=False,
color_scale=None):
def __init__(self, interactive=False, path='.', title=""):
self.interactive = interactive
self.path = path
self.stack_subplots = stack_subplots
self.color_scale = color_scale
self.title = 'uiae'
self.title = str(title)
@abc.abstractproperty
def domain(self):
def domain_classes(self):
return (Space,)
@abc.abstractproperty
def field_type(self):
return (FieldType,)
@property
def interactive(self):
return self._interactive
......@@ -52,49 +54,71 @@ class Plotter(Loggable, object):
def path(self, new_path):
self._path = os.path.normpath(new_path)
@property
def stack_subplots(self):
return self._stack_subplots
@stack_subplots.setter
def stack_subplots(self, stack_subplots):
self._stack_subplots = bool(stack_subplots)
def plot(self, fields, spaces=None, data_extractor=None, labels=None):
if isinstance(fields, Field):
fields = [fields]
elif not isinstance(fields, list):
fields = list(fields)
def plot(self, field, spaces=None, types=None, slice=None):
data = self._get_data_from_field(field, spaces, types, slice)
figures = self._create_individual_plot(data)
self._finalize_figure(figures)
spaces = utilities.cast_axis_to_tuple(spaces, len(fields[0].domain))
@abc.abstractmethod
def _get_data_from_field(self, field, spaces=None, types=None, slice=None):
# if fields is a list, create a new field with appended
# field_type = field_array and copy individual parts into the new field
if spaces is None:
spaces = tuple(range(len(fields[0].domain)))
spaces = utilities.cast_axis_to_tuple(spaces, len(field.domain))
types = utilities.cast_axis_to_tuple(types, len(field.field_type))
if field.domain[spaces] != self.domain:
raise AttributeError("Given space(s) of input field-domain do not "
"match the plotters domain.")
if field.field_type[spaces] != self.field_type:
raise AttributeError("Given field_type(s) of input field-domain "
"do not match the plotters field_type.")
axes = []
plot_domain = []
for space_index in spaces:
axes += list(fields[0].domain_axes[space_index])
plot_domain += [fields[0].domain[space_index]]
# iterate over the individual slices in order to compose the figure
# -> make a d2o.get_full_data() (for rank==0 only?)
# add
return [1,2,3]
# prepare data
data_list = [self._get_data_from_field(field, spaces, data_extractor)
for field in fields]
def _create_individual_plot(self, data):
pass
# create plots
plots_list = []
for slice_list in utilities.get_slice_list(data_list[0].shape, axes):
plots_list += \
[[self._create_individual_plot(current_data[slice_list],
plot_domain)
for current_data in data_list]]
def _finalize_figure(self, figure):
pass
# is there a use for ply.plot when one has no interest in
# saving a file?
figures = [self._create_individual_figure(plots)
for plots in plots_list]
# -> check for different file types
# -> store the file to disk (MPI awareness?)
self._finalize_figure(figures)
def _get_data_from_field(self, field, spaces, data_extractor):
for i, space_index in enumerate(spaces):
if not isinstance(field.domain[space_index],
self.domain_classes[i]):
raise AttributeError("Given space(s) of input field-domain do "
"not match the plotters domain.")
# TODO: add data_extractor functionality here
data = field.val.get_full_data(target_rank=0)
return data
@abc.abstractmethod
def _create_individual_figure(self, plots):
raise NotImplementedError
@abc.abstractmethod
def _create_individual_plot(self, data, fields):
raise NotImplementedError
def _finalize_figure(self, figures):
if len(figures) > 1:
rows = (len(figures) + 1)//2
figure_array = np.empty((2*rows), dtype=np.object)
figure_array[:len(figures)] = figures
figure_array = figure_array.reshape((2, rows))
final_figure = MultiFigure(rows, 2,
title='Test',
subfigures=figure_array)
else:
final_figure = figures[0]
plotly_offline.plot(final_figure.to_plotly(),
filename=os.path.join(self.path, self.title))
......@@ -2,3 +2,4 @@ from descriptors import *
from plots import *
from figures import *
from colormap import *
from plotter import *
......@@ -13,8 +13,9 @@ class Colormap(PlotlyWrapper):
#TODO: implement validation
pass
# no discontinuities only
@staticmethod
def from_matplotlib_colormap_internal(name, mpl_cmap): # no discontinuities only
def from_matplotlib_colormap_internal(name, mpl_cmap):
red = [(c[0], c[2]) for c in mpl_cmap['red']]
green = [(c[0], c[2]) for c in mpl_cmap['green']]
blue = [(c[0], c[2]) for c in mpl_cmap['blue']]
......@@ -40,7 +41,8 @@ class Colormap(PlotlyWrapper):
green_val = self.green[g][1]
g += 1
else:
slope = (self.green[g][1] - prev_g) / (self.green[g][0] - prev_split)
slope = ((self.green[g][1] - prev_g) /
(self.green[g][0] - prev_split))
y = prev_g - slope * prev_split
green_val = slope * next_split + y
......@@ -48,11 +50,16 @@ class Colormap(PlotlyWrapper):
blue_val = self.blue[b][1]
b += 1
else:
slope = (self.blue[b][1] - prev_b) / (self.blue[b][0] - prev_split)
slope = ((self.blue[b][1] - prev_b) /
(self.blue[b][0] - prev_split))
y = prev_r - slope * prev_split
blue_val = slope * next_split + y
prev_split, prev_r, prev_g, prev_b = next_split, red_val, green_val, blue_val
prev_split = next_split
prev_r = red_val
prev_g = green_val
prev_b = blue_val
converted.append([next_split,
'rgb(' +
str(int(red_val*255)) + "," +
......
......@@ -31,29 +31,46 @@ class MultiFigure(FigureBase):
def to_plotly(self):
title_extractor = lambda z: z.title if z else ""
sub_titles = \
tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
sub_titles = tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
specs_setter = \
lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {}
sub_specs = \
list(map(list, np.vectorize(specs_setter)(self.subfigures)))
specs_setter = lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {}
sub_specs = list(map(list, np.vectorize(specs_setter)(self.subfigures)))
multi_figure_plotly_object = plotly.tools.make_subplots(
self.rows,
self.columns,
subplot_titles=sub_titles,
specs=sub_specs)
# TODO resolve bug with titles and 3D subplots
multi_figure_plotly_object['layout'].update(height=self.height,
width=self.width,
title=self.title)
#TODO resolve bug with titles and 3D subplots
i = 1
for index, fig in np.ndenumerate(self.subfigures):
if fig:
for plot in fig.plots:
multi_figure_plotly_object.append_trace(plot.to_plotly(),
index[0]+1,
index[1]+1)
if isinstance(fig, Figure3D):
scene = dict()
if fig.xaxis:
scene['xaxis'] = fig.xaxis.to_plotly()
if fig.yaxis:
scene['yaxis'] = fig.yaxis.to_plotly()
if fig.zaxis:
scene['zaxis'] = fig.zaxis.to_plotly()
multi_figure_plotly_object['layout'].update(height=self.height,
width=self.width,
title=self.title)
multi_figure_plotly_object['layout']['scene'+str(i)] = scene
else:
if fig.xaxis:
multi_figure_plotly_object['layout']['xaxis'+str(i)] = fig.xaxis.to_plotly()
if fig.yaxis:
multi_figure_plotly_object['layout']['yaxis'+str(i)] = fig.yaxis.to_plotly()
i += 1
return multi_figure_plotly_object
# -*- coding: utf-8 -*-
import numpy as np
from nifty.plotting.colormap import Colormap
from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
if isinstance(data, list):
self.data = np.zeros((data[0].shape))
for arr in data:
self.data = np.add(self.data, arr)
else:
self.data = data
if color_map is not None:
if not isinstance(color_map, Colormap):
raise TypeError("Provided color_map must be an instance of "
"the NIFTy Colormap class.")
self.color_map = color_map
self.webgl = webgl
self.smoothing = smoothing
@property
def figure_dimension(self):
return 2
def to_plotly(self):
plotly_object = dict()
plotly_object['z'] = self.data
......
......@@ -14,7 +14,9 @@ class Mollweide(Heatmap):
raise ImportError("The module pylab is needed but not available.")
if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available.")
if isinstance(data, list):
data = [self._mollview(d) for d in data]
else:
data = self._mollview(data)
super(Mollweide, self).__init__(data, color_map, webgl, smoothing)
......
......@@ -4,14 +4,16 @@ from scatter_plot import ScatterPlot
class Cartesian(ScatterPlot):
def __init__(self, x, y, label, line, marker):
def __init__(self, x, y, label, line, marker, showlegend=True):
super(Cartesian, self).__init__(label, line, marker)
self.x = x
self.y = y
self.showlegend = showlegend
@abstractmethod
def to_plotly(self):
plotly_object = super(Cartesian, self).to_plotly()
plotly_object['x'] = self.x
plotly_object['y'] = self.y
plotly_object['showlegend'] = self.showlegend
return plotly_object
......@@ -5,14 +5,20 @@ from cartesian import Cartesian
class Cartesian2D(Cartesian):
def __init__(self, x=None, y=None, x_start=0, x_step=1,
label='', line=None, marker=None, webgl=True):
label='', line=None, marker=None, showlegend=True,
webgl=True):
if y is None:
raise Exception('Error: no y data to plot')
if x is None:
x = range(x_start, len(y) * x_step, x_step)
super(Cartesian2D, self).__init__(x, y, label, line, marker)
super(Cartesian2D, self).__init__(x, y, label, line, marker,
showlegend)
self.webgl = webgl
@property
def figure_dimension(self):
return 2
def to_plotly(self):
plotly_object = super(Cartesian2D, self).to_plotly()
if self.webgl:
......
......@@ -4,10 +4,16 @@ from cartesian import Cartesian
class Cartesian3D(Cartesian):
def __init__(self, x, y, z, label='', line=None, marker=None):
super(Cartesian3D, self).__init__(x, y, label, line, marker)
def __init__(self, x, y, z, label='', line=None, marker=None,
showlegend=True):
super(Cartesian3D, self).__init__(x, y, label, line, marker,
showlegend)
self.z = z
@property
def figure_dimension(self):
return 3
def to_plotly(self):
plotly_object = super(Cartesian3D, self).to_plotly()
plotly_object['z'] = self.z
......
from scatter_plot import ScatterPlot
......@@ -13,6 +14,10 @@ class Geo(ScatterPlot):
self.lat = lat
self.projection = proj
@property
def figure_dimension(self):
return 2
def _to_plotly(self):
plotly_object = super(Geo, self).to_plotly()
plotly_object['type'] = 'scattergeo'
......
# -*- coding: utf-8 -*-
from abc import abstractmethod
import abc
from nifty.plotting.plotly_wrapper import PlotlyWrapper
from nifty.plotting.descriptors import Marker
from nifty.plotting.descriptors import Marker,\
Line
class ScatterPlot(PlotlyWrapper):
......@@ -12,8 +13,13 @@ class ScatterPlot(PlotlyWrapper):
self.marker = marker
if not self.line and not self.marker:
self.marker = Marker()
self.line = Line()
@abstractmethod
@abc.abstractproperty
def figure_dimension(self):
raise NotImplementedError
@abc.abstractmethod
def to_plotly(self):
ply_object = dict()
ply_object['name'] = self.label
......
from nifty.spaces import HPSpace
from nifty.plotting.figures import Figure2D
from nifty.plotting.plots import Mollweide
from .plotter import Plotter
class HealpixPlotter(Plotter):
def __init__(self, interactive=False, path='.', title="", color_map=None):
super(HealpixPlotter, self).__init__(interactive, path, title)
self.color_map = color_map
@property
def domain_classes(self):
return (HPSpace, )
def _create_individual_figure(self, plots):
return Figure2D(plots)