Commit 24e27544 authored by Luigi Sbailo's avatar Luigi Sbailo
Browse files

Update notebook to latest SISSO version

parent b1f7ef94
......@@ -92,26 +92,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-21T15:58:01.856952Z",
"start_time": "2021-06-21T15:58:01.835979Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 83,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:48.732529Z",
"start_time": "2021-06-22T13:47:48.724077Z"
"end_time": "2021-09-15T12:06:55.215052Z",
"start_time": "2021-09-15T12:06:55.209182Z"
},
"init_cell": true
},
......@@ -170,17 +155,19 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 84,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:48.744505Z",
"start_time": "2021-06-22T13:47:48.737146Z"
"end_time": "2021-09-15T12:06:55.220065Z",
"start_time": "2021-09-15T12:06:55.216481Z"
},
"init_cell": true
},
"outputs": [],
"source": [
"from sissopp import get_max_number_feats, get_estimate_n_feat_next_rung, generate_fs, SISSOClassifier, generate_phi_0_from_csv, FeatureSpace\n",
"from sissopp import Inputs, FeatureSpace, SISSOClassifier, FeatureNode, Unit\n",
"from sissopp.py_interface import read_csv\n",
"from sissopp.py_interface.import_dataframe import get_unit\n",
"from tetradymite_PRM2020.visualizer import Visualizer\n",
"import numpy as np\n",
"import pandas as pd\n",
......@@ -210,11 +197,11 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 85,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:48.768658Z",
"start_time": "2021-06-22T13:47:48.747843Z"
"end_time": "2021-09-15T12:06:55.233765Z",
"start_time": "2021-09-15T12:06:55.221460Z"
},
"init_cell": true
},
......@@ -236,7 +223,7 @@
"outputs": [],
"source": [
"# This piece of code is not run at initialization. \n",
"# It serves to create the molecular structures which are visualized.\n",
"# It can create the molecular structures which are visualized.\n",
"\n",
"path_structure = './data/tetradymite_PRM2020/structures/'\n",
"try:\n",
......@@ -283,11 +270,11 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 86,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:48.893507Z",
"start_time": "2021-06-22T13:47:48.770063Z"
"end_time": "2021-09-15T12:06:55.365899Z",
"start_time": "2021-09-15T12:06:55.235398Z"
},
"init_cell": true
},
......@@ -327,93 +314,75 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 87,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:48.901657Z",
"start_time": "2021-06-22T13:47:48.894931Z"
"end_time": "2021-09-15T12:06:55.374165Z",
"start_time": "2021-09-15T12:06:55.367318Z"
},
"init_cell": true
},
"outputs": [],
"source": [
"def get_featspace_sisso(\n",
" df,\n",
" ops= ['add', 'sub', 'abs_diff', 'mult', 'div', 'exp', 'neg_exp', 'inv', 'sq', 'cb', \n",
" 'sqrt', 'cbrt', 'log', 'abs'],\n",
" cols=\"all\",\n",
" max_phi=2,\n",
"def get_feat_space_and_sisso_regressor(\n",
" selected_ops=[\"add\", \"abs_diff\", \"div\", \"sq\", \"exp\"],\n",
" selected_features = 'all',\n",
" max_rung=2,\n",
" n_sis_select=50,\n",
" remove_double_divison=True,\n",
" max_dim=3,\n",
" n_residual=1,\n",
" n_dim=2,\n",
" n_residual=10,\n",
" default=True,\n",
"):\n",
"\n",
" if default:\n",
" phi_0, prop_label, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n",
" \n",
" selected_ops = [\"add\", \"sub\", \"mult\", \"div\", \"abs_diff\", \"sq\", \"cb\", \"sqrt\", \"cbrt\", \"inv\", \"abs\"] \n",
" selected_features = 'all'\n",
" inputs = read_csv(\n",
" df_train, \n",
" \"Class\",\n",
" prop_key=\"Class\",\n",
" cols='all',\n",
" task_key=None,\n",
" max_rung=max_rung,\n",
" leave_out_frac=0.0,\n",
" leave_out_inds=None,\n",
" max_rung=1\n",
" )\n",
" feat_space = generate_fs(\n",
" phi_0, \n",
" prop, \n",
" task_sizes_train, \n",
" [\"add\", \"sub\", \"mult\", \"div\", \"abs_diff\", \"sq\", \"cb\", \"sqrt\", \"cbrt\", \"inv\", \"abs\"], \n",
" [],\n",
" \"classification\",\n",
" 0, \n",
" n_sis_select\n",
" )\n",
" )\n",
" else:\n",
" phi_0, prop_label, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n",
" \n",
" inputs = read_csv(\n",
" df_feat, \n",
" \"Class\", \n",
" cols=cols, \n",
" task_key=None, \n",
" leave_out_frac=0.0, \n",
" leave_out_inds=None, \n",
" max_rung=max_phi\n",
" )\n",
" feat_space = generate_fs(\n",
" phi_0, \n",
" prop, \n",
" task_sizes_train, \n",
" ops,\n",
" [],\n",
" \"classification\",\n",
" max_phi, \n",
" n_sis_select\n",
" )\n",
" prop_key=\"Class\",\n",
" cols=selected_features,\n",
" max_rung=max_rung,\n",
" leave_out_frac=0.0\n",
" )\n",
" \n",
" inputs.max_rung = max_rung\n",
" inputs.allowed_ops = selected_ops\n",
" inputs.n_sis_select = n_sis_select\n",
" inputs.n_dim = n_dim\n",
" inputs.n_residual = n_residual\n",
" inputs.n_model_store = 1\n",
" inputs.calc_type = \"classification\"\n",
" inputs.sample_ids_train = df_feat.index.tolist()\n",
" inputs.prop_train = df_feat[\"Class\"].to_numpy()\n",
" inputs.prop_test = np.array([])\n",
" inputs.prop_label = \"Class\"\n",
" inputs.task_names = [\"all_mats\"]\n",
"\n",
" \n",
" feat_space = FeatureSpace(inputs)\n",
" \n",
" sisso = SISSOClassifier(inputs, feat_space)\n",
" \n",
" sisso = SISSOClassifier(\n",
" feat_space,\n",
" prop_label,\n",
" prop_unit,\n",
" prop,\n",
" prop_test,\n",
" task_sizes_train,\n",
" task_sizes_test,\n",
" leave_out_inds,\n",
" max_dim,\n",
" 10,\n",
" 10\n",
" )\n",
" return feat_space, sisso"
" return feat_space, sisso "
]
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 88,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:48.918838Z",
"start_time": "2021-06-22T13:47:48.903533Z"
"end_time": "2021-09-15T12:06:55.393669Z",
"start_time": "2021-09-15T12:06:55.375598Z"
},
"init_cell": true
},
......@@ -498,18 +467,19 @@
" \n",
" global feat_space\n",
" global sisso\n",
" \n",
"\n",
" try:\n",
" feat_space, sisso = get_featspace_sisso(\n",
" df = df_train,\n",
" ops = allowed_operations,\n",
" cols = selected_features,\n",
" max_phi = tier,\n",
" n_sis_select = feat_per_iter_selection.value,\n",
" remove_double_divison=True,\n",
" max_dim = dimension_selection.value,\n",
" n_residual = 1,\n",
" default = default)\n",
" feat_space, sisso = get_feat_space_and_sisso_regressor(\n",
" selected_ops = allowed_operations,\n",
" selected_features = selected_features,\n",
" max_rung = tier,\n",
" n_sis_select = feat_per_iter_selection.value,\n",
" n_dim = dimension_selection.value,\n",
" n_residual = 10,\n",
" default = default\n",
" )\n",
"\n",
"\n",
" clear_output()\n",
" if (dimension_selection.value>1):\n",
" plot_button.disabled=False\n",
......@@ -544,11 +514,11 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 89,
"metadata": {
"ExecuteTime": {
"end_time": "2021-06-22T13:47:49.206736Z",
"start_time": "2021-06-22T13:47:48.920724Z"
"end_time": "2021-09-15T12:06:55.691702Z",
"start_time": "2021-09-15T12:06:55.395093Z"
},
"init_cell": true,
"scrolled": false
......@@ -557,7 +527,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3f97794bc3d14b55a26cb7b3abc9cd78",
"model_id": "a6173b0cb1eb48fd82eb15d5f54f1207",
"version_major": 2,
"version_minor": 0
},
......@@ -617,7 +587,7 @@
"\n",
"rung_selection = widgets.Dropdown(options=['PRM2020', 1,2,3], value=2,layout=thin_layout)\n",
"rung_selection.value = 'PRM2020'\n",
"feat_per_iter_selection = widgets.BoundedIntText(value = 50, min=10, max=100, step=1, layout=thin_layout)\n",
"feat_per_iter_selection = widgets.BoundedIntText(value = 50, min=10, max=200, step=1, layout=thin_layout)\n",
"dimension_selection = widgets.BoundedIntText(value = 2, min=1, max=4, step=1, layout = thin_layout)\n",
"settings_box = widgets.VBox([\n",
" widgets.Label(value='Settings:', layout=wide_layout),\n",
......@@ -667,7 +637,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
"version": "3.7.3"
}
},
"nbformat": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment