Commit 7cf5b0b8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for preprocess.py

parent 1d2cf7cf
...@@ -31,7 +31,7 @@ class project: ...@@ -31,7 +31,7 @@ class project:
subset_condition=None, subset_condition=None,
arena="circular", arena="circular",
smooth_alpha=0.1, smooth_alpha=0.1,
arena_dims=[1], arena_dims=(1,),
distances="All", distances="All",
ego=False, ego=False,
angles=True, angles=True,
...@@ -60,10 +60,6 @@ class project: ...@@ -60,10 +60,6 @@ class project:
self.connectivity = connectivity self.connectivity = connectivity
self.scales = self.get_scale self.scales = self.get_scale
# assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [
# re.findall("(.*)\.", tab)[0] for tab in self.tables
# ], "Video files should match table files"
def __str__(self): def __str__(self):
if self.exp_conditions: if self.exp_conditions:
return "DLC analysis of {} videos across {} conditions".format( return "DLC analysis of {} videos across {} conditions".format(
...@@ -72,7 +68,7 @@ class project: ...@@ -72,7 +68,7 @@ class project:
else: else:
return "DLC analysis of {} videos".format(len(self.videos)) return "DLC analysis of {} videos".format(len(self.videos))
def load_tables(self, verbose): def load_tables(self, verbose=False):
"""Loads videos and tables into dictionaries""" """Loads videos and tables into dictionaries"""
if self.table_format not in [".h5", ".csv"]: if self.table_format not in [".h5", ".csv"]:
...@@ -86,17 +82,29 @@ class project: ...@@ -86,17 +82,29 @@ class project:
if self.table_format == ".h5": if self.table_format == ".h5":
table_dict = { table_dict = {
re.findall("(.*?)_", tab)[0]: pd.read_hdf( re.findall("(.*?)_", tab)[0]: pd.read_hdf(
self.table_path + tab, dtype=float os.path.join(self.table_path, tab), dtype=float
) )
for tab in self.tables for tab in self.tables
} }
elif self.table_format == ".csv": elif self.table_format == ".csv":
table_dict = { table_dict = {}
re.findall("(.*?)_", tab)[0]: pd.read_csv( for tab in self.tables:
self.table_path + tab, dtype=float head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2)
data = pd.read_csv(
os.path.join(self.table_path, tab),
skiprows=2,
index_col="coords",
dtype={"coords": int},
).drop("1", axis=1)
data.columns = pd.MultiIndex.from_product(
[
[head.columns[2]],
set(list(head.iloc[0])[2:]),
["x", "y", "likelihood"],
],
names=["scorer", "bodyparts", "coords"],
) )
for tab in self.tables table_dict[re.findall("(.*?)_", tab)[0]] = data
}
lik_dict = defaultdict() lik_dict = defaultdict()
......
...@@ -10,3 +10,26 @@ from scipy.spatial import distance ...@@ -10,3 +10,26 @@ from scipy.spatial import distance
from deepof.utils import * from deepof.utils import *
import deepof.preprocess import deepof.preprocess
import pytest import pytest
@given(table_type=st.integers(min_value=0, max_value=2))
def test_project_init(table_type):
table_type = [".h5", ".csv", ".foo"][table_type]
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=table_type,
)
if table_type != ".foo":
assert type(prun) == deepof.preprocess.project
assert type(prun.load_tables(verbose=True)) == tuple
print(prun)
else:
with pytest.raises(NotImplementedError):
prun.load_tables(verbose=True)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment