Commit 9c7ce5fc authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'plotting' into 'master'

Plotting

See merge request !94
parents 88033e38 0327c36a
Pipeline #12346 failed with stages
in 6 minutes and 19 seconds
from nifty import *
from mpi4py import MPI
import plotly.offline as py
import plotly.graph_objs as go
comm = MPI.COMM_WORLD
rank = comm.rank
def plot_maps(x, name):
trace = [None]*len(x)
keys = x.keys()
field = x[keys[0]]
domain = field.domain[0]
shape = len(domain.shape)
max_n = domain.shape[0]*domain.distances[0]
step = domain.distances[0]
x_axis = np.arange(0, max_n, step)
if shape == 1:
for ii in xrange(len(x)):
trace[ii] = go.Scatter(x= x_axis, y=x[keys[ii]].val.get_full_data(), name=keys[ii])
fig = go.Figure(data=trace)
py.plot(fig, filename=name)
elif shape == 2:
for ii in xrange(len(x)):
py.plot([go.Heatmap(z=x[keys[ii]].val.get_full_data().real)], filename=keys[ii])
else:
raise TypeError("Only 1D and 2D field plots are supported")
def plot_power(x, name):
layout = go.Layout(
xaxis=dict(
type='log',
autorange=True
),
yaxis=dict(
type='log',
autorange=True
)
)
trace = [None]*len(x)
keys = x.keys()
field = x[keys[0]]
domain = field.domain[0]
x_axis = domain.kindex
for ii in xrange(len(x)):
trace[ii] = go.Scatter(x= x_axis, y=x[keys[ii]].val.get_full_data(), name=keys[ii])
fig = go.Figure(data=trace, layout=layout)
py.plot(fig, filename=name)
np.random.seed(42)
if __name__ == "__main__":
distribution_strategy = 'not'
# setting spaces
npix = np.array([500]) # number of pixels
total_volume = 1. # total length
# setting signal parameters
lambda_s = .05 # signal correlation length
sigma_s = 10. # signal variance
#setting response operator parameters
length_convolution = .025
exposure = 1.
# calculating parameters
k_0 = 4. / (2 * np.pi * lambda_s)
a_s = sigma_s ** 2. * lambda_s * total_volume
# creation of spaces
# x1 = RGSpace([npix,npix], distances=total_volume / npix,
# zerocenter=False)
# k1 = RGRGTransformation.get_codomain(x1)
x1 = HPSpace(32)
k1 = HPLMTransformation.get_codomain(x1)
p1 = PowerSpace(harmonic_partner=k1, logarithmic=False)
# creating Power Operator with given spectrum
spec = (lambda k: a_s / (1 + (k / k_0) ** 2) ** 2)
p_field = Field(p1, val=spec)
S_op = create_power_operator(k1, spec)
# creating FFT-Operator and Response-Operator with Gaussian convolution
Fft_op = FFTOperator(domain=x1, target=k1,
domain_dtype=np.float64,
target_dtype=np.complex128)
R_op = ResponseOperator(x1, sigma=[length_convolution],
exposure=[exposure])
# drawing a random field
sk = p_field.power_synthesize(real_signal=True, mean=0.)
s = Fft_op.adjoint_times(sk)
signal_to_noise = 1
N_op = DiagonalOperator(R_op.target, diagonal=s.var()/signal_to_noise, bare=True)
n = Field.from_random(domain=R_op.target,
random_type='normal',
std=s.std()/np.sqrt(signal_to_noise),
mean=0.)
d = R_op(s) + n
# Wiener filter
j = Fft_op.times(R_op.adjoint_times(N_op.inverse_times(d)))
D = HarmonicPropagatorOperator(S=S_op, N=N_op, R=R_op)
mk = D(j)
m = Fft_op.adjoint_times(mk)
# z={}
# z["signal"] = s
# z["reconstructed_map"] = m
# z["data"] = d
# z["lambda"] = R_op(s)
# z["j"] = j
#
# plot_maps(z, "Wiener_filter.html")
...@@ -50,8 +50,8 @@ from spaces import * ...@@ -50,8 +50,8 @@ from spaces import *
from operators import * from operators import *
from plotting import *
from probing import * from probing import *
from sugar import * from sugar import *
import plotting
from healpix_plotter import HealpixPlotter
from power_plotter import PowerPlotter
...@@ -3,39 +3,41 @@ ...@@ -3,39 +3,41 @@
import abc import abc
import os import os
import numpy as np
import plotly import plotly
from plotly import tools import plotly.offline as plotly_offline
import plotly.offline as ply
from keepers import Loggable from keepers import Loggable
from nifty.spaces.space import Space from nifty.spaces.space import Space
from nifty.field_types.field_type import FieldType from nifty.field import Field
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty.plotting.figures import Figure2D,\
Figure3D,\
MultiFigure
plotly.offline.init_notebook_mode() plotly.offline.init_notebook_mode()
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
class Plotter(Loggable, object): class Plotter(Loggable, object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', stack_subplots=False, def __init__(self, interactive=False, path='.', title=""):
color_scale=None):
self.interactive = interactive self.interactive = interactive
self.path = path self.path = path
self.stack_subplots = stack_subplots self.title = str(title)
self.color_scale = color_scale
self.title = 'uiae'
@abc.abstractproperty @abc.abstractproperty
def domain(self): def domain_classes(self):
return (Space,) return (Space,)
@abc.abstractproperty
def field_type(self):
return (FieldType,)
@property @property
def interactive(self): def interactive(self):
return self._interactive return self._interactive
...@@ -52,49 +54,71 @@ class Plotter(Loggable, object): ...@@ -52,49 +54,71 @@ class Plotter(Loggable, object):
def path(self, new_path): def path(self, new_path):
self._path = os.path.normpath(new_path) self._path = os.path.normpath(new_path)
@property def plot(self, fields, spaces=None, data_extractor=None, labels=None):
def stack_subplots(self): if isinstance(fields, Field):
return self._stack_subplots fields = [fields]
elif not isinstance(fields, list):
@stack_subplots.setter fields = list(fields)
def stack_subplots(self, stack_subplots):
self._stack_subplots = bool(stack_subplots)
def plot(self, field, spaces=None, types=None, slice=None): spaces = utilities.cast_axis_to_tuple(spaces, len(fields[0].domain))
data = self._get_data_from_field(field, spaces, types, slice)
figures = self._create_individual_plot(data)
self._finalize_figure(figures)
@abc.abstractmethod if spaces is None:
def _get_data_from_field(self, field, spaces=None, types=None, slice=None): spaces = tuple(range(len(fields[0].domain)))
# if fields is a list, create a new field with appended
# field_type = field_array and copy individual parts into the new field
spaces = utilities.cast_axis_to_tuple(spaces, len(field.domain)) axes = []
types = utilities.cast_axis_to_tuple(types, len(field.field_type)) plot_domain = []
if field.domain[spaces] != self.domain: for space_index in spaces:
raise AttributeError("Given space(s) of input field-domain do not " axes += list(fields[0].domain_axes[space_index])
"match the plotters domain.") plot_domain += [fields[0].domain[space_index]]
if field.field_type[spaces] != self.field_type:
raise AttributeError("Given field_type(s) of input field-domain "
"do not match the plotters field_type.")
# iterate over the individual slices in order to compose the figure # prepare data
# -> make a d2o.get_full_data() (for rank==0 only?) data_list = [self._get_data_from_field(field, spaces, data_extractor)
# add for field in fields]
return [1,2,3]
def _create_individual_plot(self, data): # create plots
pass plots_list = []
for slice_list in utilities.get_slice_list(data_list[0].shape, axes):
plots_list += \
[[self._create_individual_plot(current_data[slice_list],
plot_domain)
for current_data in data_list]]
def _finalize_figure(self, figure): figures = [self._create_individual_figure(plots)
pass for plots in plots_list]
# is there a use for ply.plot when one has no interest in
# saving a file?
# -> check for different file types self._finalize_figure(figures)
# -> store the file to disk (MPI awareness?)
def _get_data_from_field(self, field, spaces, data_extractor):
for i, space_index in enumerate(spaces):
if not isinstance(field.domain[space_index],
self.domain_classes[i]):
raise AttributeError("Given space(s) of input field-domain do "
"not match the plotters domain.")
# TODO: add data_extractor functionality here
data = field.val.get_full_data(target_rank=0)
return data
@abc.abstractmethod
def _create_individual_figure(self, plots):
raise NotImplementedError
@abc.abstractmethod
def _create_individual_plot(self, data, fields):
raise NotImplementedError
def _finalize_figure(self, figures):
if len(figures) > 1:
rows = (len(figures) + 1)//2
figure_array = np.empty((2*rows), dtype=np.object)
figure_array[:len(figures)] = figures
figure_array = figure_array.reshape((2, rows))
final_figure = MultiFigure(rows, 2,
title='Test',
subfigures=figure_array)
else:
final_figure = figures[0]
plotly_offline.plot(final_figure.to_plotly(),
filename=os.path.join(self.path, self.title))
from descriptors import * from descriptors import *
from plots import * from plots import *
from figures import * from figures import *
from colormap import * from colormap import *
\ No newline at end of file from plotter import *
...@@ -13,8 +13,9 @@ class Colormap(PlotlyWrapper): ...@@ -13,8 +13,9 @@ class Colormap(PlotlyWrapper):
#TODO: implement validation #TODO: implement validation
pass pass
# no discontinuities only
@staticmethod @staticmethod
def from_matplotlib_colormap_internal(name, mpl_cmap): # no discontinuities only def from_matplotlib_colormap_internal(name, mpl_cmap):
red = [(c[0], c[2]) for c in mpl_cmap['red']] red = [(c[0], c[2]) for c in mpl_cmap['red']]
green = [(c[0], c[2]) for c in mpl_cmap['green']] green = [(c[0], c[2]) for c in mpl_cmap['green']]
blue = [(c[0], c[2]) for c in mpl_cmap['blue']] blue = [(c[0], c[2]) for c in mpl_cmap['blue']]
...@@ -40,7 +41,8 @@ class Colormap(PlotlyWrapper): ...@@ -40,7 +41,8 @@ class Colormap(PlotlyWrapper):
green_val = self.green[g][1] green_val = self.green[g][1]
g += 1 g += 1
else: else:
slope = (self.green[g][1] - prev_g) / (self.green[g][0] - prev_split) slope = ((self.green[g][1] - prev_g) /
(self.green[g][0] - prev_split))
y = prev_g - slope * prev_split y = prev_g - slope * prev_split
green_val = slope * next_split + y green_val = slope * next_split + y
...@@ -48,11 +50,16 @@ class Colormap(PlotlyWrapper): ...@@ -48,11 +50,16 @@ class Colormap(PlotlyWrapper):
blue_val = self.blue[b][1] blue_val = self.blue[b][1]
b += 1 b += 1
else: else:
slope = (self.blue[b][1] - prev_b) / (self.blue[b][0] - prev_split) slope = ((self.blue[b][1] - prev_b) /
(self.blue[b][0] - prev_split))
y = prev_r - slope * prev_split y = prev_r - slope * prev_split
blue_val = slope * next_split + y blue_val = slope * next_split + y
prev_split, prev_r, prev_g, prev_b = next_split, red_val, green_val, blue_val prev_split = next_split
prev_r = red_val
prev_g = green_val
prev_b = blue_val
converted.append([next_split, converted.append([next_split,
'rgb(' + 'rgb(' +
str(int(red_val*255)) + "," + str(int(red_val*255)) + "," +
......
...@@ -31,29 +31,46 @@ class MultiFigure(FigureBase): ...@@ -31,29 +31,46 @@ class MultiFigure(FigureBase):
def to_plotly(self): def to_plotly(self):
title_extractor = lambda z: z.title if z else "" title_extractor = lambda z: z.title if z else ""
sub_titles = \ sub_titles = tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
specs_setter = \ specs_setter = lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {}
lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {} sub_specs = list(map(list, np.vectorize(specs_setter)(self.subfigures)))
sub_specs = \
list(map(list, np.vectorize(specs_setter)(self.subfigures)))
multi_figure_plotly_object = plotly.tools.make_subplots( multi_figure_plotly_object = plotly.tools.make_subplots(
self.rows, self.rows,
self.columns, self.columns,
subplot_titles=sub_titles, subplot_titles=sub_titles,
specs=sub_specs) specs=sub_specs)
# TODO resolve bug with titles and 3D subplots
multi_figure_plotly_object['layout'].update(height=self.height,
width=self.width,
title=self.title)
#TODO resolve bug with titles and 3D subplots
i = 1
for index, fig in np.ndenumerate(self.subfigures): for index, fig in np.ndenumerate(self.subfigures):
if fig: if fig:
for plot in fig.plots: for plot in fig.plots:
multi_figure_plotly_object.append_trace(plot.to_plotly(), multi_figure_plotly_object.append_trace(plot.to_plotly(),
index[0]+1, index[0]+1,
index[1]+1) index[1]+1)
if isinstance(fig, Figure3D):
scene = dict()
if fig.xaxis:
scene['xaxis'] = fig.xaxis.to_plotly()
if fig.yaxis:
scene['yaxis'] = fig.yaxis.to_plotly()
if fig.zaxis:
scene['zaxis'] = fig.zaxis.to_plotly()
multi_figure_plotly_object['layout'].update(height=self.height, multi_figure_plotly_object['layout']['scene'+str(i)] = scene
width=self.width, else:
title=self.title) if fig.xaxis:
multi_figure_plotly_object['layout']['xaxis'+str(i)] = fig.xaxis.to_plotly()
if fig.yaxis:
multi_figure_plotly_object['layout']['yaxis'+str(i)] = fig.yaxis.to_plotly()
i += 1
return multi_figure_plotly_object return multi_figure_plotly_object
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
from nifty.plotting.colormap import Colormap
from nifty.plotting.plotly_wrapper import PlotlyWrapper from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper): class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False, def __init__(self, data, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False smoothing=False): # smoothing 'best', 'fast', False
self.data = data if isinstance(data, list):
self.data = np.zeros((data[0].shape))
for arr in data:
self.data = np.add(self.data, arr)
else:
self.data = data
if color_map is not None:
if not isinstance(color_map, Colormap):
raise TypeError("Provided color_map must be an instance of "
"the NIFTy Colormap class.")
self.color_map = color_map self.color_map = color_map
self.webgl = webgl self.webgl = webgl
self.smoothing = smoothing self.smoothing = smoothing
@property
def figure_dimension(self):
return 2
def to_plotly(self): def to_plotly(self):
plotly_object = dict() plotly_object = dict()
plotly_object['z'] = self.data plotly_object['z'] = self.data
......
...@@ -14,8 +14,10 @@ class Mollweide(Heatmap): ...@@ -14,8 +14,10 @@ class Mollweide(Heatmap):
raise ImportError("The module pylab is needed but not available.") raise ImportError("The module pylab is needed but not available.")
if 'healpy' not in gdi: if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available.") raise ImportError("The module healpy is needed but not available.")
if isinstance(data, list):
data = self._mollview(data) data = [self._mollview(d) for d in data]
else:
data = self._mollview(data)
super(Mollweide, self).__init__(data, color_map, webgl, smoothing) super(Mollweide, self).__init__(data, color_map, webgl, smoothing)
def _mollview(self, x, xsize=800): def _mollview(self, x, xsize=800):
......
...@@ -4,14 +4,16 @@ from scatter_plot import ScatterPlot ...@@ -4,14 +4,16 @@ from scatter_plot import ScatterPlot
class Cartesian(ScatterPlot): class Cartesian(ScatterPlot):
def __init__(self, x, y, label, line, marker): def __init__(self, x, y, label, line, marker, showlegend=True):
super(Cartesian, self).__init__(label, line, marker) super(Cartesian, self).__init__(label, line, marker)
self.x = x self.x = x
self.y = y self.y = y
self.showlegend = showlegend
@abstractmethod @abstractmethod
def to_plotly(self): def to_plotly(self):
plotly_object = super(Cartesian, self).to_plotly() plotly_object = super(Cartesian, self).to_plotly()
plotly_object['x'] = self.x plotly_object['x'] = self.x
plotly_object['y'] = self.y plotly_object['y'] = self.y
plotly_object['showlegend'] = self.showlegend
return plotly_object return plotly_object
...@@ -5,14 +5,20 @@ from cartesian import Cartesian ...@@ -5,14 +5,20 @@ from cartesian import Cartesian
class Cartesian2D(Cartesian): class Cartesian2D(Cartesian):
def __init__(self, x=None, y=None, x_start=0, x_step=1, def __init__(self, x=None, y=None, x_start=0, x_step=1,
label='', line=None, marker=None, webgl=True): label='', line=None, marker=None, showlegend=True,
webgl=True):
if y is None: if y is None:
raise Exception('Error: no y data to plot')