Skip to content
Snippets Groups Projects
Commit 1e83611f authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Refactored functions in visuals.py

parent 0587b97a
Branches
Tags
No related merge requests found
......@@ -5,19 +5,20 @@ import numpy as np
import pandas as pd
import seaborn as sns
from itertools import cycle
from typing import List, Dict
# PLOTTING FUNCTIONS #
def plot_speed(Behaviour_dict, Treatments):
def plot_speed(behaviour_dict: dict, treatments: Dict[List]) -> plt.figure:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
for Treatment, Mice_list in Treatments.items():
hist = pd.concat([Behaviour_dict[mouse] for mouse in Mice_list])
for Treatment, Mice_list in treatments.items():
hist = pd.concat([behaviour_dict[mouse] for mouse in Mice_list])
sns.kdeplot(hist["bspeed"], shade=True, label=Treatment, ax=ax1)
sns.kdeplot(hist["wspeed"], shade=True, label=Treatment, ax=ax2)
......@@ -30,10 +31,13 @@ def plot_speed(Behaviour_dict, Treatments):
plt.show()
def plot_heatmap(dframe, bodyparts, xlim, ylim, save=False):
def plot_heatmap(
dframe: pd.DataFrame, bodyparts: List, xlim: float, ylim: float, save: str = False
) -> plt.figure:
"""Returns a heatmap of the movement of a specific bodypart in the arena.
If more than one bodypart is passed, it returns one subplot for each"""
# noinspection PyTypeChecker
fig, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True)
for i, bpart in enumerate(bodyparts):
......@@ -48,26 +52,24 @@ def plot_heatmap(dframe, bodyparts, xlim, ylim, save=False):
[x.set_ylim(ylim) for x in ax]
[x.set_title(bp) for x, bp in zip(ax, bodyparts)]
if save != False:
if save:
plt.savefig(save)
plt.show()
def model_comparison_plot(
bic,
m_bic,
best_bic_gmm,
n_components_range,
cov_plot,
save,
cv_types=["spherical", "tied", "diag", "full"],
):
bic: list,
m_bic: list,
n_components_range: range,
cov_plot: str,
save: str,
cv_types: tuple = ("spherical", "tied", "diag", "full"),
) -> plt.figure:
"""Plots model comparison statistics over all tests"""
m_bic = np.array(m_bic)
color_iter = cycle(["navy", "turquoise", "cornflowerblue", "darkorange"])
clf = best_bic_gmm
bars = []
# Plot the BIC scores
......@@ -93,6 +95,7 @@ def model_comparison_plot(
+ 0.5
+ 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
)
# noinspection PyArgumentList
spl.text(xpos, m_bic.min() * 0.97 + 0.1 * m_bic.max(), "*", fontsize=14)
spl.legend([b[0] for b in bars], cv_types)
spl.set_ylabel("BIC value")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment