Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
64469448
Commit
64469448
authored
Mar 18, 2021
by
lucas_miranda
Browse files
Reimplemented MMD warmup using optimizer iterators; getting rid of the clumsy callback
parent
d2c9404c
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/hypermodels.py
View file @
64469448
...
...
@@ -168,7 +168,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
lstm_units_1
,
)
=
self
.
get_hparams
(
hp
)
gmvaep
,
kl_warmup_callback
,
mmd_warmup_callback
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
gmvaep
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
{
"bidirectional_merge"
:
"ave"
,
"clipvalue"
:
clipvalue
,
...
...
@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
overlap_loss
=
self
.
overlap_loss
,
phenotype_prediction
=
self
.
pheno_class
,
predictor
=
self
.
predictor
,
).
build
(
self
.
input_shape
)[
3
:
]
).
build
(
self
.
input_shape
)[
-
1
]
return
gmvaep
...
...
deepof/model_utils.py
View file @
64469448
...
...
@@ -445,13 +445,23 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"is_placeholder"
:
self
.
is_placeholder
})
config
.
update
({
"_iters"
:
self
.
_iters
})
config
.
update
({
"_warm_up_iters"
:
self
.
_warm_up_iters
})
return
config
def
call
(
self
,
distribution_a
):
"""Updates Layer's call method"""
self
.
_regularizer
.
_weight
=
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
])
kl_batch
=
self
.
_regularizer
(
distribution_a
)
# Define and update KL weight for warmup
if
self
.
_warm_up_iters
>
0
:
kl_weight
=
tf
.
cast
(
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
]),
tf
.
float32
)
else
:
kl_weight
=
tf
.
cast
(
1.0
,
tf
.
float32
)
kl_batch
=
kl_weight
*
self
.
_regularizer
(
distribution_a
)
self
.
add_loss
(
kl_batch
,
inputs
=
[
distribution_a
])
self
.
add_metric
(
kl_batch
,
...
...
@@ -459,7 +469,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
name
=
"kl_divergence"
,
)
# noinspection PyProtectedMember
self
.
add_metric
(
self
.
_regularizer
.
_weight
,
aggregation
=
"mean"
,
name
=
"kl_rate"
)
self
.
add_metric
(
kl
_weight
,
aggregation
=
"mean"
,
name
=
"kl_rate"
)
return
distribution_a
...
...
@@ -483,8 +493,8 @@ class MMDiscrepancyLayer(Layer):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"batch_size"
:
self
.
batch_size
})
config
.
update
({
"iters"
:
self
.
_iters
})
config
.
update
({
"warmup_iters"
:
self
.
_warm_up_iters
})
config
.
update
({
"
_
iters"
:
self
.
_iters
})
config
.
update
({
"
_
warmup_iters"
:
self
.
_warm_up_iters
})
config
.
update
({
"prior"
:
self
.
prior
})
return
config
...
...
@@ -492,12 +502,17 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method"""
true_samples
=
self
.
prior
.
sample
(
self
.
batch_size
)
mmd_weight
=
tf
.
cast
(
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
]),
tf
.
float32
)
# noinspection PyTypeChecker
# Define and update MMD weight for warmup
if
self
.
_warm_up_iters
>
0
:
mmd_weight
=
tf
.
cast
(
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
]),
tf
.
float32
)
else
:
mmd_weight
=
tf
.
cast
(
1.0
,
tf
.
float32
)
mmd_batch
=
mmd_weight
*
compute_mmd
((
true_samples
,
z
))
self
.
add_loss
(
K
.
mean
(
mmd_batch
),
inputs
=
z
)
self
.
add_metric
(
mmd_batch
,
aggregation
=
"mean"
,
name
=
"mmd"
)
self
.
add_metric
(
mmd_weight
,
aggregation
=
"mean"
,
name
=
"mmd_rate"
)
...
...
deepof/models.py
View file @
64469448
...
...
@@ -625,7 +625,7 @@ class SEQ_2_SEQ_GMVAE:
# Define and control custom loss functions
if
"ELBO"
in
self
.
loss
:
warm_up_iters
=
tf
.
cast
(
kl_
warm_up_iters
=
tf
.
cast
(
self
.
kl_warmup
*
(
input_shape
[
0
]
//
self
.
batch_size
+
1
),
tf
.
int64
,
)
...
...
@@ -636,12 +636,12 @@ class SEQ_2_SEQ_GMVAE:
test_points_fn
=
lambda
q
:
q
.
sample
(
self
.
mc_kl
),
test_points_reduce_axis
=
0
,
iters
=
self
.
optimizer
.
iterations
,
warm_up_iters
=
warm_up_iters
,
warm_up_iters
=
kl_
warm_up_iters
,
)(
z
)
if
"MMD"
in
self
.
loss
:
warm_up_iters
=
tf
.
cast
(
mmd_
warm_up_iters
=
tf
.
cast
(
self
.
mmd_warmup
*
(
input_shape
[
0
]
//
self
.
batch_size
+
1
),
tf
.
int64
,
)
...
...
@@ -650,7 +650,7 @@ class SEQ_2_SEQ_GMVAE:
batch_size
=
self
.
batch_size
,
prior
=
self
.
prior
,
iters
=
self
.
optimizer
.
iterations
,
warm_up_iters
=
warm_up_iters
,
warm_up_iters
=
mmd_
warm_up_iters
,
)(
z
)
# Dummy layer with no parameters, to retrieve the previous tensor
...
...
@@ -767,4 +767,3 @@ class SEQ_2_SEQ_GMVAE:
# - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
# - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
# - Explore expanding the event dims of the final reconstruction layer
# - Gaussian Mixture as output layer? One component per bodypart (makes sense?)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment