Commit 46c85f0b authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py

parent 61e81481
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
#from source.utils import *
from source.preprocess import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:code id: tags:parameters
``` python
path = "../../Desktop/DLC_social_1/"
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%%%% Output: stream
CPU times: user 2.81 s, sys: 853 ms, total: 3.66 s
Wall time: 1.28 s
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
Coordinates of 47 videos across 4 conditions
CPU times: user 9.61 s, sys: 593 ms, total: 10.2 s
Wall time: 10.6 s
%%%% Output: execute_result
source.preprocess.coordinates
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type
```
%%%% Output: stream
CPU times: user 889 ms, sys: 60.2 ms, total: 949 ms
Wall time: 876 ms
%%%% Output: execute_result
'coords'
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
```
%%%% Output: stream
CPU times: user 539 ms, sys: 358 ms, total: 897 ms
Wall time: 897 ms
%%%% Output: execute_result
'dists'
%% Cell type:code id: tags:
``` python
%%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type
```
%%%% Output: stream
CPU times: user 132 ms, sys: 72.4 ms, total: 204 ms
Wall time: 204 ms
%%%% Output: execute_result
'angles'
%% Cell type:markdown id: tags:
# Visualization playground
%% Cell type:code id: tags:
``` python
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
```
%% Cell type:code id: tags:
``` python
#Plot animation of trajectory over time with different smoothings
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
# ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
#plt.xlabel('x')
#plt.ylabel('y')
#plt.title('Mouse Center Trajectory using different exponential smoothings')
#plt.legend()
#plt.show()
```
%% Cell type:markdown id: tags:
# Dimensionality reduction playground
%% Cell type:code id: tags:
``` python
#pca = ptest.pca(4, 1000)
```
%% Cell type:code id: tags:
``` python
#plt.scatter(*pca[0].T)
#plt.show()
```
%% Cell type:markdown id: tags:
# Preprocessing playground
%% Cell type:code id: tags:
``` python
mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
```
%% Cell type:code id: tags:
``` python
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
```
%% Cell type:code id: tags:
``` python
pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True)
pttest.shape
```
%%%% Output: execute_result
(117507, 11, 28)
%% Cell type:code id: tags:
``` python
#plt.plot(pttest[2,:,2], label='normal')
#plt.plot(pptest[2,:,2], label='gaussian')
#plt.legend()
#plt.show()
```
%% Cell type:markdown id: tags:
# Trained models playground
%% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags:
``` python
from datetime import datetime
import tensorflow.keras as k
import tensorflow as tf
```
%% Cell type:code id: tags:
``` python
NAME = 'Baseline_VAE_short_512_0=warmup_begin'
NAME = 'Baseline_VAEP_short_512_0=warmup_begin'
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
```
%% Cell type:code id: tags:
``` python
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP
```
%%%% Output: stream
Using TensorFlow backend.
%% Cell type:code id: tags:
``` python
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
```
%% Cell type:code id: tags:
``` python
encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
loss='ELBO+MMD',
kl_warmup_epochs=0,
mmd_warmup_epochs=0).build()
# encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
# loss='ELBO+MMD',
# kl_warmup_epochs=0,
# mmd_warmup_epochs=0).build()
```
%% Cell type:code id: tags:
``` python
# encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
# loss='ELBO+MMD',
# kl_warmup_epochs=10,
# mmd_warmup_epochs=10).build()
encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
loss='ELBO+MMD',
kl_warmup_epochs=0,
mmd_warmup_epochs=0).build()
```
%% Cell type:code id: tags:
``` python
#ae.summary()
#vae.summary()
#vaep.summary()
```
%% Cell type:code id: tags:
``` python
#from tensorflow.keras.utils import plot_model
#plot_model(vaep, show_shapes=True)
```
%% Cell type:code id: tags:
``` python
#plot_model(vae)
```
%% Cell type:code id: tags:
``` python
#np.random.shuffle(pttest)
pttrain = pttest[:-15000]
pttest = pttest[-15000:]
```
%% Cell type:code id: tags:
``` python
#lr_schedule = tf.keras.callbacks.LearningRateScheduler(
# lambda epoch: 1e-3 * 10**(epoch / 20))
```
%% Cell type:code id: tags:
``` python
tf.config.experimental_run_functions_eagerly(False)
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,
validation_data=(pttest[:-1], pttest[:-1]),
callbacks=[tensorboard_callback])
# tf.config.experimental_run_functions_eagerly(False)
# history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], pttest[:-1]),
# callbacks=[tensorboard_callback])
```
%%%% Output: stream
Train on 102506 samples, validate on 14999 samples
WARNING:tensorflow:Model failed to serialize as JSON. Ignoring... ('Not JSON Serializable:', <tf.Variable 'kl_beta:0' shape=() dtype=float32, numpy=1.0>)
Epoch 1/100
19968/102506 [====>.........................] - ETA: 3:49 - loss: 1231520.0176 - mae: 0.6727 - kl_divergence: 44.2783 - kl_rate: 1.0000 - mmd: 0.0851 - mmd_rate: 1.0000
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
tf.config.experimental_run_functions_eagerly(False)
history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,
validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
callbacks=[tensorboard_callback])
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment