Commit f133d5a3 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored preprocess.py and pose_utils.py

parent 6703d219
Pipeline #83092 failed with stage
in 14 minutes and 18 seconds
......@@ -390,10 +390,7 @@ def rule_based_tagging(
videos: List,
coordinates: Coordinates,
vid_index: int,
frame_limit: float = np.inf,
recog_limit: int = 1,
mode: str = None,
fps: float = 0.0,
path: str = os.path.join("."),
hparams: dict = {},
) -> pd.DataFrame:
......@@ -405,11 +402,7 @@ def rule_based_tagging(
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- vid_index (int): index in videos of the experiment to annotate
- mode (str): if show, enables the display of the annotated video in a separate window, saves to mp4 file
if save
- fps (float): frames per second of the analysed video. Same as input by default
- path (str): directory in which the experimental data is stored
- frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
- recog_limit (int): number of frames to use for arena recognition (1 by default)
- hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
that the rule-based pose estimation utilizes. Values can be:
......@@ -429,6 +422,7 @@ def rule_based_tagging(
hparams = get_hparameters(hparams)
animal_ids = coordinates._animal_ids
undercond = "_" if len(animal_ids) > 1 else ""
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
......@@ -474,7 +468,7 @@ def rule_based_tagging(
)
)
if animal_ids:
if len(animal_ids) == 2:
# Define behaviours that can be computed on the fly from the distance matrix
tag_dict["nose2nose"] = onebyone_contact(bparts=["_Nose"])
......@@ -497,40 +491,21 @@ def rule_based_tagging(
tol=hparams["follow_tol"],
)
)
tag_dict[_id + "_climbing"] = deepof.utils.smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(coords[_id + "_Nose"]), np.zeros([1, 2])
)
> (w / 200 + arena[2])
).reshape(coords.shape[0]),
index=coords.index,
).astype(bool)
)
tag_dict[_id + "_speed"] = speeds[_id + "_speed"]
tag_dict[_id + "_huddle"] = deepof.utils.smooth_boolean_array(
huddle(
coords,
speeds,
hparams["huddle_forward"],
hparams["huddle_spine"],
hparams["huddle_speed"],
)
)
else:
tag_dict["climbing"] = deepof.utils.smooth_boolean_array(
for _id in animal_ids:
tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(np.array(coords["Nose"]), np.zeros([1, 2]))
spatial.distance.cdist(
np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
)
> (w / 200 + arena[2])
).reshape(coords.shape[0]),
index=coords.index,
).astype(bool)
)
tag_dict["speed"] = speeds["Center"]
tag_dict["huddle"] = deepof.utils.smooth_boolean_array(
tag_dict[_id + undercond + "speed"] = speeds[_id + undercond + "Center"]
tag_dict[_id + undercond + "huddle"] = deepof.utils.smooth_boolean_array(
huddle(
coords,
speeds,
......@@ -540,6 +515,65 @@ def rule_based_tagging(
)
)
tag_df = pd.DataFrame(tag_dict)
return tag_df
def rule_based_video(
coordinates,
tracks,
videos,
vid_index,
tag_dict,
mode,
path,
fps,
frame_limit,
recog_limit,
hparams,
):
"""Renders a version of the input video with all rule-based taggings in place.
Parameters:
- tracks (list): list containing experiment IDs as strings
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- vid_index (int): index in videos of the experiment to annotate
- mode (str): if show, enables the display of the annotated video in a separate window, saves to mp4 file
if save
- fps (float): frames per second of the analysed video. Same as input by default
- path (str): directory in which the experimental data is stored
- frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
- recog_limit (int): number of frames to use for arena recognition (1 by default)
- hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
that the rule-based pose estimation utilizes. Values can be:
- speed_pause (int): size of the rolling window to use when computing speeds
- close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
- side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
- follow_frames (int): number of frames during which the following trait is tracked
- follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
in order to report a detection
- huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
- huddle_spine (int): maximum average distance between spine body parts to report a huddle detection
- huddle_speed (int): maximum speed to report a huddle detection
Returns:
True
"""
animal_ids = coordinates._animal_ids
# undercond = "_" if len(animal_ids) > 1 else ""
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
coords = coordinates.get_coords()[vid_name]
speeds = coordinates.get_coords(speed=1)[vid_name]
arena, h, w = deepof.utils.recognize_arena(
videos, vid_index, path, recog_limit, coordinates._arena
)
if mode in ["show", "save"]:
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
......@@ -732,7 +766,3 @@ def rule_based_tagging(
cap.release()
cv2.destroyAllWindows()
tag_df = pd.DataFrame(tag_dict)
return tag_df
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.chdir(\"../\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#from source.utils import *\n",
"from deepof.preprocess import *\n",
"from deepof.model_utils import *\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from collections import defaultdict\n",
"from tqdm import tqdm_notebook as tqdm"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"path = \"../../PycharmProjects/deepof/tests/test_examples\"\n",
"path2 = \"../../Desktop/DLC_social_2/\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set up and design the project"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '../../PycharmProjects/deepof/tests/test_examplesDLC_social_1_exp_conditions.pickle'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-27-2f7d5a25baed>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'{}DLC_social_1_exp_conditions.pickle'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mTreatment_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../../PycharmProjects/deepof/tests/test_examplesDLC_social_1_exp_conditions.pickle'"
]
}
],
"source": [
"with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:\n",
" Treatment_dict = pickle.load(handle)"
]
},
{
"cell_type": "code",
"execution_count": 28,
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
import os
os.chdir("../")
```
%% Cell type:code id: tags:
``` python
#from source.utils import *
from deepof.preprocess import *
from deepof.model_utils import *
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
```
%% Cell type:code id: tags:parameters
``` python
path = "../../PycharmProjects/deepof/tests/test_examples"
path2 = "../../Desktop/DLC_social_2/"
```
%% Cell type:markdown id: tags:
# Set up and design the project
%% Cell type:code id: tags:
``` python
with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
Treatment_dict = pickle.load(handle)
```
%%%% Output: error
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
<ipython-input-27-2f7d5a25baed> in <module>
----> 1 with open('{}DLC_social_1_exp_conditions.pickle'.format(path), 'rb') as handle:
2 Treatment_dict = pickle.load(handle)
FileNotFoundError: [Errno 2] No such file or directory: '../../PycharmProjects/deepof/tests/test_examplesDLC_social_1_exp_conditions.pickle'
%% Cell type:code id: tags:
``` python
Treatment_dict["WT+NS"]
```
%%%% Output: execute_result
['Test 6DLC',
'Test 15DLC',
'Test 24DLC',
'Test 29DLC',
'Test 38DLC',
'Test 47DLC',
'Day2Test8DLC',
'Day2Test13DLC',
'Day2Test22DLC',
'Day2Test31DLC',
'Day2Test40DLC']
%% Cell type:code id: tags:
``` python
#Which angles to compute?
bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],
'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],
'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],
'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],
'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],
'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],
'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}
```
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1 = project(path=path,#Path where to find the required files
smooth_alpha=0.5, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
subset_condition="B",
video_format='.mp4',
table_format='.h5',
exp_conditions=Treatment_dict)
```
%%%% Output: stream
CPU times: user 61.8 ms, sys: 23.3 ms, total: 85.1 ms
Wall time: 31.2 ms
%% Cell type:code id: tags:
``` python
%%time
DLC_social_2 = project(path=path2,#Path where to find the required files
smooth_alpha=0.5, #Alpha value for exponentially weighted smoothing
distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',
'B_Right_flank','B_Tail_base'],
ego=False,
angles=True,
connectivity=bp_dict,
arena='circular', #Type of arena used in the experiments
arena_dims=[380], #Dimensions of the arena. Just one if it's circular
subset_condition="B",
video_format='.mp4',
table_format='.h5')
```
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
DLC_social_1_coords = DLC_social_1.run(verbose=True)
print(DLC_social_1_coords)
type(DLC_social_1_coords)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Computing distances...
%%%% Output: error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<timed exec> in <module>
~/PycharmProjects/deepof/deepof/preprocess.py in run(self, verbose)
257
258 if self.distances:
--> 259 distances = self.get_distances(tables, verbose)
260
261 if self.angles:
~/PycharmProjects/deepof/deepof/preprocess.py in get_distances(self, table_dict, verbose)
193 distance_dict = {
194 key: bpart_distance(tab, scales[i, 1], scales[i, 0],)
--> 195 for i, (key, tab) in enumerate(table_dict.items())
196 }
197
~/PycharmProjects/deepof/deepof/preprocess.py in <dictcomp>(.0)
193 distance_dict = {
194 key: bpart_distance(tab, scales[i, 1], scales[i, 0],)
--> 195 for i, (key, tab) in enumerate(table_dict.items())
196 }
197
~/PycharmProjects/deepof/deepof/utils.py in bpart_distance(dataframe, arena_abs, arena_rel)
119 dists.append(dist)
120
--> 121 return pd.concat(dists, axis=1)
122
123
~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/pandas/core/reshape/concat.py in concat(objs, axis, join, ignore_index, keys, levels, names, verify_integrity, sort, copy)
279 verify_integrity=verify_integrity,
280 copy=copy,
--> 281 sort=sort,
282 )
283
~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/pandas/core/reshape/concat.py in __init__(self, objs, axis, join, keys, levels, names, ignore_index, verify_integrity, copy, sort)
327
328 if len(objs) == 0:
--> 329 raise ValueError("No objects to concatenate")
330
331 if keys is None:
ValueError: No objects to concatenate
%% Cell type:code id: tags:
``` python
%%time
DLC_social_2_coords = DLC_social_2.run(verbose=True)
print(DLC_social_2_coords)
type(DLC_social_2_coords)
```
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
ptest = DLC_social_1_coords.get_coords(center="B_Center", polar=False, speed=0, length='00:10:00')
ptest._type
ptest2 = DLC_social_2_coords.get_coords(center="B_Center", polar=False, speed=0, length='00:10:00')
ptest2._type
```
%%%% Output: stream
CPU times: user 558 ms, sys: 47.1 ms, total: 605 ms
Wall time: 566 ms
%%%% Output: execute_result
'coords'
%% Cell type:code id: tags:
``` python
ptest['Test 13DLC'].columns.levels
```
%% Cell type:code id: tags:
``` python
%%time
dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')
dtest._type
dtest2 = DLC_social_2_coords.get_distances(speed=0, length='00:10:00')
dtest2._type
```
%% Cell type:code id: tags:
``` python
%%time
atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest._type
atest2 = DLC_social_2_coords.get_angles(degrees=True, speed=0, length='00:10:00')
atest2._type
```
%% Cell type:markdown id: tags:
# Visualization playground
%% Cell type:code id: tags:
``` python
# ptest.plot_heatmaps(['B_Nose'], i=2)
```
%% Cell type:code id: tags:
``` python
ptest['Day2Test13DLC']['B_Nose'].iloc[:5000]
```
%% Cell type:code id: tags:
``` python
#Plot animation of trajectory over time with different smoothings
# plt.plot(ptestb['Day2Test13DLC']['B_Nose'].iloc[:50]['x'],
# ptestb['Day2Test13DLC']['B_Nose'].iloc[:50]['y'], label='alpha=0.95')
# plt.plot(ptestd['Day2Test13DLC']['B_Nose'].iloc[:50]['x'],
# ptestd['Day2Test13DLC']['B_Nose'].iloc[:50]['y'], label='alpha=0.65')
# plt.xlabel('x')
# plt.ylabel('y')
# plt.title('Mouse Center Trajectory using different exponential smoothings')
# plt.legend()
# plt.show()
```
%% Cell type:markdown id: tags:
# Dimensionality reduction playground
%% Cell type:code id: tags:
``` python
#pca = ptest.pca(4, 1000)
```
%% Cell type:code id: tags:
``` python
#plt.scatter(*pca[0].T)
#plt.show()
```
%% Cell type:markdown id: tags:
# Preprocessing playground
%% Cell type:code id: tags:
``` python
mtest = merge_tables(
DLC_social_1_coords.get_coords(center="B_Center", polar=False, length='00:10:00', align='B_Nose')
#DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),
#DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'),
)
```
%% Cell type:code id: tags:
``` python
mtest2 = merge_tables(
DLC_social_2_coords.get_coords(center="B_Center", polar=False, length='00:10:00', align='B_Nose'),
#DLC_social_2_coords.get_distances(speed=0, length='00:10:00'),
#DLC_social_2_coords.get_angles(degrees=True, speed=0, length='00:10:00'),
)
```
%% Cell type:code id: tags:
``` python
%%time
pttest = mtest.preprocess(window_size=13, window_step=10, filter=None, sigma=55,
shift=0, scale='standard', align='center', shuffle=True, test_videos=0)
print(pttest.shape)
#print(pttrain.shape)
```
%% Cell type:code id: tags:
``` python
%%time
pttest2 = mtest2.preprocess(window_size=13, window_step=1, filter=None, sigma=55,
shift=0, scale="standard", align='all', shuffle=False)
pttest2.shape
```
%% Cell type:code id: tags:
``` python
n = 100
plt.scatter(pttest[:n,10,0], pttest[:n,10,1], label='Nose')
plt.scatter(pttest[:n,10,2], pttest[:n,10,3], label='Right ear')
plt.scatter(pttest[:n,10,4], pttest[:n,10,5], label='Right hips')
plt.scatter(pttest[:n,10,6], pttest[:n,10,7], label='Left ear')
plt.scatter(pttest[:n,10,8], pttest[:n,10,9], label='Left hips')
plt.scatter(pttest[:n,10,10], pttest[:n,10,11], label='Tail base')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
# Trained models playground
%% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags:
``` python
from datetime import datetime
import tensorflow.keras as k
import tensorflow as tf
```
%% Cell type:code id: tags:
``` python
NAME = 'Baseline_AE_512_wu10_slide10_gauss_fullval'
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
```
%% Cell type:code id: tags:
``` python
from source.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_GMVAE
```
%% Cell type:code id: tags:
``` python
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()
ae.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
ae.summary()
```
%% Cell type:code id: tags:
``` python
%%time
tf.keras.backend.clear_session()
encoder, generator, grouper, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_GMVAE(pttest.shape,
loss='ELBO',
number_of_components=30,
kl_warmup_epochs=10,
mmd_warmup_epochs=10,
encoding=16,
predictor=False).build()
# gmvaep.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
import tensorflow as tf
from tensorflow import keras as keras
K = tf.keras.backend
class ExponentialLearningRate(tf.keras.callbacks.Callback):
def __init__(self, factor):
self.factor = factor
self.rates = []
self.losses = []
def on_batch_end(self, batch, logs):
self.rates.append(K.get_value(self.model.optimizer.lr))
self.losses.append(logs["loss"])
K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)
def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=10**-5, max_rate=10):
init_weights = model.get_weights()
iterations = len(X) // batch_size * epochs
factor = np.exp(np.log(max_rate / min_rate) / iterations)
init_lr = K.get_value(model.optimizer.lr)
K.set_value(model.optimizer.lr, min_rate)
exp_lr = ExponentialLearningRate(factor)
history = model.fit(X, y, epochs=epochs, batch_size=batch_size,
callbacks=[exp_lr])
K.set_value(model.optimizer.lr, init_lr)
model.set_weights(init_weights)
return exp_lr.rates, exp_lr.losses
def plot_lr_vs_loss(rates, losses):
plt.plot(rates, losses)
plt.gca().set_xscale('log')
plt.hlines(min(losses), min(rates), max(rates))
plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
plt.xlabel("Learning rate")
plt.ylabel("Loss")
```
%% Cell type:code id: tags:
``` python
class OneCycleScheduler(tf.keras.callbacks.Callback):
def __init__(self, iterations, max_rate, start_rate=None,
last_iterations=None, last_rate=None):
self.iterations = iterations
self.max_rate = max_rate
self.start_rate = start_rate or max_rate / 10
self.last_iterations = last_iterations or iterations // 10 + 1
self.half_iteration = (iterations - self.last_iterations) // 2
self.last_rate = last_rate or self.start_rate / 1000
self.iteration = 0
def _interpolate(self, iter1, iter2, rate1, rate2):
return ((rate2 - rate1) * (self.iteration - iter1)
/ (iter2 - iter1) + rate1)
def on_batch_begin(self, batch, logs):
if self.iteration < self.half_iteration:
rate = self._interpolate(0, self.half_iteration, self.start_rate, self.max_rate)
elif self.iteration < 2 * self.half_iteration:
rate = self._interpolate(self.half_iteration, 2 * self.half_iteration,
self.max_rate, self.start_rate)
else:
rate = self._interpolate(2 * self.half_iteration, self.iterations,
self.start_rate, self.last_rate)
rate = max(rate, self.last_rate)
self.iteration += 1
K.set_value(self.model.optimizer.lr, rate)
```
%% Cell type:code id: tags:
``` python
batch_size = 512
rates, losses = find_learning_rate(gmvaep, pttest[:512*10], pttest[:512*10], epochs=1, batch_size=batch_size)
plot_lr_vs_loss(rates, losses)
plt.title("Learning rate tuning")
plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 1.4])
plt.show()
```
%% Cell type:markdown id: tags:
# Encoding plots
%% Cell type:code id: tags:
``` python
import umap
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import plotly.express as px
```
%% Cell type:code id: tags:
``` python
data = pttest
samples = 15000
montecarlo = 10
```
%% Cell type:code id: tags:
``` python
weights = "GMVAE_components=30_loss=ELBO_kl_warmup=30_mmd_warmup=30_20200804-225526_final_weights.h5"
gmvaep.load_weights(weights)
if montecarlo:
clusts = np.stack([grouper(data[:samples]) for sample in (tqdm(range(montecarlo)))])
clusters = clusts.mean(axis=0)
clusters = np.argmax(clusters, axis=1)
else:
clusters = grouper(data[:samples], training=False)
clusters = np.argmax(clusters, axis=1)
```
%% Cell type:code id: tags:
``` python
def plot_encodings(data, samples, n, clusters, threshold):
reducer = PCA(n_components=n)
clusters = clusters[:, :samples]
filter = np.max(np.mean(clusters, axis=0), axis=1) > threshold
encoder.predict(data[:samples][filter])
print("{}/{} samples used ({}%); confidence threshold={}".format(sum(filter),
samples,
sum(filter)/samples*100,
threshold))
clusters = np.argmax(np.mean(clusters, axis=0), axis=1)[filter]
rep = reducer.fit_transform(encoder.predict(data[:samples][filter]))
if n == 2:
df = pd.DataFrame({"encoding-1":rep[:,0],"encoding-2":rep[:,1],"clusters":["A"+str(i) for i in clusters]})
enc = px.scatter(data_frame=df, x="encoding-1", y="encoding-2",
color="clusters", width=600, height=600,
color_discrete_sequence=px.colors.qualitative.T10)
elif n == 3:
df3d = pd.DataFrame({"encoding-1":rep[:,0],"encoding-2":rep[:,1],"encoding-3":rep[:,2],
"clusters":["A"+str(i) for i in clusters]})
enc = px.scatter_3d(data_frame=df3d, x="encoding-1", y="encoding-2", z="encoding-3",
color="clusters", width=600, height=600,
color_discrete_sequence=px.colors.qualitative.T10)
return enc
plot_encodings(data, 5000, 2, clusts, 0.5)
```
%% Cell type:markdown id: tags:
# Confidence per cluster
%% Cell type:code id: tags:
``` python
from collections import Counter
Counter(clusters)
```
%% Cell type:code id: tags:
``` python
# Confidence distribution per cluster
for cl in range(5):
cl_select = np.argmax(np.mean(clusts, axis=0), axis=1) == cl
dt = np.mean(clusts[:,cl_select,cl], axis=0)
sns.kdeplot(dt, shade=True, label=cl)
plt.xlabel('MC Dropout confidence')
plt.ylabel('Density')
plt.show()
```
%% Cell type:code id: tags:
``` python
def animated_cluster_heatmap(data, clust, clusters, threshold=0.75, samples=False):
if not samples:
samples = data.shape[0]
tpoints = data.shape[1]
bdparts = data.shape[2] // 2
cls = clusters[:,:samples,:]
filt = np.max(np.mean(cls, axis=0), axis=1) > threshold
cls = np.argmax(np.mean(cls, axis=0), axis=1)[filt]
clust_series = data[:samples][filt][cls==clust]
rshape = clust_series.reshape(clust_series.shape[0]*clust_series.shape[1],
clust_series.shape[2])
cluster_df = pd.DataFrame()
cluster_df['x'] = rshape[:,[0,2,4,6,8,10]].flatten(order='F')
cluster_df['y'] = rshape[:,[1,3,5,7,9,11]].flatten(order='F')
cluster_df['bpart'] = np.tile(np.repeat(np.arange(bdparts),
clust_series.shape[0]), tpoints)
cluster_df['frame'] = np.tile(np.repeat(np.arange(tpoints),
clust_series.shape[0]), bdparts)
fig = px.density_contour(data_frame=cluster_df, x='x', y='y', animation_frame='frame',
width=600, height=600,
color='bpart',color_discrete_sequence=px.colors.qualitative.T10)
fig.update_traces(contours_coloring="fill",
contours_showlabels = True)
fig.update_xaxes(range=[-3, 3])
fig.update_yaxes(range=[-3, 3])
return fig
```
%% Cell type:code id: tags:
``` python
# animated_cluster_heatmap(pttest, 4, clusts, samples=10)
```
%% Cell type:markdown id: tags:
# Stability across runs
%% Cell type:code id: tags:
``` python
weights = [i for i in os.listdir() if "GMVAE" in i and ".h5" in i]
mult_clusters = np.zeros([len(weights), samples])
mean_conf = []
for k,i in tqdm(enumerate(sorted(weights))):
print(i)
gmvaep.load_weights(i)
if montecarlo:
clusters = np.stack([grouper(data[:samples]) for sample in (tqdm(range(montecarlo)))])
clusters = clusters.mean(axis=0)
mean_conf.append(clusters.max(axis=1))
clusters = np.argmax(clusters, axis=1)
else:
clusters = grouper(data[:samples], training=False)
mean_conf.append(clusters.max(axis=1))
clusters = np.argmax(clusters, axis=1)
mult_clusters[k] = clusters
```
%% Cell type:code id: tags:
``` python
clusts.shape
```
%% Cell type:code id: tags:
``` python
import pandas as pd
from itertools import combinations
from sklearn.metrics import adjusted_rand_score
```
%% Cell type:code id: tags:
``` python
mult_clusters
```
%% Cell type:code id: tags:
``` python
thr = 0.95
ari_dist = []
for i,k in enumerate(combinations(range(len(weights)),2)):
filt = ((mean_conf[k[0]] > thr) & (mean_conf[k[1]]>thr))
ari = adjusted_rand_score(mult_clusters[k[0]][filt],
mult_clusters[k[1]][filt])
ari_dist.append(ari)
```
%% Cell type:code id: tags:
``` python
ari_dist
```
%% Cell type:code id: tags:
``` python
random_ari = []
for i in tqdm(range(6)):
random_ari.append(adjusted_rand_score(np.random.uniform(0,6,50).astype(int),
np.random.uniform(0,6,50).astype(int)))
```
%% Cell type:code id: tags:
``` python
sns.kdeplot(ari_dist, label="ARI gmvaep", shade=True)
sns.kdeplot(random_ari, label="ARI random", shade=True)
plt.xlabel("Normalised Adjusted Rand Index")
plt.ylabel("Density")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
# Cluster differences across conditions
%% Cell type:code id: tags:
``` python
%%time
DLCS1_coords = DLC_social_1_coords.get_coords(center="B_Center",polar=False, length='00:10:00', align='B_Nose')
Treatment_coords = {}
for cond in Treatment_dict.keys():
Treatment_coords[cond] = DLCS1_coords.filter(Treatment_dict[cond]).preprocess(window_size=13,
window_step=10, filter=None, scale='standard', align='center')
```
%% Cell type:code id: tags:
``` python
%%time
montecarlo = 10
Predictions_per_cond = {}
Confidences_per_cond = {}
for cond in Treatment_dict.keys():
Predictions_per_cond[cond] = np.stack([grouper(Treatment_coords[cond]
) for sample in (tqdm(range(montecarlo)))])
Confidences_per_cond[cond] = np.mean(Predictions_per_cond[cond], axis=0)
Predictions_per_cond[cond] = np.argmax(Confidences_per_cond[cond], axis=1)
```
%% Cell type:code id: tags:
``` python
Predictions_per_condition = {k:{cl:[] for cl in range(1,31)} for k in Treatment_dict.keys()}
for k in Predictions_per_cond.values():
print(Counter(k))
```
%% Cell type:code id: tags:
``` python
for cond in Treatment_dict.keys():
start = 0
for i,j in enumerate(DLCS1_coords.filter(Treatment_dict[cond]).values()):
update = start + j.shape[0]//10
counter = Counter(Predictions_per_cond[cond][start:update])
start += j.shape[0]//10
for num in counter.keys():
Predictions_per_condition[cond][num+1].append(counter[num+1])
```
%% Cell type:code id: tags:
``` python
counts = []
clusters = []
conditions = []
for cond,v in Predictions_per_condition.items():
for cluster,i in v.items():
counts+=i
clusters+=list(np.repeat(cluster, len(i)))
conditions+=list(np.repeat(cond, len(i)))
Prediction_per_cond_df = pd.DataFrame({'condition':conditions,
'cluster':clusters,
'count':counts})
```
%% Cell type:code id: tags:
``` python
px.box(data_frame=Prediction_per_cond_df, x='cluster', y='count', color='condition')
```
%% Cell type:markdown id: tags:
# Others
%% Cell type:code id: tags:
``` python
for i in range(5):
print(Counter(labels[str(i)]))
```
%% Cell type:code id: tags:
``` python
adjusted_rand_score(labels[0], labels[3])
```
%% Cell type:code id: tags:
``` python
sns.distplot(ari_dist)
plt.xlabel("Adjusted Rand Index")
plt.ylabel("Count")
plt.show()
```
%% Cell type:code id: tags:
``` python
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
```
%% Cell type:code id: tags:
``` python
from scipy.stats import entropy
```
%% Cell type:code id: tags:
``` python
entropy(np.array([0.5,0,0.5,0]))
```
%% Cell type:code id: tags:
``` python
tfd.Categorical(np.array([0.5,0.5,0.5,0.5])).entropy()
```
%% Cell type:code id: tags:
``` python
pk = np.array([0.5,0,0.5,0])
```
%% Cell type:code id: tags:
``` python
np.log(pk)
```
%% Cell type:code id: tags:
``` python
np.clip(np.log(pk), 0, 1)
```
%% Cell type:code id: tags:
``` python
-np.sum(pk*np.array([-0.69314718, 0, -0.69314718, 0]))
```
%% Cell type:code id: tags:
``` python
import tensorflow.keras.backend as K
entropy = K.sum(tf.multiply(pk, tf.where(~tf.math.is_inf(K.log(pk)), K.log(pk), 0)), axis=0)
entropy
```
%% Cell type:code id: tags:
``` python
sns.distplot(np.max(clusts, axis=1))
sns.distplot(clusts.reshape(clusts.shape[0] * clusts.shape[1]))
plt.axvline(1/10)
plt.show()
```
%% Cell type:code id: tags:
``` python
gauss_means = gmvaep.get_layer(name="dense_4").get_weights()[0][:32]
gauss_variances = tf.keras.activations.softplus(gmvaep.get_layer(name="dense_4").get_weights()[0][32:]).numpy()
```
%% Cell type:code id: tags:
``` python
gauss_means.shape == gauss_variances.shape
```
%% Cell type:code id: tags:
``` python
k=10
n=100
samples = []
for i in range(k):
samples.append(np.random.normal(gauss_means[:,i], gauss_variances[:,i], size=(100,32)))
```
%% Cell type:code id: tags:
``` python
from scipy.stats import ttest_ind
test_matrix = np.zeros([k,k])
for i in range(k):
for j in range(k):
test_matrix[i][j] = np.mean(ttest_ind(samples[i], samples[j], equal_var=False)[1])
```
%% Cell type:code id: tags:
``` python
threshold = 0.55
np.sum(test_matrix > threshold)
```
%% Cell type:code id: tags:
``` python
# Transition matrix
```
%% Cell type:code id: tags:
``` python
Treatment_dict
```
%% Cell type:code id: tags:
``` python
# Anomaly detection - the model was trained in the WT - NS mice alone
gmvaep.load_weights("GMVAE_components=10_loss=ELBO_kl_warmup=20_mmd_warmup=5_20200721-043310_final_weights.h5")
```
%% Cell type:code id: tags:
``` python
WT_NS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['WT+NS']}, typ="coords")
WT_WS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['WT+CSDS']}, typ="coords")
MU_NS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['NatCre+NS']}, typ="coords")
MU_WS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['NatCre+CSDS']}, typ="coords")
preps = [WT_NS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True),
WT_WS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True),
MU_NS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True),
MU_WS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True)]
```
%% Cell type:code id: tags:
``` python
preds = [gmvaep.predict(i) for i in preps]
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
from sklearn.metrics import mean_absolute_error
reconst_error = {k:mean_absolute_error(preps[i].reshape(preps[i].shape[0]*preps[i].shape[1],12).T,
preds[i].reshape(preds[i].shape[0]*preds[i].shape[1],12).T,
multioutput='raw_values') for i,k in enumerate(Treatment_dict.keys())}
reconst_error
```
%% Cell type:code id: tags:
``` python
reconst_df = pd.concat([pd.DataFrame(np.concatenate([np.repeat(k, len(v)).reshape(len(v),1), v.reshape(len(v),1)],axis=1)) for k,v in reconst_error.items()])
reconst_df = reconst_df.astype({0:str,1:float})
```
%% Cell type:code id: tags:
``` python
sns.boxplot(data=reconst_df, x=0, y=1, orient='vertical')
plt.ylabel('Mean Absolute Error')
plt.ylim(0,0.35)
plt.show()
```
%% Cell type:code id: tags:
``` python
# Check frame rates
```
......
......@@ -347,7 +347,7 @@ def test_rule_based_tagging():
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
animal_ids=None,
animal_ids=[""],
).run(verbose=True)
hardcoded_tags = rule_based_tagging(
......@@ -356,8 +356,6 @@ def test_rule_based_tagging():
prun,
vid_index=0,
path=os.path.join(".", "tests", "test_examples", "Videos"),
mode="save",
frame_limit=100,
)
assert type(hardcoded_tags) == pd.DataFrame
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment