Skip to content
Snippets Groups Projects
Commit 617966ac authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Fixed a bug on simple data projectors

parent 521ff2b9
No related branches found
No related tags found
No related merge requests found
Pipeline #92774 passed
...@@ -20,6 +20,7 @@ from typing import Dict, List, Tuple, Union ...@@ -20,6 +20,7 @@ from typing import Dict, List, Tuple, Union
from multiprocessing import cpu_count from multiprocessing import cpu_count
from sklearn import random_projection from sklearn import random_projection
from sklearn.decomposition import KernelPCA from sklearn.decomposition import KernelPCA
from sklearn.impute import SimpleImputer
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from tqdm import tqdm from tqdm import tqdm
...@@ -1049,6 +1050,7 @@ class table_dict(dict): ...@@ -1049,6 +1050,7 @@ class table_dict(dict):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
X = X[np.random.choice(X.shape[0], sample, replace=False), :] X = X[np.random.choice(X.shape[0], sample, replace=False), :]
X = SimpleImputer(strategy="median").fit_transform(X)
rproj = random_projection.GaussianRandomProjection(n_components=n_components) rproj = random_projection.GaussianRandomProjection(n_components=n_components)
X = rproj.fit_transform(X) X = rproj.fit_transform(X)
...@@ -1063,6 +1065,7 @@ class table_dict(dict): ...@@ -1063,6 +1065,7 @@ class table_dict(dict):
performance or visualization reasons""" performance or visualization reasons"""
X = self.get_training_set()[0] X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
# Takes care of propagated labels if present # Takes care of propagated labels if present
if self._propagate_labels: if self._propagate_labels:
...@@ -1084,6 +1087,7 @@ class table_dict(dict): ...@@ -1084,6 +1087,7 @@ class table_dict(dict):
performance or visualization reasons""" performance or visualization reasons"""
X = self.get_training_set()[0] X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
# Takes care of propagated labels if present # Takes care of propagated labels if present
if self._propagate_labels: if self._propagate_labels:
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -137,23 +137,16 @@ def test_climb_wall(center, axes, angle, tol): ...@@ -137,23 +137,16 @@ def test_climb_wall(center, axes, angle, tol):
"y4", "y4",
"X5", "X5",
"y5", "y5",
"X6",
"y6",
"X7",
"y7",
"X8",
"y8",
], ],
dtype=float, dtype=float,
elements=st.floats(min_value=-20, max_value=20), elements=st.floats(min_value=-20, max_value=20),
), ),
), ),
tol_forward=st.floats(min_value=0.01, max_value=4.98), tol_forward=st.floats(min_value=0.01, max_value=4.98),
tol_spine=st.floats(min_value=0.01, max_value=4.98),
tol_speed=st.floats(min_value=0.01, max_value=4.98), tol_speed=st.floats(min_value=0.01, max_value=4.98),
animal_id=st.text(min_size=0, max_size=15, alphabet=string.ascii_lowercase), animal_id=st.text(min_size=0, max_size=15, alphabet=string.ascii_lowercase),
) )
def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id): def test_huddle(pos_dframe, tol_forward, tol_speed, animal_id):
_id = animal_id _id = animal_id
if animal_id != "": if animal_id != "":
...@@ -162,14 +155,11 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id): ...@@ -162,14 +155,11 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
idx = pd.MultiIndex.from_product( idx = pd.MultiIndex.from_product(
[ [
[ [
_id + "Left_ear", _id + "Left_bhip",
_id + "Right_ear", _id + "Right_bhip",
_id + "Left_fhip", _id + "Left_fhip",
_id + "Right_fhip", _id + "Right_fhip",
_id + "Spine_1",
_id + "Center", _id + "Center",
_id + "Spine_2",
_id + "Tail_base",
], ],
["X", "y"], ["X", "y"],
], ],
...@@ -180,7 +170,6 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id): ...@@ -180,7 +170,6 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
pos_dframe, pos_dframe,
pos_dframe.xs("X", level="coords", axis=1, drop_level=True), pos_dframe.xs("X", level="coords", axis=1, drop_level=True),
tol_forward, tol_forward,
tol_spine,
tol_speed, tol_speed,
animal_id, animal_id,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment