heatmap.py 2.39 KB
Newer Older
1
# -*- coding: utf-8 -*-
2

3
import numpy as np
4

5
from nifty.plotting.descriptors import Axis
6
from nifty.plotting.colormap import Colormap
7
from nifty.plotting.plotly_wrapper import PlotlyWrapper
8

9

10
class Heatmap(PlotlyWrapper):
11 12
    def __init__(self, data, color_map=None, webgl=False, smoothing=False,
                 zmin=None, zmax=None):
13
        # smoothing 'best', 'fast', False
14 15 16 17 18

        if color_map is not None:
            if not isinstance(color_map, Colormap):
                raise TypeError("Provided color_map must be an instance of "
                                "the NIFTy Colormap class.")
19
        self.color_map = color_map
20 21
        self.webgl = webgl
        self.smoothing = smoothing
22
        self.data = data
23 24
        self.zmin = zmin
        self.zmax = zmax
Theo Steininger's avatar
Theo Steininger committed
25
        self._font_size = 22
26
        self._font_family = 'Balto'
27 28 29 30 31 32 33 34 35 36 37

    def at(self, data):
        if isinstance(data, list):
            temp_data = np.zeros((data[0].shape))
            for arr in data:
                temp_data = np.add(temp_data, arr)
        else:
            temp_data = data
        return Heatmap(data=temp_data,
                       color_map=self.color_map,
                       webgl=self.webgl,
38 39 40
                       smoothing=self.smoothing,
                       zmin=self.zmin,
                       zmax=self.zmax)
41

42 43 44 45
    @property
    def figure_dimension(self):
        return 2

46
    def to_plotly(self):
47
        plotly_object = dict()
48

49
        plotly_object['z'] = self.data
50 51
        plotly_object['zmin'] = self.zmin
        plotly_object['zmax'] = self.zmax
52

53
        plotly_object['showscale'] = True
Theo Steininger's avatar
Theo Steininger committed
54
        plotly_object['colorbar'] = {'tickfont': {'size': self._font_size,
55 56
                                                  'family': self._font_family},
                                     'exponentformat': 'power'}
57 58
        if self.color_map:
            plotly_object['colorscale'] = self.color_map.to_plotly()
59 60 61 62 63 64 65
        if self.webgl:
            plotly_object['type'] = 'heatmapgl'
        else:
            plotly_object['type'] = 'heatmap'
        if self.smoothing:
            plotly_object['zsmooth'] = self.smoothing
        return plotly_object
66 67 68 69 70

    def default_width(self):
        return 700

    def default_height(self):
71
        (y, x) = self.data.shape
72 73 74 75 76
        return int(700 * y / x)

    def default_axes(self):
        return (Axis(font_size=self._font_size),
                Axis(font_size=self._font_size))