Visualise.py 4.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
import io
import uuid

from nglview import register_backend, Structure
from ipywidgets import Dropdown, FloatSlider, IntSlider, HBox, VBox, Output

import matplotlib.pyplot as plt
import numpy as np

    

@register_backend('ase')
class MyASEStructure(Structure):
    def __init__(self, atoms, bfactor=[], occupancy=[]):
        # super(MyASEStructure, self).__init__()
        self.ext = 'pdb'
        self.params = {}
        self._atoms = atoms
        self.bfactor = bfactor  # [min, max]
        self.occupancy = occupancy  # [0, 1]
        self.id = str(uuid.uuid4())

    def get_structure_string(self):
        """ PDB file format:
        CRYST1   16.980   62.517  124.864  90.00  90.00  90.00 P 1
        MODEL     1
        ATOM      0   Fe MOL     1      15.431  60.277   6.801  1.00  0.00          FE
        ATOM      1   Fe MOL     1       1.273   3.392  93.940  1.00  0.00          FE
        """
        # with io.StringIO() as stream:
        #     self.structure.write(stream, format='proteindatabank')
        #     data = stream.getvalue()

        data = ""

        if self._atoms.get_pbc().any():
            cellpar = self._atoms.get_cell_lengths_and_angles()

            str_format = 'CRYST1' + '{:9.3f}' * 3 + '{:7.2f}' * 3 + ' P 1\n'
            data += str_format.format(*cellpar.tolist())

        data += 'MODEL     1\n'

        str_format = 'ATOM  {:5d} {:>4s} MOL     1    {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f}          {:2s}\n'
        for index, atom in enumerate(self._atoms):
            data += str_format.format(
                index,
                atom.symbol,
                atom.position[0].tolist(),
                atom.position[1].tolist(),
                atom.position[2].tolist(),
                self.occupancy[index] if index <= len(self.occupancy) - 1 else 1.0,
                self.bfactor[index] if index <= len(self.bfactor) - 1 else 1.0,
                atom.symbol.upper()
            )

        data += 'ENDMDL\n'

        return data


def ViewStructure(atoms):
    import nglview

    view = nglview.NGLWidget()
    
    structure = MyASEStructure(atoms)
    view.add_structure(structure)
    
    return view



class AtomViewer(object):
    def __init__(self, atoms, data=[], xsize=1000, ysize=500):
        self.view = self._init_nglview(atoms, data, xsize, ysize)

        self.widgets = {
            'radius': FloatSlider(
                value=0.8, min=0.0, max=1.5, step=0.01,
                description='Ball size'
            ),
            'color_scheme': Dropdown(description='Solor scheme:'),
            'colorbar': Output()
        }
        self.show_colorbar(data)

        self.widgets['radius'].observe(self._update_repr)

        self.gui = VBox([
            self.view,
            self.widgets['colorbar'],
            self.widgets['radius']])

    def _update_repr(self, chg=None):
        self.view.update_spacefill(
            radiusType='radius',
            radius=self.widgets['radius'].value
        )

    def show_colorbar(self, data):
        with self.widgets['colorbar']:
            # Have colormaps separated into categories:
            # http://matplotlib.org/examples/color/colormaps_reference.html
            cmap = 'rainbow'

            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 2))
            img = ax1.imshow([[min(data), max(data)]], aspect='auto', cmap=plt.get_cmap(cmap))
            ax1.remove()
            cbar = fig.colorbar(img, cax=ax2, orientation='horizontal')

            plt.show()

    @staticmethod
    def _init_nglview(atoms, data, xsize, ysize):
        import nglview

        view = nglview.NGLWidget(gui=False)
        view._remote_call(
            'setSize',
            target='Widget',
            args=[
                '{:d}px'.format(xsize),
                '{:d}px'.format(ysize)
            ]
        )

        data = np.max(data)-data

        structure = MyASEStructure(atoms, bfactor=data)
        view.add_structure(structure)

        view.clear_representations()
        view.add_unitcell()

        view.add_spacefill(
            # radiusType='radius',
            # radius=1.0,
            color_scheme='bfactor',
            color_scale='rainbow'
        )
        view.update_spacefill(
            radiusType='radius',
            radius=1.0
        )


        # update camera type
        view.control.spin([1, 0, 0], np.pi / 2)
        view.control.spin([0, 0, 1], np.pi / 2)
        view.camera = 'orthographic'
        view.center()

        return view