From 0587b97abeaf9c0474c34bd59de1e6600058768c Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Mon, 14 Sep 2020 13:09:41 +0200
Subject: [PATCH] Added tests for Markov transition function in utils.py

---
 deepof/utils.py   | 105 +------------------------------------------
 deepof/visuals.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 111 insertions(+), 104 deletions(-)
 create mode 100644 deepof/visuals.py

diff --git a/deepof/utils.py b/deepof/utils.py
index 04f2f06f..2b8f9692 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -10,7 +10,7 @@ import pandas as pd
 import regex as re
 import seaborn as sns
 from copy import deepcopy
-from itertools import cycle, combinations, product
+from itertools import combinations, product
 from joblib import Parallel, delayed
 from scipy import spatial
 from scipy import stats
@@ -1136,108 +1136,5 @@ def cluster_transition_matrix(
     return trans_normed
 
 
-# PLOTTING FUNCTIONS #
-
-
-def plot_speed(Behaviour_dict, Treatments):
-    """Plots a histogram with the speed of the specified mouse.
-       Treatments is expected to be a list of lists with mice keys per treatment"""
-
-    fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
-
-    for Treatment, Mice_list in Treatments.items():
-        hist = pd.concat([Behaviour_dict[mouse] for mouse in Mice_list])
-        sns.kdeplot(hist["bspeed"], shade=True, label=Treatment, ax=ax1)
-        sns.kdeplot(hist["wspeed"], shade=True, label=Treatment, ax=ax2)
-
-    ax1.set_xlim(0, 7)
-    ax2.set_xlim(0, 7)
-    ax1.set_title("Average speed density for black mouse")
-    ax2.set_title("Average speed density for white mouse")
-    plt.xlabel("Average speed")
-    plt.ylabel("Density")
-    plt.show()
-
-
-def plot_heatmap(dframe, bodyparts, xlim, ylim, save=False):
-    """Returns a heatmap of the movement of a specific bodypart in the arena.
-       If more than one bodypart is passed, it returns one subplot for each"""
-
-    fig, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True)
-
-    for i, bpart in enumerate(bodyparts):
-        heatmap = dframe[bpart]
-        if len(bodyparts) > 1:
-            sns.kdeplot(heatmap.x, heatmap.y, cmap="jet", shade=True, alpha=1, ax=ax[i])
-        else:
-            sns.kdeplot(heatmap.x, heatmap.y, cmap="jet", shade=True, alpha=1, ax=ax)
-            ax = np.array([ax])
-
-    [x.set_xlim(xlim) for x in ax]
-    [x.set_ylim(ylim) for x in ax]
-    [x.set_title(bp) for x, bp in zip(ax, bodyparts)]
-
-    if save != False:
-        plt.savefig(save)
-
-    plt.show()
-
-
-def model_comparison_plot(
-    bic,
-    m_bic,
-    best_bic_gmm,
-    n_components_range,
-    cov_plot,
-    save,
-    cv_types=["spherical", "tied", "diag", "full"],
-):
-    """Plots model comparison statistics over all tests"""
-
-    m_bic = np.array(m_bic)
-    color_iter = cycle(["navy", "turquoise", "cornflowerblue", "darkorange"])
-    clf = best_bic_gmm
-    bars = []
-
-    # Plot the BIC scores
-    plt.figure(figsize=(12, 8))
-    spl = plt.subplot(2, 1, 1)
-    covplot = np.repeat(cv_types, len(m_bic) / 4)
-
-    for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):
-        xpos = np.array(n_components_range) + 0.2 * (i - 2)
-        bars.append(
-            spl.bar(
-                xpos,
-                m_bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
-                color=color,
-                width=0.2,
-            )
-        )
-
-    spl.set_xticks(n_components_range)
-    plt.title("BIC score per model")
-    xpos = (
-        np.mod(m_bic.argmin(), len(n_components_range))
-        + 0.5
-        + 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
-    )
-    spl.text(xpos, m_bic.min() * 0.97 + 0.1 * m_bic.max(), "*", fontsize=14)
-    spl.legend([b[0] for b in bars], cv_types)
-    spl.set_ylabel("BIC value")
-
-    spl2 = plt.subplot(2, 1, 2, sharex=spl)
-    spl2.boxplot(list(np.array(bic)[covplot == cov_plot]), positions=n_components_range)
-    spl2.set_xlabel("Number of components")
-    spl2.set_ylabel("BIC value")
-
-    plt.tight_layout()
-
-    if save:
-        plt.savefig(save)
-
-    plt.show()
-
-
 # TODO:
 #    - Add sequence plot to single_behaviour_analysis (show how the condition varies across a specified time window)
diff --git a/deepof/visuals.py b/deepof/visuals.py
new file mode 100644
index 00000000..1e808051
--- /dev/null
+++ b/deepof/visuals.py
@@ -0,0 +1,110 @@
+# @author lucasmiranda42
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from itertools import cycle
+
+
+# PLOTTING FUNCTIONS #
+
+
+def plot_speed(Behaviour_dict, Treatments):
+    """Plots a histogram with the speed of the specified mouse.
+       Treatments is expected to be a list of lists with mice keys per treatment"""
+
+    fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))
+
+    for Treatment, Mice_list in Treatments.items():
+        hist = pd.concat([Behaviour_dict[mouse] for mouse in Mice_list])
+        sns.kdeplot(hist["bspeed"], shade=True, label=Treatment, ax=ax1)
+        sns.kdeplot(hist["wspeed"], shade=True, label=Treatment, ax=ax2)
+
+    ax1.set_xlim(0, 7)
+    ax2.set_xlim(0, 7)
+    ax1.set_title("Average speed density for black mouse")
+    ax2.set_title("Average speed density for white mouse")
+    plt.xlabel("Average speed")
+    plt.ylabel("Density")
+    plt.show()
+
+
+def plot_heatmap(dframe, bodyparts, xlim, ylim, save=False):
+    """Returns a heatmap of the movement of a specific bodypart in the arena.
+       If more than one bodypart is passed, it returns one subplot for each"""
+
+    fig, ax = plt.subplots(1, len(bodyparts), sharex=True, sharey=True)
+
+    for i, bpart in enumerate(bodyparts):
+        heatmap = dframe[bpart]
+        if len(bodyparts) > 1:
+            sns.kdeplot(heatmap.x, heatmap.y, cmap="jet", shade=True, alpha=1, ax=ax[i])
+        else:
+            sns.kdeplot(heatmap.x, heatmap.y, cmap="jet", shade=True, alpha=1, ax=ax)
+            ax = np.array([ax])
+
+    [x.set_xlim(xlim) for x in ax]
+    [x.set_ylim(ylim) for x in ax]
+    [x.set_title(bp) for x, bp in zip(ax, bodyparts)]
+
+    if save != False:
+        plt.savefig(save)
+
+    plt.show()
+
+
+def model_comparison_plot(
+    bic,
+    m_bic,
+    best_bic_gmm,
+    n_components_range,
+    cov_plot,
+    save,
+    cv_types=["spherical", "tied", "diag", "full"],
+):
+    """Plots model comparison statistics over all tests"""
+
+    m_bic = np.array(m_bic)
+    color_iter = cycle(["navy", "turquoise", "cornflowerblue", "darkorange"])
+    clf = best_bic_gmm
+    bars = []
+
+    # Plot the BIC scores
+    plt.figure(figsize=(12, 8))
+    spl = plt.subplot(2, 1, 1)
+    covplot = np.repeat(cv_types, len(m_bic) / 4)
+
+    for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):
+        xpos = np.array(n_components_range) + 0.2 * (i - 2)
+        bars.append(
+            spl.bar(
+                xpos,
+                m_bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
+                color=color,
+                width=0.2,
+            )
+        )
+
+    spl.set_xticks(n_components_range)
+    plt.title("BIC score per model")
+    xpos = (
+        np.mod(m_bic.argmin(), len(n_components_range))
+        + 0.5
+        + 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
+    )
+    spl.text(xpos, m_bic.min() * 0.97 + 0.1 * m_bic.max(), "*", fontsize=14)
+    spl.legend([b[0] for b in bars], cv_types)
+    spl.set_ylabel("BIC value")
+
+    spl2 = plt.subplot(2, 1, 2, sharex=spl)
+    spl2.boxplot(list(np.array(bic)[covplot == cov_plot]), positions=n_components_range)
+    spl2.set_xlabel("Number of components")
+    spl2.set_ylabel("BIC value")
+
+    plt.tight_layout()
+
+    if save:
+        plt.savefig(save)
+
+    plt.show()
-- 
GitLab