Skip to content
Snippets Groups Projects
Commit 88045e14 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Moved model_training.py to main deepof directory, and instanciated testing...

Moved model_training.py to main deepof directory, and instanciated testing module for train_utils.py
parent 23b86b09
No related branches found
No related tags found
No related merge requests found
...@@ -19,13 +19,26 @@ parser = argparse.ArgumentParser( ...@@ -19,13 +19,26 @@ parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition" description="Autoencoder training for DeepOF animal pose recognition"
) )
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
parser.add_argument( parser.add_argument(
"--val-num", "--arena-dims",
"-vn", "-adim",
help="set number of videos of the training" "set to use for validation", help="diameter in mm of the utilised arena. Used for scaling purposes",
type=int, type=int,
default=1, default=380,
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
type=int,
default=25,
) )
parser.add_argument( parser.add_argument(
"--components", "--components",
...@@ -35,35 +48,50 @@ parser.add_argument( ...@@ -35,35 +48,50 @@ parser.add_argument(
default=1, default=1,
) )
parser.add_argument( parser.add_argument(
"--input-type", "--exclude-bodyparts",
"-d", "-exc",
help="Select an input type for the autoencoder hypermodels. \ help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas",
It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle. \
Defaults to coords.",
type=str, type=str,
default="dists", default="",
) )
parser.add_argument( parser.add_argument(
"--predictor", "--gaussian-filter",
"-pred", "-gf",
help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True", help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
default=0, type=str2bool,
type=float, default=False,
) )
parser.add_argument( parser.add_argument(
"--variational", "--hypermodel",
"-v", "-m",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True", help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, "
default=True, "S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
type=str,
default="S2SVAE",
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
type=str2bool, type=str2bool,
default=False,
) )
parser.add_argument( parser.add_argument(
"--loss", "--hyperparameters",
"-l", "-hp",
help="Sets the loss function for the variational model. " help="Path pointing to a pickled dictionary of network hyperparameters. "
"It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD", "Thought to be used with the output of hyperparameter tuning",
default="ELBO+MMD", type=str,
default=None,
)
parser.add_argument(
"--input-type",
"-d",
help="Select an input type for the autoencoder hypermodels. "
"It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
"Defaults to coords.",
type=str, type=str,
default="dists",
) )
parser.add_argument( parser.add_argument(
"--kl-warmup", "--kl-warmup",
...@@ -72,6 +100,14 @@ parser.add_argument( ...@@ -72,6 +100,14 @@ parser.add_argument(
default=10, default=10,
type=int, type=int,
) )
parser.add_argument(
"--loss",
"-l",
help="Sets the loss function for the variational model. "
"It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
default="ELBO+MMD",
type=str,
)
parser.add_argument( parser.add_argument(
"--mmd-warmup", "--mmd-warmup",
"-mmdw", "-mmdw",
...@@ -79,47 +115,50 @@ parser.add_argument( ...@@ -79,47 +115,50 @@ parser.add_argument(
default=10, default=10,
type=int, type=int,
) )
parser.add_argument(
"--hyperparameters",
"-hp",
help="Path pointing to a pickled dictionary of network hyperparameters. "
"Thought to be used with the output of hyperparameter tuning",
)
parser.add_argument(
"--encoding-size",
"-e",
help="Sets the dimensionality of the latent space. Defaults to 16.",
default=16,
type=int,
)
parser.add_argument( parser.add_argument(
"--overlap-loss", "--overlap-loss",
"-ol", "-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function", help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
default=False,
type=str2bool, type=str2bool,
default=False,
) )
parser.add_argument( parser.add_argument(
"--batch-size", "--predictor",
"-bs", "-pred",
help="set training batch size. Defaults to 512", help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
type=int, default=0,
default=512, type=float,
)
parser.add_argument(
"--smooth-alpha",
"-sa",
help="Sets the exponential smoothing factor to apply to the input data. "
"Float between 0 and 1 (lower is more smooting)",
type=float,
default=0.99,
) )
parser.add_argument( parser.add_argument(
"--stability-check", "--stability-check",
"-s", "-s",
help="Sets the number of times that the model is trained and initialised. If greater than 1 (the default), " help="Sets the number of times that the model is trained and initialised. "
"saves the cluster assignments to a dataframe on disk", "If greater than 1 (the default), saves the cluster assignments to a dataframe on disk",
type=int, type=int,
default=1, default=1,
) )
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
parser.add_argument( parser.add_argument(
"--gaussian-filter", "--val-num",
"-gf", "-vn",
help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model", help="set number of videos of the training" "set to use for validation",
type=int,
default=1,
)
parser.add_argument(
"--variational",
"-v",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
default=True,
type=str2bool, type=str2bool,
default=False,
) )
parser.add_argument( parser.add_argument(
"--window-size", "--window-size",
...@@ -135,58 +174,12 @@ parser.add_argument( ...@@ -135,58 +174,12 @@ parser.add_argument(
type=int, type=int,
default=5, default=5,
) )
parser.add_argument(
"--smooth-alpha",
"-sa",
help="Sets the exponential smoothing factor to apply to the input data. "
"Float between 0 and 1 (lower is more smooting)",
type=float,
default=0.99,
)
parser.add_argument(
"--exclude-bodyparts",
"-exc",
help="Excludes the indicated bodyparts from all analyses. "
"It should consist of several values separated by commas",
type=str,
default="",
)
parser.add_argument(
"--arena-dims",
"-adim",
help="diameter in mm of the utilised arena. Used for scaling purposes",
type=int,
default=380,
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
type=str2bool,
default=False,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
default=25,
type=int,
)
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, "
"S2SVAEP, S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
default="S2SVAE",
type=str,
)
args = parser.parse_args() args = parser.parse_args()
arena_dims = args.arena_dims arena_dims = args.arena_dims
batch_size = args.batch_size batch_size = args.batch_size
bayopt_trials = args.bayopt bayopt_trials = args.bayopt
encoding = args.encoding_size
exclude_bodyparts = tuple(args.exclude_bodyparts.split(",")) exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters hparams = args.hyperparameters
...@@ -225,7 +218,7 @@ assert input_type in [ ...@@ -225,7 +218,7 @@ assert input_type in [
], "Invalid input type. Type python model_training.py -h for help." ], "Invalid input type. Type python model_training.py -h for help."
# Loads model hyperparameters and treatment conditions, if available # Loads model hyperparameters and treatment conditions, if available
hparams = load_hparams(hparams, encoding) hparams = load_hparams(hparams)
treatment_dict = load_treatments(train_path) treatment_dict = load_treatments(train_path)
# noinspection PyTypeChecker # noinspection PyTypeChecker
...@@ -407,7 +400,7 @@ else: ...@@ -407,7 +400,7 @@ else:
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
predictor=predictor, predictor=predictor,
project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp), project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
tensorboard_callback=tensorboard_callback, callbacks=[tensorboard_callback, cp_callback, onecycle],
) )
# Saves a compiled, untrained version of the best model # Saves a compiled, untrained version of the best model
......
...@@ -10,7 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t ...@@ -10,7 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
from datetime import datetime from datetime import datetime
from kerastuner import BayesianOptimization from kerastuner import BayesianOptimization
from typing import Tuple, Union, Any from typing import Tuple, Union, Any, List
import deepof.hypermodels import deepof.hypermodels
import deepof.model_utils import deepof.model_utils
import keras import keras
...@@ -20,21 +20,20 @@ import pickle ...@@ -20,21 +20,20 @@ import pickle
import tensorflow as tf import tensorflow as tf
def load_hparams(hparams, encoding): def load_hparams(hparams):
"""Loads hyperparameters from a custom dictionary pickled on disc. """Loads hyperparameters from a custom dictionary pickled on disc.
Thought to be used with the output of hyperparameter_tuning.py""" Thought to be used with the output of hyperparameter_tuning.py"""
if hparams is not None: if hparams is not None:
with open(hparams, "rb") as handle: with open(hparams, "rb") as handle:
hparams = pickle.load(handle) hparams = pickle.load(handle)
hparams["encoding"] = encoding
else: else:
hparams = { hparams = {
"units_conv": 256, "units_conv": 256,
"units_lstm": 256, "units_lstm": 256,
"units_dense2": 64, "units_dense2": 64,
"dropout_rate": 0.25, "dropout_rate": 0.25,
"encoding": encoding, "encoding": 16,
"learning_rate": 1e-3, "learning_rate": 1e-3,
} }
return hparams return hparams
...@@ -47,7 +46,7 @@ def load_treatments(train_path): ...@@ -47,7 +46,7 @@ def load_treatments(train_path):
with open( with open(
os.path.join( os.path.join(
train_path, train_path,
[i for i in os.listdir(train_path) if i.endswith(".pickle")][0], [i for i in os.listdir(train_path) if i.endswith(".pkl")][0],
), ),
"rb", "rb",
) as handle: ) as handle:
...@@ -89,14 +88,12 @@ def get_callbacks( ...@@ -89,14 +88,12 @@ def get_callbacks(
log_dir=log_dir, histogram_freq=1, profile_batch=2, log_dir=log_dir, histogram_freq=1, profile_batch=2,
) )
cp_callback = ( cp_callback = tf.keras.callbacks.ModelCheckpoint(
tf.keras.callbacks.ModelCheckpoint( "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt", verbose=1,
verbose=1, save_best_only=False,
save_best_only=False, save_weights_only=True,
save_weights_only=True, save_freq="epoch",
save_freq="epoch",
),
) )
onecycle = deepof.model_utils.one_cycle_scheduler( onecycle = deepof.model_utils.one_cycle_scheduler(
...@@ -118,9 +115,34 @@ def tune_search( ...@@ -118,9 +115,34 @@ def tune_search(
overlap_loss: float, overlap_loss: float,
predictor: float, predictor: float,
project_name: str, project_name: str,
tensorboard_callback: tf.keras.callbacks, callbacks: List,
) -> Union[bool, Tuple[Any, Any]]: ) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization""" """Define the search space using keras-tuner and bayesian optimization
Parameters:
- train (np.array): dataset to train the model on
- test (np.array): dataset to validate the model on
- bayopt_trials (int): number of Bayesian optimization iterations to run
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder) or
S2SGMVAE (Gaussian Mixture Variational autoencoder).
- k (int) number of components of the Gaussian Mixture
- kl_wu (int): number of epochs for KL divergence warm up
- loss (str): one of [ELBO, MMD, ELBO+MMD]
- mmd_wu (int): number of epochs for MMD warm up
- overlap_loss (float): assigns as weight to an extra loss term which
penalizes overlap between GM components
- predictor (float): adds an extra regularizing neural network to the model,
which tries to predict the next frame from the current one
- project_name (str): ID of the current run
- callbacks (list): list of callbacks for the training loop
Returns:
- best_hparams (dict): dictionary with the best retrieved hyperparameters
- best_run (tf.keras.Model): trained instance of the best model found
"""
tensorboard_callback, cp_callback, onecycle = callbacks
if hypermodel == "S2SAE": if hypermodel == "S2SAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape) hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
......
...@@ -12,8 +12,10 @@ from hypothesis import given ...@@ -12,8 +12,10 @@ from hypothesis import given
from hypothesis import settings from hypothesis import settings
from hypothesis import strategies as st from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays from hypothesis.extra.numpy import arrays
import deepof.model_utils
import deepof.train_utils import deepof.train_utils
import numpy as np import keras
import os
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
from tensorflow.python.framework.ops import EagerTensor from tensorflow.python.framework.ops import EagerTensor
...@@ -26,20 +28,66 @@ tfpl = tfp.layers ...@@ -26,20 +28,66 @@ tfpl = tfp.layers
tfd = tfp.distributions tfd = tfp.distributions
@given(encoding=st.integers(min_value=1, max_value=128)) def test_load_hparams():
def test_load_hparams(encoding): assert type(deepof.train_utils.load_hparams(None)) == dict
params = deepof.train_utils.load_hparams(None, encoding) assert (
assert type(params) == dict type(
assert params["encoding"] == encoding deepof.train_utils.load_hparams(
os.path.join("tests", "test_examples", "Others", "test_hparams.pkl")
)
)
== dict
)
def test_load_treatments(): def test_load_treatments():
pass assert deepof.train_utils.load_treatments(".") is None
assert (
type(
deepof.train_utils.load_treatments(
os.path.join("tests", "test_examples", "Others")
)
)
== dict
)
def test_get_callbacks(): @given(
pass X_train=arrays(
shape=st.tuples(st.integers(min_value=1, max_value=1000)), dtype=float
),
batch_size=st.integers(min_value=128, max_value=512),
k=st.integers(min_value=1, max_value=50),
kl_wu=st.integers(min_value=0, max_value=25),
loss=st.one_of(st.just("test_A"), st.just("test_B")),
mmd_wu=st.integers(min_value=0, max_value=25),
predictor=st.floats(min_value=0.0, max_value=1.0),
variational=st.booleans(),
)
def test_get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
):
runID, tbc, cpc, cycle1c = deepof.train_utils.get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu,
)
assert type(runID) == str
assert type(tbc) == keras.callbacks.tensorboard_v2.TensorBoard
assert type(cpc) == tf.python.keras.callbacks.ModelCheckpoint
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
def test_tune_search(): def test_tune_search():
pass deepof.train_utils.tune_search(
train,
test,
bayopt_trials,
hypermodel,
k,
kl_wu,
loss,
mmd_wu,
overlap_loss,
predictor,
project_name,
callbacks,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment