Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
fbed44fa
Commit
fbed44fa
authored
Apr 27, 2021
by
lucas_miranda
Browse files
Added a MirroredStrategy to train models on multiple GPUs if they are available
parent
5cceccdb
Pipeline
#100310
canceled with stages
in 5 minutes and 39 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
fbed44fa
...
...
@@ -907,6 +907,7 @@ class coordinates:
entropy_knn
:
int
=
100
,
input_type
:
str
=
False
,
run
:
int
=
0
,
strategy
:
tf
.
distribute
.
Strategy
=
tf
.
distribute
.
MirroredStrategy
(),
)
->
Tuple
:
"""
Annotates coordinates using an unsupervised autoencoder.
...
...
@@ -974,6 +975,7 @@ class coordinates:
entropy_knn
=
entropy_knn
,
input_type
=
input_type
,
run
=
run
,
strategy
=
strategy
,
)
# returns a list of trained tensorflow models
...
...
deepof/train_utils.py
View file @
fbed44fa
...
...
@@ -306,6 +306,7 @@ def autoencoder_fitting(
entropy_knn
:
int
,
input_type
:
str
,
run
:
int
=
0
,
strategy
:
tf
.
distribute
.
Strategy
=
tf
.
distribute
.
MirroredStrategy
(),
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
...
...
@@ -378,38 +379,38 @@ def autoencoder_fitting(
return_list
=
(
encoder
,
decoder
,
ae
)
else
:
(
encoder
,
generator
,
grouper
,
ae
,
prior
,
posterior
,
)
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
({}
if
hparams
is
None
else
hparams
),
batch_size
=
batch_size
,
compile_model
=
True
,
encoding
=
encoding_size
,
kl_annealing_mode
=
kl_annealing_mode
,
kl_warmup_epochs
=
kl_warmup
,
loss
=
loss
,
mmd_annealing_mode
=
mmd_annealing_mode
,
mmd_warmup_epochs
=
mmd_warmup
,
montecarlo_kl
=
montecarlo_kl
,
neuron_control
=
False
,
number_of_components
=
n_components
,
overlap_loss
=
False
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
rule_based_features
=
rule_based_features
,
reg_cat_clusters
=
reg_cat_clusters
,
reg_cluster_variance
=
reg_cluster_variance
,
).
build
(
X_train
.
shape
)
return_list
=
(
encoder
,
generator
,
grouper
,
ae
)
with
strategy
.
scope
():
(
encoder
,
generator
,
grouper
,
ae
,
prior
,
posterior
,
)
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
({}
if
hparams
is
None
else
hparams
),
batch_size
=
batch_size
*
strategy
.
num_replicas_in_sync
,
compile_model
=
True
,
encoding
=
encoding_size
,
kl_annealing_mode
=
kl_annealing_mode
,
kl_warmup_epochs
=
kl_warmup
,
loss
=
loss
,
mmd_annealing_mode
=
mmd_annealing_mode
,
mmd_warmup_epochs
=
mmd_warmup
,
montecarlo_kl
=
montecarlo_kl
,
neuron_control
=
False
,
number_of_components
=
n_components
,
overlap_loss
=
False
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
rule_based_features
=
rule_based_features
,
reg_cat_clusters
=
reg_cat_clusters
,
reg_cluster_variance
=
reg_cluster_variance
,
).
build
(
X_train
.
shape
)
return_list
=
(
encoder
,
generator
,
grouper
,
ae
)
if
pretrained
:
# If pretrained models are specified, load weights and return
...
...
@@ -478,16 +479,29 @@ def autoencoder_fitting(
ys
+=
[
y_train
[
-
Xs
.
shape
[
0
]
:]]
yvals
+=
[
y_val
[
-
Xvals
.
shape
[
0
]
:]]
# Convert data to tf.data.Dataset objects
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
DATA
)
train_dataset
=
(
tf
.
data
.
Dataset
.
from_tensor_slices
((
Xs
,
*
ys
))
.
with_options
(
options
)
.
batch
(
batch_size
)
)
val_dataset
=
(
tf
.
data
.
Dataset
.
from_tensor_slices
((
Xvals
,
*
yvals
))
.
with_options
(
options
)
.
batch
(
batch_size
)
)
ae
.
fit
(
x
=
Xs
,
y
=
ys
,
x
=
train_dataset
,
epochs
=
epochs
,
batch_size
=
batch_size
,
batch_size
=
batch_size
*
strategy
.
num_replicas_in_sync
,
verbose
=
1
,
validation_data
=
(
Xvals
,
yvals
,
),
validation_data
=
val_dataset
,
callbacks
=
callbacks_
,
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment