Commit 2158e4b7 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored rule_based_tagging and fixed type annotation problems

parent 1e83611f
......@@ -12,6 +12,7 @@ import warnings
import networkx as nx
from deepof.utils import *
from deepof.visuals import *
class project:
......
......@@ -16,7 +16,12 @@ from scipy import spatial
from scipy import stats
from sklearn import mixture
from tqdm import tqdm
from typing import Tuple, Any, List, Union
from typing import Tuple, Any, List, Union, Dict, NewType
# DEFINE CUSTOM ANNOTATED TYPES #
TableDict = NewType("TableDict", Any)
# QUALITY CONTROL AND PREPROCESSING #
......@@ -281,7 +286,12 @@ def smooth_mult_trajectory(series: np.array, alpha: float = 0.15) -> np.array:
def close_single_contact(
pos_dframe: pd.DataFrame, left: str, right: str, tol: float
pos_dframe: pd.DataFrame,
left: str,
right: str,
tol: float,
arena_abs: int,
arena_rel: int,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
......@@ -290,13 +300,17 @@ def close_single_contact(
to two-animal experiments.
- left (string): First member of the potential contact
- right (string): Second member of the potential contact
- tol (float)
- tol (float): maximum distance for which a contact is reported
- arena_abs (int): length in mm of the diameter of the real arena
- arena_rel (int): length in pixels of the diameter of the arena in the video
Returns:
- contact_array (np.array): True if the distance between the two specified points
is less than tol, False otherwise"""
close_contact = np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) < tol
close_contact = (
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol
return close_contact
......@@ -308,6 +322,8 @@ def close_double_contact(
right1: str,
right2: str,
tol: float,
arena_abs: int,
arena_rel: int,
rev: bool = False,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
......@@ -319,7 +335,10 @@ def close_double_contact(
- left2 (string): Second contact point of animal 1
- right1 (string): First contact point of animal 2
- right2 (string): Second contact point of animal 2
- tol (float)
- tol (float): maximum distance for which a contact is reported
- arena_abs (int): length in mm of the diameter of the real arena
- arena_rel (int): length in pixels of the diameter of the arena in the video
- rev (bool): reverses the default behaviour (nose2tail contact for both mice)
Returns:
- double_contact (np.array): True if the distance between the two specified points
......@@ -327,13 +346,25 @@ def close_double_contact(
if rev:
double_contact = (
np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) < tol
) & (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) < tol)
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
)
else:
double_contact = (
np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) < tol
) & (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) < tol)
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
)
return double_contact
......@@ -344,7 +375,7 @@ def recognize_arena(
path: str = ".",
recoglimit: int = 1,
arena_type: str = "circular",
) -> np.array:
) -> Tuple[np.array, int, int]:
"""Returns numpy.array with information about the arena recognised from the first frames
of the video. WARNING: estimates won't be reliable if the camera moves along the video.
......@@ -357,7 +388,9 @@ def recognize_arena(
Returns:
- arena (np.array): 1D-array containing information about the arena.
"circular" (3-element-array) -> x-y position of the center and the radius"""
"circular" (3-element-array) -> x-y position of the center and the radius
- h (int): height of the video in pixels
- w (int): width of the video in pixels"""
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
......@@ -380,7 +413,7 @@ def recognize_arena(
fnum += 1
return arena
return arena, h, w
def circular_arena_recognition(frame: np.array) -> np.array:
......@@ -493,7 +526,9 @@ def rolling_speed(
return speeds
def huddle(pos_dframe: pd.DataFrame, tol_forward: float, tol_spine: float) -> np.array:
def huddle(
pos_dframe: pd.DataFrame, tol_forward: float, tol_spine: float, tol_speed: float
) -> np.array:
"""Returns true when the mouse is huddling using simple rules. (!!!) Designed to
work with deepof's default DLC mice models; not guaranteed to work otherwise.
......@@ -660,312 +695,6 @@ def single_behaviour_analysis(
return return_list
# MAIN BEHAVIOUR TAGGING FUNCTION #
def tag_video(
Tracks,
Videos,
Track_dict,
Distance_dict,
Like_QC_dict,
vid_index,
show=False,
save=False,
fps=25.0,
speedpause=50,
framelimit=np.inf,
recoglimit=1,
path="./",
classifiers={},
):
"""Outputs a dataframe with the motives registered per frame. If mp4==True, outputs a video in mp4 format"""
vid_name = re.findall("(.*?)_", Tracks[vid_index])[0]
cap = cv2.VideoCapture(path + Videos[vid_index])
dframe = Track_dict[vid_name]
h, w = None, None
bspeed, wspeed = None, None
# Disctionary with motives per frame
tagdict = {
func: np.zeros(dframe.shape[0])
for func in [
"nose2nose",
"bnose2tail",
"wnose2tail",
"sidebyside",
"sidereside",
"bclimbwall",
"wclimbwall",
"bspeed",
"wspeed",
"bhuddle",
"whuddle",
"bfollowing",
"wfollowing",
]
}
# Keep track of the frame number, to align with the tracking data
fnum = 0
if save:
writer = None
# Loop over the first frames in the video to get resolution and center of the arena
while cap.isOpened() and fnum < recoglimit:
ret, frame = cap.read()
# if frame is read correctly ret is True
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
# Detect arena and extract positions
arena = circular_arena_recognition(frame)[0]
if h == None and w == None:
h, w = frame.shape[0], frame.shape[1]
fnum += 1
# Define behaviours that can be computed on the fly from the distance matrix
tagdict["nose2nose"] = smooth_boolean_array(
Distance_dict[vid_name][("B_Nose", "W_Nose")] < 15
)
tagdict["bnose2tail"] = smooth_boolean_array(
Distance_dict[vid_name][("B_Nose", "W_Tail_base")] < 15
)
tagdict["wnose2tail"] = smooth_boolean_array(
Distance_dict[vid_name][("B_Tail_base", "W_Nose")] < 15
)
tagdict["sidebyside"] = smooth_boolean_array(
(Distance_dict[vid_name][("B_Nose", "W_Nose")] < 40)
& (Distance_dict[vid_name][("B_Tail_base", "W_Tail_base")] < 40)
)
tagdict["sidereside"] = smooth_boolean_array(
(Distance_dict[vid_name][("B_Nose", "W_Tail_base")] < 40)
& (Distance_dict[vid_name][("B_Tail_base", "W_Nose")] < 40)
)
B_mouse_X = np.array(
Distance_dict[vid_name][
[j for j in Distance_dict[vid_name].keys() if "B_" in j[0] and "B_" in j[1]]
]
)
W_mouse_X = np.array(
Distance_dict[vid_name][
[j for j in Distance_dict[vid_name].keys() if "W_" in j[0] and "W_" in j[1]]
]
)
tagdict["bhuddle"] = smooth_boolean_array(classifiers["huddle"].predict(B_mouse_X))
tagdict["whuddle"] = smooth_boolean_array(classifiers["huddle"].predict(W_mouse_X))
tagdict["bclimbwall"] = smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(dframe["B_Nose"]), np.array([arena[:2]])
)
> (w / 200 + arena[2])
).reshape(dframe.shape[0]),
index=dframe.index,
)
)
tagdict["wclimbwall"] = smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(dframe["W_Nose"]), np.array([arena[:2]])
)
> (w / 200 + arena[2])
).reshape(dframe.shape[0]),
index=dframe.index,
)
)
tagdict["bfollowing"] = smooth_boolean_array(
following_path(
Distance_dict[vid_name],
dframe,
follower="B",
followed="W",
frames=20,
tol=20,
)
)
tagdict["wfollowing"] = smooth_boolean_array(
following_path(
Distance_dict[vid_name],
dframe,
follower="W",
followed="B",
frames=20,
tol=20,
)
)
# Compute speed on a rolling window
tagdict["bspeed"] = rolling_speed(dframe["B_Center"], window=speedpause)
tagdict["wspeed"] = rolling_speed(dframe["W_Center"], window=speedpause)
if any([show, save]):
# Loop over the frames in the video
pbar = tqdm(total=min(dframe.shape[0] - recoglimit, framelimit))
while cap.isOpened() and fnum < framelimit:
ret, frame = cap.read()
# if frame is read correctly ret is True
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
font = cv2.FONT_HERSHEY_COMPLEX_SMALL
if Like_QC_dict[vid_name][fnum]:
# Extract positions
pos_dict = {
i: np.array([dframe[i]["x"][fnum], dframe[i]["y"][fnum]])
for i in dframe.columns.levels[0]
if i != "Like_QC"
}
if h == None and w == None:
h, w = frame.shape[0], frame.shape[1]
# Label positions
downleft = (int(w * 0.3 / 10), int(h / 1.05))
downright = (int(w * 6.5 / 10), int(h / 1.05))
upleft = (int(w * 0.3 / 10), int(h / 20))
upright = (int(w * 6.3 / 10), int(h / 20))
# Display all annotations in the output video
if tagdict["nose2nose"][fnum] and not tagdict["sidebyside"][fnum]:
cv2.putText(
frame,
"Nose-Nose",
(downleft if bspeed > wspeed else downright),
font,
1,
(255, 255, 255),
2,
)
if tagdict["bnose2tail"][fnum] and not tagdict["sidereside"][fnum]:
cv2.putText(
frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2
)
if tagdict["wnose2tail"][fnum] and not tagdict["sidereside"][fnum]:
cv2.putText(
frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2
)
if tagdict["sidebyside"][fnum]:
cv2.putText(
frame,
"Side-side",
(downleft if bspeed > wspeed else downright),
font,
1,
(255, 255, 255),
2,
)
if tagdict["sidereside"][fnum]:
cv2.putText(
frame,
"Side-Rside",
(downleft if bspeed > wspeed else downright),
font,
1,
(255, 255, 255),
2,
)
if tagdict["bclimbwall"][fnum]:
cv2.putText(
frame, "Climbing", downleft, font, 1, (255, 255, 255), 2
)
if tagdict["wclimbwall"][fnum]:
cv2.putText(
frame, "Climbing", downright, font, 1, (255, 255, 255), 2
)
if tagdict["bhuddle"][fnum] and not tagdict["bclimbwall"][fnum]:
cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2)
if tagdict["whuddle"][fnum] and not tagdict["wclimbwall"][fnum]:
cv2.putText(frame, "huddle", downright, font, 1, (255, 255, 255), 2)
if tagdict["bfollowing"][fnum] and not tagdict["bclimbwall"][fnum]:
cv2.putText(
frame,
"*f",
(int(w * 0.3 / 10), int(h / 10)),
font,
1,
((150, 150, 255) if wspeed > bspeed else (150, 255, 150)),
2,
)
if tagdict["wfollowing"][fnum] and not tagdict["wclimbwall"][fnum]:
cv2.putText(
frame,
"*f",
(int(w * 6.3 / 10), int(h / 10)),
font,
1,
((150, 150, 255) if wspeed < bspeed else (150, 255, 150)),
2,
)
if (bspeed == None and wspeed == None) or fnum % speedpause == 0:
bspeed = tagdict["bspeed"][fnum]
wspeed = tagdict["wspeed"][fnum]
cv2.putText(
frame,
"W: " + str(np.round(wspeed, 2)) + " mmpf",
(upright[0] - 20, upright[1]),
font,
1,
((150, 150, 255) if wspeed < bspeed else (150, 255, 150)),
2,
)
cv2.putText(
frame,
"B: " + str(np.round(bspeed, 2)) + " mmpf",
upleft,
font,
1,
((150, 150, 255) if bspeed < wspeed else (150, 255, 150)),
2,
)
if show:
cv2.imshow("frame", frame)
if save:
if writer is None:
# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# Define the FPS. Also frame size is passed.
writer = cv2.VideoWriter()
writer.open(
re.findall("(.*?)_", Tracks[vid_index])[0] + "_tagged.avi",
cv2.VideoWriter_fourcc(*"MJPG"),
fps,
(frame.shape[1], frame.shape[0]),
True,
)
writer.write(frame)
if cv2.waitKey(1) == ord("q"):
break
pbar.update(1)
fnum += 1
cap.release()
cv2.destroyAllWindows()
tagdf = pd.DataFrame(tagdict)
return tagdf, arena
def max_behaviour(
behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
) -> np.array:
......@@ -1136,5 +865,349 @@ def cluster_transition_matrix(
return trans_normed
# MAIN BEHAVIOUR TAGGING FUNCTION #
def rule_based_tagging(
tracks: List,
videos: List,
table_dict: TableDict,
vid_index: int,
arena_abs: int,
animal_ids: List = None,
show: bool = False,
save: bool = False,
fps: float = 25.0,
speed_pause: int = 50,
frame_limit: float = np.inf,
recog_limit: int = 1,
path: str = os.path.join("./"),
arena_type: str = "circular",
classifiers: Dict = None,
) -> Tuple[pd.DataFrame, Any]:
"""Outputs a dataframe with the motives registered per frame. If mp4==True, outputs a video in mp4 format"""
# noinspection PyProtectedMember
assert table_dict._type == "merged", (
"Table_dict must be of merged type, "
"and contain at least position, speed and distance information"
)
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
dframe = table_dict[vid_name]
arena, h, w = recognize_arena(videos, vid_index, path, recog_limit, arena_type)
# Dictionary with motives per frame
behavioural_tags = []
if animal_ids:
behavioural_tags.append(["nose2nose", "sidebyside", "sidereside"])
for _id in animal_ids:
for behaviour in [
"_nose2tail",
"_climbing",
"_huddle",
"_following",
"_speed",
]:
behavioural_tags.append(_id + behaviour)
else:
behavioural_tags.append(["huddle", "climbing", "speed"])
tag_dict = {tag: np.zeros(dframe.shape[0]) for tag in behavioural_tags}
if animal_ids:
# Define behaviours that can be computed on the fly from the distance matrix
tag_dict["nose2nose"] = smooth_boolean_array(
close_single_contact(
dframe,
animal_ids[0] + "_Nose",
animal_ids[1] + "_Nose",
15.0,
arena_abs,
arena[2],
)
)
tag_dict[animal_ids[0] + "_nose2tail"] = smooth_boolean_array(
close_single_contact(
dframe,
animal_ids[0] + "_Nose",
animal_ids[1] + "_Tail_base",
15.0,
arena_abs,
arena[2],
)
)
tag_dict[animal_ids[1] + "_nose2tail"] = smooth_boolean_array(
close_single_contact(
dframe,
animal_ids[1] + "_Nose",
animal_ids[0] + "_Tail_base",
15.0,
arena_abs,
arena[2],
)
)
tag_dict["sidebyside"] = smooth_boolean_array(
close_double_contact(
dframe,
animal_ids[0] + "_Nose",
animal_ids[0] + "_Tail_base",
animal_ids[1] + "_Nose",
animal_ids[1] + "_Tail_base",
15.0,
rev=False,
arena_abs=arena_abs,
arena_rel=arena[2],
)
)
tag_dict["sidereside"] = smooth_boolean_array(
close_double_contact(
dframe,
animal_ids[0] + "_Nose",
animal_ids[0] + "_Tail_base",
animal_ids[1] + "_Nose",
animal_ids[1] + "_Tail_base",
15.0,
rev=True,
arena_abs=arena_abs,
arena_rel=arena[2],
)
)
for _id in animal_ids:
tag_dict[_id + "_following"] = smooth_boolean_array(
following_path(
dframe[vid_name],
dframe,
follower=_id,
followed=[i for i in animal_ids if i != _id][0],
frames=20,
tol=20,
)
)
tag_dict[_id + "_climbwall"] = smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(dframe[_id + "_Nose"]), np.array([arena[:2]])
)
> (w / 200 + arena[2])
).reshape(dframe.shape[0]),
index=dframe.index,
)
)
tag_dict[_id + "speed"] = rolling_speed(
dframe[_id + "_Center"], window=speed_pause
)
else:
tag_dict["climbwall"] = smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(dframe["Nose"]), np.array([arena[:2]])
)
> (w / 200 + arena[2])
).reshape(dframe.shape[0]),
index=dframe.index,
)
)
tag_dict["speed"] = rolling_speed(dframe["Center"], window=speed_pause)
if "huddle" in classifiers:
mouse_X = {
_id: np.array(
dframe[vid_name][
[
j
for j in dframe[vid_name].keys()
if (len(j) == 2 and _id in j[0] and _id in j[1])
]
]