Skip to content
Snippets Groups Projects
Commit ebe7200e authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added phenotype classification to train_model.py

parent 00b6a59b
Branches
Tags
No related merge requests found
Pipeline #89302 failed
......@@ -514,7 +514,7 @@ if not tune:
history = gmvaep.fit(
x=Xs,
y=ys,
epochs=35,
epochs=1,
batch_size=batch_size,
verbose=1,
validation_data=(
......@@ -534,16 +534,54 @@ if not tune:
)
)
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
output = gmvaep.predict(X_val)
if pheno_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar(
"val_prediction_mae".format(rec), pred_mae, step=1
)
tf.summary.scalar(
"val_prediction_mse".format(rec), pred_mse, step=1
)
if pheno_class:
pheno_acc = tf.keras.metrics.Accuracy(y_val, pheno)
pheno_auc = tf.keras.metrics.AUC(y_val, pheno)
tf.summary.scalar(
"phenotype_prediction_accuracy", pheno_acc, step=1
)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
# Logparams to tensorboard
tensorboard_metric_logging(
os.path.join(output_path, "hparams", run_ID),
logparam,
X_val,
y_val,
gmvaep,
predictor,
pheno_class,
rec,
)
# To avoid stability issues
......
......@@ -137,52 +137,6 @@ def get_callbacks(
return callbacks
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(
run_dir: str,
hpms: Any,
X: np.ndarray,
y: np.ndarray,
model: tf.python.keras.engine.functional.Functional,
predictor: float,
pheno_class: float,
rec: str,
):
output = model.predict(X)
if pheno_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X, reconstruction)
)
val_mse = tf.reduce_mean(tf.keras.metrics.mean_squared_error(X, reconstruction))
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X, prediction)
)
tf.summary.scalar("val_prediction_mae".format(rec), pred_mae, step=1)
tf.summary.scalar("val_prediction_mse".format(rec), pred_mse, step=1)
if pheno_class:
pheno_acc = tf.keras.metrics.Accuracy(y, pheno)
pheno_auc = tf.keras.metrics.AUC(y, pheno)
tf.summary.scalar("phenotype_prediction_accuracy", pheno_acc, step=1)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
def tune_search(
data: List[np.array],
encoding_size: int,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment