Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
f851531e
Commit
f851531e
authored
Jul 01, 2020
by
lucas_miranda
Browse files
Implemented weight saving callback in model_training.py
parent
9ce11008
Changes
2
Hide whitespace changes
Inline
Side-by-side
model_training.py
View file @
f851531e
...
...
@@ -288,22 +288,52 @@ input_dict_train = {
input_dict_val
=
{
"coords"
:
coords2
.
preprocess
(
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gauss"
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gaussian"
,
sigma
=
110
,
),
"dists"
:
distances2
.
preprocess
(
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gauss"
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gaussian"
,
sigma
=
110
,
),
"angles"
:
angles2
.
preprocess
(
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gauss"
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gaussian"
,
sigma
=
110
,
),
"coords+dist"
:
coords_distances2
.
preprocess
(
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gauss"
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gaussian"
,
sigma
=
110
,
),
"coords+angle"
:
coords_angles2
.
preprocess
(
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gauss"
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gaussian"
,
sigma
=
110
,
),
"coords+dist+angle"
:
coords_dist_angles2
.
preprocess
(
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gauss"
window_size
=
11
,
window_step
=
1
,
scale
=
True
,
random_state
=
42
,
filter
=
"gaussian"
,
sigma
=
110
,
),
}
...
...
source/models.py
View file @
f851531e
...
...
@@ -297,19 +297,6 @@ class SEQ_2_SEQ_GMVAE:
activation
=
None
,
)(
encoder
)
# Define and control custom loss functions
kl_warmup_callback
=
False
if
"ELBO"
in
self
.
loss
:
kl_beta
=
K
.
variable
(
1.0
,
name
=
"kl_beta"
)
kl_beta
.
_trainable
=
False
if
self
.
kl_warmup
:
kl_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
K
.
set_value
(
kl_beta
,
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
)
)
z_gauss
=
Reshape
([
2
*
self
.
ENCODING
,
self
.
number_of_components
])(
z_gauss
)
z
=
tfpl
.
DistributionLambda
(
lambda
gauss
:
tfd
.
mixture
.
Mixture
(
...
...
@@ -328,7 +315,19 @@ class SEQ_2_SEQ_GMVAE:
activity_regularizer
=
UncorrelatedFeaturesConstraint
(
3
,
weightage
=
1.0
),
)([
z_cat
,
z_gauss
])
# Define and control custom loss functions
kl_warmup_callback
=
False
if
"ELBO"
in
self
.
loss
:
kl_beta
=
K
.
variable
(
1.0
,
name
=
"kl_beta"
)
kl_beta
.
_trainable
=
False
if
self
.
kl_warmup
:
kl_warmup_callback
=
LambdaCallback
(
on_epoch_begin
=
lambda
epoch
,
logs
:
K
.
set_value
(
kl_beta
,
K
.
min
([
epoch
/
self
.
kl_warmup
,
1
])
)
)
z
=
KLDivergenceLayer
(
self
.
prior
,
weight
=
kl_beta
)(
z
)
mmd_warmup_callback
=
False
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment