Commit ba5967e6 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented autoencoder fitting as part of main module in data.py

parent d79fb2d0
......@@ -797,6 +797,7 @@ class coordinates:
predictor: float = 0,
pretrained: str = False,
save_checkpoints: bool = False,
save_weights: bool = True,
variational: bool = True,
) -> Tuple:
"""
......@@ -852,6 +853,7 @@ class coordinates:
predictor=predictor,
pretrained=pretrained,
save_checkpoints=save_checkpoints,
save_weights=save_weights,
variational=variational,
)
......
......@@ -346,254 +346,25 @@ print("Done!")
# as many times as specified by runs
if not tune:
# Training loop
for run in range(runs):
# To avoid stability issues
tf.keras.backend.clear_session()
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train=X_train,
batch_size=batch_size,
cp=True,
variational=variational,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
logparam=logparam,
outpath=output_path,
)
logparams = [
hp.HParam(
"encoding",
hp.Discrete([2, 4, 6, 8, 12, 16]),
display_name="encoding",
description="encoding size dimensionality",
),
hp.HParam(
"k",
hp.IntInterval(min_value=1, max_value=15),
display_name="k",
description="cluster_number",
),
hp.HParam(
"loss",
hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
display_name="loss function",
description="loss function",
),
hp.HParam(
"run",
hp.Discrete([0, 1, 2]),
display_name="trial run",
description="trial run",
),
]
rec = "reconstruction_" if pheno_class else ""
metrics = [
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
logparam["run"] = run
if pheno_class:
logparams.append(
hp.HParam(
"pheno_weight",
hp.RealInterval(min_value=0.0, max_value=1000.0),
display_name="pheno weight",
description="weight applied to phenotypic classifier from the latent space",
)
)
metrics += [
hp.Metric(
"phenotype_prediction_accuracy",
display_name="phenotype_prediction_accuracy",
),
hp.Metric(
"phenotype_prediction_auc",
display_name="phenotype_prediction_auc",
),
]
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
).as_default():
hp.hparams_config(
hparams=logparams,
metrics=metrics,
)
if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
print(ae.summary())
ae.save_weights(
os.path.join(
output_path, "/checkpoints/cp-{epoch:04d}.ckpt".format(epoch=0)
)
)
# Fit the specified model to the data
history = ae.fit(
x=X_train,
y=X_train,
epochs=35,
batch_size=batch_size,
verbose=1,
validation_data=(X_val, X_val),
callbacks=[
tensorboard_callback,
cp_callback,
onecycle,
CustomStopper(
monitor="val_loss",
patience=5,
restore_best_weights=True,
start_epoch=max(kl_wu, mmd_wu),
),
],
)
ae.save_weights("{}_final_weights.h5".format(run_ID))
else:
(
encoder,
generator,
grouper,
gmvaep,
kl_warmup_callback,
mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
architecture_hparams=hparams,
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_warmup_epochs=kl_wu,
loss=loss,
mmd_warmup_epochs=mmd_wu,
montecarlo_kl=mc_kl,
neuron_control=neuron_control,
number_of_components=k,
overlap_loss=overlap_loss,
phenotype_prediction=pheno_class,
predictor=predictor,
).build(
X_train.shape
)
print(gmvaep.summary())
callbacks_ = [
tensorboard_callback,
# cp_callback,
onecycle,
CustomStopper(
monitor="val_loss",
patience=5,
restore_best_weights=True,
start_epoch=max(kl_wu, mmd_wu),
),
]
if "ELBO" in loss and kl_wu > 0:
callbacks_.append(kl_warmup_callback)
if "MMD" in loss and mmd_wu > 0:
callbacks_.append(mmd_warmup_callback)
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
if predictor > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if pheno_class > 0.0:
ys += [y_train]
yvals += [y_val]
history = gmvaep.fit(
x=Xs,
y=ys,
epochs=35,
batch_size=batch_size,
verbose=1,
validation_data=(
Xvals,
yvals,
),
callbacks=callbacks_,
)
gmvaep.save_weights(
os.path.join(
output_path,
"trained_weights",
"GMVAE_loss={}_encoding={}_k={}_{}{}run_{}_final_weights.h5".format(
loss,
encoding_size,
k,
("pheno={}_".format(pheno_class) if pheno_class else ""),
("predictor={}_".format(predictor) if predictor else ""),
run,
),
)
)
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
output = gmvaep.predict(X_val)
if pheno_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar(
"val_prediction_mae".format(rec), pred_mae, step=1
)
tf.summary.scalar(
"val_prediction_mse".format(rec), pred_mse, step=1
)
if pheno_class:
pheno_acc = tf.keras.metrics.binary_accuracy(
y_val, tf.squeeze(pheno)
)
pheno_auc = roc_auc_score(y_val, pheno)
tf.summary.scalar(
"phenotype_prediction_accuracy", pheno_acc, step=1
)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
# Logparams to tensorboard
tensorboard_metric_logging(
os.path.join(output_path, "hparams", run_ID),
logparam,
)
# To avoid stability issues
tf.keras.backend.clear_session()
trained_models = deep_unsupervised_embedding(
(X_train, y_train, X_val, y_val),
batch_size=batch_size,
encoding_size=encoding_size,
hparams=hparams,
kl_warmup=kl_wu,
log_history=True,
log_hparams=True,
loss=loss,
mmd_warmup=mmd_wu,
montecarlo_kl=mc_kl,
n_components=k,
output_path=output_path,
phenotype_class=pheno_class,
predictor=predictor,
save_checkpoints=False,
save_weights=True,
variational=variational,
)
else:
# Runs hyperparameter tuning with the specified parameters and saves the results
......
......@@ -156,6 +156,7 @@ def deep_unsupervised_embedding(
predictor: float,
pretrained: str,
save_checkpoints: bool,
save_weights: bool,
variational: bool,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment