Commit 128b3fcc authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent e29a40c8
......@@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 650,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
......@@ -24,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 651,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -39,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 652,
"execution_count": null,
"metadata": {
"tags": [
"parameters"
......@@ -60,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 653,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -70,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 654,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -86,18 +77,9 @@
},
{
"cell_type": "code",
"execution_count": 655,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.85 s, sys: 1.02 s, total: 3.87 s\n",
"Wall time: 1.61 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"DLC_social_1 = project(path=path,#Path where to find the required files\n",
......@@ -116,18 +98,9 @@
},
{
"cell_type": "code",
"execution_count": 656,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.9 s, sys: 1.1 s, total: 8 s\n",
"Wall time: 1.8 s\n"
]
}
],
"outputs": [],
"source": [
"%%time\n",
"DLC_social_2 = project(path=path2,#Path where to find the required files\n",
......@@ -152,34 +125,9 @@
},
{
"cell_type": "code",
"execution_count": 657,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading trajectories...\n",
"Smoothing trajectories...\n",
"Computing distances...\n",
"Computing angles...\n",
"Done!\n",
"Coordinates of 47 videos across 4 conditions\n",
"CPU times: user 11 s, sys: 1.2 s, total: 12.2 s\n",
"Wall time: 12.9 s\n"
]
},
{
"data": {
"text/plain": [
"source.preprocess.coordinates"
]
},
"execution_count": 657,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
......@@ -189,34 +137,9 @@
},
{
"cell_type": "code",
"execution_count": 658,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading trajectories...\n",
"Smoothing trajectories...\n",
"Computing distances...\n",
"Computing angles...\n",
"Done!\n",
"DLC analysis of 31 videos\n",
"CPU times: user 14.3 s, sys: 1.56 s, total: 15.8 s\n",
"Wall time: 16.7 s\n"
]
},
{
"data": {
"text/plain": [
"source.preprocess.coordinates"
]
},
"execution_count": 658,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"DLC_social_2_coords = DLC_social_2.run(verbose=True)\n",
......@@ -233,30 +156,11 @@
},
{
"cell_type": "code",
"execution_count": 659,
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.79 s, sys: 217 ms, total: 3.01 s\n",
"Wall time: 2.93 s\n"
]
},
{
"data": {
"text/plain": [
"'coords'"
]
},
"execution_count": 659,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"%%time\n",
"ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')\n",
......@@ -268,28 +172,9 @@
},
{
"cell_type": "code",
"execution_count": 660,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 105 ms, sys: 125 ms, total: 230 ms\n",
"Wall time: 287 ms\n"
]
},
{
"data": {
"text/plain": [
"'dists'"
]
},
"execution_count": 660,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"%%time\n",
"dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')\n",
......@@ -301,28 +186,9 @@
},
{
"cell_type": "code",
"execution_count": 661,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 254 ms, sys: 180 ms, total: 433 ms\n",
"Wall time: 469 ms\n"
]
},
{
"data": {
"text/plain": [
"'angles'"
]
},
"execution_count": 661,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"%%time\n",
"atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')\n",
......@@ -341,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 662,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -350,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 663,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -374,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 664,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -383,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 665,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -400,7 +266,7 @@
},
{
"cell_type": "code",
"execution_count": 666,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -413,7 +279,7 @@
},
{
"cell_type": "code",
"execution_count": 667,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -426,7 +292,7 @@
},
{
"cell_type": "code",
"execution_count": 668,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -435,20 +301,9 @@
},
{
"cell_type": "code",
"execution_count": 669,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70504, 11, 28)"
]
},
"execution_count": 669,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"pttest = mtest.preprocess(window_size=11, window_step=10, filter=\"gaussian\", sigma=10,\n",
" shift=0, standard_scaler=True)\n",
......@@ -457,20 +312,9 @@
},
{
"cell_type": "code",
"execution_count": 670,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(465021, 11, 28)"
]
},
"execution_count": 670,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"pttest2 = mtest2.preprocess(window_size=11, window_step=1, filter=\"gaussian\", sigma=10,\n",
" shift=0, standard_scaler=True)\n",
......@@ -479,7 +323,7 @@
},
{
"cell_type": "code",
"execution_count": 671,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -505,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 672,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -516,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 673,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -529,7 +373,7 @@
},
{
"cell_type": "code",
"execution_count": 674,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -538,7 +382,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -548,35 +392,16 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"SEQ_2_SEQ_AE\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"SEQ_2_SEQ_Encoder (Sequentia (None, 32) 1791200 \n",
"_________________________________________________________________\n",
"SEQ_2_SEQ_Decoder (Sequentia multiple 2687420 \n",
"=================================================================\n",
"Total params: 4,435,388\n",
"Trainable params: 4,431,036\n",
"Non-trainable params: 4,352\n",
"_________________________________________________________________\n"
]
}
],
"outputs": [],
"source": [
"ae.summary()"
]
},
{
"cell_type": "code",
"execution_count": 718,
"execution_count": null,
"metadata": {
"scrolled": false
},
......@@ -638,7 +463,7 @@
},
{
"cell_type": "code",
"execution_count": 677,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -651,81 +476,9 @@
},
{
"cell_type": "code",
"execution_count": 646,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"SEQ_2_SEQ_VAE\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_4 (InputLayer) [(None, 11, 28)] 0 \n",
"__________________________________________________________________________________________________\n",
"conv1d_2 (Conv1D) (None, 11, 256) 36096 input_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_22 (BatchNo (None, 11, 256) 1024 conv1d_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_8 (Bidirectional) (None, 11, 512) 1050624 batch_normalization_22[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_23 (BatchNo (None, 11, 512) 2048 bidirectional_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_9 (Bidirectional) (None, 256) 656384 batch_normalization_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_24 (BatchNo (None, 256) 1024 bidirectional_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_12 (Dense) (None, 128) 32896 batch_normalization_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_25 (BatchNo (None, 128) 512 dense_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_2 (Dropout) (None, 128) 0 batch_normalization_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_13 (Dense) (None, 64) 8256 dropout_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_26 (BatchNo (None, 64) 256 dense_13[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_17 (Dense) (None, 80) 5200 batch_normalization_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_16 (Dense) (None, 5) 325 batch_normalization_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_1 (Reshape) (None, 16, 5) 0 dense_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"distribution_lambda_1 (Distribu ((None, 8), (None, 8 0 dense_16[0][0] \n",
" reshape_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"kl_divergence_layer_1 (KLDiverg (None, 8) 1 distribution_lambda_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"mm_discrepancy_layer_1 (MMDiscr (None, 8) 1 kl_divergence_layer_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_14 (Dense) (None, 64) 576 mm_discrepancy_layer_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_18 (BatchNo (None, 64) 256 dense_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_15 (Dense) (None, 128) 8320 batch_normalization_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_19 (BatchNo (None, 128) 512 dense_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"repeat_vector_2 (RepeatVector) (None, 11, 128) 0 batch_normalization_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_10 (Bidirectional (None, 11, 256) 263168 repeat_vector_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_20 (BatchNo (None, 11, 256) 1024 bidirectional_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_11 (Bidirectional (None, 11, 512) 1050624 batch_normalization_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_21 (BatchNo (None, 11, 512) 2048 bidirectional_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"vaep_reconstruction (TimeDistri (None, 11, 28) 14364 batch_normalization_21[0][0] \n",
"==================================================================================================\n",
"Total params: 3,135,539\n",
"Trainable params: 3,131,185\n",
"Non-trainable params: 4,354\n",
"__________________________________________________________________________________________________\n"
]
}
],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gmvaep.summary()"
]
......@@ -752,7 +505,7 @@
},
{
"cell_type": "code",
"execution_count": 736,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%%%% Output: stream
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
%% Cell type:code id: tags:
``` python
#from source.utils import *
from source.preprocess import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:code id: tags:parameters
``` python
path = "../../Desktop/DLC_social_1/"
path2 = "../../Desktop/DLC_social_2/"
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%%%% Output: stream
CPU times: user 2.85 s, sys: 1.02 s, total: 3.87 s
Wall time: 1.61 s
%% Cell type:code id: tags:
``` python
%%time
DLC_social_2 = project(path=path2,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5')
```
%%%% Output: stream
CPU times: user 6.9 s, sys: 1.1 s, total: 8 s
Wall time: 1.8 s
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
Coordinates of 47 videos across 4 conditions
CPU times: user 11 s, sys: 1.2 s, total: 12.2 s
Wall time: 12.9 s
%%%% Output: execute_result
source.preprocess.coordinates
%% Cell type:code id: tags:
``` python
%%time
DLC_social_2_coords = DLC_social_2.run(verbose=True)
print(DLC_social_2_coords)
type(DLC_social_2_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
DLC analysis of 31 videos
CPU times: user 14.3 s, sys: 1.56 s, total: 15.8 s
Wall time: 16.7 s
%%%% Output: execute_result
source.preprocess.coordinates
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type
ptest2 = DLC_social_2_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest2._type
```
%%%% Output: stream
CPU times: user 2.79 s, sys: 217 ms, total: 3.01 s
Wall time: 2.93 s
%%%% Output: execute_result
'coords'
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
dtest2 = DLC_social_2_coords.get_distances(speed=0, length='00:10:00')
dtest2._type
```
%%%% Output: stream
CPU times: user 105 ms, sys: 125 ms, total: 230 ms
Wall time: 287 ms
%%%% Output: execute_result
'dists'
%% Cell type:code id: tags:
``` python
%%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type
atest2 = DLC_social_2_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest2._type
```
%%%% Output: stream
CPU times: user 254 ms, sys: 180 ms, total: 433 ms
Wall time: 469 ms
%%%% Output: execute_result
'angles'
%% Cell type:markdown id: tags:
# Visualization playground
%% Cell type:code id: tags:
``` python
#ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)
```
%% Cell type:code id: tags:
``` python
#Plot animation of trajectory over time with different smoothings
#plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],