Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
d2c9404c
Commit
d2c9404c
authored
Mar 18, 2021
by
lucas_miranda
Browse files
Reimplemented MMD warmup using optimizer iterators; getting rid of the clumsy callback
parent
24b9c3ea
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
d2c9404c
...
...
@@ -470,19 +470,21 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def
__init__
(
self
,
batch_size
,
prior
,
beta
=
1.0
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
batch_size
,
prior
,
iters
,
warm_up_iters
,
*
args
,
**
kwargs
):
super
(
MMDiscrepancyLayer
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
is_placeholder
=
True
self
.
batch_size
=
batch_size
self
.
beta
=
beta
self
.
prior
=
prior
super
(
MMDiscrepancyLayer
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
_iters
=
iters
self
.
_warm_up_iters
=
warm_up_iters
def
get_config
(
self
):
# pragma: no cover
"""Updates Constraint metadata"""
config
=
super
().
get_config
().
copy
()
config
.
update
({
"batch_size"
:
self
.
batch_size
})
config
.
update
({
"beta"
:
self
.
beta
})
config
.
update
({
"iters"
:
self
.
_iters
})
config
.
update
({
"warmup_iters"
:
self
.
_warm_up_iters
})
config
.
update
({
"prior"
:
self
.
prior
})
return
config
...
...
@@ -490,11 +492,15 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method"""
true_samples
=
self
.
prior
.
sample
(
self
.
batch_size
)
mmd_weight
=
tf
.
cast
(
K
.
min
([
self
.
_iters
/
self
.
_warm_up_iters
,
1.0
]),
tf
.
float32
)
# noinspection PyTypeChecker
mmd_batch
=
self
.
beta
*
compute_mmd
((
true_samples
,
z
))
mmd_batch
=
mmd_weight
*
compute_mmd
((
true_samples
,
z
))
self
.
add_loss
(
K
.
mean
(
mmd_batch
),
inputs
=
z
)
self
.
add_metric
(
mmd_batch
,
aggregation
=
"mean"
,
name
=
"mmd"
)
self
.
add_metric
(
self
.
beta
,
aggregation
=
"mean"
,
name
=
"mmd_rate"
)
self
.
add_metric
(
mmd_weight
,
aggregation
=
"mean"
,
name
=
"mmd_rate"
)
return
z
...
...
deepof/models.py
View file @
d2c9404c
...
...
@@ -639,20 +639,18 @@ class SEQ_2_SEQ_GMVAE:
warm_up_iters
=
warm_up_iters
,
)(
z
)
mmd_warmup_callback
=
False
if
"MMD"
in
self
.
loss
:
mmd_beta
=
deepof
.
model_utils
.
K
.
variable
(
1.0
,
name
=
"mmd_beta"
)
mmd_beta
.
_trainable
=
False
if
self
.
mmd_warmup
:
mmd_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
deepof
.
model_utils
.
K
.
set_value
(
mmd_beta
,
deepof
.
model_utils
.
K
.
min
([
epoch
/
self
.
mmd_warmup
,
1
])
)
)
warm_up_iters
=
tf
.
cast
(
self
.
mmd_warmup
*
(
input_shape
[
0
]
//
self
.
batch_size
+
1
),
tf
.
int64
,
)
z
=
deepof
.
model_utils
.
MMDiscrepancyLayer
(
batch_size
=
self
.
batch_size
,
prior
=
self
.
prior
,
beta
=
mmd_beta
batch_size
=
self
.
batch_size
,
prior
=
self
.
prior
,
iters
=
self
.
optimizer
.
iterations
,
warm_up_iters
=
warm_up_iters
,
)(
z
)
# Dummy layer with no parameters, to retrieve the previous tensor
...
...
@@ -758,7 +756,6 @@ class SEQ_2_SEQ_GMVAE:
generator
,
grouper
,
gmvaep
,
mmd_warmup_callback
,
)
@
prior
.
setter
...
...
deepof/train_utils.py
View file @
d2c9404c
...
...
@@ -327,13 +327,7 @@ def autoencoder_fitting(
return_list
=
(
encoder
,
decoder
,
ae
)
else
:
(
encoder
,
generator
,
grouper
,
ae
,
mmd_warmup_callback
,
)
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
(
encoder
,
generator
,
grouper
,
ae
,)
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
({}
if
hparams
is
None
else
hparams
),
batch_size
=
batch_size
,
compile_model
=
True
,
...
...
@@ -349,9 +343,7 @@ def autoencoder_fitting(
predictor
=
predictor
,
reg_cat_clusters
=
reg_cat_clusters
,
reg_cluster_variance
=
reg_cluster_variance
,
).
build
(
X_train
.
shape
)
).
build
(
X_train
.
shape
)
return_list
=
(
encoder
,
generator
,
grouper
,
ae
)
if
pretrained
:
...
...
@@ -394,10 +386,6 @@ def autoencoder_fitting(
),
]
if
"MMD"
in
loss
and
mmd_warmup
>
0
:
# noinspection PyUnboundLocalVariable
callbacks_
.
append
(
mmd_warmup_callback
)
Xs
,
ys
=
[
X_train
],
[
X_train
]
Xvals
,
yvals
=
[
X_val
],
[
X_val
]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment