plotter.py 3.79 KB
Newer Older
1
2
3
4
5
# -*- coding: utf-8 -*-

import abc
import os

6
7
import numpy as np

8
9
import d2o

10
11
from keepers import Loggable

12
13
from nifty.config import dependency_injector as gdi

14
from nifty.spaces.space import Space
15
from nifty.field import Field
16
17
import nifty.nifty_utilities as utilities

18
19
20
from nifty.plotting.figures import MultiFigure

plotly = gdi.get('plotly')
21

22
23
try:
    plotly.offline.init_notebook_mode()
24
except AttributeError, ImportError:
25
    pass
26

27
28
rank = d2o.config.dependency_injector[
        d2o.configuration['mpi_module']].COMM_WORLD.rank
29

30
31
32
33

class Plotter(Loggable, object):
    __metaclass__ = abc.ABCMeta

34
    def __init__(self, interactive=False, path='.', title=""):
35
36
        if 'plotly' not in gdi:
            raise ImportError("The module plotly is needed but not available.")
37
38
        self.interactive = interactive
        self.path = path
39
        self.title = str(title)
40
41

    @abc.abstractproperty
42
    def domain_classes(self):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        return (Space,)

    @property
    def interactive(self):
        return self._interactive

    @interactive.setter
    def interactive(self, interactive):
        self._interactive = bool(interactive)

    @property
    def path(self):
        return self._path

    @path.setter
    def path(self, new_path):
        self._path = os.path.normpath(new_path)

61
62
63
64
65
    def plot(self, fields, spaces=None,  data_extractor=None, labels=None):
        if isinstance(fields, Field):
            fields = [fields]
        elif not isinstance(fields, list):
            fields = list(fields)
Mihai Baltac's avatar
Mihai Baltac committed
66

67
        spaces = utilities.cast_axis_to_tuple(spaces, len(fields[0].domain))
68

69
70
        if spaces is None:
            spaces = tuple(range(len(fields[0].domain)))
71

72
73
74
75
76
        axes = []
        plot_domain = []
        for space_index in spaces:
            axes += list(fields[0].domain_axes[space_index])
            plot_domain += [fields[0].domain[space_index]]
77

78
79
80
        # prepare data
        data_list = [self._get_data_from_field(field, spaces, data_extractor)
                     for field in fields]
81

82
83
84
85
86
87
88
        # create plots
        plots_list = []
        for slice_list in utilities.get_slice_list(data_list[0].shape, axes):
            plots_list += \
                    [[self._create_individual_plot(current_data[slice_list],
                                                   plot_domain)
                      for current_data in data_list]]
Mihai Baltac's avatar
Mihai Baltac committed
89

90
91
        figures = [self._create_individual_figure(plots)
                   for plots in plots_list]
92

93
94
95
96
97
98
99
100
        self._finalize_figure(figures)

    def _get_data_from_field(self, field, spaces, data_extractor):
        for i, space_index in enumerate(spaces):
            if not isinstance(field.domain[space_index],
                              self.domain_classes[i]):
                raise AttributeError("Given space(s) of input field-domain do "
                                     "not match the plotters domain.")
101

102
103
104
        # TODO: add data_extractor functionality here
        data = field.val.get_full_data(target_rank=0)
        return data
105

106
107
108
    @abc.abstractmethod
    def _create_individual_figure(self, plots):
        raise NotImplementedError
109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    @abc.abstractmethod
    def _create_individual_plot(self, data, fields):
        raise NotImplementedError

    def _finalize_figure(self, figures):
        if len(figures) > 1:
            rows = (len(figures) + 1)//2
            figure_array = np.empty((2*rows), dtype=np.object)
            figure_array[:len(figures)] = figures
            figure_array = figure_array.reshape((2, rows))

            final_figure = MultiFigure(rows, 2,
                                       title='Test',
                                       subfigures=figure_array)
        else:
            final_figure = figures[0]

127
        plotly.offline.plot(final_figure.to_plotly(),
128
                            filename=os.path.join(self.path, self.title))