Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
f78f765f
Commit
f78f765f
authored
Jul 03, 2020
by
lucas_miranda
Browse files
Implemented weight saving callback in model_training.py
parent
b834f5e5
Changes
3
Hide whitespace changes
Inline
Side-by-side
model_training.py
View file @
f78f765f
...
...
@@ -371,10 +371,10 @@ if not variational:
validation_data
=
(
input_dict_val
[
input_type
],
input_dict_val
[
input_type
]),
callbacks
=
[
tensorboard_callback
,
cp_callback
,
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
cp_callback
,
],
)
...
...
@@ -384,6 +384,8 @@ else:
generator
,
grouper
,
gmvaep
,
dead_neuron_rate_callback
,
silhouette_callback
,
kl_warmup_callback
,
mmd_warmup_callback
,
)
=
SEQ_2_SEQ_GMVAE
(
...
...
@@ -401,15 +403,17 @@ else:
callbacks_
=
[
tensorboard_callback
,
cp_callback
,
dead_neuron_rate_callback
,
silhouette_callback
,
tf
.
keras
.
callbacks
.
EarlyStopping
(
"val_mae"
,
patience
=
5
,
restore_best_weights
=
True
),
cp_callback
,
]
if
"ELBO"
in
loss
:
if
"ELBO"
in
loss
and
kl_wu
>
0
:
callbacks_
.
append
(
kl_warmup_callback
)
if
"MMD"
in
loss
:
if
"MMD"
in
loss
and
mmd_wu
>
0
:
callbacks_
.
append
(
mmd_warmup_callback
)
if
not
predictor
:
...
...
source/model_utils.py
View file @
f78f765f
# @author lucasmiranda42
from
keras
import
backend
as
K
from
sklearn.metrics
import
silhouette_score
from
tensorflow.keras.constraints
import
Constraint
from
tensorflow.keras.layers
import
Layer
import
tensorflow
as
tf
...
...
@@ -150,3 +151,26 @@ class MMDiscrepancyLayer(Layer):
self
.
add_metric
(
self
.
beta
,
aggregation
=
"mean"
,
name
=
"mmd_rate"
)
return
z
class
Latent_space_control
(
Layer
):
""" Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Latent_space_control
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
call
(
self
,
z
,
z_gauss
,
z_cat
,
**
kwargs
):
# Adds metric that monitors dead neurons in the latent space
self
.
add_metric
(
tf
.
math
.
zero_fraction
(
z_gauss
),
aggregation
=
"mean"
,
name
=
"dead_neurons"
)
# Adds Silhouette score controling overlap between clusters
hard_labels
=
tf
.
math
.
argmax
(
z_cat
,
axis
=
1
)
silhouette
=
tf
.
numpy_function
(
silhouette_score
,
[
z
,
hard_labels
],
tf
.
float32
)
self
.
add_metric
(
silhouette
,
aggregation
=
"mean"
,
name
=
"silhouette"
)
return
z
source/models.py
View file @
f78f765f
...
...
@@ -344,6 +344,18 @@ class SEQ_2_SEQ_GMVAE:
z
=
MMDiscrepancyLayer
(
prior
=
self
.
prior
,
beta
=
mmd_beta
)(
z
)
# z = Latent_space_control()(z, z_gauss, z_cat)
# Latent space callback to control dead (zero) dimensions in the latent space
dead_neuron_rate_callback
=
LambdaCallback
(
on_epoch_end
=
lambda
epoch
,
logs
:
tf
.
math
.
zero_fraction
(
z_gauss
)
)
# Latent space callback to control the latent silhouette clustering index
silhouette_callback
=
LambdaCallback
(
on_epoch_end
=
tf
.
numpy_function
(
silhouette_score
,
[
z
,
tf
.
math
.
argmax
(
z_cat
,
axis
=
1
)],
tf
.
float32
)
)
# Define and instantiate generator
generator
=
Model_D1
(
z
)
generator
=
Model_B1
(
generator
)
...
...
@@ -429,6 +441,8 @@ class SEQ_2_SEQ_GMVAE:
generator
,
grouper
,
gmvaep
,
dead_neuron_rate_callback
,
silhouette_callback
,
kl_warmup_callback
,
mmd_warmup_callback
,
)
...
...
@@ -437,7 +451,7 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
# - Clustering metrics for model selection and aid training (eg early stopping)
# - Silhouette / likelihood / classifier accuracy metrics
# - Silhouette / likelihood
(AIC / BIC)
/ classifier accuracy metrics
# - design clustering-conscious hyperparameter tuing pipeline
# TODO (in the non-immediate future):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment