Commit 7cf5b0b8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for preprocess.py

parent 1d2cf7cf
......@@ -31,7 +31,7 @@ class project:
subset_condition=None,
arena="circular",
smooth_alpha=0.1,
arena_dims=[1],
arena_dims=(1,),
distances="All",
ego=False,
angles=True,
......@@ -60,10 +60,6 @@ class project:
self.connectivity = connectivity
self.scales = self.get_scale
# assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [
# re.findall("(.*)\.", tab)[0] for tab in self.tables
# ], "Video files should match table files"
def __str__(self):
if self.exp_conditions:
return "DLC analysis of {} videos across {} conditions".format(
......@@ -72,7 +68,7 @@ class project:
else:
return "DLC analysis of {} videos".format(len(self.videos))
def load_tables(self, verbose):
def load_tables(self, verbose=False):
"""Loads videos and tables into dictionaries"""
if self.table_format not in [".h5", ".csv"]:
......@@ -86,17 +82,29 @@ class project:
if self.table_format == ".h5":
table_dict = {
re.findall("(.*?)_", tab)[0]: pd.read_hdf(
self.table_path + tab, dtype=float
os.path.join(self.table_path, tab), dtype=float
)
for tab in self.tables
}
elif self.table_format == ".csv":
table_dict = {
re.findall("(.*?)_", tab)[0]: pd.read_csv(
self.table_path + tab, dtype=float
table_dict = {}
for tab in self.tables:
head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2)
data = pd.read_csv(
os.path.join(self.table_path, tab),
skiprows=2,
index_col="coords",
dtype={"coords": int},
).drop("1", axis=1)
data.columns = pd.MultiIndex.from_product(
[
[head.columns[2]],
set(list(head.iloc[0])[2:]),
["x", "y", "likelihood"],
],
names=["scorer", "bodyparts", "coords"],
)
for tab in self.tables
}
table_dict[re.findall("(.*?)_", tab)[0]] = data
lik_dict = defaultdict()
......
......@@ -10,3 +10,26 @@ from scipy.spatial import distance
from deepof.utils import *
import deepof.preprocess
import pytest
@given(table_type=st.integers(min_value=0, max_value=2))
def test_project_init(table_type):
table_type = [".h5", ".csv", ".foo"][table_type]
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=table_type,
)
if table_type != ".foo":
assert type(prun) == deepof.preprocess.project
assert type(prun.load_tables(verbose=True)) == tuple
print(prun)
else:
with pytest.raises(NotImplementedError):
prun.load_tables(verbose=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment