diff --git a/source/model_utils.py b/source/model_utils.py index cfbfa0ccda1396f906dcb47fe74d8ea3d237f4ae..53e7e86a29dd5abb12e335d9672dda11345f93bb 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -1,5 +1,6 @@ # @author lucasmiranda42 +from itertools import combinations from keras import backend as K from sklearn.metrics import silhouette_score from tensorflow.keras.constraints import Constraint @@ -22,7 +23,11 @@ def compute_kernel(x, y): ) -def compute_mmd(x, y): +def compute_mmd(tensors): + + x = tensors[0] + y = tensors[1] + x_kernel = compute_kernel(x, x) y_kernel = compute_kernel(y, y) xy_kernel = compute_kernel(x, y) @@ -127,7 +132,8 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): class MMDiscrepancyLayer(Layer): - """ Identity transform layer that adds MM discrepancy + """ + Identity transform layer that adds MM discrepancy to the final model loss. """ @@ -153,10 +159,78 @@ class MMDiscrepancyLayer(Layer): return z +class Gaussian_mixture_overlap(Layer): + """ + Identity layer that measures the overlap between the components of the latent Gaussian Mixture + using a specified metric (MMD, Wasserstein, Fischer-Rao) + """ + + def __init__( + self, + lat_dims, + n_components, + metric="mmd", + loss=False, + samples=100, + *args, + **kwargs + ): + self.lat_dims = lat_dims + self.n_components = n_components + self.metric = metric + self.loss = loss + self.samples = samples + super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs) + + def get_config(self): + config = super().get_config().copy() + config.update({"lat_dims": self.lat_dims}) + config.update({"n_components": self.n_components}) + config.update({"metric": self.metric}) + config.update({"loss": self.loss}) + config.update({"samples": self.samples}) + return config + + def call(self, target, loss=False): + + dists = [] + for k in range(self.n_components): + locs = (target[..., : self.lat_dims, k],) + scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k]) + + dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])) + + print(dists) + dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists] + print(dists) + + if self.metric == "mmd": + + intercomponent_mmd = K.mean( + tf.convert_to_tensor( + [ + tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]]) + for c in combinations(range(len(dists)), 2) + ], + dtype=tf.float32, + ) + ) + print(intercomponent_mmd) + self.add_metric( + intercomponent_mmd, aggregation="mean", name="intercomponent_mmd" + ) + + elif self.metric == "wasserstein": + pass + + return target + + class Latent_space_control(Layer): - """ Identity layer that adds latent space and clustering stats - to the metrics compiled by the model - """ + """ + Identity layer that adds latent space and clustering stats + to the metrics compiled by the model + """ def __init__(self, *args, **kwargs): super(Latent_space_control, self).__init__(*args, **kwargs) diff --git a/source/models.py b/source/models.py index 587c5a953c490f9460fe60d9365465afea8667d9..21fd0e71edbe6811d0da64e85df19068c6a25961 100644 --- a/source/models.py +++ b/source/models.py @@ -167,6 +167,7 @@ class SEQ_2_SEQ_GMVAE: prior="standard_normal", number_of_components=1, predictor=True, + overlap_metric="mmd", ): self.input_shape = input_shape self.CONV_filters = units_conv @@ -183,6 +184,7 @@ class SEQ_2_SEQ_GMVAE: self.mmd_warmup = mmd_warmup_epochs self.number_of_components = number_of_components self.predictor = predictor + self.overlap_metric = overlap_metric if self.prior == "standard_normal": self.prior = tfd.mixture.Mixture( @@ -298,6 +300,10 @@ class SEQ_2_SEQ_GMVAE: )(encoder) z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) + z_gauss = Gaussian_mixture_overlap( + self.ENCODING, self.number_of_components, metric=self.overlap_metric + )(z_gauss) + z = tfpl.DistributionLambda( lambda gauss: tfd.mixture.Mixture( cat=tfd.categorical.Categorical(probs=gauss[0],), @@ -438,10 +444,10 @@ class SEQ_2_SEQ_GMVAE: # TODO: -# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning +# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning (done!) # - Clustering metrics for model selection and aid training (eg early stopping) -# - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics -# - design clustering-conscious hyperparameter tuing pipeline +# - Silhouette / mMMD / Fischer-Mao / Wasserstein +# - design clustering-conscious hyperparameter tuning pipeline # TODO (in the non-immediate future): # - Try Bayesian nets!