Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
ba5967e6
Commit
ba5967e6
authored
Feb 08, 2021
by
lucas_miranda
Browse files
Implemented autoencoder fitting as part of main module in data.py
parent
d79fb2d0
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
ba5967e6
...
...
@@ -797,6 +797,7 @@ class coordinates:
predictor
:
float
=
0
,
pretrained
:
str
=
False
,
save_checkpoints
:
bool
=
False
,
save_weights
:
bool
=
True
,
variational
:
bool
=
True
,
)
->
Tuple
:
"""
...
...
@@ -852,6 +853,7 @@ class coordinates:
predictor
=
predictor
,
pretrained
=
pretrained
,
save_checkpoints
=
save_checkpoints
,
save_weights
=
save_weights
,
variational
=
variational
,
)
...
...
deepof/train_model.py
View file @
ba5967e6
...
...
@@ -346,254 +346,25 @@ print("Done!")
# as many times as specified by runs
if
not
tune
:
# Training loop
for
run
in
range
(
runs
):
# To avoid stability issues
tf
.
keras
.
backend
.
clear_session
()
run_ID
,
tensorboard_callback
,
onecycle
,
cp_callback
=
get_callbacks
(
X_train
=
X_train
,
batch_size
=
batch_size
,
cp
=
True
,
variational
=
variational
,
phenotype_class
=
pheno_class
,
predictor
=
predictor
,
loss
=
loss
,
logparam
=
logparam
,
outpath
=
output_path
,
)
logparams
=
[
hp
.
HParam
(
"encoding"
,
hp
.
Discrete
([
2
,
4
,
6
,
8
,
12
,
16
]),
display_name
=
"encoding"
,
description
=
"encoding size dimensionality"
,
),
hp
.
HParam
(
"k"
,
hp
.
IntInterval
(
min_value
=
1
,
max_value
=
15
),
display_name
=
"k"
,
description
=
"cluster_number"
,
),
hp
.
HParam
(
"loss"
,
hp
.
Discrete
([
"ELBO"
,
"MMD"
,
"ELBO+MMD"
]),
display_name
=
"loss function"
,
description
=
"loss function"
,
),
hp
.
HParam
(
"run"
,
hp
.
Discrete
([
0
,
1
,
2
]),
display_name
=
"trial run"
,
description
=
"trial run"
,
),
]
rec
=
"reconstruction_"
if
pheno_class
else
""
metrics
=
[
hp
.
Metric
(
"val_{}mae"
.
format
(
rec
),
display_name
=
"val_{}mae"
.
format
(
rec
)),
hp
.
Metric
(
"val_{}mse"
.
format
(
rec
),
display_name
=
"val_{}mse"
.
format
(
rec
)),
]
logparam
[
"run"
]
=
run
if
pheno_class
:
logparams
.
append
(
hp
.
HParam
(
"pheno_weight"
,
hp
.
RealInterval
(
min_value
=
0.0
,
max_value
=
1000.0
),
display_name
=
"pheno weight"
,
description
=
"weight applied to phenotypic classifier from the latent space"
,
)
)
metrics
+=
[
hp
.
Metric
(
"phenotype_prediction_accuracy"
,
display_name
=
"phenotype_prediction_accuracy"
,
),
hp
.
Metric
(
"phenotype_prediction_auc"
,
display_name
=
"phenotype_prediction_auc"
,
),
]
with
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
output_path
,
"hparams"
,
run_ID
)
).
as_default
():
hp
.
hparams_config
(
hparams
=
logparams
,
metrics
=
metrics
,
)
if
not
variational
:
encoder
,
decoder
,
ae
=
SEQ_2_SEQ_AE
(
hparams
).
build
(
X_train
.
shape
)
print
(
ae
.
summary
())
ae
.
save_weights
(
os
.
path
.
join
(
output_path
,
"/checkpoints/cp-{epoch:04d}.ckpt"
.
format
(
epoch
=
0
)
)
)
# Fit the specified model to the data
history
=
ae
.
fit
(
x
=
X_train
,
y
=
X_train
,
epochs
=
35
,
batch_size
=
batch_size
,
verbose
=
1
,
validation_data
=
(
X_val
,
X_val
),
callbacks
=
[
tensorboard_callback
,
cp_callback
,
onecycle
,
CustomStopper
(
monitor
=
"val_loss"
,
patience
=
5
,
restore_best_weights
=
True
,
start_epoch
=
max
(
kl_wu
,
mmd_wu
),
),
],
)
ae
.
save_weights
(
"{}_final_weights.h5"
.
format
(
run_ID
))
else
:
(
encoder
,
generator
,
grouper
,
gmvaep
,
kl_warmup_callback
,
mmd_warmup_callback
,
)
=
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
hparams
,
batch_size
=
batch_size
,
compile_model
=
True
,
encoding
=
encoding_size
,
kl_warmup_epochs
=
kl_wu
,
loss
=
loss
,
mmd_warmup_epochs
=
mmd_wu
,
montecarlo_kl
=
mc_kl
,
neuron_control
=
neuron_control
,
number_of_components
=
k
,
overlap_loss
=
overlap_loss
,
phenotype_prediction
=
pheno_class
,
predictor
=
predictor
,
).
build
(
X_train
.
shape
)
print
(
gmvaep
.
summary
())
callbacks_
=
[
tensorboard_callback
,
# cp_callback,
onecycle
,
CustomStopper
(
monitor
=
"val_loss"
,
patience
=
5
,
restore_best_weights
=
True
,
start_epoch
=
max
(
kl_wu
,
mmd_wu
),
),
]
if
"ELBO"
in
loss
and
kl_wu
>
0
:
callbacks_
.
append
(
kl_warmup_callback
)
if
"MMD"
in
loss
and
mmd_wu
>
0
:
callbacks_
.
append
(
mmd_warmup_callback
)
Xs
,
ys
=
[
X_train
],
[
X_train
]
Xvals
,
yvals
=
[
X_val
],
[
X_val
]
if
predictor
>
0.0
:
Xs
,
ys
=
X_train
[:
-
1
],
[
X_train
[:
-
1
],
X_train
[
1
:]]
Xvals
,
yvals
=
X_val
[:
-
1
],
[
X_val
[:
-
1
],
X_val
[
1
:]]
if
pheno_class
>
0.0
:
ys
+=
[
y_train
]
yvals
+=
[
y_val
]
history
=
gmvaep
.
fit
(
x
=
Xs
,
y
=
ys
,
epochs
=
35
,
batch_size
=
batch_size
,
verbose
=
1
,
validation_data
=
(
Xvals
,
yvals
,
),
callbacks
=
callbacks_
,
)
gmvaep
.
save_weights
(
os
.
path
.
join
(
output_path
,
"trained_weights"
,
"GMVAE_loss={}_encoding={}_k={}_{}{}run_{}_final_weights.h5"
.
format
(
loss
,
encoding_size
,
k
,
(
"pheno={}_"
.
format
(
pheno_class
)
if
pheno_class
else
""
),
(
"predictor={}_"
.
format
(
predictor
)
if
predictor
else
""
),
run
,
),
)
)
# noinspection PyUnboundLocalVariable
def
tensorboard_metric_logging
(
run_dir
:
str
,
hpms
:
Any
):
output
=
gmvaep
.
predict
(
X_val
)
if
pheno_class
or
predictor
:
reconstruction
=
output
[
0
]
prediction
=
output
[
1
]
pheno
=
output
[
-
1
]
else
:
reconstruction
=
output
with
tf
.
summary
.
create_file_writer
(
run_dir
).
as_default
():
hp
.
hparams
(
hpms
)
# record the values used in this trial
val_mae
=
tf
.
reduce_mean
(
tf
.
keras
.
metrics
.
mean_absolute_error
(
X_val
,
reconstruction
)
)
val_mse
=
tf
.
reduce_mean
(
tf
.
keras
.
metrics
.
mean_squared_error
(
X_val
,
reconstruction
)
)
tf
.
summary
.
scalar
(
"val_{}mae"
.
format
(
rec
),
val_mae
,
step
=
1
)
tf
.
summary
.
scalar
(
"val_{}mse"
.
format
(
rec
),
val_mse
,
step
=
1
)
if
predictor
:
pred_mae
=
tf
.
reduce_mean
(
tf
.
keras
.
metrics
.
mean_absolute_error
(
X_val
,
prediction
)
)
pred_mse
=
tf
.
reduce_mean
(
tf
.
keras
.
metrics
.
mean_squared_error
(
X_val
,
prediction
)
)
tf
.
summary
.
scalar
(
"val_prediction_mae"
.
format
(
rec
),
pred_mae
,
step
=
1
)
tf
.
summary
.
scalar
(
"val_prediction_mse"
.
format
(
rec
),
pred_mse
,
step
=
1
)
if
pheno_class
:
pheno_acc
=
tf
.
keras
.
metrics
.
binary_accuracy
(
y_val
,
tf
.
squeeze
(
pheno
)
)
pheno_auc
=
roc_auc_score
(
y_val
,
pheno
)
tf
.
summary
.
scalar
(
"phenotype_prediction_accuracy"
,
pheno_acc
,
step
=
1
)
tf
.
summary
.
scalar
(
"phenotype_prediction_auc"
,
pheno_auc
,
step
=
1
)
# Logparams to tensorboard
tensorboard_metric_logging
(
os
.
path
.
join
(
output_path
,
"hparams"
,
run_ID
),
logparam
,
)
# To avoid stability issues
tf
.
keras
.
backend
.
clear_session
()
trained_models
=
deep_unsupervised_embedding
(
(
X_train
,
y_train
,
X_val
,
y_val
),
batch_size
=
batch_size
,
encoding_size
=
encoding_size
,
hparams
=
hparams
,
kl_warmup
=
kl_wu
,
log_history
=
True
,
log_hparams
=
True
,
loss
=
loss
,
mmd_warmup
=
mmd_wu
,
montecarlo_kl
=
mc_kl
,
n_components
=
k
,
output_path
=
output_path
,
phenotype_class
=
pheno_class
,
predictor
=
predictor
,
save_checkpoints
=
False
,
save_weights
=
True
,
variational
=
variational
,
)
else
:
# Runs hyperparameter tuning with the specified parameters and saves the results
...
...
deepof/train_utils.py
View file @
ba5967e6
...
...
@@ -156,6 +156,7 @@ def deep_unsupervised_embedding(
predictor
:
float
,
pretrained
:
str
,
save_checkpoints
:
bool
,
save_weights
:
bool
,
variational
:
bool
,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment