Commit 4f0c5156 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'master' into line_search

parents fe2c9d98 b8bd4934
Pipeline #15417 passed with stage
in 7 minutes and 18 seconds
......@@ -612,39 +612,47 @@ class Field(Loggable, Versionable, object):
# correct variance
if preserve_gaussian_variance:
assert issubclass(val.dtype.type, np.complexfloating),\
"complex input field is needed here"
h *= np.sqrt(2)
a *= np.sqrt(2)
if not issubclass(val.dtype.type, np.complexfloating):
# in principle one must not correct the variance for the fixed
# points of the hermitianization. However, for a complex field
# the input field loses half of its power at its fixed points
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary!
# => The hermitianization can be done on a space level since
# either nothing must be done (LMSpace) or ALL points need a
# factor of sqrt(2)
# => use the preserve_gaussian_variance flag in the
# hermitian_decomposition method above.
# This code is for educational purposes:
fixed_points = [domain[i].hermitian_fixed_points()
for i in spaces]
fixed_points = [[fp] if fp is None else fp
for fp in fixed_points]
for product_point in itertools.product(*fixed_points):
slice_object = np.array((slice(None), )*len(val.shape),
dtype=np.object)
for i, sp in enumerate(spaces):
point_component = product_point[i]
if point_component is None:
point_component = slice(None)
slice_object[list(domain_axes[sp])] = point_component
slice_object = tuple(slice_object)
h[slice_object] /= np.sqrt(2)
a[slice_object] /= np.sqrt(2)
# The code below should not be needed in practice, since it would
# only ever be called when hermitianizing a purely real field.
# However it might be of educational use and keep us from forgetting
# how these things are done ...
# if not issubclass(val.dtype.type, np.complexfloating):
# # in principle one must not correct the variance for the fixed
# # points of the hermitianization. However, for a complex field
# # the input field loses half of its power at its fixed points
# # in the `hermitian` part. Hence, here a factor of sqrt(2) is
# # also necessary!
# # => The hermitianization can be done on a space level since
# # either nothing must be done (LMSpace) or ALL points need a
# # factor of sqrt(2)
# # => use the preserve_gaussian_variance flag in the
# # hermitian_decomposition method above.
#
# # This code is for educational purposes:
# fixed_points = [domain[i].hermitian_fixed_points()
# for i in spaces]
# fixed_points = [[fp] if fp is None else fp
# for fp in fixed_points]
#
# for product_point in itertools.product(*fixed_points):
# slice_object = np.array((slice(None), )*len(val.shape),
# dtype=np.object)
# for i, sp in enumerate(spaces):
# point_component = product_point[i]
# if point_component is None:
# point_component = slice(None)
# slice_object[list(domain_axes[sp])] = point_component
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return (h, a)
def _spec_to_rescaler(self, spec, result_list, power_space_index):
......
......@@ -5,12 +5,14 @@ from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Axis(PlotlyWrapper):
def __init__(self, text=None, font='', color='', log=False,
show_grid=True):
font_size=18, show_grid=True, visible=True):
self.text = text
self.font = font
self.color = color
self.log = log
self.font_size = int(font_size)
self.show_grid = show_grid
self.visible = visible
def to_plotly(self):
ply_object = dict()
......@@ -19,11 +21,14 @@ class Axis(PlotlyWrapper):
title=self.text,
titlefont=dict(
family=self.font,
color=self.color
color=self.color,
size=self.font_size
)
))
if self.log:
ply_object['type'] = 'log'
if not self.show_grid:
ply_object['showgrid'] = False
ply_object['visible'] = self.visible
ply_object['tickfont'] = {'size': self.font_size}
return ply_object
......@@ -7,8 +7,13 @@ from nifty.plotting.plots import Heatmap, HPMollweide, GLMollweide
class Figure2D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None):
if plots is not None:
width = width if width is not None else plots[0].default_width()
height = \
height if height is not None else plots[0].default_height()
(xaxis, yaxis) = \
xaxis if xaxis is not None else plots[0].default_axes()
if isinstance(plots[0], Heatmap) and width is None and \
height is None:
(x, y) = plots[0].data.shape
......@@ -29,9 +34,10 @@ class Figure2D(FigureFromPlot):
self.xaxis = xaxis
self.yaxis = yaxis
def at(self, plots):
def at(self, plots, title=None):
title = title if title is not None else self.title
return Figure2D(plots=plots,
title=self.title,
title=title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
......
......@@ -5,14 +5,21 @@ from figure_from_plot import FigureFromPlot
class Figure3D(FigureFromPlot):
def __init__(self, plots, title=None, width=None, height=None,
xaxis=None, yaxis=None, zaxis=None):
if plots is not None:
width = width if width is not None else plots[0].default_width()
height = \
height if height is not None else plots[0].default_height()
(xaxis, yaxis, zaxis) = \
xaxis if xaxis is not None else plots[0].default_axes()
super(Figure3D, self).__init__(plots, title, width, height)
self.xaxis = xaxis
self.yaxis = yaxis
self.zaxis = zaxis
def at(self, plots):
def at(self, plots, title=None):
title = title if title is not None else self.title
return Figure3D(plots=plots,
title=self.title,
title=title,
width=self.width,
height=self.height,
xaxis=self.xaxis,
......
......@@ -12,7 +12,7 @@ class FigureBase(PlotlyWrapper):
self.height = height
@abc.abstractmethod
def at(self):
def at(self, title=None):
raise NotImplementedError
@abc.abstractmethod
......
......@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi
from heatmap import Heatmap
import numpy as np
from nifty.plotting.descriptors import Axis
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -11,14 +13,15 @@ pyHealpix = gdi.get('pyHealpix')
class GLMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None,
webgl=False, smoothing=False):
webgl=False, smoothing=False, zmin=None, zmax=None):
# smoothing 'best', 'fast', False
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing)
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
def at(self, data):
if isinstance(data, list):
......@@ -55,3 +58,12 @@ class GLMollweide(Heatmap):
ilon = np.where(ilon == nlon, 0, ilon)
res[mask] = x[ilat, ilon]
return res
def default_width(self):
return 1400
def default_height(self):
return 700
def default_axes(self):
return (Axis(visible=False), Axis(visible=False))
# -*- coding: utf-8 -*-
import numpy as np
from nifty.plotting.descriptors import Axis
from nifty.plotting.colormap import Colormap
from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False, smoothing=False):
def __init__(self, data, color_map=None, webgl=False, smoothing=False,
zmin=None, zmax=None):
# smoothing 'best', 'fast', False
if color_map is not None:
......@@ -17,6 +20,9 @@ class Heatmap(PlotlyWrapper):
self.webgl = webgl
self.smoothing = smoothing
self.data = data
self.zmin = zmin
self.zmax = zmax
self._font_size = 18
def at(self, data):
if isinstance(data, list):
......@@ -28,7 +34,9 @@ class Heatmap(PlotlyWrapper):
return Heatmap(data=temp_data,
color_map=self.color_map,
webgl=self.webgl,
smoothing=self.smoothing)
smoothing=self.smoothing,
zmin=self.zmin,
zmax=self.zmax)
@property
def figure_dimension(self):
......@@ -38,11 +46,13 @@ class Heatmap(PlotlyWrapper):
plotly_object = dict()
plotly_object['z'] = self.data
plotly_object['zmin'] = self.zmin
plotly_object['zmax'] = self.zmax
plotly_object['showscale'] = False
plotly_object['showscale'] = True
plotly_object['colorbar'] = {'tickfont': {'size': self._font_size}}
if self.color_map:
plotly_object['colorscale'] = self.color_map.to_plotly()
plotly_object['colorbar'] = dict(title=self.color_map.name, x=0.42)
if self.webgl:
plotly_object['type'] = 'heatmapgl'
else:
......@@ -50,3 +60,14 @@ class Heatmap(PlotlyWrapper):
if self.smoothing:
plotly_object['zsmooth'] = self.smoothing
return plotly_object
def default_width(self):
return 700
def default_height(self):
(x, y) = self.data.shape
return int(700 * y / x)
def default_axes(self):
return (Axis(font_size=self._font_size),
Axis(font_size=self._font_size))
......@@ -4,6 +4,8 @@ from nifty import dependency_injector as gdi
from heatmap import Heatmap
import numpy as np
from nifty.plotting.descriptors import Axis
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -11,12 +13,13 @@ pyHealpix = gdi.get('pyHealpix')
class HPMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
smoothing=False, zmin=None, zmax=None): # smoothing 'best', 'fast', False
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing)
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
def at(self, data):
if isinstance(data, list):
......@@ -39,3 +42,12 @@ class HPMollweide(Heatmap):
base = pyHealpix.Healpix_Base(int(np.sqrt(x.size/12)), "RING")
res[mask] = x[base.ang2pix(ptg)]
return res
def default_width(self):
return 1400
def default_height(self):
return 700
def default_axes(self):
return (Axis(visible=False), Axis(visible=False))
# -*- coding: utf-8 -*-
from nifty.plotting.descriptors import Axis
from cartesian import Cartesian
......@@ -30,3 +31,6 @@ class Cartesian2D(Cartesian):
plotly_object['type'] = 'scatter'
return plotly_object
def default_axes(self):
return (Axis(), Axis())
# -*- coding: utf-8 -*-
from nifty.plotting.descriptors import Axis
from cartesian import Cartesian
......@@ -25,3 +26,6 @@ class Cartesian3D(Cartesian):
plotly_object['z'] = self.data[2]
plotly_object['type'] = 'scatter3d'
return plotly_object
def default_axes(self):
return (Axis(), Axis(), Axis())
from nifty.plotting.descriptors import Axis
from scatter_plot import ScatterPlot
......@@ -31,3 +32,6 @@ class Geo(ScatterPlot):
if self.line:
plotly_object['mode'] = 'lines'
return plotly_object
def default_axes(self):
return (Axis(), Axis())
......@@ -40,3 +40,13 @@ class ScatterPlot(PlotlyWrapper):
ply_object['marker'] = self.marker.to_plotly()
return ply_object
def default_width(self):
return 700
def default_height(self):
return 700
@abc.abstractmethod
def default_axes(self):
raise NotImplementedError
......@@ -9,9 +9,9 @@ from .plotter_base import PlotterBase
class GLPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map
super(GLPlotter, self).__init__(interactive, path, title)
super(GLPlotter, self).__init__(interactive, path)
@property
def domain_classes(self):
......
......@@ -6,9 +6,9 @@ from .plotter_base import PlotterBase
class HealpixPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map
super(HealpixPlotter, self).__init__(interactive, path, title)
super(HealpixPlotter, self).__init__(interactive, path)
@property
def domain_classes(self):
......
......@@ -30,12 +30,11 @@ rank = d2o.config.dependency_injector[
class PlotterBase(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', title=""):
def __init__(self, interactive=False, path='plot.html'):
if plotly is None:
raise ImportError("The module plotly is needed but not available.")
self.interactive = interactive
self.path = path
self.title = str(title)
self.plot = self._initialize_plot()
self.figure = self._initialize_figure()
......@@ -61,7 +60,8 @@ class PlotterBase(Loggable, object):
def path(self, new_path):
self._path = os.path.normpath(new_path)
def __call__(self, fields, spaces=None, data_extractor=None, labels=None):
def __call__(self, fields, spaces=None, data_extractor=None, labels=None,
path=None, title=None):
if isinstance(fields, Field):
fields = [fields]
elif not isinstance(fields, list):
......@@ -72,6 +72,9 @@ class PlotterBase(Loggable, object):
if spaces is None:
spaces = tuple(range(len(fields[0].domain)))
if len(spaces) != len(self.domain_classes):
raise ValueError("Domain mismatch between input and plotter.")
axes = []
plot_domain = []
for space_index in spaces:
......@@ -91,9 +94,9 @@ class PlotterBase(Loggable, object):
spaces))
for (current_data, field) in zip(data_list, fields)]]
figures = [self.figure.at(plots) for plots in plots_list]
figures = [self.figure.at(plots, title=title) for plots in plots_list]
self._finalize_figure(figures)
self._finalize_figure(figures, path=path)
def _get_data_from_field(self, field, spaces, data_extractor):
for i, space_index in enumerate(spaces):
......@@ -117,7 +120,7 @@ class PlotterBase(Loggable, object):
def _initialize_multifigure(self):
return MultiFigure(subfigures=None)
def _finalize_figure(self, figures):
def _finalize_figure(self, figures, path=None):
if len(figures) > 1:
rows = (len(figures) + 1)//2
figure_array = np.empty((2*rows), dtype=np.object)
......@@ -128,5 +131,6 @@ class PlotterBase(Loggable, object):
else:
final_figure = figures[0]
path = self.path if path is None else path
plotly.offline.plot(final_figure.to_plotly(),
filename=os.path.join(self.path, self.title))
filename=path)
......@@ -11,9 +11,9 @@ from .plotter_base import PlotterBase
class PowerPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="",
line=None, marker=None):
super(PowerPlotter, self).__init__(interactive, path, title)
def __init__(self, interactive=False, path='plot.html', line=None,
marker=None):
super(PowerPlotter, self).__init__(interactive, path)
self.line = line
self.marker = marker
......
......@@ -11,9 +11,9 @@ from .plotter_base import PlotterBase
class RG1DPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="",
line=None, marker=None):
super(RG1DPlotter, self).__init__(interactive, path, title)
def __init__(self, interactive=False, path='plot.html', line=None,
marker=None):
super(RG1DPlotter, self).__init__(interactive, path)
self.line = line
self.marker = marker
......
......@@ -8,9 +8,9 @@ from .plotter_base import PlotterBase
class RG2DPlotter(PlotterBase):
def __init__(self, interactive=False, path='.', title="", color_map=None):
def __init__(self, interactive=False, path='plot.html', color_map=None):
self.color_map = color_map
super(RG2DPlotter, self).__init__(interactive, path, title)
super(RG2DPlotter, self).__init__(interactive, path)
@property
def domain_classes(self):
......
......@@ -100,23 +100,26 @@ class RGSpace(Space):
self._distances = self._parse_distances(distances)
self._zerocenter = self._parse_zerocenter(zerocenter)
def hermitian_fixed_points(self):
dimensions = len(self.shape)
mid_index = np.array(self.shape)//2
ndlist = [1]*dimensions
for k in range(dimensions):
if self.shape[k] % 2 == 0:
ndlist[k] = 2
ndlist = tuple(ndlist)
fixed_points = []
for index in np.ndindex(ndlist):
for k in range(dimensions):
if self.shape[k] % 2 != 0 and self.zerocenter[k]:
index = list(index)
index[k] = 1
index = tuple(index)
fixed_points += [tuple(index * mid_index)]
return fixed_points
# This code is unused but may be useful to keep around if it is ever needed
# again in the future ...
# def hermitian_fixed_points(self):
# dimensions = len(self.shape)
# mid_index = np.array(self.shape)//2
# ndlist = [1]*dimensions
# for k in range(dimensions):
# if self.shape[k] % 2 == 0:
# ndlist[k] = 2
# ndlist = tuple(ndlist)
# fixed_points = []
# for index in np.ndindex(ndlist):
# for k in range(dimensions):
# if self.shape[k] % 2 != 0 and self.zerocenter[k]:
# index = list(index)
# index[k] = 1
# index = tuple(index)
# fixed_points += [tuple(index * mid_index)]
# return fixed_points
def hermitianize_inverter(self, x, axes):
# calculate the number of dimensions the input array has
......
......@@ -161,19 +161,6 @@ class Space(DomainObject):
raise NotImplementedError(
"There is no generic co-smoothing kernel for Space base class.")
def hermitian_fixed_points(self):
""" Returns the array points which remain invariant under the action
of `hermitianize_inverter`
Returns
-------
list of index-tuples
The list contains the index-coordinates of the invariant points.
"""
return None
def hermitianize_inverter(self, x, axes):
""" Inverts/flips x in the context of Hermitian decomposition.
......
......@@ -67,6 +67,8 @@ class Test_Functionality(unittest.TestCase):
r2 = RGSpace(s2, harmonic=True, zerocenter=(z2,))
ra = RGSpace(s1+s2, harmonic=True, zerocenter=(z1, z2))
if preserve:
complexdata=True
v = np.random.random(s1+s2)
if complexdata:
v = v + 1j*np.random.random(s1+s2)
......
......@@ -127,7 +127,3 @@ class LMSpaceFunctionalityTests(unittest.TestCase):
def test_distance_array(self, lmax, expected):
l = LMSpace(lmax)
assert_almost_equal(l.get_distance_array('not').data, expected)
def test_hermitian_fixed_points(self):
x = LMSpace(5)
assert_equal(x.hermitian_fixed_points(), None)
......@@ -190,8 +190,3 @@ class RGSpaceFunctionalityTests(unittest.TestCase):
assert_almost_equal(res, expected)
if inplace:
assert_(x is res)
def test_hermitian_fixed_points(self):
x = RGSpace((5, 6, 5, 6), zerocenter=[False, False, True, True])
assert_equal(x.hermitian_fixed_points(),
[(0, 0, 2, 0), (0, 0, 2, 3), (0, 3, 2, 0), (0, 3, 2, 3)])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment