Skip to content
Snippets Groups Projects
Commit 3d50dfbe authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 0016131f
No related branches found
No related tags found
No related merge requests found
...@@ -262,6 +262,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -262,6 +262,7 @@ class SEQ_2_SEQ_GMVAE:
rule_based_features: int = 6, rule_based_features: int = 6,
reg_cat_clusters: bool = False, reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False, reg_cluster_variance: bool = False,
strategy = tf.distribute.MirroredStrategy()
): ):
self.hparams = self.get_hparams(architecture_hparams) self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -296,6 +297,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -296,6 +297,7 @@ class SEQ_2_SEQ_GMVAE:
self.prior = "standard_normal" self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance self.reg_cluster_variance = reg_cluster_variance
self.strategy = strategy
assert ( assert (
"ELBO" in self.loss or "MMD" in self.loss "ELBO" in self.loss or "MMD" in self.loss
...@@ -576,6 +578,8 @@ class SEQ_2_SEQ_GMVAE: ...@@ -576,6 +578,8 @@ class SEQ_2_SEQ_GMVAE:
Model_RC1, Model_RC1,
) = self.get_layers(input_shape) ) = self.get_layers(input_shape)
with self.strategy.scope():
# Define and instantiate encoder # Define and instantiate encoder
x = Input(shape=input_shape[1:]) x = Input(shape=input_shape[1:])
encoder = Model_E0(x) encoder = Model_E0(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment