# @author lucasmiranda42 from keras import backend as K from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer import tensorflow as tf # Helper functions def sampling(args, epsilon_std=1.0): z_mean, z_log_sigma = args epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std) return z_mean + K.exp(z_log_sigma) * epsilon def compute_kernel(x, y): x_size = K.shape(x)[0] y_size = K.shape(y)[0] dim = K.shape(x)[1] tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1])) tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1])) return K.exp( -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32) ) def compute_mmd(x, y): x_kernel = compute_kernel(x, x) y_kernel = compute_kernel(y, y) xy_kernel = compute_kernel(x, y) return ( tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel) ) # Custom layers for efficiency/losses class DenseTranspose(Layer): def __init__(self, dense, output_dim, activation=None, **kwargs): self.dense = dense self.output_dim = output_dim self.activation = tf.keras.activations.get(activation) super().__init__(**kwargs) def get_config(self): config = super().get_config().copy() config.update( { "dense": self.dense, "output_dim": self.output_dim, "activation": self.activation, } ) return config def build(self, batch_input_shape): self.biases = self.add_weight( name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros" ) super().build(batch_input_shape) def call(self, inputs, **kwargs): z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True) return self.activation(z + self.biases) def compute_output_shape(self, input_shape): return input_shape[0], self.output_dim class UncorrelatedFeaturesConstraint(Constraint): def __init__(self, encoding_dim, weightage=1.0): self.encoding_dim = encoding_dim self.weightage = weightage def get_config(self): config = super().get_config().copy() config.update( {"encoding_dim": self.encoding_dim, "weightage": self.weightage,} ) return config def get_covariance(self, x): x_centered_list = [] for i in range(self.encoding_dim): x_centered_list.append(x[:, i] - K.mean(x[:, i])) x_centered = tf.stack(x_centered_list) covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast( x_centered.get_shape()[0], tf.float32 ) return covariance # Constraint penalty def uncorrelated_feature(self, x): if self.encoding_dim <= 1: return 0.0 else: output = K.sum( K.square( self.covariance - tf.math.multiply(self.covariance, K.eye(self.encoding_dim)) ) ) return output def __call__(self, x): self.covariance = self.get_covariance(x) return self.weightage * self.uncorrelated_feature(x) class KLDivergenceLayer(Layer): """ Identity transform layer that adds KL divergence to the final model loss. """ def __init__(self, *args, **kwargs): self.is_placeholder = True super(KLDivergenceLayer, self).__init__(*args, **kwargs) def call(self, inputs, **kwargs): mu, log_var = inputs kl_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) self.add_loss(K.mean(kl_batch), inputs=inputs) return inputs class MMDiscrepancyLayer(Layer): """ Identity transform layer that adds MM discrepancy to the final model loss. """ def __init__(self, *args, **kwargs): self.is_placeholder = True super(MMDiscrepancyLayer, self).__init__(*args, **kwargs) def call(self, z, **kwargs): true_samples = K.random_normal( K.shape(z), mean=0.0, stddev=2.0 / K.cast_to_floatx(K.shape(z)[1]) ) mmd_batch = compute_mmd(z, true_samples) self.add_loss(K.mean(mmd_batch), inputs=z) return z