From 991c797ce17f22ca13d7cb3e5c37ce3a8408b0ee Mon Sep 17 00:00:00 2001 From: Luigi Sbailo <sbailo@fhi-berlin.mpg.de> Date: Mon, 28 Sep 2020 18:31:53 +0200 Subject: [PATCH] First prototype of the notebook --- tetradymite_PRM2020.ipynb | 539 ++++++++++++++++++++---------- tetradymite_PRM2020/visualizer.py | 55 +-- 2 files changed, 398 insertions(+), 196 deletions(-) diff --git a/tetradymite_PRM2020.ipynb b/tetradymite_PRM2020.ipynb index 384ff31..6fd6baf 100644 --- a/tetradymite_PRM2020.ipynb +++ b/tetradymite_PRM2020.ipynb @@ -45,8 +45,71 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T12:14:16.892451Z", - "start_time": "2020-09-26T12:14:16.862246Z" + "end_time": "2020-09-28T16:29:21.821237Z", + "start_time": "2020-09-28T16:29:21.811355Z" + }, + "init_cell": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "<script>\n", + " code_show=true; \n", + " function code_toggle() {\n", + " if (code_show)\n", + " {\n", + " $('div.input').hide();\n", + " } \n", + " else \n", + " {\n", + " $('div.input').show();\n", + " }\n", + " code_show = !code_show\n", + " } \n", + " $( document ).ready(code_toggle);\n", + " window.runCells(\"startup\");\n", + "</script>\n", + "The raw code for this notebook is by default hidden for easier reading.\n", + "To toggle on/off the raw code, click <a href=\"javascript:code_toggle()\">here</a>.\n" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%HTML\n", + "<script>\n", + " code_show=true; \n", + " function code_toggle() {\n", + " if (code_show)\n", + " {\n", + " $('div.input').hide();\n", + " } \n", + " else \n", + " {\n", + " $('div.input').show();\n", + " }\n", + " code_show = !code_show\n", + " } \n", + " $( document ).ready(code_toggle);\n", + " window.runCells(\"startup\");\n", + "</script>\n", + "The raw code for this notebook is by default hidden for easier reading.\n", + "To toggle on/off the raw code, click <a href=\"javascript:code_toggle()\">here</a>." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2020-09-28T15:50:43.008324Z", + "start_time": "2020-09-28T15:50:42.991266Z" } }, "outputs": [], @@ -60,9 +123,10 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T12:14:18.033512Z", - "start_time": "2020-09-26T12:14:17.339188Z" - } + "end_time": "2020-09-28T16:29:22.312658Z", + "start_time": "2020-09-28T16:29:21.823064Z" + }, + "init_cell": true }, "outputs": [], "source": [ @@ -78,25 +142,11 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T12:14:36.475512Z", - "start_time": "2020-09-26T12:14:18.035077Z" + "end_time": "2020-09-28T15:50:45.224040Z", + "start_time": "2020-09-28T15:50:45.202499Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number queries entries: 1581\n", - "Number of entries loaded in the last api call: 100\n", - "Bytes loaded in the last api call: 502327341\n", - "Bytes loaded from this query: 502327341\n", - "Number of downloaded entries: 100\n", - "Number of made api calls: 1\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "from nomad import client, config\n", "config.client.url = 'http://nomad-lab.eu/prod/rae/api'\n", @@ -109,26 +159,28 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T12:16:04.914649Z", - "start_time": "2020-09-26T12:14:36.477143Z" - } + "end_time": "2020-09-28T16:29:22.335555Z", + "start_time": "2020-09-28T16:29:22.314325Z" + }, + "init_cell": true }, "outputs": [], "source": [ - "df_train = pd.read_csv(\"./data/tetradymite_PRM2020/train.csv\", index_col=0).astype(float)" + "df_train = pd.read_pickle('./data/tetradymite_PRM2020/training_set')" ] }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T15:03:37.344693Z", - "start_time": "2020-09-26T15:03:37.315508Z" - } + "end_time": "2020-09-28T15:50:51.751396Z", + "start_time": "2020-09-28T15:50:51.727769Z" + }, + "scrolled": true }, "outputs": [], "source": [ @@ -136,30 +188,15 @@ "try:\n", " os.mkdir(path_structure)\n", "except OSError:\n", - " !rm ./data/tetradymite_PRM2020/structures/*" - ] - }, - { - "cell_type": "code", - "execution_count": 137, - "metadata": { - "ExecuteTime": { - "end_time": "2020-09-26T15:11:48.378665Z", - "start_time": "2020-09-26T15:11:45.885626Z" - }, - "scrolled": true - }, - "outputs": [], - "source": [ + " !rm ./data/tetradymite_PRM2020/structures/*\n", "compounds=df_train.index.to_list()\n", "scale_factor = 10**10\n", "alist = []\n", "for compound in compounds:\n", " for entry in range (1581):\n", - " \n", " labels = query[entry].section_run[0].section_system[-1].atom_labels\n", " if (len(labels)>5):\n", - " continue\n", + " continue\n", " \n", " labels_1 = str(labels[0])+'_'+str(labels[1])+'_'+str(labels[3])+'_'+str(labels[4])+'_'+str(labels[2])\n", " labels_2 = str(labels[0])+'_'+str(labels[1])+'_'+str(labels[4])+'_'+str(labels[3])+'_'+str(labels[2])\n", @@ -192,11 +229,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2020-09-24T15:05:00.227743Z", - "start_time": "2020-09-24T15:05:00.213980Z" + "end_time": "2020-09-28T15:50:53.305285Z", + "start_time": "2020-09-28T15:50:53.291671Z" } }, "outputs": [], @@ -205,45 +242,19 @@ "# col_01 = 10000\n", "# col_10 = 95000\n", "# col_11 = 100000\n", - "# df_red = pd.concat([df_train[df_train.columns[col_00:col_01].to_list()],df_train[df_train.columns[col_10:col_11].to_list()],df_train['Class']],axis=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "ExecuteTime": { - "end_time": "2020-09-24T15:05:00.433166Z", - "start_time": "2020-09-24T15:05:00.418153Z" - } - }, - "outputs": [], - "source": [ + "# df_red = pd.concat([df_train[df_train.columns[col_00:col_01].to_list()],df_train[df_train.columns[col_10:col_11].to_list()],df_train['Class']],axis=1)\n", "# df_red.to_pickle('./data/tetradymite_PRM2020/training_set')" ] }, { "cell_type": "code", - "execution_count": 173, - "metadata": { - "ExecuteTime": { - "end_time": "2020-09-25T12:37:48.925176Z", - "start_time": "2020-09-25T12:37:48.876683Z" - } - }, - "outputs": [], - "source": [ - "df_train = pd.read_pickle('./data/tetradymite_PRM2020/training_set')" - ] - }, - { - "cell_type": "code", - "execution_count": 81, + "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2020-09-24T18:05:41.584081Z", - "start_time": "2020-09-24T18:05:41.427515Z" - } + "end_time": "2020-09-28T16:29:22.463936Z", + "start_time": "2020-09-28T16:29:22.337130Z" + }, + "init_cell": true }, "outputs": [], "source": [ @@ -258,11 +269,12 @@ " ])\n", "for comp in df_train.index:\n", " ablmn = comp.split('_')\n", - " df_feat.loc[comp] = pd.Series({'A':ablmn[0],\n", - " 'B':ablmn[1],\n", - " 'L':ablmn[2],\n", - " 'M':ablmn[3],\n", - " 'N':ablmn[4],\n", + " df_feat.loc[comp] = pd.Series({\n", + "# 'A':ablmn[0],\n", + "# 'B':ablmn[1],\n", + "# 'L':ablmn[2],\n", + "# 'M':ablmn[3],\n", + "# 'N':ablmn[4],\n", " 'z_A':zeta[ablmn[0]],\n", " 'z_B':zeta[ablmn[1]],\n", " 'z_L':zeta[ablmn[2]],\n", @@ -278,133 +290,226 @@ " 'l_L':lambd[ablmn[2]],\n", " 'l_M':lambd[ablmn[3]],\n", " 'l_N':lambd[ablmn[4]],\n", - " }) \n" - ] - }, - { - "cell_type": "code", - "execution_count": 138, - "metadata": { - "ExecuteTime": { - "end_time": "2020-09-26T15:14:04.072271Z", - "start_time": "2020-09-26T15:14:04.041798Z" - } - }, - "outputs": [], - "source": [ - "# df_feat" + " }) \n", + "\n", + "df_feat['Class'] = df_train['Class']" ] }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T15:04:11.632031Z", - "start_time": "2020-09-26T15:04:07.252979Z" - } + "end_time": "2020-09-28T16:29:22.470868Z", + "start_time": "2020-09-28T16:29:22.465327Z" + }, + "init_cell": true }, "outputs": [], "source": [ - "phi_0, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n", - " df_train, \"Class\", cols=\"all\", task_key=None, leave_out_frac=0.0\n", - ")\n", + "def get_feat_space_and_sr(\n", + " df,\n", + " ops= ['add', 'sub', 'abs_diff', 'mult', 'div', 'exp', 'neg_exp', 'inv', 'sq', 'cb', \n", + " 'sqrt', 'cbrt', 'log', 'abs'],\n", + " cols=\"all\",\n", + " max_phi=2,\n", + " n_sis_select=50,\n", + " remove_double_divison=True,\n", + " max_dim=3,\n", + " n_residual=1,\n", + " default=True,\n", + "):\n", "\n", - "feat_space = generate_fs(\n", - " phi_0, \n", - " prop, \n", - " task_sizes_train, \n", - " [\"add\", \"sub\", \"mult\", \"div\", \"abs_diff\", \"sq\", \"cb\", \"sqrt\", \"cbrt\", \"inv\", \"abs\"], \n", - " \"classification\",\n", - " 0, \n", - " 50\n", - ")\n", + " if default:\n", + " phi_0, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n", + " df_train, \"Class\", cols='all', task_key=None, leave_out_frac=0.0\n", "\n", - "sisso = SISSOClassifier(\n", - " feat_space,\n", - " prop_unit,\n", - " prop,\n", - " prop_test,\n", - " task_sizes_train,\n", - " task_sizes_test,\n", - " leave_out_inds,\n", - " 2,\n", - " 10,\n", - " 10,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": { - "ExecuteTime": { - "end_time": "2020-09-26T15:04:28.767213Z", - "start_time": "2020-09-26T15:04:11.633705Z" - } - }, - "outputs": [], - "source": [ - "sisso.fit()" + " )\n", + " feat_space = generate_fs(\n", + " phi_0, \n", + " prop, \n", + " task_sizes_train, \n", + " [\"add\", \"sub\", \"mult\", \"div\", \"abs_diff\", \"sq\", \"cb\", \"sqrt\", \"cbrt\", \"inv\", \"abs\"], \n", + " \"classification\",\n", + " 0, \n", + " 50\n", + " )\n", + " else:\n", + " phi_0, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n", + " df_feat, \"Class\", cols=cols, task_key=None, leave_out_frac=0.0, leave_out_inds=None\n", + " )\n", + " feat_space = generate_fs(\n", + " phi_0, \n", + " prop, \n", + " task_sizes_train, \n", + " ops,\n", + " \"classification\",\n", + " max_phi, \n", + " n_sis_select\n", + " )\n", + " \n", + " sisso = SISSOClassifier(\n", + " feat_space,\n", + " prop_unit,\n", + " prop,\n", + " prop_test,\n", + " task_sizes_train,\n", + " task_sizes_test,\n", + " leave_out_inds,\n", + " max_dim,\n", + " 10,\n", + " 10\n", + " )\n", + " return feat_space, sisso" ] }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T15:04:28.784474Z", - "start_time": "2020-09-26T15:04:28.768858Z" + "end_time": "2020-09-28T16:29:22.491252Z", + "start_time": "2020-09-28T16:29:22.472488Z" }, - "scrolled": true + "init_cell": true }, "outputs": [], "source": [ - "model = sisso.models[1][0]\n", - "feat_1=model.feats[0].value\n", - "feat_0=model.feats[1].value\n", - "classified=model.prop_train\n", - "compounds=df_train.index.to_list()" - ] - }, - { - "cell_type": "code", - "execution_count": 132, - "metadata": { - "ExecuteTime": { - "end_time": "2020-09-26T15:04:28.802666Z", - "start_time": "2020-09-26T15:04:28.786184Z" - } - }, - "outputs": [], - "source": [ - "df=pd.DataFrame(data={\"Compound\":compounds,\n", - " \"Classification\":classified})\n", + "from ipywidgets import widgets, interactive\n", + "from IPython.display import HTML, clear_output\n", + "\n", + "def plot_2d_solution(b):\n", + " with out2:\n", + " model = sisso.models[1][0]\n", + " classified=model.prop_train\n", + " compounds = df_train.index.to_list()\n", + " df=pd.DataFrame(data={\n", + " \"Compound\":compounds,\n", + " \"Classification\":classified})\n", + " for feat in sisso.models[sisso.n_dim-1][0].feats:\n", + " df[str(feat)]=feat.value\n", + " visualizer=Visualizer(df, sisso)\n", + " visualizer.show()\n", + " \n", + "def prm_select(change):\n", + " if change['new'] == 'PRM2015':\n", + " default_operations = ['add', 'sub', 'abs_diff', 'mult', 'div', 'exp', 'neg_exp', 'inv', 'sq', 'cb', \n", + " 'sqrt', 'cbrt', 'log', 'abs']\n", + " default_features = ['z_AB','chi_AB','lambda_AB','z_LMN','chi_LMN','lambda_LMN']\n", + "\n", + " for op, widget in zip(possible_operations, op_list):\n", + " widget.value = op in default_operations\n", + " widget.disabled = True\n", + " for feat, widget in zip(possible_features, feat_list):\n", + " widget.value = feat in default_features\n", + " widget.disabled = True\n", + " tier_selection.value = 'PRM2015'\n", + " feat_per_iter_selection.value = 50\n", + " dimension_selection.value = 2 \n", + " else:\n", + " for widget in op_list+feat_list:\n", + " widget.disabled = False\n", "\n", - "for feat in sisso.models[sisso.n_dim - 1][0].feats:\n", - " df[str(feat)]=feat.value" + "def default_selection(b):\n", + " \n", + " default_operations = ['add', 'sub', 'abs_diff', 'mult', 'div', 'exp', 'neg_exp', 'inv', 'sq', 'cb', \n", + " 'sqrt', 'cbrt', 'log', 'abs']\n", + " default_features = ['z_AB','chi_AB','lambda_AB','z_LMN','chi_LMN','lambda_LMN']\n", + " for op, widget in zip(possible_operations, op_list):\n", + " widget.value = op in default_operations\n", + " widget.disabled = True\n", + " for feat, widget in zip(possible_features, feat_list):\n", + " widget.value = feat in default_features\n", + " widget.disabled = True\n", + " tier_selection.value = 'PRM2020'\n", + " feat_per_iter_selection.value = 50\n", + " dimension_selection.value = 2\n", + " \n", + "def find_descriptors(b):\n", + " with out2:\n", + " clear_output() \n", + " with out1:\n", + " clear_output()\n", + " print('Calculating...', flush=True)\n", + " selected_features = []\n", + " allowed_operations = []\n", + " for op, widget in zip(possible_operations, op_list):\n", + " if widget.value:\n", + " allowed_operations.append(op)\n", + " for sel_feat, widget in zip(possible_features, feat_list):\n", + " if widget.value:\n", + " feat = sel_feat.split('_')[0]\n", + " for char in sel_feat.split('_')[1]:\n", + " selected_features.append(feat + '_'+ char) \n", + " \n", + " if tier_selection.value == 'PRM2020':\n", + " selected_features = \"all\"\n", + " tier = 0\n", + " default = True\n", + " else:\n", + " tier = tier_selection.value\n", + " default = False\n", + " \n", + " global feat_space\n", + " global sisso\n", + " \n", + "# try:\n", + " feat_space, sisso = get_feat_space_and_sr(\n", + " df = df_train,\n", + " ops = allowed_operations,\n", + " cols = selected_features,\n", + " max_phi = tier,\n", + " n_sis_select = feat_per_iter_selection.value,\n", + " remove_double_divison=True,\n", + " max_dim = dimension_selection.value,\n", + " n_residual = 1,\n", + " default = default)\n", + " clear_output()\n", + " if (dimension_selection.value>1):\n", + " plot_button.disabled=False\n", + " else:\n", + " plot_button.disabled=True\n", + "\n", + " print(\"Number of features generated: \" + str(feat_space.n_feat))\n", + "\n", + " try:\n", + " sisso.fit()\n", + " for i in range(dimension_selection.value):\n", + " print(str(i+1)+'D model')\n", + "# print(\"RMSE: {:.4} | Descriptor: {}\".format(sisso.models[i][0].rmse, sisso.models[i][0]))\n", + " string = \"c0:{:.4}\".format(sisso.models[i][0].coefs[0][-1])\n", + " for j in range(i+1):\n", + " string = string + str(\" | a\"+str(j)+\":{:.4}\".format(sisso.models[i][0].coefs[0][j]))\n", + " print(string + '\\n')\n", + " global df\n", + "\n", + " except RuntimeError:\n", + " print(\"\\nThe number of selected features per SIS iteration is bigger than the number of features available. Please reduce the number of selected features per SIS iteration (number of features generated / max number of dimensions) or increase the number of selected features and operations.\")\n", + "# except:\n", + "# print('The present selection does not lead to the creation of any derived features in the highest selected rung, please select at least one binary or power operator, or reduce the maximum rung')" ] }, { "cell_type": "code", - "execution_count": 139, + "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2020-09-26T15:14:36.063894Z", - "start_time": "2020-09-26T15:14:35.462458Z" + "end_time": "2020-09-28T16:29:22.749289Z", + "start_time": "2020-09-28T16:29:22.492553Z" }, + "init_cell": true, "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ec3056dff7ad49ceaeea04bcd583f307", + "model_id": "db66a946fda449bf9fb691a810eedb38", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(HBox(children=(VBox(children=(Dropdown(description='x-axis', options=('[[[Z11A+Z11B]*[Z11C+Z11D…" + "VBox(children=(HBox(children=(VBox(children=(Label(value=''), Checkbox(value=True, disabled=True, indent=False…" ] }, "metadata": {}, @@ -412,11 +517,91 @@ } ], "source": [ - "Visualizer(df, sisso).show()" + "cb_layout = widgets.Layout(width = '15px')\n", + "thin_layout = widgets.Layout(width = '100px')\n", + "mid_layout = widgets.Layout(width = '200px')\n", + "wide_layout = widgets.Layout(width = '300px')\n", + "\n", + "possible_operations = ['add', 'sub', 'abs_diff', 'mult', 'div', 'exp', 'neg_exp', 'inv', 'sq', 'cb', \n", + " 'sqrt', 'cbrt', 'log', 'abs']\n", + "\n", + "possible_features = ['z_AB','chi_AB','lambda_AB','z_LMN','chi_LMN','lambda_LMN']\n", + "\n", + "tooltips = {\n", + " \"z_AB\" : \"Atomic number cations\",\n", + " \"chi_AB\" : \"Pauling electronegativit cations\",\n", + " \"lambda_AB\" : \"Atomic SOC constant cations\",\n", + " \"z_LMN\" : \"Atomic number anions\",\n", + " \"chi_LMN\" : \"Pauling electronegativity anions\",\n", + " \"lambda_LMN\" : \"Atomic SOC constant anions\",\n", + "}\n", + "\n", + "labels = {\n", + " 'add' : '$x + y$', 'sub' : '$x - y$', 'abs_diff' : '$|x - y|$', 'mult' : '$x \\cdot y$', 'div' : '$x / y$',\n", + " 'exp' : '$\\exp(x)$', 'neg_exp' : '$\\exp(-x)$', 'inv' : '$1/x$', 'sq' : '$x^2$', 'cb' : '$x^3$', \n", + " 'six_pow' : '$x^6$', 'sqrt' : '$\\sqrt{x}$', 'cbrt' : '$\\sqrt[3]{x}$', 'log' : '$\\log(x)$',\n", + " 'abs' : '$|x|$', 'sin' : '$\\sin(x)$', 'cos' : '$\\cos(x)$', 'z_AB' : '$Z_{AB}$', 'chi_AB' : '$\\chi_{AB}$', \n", + " 'lambda_AB' : '$\\lambda_{AB}$', 'z_LMN' : '$Z_{LMN}$', 'chi_LMN' : '$\\chi_{LMN}$', 'lambda_LMN' : '$\\lambda_{LMN}$' \n", + "}\n", + "\n", + "op_list = []\n", + "op_labels = []\n", + "feat_list = []\n", + "feat_labels = []\n", + "for operation in possible_operations:\n", + " op_list.append(widgets.Checkbox(description='', value=True, indent=False, layout=cb_layout))\n", + " op_labels.append(widgets.Label(value=labels[operation]))\n", + "for feature in possible_features:\n", + " feat_list.append(widgets.Checkbox(description=tooltips[feature], value=True, indent=False, layout=cb_layout))\n", + " feat_labels.append(widgets.Label(value=labels[feature]))\n", + " \n", + "op_box = widgets.VBox([widgets.Label()]+op_list)\n", + "op_label_box = widgets.VBox([widgets.Label(value='Operations:', layout=thin_layout)]+op_labels)\n", + "feat_box = widgets.VBox([widgets.Label()]+feat_list)\n", + "feat_label_box = widgets.VBox([widgets.Label(value='Features:', layout=thin_layout)]+feat_labels)\n", + "\n", + "tier_selection = widgets.Dropdown(options=['PRM2020', 1,2,3], layout=thin_layout)\n", + "feat_per_iter_selection = widgets.BoundedIntText(value=26, min=1, max=100, step=1, layout=thin_layout)\n", + "dimension_selection = widgets.BoundedIntText(value = 3, min=1, max=4, step=1, layout = thin_layout)\n", + "settings_box = widgets.VBox([\n", + " widgets.Label(value='Settings:', layout=wide_layout),\n", + " widgets.Label(value='SISSO rung:', layout=wide_layout),\n", + " tier_selection,\n", + " widgets.Label(value='Number of selected features per SIS iteration:', layout=wide_layout),\n", + " feat_per_iter_selection,\n", + " widgets.Label(value='Maximum number of dimensions:', layout=wide_layout),\n", + " dimension_selection])\n", + "\n", + "default_button = widgets.Button(description = 'Default selection', layout=mid_layout)\n", + "descriptor_button = widgets.Button(description = 'Run', layout=mid_layout)\n", + "plot_button = widgets.Button(description = 'Plot interactive map', disabled=True, layout=mid_layout)\n", + "default_button.on_click(default_selection)\n", + "descriptor_button.on_click(find_descriptors)\n", + "plot_button.on_click(plot_2d_solution)\n", + "button_box = widgets.VBox([default_button, descriptor_button, plot_button])\n", + "\n", + "out1 = widgets.Output()\n", + "out2 = widgets.Output()\n", + "\n", + "gui_box = widgets.HBox([op_box, op_label_box, feat_box, feat_label_box, settings_box, button_box])\n", + "out_box = widgets.VBox([gui_box, out1, out2])\n", + "\n", + "tier_selection.observe(prm_select, names='value')\n", + "\n", + "default_selection('')\n", + "display(out_box)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { + "celltoolbar": "Initialization Cell", "kernelspec": { "display_name": "Python 3", "language": "python", diff --git a/tetradymite_PRM2020/visualizer.py b/tetradymite_PRM2020/visualizer.py index 6483cf5..c58a2fd 100644 --- a/tetradymite_PRM2020/visualizer.py +++ b/tetradymite_PRM2020/visualizer.py @@ -80,7 +80,8 @@ class Visualizer: x_cls1 = df[self.features[0]][self.df_cls1] y_cls1 = df[self.features[1]][self.df_cls1] line_x, line_y = self.f_x(self.features[0], self.features[1]) - hullx_cls0, hully_cls0, hullx_cls1, hully_cls1 = self.make_hull(self.features[0], self.features[1]) + + # custom_cls0 = np.dstack((self.target_train[self.df_cls0], # self.target_predict[self.df_cls0]))[0] # custom_cls1 = np.dstack((self.target_train[self.df_cls1], @@ -122,25 +123,41 @@ class Visualizer: marker=dict(color=self.colors_cls1), ) )) - self.fig.add_trace( - go.Scatter( - x=hullx_cls0, - y=hully_cls0, - line=dict(color='Grey', width=1, dash=self.line_styles[0]), - name=r'Convex' + '<br>' + 'hull 0', - visible=False - ), - ) - self.fig.add_trace( - go.Scatter( - x=hullx_cls1, - y=hully_cls1, - line=dict(color='Grey', width=1, dash=self.line_styles[0]), - name=r'Convex' + '<br>' + 'hull 1', - visible=False - ), + try: + hullx_cls0, hully_cls0, hullx_cls1, hully_cls1 = self.make_hull(self.features[0], self.features[1]) + + self.fig.add_trace( + go.Scatter( + x=hullx_cls0, + y=hully_cls0, + line=dict(color='Grey', width=1, dash=self.line_styles[0]), + name=r'Convex' + '<br>' + 'hull 0', + visible=False + ), + ) + self.fig.add_trace( + go.Scatter( + x=hullx_cls1, + y=hully_cls1, + line=dict(color='Grey', width=1, dash=self.line_styles[0]), + name=r'Convex' + '<br>' + 'hull 1', + visible=False + ), + + ) + except: + self.fig.add_trace( + go.Scatter( + visible=False + ), + ) + self.fig.add_trace( + go.Scatter( + visible=False + ), + + ) - ) self.fig.add_trace( go.Scatter( x=line_x, -- GitLab