diff --git a/source/model_utils.py b/source/model_utils.py index 1181d34d003db493461f1f6441bb92a2b82d27df..2b1443d8ab785452083880027f2424da1841d44d 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -5,7 +5,6 @@ from keras import backend as K from sklearn.metrics import silhouette_score from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer -import numpy as np import tensorflow as tf import tensorflow_probability as tfp @@ -13,20 +12,27 @@ tfd = tfp.distributions tfpl = tfp.layers # Helper functions -def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=10000000): +@tf.function +def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000): """ Returns a uniformly initialised matrix in which the columns are as far as possible """ - init_dist = 0 - for i in range(iters): - temp = np.random.uniform(minval, maxval, shape) - dist = np.abs(np.linalg.norm(np.diff(temp))) + + init = tf.random.uniform(shape, minval, maxval) + init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1]))) + i = 0 + + while tf.less(i, iters): + temp = tf.random.uniform(shape, minval, maxval) + dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1]))) if dist > init_dist: init_dist = dist init = temp - return init.astype(np.float32) + i += 1 + + return init def compute_kernel(x, y): @@ -40,6 +46,7 @@ def compute_kernel(x, y): ) +@tf.function def compute_mmd(tensors): x = tensors[0] diff --git a/source/models.py b/source/models.py index 023d0fd818923cf98a81b13de06daa23b1ec5b24..db2febdec79b6b7db4f0747cb700e5a6dd1b1984 100644 --- a/source/models.py +++ b/source/models.py @@ -195,7 +195,7 @@ class SEQ_2_SEQ_GMVAE: if self.prior == "standard_normal": init_means = far_away_uniform_initialiser( - [self.number_of_components, self.ENCODING], minval=0, maxval=5 + shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5 ) self.prior = tfd.mixture.Mixture( diff --git a/source/utils.py b/source/utils.py index fc169e4b10b2fd78845722d99525de2d28b9d5a7..7d789844ef88899da4388ff155ee02021ca06d6c 100644 --- a/source/utils.py +++ b/source/utils.py @@ -785,7 +785,9 @@ def cluster_transition_matrix( trans_normed = np.zeros([k, k]) + 1e-5 for t in trans.keys(): trans_normed[int(t[0]), int(t[1])] = np.round( - trans[t] / (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5), 3 + trans[t] + / (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5), + 3, ) # If specified, returns the transition matrix as an nx.Graph object