Commit dcd20ea4 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented VAEP hypermodel in hypermodels.py

parent b9f7567a
......@@ -10,8 +10,8 @@
"%autoreload 2\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
'warnings.filterwarnings("ignore")',
],
},
{
"cell_type": "code",
......@@ -25,28 +25,20 @@
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from collections import defaultdict\n",
"from tqdm import tqdm_notebook as tqdm"
]
"from tqdm import tqdm_notebook as tqdm",
],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"parameters"
]
},
"metadata": {"tags": ["parameters"]},
"outputs": [],
"source": [
"path = \"../../Desktop/DLC_social_1/\""
]
"source": ['path = "../../Desktop/DLC_social_1/"'],
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set up and design the project"
]
"source": ["# Set up and design the project"],
},
{
"cell_type": "code",
......@@ -55,8 +47,8 @@
"outputs": [],
"source": [
"with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:\n",
" Treatment_dict = pickle.load(handle)"
]
" Treatment_dict = pickle.load(handle)",
],
},
{
"cell_type": "code",
......@@ -71,8 +63,8 @@
" 'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],\n",
" 'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],\n",
" 'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],\n",
" 'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}"
]
" 'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}",
],
},
{
"cell_type": "code",
......@@ -92,16 +84,10 @@
" arena_dims=[380], #Dimensions of the arena. Just one if it's circular\n",
" video_format='.mp4',\n",
" table_format='.h5',\n",
" exp_conditions=Treatment_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run project"
]
" exp_conditions=Treatment_dict)",
],
},
{"cell_type": "markdown", "metadata": {}, "source": ["# Run project"]},
{
"cell_type": "code",
"execution_count": null,
......@@ -111,28 +97,20 @@
"%%time\n",
"DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
"print(DLC_social_1_coords)\n",
"type(DLC_social_1_coords)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate coords"
]
"type(DLC_social_1_coords)",
],
},
{"cell_type": "markdown", "metadata": {}, "source": ["# Generate coords"]},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {"scrolled": true},
"outputs": [],
"source": [
"%%time\n",
"ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')\n",
"ptest._type"
]
"ptest._type",
],
},
{
"cell_type": "code",
......@@ -142,8 +120,8 @@
"source": [
"%%time\n",
"dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')\n",
"dtest._type"
]
"dtest._type",
],
},
{
"cell_type": "code",
......@@ -153,24 +131,20 @@
"source": [
"%%time\n",
"atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')\n",
"atest._type"
]
"atest._type",
],
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualization playground"
]
"source": ["# Visualization playground"],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)"
]
"source": ["#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)"],
},
{
"cell_type": "code",
......@@ -186,41 +160,32 @@
"#plt.ylabel('y')\n",
"#plt.title('Mouse Center Trajectory using different exponential smoothings')\n",
"#plt.legend()\n",
"#plt.show()"
]
"#plt.show()",
],
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dimensionality reduction playground"
]
"source": ["# Dimensionality reduction playground"],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#pca = ptest.pca(4, 1000)"
]
"source": ["#pca = ptest.pca(4, 1000)"],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#plt.scatter(*pca[0].T)\n",
"#plt.show()"
]
"source": ["#plt.scatter(*pca[0].T)\n", "#plt.show()"],
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Preprocessing playground"
]
"source": ["# Preprocessing playground"],
},
{
"cell_type": "code",
......@@ -230,8 +195,8 @@
"source": [
"mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,\n",
"# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),\n",
"# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))"
]
"# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))",
],
},
{
"cell_type": "code",
......@@ -240,7 +205,7 @@
"outputs": [],
"source": [
"#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)"
]
],
},
{
"cell_type": "code",
......@@ -249,8 +214,8 @@
"outputs": [],
"source": [
"pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True)\n",
"pttest.shape"
]
"pttest.shape",
],
},
{
"cell_type": "code",
......@@ -261,22 +226,18 @@
"#plt.plot(pttest[2,:,2], label='normal')\n",
"#plt.plot(pptest[2,:,2], label='gaussian')\n",
"#plt.legend()\n",
"#plt.show()"
]
"#plt.show()",
],
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Trained models playground"
]
"source": ["# Trained models playground"],
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Seq 2 seq Variational Auto Encoder"
]
"source": ["### Seq 2 seq Variational Auto Encoder"],
},
{
"cell_type": "code",
......@@ -294,8 +255,8 @@
"from tensorflow.keras.optimizers import Adam\n",
"from source.model_utils import *\n",
"import keras as k\n",
"import tensorflow as tf"
]
"import tensorflow as tf",
],
},
{
"cell_type": "code",
......@@ -305,10 +266,10 @@
"source": [
"NAME = 'Baseline_VAE_short_512'\n",
"log_dir = os.path.abspath(\n",
" \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
' "logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))\n',
")\n",
"tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)"
]
"tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)",
],
},
{
"cell_type": "code",
......@@ -317,25 +278,21 @@
"outputs": [],
"source": [
"from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP"
]
],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()"
]
"source": ["encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()"],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()"
]
"source": ["encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()"],
},
{
"cell_type": "code",
......@@ -344,41 +301,31 @@
"outputs": [],
"source": [
"encoder, generator, vaep = SEQ_2_SEQ_VAEP(pttest.shape).build()"
]
],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#ae.summary()\n",
"#vae.summary()\n",
"#vaep.summary()"
]
"source": ["#ae.summary()\n", "#vae.summary()\n", "#vaep.summary()"],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {"scrolled": false},
"outputs": [],
"source": [
"#from tensorflow.keras.utils import plot_model\n",
"#plot_model(vaep, show_shapes=True)"
]
"#plot_model(vaep, show_shapes=True)",
],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {"scrolled": false},
"outputs": [],
"source": [
"#plot_model(vae)"
]
"source": ["#plot_model(vae)"],
},
{
"cell_type": "code",
......@@ -388,8 +335,8 @@
"source": [
"#np.random.shuffle(pttest)\n",
"pttrain = pttest[:-15000]\n",
"pttest = pttest[-15000:]"
]
"pttest = pttest[-15000:]",
],
},
{
"cell_type": "code",
......@@ -400,43 +347,38 @@
"#tf.config.experimental_run_functions_eagerly(False)\n",
"history = ae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
" validation_data=(pttest[:-1], pttest[:-1]),\n",
" callbacks=[tensorboard_callback])"
]
" callbacks=[tensorboard_callback])",
],
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {"scrolled": true},
"outputs": [],
"source": [
"#tf.config.experimental_run_functions_eagerly(False)\n",
"# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n",
"# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
"# callbacks=[tensorboard_callback])"
]
}
"# callbacks=[tensorboard_callback])",
],
},
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "python3",
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"codemirror_mode": {"name": "ipython", "version": 3},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
"version": "3.6.10",
},
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 4,
}
......@@ -531,7 +531,10 @@ class SEQ_2_SEQ_MMVAE:
pass
# TODO next:
# - MERGE BatchNormalization layers in generator and _generator in SEQ_2_SEQ_VAE
# - VAE loss function (though this should be analysed later on taking the encodings into account)
# - Smaller input sliding window (10-15 frames)
# TODO:
# - cleaner implementation of reparameterization trick (sampling on input, outside the main graph)
# - KL / MMD warmup (Ladder Variational Autoencoders)
# - Gaussian Mixture + Categorical priors -> Deep Clustering
# - free bits paper
# - Attention mechanism for encoder / decoder (does it make sense?)
# - Transformer encoder/decoder (does it make sense?)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment