From dd62bd2e44fc3c34fc3ded04bef4cc414426305d Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Fri, 12 Mar 2021 04:08:04 +0100 Subject: [PATCH] Implemented KNN_purity callback --- deepof/model_utils.py | 7 +++++-- deepof/models.py | 4 ++-- deepof/train_utils.py | 1 + tests/test_train_utils.py | 4 ++++ 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/deepof/model_utils.py b/deepof/model_utils.py index ad94d9db..209eb2d4 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -209,8 +209,11 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): """ - def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."): + def __init__( + self, variational=True, validation_data=None, k=100, samples=10000, log_dir="." + ): super().__init__() + self.variational = variational self.validation_data = validation_data self.k = k self.samples = samples @@ -220,7 +223,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): """ Passes samples through the encoder and computes cluster purity on the latent embedding """ - if self.validation_data is not None: + if self.validation_data is not None and self.variational: # Get encoer and grouper from full model cluster_means = [ diff --git a/deepof/models.py b/deepof/models.py index cd4b2b42..d34f909d 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -253,7 +253,7 @@ class SEQ_2_SEQ_GMVAE: montecarlo_kl: int = 1, neuron_control: bool = False, number_of_components: int = 1, - overlap_loss: float = -1., + overlap_loss: float = -1.0, phenotype_prediction: float = 0.0, predictor: float = 0.0, reg_cat_clusters: bool = False, @@ -606,7 +606,7 @@ class SEQ_2_SEQ_GMVAE: z_gauss = deepof.model_utils.Cluster_overlap( self.ENCODING, self.number_of_components, - loss=tf.maximum(0., self.overlap_loss).numpy(), + loss=tf.maximum(0.0, self.overlap_loss).numpy(), )(z_gauss) z = tfpl.DistributionLambda( diff --git a/deepof/train_utils.py b/deepof/train_utils.py index f9e09bdf..77da7c98 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -117,6 +117,7 @@ def get_callbacks( samples=knn_samples, validation_data=X_val, log_dir=os.path.join(outpath, "metrics"), + variational=variational, ) onecycle = deepof.model_utils.one_cycle_scheduler( diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index b35a1978..f4899f71 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -117,6 +117,8 @@ def test_autoencoder_fitting( phenotype_class=pheno_class, predictor=predictor, variational=variational, + knn_neighbors=10, + knn_samples=10, ) @@ -168,6 +170,8 @@ def test_tune_search( True, True, None, + knn_neighbors=10, + knn_samples=10, ) )[1:] -- GitLab