Commit 5411ea58 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py

parent bc513ec5
# @author lucasmiranda42
from itertools import combinations
from keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint
......@@ -22,7 +23,11 @@ def compute_kernel(x, y):
)
def compute_mmd(x, y):
def compute_mmd(tensors):
x = tensors[0]
y = tensors[1]
x_kernel = compute_kernel(x, x)
y_kernel = compute_kernel(y, y)
xy_kernel = compute_kernel(x, y)
......@@ -127,7 +132,8 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
class MMDiscrepancyLayer(Layer):
""" Identity transform layer that adds MM discrepancy
"""
Identity transform layer that adds MM discrepancy
to the final model loss.
"""
......@@ -153,10 +159,78 @@ class MMDiscrepancyLayer(Layer):
return z
class Gaussian_mixture_overlap(Layer):
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (MMD, Wasserstein, Fischer-Rao)
"""
def __init__(
self,
lat_dims,
n_components,
metric="mmd",
loss=False,
samples=100,
*args,
**kwargs
):
self.lat_dims = lat_dims
self.n_components = n_components
self.metric = metric
self.loss = loss
self.samples = samples
super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"lat_dims": self.lat_dims})
config.update({"n_components": self.n_components})
config.update({"metric": self.metric})
config.update({"loss": self.loss})
config.update({"samples": self.samples})
return config
def call(self, target, loss=False):
dists = []
for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]))
print(dists)
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
print(dists)
if self.metric == "mmd":
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
for c in combinations(range(len(dists)), 2)
],
dtype=tf.float32,
)
)
print(intercomponent_mmd)
self.add_metric(
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
)
elif self.metric == "wasserstein":
pass
return target
class Latent_space_control(Layer):
""" Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
"""
Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
def __init__(self, *args, **kwargs):
super(Latent_space_control, self).__init__(*args, **kwargs)
......
......@@ -167,6 +167,7 @@ class SEQ_2_SEQ_GMVAE:
prior="standard_normal",
number_of_components=1,
predictor=True,
overlap_metric="mmd",
):
self.input_shape = input_shape
self.CONV_filters = units_conv
......@@ -183,6 +184,7 @@ class SEQ_2_SEQ_GMVAE:
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
self.predictor = predictor
self.overlap_metric = overlap_metric
if self.prior == "standard_normal":
self.prior = tfd.mixture.Mixture(
......@@ -298,6 +300,10 @@ class SEQ_2_SEQ_GMVAE:
)(encoder)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Gaussian_mixture_overlap(
self.ENCODING, self.number_of_components, metric=self.overlap_metric
)(z_gauss)
z = tfpl.DistributionLambda(
lambda gauss: tfd.mixture.Mixture(
cat=tfd.categorical.Categorical(probs=gauss[0],),
......@@ -438,10 +444,10 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning (done!)
# - Clustering metrics for model selection and aid training (eg early stopping)
# - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
# - design clustering-conscious hyperparameter tuing pipeline
# - Silhouette / mMMD / Fischer-Mao / Wasserstein
# - design clustering-conscious hyperparameter tuning pipeline
# TODO (in the non-immediate future):
# - Try Bayesian nets!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment