plotter_base.py 3.94 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-

import abc
import os
5
import sys
6

7
8
import numpy as np

9
10
import d2o

11
12
from keepers import Loggable

13
14
from nifty.config import dependency_injector as gdi

15
from nifty.spaces.space import Space
16
from nifty.field import Field
17
18
import nifty.nifty_utilities as utilities

19
20
21
from nifty.plotting.figures import MultiFigure

plotly = gdi.get('plotly')
22

23
if plotly is not None and 'IPython' in sys.modules:
24
    plotly.offline.init_notebook_mode()
25

26
27
rank = d2o.config.dependency_injector[
        d2o.configuration['mpi_module']].COMM_WORLD.rank
28

29

Theo Steininger's avatar
Theo Steininger committed
30
class PlotterBase(Loggable, object):
31
32
    __metaclass__ = abc.ABCMeta

33
    def __init__(self, interactive=False, path='.', title=""):
34
        if plotly is None:
35
            raise ImportError("The module plotly is needed but not available.")
36
37
        self.interactive = interactive
        self.path = path
38
        self.title = str(title)
39

Theo Steininger's avatar
Theo Steininger committed
40
41
42
43
        self.plot = self._initialize_plot()
        self.figure = self._initialize_figure()
        self.multi_figure = self._initialize_multifigure()

44
    @abc.abstractproperty
45
    def domain_classes(self):
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        return (Space,)

    @property
    def interactive(self):
        return self._interactive

    @interactive.setter
    def interactive(self, interactive):
        self._interactive = bool(interactive)

    @property
    def path(self):
        return self._path

    @path.setter
    def path(self, new_path):
        self._path = os.path.normpath(new_path)

Theo Steininger's avatar
Theo Steininger committed
64
    def __call__(self, fields, spaces=None,  data_extractor=None, labels=None):
65
66
67
68
        if isinstance(fields, Field):
            fields = [fields]
        elif not isinstance(fields, list):
            fields = list(fields)
Mihai Baltac's avatar
Mihai Baltac committed
69

70
        spaces = utilities.cast_axis_to_tuple(spaces, len(fields[0].domain))
71

72
73
        if spaces is None:
            spaces = tuple(range(len(fields[0].domain)))
74

75
76
77
78
79
        axes = []
        plot_domain = []
        for space_index in spaces:
            axes += list(fields[0].domain_axes[space_index])
            plot_domain += [fields[0].domain[space_index]]
80

81
82
83
        # prepare data
        data_list = [self._get_data_from_field(field, spaces, data_extractor)
                     for field in fields]
84

85
86
87
88
        # create plots
        plots_list = []
        for slice_list in utilities.get_slice_list(data_list[0].shape, axes):
            plots_list += \
Theo Steininger's avatar
Theo Steininger committed
89
90
91
92
                    [[self.plot.at(self._parse_data(current_data,
                                                     field,
                                                     spaces))
                      for (current_data, field) in zip(data_list, fields)]]
Mihai Baltac's avatar
Mihai Baltac committed
93

Theo Steininger's avatar
Theo Steininger committed
94
        figures = [self.figure.at(plots) for plots in plots_list]
95

96
97
98
99
100
101
102
103
        self._finalize_figure(figures)

    def _get_data_from_field(self, field, spaces, data_extractor):
        for i, space_index in enumerate(spaces):
            if not isinstance(field.domain[space_index],
                              self.domain_classes[i]):
                raise AttributeError("Given space(s) of input field-domain do "
                                     "not match the plotters domain.")
104

105
106
107
        # TODO: add data_extractor functionality here
        data = field.val.get_full_data(target_rank=0)
        return data
108

109
    @abc.abstractmethod
Theo Steininger's avatar
Theo Steininger committed
110
    def _initialize_plot(self):
111
        raise NotImplementedError
112

113
    @abc.abstractmethod
Theo Steininger's avatar
Theo Steininger committed
114
    def _initialize_figure(self):
115
116
        raise NotImplementedError

Theo Steininger's avatar
Theo Steininger committed
117
118
119
    def _initialize_multifigure(self):
        return MultiFigure(subfigures=None)

120
121
122
123
124
125
126
    def _finalize_figure(self, figures):
        if len(figures) > 1:
            rows = (len(figures) + 1)//2
            figure_array = np.empty((2*rows), dtype=np.object)
            figure_array[:len(figures)] = figures
            figure_array = figure_array.reshape((2, rows))

Theo Steininger's avatar
Theo Steininger committed
127
            final_figure = self.multi_figure(subfigures=figure_array)
128
129
130
        else:
            final_figure = figures[0]

131
        plotly.offline.plot(final_figure.to_plotly(),
132
                            filename=os.path.join(self.path, self.title))