From b975d989dafee16bc5fd969fa067a03358770a9c Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 17 Sep 2020 16:04:41 +0200
Subject: [PATCH] Added tests for deepof.visuals
---
deepof/preprocess.py | 3 ++-
deepof/utils.py | 2 +-
tests/test_preprocess.py | 5 +++++
tests/test_utils.py | 14 ++++++++++----
tests/test_visuals.py | 10 +++++++---
5 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/deepof/preprocess.py b/deepof/preprocess.py
index 010437a6..34ba64b4 100644
--- a/deepof/preprocess.py
+++ b/deepof/preprocess.py
@@ -617,7 +617,8 @@ class table_dict(dict):
if self._type != "coords" or self._polar:
raise NotImplementedError(
- "Heatmaps only available for cartesian coordinates. Set polar to False in get_coordinates and try again"
+ "Heatmaps only available for cartesian coordinates. "
+ "Set polar to False in get_coordinates and try again"
)
if not self._center:
diff --git a/deepof/utils.py b/deepof/utils.py
index 4f699329..4c95e3d0 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -765,7 +765,7 @@ def single_behaviour_analysis(
if save is not None:
plt.savefig(save)
- return_list.append(ax)
+ return_list.append(fig)
if stat_tests:
stat_dict = {}
diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py
index a3cdfc07..06eb3427 100644
--- a/tests/test_preprocess.py
+++ b/tests/test_preprocess.py
@@ -14,6 +14,7 @@ from hypothesis import strategies as st
from collections import defaultdict
from deepof.utils import *
import deepof.preprocess
+import matplotlib.figure
import pytest
@@ -208,9 +209,13 @@ def test_get_table_dicts(nodes, ego, sampler):
assert type(tset[0]) == np.ndarray
if table._type == "coords" and algn == "Nose" and polar is False and speed == 0:
+
+ assert type(table.plot_heatmaps(bodyparts=["Spine_1"])) == matplotlib.figure.Figure
+
align = sampler.draw(
st.one_of(st.just(False), st.just("all"), st.just("center"))
)
+
else:
align = False
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a4a2091b..fd9418ec 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -17,6 +17,7 @@ from hypothesis.extra.pandas import range_indexes, columns, data_frames
from scipy.spatial import distance
from deepof.utils import *
import deepof.preprocess
+import matplotlib.figure
import pytest
import string
@@ -25,6 +26,7 @@ import string
def autocorr(x, t=1):
+ """Computes autocorrelation of the given array with a lag of t"""
return np.round(np.corrcoef(np.array([x[:-t], x[t:]]))[0, 1], 5)
@@ -429,7 +431,7 @@ def test_climb_wall(arena, tol):
deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
- arena_dims=[arena[2]],
+ arena_dims=tuple([arena[2]]),
angles=False,
video_format=".mp4",
table_format=".h5",
@@ -636,18 +638,22 @@ def test_single_behaviour_analysis(sampler):
ylim = sampler.draw(st.floats(min_value=0, max_value=10))
stat_tests = sampler.draw(st.booleans())
+ plot = sampler.draw(st.integers(min_value=0, max_value=200))
+
out = single_behaviour_analysis(
behaviours[0],
treatment_dict,
behavioural_dict,
- plot=0,
+ plot=plot,
stat_tests=stat_tests,
save=None,
ylim=ylim,
)
- assert len(out) == 1 if stat_tests == 0 else len(out) == 2
+ assert len(out) == 1 if (stat_tests == 0 and plot == 0) else len(out) >= 2
assert type(out[0]) == dict
+ if plot:
+ assert np.any(np.array([type(i) for i in out]) == matplotlib.figure.Figure)
if stat_tests:
assert type(out[0]) == dict
@@ -768,7 +774,7 @@ def test_rule_based_tagging():
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
- arena_dims=[380],
+ arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=".h5",
diff --git a/tests/test_visuals.py b/tests/test_visuals.py
index 85054dc1..446602e1 100644
--- a/tests/test_visuals.py
+++ b/tests/test_visuals.py
@@ -8,14 +8,18 @@ Testing module for deepof.visuals
"""
-
+from hypothesis import given
+from hypothesis import settings
+from hypothesis import strategies as st
from deepof.utils import *
import deepof.preprocess
import deepof.visuals
import matplotlib.figure
-def test_plot_heatmap():
+@settings(deadline=None)
+@given(bparts=st.one_of(st.just(["Center"]), st.just(["Center", "Nose"])))
+def test_plot_heatmap(bparts):
prun = (
deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
@@ -33,7 +37,7 @@ def test_plot_heatmap():
type(
deepof.visuals.plot_heatmap(
prun["test"],
- ["Center"],
+ bparts,
tuple([-100, 100]),
tuple([-100, 100]),
dpi=200,
--
GitLab