Commit db23017b authored by Theo Steininger's avatar Theo Steininger

plotting now produces semi-reasonable results

parent d11f0851
Pipeline #15188 passed with stages
in 12 minutes and 43 seconds
......@@ -5,12 +5,14 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Axis(PlotlyWrapper):
def __init__(self, text=None, font='', color='', log=False,
show_grid=True):
font_size=18, show_grid=True, visible=True):
self.text = text
self.font = font
self.color = color
self.log = log
self.font_size = int(font_size)
self.show_grid = show_grid
self.visible = visible
def to_plotly(self):
ply_object = dict()
......@@ -19,11 +21,14 @@ class Axis(PlotlyWrapper):
title=self.text,
titlefont=dict(
family=self.font,
color=self.color
color=self.color,
size=self.font_size
)
))
if self.log:
ply_object['type'] = 'log'
if not self.show_grid:
ply_object['showgrid'] = False
ply_object['visible'] = self.visible
ply_object['tickfont'] = {'size': self.font_size}
return ply_object
......@@ -7,8 +7,13 @@ from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide
class Figure2D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None):
if plots is not None:
width = width if width is not None else plots[0].default_width()
height = \
height if height is not None else plots[0].default_height()
(xaxis, yaxis) = \
xaxis if xaxis is not None else plots[0].default_axes()
if isinstance(plots[0], Heatmap) and width is None and \
height is None:
(x, y) = plots[0].data.shape
......@@ -29,9 +34,10 @@ class Figure2D(FigureFromPlot):
self.xaxis = xaxis
self.yaxis = yaxis
def at(self, plots):
def at(self, plots, title=None):
title = title if title is not None else self.title
return Figure2D(plots=plots,
title=self.title,
title=title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
......
......@@ -5,14 +5,21 @@ from figure_from_plot import FigureFromPlot
class Figure3D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None, zaxis=None):
if plots is not None:
width = width if width is not None else plots[0].default_width()
height = \
height if height is not None else plots[0].default_height()
(xaxis, yaxis, zaxis) = \
xaxis if xaxis is not None else plots[0].default_axes()
super(Figure3D, self).__init__(plots, title, width, height)
self.xaxis = xaxis
self.yaxis = yaxis
self.zaxis = zaxis
def at(self, plots):
def at(self, plots, title=None):
title = title if title is not None else self.title
return Figure3D(plots=plots,
title=self.title,
title=title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
......
......@@ -12,7 +12,7 @@ class FigureBase(PlotlyWrapper):
self.height = height
@abc.abstractmethod
def at(self):
def at(self, title=None):
raise NotImplementedError
@abc.abstractmethod
......
......@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi
from heatmap import Heatmap
import numpy as np
from nifty.plotting.descriptors import Axis
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -11,14 +13,15 @@ pyHealpix = gdi.get('pyHealpix')
class GLMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None,
webgl=False, smoothing=False):
webgl=False, smoothing=False, zmin=None, zmax=None):
# smoothing 'best', 'fast', False
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing)
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
def at(self, data):
if isinstance(data, list):
......@@ -55,3 +58,12 @@ class GLMollweide(Heatmap):
ilon = np.where(ilon == nlon, 0, ilon)
res[mask] = x[ilat, ilon]
return res
def default_width(self):
return 1400
def default_height(self):
return 700
def default_axes(self):
return (Axis(visible=False), Axis(visible=False))
# -*- coding: utf-8 -*-
import numpy as np
from nifty.plotting.descriptors import Axis
from nifty.plotting.colormap import Colormap
from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False, smoothing=False):
def __init__(self, data, color_map=None, webgl=False, smoothing=False,
zmin=None, zmax=None):
# smoothing 'best', 'fast', False
if color_map is not None:
......@@ -17,6 +20,9 @@ class Heatmap(PlotlyWrapper):
self.webgl = webgl
self.smoothing = smoothing
self.data = data
self.zmin = zmin
self.zmax = zmax
self._font_size = 18
def at(self, data):
if isinstance(data, list):
......@@ -28,7 +34,9 @@ class Heatmap(PlotlyWrapper):
return Heatmap(data=temp_data,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
smoothing=self.smoothing,
zmin=self.zmin,
zmax=self.zmax)
@property
def figure_dimension(self):
......@@ -38,11 +46,13 @@ class Heatmap(PlotlyWrapper):
plotly_object = dict()
plotly_object['z'] = self.data
plotly_object['zmin'] = self.zmin
plotly_object['zmax'] = self.zmax
plotly_object['showscale'] = False
plotly_object['showscale'] = True
plotly_object['colorbar'] = {'tickfont': {'size': self._font_size}}
if self.color_map:
plotly_object['colorscale'] = self.color_map.to_plotly()
plotly_object['colorbar'] = dict(title=self.color_map.name, x=0.42)
if self.webgl:
plotly_object['type'] = 'heatmapgl'
else:
......@@ -50,3 +60,14 @@ class Heatmap(PlotlyWrapper):
if self.smoothing:
plotly_object['zsmooth'] = self.smoothing
return plotly_object
def default_width(self):
return 700
def default_height(self):
(x, y) = self.data.shape
return int(700 * y / x)
def default_axes(self):
return (Axis(font_size=self._font_size),
Axis(font_size=self._font_size))
......@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi
from heatmap import Heatmap
import numpy as np
from nifty.plotting.descriptors import Axis
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -11,12 +13,13 @@ pyHealpix = gdi.get('pyHealpix')
class HPMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
smoothing=False, zmin=None, zmax=None): # smoothing 'best', 'fast', False
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing)
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
def at(self, data):
if isinstance(data, list):
......@@ -39,3 +42,12 @@ class HPMollweide(Heatmap):
base = pyHealpix.Healpix_Base(int(np.sqrt(x.size/12)), "RING")
res[mask] = x[base.ang2pix(ptg)]
return res
def default_width(self):
return 1400
def default_height(self):
return 700
def default_axes(self):
return (Axis(visible=False), Axis(visible=False))
# -*- coding: utf-8 -*-
from nifty.plotting.descriptors import Axis
from cartesian import Cartesian
......@@ -30,3 +31,6 @@ class Cartesian2D(Cartesian):
plotly_object['type'] = 'scatter'
return plotly_object
def default_axes(self):
return (Axis(), Axis())
# -*- coding: utf-8 -*-
from nifty.plotting.descriptors import Axis
from cartesian import Cartesian
......@@ -25,3 +26,6 @@ class Cartesian3D(Cartesian):
plotly_object['z'] = self.data[2]
plotly_object['type'] = 'scatter3d'
return plotly_object
def default_axes(self):
return (Axis(), Axis(), Axis())
from nifty.plotting.descriptors import Axis
from scatter_plot import ScatterPlot
......@@ -31,3 +32,6 @@ class Geo(ScatterPlot):
if self.line:
plotly_object['mode'] = 'lines'
return plotly_object
def default_axes(self):
return (Axis(), Axis())
......@@ -40,3 +40,13 @@ class ScatterPlot(PlotlyWrapper):
ply_object['marker'] = self.marker.to_plotly()
return ply_object
def default_width(self):
return 700
def default_height(self):
return 700
@abc.abstractmethod
def default_axes(self):
raise NotImplementedError
......@@ -9,9 +9,9 @@ from .plotter_base import PlotterBase
class GLPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map
super(GLPlotter, self).__init__(interactive, path, title)
super(GLPlotter, self).__init__(interactive, path)
@property
def domain_classes(self):
......
......@@ -6,9 +6,9 @@ from .plotter_base import PlotterBase
class HealpixPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map
super(HealpixPlotter, self).__init__(interactive, path, title)
super(HealpixPlotter, self).__init__(interactive, path)
@property
def domain_classes(self):
......
......@@ -30,12 +30,11 @@ rank = d2o.config.dependency_injector[
class PlotterBase(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', title=""):
def __init__(self, interactive=False, path='plot.html'):
if plotly is None:
raise ImportError("The module plotly is needed but not available.")
self.interactive = interactive
self.path = path
self.title = str(title)
self.plot = self._initialize_plot()
self.figure = self._initialize_figure()
......@@ -61,7 +60,8 @@ class PlotterBase(Loggable, object):
def path(self, new_path):
self._path = os.path.normpath(new_path)
def __call__(self, fields, spaces=None, data_extractor=None, labels=None):
def __call__(self, fields, spaces=None, data_extractor=None, labels=None,
path=None, title=None):
if isinstance(fields, Field):
fields = [fields]
elif not isinstance(fields, list):
......@@ -72,6 +72,9 @@ class PlotterBase(Loggable, object):
if spaces is None:
spaces = tuple(range(len(fields[0].domain)))
if len(spaces) != len(self.domain_classes):
raise ValueError("Domain mismatch between input and plotter.")
axes = []
plot_domain = []
for space_index in spaces:
......@@ -91,9 +94,9 @@ class PlotterBase(Loggable, object):
spaces))
for (current_data, field) in zip(data_list, fields)]]
figures = [self.figure.at(plots) for plots in plots_list]
figures = [self.figure.at(plots, title=title) for plots in plots_list]
self._finalize_figure(figures)
self._finalize_figure(figures, path=path)
def _get_data_from_field(self, field, spaces, data_extractor):
for i, space_index in enumerate(spaces):
......@@ -117,7 +120,7 @@ class PlotterBase(Loggable, object):
def _initialize_multifigure(self):
return MultiFigure(subfigures=None)
def _finalize_figure(self, figures):
def _finalize_figure(self, figures, path=None):
if len(figures) > 1:
rows = (len(figures) + 1)//2
figure_array = np.empty((2*rows), dtype=np.object)
......@@ -128,5 +131,6 @@ class PlotterBase(Loggable, object):
else:
final_figure = figures[0]
path = self.path if path is None else path
plotly.offline.plot(final_figure.to_plotly(),
filename=os.path.join(self.path, self.title))
filename=path)
......@@ -11,9 +11,9 @@ from .plotter_base import PlotterBase
class PowerPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="",
line=None, marker=None):
super(PowerPlotter, self).__init__(interactive, path, title)
def __init__(self, interactive=False, path='plot.html', line=None,
marker=None):
super(PowerPlotter, self).__init__(interactive, path)
self.line = line
self.marker = marker
......
......@@ -11,9 +11,9 @@ from .plotter_base import PlotterBase
class RG1DPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="",
line=None, marker=None):
super(RG1DPlotter, self).__init__(interactive, path, title)
def __init__(self, interactive=False, path='plot.html', line=None,
marker=None):
super(RG1DPlotter, self).__init__(interactive, path)
self.line = line
self.marker = marker
......
......@@ -8,9 +8,9 @@ from .plotter_base import PlotterBase
class RG2DPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map
super(RG2DPlotter, self).__init__(interactive, path, title)
super(RG2DPlotter, self).__init__(interactive, path)
@property
def domain_classes(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment