diff --git a/compressed_sensing/visualizer.py b/compressed_sensing/visualizer.py index bd8a1393f630004beb56cdbe85e5bfeffb3a3258..9fd704efef5d54937f21e064b84d2ee829acc592 100644 --- a/compressed_sensing/visualizer.py +++ b/compressed_sensing/visualizer.py @@ -3,6 +3,7 @@ import ipywidgets as widgets from jupyter_jsmol import JsmolView import numpy as np + class Visualizer: def __init__(self, df_D, sisso, D_selected_df): @@ -110,8 +111,12 @@ class Visualizer: y_max = max(max(y_RS), max(y_ZB)) x_delta = 0.05 * abs(x_max - x_min) y_delta = 0.05 * abs(y_max - y_min) - self.fig.update_layout( + plot_bgcolor=self.bg_color, + font=dict( + size=int(self.font_size), + family=self.font_families[0] + ), xaxis_title=self.features[0], yaxis_title=self.features[1], xaxis_range=[x_min - x_delta, x_max + x_delta], @@ -135,20 +140,13 @@ class Visualizer: self.scatter_RS = self.fig.data[0] self.scatter_ZB = self.fig.data[1] self.scatter_line = self.fig.data[2] - - self.fig.update_layout( - plot_bgcolor=self.bg_color, - font=dict( - size=int(self.font_size), - family=self.font_families[0] - ) - ) self.RS_npoints = len(D_selected_df.loc[D_selected_df['Structure'] == 'RS']) self.ZB_npoints = len(D_selected_df.loc[D_selected_df['Structure'] == 'ZB']) - - self.scatter_RS.marker.symbol = [self.marker_symbol] * self.RS_npoints - self.scatter_ZB.marker.symbol = [self.marker_symbol] * self.ZB_npoints - self.set_markers_size() + self.RS_symbols = [self.marker_symbol] * self.RS_npoints + self.ZB_symbols = [self.marker_symbol] * self.ZB_npoints + self.RS_sizes = [self.marker_size] * self.RS_npoints + self.ZB_sizes = [self.marker_size] * self.ZB_npoints + self.update_markers() self.widg_featx = widgets.Dropdown( description='x-axis', @@ -296,6 +294,13 @@ class Visualizer: height=30, ) + def update_markers(self): + with self.fig.batch_update(): + self.scatter_RS.marker.size = self.RS_sizes + self.scatter_ZB.marker.size = self.ZB_sizes + self.scatter_RS.marker.symbol = self.RS_symbols + self.scatter_ZB.marker.symbol = self.ZB_symbols + def f_x(self, x): if self.current_features[0] == self.current_features[1]: return x @@ -312,8 +317,8 @@ class Visualizer: sizes_RS = [self.marker_size] * self.RS_npoints sizes_ZB = [self.marker_size] * self.ZB_npoints - symbols_RS = list(self.scatter_RS.marker.symbol) - symbols_ZB = list(self.scatter_ZB.marker.symbol) + symbols_RS = self.RS_symbols + symbols_ZB = self.ZB_symbols try: point = symbols_RS.index('x') @@ -333,9 +338,9 @@ class Visualizer: sizes_ZB[point] = self.cross_size except: pass - with self.fig.batch_update(): - self.scatter_RS.marker.size = sizes_RS - self.scatter_ZB.marker.size = sizes_ZB + + self.RS_sizes = sizes_RS + self.ZB_sizes = sizes_ZB else: min_value = min(min(self.D_selected_df.loc[self.D_selected_df['Structure'] == 'RS'][feature]), @@ -347,11 +352,10 @@ class Visualizer: feature] sizes_ZB = self.marker_size / 2 + coeff * self.D_selected_df.loc[self.D_selected_df['Structure'] == 'ZB'][ feature] - with self.fig.batch_update(): - self.scatter_RS.marker.size = sizes_RS - self.scatter_ZB.marker.size = sizes_ZB + self.RS_sizes = sizes_RS + self.ZB_sizes = sizes_ZB - def handle_xfeat_change (self, change): + def handle_xfeat_change(self, change): # changes the feature plotted on the x-axis # separating line is modified accordingly self.fig.update_layout( @@ -390,6 +394,7 @@ class Visualizer: def handle_markerfeat_change(self, change): self.set_markers_size(feature=change.new) + self.update_markers() def display_button_l_clicked(self, button): @@ -401,8 +406,9 @@ class Visualizer: "load data/compressed_sensing/structures/" + structure_l + "_structures/" + self.widg_compound_text_l.value + ".xyz") - symbols_RS = list(self.scatter_RS.marker.symbol) - symbols_ZB = list(self.scatter_ZB.marker.symbol) + symbols_RS = self.RS_symbols + symbols_ZB = self.ZB_symbols + try: point = symbols_RS.index('x') symbols_RS[point] = self.marker_symbol @@ -418,10 +424,11 @@ class Visualizer: if structure_l == 'ZB': point = np.where(self.scatter_ZB['text'] == self.widg_compound_text_l.value)[0][0] symbols_ZB[point] = 'x' - with self.fig.batch_update(): - self.scatter_RS.marker.symbol = symbols_RS - self.scatter_ZB.marker.symbol = symbols_ZB + + self.ZB_symbols = symbols_ZB + self.RS_symbols = symbols_RS self.set_markers_size(feature=self.widg_featmarker.value) + self.update_markers() def display_button_r_clicked(self, button): @@ -433,8 +440,9 @@ class Visualizer: "load data/compressed_sensing/structures/" + structure_r + "_structures/" + self.widg_compound_text_r.value + ".xyz") - symbols_RS = list(self.scatter_RS.marker.symbol) - symbols_ZB = list(self.scatter_ZB.marker.symbol) + symbols_RS = self.RS_symbols + symbols_ZB = self.ZB_symbols + try: point = symbols_RS.index('cross') symbols_RS[point] = self.marker_symbol @@ -450,10 +458,11 @@ class Visualizer: if structure_r == 'ZB': point = np.where(self.scatter_ZB['text'] == self.widg_compound_text_r.value)[0][0] symbols_ZB[point] = 'cross' - with self.fig.batch_update(): - self.scatter_RS.marker.symbol = symbols_RS - self.scatter_ZB.marker.symbol = symbols_ZB + + self.RS_symbols = symbols_RS + self.ZB_symbols = symbols_ZB self.set_markers_size(feature=self.widg_featmarker.value) + self.update_markers() def update_button_clicked(self, button): @@ -461,7 +470,7 @@ class Visualizer: self.cross_size = int(self.widg_crosssize.value) self.set_markers_size(feature=self.widg_featmarker.value) - + self.update_markers() try: self.scatter_RS.update(marker=dict(color=self.widg_rscolor.value)) except: @@ -505,9 +514,10 @@ class Visualizer: def reset_button_clicked(self, button): - self.scatter_RS.marker.symbol = [self.marker_symbol] * self.RS_npoints - self.scatter_ZB.marker.symbol = [self.marker_symbol] * self.ZB_npoints + self.RS_symbols = [self.marker_symbol] * self.RS_npoints + self.ZB_symbols = [self.marker_symbol] * self.ZB_npoints self.set_markers_size(self.widg_featmarker.value) + self.update_markers() def handle_checkbox_l(self, change): if change.new: @@ -538,8 +548,8 @@ class Visualizer: if not points.point_inds: return - symbols_RS = list(self.scatter_RS.marker.symbol) - symbols_ZB = list(self.scatter_ZB.marker.symbol) + symbols_RS = self.RS_symbols + symbols_ZB = self.ZB_symbols # The element previously marked with x/cross is marked with circle as default value if self.widg_checkbox_l.value: @@ -568,13 +578,12 @@ class Visualizer: if self.widg_checkbox_r.value: symbols_RS[points.point_inds[0]] = 'cross' - with self.fig.batch_update(): - self.scatter_RS.marker.symbol = symbols_RS - self.scatter_ZB.marker.symbol = symbols_ZB - + self.RS_symbols = symbols_RS + self.ZB_symbols = symbols_ZB self.set_markers_size(feature=self.widg_featmarker.value) - formula = trace['text'][points.point_inds[0]][0] + self.update_markers() + formula = trace['text'][points.point_inds[0]][0] if self.widg_checkbox_l.value: self.widg_compound_text_l.value = formula self.view_structure_RS_l(formula) @@ -586,8 +595,8 @@ class Visualizer: if not points.point_inds: return - symbols_RS = list(self.scatter_RS.marker.symbol) - symbols_ZB = list(self.scatter_ZB.marker.symbol) + symbols_RS = self.RS_symbols + symbols_ZB = self.ZB_symbols # The element previously marked with x/cross is marked with circle as default value if self.widg_checkbox_l.value: @@ -616,13 +625,12 @@ class Visualizer: if self.widg_checkbox_r.value: symbols_ZB[points.point_inds[0]] = 'cross' - with self.fig.batch_update(): - self.scatter_RS.marker.symbol = symbols_RS - self.scatter_ZB.marker.symbol = symbols_ZB - + self.RS_symbols = symbols_RS + self.ZB_symbols = symbols_ZB self.set_markers_size(feature=self.widg_featmarker.value) - formula = trace['text'][points.point_inds[0]][0] + self.update_markers() + formula = trace['text'][points.point_inds[0]][0] if self.widg_checkbox_l.value: self.widg_compound_text_l.value = formula self.view_structure_ZB_l(formula) @@ -630,7 +638,7 @@ class Visualizer: self.widg_compound_text_r.value = formula self.view_structure_ZB_r(formula) - def view(self): + def show(self): self.widg_featx.observe(self.handle_xfeat_change, names='value') self.widg_featy.observe(self.handle_yfeat_change, names='value') @@ -659,7 +667,6 @@ class Visualizer: box_print = widgets.HBox([self.widg_plot_name, self.widg_plot_format, self.widg_scale, self.widg_print_button]) box_features = widgets.HBox([self.widg_featx, self.widg_featy, self.widg_featmarker]) - container = widgets.VBox([box_print, box_features, self.fig, self.widg_description, widgets.HBox([ @@ -671,7 +678,7 @@ class Visualizer: [widgets.HBox([self.widg_compound_text_r, self.widg_display_button_r, self.widg_img2, self.widg_checkbox_r]), output_r]), - ]) + ]) ]) display(container)