Commit ec451a6f authored by lucas_miranda's avatar lucas_miranda
Browse files

Added better logging on tensorboard for deepof_experiments

parent 59abb003
Pipeline #88579 failed with stage
in 21 minutes and 40 seconds
...@@ -13,6 +13,7 @@ from deepof.data import * ...@@ -13,6 +13,7 @@ from deepof.data import *
from deepof.models import * from deepof.models import *
from deepof.utils import * from deepof.utils import *
from deepof.train_utils import * from deepof.train_utils import *
from tensorboard.plugins.hparams import api as hp
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition" description="Autoencoder training for DeepOF animal pose recognition"
...@@ -106,6 +107,14 @@ parser.add_argument( ...@@ -106,6 +107,14 @@ parser.add_argument(
default=10, default=10,
type=int, type=int,
) )
parser.add_argument(
"--logparam",
"-lp",
help="Logs the specified parameter to tensorboard. None by default. "
"Supported values are: encoding",
default=None,
type=str,
)
parser.add_argument( parser.add_argument(
"--loss", "--loss",
"-l", "-l",
...@@ -202,6 +211,7 @@ hparams = args.hyperparameters ...@@ -202,6 +211,7 @@ hparams = args.hyperparameters
input_type = args.input_type input_type = args.input_type
k = args.components k = args.components
kl_wu = args.kl_warmup kl_wu = args.kl_warmup
logparam = args.logparam
loss = args.loss loss = args.loss
mmd_wu = args.mmd_warmup mmd_wu = args.mmd_warmup
overlap_loss = args.overlap_loss overlap_loss = args.overlap_loss
...@@ -237,6 +247,10 @@ assert input_type in [ ...@@ -237,6 +247,10 @@ assert input_type in [
hparams = load_hparams(hparams) hparams = load_hparams(hparams)
treatment_dict = load_treatments(train_path) treatment_dict = load_treatments(train_path)
# Logs hyperparameters if specified on the --logparam CLI argument
if logparam == "encoding":
logparam = {"encoding": encoding_size}
# noinspection PyTypeChecker # noinspection PyTypeChecker
project_coords = project( project_coords = project(
animal_ids=tuple([animal_id]), animal_ids=tuple([animal_id]),
...@@ -317,14 +331,19 @@ if not tune: ...@@ -317,14 +331,19 @@ if not tune:
tf.keras.backend.clear_session() tf.keras.backend.clear_session()
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks( run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train, X_train=X_train,
batch_size, batch_size=batch_size,
True, cp=True,
variational, variational=variational,
predictor, phenotype_class=pheno_class,
loss, predictor=predictor,
loss=loss,
logparam=logparam,
) )
if logparam == "encoding":
encoding_size = hp.HParam("encoding_dim", hp.Discrete(encoding_size))
if not variational: if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape) encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
print(ae.summary()) print(ae.summary())
...@@ -414,7 +433,9 @@ if not tune: ...@@ -414,7 +433,9 @@ if not tune:
) )
gmvaep.save_weights( gmvaep.save_weights(
"GMVAE_loss={}_encoding={}_final_weights.h5".format(loss, encoding_size) "GMVAE_loss={}_encoding={}_run_{}_final_weights.h5".format(
loss, encoding_size, run
)
) )
# To avoid stability issues # To avoid stability issues
...@@ -426,7 +447,14 @@ else: ...@@ -426,7 +447,14 @@ else:
hyp = "S2SGMVAE" if variational else "S2SAE" hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, onecycle = get_callbacks( run_ID, tensorboard_callback, onecycle = get_callbacks(
X_train, batch_size, False, variational, predictor, loss X_train=X_train,
batch_size=batch_size,
cp=False,
variational=variational,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
logparam=None,
) )
best_hyperparameters, best_model = tune_search( best_hyperparameters, best_model = tune_search(
......
...@@ -68,8 +68,10 @@ def get_callbacks( ...@@ -68,8 +68,10 @@ def get_callbacks(
batch_size: int, batch_size: int,
cp: bool, cp: bool,
variational: bool, variational: bool,
phenotype_class: float,
predictor: float, predictor: float,
loss: str, loss: str,
logparam: dict = None,
) -> List[Union[Any]]: ) -> List[Union[Any]]:
"""Generates callbacks for model training, including: """Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details; - run_ID: run name, with coarse parameter details;
...@@ -77,10 +79,16 @@ def get_callbacks( ...@@ -77,10 +79,16 @@ def get_callbacks(
- cp_callback: for checkpoint saving, - cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling""" - onecycle: for learning rate scheduling"""
run_ID = "{}{}{}_{}".format( run_ID = "{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"), ("GMVAE" if variational else "AE"),
("P" if predictor > 0 and variational else ""), ("P" if predictor > 0 and variational else ""),
("_Pheno" if phenotype_class > 0 else ""),
("_loss={}".format(loss) if variational else ""), ("_loss={}".format(loss) if variational else ""),
(
"_{}={}".format(list(logparam.keys())[0], list(logparam.values())[0])
if logparam is not None
else ""
),
datetime.now().strftime("%Y%m%d-%H%M%S"), datetime.now().strftime("%Y%m%d-%H%M%S"),
) )
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment