Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
14bcf40b
Commit
14bcf40b
authored
Jun 04, 2020
by
lucas_miranda
Browse files
Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py
parent
0a974291
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
main.ipynb
View file @
14bcf40b
This diff is collapsed.
Click to expand it.
source/model_utils.py
View file @
14bcf40b
...
...
@@ -118,7 +118,7 @@ class KLDivergenceLayer(Layer):
to the final model loss.
"""
def
__init__
(
self
,
beta
=
1
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
beta
=
1
.0
,
*
args
,
**
kwargs
):
self
.
is_placeholder
=
True
self
.
beta
=
beta
super
(
KLDivergenceLayer
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -131,7 +131,9 @@ class KLDivergenceLayer(Layer):
def
call
(
self
,
inputs
,
**
kwargs
):
mu
,
log_var
=
inputs
kL_batch
=
-
0.5
*
K
.
sum
(
1
+
log_var
-
K
.
square
(
mu
)
-
K
.
exp
(
log_var
),
axis
=-
1
)
self
.
add_loss
(
self
.
beta
*
K
.
mean
(
kL_batch
),
inputs
=
inputs
)
self
.
add_metric
(
self
.
beta
,
aggregation
=
"mean"
,
name
=
"kl_rate"
)
return
inputs
...
...
@@ -141,7 +143,7 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def
__init__
(
self
,
beta
=
1
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
beta
=
1
.0
,
*
args
,
**
kwargs
):
self
.
is_placeholder
=
True
self
.
beta
=
beta
super
(
MMDiscrepancyLayer
,
self
).
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -156,5 +158,6 @@ class MMDiscrepancyLayer(Layer):
mmd_batch
=
compute_mmd
(
true_samples
,
z
)
self
.
add_loss
(
self
.
beta
*
K
.
mean
(
mmd_batch
),
inputs
=
z
)
self
.
add_metric
(
self
.
beta
,
aggregation
=
"mean"
,
name
=
"mmd_rate"
)
return
z
source/models.py
View file @
14bcf40b
...
...
@@ -273,32 +273,32 @@ class SEQ_2_SEQ_VAE:
z_log_sigma
=
Dense
(
self
.
ENCODING
)(
encoder
)
# Define and control custom loss functions
kl_w
u
=
False
kl_w
armup_callback
=
False
if
"ELBO"
in
self
.
loss
:
kl_beta
=
1
kl_beta
=
K
.
variable
(
1.0
,
name
=
"kl_beta"
)
if
self
.
kl_warmup
:
def
klwarmup
(
epoch
):
kl_beta
=
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
kl_wu
=
LambdaCallback
(
on_epoch_end
=
lambda
epoch
,
log
:
klwarmup
(
epoch
))
kl_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
K
.
set_value
(
kl_beta
,
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
)
)
z_mean
,
z_log_sigma
=
KLDivergenceLayer
(
beta
=
kl_beta
)([
z_mean
,
z_log_sigma
])
z
=
Lambda
(
sampling
)([
z_mean
,
z_log_sigma
])
mmd_w
u
=
False
mmd_w
armup_callback
=
False
if
"MMD"
in
self
.
loss
:
mmd_beta
=
1
if
self
.
kl
_warmup
:
mmd_beta
=
K
.
variable
(
1.0
,
name
=
"mmd_beta"
)
if
self
.
mmd
_warmup
:
def
mmdwarmup
(
epoch
):
mmd_beta
=
K
.
min
([
epoch
/
self
.
mmd_warmup
,
1
])
mmd_wu
=
LambdaCallback
(
on_epoch_end
=
lambda
epoch
,
log
:
mmdwarmup
(
epoch
)
mmd_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
K
.
set_value
(
mmd_beta
,
K
.
min
([
epoch
/
self
.
mmd_warmup
,
1
])
)
)
z
=
MMDiscrepancyLayer
(
beta
=
mmd_beta
)(
z
)
...
...
@@ -348,7 +348,7 @@ class SEQ_2_SEQ_VAE:
experimental_run_tf_function
=
False
,
)
return
encoder
,
generator
,
vae
,
kl_w
u
,
mmd_wu
return
encoder
,
generator
,
vae
,
kl_w
armup_callback
,
mmd_warmup_callback
class
SEQ_2_SEQ_VAEP
:
...
...
@@ -477,34 +477,30 @@ class SEQ_2_SEQ_VAEP:
z_log_sigma
=
Dense
(
self
.
ENCODING
)(
encoder
)
# Define and control custom loss functions
kl_w
u
=
False
kl_w
armup_callback
=
False
if
"ELBO"
in
self
.
loss
:
kl_beta
=
1
kl_beta
=
K
.
variable
(
1.0
,
name
=
"kl_beta"
)
if
self
.
kl_warmup
:
def
klwarmup
(
epoch
):
kl_beta
=
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
kl_wu
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
log
:
klwarmup
(
epoch
)
kl_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
K
.
set_value
(
kl_beta
,
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
)
)
z_mean
,
z_log_sigma
=
KLDivergenceLayer
(
beta
=
kl_beta
)([
z_mean
,
z_log_sigma
])
z
=
Lambda
(
sampling
)([
z_mean
,
z_log_sigma
])
mmd_w
u
=
False
mmd_w
armup_callback
=
False
if
"MMD"
in
self
.
loss
:
mmd_beta
=
1
if
self
.
kl_warmup
:
def
mmdwarmup
(
epoch
):
mmd_beta
=
K
.
min
([
epoch
/
self
.
mmd_warmup
,
1
])
mmd_wu
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
log
:
mmdwarmup
(
epoch
)
mmd_beta
=
K
.
variable
(
1.0
,
name
=
"mmd_beta"
)
if
self
.
mmd_warmup
:
mmd_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
K
.
set_value
(
mmd_beta
,
K
.
min
([
epoch
/
self
.
mmd_warmup
,
1
])
)
)
z
=
MMDiscrepancyLayer
(
beta
=
mmd_beta
)(
z
)
...
...
@@ -594,7 +590,7 @@ class SEQ_2_SEQ_VAEP:
experimental_run_tf_function
=
False
,
)
return
encoder
,
generator
,
vaep
,
kl_w
u
,
mmd_wu
return
encoder
,
generator
,
vaep
,
kl_w
armup_callback
,
mmd_warmup_callback
class
SEQ_2_SEQ_MMVAE
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment