diff --git a/deepof/preprocess.py b/deepof/preprocess.py index 010437a685dabf17b170e6c920863de8328d2ac5..34ba64b4744b13ac90a2cdb97e4ca27af78c45a0 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -617,7 +617,8 @@ class table_dict(dict): if self._type != "coords" or self._polar: raise NotImplementedError( - "Heatmaps only available for cartesian coordinates. Set polar to False in get_coordinates and try again" + "Heatmaps only available for cartesian coordinates. " + "Set polar to False in get_coordinates and try again" ) if not self._center: diff --git a/deepof/utils.py b/deepof/utils.py index 4f699329d3fa32776d2fe3d25e6507aaf81ea837..4c95e3d0fa07baf230162592766d7ba8e8fe3089 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -765,7 +765,7 @@ def single_behaviour_analysis( if save is not None: plt.savefig(save) - return_list.append(ax) + return_list.append(fig) if stat_tests: stat_dict = {} diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index a3cdfc07ce58f52d9a11e708139561ac833303fb..06eb3427f63fd10829b226329e6520c0090c9a07 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -14,6 +14,7 @@ from hypothesis import strategies as st from collections import defaultdict from deepof.utils import * import deepof.preprocess +import matplotlib.figure import pytest @@ -208,9 +209,13 @@ def test_get_table_dicts(nodes, ego, sampler): assert type(tset[0]) == np.ndarray if table._type == "coords" and algn == "Nose" and polar is False and speed == 0: + + assert type(table.plot_heatmaps(bodyparts=["Spine_1"])) == matplotlib.figure.Figure + align = sampler.draw( st.one_of(st.just(False), st.just("all"), st.just("center")) ) + else: align = False diff --git a/tests/test_utils.py b/tests/test_utils.py index a4a2091be06f7091644b481e76e894d89dc9c9d1..fd9418ec2ed3ada5ca45cf7c47116230a88154fa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ from hypothesis.extra.pandas import range_indexes, columns, data_frames from scipy.spatial import distance from deepof.utils import * import deepof.preprocess +import matplotlib.figure import pytest import string @@ -25,6 +26,7 @@ import string def autocorr(x, t=1): + """Computes autocorrelation of the given array with a lag of t""" return np.round(np.corrcoef(np.array([x[:-t], x[t:]]))[0, 1], 5) @@ -429,7 +431,7 @@ def test_climb_wall(arena, tol): deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), arena="circular", - arena_dims=[arena[2]], + arena_dims=tuple([arena[2]]), angles=False, video_format=".mp4", table_format=".h5", @@ -636,18 +638,22 @@ def test_single_behaviour_analysis(sampler): ylim = sampler.draw(st.floats(min_value=0, max_value=10)) stat_tests = sampler.draw(st.booleans()) + plot = sampler.draw(st.integers(min_value=0, max_value=200)) + out = single_behaviour_analysis( behaviours[0], treatment_dict, behavioural_dict, - plot=0, + plot=plot, stat_tests=stat_tests, save=None, ylim=ylim, ) - assert len(out) == 1 if stat_tests == 0 else len(out) == 2 + assert len(out) == 1 if (stat_tests == 0 and plot == 0) else len(out) >= 2 assert type(out[0]) == dict + if plot: + assert np.any(np.array([type(i) for i in out]) == matplotlib.figure.Figure) if stat_tests: assert type(out[0]) == dict @@ -768,7 +774,7 @@ def test_rule_based_tagging(): prun = deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), arena="circular", - arena_dims=[380], + arena_dims=tuple([380]), angles=False, video_format=".mp4", table_format=".h5", diff --git a/tests/test_visuals.py b/tests/test_visuals.py index 85054dc13a28489605ad8cbab5a8da8719bafe9a..446602e1a4251855f2d839c610f9636c2781c0cf 100644 --- a/tests/test_visuals.py +++ b/tests/test_visuals.py @@ -8,14 +8,18 @@ Testing module for deepof.visuals """ - +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st from deepof.utils import * import deepof.preprocess import deepof.visuals import matplotlib.figure -def test_plot_heatmap(): +@settings(deadline=None) +@given(bparts=st.one_of(st.just(["Center"]), st.just(["Center", "Nose"]))) +def test_plot_heatmap(bparts): prun = ( deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), @@ -33,7 +37,7 @@ def test_plot_heatmap(): type( deepof.visuals.plot_heatmap( prun["test"], - ["Center"], + bparts, tuple([-100, 100]), tuple([-100, 100]), dpi=200,