Commit acd0215f authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented VAEP hypermodel in hypermodels.py

parent b2e2a664
......@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
......@@ -10,12 +10,12 @@
"%autoreload 2\n",
"\n",
"import warnings\n",
'warnings.filterwarnings("ignore")',
],
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
......@@ -25,34 +25,42 @@
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from collections import defaultdict\n",
"from tqdm import tqdm_notebook as tqdm",
],
"from tqdm import tqdm_notebook as tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {"tags": ["parameters"]},
"execution_count": 3,
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": ['path = "../../Desktop/DLC_social_1/"'],
"source": [
"path = \"../../Desktop/DLC_social_1/\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Set up and design the project"],
"source": [
"# Set up and design the project"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:\n",
" Treatment_dict = pickle.load(handle)",
],
" Treatment_dict = pickle.load(handle)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
......@@ -63,14 +71,23 @@
" 'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],\n",
" 'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],\n",
" 'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],\n",
" 'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}",
],
" 'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.7 s, sys: 833 ms, total: 3.54 s\n",
"Wall time: 1.23 s\n"
]
}
],
"source": [
"%%time\n",
"DLC_social_1 = project(path=path,#Path where to find the required files\n",
......@@ -84,71 +101,171 @@
" arena_dims=[380], #Dimensions of the arena. Just one if it's circular\n",
" video_format='.mp4',\n",
" table_format='.h5',\n",
" exp_conditions=Treatment_dict)",
],
" exp_conditions=Treatment_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run project"
]
},
{"cell_type": "markdown", "metadata": {}, "source": ["# Run project"]},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading trajectories...\n",
"Smoothing trajectories...\n",
"Computing distances...\n",
"Computing angles...\n",
"Done!\n",
"Coordinates of 47 videos across 4 conditions\n",
"CPU times: user 9.08 s, sys: 636 ms, total: 9.72 s\n",
"Wall time: 11.7 s\n"
]
},
{
"data": {
"text/plain": [
"source.preprocess.coordinates"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
"print(DLC_social_1_coords)\n",
"type(DLC_social_1_coords)",
],
"type(DLC_social_1_coords)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate coords"
]
},
{"cell_type": "markdown", "metadata": {}, "source": ["# Generate coords"]},
{
"cell_type": "code",
"execution_count": null,
"metadata": {"scrolled": true},
"outputs": [],
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 804 ms, sys: 66.5 ms, total: 870 ms\n",
"Wall time: 835 ms\n"
]
},
{
"data": {
"text/plain": [
"'coords'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')\n",
"ptest._type",
],
"ptest._type"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 563 ms, sys: 402 ms, total: 965 ms\n",
"Wall time: 971 ms\n"
]
},
{
"data": {
"text/plain": [
"'dists'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')\n",
"dtest._type",
],
"dtest._type"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 134 ms, sys: 80.2 ms, total: 214 ms\n",
"Wall time: 214 ms\n"
]
},
{
"data": {
"text/plain": [
"'angles'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')\n",
"atest._type",
],
"atest._type"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Visualization playground"],
"source": [
"# Visualization playground"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": ["#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)"],
"source": [
"#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
......@@ -160,90 +277,122 @@
"#plt.ylabel('y')\n",
"#plt.title('Mouse Center Trajectory using different exponential smoothings')\n",
"#plt.legend()\n",
"#plt.show()",
],
"#plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Dimensionality reduction playground"],
"source": [
"# Dimensionality reduction playground"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": ["#pca = ptest.pca(4, 1000)"],
"source": [
"#pca = ptest.pca(4, 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": ["#plt.scatter(*pca[0].T)\n", "#plt.show()"],
"source": [
"#plt.scatter(*pca[0].T)\n",
"#plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Preprocessing playground"],
"source": [
"# Preprocessing playground"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,\n",
"# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),\n",
"# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))",
],
"# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)"
],
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"(117507, 11, 28)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True)\n",
"pttest.shape",
],
"pttest.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"#plt.plot(pttest[2,:,2], label='normal')\n",
"#plt.plot(pptest[2,:,2], label='gaussian')\n",
"#plt.legend()\n",
"#plt.show()",
],
"#plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Trained models playground"],
"source": [
"# Trained models playground"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["### Seq 2 seq Variational Auto Encoder"],
"source": [
"### Seq 2 seq Variational Auto Encoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from datetime import datetime\n",
"from tensorflow.keras import Input, Model, Sequential\n",
......@@ -255,130 +404,171 @@
"from tensorflow.keras.optimizers import Adam\n",
"from source.model_utils import *\n",
"import keras as k\n",
"import tensorflow as tf",
],
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"NAME = 'Baseline_VAE_short_512'\n",
"log_dir = os.path.abspath(\n",
' "logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))\n',
" \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
")\n",
"tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)",
],
"tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP"
],
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": ["encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()"],
"source": [
"encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": ["encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()"],
"source": [
"encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"encoder, generator, vaep = SEQ_2_SEQ_VAEP(pttest.shape).build()"
],
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": ["#ae.summary()\n", "#vae.summary()\n", "#vaep.summary()"],
"source": [
"#ae.summary()\n",
"#vae.summary()\n",
"#vaep.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {"scrolled": false},
"execution_count": 26,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"#from tensorflow.keras.utils import plot_model\n",
"#plot_model(vaep, show_shapes=True)",
],
"#plot_model(vaep, show_shapes=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {"scrolled": false},
"execution_count": 27,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": ["#plot_model(vae)"],
"source": [
"#plot_model(vae)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"#np.random.shuffle(pttest)\n",
"pttrain = pttest[:-15000]\n",
"pttest = pttest[-15000:]",
],
"pttest = pttest[-15000:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"lr_schedule = tf.keras.callbacks.LearningRateScheduler(\n",
" lambda epoch: 1e-3 * 10**(epoch / 20))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 102506 samples, validate on 14999 samples\n",
"Epoch 1/100\n",
"102506/102506 [==============================] - 234s 2ms/sample - loss: 20766.2278 - mae: 0.3502 - val_loss: 13873.9038 - val_mae: 0.2742\n",
"Epoch 2/100\n",
" 9216/102506 [=>............................] - ETA: 3:10 - loss: 15431.4450 - mae: 0.2930"
]
}
],
"source": [
"#tf.config.experimental_run_functions_eagerly(False)\n",
"history = ae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
" validation_data=(pttest[:-1], pttest[:-1]),\n",
" callbacks=[tensorboard_callback])",
],
" callbacks=[tensorboard_callback, lr_schedule])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {"scrolled": true},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"#tf.config.experimental_run_functions_eagerly(False)\n",
"# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n",
"# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
"# callbacks=[tensorboard_callback])",
],
},
"# callbacks=[tensorboard_callback])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {"name": "ipython", "version": 3},
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10",
},
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 4,
"nbformat_minor": 4
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment