Commit 70724ab8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Modified LSTMs to work with cuDNN implementation

parent 93fa64b0
...@@ -5,6 +5,7 @@ from keras import backend as K ...@@ -5,6 +5,7 @@ from keras import backend as K
from sklearn.metrics import silhouette_score from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
...@@ -12,6 +13,22 @@ tfd = tfp.distributions ...@@ -12,6 +13,22 @@ tfd = tfp.distributions
tfpl = tfp.layers tfpl = tfp.layers
# Helper functions # Helper functions
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=1000000):
"""
Returns a uniformly initialised matrix in which the columns are as far as possible
"""
init_dist = 0
for i in range(iters):
temp = np.random.uniform(minval, maxval, shape)
dist = np.abs(np.linalg.norm(np.diff(temp)))
if dist > init_dist:
init_dist = dist
init = temp
return init.astype(np.float32)
def compute_kernel(x, y): def compute_kernel(x, y):
x_size = K.shape(x)[0] x_size = K.shape(x)[0]
y_size = K.shape(y)[0] y_size = K.shape(y)[0]
......
...@@ -193,21 +193,21 @@ class SEQ_2_SEQ_GMVAE: ...@@ -193,21 +193,21 @@ class SEQ_2_SEQ_GMVAE:
self.overlap_loss = overlap_loss self.overlap_loss = overlap_loss
if self.prior == "standard_normal": if self.prior == "standard_normal":
init_means = far_away_uniform_initialiser(
[self.number_of_components, self.ENCODING], minval=0, maxval=15
)
self.prior = tfd.mixture.Mixture( self.prior = tfd.mixture.Mixture(
cat=tfd.categorical.Categorical( cat=tfd.categorical.Categorical(
probs=tf.ones(self.number_of_components) / self.number_of_components probs=tf.ones(self.number_of_components) / self.number_of_components
), ),
components=[ components=[
tfd.Independent( tfd.Independent(
tfd.Normal( tfd.Normal(loc=init_means[k], scale=1,),
loc=tf.random.uniform(
shape=[self.ENCODING], minval=0, maxval=15
),
scale=1,
),
reinterpreted_batch_ndims=1, reinterpreted_batch_ndims=1,
) )
for _ in range(self.number_of_components) for k in range(self.number_of_components)
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment