From 23ff6d44855afed1de68e80d78b402c43a510f81 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Mon, 26 Apr 2021 16:56:08 +0200 Subject: [PATCH] Added a MirroredStrategy to train models on multiple GPUs if they are available --- deepof/data.py | 2 + deepof/models.py | 440 +++++++++++++++++++++--------------------- deepof/train_utils.py | 67 +++---- 3 files changed, 254 insertions(+), 255 deletions(-) diff --git a/deepof/data.py b/deepof/data.py index 5fd22106..e617de44 100644 --- a/deepof/data.py +++ b/deepof/data.py @@ -907,6 +907,7 @@ class coordinates: entropy_knn: int = 100, input_type: str = False, run: int = 0, + strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(), ) -> Tuple: """ Annotates coordinates using an unsupervised autoencoder. @@ -974,6 +975,7 @@ class coordinates: entropy_knn=entropy_knn, input_type=input_type, run=run, + strategy=strategy, ) # returns a list of trained tensorflow models diff --git a/deepof/models.py b/deepof/models.py index 0b408f22..cb6abf02 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -262,7 +262,6 @@ class SEQ_2_SEQ_GMVAE: rule_based_features: int = 6, reg_cat_clusters: bool = False, reg_cluster_variance: bool = False, - strategy = tf.distribute.MirroredStrategy() ): self.hparams = self.get_hparams(architecture_hparams) self.batch_size = batch_size @@ -297,7 +296,6 @@ class SEQ_2_SEQ_GMVAE: self.prior = "standard_normal" self.reg_cat_clusters = reg_cat_clusters self.reg_cluster_variance = reg_cluster_variance - self.strategy = strategy assert ( "ELBO" in self.loss or "MMD" in self.loss @@ -578,245 +576,243 @@ class SEQ_2_SEQ_GMVAE: Model_RC1, ) = self.get_layers(input_shape) - with self.strategy.scope(): - - # Define and instantiate encoder - x = Input(shape=input_shape[1:]) - encoder = Model_E0(x) - encoder = BatchNormalization()(encoder) - encoder = Model_E1(encoder) - encoder = BatchNormalization()(encoder) - encoder = Model_E2(encoder) - encoder = BatchNormalization()(encoder) - encoder = Model_E3(encoder) - encoder = BatchNormalization()(encoder) - encoder = Dropout(self.DROPOUT_RATE)(encoder) - encoder = Sequential(Model_E4)(encoder) - - # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) - z_cat = Dense( - self.number_of_components, - name="cluster_assignment", - activation="softmax", - activity_regularizer=( - tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) - if self.reg_cat_clusters - else None - ), - )(encoder) + # Define and instantiate encoder + x = Input(shape=input_shape[1:]) + encoder = Model_E0(x) + encoder = BatchNormalization()(encoder) + encoder = Model_E1(encoder) + encoder = BatchNormalization()(encoder) + encoder = Model_E2(encoder) + encoder = BatchNormalization()(encoder) + encoder = Model_E3(encoder) + encoder = BatchNormalization()(encoder) + encoder = Dropout(self.DROPOUT_RATE)(encoder) + encoder = Sequential(Model_E4)(encoder) + + # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) + z_cat = Dense( + self.number_of_components, + name="cluster_assignment", + activation="softmax", + activity_regularizer=( + tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) + if self.reg_cat_clusters + else None + ), + )(encoder) - z_gauss_mean = Dense( - tfpl.IndependentNormal.params_size( - self.ENCODING * self.number_of_components - ) - // 2, - name="cluster_means", - activation=None, - kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values - # computed from the labels, we could also initialize the prior this way, and update it every N epochs - )(encoder) - - z_gauss_var = Dense( - tfpl.IndependentNormal.params_size( - self.ENCODING * self.number_of_components - ) - // 2, - name="cluster_variances", - activation=None, - activity_regularizer=( - tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None - ), - )(encoder) + z_gauss_mean = Dense( + tfpl.IndependentNormal.params_size( + self.ENCODING * self.number_of_components + ) + // 2, + name="cluster_means", + activation=None, + kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values + # computed from the labels, we could also initialize the prior this way, and update it every N epochs + )(encoder) + + z_gauss_var = Dense( + tfpl.IndependentNormal.params_size( + self.ENCODING * self.number_of_components + ) + // 2, + name="cluster_variances", + activation=None, + activity_regularizer=( + tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None + ), + )(encoder) - z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1) + z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1) - z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) + z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) - # Identity layer controlling for dead neurons in the Gaussian Mixture posterior - if self.neuron_control: - z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss) + # Identity layer controlling for dead neurons in the Gaussian Mixture posterior + if self.neuron_control: + z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss) - if self.overlap_loss: - z_gauss = deepof.model_utils.Cluster_overlap( - self.ENCODING, - self.number_of_components, - loss=self.overlap_loss, - )(z_gauss) + if self.overlap_loss: + z_gauss = deepof.model_utils.Cluster_overlap( + self.ENCODING, + self.number_of_components, + loss=self.overlap_loss, + )(z_gauss) - z = tfpl.DistributionLambda( - make_distribution_fn=lambda gauss: tfd.mixture.Mixture( - cat=tfd.categorical.Categorical( - probs=gauss[0], - ), - components=[ - tfd.Independent( - tfd.Normal( - loc=gauss[1][..., : self.ENCODING, k], - scale=1e-3 - + softplus(gauss[1][..., self.ENCODING :, k]) - + 1e-5, - ), - reinterpreted_batch_ndims=1, - ) - for k in range(self.number_of_components) - ], + z = tfpl.DistributionLambda( + make_distribution_fn=lambda gauss: tfd.mixture.Mixture( + cat=tfd.categorical.Categorical( + probs=gauss[0], ), - convert_to_tensor_fn="sample", - name="encoding_distribution", - )([z_cat, z_gauss]) + components=[ + tfd.Independent( + tfd.Normal( + loc=gauss[1][..., : self.ENCODING, k], + scale=1e-3 + + softplus(gauss[1][..., self.ENCODING :, k]) + + 1e-5, + ), + reinterpreted_batch_ndims=1, + ) + for k in range(self.number_of_components) + ], + ), + convert_to_tensor_fn="sample", + name="encoding_distribution", + )([z_cat, z_gauss]) - posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution") + posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution") - # Define and control custom loss functions - if "ELBO" in self.loss: - kl_warm_up_iters = tf.cast( - self.kl_warmup * (input_shape[0] // self.batch_size + 1), - tf.int64, - ) + # Define and control custom loss functions + if "ELBO" in self.loss: + kl_warm_up_iters = tf.cast( + self.kl_warmup * (input_shape[0] // self.batch_size + 1), + tf.int64, + ) - # noinspection PyCallingNonCallable - z = deepof.model_utils.KLDivergenceLayer( - distribution_b=self.prior, - test_points_fn=lambda q: q.sample(self.mc_kl), - test_points_reduce_axis=0, - iters=self.optimizer.iterations, - warm_up_iters=kl_warm_up_iters, - annealing_mode=self.kl_annealing_mode, - )(z) - - if "MMD" in self.loss: - mmd_warm_up_iters = tf.cast( - self.mmd_warmup * (input_shape[0] // self.batch_size + 1), - tf.int64, - ) + # noinspection PyCallingNonCallable + z = deepof.model_utils.KLDivergenceLayer( + distribution_b=self.prior, + test_points_fn=lambda q: q.sample(self.mc_kl), + test_points_reduce_axis=0, + iters=self.optimizer.iterations, + warm_up_iters=kl_warm_up_iters, + annealing_mode=self.kl_annealing_mode, + )(z) + + if "MMD" in self.loss: + mmd_warm_up_iters = tf.cast( + self.mmd_warmup * (input_shape[0] // self.batch_size + 1), + tf.int64, + ) - z = deepof.model_utils.MMDiscrepancyLayer( - batch_size=self.batch_size, - prior=self.prior, - iters=self.optimizer.iterations, - warm_up_iters=mmd_warm_up_iters, - annealing_mode=self.mmd_annealing_mode, - )(z) - - # Dummy layer with no parameters, to retrieve the previous tensor - z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z) - - # Define and instantiate generator - g = Input(shape=self.ENCODING) - generator = Sequential(Model_D1)(g) - generator = Model_D2(generator) - generator = BatchNormalization()(generator) - generator = Model_D3(generator) - generator = Model_D4(generator) - generator = BatchNormalization()(generator) - generator = Model_D5(generator) - generator = BatchNormalization()(generator) - generator = Model_D6(generator) - generator = BatchNormalization()(generator) - x_decoded_mean = Dense( + z = deepof.model_utils.MMDiscrepancyLayer( + batch_size=self.batch_size, + prior=self.prior, + iters=self.optimizer.iterations, + warm_up_iters=mmd_warm_up_iters, + annealing_mode=self.mmd_annealing_mode, + )(z) + + # Dummy layer with no parameters, to retrieve the previous tensor + z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z) + + # Define and instantiate generator + g = Input(shape=self.ENCODING) + generator = Sequential(Model_D1)(g) + generator = Model_D2(generator) + generator = BatchNormalization()(generator) + generator = Model_D3(generator) + generator = Model_D4(generator) + generator = BatchNormalization()(generator) + generator = Model_D5(generator) + generator = BatchNormalization()(generator) + generator = Model_D6(generator) + generator = BatchNormalization()(generator) + x_decoded_mean = Dense( + tfpl.IndependentNormal.params_size(input_shape[2:]) // 2 + )(generator) + x_decoded_var = tf.keras.activations.softplus( + Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator) + ) + x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var) + x_decoded = tf.keras.layers.concatenate( + [x_decoded_mean, x_decoded_var], axis=-1 + ) + x_decoded_mean = tfpl.IndependentNormal( + event_shape=input_shape[2:], + convert_to_tensor_fn=tfp.distributions.Distribution.mean, + name="vae_reconstruction", + )(x_decoded) + + # define individual branches as models + encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder") + generator = Model(g, x_decoded_mean, name="vae_reconstruction") + + def log_loss(x_true, p_x_q_given_z): + """Computes the negative log likelihood of the data given + the output distribution""" + return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true)) + + model_outs = [generator(encoder.outputs)] + model_losses = [log_loss] + model_metrics = {"vae_reconstruction": ["mae", "mse"]} + loss_weights = [1.0] + + if self.next_sequence_prediction > 0: + # Define and instantiate predictor + predictor = Dense( + self.DENSE_2, + activation=self.dense_activation, + kernel_initializer=he_uniform(), + )(z) + predictor = BatchNormalization()(predictor) + predictor = Model_P1(predictor) + predictor = BatchNormalization()(predictor) + predictor = RepeatVector(input_shape[1])(predictor) + predictor = Model_P2(predictor) + predictor = BatchNormalization()(predictor) + predictor = Model_P3(predictor) + predictor = BatchNormalization()(predictor) + predictor = Model_P4(predictor) + x_predicted_mean = Dense( tfpl.IndependentNormal.params_size(input_shape[2:]) // 2 - )(generator) - x_decoded_var = tf.keras.activations.softplus( - Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator) + )(predictor) + x_predicted_var = tf.keras.activations.softplus( + Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)( + predictor + ) + ) + x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)( + x_predicted_var ) - x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var) x_decoded = tf.keras.layers.concatenate( - [x_decoded_mean, x_decoded_var], axis=-1 + [x_predicted_mean, x_predicted_var], axis=-1 ) - x_decoded_mean = tfpl.IndependentNormal( + x_predicted_mean = tfpl.IndependentNormal( event_shape=input_shape[2:], convert_to_tensor_fn=tfp.distributions.Distribution.mean, - name="vae_reconstruction", + name="vae_prediction", )(x_decoded) - # define individual branches as models - encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder") - generator = Model(g, x_decoded_mean, name="vae_reconstruction") - - def log_loss(x_true, p_x_q_given_z): - """Computes the negative log likelihood of the data given - the output distribution""" - return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true)) - - model_outs = [generator(encoder.outputs)] - model_losses = [log_loss] - model_metrics = {"vae_reconstruction": ["mae", "mse"]} - loss_weights = [1.0] - - if self.next_sequence_prediction > 0: - # Define and instantiate predictor - predictor = Dense( - self.DENSE_2, - activation=self.dense_activation, - kernel_initializer=he_uniform(), - )(z) - predictor = BatchNormalization()(predictor) - predictor = Model_P1(predictor) - predictor = BatchNormalization()(predictor) - predictor = RepeatVector(input_shape[1])(predictor) - predictor = Model_P2(predictor) - predictor = BatchNormalization()(predictor) - predictor = Model_P3(predictor) - predictor = BatchNormalization()(predictor) - predictor = Model_P4(predictor) - x_predicted_mean = Dense( - tfpl.IndependentNormal.params_size(input_shape[2:]) // 2 - )(predictor) - x_predicted_var = tf.keras.activations.softplus( - Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)( - predictor - ) - ) - x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)( - x_predicted_var - ) - x_decoded = tf.keras.layers.concatenate( - [x_predicted_mean, x_predicted_var], axis=-1 - ) - x_predicted_mean = tfpl.IndependentNormal( - event_shape=input_shape[2:], - convert_to_tensor_fn=tfp.distributions.Distribution.mean, - name="vae_prediction", - )(x_decoded) - - model_outs.append(x_predicted_mean) - model_losses.append(log_loss) - model_metrics["vae_prediction"] = ["mae", "mse"] - loss_weights.append(self.next_sequence_prediction) - - if self.phenotype_prediction > 0: - pheno_pred = Model_PC1(z) - pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred) - pheno_pred = tfpl.IndependentBernoulli( - event_shape=1, - convert_to_tensor_fn=tfp.distributions.Distribution.mean, - name="phenotype_prediction", - )(pheno_pred) - - model_outs.append(pheno_pred) - model_losses.append(log_loss) - model_metrics["phenotype_prediction"] = ["AUC", "accuracy"] - loss_weights.append(self.phenotype_prediction) - - if self.rule_based_prediction > 0: - rule_pred = Model_RC1(z) - - rule_pred = Dense( - tfpl.IndependentBernoulli.params_size(self.rule_based_features) - )(rule_pred) - rule_pred = tfpl.IndependentBernoulli( - event_shape=self.rule_based_features, - convert_to_tensor_fn=tfp.distributions.Distribution.mean, - name="rule_based_prediction", - )(rule_pred) - - model_outs.append(rule_pred) - model_losses.append(log_loss) - model_metrics["rule_based_prediction"] = [ - "mae", - "mse", - ] - loss_weights.append(self.rule_based_prediction) + model_outs.append(x_predicted_mean) + model_losses.append(log_loss) + model_metrics["vae_prediction"] = ["mae", "mse"] + loss_weights.append(self.next_sequence_prediction) + + if self.phenotype_prediction > 0: + pheno_pred = Model_PC1(z) + pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred) + pheno_pred = tfpl.IndependentBernoulli( + event_shape=1, + convert_to_tensor_fn=tfp.distributions.Distribution.mean, + name="phenotype_prediction", + )(pheno_pred) + + model_outs.append(pheno_pred) + model_losses.append(log_loss) + model_metrics["phenotype_prediction"] = ["AUC", "accuracy"] + loss_weights.append(self.phenotype_prediction) + + if self.rule_based_prediction > 0: + rule_pred = Model_RC1(z) + + rule_pred = Dense( + tfpl.IndependentBernoulli.params_size(self.rule_based_features) + )(rule_pred) + rule_pred = tfpl.IndependentBernoulli( + event_shape=self.rule_based_features, + convert_to_tensor_fn=tfp.distributions.Distribution.mean, + name="rule_based_prediction", + )(rule_pred) + + model_outs.append(rule_pred) + model_losses.append(log_loss) + model_metrics["rule_based_prediction"] = [ + "mae", + "mse", + ] + loss_weights.append(self.rule_based_prediction) # define grouper and end-to-end autoencoder model grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering") diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 9bd7be02..6829c508 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -306,6 +306,7 @@ def autoencoder_fitting( entropy_knn: int, input_type: str, run: int = 0, + strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(), ): """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding""" @@ -378,37 +379,38 @@ def autoencoder_fitting( return_list = (encoder, decoder, ae) else: - ( - encoder, - generator, - grouper, - ae, - prior, - posterior, - ) = deepof.models.SEQ_2_SEQ_GMVAE( - architecture_hparams=({} if hparams is None else hparams), - batch_size=batch_size, - compile_model=True, - encoding=encoding_size, - kl_annealing_mode=kl_annealing_mode, - kl_warmup_epochs=kl_warmup, - loss=loss, - mmd_annealing_mode=mmd_annealing_mode, - mmd_warmup_epochs=mmd_warmup, - montecarlo_kl=montecarlo_kl, - neuron_control=False, - number_of_components=n_components, - overlap_loss=False, - next_sequence_prediction=next_sequence_prediction, - phenotype_prediction=phenotype_prediction, - rule_based_prediction=rule_based_prediction, - rule_based_features=rule_based_features, - reg_cat_clusters=reg_cat_clusters, - reg_cluster_variance=reg_cluster_variance, - ).build( - X_train.shape - ) - return_list = (encoder, generator, grouper, ae) + with strategy.scope(): + ( + encoder, + generator, + grouper, + ae, + prior, + posterior, + ) = deepof.models.SEQ_2_SEQ_GMVAE( + architecture_hparams=({} if hparams is None else hparams), + batch_size=batch_size * strategy.num_replicas_in_sync, + compile_model=True, + encoding=encoding_size, + kl_annealing_mode=kl_annealing_mode, + kl_warmup_epochs=kl_warmup, + loss=loss, + mmd_annealing_mode=mmd_annealing_mode, + mmd_warmup_epochs=mmd_warmup, + montecarlo_kl=montecarlo_kl, + neuron_control=False, + number_of_components=n_components, + overlap_loss=False, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, + rule_based_features=rule_based_features, + reg_cat_clusters=reg_cat_clusters, + reg_cluster_variance=reg_cluster_variance, + ).build( + X_train.shape + ) + return_list = (encoder, generator, grouper, ae) if pretrained: # If pretrained models are specified, load weights and return @@ -422,7 +424,6 @@ def autoencoder_fitting( x=X_train, y=X_train, epochs=epochs, - batch_size=batch_size, verbose=1, validation_data=(X_val, X_val), callbacks=cbacks @@ -482,7 +483,7 @@ def autoencoder_fitting( x=Xs, y=ys, epochs=epochs, - batch_size=batch_size, + batch_size=batch_size * strategy.num_replicas_in_sync, verbose=1, validation_data=( Xvals, -- GitLab