Commit 617966ac authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed a bug on simple data projectors

parent 521ff2b9
Pipeline #92774 passed with stage
in 21 minutes and 55 seconds
......@@ -20,6 +20,7 @@ from typing import Dict, List, Tuple, Union
from multiprocessing import cpu_count
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.impute import SimpleImputer
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from tqdm import tqdm
......@@ -1049,6 +1050,7 @@ class table_dict(dict):
# noinspection PyUnresolvedReferences
X = X[np.random.choice(X.shape[0], sample, replace=False), :]
X = SimpleImputer(strategy="median").fit_transform(X)
rproj = random_projection.GaussianRandomProjection(n_components=n_components)
X = rproj.fit_transform(X)
......@@ -1063,6 +1065,7 @@ class table_dict(dict):
performance or visualization reasons"""
X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
# Takes care of propagated labels if present
if self._propagate_labels:
......@@ -1084,6 +1087,7 @@ class table_dict(dict):
performance or visualization reasons"""
X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
# Takes care of propagated labels if present
if self._propagate_labels:
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -137,23 +137,16 @@ def test_climb_wall(center, axes, angle, tol):
"y4",
"X5",
"y5",
"X6",
"y6",
"X7",
"y7",
"X8",
"y8",
],
dtype=float,
elements=st.floats(min_value=-20, max_value=20),
),
),
tol_forward=st.floats(min_value=0.01, max_value=4.98),
tol_spine=st.floats(min_value=0.01, max_value=4.98),
tol_speed=st.floats(min_value=0.01, max_value=4.98),
animal_id=st.text(min_size=0, max_size=15, alphabet=string.ascii_lowercase),
)
def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
def test_huddle(pos_dframe, tol_forward, tol_speed, animal_id):
_id = animal_id
if animal_id != "":
......@@ -162,14 +155,11 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
idx = pd.MultiIndex.from_product(
[
[
_id + "Left_ear",
_id + "Right_ear",
_id + "Left_bhip",
_id + "Right_bhip",
_id + "Left_fhip",
_id + "Right_fhip",
_id + "Spine_1",
_id + "Center",
_id + "Spine_2",
_id + "Tail_base",
],
["X", "y"],
],
......@@ -180,7 +170,6 @@ def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed, animal_id):
pos_dframe,
pos_dframe.xs("X", level="coords", axis=1, drop_level=True),
tol_forward,
tol_spine,
tol_speed,
animal_id,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment