Commit b5f14108 authored by lucas_miranda's avatar lucas_miranda
Browse files

Enhanced performance with tf.function decorators

parent fb49fcc8
This diff is collapsed.
# @author lucasmiranda42
from itertools import combinations
from keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
......@@ -36,13 +36,17 @@ def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
def compute_kernel(x, y):
x_size = K.shape(x)[0]
y_size = K.shape(y)[0]
dim = K.shape(x)[1]
tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
return K.exp(
-tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
x_size = tf.shape(x)[0]
y_size = tf.shape(y)[0]
dim = tf.shape(x)[1]
tiled_x = tf.tile(
tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
)
tiled_y = tf.tile(
tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
)
return tf.exp(
-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
)
......@@ -124,10 +128,10 @@ class UncorrelatedFeaturesConstraint(Constraint):
x_centered_list = []
for i in range(self.encoding_dim):
x_centered_list.append(x[:, i] - K.mean(x[:, i]))
x_centered_list.append(x[:, i] - tf.reduce_mean(x[:, i]))
x_centered = tf.stack(x_centered_list)
covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
covariance = tf.tensordot(x_centered, tf.transpose(x_centered)) / tf.cast(
x_centered.get_shape()[0], tf.float32
)
......@@ -138,10 +142,10 @@ class UncorrelatedFeaturesConstraint(Constraint):
if self.encoding_dim <= 1:
return 0.0
else:
output = K.sum(
K.square(
output = tf.reduce_sum.sum(
tf.square(
self.covariance
- tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
- tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
)
)
return output
......@@ -228,7 +232,7 @@ class MMDiscrepancyLayer(Layer):
def call(self, z, **kwargs):
true_samples = self.prior.sample(self.batch_size)
mmd_batch = self.beta * compute_mmd([true_samples, z])
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_loss(tf.reduce_mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
......@@ -270,7 +274,7 @@ class Gaussian_mixture_overlap(Layer):
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
### MMD-based overlap ###
intercomponent_mmd = K.mean(
intercomponent_mmd = tf.reduce_mean(
tf.convert_to_tensor(
[
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
......@@ -322,7 +326,7 @@ class Latent_space_control(Layer):
self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
self.add_loss(-tf.reduce_mean(silhouette), inputs=[z, hard_labels])
return z
......@@ -344,11 +348,11 @@ class Entropy_regulariser(Layer):
# axis=1 increases the entropy of a cluster across instances
# axis=0 increases the entropy of the assignment for a given instance
entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
entropy = tf.reduce_sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
# Adds metric that monitors dead neurons in the latent space
self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
self.add_loss(self.weight * K.sum(entropy), inputs=[z])
self.add_loss(self.weight * tf.reduce_sum(entropy), inputs=[z])
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment