Commit db3e8424 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for deepof.visuals

parent 15b9eb4d
......@@ -440,7 +440,7 @@ class coordinates:
for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype('timedelta64[s]')
).astype("timedelta64[s]")
if align:
assert (
......@@ -492,7 +492,7 @@ class coordinates:
for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype('timedelta64[s]')
).astype("timedelta64[s]")
return table_dict(tabs, typ="dists")
......@@ -532,7 +532,7 @@ class coordinates:
for key, tab in tabs.items():
tabs[key].index = pd.timedelta_range(
"00:00:00", length, periods=tab.shape[0] + 1, closed="left"
).astype('timedelta64[s]')
).astype("timedelta64[s]")
return table_dict(tabs, typ="angles")
......@@ -571,6 +571,7 @@ class coordinates:
return self._arena, self._arena_dims, self._scales
def rule_based_annotation(self):
"""Annotates coordinates using a simple rule-based pipeline"""
pass
......@@ -634,10 +635,12 @@ class table_dict(dict):
else [0, self._arena_dims[i][1]]
)
plot_heatmap(
heatmaps = plot_heatmap(
list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save,
)
return heatmaps
def get_training_set(self, test_videos: int = 0) -> Tuple[np.ndarray, np.ndarray]:
"""Generates training and test sets as numpy.array objects for model training"""
......
......@@ -13,42 +13,36 @@ import numpy as np
import pandas as pd
import seaborn as sns
from itertools import cycle
from typing import List, Dict
from typing import List
# PLOTTING FUNCTIONS #
def plot_speed(
behaviour_dict: Dict[str, pd.DataFrame], treatments: Dict[str, List]
) -> plt.figure:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
for Treatment, Mice_list in treatments.items():
hist = pd.concat([behaviour_dict[mouse] for mouse in Mice_list])
sns.kdeplot(hist["bspeed"], shade=True, label=Treatment, ax=ax1)
sns.kdeplot(hist["wspeed"], shade=True, label=Treatment, ax=ax2)
ax1.set_xlim(0, 7)
ax2.set_xlim(0, 7)
ax1.set_title("Average speed density for black mouse")
ax2.set_title("Average speed density for white mouse")
plt.xlabel("Average speed")
plt.ylabel("Density")
plt.show()
def plot_heatmap(
dframe: pd.DataFrame, bodyparts: List, xlim: float, ylim: float, save: str = False
dframe: pd.DataFrame,
bodyparts: List,
xlim: tuple,
ylim: tuple,
save: str = False,
dpi: int = 200,
) -> plt.figure:
"""Returns a heatmap of the movement of a specific bodypart in the arena.
If more than one bodypart is passed, it returns one subplot for each"""
If more than one bodypart is passed, it returns one subplot for each
Parameters:
- dframe (pandas.DataFrame): table_dict value with info to plot
- bodyparts (List): bodyparts to represent (at least 1)
- xlim (float): limits of the x-axis
- ylim (float): limits of the y-axis
- save (str): name of the file to which the figure should be saved
- dpi (int): dots per inch of the returned image
Returns:
- heatmaps (plt.figure): figure with the specified characteristics"""
# noinspection PyTypeChecker
fig, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True)
heatmaps, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True, dpi=dpi)
for i, bpart in enumerate(bodyparts):
heatmap = dframe[bpart]
......@@ -65,7 +59,7 @@ def plot_heatmap(
if save:
plt.savefig(save)
plt.show()
return heatmaps
def model_comparison_plot(
......@@ -73,17 +67,40 @@ def model_comparison_plot(
m_bic: list,
n_components_range: range,
cov_plot: str,
save: str,
save: str = False,
cv_types: tuple = ("spherical", "tied", "diag", "full"),
dpi: int = 200,
) -> plt.figure:
"""Plots model comparison statistics over all tests"""
"""
Plots model comparison statistics for Gaussian Mixture Model analysis.
Similar to https://scikit-learn.org/stable/modules/mixture.html, it shows
an upper panel with BIC per number of components and covariance matrix type
in a bar plot, and a lower panel with box plots showing bootstrap runs of the
models corresponding to one of the covariance types.
Parameters:
- bic (list): list with BIC for all used models
- m_bic (list): list with minimum bic across cov matrices
for all used models
- n_components_range (range): range of components to evaluate
- cov_plot (str): covariance matrix to use in the lower panel
- save (str): name of the file to which the figure should be saved
- cv_types (tuple): tuple indicating which covariance matrix types
to use. All (spherical, tied, diag and full) used by default.
- dpi (int): dots per inch of the returned image
Returns:
- modelcomp (plt.figure): figure with all specified characteristics
"""
m_bic = np.array(m_bic)
color_iter = cycle(["navy", "turquoise", "cornflowerblue", "darkorange"])
bars = []
# Plot the BIC scores
plt.figure(figsize=(12, 8))
modelcomp = plt.figure(dpi=dpi)
spl = plt.subplot(2, 1, 1)
covplot = np.repeat(cv_types, len(m_bic) / 4)
......@@ -115,9 +132,7 @@ def model_comparison_plot(
spl2.set_xlabel("Number of components")
spl2.set_ylabel("BIC value")
plt.tight_layout()
if save:
plt.savefig(save)
plt.show()
return modelcomp
......@@ -14,4 +14,4 @@ from hypothesis import strategies as st
from collections import defaultdict
from deepof.utils import *
import deepof.preprocess
import pytest
\ No newline at end of file
import pytest
......@@ -32,7 +32,7 @@ def test_project_init(table_type, arena_type):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena=arena_type,
arena_dims=[380],
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=table_type,
......@@ -41,7 +41,7 @@ def test_project_init(table_type, arena_type):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena=arena_type,
arena_dims=[380],
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=table_type,
......@@ -72,7 +72,7 @@ def test_get_distances(nodes, ego):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=".h5",
......@@ -98,7 +98,7 @@ def test_get_angles(nodes, ego):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
distances=nodes,
......@@ -123,7 +123,7 @@ def test_run(nodes, ego):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
distances=nodes,
......@@ -147,7 +147,7 @@ def test_get_table_dicts(nodes, ego, sampler):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
distances=nodes,
......
......@@ -11,7 +11,63 @@ Testing module for deepof.visuals
from hypothesis import given
from hypothesis import settings
from hypothesis import strategies as st
from collections import defaultdict
from deepof.utils import *
import deepof.preprocess
import pytest
\ No newline at end of file
import deepof.visuals
import matplotlib.figure
def test_plot_heatmap():
prun = (
deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=".h5",
)
.run()
.get_coords()
)
assert (
type(
deepof.visuals.plot_heatmap(
prun["test"],
["Center"],
tuple([-100, 100]),
tuple([-100, 100]),
dpi=200,
)
)
== matplotlib.figure.Figure
)
def test_model_comparison_plot():
prun = (
deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=tuple([380]),
angles=False,
video_format=".mp4",
table_format=".h5",
)
.run()
.get_coords()
)
gmm_run = gmm_model_selection(
prun["test"], n_components_range=range(1, 3), n_runs=1, part_size=100
)
assert (
type(
deepof.visuals.model_comparison_plot(
gmm_run[0], gmm_run[1], range(1, 3), cov_plot="full"
)
)
== matplotlib.figure.Figure
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment