Commit d8c9ccc5 authored by Theo Steininger's avatar Theo Steininger

Merge branch 'spherical_plots' into 'master'

Spherical plots

See merge request !128
parents d541b2ca c5494e13
Pipeline #12818 passed with stages
in 11 minutes and 35 seconds
......@@ -29,9 +29,7 @@ __all__ = ['dependency_injector', 'nifty_configuration']
dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'),
'pyHealpix',
'plotly',
'pylab',
'healpy'])
'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
......
# -*- coding: utf-8 -*-
from figure_from_plot import FigureFromPlot
from nifty.plotting.plots import Heatmap, Mollweide
from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide
class Figure2D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None):
super(Figure2D, self).__init__(plots, title, width, height)
# TODO: add sanitization of plots input
if isinstance(plots[0], Heatmap) and not width and not height:
(x, y) = plots[0].data.shape
if plots is not None:
if isinstance(plots[0], Heatmap) and width is None and \
height is None:
(x, y) = plots[0].data.shape
if x > y:
width = 500
height = int(500*y/x)
else:
height = 500
width = int(500 * y / x)
if isinstance(plots[0], GLMollweide) or \
isinstance(plots[0], HPMollweide):
xaxis = False if (xaxis is None) else xaxis
yaxis = False if (yaxis is None) else yaxis
if x > y:
width = 500
height = int(500*y/x)
else:
height = 500
width = int(500 * y / x)
if isinstance(plots[0], Mollweide):
if not xaxis:
xaxis = False
if not yaxis:
yaxis = False
width = None
height = None
super(Figure2D, self).__init__(plots, title, width, height)
self.xaxis = xaxis
self.yaxis = yaxis
def at(self, plots):
return Figure2D(plots=plots,
title=self.title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
yaxis=self.yaxis)
def to_plotly(self):
plotly_object = super(Figure2D, self).to_plotly()
if self.xaxis or self.yaxis:
plotly_object['layout']['scene']['aspectratio'] = {}
if self.xaxis:
......
......@@ -10,6 +10,15 @@ class Figure3D(FigureFromPlot):
self.yaxis = yaxis
self.zaxis = zaxis
def at(self, plots):
return Figure3D(plots=plots,
title=self.title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
yaxis=self.yaxis,
zaxis=self.zaxis)
def to_plotly(self):
plotly_object = super(Figure3D, self).to_plotly()
if self.xaxis or self.yaxis or self.zaxis:
......
# -*- coding: utf-8 -*-
from abc import abstractmethod
import abc
from nifty.plotting.plotly_wrapper import PlotlyWrapper
......@@ -11,6 +11,10 @@ class FigureBase(PlotlyWrapper):
self.width = width
self.height = height
@abstractmethod
@abc.abstractmethod
def at(self):
raise NotImplementedError
@abc.abstractmethod
def to_plotly(self):
raise NotImplementedError
......@@ -10,13 +10,21 @@ plotly = gdi.get('plotly')
# TODO: add nice height and width defaults for multifigure
class MultiFigure(FigureBase):
def __init__(self, rows, columns, title=None, width=None, height=None,
subfigures=None):
def __init__(self, subfigures, title=None, width=None, height=None):
if 'plotly' not in gdi:
raise ImportError("The module plotly is needed but not available.")
super(MultiFigure, self).__init__(title, width, height)
self.subfigures = np.empty((rows, columns), dtype=np.object)
self.subfigures[:] = subfigures
if subfigures is not None:
self.subfigures = np.asarray(subfigures, dtype=np.object)
if len(self.subfigures.shape) != 2:
raise ValueError("Subfigures must be a two-dimensional array.")
def at(self, subfigures):
return MultiFigure(subfigures=subfigures,
title=self.title,
width=self.width,
height=self.height)
@property
def rows(self):
......@@ -26,15 +34,15 @@ class MultiFigure(FigureBase):
def columns(self):
return self.subfigures.shape[1]
def add_subfigure(self, figure, row, column):
self.subfigures[row, column] = figure
def to_plotly(self):
title_extractor = lambda z: z.title if z else ""
sub_titles = tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
sub_titles = tuple(np.vectorize(title_extractor)(
self.subfigures.flatten()))
specs_setter = lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {}
sub_specs = list(map(list, np.vectorize(specs_setter)(self.subfigures)))
specs_setter = lambda z: ({'is_3d': True}
if isinstance(z, Figure3D) else {})
sub_specs = list(map(list, np.vectorize(specs_setter)(
self.subfigures)))
multi_figure_plotly_object = plotly.tools.make_subplots(
self.rows,
......@@ -46,7 +54,7 @@ class MultiFigure(FigureBase):
width=self.width,
title=self.title)
#TODO resolve bug with titles and 3D subplots
# TODO resolve bug with titles and 3D subplots
i = 1
for index, fig in np.ndenumerate(self.subfigures):
......
# -*- coding: utf-8 -*-
from mollweide import Mollweide
from hpmollweide import HPMollweide
from glmollweide import GLMollweide
from heatmap import Heatmap
# -*- coding: utf-8 -*-
from nifty import dependency_injector as gdi
from heatmap import Heatmap
import numpy as np
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
class GLMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None,
webgl=False, smoothing=False):
# smoothing 'best', 'fast', False
if 'pyHealpix' not in gdi:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing)
def at(self, data):
if isinstance(data, list):
data = [self._mollview(d) for d in data]
else:
data = self._mollview(data)
return GLMollweide(data=data,
xsize=self.xsize,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
@staticmethod
def _find_closest(A, target):
# A must be sorted
idx = A.searchsorted(target)
idx = np.clip(idx, 1, len(A)-1)
left = A[idx-1]
right = A[idx]
idx -= target - left < right - target
return idx
def _mollview(self, x):
xsize = self.xsize
nlat = x.shape[0]
nlon = x.shape[1]
res, mask, theta, phi = mollweide_helper(xsize)
ra = np.linspace(0, 2*np.pi, nlon+1)
dec = pyHealpix.GL_thetas(nlat)
ilat = self._find_closest(dec, theta)
ilon = self._find_closest(ra, phi)
ilon = np.where(ilon == nlon, 0, ilon)
res[mask] = x[ilat, ilon]
return res
......@@ -6,14 +6,8 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
if isinstance(data, list):
self.data = np.zeros((data[0].shape))
for arr in data:
self.data = np.add(self.data, arr)
else:
self.data = data
def __init__(self, data, color_map=None, webgl=False, smoothing=False):
# smoothing 'best', 'fast', False
if color_map is not None:
if not isinstance(color_map, Colormap):
......@@ -22,6 +16,19 @@ class Heatmap(PlotlyWrapper):
self.color_map = color_map
self.webgl = webgl
self.smoothing = smoothing
self.data = data
def at(self, data):
if isinstance(data, list):
temp_data = np.zeros((data[0].shape))
for arr in data:
temp_data = np.add(temp_data, arr)
else:
temp_data = data
return Heatmap(data=temp_data,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
@property
def figure_dimension(self):
......@@ -29,7 +36,9 @@ class Heatmap(PlotlyWrapper):
def to_plotly(self):
plotly_object = dict()
plotly_object['z'] = self.data
plotly_object['showscale'] = False
if self.color_map:
plotly_object['colorscale'] = self.color_map.to_plotly()
......
# -*- coding: utf-8 -*-
from nifty import dependency_injector as gdi
from heatmap import Heatmap
import numpy as np
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
class HPMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
if 'pyHealpix' not in gdi:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing)
def at(self, data):
if isinstance(data, list):
data = [self._mollview(d) for d in data]
else:
data = self._mollview(data)
return HPMollweide(data=data,
xsize=self.xsize,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
def _mollview(self, x):
xsize = self.xsize
res, mask, theta, phi = mollweide_helper(xsize)
ptg = np.empty((phi.size, 2), dtype=np.float64)
ptg[:, 0] = theta
ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(x.size/12)), "RING")
res[mask] = x[base.ang2pix(ptg)]
return res
# -*- coding: utf-8 -*-
from nifty import dependency_injector as gdi
from heatmap import Heatmap
pylab = gdi.get('pylab')
healpy = gdi.get('healpy')
class Mollweide(Heatmap):
def __init__(self, data, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
if 'pylab' not in gdi:
raise ImportError("The module pylab is needed but not available.")
if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available.")
if isinstance(data, list):
data = [self._mollview(d) for d in data]
else:
data = self._mollview(data)
super(Mollweide, self).__init__(data, color_map, webgl, smoothing)
def _mollview(self, x, xsize=800):
x = healpy.pixelfunc.ma_to_array(x)
f = pylab.figure(None, figsize=(8.5, 5.4))
extent = (0.02, 0.05, 0.96, 0.9)
ax = healpy.projaxes.HpxMollweideAxes(f, extent)
img = ax.projmap(x, nest=False, xsize=xsize)
return img
# -*- coding: utf-8 -*-
import numpy as np
def mollweide_helper(xsize):
xsize = int(xsize)
ysize = int(xsize/2)
res = np.full(shape=(ysize, xsize), fill_value=np.nan,
dtype=np.float64)
xc = (xsize-1)*0.5
yc = (ysize-1)*0.5
u, v = np.meshgrid(np.arange(xsize), np.arange(ysize))
u = 2*(u-xc)/(xc/1.02)
v = (v-yc)/(yc/1.02)
mask = np.where((u*u*0.25 + v*v) <= 1.)
t1 = v[mask]
theta = 0.5*np.pi-(
np.arcsin(2/np.pi*(np.arcsin(t1) + t1*np.sqrt((1.-t1)*(1+t1)))))
phi = -0.5*np.pi*u[mask]/np.maximum(np.sqrt((1-t1)*(1+t1)), 1e-6)
phi = np.where(phi < 0, phi+2*np.pi, phi)
return res, mask, theta, phi
......@@ -4,16 +4,14 @@ from scatter_plot import ScatterPlot
class Cartesian(ScatterPlot):
def __init__(self, x, y, label, line, marker, showlegend=True):
super(Cartesian, self).__init__(label, line, marker)
self.x = x
self.y = y
def __init__(self, data, label, line, marker, showlegend=True):
super(Cartesian, self).__init__(data, label, line, marker)
self.showlegend = showlegend
@abstractmethod
def to_plotly(self):
plotly_object = super(Cartesian, self).to_plotly()
plotly_object['x'] = self.x
plotly_object['y'] = self.y
plotly_object['x'] = self.data[0]
plotly_object['y'] = self.data[1]
plotly_object['showlegend'] = self.showlegend
return plotly_object
......@@ -4,17 +4,20 @@ from cartesian import Cartesian
class Cartesian2D(Cartesian):
def __init__(self, x=None, y=None, x_start=0, x_step=1,
label='', line=None, marker=None, showlegend=True,
def __init__(self, data, label='', line=None, marker=None, showlegend=True,
webgl=True):
if y is None:
raise Exception('Error: no y data to plot')
if x is None:
x = range(x_start, len(y) * x_step, x_step)
super(Cartesian2D, self).__init__(x, y, label, line, marker,
super(Cartesian2D, self).__init__(data, label, line, marker,
showlegend)
self.webgl = webgl
def at(self, data):
return Cartesian2D(data=data,
label=self.label,
line=self.line,
marker=self.marker,
showlegend=self.showlegend,
webgl=self.webgl)
@property
def figure_dimension(self):
return 2
......
......@@ -4,11 +4,17 @@ from cartesian import Cartesian
class Cartesian3D(Cartesian):
def __init__(self, x, y, z, label='', line=None, marker=None,
def __init__(self, data, label='', line=None, marker=None,
showlegend=True):
super(Cartesian3D, self).__init__(x, y, label, line, marker,
super(Cartesian3D, self).__init__(data, label, line, marker,
showlegend)
self.z = z
def at(self, data):
return Cartesian3D(data=data,
label=self.label,
line=self.line,
marker=self.marker,
showlegend=self.showlegend)
@property
def figure_dimension(self):
......@@ -16,6 +22,6 @@ class Cartesian3D(Cartesian):
def to_plotly(self):
plotly_object = super(Cartesian3D, self).to_plotly()
plotly_object['z'] = self.z
plotly_object['z'] = self.data[2]
plotly_object['type'] = 'scatter3d'
return plotly_object
......@@ -3,26 +3,31 @@ from scatter_plot import ScatterPlot
class Geo(ScatterPlot):
def __init__(self, lon, lat, label='', line=None, marker=None,
proj='mollweide'):
def __init__(self, data, label='', line=None, marker=None,
projection='mollweide'):
"""
proj: mollweide or mercator
"""
super.__init__(label, line, marker)
self.lon = lon
self.lat = lat
self.projection = proj
super(Geo, self).__init__(label, line, marker)
self.projection = projection
def at(self, data):
return Geo(data=data,
label=self.label,
line=self.line,
marker=self.marker,
projection=self.projection)
@property
def figure_dimension(self):
return 2
def _to_plotly(self):
def _to_plotly(self, data):
plotly_object = super(Geo, self).to_plotly()
plotly_object['type'] = 'scattergeo'
plotly_object['lon'] = self.lon
plotly_object['lat'] = self.lat
plotly_object['type'] = self.projection
plotly_object['lon'] = data[0]
plotly_object['lat'] = data[1]
if self.line:
plotly_object['mode'] = 'lines'
return plotly_object
......@@ -7,7 +7,8 @@ from nifty.plotting.descriptors import Marker,\
class ScatterPlot(PlotlyWrapper):
def __init__(self, label, line, marker):
def __init__(self, data, label, line, marker):
self.data = data
self.label = label
self.line = line
self.marker = marker
......@@ -15,6 +16,10 @@ class ScatterPlot(PlotlyWrapper):
self.marker = Marker()
self.line = Line()
@abc.abstractmethod
def at(self, data):
raise NotImplementedError
@abc.abstractproperty
def figure_dimension(self):
raise NotImplementedError
......
from healpix_plotter import HealpixPlotter
from gl_plotter import GLPlotter
from power_plotter import PowerPlotter
from rg2d_plotter import RG2DPlotter
import numpy as np
from nifty.spaces import GLSpace
from nifty.plotting.figures import Figure2D
from nifty.plotting.plots import GLMollweide
from .plotter_base import PlotterBase
class GLPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
self.color_map = color_map
super(GLPlotter, self).__init__(interactive, path, title)
@property
def domain_classes(self):
return (GLSpace, )
def _initialize_plot(self):
result_plot = GLMollweide(data=None,
color_map=self.color_map)
return result_plot
def _initialize_figure(self):
return Figure2D(plots=None)
def _parse_data(self, data, field, spaces):
gl_space = field.domain[spaces[0]]
data = np.reshape(data, (gl_space.nlat, gl_space.nlon))
return data
from nifty.spaces import HPSpace
from nifty.plotting.figures import Figure2D
from nifty.plotting.plots import Mollweide
from .plotter import Plotter
from nifty.plotting.plots import HPMollweide
from .plotter_base import PlotterBase
class HealpixPlotter(Plotter):
class HealpixPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
super(HealpixPlotter, self).__init__(interactive, path, title)
self.color_map = color_map
super(HealpixPlotter, self).__init__(interactive, path, title)
@property
def domain_classes(self):
return (HPSpace, )
def _create_individual_figure(self, plots):
return Figure2D(plots)
def _create_individual_plot(self, data, plot_domain):
result_plot = Mollweide(data=data,
color_map=self.color_map)
def _initialize_plot(self):
result_plot = HPMollweide(data=None,
color_map=self.color_map)
return result_plot
def _initialize_figure(self):
return Figure2D(plots=None)
......@@ -27,7 +27,7 @@ rank = d2o.config.dependency_injector[
d2o.configuration['mpi_module']].COMM_WORLD.rank
class Plotter(Loggable, object):
class PlotterBase(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', title=""):
......@@ -37,6 +37,10 @@ class Plotter(Loggable, object):
self.path = path
self.title = str(title)
self.plot = self._initialize_plot()
self.figure = self._initialize_figure()
self.multi_figure = self._initialize_multifigure()
@abc.abstractproperty
def domain_classes(self):
return (Space,)
......@@ -57,7 +61,7 @@ class Plotter(Loggable, object):
def path(self, new_path):
self._path = os.path.normpath(new_path)
def plot(self, fields, spaces=None, data_extractor=None, labels=None):
def __call__(self, fields, spaces=None, data_extractor=None, labels=None):
if isinstance(fields, Field):
fields = [fields]
elif not isinstance(fields, list):
......@@ -82,12 +86,12 @@ class Plotter(Loggable, object):
plots_list = []
for slice_list in utilities.get_slice_list(data_list[0].shape, axes):
plots_list += \
[[self._create_individual_plot(current_data[slice_list],
plot_domain)
for current_data in data_list]]
[[self.plot.at(self._parse_data(current_data,
field,
spaces))
for (current_data, field) in zip(data_list, fields)]]
figures = [self._create_individual_figure(plots)
for plots in plots_list]
figures = [self.figure.at(plots) for plots in plots_list]
self._finalize_figure(figures)
......@@ -103,13 +107,16 @@ class Plotter(Loggable, object):
return data
@abc.abstractmethod
def _create_individual_figure(self, plots):
def _initialize_plot(self):
raise NotImplementedError
@abc.abstractmethod
def _create_individual_plot(self