Commit 5411ea58 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py

parent bc513ec5
# @author lucasmiranda42
from itertools import combinations
from keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint
......@@ -22,7 +23,11 @@ def compute_kernel(x, y):
)
def compute_mmd(x, y):
def compute_mmd(tensors):
x = tensors[0]
y = tensors[1]
x_kernel = compute_kernel(x, x)
y_kernel = compute_kernel(y, y)
xy_kernel = compute_kernel(x, y)
......@@ -127,7 +132,8 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
class MMDiscrepancyLayer(Layer):
""" Identity transform layer that adds MM discrepancy
"""
Identity transform layer that adds MM discrepancy
to the final model loss.
"""
......@@ -153,8 +159,76 @@ class MMDiscrepancyLayer(Layer):
return z
class Gaussian_mixture_overlap(Layer):
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (MMD, Wasserstein, Fischer-Rao)
"""
def __init__(
self,
lat_dims,
n_components,
metric="mmd",
loss=False,
samples=100,
*args,
**kwargs
):
self.lat_dims = lat_dims
self.n_components = n_components
self.metric = metric
self.loss = loss
self.samples = samples
super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"lat_dims": self.lat_dims})
config.update({"n_components": self.n_components})
config.update({"metric": self.metric})
config.update({"loss": self.loss})
config.update({"samples": self.samples})
return config
def call(self, target, loss=False):
dists = []
for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]))
print(dists)
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
print(dists)
if self.metric == "mmd":
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
for c in combinations(range(len(dists)), 2)
],
dtype=tf.float32,
)
)
print(intercomponent_mmd)
self.add_metric(
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
)
elif self.metric == "wasserstein":
pass
return target
class Latent_space_control(Layer):
""" Identity layer that adds latent space and clustering stats
"""
Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
......
......@@ -167,6 +167,7 @@ class SEQ_2_SEQ_GMVAE:
prior="standard_normal",
number_of_components=1,
predictor=True,
overlap_metric="mmd",
):
self.input_shape = input_shape
self.CONV_filters = units_conv
......@@ -183,6 +184,7 @@ class SEQ_2_SEQ_GMVAE:
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
self.predictor = predictor
self.overlap_metric = overlap_metric
if self.prior == "standard_normal":
self.prior = tfd.mixture.Mixture(
......@@ -298,6 +300,10 @@ class SEQ_2_SEQ_GMVAE:
)(encoder)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Gaussian_mixture_overlap(
self.ENCODING, self.number_of_components, metric=self.overlap_metric
)(z_gauss)
z = tfpl.DistributionLambda(
lambda gauss: tfd.mixture.Mixture(
cat=tfd.categorical.Categorical(probs=gauss[0],),
......@@ -438,10 +444,10 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning (done!)
# - Clustering metrics for model selection and aid training (eg early stopping)
# - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
# - design clustering-conscious hyperparameter tuing pipeline
# - Silhouette / mMMD / Fischer-Mao / Wasserstein
# - design clustering-conscious hyperparameter tuning pipeline
# TODO (in the non-immediate future):
# - Try Bayesian nets!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment