Commit ea24a625 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored

parent 391b7da8
......@@ -317,27 +317,27 @@
"metadata": {},
"outputs": [],
"source": [
"pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)"
"#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(704997, 51, 28)"
"(70500, 51, 28)"
]
},
"execution_count": 12,
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pttest = mtest.preprocess(window_size=51, filter=None)\n",
"pttest = mtest.preprocess(window_size=51, window_step=10, filter=None)\n",
"pttest.shape"
]
},
......@@ -380,136 +380,71 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE"
"from datetime import datetime\n",
"from tensorflow.keras import Input, Model, Sequential\n",
"from tensorflow.keras.constraints import UnitNorm\n",
"from tensorflow.keras.layers import Bidirectional, Dense, Dropout\n",
"from tensorflow.keras.layers import Lambda, LSTM\n",
"from tensorflow.keras.layers import RepeatVector, TimeDistributed\n",
"from tensorflow.keras.losses import Huber\n",
"from tensorflow.keras.optimizers import Adam\n",
"from source.model_utils import *\n",
"import keras as k\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()"
"NAME = 'Baseline_AE'\n",
"log_dir = os.path.abspath(\n",
" \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
")\n",
"tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 73,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"SEQ_2_SEQ_AE\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"SEQ_2_SEQ_Encoder (Sequentia (None, 32) 1396640 \n",
"_________________________________________________________________\n",
"SEQ_2_SEQ_Decoder (Sequentia (None, 51, 28) 2392508 \n",
"=================================================================\n",
"Total params: 3,774,652\n",
"Trainable params: 3,774,652\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"outputs": [],
"source": [
"ae.summary()"
"from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE"
]
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()"
"encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()"
]
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"SEQ_2_SEQ_VAE\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_14 (InputLayer) [(None, 51, 28)] 0 \n",
"__________________________________________________________________________________________________\n",
"conv1d_26 (Conv1D) (None, 51, 256) 36096 input_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_102 (Bidirectiona (None, 51, 512) 1050624 conv1d_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_103 (Bidirectiona (None, 128) 295424 bidirectional_102[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_126 (Dense) (None, 64) 8256 bidirectional_103[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_25 (Dropout) (None, 64) 0 dense_126[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_127 (Dense) (None, 64) 4160 dropout_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_128 (Dense) (None, 32) 2080 dense_127[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_129 (Dense) (None, 32) 1056 dense_128[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_130 (Dense) (None, 32) 1056 dense_128[0][0] \n",
"__________________________________________________________________________________________________\n",
"kl_divergence_layer_12 (KLDiver [(None, 32), (None, 0 dense_129[0][0] \n",
" dense_130[0][0] \n",
"__________________________________________________________________________________________________\n",
"lambda_12 (Lambda) (None, 32) 0 kl_divergence_layer_12[0][0] \n",
" kl_divergence_layer_12[0][1] \n",
"__________________________________________________________________________________________________\n",
"mm_discrepancy_layer_12 (MMDisc (None, 32) 0 lambda_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_transpose_74 (DenseTransp (None, 64) 2144 mm_discrepancy_layer_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_transpose_75 (DenseTransp (None, 64) 4224 dense_transpose_74[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_transpose_76 (DenseTransp (None, 128) 8384 dense_transpose_75[0][0] \n",
"__________________________________________________________________________________________________\n",
"repeat_vector_24 (RepeatVector) (None, 51, 128) 0 dense_transpose_76[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_104 (Bidirectiona (None, 51, 512) 788480 repeat_vector_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_105 (Bidirectiona (None, 51, 512) 1574912 bidirectional_104[0][0] \n",
"__________________________________________________________________________________________________\n",
"time_distributed_24 (TimeDistri (None, 51, 28) 14364 bidirectional_105[0][0] \n",
"==================================================================================================\n",
"Total params: 3,776,764\n",
"Trainable params: 3,776,764\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"outputs": [],
"source": [
"vae.summary()"
"encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"#tf.config.experimental_run_functions_eagerly(False)\n",
"ptrain = pttest[np.random.choice(pttest.shape[0], 1000, replace=False), :, :]\n",
"history = vae.fit(ptrain, ptrain, epochs=50, batch_size=batch_size, verbose=1)"
"pttrain = pttest[:-1500]\n",
"pttest = pttest[-1500:]"
]
},
{
......@@ -518,27 +453,19 @@
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"#plt.plot(history.history['mae'], label='Huber + MMD mae')\n",
"plt.plot(history.history['mae'], label='Huber + KL mae')\n",
"\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 69000 samples, validate on 1500 samples\n",
"Epoch 1/50\n"
]
}
],
"source": [
"#Huber loss + MMD/ELBO in training data\n",
"plt.plot(pttest[:2000,0,0], label='data')\n",
"plt.plot(vae.predict(pttest[:2000])[:,0,0], label='MMD reconstruction')\n",
"\n",
"plt.legend()\n",
"plt.show()"
"#tf.config.experimental_run_functions_eagerly(False)\n",
"history = ae.fit(pttrain, pttrain, epochs=50, batch_size=256, verbose=1, validation_data=(pttest, pttest))"
]
},
{
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
#from source.utils import *
from source.preprocess import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('../../Desktop/DLC_social_1/DLC_social_1_exp_conditions.pickle', 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path='../../Desktop/DLC_social_1/',#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%%%% Output: stream
CPU times: user 2.76 s, sys: 847 ms, total: 3.61 s
Wall time: 1.23 s
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
Coordinates of 47 videos across 4 conditions
CPU times: user 10 s, sys: 684 ms, total: 10.7 s
Wall time: 11.1 s
%%%% Output: execute_result
source.preprocess.coordinates
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type
```
%%%% Output: stream
CPU times: user 928 ms, sys: 76.5 ms, total: 1 s
Wall time: 943 ms
%%%% Output: execute_result
'coords'
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
```
%%%% Output: stream
CPU times: user 579 ms, sys: 434 ms, total: 1.01 s
Wall time: 1.04 s
%%%% Output: execute_result
'dists'
%% Cell type:code id: tags:
``` python
%%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type
```
%%%% Output: stream
CPU times: user 179 ms, sys: 110 ms, total: 289 ms
Wall time: 310 ms
%%%% Output: execute_result
'angles'
%% Cell type:markdown id: tags:
# Visualization playground
%% Cell type:code id: tags:
``` python
ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
```
%% Cell type:code id: tags:
``` python
#Plot animation of trajectory over time with different smoothings
plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Mouse Center Trajectory using different exponential smoothings')
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
# Dimensionality reduction playground
%% Cell type:code id: tags:
``` python
pca = ptest.pca(4, 1000)
```
%% Cell type:code id: tags:
``` python
plt.scatter(*pca[0].T)
plt.show()
```
%% Cell type:markdown id: tags:
# Preprocessing playground
%% Cell type:code id: tags:
``` python
mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
```
%% Cell type:code id: tags:
``` python
pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
```
%% Cell type:code id: tags:
``` python
pttest = mtest.preprocess(window_size=51, filter=None)
pttest = mtest.preprocess(window_size=51, window_step=10, filter=None)
pttest.shape
```
%%%% Output: execute_result
(704997, 51, 28)
(70500, 51, 28)
%% Cell type:code id: tags:
``` python
plt.plot(pttest[2,:,2], label='normal')
plt.plot(pptest[2,:,2], label='gaussian')
plt.legend()
plt.show()
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
# Trained models playground
%% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags:
``` python
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE
from datetime import datetime
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.layers import Bidirectional, Dense, Dropout
from tensorflow.keras.layers import Lambda, LSTM
from tensorflow.keras.layers import RepeatVector, TimeDistributed
from tensorflow.keras.losses import Huber
from tensorflow.keras.optimizers import Adam
from source.model_utils import *
import keras as k
import tensorflow as tf
```
%% Cell type:code id: tags:
``` python
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
NAME = 'Baseline_AE'
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
```
%% Cell type:code id: tags:
``` python
ae.summary()
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE
```
%%%% Output: stream
Model: "SEQ_2_SEQ_AE"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
SEQ_2_SEQ_Encoder (Sequentia (None, 32) 1396640
_________________________________________________________________
SEQ_2_SEQ_Decoder (Sequentia (None, 51, 28) 2392508
=================================================================
Total params: 3,774,652
Trainable params: 3,774,652
Non-trainable params: 0
_________________________________________________________________
%% Cell type:code id: tags:
``` python
encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
```
%% Cell type:code id: tags:
``` python
vae.summary()
encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()
```
%%%% Output: stream
Model: "SEQ_2_SEQ_VAE"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_14 (InputLayer) [(None, 51, 28)] 0
__________________________________________________________________________________________________
conv1d_26 (Conv1D) (None, 51, 256) 36096 input_14[0][0]
__________________________________________________________________________________________________
bidirectional_102 (Bidirectiona (None, 51, 512) 1050624 conv1d_26[0][0]
__________________________________________________________________________________________________
bidirectional_103 (Bidirectiona (None, 128) 295424 bidirectional_102[0][0]
__________________________________________________________________________________________________
dense_126 (Dense) (None, 64) 8256 bidirectional_103[0][0]
__________________________________________________________________________________________________
dropout_25 (Dropout) (None, 64) 0 dense_126[0][0]
__________________________________________________________________________________________________
dense_127 (Dense) (None, 64) 4160 dropout_25[0][0]
__________________________________________________________________________________________________
dense_128 (Dense) (None, 32) 2080 dense_127[0][0]
__________________________________________________________________________________________________
dense_129 (Dense) (None, 32) 1056 dense_128[0][0]
__________________________________________________________________________________________________
dense_130 (Dense) (None, 32) 1056 dense_128[0][0]
__________________________________________________________________________________________________
kl_divergence_layer_12 (KLDiver [(None, 32), (None, 0 dense_129[0][0]
dense_130[0][0]
__________________________________________________________________________________________________
lambda_12 (Lambda) (None, 32) 0 kl_divergence_layer_12[0][0]
kl_divergence_layer_12[0][1]
__________________________________________________________________________________________________
mm_discrepancy_layer_12 (MMDisc (None, 32) 0 lambda_12[0][0]
__________________________________________________________________________________________________
dense_transpose_74 (DenseTransp (None, 64) 2144 mm_discrepancy_layer_12[0][0]
__________________________________________________________________________________________________
dense_transpose_75 (DenseTransp (None, 64) 4224 dense_transpose_74[0][0]
__________________________________________________________________________________________________
dense_transpose_76 (DenseTransp (None, 128) 8384 dense_transpose_75[0][0]
__________________________________________________________________________________________________
repeat_vector_24 (RepeatVector) (None, 51, 128) 0 dense_transpose_76[0][0]
__________________________________________________________________________________________________
bidirectional_104 (Bidirectiona (None, 51, 512) 788480 repeat_vector_24[0][0]
__________________________________________________________________________________________________
bidirectional_105 (Bidirectiona (None, 51, 512) 1574912 bidirectional_104[0][0]
__________________________________________________________________________________________________
time_distributed_24 (TimeDistri (None, 51, 28) 14364 bidirectional_105[0][0]
==================================================================================================
Total params: 3,776,764
Trainable params: 3,776,764
Non-trainable params: 0
__________________________________________________________________________________________________
%% Cell type:code id: tags:
``` python
#tf.config.experimental_run_functions_eagerly(False)
ptrain = pttest[np.random.choice(pttest.shape[0], 1000, replace=False), :, :]
history = vae.fit(ptrain, ptrain, epochs=50, batch_size=batch_size, verbose=1)
pttrain = pttest[:-1500]
pttest = pttest[-1500:]
```
%% Cell type:code id: tags:
``` python
#plt.plot(history.history['mae'], label='Huber + MMD mae')
plt.plot(history.history['mae'], label='Huber + KL mae')
plt.legend()
plt.show()
#tf.config.experimental_run_functions_eagerly(False)
history = ae.fit(pttrain, pttrain, epochs=50, batch_size=256, verbose=1, validation_data=(pttest, pttest))
```
%% Cell type:code id: tags:
``` python
#Huber loss + MMD/ELBO in training data
plt.plot(pttest[:2000,0,0], label='data')
plt.plot(vae.predict(pttest[:2000])[:,0,0], label='MMD reconstruction')
%%%% Output: stream
plt.legend()
plt.show()
```
Train on 69000 samples, validate on 1500 samples
Epoch 1/50
%% Cell type:code id: tags:
``` python
```
......
......@@ -280,4 +280,5 @@ class SEQ_2_SEQ_MMVAE:
# - Initial Convnet switch
# - Bidirectional LSTM switches
# - Change LSTMs for GRU
# - VAE loss function (though this should be analysed later on taking the encodings into account)
\ No newline at end of file
# - VAE loss function (though this should be analysed later on taking the encodings into account)
# - Tied/Untied weights!
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment