Commit 545d089a authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed all type() checks for isinstance() to take inheritance into account

parent 48ef8ea8
......@@ -40,8 +40,8 @@ import warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# DEFINE CUSTOM ANNOTATED TYPES #
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
Coordinates = deepof.utils.Newisinstance("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.Newisinstance("Table_dict", deepof.utils.Any)
# CLASSES FOR PREPROCESSING AND DATA WRANGLING
......@@ -549,7 +549,7 @@ class coordinates:
- self._scales[i][1] / 2
)
elif type(center) == str and center != "arena":
elif isinstance(center, str) and center != "arena":
for i, (key, value) in enumerate(tabs.items()):
......@@ -583,7 +583,7 @@ class coordinates:
for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype("timedelta64[s]")
).asisinstance("timedelta64[s]")
if align:
assert (
......@@ -667,7 +667,7 @@ class coordinates:
for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype("timedelta64[s]")
).asisinstance("timedelta64[s]")
if propagate_labels:
for key, tab in tabs.items():
......@@ -732,7 +732,7 @@ class coordinates:
for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype("timedelta64[s]")
).asisinstance("timedelta64[s]")
if propagate_labels:
for key, tab in tabs.items():
......@@ -833,7 +833,7 @@ class coordinates:
)
pbar.update(1)
if type(video_output) == list:
if isinstance(video_output, list):
vid_idxs = video_output
elif video_output == "all":
vid_idxs = list(self._tables.keys())
......
......@@ -25,7 +25,7 @@ import warnings
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
# Create custom string type
Coordinates = NewType("Coordinates", Any)
Coordinates = Newisinstance("Coordinates", Any)
def close_single_contact(
......@@ -53,12 +53,12 @@ def close_single_contact(
close_contact = None
if type(right) == str:
if isinstance(right, str):
close_contact = (
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol
elif type(right) == list:
elif isinstance(right, list):
close_contact = np.any(
[
(np.linalg.norm(pos_dframe[left] - pos_dframe[r], axis=1) * arena_abs)
......@@ -528,7 +528,7 @@ def max_behaviour(
speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]
behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float")
behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).asisinstance("float")
win_array = behaviour_dframe.rolling(window_size, center=True).sum()
if stepped:
win_array = win_array[::window_size]
......@@ -678,8 +678,8 @@ def rule_based_tagging(
return deepof.utils.smooth_boolean_array(
close_single_contact(
coords,
(left if type(left) != list else right),
(right if type(left) != list else left),
(left if not isinstance(left, list) else right),
(right if not isinstance(left, list) else left),
params["close_contact_tol"],
arena_abs,
arena[1][1],
......
......@@ -26,7 +26,7 @@ from typing import Tuple, Any, List, Union, NewType
# DEFINE CUSTOM ANNOTATED TYPES #
Coordinates = NewType("Coordinates", Any)
Coordinates = Newisinstance("Coordinates", Any)
# CONNECTIVITY FOR DLC MODELS
......@@ -750,7 +750,7 @@ def cluster_transition_matrix(
# Stores all possible transitions between clusters
clusters = [str(i) for i in range(nclusts)]
cluster_sequence = cluster_sequence.astype(str)
cluster_sequence = cluster_sequence.asisinstance(str)
trans = {t: 0 for t in product(clusters, clusters)}
k = len(clusters)
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import cv2
import deepof.data
import deepof.models
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
import re
import seaborn as sns
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import tensorflow as tf
import tqdm.notebook as tqdm
from ipywidgets import interact
```
%% Cell type:code id: tags:
``` python
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
```
%% Cell type:code id: tags:
``` python
import umap
```
%% Cell type:markdown id: tags:
# Retrieve phenotypes
%% Cell type:code id: tags:
``` python
flatten = lambda t: [item for sublist in t for item in sublist]
```
%% Cell type:code id: tags:
``` python
# Load first batch
dset11 = pd.ExcelFile(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/1.Openfield_data-part1/JB05.1-OF-SI-part1.xlsx"
)
dset12 = pd.ExcelFile(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/2.Openfielddata-part2/AnimalID's-JB05.1-part2.xlsx"
)
dset11 = pd.read_excel(dset11, "Tabelle2")
dset12 = pd.read_excel(dset12, "Tabelle2")
dset11.Test = dset11.Test.apply(lambda x: "Test {}_s11".format(x))
dset12.Test = dset12.Test.apply(lambda x: "Test {}_s12".format(x))
dset1 = {"CSDS":list(dset11.loc[dset11.Treatment.isin(["CTR+CSDS","NatCre+CSDS"]), "Test"]) +
list(dset12.loc[dset12.Treatment.isin(["CTR+CSDS","NatCre+CSDS"]), "Test"]),
"NS": list(dset11.loc[dset11.Treatment.isin(["CTR+nonstressed","NatCre+nonstressed"]), "Test"]) +
list(dset12.loc[dset12.Treatment.isin(["CTR+nonstressed","NatCre+nonstressed"]), "Test"]),}
dset1inv = {}
for i in flatten(list(dset1.values())):
if i in dset1["CSDS"]:
dset1inv[i] = "CSDS"
else:
dset1inv[i] = "NS"
assert len(dset1inv) == dset11.shape[0] + dset12.shape[0], "You missed some labels!"
```
%% Cell type:code id: tags:
``` python
# Load second batch
dset21 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part1/2_Single/stressproject22.04.2020genotypes-openfieldday1.xlsx"
)
dset22 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part2/2_Single/OpenFieldvideos-part2.xlsx"
)
dset21.Test = dset21.Test.apply(lambda x: "Test {}_s21".format(x))
dset22.Test = dset22.Test.apply(lambda x: "Test {}_s22".format(x))
dset2 = {"CSDS":list(dset21.loc[dset21.Treatment == "Stress", "Test"]) +
list(dset22.loc[dset22.Treatment == "Stressed", "Test"]),
"NS": list(dset21.loc[dset21.Treatment == "Nonstressed", "Test"]) +
list(dset22.loc[dset22.Treatment == "Nonstressed", "Test"])}
dset2inv = {}
for i in flatten(list(dset2.values())):
if i in dset2["CSDS"]:
dset2inv[i] = "CSDS"
else:
dset2inv[i] = "NS"
assert len(dset2inv) == dset21.shape[0] + dset22.shape[0], "You missed some labels!"
```
%% Cell type:code id: tags:
``` python
# Load third batch
dset31 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx",
sheet_name=1
)
dset32 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx",
sheet_name=1
)
dset31.Test = dset31.Test.apply(lambda x: "Test {}_s31".format(x))
dset32.Test = dset32.Test.apply(lambda x: "Test {}_s32".format(x))
dset3 = {"CSDS":[],
"NS": list(dset31.loc[:, "Test"]) +
list(dset32.loc[:, "Test"])}
dset3inv = {}
for i in flatten(list(dset3.values())):
if i in dset3["CSDS"]:
dset3inv[i] = "CSDS"
else:
dset3inv[i] = "NS"
assert len(dset3inv) == dset31.shape[0] + dset32.shape[0], "You missed some labels!"
```
%% Cell type:code id: tags:
``` python
# Load fourth batch
dset41 = os.listdir("../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_4/JB05.4-OpenFieldvideos/")
# Remove empty video!
dset41 = [vid for vid in dset41 if "52" not in vid]
dset4 = {"CSDS":[],
"NS": [i[:-4]+"_s41" for i in dset41]}
dset4inv = {}
for i in flatten(list(dset4.values())):
if i in dset4["CSDS"]:
dset4inv[i] = "CSDS"
else:
dset4inv[i] = "NS"
assert len(dset4inv) == len(dset41), "You missed some labels!"
```
%% Cell type:code id: tags:
``` python
# Merge phenotype dicts and serialise!
aggregated_dset = {**dset1inv, **dset2inv, **dset3inv, **dset4inv}
```
%% Cell type:code id: tags:
``` python
from collections import Counter
print(Counter(aggregated_dset.values()))
print(115+52)
```
%%%% Output: stream
Counter({'NS': 115, 'CSDS': 52})
167
%% Cell type:code id: tags:
``` python
# Save aggregated dataset to disk
import pickle
with open("../../Desktop/deepof-data/deepof_single_topview/deepof_exp_conditions.pkl", "wb") as handle:
pickle.dump(aggregated_dset, handle, protocol=pickle.HIGHEST_PROTOCOL)
```
%% Cell type:markdown id: tags:
# Define and run project
%% Cell type:code id: tags:
``` python
%%time
deepof_main = deepof.data.project(path=os.path.join("..","..","Desktop","deepoftesttemp"),
smooth_alpha=0.99,
arena_dims=[380],
exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip", "Tail_base"],
exp_conditions=aggregated_dset
)
```
%%%% Output: stream
CPU times: user 111 ms, sys: 14 ms, total: 125 ms
Wall time: 123 ms
%% Cell type:code id: tags:
``` python
%%time
deepof_main = deepof_main.run(verbose=True)
print(deepof_main)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Interpolating outliers...
Iterative imputation of ocluded bodyparts...
Computing distances...
Computing angles...
Done!
Coordinates of 2 videos across 2 conditions
CPU times: user 4.8 s, sys: 806 ms, total: 5.61 s
Wall time: 4.32 s
%% Cell type:code id: tags:
``` python
all_quality = pd.concat([tab for tab in deepof_main.get_quality().values()])
```
%% Cell type:code id: tags:
``` python
all_quality.boxplot(rot=45)
plt.ylim(0.99985, 1.00001)
plt.show()
```
%%%% Output: display_data
![]()
%% Cell type:code id: tags:
``` python
@interact(quality_top=(0., 1., 0.01))
def low_quality_tags(quality_top):
pd.DataFrame(pd.melt(all_quality).groupby("bodyparts").value.apply(
lambda y: sum(y<quality_top) / len(y) * 100)
).sort_values(by="value", ascending=False).plot.bar(rot=45)
plt.xlabel("body part")
plt.ylabel("Tags with quality under {} (%)".format(quality_top * 100))
plt.tight_layout()
plt.legend([])
plt.show()
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
deepof_coords = deepof_main.get_coords(center="Center", polar=False, speed=0, align="Spine_1", align_inplace=True, propagate_labels=False)
#deepof_dists = deepof_main.get_distances(propagate_labels=False)
#deepof_angles = deepof_main.get_angles(propagate_labels=False)
```
%%%% Output: stream
CPU times: user 624 ms, sys: 27 ms, total: 651 ms
Wall time: 662 ms
%% Cell type:markdown id: tags:
# Visualization
%% Cell type:code id: tags:
``` python
%%time
tf.keras.backend.clear_session()
print("Preprocessing training set...")
deepof_train = deepof_coords.preprocess(
window_size=24,
window_step=24,
conv_filter=None,
scale="standard",
shuffle=False,
test_videos=0,
)[0]
# print("Loading pre-trained model...")
# encoder, decoder, grouper, gmvaep, = deepof.models.SEQ_2_SEQ_GMVAE(
# loss="ELBO",
# number_of_components=20,
# compile_model=True,
# kl_warmup_epochs=20,
# montecarlo_kl=10,
# encoding=6,
# mmd_warmup_epochs=20,
# predictor=0,
# phenotype_prediction=0,
# ).build(deepof_train.shape)[:4]
```
%%%% Output: stream
Preprocessing training set...
CPU times: user 18.1 ms, sys: 13 ms, total: 31.1 ms
Wall time: 37.4 ms
%% Cell type:code id: tags:
``` python
weights = ["./latreg_trained_weights/"+i for i in os.listdir("./latreg_trained_weights/") if "encoding=8" in i]
weights
```
%%%% Output: execute_result
['./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=none_20210212-021944_final_weights.h5',
'./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=categorical_20210212-031749_final_weights.h5',
'./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=categorical+variance_20210212-022008_final_weights.h5',
'./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=variance_20210212-023839_final_weights.h5']
%% Cell type:code id: tags:
``` python
trained_network = weights[2]
print(trained_network)
l = int(re.findall("encoding=(\d+)_", trained_network)[0])
k = int(re.findall("k=(\d+)_", trained_network)[0])
pheno = 0
encoder, decoder, grouper, gmvaep, = deepof.models.SEQ_2_SEQ_GMVAE(
loss="ELBO",
number_of_components=k,
compile_model=True,
kl_warmup_epochs=20,
montecarlo_kl=10,
encoding=l,
mmd_warmup_epochs=20,
predictor=0,
phenotype_prediction=pheno,
reg_cat_clusters=("categorical" in trained_network),
reg_cluster_variance=("variance" in trained_network),
).build(deepof_train.shape)[:4]
gmvaep.load_weights(trained_network)
```
%%%% Output: stream
./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=categorical+variance_20210212-022008_final_weights.h5
%% Cell type:code id: tags:
``` python
# Get data to pass through the models
trained_distribution = encoder(deepof_train)
categories = tf.keras.models.Model(encoder.input, encoder.layers[15].output)(deepof_train).numpy()
# Fit a scaler to unscale the reconstructions later on
video_key = np.random.choice(list(deepof_coords.keys()), 1)[0]
scaler = StandardScaler()
scaler.fit(np.array(pd.concat(list(deepof_coords.values()))))
```
%%%% Output: execute_result
StandardScaler()
%% Cell type:code id: tags:
``` python
# Retrieve latent distribution parameters and sample from posterior
def get_median_params(component, categories, cluster, param):
# means = [np.median(component.mean().numpy(), axis=0) for component in mix_components]
# stddevs = [np.median(component.stddev().numpy(), axis=0) for component in mix_components]
if param == "mean":
component = component.mean().numpy()
elif param == "stddev":
component = component.stddev().numpy()
cluster_select = np.argmax(categories, axis=1)==cluster
if np.sum(cluster_select) == 0:
return None
component = component[cluster_select]
return np.median(component, axis=0)
```
%% Cell type:code id: tags:
``` python
def retrieve_latent_parameters(distribution, reduce=False, plot=False, categories=None, filt=0, save=True):
mix_components = distribution.components
# The main problem is here! We need to select only those training instances in which a given cluster was selected.
# Then compute the median for those only
means = [get_median_params(component, categories, i, "mean") for i,component in enumerate(mix_components)]
stddevs = [get_median_params(component, categories, i, "stddev") for i,component in enumerate(mix_components)]
means = [i for i in means if i is not None]
stddevs = [i for i in stddevs if i is not None]
if filter:
filts = np.max(categories, axis=0) > filt
means = [i for i,j in zip(means, filts) if j]
stddevs = [i for i,j in zip(stddevs, filts) if j]
if reduce:
data = [np.random.normal(size=[1000, len(means[0])],
loc=meanvec,
scale=stddevvec)[:,np.newaxis] for meanvec, stddevvec in zip(means, stddevs)]
data = np.concatenate(data, axis=1).reshape([1000*len(means), len(means[0])])
reducer = PCA(n_components=3)
data = reducer.fit_transform(data)
data = data.reshape([1000, len(means), 3])
if plot == 2:
for i in range(len(means)):
plt.scatter(data[:,i,0], data[:,i,1], label=i)
plt.title("Mean representation of latent space - K={}/{} - L={} - filt={}".format(len(means),
len(mix_components),
len(means[0]), filt))
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
#plt.legend()
if save:
plt.savefig("Mean representation of latent space - K={}.{} - L={} - filt={}.png".format(len(means),
len(mix_components),
len(means[0]), filt).replace(" ", "_"))
plt.show()
elif plot == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in range(len(means)):
ax.scatter(data[:,i,0], data[:,i,1], data[:,i,2], label=i)
plt.title("Mean representation of latent space - K={}/{} - L={} - filt={}".format(len(means),
len(mix_components),
len(means[0]), filt))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
ax.set_zlabel("PCA 3")
#plt.legend()
if save:
plt.savefig("Mean representation of latent space - K={}.{} - L={} - filt={}.png".format(len(means),
len(mix_components),
len(means[0]), filt).replace(" ", "_"))
plt.show()
elif plot > 3:
raise ValueError("Can't plot in more than 3 dimensions!")
return means, stddevs
def sample_from_posterior(decoder, parameters, component, enable_variance=False, video_output=False, samples=1):
means, stddevs = parameters
sample = np.random.normal(size=[samples, len(means[component])], loc=means[component], scale=(stddevs[component] if enable_variance else 0))