Commit def8d460 authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed bug on rule_based_annotation

parent 32be0d7b
Pipeline #83397 canceled with stage
in 4 minutes and 10 seconds
......@@ -56,7 +56,7 @@ class project:
path: str = deepof.utils.os.path.join("."),
exp_conditions: dict = None,
arena: str = "circular",
smooth_alpha: float = 0.1,
smooth_alpha: float = 1.0,
arena_dims: tuple = (1,),
model: str = "mouse_topview",
animal_ids: List = tuple([""]),
......@@ -84,7 +84,7 @@ class project:
self.video_format = video_format
self.arena = arena
self.arena_dims = arena_dims
self.smooth_alpha = smooth_alpha
self.smooth_alpha = 0.99
self.scales = self.get_scale
self.animal_ids = animal_ids
......@@ -652,7 +652,6 @@ class coordinates:
for key in tqdm(self._tables.keys()):
video = [vid for vid in self._videos if key + "DLC" in vid][0]
print(key, video)
tag_dict[key] = deepof.pose_utils.rule_based_tagging(
list(self._tables.keys()),
self._videos,
......
......@@ -538,7 +538,7 @@ def rule_based_tagging(
for _id in animal_ids:
tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
climb_wall(arena_type, arena, coords, 0, _id + undercond + "Nose")
climb_wall(arena_type, arena, coords, w / 100, _id + undercond + "Nose")
)
tag_dict[_id + undercond + "speed"] = speeds[_id + undercond + "Center"]
tag_dict[_id + undercond + "huddle"] = deepof.utils.smooth_boolean_array(
......
......@@ -264,8 +264,6 @@ def align_trajectories(data: np.array, mode: str = "all") -> np.array:
Returns:
- aligned_trajs (2D np.array): aligned positions over time"""
print(data.shape, mode)
angles = np.zeros(data.shape[0])
data = deepcopy(data)
dshape = data.shape
......@@ -326,7 +324,7 @@ def rolling_window(a: np.array, window_size: int, window_step: int) -> np.array:
return rolled_a
def smooth_mult_trajectory(series: np.array, alpha: float = 0.15) -> np.array:
def smooth_mult_trajectory(series: np.array, alpha: float = 0.99) -> np.array:
"""Returns a smooths a trajectory using exponentially weighted averages
Parameters:
......@@ -386,7 +384,7 @@ def recognize_arena(
# Detect arena and extract positions
arena = circular_arena_recognition(frame)[0]
if h is None and w is None:
h, w = frame.shape[0], frame.shape[1]
w, h = frame.shape[0], frame.shape[1]
fnum += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment