Skip to content
Snippets Groups Projects
Commit 3d50dfbe authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 0016131f
No related branches found
No related tags found
No related merge requests found
...@@ -262,6 +262,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -262,6 +262,7 @@ class SEQ_2_SEQ_GMVAE:
rule_based_features: int = 6, rule_based_features: int = 6,
reg_cat_clusters: bool = False, reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False, reg_cluster_variance: bool = False,
strategy = tf.distribute.MirroredStrategy()
): ):
self.hparams = self.get_hparams(architecture_hparams) self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -296,6 +297,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -296,6 +297,7 @@ class SEQ_2_SEQ_GMVAE:
self.prior = "standard_normal" self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance self.reg_cluster_variance = reg_cluster_variance
self.strategy = strategy
assert ( assert (
"ELBO" in self.loss or "MMD" in self.loss "ELBO" in self.loss or "MMD" in self.loss
...@@ -576,243 +578,245 @@ class SEQ_2_SEQ_GMVAE: ...@@ -576,243 +578,245 @@ class SEQ_2_SEQ_GMVAE:
Model_RC1, Model_RC1,
) = self.get_layers(input_shape) ) = self.get_layers(input_shape)
# Define and instantiate encoder with self.strategy.scope():
x = Input(shape=input_shape[1:])
encoder = Model_E0(x) # Define and instantiate encoder
encoder = BatchNormalization()(encoder) x = Input(shape=input_shape[1:])
encoder = Model_E1(encoder) encoder = Model_E0(x)
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
encoder = Model_E2(encoder) encoder = Model_E1(encoder)
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
encoder = Model_E3(encoder) encoder = Model_E2(encoder)
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder) encoder = Model_E3(encoder)
encoder = Sequential(Model_E4)(encoder) encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) encoder = Sequential(Model_E4)(encoder)
z_cat = Dense(
self.number_of_components, # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
name="cluster_assignment", z_cat = Dense(
activation="softmax", self.number_of_components,
activity_regularizer=( name="cluster_assignment",
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) activation="softmax",
if self.reg_cat_clusters activity_regularizer=(
else None tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
), if self.reg_cat_clusters
)(encoder) else None
),
)(encoder)
z_gauss_mean = Dense( z_gauss_mean = Dense(
tfpl.IndependentNormal.params_size( tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components self.ENCODING * self.number_of_components
) )
// 2, // 2,
name="cluster_means", name="cluster_means",
activation=None, activation=None,
kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values
# computed from the labels, we could also initialize the prior this way, and update it every N epochs # computed from the labels, we could also initialize the prior this way, and update it every N epochs
)(encoder) )(encoder)
z_gauss_var = Dense( z_gauss_var = Dense(
tfpl.IndependentNormal.params_size( tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components self.ENCODING * self.number_of_components
) )
// 2, // 2,
name="cluster_variances", name="cluster_variances",
activation=None, activation=None,
activity_regularizer=( activity_regularizer=(
tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
), ),
)(encoder) )(encoder)
z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1) z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
# Identity layer controlling for dead neurons in the Gaussian Mixture posterior # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
if self.neuron_control: if self.neuron_control:
z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss) z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
if self.overlap_loss: if self.overlap_loss:
z_gauss = deepof.model_utils.Cluster_overlap( z_gauss = deepof.model_utils.Cluster_overlap(
self.ENCODING, self.ENCODING,
self.number_of_components, self.number_of_components,
loss=self.overlap_loss, loss=self.overlap_loss,
)(z_gauss) )(z_gauss)
z = tfpl.DistributionLambda( z = tfpl.DistributionLambda(
make_distribution_fn=lambda gauss: tfd.mixture.Mixture( make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
cat=tfd.categorical.Categorical( cat=tfd.categorical.Categorical(
probs=gauss[0], probs=gauss[0],
),
components=[
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
),
reinterpreted_batch_ndims=1,
)
for k in range(self.number_of_components)
],
), ),
components=[ convert_to_tensor_fn="sample",
tfd.Independent( name="encoding_distribution",
tfd.Normal( )([z_cat, z_gauss])
loc=gauss[1][..., : self.ENCODING, k],
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
),
reinterpreted_batch_ndims=1,
)
for k in range(self.number_of_components)
],
),
convert_to_tensor_fn="sample",
name="encoding_distribution",
)([z_cat, z_gauss])
posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution") posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
# Define and control custom loss functions # Define and control custom loss functions
if "ELBO" in self.loss: if "ELBO" in self.loss:
kl_warm_up_iters = tf.cast( kl_warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] // self.batch_size + 1), self.kl_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64, tf.int64,
) )
# noinspection PyCallingNonCallable # noinspection PyCallingNonCallable
z = deepof.model_utils.KLDivergenceLayer( z = deepof.model_utils.KLDivergenceLayer(
distribution_b=self.prior, distribution_b=self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl), test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0, test_points_reduce_axis=0,
iters=self.optimizer.iterations, iters=self.optimizer.iterations,
warm_up_iters=kl_warm_up_iters, warm_up_iters=kl_warm_up_iters,
annealing_mode=self.kl_annealing_mode, annealing_mode=self.kl_annealing_mode,
)(z) )(z)
if "MMD" in self.loss: if "MMD" in self.loss:
mmd_warm_up_iters = tf.cast( mmd_warm_up_iters = tf.cast(
self.mmd_warmup * (input_shape[0] // self.batch_size + 1), self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64, tf.int64,
) )
z = deepof.model_utils.MMDiscrepancyLayer( z = deepof.model_utils.MMDiscrepancyLayer(
batch_size=self.batch_size, batch_size=self.batch_size,
prior=self.prior, prior=self.prior,
iters=self.optimizer.iterations, iters=self.optimizer.iterations,
warm_up_iters=mmd_warm_up_iters, warm_up_iters=mmd_warm_up_iters,
annealing_mode=self.mmd_annealing_mode, annealing_mode=self.mmd_annealing_mode,
)(z) )(z)
# Dummy layer with no parameters, to retrieve the previous tensor # Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z) z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
# Define and instantiate generator # Define and instantiate generator
g = Input(shape=self.ENCODING) g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g) generator = Sequential(Model_D1)(g)
generator = Model_D2(generator) generator = Model_D2(generator)
generator = BatchNormalization()(generator) generator = BatchNormalization()(generator)
generator = Model_D3(generator) generator = Model_D3(generator)
generator = Model_D4(generator) generator = Model_D4(generator)
generator = BatchNormalization()(generator) generator = BatchNormalization()(generator)
generator = Model_D5(generator) generator = Model_D5(generator)
generator = BatchNormalization()(generator) generator = BatchNormalization()(generator)
generator = Model_D6(generator) generator = Model_D6(generator)
generator = BatchNormalization()(generator) generator = BatchNormalization()(generator)
x_decoded_mean = Dense( x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
x_decoded_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
)
x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
x_decoded = tf.keras.layers.concatenate(
[x_decoded_mean, x_decoded_var], axis=-1
)
x_decoded_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_reconstruction",
)(x_decoded)
# define individual branches as models
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
generator = Model(g, x_decoded_mean, name="vae_reconstruction")
def log_loss(x_true, p_x_q_given_z):
"""Computes the negative log likelihood of the data given
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [generator(encoder.outputs)]
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
if self.next_sequence_prediction > 0:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)(z)
predictor = BatchNormalization()(predictor)
predictor = Model_P1(predictor)
predictor = BatchNormalization()(predictor)
predictor = RepeatVector(input_shape[1])(predictor)
predictor = Model_P2(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P4(predictor)
x_predicted_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2 tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(predictor) )(generator)
x_predicted_var = tf.keras.activations.softplus( x_decoded_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)( Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
predictor
)
)
x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
x_predicted_var
) )
x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
x_decoded = tf.keras.layers.concatenate( x_decoded = tf.keras.layers.concatenate(
[x_predicted_mean, x_predicted_var], axis=-1 [x_decoded_mean, x_decoded_var], axis=-1
) )
x_predicted_mean = tfpl.IndependentNormal( x_decoded_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:], event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean, convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_prediction", name="vae_reconstruction",
)(x_decoded) )(x_decoded)
model_outs.append(x_predicted_mean) # define individual branches as models
model_losses.append(log_loss) encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
model_metrics["vae_prediction"] = ["mae", "mse"] generator = Model(g, x_decoded_mean, name="vae_reconstruction")
loss_weights.append(self.next_sequence_prediction)
def log_loss(x_true, p_x_q_given_z):
if self.phenotype_prediction > 0: """Computes the negative log likelihood of the data given
pheno_pred = Model_PC1(z) the output distribution"""
pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred) return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
pheno_pred = tfpl.IndependentBernoulli(
event_shape=1, model_outs = [generator(encoder.outputs)]
convert_to_tensor_fn=tfp.distributions.Distribution.mean, model_losses = [log_loss]
name="phenotype_prediction", model_metrics = {"vae_reconstruction": ["mae", "mse"]}
)(pheno_pred) loss_weights = [1.0]
model_outs.append(pheno_pred) if self.next_sequence_prediction > 0:
model_losses.append(log_loss) # Define and instantiate predictor
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"] predictor = Dense(
loss_weights.append(self.phenotype_prediction) self.DENSE_2,
activation=self.dense_activation,
if self.rule_based_prediction > 0: kernel_initializer=he_uniform(),
rule_pred = Model_RC1(z) )(z)
predictor = BatchNormalization()(predictor)
rule_pred = Dense( predictor = Model_P1(predictor)
tfpl.IndependentBernoulli.params_size(self.rule_based_features) predictor = BatchNormalization()(predictor)
)(rule_pred) predictor = RepeatVector(input_shape[1])(predictor)
rule_pred = tfpl.IndependentBernoulli( predictor = Model_P2(predictor)
event_shape=self.rule_based_features, predictor = BatchNormalization()(predictor)
convert_to_tensor_fn=tfp.distributions.Distribution.mean, predictor = Model_P3(predictor)
name="rule_based_prediction", predictor = BatchNormalization()(predictor)
)(rule_pred) predictor = Model_P4(predictor)
x_predicted_mean = Dense(
model_outs.append(rule_pred) tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
model_losses.append(log_loss) )(predictor)
model_metrics["rule_based_prediction"] = [ x_predicted_var = tf.keras.activations.softplus(
"mae", Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
"mse", predictor
] )
loss_weights.append(self.rule_based_prediction) )
x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
x_predicted_var
)
x_decoded = tf.keras.layers.concatenate(
[x_predicted_mean, x_predicted_var], axis=-1
)
x_predicted_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_prediction",
)(x_decoded)
model_outs.append(x_predicted_mean)
model_losses.append(log_loss)
model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.next_sequence_prediction)
if self.phenotype_prediction > 0:
pheno_pred = Model_PC1(z)
pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
pheno_pred = tfpl.IndependentBernoulli(
event_shape=1,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="phenotype_prediction",
)(pheno_pred)
model_outs.append(pheno_pred)
model_losses.append(log_loss)
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction)
if self.rule_based_prediction > 0:
rule_pred = Model_RC1(z)
rule_pred = Dense(
tfpl.IndependentBernoulli.params_size(self.rule_based_features)
)(rule_pred)
rule_pred = tfpl.IndependentBernoulli(
event_shape=self.rule_based_features,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="rule_based_prediction",
)(rule_pred)
model_outs.append(rule_pred)
model_losses.append(log_loss)
model_metrics["rule_based_prediction"] = [
"mae",
"mse",
]
loss_weights.append(self.rule_based_prediction)
# define grouper and end-to-end autoencoder model # define grouper and end-to-end autoencoder model
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering") grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment