Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
2b01ccdd
Commit
2b01ccdd
authored
Jul 06, 2020
by
lucas_miranda
Browse files
Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py
parent
e46d76a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
source/model_utils.py
View file @
2b01ccdd
...
...
@@ -220,6 +220,9 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd
,
aggregation
=
"mean"
,
name
=
"intercomponent_mmd"
)
if
self
.
loss
:
self
.
add_loss
(
-
intercomponent_mmd
,
inputs
=
[
target
])
elif
self
.
metric
==
"wasserstein"
:
pass
...
...
@@ -232,9 +235,14 @@ class Latent_space_control(Layer):
to the metrics compiled by the model
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
loss
=
False
,
*
args
,
**
kwargs
):
self
.
loss
=
loss
super
(
Latent_space_control
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"loss"
:
self
.
loss
})
def
call
(
self
,
z
,
z_gauss
,
z_cat
,
**
kwargs
):
# Adds metric that monitors dead neurons in the latent space
...
...
@@ -245,7 +253,9 @@ class Latent_space_control(Layer):
# Adds Silhouette score controling overlap between clusters
hard_labels
=
tf
.
math
.
argmax
(
z_cat
,
axis
=
1
)
silhouette
=
tf
.
numpy_function
(
silhouette_score
,
[
z
,
hard_labels
],
tf
.
float32
)
self
.
add_loss
(
-
K
.
mean
(
silhouette
),
inputs
=
[
z
,
hard_labels
])
self
.
add_metric
(
silhouette
,
aggregation
=
"mean"
,
name
=
"silhouette"
)
if
self
.
loss
:
self
.
add_loss
(
-
K
.
mean
(
silhouette
),
inputs
=
[
z
,
hard_labels
])
return
z
source/models.py
View file @
2b01ccdd
...
...
@@ -168,6 +168,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components
=
1
,
predictor
=
True
,
overlap_metric
=
"mmd"
,
overlap_loss
=
False
,
):
self
.
input_shape
=
input_shape
self
.
CONV_filters
=
units_conv
...
...
@@ -185,6 +186,7 @@ class SEQ_2_SEQ_GMVAE:
self
.
number_of_components
=
number_of_components
self
.
predictor
=
predictor
self
.
overlap_metric
=
overlap_metric
self
.
overlap_loss
=
overlap_loss
if
self
.
prior
==
"standard_normal"
:
self
.
prior
=
tfd
.
mixture
.
Mixture
(
...
...
@@ -301,7 +303,10 @@ class SEQ_2_SEQ_GMVAE:
z_gauss
=
Reshape
([
2
*
self
.
ENCODING
,
self
.
number_of_components
])(
z_gauss
)
z_gauss
=
Gaussian_mixture_overlap
(
self
.
ENCODING
,
self
.
number_of_components
,
metric
=
self
.
overlap_metric
self
.
ENCODING
,
self
.
number_of_components
,
metric
=
self
.
overlap_metric
,
loss
=
self
.
overlap_loss
,
)(
z_gauss
)
z
=
tfpl
.
DistributionLambda
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment