Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
b5d4cf64
Commit
b5d4cf64
authored
Jun 05, 2020
by
lucas_miranda
Browse files
Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py
parent
442ed01b
Changes
1
Hide whitespace changes
Inline
Side-by-side
main.ipynb
View file @
b5d4cf64
...
@@ -326,10 +326,10 @@
...
@@ -326,10 +326,10 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n",
"
#
encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n",
" loss='ELBO+MMD',\n",
"
#
loss='ELBO+MMD',\n",
" kl_warmup_epochs=10,\n",
"
#
kl_warmup_epochs=10,\n",
" mmd_warmup_epochs=10).build()"
"
#
mmd_warmup_epochs=10).build()"
]
]
},
},
{
{
...
@@ -338,10 +338,10 @@
...
@@ -338,10 +338,10 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"
#
encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n",
"encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n",
"
#
loss='ELBO+MMD',\n",
" loss='ELBO+MMD',\n",
"
#
kl_warmup_epochs=10,\n",
" kl_warmup_epochs=10,\n",
"
#
mmd_warmup_epochs=10).build()"
" mmd_warmup_epochs=10).build()"
]
]
},
},
{
{
...
@@ -405,10 +405,10 @@
...
@@ -405,10 +405,10 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"#tf.config.experimental_run_functions_eagerly(False)\n",
"#
tf.config.experimental_run_functions_eagerly(False)\n",
"history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
"
#
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n",
" validation_data=(pttest[:-1], pttest[:-1]),\n",
"
#
validation_data=(pttest[:-1], pttest[:-1]),\n",
" callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
"
#
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
]
},
},
{
{
...
@@ -419,10 +419,10 @@
...
@@ -419,10 +419,10 @@
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"
#
tf.config.experimental_run_functions_eagerly(False)\n",
"tf.config.experimental_run_functions_eagerly(False)\n",
"
#
history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n",
"history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n",
"
#
validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
" validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
"
#
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
" callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
]
]
}
}
],
],
...
...
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
%
load_ext
autoreload
%
load_ext
autoreload
%
autoreload
2
%
autoreload
2
import
warnings
import
warnings
warnings
.
filterwarnings
(
"ignore"
)
warnings
.
filterwarnings
(
"ignore"
)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#from source.utils import *
#from source.utils import *
from
source.preprocess
import
*
from
source.preprocess
import
*
import
pickle
import
pickle
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
import
pandas
as
pd
from
collections
import
defaultdict
from
collections
import
defaultdict
from
tqdm
import
tqdm_notebook
as
tqdm
from
tqdm
import
tqdm_notebook
as
tqdm
```
```
%% Cell type:code id: tags:parameters
%% Cell type:code id: tags:parameters
```
python
```
python
path
=
"../../Desktop/DLC_social_1/"
path
=
"../../Desktop/DLC_social_1/"
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Set up and design the project
# Set up and design the project
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
with
open
(
'{}DLC_social_1_exp_conditions.pickle'
.
format
(
path
),
'rb'
)
as
handle
:
with
open
(
'{}DLC_social_1_exp_conditions.pickle'
.
format
(
path
),
'rb'
)
as
handle
:
Treatment_dict
=
pickle
.
load
(
handle
)
Treatment_dict
=
pickle
.
load
(
handle
)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#Which angles to compute?
#Which angles to compute?
bp_dict
=
{
'B_Nose'
:[
'B_Left_ear'
,
'B_Right_ear'
],
bp_dict
=
{
'B_Nose'
:[
'B_Left_ear'
,
'B_Right_ear'
],
'B_Left_ear'
:[
'B_Nose'
,
'B_Right_ear'
,
'B_Center'
,
'B_Left_flank'
],
'B_Left_ear'
:[
'B_Nose'
,
'B_Right_ear'
,
'B_Center'
,
'B_Left_flank'
],
'B_Right_ear'
:[
'B_Nose'
,
'B_Left_ear'
,
'B_Center'
,
'B_Right_flank'
],
'B_Right_ear'
:[
'B_Nose'
,
'B_Left_ear'
,
'B_Center'
,
'B_Right_flank'
],
'B_Center'
:[
'B_Left_ear'
,
'B_Right_ear'
,
'B_Left_flank'
,
'B_Right_flank'
,
'B_Tail_base'
],
'B_Center'
:[
'B_Left_ear'
,
'B_Right_ear'
,
'B_Left_flank'
,
'B_Right_flank'
,
'B_Tail_base'
],
'B_Left_flank'
:[
'B_Left_ear'
,
'B_Center'
,
'B_Tail_base'
],
'B_Left_flank'
:[
'B_Left_ear'
,
'B_Center'
,
'B_Tail_base'
],
'B_Right_flank'
:[
'B_Right_ear'
,
'B_Center'
,
'B_Tail_base'
],
'B_Right_flank'
:[
'B_Right_ear'
,
'B_Center'
,
'B_Tail_base'
],
'B_Tail_base'
:[
'B_Center'
,
'B_Left_flank'
,
'B_Right_flank'
]}
'B_Tail_base'
:[
'B_Center'
,
'B_Left_flank'
,
'B_Right_flank'
]}
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
%%
time
%%
time
DLC_social_1
=
project
(
path
=
path
,
#Path where to find the required files
DLC_social_1
=
project
(
path
=
path
,
#Path where to find the required files
smooth_alpha
=
0.85
,
#Alpha value for exponentially weighted smoothing
smooth_alpha
=
0.85
,
#Alpha value for exponentially weighted smoothing
distances
=
[
'B_Center'
,
'B_Nose'
,
'B_Left_ear'
,
'B_Right_ear'
,
'B_Left_flank'
,
distances
=
[
'B_Center'
,
'B_Nose'
,
'B_Left_ear'
,
'B_Right_ear'
,
'B_Left_flank'
,
'B_Right_flank'
,
'B_Tail_base'
],
'B_Right_flank'
,
'B_Tail_base'
],
ego
=
False
,
ego
=
False
,
angles
=
True
,
angles
=
True
,
connectivity
=
bp_dict
,
connectivity
=
bp_dict
,
arena
=
'circular'
,
#Type of arena used in the experiments
arena
=
'circular'
,
#Type of arena used in the experiments
arena_dims
=
[
380
],
#Dimensions of the arena. Just one if it's circular
arena_dims
=
[
380
],
#Dimensions of the arena. Just one if it's circular
video_format
=
'.mp4'
,
video_format
=
'.mp4'
,
table_format
=
'.h5'
,
table_format
=
'.h5'
,
exp_conditions
=
Treatment_dict
)
exp_conditions
=
Treatment_dict
)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Run project
# Run project
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
%%
time
%%
time
DLC_social_1_coords
=
DLC_social_1
.
run
(
verbose
=
True
)
DLC_social_1_coords
=
DLC_social_1
.
run
(
verbose
=
True
)
print
(
DLC_social_1_coords
)
print
(
DLC_social_1_coords
)
type
(
DLC_social_1_coords
)
type
(
DLC_social_1_coords
)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Generate coords
# Generate coords
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
%%
time
%%
time
ptest
=
DLC_social_1_coords
.
get_coords
(
center
=
True
,
polar
=
False
,
speed
=
0
,
length
=
'00:10:00'
)
ptest
=
DLC_social_1_coords
.
get_coords
(
center
=
True
,
polar
=
False
,
speed
=
0
,
length
=
'00:10:00'
)
ptest
.
_type
ptest
.
_type
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
%%
time
%%
time
dtest
=
DLC_social_1_coords
.
get_distances
(
speed
=
0
,
length
=
'00:10:00'
)
dtest
=
DLC_social_1_coords
.
get_distances
(
speed
=
0
,
length
=
'00:10:00'
)
dtest
.
_type
dtest
.
_type
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
%%
time
%%
time
atest
=
DLC_social_1_coords
.
get_angles
(
degrees
=
True
,
speed
=
0
,
length
=
'00:10:00'
)
atest
=
DLC_social_1_coords
.
get_angles
(
degrees
=
True
,
speed
=
0
,
length
=
'00:10:00'
)
atest
.
_type
atest
.
_type
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Visualization playground
# Visualization playground
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#Plot animation of trajectory over time with different smoothings
#Plot animation of trajectory over time with different smoothings
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],
# ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
# ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')
#plt.xlabel('x')
#plt.xlabel('x')
#plt.ylabel('y')
#plt.ylabel('y')
#plt.title('Mouse Center Trajectory using different exponential smoothings')
#plt.title('Mouse Center Trajectory using different exponential smoothings')
#plt.legend()
#plt.legend()
#plt.show()
#plt.show()
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Dimensionality reduction playground
# Dimensionality reduction playground
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#pca = ptest.pca(4, 1000)
#pca = ptest.pca(4, 1000)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#plt.scatter(*pca[0].T)
#plt.scatter(*pca[0].T)
#plt.show()
#plt.show()
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Preprocessing playground
# Preprocessing playground
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
mtest
=
merge_tables
(
DLC_social_1_coords
.
get_coords
(
center
=
True
,
polar
=
True
,
length
=
'00:10:00'
))
#,
mtest
=
merge_tables
(
DLC_social_1_coords
.
get_coords
(
center
=
True
,
polar
=
True
,
length
=
'00:10:00'
))
#,
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
# DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
#pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
pttest
=
mtest
.
preprocess
(
window_size
=
11
,
window_step
=
6
,
filter
=
None
,
standard_scaler
=
True
)
pttest
=
mtest
.
preprocess
(
window_size
=
11
,
window_step
=
6
,
filter
=
None
,
standard_scaler
=
True
)
pttest
.
shape
pttest
.
shape
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#plt.plot(pttest[2,:,2], label='normal')
#plt.plot(pttest[2,:,2], label='normal')
#plt.plot(pptest[2,:,2], label='gaussian')
#plt.plot(pptest[2,:,2], label='gaussian')
#plt.legend()
#plt.legend()
#plt.show()
#plt.show()
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
# Trained models playground
# Trained models playground
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder
### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
from
datetime
import
datetime
from
datetime
import
datetime
import
tensorflow.keras
as
k
import
tensorflow.keras
as
k
import
tensorflow
as
tf
import
tensorflow
as
tf
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
NAME
=
'Baseline_VAE_short_512_10=warmup_begin'
NAME
=
'Baseline_VAE_short_512_10=warmup_begin'
log_dir
=
os
.
path
.
abspath
(
log_dir
=
os
.
path
.
abspath
(
"logs/fit/{}_{}"
.
format
(
NAME
,
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S"
))
"logs/fit/{}_{}"
.
format
(
NAME
,
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S"
))
)
)
tensorboard_callback
=
k
.
callbacks
.
TensorBoard
(
log_dir
=
log_dir
,
histogram_freq
=
1
)
tensorboard_callback
=
k
.
callbacks
.
TensorBoard
(
log_dir
=
log_dir
,
histogram_freq
=
1
)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
from
source.models
import
SEQ_2_SEQ_AE
,
SEQ_2_SEQ_VAE
,
SEQ_2_SEQ_VAEP
from
source.models
import
SEQ_2_SEQ_AE
,
SEQ_2_SEQ_VAE
,
SEQ_2_SEQ_VAEP
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
encoder
,
decoder
,
ae
=
SEQ_2_SEQ_AE
(
pttest
.
shape
).
build
()
encoder
,
decoder
,
ae
=
SEQ_2_SEQ_AE
(
pttest
.
shape
).
build
()
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
encoder
,
generator
,
vae
,
kl_warmup_callback
,
mmd_warmup_callback
=
SEQ_2_SEQ_VAE
(
pttest
.
shape
,
#
encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
loss
=
'ELBO+MMD'
,
#
loss='ELBO+MMD',
kl_warmup_epochs
=
10
,
#
kl_warmup_epochs=10,
mmd_warmup_epochs
=
10
).
build
()
#
mmd_warmup_epochs=10).build()
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#
encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,
encoder
,
generator
,
vaep
,
kl_warmup_callback
,
mmd_warmup_callback
=
SEQ_2_SEQ_VAEP
(
pttest
.
shape
,
#
loss='ELBO+MMD',
loss
=
'ELBO+MMD'
,
#
kl_warmup_epochs=10,
kl_warmup_epochs
=
10
,
#
mmd_warmup_epochs=10).build()
mmd_warmup_epochs
=
10
).
build
()
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#ae.summary()
#ae.summary()
#vae.summary()
#vae.summary()
#vaep.summary()
#vaep.summary()
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#from tensorflow.keras.utils import plot_model
#from tensorflow.keras.utils import plot_model
#plot_model(vaep, show_shapes=True)
#plot_model(vaep, show_shapes=True)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#plot_model(vae)
#plot_model(vae)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#np.random.shuffle(pttest)
#np.random.shuffle(pttest)
pttrain
=
pttest
[:
-
15000
]
pttrain
=
pttest
[:
-
15000
]
pttest
=
pttest
[
-
15000
:]
pttest
=
pttest
[
-
15000
:]
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#lr_schedule = tf.keras.callbacks.LearningRateScheduler(
#lr_schedule = tf.keras.callbacks.LearningRateScheduler(
# lambda epoch: 1e-3 * 10**(epoch / 20))
# lambda epoch: 1e-3 * 10**(epoch / 20))
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#tf.config.experimental_run_functions_eagerly(False)
#
tf.config.experimental_run_functions_eagerly(False)
history
=
vae
.
fit
(
x
=
pttrain
[:
-
1
],
y
=
pttrain
[:
-
1
],
epochs
=
100
,
batch_size
=
512
,
verbose
=
1
,
#
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,
validation_data
=
(
pttest
[:
-
1
],
pttest
[:
-
1
]),
#
validation_data=(pttest[:-1], pttest[:-1]),
callbacks
=
[
tensorboard_callback
,
kl_warmup_callback
,
mmd_warmup_callback
])
#
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
#
tf.config.experimental_run_functions_eagerly(False)
tf
.
config
.
experimental_run_functions_eagerly
(
False
)
#
history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,
history
=
vaep
.
fit
(
x
=
pttrain
[:
-
1
],
y
=
[
pttrain
[:
-
1
],
pttrain
[
1
:]],
epochs
=
100
,
batch_size
=
512
,
verbose
=
1
,
#
validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),
validation_data
=
(
pttest
[:
-
1
],
[
pttest
[:
-
1
],
pttest
[
1
:]]),
#
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
callbacks
=
[
tensorboard_callback
,
kl_warmup_callback
,
mmd_warmup_callback
])
```
```
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment