diff --git a/deepof/train_model.py b/deepof/train_model.py index f2c69eac751d68260af9d5787c32cf3c9e3d7728..c604c5f943a7d36a6993ffbd8ab31afe9b3d6fa8 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -13,6 +13,7 @@ from deepof.data import * from deepof.models import * from deepof.utils import * from deepof.train_utils import * +from tensorboard.plugins.hparams import api as hp parser = argparse.ArgumentParser( description="Autoencoder training for DeepOF animal pose recognition" @@ -106,6 +107,14 @@ parser.add_argument( default=10, type=int, ) +parser.add_argument( + "--logparam", + "-lp", + help="Logs the specified parameter to tensorboard. None by default. " + "Supported values are: encoding", + default=None, + type=str, +) parser.add_argument( "--loss", "-l", @@ -202,6 +211,7 @@ hparams = args.hyperparameters input_type = args.input_type k = args.components kl_wu = args.kl_warmup +logparam = args.logparam loss = args.loss mmd_wu = args.mmd_warmup overlap_loss = args.overlap_loss @@ -237,6 +247,10 @@ assert input_type in [ hparams = load_hparams(hparams) treatment_dict = load_treatments(train_path) +# Logs hyperparameters if specified on the --logparam CLI argument +if logparam == "encoding": + logparam = {"encoding": encoding_size} + # noinspection PyTypeChecker project_coords = project( animal_ids=tuple([animal_id]), @@ -317,14 +331,19 @@ if not tune: tf.keras.backend.clear_session() run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks( - X_train, - batch_size, - True, - variational, - predictor, - loss, + X_train=X_train, + batch_size=batch_size, + cp=True, + variational=variational, + phenotype_class=pheno_class, + predictor=predictor, + loss=loss, + logparam=logparam, ) + if logparam == "encoding": + encoding_size = hp.HParam("encoding_dim", hp.Discrete(encoding_size)) + if not variational: encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape) print(ae.summary()) @@ -414,7 +433,9 @@ if not tune: ) gmvaep.save_weights( - "GMVAE_loss={}_encoding={}_final_weights.h5".format(loss, encoding_size) + "GMVAE_loss={}_encoding={}_run_{}_final_weights.h5".format( + loss, encoding_size, run + ) ) # To avoid stability issues @@ -426,7 +447,14 @@ else: hyp = "S2SGMVAE" if variational else "S2SAE" run_ID, tensorboard_callback, onecycle = get_callbacks( - X_train, batch_size, False, variational, predictor, loss + X_train=X_train, + batch_size=batch_size, + cp=False, + variational=variational, + phenotype_class=pheno_class, + predictor=predictor, + loss=loss, + logparam=None, ) best_hyperparameters, best_model = tune_search( diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 2b10736dec2f574650bae673864ce241c632dfb6..f8d1deee4aa223c682ed4b58602c459e71ac813b 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -68,8 +68,10 @@ def get_callbacks( batch_size: int, cp: bool, variational: bool, + phenotype_class: float, predictor: float, loss: str, + logparam: dict = None, ) -> List[Union[Any]]: """Generates callbacks for model training, including: - run_ID: run name, with coarse parameter details; @@ -77,10 +79,16 @@ def get_callbacks( - cp_callback: for checkpoint saving, - onecycle: for learning rate scheduling""" - run_ID = "{}{}{}_{}".format( + run_ID = "{}{}{}{}{}_{}".format( ("GMVAE" if variational else "AE"), ("P" if predictor > 0 and variational else ""), + ("_Pheno" if phenotype_class > 0 else ""), ("_loss={}".format(loss) if variational else ""), + ( + "_{}={}".format(list(logparam.keys())[0], list(logparam.values())[0]) + if logparam is not None + else "" + ), datetime.now().strftime("%Y%m%d-%H%M%S"), )