Commit 44e98b2c authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 0d565545
Pipeline #100313 canceled with stages
in 21 minutes and 14 seconds
......@@ -316,6 +316,19 @@ def autoencoder_fitting(
# To avoid stability issues
tf.keras.backend.clear_session()
# Set options for tf.data.Datasets
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.DATA
)
# Generate validation dataset for callback usage
X_val_dataset = (
tf.data.Dataset.from_tensor_slices(X_val)
.with_options(options)
.batch(batch_size)
)
# Defines what to log on tensorboard (useful for trying out different models)
logparam = {
"encoding": encoding_size,
......@@ -337,7 +350,7 @@ def autoencoder_fitting(
loss_warmup=kl_warmup,
warmup_mode=kl_annealing_mode,
input_type=input_type,
X_val=(X_val if X_val.shape != (0,) else None),
X_val=(X_val_dataset if X_val.shape != (0,) else None),
cp=save_checkpoints,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
......@@ -480,26 +493,21 @@ def autoencoder_fitting(
yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.DATA
)
train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, *ys))
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size * strategy.num_replicas_in_sync)
.shuffle(buffer_size=X_train.shape[0])
.with_options(options)
.batch(batch_size)
)
val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, *yvals))
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size * strategy.num_replicas_in_sync)
.with_options(options)
.batch(batch_size)
)
ae.fit(
x=train_dataset,
epochs=epochs,
batch_size=batch_size * strategy.num_replicas_in_sync,
verbose=1,
validation_data=val_dataset,
callbacks=callbacks_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment