diff --git a/descriptor_role/visualizer.py b/descriptor_role/visualizer.py index 56945fde9c95858b2897afc1f44c224179c99edf..673027d96b9e60846f776a8c35b264dcf69a0b1f 100644 --- a/descriptor_role/visualizer.py +++ b/descriptor_role/visualizer.py @@ -6,6 +6,7 @@ from IPython.display import display, HTML, FileLink import os import pandas as pd + class Visualizer: def __init__(self, df, sisso, feat_space): @@ -46,18 +47,18 @@ class Visualizer: self.sisso = sisso self.feat_space = feat_space self.total_features = sisso.n_dim - self.feat_val_list = [feat.value for feat in sisso.models[sisso.n_dim-1][0].feats] - self.features = [str(feat) for feat in sisso.models[sisso.n_dim-1][0].feats] + self.feat_val_list = [feat.value for feat in sisso.models[sisso.n_dim - 1][0].feats] + self.features = [str(feat) for feat in sisso.models[sisso.n_dim - 1][0].feats] self.df_selected = pd.DataFrame() for feat, values in zip(self.features, self.feat_val_list): self.df_selected[feat] = values self.df_selected['Structure'] = df.reset_index()['min_struc_type'] self.df_selected['Chem Formula'] = df.index.to_numpy() - self.coefficients = self.sisso.models[sisso.n_dim-1][0].coefs[0][:-1] - self.intercept = self.sisso.models[sisso.n_dim-1][0].coefs[0][-1] - self.target_predict = sisso.models[sisso.n_dim-1][0].fit - self.target_train = sisso.models[sisso.n_dim-1][0].prop_train + self.coefficients = self.sisso.models[sisso.n_dim - 1][0].coefs[0][:-1] + self.intercept = self.sisso.models[sisso.n_dim - 1][0].coefs[0][-1] + self.target_predict = sisso.models[sisso.n_dim - 1][0].fit + self.target_train = sisso.models[sisso.n_dim - 1][0].prop_train self.bg_toggle = True self.compounds_list = df.index.tolist() self.df_RS = df['min_struc_type'] == 'RS' @@ -170,7 +171,7 @@ class Visualizer: self.scatter_RS = self.fig.data[0] self.scatter_ZB = self.fig.data[1] self.scatter_line = self.fig.data[2] - if (self.total_features==2): + if self.total_features == 2: self.scatter_line.visible = True self.update_markers() @@ -187,7 +188,8 @@ class Visualizer: else: # print('x', self.coefficients[self.current_features[idx_x]]) # print('y', self.coefficients[self.current_features[idx_y]]) - line_y = -line_x * self.coefficients[idx_x] / self.coefficients[idx_y] - self.intercept / self.coefficients[idx_y] + line_y = -line_x * self.coefficients[idx_x] / self.coefficients[idx_y] - self.intercept / self.coefficients[ + idx_y] return line_x, line_y def update_markers(self): @@ -246,7 +248,7 @@ class Visualizer: self.RS_sizes = sizes_RS self.ZB_sizes = sizes_ZB - def make_colors(self, feature): + def make_colors(self, feature, gradient): if feature == 'Default color': @@ -257,20 +259,70 @@ class Visualizer: min_value = self.df_selected[feature].min() max_value = self.df_selected[feature].max() - shade_RS = 0.7*(self.df_selected.loc[self.df_selected['Structure'] == 'RS'][feature].to_numpy() - min_value)/\ (max_value-min_value) shade_ZB = 0.7*(self.df_selected.loc[self.df_selected['Structure'] == 'ZB'][feature].to_numpy() - min_value)/\ (max_value-min_value) - for i, e in enumerate(shade_RS): - value = 255*(0.7-e) - string = 'rgb('+str(value)+","+str(value)+","+str(value)+')' - self.RS_colors[i] = string - for i, e in enumerate(shade_ZB): - value = 255*(0.7-e) - string = 'rgb('+str(value)+","+str(value)+","+str(value)+')' - self.ZB_colors[i] = string + if gradient == 'Grey scale': + for i, e in enumerate(shade_RS): + value = 255*(0.7-e) + string = 'rgb('+str(value)+","+str(value)+","+str(value)+')' + self.RS_colors[i] = string + for i, e in enumerate(shade_ZB): + value = 255*(0.7-e) + string = 'rgb('+str(value)+","+str(value)+","+str(value)+')' + self.ZB_colors[i] = string + + if gradient == 'Purple scale': + for i, e in enumerate(shade_RS): + value = 255 * (0.7 - e) + string = 'rgb(' + str(value) + "," + str(0) + "," + str(value) + ')' + self.RS_colors[i] = string + for i, e in enumerate(shade_ZB): + value = 255 * (0.7 - e) + string = 'rgb(' + str(value) + "," + str(0) + "," + str(value) + ')' + self.ZB_colors[i] = string + + if gradient == 'Turquoise scale': + for i, e in enumerate(shade_RS): + value = 255 * (0.7 - e) + string = 'rgb(' + str(0) + "," + str(value) + "," + str(value) + ')' + self.RS_colors[i] = string + for i, e in enumerate(shade_ZB): + value = 255 * (0.7 - e) + string = 'rgb(' + str(0) + "," + str(value) + "," + str(value) + ')' + self.ZB_colors[i] = string + + if gradient == 'Red scale': + for i, e in enumerate(shade_RS): + value = 255 * (0.7 - e) + string = 'rgb(' + str(value) + "," + str(0) + "," + str(0) + ')' + self.RS_colors[i] = string + for i, e in enumerate(shade_ZB): + value = 255 * (0.7 - e) + string = 'rgb(' + str(value) + "," + str(0) + "," + str(0) + ')' + self.ZB_colors[i] = string + + if gradient == 'Blue scale': + for i, e in enumerate(shade_RS): + value = 255 * (0.7 - e) + string = 'rgb(' + str(0) + "," + str(0) + "," + str(value) + ')' + self.RS_colors[i] = string + for i, e in enumerate(shade_ZB): + value = 255 * (0.7 - e) + string = 'rgb(' + str(0) + "," + str(0) + "," + str(value) + ')' + self.ZB_colors[i] = string + + if gradient == 'Green scale': + for i, e in enumerate(shade_RS): + value = 255 * (0.7 - e) + string = 'rgb(' + str(0) + "," + str(value) + "," + str(0) + ')' + self.RS_colors[i] = string + for i, e in enumerate(shade_ZB): + value = 255 * (0.7 - e) + string = 'rgb(' + str(0) + "," + str(value) + "," + str(0) + ')' + self.ZB_colors[i] = string def handle_xfeat_change(self, change): # changes the feature plotted on the x-axis @@ -311,7 +363,17 @@ class Visualizer: self.update_markers() def handle_colorfeat_change(self, change): - self.make_colors(feature=change.new) + if change.new == 'Default color': + self.widg_gradient.layout.visibility = 'hidden' + self.RS_colors = [self.rs_color] * self.RS_npoints + self.ZB_colors = [self.zb_color] * self.ZB_npoints + else: + self.widg_gradient.layout.visibility = 'visible' + self.make_colors(feature=change.new, gradient=self.widg_gradient.value) + self.update_markers() + + def handle_gradient_change(self, change): + self.make_colors(feature=self.widg_featcolor.value, gradient=change.new) self.update_markers() def display_button_l_clicked(self, button): @@ -319,7 +381,7 @@ class Visualizer: # Actions are performed only if the string inserted in the text widget corresponds to an existing compound if self.widg_compound_text_l.value in self.df_selected['Chem Formula'].tolist(): structure_l = self.df_selected[self.df_selected['Chem Formula'] == - self.widg_compound_text_l.value]['Structure'].values[0] + self.widg_compound_text_l.value]['Structure'].values[0] self.viewer_l.script( "load data/descriptor_role/structures/" + structure_l + "_structures/" + self.widg_compound_text_l.value + ".xyz") @@ -353,7 +415,7 @@ class Visualizer: # Actions are performed only if the string inserted in the text widget corresponds to an existing compound if self.widg_compound_text_r.value in self.df_selected['Chem Formula'].tolist(): structure_r = self.df_selected[self.df_selected['Chem Formula'] == - self.widg_compound_text_r.value]['Structure'].values[0] + self.widg_compound_text_r.value]['Structure'].values[0] self.viewer_r.script( "load data/descriptor_role/structures/" + structure_r + "_structures/" + self.widg_compound_text_r.value + ".xyz") @@ -384,25 +446,26 @@ class Visualizer: def updatecolor_button_clicked(self, button): - try: - self.scatter_RS.update(marker=dict(color=self.widg_rscolor.value)) - self.rs_color = self.widg_rscolor.value - self.RS_colors = self.RS_npoints * [self.rs_color] - except: - pass - try: - self.scatter_ZB.update(marker=dict(color=self.widg_zbcolor.value)) - self.rs_color = self.widg_rscolor.value - self.RS_colors = self.RS_npoints * [self.rs_color] - except: - pass - - if self.bg_toggle: + if self.widg_featcolor.value == 'Default color': try: - self.fig.update_layout(plot_bgcolor=self.widg_bgcolor.value) - self.bg_color = self.widg_bgcolor.value + self.scatter_RS.update(marker=dict(color=self.widg_rscolor.value)) + self.rs_color = self.widg_rscolor.value + self.RS_colors = self.RS_npoints * [self.rs_color] except: pass + try: + self.scatter_ZB.update(marker=dict(color=self.widg_zbcolor.value)) + self.rs_color = self.widg_rscolor.value + self.RS_colors = self.RS_npoints * [self.rs_color] + except: + pass + + if self.bg_toggle: + try: + self.fig.update_layout(plot_bgcolor=self.widg_bgcolor.value) + self.bg_color = self.widg_bgcolor.value + except: + pass def handle_fontfamily_change(self, change): @@ -498,12 +561,14 @@ class Visualizer: def plotappearance_button_clicked(self, button): if self.widg_box_utils.layout.visibility == 'visible': self.widg_box_utils.layout.visibility = 'hidden' + for i in range(270, -1, -1): + self.widg_box_viewers.layout.top = str(i) + 'px' self.widg_box_utils.layout.bottom = '0px' - self.widg_box_viewers.layout.top = '0px' else: - self.widg_box_utils.layout.visibility = 'visible' + for i in range(271): + self.widg_box_viewers.layout.top = str(i) + 'px' self.widg_box_utils.layout.bottom = '400px' - self.widg_box_viewers.layout.top = '320px' + self.widg_box_utils.layout.visibility = 'visible' def handle_checkbox_l(self, change): if change.new: @@ -630,6 +695,7 @@ class Visualizer: self.widg_featy.observe(self.handle_yfeat_change, names='value') self.widg_featmarker.observe(self.handle_markerfeat_change, names='value') self.widg_featcolor.observe(self.handle_colorfeat_change, names='value') + self.widg_gradient.observe(self.handle_gradient_change, names='value') self.widg_checkbox_l.observe(self.handle_checkbox_l, names='value') self.widg_checkbox_r.observe(self.handle_checkbox_r, names='value') self.widg_display_button_l.on_click(self.display_button_l_clicked) @@ -659,6 +725,7 @@ class Visualizer: display(self.viewer_r) self.widg_box_utils.layout.visibility = 'hidden' + self.widg_gradient.layout.visibility = 'hidden' box_print = widgets.VBox([self.widg_printdescription, widgets.HBox([self.widg_empty, self.widg_print_button, self.widg_print_out]), widgets.HBox([self.widg_plot_name, self.widg_plot_format, self.widg_scale])]) @@ -666,7 +733,9 @@ class Visualizer: box_print.layout.max_width = '750px' box_feat = widgets.HBox([widgets.VBox([self.widg_featx, self.widg_featy]), - widgets.VBox([self.widg_featmarker, self.widg_featcolor]) + widgets.VBox([self.widg_featmarker, + widgets.HBox([self.widg_featcolor, self.widg_gradient]) + ]) ]) box_print.layout.height = '110px' @@ -677,7 +746,7 @@ class Visualizer: self.widg_box_utils.layout.border = 'dashed 1px' self.widg_box_utils.right = '100px' - self.widg_box_utils.layout.max_width = '750px' + self.widg_box_utils.layout.max_width = '700px' container = widgets.VBox([box_print, box_feat, self.fig, self.widg_plotutils_button, @@ -709,6 +778,17 @@ class Visualizer: options=['Default color'] + self.features, value='Default color' ) + self.widg_gradient = widgets.Dropdown( + description='-gradient', + options=['Grey scale', + 'Purple scale', + 'Turquoise scale', + 'Red scale', + 'Blue scale', + 'Green scale'], + value='Grey scale', + layout=widgets.Layout(width='150px', right='20px') + ) self.widg_compound_text_l = widgets.Combobox( placeholder='...', description='Compound:', @@ -744,69 +824,82 @@ class Visualizer: self.widg_markersize = widgets.BoundedIntText( placeholder=str(self.marker_size), description='Marker size', - value=str(self.marker_size) + value=str(self.marker_size), + layout=widgets.Layout(left='30px', width='200px') ) self.widg_crosssize = widgets.BoundedIntText( placeholder=str(self.cross_size), description='Cross size', - value=str(self.cross_size) + value=str(self.cross_size), + layout=widgets.Layout(left='30px', width='200px') ) self.widg_fontsize = widgets.BoundedIntText( placeholder=str(self.font_size), description='Font size', - value=str(self.font_size) + value=str(self.font_size), + layout = widgets.Layout(left='30px', width='200px') ) self.widg_linewidth = widgets.BoundedIntText( placeholder=str(self.line_width), description='Line width', - value=str(self.line_width) + value=str(self.line_width), + layout = widgets.Layout(left='30px', width='200px') ) self.widg_linestyle = widgets.Dropdown( options=self.line_styles, description='Line style', value=self.line_styles[0], + layout=widgets.Layout(left='30px', width='200px') ) self.widg_fontfamily = widgets.Dropdown( options=self.font_families, description='Font family', - value=self.font_families[0] - ) - self.widg_bgtoggle_button = widgets.Button( - description='Toggle on/off background', - layout=widgets.Layout(width='300px'), + value=self.font_families[0], + layout=widgets.Layout(left='30px', width='200px') ) + self.widg_bgcolor = widgets.Text( placeholder=str(self.bg_color), - description='Color', + description='Background', value=str(self.bg_color), + layout=widgets.Layout(left='30px', width='200px'), + ) self.widg_rscolor = widgets.Text( placeholder=str(self.rs_color), description='RS color', value=str(self.rs_color), + layout=widgets.Layout(left='30px', width='200px'), ) self.widg_zbcolor = widgets.Text( placeholder=str(self.zb_color), description='ZB color', value=str(self.zb_color), + layout=widgets.Layout(left='30px', width='200px'), ) self.widg_rsmarkersymbol = widgets.Dropdown( description='RS symbol', options=self.symbols, - value=self.marker_symbol_RS + value=self.marker_symbol_RS, + layout=widgets.Layout(left='30px', width='200px') ) self.widg_zbmarkersymbol = widgets.Dropdown( description='ZB symbol', options=self.symbols, - value=self.marker_symbol_ZB + value=self.marker_symbol_ZB, + layout=widgets.Layout(left='30px', width='200px') + ) + self.widg_bgtoggle_button = widgets.Button( + description='Toggle on/off background', + layout=widgets.Layout(left='50px', width='200px'), ) self.widg_updatecolor_button = widgets.Button( description='Update colors', - layout=widgets.Layout(width='150px') + layout=widgets.Layout(left='50px', width='200px') ) self.widg_reset_button = widgets.Button( description='Reset symbols', - layout=widgets.Layout(width='150px') + layout=widgets.Layout(left='50px',width='200px') ) self.widg_plot_name = widgets.Text( placeholder='plot', @@ -839,7 +932,7 @@ class Visualizer: value="Click 'Print' to export the plot in the desired format. The resolution of the image can be increased" " by increasing the 'Scale' value." ) - self.widg_featuredescription = widgets.Label( + self.widg_featuredescription = widgets.Label( value="The dropdown menus select the features to visualize." ) self.widg_description = widgets.Label( @@ -847,12 +940,15 @@ class Visualizer: 'structure selected in the map above.' ) self.widg_colordescription = widgets.Label( - value='In the boxes below, the colors used in the plot. Colors can be written as a text string, i.e. red, ' - 'green,..., or in a rgb/a, hex format. ' + value='Colors in the boxes below can be written as a text string, i.e. red, ' + 'green,..., or in a rgb/a, hex format. ', + layout=widgets.Layout(left='50px', width='640px') + ) self.widg_colordescription2 = widgets.Label( value="After modifying a specific field, click on the 'Update colors' button to display the changes in " - "the plot." + "the plot.", + layout=widgets.Layout(left='50px', width='640px') ) self.widg_printdescription = widgets.Label( value="Click 'Print' to export the plot in the desired format. The resolution of the image can be increased" @@ -862,14 +958,16 @@ class Visualizer: description='Toggle on/off the plot appearance utils', layout=widgets.Layout(width='600px') ) - self.widg_box_utils = widgets.VBox([widgets.HBox([self.widg_markersize, self.widg_crosssize]), - widgets.HBox([self.widg_linewidth, self.widg_linestyle]), - widgets.HBox([self.widg_fontsize, self.widg_fontfamily]), - widgets.HBox([self.widg_rsmarkersymbol, self.widg_zbmarkersymbol]), - self.widg_colordescription, self.widg_colordescription2, - widgets.HBox([self.widg_rscolor, self.widg_zbcolor]), - widgets.HBox([self.widg_bgtoggle_button, self.widg_bgcolor]), - widgets.HBox([self.widg_updatecolor_button, self.widg_reset_button])]) + self.widg_box_utils = widgets.VBox([widgets.HBox([self.widg_markersize, self.widg_crosssize, + self.widg_rsmarkersymbol]), + widgets.HBox([self.widg_linewidth, self.widg_linestyle, + self.widg_zbmarkersymbol]), + widgets.HBox([self.widg_fontsize, self.widg_fontfamily]), + self.widg_colordescription, self.widg_colordescription2, + widgets.HBox([self.widg_rscolor, self.widg_zbcolor, self.widg_bgcolor]), + widgets.HBox([self.widg_bgtoggle_button,self.widg_updatecolor_button, + self.widg_reset_button]) + ]) file1 = open("./assets/descriptor_role/cross.png", "rb") image1 = file1.read() @@ -891,15 +989,12 @@ class Visualizer: self.output_r = widgets.Output() self.widg_box_viewers = widgets.VBox([self.widg_description, widgets.HBox([ - widgets.VBox([ - widgets.HBox([self.widg_compound_text_l, self.widg_display_button_l, - self.widg_img1, self.widg_checkbox_l]), - self.output_l]), - widgets.VBox( - [widgets.HBox([self.widg_compound_text_r, self.widg_display_button_r, - self.widg_img2, self.widg_checkbox_r]), - self.output_r]) - ])]) - - - + widgets.VBox([ + widgets.HBox([self.widg_compound_text_l, self.widg_display_button_l, + self.widg_img1, self.widg_checkbox_l]), + self.output_l]), + widgets.VBox( + [widgets.HBox([self.widg_compound_text_r, self.widg_display_button_r, + self.widg_img2, self.widg_checkbox_r]), + self.output_r]) + ])])