From 02b17a2203f8176d27ff1cc98fbd0773119dd3c3 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 9 Dec 2020 21:01:54 +0100 Subject: [PATCH] Added phenotype classification rule to snakemake pipeline --- deepof/train_model.py | 62 ++++++++++++++++++++++++++++--------------- deepof/train_utils.py | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 22 deletions(-) diff --git a/deepof/train_model.py b/deepof/train_model.py index 0b233b97..10b2a382 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -268,6 +268,8 @@ logparam = { "k": k, "loss": loss, } +if pheno_class: + logparam["pheno_weight"] = pheno_class # noinspection PyTypeChecker project_coords = project( @@ -293,7 +295,7 @@ coords = project_coords.get_coords( center=animal_id + undercond + "Center", align=animal_id + undercond + "Spine_1", align_inplace=True, - propagate_labels=(pheno_class > 0) + propagate_labels=(pheno_class > 0), ) distances = project_coords.get_distances() angles = project_coords.get_angles() @@ -388,17 +390,40 @@ if not tune: ), ] + rec = "reconstruction_" if (pheno_class) else "" + metrics = [ + hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)), + hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)), + ] logparam["run"] = run + if pheno_class: + logparams.append( + hp.HParam( + "pheno_weight", + hp.RealInterval(min_value=0, max_value=1000), + display_name="pheno weight", + description="weight applied to phenotypic classifier from the latent space", + ) + ) + metrics.append( + [ + hp.Metric( + "phenotype_prediction_accuracy", + display_name="phenotype_prediction_accuracy", + ), + hp.Metric( + "phenotype_prediction_auc", + display_name="phenotype_prediction_auc", + ), + ] + ) with tf.summary.create_file_writer( os.path.join(output_path, "hparams", run_ID) ).as_default(): hp.hparams_config( hparams=logparams, - metrics=[ - hp.Metric("val_mae", display_name="val_mae"), - hp.Metric("val_mse", display_name="val_mse"), - ], + metrics=metrics, ) if not variational: @@ -512,23 +537,16 @@ if not tune: ) # Logparams to tensorboard - def run(run_dir, hpms): - with tf.summary.create_file_writer(run_dir).as_default(): - hp.hparams(hpms) # record the values used in this trial - val_mae = tf.reduce_mean( - tf.keras.metrics.mean_absolute_error( - X_val, gmvaep.predict(X_val) - ) - ) - val_mse = tf.reduce_mean( - tf.keras.metrics.mean_squared_error( - X_val, gmvaep.predict(X_val) - ) - ) - tf.summary.scalar("val_mae", val_mae, step=1) - tf.summary.scalar("val_mse", val_mse, step=1) - - run(os.path.join(output_path, "hparams", run_ID), logparam) + tensorboard_metric_logging( + os.path.join(output_path, "hparams", run_ID), + logparam, + X_val, + y_val, + gmvaep, + predictor, + pheno_class, + rec, + ) # To avoid stability issues tf.keras.backend.clear_session() diff --git a/deepof/train_utils.py b/deepof/train_utils.py index a8ac8d9d..6f227b1d 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -12,6 +12,7 @@ from datetime import date, datetime from kerastuner import BayesianOptimization, Hyperband from kerastuner import HyperParameters from kerastuner_tensorboard_logger import TensorBoardLogger +from tensorboard.plugins.hparams import api as hp from typing import Tuple, Union, Any, List import deepof.hypermodels import deepof.model_utils @@ -136,6 +137,52 @@ def get_callbacks( return callbacks +# noinspection PyUnboundLocalVariable +def tensorboard_metric_logging( + run_dir: str, + hpms: Any, + X: np.ndarray, + y: np.ndarray, + model: tensorflow.python.keras.engine.functional.Functional, + predictor: float, + pheno_class: float, + rec: str, +): + output = model.predict(X) + if pheno_class or predictor: + reconstruction = output[0] + prediction = output[1] + pheno = output[-1] + else: + reconstruction = output + + with tf.summary.create_file_writer(run_dir).as_default(): + hp.hparams(hpms) # record the values used in this trial + val_mae = tf.reduce_mean( + tf.keras.metrics.mean_absolute_error(X, reconstruction) + ) + val_mse = tf.reduce_mean(tf.keras.metrics.mean_squared_error(X, reconstruction)) + tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1) + tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1) + + if predictor: + pred_mae = tf.reduce_mean( + tf.keras.metrics.mean_absolute_error(X, prediction) + ) + pred_mse = tf.reduce_mean( + tf.keras.metrics.mean_squared_error(X, prediction) + ) + tf.summary.scalar("val_prediction_mae".format(rec), pred_mae, step=1) + tf.summary.scalar("val_prediction_mse".format(rec), pred_mse, step=1) + + if pheno_class: + pheno_acc = tf.keras.metrics.Accuracy(y, pheno) + pheno_auc = tf.keras.metrics.AUC(y, pheno) + + tf.summary.scalar("phenotype_prediction_accuracy", pheno_acc, step=1) + tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1) + + def tune_search( data: List[np.array], encoding_size: int, -- GitLab