diff --git a/source/model_utils.py b/source/model_utils.py index ecbdf8fb7e4cee5d173a54795d765731f571115f..fd71867989414b8054a384e9286e805511de1146 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -5,6 +5,7 @@ from keras import backend as K from sklearn.metrics import silhouette_score from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer +import numpy as np import tensorflow as tf import tensorflow_probability as tfp @@ -12,6 +13,22 @@ tfd = tfp.distributions tfpl = tfp.layers # Helper functions +def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=1000000): + """ + Returns a uniformly initialised matrix in which the columns are as far as possible + """ + init_dist = 0 + for i in range(iters): + temp = np.random.uniform(minval, maxval, shape) + dist = np.abs(np.linalg.norm(np.diff(temp))) + + if dist > init_dist: + init_dist = dist + init = temp + + return init.astype(np.float32) + + def compute_kernel(x, y): x_size = K.shape(x)[0] y_size = K.shape(y)[0] diff --git a/source/models.py b/source/models.py index 95f5a4034a1d79dfacadef59e3f6b81cdc510c05..ce692b6962632aaf40dfa54cb6a1306497b91290 100644 --- a/source/models.py +++ b/source/models.py @@ -193,21 +193,21 @@ class SEQ_2_SEQ_GMVAE: self.overlap_loss = overlap_loss if self.prior == "standard_normal": + + init_means = far_away_uniform_initialiser( + [self.number_of_components, self.ENCODING], minval=0, maxval=15 + ) + self.prior = tfd.mixture.Mixture( cat=tfd.categorical.Categorical( probs=tf.ones(self.number_of_components) / self.number_of_components ), components=[ tfd.Independent( - tfd.Normal( - loc=tf.random.uniform( - shape=[self.ENCODING], minval=0, maxval=15 - ), - scale=1, - ), + tfd.Normal(loc=init_means[k], scale=1,), reinterpreted_batch_ndims=1, ) - for _ in range(self.number_of_components) + for k in range(self.number_of_components) ], )