Commit db3e8424 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for deepof.visuals

parent 15b9eb4d
...@@ -440,7 +440,7 @@ class coordinates: ...@@ -440,7 +440,7 @@ class coordinates:
for key, tab in tabs.items(): for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range( tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype('timedelta64[s]') ).astype("timedelta64[s]")
if align: if align:
assert ( assert (
...@@ -492,7 +492,7 @@ class coordinates: ...@@ -492,7 +492,7 @@ class coordinates:
for key, tab in tabs.items(): for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range( tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype('timedelta64[s]') ).astype("timedelta64[s]")
return table_dict(tabs, typ="dists") return table_dict(tabs, typ="dists")
...@@ -532,7 +532,7 @@ class coordinates: ...@@ -532,7 +532,7 @@ class coordinates:
for key, tab in tabs.items(): for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range( tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype('timedelta64[s]') ).astype("timedelta64[s]")
return table_dict(tabs, typ="angles") return table_dict(tabs, typ="angles")
...@@ -571,6 +571,7 @@ class coordinates: ...@@ -571,6 +571,7 @@ class coordinates:
return self._arena, self._arena_dims, self._scales return self._arena, self._arena_dims, self._scales
def rule_based_annotation(self): def rule_based_annotation(self):
"""Annotates coordinates using a simple rule-based pipeline"""
pass pass
...@@ -634,10 +635,12 @@ class table_dict(dict): ...@@ -634,10 +635,12 @@ class table_dict(dict):
else [0, self._arena_dims[i][1]] else [0, self._arena_dims[i][1]]
) )
plot_heatmap( heatmaps = plot_heatmap(
list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save, list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save,
) )
return heatmaps
def get_training_set(self, test_videos: int = 0) -> Tuple[np.ndarray, np.ndarray]: def get_training_set(self, test_videos: int = 0) -> Tuple[np.ndarray, np.ndarray]:
"""Generates training and test sets as numpy.array objects for model training""" """Generates training and test sets as numpy.array objects for model training"""
......
...@@ -13,42 +13,36 @@ import numpy as np ...@@ -13,42 +13,36 @@ import numpy as np
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns
from itertools import cycle from itertools import cycle
from typing import List, Dict from typing import List
# PLOTTING FUNCTIONS # # PLOTTING FUNCTIONS #
def plot_speed(
behaviour_dict: Dict[str, pd.DataFrame], treatments: Dict[str, List]
) -> plt.figure:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
for Treatment, Mice_list in treatments.items():
hist = pd.concat([behaviour_dict[mouse] for mouse in Mice_list])
sns.kdeplot(hist["bspeed"], shade=True, label=Treatment, ax=ax1)
sns.kdeplot(hist["wspeed"], shade=True, label=Treatment, ax=ax2)
ax1.set_xlim(0, 7)
ax2.set_xlim(0, 7)
ax1.set_title("Average speed density for black mouse")
ax2.set_title("Average speed density for white mouse")
plt.xlabel("Average speed")
plt.ylabel("Density")
plt.show()
def plot_heatmap( def plot_heatmap(
dframe: pd.DataFrame, bodyparts: List, xlim: float, ylim: float, save: str = False dframe: pd.DataFrame,
bodyparts: List,
xlim: tuple,
ylim: tuple,
save: str = False,
dpi: int = 200,
) -> plt.figure: ) -> plt.figure:
"""Returns a heatmap of the movement of a specific bodypart in the arena. """Returns a heatmap of the movement of a specific bodypart in the arena.
If more than one bodypart is passed, it returns one subplot for each""" If more than one bodypart is passed, it returns one subplot for each
Parameters:
- dframe (pandas.DataFrame): table_dict value with info to plot
- bodyparts (List): bodyparts to represent (at least 1)
- xlim (float): limits of the x-axis
- ylim (float): limits of the y-axis
- save (str): name of the file to which the figure should be saved
- dpi (int): dots per inch of the returned image
Returns:
- heatmaps (plt.figure): figure with the specified characteristics"""
# noinspection PyTypeChecker # noinspection PyTypeChecker
fig, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True) heatmaps, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True, dpi=dpi)
for i, bpart in enumerate(bodyparts): for i, bpart in enumerate(bodyparts):
heatmap = dframe[bpart] heatmap = dframe[bpart]
...@@ -65,7 +59,7 @@ def plot_heatmap( ...@@ -65,7 +59,7 @@ def plot_heatmap(
if save: if save:
plt.savefig(save) plt.savefig(save)
plt.show() return heatmaps
def model_comparison_plot( def model_comparison_plot(
...@@ -73,17 +67,40 @@ def model_comparison_plot( ...@@ -73,17 +67,40 @@ def model_comparison_plot(
m_bic: list, m_bic: list,
n_components_range: range, n_components_range: range,
cov_plot: str, cov_plot: str,
save: str, save: str = False,
cv_types: tuple = ("spherical", "tied", "diag", "full"), cv_types: tuple = ("spherical", "tied", "diag", "full"),
dpi: int = 200,
) -> plt.figure: ) -> plt.figure:
"""Plots model comparison statistics over all tests""" """
Plots model comparison statistics for Gaussian Mixture Model analysis.
Similar to https://scikit-learn.org/stable/modules/mixture.html, it shows
an upper panel with BIC per number of components and covariance matrix type
in a bar plot, and a lower panel with box plots showing bootstrap runs of the
models corresponding to one of the covariance types.
Parameters:
- bic (list): list with BIC for all used models
- m_bic (list): list with minimum bic across cov matrices
for all used models
- n_components_range (range): range of components to evaluate
- cov_plot (str): covariance matrix to use in the lower panel
- save (str): name of the file to which the figure should be saved
- cv_types (tuple): tuple indicating which covariance matrix types
to use. All (spherical, tied, diag and full) used by default.
- dpi (int): dots per inch of the returned image
Returns:
- modelcomp (plt.figure): figure with all specified characteristics
"""
m_bic = np.array(m_bic) m_bic = np.array(m_bic)
color_iter = cycle(["navy", "turquoise", "cornflowerblue", "darkorange"]) color_iter = cycle(["navy", "turquoise", "cornflowerblue", "darkorange"])
bars = [] bars = []
# Plot the BIC scores # Plot the BIC scores
plt.figure(figsize=(12, 8)) modelcomp = plt.figure(dpi=dpi)
spl = plt.subplot(2, 1, 1) spl = plt.subplot(2, 1, 1)
covplot = np.repeat(cv_types, len(m_bic) / 4) covplot = np.repeat(cv_types, len(m_bic) / 4)
...@@ -115,9 +132,7 @@ def model_comparison_plot( ...@@ -115,9 +132,7 @@ def model_comparison_plot(
spl2.set_xlabel("Number of components") spl2.set_xlabel("Number of components")
spl2.set_ylabel("BIC value") spl2.set_ylabel("BIC value")
plt.tight_layout()
if save: if save:
plt.savefig(save) plt.savefig(save)
plt.show() return modelcomp
...@@ -14,4 +14,4 @@ from hypothesis import strategies as st ...@@ -14,4 +14,4 @@ from hypothesis import strategies as st
from collections import defaultdict from collections import defaultdict
from deepof.utils import * from deepof.utils import *
import deepof.preprocess import deepof.preprocess
import pytest import pytest
\ No newline at end of file
...@@ -32,7 +32,7 @@ def test_project_init(table_type, arena_type): ...@@ -32,7 +32,7 @@ def test_project_init(table_type, arena_type):
prun = deepof.preprocess.project( prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"), path=os.path.join(".", "tests", "test_examples"),
arena=arena_type, arena=arena_type,
arena_dims=[380], arena_dims=tuple([380]),
angles=False, angles=False,
video_format=".mp4", video_format=".mp4",
table_format=table_type, table_format=table_type,
...@@ -41,7 +41,7 @@ def test_project_init(table_type, arena_type): ...@@ -41,7 +41,7 @@ def test_project_init(table_type, arena_type):
prun = deepof.preprocess.project( prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"), path=os.path.join(".", "tests", "test_examples"),
arena=arena_type, arena=arena_type,
arena_dims=[380], arena_dims=tuple([380]),
angles=False, angles=False,
video_format=".mp4", video_format=".mp4",
table_format=table_type, table_format=table_type,
...@@ -72,7 +72,7 @@ def test_get_distances(nodes, ego): ...@@ -72,7 +72,7 @@ def test_get_distances(nodes, ego):
prun = deepof.preprocess.project( prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"), path=os.path.join(".", "tests", "test_examples"),
arena="circular", arena="circular",
arena_dims=[380], arena_dims=tuple([380]),
angles=False, angles=False,
video_format=".mp4", video_format=".mp4",
table_format=".h5", table_format=".h5",
...@@ -98,7 +98,7 @@ def test_get_angles(nodes, ego): ...@@ -98,7 +98,7 @@ def test_get_angles(nodes, ego):
prun = deepof.preprocess.project( prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"), path=os.path.join(".", "tests", "test_examples"),
arena="circular", arena="circular",
arena_dims=[380], arena_dims=tuple([380]),
video_format=".mp4", video_format=".mp4",
table_format=".h5", table_format=".h5",
distances=nodes, distances=nodes,
...@@ -123,7 +123,7 @@ def test_run(nodes, ego): ...@@ -123,7 +123,7 @@ def test_run(nodes, ego):
prun = deepof.preprocess.project( prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"), path=os.path.join(".", "tests", "test_examples"),
arena="circular", arena="circular",
arena_dims=[380], arena_dims=tuple([380]),
video_format=".mp4", video_format=".mp4",
table_format=".h5", table_format=".h5",
distances=nodes, distances=nodes,
...@@ -147,7 +147,7 @@ def test_get_table_dicts(nodes, ego, sampler): ...@@ -147,7 +147,7 @@ def test_get_table_dicts(nodes, ego, sampler):
prun = deepof.preprocess.project( prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"), path=os.path.join(".", "tests", "test_examples"),
arena="circular", arena="circular",
arena_dims=[380], arena_dims=tuple([380]),
video_format=".mp4", video_format=".mp4",
table_format=".h5", table_format=".h5",
distances=nodes, distances=nodes,
......
...@@ -11,7 +11,63 @@ Testing module for deepof.visuals ...@@ -11,7 +11,63 @@ Testing module for deepof.visuals
from hypothesis import given from hypothesis import given
from hypothesis import settings from hypothesis import settings
from hypothesis import strategies as st from hypothesis import strategies as st
from collections import defaultdict
from deepof.utils import * from deepof.utils import *
import deepof.preprocess import deepof.preprocess
import pytest import deepof.visuals
\ No newline at end of file import matplotlib.figure
def test_plot_heatmap():
prun = (
deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=".h5",
)
.run()
.get_coords()
)
assert (
type(
deepof.visuals.plot_heatmap(
prun["test"],
["Center"],
tuple([-100, 100]),
tuple([-100, 100]),
dpi=200,
)
)
== matplotlib.figure.Figure
)
def test_model_comparison_plot():
prun = (
deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=".h5",
)
.run()
.get_coords()
)
gmm_run = gmm_model_selection(
prun["test"], n_components_range=range(1, 3), n_runs=1, part_size=100
)
assert (
type(
deepof.visuals.model_comparison_plot(
gmm_run[0], gmm_run[1], range(1, 3), cov_plot="full"
)
)
== matplotlib.figure.Figure
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment