Commit 23168afd authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented iterative imputing for ocluded body parts

parent be629717
Pipeline #93119 canceled with stage
in 40 minutes and 40 seconds
......@@ -20,7 +20,8 @@ from typing import Dict, List, Tuple, Union
from multiprocessing import cpu_count
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.impute import SimpleImputer
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from tqdm import tqdm
......@@ -56,6 +57,7 @@ class project:
animal_ids: List = tuple([""]),
arena: str = "circular",
arena_dims: tuple = (1,),
enable_iterative_imputation: bool = None,
exclude_bodyparts: List = tuple([""]),
exp_conditions: dict = None,
interpolate_outliers: bool = True,
......@@ -110,6 +112,10 @@ class project:
self.smooth_alpha = smooth_alpha
self.subset_condition = None
self.video_format = video_format
if enable_iterative_imputation is None:
self.enable_iterative_imputation = self.animal_ids == tuple([""])
else:
self.enable_iterative_imputation = enable_iterative_imputation
model_dict = {
"mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
......@@ -284,6 +290,17 @@ class project:
n_std=self.interpolation_std,
)
if self.enable_iterative_imputation:
if verbose:
print("Iterative imputation of ocluded bodyparts...")
for k, value in tab_dict.items():
imputed = IterativeImputer(max_iter=250, skip_complete=True).fit_transform(value)
tab_dict[k] = pd.DataFrame(
imputed, index=value.index, columns=value.columns
)
return tab_dict, lik_dict
def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
......@@ -1133,7 +1150,7 @@ class table_dict(dict):
# noinspection PyUnresolvedReferences
X = X[np.random.choice(X.shape[0], sample, replace=False), :]
X = SimpleImputer(strategy="median").fit_transform(X)
X = IterativeImputer().fit_transform(X)
rproj = random_projection.GaussianRandomProjection(n_components=n_components)
X = rproj.fit_transform(X)
......@@ -1148,7 +1165,7 @@ class table_dict(dict):
performance or visualization reasons"""
X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
X = IterativeImputer().fit_transform(X)
# Takes care of propagated labels if present
if self._propagate_labels:
......@@ -1170,7 +1187,7 @@ class table_dict(dict):
performance or visualization reasons"""
X = self.get_training_set()[0]
X = SimpleImputer(strategy="median").fit_transform(X)
X = IterativeImputer().fit_transform(X)
# Takes care of propagated labels if present
if self._propagate_labels:
......
......@@ -81,7 +81,7 @@ def test_get_callbacks(
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
@settings(max_examples=2, deadline=None)
@settings(max_examples=10, deadline=None)
@given(
X_train=arrays(
dtype=float,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment