Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
d6ad5d0f
Commit
d6ad5d0f
authored
Jul 01, 2020
by
lucas_miranda
Browse files
Implemented weight saving callback in model_training.py
parent
2cf52af5
Changes
1
Hide whitespace changes
Inline
Side-by-side
model_training.py
View file @
d6ad5d0f
...
...
@@ -311,7 +311,9 @@ if not variational:
validation_data
=
(
input_dict_val
[
input_type
],
input_dict_val
[
input_type
]),
callbacks
=
[
tensorboard_callback
,
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
cp_callback
,
],
)
...
...
@@ -322,8 +324,8 @@ else:
generator
,
grouper
,
gmvaep
,
(
mmd_warmup_callback
if
"MMD"
in
loss
else
None
)
,
(
kl_warmup_callback
if
"ELBO"
in
loss
else
None
)
,
mmd_warmup_callback
,
kl_warmup_callback
,
)
=
SEQ_2_SEQ_GMVAE
(
input_dict_train
[
input_type
].
shape
,
loss
=
loss
,
...
...
@@ -337,6 +339,19 @@ else:
print
(
gmvaep
.
summary
())
callbacks_
=
[
tensorboard_callback
,
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
cp_callback
,
]
if
"ELBO"
in
loss
:
callbacks_
.
append
(
kl_warmup_callback
)
if
"MMD"
in
loss
:
callbacks_
.
append
(
mmd_warmup_callback
)
if
not
predictor
:
history
=
gmvaep
.
fit
(
x
=
input_dict_train
[
input_type
],
...
...
@@ -345,13 +360,7 @@ else:
batch_size
=
512
,
verbose
=
1
,
validation_data
=
(
input_dict_val
[
input_type
],
input_dict_val
[
input_type
]),
callbacks
=
[
tensorboard_callback
,
(
mmd_warmup_callback
if
"MMD"
in
loss
else
None
),
(
kl_warmup_callback
if
"ELBO"
in
loss
else
None
),
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
cp_callback
,
],
callbacks
=
callbacks_
,
)
else
:
history
=
gmvaep
.
fit
(
...
...
@@ -364,13 +373,5 @@ else:
input_dict_val
[
input_type
][:
-
1
],
[
input_dict_val
[
input_type
][:
-
1
],
input_dict_val
[
input_type
][
1
:]],
),
callbacks
=
[
tensorboard_callback
,
(
mmd_warmup_callback
if
"MMD"
in
loss
else
None
),
(
kl_warmup_callback
if
"ELBO"
in
loss
else
None
),
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
cp_callback
,
],
callbacks
=
callbacks_
,
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment