From 7cf5b0b8771a1bfff0922aeabceb82772cf6499e Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 16 Sep 2020 01:40:41 +0200
Subject: [PATCH] Added tests for preprocess.py

---
 deepof/preprocess.py     | 32 ++++++++++++++++++++------------
 tests/test_preprocess.py | 23 +++++++++++++++++++++++
 2 files changed, 43 insertions(+), 12 deletions(-)

diff --git a/deepof/preprocess.py b/deepof/preprocess.py
index cff9d1b8..3d8c4c3d 100644
--- a/deepof/preprocess.py
+++ b/deepof/preprocess.py
@@ -31,7 +31,7 @@ class project:
         subset_condition=None,
         arena="circular",
         smooth_alpha=0.1,
-        arena_dims=[1],
+        arena_dims=(1,),
         distances="All",
         ego=False,
         angles=True,
@@ -60,10 +60,6 @@ class project:
         self.connectivity = connectivity
         self.scales = self.get_scale
 
-        # assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [
-        #     re.findall("(.*)\.", tab)[0] for tab in self.tables
-        # ], "Video files should match table files"
-
     def __str__(self):
         if self.exp_conditions:
             return "DLC analysis of {} videos across {} conditions".format(
@@ -72,7 +68,7 @@ class project:
         else:
             return "DLC analysis of {} videos".format(len(self.videos))
 
-    def load_tables(self, verbose):
+    def load_tables(self, verbose=False):
         """Loads videos and tables into dictionaries"""
 
         if self.table_format not in [".h5", ".csv"]:
@@ -86,17 +82,29 @@ class project:
         if self.table_format == ".h5":
             table_dict = {
                 re.findall("(.*?)_", tab)[0]: pd.read_hdf(
-                    self.table_path + tab, dtype=float
+                    os.path.join(self.table_path, tab), dtype=float
                 )
                 for tab in self.tables
             }
         elif self.table_format == ".csv":
-            table_dict = {
-                re.findall("(.*?)_", tab)[0]: pd.read_csv(
-                    self.table_path + tab, dtype=float
+            table_dict = {}
+            for tab in self.tables:
+                head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2)
+                data = pd.read_csv(
+                    os.path.join(self.table_path, tab),
+                    skiprows=2,
+                    index_col="coords",
+                    dtype={"coords": int},
+                ).drop("1", axis=1)
+                data.columns = pd.MultiIndex.from_product(
+                    [
+                        [head.columns[2]],
+                        set(list(head.iloc[0])[2:]),
+                        ["x", "y", "likelihood"],
+                    ],
+                    names=["scorer", "bodyparts", "coords"],
                 )
-                for tab in self.tables
-            }
+                table_dict[re.findall("(.*?)_", tab)[0]] = data
 
         lik_dict = defaultdict()
 
diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py
index 0bdf4f9d..2fdb0d26 100644
--- a/tests/test_preprocess.py
+++ b/tests/test_preprocess.py
@@ -10,3 +10,26 @@ from scipy.spatial import distance
 from deepof.utils import *
 import deepof.preprocess
 import pytest
+
+
+@given(table_type=st.integers(min_value=0, max_value=2))
+def test_project_init(table_type):
+
+    table_type = [".h5", ".csv", ".foo"][table_type]
+
+    prun = deepof.preprocess.project(
+        path=os.path.join(".", "tests", "test_examples"),
+        arena="circular",
+        arena_dims=[380],
+        angles=False,
+        video_format=".mp4",
+        table_format=table_type,
+    )
+
+    if table_type != ".foo":
+        assert type(prun) == deepof.preprocess.project
+        assert type(prun.load_tables(verbose=True)) == tuple
+        print(prun)
+    else:
+        with pytest.raises(NotImplementedError):
+            prun.load_tables(verbose=True)
-- 
GitLab