Commit 5d154b87 authored by lucas_miranda's avatar lucas_miranda
Browse files

Enhanced performance with tf.function decorators

parent b5f14108
# @author lucasmiranda42
from itertools import combinations
from sklearn.metrics import silhouette_score
from tensorflow.keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
......@@ -128,10 +128,10 @@ class UncorrelatedFeaturesConstraint(Constraint):
x_centered_list = []
for i in range(self.encoding_dim):
x_centered_list.append(x[:, i] - tf.reduce_mean(x[:, i]))
x_centered_list.append(x[:, i] - K.mean(x[:, i]))
x_centered = tf.stack(x_centered_list)
covariance = tf.tensordot(x_centered, tf.transpose(x_centered)) / tf.cast(
covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
x_centered.get_shape()[0], tf.float32
)
......@@ -142,8 +142,8 @@ class UncorrelatedFeaturesConstraint(Constraint):
if self.encoding_dim <= 1:
return 0.0
else:
output = tf.reduce_sum.sum(
tf.square(
output = K.sum(
K.square(
self.covariance
- tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
)
......@@ -232,7 +232,7 @@ class MMDiscrepancyLayer(Layer):
def call(self, z, **kwargs):
true_samples = self.prior.sample(self.batch_size)
mmd_batch = self.beta * compute_mmd([true_samples, z])
self.add_loss(tf.reduce_mean(mmd_batch), inputs=z)
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
......@@ -274,7 +274,7 @@ class Gaussian_mixture_overlap(Layer):
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
### MMD-based overlap ###
intercomponent_mmd = tf.reduce_mean(
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
......@@ -326,7 +326,7 @@ class Latent_space_control(Layer):
self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-tf.reduce_mean(silhouette), inputs=[z, hard_labels])
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
return z
......@@ -348,11 +348,11 @@ class Entropy_regulariser(Layer):
# axis=1 increases the entropy of a cluster across instances
# axis=0 increases the entropy of the assignment for a given instance
entropy = tf.reduce_sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
# Adds metric that monitors dead neurons in the latent space
self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
self.add_loss(self.weight * tf.reduce_sum(entropy), inputs=[z])
self.add_loss(self.weight * K.sum(entropy), inputs=[z])
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment