Commit 4f0c5156 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'master' into line_search

parents fe2c9d98 b8bd4934
Pipeline #15417 passed with stage
in 7 minutes and 18 seconds
...@@ -612,39 +612,47 @@ class Field(Loggable, Versionable, object): ...@@ -612,39 +612,47 @@ class Field(Loggable, Versionable, object):
# correct variance # correct variance
if preserve_gaussian_variance: if preserve_gaussian_variance:
assert issubclass(val.dtype.type, np.complexfloating),\
"complex input field is needed here"
h *= np.sqrt(2) h *= np.sqrt(2)
a *= np.sqrt(2) a *= np.sqrt(2)
if not issubclass(val.dtype.type, np.complexfloating): # The code below should not be needed in practice, since it would
# in principle one must not correct the variance for the fixed # only ever be called when hermitianizing a purely real field.
# points of the hermitianization. However, for a complex field # However it might be of educational use and keep us from forgetting
# the input field loses half of its power at its fixed points # how these things are done ...
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary! # if not issubclass(val.dtype.type, np.complexfloating):
# => The hermitianization can be done on a space level since # # in principle one must not correct the variance for the fixed
# either nothing must be done (LMSpace) or ALL points need a # # points of the hermitianization. However, for a complex field
# factor of sqrt(2) # # the input field loses half of its power at its fixed points
# => use the preserve_gaussian_variance flag in the # # in the `hermitian` part. Hence, here a factor of sqrt(2) is
# hermitian_decomposition method above. # # also necessary!
# # => The hermitianization can be done on a space level since
# This code is for educational purposes: # # either nothing must be done (LMSpace) or ALL points need a
fixed_points = [domain[i].hermitian_fixed_points() # # factor of sqrt(2)
for i in spaces] # # => use the preserve_gaussian_variance flag in the
fixed_points = [[fp] if fp is None else fp # # hermitian_decomposition method above.
for fp in fixed_points] #
# # This code is for educational purposes:
for product_point in itertools.product(*fixed_points): # fixed_points = [domain[i].hermitian_fixed_points()
slice_object = np.array((slice(None), )*len(val.shape), # for i in spaces]
dtype=np.object) # fixed_points = [[fp] if fp is None else fp
for i, sp in enumerate(spaces): # for fp in fixed_points]
point_component = product_point[i] #
if point_component is None: # for product_point in itertools.product(*fixed_points):
point_component = slice(None) # slice_object = np.array((slice(None), )*len(val.shape),
slice_object[list(domain_axes[sp])] = point_component # dtype=np.object)
# for i, sp in enumerate(spaces):
slice_object = tuple(slice_object) # point_component = product_point[i]
h[slice_object] /= np.sqrt(2) # if point_component is None:
a[slice_object] /= np.sqrt(2) # point_component = slice(None)
# slice_object[list(domain_axes[sp])] = point_component
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return (h, a) return (h, a)
def _spec_to_rescaler(self, spec, result_list, power_space_index): def _spec_to_rescaler(self, spec, result_list, power_space_index):
......
...@@ -5,12 +5,14 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper ...@@ -5,12 +5,14 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Axis(PlotlyWrapper): class Axis(PlotlyWrapper):
def __init__(self, text=None, font='', color='', log=False, def __init__(self, text=None, font='', color='', log=False,
show_grid=True): font_size=18, show_grid=True, visible=True):
self.text = text self.text = text
self.font = font self.font = font
self.color = color self.color = color
self.log = log self.log = log
self.font_size = int(font_size)
self.show_grid = show_grid self.show_grid = show_grid
self.visible = visible
def to_plotly(self): def to_plotly(self):
ply_object = dict() ply_object = dict()
...@@ -19,11 +21,14 @@ class Axis(PlotlyWrapper): ...@@ -19,11 +21,14 @@ class Axis(PlotlyWrapper):
title=self.text, title=self.text,
titlefont=dict( titlefont=dict(
family=self.font, family=self.font,
color=self.color color=self.color,
size=self.font_size
) )
)) ))
if self.log: if self.log:
ply_object['type'] = 'log' ply_object['type'] = 'log'
if not self.show_grid: if not self.show_grid:
ply_object['showgrid'] = False ply_object['showgrid'] = False
ply_object['visible'] = self.visible
ply_object['tickfont'] = {'size': self.font_size}
return ply_object return ply_object
...@@ -7,8 +7,13 @@ from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide ...@@ -7,8 +7,13 @@ from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide
class Figure2D(FigureFromPlot): class Figure2D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None, def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None): xaxis=None, yaxis=None):
if plots is not None: if plots is not None:
width = width if width is not None else plots[0].default_width()
height = \
height if height is not None else plots[0].default_height()
(xaxis, yaxis) = \
xaxis if xaxis is not None else plots[0].default_axes()
if isinstance(plots[0], Heatmap) and width is None and \ if isinstance(plots[0], Heatmap) and width is None and \
height is None: height is None:
(x, y) = plots[0].data.shape (x, y) = plots[0].data.shape
...@@ -29,9 +34,10 @@ class Figure2D(FigureFromPlot): ...@@ -29,9 +34,10 @@ class Figure2D(FigureFromPlot):
self.xaxis = xaxis self.xaxis = xaxis
self.yaxis = yaxis self.yaxis = yaxis
def at(self, plots): def at(self, plots, title=None):
title = title if title is not None else self.title
return Figure2D(plots=plots, return Figure2D(plots=plots,
title=self.title, title=title,
width=self.width, width=self.width,
height=self.height, height=self.height,
xaxis=self.xaxis, xaxis=self.xaxis,
......
...@@ -5,14 +5,21 @@ from figure_from_plot import FigureFromPlot ...@@ -5,14 +5,21 @@ from figure_from_plot import FigureFromPlot
class Figure3D(FigureFromPlot): class Figure3D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None, def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None, zaxis=None): xaxis=None, yaxis=None, zaxis=None):
if plots is not None:
width = width if width is not None else plots[0].default_width()
height = \
height if height is not None else plots[0].default_height()
(xaxis, yaxis, zaxis) = \
xaxis if xaxis is not None else plots[0].default_axes()
super(Figure3D, self).__init__(plots, title, width, height) super(Figure3D, self).__init__(plots, title, width, height)
self.xaxis = xaxis self.xaxis = xaxis
self.yaxis = yaxis self.yaxis = yaxis
self.zaxis = zaxis self.zaxis = zaxis
def at(self, plots): def at(self, plots, title=None):
title = title if title is not None else self.title
return Figure3D(plots=plots, return Figure3D(plots=plots,
title=self.title, title=title,
width=self.width, width=self.width,
height=self.height, height=self.height,
xaxis=self.xaxis, xaxis=self.xaxis,
......
...@@ -12,7 +12,7 @@ class FigureBase(PlotlyWrapper): ...@@ -12,7 +12,7 @@ class FigureBase(PlotlyWrapper):
self.height = height self.height = height
@abc.abstractmethod @abc.abstractmethod
def at(self): def at(self, title=None):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
......
...@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi ...@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi
from heatmap import Heatmap from heatmap import Heatmap
import numpy as np import numpy as np
from nifty.plotting.descriptors import Axis
from .mollweide_helper import mollweide_helper from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix') pyHealpix = gdi.get('pyHealpix')
...@@ -11,14 +13,15 @@ pyHealpix = gdi.get('pyHealpix') ...@@ -11,14 +13,15 @@ pyHealpix = gdi.get('pyHealpix')
class GLMollweide(Heatmap): class GLMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None, def __init__(self, data, xsize=800, color_map=None,
webgl=False, smoothing=False): webgl=False, smoothing=False, zmin=None, zmax=None):
# smoothing 'best', 'fast', False # smoothing 'best', 'fast', False
if pyHealpix is None: if pyHealpix is None:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
self.xsize = xsize self.xsize = xsize
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing) super(GLMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
def at(self, data): def at(self, data):
if isinstance(data, list): if isinstance(data, list):
...@@ -55,3 +58,12 @@ class GLMollweide(Heatmap): ...@@ -55,3 +58,12 @@ class GLMollweide(Heatmap):
ilon = np.where(ilon == nlon, 0, ilon) ilon = np.where(ilon == nlon, 0, ilon)
res[mask] = x[ilat, ilon] res[mask] = x[ilat, ilon]
return res return res
def default_width(self):
return 1400
def default_height(self):
return 700
def default_axes(self):
return (Axis(visible=False), Axis(visible=False))
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
from nifty.plotting.descriptors import Axis
from nifty.plotting.colormap import Colormap from nifty.plotting.colormap import Colormap
from nifty.plotting.plotly_wrapper import PlotlyWrapper from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper): class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False, smoothing=False): def __init__(self, data, color_map=None, webgl=False, smoothing=False,
zmin=None, zmax=None):
# smoothing 'best', 'fast', False # smoothing 'best', 'fast', False
if color_map is not None: if color_map is not None:
...@@ -17,6 +20,9 @@ class Heatmap(PlotlyWrapper): ...@@ -17,6 +20,9 @@ class Heatmap(PlotlyWrapper):
self.webgl = webgl self.webgl = webgl
self.smoothing = smoothing self.smoothing = smoothing
self.data = data self.data = data
self.zmin = zmin
self.zmax = zmax
self._font_size = 18
def at(self, data): def at(self, data):
if isinstance(data, list): if isinstance(data, list):
...@@ -28,7 +34,9 @@ class Heatmap(PlotlyWrapper): ...@@ -28,7 +34,9 @@ class Heatmap(PlotlyWrapper):
return Heatmap(data=temp_data, return Heatmap(data=temp_data,
color_map=self.color_map, color_map=self.color_map,
webgl=self.webgl, webgl=self.webgl,
smoothing=self.smoothing) smoothing=self.smoothing,
zmin=self.zmin,
zmax=self.zmax)
@property @property
def figure_dimension(self): def figure_dimension(self):
...@@ -38,11 +46,13 @@ class Heatmap(PlotlyWrapper): ...@@ -38,11 +46,13 @@ class Heatmap(PlotlyWrapper):
plotly_object = dict() plotly_object = dict()
plotly_object['z'] = self.data plotly_object['z'] = self.data
plotly_object['zmin'] = self.zmin
plotly_object['zmax'] = self.zmax
plotly_object['showscale'] = False plotly_object['showscale'] = True
plotly_object['colorbar'] = {'tickfont': {'size': self._font_size}}
if self.color_map: if self.color_map:
plotly_object['colorscale'] = self.color_map.to_plotly() plotly_object['colorscale'] = self.color_map.to_plotly()
plotly_object['colorbar'] = dict(title=self.color_map.name, x=0.42)
if self.webgl: if self.webgl:
plotly_object['type'] = 'heatmapgl' plotly_object['type'] = 'heatmapgl'
else: else:
...@@ -50,3 +60,14 @@ class Heatmap(PlotlyWrapper): ...@@ -50,3 +60,14 @@ class Heatmap(PlotlyWrapper):
if self.smoothing: if self.smoothing:
plotly_object['zsmooth'] = self.smoothing plotly_object['zsmooth'] = self.smoothing
return plotly_object return plotly_object
def default_width(self):
return 700
def default_height(self):
(x, y) = self.data.shape
return int(700 * y / x)
def default_axes(self):
return (Axis(font_size=self._font_size),
Axis(font_size=self._font_size))
...@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi ...@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi
from heatmap import Heatmap from heatmap import Heatmap
import numpy as np import numpy as np
from nifty.plotting.descriptors import Axis
from .mollweide_helper import mollweide_helper from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix') pyHealpix = gdi.get('pyHealpix')
...@@ -11,12 +13,13 @@ pyHealpix = gdi.get('pyHealpix') ...@@ -11,12 +13,13 @@ pyHealpix = gdi.get('pyHealpix')
class HPMollweide(Heatmap): class HPMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None, webgl=False, def __init__(self, data, xsize=800, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False smoothing=False, zmin=None, zmax=None): # smoothing 'best', 'fast', False
if pyHealpix is None: if pyHealpix is None:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
self.xsize = xsize self.xsize = xsize
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing) super(HPMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
def at(self, data): def at(self, data):
if isinstance(data, list): if isinstance(data, list):
...@@ -39,3 +42,12 @@ class HPMollweide(Heatmap): ...@@ -39,3 +42,12 @@ class HPMollweide(Heatmap):
base = pyHealpix.Healpix_Base(int(np.sqrt(x.size/12)), "RING") base = pyHealpix.Healpix_Base(int(np.sqrt(x.size/12)), "RING")
res[mask] = x[base.ang2pix(ptg)] res[mask] = x[base.ang2pix(ptg)]
return res return res
def default_width(self):
return 1400
def default_height(self):
return 700
def default_axes(self):
return (Axis(visible=False), Axis(visible=False))
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from nifty.plotting.descriptors import Axis
from cartesian import Cartesian from cartesian import Cartesian
...@@ -30,3 +31,6 @@ class Cartesian2D(Cartesian): ...@@ -30,3 +31,6 @@ class Cartesian2D(Cartesian):
plotly_object['type'] = 'scatter' plotly_object['type'] = 'scatter'
return plotly_object return plotly_object
def default_axes(self):
return (Axis(), Axis())
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from nifty.plotting.descriptors import Axis
from cartesian import Cartesian from cartesian import Cartesian
...@@ -25,3 +26,6 @@ class Cartesian3D(Cartesian): ...@@ -25,3 +26,6 @@ class Cartesian3D(Cartesian):
plotly_object['z'] = self.data[2] plotly_object['z'] = self.data[2]
plotly_object['type'] = 'scatter3d' plotly_object['type'] = 'scatter3d'
return plotly_object return plotly_object
def default_axes(self):
return (Axis(), Axis(), Axis())
from nifty.plotting.descriptors import Axis
from scatter_plot import ScatterPlot from scatter_plot import ScatterPlot
...@@ -31,3 +32,6 @@ class Geo(ScatterPlot): ...@@ -31,3 +32,6 @@ class Geo(ScatterPlot):
if self.line: if self.line:
plotly_object['mode'] = 'lines' plotly_object['mode'] = 'lines'
return plotly_object return plotly_object
def default_axes(self):
return (Axis(), Axis())
...@@ -40,3 +40,13 @@ class ScatterPlot(PlotlyWrapper): ...@@ -40,3 +40,13 @@ class ScatterPlot(PlotlyWrapper):
ply_object['marker'] = self.marker.to_plotly() ply_object['marker'] = self.marker.to_plotly()
return ply_object return ply_object
def default_width(self):
return 700
def default_height(self):
return 700
@abc.abstractmethod
def default_axes(self):
raise NotImplementedError
...@@ -9,9 +9,9 @@ from .plotter_base import PlotterBase ...@@ -9,9 +9,9 @@ from .plotter_base import PlotterBase
class GLPlotter(PlotterBase): class GLPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None): def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map self.color_map = color_map
super(GLPlotter, self).__init__(interactive, path, title) super(GLPlotter, self).__init__(interactive, path)
@property @property
def domain_classes(self): def domain_classes(self):
......
...@@ -6,9 +6,9 @@ from .plotter_base import PlotterBase ...@@ -6,9 +6,9 @@ from .plotter_base import PlotterBase
class HealpixPlotter(PlotterBase): class HealpixPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None): def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map self.color_map = color_map
super(HealpixPlotter, self).__init__(interactive, path, title) super(HealpixPlotter, self).__init__(interactive, path)
@property @property
def domain_classes(self): def domain_classes(self):
......
...@@ -30,12 +30,11 @@ rank = d2o.config.dependency_injector[ ...@@ -30,12 +30,11 @@ rank = d2o.config.dependency_injector[
class PlotterBase(Loggable, object): class PlotterBase(Loggable, object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', title=""): def __init__(self, interactive=False, path='plot.html'):
if plotly is None: if plotly is None:
raise ImportError("The module plotly is needed but not available.") raise ImportError("The module plotly is needed but not available.")
self.interactive = interactive self.interactive = interactive
self.path = path self.path = path
self.title = str(title)
self.plot = self._initialize_plot() self.plot = self._initialize_plot()
self.figure = self._initialize_figure() self.figure = self._initialize_figure()
...@@ -61,7 +60,8 @@ class PlotterBase(Loggable, object): ...@@ -61,7 +60,8 @@ class PlotterBase(Loggable, object):
def path(self, new_path): def path(self, new_path):
self._path = os.path.normpath(new_path) self._path = os.path.normpath(new_path)
def __call__(self, fields, spaces=None, data_extractor=None, labels=None): def __call__(self, fields, spaces=None, data_extractor=None, labels=None,
path=None, title=None):
if isinstance(fields, Field): if isinstance(fields, Field):
fields = [fields] fields = [fields]
elif not isinstance(fields, list): elif not isinstance(fields, list):
...@@ -72,6 +72,9 @@ class PlotterBase(Loggable, object): ...@@ -72,6 +72,9 @@ class PlotterBase(Loggable, object):
if spaces is None: if spaces is None:
spaces = tuple(range(len(fields[0].domain))) spaces = tuple(range(len(fields[0].domain)))
if len(spaces) != len(self.domain_classes):
raise ValueError("Domain mismatch between input and plotter.")
axes = [] axes = []
plot_domain = [] plot_domain = []
for space_index in spaces: for space_index in spaces:
...@@ -91,9 +94,9 @@ class PlotterBase(Loggable, object): ...@@ -91,9 +94,9 @@ class PlotterBase(Loggable, object):
spaces)) spaces))
for (current_data, field) in zip(data_list, fields)]] for (current_data, field) in zip(data_list, fields)]]
figures = [self.figure.at(plots) for plots in plots_list] figures = [self.figure.at(plots, title=title) for plots in plots_list]
self._finalize_figure(figures) self._finalize_figure(figures, path=path)
def _get_data_from_field(self, field, spaces, data_extractor): def _get_data_from_field(self, field, spaces, data_extractor):
for i, space_index in enumerate(spaces): for i, space_index in enumerate(spaces):
...@@ -117,7 +120,7 @@ class PlotterBase(Loggable, object): ...@@ -117,7 +120,7 @@ class PlotterBase(Loggable, object):
def _initialize_multifigure(self): def _initialize_multifigure(self):
return MultiFigure(subfigures=None) return MultiFigure(subfigures=None)
def _finalize_figure(self, figures): def _finalize_figure(self, figures, path=None):
if len(figures) > 1: if len(figures) > 1:
rows = (len(figures) + 1)//2 rows = (len(figures) + 1)//2
figure_array = np.empty((2*rows), dtype=np.object) figure_array = np.empty((2*rows), dtype=np.object)
...@@ -128,5 +131,6 @@ class PlotterBase(Loggable, object): ...@@ -128,5 +131,6 @@ class PlotterBase(Loggable, object):
else: else: