Commit 617966ac authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed a bug on simple data projectors

parent 521ff2b9
Pipeline #92774 passed with stage
in 21 minutes and 55 seconds
...@@ -20,6 +20,7 @@ from typing import Dict, List, Tuple, Union ...@@ -20,6 +20,7 @@ from typing import Dict, List, Tuple, Union
from multiprocessing import cpu_count from multiprocessing import cpu_count
from sklearn import random_projection from sklearn import random_projection
from sklearn.decomposition import KernelPCA from sklearn.decomposition import KernelPCA
from sklearn.impute import SimpleImputer
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from tqdm import tqdm from tqdm import tqdm
...@@ -1049,6 +1050,7 @@ class table_dict(dict): ...@@ -1049,6 +1050,7 @@ class table_dict(dict):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
X = X[np.random.choice(X.shape[0], sample, replace=False), :] X = X[np.random.choice(X.shape[0], sample, replace=False), :]
X = SimpleImputer(strategy="median").fit_transform(X)
rproj = random_projection.GaussianRandomProjection(n_components=n_components) rproj = random_projection.GaussianRandomProjection(n_components=n_components)
X = rproj.fit_transform(X) X = rproj.fit_transform(X)
...@@ -1063,6 +1065,7 @@ class table_dict(dict): ...@@ -1063,6 +1065,7 @@ class table_dict(dict):
performance or visualization reasons""" performance or visualization reasons"""
X = self.get_training_set()[0] X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
# Takes care of propagated labels if present # Takes care of propagated labels if present
if self._propagate_labels: if self._propagate_labels:
...@@ -1084,6 +1087,7 @@ class table_dict(dict): ...@@ -1084,6 +1087,7 @@ class table_dict(dict):
performance or visualization reasons""" performance or visualization reasons"""
X = self.get_training_set()[0] X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
# Takes care of propagated labels if present # Takes care of propagated labels if present
if self._propagate_labels: if self._propagate_labels:
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -137,23 +137,16 @@ def test_climb_wall(center, axes, angle, tol): ...@@ -137,23 +137,16 @@ def test_climb_wall(center, axes, angle, tol):
"y4", "y4",
"X5", "X5",
"y5", "y5",
"X6",
"y6",
"X7",
"y7",
"X8",
"y8",
], ],
dtype=float, dtype=float,
elements=st.floats(min_value=-20, max_value=20), elements=st.floats(min_value=-20, max_value=20),
), ),
), ),
tol_forward=st.floats(min_value=0.01, max_value=4.98), tol_forward=st.floats(min_value=0.01, max_value=4.98),
tol_spine=st.floats(min_value=0.01, max_value=4.98),
tol_speed=st.floats(min_value=0.01, max_value=4.98), tol_speed=st.floats(min_value=0.01, max_value=4.98),
animal_id=st.text(min_size=0, max_size=15, alphabet=string.ascii_lowercase), animal_id=st.text(min_size=0, max_size=15, alphabet=string.ascii_lowercase),
) )
def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id): def test_huddle(pos_dframe, tol_forward, tol_speed, animal_id):
_id = animal_id _id = animal_id
if animal_id != "": if animal_id != "":
...@@ -162,14 +155,11 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id): ...@@ -162,14 +155,11 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
idx = pd.MultiIndex.from_product( idx = pd.MultiIndex.from_product(
[ [
[ [
_id + "Left_ear", _id + "Left_bhip",
_id + "Right_ear", _id + "Right_bhip",
_id + "Left_fhip", _id + "Left_fhip",
_id + "Right_fhip", _id + "Right_fhip",
_id + "Spine_1",
_id + "Center", _id + "Center",
_id + "Spine_2",
_id + "Tail_base",
], ],
["X", "y"], ["X", "y"],
], ],
...@@ -180,7 +170,6 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id): ...@@ -180,7 +170,6 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
pos_dframe, pos_dframe,
pos_dframe.xs("X", level="coords", axis=1, drop_level=True), pos_dframe.xs("X", level="coords", axis=1, drop_level=True),
tol_forward, tol_forward,
tol_spine,
tol_speed, tol_speed,
animal_id, animal_id,
) )
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment