diff --git a/deepof/data.py b/deepof/data.py
index f111696109f6961637b05a1ab3cc32e41dcd30e5..ed52fd790d0459d1c21404e1dc83173101fd0e8d 100644
--- a/deepof/data.py
+++ b/deepof/data.py
@@ -60,23 +60,23 @@ class project:
     """
 
     def __init__(
-            self,
-            animal_ids: List = tuple([""]),
-            arena: str = "circular",
-            arena_detection: str = "rule-based",
-            arena_dims: tuple = (1,),
-            enable_iterative_imputation: bool = None,
-            exclude_bodyparts: List = tuple([""]),
-            exp_conditions: dict = None,
-            interpolate_outliers: bool = True,
-            interpolation_limit: int = 5,
-            interpolation_std: int = 5,
-            likelihood_tol: float = 0.25,
-            model: str = "mouse_topview",
-            path: str = deepof.utils.os.path.join("."),
-            smooth_alpha: float = 0.99,
-            table_format: str = "autodetect",
-            video_format: str = ".mp4",
+        self,
+        animal_ids: List = tuple([""]),
+        arena: str = "circular",
+        arena_detection: str = "rule-based",
+        arena_dims: tuple = (1,),
+        enable_iterative_imputation: bool = None,
+        exclude_bodyparts: List = tuple([""]),
+        exp_conditions: dict = None,
+        interpolate_outliers: bool = True,
+        interpolation_limit: int = 5,
+        interpolation_std: int = 5,
+        likelihood_tol: float = 0.25,
+        model: str = "mouse_topview",
+        path: str = deepof.utils.os.path.join("."),
+        smooth_alpha: float = 0.99,
+        table_format: str = "autodetect",
+        video_format: str = ".mp4",
     ):
 
         # Set working paths
@@ -287,8 +287,8 @@ class project:
                 ).T.index.remove_unused_levels()
 
                 tab = value.loc[
-                      :, [i for i in value.columns.levels[0] if i not in lablist]
-                      ]
+                    :, [i for i in value.columns.levels[0] if i not in lablist]
+                ]
 
                 tab.columns = tabcols
 
@@ -362,14 +362,14 @@ class project:
 
         for key in distance_dict.keys():
             distance_dict[key] = distance_dict[key].loc[
-                                 :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
-                                 ]
+                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
+            ]
 
         if self.ego:
             for key, val in distance_dict.items():
                 distance_dict[key] = val.loc[
-                                     :, [dist for dist in val.columns if self.ego in dist]
-                                     ]
+                    :, [dist for dist in val.columns if self.ego in dist]
+                ]
 
         return distance_dict
 
@@ -472,20 +472,20 @@ class coordinates:
     """
 
     def __init__(
-            self,
-            arena: str,
-            arena_detection: str,
-            arena_dims: np.array,
-            path: str,
-            quality: dict,
-            scales: np.array,
-            tables: dict,
-            videos: list,
-            angles: dict = None,
-            animal_ids: List = tuple([""]),
-            distances: dict = None,
-            exp_conditions: dict = None,
-            ellipse_detection: tf.keras.models.Model = None,
+        self,
+        arena: str,
+        arena_detection: str,
+        arena_dims: np.array,
+        path: str,
+        quality: dict,
+        scales: np.array,
+        tables: dict,
+        videos: list,
+        angles: dict = None,
+        animal_ids: List = tuple([""]),
+        distances: dict = None,
+        exp_conditions: dict = None,
+        ellipse_detection: tf.keras.models.Model = None,
     ):
         self._animal_ids = animal_ids
         self._arena = arena
@@ -510,15 +510,15 @@ class coordinates:
             return "deepof analysis of {} videos".format(len(self._videos))
 
     def get_coords(
-            self,
-            center: str = "arena",
-            polar: bool = False,
-            speed: int = 0,
-            length: str = None,
-            align: bool = False,
-            align_inplace: bool = False,
-            propagate_labels: bool = False,
-            propagate_annotations: Dict = False,
+        self,
+        center: str = "arena",
+        polar: bool = False,
+        speed: int = 0,
+        length: str = None,
+        align: bool = False,
+        align_inplace: bool = False,
+        propagate_labels: bool = False,
+        propagate_annotations: Dict = False,
     ) -> Table_dict:
         """
         Returns a table_dict object with the coordinates of each animal as values.
@@ -558,23 +558,23 @@ class coordinates:
 
                     try:
                         value.loc[:, (slice("coords"), ["x"])] = (
-                                value.loc[:, (slice("coords"), ["x"])]
-                                - self._scales[i][0] / 2
+                            value.loc[:, (slice("coords"), ["x"])]
+                            - self._scales[i][0] / 2
                         )
 
                         value.loc[:, (slice("coords"), ["y"])] = (
-                                value.loc[:, (slice("coords"), ["y"])]
-                                - self._scales[i][1] / 2
+                            value.loc[:, (slice("coords"), ["y"])]
+                            - self._scales[i][1] / 2
                         )
                     except KeyError:
                         value.loc[:, (slice("coords"), ["rho"])] = (
-                                value.loc[:, (slice("coords"), ["rho"])]
-                                - self._scales[i][0] / 2
+                            value.loc[:, (slice("coords"), ["rho"])]
+                            - self._scales[i][0] / 2
                         )
 
                         value.loc[:, (slice("coords"), ["phi"])] = (
-                                value.loc[:, (slice("coords"), ["phi"])]
-                                - self._scales[i][1] / 2
+                            value.loc[:, (slice("coords"), ["phi"])]
+                            - self._scales[i][1] / 2
                         )
 
         elif isinstance(center, str) and center != "arena":
@@ -583,24 +583,24 @@ class coordinates:
 
                 try:
                     value.loc[:, (slice("coords"), ["x"])] = value.loc[
-                                                             :, (slice("coords"), ["x"])
-                                                             ].subtract(value[center]["x"], axis=0)
+                        :, (slice("coords"), ["x"])
+                    ].subtract(value[center]["x"], axis=0)
 
                     value.loc[:, (slice("coords"), ["y"])] = value.loc[
-                                                             :, (slice("coords"), ["y"])
-                                                             ].subtract(value[center]["y"], axis=0)
+                        :, (slice("coords"), ["y"])
+                    ].subtract(value[center]["y"], axis=0)
                 except KeyError:
                     value.loc[:, (slice("coords"), ["rho"])] = value.loc[
-                                                               :, (slice("coords"), ["rho"])
-                                                               ].subtract(value[center]["rho"], axis=0)
+                        :, (slice("coords"), ["rho"])
+                    ].subtract(value[center]["rho"], axis=0)
 
                     value.loc[:, (slice("coords"), ["phi"])] = value.loc[
-                                                               :, (slice("coords"), ["phi"])
-                                                               ].subtract(value[center]["phi"], axis=0)
+                        :, (slice("coords"), ["phi"])
+                    ].subtract(value[center]["phi"], axis=0)
 
                 tabs[key] = value.loc[
-                            :, [tab for tab in value.columns if tab[0] != center]
-                            ]
+                    :, [tab for tab in value.columns if tab[0] != center]
+                ]
 
         if speed:
             for key, tab in tabs.items():
@@ -615,16 +615,16 @@ class coordinates:
 
         if align:
             assert (
-                    align in list(tabs.values())[0].columns.levels[0]
+                align in list(tabs.values())[0].columns.levels[0]
             ), "align must be set to the name of a bodypart"
 
             for key, tab in tabs.items():
                 # Bring forward the column to align
                 columns = [i for i in tab.columns if align not in i]
                 columns = [
-                              (align, ("phi" if polar else "x")),
-                              (align, ("rho" if polar else "y")),
-                          ] + columns
+                    (align, ("phi" if polar else "x")),
+                    (align, ("rho" if polar else "y")),
+                ] + columns
                 tab = tab[columns]
                 tabs[key] = tab
 
@@ -659,11 +659,11 @@ class coordinates:
         )
 
     def get_distances(
-            self,
-            speed: int = 0,
-            length: str = None,
-            propagate_labels: bool = False,
-            propagate_annotations: Dict = False,
+        self,
+        speed: int = 0,
+        length: str = None,
+        propagate_labels: bool = False,
+        propagate_annotations: Dict = False,
     ) -> Table_dict:
         """
         Returns a table_dict object with the distances between body parts animal as values.
@@ -720,12 +720,12 @@ class coordinates:
         )  # pragma: no cover
 
     def get_angles(
-            self,
-            degrees: bool = False,
-            speed: int = 0,
-            length: str = None,
-            propagate_labels: bool = False,
-            propagate_annotations: Dict = False,
+        self,
+        degrees: bool = False,
+        speed: int = 0,
+        length: str = None,
+        propagate_labels: bool = False,
+        propagate_annotations: Dict = False,
     ) -> Table_dict:
         """
         Returns a table_dict object with the angles between body parts animal as values.
@@ -811,11 +811,11 @@ class coordinates:
 
     # noinspection PyDefaultArgument
     def rule_based_annotation(
-            self,
-            params: Dict = {},
-            video_output: bool = False,
-            frame_limit: int = np.inf,
-            debug: bool = False,
+        self,
+        params: Dict = {},
+        video_output: bool = False,
+        frame_limit: int = np.inf,
+        debug: bool = False,
     ) -> Table_dict:
         """Annotates coordinates using a simple rule-based pipeline"""
 
@@ -883,29 +883,29 @@ class coordinates:
 
     @staticmethod
     def deep_unsupervised_embedding(
-            preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
-            batch_size: int = 256,
-            encoding_size: int = 4,
-            epochs: int = 35,
-            hparams: dict = None,
-            kl_warmup: int = 0,
-            log_history: bool = True,
-            log_hparams: bool = False,
-            loss: str = "ELBO",
-            mmd_warmup: int = 0,
-            montecarlo_kl: int = 10,
-            n_components: int = 25,
-            output_path: str = ".",
-            phenotype_class: float = 0,
-            predictor: float = 0,
-            pretrained: str = False,
-            save_checkpoints: bool = False,
-            save_weights: bool = True,
-            variational: bool = True,
-            reg_cat_clusters: bool = False,
-            reg_cluster_variance: bool = False,
-            entropy_samples: int = 10000,
-            entropy_knn: int = 100,
+        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
+        batch_size: int = 256,
+        encoding_size: int = 4,
+        epochs: int = 35,
+        hparams: dict = None,
+        kl_warmup: int = 0,
+        log_history: bool = True,
+        log_hparams: bool = False,
+        loss: str = "ELBO",
+        mmd_warmup: int = 0,
+        montecarlo_kl: int = 10,
+        n_components: int = 25,
+        output_path: str = ".",
+        phenotype_class: float = 0,
+        predictor: float = 0,
+        pretrained: str = False,
+        save_checkpoints: bool = False,
+        save_weights: bool = True,
+        variational: bool = True,
+        reg_cat_clusters: bool = False,
+        reg_cluster_variance: bool = False,
+        entropy_samples: int = 10000,
+        entropy_knn: int = 100,
     ) -> Tuple:
         """
         Annotates coordinates using an unsupervised autoencoder.
@@ -983,15 +983,15 @@ class table_dict(dict):
     """
 
     def __init__(
-            self,
-            tabs: Dict,
-            typ: str,
-            arena: str = None,
-            arena_dims: np.array = None,
-            center: str = None,
-            polar: bool = None,
-            propagate_labels: bool = False,
-            propagate_annotations: Dict = False,
+        self,
+        tabs: Dict,
+        typ: str,
+        arena: str = None,
+        arena_dims: np.array = None,
+        center: str = None,
+        polar: bool = None,
+        propagate_labels: bool = False,
+        propagate_annotations: Dict = False,
     ):
         super().__init__(tabs)
         self._type = typ
@@ -1014,13 +1014,13 @@ class table_dict(dict):
 
     # noinspection PyTypeChecker
     def plot_heatmaps(
-            self,
-            bodyparts: list,
-            xlim: float = None,
-            ylim: float = None,
-            save: bool = False,
-            i: int = 0,
-            dpi: int = 100,
+        self,
+        bodyparts: list,
+        xlim: float = None,
+        ylim: float = None,
+        save: bool = False,
+        i: int = 0,
+        dpi: int = 100,
     ) -> plt.figure:
         """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
 
@@ -1046,9 +1046,9 @@ class table_dict(dict):
             return heatmaps
 
     def get_training_set(
-            self,
-            test_videos: int = 0,
-            encode_labels: bool = True,
+        self,
+        test_videos: int = 0,
+        encode_labels: bool = True,
     ) -> Tuple[np.ndarray, list, Union[np.ndarray, list], list]:
         """Generates training and test sets as numpy.array objects for model training"""
 
@@ -1107,17 +1107,17 @@ class table_dict(dict):
 
     # noinspection PyTypeChecker,PyGlobalUndefined
     def preprocess(
-            self,
-            window_size: int = 1,
-            window_step: int = 1,
-            scale: str = "standard",
-            test_videos: int = 0,
-            verbose: bool = False,
-            conv_filter: bool = None,
-            sigma: float = 1.0,
-            shift: float = 0.0,
-            shuffle: bool = False,
-            align: str = False,
+        self,
+        window_size: int = 1,
+        window_step: int = 1,
+        scale: str = "standard",
+        test_videos: int = 0,
+        verbose: bool = False,
+        conv_filter: bool = None,
+        sigma: float = 1.0,
+        shift: float = 0.0,
+        shuffle: bool = False,
+        align: str = False,
     ) -> np.ndarray:
         """
 
@@ -1244,7 +1244,7 @@ class table_dict(dict):
         return X_train, y_train, np.array(X_test), np.array(y_test)
 
     def random_projection(
-            self, n_components: int = None, sample: int = 1000
+        self, n_components: int = None, sample: int = 1000
     ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
         """Returns a training set generated from the 2D original data (time x features) and a random projection
         to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
