Commit 4b11da0d authored by Theo Steininger's avatar Theo Steininger
Browse files

Refactored plotting classes.

parent ad8d2948
Pipeline #12650 passed with stage
in 4 minutes and 24 seconds
...@@ -7,31 +7,44 @@ from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide ...@@ -7,31 +7,44 @@ from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide
class Figure2D(FigureFromPlot): class Figure2D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None, def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None): xaxis=None, yaxis=None):
super(Figure2D, self).__init__(plots, title, width, height)
# TODO: add sanitization of plots input if plots is not None:
if isinstance(plots[0], Heatmap) and not width and not height: if isinstance(plots[0], Heatmap) and width is None and \
(x, y) = plots[0].data.shape height is None:
(x, y) = plots[0].data.shape
if x > y:
width = 500
height = int(500*y/x)
else:
height = 500
width = int(500 * y / x)
if isinstance(plots[0], GLMollweide) or \
isinstance(plots[0], HPMollweide):
xaxis = False if (xaxis is None) else xaxis
yaxis = False if (yaxis is None) else yaxis
if x > y:
width = 500
height = int(500*y/x)
else: else:
height = 500 width = None
width = int(500 * y / x) height = None
if isinstance(plots[0], GLMollweide) or isinstance(plots[0],
HPMollweide):
if not xaxis:
xaxis = False
if not yaxis:
yaxis = False
super(Figure2D, self).__init__(plots, title, width, height) super(Figure2D, self).__init__(plots, title, width, height)
self.xaxis = xaxis self.xaxis = xaxis
self.yaxis = yaxis self.yaxis = yaxis
def at(self, plots):
return Figure2D(plots=plots,
title=self.title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
yaxis=self.yaxis)
def to_plotly(self): def to_plotly(self):
plotly_object = super(Figure2D, self).to_plotly() plotly_object = super(Figure2D, self).to_plotly()
if self.xaxis or self.yaxis: if self.xaxis or self.yaxis:
plotly_object['layout']['scene']['aspectratio'] = {} plotly_object['layout']['scene']['aspectratio'] = {}
if self.xaxis: if self.xaxis:
......
...@@ -10,6 +10,15 @@ class Figure3D(FigureFromPlot): ...@@ -10,6 +10,15 @@ class Figure3D(FigureFromPlot):
self.yaxis = yaxis self.yaxis = yaxis
self.zaxis = zaxis self.zaxis = zaxis
def at(self, plots):
return Figure3D(plots=plots,
title=self.title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
yaxis=self.yaxis,
zaxis=self.zaxis)
def to_plotly(self): def to_plotly(self):
plotly_object = super(Figure3D, self).to_plotly() plotly_object = super(Figure3D, self).to_plotly()
if self.xaxis or self.yaxis or self.zaxis: if self.xaxis or self.yaxis or self.zaxis:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import abstractmethod import abc
from nifty.plotting.plotly_wrapper import PlotlyWrapper from nifty.plotting.plotly_wrapper import PlotlyWrapper
...@@ -11,6 +11,10 @@ class FigureBase(PlotlyWrapper): ...@@ -11,6 +11,10 @@ class FigureBase(PlotlyWrapper):
self.width = width self.width = width
self.height = height self.height = height
@abstractmethod @abc.abstractmethod
def at(self):
raise NotImplementedError
@abc.abstractmethod
def to_plotly(self): def to_plotly(self):
raise NotImplementedError raise NotImplementedError
...@@ -10,13 +10,21 @@ plotly = gdi.get('plotly') ...@@ -10,13 +10,21 @@ plotly = gdi.get('plotly')
# TODO: add nice height and width defaults for multifigure # TODO: add nice height and width defaults for multifigure
class MultiFigure(FigureBase): class MultiFigure(FigureBase):
def __init__(self, rows, columns, title=None, width=None, height=None, def __init__(self, subfigures, title=None, width=None, height=None):
subfigures=None):
if 'plotly' not in gdi: if 'plotly' not in gdi:
raise ImportError("The module plotly is needed but not available.") raise ImportError("The module plotly is needed but not available.")
super(MultiFigure, self).__init__(title, width, height) super(MultiFigure, self).__init__(title, width, height)
self.subfigures = np.empty((rows, columns), dtype=np.object) if subfigures is not None:
self.subfigures[:] = subfigures self.subfigures = np.asarray(subfigures, dtype=np.object)
if len(self.subfigures.shape) != 2:
raise ValueError("Subfigures must be a two-dimensional array.")
def at(self, subfigures):
return MultiFigure(subfigures=subfigures,
title=self.title,
width=self.width,
height=self.height)
@property @property
def rows(self): def rows(self):
...@@ -26,15 +34,15 @@ class MultiFigure(FigureBase): ...@@ -26,15 +34,15 @@ class MultiFigure(FigureBase):
def columns(self): def columns(self):
return self.subfigures.shape[1] return self.subfigures.shape[1]
def add_subfigure(self, figure, row, column):
self.subfigures[row, column] = figure
def to_plotly(self): def to_plotly(self):
title_extractor = lambda z: z.title if z else "" title_extractor = lambda z: z.title if z else ""
sub_titles = tuple(np.vectorize(title_extractor)(self.subfigures.flatten())) sub_titles = tuple(np.vectorize(title_extractor)(
self.subfigures.flatten()))
specs_setter = lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {} specs_setter = lambda z: ({'is_3d': True}
sub_specs = list(map(list, np.vectorize(specs_setter)(self.subfigures))) if isinstance(z, Figure3D) else {})
sub_specs = list(map(list, np.vectorize(specs_setter)(
self.subfigures)))
multi_figure_plotly_object = plotly.tools.make_subplots( multi_figure_plotly_object = plotly.tools.make_subplots(
self.rows, self.rows,
...@@ -46,7 +54,7 @@ class MultiFigure(FigureBase): ...@@ -46,7 +54,7 @@ class MultiFigure(FigureBase):
width=self.width, width=self.width,
title=self.title) title=self.title)
#TODO resolve bug with titles and 3D subplots # TODO resolve bug with titles and 3D subplots
i = 1 i = 1
for index, fig in np.ndenumerate(self.subfigures): for index, fig in np.ndenumerate(self.subfigures):
......
...@@ -10,16 +10,26 @@ pyHealpix = gdi.get('pyHealpix') ...@@ -10,16 +10,26 @@ pyHealpix = gdi.get('pyHealpix')
class GLMollweide(Heatmap): class GLMollweide(Heatmap):
def __init__(self, data, nlat, nlon, color_map=None, webgl=False, def __init__(self, data, xsize=800, color_map=None,
smoothing=False): # smoothing 'best', 'fast', False webgl=False, smoothing=False):
# smoothing 'best', 'fast', False
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
self.xsize = xsize
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing)
def at(self, data):
if isinstance(data, list): if isinstance(data, list):
data = [self._mollview(d) for d in data] data = [self._mollview(d) for d in data]
else: else:
data = self._mollview(data, nlat, nlon) data = self._mollview(data)
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing) return GLMollweide(data=data,
xsize=self.xsize,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
@staticmethod @staticmethod
def _find_closest(A, target): def _find_closest(A, target):
...@@ -31,10 +41,13 @@ class GLMollweide(Heatmap): ...@@ -31,10 +41,13 @@ class GLMollweide(Heatmap):
idx -= target - left < right - target idx -= target - left < right - target
return idx return idx
def _mollview(self, x, nlat, nlon, xsize=800): def _mollview(self, x):
xsize = self.xsize
nlat = x.shape[0]
nlon = x.shape[1]
res, mask, theta, phi = mollweide_helper(xsize) res, mask, theta, phi = mollweide_helper(xsize)
x = np.reshape(x, (nlat, nlon))
ra = np.linspace(0, 2*np.pi, nlon+1) ra = np.linspace(0, 2*np.pi, nlon+1)
dec = pyHealpix.GL_thetas(nlat) dec = pyHealpix.GL_thetas(nlat)
ilat = self._find_closest(dec, theta) ilat = self._find_closest(dec, theta)
......
...@@ -6,14 +6,8 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper ...@@ -6,14 +6,8 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper): class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False, def __init__(self, data, color_map=None, webgl=False, smoothing=False):
smoothing=False): # smoothing 'best', 'fast', False # smoothing 'best', 'fast', False
if isinstance(data, list):
self.data = np.zeros((data[0].shape))
for arr in data:
self.data = np.add(self.data, arr)
else:
self.data = data
if color_map is not None: if color_map is not None:
if not isinstance(color_map, Colormap): if not isinstance(color_map, Colormap):
...@@ -22,6 +16,19 @@ class Heatmap(PlotlyWrapper): ...@@ -22,6 +16,19 @@ class Heatmap(PlotlyWrapper):
self.color_map = color_map self.color_map = color_map
self.webgl = webgl self.webgl = webgl
self.smoothing = smoothing self.smoothing = smoothing
self.data = data
def at(self, data):
if isinstance(data, list):
temp_data = np.zeros((data[0].shape))
for arr in data:
temp_data = np.add(temp_data, arr)
else:
temp_data = data
return Heatmap(data=temp_data,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
@property @property
def figure_dimension(self): def figure_dimension(self):
...@@ -29,7 +36,9 @@ class Heatmap(PlotlyWrapper): ...@@ -29,7 +36,9 @@ class Heatmap(PlotlyWrapper):
def to_plotly(self): def to_plotly(self):
plotly_object = dict() plotly_object = dict()
plotly_object['z'] = self.data plotly_object['z'] = self.data
plotly_object['showscale'] = False plotly_object['showscale'] = False
if self.color_map: if self.color_map:
plotly_object['colorscale'] = self.color_map.to_plotly() plotly_object['colorscale'] = self.color_map.to_plotly()
......
...@@ -10,18 +10,27 @@ pyHealpix = gdi.get('pyHealpix') ...@@ -10,18 +10,27 @@ pyHealpix = gdi.get('pyHealpix')
class HPMollweide(Heatmap): class HPMollweide(Heatmap):
def __init__(self, data, color_map=None, webgl=False, def __init__(self, data, xsize=800, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False smoothing=False): # smoothing 'best', 'fast', False
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
self.xsize = xsize
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing)
def at(self, data):
if isinstance(data, list): if isinstance(data, list):
data = [self._mollview(d) for d in data] data = [self._mollview(d) for d in data]
else: else:
data = self._mollview(data) data = self._mollview(data)
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing) return HPMollweide(data=data,
xsize=self.xsize,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
def _mollview(self, x, xsize=800): def _mollview(self, x):
xsize = self.xsize
res, mask, theta, phi = mollweide_helper(xsize) res, mask, theta, phi = mollweide_helper(xsize)
ptg = np.empty((phi.size, 2), dtype=np.float64) ptg = np.empty((phi.size, 2), dtype=np.float64)
......
...@@ -4,16 +4,14 @@ from scatter_plot import ScatterPlot ...@@ -4,16 +4,14 @@ from scatter_plot import ScatterPlot
class Cartesian(ScatterPlot): class Cartesian(ScatterPlot):
def __init__(self, x, y, label, line, marker, showlegend=True): def __init__(self, data, label, line, marker, showlegend=True):
super(Cartesian, self).__init__(label, line, marker) super(Cartesian, self).__init__(data, label, line, marker)
self.x = x
self.y = y
self.showlegend = showlegend self.showlegend = showlegend
@abstractmethod @abstractmethod
def to_plotly(self): def to_plotly(self):
plotly_object = super(Cartesian, self).to_plotly() plotly_object = super(Cartesian, self).to_plotly()
plotly_object['x'] = self.x plotly_object['x'] = self.data[0]
plotly_object['y'] = self.y plotly_object['y'] = self.data[1]
plotly_object['showlegend'] = self.showlegend plotly_object['showlegend'] = self.showlegend
return plotly_object return plotly_object
...@@ -4,17 +4,20 @@ from cartesian import Cartesian ...@@ -4,17 +4,20 @@ from cartesian import Cartesian
class Cartesian2D(Cartesian): class Cartesian2D(Cartesian):
def __init__(self, x=None, y=None, x_start=0, x_step=1, def __init__(self, data, label='', line=None, marker=None, showlegend=True,
label='', line=None, marker=None, showlegend=True,
webgl=True): webgl=True):
if y is None: super(Cartesian2D, self).__init__(data, label, line, marker,
raise Exception('Error: no y data to plot')
if x is None:
x = range(x_start, len(y) * x_step, x_step)
super(Cartesian2D, self).__init__(x, y, label, line, marker,
showlegend) showlegend)
self.webgl = webgl self.webgl = webgl
def at(self, data):
return Cartesian2D(data=data,
label=self.label,
line=self.line,
marker=self.marker,
showlegend=self.showlegend,
webgl=self.webgl)
@property @property
def figure_dimension(self): def figure_dimension(self):
return 2 return 2
......
...@@ -4,11 +4,17 @@ from cartesian import Cartesian ...@@ -4,11 +4,17 @@ from cartesian import Cartesian
class Cartesian3D(Cartesian): class Cartesian3D(Cartesian):
def __init__(self, x, y, z, label='', line=None, marker=None, def __init__(self, data, label='', line=None, marker=None,
showlegend=True): showlegend=True):
super(Cartesian3D, self).__init__(x, y, label, line, marker, super(Cartesian3D, self).__init__(data, label, line, marker,
showlegend) showlegend)
self.z = z
def at(self, data):
return Cartesian3D(data=data,
label=self.label,
line=self.line,
marker=self.marker,
showlegend=self.showlegend)
@property @property
def figure_dimension(self): def figure_dimension(self):
...@@ -16,6 +22,6 @@ class Cartesian3D(Cartesian): ...@@ -16,6 +22,6 @@ class Cartesian3D(Cartesian):
def to_plotly(self): def to_plotly(self):
plotly_object = super(Cartesian3D, self).to_plotly() plotly_object = super(Cartesian3D, self).to_plotly()
plotly_object['z'] = self.z plotly_object['z'] = self.data[2]
plotly_object['type'] = 'scatter3d' plotly_object['type'] = 'scatter3d'
return plotly_object return plotly_object
...@@ -3,26 +3,31 @@ from scatter_plot import ScatterPlot ...@@ -3,26 +3,31 @@ from scatter_plot import ScatterPlot
class Geo(ScatterPlot): class Geo(ScatterPlot):
def __init__(self, lon, lat, label='', line=None, marker=None, def __init__(self, data, label='', line=None, marker=None,
proj='mollweide'): projection='mollweide'):
""" """
proj: mollweide or mercator proj: mollweide or mercator
""" """
super.__init__(label, line, marker) super(Geo, self).__init__(label, line, marker)
self.lon = lon self.projection = projection
self.lat = lat
self.projection = proj def at(self, data):
return Geo(data=data,
label=self.label,
line=self.line,
marker=self.marker,
projection=self.projection)
@property @property
def figure_dimension(self): def figure_dimension(self):
return 2 return 2
def _to_plotly(self): def _to_plotly(self, data):
plotly_object = super(Geo, self).to_plotly() plotly_object = super(Geo, self).to_plotly()
plotly_object['type'] = 'scattergeo' plotly_object['type'] = self.projection
plotly_object['lon'] = self.lon plotly_object['lon'] = data[0]
plotly_object['lat'] = self.lat plotly_object['lat'] = data[1]
if self.line: if self.line:
plotly_object['mode'] = 'lines' plotly_object['mode'] = 'lines'
return plotly_object return plotly_object
...@@ -7,7 +7,8 @@ from nifty.plotting.descriptors import Marker,\ ...@@ -7,7 +7,8 @@ from nifty.plotting.descriptors import Marker,\
class ScatterPlot(PlotlyWrapper): class ScatterPlot(PlotlyWrapper):
def __init__(self, label, line, marker): def __init__(self, data, label, line, marker):
self.data = data
self.label = label self.label = label
self.line = line self.line = line
self.marker = marker self.marker = marker
...@@ -15,6 +16,10 @@ class ScatterPlot(PlotlyWrapper): ...@@ -15,6 +16,10 @@ class ScatterPlot(PlotlyWrapper):
self.marker = Marker() self.marker = Marker()
self.line = Line() self.line = Line()
@abc.abstractmethod
def at(self, data):
raise NotImplementedError
@abc.abstractproperty @abc.abstractproperty
def figure_dimension(self): def figure_dimension(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
from healpix_plotter import HealpixPlotter from healpix_plotter import HealpixPlotter
from gl_plotter import GLPlotter from gl_plotter import GLPlotter
from power_plotter import PowerPlotter from power_plotter import PowerPlotter
from rg2d_plotter import RG2DPlotter
import numpy as np
from nifty.spaces import GLSpace from nifty.spaces import GLSpace
from nifty.plotting.figures import Figure2D from nifty.plotting.figures import Figure2D
from nifty.plotting.plots import GLMollweide from nifty.plotting.plots import GLMollweide
from .plotter import Plotter from .plotter_base import PlotterBase
class GLPlotter(Plotter): class GLPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None): def __init__(self, interactive=False, path='.', title="", color_map=None):
super(GLPlotter, self).__init__(interactive, path, title)
self.color_map = color_map self.color_map = color_map