Commit fe6bf4e5 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'plotting' into 'master'

# Conflicts:
#   nifty/config/nifty_config.py
parents 2f498dbf 2794762d
Pipeline #11673 failed with stages
in 10 minutes and 1 second
......@@ -72,4 +72,4 @@ nifty_configuration = keepers.get_Configuration(
try:
nifty_configuration.load()
except:
pass
pass
\ No newline at end of file
......@@ -21,7 +21,7 @@ import numpy as np
import nifty.nifty_utilities as utilities
from nifty.operators.endomorphic_operator import EndomorphicOperator
from nifty.operators.fft_operator import FFTOperator
import smooth_util as su
from nifty.operators.smoothing_operator import smooth_util as su
from d2o import STRATEGIES
......
from plotter import Plotter
# -*- coding: utf-8 -*-
import abc
import os
import plotly
from plotly import tools
import plotly.offline as ply
from keepers import Loggable
from nifty.spaces.space import Space
from nifty.field_types.field_type import FieldType
import nifty.nifty_utilities as utilities
plotly.offline.init_notebook_mode()
class Plotter(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', stack_subplots=False,
color_scale=None):
self.interactive = interactive
self.path = path
self.stack_subplots = stack_subplots
self.color_scale = color_scale
self.title = 'uiae'
@abc.abstractproperty
def domain(self):
return (Space,)
@abc.abstractproperty
def field_type(self):
return (FieldType,)
@property
def interactive(self):
return self._interactive
@interactive.setter
def interactive(self, interactive):
self._interactive = bool(interactive)
@property
def path(self):
return self._path
@path.setter
def path(self, new_path):
self._path = os.path.normpath(new_path)
@property
def stack_subplots(self):
return self._stack_subplots
@stack_subplots.setter
def stack_subplots(self, stack_subplots):
self._stack_subplots = bool(stack_subplots)
def plot(self, field, spaces=None, types=None, slice=None):
data = self._get_data_from_field(field, spaces, types, slice)
figures = self._create_individual_plot(data)
self._finalize_figure(figures)
@abc.abstractmethod
def _get_data_from_field(self, field, spaces=None, types=None, slice=None):
# if fields is a list, create a new field with appended
# field_type = field_array and copy individual parts into the new field
spaces = utilities.cast_axis_to_tuple(spaces, len(field.domain))
types = utilities.cast_axis_to_tuple(types, len(field.field_type))
if field.domain[spaces] != self.domain:
raise AttributeError("Given space(s) of input field-domain do not "
"match the plotters domain.")
if field.field_type[spaces] != self.field_type:
raise AttributeError("Given field_type(s) of input field-domain "
"do not match the plotters field_type.")
# iterate over the individual slices in order to compose the figure
# -> make a d2o.get_full_data() (for rank==0 only?)
# add
return [1,2,3]
def _create_individual_plot(self, data):
pass
def _finalize_figure(self, figure):
pass
# is there a use for ply.plot when one has no interest in
# saving a file?
# -> check for different file types
# -> store the file to disk (MPI awareness?)
import descriptors
import plots
import figures
from plotting import plot, plot_image
from plottable import Plottable
__all__ = ['descriptors', 'plots', 'figures', 'plot', 'plot_image', 'Plottable']
from plotly_wrapper import _PlotlyWrapper
class Marker(_PlotlyWrapper):
# find symbols at: https://plot.ly/python/reference/#scatter-marker-symbol
def __init__(self, color=None, size=None, symbol=None, opacity=None):
self.color = color
self.size = size
self.symbol = symbol
self.opacity = opacity
def _to_plotly(self):
return dict(color=self.color, size=self.size, symbol=self.symbol, opacity=self.opacity)
class Line(_PlotlyWrapper):
def __init__(self, color=None, width=None):
self.color = color
self.width = width
def _to_plotly(self):
return dict(color=self.color, width=self.width)
class Axis(_PlotlyWrapper):
def __init__(self, text=None, font='', color='', log=False, show_grid=True):
self.text = text
self.font = font
self.color = color
self.log = log
self.show_grid = show_grid
def _to_plotly(self):
ply_object = dict()
if self.text:
ply_object.update(dict(
title=self.text,
titlefont=dict(
family=self.font,
color=self.color
)
))
if self.log:
ply_object['type'] = 'log'
if not self.show_grid:
ply_object['showgrid'] = False
return ply_object
from figure import Figure, MultiFigure
__all__ = ['Figure', 'MultiFigure']
\ No newline at end of file
from nifty.plotting.plots.private import _Plot2D, _Plot3D
from figure_internal import _2dFigure, _3dFigure, _MapFigure, _BaseFigure
from nifty.plotting.plots import HeatMap, MollweideHeatmap
from nifty.plotting.figures.util import validate_plots
from plotly.tools import make_subplots
class Figure(_BaseFigure):
def __init__(self, data, title=None, xaxis=None, yaxis=None, zaxis=None, width=None, height=None):
_BaseFigure.__init__(self, data, title, width, height)
kind, self.data = validate_plots(data)
if kind == _Plot2D:
if isinstance(self.data[0], HeatMap) and not width and not height:
x = len(self.data[0].data)
y = len(self.data[0].data[0])
if x > y:
width = 500
height = int(500*y/x)
else:
height = 500
width = int(500 * y / x)
if isinstance(self.data[0], MollweideHeatmap):
if not xaxis:
xaxis = False
if not yaxis:
yaxis = False
self.internal = _2dFigure(self.data, title, width, height, xaxis, yaxis)
elif kind == _Plot3D:
self.internal = _3dFigure(self.data, title, width, height, xaxis, yaxis, zaxis)
elif kind:
self.internal = _MapFigure(self.data, title)
def _to_plotly(self):
return self.internal._to_plotly()
class MultiFigure(_BaseFigure):
def __init__(self, rows, cols, title=None, width=None, height=None):
_BaseFigure.__init__(self, None, title, width, height)
self.cols = cols
self.rows = rows
self.subfigures = []
def get_subfigure(self, row, col):
for fig, r, c, _, _ in self.subfigures:
if r == row and c == col:
return fig
else:
return None
def add_subfigure(self, figure, row, col, row_span=1, col_span=1):
self.subfigures.append((figure, row, col, row_span, col_span))
def _to_plotly(self):
sub_titles = tuple([a[0].title for a in self.subfigures])
sub_specs = [[None]*self.cols for _ in range(self.rows)]
for fig, r, c, rs, cs in self.subfigures:
sub_specs[r][c] = dict(colspan=cs, rowspan=rs)
if isinstance(fig.internal, _3dFigure):
sub_specs[r][c]['is_3d'] = True
multi_figure_ply = make_subplots(self.rows,self.cols, subplot_titles=sub_titles, specs=sub_specs)
for fig, r, c, _, _ in self.subfigures:
for plot in fig.data:
multi_figure_ply.append_trace(plot._to_plotly(), r+1, c+1)
multi_figure_ply['layout'].update(height=self.height, width=self.width, title=self.title)
return multi_figure_ply
@staticmethod
def from_figures_2cols(figures, title=None, width=None, height=None):
multi_figure = MultiFigure((len(figures)+1)/2, 2, title, width, height)
for i in range(0, len(figures), 2):
multi_figure.add_subfigure(figures[i], i/2, 0)
for i in range(1, len(figures), 2):
multi_figure.add_subfigure(figures[i], i/2, 1)
return multi_figure
from abc import ABCMeta, abstractmethod
from nifty.plotting.plotly_wrapper import _PlotlyWrapper
class _BaseFigure(_PlotlyWrapper):
__metaclass__ = ABCMeta
def __init__(self, data, title, width, height):
self.data = data
self.title = title
self.width = width
self.height = height
@abstractmethod
def _to_plotly(self):
ply_object = dict(
data=[plt._to_plotly() for plt in self.data],
layout=dict(
title=self.title,
scene = dict(
aspectmode='cube'
),
autosize=False,
width=self.width,
height=self.height,
)
)
return ply_object
class _2dFigure(_BaseFigure):
def __init__(self, data, title=None, width=None, height=None, xaxis=None, yaxis=None):
_BaseFigure.__init__(self, data, title, width, height)
self.xaxis = xaxis
self.yaxis = yaxis
def _to_plotly(self):
ply_object = _BaseFigure._to_plotly(self)
if self.xaxis or self.yaxis:
ply_object['layout']['scene']['aspectratio'] = dict()
if self.xaxis:
ply_object['layout']['xaxis'] = self.xaxis._to_plotly()
elif self.xaxis == False:
ply_object['layout']['xaxis'] = dict(
autorange=True,
showgrid=False,
zeroline=False,
showline=False,
autotick=True,
ticks='',
showticklabels=False
)
if self.yaxis:
ply_object['layout']['yaxis'] = self.yaxis._to_plotly()
elif self.yaxis == False:
ply_object['layout']['yaxis'] = dict(showline=False)
return ply_object
class _3dFigure(_2dFigure):
def __init__(self, data, title=None, width=None, height=None, xaxis=None, yaxis=None, zaxis=None):
_2dFigure.__init__(self, data, title, width, height, xaxis, yaxis)
self.zaxis = zaxis
def _to_plotly(self):
ply_object = _BaseFigure._to_plotly(self)
if self.xaxis or self.yaxis or self.zaxis:
ply_object['layout']['scene']['aspectratio'] = dict()
if self.xaxis:
ply_object['layout']['scene']['xaxis'] = self.xaxis._to_plotly()
elif self.xaxis == False:
ply_object['layout']['scene']['xaxis'] = dict(showline=False)
if self.yaxis:
ply_object['layout']['scene']['yaxis'] = self.yaxis._to_plotly()
elif self.yaxis == False:
ply_object['layout']['scene']['yaxis'] = dict(showline=False)
if self.zaxis:
ply_object['layout']['scene']['zaxis'] = self.zaxis._to_plotly()
elif self.zaxis == False:
ply_object['layout']['scene']['zaxis'] = dict(showline=False)
return ply_object
class _MapFigure(_BaseFigure):
def __init__(self, data, title, width=None, height=None):
_BaseFigure.__init__(self, data, title, width, height)
def _to_plotly(self):
ply_object = _BaseFigure._to_plotly(self)
# print(ply_object, ply_object['layout'])
# ply_object['layout']['geo'] = dict(
# projection=dict(type=self.data.projection),
# showcoastlines=False
# )
return ply_object
from nifty.plotting.plots.private import _Plot2D, _Plot3D
from nifty.plotting.plots import ScatterGeoMap
def validate_plots(data, except_empty=True):
if not data:
if except_empty:
raise Exception('Error: no plots given')
else:
return True
if type(data) != list:
data = [data]
if isinstance(data[0], _Plot2D):
kind = _Plot2D
elif isinstance(data[0], _Plot3D):
kind = _Plot3D
elif isinstance(data[0], ScatterGeoMap):
kind = ScatterGeoMap
else:
kind = None
if kind:
for plt in data:
if not isinstance(plt, kind):
raise Exception(
"""Error: Plots are not of the right kind!
Compatible types are:
- Scatter2D and HeatMap
- Scatter3D
- ScatterMap""")
else:
raise Exception('Error: plot type unknown')
return kind, data
from abc import ABCMeta, abstractmethod
class _PlotlyWrapper(object):
__metaclass__ = ABCMeta
@abstractmethod
def _to_plotly(self):
pass
\ No newline at end of file
from scatter import *
from heatmap import *
from geomap import *
__all__ = ['Scatter2D', 'Scatter3D', 'ScatterGeoMap', 'HeatMap', 'MollweideHeatmap']
from nifty.plotting.plots.private import _PlotBase
class ScatterGeoMap(_PlotBase):
def __init__(self, lon, lat, label='', line=None, marker=None, proj='mollweide'): # or 'mercator'
_PlotBase.__init__(self, label, line, marker)
self.lon = lon
self.lat = lat
self.projection = proj
def _to_plotly(self):
ply_object = _PlotBase._to_plotly(self)
ply_object['type'] = 'scattergeo'
ply_object['lon'] = self.lon
ply_object['lat'] = self.lat
if self.line:
ply_object['mode'] = 'lines'
return ply_object
from nifty.plotting.plots.private import _PlotBase, _Plot2D
import healpy.projaxes as PA
import healpy.pixelfunc as pixelfunc
class HeatMap(_PlotBase, _Plot2D):
def __init__(self, data, label='', line=None, marker=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
_PlotBase.__init__(self, label, line, marker)
self.data = data
self.webgl = webgl
self.smoothing = smoothing
def _to_plotly(self):
ply_object = _PlotBase._to_plotly(self)
ply_object['z'] = self.data
if self.webgl:
ply_object['type'] = 'heatmapgl'
else:
ply_object['type'] = 'heatmap'
if self.smoothing:
ply_object['zsmooth'] = self.smoothing
return ply_object
class MollweideHeatmap(HeatMap):
def __init__(self, data, label='', line=None, marker=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
HeatMap.__init__(self, _mollview(data), label, line, marker, webgl, smoothing)
def _mollview(x, xsize=800):
import pylab
x = pixelfunc.ma_to_array(x)
f = pylab.figure(None, figsize=(8.5, 5.4))
extent = (0.02, 0.05, 0.96, 0.9)
ax = PA.HpxMollweideAxes(f, extent)
img = ax.projmap(x, nest=False, xsize=xsize)
return img
from abc import ABCMeta, abstractmethod
from nifty.plotting.plotly_wrapper import _PlotlyWrapper
class _PlotBase(_PlotlyWrapper):
__metaclass__ = ABCMeta
def __init__(self, label, line, marker):
self.label = label
self.line = line
self.marker = marker
@abstractmethod
def _to_plotly(self):
ply_object = dict()
ply_object['name'] = self.label
if self.line and self.marker:
ply_object['mode'] = 'lines+markers'
ply_object['line'] = self.line._to_plotly()
ply_object['marker'] = self.marker._to_plotly()
elif self.line:
ply_object['mode'] = 'markers'
ply_object['line'] = self.line._to_plotly()
elif self.marker:
ply_object['mode'] = 'line'
ply_object['marker'] = self.marker._to_plotly()
return ply_object
class _Scatter2DBase(_PlotBase):
__metaclass__ = ABCMeta
def __init__(self, x, y, label, line, marker):
_PlotBase.__init__(self, label, line, marker)
self.x = x
self.y = y
@abstractmethod
def _to_plotly(self):
ply_object = _PlotBase._to_plotly(self)
ply_object['x'] = self.x
ply_object['y'] = self.y
return ply_object
class _Plot2D:
pass # only used as a labeling system for the plots represntation
class _Plot3D:
pass # only used as a labeling system for the plots represntation
from nifty.plotting.plots.private import _Scatter2DBase
class Scatter2D(_Scatter2DBase):
def __init__(self, x=None, y=None, x_start=0, x_step=1,
label='', line=None, marker=None, webgl=True):
if y is None:
raise Exception('Error: no y data to plot')
if x is None:
x = range(x_start, len(y) * x_step, x_step)
_Scatter2DBase.__init__(self, x, y, label, line, marker)
self.webgl = webgl
def _to_plotly(self):
ply_object = _Scatter2DBase._to_plotly(self)
if self.webgl:
ply_object['type'] = 'scattergl'
else:
ply_object['type'] = 'scatter'
return ply_object
class Scatter3D(_Scatter2DBase):
def __init__(self, x, y, z, label='', line=None, marker=None):
_Scatter2DBase.__init__(self, x, y, label, line, marker)
self.z = z
def _to_plotly(self):
ply_object = _Scatter2DBase._to_plotly(self)
ply_object['z'] = self.z
ply_object['type'] = 'scatter3d'
return ply_object
from abc import ABCMeta, abstractmethod
class Plottable:
__metaclass__ = ABCMeta
@abstractmethod
def plot(self):
pass
import os
from PIL import Image
import plotly.offline as ply_offline
import plotly.plotly as ply
def plot(figure, filename=None, interactive=False):
if not filename:
filename = os.path.abspath('/tmp/temp-plot.html')
if interactive:
try:
__IPYTHON__
ply_offline.init_notebook_mode(connected=True)
ply_offline.iplot(figure._to_plotly(), filename=filename)
except NameError:
ply_offline.plot(figure._to_plotly(), filename=filename)