From 70724ab835ba40b04ed60942c7fdc523f3cac433 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Mon, 13 Jul 2020 11:33:26 +0200
Subject: [PATCH] Modified LSTMs to work with cuDNN implementation

---
 source/model_utils.py | 17 +++++++++++++++++
 source/models.py      | 14 +++++++-------
 2 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/source/model_utils.py b/source/model_utils.py
index ecbdf8fb..fd718679 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -5,6 +5,7 @@ from keras import backend as K
 from sklearn.metrics import silhouette_score
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
+import numpy as np
 import tensorflow as tf
 import tensorflow_probability as tfp
 
@@ -12,6 +13,22 @@ tfd = tfp.distributions
 tfpl = tfp.layers
 
 # Helper functions
+def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=1000000):
+    """
+    Returns a uniformly initialised matrix in which the columns are as far as possible
+    """
+    init_dist = 0
+    for i in range(iters):
+        temp = np.random.uniform(minval, maxval, shape)
+        dist = np.abs(np.linalg.norm(np.diff(temp)))
+
+        if dist > init_dist:
+            init_dist = dist
+            init = temp
+
+    return init.astype(np.float32)
+
+
 def compute_kernel(x, y):
     x_size = K.shape(x)[0]
     y_size = K.shape(y)[0]
diff --git a/source/models.py b/source/models.py
index 95f5a403..ce692b69 100644
--- a/source/models.py
+++ b/source/models.py
@@ -193,21 +193,21 @@ class SEQ_2_SEQ_GMVAE:
         self.overlap_loss = overlap_loss
 
         if self.prior == "standard_normal":
+
+            init_means = far_away_uniform_initialiser(
+                [self.number_of_components, self.ENCODING], minval=0, maxval=15
+            )
+
             self.prior = tfd.mixture.Mixture(
                 cat=tfd.categorical.Categorical(
                     probs=tf.ones(self.number_of_components) / self.number_of_components
                 ),
                 components=[
                     tfd.Independent(
-                        tfd.Normal(
-                            loc=tf.random.uniform(
-                                shape=[self.ENCODING], minval=0, maxval=15
-                            ),
-                            scale=1,
-                        ),
+                        tfd.Normal(loc=init_means[k], scale=1,),
                         reinterpreted_batch_ndims=1,
                     )
-                    for _ in range(self.number_of_components)
+                    for k in range(self.number_of_components)
                 ],
             )
 
-- 
GitLab