From b975d989dafee16bc5fd969fa067a03358770a9c Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Thu, 17 Sep 2020 16:04:41 +0200 Subject: [PATCH] Added tests for deepof.visuals --- deepof/preprocess.py | 3 ++- deepof/utils.py | 2 +- tests/test_preprocess.py | 5 +++++ tests/test_utils.py | 14 ++++++++++---- tests/test_visuals.py | 10 +++++++--- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/deepof/preprocess.py b/deepof/preprocess.py index 010437a6..34ba64b4 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -617,7 +617,8 @@ class table_dict(dict): if self._type != "coords" or self._polar: raise NotImplementedError( - "Heatmaps only available for cartesian coordinates. Set polar to False in get_coordinates and try again" + "Heatmaps only available for cartesian coordinates. " + "Set polar to False in get_coordinates and try again" ) if not self._center: diff --git a/deepof/utils.py b/deepof/utils.py index 4f699329..4c95e3d0 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -765,7 +765,7 @@ def single_behaviour_analysis( if save is not None: plt.savefig(save) - return_list.append(ax) + return_list.append(fig) if stat_tests: stat_dict = {} diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index a3cdfc07..06eb3427 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -14,6 +14,7 @@ from hypothesis import strategies as st from collections import defaultdict from deepof.utils import * import deepof.preprocess +import matplotlib.figure import pytest @@ -208,9 +209,13 @@ def test_get_table_dicts(nodes, ego, sampler): assert type(tset[0]) == np.ndarray if table._type == "coords" and algn == "Nose" and polar is False and speed == 0: + + assert type(table.plot_heatmaps(bodyparts=["Spine_1"])) == matplotlib.figure.Figure + align = sampler.draw( st.one_of(st.just(False), st.just("all"), st.just("center")) ) + else: align = False diff --git a/tests/test_utils.py b/tests/test_utils.py index a4a2091b..fd9418ec 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ from hypothesis.extra.pandas import range_indexes, columns, data_frames from scipy.spatial import distance from deepof.utils import * import deepof.preprocess +import matplotlib.figure import pytest import string @@ -25,6 +26,7 @@ import string def autocorr(x, t=1): + """Computes autocorrelation of the given array with a lag of t""" return np.round(np.corrcoef(np.array([x[:-t], x[t:]]))[0, 1], 5) @@ -429,7 +431,7 @@ def test_climb_wall(arena, tol): deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), arena="circular", - arena_dims=[arena[2]], + arena_dims=tuple([arena[2]]), angles=False, video_format=".mp4", table_format=".h5", @@ -636,18 +638,22 @@ def test_single_behaviour_analysis(sampler): ylim = sampler.draw(st.floats(min_value=0, max_value=10)) stat_tests = sampler.draw(st.booleans()) + plot = sampler.draw(st.integers(min_value=0, max_value=200)) + out = single_behaviour_analysis( behaviours[0], treatment_dict, behavioural_dict, - plot=0, + plot=plot, stat_tests=stat_tests, save=None, ylim=ylim, ) - assert len(out) == 1 if stat_tests == 0 else len(out) == 2 + assert len(out) == 1 if (stat_tests == 0 and plot == 0) else len(out) >= 2 assert type(out[0]) == dict + if plot: + assert np.any(np.array([type(i) for i in out]) == matplotlib.figure.Figure) if stat_tests: assert type(out[0]) == dict @@ -768,7 +774,7 @@ def test_rule_based_tagging(): prun = deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), arena="circular", - arena_dims=[380], + arena_dims=tuple([380]), angles=False, video_format=".mp4", table_format=".h5", diff --git a/tests/test_visuals.py b/tests/test_visuals.py index 85054dc1..446602e1 100644 --- a/tests/test_visuals.py +++ b/tests/test_visuals.py @@ -8,14 +8,18 @@ Testing module for deepof.visuals """ - +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st from deepof.utils import * import deepof.preprocess import deepof.visuals import matplotlib.figure -def test_plot_heatmap(): +@settings(deadline=None) +@given(bparts=st.one_of(st.just(["Center"]), st.just(["Center", "Nose"]))) +def test_plot_heatmap(bparts): prun = ( deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), @@ -33,7 +37,7 @@ def test_plot_heatmap(): type( deepof.visuals.plot_heatmap( prun["test"], - ["Center"], + bparts, tuple([-100, 100]), tuple([-100, 100]), dpi=200, -- GitLab