Commit c9ffecb6 authored by lucas_miranda's avatar lucas_miranda
Browse files

Unified all VAE variants in a single model with different settings

parent d4d469a0
......@@ -295,7 +295,7 @@
"metadata": {},
"outputs": [],
"source": [
"NAME = 'Baseline_MMVAEP_512_wu10'\n",
"NAME = 'Baseline_MMVAE_512_wu10'\n",
"log_dir = os.path.abspath(\n",
" \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
")\n",
......@@ -388,7 +388,8 @@
" loss='ELBO+MMD',\n",
" number_of_components=1,\n",
" kl_warmup_epochs=10,\n",
" mmd_warmup_epochs=10).build()\n",
" mmd_warmup_epochs=10,\n",
" predictor=False).build()\n",
"gmvaep.build(pttest.shape)"
]
},
......@@ -396,11 +397,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
"# tf.keras.utils.plot_model(gmvaep, show_shapes=True)"
"tf.keras.utils.plot_model(gmvaep, show_shapes=True)"
]
},
{
......@@ -431,9 +432,9 @@
"outputs": [],
"source": [
"# tf.config.experimental_run_functions_eagerly(False)\n",
"# history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
"# validation_data=(pttest[:-1], pttest[:-1]),\n",
"# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
"history = gmvaep.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
" validation_data=(pttest[:-1], pttest[:-1]),\n",
" callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
},
{
......@@ -457,9 +458,9 @@
"outputs": [],
"source": [
"# tf.config.experimental_run_functions_eagerly(False)\n",
"history = gmvaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n",
" validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
" callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
"# history = gmvaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n",
"# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
"# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
},
{
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
#from source.utils import *
from source.preprocess import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:code id: tags:parameters
``` python
path = "../../Desktop/DLC_social_1/"
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type
```
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
```
%% Cell type:code id: tags:
``` python
%%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type
```
%% Cell type:markdown id: tags:
# Visualization playground
%% Cell type:code id: tags:
``` python
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
```
%% Cell type:code id: tags:
``` python
#Plot animation of trajectory over time with different smoothings
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
# ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
#plt.xlabel('x')
#plt.ylabel('y')
#plt.title('Mouse Center Trajectory using different exponential smoothings')
#plt.legend()
#plt.show()
```
%% Cell type:markdown id: tags:
# Dimensionality reduction playground
%% Cell type:code id: tags:
``` python
#pca = ptest.pca(4, 1000)
```
%% Cell type:code id: tags:
``` python
#plt.scatter(*pca[0].T)
#plt.show()
```
%% Cell type:markdown id: tags:
# Preprocessing playground
%% Cell type:code id: tags:
``` python
mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
```
%% Cell type:code id: tags:
``` python
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
```
%% Cell type:code id: tags:
``` python
pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=True)
pttest.shape
```
%% Cell type:code id: tags:
``` python
#plt.plot(pttest[2,:,2], label='normal')
#plt.plot(pptest[2,:,2], label='gaussian')
#plt.legend()
#plt.show()
```
%% Cell type:markdown id: tags:
# Trained models playground
%% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags:
``` python
from datetime import datetime
import tensorflow.keras as k
import tensorflow as tf
```
%% Cell type:code id: tags:
``` python
NAME = 'Baseline_MMVAEP_512_wu10'
NAME = 'Baseline_MMVAE_512_wu10'
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
```
%% Cell type:code id: tags:
``` python
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_VAE, SEQ_2_SEQ_VAEP, SEQ_2_SEQ_MMVAEP
```
%% Cell type:code id: tags:
``` python
# encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
# ae.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
# ae.summary()
```
%% Cell type:code id: tags:
``` python
# encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
# loss='ELBO+MMD',
# kl_warmup_epochs=10,
# mmd_warmup_epochs=10).build()
# vae.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
# vae.summary()
```
%% Cell type:code id: tags:
``` python
# encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
# loss='ELBO+MMD',
# kl_warmup_epochs=10,
# mmd_warmup_epochs=10).build()
# vaep.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
# tf.keras.utils.plot_model(vaep)
```
%% Cell type:code id: tags:
``` python
encoder, generator, grouper, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_MMVAEP(pttest.shape,
loss='ELBO+MMD',
number_of_components=1,
kl_warmup_epochs=10,
mmd_warmup_epochs=10).build()
mmd_warmup_epochs=10,
predictor=False).build()
gmvaep.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
# tf.keras.utils.plot_model(gmvaep, show_shapes=True)
tf.keras.utils.plot_model(gmvaep, show_shapes=True)
```
%% Cell type:code id: tags:
``` python
#np.random.shuffle(pttest)
pttrain = pttest[:-15000]
pttest = pttest[-15000:]
```
%% Cell type:code id: tags:
``` python
#lr_schedule = tf.keras.callbacks.LearningRateScheduler(
# lambda epoch: 1e-3 * 10**(epoch / 20))
```
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
# history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], pttest[:-1]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
history = gmvaep.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,
validation_data=(pttest[:-1], pttest[:-1]),
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
history = gmvaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,
validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
# history = gmvaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,
# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:markdown id: tags:
## Encoding interpretation playground
%% Cell type:code id: tags:
``` python
encodings = encoder.predict(pttest)
clusters = np.argmax(grouper.predict(pttest), axis=1)
```
%% Cell type:code id: tags:
``` python
%matplotlib notebook
# This import registers the 3D projection, but is otherwise unused.
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
encs = encodings[:3000]
ax.scatter(encs[:,0],encs[:,1],encs[:,2], c=clusters[:3000])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
```
%% Cell type:code id: tags:
``` python
```
......
......@@ -6,7 +6,7 @@ from tensorflow.keras.activations import softplus
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform, Orthogonal
from tensorflow.keras.layers import BatchNormalization, Bidirectional, Concatenate
from tensorflow.keras.layers import BatchNormalization, Bidirectional
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
from tensorflow.keras.losses import Huber
......@@ -606,6 +606,7 @@ class SEQ_2_SEQ_MMVAEP:
mmd_warmup_epochs=0,
prior="standard_normal",
number_of_components=1,
predictor=True,
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -621,6 +622,7 @@ class SEQ_2_SEQ_MMVAEP:
self.kl_warmup = kl_warmup_epochs
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
self.predictor = predictor
if self.prior == "standard_normal":
self.prior = tfd.mixture.Mixture(
......@@ -790,43 +792,48 @@ class SEQ_2_SEQ_MMVAEP:
Dense(self.input_shape[2]), name="vaep_reconstruction"
)(generator)
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
)(z)
predictor = BatchNormalization()(predictor)
predictor = Dense(
self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
)(predictor)
predictor = BatchNormalization()(predictor)
predictor = RepeatVector(self.input_shape[1])(predictor)
predictor = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="tanh",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
predictor = BatchNormalization()(predictor)
predictor = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed(
Dense(self.input_shape[2]), name="vaep_prediction"
)(predictor)
if self.predictor:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
)(z)
predictor = BatchNormalization()(predictor)
predictor = Dense(
self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
)(predictor)
predictor = BatchNormalization()(predictor)
predictor = RepeatVector(self.input_shape[1])(predictor)
predictor = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="tanh",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
predictor = BatchNormalization()(predictor)
predictor = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed(
Dense(self.input_shape[2]), name="vaep_prediction"
)(predictor)
# end-to-end autoencoder
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
gmvaep = Model(
inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAEP"
inputs=x,
outputs=(
[x_decoded_mean, x_predicted_mean] if self.predictor else x_decoded_mean
),
name="SEQ_2_SEQ_VAE",
)
# Build generator as a separate entity
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment