diff --git a/deepof/data.py b/deepof/data.py
index 58685a5323da6c30aded0981bf4f1f06f627c3de..0679e680af6ebb30155b4459096c61f4f1f3205d 100644
--- a/deepof/data.py
+++ b/deepof/data.py
@@ -14,29 +14,30 @@ Contains methods for generating training and test sets ready for model training.
 
 """
 
+import os
+import warnings
 from collections import defaultdict
-from joblib import delayed, Parallel, parallel_backend
-from typing import Dict, List, Tuple, Union
 from multiprocessing import cpu_count
+from typing import Dict, List, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from joblib import delayed, Parallel, parallel_backend
 from pkg_resources import resource_filename
 from sklearn import random_projection
 from sklearn.decomposition import KernelPCA
-from sklearn.experimental import enable_iterative_imputer
 from sklearn.impute import IterativeImputer
 from sklearn.manifold import TSNE
 from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
 from tqdm import tqdm
+
 import deepof.models
 import deepof.pose_utils
+import deepof.train_utils
 import deepof.utils
 import deepof.visuals
-import deepof.train_utils
-import matplotlib.pyplot as plt
-import numpy as np
-import os
-import pandas as pd
-import tensorflow as tf
-import warnings
 
 # Remove excessive logging from tensorflow
 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
@@ -58,23 +59,23 @@ class project:
     """
 
     def __init__(
-        self,
-        animal_ids: List = tuple([""]),
-        arena: str = "circular",
-        arena_detection: str = "cnn",
-        arena_dims: tuple = (1,),
-        enable_iterative_imputation: bool = None,
-        exclude_bodyparts: List = tuple([""]),
-        exp_conditions: dict = None,
-        interpolate_outliers: bool = True,
-        interpolation_limit: int = 5,
-        interpolation_std: int = 5,
-        likelihood_tol: float = 0.25,
-        model: str = "mouse_topview",
-        path: str = deepof.utils.os.path.join("."),
-        smooth_alpha: float = 0.99,
-        table_format: str = "autodetect",
-        video_format: str = ".mp4",
+            self,
+            animal_ids: List = tuple([""]),
+            arena: str = "circular",
+            arena_detection: str = "cnn",
+            arena_dims: tuple = (1,),
+            enable_iterative_imputation: bool = None,
+            exclude_bodyparts: List = tuple([""]),
+            exp_conditions: dict = None,
+            interpolate_outliers: bool = True,
+            interpolation_limit: int = 5,
+            interpolation_std: int = 5,
+            likelihood_tol: float = 0.25,
+            model: str = "mouse_topview",
+            path: str = deepof.utils.os.path.join("."),
+            smooth_alpha: float = 0.99,
+            table_format: str = "autodetect",
+            video_format: str = ".mp4",
     ):
 
         # Set working paths
@@ -115,7 +116,6 @@ class project:
         self.arena_dims = arena_dims
         self.ellipse_detection = None
         if arena == "circular" and arena_detection == "cnn":
-
             self.ellipse_detection = tf.keras.models.load_model(
                 [
                     os.path.join(self.trained_path, i)
@@ -286,8 +286,8 @@ class project:
                 ).T.index.remove_unused_levels()
 
                 tab = value.loc[
-                    :, [i for i in value.columns.levels[0] if i not in lablist]
-                ]
+                      :, [i for i in value.columns.levels[0] if i not in lablist]
+                      ]
 
                 tab.columns = tabcols
 
@@ -361,14 +361,14 @@ class project:
 
         for key in distance_dict.keys():
             distance_dict[key] = distance_dict[key].loc[
-                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
-            ]
+                                 :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
+                                 ]
 
         if self.ego:
             for key, val in distance_dict.items():
                 distance_dict[key] = val.loc[
-                    :, [dist for dist in val.columns if self.ego in dist]
-                ]
+                                     :, [dist for dist in val.columns if self.ego in dist]
+                                     ]
 
         return distance_dict
 
@@ -471,20 +471,20 @@ class coordinates:
     """
 
     def __init__(
-        self,
-        arena: str,
-        arena_detection: str,
-        arena_dims: np.array,
-        path: str,
-        quality: dict,
-        scales: np.array,
-        tables: dict,
-        videos: list,
-        angles: dict = None,
-        animal_ids: List = tuple([""]),
-        distances: dict = None,
-        exp_conditions: dict = None,
-        ellipse_detection: tf.keras.models.Model = None,
+            self,
+            arena: str,
+            arena_detection: str,
+            arena_dims: np.array,
+            path: str,
+            quality: dict,
+            scales: np.array,
+            tables: dict,
+            videos: list,
+            angles: dict = None,
+            animal_ids: List = tuple([""]),
+            distances: dict = None,
+            exp_conditions: dict = None,
+            ellipse_detection: tf.keras.models.Model = None,
     ):
         self._animal_ids = animal_ids
         self._arena = arena
@@ -509,15 +509,15 @@ class coordinates:
             return "deepof analysis of {} videos".format(len(self._videos))
 
     def get_coords(
-        self,
-        center: str = "arena",
-        polar: bool = False,
-        speed: int = 0,
-        length: str = None,
-        align: bool = False,
-        align_inplace: bool = False,
-        propagate_labels: bool = False,
-        propagate_annotations: Dict = False,
+            self,
+            center: str = "arena",
+            polar: bool = False,
+            speed: int = 0,
+            length: str = None,
+            align: bool = False,
+            align_inplace: bool = False,
+            propagate_labels: bool = False,
+            propagate_annotations: Dict = False,
     ) -> Table_dict:
         """
         Returns a table_dict object with the coordinates of each animal as values.
@@ -557,23 +557,23 @@ class coordinates:
 
                     try:
                         value.loc[:, (slice("coords"), ["x"])] = (
-                            value.loc[:, (slice("coords"), ["x"])]
-                            - self._scales[i][0] / 2
+                                value.loc[:, (slice("coords"), ["x"])]
+                                - self._scales[i][0] / 2
                         )
 
                         value.loc[:, (slice("coords"), ["y"])] = (
-                            value.loc[:, (slice("coords"), ["y"])]
-                            - self._scales[i][1] / 2
+                                value.loc[:, (slice("coords"), ["y"])]
+                                - self._scales[i][1] / 2
                         )
                     except KeyError:
                         value.loc[:, (slice("coords"), ["rho"])] = (
-                            value.loc[:, (slice("coords"), ["rho"])]
-                            - self._scales[i][0] / 2
+                                value.loc[:, (slice("coords"), ["rho"])]
+                                - self._scales[i][0] / 2
                         )
 
                         value.loc[:, (slice("coords"), ["phi"])] = (
-                            value.loc[:, (slice("coords"), ["phi"])]
-                            - self._scales[i][1] / 2
+                                value.loc[:, (slice("coords"), ["phi"])]
+                                - self._scales[i][1] / 2
                         )
 
         elif isinstance(center, str) and center != "arena":
@@ -582,24 +582,24 @@ class coordinates:
 
                 try:
                     value.loc[:, (slice("coords"), ["x"])] = value.loc[
-                        :, (slice("coords"), ["x"])
-                    ].subtract(value[center]["x"], axis=0)
+                                                             :, (slice("coords"), ["x"])
+                                                             ].subtract(value[center]["x"], axis=0)
 
                     value.loc[:, (slice("coords"), ["y"])] = value.loc[
-                        :, (slice("coords"), ["y"])
-                    ].subtract(value[center]["y"], axis=0)
+                                                             :, (slice("coords"), ["y"])
+                                                             ].subtract(value[center]["y"], axis=0)
                 except KeyError:
                     value.loc[:, (slice("coords"), ["rho"])] = value.loc[
-                        :, (slice("coords"), ["rho"])
-                    ].subtract(value[center]["rho"], axis=0)
+                                                               :, (slice("coords"), ["rho"])
+                                                               ].subtract(value[center]["rho"], axis=0)
 
                     value.loc[:, (slice("coords"), ["phi"])] = value.loc[
-                        :, (slice("coords"), ["phi"])
-                    ].subtract(value[center]["phi"], axis=0)
+                                                               :, (slice("coords"), ["phi"])
+                                                               ].subtract(value[center]["phi"], axis=0)
 
                 tabs[key] = value.loc[
-                    :, [tab for tab in value.columns if tab[0] != center]
-                ]
+                            :, [tab for tab in value.columns if tab[0] != center]
+                            ]
 
         if speed:
             for key, tab in tabs.items():
@@ -614,16 +614,16 @@ class coordinates:
 
         if align:
             assert (
-                align in list(tabs.values())[0].columns.levels[0]
+                    align in list(tabs.values())[0].columns.levels[0]
             ), "align must be set to the name of a bodypart"
 
             for key, tab in tabs.items():
                 # Bring forward the column to align
                 columns = [i for i in tab.columns if align not in i]
                 columns = [
-                    (align, ("phi" if polar else "x")),
-                    (align, ("rho" if polar else "y")),
-                ] + columns
+                              (align, ("phi" if polar else "x")),
+                              (align, ("rho" if polar else "y")),
+                          ] + columns
                 tab = tab[columns]
                 tabs[key] = tab
 
@@ -658,11 +658,11 @@ class coordinates:
         )
 
     def get_distances(
-        self,
-        speed: int = 0,
-        length: str = None,
-        propagate_labels: bool = False,
-        propagate_annotations: Dict = False,
+            self,
+            speed: int = 0,
+            length: str = None,
+            propagate_labels: bool = False,
+            propagate_annotations: Dict = False,
     ) -> Table_dict:
         """
         Returns a table_dict object with the distances between body parts animal as values.
@@ -719,12 +719,12 @@ class coordinates:
         )  # pragma: no cover
 
     def get_angles(
-        self,
-        degrees: bool = False,
-        speed: int = 0,
-        length: str = None,
-        propagate_labels: bool = False,
-        propagate_annotations: Dict = False,
+            self,
+            degrees: bool = False,
+            speed: int = 0,
+            length: str = None,
+            propagate_labels: bool = False,
+            propagate_annotations: Dict = False,
     ) -> Table_dict:
         """
         Returns a table_dict object with the angles between body parts animal as values.
@@ -810,11 +810,11 @@ class coordinates:
 
     # noinspection PyDefaultArgument
     def rule_based_annotation(
-        self,
-        params: Dict = {},
-        video_output: bool = False,
-        frame_limit: int = np.inf,
-        debug: bool = False,
+            self,
+            params: Dict = {},
+            video_output: bool = False,
+            frame_limit: int = np.inf,
+            debug: bool = False,
     ) -> Table_dict:
         """Annotates coordinates using a simple rule-based pipeline"""
 
@@ -882,29 +882,29 @@ class coordinates:
 
     @staticmethod
     def deep_unsupervised_embedding(
-        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
-        batch_size: int = 256,
-        encoding_size: int = 4,
-        epochs: int = 35,
-        hparams: dict = None,
-        kl_warmup: int = 0,
-        log_history: bool = True,
-        log_hparams: bool = False,
-        loss: str = "ELBO",
-        mmd_warmup: int = 0,
-        montecarlo_kl: int = 10,
-        n_components: int = 25,
-        output_path: str = ".",
-        phenotype_class: float = 0,
-        predictor: float = 0,
-        pretrained: str = False,
-        save_checkpoints: bool = False,
-        save_weights: bool = True,
-        variational: bool = True,
-        reg_cat_clusters: bool = False,
-        reg_cluster_variance: bool = False,
-        entropy_samples: int = 10000,
-        entropy_knn: int = 100,
+            preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
+            batch_size: int = 256,
+            encoding_size: int = 4,
+            epochs: int = 35,
+            hparams: dict = None,
+            kl_warmup: int = 0,
+            log_history: bool = True,
+            log_hparams: bool = False,
+            loss: str = "ELBO",
+            mmd_warmup: int = 0,
+            montecarlo_kl: int = 10,
+            n_components: int = 25,
+            output_path: str = ".",
+            phenotype_class: float = 0,
+            predictor: float = 0,
+            pretrained: str = False,
+            save_checkpoints: bool = False,
+            save_weights: bool = True,
+            variational: bool = True,
+            reg_cat_clusters: bool = False,
+            reg_cluster_variance: bool = False,
+            entropy_samples: int = 10000,
+            entropy_knn: int = 100,
     ) -> Tuple:
         """
         Annotates coordinates using an unsupervised autoencoder.
@@ -982,15 +982,15 @@ class table_dict(dict):
     """
 
     def __init__(
-        self,
-        tabs: Dict,
-        typ: str,
-        arena: str = None,
-        arena_dims: np.array = None,
-        center: str = None,
-        polar: bool = None,
-        propagate_labels: bool = False,
-        propagate_annotations: Dict = False,
+            self,
+            tabs: Dict,
+            typ: str,
+            arena: str = None,
+            arena_dims: np.array = None,
+            center: str = None,
+            polar: bool = None,
+            propagate_labels: bool = False,
+            propagate_annotations: Dict = False,
     ):
         super().__init__(tabs)
         self._type = typ
@@ -1013,13 +1013,13 @@ class table_dict(dict):
 
     # noinspection PyTypeChecker
     def plot_heatmaps(
-        self,
-        bodyparts: list,
-        xlim: float = None,
-        ylim: float = None,
-        save: bool = False,
-        i: int = 0,
-        dpi: int = 100,
+            self,
+            bodyparts: list,
+            xlim: float = None,
+            ylim: float = None,
+            save: bool = False,
+            i: int = 0,
+            dpi: int = 100,
     ) -> plt.figure:
         """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
 
@@ -1045,9 +1045,9 @@ class table_dict(dict):
             return heatmaps
 
     def get_training_set(
-        self,
-        test_videos: int = 0,
-        encode_labels: bool = True,
+            self,
+            test_videos: int = 0,
+            encode_labels: bool = True,
     ) -> Tuple[np.ndarray, list, Union[np.ndarray, list], list]:
         """Generates training and test sets as numpy.array objects for model training"""
 
@@ -1106,17 +1106,17 @@ class table_dict(dict):
 
     # noinspection PyTypeChecker,PyGlobalUndefined
     def preprocess(
-        self,
-        window_size: int = 1,
-        window_step: int = 1,
-        scale: str = "standard",
-        test_videos: int = 0,
-        verbose: bool = False,
-        conv_filter: bool = None,
-        sigma: float = 1.0,
-        shift: float = 0.0,
-        shuffle: bool = False,
-        align: str = False,
+            self,
+            window_size: int = 1,
+            window_step: int = 1,
+            scale: str = "standard",
+            test_videos: int = 0,
+            verbose: bool = False,
+            conv_filter: bool = None,
+            sigma: float = 1.0,
+            shift: float = 0.0,
+            shuffle: bool = False,
+            align: str = False,
     ) -> np.ndarray:
         """
 
@@ -1243,7 +1243,7 @@ class table_dict(dict):
         return X_train, y_train, np.array(X_test), np.array(y_test)
 
     def random_projection(
-        self, n_components: int = None, sample: int = 1000
+            self, n_components: int = None, sample: int = 1000
     ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
         """Returns a training set generated from the 2D original data (time x features) and a random projection
         to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
@@ -1265,7 +1265,7 @@ class table_dict(dict):
         return X, rproj
 
     def pca(
-        self, n_components: int = None, sample: int = 1000, kernel: str = "linear"
+            self, n_components: int = None, sample: int = 1000, kernel: str = "linear"
     ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
         """Returns a training set generated from the 2D original data (time x features) and a PCA projection
         to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
@@ -1287,7 +1287,7 @@ class table_dict(dict):
         return X, pca
 
     def tsne(
-        self, n_components: int = None, sample: int = 1000, perplexity: int = 30
+            self, n_components: int = None, sample: int = 1000, perplexity: int = 30
     ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
         """Returns a training set generated from the 2D original data (time x features) and a PCA projection
         to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
@@ -1331,7 +1331,6 @@ def merge_tables(*args):
 
     return merged_tables
 
-
 # TODO:
 #   - Generate ragged training array using a metric (acceleration, maybe?)
 #   - Use something like Dynamic Time Warping to put all instances in the same length
diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index 2ab5fc46407348152c00801e9a8047585dea3c53..51a24f7e443fcc37826d7686663b3591e08ef2de 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -8,10 +8,11 @@ keras hypermodels for hyperparameter tuning of deep autoencoders
 
 """
 
+import tensorflow_probability as tfp
 from kerastuner import HyperModel
-import deepof.models
+
 import deepof.model_utils
-import tensorflow_probability as tfp
+import deepof.models
 
 tfd = tfp.distributions
 tfpl = tfp.layers
@@ -92,18 +93,18 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
     """Hyperparameter tuning pipeline for deepof.models.SEQ_2_SEQ_GMVAE"""
 
     def __init__(
-        self,
-        input_shape: tuple,
-        encoding: int,
-        kl_warmup_epochs: int = 0,
-        learn_rate: float = 1e-3,
-        loss: str = "ELBO+MMD",
-        mmd_warmup_epochs: int = 0,
-        number_of_components: int = 10,
-        overlap_loss: float = False,
-        phenotype_predictor: float = 0.0,
-        predictor: float = 0.0,
-        prior: str = "standard_normal",
+            self,
+            input_shape: tuple,
+            encoding: int,
+            kl_warmup_epochs: int = 0,
+            learn_rate: float = 1e-3,
+            loss: str = "ELBO+MMD",
+            mmd_warmup_epochs: int = 0,
+            number_of_components: int = 10,
+            overlap_loss: float = False,
+            phenotype_predictor: float = 0.0,
+            predictor: float = 0.0,
+            prior: str = "standard_normal",
     ):
         super().__init__()
         self.input_shape = input_shape
@@ -119,7 +120,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.prior = prior
 
         assert (
-            "ELBO" in self.loss or "MMD" in self.loss
+                "ELBO" in self.loss or "MMD" in self.loss
         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
 
     def get_hparams(self, hp):
@@ -191,7 +192,6 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
         return gmvaep
 
-
 # TODO:
 #    - We can add as many parameters as we want to the hypermodel!
 #    with this implementation, predictor, warmup, loss and even number of components can be tuned using BayOpt
diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index dfa1e3ec0ffe97a5ab833b395a3d12bd331e5d71..f5f54311441378729baaaeb5a54bc6abcb043ddb 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -10,15 +10,16 @@ Functions and general utilities for the deepof tensorflow models. See documentat
 
 from itertools import combinations
 from typing import Any, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+import tensorflow_probability as tfp
 from scipy.stats import entropy
 from sklearn.neighbors import NearestNeighbors
 from tensorflow.keras import backend as K
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
-import matplotlib.pyplot as plt
-import numpy as np
-import tensorflow as tf
-import tensorflow_probability as tfp
 
 tfd = tfp.distributions
 tfpl = tfp.layers
@@ -44,7 +45,7 @@ class exponential_learning_rate(tf.keras.callbacks.Callback):
 
 
 def find_learning_rate(
-    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
+        model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
 ):
     """Trains the provided model for an epoch with an exponentially increasing learning rate"""
 
@@ -123,9 +124,9 @@ def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
     y_kernel = compute_kernel(y, y)
     xy_kernel = compute_kernel(x, y)
     mmd = (
-        tf.reduce_mean(x_kernel)
-        + tf.reduce_mean(y_kernel)
-        - 2 * tf.reduce_mean(xy_kernel)
+            tf.reduce_mean(x_kernel)
+            + tf.reduce_mean(y_kernel)
+            - 2 * tf.reduce_mean(xy_kernel)
     )
     return mmd
 
@@ -140,13 +141,13 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
     """
 
     def __init__(
-        self,
-        iterations: int,
-        max_rate: float,
-        start_rate: float = None,
-        last_iterations: int = None,
-        last_rate: float = None,
-        log_dir: str = ".",
+            self,
+            iterations: int,
+            max_rate: float,
+            start_rate: float = None,
+            last_iterations: int = None,
+            last_rate: float = None,
+            log_dir: str = ".",
     ):
         super().__init__()
         self.iterations = iterations
@@ -212,13 +213,13 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
     """
 
     def __init__(
-        self,
-        encoding_dim: int,
-        variational: bool = True,
-        validation_data: np.ndarray = None,
-        k: int = 100,
-        samples: int = 10000,
-        log_dir: str = ".",
+            self,
+            encoding_dim: int,
+            variational: bool = True,
+            validation_data: np.ndarray = None,
+            k: int = 100,
+            samples: int = 10000,
+            log_dir: str = ".",
     ):
         super().__init__()
         self.enc = encoding_dim
@@ -270,7 +271,6 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
             purity_vector = np.zeros(self.samples)
 
             for i, sample in enumerate(random_idxs):
-
                 neighborhood = knn.kneighbors(
                     encoding[sample][np.newaxis, :], self.k, return_distance=False
                 ).flatten()
@@ -531,7 +531,7 @@ class Cluster_overlap(Layer):
         dists = []
         for k in range(self.n_components):
             locs = (target[..., : self.lat_dims, k],)
-            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
+            scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
 
             dists.append(
                 tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
diff --git a/deepof/models.py b/deepof/models.py
index ba305c7f23f3d90b4f9eb7cb5d0f762d8351039a..4c0648b4f38b8ae97783075db4e0b33af7c86c74 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -9,9 +9,11 @@ deep autoencoder models for unsupervised pose detection
 """
 
 from typing import Any, Dict, Tuple
+
+import tensorflow as tf
+import tensorflow_probability as tfp
 from tensorflow.keras import Input, Model, Sequential
 from tensorflow.keras.activations import softplus
-from tensorflow.keras.callbacks import LambdaCallback
 from tensorflow.keras.constraints import UnitNorm
 from tensorflow.keras.initializers import he_uniform, Orthogonal
 from tensorflow.keras.layers import BatchNormalization, Bidirectional
@@ -19,9 +21,8 @@ from tensorflow.keras.layers import Dense, Dropout, LSTM
 from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
 from tensorflow.keras.losses import Huber
 from tensorflow.keras.optimizers import Nadam
+
 import deepof.model_utils
-import tensorflow as tf
-import tensorflow_probability as tfp
 
 tfb = tfp.bijectors
 tfd = tfp.distributions
@@ -33,9 +34,9 @@ class SEQ_2_SEQ_AE:
     """  Simple sequence to sequence autoencoder implemented with tf.keras """
 
     def __init__(
-        self,
-        architecture_hparams: Dict = {},
-        huber_delta: float = 1.0,
+            self,
+            architecture_hparams: Dict = {},
+            huber_delta: float = 1.0,
     ):
         self.hparams = self.get_hparams(architecture_hparams)
         self.CONV_filters = self.hparams["units_conv"]
@@ -170,8 +171,8 @@ class SEQ_2_SEQ_AE:
         )
 
     def build(
-        self,
-        input_shape: tuple,
+            self,
+            input_shape: tuple,
     ) -> Tuple[Any, Any, Any]:
         """Builds the tf.keras model"""
 
@@ -241,22 +242,22 @@ class SEQ_2_SEQ_GMVAE:
     """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
 
     def __init__(
-        self,
-        architecture_hparams: dict = {},
-        batch_size: int = 256,
-        compile_model: bool = True,
-        encoding: int = 6,
-        kl_warmup_epochs: int = 20,
-        loss: str = "ELBO",
-        mmd_warmup_epochs: int = 20,
-        montecarlo_kl: int = 1,
-        neuron_control: bool = False,
-        number_of_components: int = 1,
-        overlap_loss: float = 0.0,
-        phenotype_prediction: float = 0.0,
-        predictor: float = 0.0,
-        reg_cat_clusters: bool = False,
-        reg_cluster_variance: bool = False,
+            self,
+            architecture_hparams: dict = {},
+            batch_size: int = 256,
+            compile_model: bool = True,
+            encoding: int = 6,
+            kl_warmup_epochs: int = 20,
+            loss: str = "ELBO",
+            mmd_warmup_epochs: int = 20,
+            montecarlo_kl: int = 1,
+            neuron_control: bool = False,
+            number_of_components: int = 1,
+            overlap_loss: float = 0.0,
+            phenotype_prediction: float = 0.0,
+            predictor: float = 0.0,
+            reg_cat_clusters: bool = False,
+            reg_cluster_variance: bool = False,
     ):
         self.hparams = self.get_hparams(architecture_hparams)
         self.batch_size = batch_size
@@ -289,7 +290,7 @@ class SEQ_2_SEQ_GMVAE:
         self.reg_cluster_variance = reg_cluster_variance
 
         assert (
-            "ELBO" in self.loss or "MMD" in self.loss
+                "ELBO" in self.loss or "MMD" in self.loss
         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
 
     @property
@@ -612,7 +613,7 @@ class SEQ_2_SEQ_GMVAE:
                     tfd.Independent(
                         tfd.Normal(
                             loc=gauss[1][..., : self.ENCODING, k],
-                            scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
+                            scale=softplus(gauss[1][..., self.ENCODING:, k]) + 1e-5,
                         ),
                         reinterpreted_batch_ndims=1,
                     )
@@ -624,7 +625,6 @@ class SEQ_2_SEQ_GMVAE:
 
         # Define and control custom loss functions
         if "ELBO" in self.loss:
-
             kl_warm_up_iters = tf.cast(
                 self.kl_warmup * (input_shape[0] // self.batch_size + 1),
                 tf.int64,
@@ -640,7 +640,6 @@ class SEQ_2_SEQ_GMVAE:
             )(z)
 
         if "MMD" in self.loss:
-
             mmd_warm_up_iters = tf.cast(
                 self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
                 tf.int64,
@@ -762,7 +761,6 @@ class SEQ_2_SEQ_GMVAE:
     def prior(self, value):
         self._prior = value
 
-
 # TODO:
 #       - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
 #       - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py
index 5deb463bafa961642acd25dca9a08ad26a0bd515..7da15d5e69ed4c74d4c7d3bee0abdfa727519e1c 100644
--- a/deepof/pose_utils.py
+++ b/deepof/pose_utils.py
@@ -8,19 +8,21 @@ Functions and general utilities for rule-based pose estimation. See documentatio
 
 """
 
+import os
+import warnings
 from itertools import combinations
-from scipy import stats
 from typing import Any, List, NewType
+
 import cv2
-import deepof.utils
 import matplotlib.pyplot as plt
 import numpy as np
-import os
 import pandas as pd
 import regex as re
 import seaborn as sns
 import tensorflow as tf
-import warnings
+from scipy import stats
+
+import deepof.utils
 
 # Ignore warning with no downstream effect
 warnings.filterwarnings("ignore", message="All-NaN slice encountered")
@@ -30,12 +32,12 @@ Coordinates = NewType("Coordinates", Any)
 
 
 def close_single_contact(
-    pos_dframe: pd.DataFrame,
-    left: str,
-    right: str,
-    tol: float,
-    arena_abs: int,
-    arena_rel: int,
+        pos_dframe: pd.DataFrame,
+        left: str,
+        right: str,
+        tol: float,
+        arena_abs: int,
+        arena_rel: int,
 ) -> np.array:
     """Returns a boolean array that's True if the specified body parts are closer than tol.
 
@@ -56,8 +58,8 @@ def close_single_contact(
 
     if isinstance(right, str):
         close_contact = (
-            np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
-        ) / arena_rel < tol
+                                np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
+                        ) / arena_rel < tol
 
     elif isinstance(right, list):
         close_contact = np.any(
@@ -74,15 +76,15 @@ def close_single_contact(
 
 
 def close_double_contact(
-    pos_dframe: pd.DataFrame,
-    left1: str,
-    left2: str,
-    right1: str,
-    right2: str,
-    tol: float,
-    arena_abs: int,
-    arena_rel: int,
-    rev: bool = False,
+        pos_dframe: pd.DataFrame,
+        left1: str,
+        left2: str,
+        right1: str,
+        right2: str,
+        tol: float,
+        arena_abs: int,
+        arena_rel: int,
+        rev: bool = False,
 ) -> np.array:
     """Returns a boolean array that's True if the specified body parts are closer than tol.
 
@@ -104,25 +106,25 @@ def close_double_contact(
 
     if rev:
         double_contact = (
-            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        ) & (
-            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        )
+                                 (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         ) & (
+                                 (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         )
 
     else:
         double_contact = (
-            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        ) & (
-            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        )
+                                 (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         ) & (
+                                 (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         )
 
     return double_contact
 
@@ -150,12 +152,12 @@ def outside_ellipse(x, y, e_center, e_axes, e_angle, threshold=0.0):
 
 
 def climb_wall(
-    arena_type: str,
-    arena: np.array,
-    pos_dict: pd.DataFrame,
-    tol: float,
-    nose: str,
-    centered_data: bool = False,
+        arena_type: str,
+        arena: np.array,
+        pos_dict: pd.DataFrame,
+        tol: float,
+        nose: str,
+        centered_data: bool = False,
 ) -> np.array:
     """Returns True if the specified mouse is climbing the wall
 
@@ -195,16 +197,16 @@ def climb_wall(
 
 
 def sniff_object(
-    speed_dframe: pd.DataFrame,
-    arena_type: str,
-    arena: np.array,
-    pos_dict: pd.DataFrame,
-    tol: float,
-    tol_speed: float,
-    nose: str,
-    centered_data: bool = False,
-    object: str = "arena",
-    animal_id: str = "",
+        speed_dframe: pd.DataFrame,
+        arena_type: str,
+        arena: np.array,
+        pos_dict: pd.DataFrame,
+        tol: float,
+        tol_speed: float,
+        nose: str,
+        centered_data: bool = False,
+        object: str = "arena",
+        animal_id: str = "",
 ):
     """Returns True if the specified mouse is sniffing an object
 
@@ -267,11 +269,11 @@ def sniff_object(
 
 
 def huddle(
-    pos_dframe: pd.DataFrame,
-    speed_dframe: pd.DataFrame,
-    tol_forward: float,
-    tol_speed: float,
-    animal_id: str = "",
+        pos_dframe: pd.DataFrame,
+        speed_dframe: pd.DataFrame,
+        tol_forward: float,
+        tol_speed: float,
+        animal_id: str = "",
 ) -> np.array:
     """Returns true when the mouse is huddling using simple rules.
 
@@ -292,18 +294,18 @@ def huddle(
         animal_id += "_"
 
     forward = (
-        np.linalg.norm(
-            pos_dframe[animal_id + "Left_bhip"] - pos_dframe[animal_id + "Left_fhip"],
-            axis=1,
-        )
-        < tol_forward
-    ) | (
-        np.linalg.norm(
-            pos_dframe[animal_id + "Right_bhip"] - pos_dframe[animal_id + "Right_fhip"],
-            axis=1,
-        )
-        < tol_forward
-    )
+                      np.linalg.norm(
+                          pos_dframe[animal_id + "Left_bhip"] - pos_dframe[animal_id + "Left_fhip"],
+                          axis=1,
+                      )
+                      < tol_forward
+              ) | (
+                      np.linalg.norm(
+                          pos_dframe[animal_id + "Right_bhip"] - pos_dframe[animal_id + "Right_fhip"],
+                          axis=1,
+                      )
+                      < tol_forward
+              )
 
     speed = speed_dframe[animal_id + "Center"] < tol_speed
     hudd = forward & speed
@@ -312,11 +314,11 @@ def huddle(
 
 
 def dig(
-    speed_dframe: pd.DataFrame,
-    likelihood_dframe: pd.DataFrame,
-    tol_speed: float,
-    tol_likelihood: float,
-    animal_id: str = "",
+        speed_dframe: pd.DataFrame,
+        likelihood_dframe: pd.DataFrame,
+        tol_speed: float,
+        tol_likelihood: float,
+        animal_id: str = "",
 ):
     """Returns true when the mouse is digging using simple rules.
 
@@ -343,11 +345,11 @@ def dig(
 
 
 def look_around(
-    speed_dframe: pd.DataFrame,
-    likelihood_dframe: pd.DataFrame,
-    tol_speed: float,
-    tol_likelihood: float,
-    animal_id: str = "",
+        speed_dframe: pd.DataFrame,
+        likelihood_dframe: pd.DataFrame,
+        tol_speed: float,
+        tol_likelihood: float,
+        animal_id: str = "",
 ):
     """Returns true when the mouse is digging using simple rules.
 
@@ -376,12 +378,12 @@ def look_around(
 
 
 def following_path(
-    distance_dframe: pd.DataFrame,
-    position_dframe: pd.DataFrame,
-    follower: str,
-    followed: str,
-    frames: int = 20,
-    tol: float = 0,
+        distance_dframe: pd.DataFrame,
+        position_dframe: pd.DataFrame,
+        follower: str,
+        followed: str,
+        frames: int = 20,
+        tol: float = 0,
 ) -> np.array:
     """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
     followed has walked over the last specified number of frames
@@ -412,15 +414,15 @@ def following_path(
 
     # Check that the animals are oriented follower's nose -> followed's tail
     right_orient1 = (
-        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
-        < distance_dframe[
-            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
-        ]
+            distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
+            < distance_dframe[
+                tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
+            ]
     )
 
     right_orient2 = (
-        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
-        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
+            distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
+            < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
     )
 
     follow = np.all(
@@ -432,13 +434,13 @@ def following_path(
 
 
 def single_behaviour_analysis(
-    behaviour_name: str,
-    treatment_dict: dict,
-    behavioural_dict: dict,
-    plot: int = 0,
-    stat_tests: bool = True,
-    save: str = None,
-    ylim: float = None,
+        behaviour_name: str,
+        treatment_dict: dict,
+        behavioural_dict: dict,
+        plot: int = 0,
+        stat_tests: bool = True,
+        save: str = None,
+        ylim: float = None,
 ) -> list:
     """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
     with the actual tags, outputs a box plot and a series of significance tests amongst the groups
@@ -493,13 +495,13 @@ def single_behaviour_analysis(
         for i in combinations(treatment_dict.keys(), 2):
             # Solves issue with automatically generated examples
             if np.any(
-                np.array(
-                    [
-                        beh_dict[i[0]] == beh_dict[i[1]],
-                        np.var(beh_dict[i[0]]) == 0,
-                        np.var(beh_dict[i[1]]) == 0,
-                    ]
-                )
+                    np.array(
+                        [
+                            beh_dict[i[0]] == beh_dict[i[1]],
+                            np.var(beh_dict[i[0]]) == 0,
+                            np.var(beh_dict[i[1]]) == 0,
+                        ]
+                    )
             ):
                 stat_dict[i] = "Identical sources. Couldn't run"
             else:
@@ -512,7 +514,7 @@ def single_behaviour_analysis(
 
 
 def max_behaviour(
-    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
+        behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
 ) -> np.array:
     """Returns the most frequent behaviour in a window of window_size frames
 
@@ -596,19 +598,19 @@ def frame_corners(w, h, corners: dict = {}):
 
 # noinspection PyDefaultArgument,PyProtectedMember
 def rule_based_tagging(
-    tracks: List,
-    videos: List,
-    coordinates: Coordinates,
-    coords: Any,
-    dists: Any,
-    speeds: Any,
-    vid_index: int,
-    arena_type: str,
-    arena_detection_mode: str,
-    ellipse_detection_model: tf.keras.models.Model = None,
-    recog_limit: int = 100,
-    path: str = os.path.join("."),
-    params: dict = {},
+        tracks: List,
+        videos: List,
+        coordinates: Coordinates,
+        coords: Any,
+        dists: Any,
+        speeds: Any,
+        vid_index: int,
+        arena_type: str,
+        arena_detection_mode: str,
+        ellipse_detection_model: tf.keras.models.Model = None,
+        recog_limit: int = 100,
+        path: str = os.path.join("."),
+        params: dict = {},
 ) -> pd.DataFrame:
     """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
     video displaying the information in real time
@@ -828,18 +830,18 @@ def rule_based_tagging(
 
 
 def tag_rulebased_frames(
-    frame,
-    font,
-    frame_speeds,
-    animal_ids,
-    corners,
-    tag_dict,
-    fnum,
-    undercond,
-    hparams,
-    arena,
-    debug,
-    coords,
+        frame,
+        font,
+        frame_speeds,
+        animal_ids,
+        corners,
+        tag_dict,
+        fnum,
+        undercond,
+        hparams,
+        arena,
+        debug,
+        coords,
 ):
     """Helper function for rule_based_video. Annotates a given frame with on-screen information
     about the recognised patterns"""
@@ -995,16 +997,16 @@ def tag_rulebased_frames(
 
 # noinspection PyProtectedMember,PyDefaultArgument
 def rule_based_video(
-    coordinates: Coordinates,
-    tracks: List,
-    videos: List,
-    vid_index: int,
-    tag_dict: pd.DataFrame,
-    frame_limit: int = np.inf,
-    recog_limit: int = 100,
-    path: str = os.path.join("."),
-    params: dict = {},
-    debug: bool = False,
+        coordinates: Coordinates,
+        tracks: List,
+        videos: List,
+        vid_index: int,
+        tag_dict: pd.DataFrame,
+        frame_limit: int = np.inf,
+        recog_limit: int = 100,
+        path: str = os.path.join("."),
+        params: dict = {},
+        debug: bool = False,
 ) -> True:
     """Renders a version of the input video with all rule-based taggings in place.
 
@@ -1078,8 +1080,8 @@ def rule_based_video(
         # Capture speeds
         try:
             if (
-                list(frame_speeds.values())[0] == -np.inf
-                or fnum % params["speed_pause"] == 0
+                    list(frame_speeds.values())[0] == -np.inf
+                    or fnum % params["speed_pause"] == 0
             ):
                 for _id in animal_ids:
                     frame_speeds[_id] = tag_dict[_id + undercond + "speed"][fnum]
@@ -1123,6 +1125,5 @@ def rule_based_video(
 
     return True
 
-
 # TODO:
 #    - Is border sniffing anything you might consider interesting?
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 95e4804ce1611023f20ce2f4ebe1562a25dd01f7..fd3de2e913f3f719519a5ca8f5b956f2ce120ea5 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -10,11 +10,8 @@ usage: python -m examples.model_training -h
 """
 
 from deepof.data import *
-from deepof.models import *
-from deepof.utils import *
 from deepof.train_utils import *
-from tensorboard.plugins.hparams import api as hp
-from sklearn.metrics import roc_auc_score
+from deepof.utils import *
 
 parser = argparse.ArgumentParser(
     description="Autoencoder training for DeepOF animal pose recognition"
@@ -80,7 +77,7 @@ parser.add_argument(
     "--hyperparameter-tuning",
     "-tune",
     help="Indicates whether hyperparameters should be tuned either using 'bayopt' of 'hyperband'. "
-    "See documentation for details",
+         "See documentation for details",
     type=str,
     default=False,
 )
@@ -88,7 +85,7 @@ parser.add_argument(
     "--hyperparameters",
     "-hp",
     help="Path pointing to a pickled dictionary of network hyperparameters. "
-    "Thought to be used with the output of hyperparameter tuning",
+         "Thought to be used with the output of hyperparameter tuning",
     type=str,
     default=None,
 )
@@ -96,8 +93,8 @@ parser.add_argument(
     "--input-type",
     "-d",
     help="Select an input type for the autoencoder hypermodels. "
-    "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
-    "Defaults to coords.",
+         "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
+         "Defaults to coords.",
     type=str,
     default="dists",
 )
@@ -126,8 +123,8 @@ parser.add_argument(
     "--latent-reg",
     "-lreg",
     help="Sets the strategy to regularize the latent mixture of Gaussians. "
-    "It has to be one of none, categorical (an elastic net penalty is applied to the categorical distribution),"
-    "variance (l2 penalty to the variance of the clusters) or categorical+variance. Defaults to none.",
+         "It has to be one of none, categorical (an elastic net penalty is applied to the categorical distribution),"
+         "variance (l2 penalty to the variance of the clusters) or categorical+variance. Defaults to none.",
     default="none",
     type=str,
 )
@@ -135,7 +132,7 @@ parser.add_argument(
     "--loss",
     "-l",
     help="Sets the loss function for the variational model. "
-    "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
+         "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
     default="ELBO+MMD",
     type=str,
 )
@@ -185,7 +182,7 @@ parser.add_argument(
     "--predictor",
     "-pred",
     help="Activates the prediction branch of the variational Seq 2 Seq model with the specified weight. "
-    "Defaults to 0.0 (inactive)",
+         "Defaults to 0.0 (inactive)",
     default=0.0,
     type=float,
 )
@@ -193,7 +190,7 @@ parser.add_argument(
     "--smooth-alpha",
     "-sa",
     help="Sets the exponential smoothing factor to apply to the input data. "
-    "Float between 0 and 1 (lower is more smooting)",
+         "Float between 0 and 1 (lower is more smooting)",
     type=float,
     default=0.99,
 )
@@ -439,11 +436,11 @@ else:
 
     # Saves the best hyperparameters
     with open(
-        os.path.join(
-            output_path,
-            "{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()),
-        ),
-        "wb",
+            os.path.join(
+                output_path,
+                "{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()),
+            ),
+            "wb",
     ) as handle:
         pickle.dump(
             best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 84580f616b7687f7673d72e90b69c439d75a3644..fe51d0fe7a5033e9c0d8cfc974b8ce6df23dbc4d 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -8,19 +8,20 @@ Simple utility functions used in deepof example scripts. These are not part of t
 
 """
 
+import json
+import os
 from datetime import date, datetime
+from typing import Tuple, Union, Any, List
+
+import numpy as np
+import tensorflow as tf
 from kerastuner import BayesianOptimization, Hyperband
-from kerastuner import HyperParameters
 from kerastuner_tensorboard_logger import TensorBoardLogger
 from sklearn.metrics import roc_auc_score
 from tensorboard.plugins.hparams import api as hp
-from typing import Tuple, Union, Any, List
+
 import deepof.hypermodels
 import deepof.model_utils
-import json
-import numpy as np
-import os
-import tensorflow as tf
 
 # Ignore warning with no downstream effect
 tf.get_logger().setLevel("ERROR")
@@ -51,11 +52,11 @@ def load_treatments(train_path):
     to be loaded as metadata in the coordinates class"""
     try:
         with open(
-            os.path.join(
-                train_path,
-                [i for i in os.listdir(train_path) if i.endswith(".json")][0],
-            ),
-            "r",
+                os.path.join(
+                    train_path,
+                    [i for i in os.listdir(train_path) if i.endswith(".json")][0],
+                ),
+                "r",
         ) as handle:
             treatment_dict = json.load(handle)
     except IndexError:
@@ -65,20 +66,20 @@ def load_treatments(train_path):
 
 
 def get_callbacks(
-    X_train: np.array,
-    batch_size: int,
-    variational: bool,
-    phenotype_class: float,
-    predictor: float,
-    loss: str,
-    X_val: np.array = None,
-    cp: bool = False,
-    reg_cat_clusters: bool = False,
-    reg_cluster_variance: bool = False,
-    entropy_samples: int = 15000,
-    entropy_knn: int = 100,
-    logparam: dict = None,
-    outpath: str = ".",
+        X_train: np.array,
+        batch_size: int,
+        variational: bool,
+        phenotype_class: float,
+        predictor: float,
+        loss: str,
+        X_val: np.array = None,
+        cp: bool = False,
+        reg_cat_clusters: bool = False,
+        reg_cluster_variance: bool = False,
+        entropy_samples: int = 15000,
+        entropy_knn: int = 100,
+        logparam: dict = None,
+        outpath: str = ".",
 ) -> List[Union[Any]]:
     """Generates callbacks for model training, including:
     - run_ID: run name, with coarse parameter details;
@@ -196,14 +197,14 @@ def log_hyperparameters(phenotype_class: float, rec: str):
 
 # noinspection PyUnboundLocalVariable
 def tensorboard_metric_logging(
-    run_dir: str,
-    hpms: Any,
-    ae: Any,
-    X_val: np.ndarray,
-    y_val: np.ndarray,
-    phenotype_class: float,
-    predictor: float,
-    rec: str,
+        run_dir: str,
+        hpms: Any,
+        ae: Any,
+        X_val: np.ndarray,
+        y_val: np.ndarray,
+        phenotype_class: float,
+        predictor: float,
+        rec: str,
 ):
     """Autoencoder metric logging in tensorboard"""
 
@@ -245,29 +246,29 @@ def tensorboard_metric_logging(
 
 
 def autoencoder_fitting(
-    preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
-    batch_size: int,
-    encoding_size: int,
-    epochs: int,
-    hparams: dict,
-    kl_warmup: int,
-    log_history: bool,
-    log_hparams: bool,
-    loss: str,
-    mmd_warmup: int,
-    montecarlo_kl: int,
-    n_components: int,
-    output_path: str,
-    phenotype_class: float,
-    predictor: float,
-    pretrained: str,
-    save_checkpoints: bool,
-    save_weights: bool,
-    variational: bool,
-    reg_cat_clusters: bool,
-    reg_cluster_variance: bool,
-    entropy_samples: int,
-    entropy_knn: int,
+        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
+        batch_size: int,
+        encoding_size: int,
+        epochs: int,
+        hparams: dict,
+        kl_warmup: int,
+        log_history: bool,
+        log_hparams: bool,
+        loss: str,
+        mmd_warmup: int,
+        montecarlo_kl: int,
+        n_components: int,
+        output_path: str,
+        phenotype_class: float,
+        predictor: float,
+        pretrained: str,
+        save_checkpoints: bool,
+        save_weights: bool,
+        variational: bool,
+        reg_cat_clusters: bool,
+        reg_cluster_variance: bool,
+        entropy_samples: int,
+        entropy_knn: int,
 ):
     """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
 
@@ -312,7 +313,7 @@ def autoencoder_fitting(
         logparams, metrics = log_hyperparameters(phenotype_class, rec)
 
         with tf.summary.create_file_writer(
-            os.path.join(output_path, "hparams", run_ID)
+                os.path.join(output_path, "hparams", run_ID)
         ).as_default():
             hp.hparams_config(
                 hparams=logparams,
@@ -362,14 +363,14 @@ def autoencoder_fitting(
                 verbose=1,
                 validation_data=(X_val, X_val),
                 callbacks=cbacks
-                + [
-                    CustomStopper(
-                        monitor="val_loss",
-                        patience=5,
-                        restore_best_weights=True,
-                        start_epoch=max(kl_warmup, mmd_warmup),
-                    ),
-                ],
+                          + [
+                              CustomStopper(
+                                  monitor="val_loss",
+                                  patience=5,
+                                  restore_best_weights=True,
+                                  start_epoch=max(kl_warmup, mmd_warmup),
+                              ),
+                          ],
             )
 
             if save_weights:
@@ -439,23 +440,23 @@ def autoencoder_fitting(
 
 
 def tune_search(
-    data: List[np.array],
-    encoding_size: int,
-    hypertun_trials: int,
-    hpt_type: str,
-    hypermodel: str,
-    k: int,
-    kl_warmup_epochs: int,
-    loss: str,
-    mmd_warmup_epochs: int,
-    overlap_loss: float,
-    phenotype_class: float,
-    predictor: float,
-    project_name: str,
-    callbacks: List,
-    n_epochs: int = 30,
-    n_replicas: int = 1,
-    outpath: str = ".",
+        data: List[np.array],
+        encoding_size: int,
+        hypertun_trials: int,
+        hpt_type: str,
+        hypermodel: str,
+        k: int,
+        kl_warmup_epochs: int,
+        loss: str,
+        mmd_warmup_epochs: int,
+        overlap_loss: float,
+        phenotype_class: float,
+        predictor: float,
+        project_name: str,
+        callbacks: List,
+        n_epochs: int = 30,
+        n_replicas: int = 1,
+        outpath: str = ".",
 ) -> Union[bool, Tuple[Any, Any]]:
     """Define the search space using keras-tuner and bayesian optimization
 
@@ -494,7 +495,7 @@ def tune_search(
 
     if hypermodel == "S2SAE":  # pragma: no cover
         assert (
-            predictor == 0.0 and phenotype_class == 0.0
+                predictor == 0.0 and phenotype_class == 0.0
         ), "Prediction branches are only available for variational models. See documentation for more details"
         batch_size = 1
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
diff --git a/deepof/utils.py b/deepof/utils.py
index d24905bfd55baccf4882acad00e574be862db2e0..3fe7a8305551b6f0959f3bb9c09c0f7256b22d1c 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -9,20 +9,21 @@ Functions and general utilities for the deepof package. See documentation for de
 """
 
 import argparse
-import cv2
 import multiprocessing
+import os
+from copy import deepcopy
+from itertools import combinations, product
+from typing import Tuple, Any, List, Union, NewType
+
+import cv2
 import networkx as nx
 import numpy as np
-import os
 import pandas as pd
 import regex as re
 import tensorflow as tf
-from copy import deepcopy
-from itertools import combinations, product
 from joblib import Parallel, delayed
 from sklearn import mixture
 from tqdm import tqdm
-from typing import Tuple, Any, List, Union, NewType
 
 # DEFINE CUSTOM ANNOTATED TYPES #
 
@@ -145,7 +146,7 @@ def tab2polar(cartesian_df: pd.DataFrame) -> pd.DataFrame:
 
 
 def compute_dist(
-    pair_array: np.array, arena_abs: int = 1, arena_rel: int = 1
+        pair_array: np.array, arena_abs: int = 1, arena_rel: int = 1
 ) -> pd.DataFrame:
     """Returns a pandas.DataFrame with the scaled distances between a pair of body parts.
 
@@ -168,7 +169,7 @@ def compute_dist(
 
 
 def bpart_distance(
-    dataframe: pd.DataFrame, arena_abs: int = 1, arena_rel: int = 1
+        dataframe: pd.DataFrame, arena_abs: int = 1, arena_rel: int = 1
 ) -> pd.DataFrame:
     """Returns a pandas.DataFrame with the scaled distances between all pairs of body parts.
 
@@ -207,7 +208,7 @@ def angle(a: np.array, b: np.array, c: np.array) -> np.array:
     bc = c - b
 
     cosine_angle = np.einsum("...i,...i", ba, bc) / (
-        np.linalg.norm(ba, axis=1) * np.linalg.norm(bc, axis=1)
+            np.linalg.norm(ba, axis=1) * np.linalg.norm(bc, axis=1)
     )
     ang = np.arccos(cosine_angle)
 
@@ -230,7 +231,7 @@ def angle_trio(bpart_array: np.array) -> np.array:
 
 
 def rotate(
-    p: np.array, angles: np.array, origin: np.array = np.array([0, 0])
+        p: np.array, angles: np.array, origin: np.array = np.array([0, 0])
 ) -> np.array:
     """Returns a numpy.array with the initial values rotated by angles radians
 
@@ -362,12 +363,12 @@ def moving_average(time_series: pd.Series, N: int = 5):
 
 
 def mask_outliers(
-    time_series: pd.DataFrame,
-    likelihood: pd.DataFrame,
-    likelihood_tolerance: float,
-    lag: int,
-    n_std: int,
-    mode: str,
+        time_series: pd.DataFrame,
+        likelihood: pd.DataFrame,
+        likelihood_tolerance: float,
+        lag: int,
+        n_std: int,
+        mode: str,
 ):
     """Returns a mask over the bivariate trajectory of a body part, identifying as True all detected outliers
 
@@ -403,13 +404,13 @@ def mask_outliers(
 
 
 def full_outlier_mask(
-    experiment: pd.DataFrame,
-    likelihood: pd.DataFrame,
-    likelihood_tolerance: float,
-    exclude: str,
-    lag: int,
-    n_std: int,
-    mode: str,
+        experiment: pd.DataFrame,
+        likelihood: pd.DataFrame,
+        likelihood_tolerance: float,
+        exclude: str,
+        lag: int,
+        n_std: int,
+        mode: str,
 ):
     """Iterates over all body parts of experiment, and outputs a dataframe where all x, y positions are
     replaced by a boolean mask, where True indicates an outlier
@@ -454,14 +455,14 @@ def full_outlier_mask(
 
 
 def interpolate_outliers(
-    experiment: pd.DataFrame,
-    likelihood: pd.DataFrame,
-    likelihood_tolerance: float,
-    exclude: str = "",
-    lag: int = 5,
-    n_std: int = 3,
-    mode: str = "or",
-    limit: int = 10,
+        experiment: pd.DataFrame,
+        likelihood: pd.DataFrame,
+        likelihood_tolerance: float,
+        exclude: str = "",
+        lag: int = 5,
+        n_std: int = 3,
+        mode: str = "or",
+        limit: int = 10,
 ):
     """Marks all outliers in experiment and replaces them using a univariate linear interpolation approach.
     Note that this approach only works for equally spaced data (constant camera acquisition rates).
@@ -495,13 +496,13 @@ def interpolate_outliers(
 
 
 def recognize_arena(
-    videos: list,
-    vid_index: int,
-    path: str = ".",
-    recoglimit: int = 100,
-    arena_type: str = "circular",
-    detection_mode: str = "cnn",
-    cnn_model: tf.keras.models.Model = None,
+        videos: list,
+        vid_index: int,
+        path: str = ".",
+        recoglimit: int = 100,
+        arena_type: str = "circular",
+        detection_mode: str = "cnn",
+        cnn_model: tf.keras.models.Model = None,
 ) -> Tuple[np.array, int, int]:
     """Returns numpy.array with information about the arena recognised from the first frames
     of the video. WARNING: estimates won't be reliable if the camera moves along the video.
@@ -566,9 +567,9 @@ def recognize_arena(
 
 
 def circular_arena_recognition(
-    frame: np.array,
-    detection_mode: str = "cnn",
-    cnn_model: tf.keras.models.Model = None,
+        frame: np.array,
+        detection_mode: str = "cnn",
+        cnn_model: tf.keras.models.Model = None,
 ) -> np.array:
     """Returns x,y position of the center, the lengths of the major and minor axes,
     and the angle of the recognised arena
@@ -631,13 +632,13 @@ def circular_arena_recognition(
 
 
 def rolling_speed(
-    dframe: pd.DatetimeIndex,
-    window: int = 3,
-    rounds: int = 3,
-    deriv: int = 1,
-    center: str = None,
-    shift: int = 2,
-    typ: str = "coords",
+        dframe: pd.DatetimeIndex,
+        window: int = 3,
+        rounds: int = 3,
+        deriv: int = 1,
+        center: str = None,
+        shift: int = 2,
+        typ: str = "coords",
 ) -> pd.DataFrame:
     """Returns the average speed over n frames in pixels per frame
 
@@ -667,18 +668,17 @@ def rolling_speed(
     speeds = pd.DataFrame
 
     for der in range(deriv):
-
         features = 2 if der == 0 and typ == "coords" else 1
 
         distances = (
-            np.concatenate(
-                [
-                    np.array(dframe).reshape([-1, features], order="C"),
-                    np.array(dframe.shift(shift)).reshape([-1, features], order="C"),
-                ],
-                axis=1,
-            )
-            / shift
+                np.concatenate(
+                    [
+                        np.array(dframe).reshape([-1, features], order="C"),
+                        np.array(dframe.shift(shift)).reshape([-1, features], order="C"),
+                    ],
+                    axis=1,
+                )
+                / shift
         )
 
         distances = np.array(compute_dist(distances))
@@ -726,12 +726,12 @@ def gmm_compute(x: np.array, n_components: int, cv_type: str) -> list:
 
 
 def gmm_model_selection(
-    x: pd.DataFrame,
-    n_components_range: range,
-    part_size: int,
-    n_runs: int = 100,
-    n_cores: int = False,
-    cv_types: Tuple = ("spherical", "tied", "diag", "full"),
+        x: pd.DataFrame,
+        n_components_range: range,
+        part_size: int,
+        n_runs: int = 100,
+        n_cores: int = False,
+        cv_types: Tuple = ("spherical", "tied", "diag", "full"),
 ) -> Tuple[List[list], List[np.ndarray], Union[int, Any]]:
     """Runs GMM clustering model selection on the specified X dataframe, outputs the bic distribution per model,
     a vector with the median BICs and an object with the overall best model
@@ -788,10 +788,10 @@ def gmm_model_selection(
 
 
 def cluster_transition_matrix(
-    cluster_sequence: np.array,
-    nclusts: int,
-    autocorrelation: bool = True,
-    return_graph: bool = False,
+        cluster_sequence: np.array,
+        nclusts: int,
+        autocorrelation: bool = True,
+        return_graph: bool = False,
 ) -> Tuple[Union[nx.Graph, Any], np.ndarray]:
     """Computes the transition matrix between clusters and the autocorrelation in the sequence.
 
@@ -840,7 +840,6 @@ def cluster_transition_matrix(
 
     return trans_normed
 
-
 # TODO:
 #    - Add sequence plot to single_behaviour_analysis (show how the condition varies across a specified time window)
 #    - Add digging to rule_based_tagging
diff --git a/deepof/visuals.py b/deepof/visuals.py
index cee6cf4f5f42f863810e29501f11267c027c212b..ad445ad82a6b8883bef26a0906ffed0f0e548834 100644
--- a/deepof/visuals.py
+++ b/deepof/visuals.py
@@ -8,25 +8,25 @@ General plotting functions for the deepof package
 
 """
 
+from itertools import cycle
+from typing import List
 
 import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import seaborn as sns
-from itertools import cycle
-from typing import List
 
 
 # PLOTTING FUNCTIONS #
 
 
 def plot_heatmap(
-    dframe: pd.DataFrame,
-    bodyparts: List,
-    xlim: tuple,
-    ylim: tuple,
-    save: str = False,
-    dpi: int = 200,
+        dframe: pd.DataFrame,
+        bodyparts: List,
+        xlim: tuple,
+        ylim: tuple,
+        save: str = False,
+        dpi: int = 200,
 ) -> plt.figure:
     """Returns a heatmap of the movement of a specific bodypart in the arena.
     If more than one bodypart is passed, it returns one subplot for each
@@ -73,13 +73,13 @@ def plot_heatmap(
 
 
 def model_comparison_plot(
-    bic: list,
-    m_bic: list,
-    n_components_range: range,
-    cov_plot: str,
-    save: str = False,
-    cv_types: tuple = ("spherical", "tied", "diag", "full"),
-    dpi: int = 200,
+        bic: list,
+        m_bic: list,
+        n_components_range: range,
+        cov_plot: str,
+        save: str = False,
+        cv_types: tuple = ("spherical", "tied", "diag", "full"),
+        dpi: int = 200,
 ) -> plt.figure:
     """
 
@@ -119,7 +119,7 @@ def model_comparison_plot(
         bars.append(
             spl.bar(
                 xpos,
-                m_bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
+                m_bic[i * len(n_components_range): (i + 1) * len(n_components_range)],
                 color=color,
                 width=0.2,
             )
@@ -128,9 +128,9 @@ def model_comparison_plot(
     spl.set_xticks(n_components_range)
     plt.title("BIC score per model")
     xpos = (
-        np.mod(m_bic.argmin(), len(n_components_range))
-        + 0.5
-        + 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
+            np.mod(m_bic.argmin(), len(n_components_range))
+            + 0.5
+            + 0.2 * np.floor(m_bic.argmin() / len(n_components_range))
     )
     # noinspection PyArgumentList
     spl.text(xpos, m_bic.min() * 0.97 + 0.1 * m_bic.max(), "*", fontsize=14)