Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
de37354f
Commit
de37354f
authored
Jun 04, 2020
by
lucas_miranda
Browse files
Implemented KL and MMD warmup on SEQ2SEQ_VAE in models.py
parent
37c1c84a
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
main.ipynb
View file @
de37354f
This diff is collapsed.
Click to expand it.
source/model_utils.py
View file @
de37354f
...
...
@@ -123,6 +123,11 @@ class KLDivergenceLayer(Layer):
self
.
beta
=
beta
super
(
KLDivergenceLayer
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"beta"
:
self
.
beta
})
return
config
def
call
(
self
,
inputs
,
**
kwargs
):
mu
,
log_var
=
inputs
...
...
@@ -144,6 +149,11 @@ class MMDiscrepancyLayer(Layer):
self
.
beta
=
beta
super
(
MMDiscrepancyLayer
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"beta"
:
self
.
beta
})
return
config
def
call
(
self
,
z
,
**
kwargs
):
true_samples
=
K
.
random_normal
(
K
.
shape
(
z
))
mmd_batch
=
compute_mmd
(
true_samples
,
z
)
...
...
source/models.py
View file @
de37354f
# @author lucasmiranda42
from
tensorflow.keras
import
backend
as
K
from
tensorflow.keras
import
Input
,
Model
,
Sequential
from
tensorflow.keras.callbacks
import
LambdaCallback
from
tensorflow.keras.constraints
import
UnitNorm
from
tensorflow.keras.initializers
import
he_uniform
,
Orthogonal
from
tensorflow.keras.layers
import
BatchNormalization
,
Bidirectional
,
Dense
...
...
@@ -173,6 +175,10 @@ class SEQ_2_SEQ_VAE:
self
.
kl_warmup
=
kl_warmup_epochs
self
.
mmd_warmup
=
mmd_warmup_epochs
assert
(
"ELBO"
in
self
.
loss
or
"MMD"
in
self
.
loss
),
"loss must be one of ELBO, MMD or ELBO+MMD (default)"
def
build
(
self
):
# Encoder Layers
Model_E0
=
tf
.
keras
.
layers
.
Conv1D
(
...
...
@@ -266,13 +272,36 @@ class SEQ_2_SEQ_VAE:
z_mean
=
Dense
(
self
.
ENCODING
)(
encoder
)
z_log_sigma
=
Dense
(
self
.
ENCODING
)(
encoder
)
kl_wu
=
False
if
"ELBO"
in
self
.
loss
:
z_mean
,
z_log_sigma
=
KLDivergenceLayer
()([
z_mean
,
z_log_sigma
])
kl_beta
=
1
if
self
.
kl_warmup
:
def
klwarmup
(
epoch
):
value
=
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
print
(
"beta:"
,
value
)
kl_beta
=
value
kl_wu
=
LambdaCallback
(
on_epoch_end
=
lambda
epoch
,
log
:
klwarmup
(
epoch
))
z_mean
,
z_log_sigma
=
KLDivergenceLayer
(
beta
=
kl_beta
)([
z_mean
,
z_log_sigma
])
z
=
Lambda
(
sampling
)([
z_mean
,
z_log_sigma
])
mmd_wu
=
False
if
"MMD"
in
self
.
loss
:
z
=
MMDiscrepancyLayer
()(
z
)
mmd_beta
=
1
if
self
.
kl_warmup
:
def
mmdwarmup
(
epoch
):
value
=
K
.
min
([
epoch
/
self
.
mmd_warmup
,
1
])
print
(
"mmd_beta:"
,
value
)
mmd_beta
=
value
mmd_wu
=
LambdaCallback
(
on_epoch_end
=
lambda
epoch
,
log
:
mmdwarmup
(
epoch
))
z
=
MMDiscrepancyLayer
(
beta
=
mmd_beta
)(
z
)
# Define and instantiate generator
generator
=
Model_D0
(
z
)
...
...
@@ -319,7 +348,7 @@ class SEQ_2_SEQ_VAE:
experimental_run_tf_function
=
False
,
)
return
encoder
,
generator
,
vae
return
encoder
,
generator
,
vae
,
kl_wu
,
mmd_wu
class
SEQ_2_SEQ_VAEP
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment