Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
5411ea58
Commit
5411ea58
authored
Jul 06, 2020
by
lucas_miranda
Browse files
Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py
parent
bc513ec5
Changes
2
Show whitespace changes
Inline
Side-by-side
source/model_utils.py
View file @
5411ea58
# @author lucasmiranda42
from
itertools
import
combinations
from
keras
import
backend
as
K
from
sklearn.metrics
import
silhouette_score
from
tensorflow.keras.constraints
import
Constraint
...
...
@@ -22,7 +23,11 @@ def compute_kernel(x, y):
)
def
compute_mmd
(
x
,
y
):
def
compute_mmd
(
tensors
):
x
=
tensors
[
0
]
y
=
tensors
[
1
]
x_kernel
=
compute_kernel
(
x
,
x
)
y_kernel
=
compute_kernel
(
y
,
y
)
xy_kernel
=
compute_kernel
(
x
,
y
)
...
...
@@ -127,7 +132,8 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
class
MMDiscrepancyLayer
(
Layer
):
""" Identity transform layer that adds MM discrepancy
"""
Identity transform layer that adds MM discrepancy
to the final model loss.
"""
...
...
@@ -153,8 +159,76 @@ class MMDiscrepancyLayer(Layer):
return
z
class
Gaussian_mixture_overlap
(
Layer
):
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (MMD, Wasserstein, Fischer-Rao)
"""
def
__init__
(
self
,
lat_dims
,
n_components
,
metric
=
"mmd"
,
loss
=
False
,
samples
=
100
,
*
args
,
**
kwargs
):
self
.
lat_dims
=
lat_dims
self
.
n_components
=
n_components
self
.
metric
=
metric
self
.
loss
=
loss
self
.
samples
=
samples
super
(
Gaussian_mixture_overlap
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"lat_dims"
:
self
.
lat_dims
})
config
.
update
({
"n_components"
:
self
.
n_components
})
config
.
update
({
"metric"
:
self
.
metric
})
config
.
update
({
"loss"
:
self
.
loss
})
config
.
update
({
"samples"
:
self
.
samples
})
return
config
def
call
(
self
,
target
,
loss
=
False
):
dists
=
[]
for
k
in
range
(
self
.
n_components
):
locs
=
(
target
[...,
:
self
.
lat_dims
,
k
],)
scales
=
tf
.
keras
.
activations
.
softplus
(
target
[...,
self
.
lat_dims
:,
k
])
dists
.
append
(
tfd
.
BatchReshape
(
tfd
.
MultivariateNormalDiag
(
locs
,
scales
),
[
-
1
]))
print
(
dists
)
dists
=
[
tf
.
transpose
(
gauss
.
sample
(
self
.
samples
),
[
1
,
0
,
2
])
for
gauss
in
dists
]
print
(
dists
)
if
self
.
metric
==
"mmd"
:
intercomponent_mmd
=
K
.
mean
(
tf
.
convert_to_tensor
(
[
tf
.
vectorized_map
(
compute_mmd
,
[
dists
[
c
[
0
]],
dists
[
c
[
1
]]])
for
c
in
combinations
(
range
(
len
(
dists
)),
2
)
],
dtype
=
tf
.
float32
,
)
)
print
(
intercomponent_mmd
)
self
.
add_metric
(
intercomponent_mmd
,
aggregation
=
"mean"
,
name
=
"intercomponent_mmd"
)
elif
self
.
metric
==
"wasserstein"
:
pass
return
target
class
Latent_space_control
(
Layer
):
""" Identity layer that adds latent space and clustering stats
"""
Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
...
...
source/models.py
View file @
5411ea58
...
...
@@ -167,6 +167,7 @@ class SEQ_2_SEQ_GMVAE:
prior
=
"standard_normal"
,
number_of_components
=
1
,
predictor
=
True
,
overlap_metric
=
"mmd"
,
):
self
.
input_shape
=
input_shape
self
.
CONV_filters
=
units_conv
...
...
@@ -183,6 +184,7 @@ class SEQ_2_SEQ_GMVAE:
self
.
mmd_warmup
=
mmd_warmup_epochs
self
.
number_of_components
=
number_of_components
self
.
predictor
=
predictor
self
.
overlap_metric
=
overlap_metric
if
self
.
prior
==
"standard_normal"
:
self
.
prior
=
tfd
.
mixture
.
Mixture
(
...
...
@@ -298,6 +300,10 @@ class SEQ_2_SEQ_GMVAE:
)(
encoder
)
z_gauss
=
Reshape
([
2
*
self
.
ENCODING
,
self
.
number_of_components
])(
z_gauss
)
z_gauss
=
Gaussian_mixture_overlap
(
self
.
ENCODING
,
self
.
number_of_components
,
metric
=
self
.
overlap_metric
)(
z_gauss
)
z
=
tfpl
.
DistributionLambda
(
lambda
gauss
:
tfd
.
mixture
.
Mixture
(
cat
=
tfd
.
categorical
.
Categorical
(
probs
=
gauss
[
0
],),
...
...
@@ -438,10 +444,10 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
(done!)
# - Clustering metrics for model selection and aid training (eg early stopping)
# - Silhouette /
likelihood (AIC / BIC)
/
cl
ass
ifier accuracy metrics
# - design clustering-conscious hyperparameter tuing pipeline
# - Silhouette /
mMMD / Fischer-Mao
/
W
ass
erstein
# - design clustering-conscious hyperparameter tu
n
ing pipeline
# TODO (in the non-immediate future):
# - Try Bayesian nets!
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment