Commit 88045e14 authored by lucas_miranda's avatar lucas_miranda
Browse files

Moved model_training.py to main deepof directory, and instanciated testing...

Moved model_training.py to main deepof directory, and instanciated testing module for train_utils.py
parent 23b86b09
......@@ -19,13 +19,26 @@ parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition"
)
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
parser.add_argument(
"--val-num",
"-vn",
help="set number of videos of the training" "set to use for validation",
"--arena-dims",
"-adim",
help="diameter in mm of the utilised arena. Used for scaling purposes",
type=int,
default=1,
default=380,
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
type=int,
default=25,
)
parser.add_argument(
"--components",
......@@ -35,35 +48,50 @@ parser.add_argument(
default=1,
)
parser.add_argument(
"--input-type",
"-d",
help="Select an input type for the autoencoder hypermodels. \
It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle. \
Defaults to coords.",
"--exclude-bodyparts",
"-exc",
help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas",
type=str,
default="dists",
default="",
)
parser.add_argument(
"--predictor",
"-pred",
help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
default=0,
type=float,
"--gaussian-filter",
"-gf",
help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
type=str2bool,
default=False,
)
parser.add_argument(
"--variational",
"-v",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
default=True,
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, "
"S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
type=str,
default="S2SVAE",
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
type=str2bool,
default=False,
)
parser.add_argument(
"--loss",
"-l",
help="Sets the loss function for the variational model. "
"It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
default="ELBO+MMD",
"--hyperparameters",
"-hp",
help="Path pointing to a pickled dictionary of network hyperparameters. "
"Thought to be used with the output of hyperparameter tuning",
type=str,
default=None,
)
parser.add_argument(
"--input-type",
"-d",
help="Select an input type for the autoencoder hypermodels. "
"It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
"Defaults to coords.",
type=str,
default="dists",
)
parser.add_argument(
"--kl-warmup",
......@@ -72,6 +100,14 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--loss",
"-l",
help="Sets the loss function for the variational model. "
"It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
default="ELBO+MMD",
type=str,
)
parser.add_argument(
"--mmd-warmup",
"-mmdw",
......@@ -79,47 +115,50 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--hyperparameters",
"-hp",
help="Path pointing to a pickled dictionary of network hyperparameters. "
"Thought to be used with the output of hyperparameter tuning",
)
parser.add_argument(
"--encoding-size",
"-e",
help="Sets the dimensionality of the latent space. Defaults to 16.",
default=16,
type=int,
)
parser.add_argument(
"--overlap-loss",
"-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
default=False,
type=str2bool,
default=False,
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512,
"--predictor",
"-pred",
help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
default=0,
type=float,
)
parser.add_argument(
"--smooth-alpha",
"-sa",
help="Sets the exponential smoothing factor to apply to the input data. "
"Float between 0 and 1 (lower is more smooting)",
type=float,
default=0.99,
)
parser.add_argument(
"--stability-check",
"-s",
help="Sets the number of times that the model is trained and initialised. If greater than 1 (the default), "
"saves the cluster assignments to a dataframe on disk",
help="Sets the number of times that the model is trained and initialised. "
"If greater than 1 (the default), saves the cluster assignments to a dataframe on disk",
type=int,
default=1,
)
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
parser.add_argument(
"--gaussian-filter",
"-gf",
help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
"--val-num",
"-vn",
help="set number of videos of the training" "set to use for validation",
type=int,
default=1,
)
parser.add_argument(
"--variational",
"-v",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
default=True,
type=str2bool,
default=False,
)
parser.add_argument(
"--window-size",
......@@ -135,58 +174,12 @@ parser.add_argument(
type=int,
default=5,
)
parser.add_argument(
"--smooth-alpha",
"-sa",
help="Sets the exponential smoothing factor to apply to the input data. "
"Float between 0 and 1 (lower is more smooting)",
type=float,
default=0.99,
)
parser.add_argument(
"--exclude-bodyparts",
"-exc",
help="Excludes the indicated bodyparts from all analyses. "
"It should consist of several values separated by commas",
type=str,
default="",
)
parser.add_argument(
"--arena-dims",
"-adim",
help="diameter in mm of the utilised arena. Used for scaling purposes",
type=int,
default=380,
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
type=str2bool,
default=False,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
default=25,
type=int,
)
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, "
"S2SVAEP, S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
default="S2SVAE",
type=str,
)
args = parser.parse_args()
arena_dims = args.arena_dims
batch_size = args.batch_size
bayopt_trials = args.bayopt
encoding = args.encoding_size
exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
......@@ -225,7 +218,7 @@ assert input_type in [
], "Invalid input type. Type python model_training.py -h for help."
# Loads model hyperparameters and treatment conditions, if available
hparams = load_hparams(hparams, encoding)
hparams = load_hparams(hparams)
treatment_dict = load_treatments(train_path)
# noinspection PyTypeChecker
......@@ -407,7 +400,7 @@ else:
overlap_loss=overlap_loss,
predictor=predictor,
project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
tensorboard_callback=tensorboard_callback,
callbacks=[tensorboard_callback, cp_callback, onecycle],
)
# Saves a compiled, untrained version of the best model
......
......@@ -10,7 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
from datetime import datetime
from kerastuner import BayesianOptimization
from typing import Tuple, Union, Any
from typing import Tuple, Union, Any, List
import deepof.hypermodels
import deepof.model_utils
import keras
......@@ -20,21 +20,20 @@ import pickle
import tensorflow as tf
def load_hparams(hparams, encoding):
def load_hparams(hparams):
"""Loads hyperparameters from a custom dictionary pickled on disc.
Thought to be used with the output of hyperparameter_tuning.py"""
if hparams is not None:
with open(hparams, "rb") as handle:
hparams = pickle.load(handle)
hparams["encoding"] = encoding
else:
hparams = {
"units_conv": 256,
"units_lstm": 256,
"units_dense2": 64,
"dropout_rate": 0.25,
"encoding": encoding,
"encoding": 16,
"learning_rate": 1e-3,
}
return hparams
......@@ -47,7 +46,7 @@ def load_treatments(train_path):
with open(
os.path.join(
train_path,
[i for i in os.listdir(train_path) if i.endswith(".pickle")][0],
[i for i in os.listdir(train_path) if i.endswith(".pkl")][0],
),
"rb",
) as handle:
......@@ -89,14 +88,12 @@ def get_callbacks(
log_dir=log_dir, histogram_freq=1, profile_batch=2,
)
cp_callback = (
tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
verbose=1,
save_best_only=False,
save_weights_only=True,
save_freq="epoch",
),
cp_callback = tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
verbose=1,
save_best_only=False,
save_weights_only=True,
save_freq="epoch",
)
onecycle = deepof.model_utils.one_cycle_scheduler(
......@@ -118,9 +115,34 @@ def tune_search(
overlap_loss: float,
predictor: float,
project_name: str,
tensorboard_callback: tf.keras.callbacks,
callbacks: List,
) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization"""
"""Define the search space using keras-tuner and bayesian optimization
Parameters:
- train (np.array): dataset to train the model on
- test (np.array): dataset to validate the model on
- bayopt_trials (int): number of Bayesian optimization iterations to run
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder) or
S2SGMVAE (Gaussian Mixture Variational autoencoder).
- k (int) number of components of the Gaussian Mixture
- kl_wu (int): number of epochs for KL divergence warm up
- loss (str): one of [ELBO, MMD, ELBO+MMD]
- mmd_wu (int): number of epochs for MMD warm up
- overlap_loss (float): assigns as weight to an extra loss term which
penalizes overlap between GM components
- predictor (float): adds an extra regularizing neural network to the model,
which tries to predict the next frame from the current one
- project_name (str): ID of the current run
- callbacks (list): list of callbacks for the training loop
Returns:
- best_hparams (dict): dictionary with the best retrieved hyperparameters
- best_run (tf.keras.Model): trained instance of the best model found
"""
tensorboard_callback, cp_callback, onecycle = callbacks
if hypermodel == "S2SAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
......
......@@ -12,8 +12,10 @@ from hypothesis import given
from hypothesis import settings
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
import deepof.model_utils
import deepof.train_utils
import numpy as np
import keras
import os
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.framework.ops import EagerTensor
......@@ -26,20 +28,66 @@ tfpl = tfp.layers
tfd = tfp.distributions
@given(encoding=st.integers(min_value=1, max_value=128))
def test_load_hparams(encoding):
params = deepof.train_utils.load_hparams(None, encoding)
assert type(params) == dict
assert params["encoding"] == encoding
def test_load_hparams():
assert type(deepof.train_utils.load_hparams(None)) == dict
assert (
type(
deepof.train_utils.load_hparams(
os.path.join("tests", "test_examples", "Others", "test_hparams.pkl")
)
)
== dict
)
def test_load_treatments():
pass
assert deepof.train_utils.load_treatments(".") is None
assert (
type(
deepof.train_utils.load_treatments(
os.path.join("tests", "test_examples", "Others")
)
)
== dict
)
def test_get_callbacks():
pass
@given(
X_train=arrays(
shape=st.tuples(st.integers(min_value=1, max_value=1000)), dtype=float
),
batch_size=st.integers(min_value=128, max_value=512),
k=st.integers(min_value=1, max_value=50),
kl_wu=st.integers(min_value=0, max_value=25),
loss=st.one_of(st.just("test_A"), st.just("test_B")),
mmd_wu=st.integers(min_value=0, max_value=25),
predictor=st.floats(min_value=0.0, max_value=1.0),
variational=st.booleans(),
)
def test_get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
):
runID, tbc, cpc, cycle1c = deepof.train_utils.get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu,
)
assert type(runID) == str
assert type(tbc) == keras.callbacks.tensorboard_v2.TensorBoard
assert type(cpc) == tf.python.keras.callbacks.ModelCheckpoint
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
def test_tune_search():
pass
deepof.train_utils.tune_search(
train,
test,
bayopt_trials,
hypermodel,
k,
kl_wu,
loss,
mmd_wu,
overlap_loss,
predictor,
project_name,
callbacks,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment