From 7cf5b0b8771a1bfff0922aeabceb82772cf6499e Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 16 Sep 2020 01:40:41 +0200 Subject: [PATCH] Added tests for preprocess.py --- deepof/preprocess.py | 32 ++++++++++++++++++++------------ tests/test_preprocess.py | 23 +++++++++++++++++++++++ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/deepof/preprocess.py b/deepof/preprocess.py index cff9d1b8..3d8c4c3d 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -31,7 +31,7 @@ class project: subset_condition=None, arena="circular", smooth_alpha=0.1, - arena_dims=[1], + arena_dims=(1,), distances="All", ego=False, angles=True, @@ -60,10 +60,6 @@ class project: self.connectivity = connectivity self.scales = self.get_scale - # assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [ - # re.findall("(.*)\.", tab)[0] for tab in self.tables - # ], "Video files should match table files" - def __str__(self): if self.exp_conditions: return "DLC analysis of {} videos across {} conditions".format( @@ -72,7 +68,7 @@ class project: else: return "DLC analysis of {} videos".format(len(self.videos)) - def load_tables(self, verbose): + def load_tables(self, verbose=False): """Loads videos and tables into dictionaries""" if self.table_format not in [".h5", ".csv"]: @@ -86,17 +82,29 @@ class project: if self.table_format == ".h5": table_dict = { re.findall("(.*?)_", tab)[0]: pd.read_hdf( - self.table_path + tab, dtype=float + os.path.join(self.table_path, tab), dtype=float ) for tab in self.tables } elif self.table_format == ".csv": - table_dict = { - re.findall("(.*?)_", tab)[0]: pd.read_csv( - self.table_path + tab, dtype=float + table_dict = {} + for tab in self.tables: + head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2) + data = pd.read_csv( + os.path.join(self.table_path, tab), + skiprows=2, + index_col="coords", + dtype={"coords": int}, + ).drop("1", axis=1) + data.columns = pd.MultiIndex.from_product( + [ + [head.columns[2]], + set(list(head.iloc[0])[2:]), + ["x", "y", "likelihood"], + ], + names=["scorer", "bodyparts", "coords"], ) - for tab in self.tables - } + table_dict[re.findall("(.*?)_", tab)[0]] = data lik_dict = defaultdict() diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 0bdf4f9d..2fdb0d26 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -10,3 +10,26 @@ from scipy.spatial import distance from deepof.utils import * import deepof.preprocess import pytest + + +@given(table_type=st.integers(min_value=0, max_value=2)) +def test_project_init(table_type): + + table_type = [".h5", ".csv", ".foo"][table_type] + + prun = deepof.preprocess.project( + path=os.path.join(".", "tests", "test_examples"), + arena="circular", + arena_dims=[380], + angles=False, + video_format=".mp4", + table_format=table_type, + ) + + if table_type != ".foo": + assert type(prun) == deepof.preprocess.project + assert type(prun.load_tables(verbose=True)) == tuple + print(prun) + else: + with pytest.raises(NotImplementedError): + prun.load_tables(verbose=True) -- GitLab