Commit 3d50dfbe authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 0016131f
......@@ -262,6 +262,7 @@ class SEQ_2_SEQ_GMVAE:
rule_based_features: int = 6,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
strategy = tf.distribute.MirroredStrategy()
):
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
......@@ -296,6 +297,7 @@ class SEQ_2_SEQ_GMVAE:
self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance
self.strategy = strategy
assert (
"ELBO" in self.loss or "MMD" in self.loss
......@@ -576,6 +578,8 @@ class SEQ_2_SEQ_GMVAE:
Model_RC1,
) = self.get_layers(input_shape)
with self.strategy.scope():
# Define and instantiate encoder
x = Input(shape=input_shape[1:])
encoder = Model_E0(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment