Commit d2096481 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented version of SEQ_2_SEQ VAE based on tensorflow_probability

parent b99b3ace
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#from source.utils import * #from source.utils import *
from source.preprocess import * from source.preprocess import *
import pickle import pickle
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from collections import defaultdict from collections import defaultdict
from tqdm import tqdm_notebook as tqdm from tqdm import tqdm_notebook as tqdm
``` ```
%% Cell type:code id: tags:parameters %% Cell type:code id: tags:parameters
``` python ``` python
path = "../../Desktop/DLC_social_1/" path = "../../Desktop/DLC_social_1/"
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Set up and design the project # Set up and design the project
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle: with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle) Treatment_dict = pickle.load(handle)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#Which angles to compute? #Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'], bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'], 'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'], 'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'], 'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'], 'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'], 'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']} 'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
DLC_social_1 = project(path=path,#Path where to find the required files DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank', distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'], 'B_Right_flank','B_Tail_base'],
ego=False, ego=False,
angles=True, angles=True,
connectivity=bp_dict, connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4', video_format='.mp4',
table_format='.h5', table_format='.h5',
exp_conditions=Treatment_dict) exp_conditions=Treatment_dict)
``` ```
%%%% Output: stream
CPU times: user 2.59 s, sys: 818 ms, total: 3.41 s
Wall time: 1.1 s
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Run project # Run project
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
DLC_social_1_coords = DLC_social_1.run(verbose=True) DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords) print(DLC_social_1_coords)
type(DLC_social_1_coords) type(DLC_social_1_coords)
``` ```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Generate coords # Generate coords
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00') ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type ptest._type
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00') dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type dtest._type
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00') atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type atest._type
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Visualization playground # Visualization playground
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1) #ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#Plot animation of trajectory over time with different smoothings #Plot animation of trajectory over time with different smoothings
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'], #plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
# ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85') # ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
#plt.xlabel('x') #plt.xlabel('x')
#plt.ylabel('y') #plt.ylabel('y')
#plt.title('Mouse Center Trajectory using different exponential smoothings') #plt.title('Mouse Center Trajectory using different exponential smoothings')
#plt.legend() #plt.legend()
#plt.show() #plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Dimensionality reduction playground # Dimensionality reduction playground
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#pca = ptest.pca(4, 1000) #pca = ptest.pca(4, 1000)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#plt.scatter(*pca[0].T) #plt.scatter(*pca[0].T)
#plt.show() #plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Preprocessing playground # Preprocessing playground
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#, mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'), # DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')) # DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20) #pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True) pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True)
pttest.shape pttest.shape
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#plt.plot(pttest[2,:,2], label='normal') #plt.plot(pttest[2,:,2], label='normal')
#plt.plot(pptest[2,:,2], label='gaussian') #plt.plot(pptest[2,:,2], label='gaussian')
#plt.legend() #plt.legend()
#plt.show() #plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Trained models playground # Trained models playground
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder ### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from datetime import datetime from datetime import datetime
import tensorflow.keras as k import tensorflow.keras as k
import tensorflow as tf import tensorflow as tf
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
NAME = 'Baseline_VAE_short_512_10=warmup_begin' NAME = 'Baseline_VAE_short_512_10=warmup_begin'
log_dir = os.path.abspath( log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S")) "logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
) )
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP, SEQ_2_SEQ_MMVAEP from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP, SEQ_2_SEQ_MMVAEP
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build() encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
ae.build(pttest.shape) ae.build(pttest.shape)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
ae.summary() ae.summary()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
k.backend.clear_session()
encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape, encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
loss='ELBO+MMD', loss='ELBO+MMD',
kl_warmup_epochs=10, kl_warmup_epochs=10,
mmd_warmup_epochs=10).build() mmd_warmup_epochs=10).build()
vae.build(pttest.shape) #vae.build(pttest.shape)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
vae.summary() vae.summary()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape, encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
loss='ELBO+MMD', loss='ELBO+MMD',
kl_warmup_epochs=10, kl_warmup_epochs=10,
mmd_warmup_epochs=10).build() mmd_warmup_epochs=10).build()
vaep.build(pttest.shape) vaep.build(pttest.shape)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
vaep.summary() vaep.summary()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape, encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape,
loss='ELBO+MMD', loss='ELBO+MMD',
number_of_components=2, number_of_components=2,
kl_warmup_epochs=10, kl_warmup_epochs=10,
mmd_warmup_epochs=10).build() mmd_warmup_epochs=10).build()
gmvaep.build(pttest.shape) gmvaep.build(pttest.shape)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from tensorflow.keras.utils import plot_model from tensorflow.keras.utils import plot_model
plot_model(gmvaep, show_shapes=True) plot_model(gmvaep, show_shapes=True)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
?plot_model
```
%% Cell type:code id: tags:
``` python
#np.random.shuffle(pttest) #np.random.shuffle(pttest)
pttrain = pttest[:-15000] pttrain = pttest[:-15000]
pttest = pttest[-15000:] pttest = pttest[-15000:]
pttrain = pttrain[:15000]
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
#lr_schedule = tf.keras.callbacks.LearningRateScheduler( #lr_schedule = tf.keras.callbacks.LearningRateScheduler(
# lambda epoch: 1e-3 * 10**(epoch / 20)) # lambda epoch: 1e-3 * 10**(epoch / 20))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# tf.config.experimental_run_functions_eagerly(False) # tf.config.experimental_run_functions_eagerly(False)
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1, history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=2, batch_size=512, verbose=1,
validation_data=(pttest[:-1], pttest[:-1]), validation_data=(pttest[:-1], pttest[:-1]),
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback]) callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# tf.config.experimental_run_functions_eagerly(False) # tf.config.experimental_run_functions_eagerly(False)
# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1, # history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]), # validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback]) # callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
``` ```
......
...@@ -7,20 +7,9 @@ import tensorflow as tf ...@@ -7,20 +7,9 @@ import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
tfd = tfp.distributions tfd = tfp.distributions
tfpl = tfp.layers
# Helper functions # Helper functions
def sampling(args, epsilon_std=1.0, number_of_components=1, categorical=None):
z_mean, z_log_sigma = args
if number_of_components == 1:
epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std)
return z_mean + K.exp(z_log_sigma) * epsilon
else:
# Implement mixture of gaussians encoding and sampling
pass
def compute_kernel(x, y): def compute_kernel(x, y):
x_size = K.shape(x)[0] x_size = K.shape(x)[0]
y_size = K.shape(y)[0] y_size = K.shape(y)[0]
...@@ -120,35 +109,20 @@ class UncorrelatedFeaturesConstraint(Constraint): ...@@ -120,35 +109,20 @@ class UncorrelatedFeaturesConstraint(Constraint):
return self.weightage * self.uncorrelated_feature(x) return self.weightage * self.uncorrelated_feature(x)
class KLDivergenceLayer(Layer): class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def __init__(self, *args, **kwargs):
""" Identity transform layer that adds KL divergence
to the final model loss.
"""
def __init__(self, beta=1.0, *args, **kwargs):
self.is_placeholder = True self.is_placeholder = True
self.beta = beta
super(KLDivergenceLayer, self).__init__(*args, **kwargs) super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self): def call(self, distribution_a):
config = super().get_config().copy() kl_batch = self._regularizer(distribution_a)
config.update({"beta": self.beta}) self.add_loss(kl_batch, inputs=[distribution_a])
return config self.add_metric(
kl_batch, aggregation="mean", name="kl_divergence",
def call(self, inputs, **kwargs):
mu, log_var = inputs
KL_batch = (
-0.5
* self.beta