plotter_base.py 3.94 KB
Newer Older
1 2 3 4
# -*- coding: utf-8 -*-

import abc
import os
5
import sys
6

7 8
import numpy as np

9 10
import d2o

11 12
from keepers import Loggable

13 14
from nifty.config import dependency_injector as gdi

15
from nifty.spaces.space import Space
16
from nifty.field import Field
17 18
import nifty.nifty_utilities as utilities

19 20 21
from nifty.plotting.figures import MultiFigure

plotly = gdi.get('plotly')
22

23
if plotly is not None and 'IPython' in sys.modules:
24
    plotly.offline.init_notebook_mode()
25

26 27
rank = d2o.config.dependency_injector[
        d2o.configuration['mpi_module']].COMM_WORLD.rank
28

29

Theo Steininger's avatar
Theo Steininger committed
30
class PlotterBase(Loggable, object):
31 32
    __metaclass__ = abc.ABCMeta

33
    def __init__(self, interactive=False, path='.', title=""):
34
        if plotly is None:
35
            raise ImportError("The module plotly is needed but not available.")
36 37
        self.interactive = interactive
        self.path = path
38
        self.title = str(title)
39

Theo Steininger's avatar
Theo Steininger committed
40 41 42 43
        self.plot = self._initialize_plot()
        self.figure = self._initialize_figure()
        self.multi_figure = self._initialize_multifigure()

44
    @abc.abstractproperty
45
    def domain_classes(self):
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        return (Space,)

    @property
    def interactive(self):
        return self._interactive

    @interactive.setter
    def interactive(self, interactive):
        self._interactive = bool(interactive)

    @property
    def path(self):
        return self._path

    @path.setter
    def path(self, new_path):
        self._path = os.path.normpath(new_path)

Theo Steininger's avatar
Theo Steininger committed
64
    def __call__(self, fields, spaces=None,  data_extractor=None, labels=None):
65 66 67 68
        if isinstance(fields, Field):
            fields = [fields]
        elif not isinstance(fields, list):
            fields = list(fields)
Mihai Baltac's avatar
Mihai Baltac committed
69

70
        spaces = utilities.cast_axis_to_tuple(spaces, len(fields[0].domain))
71

72 73
        if spaces is None:
            spaces = tuple(range(len(fields[0].domain)))
74

75 76 77 78 79
        axes = []
        plot_domain = []
        for space_index in spaces:
            axes += list(fields[0].domain_axes[space_index])
            plot_domain += [fields[0].domain[space_index]]
80

81 82 83
        # prepare data
        data_list = [self._get_data_from_field(field, spaces, data_extractor)
                     for field in fields]
84

85 86 87 88
        # create plots
        plots_list = []
        for slice_list in utilities.get_slice_list(data_list[0].shape, axes):
            plots_list += \
Theo Steininger's avatar
Theo Steininger committed
89 90 91 92
                    [[self.plot.at(self._parse_data(current_data,
                                                     field,
                                                     spaces))
                      for (current_data, field) in zip(data_list, fields)]]
Mihai Baltac's avatar
Mihai Baltac committed
93

Theo Steininger's avatar
Theo Steininger committed
94
        figures = [self.figure.at(plots) for plots in plots_list]
95

96 97 98 99 100 101 102 103
        self._finalize_figure(figures)

    def _get_data_from_field(self, field, spaces, data_extractor):
        for i, space_index in enumerate(spaces):
            if not isinstance(field.domain[space_index],
                              self.domain_classes[i]):
                raise AttributeError("Given space(s) of input field-domain do "
                                     "not match the plotters domain.")
104

105 106 107
        # TODO: add data_extractor functionality here
        data = field.val.get_full_data(target_rank=0)
        return data
108

109
    @abc.abstractmethod
Theo Steininger's avatar
Theo Steininger committed
110
    def _initialize_plot(self):
111
        raise NotImplementedError
112

113
    @abc.abstractmethod
Theo Steininger's avatar
Theo Steininger committed
114
    def _initialize_figure(self):
115 116
        raise NotImplementedError

Theo Steininger's avatar
Theo Steininger committed
117 118 119
    def _initialize_multifigure(self):
        return MultiFigure(subfigures=None)

120 121 122 123 124 125 126
    def _finalize_figure(self, figures):
        if len(figures) > 1:
            rows = (len(figures) + 1)//2
            figure_array = np.empty((2*rows), dtype=np.object)
            figure_array[:len(figures)] = figures
            figure_array = figure_array.reshape((2, rows))

Theo Steininger's avatar
Theo Steininger committed
127
            final_figure = self.multi_figure(subfigures=figure_array)
128 129 130
        else:
            final_figure = figures[0]

131
        plotly.offline.plot(final_figure.to_plotly(),
132
                            filename=os.path.join(self.path, self.title))