Commit 1aaf3335 authored by Theo Steininger's avatar Theo Steininger
Browse files

Added plotter class. Not fully implemented yet.

parent ea6c6e12
Pipeline #9991 failed with stage
in 35 minutes and 58 seconds
......@@ -12,7 +12,8 @@ dependency_injector = keepers.DependencyInjector(
'gfft',
('nifty.dummys.gfft_dummy', 'gfft_dummy'),
'healpy',
'libsharp_wrapper_gl'])
'libsharp_wrapper_gl',
'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
......
# -*- coding: utf-8 -*-
import abc
import os
import plotly
from plotly import tools
import plotly.offline as ply
from keepers import Loggable
from nifty.spaces.space import Space
from nifty.field_types.field_type import FieldType
import nifty.nifty_utilities as utilities
plotly.offline.init_notebook_mode()
class Plotter(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', stack_subplots=False,
color_scale):
self.interactive = interactive
self.path = path
self.stack_subplots = stack_subplots
self.color_scale = None
self.title = 'uiae'
@abc.abstractproperty
def domain(self):
return (Space,)
@abc.abstractproperty
def field_type(self):
return (FieldType,)
@property
def interactive(self):
return self._interactive
@interactive.setter
def interactive(self, interactive):
self._interactive = bool(interactive)
@property
def path(self):
return self._path
@path.setter
def path(self, new_path):
self._path = os.path.normpath(new_path)
@property
def stack_subplots(self):
return self._stack_subplots
@stack_subplots.setter
def stack_subplots(self, stack_subplots):
self._stack_subplots = bool(stack_subplots)
@abc.abstractmethod
def plot(self, field, spaces=None, types=None, data_preselector=None):
# if fields is a list, create a new field with appended
# field_type = field_array and copy individual parts into the new field
spaces = utilities.cast_axis_to_tuple(spaces, len(field.domain))
types = utilities.cast_axis_to_tuple(types, len(field.field_type))
if field.domain[spaces] != self.domain:
raise AttributeError("Given space(s) of input field-domain do not "
"match the plotters domain.")
if field.field_type[spaces] != self.field_type:
raise AttributeError("Given field_type(s) of input field-domain "
"do not match the plotters field_type.")
# iterate over the individual slices in order to compose the figure
# -> make a d2o.get_full_data() (for rank==0 only?)
# add clipping
# no_subplot
result_figure = self._create_individual_plot(data)
# non-trivial subplots
result_figure = tools.make_subplots(cols=2, rows='total_iterator%2 + 1',
subplot_titles='iterator_index')
self._finalize_figure(result_figure)
def _create_individual_plot(self, data):
pass
def _finalize_figure(self, figure):
if self.interactive:
ply.iplot(figure)
else:
# is there a use for ply.plot when one has no intereset in
# saving a file?
# -> check for different file types
# -> store the file to disk (MPI awareness?)
import descriptors
import plots
import figures
from plotting import plot, plot_image
from plottable import Plottable
__all__ = ['descriptors', 'plots', 'figures', 'plot', 'plot_image', 'Plottable']
# -*- coding: utf-8 -*-
from plotly_wrapper import _PlotlyWrapper
class Marker(_PlotlyWrapper):
# find symbols at: https://plot.ly/python/reference/#scatter-marker-symbol
def __init__(self, color=None, size=None, symbol=None, opacity=None):
self.color = color
self.size = size
self.symbol = symbol
self.opacity = opacity
def _to_plotly(self):
return dict(color=self.color, size=self.size, symbol=self.symbol, opacity=self.opacity)
class Line(_PlotlyWrapper):
def __init__(self, color=None, width=None):
self.color = color
self.width = width
def _to_plotly(self):
return dict(color=self.color, width=self.width)
class Axis(_PlotlyWrapper):
def __init__(self, text=None, font='', color='', log=False, aspect_ratio=None):
self.text = text
self.font = font
self.color = color
self.log = log
self.aspect_ratio = aspect_ratio
def _to_plotly(self):
ply_object = dict()
if self.text:
ply_object.update(dict(
title=self.text,
titlefont=dict(
family=self.font,
color=self.color
)
))
if self.log:
ply_object['type'] = 'log'
return ply_object
from figure import Figure
__all__ = ['Figure']
\ No newline at end of file
from nifty.plotting.plotly_wrapper import _PlotlyWrapper
from nifty.plotting.plots.private import _Plot2D, _Plot3D
from figure_internal import _2dFigure, _3dFigure, _MapFigure
from nifty.plotting.plots import HeatMap
from nifty.plotting.figures.util import validate_plots
class Figure(_PlotlyWrapper):
def __init__(self, data, title=None, width=None, height=None, xaxis=None, yaxis=None, zaxis=None):
kind, data = validate_plots(data)
if kind == _Plot2D:
if isinstance(data[0], HeatMap) and not width and not height:
x = len(data[0].data)
y = len(data[0].data[0])
if x > y:
width = 1000
height = int(1000*y/x)
else:
height = 1000
width = int(1000 * y / x)
self.internal = _2dFigure(data, title, width, height, xaxis, yaxis)
elif kind == _Plot3D:
self.internal = _3dFigure(data, title, width, height, xaxis, yaxis, zaxis)
else:
self.internal = _MapFigure(data, title)
def _to_plotly(self):
return self.internal._to_plotly()
from abc import ABCMeta, abstractmethod
from nifty.plotting.plotly_wrapper import _PlotlyWrapper
class _BaseFigure(_PlotlyWrapper):
__metaclass__ = ABCMeta
def __init__(self, data, title, width, height):
self.data = data
self.title = title
self.width = width
self.height = height
@abstractmethod
def _to_plotly(self):
ply_object = dict(
data=[plt._to_plotly() for plt in self.data],
layout=dict(
title=self.title,
scene = dict(
aspectmode='cube'
),
autosize=True,
width=self.width,
height=self.height
)
)
return ply_object
class _2dFigure(_BaseFigure):
def __init__(self, data, title=None, width=None, height=None, xaxis=None, yaxis=None):
_BaseFigure.__init__(self, data, title, width, height)
self.xaxis = xaxis
self.yaxis = yaxis
def _to_plotly(self):
ply_object = _BaseFigure._to_plotly(self)
if self.xaxis or self.yaxis:
ply_object['layout']['scene']['aspectratio'] = dict()
if self.xaxis:
ply_object['layout']['xaxis'] = self.xaxis._to_plotly()
ply_object['layout']['scene']['aspectratio']['x'] = self.xaxis.aspect_ratio
if self.yaxis:
ply_object['layout']['yaxis'] = self.yaxis._to_plotly()
ply_object['layout']['scene']['aspectratio']['y'] = self.yaxis.aspect_ratio
return ply_object
class _3dFigure(_2dFigure):
def __init__(self, data, title=None, width=None, height=None, xaxis=None, yaxis=None, zaxis=None):
_2dFigure.__init__(self, data, title, width, height, xaxis, yaxis)
self.zaxis = zaxis
def _to_plotly(self):
ply_object = _BaseFigure._to_plotly(self)
if self.xaxis or self.yaxis or self.zaxis:
ply_object['layout']['scene']['aspectratio'] = dict()
if self.xaxis:
ply_object['layout']['scene']['xaxis'] = self.xaxis._to_plotly()
ply_object['layout']['scene']['aspectratio']['x'] = self.xaxis.aspect_ratio
if self.yaxis:
ply_object['layout']['scene']['yaxis'] = self.yaxis._to_plotly()
ply_object['layout']['scene']['aspectratio']['y'] = self.yaxis.aspect_ratio
if self.zaxis:
ply_object['layout']['scene']['zaxis'] = self.zaxis._to_plotly()
ply_object['layout']['scene']['aspectratio']['z'] = self.zaxis.aspect_ratio
return ply_object
class _MapFigure(_BaseFigure):
def __init__(self, data, title, width=None, height=None):
_BaseFigure.__init__(self, data, title, width, height)
def _to_plotly(self):
ply_object = _BaseFigure._to_plotly(self)
# print(ply_object, ply_object['layout'])
# ply_object['layout']['geo'] = dict(
# projection=dict(type=self.data.projection),
# showcoastlines=False
# )
return ply_object
from nifty.plotting.plotly_wrapper import _PlotlyWrapper
from nifty.plotting.figures.util import validate_plots
class MultiFigure(_PlotlyWrapper):
def __init__(self, cols, rows, title=None, width=None, height=None):
self.cols = cols
self.rows = rows
self.title = title
self.width = width
self.height = height
def addSubfigure(self, data, title=None, width=None, height=None):
kind, data = validate_plots(data)
def _to_plotly(self):
pass
from nifty.plotting.plots.private import _Plot2D, _Plot3D
from nifty.plotting.plots import ScatterGeoMap
def validate_plots(data):
if not data:
raise Exception('Error: no plots given')
if type(data) != list:
data = [data]
if isinstance(data[0], _Plot2D):
kind = _Plot2D
elif isinstance(data[0], _Plot3D):
kind = _Plot3D
elif isinstance(data[0], ScatterGeoMap):
kind = ScatterGeoMap
else:
kind = None
if kind:
for plt in data:
if not isinstance(plt, kind):
raise Exception(
"""Error: Plots are not of the right kind!
Compatible types are:
- Scatter2D and HeatMap
- Scatter3D
- ScatterMap""")
else:
raise Exception('Error: plot type unknown')
return kind, data
from abc import ABCMeta, abstractmethod
class _PlotlyWrapper:
__metaclass__ = ABCMeta
@abstractmethod
def _to_plotly(self):
pass
\ No newline at end of file
from scatter import *
from heatmap import *
from geomap import *
__all__ = ['Scatter2D', 'Scatter3D', 'ScatterGeoMap', 'HeatMap']
from nifty.plotting.plots.private import _PlotBase
class ScatterGeoMap(_PlotBase):
def __init__(self, lon, lat, label='', line=None, marker=None, proj='mollweide'): # or 'mercator'
_PlotBase.__init__(self, label, line, marker)
self.lon = lon
self.lat = lat
self.projection = proj
def _to_plotly(self):
ply_object = _PlotBase._to_plotly(self)
ply_object['type'] = 'scattergeo'
ply_object['lon'] = self.lon
ply_object['lat'] = self.lat
if self.line:
ply_object['mode'] = 'lines'
return ply_object
from nifty.plotting.plots.private import _PlotBase, _Plot2D
class HeatMap(_PlotBase, _Plot2D):
def __init__(self, data, label='', line=None, marker=None, webgl=False, smoothing=False): # smoothing 'best', 'fast', False
_PlotBase.__init__(self, label, line, marker)
self.data = data
self.webgl = webgl
self.smoothing = smoothing
def _to_plotly(self):
ply_object = _PlotBase._to_plotly(self)
ply_object['z'] = self.data
if self.webgl:
ply_object['type'] = 'heatmapgl'
else:
ply_object['type'] = 'heatmap'
if self.smoothing:
ply_object['zsmooth'] = self.smoothing
return ply_object
\ No newline at end of file
from abc import ABCMeta, abstractmethod
from nifty.plotting.plotly_wrapper import _PlotlyWrapper
class _PlotBase(_PlotlyWrapper):
__metaclass__ = ABCMeta
def __init__(self, label, line, marker):
self.label = label
self.line = line
self.marker = marker
@abstractmethod
def _to_plotly(self):
ply_object = dict()
ply_object['name'] = self.label
if self.line and self.marker:
ply_object['mode'] = 'lines+markers'
ply_object['line'] = self.line._to_plotly()
ply_object['marker'] = self.marker._to_plotly()
elif self.line:
ply_object['mode'] = 'markers'
ply_object['line'] = self.line._to_plotly()
elif self.marker:
ply_object['mode'] = 'line'
ply_object['marker'] = self.marker._to_plotly()
return ply_object
class _Scatter2DBase(_PlotBase):
__metaclass__ = ABCMeta
def __init__(self, x, y, label, line, marker):
_PlotBase.__init__(self, label, line, marker)
self.x = x
self.y = y
@abstractmethod
def _to_plotly(self):
ply_object = _PlotBase._to_plotly(self)
ply_object['x'] = self.x
ply_object['y'] = self.y
return ply_object
class _Plot2D:
pass # only used as a labeling system for the plots represntation
class _Plot3D:
pass # only used as a labeling system for the plots represntation
from nifty.plotting.plots.private import _Scatter2DBase, _Plot2D, _Plot3D
class Scatter2D(_Scatter2DBase, _Plot2D):
def __init__(self, x=None, y=None, x_start=0, x_step=1,
label='', line=None, marker=None, webgl=True):
if y is None:
raise Exception('Error: no y data to plot')
if x is None:
x = range(x_start, len(y) * x_step, x_step)
_Scatter2DBase.__init__(self, x, y, label, line, marker)
self.webgl = webgl
def _to_plotly(self):
ply_object = _Scatter2DBase._to_plotly(self)
if self.webgl:
ply_object['type'] = 'scattergl'
else:
ply_object['type'] = 'scatter'
return ply_object
class Scatter3D(_Scatter2DBase, _Plot3D):
def __init__(self, x, y, z, label='', line=None, marker=None):
_Scatter2DBase.__init__(self, x, y, label, line, marker)
self.z = z
def _to_plotly(self):
ply_object = _Scatter2DBase._to_plotly(self)
ply_object['z'] = self.z
ply_object['type'] = 'scatter3d'
return ply_object
from abc import ABCMeta, abstractmethod
class Plottable(object):
__metaclass__ = ABCMeta
@abstractmethod
def plot(self):
pass
import os
from PIL import Image
import plotly.offline as ply_offline
import plotly.plotly as ply
def plot(figure, filename=None):
if not filename:
filename = os.path.abspath('/tmp/temp-plot.html')
ply_offline.plot(figure._to_plotly(), filename=filename)
def plot_image(figure, filename=None, show=False):
if not filename:
filename = os.path.abspath('temp-plot.jpeg')
ply_obj = figure._to_plotly()
ply.image.save_as(ply_obj, filename=filename)
if show:
img = Image.open(filename)
img.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment