Commit 3c4fa0b5 authored by lucas_miranda's avatar lucas_miranda
Browse files

Integrated outlier interpolation to data.py

parent 70fb3adf
......@@ -55,6 +55,10 @@ class project:
arena_dims: tuple = (1,),
exclude_bodyparts: List = tuple([""]),
exp_conditions: dict = None,
interpolate_outliers: str = "MA",
interpolation_limit: int = 15,
interpolation_std: int = 2,
likelihood_tol: float = 0.9,
model: str = "mouse_topview",
path: str = deepof.utils.os.path.join("."),
smooth_alpha: float = 0.99,
......@@ -88,17 +92,21 @@ class project:
if tab.endswith(self.table_format) and not tab.startswith(".")
]
)
self.angles = True
self.animal_ids = animal_ids
self.arena = arena
self.arena_dims = arena_dims
self.distances = "all"
self.ego = False
self.exp_conditions = exp_conditions
self.interpolate_outliers = interpolate_outliers
self.interpolation_limit = interpolation_limit
self.interpolation_std = interpolation_std
self.likelihood_tolerance = likelihood_tol
self.scales = self.get_scale
self.smooth_alpha = smooth_alpha
self.video_format = video_format
self.angles = True
self.distances = "all"
self.ego = False
self.subset_condition = None
self.video_format = video_format
model_dict = {
"mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
......@@ -258,6 +266,21 @@ class project:
)
tab_dict[k] = temp.sort_index(axis=1)
if self.interpolate_outliers:
if verbose:
print("Interpolating outliers...")
for k, value in tab_dict.items():
tab_dict[k] = deepof.utils.interpolate_outliers(
value,
lik_dict[k],
likelihood_tolerance=self.likelihood_tolerance,
mode="or",
limit=self.interpolation_limit,
n_std=self.interpolation_std,
)
return tab_dict, lik_dict
def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
......
......@@ -430,7 +430,11 @@ def full_outlier_mask(
"""
body_parts = experiment.columns.levels[0]
full_mask = experiment.copy().drop(exclude, axis=1)
full_mask = experiment.copy()
if exclude:
full_mask.drop(exclude, axis=1, inplace=True)
for bpart in body_parts:
if bpart != exclude:
mask = mask_outliers(
......@@ -453,7 +457,7 @@ def interpolate_outliers(
experiment: pd.DataFrame,
likelihood: pd.DataFrame,
likelihood_tolerance: float,
exclude: str,
exclude: str = "",
lag: int = 5,
n_std: int = 3,
mode: str = "or",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment