Commit ec451a6f authored by lucas_miranda's avatar lucas_miranda
Browse files

Added better logging on tensorboard for deepof_experiments

parent 59abb003
Pipeline #88579 failed with stage
in 21 minutes and 40 seconds
......@@ -13,6 +13,7 @@ from deepof.data import *
from deepof.models import *
from deepof.utils import *
from deepof.train_utils import *
from tensorboard.plugins.hparams import api as hp
parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition"
......@@ -106,6 +107,14 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--logparam",
"-lp",
help="Logs the specified parameter to tensorboard. None by default. "
"Supported values are: encoding",
default=None,
type=str,
)
parser.add_argument(
"--loss",
"-l",
......@@ -202,6 +211,7 @@ hparams = args.hyperparameters
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
logparam = args.logparam
loss = args.loss
mmd_wu = args.mmd_warmup
overlap_loss = args.overlap_loss
......@@ -237,6 +247,10 @@ assert input_type in [
hparams = load_hparams(hparams)
treatment_dict = load_treatments(train_path)
# Logs hyperparameters if specified on the --logparam CLI argument
if logparam == "encoding":
logparam = {"encoding": encoding_size}
# noinspection PyTypeChecker
project_coords = project(
animal_ids=tuple([animal_id]),
......@@ -317,14 +331,19 @@ if not tune:
tf.keras.backend.clear_session()
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train,
batch_size,
True,
variational,
predictor,
loss,
X_train=X_train,
batch_size=batch_size,
cp=True,
variational=variational,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
logparam=logparam,
)
if logparam == "encoding":
encoding_size = hp.HParam("encoding_dim", hp.Discrete(encoding_size))
if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
print(ae.summary())
......@@ -414,7 +433,9 @@ if not tune:
)
gmvaep.save_weights(
"GMVAE_loss={}_encoding={}_final_weights.h5".format(loss, encoding_size)
"GMVAE_loss={}_encoding={}_run_{}_final_weights.h5".format(
loss, encoding_size, run
)
)
# To avoid stability issues
......@@ -426,7 +447,14 @@ else:
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, onecycle = get_callbacks(
X_train, batch_size, False, variational, predictor, loss
X_train=X_train,
batch_size=batch_size,
cp=False,
variational=variational,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
logparam=None,
)
best_hyperparameters, best_model = tune_search(
......
......@@ -68,8 +68,10 @@ def get_callbacks(
batch_size: int,
cp: bool,
variational: bool,
phenotype_class: float,
predictor: float,
loss: str,
logparam: dict = None,
) -> List[Union[Any]]:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
......@@ -77,10 +79,16 @@ def get_callbacks(
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
run_ID = "{}{}{}_{}".format(
run_ID = "{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("P" if predictor > 0 and variational else ""),
("_Pheno" if phenotype_class > 0 else ""),
("_loss={}".format(loss) if variational else ""),
(
"_{}={}".format(list(logparam.keys())[0], list(logparam.values())[0])
if logparam is not None
else ""
),
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment