From b975d989dafee16bc5fd969fa067a03358770a9c Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 17 Sep 2020 16:04:41 +0200
Subject: [PATCH] Added tests for deepof.visuals

---
 deepof/preprocess.py     |  3 ++-
 deepof/utils.py          |  2 +-
 tests/test_preprocess.py |  5 +++++
 tests/test_utils.py      | 14 ++++++++++----
 tests/test_visuals.py    | 10 +++++++---
 5 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/deepof/preprocess.py b/deepof/preprocess.py
index 010437a6..34ba64b4 100644
--- a/deepof/preprocess.py
+++ b/deepof/preprocess.py
@@ -617,7 +617,8 @@ class table_dict(dict):
 
         if self._type != "coords" or self._polar:
             raise NotImplementedError(
-                "Heatmaps only available for cartesian coordinates. Set polar to False in get_coordinates and try again"
+                "Heatmaps only available for cartesian coordinates. "
+                "Set polar to False in get_coordinates and try again"
             )
 
         if not self._center:
diff --git a/deepof/utils.py b/deepof/utils.py
index 4f699329..4c95e3d0 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -765,7 +765,7 @@ def single_behaviour_analysis(
         if save is not None:
             plt.savefig(save)
 
-        return_list.append(ax)
+        return_list.append(fig)
 
     if stat_tests:
         stat_dict = {}
diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py
index a3cdfc07..06eb3427 100644
--- a/tests/test_preprocess.py
+++ b/tests/test_preprocess.py
@@ -14,6 +14,7 @@ from hypothesis import strategies as st
 from collections import defaultdict
 from deepof.utils import *
 import deepof.preprocess
+import matplotlib.figure
 import pytest
 
 
@@ -208,9 +209,13 @@ def test_get_table_dicts(nodes, ego, sampler):
     assert type(tset[0]) == np.ndarray
 
     if table._type == "coords" and algn == "Nose" and polar is False and speed == 0:
+
+        assert type(table.plot_heatmaps(bodyparts=["Spine_1"])) == matplotlib.figure.Figure
+
         align = sampler.draw(
             st.one_of(st.just(False), st.just("all"), st.just("center"))
         )
+
     else:
         align = False
 
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a4a2091b..fd9418ec 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -17,6 +17,7 @@ from hypothesis.extra.pandas import range_indexes, columns, data_frames
 from scipy.spatial import distance
 from deepof.utils import *
 import deepof.preprocess
+import matplotlib.figure
 import pytest
 import string
 
@@ -25,6 +26,7 @@ import string
 
 
 def autocorr(x, t=1):
+    """Computes autocorrelation of the given array with a lag of t"""
     return np.round(np.corrcoef(np.array([x[:-t], x[t:]]))[0, 1], 5)
 
 
@@ -429,7 +431,7 @@ def test_climb_wall(arena, tol):
         deepof.preprocess.project(
             path=os.path.join(".", "tests", "test_examples"),
             arena="circular",
-            arena_dims=[arena[2]],
+            arena_dims=tuple([arena[2]]),
             angles=False,
             video_format=".mp4",
             table_format=".h5",
@@ -636,18 +638,22 @@ def test_single_behaviour_analysis(sampler):
     ylim = sampler.draw(st.floats(min_value=0, max_value=10))
     stat_tests = sampler.draw(st.booleans())
 
+    plot = sampler.draw(st.integers(min_value=0, max_value=200))
+
     out = single_behaviour_analysis(
         behaviours[0],
         treatment_dict,
         behavioural_dict,
-        plot=0,
+        plot=plot,
         stat_tests=stat_tests,
         save=None,
         ylim=ylim,
     )
 
-    assert len(out) == 1 if stat_tests == 0 else len(out) == 2
+    assert len(out) == 1 if (stat_tests == 0 and plot == 0) else len(out) >= 2
     assert type(out[0]) == dict
+    if plot:
+        assert np.any(np.array([type(i) for i in out]) == matplotlib.figure.Figure)
     if stat_tests:
         assert type(out[0]) == dict
 
@@ -768,7 +774,7 @@ def test_rule_based_tagging():
     prun = deepof.preprocess.project(
         path=os.path.join(".", "tests", "test_examples"),
         arena="circular",
-        arena_dims=[380],
+        arena_dims=tuple([380]),
         angles=False,
         video_format=".mp4",
         table_format=".h5",
diff --git a/tests/test_visuals.py b/tests/test_visuals.py
index 85054dc1..446602e1 100644
--- a/tests/test_visuals.py
+++ b/tests/test_visuals.py
@@ -8,14 +8,18 @@ Testing module for deepof.visuals
 
 """
 
-
+from hypothesis import given
+from hypothesis import settings
+from hypothesis import strategies as st
 from deepof.utils import *
 import deepof.preprocess
 import deepof.visuals
 import matplotlib.figure
 
 
-def test_plot_heatmap():
+@settings(deadline=None)
+@given(bparts=st.one_of(st.just(["Center"]), st.just(["Center", "Nose"])))
+def test_plot_heatmap(bparts):
     prun = (
         deepof.preprocess.project(
             path=os.path.join(".", "tests", "test_examples"),
@@ -33,7 +37,7 @@ def test_plot_heatmap():
         type(
             deepof.visuals.plot_heatmap(
                 prun["test"],
-                ["Center"],
+                bparts,
                 tuple([-100, 100]),
                 tuple([-100, 100]),
                 dpi=200,
-- 
GitLab