diff --git a/deepof/preprocess.py b/deepof/preprocess.py index cff9d1b8547eca8d92ae91344cbb111e855fd013..3d8c4c3de9cdb615fbc8c79124237f182858af73 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -31,7 +31,7 @@ class project: subset_condition=None, arena="circular", smooth_alpha=0.1, - arena_dims=[1], + arena_dims=(1,), distances="All", ego=False, angles=True, @@ -60,10 +60,6 @@ class project: self.connectivity = connectivity self.scales = self.get_scale - # assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [ - # re.findall("(.*)\.", tab)[0] for tab in self.tables - # ], "Video files should match table files" - def __str__(self): if self.exp_conditions: return "DLC analysis of {} videos across {} conditions".format( @@ -72,7 +68,7 @@ class project: else: return "DLC analysis of {} videos".format(len(self.videos)) - def load_tables(self, verbose): + def load_tables(self, verbose=False): """Loads videos and tables into dictionaries""" if self.table_format not in [".h5", ".csv"]: @@ -86,17 +82,29 @@ class project: if self.table_format == ".h5": table_dict = { re.findall("(.*?)_", tab)[0]: pd.read_hdf( - self.table_path + tab, dtype=float + os.path.join(self.table_path, tab), dtype=float ) for tab in self.tables } elif self.table_format == ".csv": - table_dict = { - re.findall("(.*?)_", tab)[0]: pd.read_csv( - self.table_path + tab, dtype=float + table_dict = {} + for tab in self.tables: + head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2) + data = pd.read_csv( + os.path.join(self.table_path, tab), + skiprows=2, + index_col="coords", + dtype={"coords": int}, + ).drop("1", axis=1) + data.columns = pd.MultiIndex.from_product( + [ + [head.columns[2]], + set(list(head.iloc[0])[2:]), + ["x", "y", "likelihood"], + ], + names=["scorer", "bodyparts", "coords"], ) - for tab in self.tables - } + table_dict[re.findall("(.*?)_", tab)[0]] = data lik_dict = defaultdict() diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 0bdf4f9d1d69d0b3061eb77ad0c4a437745c4387..2fdb0d26a5668b748c0e4ff224543c12023db260 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -10,3 +10,26 @@ from scipy.spatial import distance from deepof.utils import * import deepof.preprocess import pytest + + +@given(table_type=st.integers(min_value=0, max_value=2)) +def test_project_init(table_type): + + table_type = [".h5", ".csv", ".foo"][table_type] + + prun = deepof.preprocess.project( + path=os.path.join(".", "tests", "test_examples"), + arena="circular", + arena_dims=[380], + angles=False, + video_format=".mp4", + table_format=table_type, + ) + + if table_type != ".foo": + assert type(prun) == deepof.preprocess.project + assert type(prun.load_tables(verbose=True)) == tuple + print(prun) + else: + with pytest.raises(NotImplementedError): + prun.load_tables(verbose=True)