Commit c1177f8b authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 3cef046e
Pipeline #98386 passed with stages
in 18 minutes and 42 seconds
......@@ -223,37 +223,29 @@ def tensorboard_metric_logging(
):
"""Autoencoder metric logging in tensorboard"""
output = ae.predict(X_val)
if any([next_sequence_prediction, phenotype_prediction, rule_based_prediction]):
reconstruction_pred, reconstruction_true = output[0], y_val[0]
nextseq_pred, nextseq_true = output[1], y_val[1]
pheno_pred, pheno_true = output[2], y_val[2]
rules_pred, rules_true = output[3], y_val[3]
else:
reconstruction = output
outputs = ae.predict(X_val)
idx_generator = (idx for idx in range(len(outputs)))
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
idx = next(idx_generator)
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(
reconstruction_pred, reconstruction_true
)
tf.keras.metrics.mean_absolute_error(y_val[idx], outputs[idx])
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(
reconstruction_pred, reconstruction_true
)
tf.keras.metrics.mean_squared_error(y_val[idx], outputs[idx])
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if next_sequence_prediction:
idx = next(idx_generator)
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(nextseq_pred, nextseq_true)
tf.keras.metrics.mean_absolute_error(y_val[idx], outputs[idx])
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(nextseq_pred, nextseq_true)
tf.keras.metrics.mean_squared_error(y_val[idx], outputs[idx])
)
tf.summary.scalar(
"val_next_sequence_prediction_mae".format(rec), pred_mae, step=1
......@@ -263,22 +255,22 @@ def tensorboard_metric_logging(
)
if phenotype_prediction:
pheno_acc = tf.keras.metrics.binary_accuracy(
pheno_true, tf.squeeze(pheno_pred)
)
idx = next(idx_generator)
pheno_acc = tf.keras.metrics.binary_accuracy(y_val[idx], outputs[idx])
pheno_auc = tf.keras.metrics.AUC()
pheno_auc.update_state(pheno_true, pheno_pred)
pheno_auc.update_state(y_val[idx], outputs[idx])
pheno_auc = pheno_auc.result().numpy()
tf.summary.scalar("phenotype_prediction_accuracy", pheno_acc, step=1)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
if rule_based_prediction:
idx = next(idx_generator)
rules_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(rules_pred, rules_true)
tf.keras.metrics.mean_absolute_error(y_val[idx], outputs[idx])
)
rules_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(rules_pred, rules_true)
tf.keras.metrics.mean_squared_error(y_val[idx], outputs[idx])
)
tf.summary.scalar("val_prediction_mae".format(rec), rules_mae, step=1)
tf.summary.scalar("val_prediction_mse".format(rec), rules_mse, step=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment