Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
998f5107
Commit
998f5107
authored
Mar 18, 2021
by
lucas_miranda
Browse files
Reimplemented KL warmup using optimizer iterators; getting rid of the clumsy callback
parent
90ba416a
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
998f5107
...
...
@@ -439,7 +439,6 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
self
.
is_placeholder
=
True
self
.
_iters
=
iters
self
.
_warm_up_iters
=
warm_up_iters
self
.
_regularizer
.
_weight
=
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
])
def
get_config
(
self
):
# pragma: no cover
"""Updates Constraint metadata"""
...
...
@@ -451,6 +450,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def
call
(
self
,
distribution_a
):
"""Updates Layer's call method"""
self
.
_regularizer
.
_weight
=
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
])
kl_batch
=
self
.
_regularizer
(
distribution_a
)
self
.
add_loss
(
kl_batch
,
inputs
=
[
distribution_a
])
self
.
add_metric
(
...
...
deepof/models.py
View file @
998f5107
...
...
@@ -626,7 +626,8 @@ class SEQ_2_SEQ_GMVAE:
if
"ELBO"
in
self
.
loss
:
warm_up_iters
=
tf
.
cast
(
self
.
kl_warmup
*
(
input_shape
[
0
]
/
self
.
batch_size
),
tf
.
int64
self
.
kl_warmup
*
(
input_shape
[
0
]
/
self
.
batch_size
),
tf
.
int64
,
)
# noinspection PyCallingNonCallable
...
...
deepof/train_utils.py
View file @
998f5107
...
...
@@ -332,7 +332,6 @@ def autoencoder_fitting(
generator
,
grouper
,
ae
,
kl_warmup_callback
,
mmd_warmup_callback
,
)
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
({}
if
hparams
is
None
else
hparams
),
...
...
@@ -395,9 +394,6 @@ def autoencoder_fitting(
),
]
if
"ELBO"
in
loss
and
kl_warmup
>
0
:
# noinspection PyUnboundLocalVariable
callbacks_
.
append
(
kl_warmup_callback
)
if
"MMD"
in
loss
and
mmd_warmup
>
0
:
# noinspection PyUnboundLocalVariable
callbacks_
.
append
(
mmd_warmup_callback
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment