Commit ffd17471 authored by lucas_miranda's avatar lucas_miranda
Browse files

Integrated autoencoder training to main module

parent 19df12a5
......@@ -790,7 +790,7 @@ class coordinates:
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
outpath: str = ".",
output_path: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
......@@ -843,7 +843,7 @@ class coordinates:
mmd_warmup,
montecarlo_kl,
n_components,
outpath,
output_path,
phenotype_class,
predictor,
pretrained,
......
......@@ -622,7 +622,7 @@ else:
loss=loss,
mmd_warmup_epochs=mmd_wu,
overlap_loss=overlap_loss,
pheno_class=pheno_class,
phenotype_class=pheno_class,
predictor=predictor,
project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
callbacks=[
......
......@@ -21,9 +21,6 @@ import os
import pickle
import tensorflow as tf
hp = HyperParameters()
class CustomStopper(tf.keras.callbacks.EarlyStopping):
""" Custom callback for """
......@@ -147,7 +144,7 @@ def deep_unsupervised_embedding(
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
outpath: str = ".",
output_path: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
......@@ -162,10 +159,10 @@ def deep_unsupervised_embedding(
# To avoid stability issues
tf.keras.backend.clear_session()
# defines what to log on tensorboard (useful for trying out different models)
# Defines what to log on tensorboard (useful for trying out different models)
logparam = {
"encoding": encoding_size,
"k": k,
"k": n_components,
"loss": loss,
}
if phenotype_class:
......@@ -184,6 +181,7 @@ def deep_unsupervised_embedding(
outpath=output_path,
)
# Logs hyperparameters to tensorboard
logparams = [
hp.HParam(
"encoding",
......@@ -203,12 +201,6 @@ def deep_unsupervised_embedding(
display_name="loss function",
description="loss function",
),
hp.HParam(
"run",
hp.Discrete([0, 1, 2]),
display_name="trial run",
description="trial run",
),
]
rec = "reconstruction_" if phenotype_class else ""
......@@ -216,7 +208,6 @@ def deep_unsupervised_embedding(
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
logparam["run"] = run
if phenotype_class:
logparams.append(
hp.HParam(
......@@ -238,7 +229,7 @@ def deep_unsupervised_embedding(
]
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
os.path.join(output_path, "hparams", run_ID)
).as_default():
hp.hparams_config(
hparams=logparams,
......@@ -302,7 +293,7 @@ def deep_unsupervised_embedding(
monitor="val_loss",
patience=5,
restore_best_weights=True,
start_epoch=max(kl_wu, mmd_wu),
start_epoch=max(kl_warmup, mmd_warmup),
),
],
)
......@@ -311,20 +302,20 @@ def deep_unsupervised_embedding(
callbacks_ = [
tensorboard_callback,
# cp_callback,
cp_callback,
onecycle,
CustomStopper(
monitor="val_loss",
patience=5,
restore_best_weights=True,
start_epoch=max(kl_wu, mmd_wu),
start_epoch=max(kl_warmup, mmd_warmup),
),
]
if "ELBO" in loss and kl_wu > 0:
if "ELBO" in loss and kl_warmup > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(kl_warmup_callback)
if "MMD" in loss and mmd_wu > 0:
if "MMD" in loss and mmd_warmup > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(mmd_warmup_callback)
......@@ -339,7 +330,7 @@ def deep_unsupervised_embedding(
ys += [y_train]
yvals += [y_val]
gmvaep.fit(
ae.fit(
x=Xs,
y=ys,
epochs=35,
......@@ -352,6 +343,58 @@ def deep_unsupervised_embedding(
callbacks=callbacks_,
)
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
output = gmvaep.predict(X_val)
if phenotype_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar(
"val_prediction_mae".format(rec), pred_mae, step=1
)
tf.summary.scalar(
"val_prediction_mse".format(rec), pred_mse, step=1
)
if phenotype_class:
pheno_acc = tf.keras.metrics.binary_accuracy(
y_val, tf.squeeze(pheno)
)
pheno_auc = roc_auc_score(y_val, pheno)
tf.summary.scalar(
"phenotype_prediction_accuracy", pheno_acc, step=1
)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
# Logparams to tensorboard
tensorboard_metric_logging(
os.path.join(output_path, "hparams", run_ID),
logparam,
)
return return_list
......@@ -410,7 +453,7 @@ def tune_search(
if hypermodel == "S2SAE": # pragma: no cover
assert (
predictor == 0.0 and phenotype_class == 0.0
predictor == 0.0 and phenotype_class == 0.0
), "Prediction branches are only available for variational models. See documentation for more details"
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment