diff --git a/nn_regression.ipynb b/nn_regression.ipynb index d8f1d311c69dcc3af5965be9db4355959d3737e9..00879dcee6a1bb6b0a8b34e0ff0083ee35632406 100644 --- a/nn_regression.ipynb +++ b/nn_regression.ipynb @@ -75,8 +75,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:48:01.350451Z", - "start_time": "2020-05-22T14:47:59.404436Z" + "end_time": "2021-01-29T13:36:50.877631Z", + "start_time": "2021-01-29T13:36:50.869148Z" }, "scrolled": true }, @@ -84,6 +84,10 @@ "source": [ "# Plotting\n", "%matplotlib inline\n", + "\n", + "import warnings\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)\n", + "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", @@ -94,16 +98,16 @@ "# Keras for neural networks\n", "from keras.models import load_model\n", "from keras.layers import Input, Dense, Dropout\n", - "from keras.models import Model\n", - "\n", - "# pymatgen\n", - "from pymatgen.core.periodic_table import Element\n", + "from keras.models import Model, load_model\n", "\n", "# other packages\n", "import os\n", "import numpy as np\n", "from collections import Counter\n", "\n", + "# pymatgen\n", + "from pymatgen.core.periodic_table import Element\n", + "\n", "# sklearn\n", "from sklearn.metrics import r2_score, mean_absolute_error\n", "from sklearn.model_selection import KFold\n", @@ -111,7 +115,10 @@ "from sklearn.preprocessing import StandardScaler\n", "\n", "# pandas\n", - "import pandas as pd" + "import pandas as pd\n", + "\n", + "# json\n", + "import json" ] }, { @@ -356,8 +363,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:25.380307Z", - "start_time": "2020-05-22T14:08:25.007228Z" + "end_time": "2021-01-29T13:36:51.267156Z", + "start_time": "2021-01-29T13:36:50.879337Z" }, "scrolled": true }, @@ -391,8 +398,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:25.710996Z", - "start_time": "2020-05-22T14:08:25.382690Z" + "end_time": "2021-01-29T13:36:51.693354Z", + "start_time": "2021-01-29T13:36:51.269399Z" } }, "outputs": [], @@ -415,8 +422,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.547127Z", - "start_time": "2020-05-22T14:08:25.713822Z" + "end_time": "2021-01-29T13:36:53.555987Z", + "start_time": "2021-01-29T13:36:51.695774Z" } }, "outputs": [], @@ -465,8 +472,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.640466Z", - "start_time": "2020-05-22T14:08:26.548884Z" + "end_time": "2021-01-29T13:36:53.646679Z", + "start_time": "2021-01-29T13:36:53.558131Z" } }, "outputs": [], @@ -488,8 +495,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.645227Z", - "start_time": "2020-05-22T14:08:26.642010Z" + "end_time": "2021-01-29T13:36:53.651474Z", + "start_time": "2021-01-29T13:36:53.648126Z" } }, "outputs": [], @@ -516,8 +523,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.793126Z", - "start_time": "2020-05-22T14:08:26.646886Z" + "end_time": "2021-01-29T13:36:53.812557Z", + "start_time": "2021-01-29T13:36:53.654271Z" } }, "outputs": [], @@ -544,8 +551,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.799383Z", - "start_time": "2020-05-22T14:08:26.795789Z" + "end_time": "2021-01-29T13:36:53.818325Z", + "start_time": "2021-01-29T13:36:53.815133Z" } }, "outputs": [], @@ -569,8 +576,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.921060Z", - "start_time": "2020-05-22T14:08:26.801212Z" + "end_time": "2021-01-29T13:36:53.954336Z", + "start_time": "2021-01-29T13:36:53.820003Z" } }, "outputs": [], @@ -605,8 +612,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:26.928103Z", - "start_time": "2020-05-22T14:08:26.922685Z" + "end_time": "2021-01-29T13:36:53.963378Z", + "start_time": "2021-01-29T13:36:53.956231Z" } }, "outputs": [], @@ -678,8 +685,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:08:27.092710Z", - "start_time": "2020-05-22T14:08:26.930206Z" + "end_time": "2021-01-29T13:36:54.130309Z", + "start_time": "2021-01-29T13:36:53.965483Z" } }, "outputs": [], @@ -705,7 +712,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we fit the model to the training data, while computing the mean squared error on the validation set for each epoch:" + "Now we fit the model to the training data, while computing the mean squared error on the validation set for each epoch. A pretrained model is provided, while you can also start from a fresh one (just replace \"reload=True\" with \"reload=False\"):" ] }, { @@ -713,17 +720,29 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:15:47.803232Z", - "start_time": "2020-05-22T14:08:27.094355Z" + "end_time": "2021-01-29T13:36:55.675381Z", + "start_time": "2021-01-29T13:36:54.131740Z" }, "scrolled": true }, "outputs": [], "source": [ - "history = model.fit(X_train, y_train,\n", - " validation_data = (X_val, y_val),\n", - " epochs=params[\"epochs\"],\n", - " batch_size=params[\"batch_size\"], verbose=True)" + "reload = True\n", + "\n", + "if reload:\n", + " model = load_model('./data/nn_regression/model.h5')\n", + " with open('./data/nn_regression/history.json') as json_file:\n", + " history = json.load(json_file)\n", + "else:\n", + " history = model.fit(X_train, y_train,\n", + " validation_data = (X_val, y_val),\n", + " epochs=params[\"epochs\"],\n", + " batch_size=params[\"batch_size\"], verbose=True)\n", + " # save new model\n", + " model.save('./data/nn_regression/new_model.h5')\n", + " with open('./data/nn_regression/new_history.json', 'w') as outfile:\n", + " json.dump(history.history, outfile)\n", + " " ] }, { @@ -738,8 +757,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:15:47.945244Z", - "start_time": "2020-05-22T14:15:47.805420Z" + "end_time": "2021-01-29T13:36:55.892054Z", + "start_time": "2021-01-29T13:36:55.677104Z" } }, "outputs": [], @@ -747,8 +766,8 @@ "import matplotlib.pyplot as plt\n", "\n", "# summarize history for loss: A plot of loss on the training and validation datasets over training epochs.\n", - "plt.plot(history.history['loss'])\n", - "plt.plot(history.history['val_loss'])\n", + "plt.plot(history['loss'])\n", + "plt.plot(history['val_loss'])\n", "plt.title('model loss')\n", "plt.ylabel('loss')\n", "plt.xlabel('epoch')\n", @@ -783,8 +802,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:15:53.924363Z", - "start_time": "2020-05-22T14:15:47.947273Z" + "end_time": "2021-01-29T13:37:02.714771Z", + "start_time": "2021-01-29T13:36:55.894010Z" } }, "outputs": [], @@ -845,8 +864,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:15:56.919219Z", - "start_time": "2020-05-22T14:15:53.926422Z" + "end_time": "2021-01-29T13:37:06.094492Z", + "start_time": "2021-01-29T13:37:02.716800Z" } }, "outputs": [], @@ -892,8 +911,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:15:56.923919Z", - "start_time": "2020-05-22T14:15:56.921595Z" + "end_time": "2021-01-29T13:37:06.098758Z", + "start_time": "2021-01-29T13:37:06.096204Z" } }, "outputs": [], @@ -906,8 +925,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2020-05-22T14:16:02.573238Z", - "start_time": "2020-05-22T14:15:56.925596Z" + "end_time": "2021-01-29T13:37:12.129710Z", + "start_time": "2021-01-29T13:37:06.100997Z" }, "scrolled": true }, diff --git a/setup.py b/setup.py index de85809933a517d13f09d6f0622f0657fb1a94b1..33e5bc18ff2cee2afeb4c661949173ebe462346a 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ import json from setuptools import setup, find_packages +# Succesfull test of this tutorial was conducted for: +# 'tensorflow==1.13.1', 'keras==2.2.4', 'numpy==1.16.4', 'scipy==1.1.0', 'matplotlib', 'pandas', 'seaborn', 'pymatgen==2020.3.13', 'sklearn' + with open('metainfo.json') as file: metainfo = json.load(file) @@ -13,5 +16,5 @@ setup( description=metainfo['title'], long_description=metainfo['description'], packages=find_packages(), - install_requires=['tensorflow==1.13.1', 'keras==2.2.4', 'numpy==1.16.4', 'scipy==1.1.0', 'matplotlib', 'pandas', 'seaborn', 'pymatgen==2020.3.13', 'sklearn'], + install_requires=['tensorflow', 'keras', 'numpy', 'scipy', 'matplotlib', 'pandas', 'seaborn', 'pymatgen', 'sklearn'], )