Commit 6cc67976 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented version of SEQ_2_SEQ VAE based on tensorflow_probability, and...

Implemented version of SEQ_2_SEQ VAE based on tensorflow_probability, and SEQ_2_SEQ_GMVAE draft using tensorflow_probability
parent d2096481
......@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -30,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"tags": [
"parameters"
......@@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -76,18 +76,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.59 s, sys: 818 ms, total: 3.41 s\n",
"Wall time: 1.1 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"DLC_social_1 = project(path=path,#Path where to find the required files\n",
......@@ -115,17 +106,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading trajectories...\n",
"Smoothing trajectories...\n",
"Computing distances...\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
......@@ -355,7 +336,6 @@
"metadata": {},
"outputs": [],
"source": [
"k.backend.clear_session()\n",
"encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n",
" loss='ELBO+MMD',\n",
" kl_warmup_epochs=10,\n",
......@@ -378,11 +358,11 @@
"metadata": {},
"outputs": [],
"source": [
"encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n",
" loss='ELBO+MMD',\n",
" kl_warmup_epochs=10,\n",
" mmd_warmup_epochs=10).build()\n",
"vaep.build(pttest.shape)"
"# encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n",
"# loss='ELBO+MMD',\n",
"# kl_warmup_epochs=10,\n",
"# mmd_warmup_epochs=10).build()\n",
"# vaep.build(pttest.shape)"
]
},
{
......@@ -391,7 +371,7 @@
"metadata": {},
"outputs": [],
"source": [
"vaep.summary()"
"# vaep.summary()"
]
},
{
......@@ -400,24 +380,12 @@
"metadata": {},
"outputs": [],
"source": [
"encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape,\n",
" loss='ELBO+MMD',\n",
" number_of_components=2,\n",
" kl_warmup_epochs=10,\n",
" mmd_warmup_epochs=10).build()\n",
"gmvaep.build(pttest.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"from tensorflow.keras.utils import plot_model\n",
"plot_model(gmvaep, show_shapes=True)"
"# encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape,\n",
"# loss='ELBO+MMD',\n",
"# number_of_components=2,\n",
"# kl_warmup_epochs=10,\n",
"# mmd_warmup_epochs=10).build()\n",
"# gmvaep.build(pttest.shape)"
]
},
{
......@@ -428,8 +396,7 @@
"source": [
"#np.random.shuffle(pttest)\n",
"pttrain = pttest[:-15000]\n",
"pttest = pttest[-15000:]\n",
"pttrain = pttrain[:15000]"
"pttest = pttest[-15000:]"
]
},
{
......@@ -449,7 +416,7 @@
"outputs": [],
"source": [
"# tf.config.experimental_run_functions_eagerly(False)\n",
"history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=2, batch_size=512, verbose=1,\n",
"history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
" validation_data=(pttest[:-1], pttest[:-1]),\n",
" callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
......@@ -463,10 +430,53 @@
"outputs": [],
"source": [
"# tf.config.experimental_run_functions_eagerly(False)\n",
"# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,\n",
"# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=2, batch_size=512, verbose=1,\n",
"# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
"# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# tf.config.experimental_run_functions_eagerly(False)\n",
"# history = gmvaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=2, batch_size=512, verbose=1,\n",
"# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
"# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Probability playground"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#I need to find a way of using DistributionLambda in my settings, \n",
"#to build a gaussian mixture likelihhod with the proper categorical prior for clustering"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
#from source.utils import *
from source.preprocess import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:code id: tags:parameters
``` python
path = "../../Desktop/DLC_social_1/"
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%%%% Output: stream
CPU times: user 2.59 s, sys: 818 ms, total: 3.41 s
Wall time: 1.1 s
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type
```
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
```
%% Cell type:code id: tags:
``` python
%%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type
```
%% Cell type:markdown id: tags:
# Visualization playground
%% Cell type:code id: tags:
``` python
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
```
%% Cell type:code id: tags:
``` python
#Plot animation of trajectory over time with different smoothings
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
# ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
#plt.xlabel('x')
#plt.ylabel('y')
#plt.title('Mouse Center Trajectory using different exponential smoothings')
#plt.legend()
#plt.show()
```
%% Cell type:markdown id: tags:
# Dimensionality reduction playground
%% Cell type:code id: tags:
``` python
#pca = ptest.pca(4, 1000)
```
%% Cell type:code id: tags:
``` python
#plt.scatter(*pca[0].T)
#plt.show()
```
%% Cell type:markdown id: tags:
# Preprocessing playground
%% Cell type:code id: tags:
``` python
mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
```
%% Cell type:code id: tags:
``` python
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
```
%% Cell type:code id: tags:
``` python
pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True)
pttest.shape
```
%% Cell type:code id: tags:
``` python
#plt.plot(pttest[2,:,2], label='normal')
#plt.plot(pptest[2,:,2], label='gaussian')
#plt.legend()
#plt.show()
```
%% Cell type:markdown id: tags:
# Trained models playground
%% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags:
``` python
from datetime import datetime
import tensorflow.keras as k
import tensorflow as tf
```
%% Cell type:code id: tags:
``` python
NAME = 'Baseline_VAE_short_512_10=warmup_begin'
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
```
%% Cell type:code id: tags:
``` python
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP, SEQ_2_SEQ_MMVAEP
```
%% Cell type:code id: tags:
``` python
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
ae.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
ae.summary()
```
%% Cell type:code id: tags:
``` python
k.backend.clear_session()
encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
loss='ELBO+MMD',
kl_warmup_epochs=10,
mmd_warmup_epochs=10).build()
#vae.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
vae.summary()
```
%% Cell type:code id: tags:
``` python
encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
loss='ELBO+MMD',
kl_warmup_epochs=10,
mmd_warmup_epochs=10).build()
vaep.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
vaep.summary()
# encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
# loss='ELBO+MMD',
# kl_warmup_epochs=10,
# mmd_warmup_epochs=10).build()
# vaep.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape,
loss='ELBO+MMD',
number_of_components=2,
kl_warmup_epochs=10,
mmd_warmup_epochs=10).build()
gmvaep.build(pttest.shape)
# vaep.summary()
```
%% Cell type:code id: tags:
``` python
from tensorflow.keras.utils import plot_model
plot_model(gmvaep, show_shapes=True)
# encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape,
# loss='ELBO+MMD',
# number_of_components=2,
# kl_warmup_epochs=10,
# mmd_warmup_epochs=10).build()
# gmvaep.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
#np.random.shuffle(pttest)
pttrain = pttest[:-15000]
pttest = pttest[-15000:]
pttrain = pttrain[:15000]
```
%% Cell type:code id: tags:
``` python
#lr_schedule = tf.keras.callbacks.LearningRateScheduler(
# lambda epoch: 1e-3 * 10**(epoch / 20))
```
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=2, batch_size=512, verbose=1,
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,
validation_data=(pttest[:-1], pttest[:-1]),
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,
# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=2, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
# history = gmvaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=2, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:markdown id: tags:
## Probability playground
%% Cell type:code id: tags:
``` python
#I need to find a way of using DistributionLambda in my settings,
#to build a gaussian mixture likelihhod with the proper categorical prior for clustering
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......
......@@ -109,6 +109,15 @@ class UncorrelatedFeaturesConstraint(Constraint):
return self.weightage * self.uncorrelated_feature(x)
class GaussianMixtureLayer(Layer):
def __init(self, *args, **kwargs):
self.is_placeholder = True
super(GaussianMixtureLayer, self).__init__(*args, **kwargs)
def call(self, inputs, **kwargs):
pass
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def __init__(self, *args, **kwargs):
self.is_placeholder = True
......
......@@ -528,7 +528,7 @@ class SEQ_2_SEQ_VAEP:
)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# Define and instantiate generator
generator = Model_D0(z)
......@@ -648,9 +648,17 @@ class SEQ_2_SEQ_MMVAEP:
self.number_of_components = number_of_components
if self.prior == "standard_normal":
self.prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
reinterpreted_batch_ndims=1,
self.prior = tfd.mixture.Mixture(
tfd.categorical.Categorical(
probs=tf.ones(self.number_of_components) / self.number_of_components
),
[
tfd.Independent(
tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
reinterpreted_batch_ndims=1,
)
for _ in range(self.number_of_components)
],
)
assert (
......@@ -747,8 +755,14 @@ class SEQ_2_SEQ_MMVAEP:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
encoder = Dense(
tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
# Map encoder to a categorical distribution over the components
zcat = Dense(self.number_of_components, activation="softmax")(encoder)
# Map encoder to a dense layer representing the parameters of
# the gaussian mixture latent space
zgauss = Dense(
tfpl.MixtureNormal.params_size(self.number_of_components, self.ENCODING),
activation=None,
)(encoder)
# Define and control custom loss functions
......@@ -764,7 +778,7 @@ class SEQ_2_SEQ_MMVAEP:
)
)
z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
z = tfpl.MixtureNormal(self.number_of_components, self.ENCODING)(zgauss)
if "ELBO" in self.loss:
z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
......@@ -781,7 +795,7 @@ class SEQ_2_SEQ_MMVAEP:
)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# Define and instantiate generator
generator = Model_D0(z)
......@@ -869,8 +883,11 @@ class SEQ_2_SEQ_MMVAEP:
# TODO:
# - Try sample, mean and mode for MMDiscrepancyLayer
# - Gaussian Mixture + Categorical priors -> Deep Clustering
# - MCMC sampling (n>1)
# - prior of equal gaussians
# - prior of equal gaussians + gaussian noise on the means (not exactly the same init)
# - MCMC sampling (n>1) (already suported by tfp! we should try it)
#
# TODO (in the non-immediate future):
# - free bits paper
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment