Commit 545d089a authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed all type() checks for isinstance() to take inheritance into account

parent 48ef8ea8
...@@ -40,8 +40,8 @@ import warnings ...@@ -40,8 +40,8 @@ import warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# DEFINE CUSTOM ANNOTATED TYPES # # DEFINE CUSTOM ANNOTATED TYPES #
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any) Coordinates = deepof.utils.Newisinstance("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any) Table_dict = deepof.utils.Newisinstance("Table_dict", deepof.utils.Any)
# CLASSES FOR PREPROCESSING AND DATA WRANGLING # CLASSES FOR PREPROCESSING AND DATA WRANGLING
...@@ -549,7 +549,7 @@ class coordinates: ...@@ -549,7 +549,7 @@ class coordinates:
- self._scales[i][1] / 2 - self._scales[i][1] / 2
) )
elif type(center) == str and center != "arena": elif isinstance(center, str) and center != "arena":
for i, (key, value) in enumerate(tabs.items()): for i, (key, value) in enumerate(tabs.items()):
...@@ -583,7 +583,7 @@ class coordinates: ...@@ -583,7 +583,7 @@ class coordinates:
for key, tab in tabs.items(): for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range( tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype("timedelta64[s]") ).asisinstance("timedelta64[s]")
if align: if align:
assert ( assert (
...@@ -667,7 +667,7 @@ class coordinates: ...@@ -667,7 +667,7 @@ class coordinates:
for key, tab in tabs.items(): for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range( tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype("timedelta64[s]") ).asisinstance("timedelta64[s]")
if propagate_labels: if propagate_labels:
for key, tab in tabs.items(): for key, tab in tabs.items():
...@@ -732,7 +732,7 @@ class coordinates: ...@@ -732,7 +732,7 @@ class coordinates:
for key, tab in tabs.items(): for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range( tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype("timedelta64[s]") ).asisinstance("timedelta64[s]")
if propagate_labels: if propagate_labels:
for key, tab in tabs.items(): for key, tab in tabs.items():
...@@ -833,7 +833,7 @@ class coordinates: ...@@ -833,7 +833,7 @@ class coordinates:
) )
pbar.update(1) pbar.update(1)
if type(video_output) == list: if isinstance(video_output, list):
vid_idxs = video_output vid_idxs = video_output
elif video_output == "all": elif video_output == "all":
vid_idxs = list(self._tables.keys()) vid_idxs = list(self._tables.keys())
......
...@@ -25,7 +25,7 @@ import warnings ...@@ -25,7 +25,7 @@ import warnings
warnings.filterwarnings("ignore", message="All-NaN slice encountered") warnings.filterwarnings("ignore", message="All-NaN slice encountered")
# Create custom string type # Create custom string type
Coordinates = NewType("Coordinates", Any) Coordinates = Newisinstance("Coordinates", Any)
def close_single_contact( def close_single_contact(
...@@ -53,12 +53,12 @@ def close_single_contact( ...@@ -53,12 +53,12 @@ def close_single_contact(
close_contact = None close_contact = None
if type(right) == str: if isinstance(right, str):
close_contact = ( close_contact = (
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol ) / arena_rel < tol
elif type(right) == list: elif isinstance(right, list):
close_contact = np.any( close_contact = np.any(
[ [
(np.linalg.norm(pos_dframe[left] - pos_dframe[r], axis=1) * arena_abs) (np.linalg.norm(pos_dframe[left] - pos_dframe[r], axis=1) * arena_abs)
...@@ -528,7 +528,7 @@ def max_behaviour( ...@@ -528,7 +528,7 @@ def max_behaviour(
speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()] speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]
behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float") behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).asisinstance("float")
win_array = behaviour_dframe.rolling(window_size, center=True).sum() win_array = behaviour_dframe.rolling(window_size, center=True).sum()
if stepped: if stepped:
win_array = win_array[::window_size] win_array = win_array[::window_size]
...@@ -678,8 +678,8 @@ def rule_based_tagging( ...@@ -678,8 +678,8 @@ def rule_based_tagging(
return deepof.utils.smooth_boolean_array( return deepof.utils.smooth_boolean_array(
close_single_contact( close_single_contact(
coords, coords,
(left if type(left) != list else right), (left if not isinstance(left, list) else right),
(right if type(left) != list else left), (right if not isinstance(left, list) else left),
params["close_contact_tol"], params["close_contact_tol"],
arena_abs, arena_abs,
arena[1][1], arena[1][1],
......
...@@ -26,7 +26,7 @@ from typing import Tuple, Any, List, Union, NewType ...@@ -26,7 +26,7 @@ from typing import Tuple, Any, List, Union, NewType
# DEFINE CUSTOM ANNOTATED TYPES # # DEFINE CUSTOM ANNOTATED TYPES #
Coordinates = NewType("Coordinates", Any) Coordinates = Newisinstance("Coordinates", Any)
# CONNECTIVITY FOR DLC MODELS # CONNECTIVITY FOR DLC MODELS
...@@ -750,7 +750,7 @@ def cluster_transition_matrix( ...@@ -750,7 +750,7 @@ def cluster_transition_matrix(
# Stores all possible transitions between clusters # Stores all possible transitions between clusters
clusters = [str(i) for i in range(nclusts)] clusters = [str(i) for i in range(nclusts)]
cluster_sequence = cluster_sequence.astype(str) cluster_sequence = cluster_sequence.asisinstance(str)
trans = {t: 0 for t in product(clusters, clusters)} trans = {t: 0 for t in product(clusters, clusters)}
k = len(clusters) k = len(clusters)
......
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import os import os
os.chdir(os.path.dirname("../")) os.chdir(os.path.dirname("../"))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import cv2 import cv2
import deepof.data import deepof.data
import deepof.models import deepof.models
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import re import re
import seaborn as sns import seaborn as sns
from sklearn.preprocessing import StandardScaler, MinMaxScaler from sklearn.preprocessing import StandardScaler, MinMaxScaler
import tensorflow as tf import tensorflow as tf
import tqdm.notebook as tqdm import tqdm.notebook as tqdm
from ipywidgets import interact from ipywidgets import interact
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import umap import umap
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Retrieve phenotypes # Retrieve phenotypes
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
flatten = lambda t: [item for sublist in t for item in sublist] flatten = lambda t: [item for sublist in t for item in sublist]
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Load first batch # Load first batch
dset11 = pd.ExcelFile( dset11 = pd.ExcelFile(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/1.Openfield_data-part1/JB05.1-OF-SI-part1.xlsx" "../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/1.Openfield_data-part1/JB05.1-OF-SI-part1.xlsx"
) )
dset12 = pd.ExcelFile( dset12 = pd.ExcelFile(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/2.Openfielddata-part2/AnimalID's-JB05.1-part2.xlsx" "../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/2.Openfielddata-part2/AnimalID's-JB05.1-part2.xlsx"
) )
dset11 = pd.read_excel(dset11, "Tabelle2") dset11 = pd.read_excel(dset11, "Tabelle2")
dset12 = pd.read_excel(dset12, "Tabelle2") dset12 = pd.read_excel(dset12, "Tabelle2")
dset11.Test = dset11.Test.apply(lambda x: "Test {}_s11".format(x)) dset11.Test = dset11.Test.apply(lambda x: "Test {}_s11".format(x))
dset12.Test = dset12.Test.apply(lambda x: "Test {}_s12".format(x)) dset12.Test = dset12.Test.apply(lambda x: "Test {}_s12".format(x))
dset1 = {"CSDS":list(dset11.loc[dset11.Treatment.isin(["CTR+CSDS","NatCre+CSDS"]), "Test"]) + dset1 = {"CSDS":list(dset11.loc[dset11.Treatment.isin(["CTR+CSDS","NatCre+CSDS"]), "Test"]) +
list(dset12.loc[dset12.Treatment.isin(["CTR+CSDS","NatCre+CSDS"]), "Test"]), list(dset12.loc[dset12.Treatment.isin(["CTR+CSDS","NatCre+CSDS"]), "Test"]),
"NS": list(dset11.loc[dset11.Treatment.isin(["CTR+nonstressed","NatCre+nonstressed"]), "Test"]) + "NS": list(dset11.loc[dset11.Treatment.isin(["CTR+nonstressed","NatCre+nonstressed"]), "Test"]) +
list(dset12.loc[dset12.Treatment.isin(["CTR+nonstressed","NatCre+nonstressed"]), "Test"]),} list(dset12.loc[dset12.Treatment.isin(["CTR+nonstressed","NatCre+nonstressed"]), "Test"]),}
dset1inv = {} dset1inv = {}
for i in flatten(list(dset1.values())): for i in flatten(list(dset1.values())):
if i in dset1["CSDS"]: if i in dset1["CSDS"]:
dset1inv[i] = "CSDS" dset1inv[i] = "CSDS"
else: else:
dset1inv[i] = "NS" dset1inv[i] = "NS"
assert len(dset1inv) == dset11.shape[0] + dset12.shape[0], "You missed some labels!" assert len(dset1inv) == dset11.shape[0] + dset12.shape[0], "You missed some labels!"
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Load second batch # Load second batch
dset21 = pd.read_excel( dset21 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part1/2_Single/stressproject22.04.2020genotypes-openfieldday1.xlsx" "../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part1/2_Single/stressproject22.04.2020genotypes-openfieldday1.xlsx"
) )
dset22 = pd.read_excel( dset22 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part2/2_Single/OpenFieldvideos-part2.xlsx" "../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part2/2_Single/OpenFieldvideos-part2.xlsx"
) )
dset21.Test = dset21.Test.apply(lambda x: "Test {}_s21".format(x)) dset21.Test = dset21.Test.apply(lambda x: "Test {}_s21".format(x))
dset22.Test = dset22.Test.apply(lambda x: "Test {}_s22".format(x)) dset22.Test = dset22.Test.apply(lambda x: "Test {}_s22".format(x))
dset2 = {"CSDS":list(dset21.loc[dset21.Treatment == "Stress", "Test"]) + dset2 = {"CSDS":list(dset21.loc[dset21.Treatment == "Stress", "Test"]) +
list(dset22.loc[dset22.Treatment == "Stressed", "Test"]), list(dset22.loc[dset22.Treatment == "Stressed", "Test"]),
"NS": list(dset21.loc[dset21.Treatment == "Nonstressed", "Test"]) + "NS": list(dset21.loc[dset21.Treatment == "Nonstressed", "Test"]) +
list(dset22.loc[dset22.Treatment == "Nonstressed", "Test"])} list(dset22.loc[dset22.Treatment == "Nonstressed", "Test"])}
dset2inv = {} dset2inv = {}
for i in flatten(list(dset2.values())): for i in flatten(list(dset2.values())):
if i in dset2["CSDS"]: if i in dset2["CSDS"]:
dset2inv[i] = "CSDS" dset2inv[i] = "CSDS"
else: else:
dset2inv[i] = "NS" dset2inv[i] = "NS"
assert len(dset2inv) == dset21.shape[0] + dset22.shape[0], "You missed some labels!" assert len(dset2inv) == dset21.shape[0] + dset22.shape[0], "You missed some labels!"
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Load third batch # Load third batch
dset31 = pd.read_excel( dset31 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx", "../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx",
sheet_name=1 sheet_name=1
) )
dset32 = pd.read_excel( dset32 = pd.read_excel(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx", "../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx",
sheet_name=1 sheet_name=1
) )
dset31.Test = dset31.Test.apply(lambda x: "Test {}_s31".format(x)) dset31.Test = dset31.Test.apply(lambda x: "Test {}_s31".format(x))
dset32.Test = dset32.Test.apply(lambda x: "Test {}_s32".format(x)) dset32.Test = dset32.Test.apply(lambda x: "Test {}_s32".format(x))
dset3 = {"CSDS":[], dset3 = {"CSDS":[],
"NS": list(dset31.loc[:, "Test"]) + "NS": list(dset31.loc[:, "Test"]) +
list(dset32.loc[:, "Test"])} list(dset32.loc[:, "Test"])}
dset3inv = {} dset3inv = {}
for i in flatten(list(dset3.values())): for i in flatten(list(dset3.values())):
if i in dset3["CSDS"]: if i in dset3["CSDS"]:
dset3inv[i] = "CSDS" dset3inv[i] = "CSDS"
else: else:
dset3inv[i] = "NS" dset3inv[i] = "NS"
assert len(dset3inv) == dset31.shape[0] + dset32.shape[0], "You missed some labels!" assert len(dset3inv) == dset31.shape[0] + dset32.shape[0], "You missed some labels!"
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Load fourth batch # Load fourth batch
dset41 = os.listdir("../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_4/JB05.4-OpenFieldvideos/") dset41 = os.listdir("../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_4/JB05.4-OpenFieldvideos/")
# Remove empty video! # Remove empty video!
dset41 = [vid for vid in dset41 if "52" not in vid] dset41 = [vid for vid in dset41 if "52" not in vid]
dset4 = {"CSDS":[], dset4 = {"CSDS":[],
"NS": [i[:-4]+"_s41" for i in dset41]} "NS": [i[:-4]+"_s41" for i in dset41]}
dset4inv = {} dset4inv = {}
for i in flatten(list(dset4.values())): for i in flatten(list(dset4.values())):
if i in dset4["CSDS"]: if i in dset4["CSDS"]:
dset4inv[i] = "CSDS" dset4inv[i] = "CSDS"
else: else:
dset4inv[i] = "NS" dset4inv[i] = "NS"
assert len(dset4inv) == len(dset41), "You missed some labels!" assert len(dset4inv) == len(dset41), "You missed some labels!"
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Merge phenotype dicts and serialise! # Merge phenotype dicts and serialise!
aggregated_dset = {**dset1inv, **dset2inv, **dset3inv, **dset4inv} aggregated_dset = {**dset1inv, **dset2inv, **dset3inv, **dset4inv}
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from collections import Counter from collections import Counter
print(Counter(aggregated_dset.values())) print(Counter(aggregated_dset.values()))
print(115+52) print(115+52)
``` ```
%%%% Output: stream %%%% Output: stream
Counter({'NS': 115, 'CSDS': 52}) Counter({'NS': 115, 'CSDS': 52})
167 167
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Save aggregated dataset to disk # Save aggregated dataset to disk
import pickle import pickle
with open("../../Desktop/deepof-data/deepof_single_topview/deepof_exp_conditions.pkl", "wb") as handle: with open("../../Desktop/deepof-data/deepof_single_topview/deepof_exp_conditions.pkl", "wb") as handle:
pickle.dump(aggregated_dset, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(aggregated_dset, handle, protocol=pickle.HIGHEST_PROTOCOL)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Define and run project # Define and run project
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
deepof_main = deepof.data.project(path=os.path.join("..","..","Desktop","deepoftesttemp"), deepof_main = deepof.data.project(path=os.path.join("..","..","Desktop","deepoftesttemp"),
smooth_alpha=0.99, smooth_alpha=0.99,
arena_dims=[380], arena_dims=[380],
exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip", "Tail_base"], exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip", "Tail_base"],
exp_conditions=aggregated_dset exp_conditions=aggregated_dset
) )
``` ```
%%%% Output: stream %%%% Output: stream
CPU times: user 111 ms, sys: 14 ms, total: 125 ms CPU times: user 111 ms, sys: 14 ms, total: 125 ms
Wall time: 123 ms Wall time: 123 ms
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
deepof_main = deepof_main.run(verbose=True) deepof_main = deepof_main.run(verbose=True)
print(deepof_main) print(deepof_main)
``` ```
%%%% Output: stream %%%% Output: stream
Loading trajectories... Loading trajectories...
Smoothing trajectories... Smoothing trajectories...
Interpolating outliers... Interpolating outliers...
Iterative imputation of ocluded bodyparts... Iterative imputation of ocluded bodyparts...
Computing distances... Computing distances...
Computing angles... Computing angles...
Done! Done!
Coordinates of 2 videos across 2 conditions Coordinates of 2 videos across 2 conditions
CPU times: user 4.8 s, sys: 806 ms, total: 5.61 s CPU times: user 4.8 s, sys: 806 ms, total: 5.61 s
Wall time: 4.32 s Wall time: 4.32 s
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python