Commit fbed44fa authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 5cceccdb
Pipeline #100310 canceled with stages
in 5 minutes and 39 seconds
...@@ -907,6 +907,7 @@ class coordinates: ...@@ -907,6 +907,7 @@ class coordinates:
entropy_knn: int = 100, entropy_knn: int = 100,
input_type: str = False, input_type: str = False,
run: int = 0, run: int = 0,
strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
) -> Tuple: ) -> Tuple:
""" """
Annotates coordinates using an unsupervised autoencoder. Annotates coordinates using an unsupervised autoencoder.
...@@ -974,6 +975,7 @@ class coordinates: ...@@ -974,6 +975,7 @@ class coordinates:
entropy_knn=entropy_knn, entropy_knn=entropy_knn,
input_type=input_type, input_type=input_type,
run=run, run=run,
strategy=strategy,
) )
# returns a list of trained tensorflow models # returns a list of trained tensorflow models
......
...@@ -306,6 +306,7 @@ def autoencoder_fitting( ...@@ -306,6 +306,7 @@ def autoencoder_fitting(
entropy_knn: int, entropy_knn: int,
input_type: str, input_type: str,
run: int = 0, run: int = 0,
strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
): ):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding""" """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
...@@ -378,38 +379,38 @@ def autoencoder_fitting( ...@@ -378,38 +379,38 @@ def autoencoder_fitting(
return_list = (encoder, decoder, ae) return_list = (encoder, decoder, ae)
else: else:
with strategy.scope():
( (
encoder, encoder,
generator, generator,
grouper, grouper,
ae, ae,
prior, prior,
posterior, posterior,
) = deepof.models.SEQ_2_SEQ_GMVAE( ) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams), architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size, batch_size=batch_size * strategy.num_replicas_in_sync,
compile_model=True, compile_model=True,
encoding=encoding_size, encoding=encoding_size,
kl_annealing_mode=kl_annealing_mode, kl_annealing_mode=kl_annealing_mode,
kl_warmup_epochs=kl_warmup, kl_warmup_epochs=kl_warmup,
loss=loss, loss=loss,
mmd_annealing_mode=mmd_annealing_mode, mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup_epochs=mmd_warmup, mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl, montecarlo_kl=montecarlo_kl,
neuron_control=False, neuron_control=False,
number_of_components=n_components, number_of_components=n_components,
overlap_loss=False, overlap_loss=False,
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction, rule_based_prediction=rule_based_prediction,
rule_based_features=rule_based_features, rule_based_features=rule_based_features,
reg_cat_clusters=reg_cat_clusters, reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance, reg_cluster_variance=reg_cluster_variance,
).build( ).build(
X_train.shape X_train.shape
) )
return_list = (encoder, generator, grouper, ae) return_list = (encoder, generator, grouper, ae)
if pretrained: if pretrained:
# If pretrained models are specified, load weights and return # If pretrained models are specified, load weights and return
...@@ -478,16 +479,29 @@ def autoencoder_fitting( ...@@ -478,16 +479,29 @@ def autoencoder_fitting(
ys += [y_train[-Xs.shape[0] :]] ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]] yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.DATA
)
train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, *ys))
.with_options(options)
.batch(batch_size)
)
val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, *yvals))
.with_options(options)
.batch(batch_size)
)
ae.fit( ae.fit(
x=Xs, x=train_dataset,
y=ys,
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size * strategy.num_replicas_in_sync,
verbose=1, verbose=1,
validation_data=( validation_data=val_dataset,
Xvals,
yvals,
),
callbacks=callbacks_, callbacks=callbacks_,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment