Commit 72ac0aac authored by lucas_miranda's avatar lucas_miranda
Browse files

Added nose2body to rule_based_annotation()

parent c1020aa1
Pipeline #95417 passed with stages
in 37 minutes and 21 seconds
......@@ -622,7 +622,7 @@ class coordinates:
center=center,
polar=polar,
propagate_labels=propagate_labels,
propagate_annotations=bool(propagate_annotations),
propagate_annotations=propagate_annotations,
)
def get_distances(
......@@ -678,7 +678,7 @@ class coordinates:
return table_dict(
tabs,
propagate_labels=propagate_labels,
propagate_annotations=bool(propagate_annotations),
propagate_annotations=propagate_annotations,
typ="dists",
)
......@@ -743,7 +743,7 @@ class coordinates:
return table_dict(
tabs,
propagate_labels=propagate_labels,
propagate_annotations=bool(propagate_annotations),
propagate_annotations=propagate_annotations,
typ="angles",
)
......@@ -953,7 +953,7 @@ class table_dict(dict):
center: str = None,
polar: bool = None,
propagate_labels: bool = False,
propagate_annotations: bool = False,
propagate_annotations: Dict = False,
):
super().__init__(tabs)
self._type = typ
......@@ -1016,6 +1016,7 @@ class table_dict(dict):
"""Generates training and test sets as numpy.array objects for model training"""
# Padding of videos with slightly different lengths
# Making sure that the training and test sets end up balanced
raw_data = np.array([np.array(v) for v in self.values()], dtype=object)
if self._propagate_labels:
concat_raw = np.concatenate(raw_data, axis=0)
......@@ -1047,12 +1048,25 @@ class table_dict(dict):
except TypeError:
pass
if self._propagate_annotations:
n_annot = list(self._propagate_annotations.values())[0].shape[1]
propagated_annots = X_train[:, -n_annot:]
X_train = X_train[:, :-n_annot]
try:
X_test = X_test[:, :-n_annot]
except TypeError:
pass
if encode_labels:
le = LabelEncoder()
y_train = le.fit_transform(y_train)
y_test = le.transform(y_test)
return X_train, y_train, X_test, y_test
try:
# noinspection PyUnboundLocalVariable
return X_train, y_train, X_test, y_test, propagated_annots
except NameError:
return X_train, y_train, X_test, y_test
# noinspection PyTypeChecker,PyGlobalUndefined
def preprocess(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment