Commit 3d50dfbe authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 0016131f
......@@ -262,6 +262,7 @@ class SEQ_2_SEQ_GMVAE:
rule_based_features: int = 6,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
strategy = tf.distribute.MirroredStrategy()
):
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
......@@ -296,6 +297,7 @@ class SEQ_2_SEQ_GMVAE:
self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance
self.strategy = strategy
assert (
"ELBO" in self.loss or "MMD" in self.loss
......@@ -576,243 +578,245 @@ class SEQ_2_SEQ_GMVAE:
Model_RC1,
) = self.get_layers(input_shape)
# Define and instantiate encoder
x = Input(shape=input_shape[1:])
encoder = Model_E0(x)
encoder = BatchNormalization()(encoder)
encoder = Model_E1(encoder)
encoder = BatchNormalization()(encoder)
encoder = Model_E2(encoder)
encoder = BatchNormalization()(encoder)
encoder = Model_E3(encoder)
encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Sequential(Model_E4)(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(
self.number_of_components,
name="cluster_assignment",
activation="softmax",
activity_regularizer=(
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
if self.reg_cat_clusters
else None
),
)(encoder)
with self.strategy.scope():
# Define and instantiate encoder
x = Input(shape=input_shape[1:])
encoder = Model_E0(x)
encoder = BatchNormalization()(encoder)
encoder = Model_E1(encoder)
encoder = BatchNormalization()(encoder)
encoder = Model_E2(encoder)
encoder = BatchNormalization()(encoder)
encoder = Model_E3(encoder)
encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Sequential(Model_E4)(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(
self.number_of_components,
name="cluster_assignment",
activation="softmax",
activity_regularizer=(
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
if self.reg_cat_clusters
else None
),
)(encoder)
z_gauss_mean = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
)
// 2,
name="cluster_means",
activation=None,
kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values
# computed from the labels, we could also initialize the prior this way, and update it every N epochs
)(encoder)
z_gauss_var = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
)
// 2,
name="cluster_variances",
activation=None,
activity_regularizer=(
tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
),
)(encoder)
z_gauss_mean = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
)
// 2,
name="cluster_means",
activation=None,
kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values
# computed from the labels, we could also initialize the prior this way, and update it every N epochs
)(encoder)
z_gauss_var = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
)
// 2,
name="cluster_variances",
activation=None,
activity_regularizer=(
tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
),
)(encoder)
z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
# Identity layer controlling for dead neurons in the Gaussian Mixture posterior
if self.neuron_control:
z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
# Identity layer controlling for dead neurons in the Gaussian Mixture posterior
if self.neuron_control:
z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
if self.overlap_loss:
z_gauss = deepof.model_utils.Cluster_overlap(
self.ENCODING,
self.number_of_components,
loss=self.overlap_loss,
)(z_gauss)
if self.overlap_loss:
z_gauss = deepof.model_utils.Cluster_overlap(
self.ENCODING,
self.number_of_components,
loss=self.overlap_loss,
)(z_gauss)
z = tfpl.DistributionLambda(
make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
cat=tfd.categorical.Categorical(
probs=gauss[0],
z = tfpl.DistributionLambda(
make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
cat=tfd.categorical.Categorical(
probs=gauss[0],
),
components=[
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
),
reinterpreted_batch_ndims=1,
)
for k in range(self.number_of_components)
],
),
components=[
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
),
reinterpreted_batch_ndims=1,
)
for k in range(self.number_of_components)
],
),
convert_to_tensor_fn="sample",
name="encoding_distribution",
)([z_cat, z_gauss])
convert_to_tensor_fn="sample",
name="encoding_distribution",
)([z_cat, z_gauss])
posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
# Define and control custom loss functions
if "ELBO" in self.loss:
kl_warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
# Define and control custom loss functions
if "ELBO" in self.loss:
kl_warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
# noinspection PyCallingNonCallable
z = deepof.model_utils.KLDivergenceLayer(
distribution_b=self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
iters=self.optimizer.iterations,
warm_up_iters=kl_warm_up_iters,
annealing_mode=self.kl_annealing_mode,
)(z)
if "MMD" in self.loss:
mmd_warm_up_iters = tf.cast(
self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
# noinspection PyCallingNonCallable
z = deepof.model_utils.KLDivergenceLayer(
distribution_b=self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
iters=self.optimizer.iterations,
warm_up_iters=kl_warm_up_iters,
annealing_mode=self.kl_annealing_mode,
)(z)
if "MMD" in self.loss:
mmd_warm_up_iters = tf.cast(
self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
z = deepof.model_utils.MMDiscrepancyLayer(
batch_size=self.batch_size,
prior=self.prior,
iters=self.optimizer.iterations,
warm_up_iters=mmd_warm_up_iters,
annealing_mode=self.mmd_annealing_mode,
)(z)
# Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
# Define and instantiate generator
g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g)
generator = Model_D2(generator)
generator = BatchNormalization()(generator)
generator = Model_D3(generator)
generator = Model_D4(generator)
generator = BatchNormalization()(generator)
generator = Model_D5(generator)
generator = BatchNormalization()(generator)
generator = Model_D6(generator)
generator = BatchNormalization()(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
x_decoded_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
)
x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
x_decoded = tf.keras.layers.concatenate(
[x_decoded_mean, x_decoded_var], axis=-1
)
x_decoded_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_reconstruction",
)(x_decoded)
# define individual branches as models
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
generator = Model(g, x_decoded_mean, name="vae_reconstruction")
def log_loss(x_true, p_x_q_given_z):
"""Computes the negative log likelihood of the data given
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [generator(encoder.outputs)]
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
if self.next_sequence_prediction > 0:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)(z)
predictor = BatchNormalization()(predictor)
predictor = Model_P1(predictor)
predictor = BatchNormalization()(predictor)
predictor = RepeatVector(input_shape[1])(predictor)
predictor = Model_P2(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P4(predictor)
x_predicted_mean = Dense(
z = deepof.model_utils.MMDiscrepancyLayer(
batch_size=self.batch_size,
prior=self.prior,
iters=self.optimizer.iterations,
warm_up_iters=mmd_warm_up_iters,
annealing_mode=self.mmd_annealing_mode,
)(z)
# Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
# Define and instantiate generator
g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g)
generator = Model_D2(generator)
generator = BatchNormalization()(generator)
generator = Model_D3(generator)
generator = Model_D4(generator)
generator = BatchNormalization()(generator)
generator = Model_D5(generator)
generator = BatchNormalization()(generator)
generator = Model_D6(generator)
generator = BatchNormalization()(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(predictor)
x_predicted_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
predictor
)
)
x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
x_predicted_var
)(generator)
x_decoded_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
)
x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
x_decoded = tf.keras.layers.concatenate(
[x_predicted_mean, x_predicted_var], axis=-1
[x_decoded_mean, x_decoded_var], axis=-1
)
x_predicted_mean = tfpl.IndependentNormal(
x_decoded_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_prediction",
name="vae_reconstruction",
)(x_decoded)
model_outs.append(x_predicted_mean)
model_losses.append(log_loss)
model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.next_sequence_prediction)
if self.phenotype_prediction > 0:
pheno_pred = Model_PC1(z)
pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
pheno_pred = tfpl.IndependentBernoulli(
event_shape=1,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="phenotype_prediction",
)(pheno_pred)
model_outs.append(pheno_pred)
model_losses.append(log_loss)
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction)
if self.rule_based_prediction > 0:
rule_pred = Model_RC1(z)
rule_pred = Dense(
tfpl.IndependentBernoulli.params_size(self.rule_based_features)
)(rule_pred)
rule_pred = tfpl.IndependentBernoulli(
event_shape=self.rule_based_features,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="rule_based_prediction",
)(rule_pred)
model_outs.append(rule_pred)
model_losses.append(log_loss)
model_metrics["rule_based_prediction"] = [
"mae",
"mse",
]
loss_weights.append(self.rule_based_prediction)
# define individual branches as models
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
generator = Model(g, x_decoded_mean, name="vae_reconstruction")
def log_loss(x_true, p_x_q_given_z):
"""Computes the negative log likelihood of the data given
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [generator(encoder.outputs)]
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
if self.next_sequence_prediction > 0:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)(z)
predictor = BatchNormalization()(predictor)
predictor = Model_P1(predictor)
predictor = BatchNormalization()(predictor)
predictor = RepeatVector(input_shape[1])(predictor)
predictor = Model_P2(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P4(predictor)
x_predicted_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(predictor)
x_predicted_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
predictor
)
)
x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
x_predicted_var
)
x_decoded = tf.keras.layers.concatenate(
[x_predicted_mean, x_predicted_var], axis=-1
)
x_predicted_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_prediction",
)(x_decoded)
model_outs.append(x_predicted_mean)
model_losses.append(log_loss)
model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.next_sequence_prediction)
if self.phenotype_prediction > 0:
pheno_pred = Model_PC1(z)
pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
pheno_pred = tfpl.IndependentBernoulli(
event_shape=1,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="phenotype_prediction",
)(pheno_pred)
model_outs.append(pheno_pred)
model_losses.append(log_loss)
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction)
if self.rule_based_prediction > 0:
rule_pred = Model_RC1(z)
rule_pred = Dense(
tfpl.IndependentBernoulli.params_size(self.rule_based_features)
)(rule_pred)
rule_pred = tfpl.IndependentBernoulli(
event_shape=self.rule_based_features,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="rule_based_prediction",
)(rule_pred)
model_outs.append(rule_pred)
model_losses.append(log_loss)
model_metrics["rule_based_prediction"] = [
"mae",
"mse",
]
loss_weights.append(self.rule_based_prediction)
# define grouper and end-to-end autoencoder model
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment