Commit 7663405d authored by lucas_miranda's avatar lucas_miranda
Browse files

Enhanced performance with tf.function decorators

parent f5130f60
......@@ -5,7 +5,6 @@ from keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
......@@ -13,20 +12,27 @@ tfd = tfp.distributions
tfpl = tfp.layers
# Helper functions
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=10000000):
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
"""
Returns a uniformly initialised matrix in which the columns are as far as possible
"""
init_dist = 0
for i in range(iters):
temp = np.random.uniform(minval, maxval, shape)
dist = np.abs(np.linalg.norm(np.diff(temp)))
init = tf.random.uniform(shape, minval, maxval)
init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
i = 0
while tf.less(i, iters):
temp = tf.random.uniform(shape, minval, maxval)
dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
if dist > init_dist:
init_dist = dist
init = temp
return init.astype(np.float32)
i += 1
return init
def compute_kernel(x, y):
......@@ -40,6 +46,7 @@ def compute_kernel(x, y):
)
@tf.function
def compute_mmd(tensors):
x = tensors[0]
......
......@@ -195,7 +195,7 @@ class SEQ_2_SEQ_GMVAE:
if self.prior == "standard_normal":
init_means = far_away_uniform_initialiser(
[self.number_of_components, self.ENCODING], minval=0, maxval=5
shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5
)
self.prior = tfd.mixture.Mixture(
......
......@@ -785,7 +785,9 @@ def cluster_transition_matrix(
trans_normed = np.zeros([k, k]) + 1e-5
for t in trans.keys():
trans_normed[int(t[0]), int(t[1])] = np.round(
trans[t] / (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5), 3
trans[t]
/ (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5),
3,
)
# If specified, returns the transition matrix as an nx.Graph object
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment