diff --git a/source/model_utils.py b/source/model_utils.py
index 1181d34d003db493461f1f6441bb92a2b82d27df..2b1443d8ab785452083880027f2424da1841d44d 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -5,7 +5,6 @@ from keras import backend as K
 from sklearn.metrics import silhouette_score
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
-import numpy as np
 import tensorflow as tf
 import tensorflow_probability as tfp
 
@@ -13,20 +12,27 @@ tfd = tfp.distributions
 tfpl = tfp.layers
 
 # Helper functions
-def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=10000000):
+@tf.function
+def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
     """
     Returns a uniformly initialised matrix in which the columns are as far as possible
     """
-    init_dist = 0
-    for i in range(iters):
-        temp = np.random.uniform(minval, maxval, shape)
-        dist = np.abs(np.linalg.norm(np.diff(temp)))
+
+    init = tf.random.uniform(shape, minval, maxval)
+    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
+    i = 0
+
+    while tf.less(i, iters):
+        temp = tf.random.uniform(shape, minval, maxval)
+        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
 
         if dist > init_dist:
             init_dist = dist
             init = temp
 
-    return init.astype(np.float32)
+        i += 1
+
+    return init
 
 
 def compute_kernel(x, y):
@@ -40,6 +46,7 @@ def compute_kernel(x, y):
     )
 
 
+@tf.function
 def compute_mmd(tensors):
 
     x = tensors[0]
diff --git a/source/models.py b/source/models.py
index 023d0fd818923cf98a81b13de06daa23b1ec5b24..db2febdec79b6b7db4f0747cb700e5a6dd1b1984 100644
--- a/source/models.py
+++ b/source/models.py
@@ -195,7 +195,7 @@ class SEQ_2_SEQ_GMVAE:
         if self.prior == "standard_normal":
 
             init_means = far_away_uniform_initialiser(
-                [self.number_of_components, self.ENCODING], minval=0, maxval=5
+                shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5
             )
 
             self.prior = tfd.mixture.Mixture(
diff --git a/source/utils.py b/source/utils.py
index fc169e4b10b2fd78845722d99525de2d28b9d5a7..7d789844ef88899da4388ff155ee02021ca06d6c 100644
--- a/source/utils.py
+++ b/source/utils.py
@@ -785,7 +785,9 @@ def cluster_transition_matrix(
     trans_normed = np.zeros([k, k]) + 1e-5
     for t in trans.keys():
         trans_normed[int(t[0]), int(t[1])] = np.round(
-            trans[t] / (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5), 3
+            trans[t]
+            / (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5),
+            3,
         )
 
     # If specified, returns the transition matrix as an nx.Graph object