Commit cec5db60 authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed bug in preprocess.py

parent 53ced794
......@@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
......@@ -24,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
......@@ -39,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 10,
"metadata": {
"tags": [
"parameters"
......@@ -60,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
......@@ -70,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
......@@ -86,14 +77,14 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.68 s, sys: 673 ms, total: 3.35 s\n",
"CPU times: user 2.61 s, sys: 794 ms, total: 3.4 s\n",
"Wall time: 1.13 s\n"
]
}
......@@ -116,15 +107,15 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.9 s, sys: 965 ms, total: 7.87 s\n",
"Wall time: 1.47 s\n"
"CPU times: user 6.57 s, sys: 1.02 s, total: 7.59 s\n",
"Wall time: 1.56 s\n"
]
}
],
......@@ -152,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 15,
"metadata": {},
"outputs": [
{
......@@ -165,8 +156,8 @@
"Computing angles...\n",
"Done!\n",
"Coordinates of 47 videos across 4 conditions\n",
"CPU times: user 9.37 s, sys: 633 ms, total: 10 s\n",
"Wall time: 10.1 s\n"
"CPU times: user 9.65 s, sys: 787 ms, total: 10.4 s\n",
"Wall time: 10.7 s\n"
]
},
{
......@@ -175,7 +166,7 @@
"source.preprocess.coordinates"
]
},
"execution_count": 20,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
......@@ -189,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 16,
"metadata": {},
"outputs": [
{
......@@ -202,8 +193,8 @@
"Computing angles...\n",
"Done!\n",
"DLC analysis of 31 videos\n",
"CPU times: user 5.72 s, sys: 457 ms, total: 6.17 s\n",
"Wall time: 6.34 s\n"
"CPU times: user 5.97 s, sys: 541 ms, total: 6.51 s\n",
"Wall time: 6.66 s\n"
]
},
{
......@@ -212,7 +203,7 @@
"source.preprocess.coordinates"
]
},
"execution_count": 21,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
......@@ -233,7 +224,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 17,
"metadata": {
"scrolled": true
},
......@@ -242,8 +233,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.39 s, sys: 88.2 ms, total: 1.48 s\n",
"Wall time: 1.4 s\n"
"CPU times: user 1.31 s, sys: 97.5 ms, total: 1.41 s\n",
"Wall time: 1.34 s\n"
]
},
{
......@@ -252,7 +243,7 @@
"'coords'"
]
},
"execution_count": 22,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
......@@ -268,15 +259,15 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 77.7 ms, sys: 37.4 ms, total: 115 ms\n",
"Wall time: 115 ms\n"
"CPU times: user 70.7 ms, sys: 41 ms, total: 112 ms\n",
"Wall time: 111 ms\n"
]
},
{
......@@ -285,7 +276,7 @@
"'dists'"
]
},
"execution_count": 23,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
......@@ -301,9 +292,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 199 ms, sys: 120 ms, total: 320 ms\n",
"Wall time: 319 ms\n"
]
},
{
"data": {
"text/plain": [
"'angles'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')\n",
......@@ -322,7 +332,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
......@@ -331,7 +341,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
......@@ -355,7 +365,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
......@@ -364,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
......@@ -381,33 +391,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"mtest = merge_tables(\n",
" #DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'),\n",
" DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'),\n",
" DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),\n",
" #DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'),\n",
" DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"mtest2 = merge_tables(\n",
" #DLC_social_2_coords.get_coords(center=True, polar=True, length='00:10:00'),\n",
" DLC_social_2_coords.get_coords(center=True, polar=True, length='00:10:00'),\n",
" DLC_social_2_coords.get_distances(speed=0, length='00:10:00'),\n",
" #DLC_social_2_coords.get_angles(degrees=True, speed=0, length='00:10:00'),\n",
" DLC_social_2_coords.get_angles(degrees=True, speed=0, length='00:10:00'),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
......@@ -416,9 +426,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"(70504, 11, 67)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pttest = mtest.preprocess(window_size=11, window_step=10, filter=\"gaussian\", sigma=10,\n",
" shift=0, standard_scaler=True)\n",
......@@ -427,9 +448,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"(465021, 11, 67)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pttest2 = mtest2.preprocess(window_size=11, window_step=1, filter=\"gaussian\", sigma=10,\n",
" shift=0, standard_scaler=True)\n",
......@@ -464,7 +496,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
......@@ -475,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
......@@ -488,16 +520,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 31,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_GMVAE"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
......@@ -507,16 +547,71 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 33,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"SEQ_2_SEQ_AE\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"SEQ_2_SEQ_Encoder (Sequentia (None, 32) 1450656 \n",
"_________________________________________________________________\n",
"SEQ_2_SEQ_Decoder (Sequentia multiple 2415587 \n",
"=================================================================\n",
"Total params: 3,851,747\n",
"Trainable params: 3,848,163\n",
"Non-trainable params: 3,584\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"ae.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open(\"./coords-based_S2SAE_BAYESIAN_OPT_params.pickle\", \"rb\") as handle:\n",
" nparams = pickle.load(handle)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'units_conv': 160,\n",
" 'units_lstm': 128,\n",
" 'units_dense1': 128,\n",
" 'dropout_rate': 0.15000000000000002,\n",
" 'units_dense2': 128,\n",
" 'learning_rate': 0.000505466655750672}"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nparams.values"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"scrolled": false
},
......@@ -524,10 +619,11 @@
"source": [
"encoder, generator, grouper, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_GMVAE(pttest.shape,\n",
" loss='ELBO+MMD',\n",
" number_of_components=1,\n",
" number_of_components=5,\n",
" kl_warmup_epochs=10,\n",
" mmd_warmup_epochs=10,\n",
" predictor=False).build()\n",
" predictor=False,\n",
" **nparams.values).build()\n",
"gmvaep.build(pttest.shape)"
]
},
......@@ -614,8 +710,28 @@
"metadata": {},
"outputs": [],
"source": [
"# encodings = encoder.predict(pttest)\n",
"# clusters = np.argmax(grouper.predict(pttest), axis=1)"
"encodings = encoder.predict(pttest)\n",
"clusters = np.argmax(grouper.predict(pttest), axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"Counter(clusters)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.manifold import TSNE\n",
"dimred = TSNE(n_components=3, perplexity=30)"
]
},
{
......@@ -624,8 +740,8 @@
"metadata": {},
"outputs": [],
"source": [
"# from collections import Counter\n",
"# Counter(clusters)"
"encs = encodings[:5000]\n",
"encs = dimred.fit_transform(encs)"
]
},
{
......@@ -634,24 +750,23 @@
"metadata": {},
"outputs": [],
"source": [
"# %matplotlib notebook\n",
"# # This import registers the 3D projection, but is otherwise unused.\n",
"# from mpl_toolkits.mplot3d import Axes3D \n",
"%matplotlib notebook\n",
"# This import registers the 3D projection, but is otherwise unused.\n",
"from mpl_toolkits.mplot3d import Axes3D \n",
"\n",
"# import matplotlib.pyplot as plt\n",
"# import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# fig = plt.figure()\n",
"# ax = fig.add_subplot(111, projection='3d')\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"\n",
"# encs = encodings[:]\n",
"# ax.scatter(encs[:,0],encs[:,1],encs[:,2], c=clusters[:])\n",
"ax.scatter(encs[:,0],encs[:,1],encs[:,2], c=clusters[:5000])\n",
"\n",
"# ax.set_xlabel('Encoding 0')\n",
"# ax.set_ylabel('Encoding 1')\n",
"# ax.set_zlabel('Encoding 2')\n",
"ax.set_xlabel('Encoding 0')\n",
"ax.set_ylabel('Encoding 1')\n",
"ax.set_zlabel('Encoding 2')\n",
"\n",
"# plt.show()"
"plt.show()"
]
},
{
......@@ -667,6 +782,18 @@
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%%%% Output: stream
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
%% Cell type:code id: tags:
``` python
#from source.utils import *
from source.preprocess import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:code id: tags:parameters
``` python
path = "../../Desktop/DLC_social_1/"
path2 = "../../Desktop/DLC_social_2/"
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%%%% Output: stream
CPU times: user 2.68 s, sys: 673 ms, total: 3.35 s
CPU times: user 2.61 s, sys: 794 ms, total: 3.4 s
Wall time: 1.13 s
%% Cell type:code id: tags:
``` python
%%time
DLC_social_2 = project(path=path2,#Path where to find the required files
smooth_alpha=0.85, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
video_format='.mp4',
table_format='.h5')
```
%%%% Output: stream
CPU times: user 6.9 s, sys: 965 ms, total: 7.87 s
Wall time: 1.47 s
CPU times: user 6.57 s, sys: 1.02 s, total: 7.59 s
Wall time: 1.56 s
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
Coordinates of 47 videos across 4 conditions
CPU times: user 9.37 s, sys: 633 ms, total: 10 s
Wall time: 10.1 s
CPU times: user 9.65 s, sys: 787 ms, total: 10.4 s
Wall time: 10.7 s
%%%% Output: execute_result
source.preprocess.coordinates
%% Cell type:code id: tags:
``` python
%%time
DLC_social_2_coords = DLC_social_2.run(verbose=True)
print(DLC_social_2_coords)
type(DLC_social_2_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
DLC analysis of 31 videos
CPU times: user 5.72 s, sys: 457 ms, total: 6.17 s
Wall time: 6.34 s
CPU times: user 5.97 s, sys: 541 ms, total: 6.51 s
Wall time: 6.66 s
%%%% Output: execute_result
source.preprocess.coordinates
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest._type
ptest2 = DLC_social_2_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')
ptest2._type
```
%%%% Output: stream
CPU times: user 1.39 s, sys: 88.2 ms, total: 1.48 s
Wall time: 1.4 s
CPU times: user 1.31 s, sys: 97.5 ms, total: 1.41 s
Wall time: 1.34 s
%%%% Output: execute_result
'coords'
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
dtest2 = DLC_social_2_coords.get_distances(speed=0, length='00:10:00')
dtest2._type
```
%%%% Output: stream
CPU times: user 77.7 ms, sys: 37.4 ms, total: 115 ms
Wall time: 115 ms
CPU times: user 70.7 ms, sys: 41 ms, total: 112 ms
Wall time: 111 ms
%%%% Output: execute_result
'dists'
%% Cell type:code id: tags:
``` python
%%time