diff --git a/deepof/model_training.py b/deepof/model_training.py
index 34b58f3986642fa92b0fe2c92a5a18d1fd685fc3..c959cb704cf2b1eed803dd1d03c087468147e715 100644
--- a/deepof/model_training.py
+++ b/deepof/model_training.py
@@ -19,13 +19,26 @@ parser = argparse.ArgumentParser(
     description="Autoencoder training for DeepOF animal pose recognition"
 )
 
-parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
 parser.add_argument(
-    "--val-num",
-    "-vn",
-    help="set number of videos of the training" "set to use for validation",
+    "--arena-dims",
+    "-adim",
+    help="diameter in mm of the utilised arena. Used for scaling purposes",
     type=int,
-    default=1,
+    default=380,
+)
+parser.add_argument(
+    "--batch-size",
+    "-bs",
+    help="set training batch size. Defaults to 512",
+    type=int,
+    default=512,
+)
+parser.add_argument(
+    "--bayopt",
+    "-n",
+    help="sets the number of Bayesian optimization iterations to run. Default is 25",
+    type=int,
+    default=25,
 )
 parser.add_argument(
     "--components",
@@ -35,35 +48,50 @@ parser.add_argument(
     default=1,
 )
 parser.add_argument(
-    "--input-type",
-    "-d",
-    help="Select an input type for the autoencoder hypermodels. \
-    It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle. \
-    Defaults to coords.",
+    "--exclude-bodyparts",
+    "-exc",
+    help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas",
     type=str,
-    default="dists",
+    default="",
 )
 parser.add_argument(
-    "--predictor",
-    "-pred",
-    help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
-    default=0,
-    type=float,
+    "--gaussian-filter",
+    "-gf",
+    help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
+    type=str2bool,
+    default=False,
 )
 parser.add_argument(
-    "--variational",
-    "-v",
-    help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
-    default=True,
+    "--hypermodel",
+    "-m",
+    help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, "
+    "S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
+    type=str,
+    default="S2SVAE",
+)
+parser.add_argument(
+    "--hyperparameter-tuning",
+    "-tune",
+    help="If True, hyperparameter tuning is performed. See documentation for details",
     type=str2bool,
+    default=False,
 )
 parser.add_argument(
-    "--loss",
-    "-l",
-    help="Sets the loss function for the variational model. "
-    "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
-    default="ELBO+MMD",
+    "--hyperparameters",
+    "-hp",
+    help="Path pointing to a pickled dictionary of network hyperparameters. "
+    "Thought to be used with the output of hyperparameter tuning",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--input-type",
+    "-d",
+    help="Select an input type for the autoencoder hypermodels. "
+    "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
+    "Defaults to coords.",
     type=str,
+    default="dists",
 )
 parser.add_argument(
     "--kl-warmup",
@@ -72,6 +100,14 @@ parser.add_argument(
     default=10,
     type=int,
 )
+parser.add_argument(
+    "--loss",
+    "-l",
+    help="Sets the loss function for the variational model. "
+    "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
+    default="ELBO+MMD",
+    type=str,
+)
 parser.add_argument(
     "--mmd-warmup",
     "-mmdw",
@@ -79,47 +115,50 @@ parser.add_argument(
     default=10,
     type=int,
 )
-parser.add_argument(
-    "--hyperparameters",
-    "-hp",
-    help="Path pointing to a pickled dictionary of network hyperparameters. "
-    "Thought to be used with the output of hyperparameter tuning",
-)
-parser.add_argument(
-    "--encoding-size",
-    "-e",
-    help="Sets the dimensionality of the latent space. Defaults to 16.",
-    default=16,
-    type=int,
-)
 parser.add_argument(
     "--overlap-loss",
     "-ol",
     help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
-    default=False,
     type=str2bool,
+    default=False,
 )
 parser.add_argument(
-    "--batch-size",
-    "-bs",
-    help="set training batch size. Defaults to 512",
-    type=int,
-    default=512,
+    "--predictor",
+    "-pred",
+    help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
+    default=0,
+    type=float,
+)
+parser.add_argument(
+    "--smooth-alpha",
+    "-sa",
+    help="Sets the exponential smoothing factor to apply to the input data. "
+    "Float between 0 and 1 (lower is more smooting)",
+    type=float,
+    default=0.99,
 )
 parser.add_argument(
     "--stability-check",
     "-s",
-    help="Sets the number of times that the model is trained and initialised. If greater than 1 (the default), "
-    "saves the cluster assignments to a dataframe on disk",
+    help="Sets the number of times that the model is trained and initialised. "
+    "If greater than 1 (the default), saves the cluster assignments to a dataframe on disk",
     type=int,
     default=1,
 )
+parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
 parser.add_argument(
-    "--gaussian-filter",
-    "-gf",
-    help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
+    "--val-num",
+    "-vn",
+    help="set number of videos of the training" "set to use for validation",
+    type=int,
+    default=1,
+)
+parser.add_argument(
+    "--variational",
+    "-v",
+    help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
+    default=True,
     type=str2bool,
-    default=False,
 )
 parser.add_argument(
     "--window-size",
@@ -135,58 +174,12 @@ parser.add_argument(
     type=int,
     default=5,
 )
-parser.add_argument(
-    "--smooth-alpha",
-    "-sa",
-    help="Sets the exponential smoothing factor to apply to the input data. "
-    "Float between 0 and 1 (lower is more smooting)",
-    type=float,
-    default=0.99,
-)
-parser.add_argument(
-    "--exclude-bodyparts",
-    "-exc",
-    help="Excludes the indicated bodyparts from all analyses. "
-    "It should consist of several values separated by commas",
-    type=str,
-    default="",
-)
-parser.add_argument(
-    "--arena-dims",
-    "-adim",
-    help="diameter in mm of the utilised arena. Used for scaling purposes",
-    type=int,
-    default=380,
-)
-parser.add_argument(
-    "--hyperparameter-tuning",
-    "-tune",
-    help="If True, hyperparameter tuning is performed. See documentation for details",
-    type=str2bool,
-    default=False,
-)
-parser.add_argument(
-    "--bayopt",
-    "-n",
-    help="sets the number of Bayesian optimization iterations to run. Default is 25",
-    default=25,
-    type=int,
-)
-parser.add_argument(
-    "--hypermodel",
-    "-m",
-    help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, "
-    "S2SVAEP, S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
-    default="S2SVAE",
-    type=str,
-)
 
 args = parser.parse_args()
 
 arena_dims = args.arena_dims
 batch_size = args.batch_size
 bayopt_trials = args.bayopt
-encoding = args.encoding_size
 exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
 gaussian_filter = args.gaussian_filter
 hparams = args.hyperparameters
@@ -225,7 +218,7 @@ assert input_type in [
 ], "Invalid input type. Type python model_training.py -h for help."
 
 # Loads model hyperparameters and treatment conditions, if available
-hparams = load_hparams(hparams, encoding)
+hparams = load_hparams(hparams)
 treatment_dict = load_treatments(train_path)
 
 # noinspection PyTypeChecker
@@ -407,7 +400,7 @@ else:
         overlap_loss=overlap_loss,
         predictor=predictor,
         project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
-        tensorboard_callback=tensorboard_callback,
+        callbacks=[tensorboard_callback, cp_callback, onecycle],
     )
 
     # Saves a compiled, untrained version of the best model
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index b4cc6da755cc9b619e8bba78c3ccd9bc3f9f61fe..90f475fb4b878517f7a20a5ff683ded09823ce19 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -10,7 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
 from datetime import datetime
 
 from kerastuner import BayesianOptimization
-from typing import Tuple, Union, Any
+from typing import Tuple, Union, Any, List
 import deepof.hypermodels
 import deepof.model_utils
 import keras
@@ -20,21 +20,20 @@ import pickle
 import tensorflow as tf
 
 
-def load_hparams(hparams, encoding):
+def load_hparams(hparams):
     """Loads hyperparameters from a custom dictionary pickled on disc.
     Thought to be used with the output of hyperparameter_tuning.py"""
 
     if hparams is not None:
         with open(hparams, "rb") as handle:
             hparams = pickle.load(handle)
-        hparams["encoding"] = encoding
     else:
         hparams = {
             "units_conv": 256,
             "units_lstm": 256,
             "units_dense2": 64,
             "dropout_rate": 0.25,
-            "encoding": encoding,
+            "encoding": 16,
             "learning_rate": 1e-3,
         }
     return hparams
@@ -47,7 +46,7 @@ def load_treatments(train_path):
         with open(
             os.path.join(
                 train_path,
-                [i for i in os.listdir(train_path) if i.endswith(".pickle")][0],
+                [i for i in os.listdir(train_path) if i.endswith(".pkl")][0],
             ),
             "rb",
         ) as handle:
@@ -89,14 +88,12 @@ def get_callbacks(
         log_dir=log_dir, histogram_freq=1, profile_batch=2,
     )
 
-    cp_callback = (
-        tf.keras.callbacks.ModelCheckpoint(
-            "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
-            verbose=1,
-            save_best_only=False,
-            save_weights_only=True,
-            save_freq="epoch",
-        ),
+    cp_callback = tf.keras.callbacks.ModelCheckpoint(
+        "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
+        verbose=1,
+        save_best_only=False,
+        save_weights_only=True,
+        save_freq="epoch",
     )
 
     onecycle = deepof.model_utils.one_cycle_scheduler(
@@ -118,9 +115,34 @@ def tune_search(
     overlap_loss: float,
     predictor: float,
     project_name: str,
-    tensorboard_callback: tf.keras.callbacks,
+    callbacks: List,
 ) -> Union[bool, Tuple[Any, Any]]:
-    """Define the search space using keras-tuner and bayesian optimization"""
+    """Define the search space using keras-tuner and bayesian optimization
+
+        Parameters:
+            - train (np.array): dataset to train the model on
+            - test (np.array): dataset to validate the model on
+            - bayopt_trials (int): number of Bayesian optimization iterations to run
+            - hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder) or
+            S2SGMVAE (Gaussian Mixture Variational autoencoder).
+            - k (int) number of components of the Gaussian Mixture
+            - kl_wu (int): number of epochs for KL divergence warm up
+            - loss (str): one of [ELBO, MMD, ELBO+MMD]
+            - mmd_wu (int): number of epochs for MMD warm up
+            - overlap_loss (float): assigns as weight to an extra loss term which
+            penalizes overlap between GM components
+            - predictor (float): adds an extra regularizing neural network to the model,
+            which tries to predict the next frame from the current one
+            - project_name (str): ID of the current run
+            - callbacks (list): list of callbacks for the training loop
+
+        Returns:
+            - best_hparams (dict): dictionary with the best retrieved hyperparameters
+            - best_run (tf.keras.Model): trained instance of the best model found
+
+    """
+
+    tensorboard_callback, cp_callback, onecycle = callbacks
 
     if hypermodel == "S2SAE":
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index e1e96d368e7442f2d43d7f0fa984e8614da73bc6..8255781750e38485d59f9ea03841d3301ffeae6c 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -12,8 +12,10 @@ from hypothesis import given
 from hypothesis import settings
 from hypothesis import strategies as st
 from hypothesis.extra.numpy import arrays
+import deepof.model_utils
 import deepof.train_utils
-import numpy as np
+import keras
+import os
 import tensorflow as tf
 import tensorflow_probability as tfp
 from tensorflow.python.framework.ops import EagerTensor
@@ -26,20 +28,66 @@ tfpl = tfp.layers
 tfd = tfp.distributions
 
 
-@given(encoding=st.integers(min_value=1, max_value=128))
-def test_load_hparams(encoding):
-    params = deepof.train_utils.load_hparams(None, encoding)
-    assert type(params) == dict
-    assert params["encoding"] == encoding
+def test_load_hparams():
+    assert type(deepof.train_utils.load_hparams(None)) == dict
+    assert (
+        type(
+            deepof.train_utils.load_hparams(
+                os.path.join("tests", "test_examples", "Others", "test_hparams.pkl")
+            )
+        )
+        == dict
+    )
 
 
 def test_load_treatments():
-    pass
+    assert deepof.train_utils.load_treatments(".") is None
+    assert (
+        type(
+            deepof.train_utils.load_treatments(
+                os.path.join("tests", "test_examples", "Others")
+            )
+        )
+        == dict
+    )
 
 
-def test_get_callbacks():
-    pass
+@given(
+    X_train=arrays(
+        shape=st.tuples(st.integers(min_value=1, max_value=1000)), dtype=float
+    ),
+    batch_size=st.integers(min_value=128, max_value=512),
+    k=st.integers(min_value=1, max_value=50),
+    kl_wu=st.integers(min_value=0, max_value=25),
+    loss=st.one_of(st.just("test_A"), st.just("test_B")),
+    mmd_wu=st.integers(min_value=0, max_value=25),
+    predictor=st.floats(min_value=0.0, max_value=1.0),
+    variational=st.booleans(),
+)
+def test_get_callbacks(
+    X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
+):
+    runID, tbc, cpc, cycle1c = deepof.train_utils.get_callbacks(
+        X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu,
+    )
+    assert type(runID) == str
+    assert type(tbc) == keras.callbacks.tensorboard_v2.TensorBoard
+    assert type(cpc) == tf.python.keras.callbacks.ModelCheckpoint
+    assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
 
 
 def test_tune_search():
-    pass
+    deepof.train_utils.tune_search(
+        train,
+        test,
+        bayopt_trials,
+        hypermodel,
+        k,
+        kl_wu,
+        loss,
+        mmd_wu,
+        overlap_loss,
+        predictor,
+        project_name,
+        callbacks,
+    )