Commit d866466b authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored train_utils.py

parent 8a8bd3d4
......@@ -139,6 +139,99 @@ def get_callbacks(
return callbacks
def log_hyperparameters():
"""Blueprint for hyperparameter and metric logging in tensorboard during hyperparameter tuning"""
logparams = [
hp.HParam(
"encoding",
hp.Discrete([2, 4, 6, 8, 12, 16]),
display_name="encoding",
description="encoding size dimensionality",
),
hp.HParam(
"k",
hp.IntInterval(min_value=1, max_value=25),
display_name="k",
description="cluster_number",
),
hp.HParam(
"loss",
hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
display_name="loss function",
description="loss function",
),
]
rec = "reconstruction_" if phenotype_class else ""
metrics = [
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
if phenotype_class:
logparams.append(
hp.HParam(
"pheno_weight",
hp.RealInterval(min_value=0.0, max_value=1000.0),
display_name="pheno weight",
description="weight applied to phenotypic classifier from the latent space",
)
)
metrics += [
hp.Metric(
"phenotype_prediction_accuracy",
display_name="phenotype_prediction_accuracy",
),
hp.Metric(
"phenotype_prediction_auc",
display_name="phenotype_prediction_auc",
),
]
return logparams, metrics
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
"""Autoencoder metric logging in tensorboard"""
output = ae.predict(X_val)
if phenotype_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar("val_prediction_mae".format(rec), pred_mae, step=1)
tf.summary.scalar("val_prediction_mse".format(rec), pred_mse, step=1)
if phenotype_class:
pheno_acc = tf.keras.metrics.binary_accuracy(y_val, tf.squeeze(pheno))
pheno_auc = roc_auc_score(y_val, pheno)
tf.summary.scalar("phenotype_prediction_accuracy", pheno_acc, step=1)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
def autoencoder_fitting(
preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
batch_size: int,
......@@ -195,51 +288,7 @@ def autoencoder_fitting(
# Logs hyperparameters to tensorboard
if log_hparams:
logparams = [
hp.HParam(
"encoding",
hp.Discrete([2, 4, 6, 8, 12, 16]),
display_name="encoding",
description="encoding size dimensionality",
),
hp.HParam(
"k",
hp.IntInterval(min_value=1, max_value=25),
display_name="k",
description="cluster_number",
),
hp.HParam(
"loss",
hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
display_name="loss function",
description="loss function",
),
]
rec = "reconstruction_" if phenotype_class else ""
metrics = [
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
if phenotype_class:
logparams.append(
hp.HParam(
"pheno_weight",
hp.RealInterval(min_value=0.0, max_value=1000.0),
display_name="pheno weight",
description="weight applied to phenotypic classifier from the latent space",
)
)
metrics += [
hp.Metric(
"phenotype_prediction_accuracy",
display_name="phenotype_prediction_accuracy",
),
hp.Metric(
"phenotype_prediction_auc",
display_name="phenotype_prediction_auc",
),
]
logparams, metrics = log_hyperparameters()
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
......@@ -358,54 +407,6 @@ def autoencoder_fitting(
ae.save_weights("{}_final_weights.h5".format(run_ID))
if log_hparams:
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
output = ae.predict(X_val)
if phenotype_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar(
"val_prediction_mae".format(rec), pred_mae, step=1
)
tf.summary.scalar(
"val_prediction_mse".format(rec), pred_mse, step=1
)
if phenotype_class:
pheno_acc = tf.keras.metrics.binary_accuracy(
y_val, tf.squeeze(pheno)
)
pheno_auc = roc_auc_score(y_val, pheno)
tf.summary.scalar(
"phenotype_prediction_accuracy", pheno_acc, step=1
)
tf.summary.scalar(
"phenotype_prediction_auc", pheno_auc, step=1
)
# Logparams to tensorboard
tensorboard_metric_logging(
os.path.join(output_path, "hparams", run_ID),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment