Commit 8a165ed4 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent c2f7d805
Pipeline #98294 failed with stages
in 22 minutes and 54 seconds
......@@ -1167,13 +1167,16 @@ class table_dict(dict):
when calling get_coords) axis with the y-axis of the cartesian plane. If 'center', rotates all instances
using the angle of the central frame of the sliding window. This way rotations of the animal are caught
as well. It doesn't do anything if False.
- propagate_labels (bool): If True, returns a label vector acompaigning each training instance
Returns:
- X_train (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
generated from all training videos
- X_test (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
generated from all test videos (if test_videos > 0)
- y_train (np.ndarray): 2d dataset with a shape dependent in the type of labels the model uses
(phenotypes, rule-based tags).
- y_test (np.ndarray): 2d dataset with a shape dependent in the type of labels the model uses
(phenotypes, rule-based tags).
"""
......@@ -1216,8 +1219,9 @@ class table_dict(dict):
X_train = deepof.utils.align_trajectories(X_train, align)
X_train = deepof.utils.rolling_window(X_train, window_size, window_step)
if self._propagate_labels:
y_train = y_train[::window_step][: X_train.shape[0]]
if self._propagate_labels or self._propagate_annotations:
y_train = deepof.utils.rolling_window(y_train, window_size, window_step)
y_train = y_train.mean(axis=1)
if align == "center":
X_train = deepof.utils.align_trajectories(X_train, align)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment