Commit 5cceccdb authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a MirroredStrategy to train models on multiple GPUs if they are available

parent 3559e66c
Pipeline #100307 passed with stages
in 19 minutes and 16 seconds
......@@ -907,7 +907,6 @@ class coordinates:
entropy_knn: int = 100,
input_type: str = False,
run: int = 0,
strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder.
......@@ -975,7 +974,6 @@ class coordinates:
entropy_knn=entropy_knn,
input_type=input_type,
run=run,
strategy=strategy,
)
# returns a list of trained tensorflow models
......
......@@ -306,7 +306,6 @@ def autoencoder_fitting(
entropy_knn: int,
input_type: str,
run: int = 0,
strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -379,38 +378,38 @@ def autoencoder_fitting(
return_list = (encoder, decoder, ae)
else:
with strategy.scope():
(
encoder,
generator,
grouper,
ae,
prior,
posterior,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size * strategy.num_replicas_in_sync,
compile_model=True,
encoding=encoding_size,
kl_annealing_mode=kl_annealing_mode,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
rule_based_features=rule_based_features,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
).build(
X_train.shape
)
return_list = (encoder, generator, grouper, ae)
(
encoder,
generator,
grouper,
ae,
prior,
posterior,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_annealing_mode=kl_annealing_mode,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
rule_based_features=rule_based_features,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
).build(
X_train.shape
)
return_list = (encoder, generator, grouper, ae)
if pretrained:
# If pretrained models are specified, load weights and return
......@@ -483,7 +482,7 @@ def autoencoder_fitting(
x=Xs,
y=ys,
epochs=epochs,
batch_size=batch_size * strategy.num_replicas_in_sync,
batch_size=batch_size,
verbose=1,
validation_data=(
Xvals,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment