From 23ff6d44855afed1de68e80d78b402c43a510f81 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Mon, 26 Apr 2021 16:56:08 +0200
Subject: [PATCH] Added a MirroredStrategy to train models on multiple GPUs if
 they are available

---
 deepof/data.py        |   2 +
 deepof/models.py      | 440 +++++++++++++++++++++---------------------
 deepof/train_utils.py |  67 +++----
 3 files changed, 254 insertions(+), 255 deletions(-)

diff --git a/deepof/data.py b/deepof/data.py
index 5fd22106..e617de44 100644
--- a/deepof/data.py
+++ b/deepof/data.py
@@ -907,6 +907,7 @@ class coordinates:
         entropy_knn: int = 100,
         input_type: str = False,
         run: int = 0,
+        strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
     ) -> Tuple:
         """
         Annotates coordinates using an unsupervised autoencoder.
@@ -974,6 +975,7 @@ class coordinates:
             entropy_knn=entropy_knn,
             input_type=input_type,
             run=run,
+            strategy=strategy,
         )
 
         # returns a list of trained tensorflow models
diff --git a/deepof/models.py b/deepof/models.py
index 0b408f22..cb6abf02 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -262,7 +262,6 @@ class SEQ_2_SEQ_GMVAE:
         rule_based_features: int = 6,
         reg_cat_clusters: bool = False,
         reg_cluster_variance: bool = False,
-        strategy = tf.distribute.MirroredStrategy()
     ):
         self.hparams = self.get_hparams(architecture_hparams)
         self.batch_size = batch_size
@@ -297,7 +296,6 @@ class SEQ_2_SEQ_GMVAE:
         self.prior = "standard_normal"
         self.reg_cat_clusters = reg_cat_clusters
         self.reg_cluster_variance = reg_cluster_variance
-        self.strategy = strategy
 
         assert (
             "ELBO" in self.loss or "MMD" in self.loss
@@ -578,245 +576,243 @@ class SEQ_2_SEQ_GMVAE:
             Model_RC1,
         ) = self.get_layers(input_shape)
 
-        with self.strategy.scope():
-
-            # Define and instantiate encoder
-            x = Input(shape=input_shape[1:])
-            encoder = Model_E0(x)
-            encoder = BatchNormalization()(encoder)
-            encoder = Model_E1(encoder)
-            encoder = BatchNormalization()(encoder)
-            encoder = Model_E2(encoder)
-            encoder = BatchNormalization()(encoder)
-            encoder = Model_E3(encoder)
-            encoder = BatchNormalization()(encoder)
-            encoder = Dropout(self.DROPOUT_RATE)(encoder)
-            encoder = Sequential(Model_E4)(encoder)
-
-            # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
-            z_cat = Dense(
-                self.number_of_components,
-                name="cluster_assignment",
-                activation="softmax",
-                activity_regularizer=(
-                    tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
-                    if self.reg_cat_clusters
-                    else None
-                ),
-            )(encoder)
+        # Define and instantiate encoder
+        x = Input(shape=input_shape[1:])
+        encoder = Model_E0(x)
+        encoder = BatchNormalization()(encoder)
+        encoder = Model_E1(encoder)
+        encoder = BatchNormalization()(encoder)
+        encoder = Model_E2(encoder)
+        encoder = BatchNormalization()(encoder)
+        encoder = Model_E3(encoder)
+        encoder = BatchNormalization()(encoder)
+        encoder = Dropout(self.DROPOUT_RATE)(encoder)
+        encoder = Sequential(Model_E4)(encoder)
+
+        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
+        z_cat = Dense(
+            self.number_of_components,
+            name="cluster_assignment",
+            activation="softmax",
+            activity_regularizer=(
+                tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
+                if self.reg_cat_clusters
+                else None
+            ),
+        )(encoder)
 
-            z_gauss_mean = Dense(
-                tfpl.IndependentNormal.params_size(
-                    self.ENCODING * self.number_of_components
-                )
-                // 2,
-                name="cluster_means",
-                activation=None,
-                kernel_initializer=Orthogonal(),  # An alternative is a constant initializer with a matrix of values
-                # computed from the labels, we could also initialize the prior this way, and update it every N epochs
-            )(encoder)
-
-            z_gauss_var = Dense(
-                tfpl.IndependentNormal.params_size(
-                    self.ENCODING * self.number_of_components
-                )
-                // 2,
-                name="cluster_variances",
-                activation=None,
-                activity_regularizer=(
-                    tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
-                ),
-            )(encoder)
+        z_gauss_mean = Dense(
+            tfpl.IndependentNormal.params_size(
+                self.ENCODING * self.number_of_components
+            )
+            // 2,
+            name="cluster_means",
+            activation=None,
+            kernel_initializer=Orthogonal(),  # An alternative is a constant initializer with a matrix of values
+            # computed from the labels, we could also initialize the prior this way, and update it every N epochs
+        )(encoder)
+
+        z_gauss_var = Dense(
+            tfpl.IndependentNormal.params_size(
+                self.ENCODING * self.number_of_components
+            )
+            // 2,
+            name="cluster_variances",
+            activation=None,
+            activity_regularizer=(
+                tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
+            ),
+        )(encoder)
 
-            z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
+        z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
 
-            z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
+        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
 
-            # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
-            if self.neuron_control:
-                z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
+        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
+        if self.neuron_control:
+            z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
 
-            if self.overlap_loss:
-                z_gauss = deepof.model_utils.Cluster_overlap(
-                    self.ENCODING,
-                    self.number_of_components,
-                    loss=self.overlap_loss,
-                )(z_gauss)
+        if self.overlap_loss:
+            z_gauss = deepof.model_utils.Cluster_overlap(
+                self.ENCODING,
+                self.number_of_components,
+                loss=self.overlap_loss,
+            )(z_gauss)
 
-            z = tfpl.DistributionLambda(
-                make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
-                    cat=tfd.categorical.Categorical(
-                        probs=gauss[0],
-                    ),
-                    components=[
-                        tfd.Independent(
-                            tfd.Normal(
-                                loc=gauss[1][..., : self.ENCODING, k],
-                                scale=1e-3
-                                + softplus(gauss[1][..., self.ENCODING :, k])
-                                + 1e-5,
-                            ),
-                            reinterpreted_batch_ndims=1,
-                        )
-                        for k in range(self.number_of_components)
-                    ],
+        z = tfpl.DistributionLambda(
+            make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
+                cat=tfd.categorical.Categorical(
+                    probs=gauss[0],
                 ),
-                convert_to_tensor_fn="sample",
-                name="encoding_distribution",
-            )([z_cat, z_gauss])
+                components=[
+                    tfd.Independent(
+                        tfd.Normal(
+                            loc=gauss[1][..., : self.ENCODING, k],
+                            scale=1e-3
+                            + softplus(gauss[1][..., self.ENCODING :, k])
+                            + 1e-5,
+                        ),
+                        reinterpreted_batch_ndims=1,
+                    )
+                    for k in range(self.number_of_components)
+                ],
+            ),
+            convert_to_tensor_fn="sample",
+            name="encoding_distribution",
+        )([z_cat, z_gauss])
 
-            posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
+        posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
 
-            # Define and control custom loss functions
-            if "ELBO" in self.loss:
-                kl_warm_up_iters = tf.cast(
-                    self.kl_warmup * (input_shape[0] // self.batch_size + 1),
-                    tf.int64,
-                )
+        # Define and control custom loss functions
+        if "ELBO" in self.loss:
+            kl_warm_up_iters = tf.cast(
+                self.kl_warmup * (input_shape[0] // self.batch_size + 1),
+                tf.int64,
+            )
 
-                # noinspection PyCallingNonCallable
-                z = deepof.model_utils.KLDivergenceLayer(
-                    distribution_b=self.prior,
-                    test_points_fn=lambda q: q.sample(self.mc_kl),
-                    test_points_reduce_axis=0,
-                    iters=self.optimizer.iterations,
-                    warm_up_iters=kl_warm_up_iters,
-                    annealing_mode=self.kl_annealing_mode,
-                )(z)
-
-            if "MMD" in self.loss:
-                mmd_warm_up_iters = tf.cast(
-                    self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
-                    tf.int64,
-                )
+            # noinspection PyCallingNonCallable
+            z = deepof.model_utils.KLDivergenceLayer(
+                distribution_b=self.prior,
+                test_points_fn=lambda q: q.sample(self.mc_kl),
+                test_points_reduce_axis=0,
+                iters=self.optimizer.iterations,
+                warm_up_iters=kl_warm_up_iters,
+                annealing_mode=self.kl_annealing_mode,
+            )(z)
+
+        if "MMD" in self.loss:
+            mmd_warm_up_iters = tf.cast(
+                self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
+                tf.int64,
+            )
 
-                z = deepof.model_utils.MMDiscrepancyLayer(
-                    batch_size=self.batch_size,
-                    prior=self.prior,
-                    iters=self.optimizer.iterations,
-                    warm_up_iters=mmd_warm_up_iters,
-                    annealing_mode=self.mmd_annealing_mode,
-                )(z)
-
-            # Dummy layer with no parameters, to retrieve the previous tensor
-            z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
-
-            # Define and instantiate generator
-            g = Input(shape=self.ENCODING)
-            generator = Sequential(Model_D1)(g)
-            generator = Model_D2(generator)
-            generator = BatchNormalization()(generator)
-            generator = Model_D3(generator)
-            generator = Model_D4(generator)
-            generator = BatchNormalization()(generator)
-            generator = Model_D5(generator)
-            generator = BatchNormalization()(generator)
-            generator = Model_D6(generator)
-            generator = BatchNormalization()(generator)
-            x_decoded_mean = Dense(
+            z = deepof.model_utils.MMDiscrepancyLayer(
+                batch_size=self.batch_size,
+                prior=self.prior,
+                iters=self.optimizer.iterations,
+                warm_up_iters=mmd_warm_up_iters,
+                annealing_mode=self.mmd_annealing_mode,
+            )(z)
+
+        # Dummy layer with no parameters, to retrieve the previous tensor
+        z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
+
+        # Define and instantiate generator
+        g = Input(shape=self.ENCODING)
+        generator = Sequential(Model_D1)(g)
+        generator = Model_D2(generator)
+        generator = BatchNormalization()(generator)
+        generator = Model_D3(generator)
+        generator = Model_D4(generator)
+        generator = BatchNormalization()(generator)
+        generator = Model_D5(generator)
+        generator = BatchNormalization()(generator)
+        generator = Model_D6(generator)
+        generator = BatchNormalization()(generator)
+        x_decoded_mean = Dense(
+            tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
+        )(generator)
+        x_decoded_var = tf.keras.activations.softplus(
+            Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
+        )
+        x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
+        x_decoded = tf.keras.layers.concatenate(
+            [x_decoded_mean, x_decoded_var], axis=-1
+        )
+        x_decoded_mean = tfpl.IndependentNormal(
+            event_shape=input_shape[2:],
+            convert_to_tensor_fn=tfp.distributions.Distribution.mean,
+            name="vae_reconstruction",
+        )(x_decoded)
+
+        # define individual branches as models
+        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
+        generator = Model(g, x_decoded_mean, name="vae_reconstruction")
+
+        def log_loss(x_true, p_x_q_given_z):
+            """Computes the negative log likelihood of the data given
+            the output distribution"""
+            return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
+
+        model_outs = [generator(encoder.outputs)]
+        model_losses = [log_loss]
+        model_metrics = {"vae_reconstruction": ["mae", "mse"]}
+        loss_weights = [1.0]
+
+        if self.next_sequence_prediction > 0:
+            # Define and instantiate predictor
+            predictor = Dense(
+                self.DENSE_2,
+                activation=self.dense_activation,
+                kernel_initializer=he_uniform(),
+            )(z)
+            predictor = BatchNormalization()(predictor)
+            predictor = Model_P1(predictor)
+            predictor = BatchNormalization()(predictor)
+            predictor = RepeatVector(input_shape[1])(predictor)
+            predictor = Model_P2(predictor)
+            predictor = BatchNormalization()(predictor)
+            predictor = Model_P3(predictor)
+            predictor = BatchNormalization()(predictor)
+            predictor = Model_P4(predictor)
+            x_predicted_mean = Dense(
                 tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
-            )(generator)
-            x_decoded_var = tf.keras.activations.softplus(
-                Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
+            )(predictor)
+            x_predicted_var = tf.keras.activations.softplus(
+                Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
+                    predictor
+                )
+            )
+            x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
+                x_predicted_var
             )
-            x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
             x_decoded = tf.keras.layers.concatenate(
-                [x_decoded_mean, x_decoded_var], axis=-1
+                [x_predicted_mean, x_predicted_var], axis=-1
             )
-            x_decoded_mean = tfpl.IndependentNormal(
+            x_predicted_mean = tfpl.IndependentNormal(
                 event_shape=input_shape[2:],
                 convert_to_tensor_fn=tfp.distributions.Distribution.mean,
-                name="vae_reconstruction",
+                name="vae_prediction",
             )(x_decoded)
 
-            # define individual branches as models
-            encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
-            generator = Model(g, x_decoded_mean, name="vae_reconstruction")
-
-            def log_loss(x_true, p_x_q_given_z):
-                """Computes the negative log likelihood of the data given
-                the output distribution"""
-                return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
-
-            model_outs = [generator(encoder.outputs)]
-            model_losses = [log_loss]
-            model_metrics = {"vae_reconstruction": ["mae", "mse"]}
-            loss_weights = [1.0]
-
-            if self.next_sequence_prediction > 0:
-                # Define and instantiate predictor
-                predictor = Dense(
-                    self.DENSE_2,
-                    activation=self.dense_activation,
-                    kernel_initializer=he_uniform(),
-                )(z)
-                predictor = BatchNormalization()(predictor)
-                predictor = Model_P1(predictor)
-                predictor = BatchNormalization()(predictor)
-                predictor = RepeatVector(input_shape[1])(predictor)
-                predictor = Model_P2(predictor)
-                predictor = BatchNormalization()(predictor)
-                predictor = Model_P3(predictor)
-                predictor = BatchNormalization()(predictor)
-                predictor = Model_P4(predictor)
-                x_predicted_mean = Dense(
-                    tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
-                )(predictor)
-                x_predicted_var = tf.keras.activations.softplus(
-                    Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
-                        predictor
-                    )
-                )
-                x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
-                    x_predicted_var
-                )
-                x_decoded = tf.keras.layers.concatenate(
-                    [x_predicted_mean, x_predicted_var], axis=-1
-                )
-                x_predicted_mean = tfpl.IndependentNormal(
-                    event_shape=input_shape[2:],
-                    convert_to_tensor_fn=tfp.distributions.Distribution.mean,
-                    name="vae_prediction",
-                )(x_decoded)
-
-                model_outs.append(x_predicted_mean)
-                model_losses.append(log_loss)
-                model_metrics["vae_prediction"] = ["mae", "mse"]
-                loss_weights.append(self.next_sequence_prediction)
-
-            if self.phenotype_prediction > 0:
-                pheno_pred = Model_PC1(z)
-                pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
-                pheno_pred = tfpl.IndependentBernoulli(
-                    event_shape=1,
-                    convert_to_tensor_fn=tfp.distributions.Distribution.mean,
-                    name="phenotype_prediction",
-                )(pheno_pred)
-
-                model_outs.append(pheno_pred)
-                model_losses.append(log_loss)
-                model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
-                loss_weights.append(self.phenotype_prediction)
-
-            if self.rule_based_prediction > 0:
-                rule_pred = Model_RC1(z)
-
-                rule_pred = Dense(
-                    tfpl.IndependentBernoulli.params_size(self.rule_based_features)
-                )(rule_pred)
-                rule_pred = tfpl.IndependentBernoulli(
-                    event_shape=self.rule_based_features,
-                    convert_to_tensor_fn=tfp.distributions.Distribution.mean,
-                    name="rule_based_prediction",
-                )(rule_pred)
-
-                model_outs.append(rule_pred)
-                model_losses.append(log_loss)
-                model_metrics["rule_based_prediction"] = [
-                    "mae",
-                    "mse",
-                ]
-                loss_weights.append(self.rule_based_prediction)
+            model_outs.append(x_predicted_mean)
+            model_losses.append(log_loss)
+            model_metrics["vae_prediction"] = ["mae", "mse"]
+            loss_weights.append(self.next_sequence_prediction)
+
+        if self.phenotype_prediction > 0:
+            pheno_pred = Model_PC1(z)
+            pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
+            pheno_pred = tfpl.IndependentBernoulli(
+                event_shape=1,
+                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
+                name="phenotype_prediction",
+            )(pheno_pred)
+
+            model_outs.append(pheno_pred)
+            model_losses.append(log_loss)
+            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
+            loss_weights.append(self.phenotype_prediction)
+
+        if self.rule_based_prediction > 0:
+            rule_pred = Model_RC1(z)
+
+            rule_pred = Dense(
+                tfpl.IndependentBernoulli.params_size(self.rule_based_features)
+            )(rule_pred)
+            rule_pred = tfpl.IndependentBernoulli(
+                event_shape=self.rule_based_features,
+                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
+                name="rule_based_prediction",
+            )(rule_pred)
+
+            model_outs.append(rule_pred)
+            model_losses.append(log_loss)
+            model_metrics["rule_based_prediction"] = [
+                "mae",
+                "mse",
+            ]
+            loss_weights.append(self.rule_based_prediction)
 
         # define grouper and end-to-end autoencoder model
         grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 9bd7be02..6829c508 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -306,6 +306,7 @@ def autoencoder_fitting(
     entropy_knn: int,
     input_type: str,
     run: int = 0,
+    strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
 ):
     """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
 
@@ -378,37 +379,38 @@ def autoencoder_fitting(
         return_list = (encoder, decoder, ae)
 
     else:
-        (
-            encoder,
-            generator,
-            grouper,
-            ae,
-            prior,
-            posterior,
-        ) = deepof.models.SEQ_2_SEQ_GMVAE(
-            architecture_hparams=({} if hparams is None else hparams),
-            batch_size=batch_size,
-            compile_model=True,
-            encoding=encoding_size,
-            kl_annealing_mode=kl_annealing_mode,
-            kl_warmup_epochs=kl_warmup,
-            loss=loss,
-            mmd_annealing_mode=mmd_annealing_mode,
-            mmd_warmup_epochs=mmd_warmup,
-            montecarlo_kl=montecarlo_kl,
-            neuron_control=False,
-            number_of_components=n_components,
-            overlap_loss=False,
-            next_sequence_prediction=next_sequence_prediction,
-            phenotype_prediction=phenotype_prediction,
-            rule_based_prediction=rule_based_prediction,
-            rule_based_features=rule_based_features,
-            reg_cat_clusters=reg_cat_clusters,
-            reg_cluster_variance=reg_cluster_variance,
-        ).build(
-            X_train.shape
-        )
-        return_list = (encoder, generator, grouper, ae)
+        with strategy.scope():
+            (
+                encoder,
+                generator,
+                grouper,
+                ae,
+                prior,
+                posterior,
+            ) = deepof.models.SEQ_2_SEQ_GMVAE(
+                architecture_hparams=({} if hparams is None else hparams),
+                batch_size=batch_size * strategy.num_replicas_in_sync,
+                compile_model=True,
+                encoding=encoding_size,
+                kl_annealing_mode=kl_annealing_mode,
+                kl_warmup_epochs=kl_warmup,
+                loss=loss,
+                mmd_annealing_mode=mmd_annealing_mode,
+                mmd_warmup_epochs=mmd_warmup,
+                montecarlo_kl=montecarlo_kl,
+                neuron_control=False,
+                number_of_components=n_components,
+                overlap_loss=False,
+                next_sequence_prediction=next_sequence_prediction,
+                phenotype_prediction=phenotype_prediction,
+                rule_based_prediction=rule_based_prediction,
+                rule_based_features=rule_based_features,
+                reg_cat_clusters=reg_cat_clusters,
+                reg_cluster_variance=reg_cluster_variance,
+            ).build(
+                X_train.shape
+            )
+            return_list = (encoder, generator, grouper, ae)
 
     if pretrained:
         # If pretrained models are specified, load weights and return
@@ -422,7 +424,6 @@ def autoencoder_fitting(
                 x=X_train,
                 y=X_train,
                 epochs=epochs,
-                batch_size=batch_size,
                 verbose=1,
                 validation_data=(X_val, X_val),
                 callbacks=cbacks
@@ -482,7 +483,7 @@ def autoencoder_fitting(
                 x=Xs,
                 y=ys,
                 epochs=epochs,
-                batch_size=batch_size,
+                batch_size=batch_size * strategy.num_replicas_in_sync,
                 verbose=1,
                 validation_data=(
                     Xvals,
-- 
GitLab