Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
4fa12580
Commit
4fa12580
authored
Jul 03, 2020
by
lucas_miranda
Browse files
Implemented weight saving callback in model_training.py
parent
f78f765f
Changes
2
Hide whitespace changes
Inline
Side-by-side
model_training.py
View file @
4fa12580
...
...
@@ -384,8 +384,8 @@ else:
generator
,
grouper
,
gmvaep
,
dead_neuron_rate_callback
,
silhouette_callback
,
#
dead_neuron_rate_callback,
#
silhouette_callback,
kl_warmup_callback
,
mmd_warmup_callback
,
)
=
SEQ_2_SEQ_GMVAE
(
...
...
@@ -404,8 +404,8 @@ else:
callbacks_
=
[
tensorboard_callback
,
cp_callback
,
dead_neuron_rate_callback
,
silhouette_callback
,
#
dead_neuron_rate_callback,
#
silhouette_callback,
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
...
...
source/models.py
View file @
4fa12580
...
...
@@ -344,17 +344,17 @@ class SEQ_2_SEQ_GMVAE:
z
=
MMDiscrepancyLayer
(
prior
=
self
.
prior
,
beta
=
mmd_beta
)(
z
)
#
z = Latent_space_control()(z, z_gauss, z_cat)
# Latent space callback to control dead (zero) dimensions in the latent space
dead_neuron_rate_callback
=
LambdaCallback
(
on_epoch_end
=
lambda
epoch
,
logs
:
tf
.
math
.
zero_fraction
(
z_gauss
)
)
# Latent space callback to control the latent silhouette clustering index
silhouette_callback
=
LambdaCallback
(
on_epoch_end
=
tf
.
numpy_function
(
silhouette_score
,
[
z
,
tf
.
math
.
argmax
(
z_cat
,
axis
=
1
)],
tf
.
float32
)
)
z
=
Latent_space_control
()(
z
,
z_gauss
,
z_cat
)
#
#
Latent space callback to control dead (zero) dimensions in the latent space
#
dead_neuron_rate_callback = LambdaCallback(
#
on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
#
)
#
#
#
Latent space callback to control the latent silhouette clustering index
#
silhouette_callback = LambdaCallback(
#
on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
#
)
# Define and instantiate generator
generator
=
Model_D1
(
z
)
...
...
@@ -441,8 +441,8 @@ class SEQ_2_SEQ_GMVAE:
generator
,
grouper
,
gmvaep
,
dead_neuron_rate_callback
,
silhouette_callback
,
#
dead_neuron_rate_callback,
#
silhouette_callback,
kl_warmup_callback
,
mmd_warmup_callback
,
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment