diff --git a/deepof/model_training.py b/deepof/model_training.py index 34b58f3986642fa92b0fe2c92a5a18d1fd685fc3..c959cb704cf2b1eed803dd1d03c087468147e715 100644 --- a/deepof/model_training.py +++ b/deepof/model_training.py @@ -19,13 +19,26 @@ parser = argparse.ArgumentParser( description="Autoencoder training for DeepOF animal pose recognition" ) -parser.add_argument("--train-path", "-tp", help="set training set path", type=str) parser.add_argument( - "--val-num", - "-vn", - help="set number of videos of the training" "set to use for validation", + "--arena-dims", + "-adim", + help="diameter in mm of the utilised arena. Used for scaling purposes", type=int, - default=1, + default=380, +) +parser.add_argument( + "--batch-size", + "-bs", + help="set training batch size. Defaults to 512", + type=int, + default=512, +) +parser.add_argument( + "--bayopt", + "-n", + help="sets the number of Bayesian optimization iterations to run. Default is 25", + type=int, + default=25, ) parser.add_argument( "--components", @@ -35,35 +48,50 @@ parser.add_argument( default=1, ) parser.add_argument( - "--input-type", - "-d", - help="Select an input type for the autoencoder hypermodels. \ - It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle. \ - Defaults to coords.", + "--exclude-bodyparts", + "-exc", + help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas", type=str, - default="dists", + default="", ) parser.add_argument( - "--predictor", - "-pred", - help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True", - default=0, - type=float, + "--gaussian-filter", + "-gf", + help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model", + type=str2bool, + default=False, ) parser.add_argument( - "--variational", - "-v", - help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True", - default=True, + "--hypermodel", + "-m", + help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, " + "S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.", + type=str, + default="S2SVAE", +) +parser.add_argument( + "--hyperparameter-tuning", + "-tune", + help="If True, hyperparameter tuning is performed. See documentation for details", type=str2bool, + default=False, ) parser.add_argument( - "--loss", - "-l", - help="Sets the loss function for the variational model. " - "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD", - default="ELBO+MMD", + "--hyperparameters", + "-hp", + help="Path pointing to a pickled dictionary of network hyperparameters. " + "Thought to be used with the output of hyperparameter tuning", + type=str, + default=None, +) +parser.add_argument( + "--input-type", + "-d", + help="Select an input type for the autoencoder hypermodels. " + "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle." + "Defaults to coords.", type=str, + default="dists", ) parser.add_argument( "--kl-warmup", @@ -72,6 +100,14 @@ parser.add_argument( default=10, type=int, ) +parser.add_argument( + "--loss", + "-l", + help="Sets the loss function for the variational model. " + "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD", + default="ELBO+MMD", + type=str, +) parser.add_argument( "--mmd-warmup", "-mmdw", @@ -79,47 +115,50 @@ parser.add_argument( default=10, type=int, ) -parser.add_argument( - "--hyperparameters", - "-hp", - help="Path pointing to a pickled dictionary of network hyperparameters. " - "Thought to be used with the output of hyperparameter tuning", -) -parser.add_argument( - "--encoding-size", - "-e", - help="Sets the dimensionality of the latent space. Defaults to 16.", - default=16, - type=int, -) parser.add_argument( "--overlap-loss", "-ol", help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function", - default=False, type=str2bool, + default=False, ) parser.add_argument( - "--batch-size", - "-bs", - help="set training batch size. Defaults to 512", - type=int, - default=512, + "--predictor", + "-pred", + help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True", + default=0, + type=float, +) +parser.add_argument( + "--smooth-alpha", + "-sa", + help="Sets the exponential smoothing factor to apply to the input data. " + "Float between 0 and 1 (lower is more smooting)", + type=float, + default=0.99, ) parser.add_argument( "--stability-check", "-s", - help="Sets the number of times that the model is trained and initialised. If greater than 1 (the default), " - "saves the cluster assignments to a dataframe on disk", + help="Sets the number of times that the model is trained and initialised. " + "If greater than 1 (the default), saves the cluster assignments to a dataframe on disk", type=int, default=1, ) +parser.add_argument("--train-path", "-tp", help="set training set path", type=str) parser.add_argument( - "--gaussian-filter", - "-gf", - help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model", + "--val-num", + "-vn", + help="set number of videos of the training" "set to use for validation", + type=int, + default=1, +) +parser.add_argument( + "--variational", + "-v", + help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True", + default=True, type=str2bool, - default=False, ) parser.add_argument( "--window-size", @@ -135,58 +174,12 @@ parser.add_argument( type=int, default=5, ) -parser.add_argument( - "--smooth-alpha", - "-sa", - help="Sets the exponential smoothing factor to apply to the input data. " - "Float between 0 and 1 (lower is more smooting)", - type=float, - default=0.99, -) -parser.add_argument( - "--exclude-bodyparts", - "-exc", - help="Excludes the indicated bodyparts from all analyses. " - "It should consist of several values separated by commas", - type=str, - default="", -) -parser.add_argument( - "--arena-dims", - "-adim", - help="diameter in mm of the utilised arena. Used for scaling purposes", - type=int, - default=380, -) -parser.add_argument( - "--hyperparameter-tuning", - "-tune", - help="If True, hyperparameter tuning is performed. See documentation for details", - type=str2bool, - default=False, -) -parser.add_argument( - "--bayopt", - "-n", - help="sets the number of Bayesian optimization iterations to run. Default is 25", - default=25, - type=int, -) -parser.add_argument( - "--hypermodel", - "-m", - help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, " - "S2SVAEP, S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.", - default="S2SVAE", - type=str, -) args = parser.parse_args() arena_dims = args.arena_dims batch_size = args.batch_size bayopt_trials = args.bayopt -encoding = args.encoding_size exclude_bodyparts = tuple(args.exclude_bodyparts.split(",")) gaussian_filter = args.gaussian_filter hparams = args.hyperparameters @@ -225,7 +218,7 @@ assert input_type in [ ], "Invalid input type. Type python model_training.py -h for help." # Loads model hyperparameters and treatment conditions, if available -hparams = load_hparams(hparams, encoding) +hparams = load_hparams(hparams) treatment_dict = load_treatments(train_path) # noinspection PyTypeChecker @@ -407,7 +400,7 @@ else: overlap_loss=overlap_loss, predictor=predictor, project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp), - tensorboard_callback=tensorboard_callback, + callbacks=[tensorboard_callback, cp_callback, onecycle], ) # Saves a compiled, untrained version of the best model diff --git a/deepof/train_utils.py b/deepof/train_utils.py index b4cc6da755cc9b619e8bba78c3ccd9bc3f9f61fe..90f475fb4b878517f7a20a5ff683ded09823ce19 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -10,7 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t from datetime import datetime from kerastuner import BayesianOptimization -from typing import Tuple, Union, Any +from typing import Tuple, Union, Any, List import deepof.hypermodels import deepof.model_utils import keras @@ -20,21 +20,20 @@ import pickle import tensorflow as tf -def load_hparams(hparams, encoding): +def load_hparams(hparams): """Loads hyperparameters from a custom dictionary pickled on disc. Thought to be used with the output of hyperparameter_tuning.py""" if hparams is not None: with open(hparams, "rb") as handle: hparams = pickle.load(handle) - hparams["encoding"] = encoding else: hparams = { "units_conv": 256, "units_lstm": 256, "units_dense2": 64, "dropout_rate": 0.25, - "encoding": encoding, + "encoding": 16, "learning_rate": 1e-3, } return hparams @@ -47,7 +46,7 @@ def load_treatments(train_path): with open( os.path.join( train_path, - [i for i in os.listdir(train_path) if i.endswith(".pickle")][0], + [i for i in os.listdir(train_path) if i.endswith(".pkl")][0], ), "rb", ) as handle: @@ -89,14 +88,12 @@ def get_callbacks( log_dir=log_dir, histogram_freq=1, profile_batch=2, ) - cp_callback = ( - tf.keras.callbacks.ModelCheckpoint( - "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt", - verbose=1, - save_best_only=False, - save_weights_only=True, - save_freq="epoch", - ), + cp_callback = tf.keras.callbacks.ModelCheckpoint( + "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt", + verbose=1, + save_best_only=False, + save_weights_only=True, + save_freq="epoch", ) onecycle = deepof.model_utils.one_cycle_scheduler( @@ -118,9 +115,34 @@ def tune_search( overlap_loss: float, predictor: float, project_name: str, - tensorboard_callback: tf.keras.callbacks, + callbacks: List, ) -> Union[bool, Tuple[Any, Any]]: - """Define the search space using keras-tuner and bayesian optimization""" + """Define the search space using keras-tuner and bayesian optimization + + Parameters: + - train (np.array): dataset to train the model on + - test (np.array): dataset to validate the model on + - bayopt_trials (int): number of Bayesian optimization iterations to run + - hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder) or + S2SGMVAE (Gaussian Mixture Variational autoencoder). + - k (int) number of components of the Gaussian Mixture + - kl_wu (int): number of epochs for KL divergence warm up + - loss (str): one of [ELBO, MMD, ELBO+MMD] + - mmd_wu (int): number of epochs for MMD warm up + - overlap_loss (float): assigns as weight to an extra loss term which + penalizes overlap between GM components + - predictor (float): adds an extra regularizing neural network to the model, + which tries to predict the next frame from the current one + - project_name (str): ID of the current run + - callbacks (list): list of callbacks for the training loop + + Returns: + - best_hparams (dict): dictionary with the best retrieved hyperparameters + - best_run (tf.keras.Model): trained instance of the best model found + + """ + + tensorboard_callback, cp_callback, onecycle = callbacks if hypermodel == "S2SAE": hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape) diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index e1e96d368e7442f2d43d7f0fa984e8614da73bc6..8255781750e38485d59f9ea03841d3301ffeae6c 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -12,8 +12,10 @@ from hypothesis import given from hypothesis import settings from hypothesis import strategies as st from hypothesis.extra.numpy import arrays +import deepof.model_utils import deepof.train_utils -import numpy as np +import keras +import os import tensorflow as tf import tensorflow_probability as tfp from tensorflow.python.framework.ops import EagerTensor @@ -26,20 +28,66 @@ tfpl = tfp.layers tfd = tfp.distributions -@given(encoding=st.integers(min_value=1, max_value=128)) -def test_load_hparams(encoding): - params = deepof.train_utils.load_hparams(None, encoding) - assert type(params) == dict - assert params["encoding"] == encoding +def test_load_hparams(): + assert type(deepof.train_utils.load_hparams(None)) == dict + assert ( + type( + deepof.train_utils.load_hparams( + os.path.join("tests", "test_examples", "Others", "test_hparams.pkl") + ) + ) + == dict + ) def test_load_treatments(): - pass + assert deepof.train_utils.load_treatments(".") is None + assert ( + type( + deepof.train_utils.load_treatments( + os.path.join("tests", "test_examples", "Others") + ) + ) + == dict + ) -def test_get_callbacks(): - pass +@given( + X_train=arrays( + shape=st.tuples(st.integers(min_value=1, max_value=1000)), dtype=float + ), + batch_size=st.integers(min_value=128, max_value=512), + k=st.integers(min_value=1, max_value=50), + kl_wu=st.integers(min_value=0, max_value=25), + loss=st.one_of(st.just("test_A"), st.just("test_B")), + mmd_wu=st.integers(min_value=0, max_value=25), + predictor=st.floats(min_value=0.0, max_value=1.0), + variational=st.booleans(), +) +def test_get_callbacks( + X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu +): + runID, tbc, cpc, cycle1c = deepof.train_utils.get_callbacks( + X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu, + ) + assert type(runID) == str + assert type(tbc) == keras.callbacks.tensorboard_v2.TensorBoard + assert type(cpc) == tf.python.keras.callbacks.ModelCheckpoint + assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler def test_tune_search(): - pass + deepof.train_utils.tune_search( + train, + test, + bayopt_trials, + hypermodel, + k, + kl_wu, + loss, + mmd_wu, + overlap_loss, + predictor, + project_name, + callbacks, + )