@@ -1266,7 +1266,7 @@ class table_dict(dict):
         return X, rproj
 
     def pca(
-            self, n_components: int = None, sample: int = 1000, kernel: str = "linear"
+        self, n_components: int = None, sample: int = 1000, kernel: str = "linear"
     ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
         """Returns a training set generated from the 2D original data (time x features) and a PCA projection
         to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
@@ -1288,7 +1288,7 @@ class table_dict(dict):
         return X, pca
 
     def tsne(
-            self, n_components: int = None, sample: int = 1000, perplexity: int = 30
+        self, n_components: int = None, sample: int = 1000, perplexity: int = 30
     ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
         """Returns a training set generated from the 2D original data (time x features) and a PCA projection
         to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
@@ -1332,6 +1332,7 @@ def merge_tables(*args):
 
     return merged_tables
 
+
 # TODO:
 #   - Generate ragged training array using a metric (acceleration, maybe?)
 #   - Use something like Dynamic Time Warping to put all instances in the same length
diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index 51a24f7e443fcc37826d7686663b3591e08ef2de..d7e93f11180baa7f097c3b6994bb52d2e68dd699 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -93,18 +93,18 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
     """Hyperparameter tuning pipeline for deepof.models.SEQ_2_SEQ_GMVAE"""
 
     def __init__(
-            self,
-            input_shape: tuple,
-            encoding: int,
-            kl_warmup_epochs: int = 0,
-            learn_rate: float = 1e-3,
-            loss: str = "ELBO+MMD",
-            mmd_warmup_epochs: int = 0,
-            number_of_components: int = 10,
-            overlap_loss: float = False,
-            phenotype_predictor: float = 0.0,
-            predictor: float = 0.0,
-            prior: str = "standard_normal",
+        self,
+        input_shape: tuple,
+        encoding: int,
+        kl_warmup_epochs: int = 0,
+        learn_rate: float = 1e-3,
+        loss: str = "ELBO+MMD",
+        mmd_warmup_epochs: int = 0,
+        number_of_components: int = 10,
+        overlap_loss: float = False,
+        phenotype_predictor: float = 0.0,
+        predictor: float = 0.0,
+        prior: str = "standard_normal",
     ):
         super().__init__()
         self.input_shape = input_shape
@@ -120,7 +120,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.prior = prior
 
         assert (
-                "ELBO" in self.loss or "MMD" in self.loss
+            "ELBO" in self.loss or "MMD" in self.loss
         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
 
     def get_hparams(self, hp):
@@ -192,6 +192,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
         return gmvaep
 
+
 # TODO:
 #    - We can add as many parameters as we want to the hypermodel!
 #    with this implementation, predictor, warmup, loss and even number of components can be tuned using BayOpt
diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index f5f54311441378729baaaeb5a54bc6abcb043ddb..a0316a591225f47ae966516cb4864dfcc0fc7ba9 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -45,7 +45,7 @@ class exponential_learning_rate(tf.keras.callbacks.Callback):
 
 
 def find_learning_rate(
-        model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
+    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
 ):
     """Trains the provided model for an epoch with an exponentially increasing learning rate"""
 
@@ -124,9 +124,9 @@ def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
     y_kernel = compute_kernel(y, y)
     xy_kernel = compute_kernel(x, y)
     mmd = (
-            tf.reduce_mean(x_kernel)
-            + tf.reduce_mean(y_kernel)
-            - 2 * tf.reduce_mean(xy_kernel)
+        tf.reduce_mean(x_kernel)
+        + tf.reduce_mean(y_kernel)
+        - 2 * tf.reduce_mean(xy_kernel)
     )
     return mmd
 
@@ -141,13 +141,13 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
     """
 
     def __init__(
-            self,
-            iterations: int,
-            max_rate: float,
-            start_rate: float = None,
-            last_iterations: int = None,
-            last_rate: float = None,
-            log_dir: str = ".",
+        self,
+        iterations: int,
+        max_rate: float,
+        start_rate: float = None,
+        last_iterations: int = None,
+        last_rate: float = None,
+        log_dir: str = ".",
     ):
         super().__init__()
         self.iterations = iterations
@@ -213,13 +213,13 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
     """
 
     def __init__(
-            self,
-            encoding_dim: int,
-            variational: bool = True,
-            validation_data: np.ndarray = None,
-            k: int = 100,
-            samples: int = 10000,
-            log_dir: str = ".",
+        self,
+        encoding_dim: int,
+        variational: bool = True,
+        validation_data: np.ndarray = None,
+        k: int = 100,
+        samples: int = 10000,
+        log_dir: str = ".",
     ):
         super().__init__()
         self.enc = encoding_dim
@@ -531,7 +531,7 @@ class Cluster_overlap(Layer):
         dists = []
         for k in range(self.n_components):
             locs = (target[..., : self.lat_dims, k],)
-            scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
+            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
 
             dists.append(
                 tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
diff --git a/deepof/models.py b/deepof/models.py
index 4c0648b4f38b8ae97783075db4e0b33af7c86c74..ee5affbe07f0b58d4f77ae7cef9de6c52a8e2097 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -34,9 +34,9 @@ class SEQ_2_SEQ_AE:
     """  Simple sequence to sequence autoencoder implemented with tf.keras """
 
     def __init__(
-            self,
-            architecture_hparams: Dict = {},
-            huber_delta: float = 1.0,
+        self,
+        architecture_hparams: Dict = {},
+        huber_delta: float = 1.0,
     ):
         self.hparams = self.get_hparams(architecture_hparams)
         self.CONV_filters = self.hparams["units_conv"]
@@ -171,8 +171,8 @@ class SEQ_2_SEQ_AE:
         )
 
     def build(
-            self,
-            input_shape: tuple,
+        self,
+        input_shape: tuple,
     ) -> Tuple[Any, Any, Any]:
         """Builds the tf.keras model"""
 
@@ -242,22 +242,22 @@ class SEQ_2_SEQ_GMVAE:
     """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
 
     def __init__(
-            self,
-            architecture_hparams: dict = {},
-            batch_size: int = 256,
-            compile_model: bool = True,
-            encoding: int = 6,
-            kl_warmup_epochs: int = 20,
-            loss: str = "ELBO",
-            mmd_warmup_epochs: int = 20,
-            montecarlo_kl: int = 1,
-            neuron_control: bool = False,
-            number_of_components: int = 1,
-            overlap_loss: float = 0.0,
-            phenotype_prediction: float = 0.0,
-            predictor: float = 0.0,
-            reg_cat_clusters: bool = False,
-            reg_cluster_variance: bool = False,
+        self,
+        architecture_hparams: dict = {},
+        batch_size: int = 256,
+        compile_model: bool = True,
+        encoding: int = 6,
+        kl_warmup_epochs: int = 20,
+        loss: str = "ELBO",
+        mmd_warmup_epochs: int = 20,
+        montecarlo_kl: int = 1,
+        neuron_control: bool = False,
+        number_of_components: int = 1,
+        overlap_loss: float = 0.0,
+        phenotype_prediction: float = 0.0,
+        predictor: float = 0.0,
+        reg_cat_clusters: bool = False,
+        reg_cluster_variance: bool = False,
     ):
         self.hparams = self.get_hparams(architecture_hparams)
         self.batch_size = batch_size
@@ -290,7 +290,7 @@ class SEQ_2_SEQ_GMVAE:
         self.reg_cluster_variance = reg_cluster_variance
 
         assert (
-                "ELBO" in self.loss or "MMD" in self.loss
+            "ELBO" in self.loss or "MMD" in self.loss
         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
 
     @property
@@ -613,7 +613,7 @@ class SEQ_2_SEQ_GMVAE:
                     tfd.Independent(
                         tfd.Normal(
                             loc=gauss[1][..., : self.ENCODING, k],
-                            scale=softplus(gauss[1][..., self.ENCODING:, k]) + 1e-5,
+                            scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
                         ),
                         reinterpreted_batch_ndims=1,
                     )
@@ -761,6 +761,7 @@ class SEQ_2_SEQ_GMVAE:
     def prior(self, value):
         self._prior = value
 
+
 # TODO:
 #       - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
 #       - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py
index 7da15d5e69ed4c74d4c7d3bee0abdfa727519e1c..db37e25010ecd6ad9aa969d4221e15550ede40da 100644
--- a/deepof/pose_utils.py
+++ b/deepof/pose_utils.py
@@ -32,12 +32,12 @@ Coordinates = NewType("Coordinates", Any)
 
 
 def close_single_contact(
-        pos_dframe: pd.DataFrame,
-        left: str,
-        right: str,
-        tol: float,
-        arena_abs: int,
-        arena_rel: int,
+    pos_dframe: pd.DataFrame,
+    left: str,
+    right: str,
+    tol: float,
+    arena_abs: int,
+    arena_rel: int,
 ) -> np.array:
     """Returns a boolean array that's True if the specified body parts are closer than tol.
 
@@ -58,8 +58,8 @@ def close_single_contact(
 
     if isinstance(right, str):
         close_contact = (
-                                np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
-                        ) / arena_rel < tol
+            np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
+        ) / arena_rel < tol
 
     elif isinstance(right, list):
         close_contact = np.any(
@@ -76,15 +76,15 @@ def close_single_contact(
 
 
 def close_double_contact(
-        pos_dframe: pd.DataFrame,
-        left1: str,
-        left2: str,
-        right1: str,
-        right2: str,
-        tol: float,
-        arena_abs: int,
-        arena_rel: int,
-        rev: bool = False,
+    pos_dframe: pd.DataFrame,
+    left1: str,
+    left2: str,
+    right1: str,
+    right2: str,
+    tol: float,
+    arena_abs: int,
+    arena_rel: int,
+    rev: bool = False,
 ) -> np.array:
     """Returns a boolean array that's True if the specified body parts are closer than tol.
 
@@ -106,25 +106,25 @@ def close_double_contact(
 
     if rev:
         double_contact = (
-                                 (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
-                                 / arena_rel
-                                 < tol
-                         ) & (
-                                 (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
-                                 / arena_rel
-                                 < tol
-                         )
+            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
+            / arena_rel
+            < tol
+        ) & (
+            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
+            / arena_rel
+            < tol
+        )
 
     else:
         double_contact = (
-                                 (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
-                                 / arena_rel
-                                 < tol
-                         ) & (
-                                 (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
-                                 / arena_rel
-                                 < tol
-                         )
+            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
+            / arena_rel
+            < tol
+        ) & (
+            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
+            / arena_rel
+            < tol
+        )
 
     return double_contact
 
@@ -152,12 +152,12 @@ def outside_ellipse(x, y, e_center, e_axes, e_angle, threshold=0.0):
 
 
 def climb_wall(
-        arena_type: str,
-        arena: np.array,
-        pos_dict: pd.DataFrame,
-        tol: float,
-        nose: str,
-        centered_data: bool = False,
+    arena_type: str,
+    arena: np.array,
+    pos_dict: pd.DataFrame,
+    tol: float,
+    nose: str,
+    centered_data: bool = False,
 ) -> np.array:
     """Returns True if the specified mouse is climbing the wall
 
@@ -197,16 +197,16 @@ def climb_wall(
 
 
 def sniff_object(
-        speed_dframe: pd.DataFrame,
-        arena_type: str,
-        arena: np.array,
-        pos_dict: pd.DataFrame,
-        tol: float,
-        tol_speed: float,
-        nose: str,
-        centered_data: bool = False,
-        object: str = "arena",
-        animal_id: str = "",
+    speed_dframe: pd.DataFrame,
+    arena_type: str,
+    arena: np.array,
+    pos_dict: pd.DataFrame,
+    tol: float,
+    tol_speed: float,
+    nose: str,
+    centered_data: bool = False,
+    object: str = "arena",
+    animal_id: str = "",
 ):
     """Returns True if the specified mouse is sniffing an object
 
@@ -269,11 +269,11 @@ def sniff_object(
 
 
 def huddle(
-        pos_dframe: pd.DataFrame,
-        speed_dframe: pd.DataFrame,
-        tol_forward: float,
-        tol_speed: float,
-        animal_id: str = "",
+    pos_dframe: pd.DataFrame,
+    speed_dframe: pd.DataFrame,
+    tol_forward: float,
+    tol_speed: float,
+    animal_id: str = "",
 ) -> np.array:
     """Returns true when the mouse is huddling using simple rules.
 
@@ -294,18 +294,18 @@ def huddle(
         animal_id += "_"
 
     forward = (
-                      np.linalg.norm(
-                          pos_dframe[animal_id + "Left_bhip"] - pos_dframe[animal_id + "Left_fhip"],
-                          axis=1,
-                      )
-                      < tol_forward
-              ) | (
-                      np.linalg.norm(
-                          pos_dframe[animal_id + "Right_bhip"] - pos_dframe[animal_id + "Right_fhip"],
-                          axis=1,
-                      )
-                      < tol_forward
-              )
+        np.linalg.norm(
+            pos_dframe[animal_id + "Left_bhip"] - pos_dframe[animal_id + "Left_fhip"],
+            axis=1,
+        )
+        < tol_forward
+    ) | (
+        np.linalg.norm(
+            pos_dframe[animal_id + "Right_bhip"] - pos_dframe[animal_id + "Right_fhip"],
+            axis=1,
+        )
+        < tol_forward
+    )
 
     speed = speed_dframe[animal_id + "Center"] < tol_speed
     hudd = forward & speed
@@ -314,11 +314,11 @@ def huddle(
 
 
 def dig(
-        speed_dframe: pd.DataFrame,
-        likelihood_dframe: pd.DataFrame,
-        tol_speed: float,
-        tol_likelihood: float,
-        animal_id: str = "",
+    speed_dframe: pd.DataFrame,
+    likelihood_dframe: pd.DataFrame,
+    tol_speed: float,
+    tol_likelihood: float,
+    animal_id: str = "",
 ):
     """Returns true when the mouse is digging using simple rules.
 
@@ -345,11 +345,11 @@ def dig(
 
 
 def look_around(
-        speed_dframe: pd.DataFrame,
-        likelihood_dframe: pd.DataFrame,
-        tol_speed: float,
-        tol_likelihood: float,
-        animal_id: str = "",
+    speed_dframe: pd.DataFrame,
+    likelihood_dframe: pd.DataFrame,
+    tol_speed: float,
+    tol_likelihood: float,
+    animal_id: str = "",
 ):
     """Returns true when the mouse is digging using simple rules.
 
@@ -378,12 +378,12 @@ def look_around(
 
 
 def following_path(
-        distance_dframe: pd.DataFrame,
-        position_dframe: pd.DataFrame,
-        follower: str,
-        followed: str,
-        frames: int = 20,
-        tol: float = 0,
+    distance_dframe: pd.DataFrame,
+    position_dframe: pd.DataFrame,
+    follower: str,
+    followed: str,
+    frames: int = 20,
+    tol: float = 0,
 ) -> np.array:
     """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
     followed has walked over the last specified number of frames
@@ -414,15 +414,15 @@ def following_path(
 
     # Check that the animals are oriented follower's nose -> followed's tail
     right_orient1 = (
-            distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
-            < distance_dframe[
-                tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
-            ]
+        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
+        < distance_dframe[
+            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
+        ]
     )
 
     right_orient2 = (
-            distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
-            < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
+        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
+        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
     )
 
     follow = np.all(
@@ -434,13 +434,13 @@ def following_path(
 
 
 def single_behaviour_analysis(
-        behaviour_name: str,
-        treatment_dict: dict,
-        behavioural_dict: dict,
-        plot: int = 0,
-        stat_tests: bool = True,
-        save: str = None,
-        ylim: float = None,
+    behaviour_name: str,
+    treatment_dict: dict,
+    behavioural_dict: dict,
+    plot: int = 0,
+    stat_tests: bool = True,
+    save: str = None,
+    ylim: float = None,
 ) -> list:
     """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
     with the actual tags, outputs a box plot and a series of significance tests amongst the groups
@@ -495,13 +495,13 @@ def single_behaviour_analysis(
         for i in combinations(treatment_dict.keys(), 2):
             # Solves issue with automatically generated examples
             if np.any(
-                    np.array(
-                        [
-                            beh_dict[i[0]] == beh_dict[i[1]],
-                            np.var(beh_dict[i[0]]) == 0,
-                            np.var(beh_dict[i[1]]) == 0,
-                        ]
-                    )
+                np.array(
+                    [
+                        beh_dict[i[0]] == beh_dict[i[1]],
+                        np.var(beh_dict[i[0]]) == 0,
+                        np.var(beh_dict[i[1]]) == 0,
+                    ]
+                )
             ):
                 stat_dict[i] = "Identical sources. Couldn't run"
             else:
@@ -514,7 +514,7 @@ def single_behaviour_analysis(
 
 
 def max_behaviour(
-        behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
+    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
 ) -> np.array:
     """Returns the most frequent behaviour in a window of window_size frames
 
@@ -598,19 +598,19 @@ def frame_corners(w, h, corners: dict = {}):
 
 # noinspection PyDefaultArgument,PyProtectedMember
 def rule_based_tagging(
-        tracks: List,
-        videos: List,
-        coordinates: Coordinates,
-        coords: Any,
-        dists: Any,
-        speeds: Any,
-        vid_index: int,
-        arena_type: str,
-        arena_detection_mode: str,
-        ellipse_detection_model: tf.keras.models.Model = None,
-        recog_limit: int = 100,
-        path: str = os.path.join("."),
-        params: dict = {},
+    tracks: List,
+    videos: List,
+    coordinates: Coordinates,
+    coords: Any,
+    dists: Any,
+    speeds: Any,
+    vid_index: int,
+    arena_type: str,
+    arena_detection_mode: str,
+    ellipse_detection_model: tf.keras.models.Model = None,
+    recog_limit: int = 100,
+    path: str = os.path.join("."),
+    params: dict = {},
 ) -> pd.DataFrame:
     """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
     video displaying the information in real time
@@ -830,18 +830,18 @@ def rule_based_tagging(
 
 
 def tag_rulebased_frames(
-        frame,
-        font,
-        frame_speeds,
-        animal_ids,
-        corners,
-        tag_dict,
-        fnum,
-        undercond,
-        hparams,
-        arena,
-        debug,
-        coords,
+    frame,
+    font,
+    frame_speeds,
+    animal_ids,
+    corners,
+    tag_dict,
+    fnum,
+    undercond,
+    hparams,
+    arena,
+    debug,
+    coords,
 ):
     """Helper function for rule_based_video. Annotates a given frame with on-screen information
     about the recognised patterns"""
@@ -997,16 +997,16 @@ def tag_rulebased_frames(
 
 # noinspection PyProtectedMember,PyDefaultArgument
 def rule_based_video(
-        coordinates: Coordinates,
-        tracks: List,
-        videos: List,
-        vid_index: int,
-        tag_dict: pd.DataFrame,
-        frame_limit: int = np.inf,
-        recog_limit: int = 100,
-        path: str = os.path.join("."),
-        params: dict = {},
-        debug: bool = False,
+    coordinates: Coordinates,
+    tracks: List,
+    videos: List,
+    vid_index: int,
+    tag_dict: pd.DataFrame,
+    frame_limit: int = np.inf,
+    recog_limit: int = 100,
+    path: str = os.path.join("."),
+    params: dict = {},
+    debug: bool = False,
 ) -> True:
     """Renders a version of the input video with all rule-based taggings in place.
 
@@ -1080,8 +1080,8 @@ def rule_based_video(
         # Capture speeds
         try:
             if (
-                    list(frame_speeds.values())[0] == -np.inf
-                    or fnum % params["speed_pause"] == 0
+                list(frame_speeds.values())[0] == -np.inf
+                or fnum % params["speed_pause"] == 0
             ):
                 for _id in animal_ids:
                     frame_speeds[_id] = tag_dict[_id + undercond + "speed"][fnum]
@@ -1125,5 +1125,6 @@ def rule_based_video(
 
     return True
 
+
 # TODO:
 #    - Is border sniffing anything you might consider interesting?
diff --git a/deepof/train_model.py b/deepof/train_model.py
index fd3de2e913f3f719519a5ca8f5b956f2ce120ea5..4b55ef4570e29c0c5b42ab69a84a5e679569297b 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -77,7 +77,7 @@ parser.add_argument(
     "--hyperparameter-tuning",
     "-tune",
     help="Indicates whether hyperparameters should be tuned either using 'bayopt' of 'hyperband'. "
-         "See documentation for details",
+    "See documentation for details",
     type=str,
     default=False,
 )
@@ -85,7 +85,7 @@ parser.add_argument(
     "--hyperparameters",
     "-hp",
     help="Path pointing to a pickled dictionary of network hyperparameters. "
-         "Thought to be used with the output of hyperparameter tuning",
+    "Thought to be used with the output of hyperparameter tuning",
     type=str,
     default=None,
 )
@@ -93,8 +93,8 @@ parser.add_argument(
     "--input-type",
     "-d",
     help="Select an input type for the autoencoder hypermodels. "
-         "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
-         "Defaults to coords.",
+    "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
+    "Defaults to coords.",
     type=str,
     default="dists",
 )
@@ -123,8 +123,8 @@ parser.add_argument(
     "--latent-reg",
     "-lreg",
     help="Sets the strategy to regularize the latent mixture of Gaussians. "
-         "It has to be one of none, categorical (an elastic net penalty is applied to the categorical distribution),"
-         "variance (l2 penalty to the variance of the clusters) or categorical+variance. Defaults to none.",
+    "It has to be one of none, categorical (an elastic net penalty is applied to the categorical distribution),"
+    "variance (l2 penalty to the variance of the clusters) or categorical+variance. Defaults to none.",
     default="none",
     type=str,
 )
@@ -132,7 +132,7 @@ parser.add_argument(
     "--loss",
     "-l",
     help="Sets the loss function for the variational model. "
-         "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
+    "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
     default="ELBO+MMD",
     type=str,
 )
@@ -182,7 +182,7 @@ parser.add_argument(
     "--predictor",
     "-pred",
     help="Activates the prediction branch of the variational Seq 2 Seq model with the specified weight. "
-         "Defaults to 0.0 (inactive)",
+    "Defaults to 0.0 (inactive)",
     default=0.0,
     type=float,
 )
@@ -190,7 +190,7 @@ parser.add_argument(
     "--smooth-alpha",
     "-sa",
     help="Sets the exponential smoothing factor to apply to the input data. "
-         "Float between 0 and 1 (lower is more smooting)",
+    "Float between 0 and 1 (lower is more smooting)",
     type=float,
     default=0.99,
 )
@@ -436,11 +436,11 @@ else:
 
     # Saves the best hyperparameters
     with open(
-            os.path.join(
-                output_path,
-                "{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()),
-            ),
-            "wb",
+        os.path.join(
+            output_path,
+            "{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()),
+        ),
+        "wb",
     ) as handle:
         pickle.dump(
             best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index fe51d0fe7a5033e9c0d8cfc974b8ce6df23dbc4d..f623a6ebafaecedf5356260b98f334cebbcc734d 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -52,11 +52,11 @@ def load_treatments(train_path):
     to be loaded as metadata in the coordinates class"""
     try:
         with open(
-                os.path.join(
-                    train_path,
-                    [i for i in os.listdir(train_path) if i.endswith(".json")][0],
-                ),
-                "r",
+            os.path.join(
+                train_path,
+                [i for i in os.listdir(train_path) if i.endswith(".json")][0],
+            ),
+            "r",
         ) as handle:
             treatment_dict = json.load(handle)
     except IndexError:
@@ -66,20 +66,20 @@ def load_treatments(train_path):
 
 
 def get_callbacks(
-        X_train: np.array,
-        batch_size: int,
-        variational: bool,
-        phenotype_class: float,
-        predictor: float,
-        loss: str,
-        X_val: np.array = None,
-        cp: bool = False,
-        reg_cat_clusters: bool = False,
-        reg_cluster_variance: bool = False,
-        entropy_samples: int = 15000,
-        entropy_knn: int = 100,
-        logparam: dict = None,
-        outpath: str = ".",
+    X_train: np.array,
+    batch_size: int,
+    variational: bool,
+    phenotype_class: float,
+    predictor: float,
+    loss: str,
+    X_val: np.array = None,
+    cp: bool = False,
+    reg_cat_clusters: bool = False,
+    reg_cluster_variance: bool = False,
+    entropy_samples: int = 15000,
+    entropy_knn: int = 100,
+    logparam: dict = None,
+    outpath: str = ".",
 ) -> List[Union[Any]]:
     """Generates callbacks for model training, including:
     - run_ID: run name, with coarse parameter details;
@@ -197,14 +197,14 @@ def log_hyperparameters(phenotype_class: float, rec: str):
 
 # noinspection PyUnboundLocalVariable
 def tensorboard_metric_logging(
-        run_dir: str,
-        hpms: Any,
-        ae: Any,
-        X_val: np.ndarray,
-        y_val: np.ndarray,
-        phenotype_class: float,
-        predictor: float,
-        rec: str,
+    run_dir: str,
+    hpms: Any,
+    ae: Any,
+    X_val: np.ndarray,
+    y_val: np.ndarray,
+    phenotype_class: float,
+    predictor: float,
+    rec: str,
 ):
     """Autoencoder metric logging in tensorboard"""
 
@@ -246,29 +246,29 @@ def tensorboard_metric_logging(
 
 
 def autoencoder_fitting(
-        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
-        batch_size: int,
-        encoding_size: int,
-        epochs: int,
-        hparams: dict,
-        kl_warmup: int,
-        log_history: bool,
-        log_hparams: bool,
-        loss: str,
-        mmd_warmup: int,
-        montecarlo_kl: int,
-        n_components: int,
-        output_path: str,
-        phenotype_class: float,
-        predictor: float,
-        pretrained: str,
-        save_checkpoints: bool,
-        save_weights: bool,
-        variational: bool,
-        reg_cat_clusters: bool,
-        reg_cluster_variance: bool,
-        entropy_samples: int,
-        entropy_knn: int,
+    preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
+    batch_size: int,
+    encoding_size: int,
+    epochs: int,
+    hparams: dict,
+    kl_warmup: int,
+    log_history: bool,
+    log_hparams: bool,
+    loss: str,
+    mmd_warmup: int,
+    montecarlo_kl: int,
+    n_components: int,
+    output_path: str,
+    phenotype_class: float,
+    predictor: float,
+    pretrained: str,
+    save_checkpoints: bool,
+    save_weights: bool,
+    variational: bool,
+    reg_cat_clusters: bool,
+    reg_cluster_variance: bool,
+    entropy_samples: int,
+    entropy_knn: int,
 ):
     """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
 
@@ -313,7 +313,7 @@ def autoencoder_fitting(
         logparams, metrics = log_hyperparameters(phenotype_class, rec)
 
         with tf.summary.create_file_writer(
-                os.path.join(output_path, "hparams", run_ID)
+            os.path.join(output_path, "hparams", run_ID)
         ).as_default():
             hp.hparams_config(
                 hparams=logparams,
@@ -363,14 +363,14 @@ def autoencoder_fitting(
                 verbose=1,
                 validation_data=(X_val, X_val),
                 callbacks=cbacks
-                          + [
-                              CustomStopper(
-                                  monitor="val_loss",
-                                  patience=5,
-                                  restore_best_weights=True,
-                                  start_epoch=max(kl_warmup, mmd_warmup),
-                              ),
-                          ],
+                + [
+                    CustomStopper(
+                        monitor="val_loss",
+                        patience=5,
+                        restore_best_weights=True,
+                        start_epoch=max(kl_warmup, mmd_warmup),
+                    ),
+                ],
             )
 
             if save_weights:
@@ -440,23 +440,23 @@ def autoencoder_fitting(
 
 
 def tune_search(
-        data: List[np.array],
-        encoding_size: int,
-        hypertun_trials: int,
-        hpt_type: str,
-        hypermodel: str,
-        k: int,
-        kl_warmup_epochs: int,
-        loss: str,
-        mmd_warmup_epochs: int,
-        overlap_loss: float,
-        phenotype_class: float,
-        predictor: float,
-        project_name: str,
-        callbacks: List,
-        n_epochs: int = 30,
-        n_replicas: int = 1,
-        outpath: str = ".",
+    data: List[np.array],
+    encoding_size: int,
+    hypertun_trials: int,
+    hpt_type: str,
+    hypermodel: str,
+    k: int,
+    kl_warmup_epochs: int,
+    loss: str,
+    mmd_warmup_epochs: int,
+    overlap_loss: float,
+    phenotype_class: float,
+    predictor: float,
+    project_name: str,
+    callbacks: List,
+    n_epochs: int = 30,
+    n_replicas: int = 1,
+    outpath: str = ".",
 ) -> Union[bool, Tuple[Any, Any]]:
     """Define the search space using keras-tuner and bayesian optimization
 
@@ -495,7 +495,7 @@ def tune_search(
 
     if hypermodel == "S2SAE":  # pragma: no cover
         assert (
-                predictor == 0.0 and phenotype_class == 0.0
+            predictor == 0.0 and phenotype_class == 0.0
         ), "Prediction branches are only available for variational models. See documentation for more details"
         batch_size = 1
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
diff --git a/deepof/utils.py b/deepof/utils.py
index 929db44e467f5813f0d57d2581e5a33398cb08d7..a92db30b60137c3607ad37d52872282a4ef96e69 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -146,7 +146,7 @@ def tab2polar(cartesian_df: pd.DataFrame) -> pd.DataFrame:
 
 
 def compute_dist(
-        pair_array: np.array, arena_abs: int = 1, arena_rel: int = 1
+    pair_array: np.array, arena_abs: int = 1, arena_rel: int = 1
 ) -> pd.DataFrame:
     """Returns a pandas.DataFrame with the scaled distances between a pair of body parts.
 
@@ -169,7 +169,7 @@ def compute_dist(
 
 
 def bpart_distance(
-        dataframe: pd.DataFrame, arena_abs: int = 1, arena_rel: int = 1
+    dataframe: pd.DataFrame, arena_abs: int = 1, arena_rel: int = 1
 ) -> pd.DataFrame:
     """Returns a pandas.DataFrame with the scaled distances between all pairs of body parts.
 
@@ -208,7 +208,7 @@ def angle(a: np.array, b: np.array, c: np.array) -> np.array:
     bc = c - b
 
     cosine_angle = np.einsum("...i,...i", ba, bc) / (
-            np.linalg.norm(ba, axis=1) * np.linalg.norm(bc, axis=1)
+        np.linalg.norm(ba, axis=1) * np.linalg.norm(bc, axis=1)
     )
     ang = np.arccos(cosine_angle)
 
@@ -231,7 +231,7 @@ def angle_trio(bpart_array: np.array) -> np.array:
 
 
 def rotate(
-        p: np.array, angles: np.array, origin: np.array = np.array([0, 0])
+    p: np.array, angles: np.array, origin: np.array = np.array([0, 0])
 ) -> np.array:
     """Returns a numpy.array with the initial values rotated by angles radians
 
@@ -363,12 +363,12 @@ def moving_average(time_series: pd.Series, N: int = 5):
 
 
 def mask_outliers(
-        time_series: pd.DataFrame,
-        likelihood: pd.DataFrame,
-        likelihood_tolerance: float,
-        lag: int,
-        n_std: int,
-        mode: str,
+    time_series: pd.DataFrame,
+    likelihood: pd.DataFrame,
+    likelihood_tolerance: float,
+    lag: int,
+    n_std: int,
+    mode: str,
 ):
     """Returns a mask over the bivariate trajectory of a body part, identifying as True all detected outliers
 
@@ -404,13 +404,13 @@ def mask_outliers(
 
 
 def full_outlier_mask(
-        experiment: pd.DataFrame,
-        likelihood: pd.DataFrame,
-        likelihood_tolerance: float,
-        exclude: str,
-        lag: int,
-        n_std: int,
-        mode: str,
+    experiment: pd.DataFrame,
+    likelihood: pd.DataFrame,
+    likelihood_tolerance: float,
+    exclude: str,
+    lag: int,
+    n_std: int,
+    mode: str,
 ):
     """Iterates over all body parts of experiment, and outputs a dataframe where all x, y positions are
     replaced by a boolean mask, where True indicates an outlier
@@ -455,14 +455,14 @@ def full_outlier_mask(
 
 
 def interpolate_outliers(
-        experiment: pd.DataFrame,
-        likelihood: pd.DataFrame,
-        likelihood_tolerance: float,
-        exclude: str = "",
-        lag: int = 5,
-        n_std: int = 3,
-        mode: str = "or",
-        limit: int = 10,
+    experiment: pd.DataFrame,
+    likelihood: pd.DataFrame,
+    likelihood_tolerance: float,
+    exclude: str = "",
+    lag: int = 5,
+    n_std: int = 3,
+    mode: str = "or",
+    limit: int = 10,
 ):
     """Marks all outliers in experiment and replaces them using a univariate linear interpolation approach.
     Note that this approach only works for equally spaced data (constant camera acquisition rates).
@@ -496,13 +496,13 @@ def interpolate_outliers(
 
 
 def recognize_arena(
-        videos: list,
-        vid_index: int,
-        path: str = ".",
-        recoglimit: int = 100,
-        arena_type: str = "circular",
-        detection_mode: str = "cnn",
-        cnn_model: tf.keras.models.Model = None,
+    videos: list,
+    vid_index: int,
+    path: str = ".",
+    recoglimit: int = 100,
+    arena_type: str = "circular",
+    detection_mode: str = "cnn",
+    cnn_model: tf.keras.models.Model = None,
 ) -> Tuple[np.array, int, int]:
     """Returns numpy.array with information about the arena recognised from the first frames
     of the video. WARNING: estimates won't be reliable if the camera moves along the video.
@@ -567,9 +567,9 @@ def recognize_arena(
 
 
 def circular_arena_recognition(
-        frame: np.array,
-        detection_mode: str = "cnn",
-        cnn_model: tf.keras.models.Model = None,
+    frame: np.array,
+    detection_mode: str = "cnn",
+    cnn_model: tf.keras.models.Model = None,
 ) -> np.array:
     """Returns x,y position of the center, the lengths of the major and minor axes,
     and the angle of the recognised arena
@@ -596,7 +596,9 @@ def circular_arena_recognition(
             thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
 
         # Obtain contours from the image, and retain the largest one
-        cnts, _ = cv2.findContours(thresh.astype(np.int64), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_TC89_KCOS)
+        cnts, _ = cv2.findContours(
+            thresh.astype(np.int64), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_TC89_KCOS
+        )
         main_cnt = np.argmax([len(c) for c in cnts])
 
         # Detect the main ellipse containing the arena
@@ -635,13 +637,13 @@ def circular_arena_recognition(
 
 
 def rolling_speed(
-        dframe: pd.DatetimeIndex,
-        window: int = 3,
-        rounds: int = 3,
-        deriv: int = 1,
-        center: str = None,
-        shift: int = 2,
-        typ: str = "coords",
+    dframe: pd.DatetimeIndex,
+    window: int = 3,
+    rounds: int = 3,
+    deriv: int = 1,
+    center: str = None,
+    shift: int = 2,
+    typ: str = "coords",
 ) -> pd.DataFrame:
     """Returns the average speed over n frames in pixels per frame
 
@@ -674,14 +676,14 @@ def rolling_speed(
         features = 2 if der == 0 and typ == "coords" else 1
 
         distances = (
-                np.concatenate(
-                    [
-                        np.array(dframe).reshape([-1, features], order="C"),
-                        np.array(dframe.shift(shift)).reshape([-1, features], order="C"),
-                    ],
-                    axis=1,
-                )
-                / shift
+            np.concatenate(
+                [
+                    np.array(dframe).reshape([-1, features], order="C"),
+                    np.array(dframe.shift(shift)).reshape([-1, features], order="C"),
+                ],
+                axis=1,
+            )
+            / shift
         )
 
         distances = np.array(compute_dist(distances))
@@ -729,12 +731,12 @@ def gmm_compute(x: np.array, n_components: int, cv_type: str) -> list:
 
 
 def gmm_model_selection(
-        x: pd.DataFrame,
-        n_components_range: range,
-        part_size: int,
-        n_runs: int = 100,
-        n_cores: int = False,
-        cv_types: Tuple = ("spherical", "tied", "diag", "full"),
+    x: pd.DataFrame,
+    n_components_range: range,
+    part_size: int,
+    n_runs: int = 100,
+    n_cores: int = False,
+    cv_types: Tuple = ("spherical", "tied", "diag", "full"),
 ) -> Tuple[List[list], List[np.ndarray], Union[int, Any]]:
     """Runs GMM clustering model selection on the specified X dataframe, outputs the bic distribution per model,
     a vector with the median BICs and an object with the overall best model
@@ -791,10 +793,10 @@ def gmm_model_selection(
 
 
 def cluster_transition_matrix(
-        cluster_sequence: np.array,
-        nclusts: int,
-        autocorrelation: bool = True,
-        return_graph: bool = False,
+    cluster_sequence: np.array,
+    nclusts: int,
+    autocorrelation: bool = True,
+    return_graph: bool = False,
 ) -> Tuple[Union[nx.Graph, Any], np.ndarray]:
     """Computes the transition matrix between clusters and the autocorrelation in the sequence.
 
@@ -843,6 +845,7 @@ def cluster_transition_matrix(
 
     return trans_normed
 
+
 # TODO:
 #    - Add sequence plot to single_behaviour_analysis (show how the condition varies across a specified time window)
 #    - Add digging to rule_based_tagging
diff --git a/deepof/visuals.py b/deepof/visuals.py
index ad445ad82a6b8883bef26a0906ffed0f0e548834..f061a677ab380956973170c4faa4a1714e08d5dc 100644
--- a/deepof/visuals.py
+++ b/deepof/visuals.py
@@ -21,12 +21,12 @@ import seaborn as sns
 
 
 def plot_heatmap(
-        dframe: pd.DataFrame,
-        bodyparts: List,
-        xlim: tuple,
-        ylim: tuple,
-        save: str = False,
-        dpi: int = 200,
+    dframe: pd.DataFrame,
+    bodyparts: List,
+    xlim: tuple,
+    ylim: tuple,
+    save: str = False,
+    dpi: int = 200,
 ) -> plt.figure:
     """Returns a heatmap of the movement of a specific bodypart in the arena.
     If more than one bodypart is passed, it returns one subplot for each
@@ -73,13 +73,13 @@ def plot_heatmap(
 
 
 def model_comparison_plot(
-        bic: list,
-        m_bic: list,
-        n_components_range: range,
-        cov_plot: str,
-        save: str = False,
-        cv_types: tuple = ("spherical", "tied", "diag", "full"),
-        dpi: int = 200,
+    bic: list,
+    m_bic: list,
+    n_components_range: range,
+    cov_plot: str,
+    save: str = False,
+    cv_types: tuple = ("spherical", "tied", "diag", "full"),
+    dpi: int = 200,
 ) -> plt.figure:
     """
 
@@ -119,7 +119,7 @@ def model_comparison_plot(
         bars.append(
             spl.bar(
                 xpos,
-                m_bic[i * len(n_components_range): (i + 1) * len(n_components_range)],
+                m_bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
                 color=color,
                 width=0.2,
             )
@@ -128,9 +128,9 @@ def model_comparison_plot(
     spl.set_xticks(n_components_range)
     plt.title("BIC score per model")
     xpos = (
-            np.mod(m_bic.argmin(), len(n_components_range))
-            + 0.5
-            + 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
+        np.mod(m_bic.argmin(), len(n_components_range))
+        + 0.5
+        + 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
     )
     # noinspection PyArgumentList
     spl.text(xpos, m_bic.min() * 0.97 + 0.1 * m_bic.max(), "*", fontsize=14)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 65a234d8eef46ca98c49453963b02bb8cefdbb75..c2373d8f880846732cdd0fc7b694ac95ab277921 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -407,42 +407,51 @@ def test_interpolate_outliers(mode):
 
 
 @settings(deadline=None)
-@given(indexes=st.data())
-def test_recognize_arena_and_subfunctions(indexes):
+@given(
+    indexes=st.data(), detection_type=st.one_of(st.just("rule-based"), st.just("cnn"))
+)
+def test_recognize_arena_and_subfunctions(indexes, detection_type):
 
     path = os.path.join(".", "tests", "test_examples", "test_single_topview", "Videos")
     videos = [i for i in os.listdir(path) if i.endswith("mp4")]
+    cnn_path = os.path.join("deepof", "trained_models")
+    cnn_model = os.path.join(
+        cnn_path, [i for i in os.listdir(cnn_path) if i.startswith("elliptic")][0]
+    )
 
     vid_index = indexes.draw(st.integers(min_value=0, max_value=len(videos) - 1))
     recoglimit = indexes.draw(st.integers(min_value=1, max_value=10))
 
-    assert deepof.utils.recognize_arena(videos, vid_index, path, recoglimit, "")[0] == 0
     assert (
-        len(
-            deepof.utils.recognize_arena(
-                videos, vid_index, path, recoglimit, "circular"
-            )
-        )
-        == 3
+        deepof.utils.recognize_arena(
+            videos,
+            vid_index,
+            path,
+            recoglimit,
+            "",
+            detection_mode=detection_type,
+            cnn_model=cnn_model,
+        )[0]
+        == 0
     )
-    assert (
-        len(
-            deepof.utils.recognize_arena(
-                videos, vid_index, path, recoglimit, "circular"
-            )[0]
-        )
-        == 3
+
+    arena = deepof.utils.recognize_arena(
+        videos,
+        vid_index,
+        path,
+        recoglimit,
+        "circular",
+        detection_mode=detection_type,
+        cnn_model=cnn_model,
     )
+    assert len(arena) == 3
+    assert len(arena[0]) == 3
     assert isinstance(
-        deepof.utils.recognize_arena(videos, vid_index, path, recoglimit, "circular")[
-            1
-        ],
+        arena[1],
         int,
     )
     assert isinstance(
-        deepof.utils.recognize_arena(videos, vid_index, path, recoglimit, "circular")[
-            2
-        ],
+        arena[2],
         int,
     )