Skip to content
Snippets Groups Projects
Commit 7cf5b0b8 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added tests for preprocess.py

parent 1d2cf7cf
No related branches found
No related tags found
No related merge requests found
......@@ -31,7 +31,7 @@ class project:
subset_condition=None,
arena="circular",
smooth_alpha=0.1,
arena_dims=[1],
arena_dims=(1,),
distances="All",
ego=False,
angles=True,
......@@ -60,10 +60,6 @@ class project:
self.connectivity = connectivity
self.scales = self.get_scale
# assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [
# re.findall("(.*)\.", tab)[0] for tab in self.tables
# ], "Video files should match table files"
def __str__(self):
if self.exp_conditions:
return "DLC analysis of {} videos across {} conditions".format(
......@@ -72,7 +68,7 @@ class project:
else:
return "DLC analysis of {} videos".format(len(self.videos))
def load_tables(self, verbose):
def load_tables(self, verbose=False):
"""Loads videos and tables into dictionaries"""
if self.table_format not in [".h5", ".csv"]:
......@@ -86,17 +82,29 @@ class project:
if self.table_format == ".h5":
table_dict = {
re.findall("(.*?)_", tab)[0]: pd.read_hdf(
self.table_path + tab, dtype=float
os.path.join(self.table_path, tab), dtype=float
)
for tab in self.tables
}
elif self.table_format == ".csv":
table_dict = {
re.findall("(.*?)_", tab)[0]: pd.read_csv(
self.table_path + tab, dtype=float
table_dict = {}
for tab in self.tables:
head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2)
data = pd.read_csv(
os.path.join(self.table_path, tab),
skiprows=2,
index_col="coords",
dtype={"coords": int},
).drop("1", axis=1)
data.columns = pd.MultiIndex.from_product(
[
[head.columns[2]],
set(list(head.iloc[0])[2:]),
["x", "y", "likelihood"],
],
names=["scorer", "bodyparts", "coords"],
)
for tab in self.tables
}
table_dict[re.findall("(.*?)_", tab)[0]] = data
lik_dict = defaultdict()
......
......@@ -10,3 +10,26 @@ from scipy.spatial import distance
from deepof.utils import *
import deepof.preprocess
import pytest
@given(table_type=st.integers(min_value=0, max_value=2))
def test_project_init(table_type):
table_type = [".h5", ".csv", ".foo"][table_type]
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=table_type,
)
if table_type != ".foo":
assert type(prun) == deepof.preprocess.project
assert type(prun.load_tables(verbose=True)) == tuple
print(prun)
else:
with pytest.raises(NotImplementedError):
prun.load_tables(verbose=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment