Commit 8bb5b571 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'master' into 'docstring_operators'

Master

See merge request !107
parents 765ba57e 7bb482cf
Pipeline #12384 passed with stage
in 8 minutes and 41 seconds
from nifty import *
from mpi4py import MPI
import plotly.offline as py
import plotly.graph_objs as go
comm = MPI.COMM_WORLD
rank = comm.rank
def plot_maps(x, name):
trace = [None]*len(x)
keys = x.keys()
field = x[keys[0]]
domain = field.domain[0]
shape = len(domain.shape)
max_n = domain.shape[0]*domain.distances[0]
step = domain.distances[0]
x_axis = np.arange(0, max_n, step)
if shape == 1:
for ii in xrange(len(x)):
trace[ii] = go.Scatter(x= x_axis, y=x[keys[ii]].val.get_full_data(), name=keys[ii])
fig = go.Figure(data=trace)
py.plot(fig, filename=name)
elif shape == 2:
for ii in xrange(len(x)):
py.plot([go.Heatmap(z=x[keys[ii]].val.get_full_data().real)], filename=keys[ii])
else:
raise TypeError("Only 1D and 2D field plots are supported")
def plot_power(x, name):
layout = go.Layout(
xaxis=dict(
type='log',
autorange=True
),
yaxis=dict(
type='log',
autorange=True
)
)
trace = [None]*len(x)
keys = x.keys()
field = x[keys[0]]
domain = field.domain[0]
x_axis = domain.kindex
for ii in xrange(len(x)):
trace[ii] = go.Scatter(x= x_axis, y=x[keys[ii]].val.get_full_data(), name=keys[ii])
fig = go.Figure(data=trace, layout=layout)
py.plot(fig, filename=name)
np.random.seed(42)
if __name__ == "__main__":
distribution_strategy = 'not'
# setting spaces
npix = np.array([500]) # number of pixels
total_volume = 1. # total length
# setting signal parameters
lambda_s = .05 # signal correlation length
sigma_s = 10. # signal variance
#setting response operator parameters
length_convolution = .025
exposure = 1.
# calculating parameters
k_0 = 4. / (2 * np.pi * lambda_s)
a_s = sigma_s ** 2. * lambda_s * total_volume
# creation of spaces
# x1 = RGSpace([npix,npix], distances=total_volume / npix,
# zerocenter=False)
# k1 = RGRGTransformation.get_codomain(x1)
x1 = HPSpace(32)
k1 = HPLMTransformation.get_codomain(x1)
p1 = PowerSpace(harmonic_partner=k1, logarithmic=False)
# creating Power Operator with given spectrum
spec = (lambda k: a_s / (1 + (k / k_0) ** 2) ** 2)
p_field = Field(p1, val=spec)
S_op = create_power_operator(k1, spec)
# creating FFT-Operator and Response-Operator with Gaussian convolution
Fft_op = FFTOperator(domain=x1, target=k1,
domain_dtype=np.float64,
target_dtype=np.complex128)
R_op = ResponseOperator(x1, sigma=[length_convolution],
exposure=[exposure])
# drawing a random field
sk = p_field.power_synthesize(real_signal=True, mean=0.)
s = Fft_op.adjoint_times(sk)
signal_to_noise = 1
N_op = DiagonalOperator(R_op.target, diagonal=s.var()/signal_to_noise, bare=True)
n = Field.from_random(domain=R_op.target,
random_type='normal',
std=s.std()/np.sqrt(signal_to_noise),
mean=0.)
d = R_op(s) + n
# Wiener filter
j = Fft_op.times(R_op.adjoint_times(N_op.inverse_times(d)))
D = HarmonicPropagatorOperator(S=S_op, N=N_op, R=R_op)
mk = D(j)
m = Fft_op.adjoint_times(mk)
# z={}
# z["signal"] = s
# z["reconstructed_map"] = m
# z["data"] = d
# z["lambda"] = R_op(s)
# z["j"] = j
#
# plot_maps(z, "Wiener_filter.html")
......@@ -50,8 +50,8 @@ from spaces import *
from operators import *
from plotting import *
from probing import *
from sugar import *
import plotting
......@@ -24,6 +24,21 @@ from keepers import Loggable,\
class DomainObject(Versionable, Loggable, object):
"""The abstract class that can be used as a domain for a field.
This holds all the information and functionality a field needs to know
about its domain and how the data of the field are stored.
Attributes
----------
dim : int
Number of pixel-dimensions of the underlying data object.
shape : tuple
Shape of the array that stores the degrees of freedom for any field
on this domain.
"""
__metaclass__ = NiftyMeta
def __init__(self):
......
......@@ -280,49 +280,7 @@ class LinearOperator(Loggable, object):
return y
def inverse_adjoint_times(self, x, spaces=None, **kwargs):
""" Applies the inverse-adjoint Operator to a given Field.
Operator and Field have to live over the same domain.
Parameters
----------
x : NIFTY.Field
applies the Operator to the given Field
spaces : integer (default: None)
defines on which space of the given Field the Operator acts
**kwargs
Additional keyword arguments get passed to the used copy_empty
routine.
Returns
-------
out : NIFTy.Field
the processed Field living on the target space
See Also
--------
"""
if self.unitary:
return self.times(x, spaces, **kwargs)
spaces = self._check_input_compatibility(x, spaces)
try:
y = self._inverse_adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
try:
y = self._inverse_adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y
return self.adjoint_inverse_times(x, spaces, **kwargs)
def _times(self, x, spaces):
raise NotImplementedError(
......
# -*- coding: utf-8 -*-
import abc
import os
import plotly
from plotly import tools
import plotly.offline as ply
from keepers import Loggable
from nifty.spaces.space import Space
from nifty.field_types.field_type import FieldType
import nifty.nifty_utilities as utilities
plotly.offline.init_notebook_mode()
class Plotter(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, interactive=False, path='.', stack_subplots=False,
color_scale=None):
self.interactive = interactive
self.path = path
self.stack_subplots = stack_subplots
self.color_scale = color_scale
self.title = 'uiae'
@abc.abstractproperty
def domain(self):
return (Space,)
@abc.abstractproperty
def field_type(self):
return (FieldType,)
@property
def interactive(self):
return self._interactive
@interactive.setter
def interactive(self, interactive):
self._interactive = bool(interactive)
@property
def path(self):
return self._path
@path.setter
def path(self, new_path):
self._path = os.path.normpath(new_path)
@property
def stack_subplots(self):
return self._stack_subplots
@stack_subplots.setter
def stack_subplots(self, stack_subplots):
self._stack_subplots = bool(stack_subplots)
def plot(self, field, spaces=None, types=None, slice=None):
data = self._get_data_from_field(field, spaces, types, slice)
figures = self._create_individual_plot(data)
self._finalize_figure(figures)
@abc.abstractmethod
def _get_data_from_field(self, field, spaces=None, types=None, slice=None):
# if fields is a list, create a new field with appended
# field_type = field_array and copy individual parts into the new field
spaces = utilities.cast_axis_to_tuple(spaces, len(field.domain))
types = utilities.cast_axis_to_tuple(types, len(field.field_type))
if field.domain[spaces] != self.domain:
raise AttributeError("Given space(s) of input field-domain do not "
"match the plotters domain.")
if field.field_type[spaces] != self.field_type:
raise AttributeError("Given field_type(s) of input field-domain "
"do not match the plotters field_type.")
# iterate over the individual slices in order to compose the figure
# -> make a d2o.get_full_data() (for rank==0 only?)
# add
return [1,2,3]
def _create_individual_plot(self, data):
pass
def _finalize_figure(self, figure):
pass
# is there a use for ply.plot when one has no interest in
# saving a file?
# -> check for different file types
# -> store the file to disk (MPI awareness?)
from descriptors import *
from plots import *
from figures import *
from colormap import *
\ No newline at end of file
from colormap import *
from plotter import *
......@@ -13,8 +13,9 @@ class Colormap(PlotlyWrapper):
#TODO: implement validation
pass
# no discontinuities only
@staticmethod
def from_matplotlib_colormap_internal(name, mpl_cmap): # no discontinuities only
def from_matplotlib_colormap_internal(name, mpl_cmap):
red = [(c[0], c[2]) for c in mpl_cmap['red']]
green = [(c[0], c[2]) for c in mpl_cmap['green']]
blue = [(c[0], c[2]) for c in mpl_cmap['blue']]
......@@ -40,7 +41,8 @@ class Colormap(PlotlyWrapper):
green_val = self.green[g][1]
g += 1
else:
slope = (self.green[g][1] - prev_g) / (self.green[g][0] - prev_split)
slope = ((self.green[g][1] - prev_g) /
(self.green[g][0] - prev_split))
y = prev_g - slope * prev_split
green_val = slope * next_split + y
......@@ -48,11 +50,16 @@ class Colormap(PlotlyWrapper):
blue_val = self.blue[b][1]
b += 1
else:
slope = (self.blue[b][1] - prev_b) / (self.blue[b][0] - prev_split)
slope = ((self.blue[b][1] - prev_b) /
(self.blue[b][0] - prev_split))
y = prev_r - slope * prev_split
blue_val = slope * next_split + y
prev_split, prev_r, prev_g, prev_b = next_split, red_val, green_val, blue_val
prev_split = next_split
prev_r = red_val
prev_g = green_val
prev_b = blue_val
converted.append([next_split,
'rgb(' +
str(int(red_val*255)) + "," +
......
......@@ -31,29 +31,46 @@ class MultiFigure(FigureBase):
def to_plotly(self):
title_extractor = lambda z: z.title if z else ""
sub_titles = \
tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
sub_titles = tuple(np.vectorize(title_extractor)(self.subfigures.flatten()))
specs_setter = \
lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {}
sub_specs = \
list(map(list, np.vectorize(specs_setter)(self.subfigures)))
specs_setter = lambda z: {'is_3d': True} if isinstance(z, Figure3D) else {}
sub_specs = list(map(list, np.vectorize(specs_setter)(self.subfigures)))
multi_figure_plotly_object = plotly.tools.make_subplots(
self.rows,
self.columns,
subplot_titles=sub_titles,
specs=sub_specs)
# TODO resolve bug with titles and 3D subplots
multi_figure_plotly_object['layout'].update(height=self.height,
width=self.width,
title=self.title)
#TODO resolve bug with titles and 3D subplots
i = 1
for index, fig in np.ndenumerate(self.subfigures):
if fig:
for plot in fig.plots:
multi_figure_plotly_object.append_trace(plot.to_plotly(),
index[0]+1,
index[1]+1)
if isinstance(fig, Figure3D):
scene = dict()
if fig.xaxis:
scene['xaxis'] = fig.xaxis.to_plotly()
if fig.yaxis:
scene['yaxis'] = fig.yaxis.to_plotly()
if fig.zaxis:
scene['zaxis'] = fig.zaxis.to_plotly()
multi_figure_plotly_object['layout'].update(height=self.height,
width=self.width,
title=self.title)
multi_figure_plotly_object['layout']['scene'+str(i)] = scene
else:
if fig.xaxis:
multi_figure_plotly_object['layout']['xaxis'+str(i)] = fig.xaxis.to_plotly()
if fig.yaxis:
multi_figure_plotly_object['layout']['yaxis'+str(i)] = fig.yaxis.to_plotly()
i += 1
return multi_figure_plotly_object
# -*- coding: utf-8 -*-
import numpy as np
from nifty.plotting.colormap import Colormap
from nifty.plotting.plotly_wrapper import PlotlyWrapper
class Heatmap(PlotlyWrapper):
def __init__(self, data, color_map=None, webgl=False,
smoothing=False): # smoothing 'best', 'fast', False
self.data = data
if isinstance(data, list):
self.data = np.zeros((data[0].shape))
for arr in data:
self.data = np.add(self.data, arr)
else:
self.data = data
if color_map is not None:
if not isinstance(color_map, Colormap):
raise TypeError("Provided color_map must be an instance of "
"the NIFTy Colormap class.")
self.color_map = color_map
self.webgl = webgl
self.smoothing = smoothing
@property
def figure_dimension(self):
return 2
def to_plotly(self):
plotly_object = dict()
plotly_object['z'] = self.data
......
......@@ -14,8 +14,10 @@ class Mollweide(Heatmap):
raise ImportError("The module pylab is needed but not available.")
if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available.")
data = self._mollview(data)
if isinstance(data, list):
data = [self._mollview(d) for d in data]
else:
data = self._mollview(data)
super(Mollweide, self).__init__(data, color_map, webgl, smoothing)
def _mollview(self, x, xsize=800):
......
......@@ -4,14 +4,16 @@ from scatter_plot import ScatterPlot
class Cartesian(ScatterPlot):
def __init__(self, x, y, label, line, marker):
def __init__(self, x, y, label, line, marker, showlegend=True):
super(Cartesian, self).__init__(label, line, marker)
self.x = x
self.y = y
self.showlegend = showlegend
@abstractmethod
def to_plotly(self):
plotly_object = super(Cartesian, self).to_plotly()
plotly_object['x'] = self.x
plotly_object['y'] = self.y
plotly_object['showlegend'] = self.showlegend
return plotly_object
......@@ -5,14 +5,20 @@ from cartesian import Cartesian
class Cartesian2D(Cartesian):
def __init__(self, x=None, y=None, x_start=0, x_step=1,
label='', line=None, marker=None, webgl=True):
label='', line=None, marker=None, showlegend=True,
webgl=True):
if y is None:
raise Exception('Error: no y data to plot')
if x is None:
x = range(x_start, len(y) * x_step, x_step)
super(Cartesian2D, self).__init__(x, y, label, line, marker)
super(Cartesian2D, self).__init__(x, y, label, line, marker,
showlegend)
self.webgl = webgl
@property
def figure_dimension(self):
return 2
def to_plotly(self):
plotly_object = super(Cartesian2D, self).to_plotly()
if self.webgl:
......
......@@ -4,10 +4,16 @@ from cartesian import Cartesian
class Cartesian3D(Cartesian):
def __init__(self, x, y, z, label='', line=None, marker=None):
super(Cartesian3D, self).__init__(x, y, label, line, marker)
def __init__(self, x, y, z, label='', line=None, marker=None,
showlegend=True):
super(Cartesian3D, self).__init__(x, y, label, line, marker,
showlegend)
self.z = z
@property
def figure_dimension(self):
return 3
def to_plotly(self):
plotly_object = super(Cartesian3D, self).to_plotly()
plotly_object['z'] = self.z
......
from scatter_plot import ScatterPlot
......@@ -13,6 +14,10 @@ class Geo(ScatterPlot):
self.lat = lat
self.projection = proj
@property
def figure_dimension(self):
return 2
def _to_plotly(self):
plotly_object = super(Geo, self).to_plotly()
plotly_object['type'] = 'scattergeo'
......
# -*- coding: utf-8 -*-
from abc import abstractmethod
import abc
from nifty.plotting.plotly_wrapper import PlotlyWrapper
from nifty.plotting.descriptors import Marker
from nifty.plotting.descriptors import Marker,\
Line
class ScatterPlot(PlotlyWrapper):
......@@ -12,8 +13,13 @@ class ScatterPlot(PlotlyWrapper):
self.marker = marker
if not self.line and not self.marker:
self.marker = Marker()
self.line = Line()
@abstractmethod
@abc.abstractproperty
def figure_dimension(self):
raise NotImplementedError
@abc.abstractmethod
def to_plotly(self):
ply_object = dict()
ply_object['name'] = self.label
......
from healpix_plotter import HealpixPlotter
from power_plotter import PowerPlotter