diff --git a/tetradymite_PRM2020.ipynb b/tetradymite_PRM2020.ipynb index f78b8a9b176c4f9ca9c9293c71ce73abf33bd94d..8cd924456e53254a611f02130d5dec64e84c87a1 100644 --- a/tetradymite_PRM2020.ipynb +++ b/tetradymite_PRM2020.ipynb @@ -92,26 +92,11 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2021-06-21T15:58:01.856952Z", - "start_time": "2021-06-21T15:58:01.835979Z" - } - }, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 74, + "execution_count": 83, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:48.732529Z", - "start_time": "2021-06-22T13:47:48.724077Z" + "end_time": "2021-09-15T12:06:55.215052Z", + "start_time": "2021-09-15T12:06:55.209182Z" }, "init_cell": true }, @@ -170,17 +155,19 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 84, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:48.744505Z", - "start_time": "2021-06-22T13:47:48.737146Z" + "end_time": "2021-09-15T12:06:55.220065Z", + "start_time": "2021-09-15T12:06:55.216481Z" }, "init_cell": true }, "outputs": [], "source": [ - "from sissopp import get_max_number_feats, get_estimate_n_feat_next_rung, generate_fs, SISSOClassifier, generate_phi_0_from_csv, FeatureSpace\n", + "from sissopp import Inputs, FeatureSpace, SISSOClassifier, FeatureNode, Unit\n", + "from sissopp.py_interface import read_csv\n", + "from sissopp.py_interface.import_dataframe import get_unit\n", "from tetradymite_PRM2020.visualizer import Visualizer\n", "import numpy as np\n", "import pandas as pd\n", @@ -210,11 +197,11 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 85, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:48.768658Z", - "start_time": "2021-06-22T13:47:48.747843Z" + "end_time": "2021-09-15T12:06:55.233765Z", + "start_time": "2021-09-15T12:06:55.221460Z" }, "init_cell": true }, @@ -236,7 +223,7 @@ "outputs": [], "source": [ "# This piece of code is not run at initialization. \n", - "# It serves to create the molecular structures which are visualized.\n", + "# It can create the molecular structures which are visualized.\n", "\n", "path_structure = './data/tetradymite_PRM2020/structures/'\n", "try:\n", @@ -283,11 +270,11 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 86, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:48.893507Z", - "start_time": "2021-06-22T13:47:48.770063Z" + "end_time": "2021-09-15T12:06:55.365899Z", + "start_time": "2021-09-15T12:06:55.235398Z" }, "init_cell": true }, @@ -327,93 +314,75 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 87, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:48.901657Z", - "start_time": "2021-06-22T13:47:48.894931Z" + "end_time": "2021-09-15T12:06:55.374165Z", + "start_time": "2021-09-15T12:06:55.367318Z" }, "init_cell": true }, "outputs": [], "source": [ - "def get_featspace_sisso(\n", - " df,\n", - " ops= ['add', 'sub', 'abs_diff', 'mult', 'div', 'exp', 'neg_exp', 'inv', 'sq', 'cb', \n", - " 'sqrt', 'cbrt', 'log', 'abs'],\n", - " cols=\"all\",\n", - " max_phi=2,\n", + "def get_feat_space_and_sisso_regressor(\n", + " selected_ops=[\"add\", \"abs_diff\", \"div\", \"sq\", \"exp\"],\n", + " selected_features = 'all',\n", + " max_rung=2,\n", " n_sis_select=50,\n", - " remove_double_divison=True,\n", - " max_dim=3,\n", - " n_residual=1,\n", + " n_dim=2,\n", + " n_residual=10,\n", " default=True,\n", "):\n", "\n", " if default:\n", - " phi_0, prop_label, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n", + " \n", + " selected_ops = [\"add\", \"sub\", \"mult\", \"div\", \"abs_diff\", \"sq\", \"cb\", \"sqrt\", \"cbrt\", \"inv\", \"abs\"] \n", + " selected_features = 'all'\n", + " inputs = read_csv(\n", " df_train, \n", - " \"Class\",\n", + " prop_key=\"Class\",\n", " cols='all',\n", - " task_key=None,\n", + " max_rung=max_rung,\n", " leave_out_frac=0.0,\n", - " leave_out_inds=None,\n", - " max_rung=1\n", - " )\n", - " feat_space = generate_fs(\n", - " phi_0, \n", - " prop, \n", - " task_sizes_train, \n", - " [\"add\", \"sub\", \"mult\", \"div\", \"abs_diff\", \"sq\", \"cb\", \"sqrt\", \"cbrt\", \"inv\", \"abs\"], \n", - " [],\n", - " \"classification\",\n", - " 0, \n", - " n_sis_select\n", - " )\n", + " )\n", " else:\n", - " phi_0, prop_label, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(\n", + " \n", + " inputs = read_csv(\n", " df_feat, \n", - " \"Class\", \n", - " cols=cols, \n", - " task_key=None, \n", - " leave_out_frac=0.0, \n", - " leave_out_inds=None, \n", - " max_rung=max_phi\n", - " )\n", - " feat_space = generate_fs(\n", - " phi_0, \n", - " prop, \n", - " task_sizes_train, \n", - " ops,\n", - " [],\n", - " \"classification\",\n", - " max_phi, \n", - " n_sis_select\n", - " )\n", + " prop_key=\"Class\",\n", + " cols=selected_features,\n", + " max_rung=max_rung,\n", + " leave_out_frac=0.0\n", + " )\n", + " \n", + " inputs.max_rung = max_rung\n", + " inputs.allowed_ops = selected_ops\n", + " inputs.n_sis_select = n_sis_select\n", + " inputs.n_dim = n_dim\n", + " inputs.n_residual = n_residual\n", + " inputs.n_model_store = 1\n", + " inputs.calc_type = \"classification\"\n", + " inputs.sample_ids_train = df_feat.index.tolist()\n", + " inputs.prop_train = df_feat[\"Class\"].to_numpy()\n", + " inputs.prop_test = np.array([])\n", + " inputs.prop_label = \"Class\"\n", + " inputs.task_names = [\"all_mats\"]\n", + "\n", + " \n", + " feat_space = FeatureSpace(inputs)\n", + " \n", + " sisso = SISSOClassifier(inputs, feat_space)\n", " \n", - " sisso = SISSOClassifier(\n", - " feat_space,\n", - " prop_label,\n", - " prop_unit,\n", - " prop,\n", - " prop_test,\n", - " task_sizes_train,\n", - " task_sizes_test,\n", - " leave_out_inds,\n", - " max_dim,\n", - " 10,\n", - " 10\n", - " )\n", - " return feat_space, sisso" + " return feat_space, sisso " ] }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 88, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:48.918838Z", - "start_time": "2021-06-22T13:47:48.903533Z" + "end_time": "2021-09-15T12:06:55.393669Z", + "start_time": "2021-09-15T12:06:55.375598Z" }, "init_cell": true }, @@ -498,18 +467,19 @@ " \n", " global feat_space\n", " global sisso\n", - " \n", + "\n", " try:\n", - " feat_space, sisso = get_featspace_sisso(\n", - " df = df_train,\n", - " ops = allowed_operations,\n", - " cols = selected_features,\n", - " max_phi = tier,\n", - " n_sis_select = feat_per_iter_selection.value,\n", - " remove_double_divison=True,\n", - " max_dim = dimension_selection.value,\n", - " n_residual = 1,\n", - " default = default)\n", + " feat_space, sisso = get_feat_space_and_sisso_regressor(\n", + " selected_ops = allowed_operations,\n", + " selected_features = selected_features,\n", + " max_rung = tier,\n", + " n_sis_select = feat_per_iter_selection.value,\n", + " n_dim = dimension_selection.value,\n", + " n_residual = 10,\n", + " default = default\n", + " )\n", + "\n", + "\n", " clear_output()\n", " if (dimension_selection.value>1):\n", " plot_button.disabled=False\n", @@ -544,11 +514,11 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 89, "metadata": { "ExecuteTime": { - "end_time": "2021-06-22T13:47:49.206736Z", - "start_time": "2021-06-22T13:47:48.920724Z" + "end_time": "2021-09-15T12:06:55.691702Z", + "start_time": "2021-09-15T12:06:55.395093Z" }, "init_cell": true, "scrolled": false @@ -557,7 +527,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3f97794bc3d14b55a26cb7b3abc9cd78", + "model_id": "a6173b0cb1eb48fd82eb15d5f54f1207", "version_major": 2, "version_minor": 0 }, @@ -617,7 +587,7 @@ "\n", "rung_selection = widgets.Dropdown(options=['PRM2020', 1,2,3], value=2,layout=thin_layout)\n", "rung_selection.value = 'PRM2020'\n", - "feat_per_iter_selection = widgets.BoundedIntText(value = 50, min=10, max=100, step=1, layout=thin_layout)\n", + "feat_per_iter_selection = widgets.BoundedIntText(value = 50, min=10, max=200, step=1, layout=thin_layout)\n", "dimension_selection = widgets.BoundedIntText(value = 2, min=1, max=4, step=1, layout = thin_layout)\n", "settings_box = widgets.VBox([\n", " widgets.Label(value='Settings:', layout=wide_layout),\n", @@ -667,7 +637,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.7.3" } }, "nbformat": 4,