model_utils.py 9.16 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from keras import backend as K
5
from sklearn.metrics import silhouette_score
6
7
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
8
import numpy as np
9
import tensorflow as tf
10
import tensorflow_probability as tfp
11

12
tfd = tfp.distributions
13
tfpl = tfp.layers
14
15

# Helper functions
16
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=10000000):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
    init_dist = 0
    for i in range(iters):
        temp = np.random.uniform(minval, maxval, shape)
        dist = np.abs(np.linalg.norm(np.diff(temp)))

        if dist > init_dist:
            init_dist = dist
            init = temp

    return init.astype(np.float32)


32
33
34
35
36
37
38
39
40
41
42
def compute_kernel(x, y):
    x_size = K.shape(x)[0]
    y_size = K.shape(y)[0]
    dim = K.shape(x)[1]
    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
    return K.exp(
        -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
    )


43
44
45
46
47
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

48
49
50
51
52
53
54
55
56
57
58
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


# Custom layers for efficiency/losses
59
class MCDropout(tf.keras.layers.Dropout):
60
    def call(self, inputs, **kwargs):
61
62
        return super().call(inputs, training=True)

63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))

        x_centered = tf.stack(x_centered_list)
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
            output = K.sum(
                K.square(
                    self.covariance
                    - tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


140
141
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
142
143
144
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

145
146
147
148
149
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
150
        )
151
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
152

153
        return distribution_a
154
155
156


class MMDiscrepancyLayer(Layer):
157
    """
158
    Identity transform layer that adds MM Discrepancy
159
160
161
    to the final model loss.
    """

162
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
163
        self.is_placeholder = True
164
        self.batch_size = batch_size
165
        self.beta = beta
166
        self.prior = prior
167
168
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

169
170
    def get_config(self):
        config = super().get_config().copy()
171
        config.update({"batch_size": self.batch_size})
172
        config.update({"beta": self.beta})
173
        config.update({"prior": self.prior})
174
175
        return config

176
    def call(self, z, **kwargs):
177
        true_samples = self.prior.sample(self.batch_size)
178
        mmd_batch = self.beta * compute_mmd([true_samples, z])
179
        self.add_loss(K.mean(mmd_batch), inputs=z)
180
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
181
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
182
183

        return z
184
185


186
187
188
189
190
191
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

192
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

214
215
216
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
217
218
219

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

220
221
222
223
        ### MMD-based overlap ###
        intercomponent_mmd = K.mean(
            tf.convert_to_tensor(
                [
224
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
225
226
227
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
228
            )
229
        )
230

231
        self.add_metric(
232
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
233
        )
234

235
236
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
237
238
239
240

        return target


241
class Latent_space_control(Layer):
242
243
244
245
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
246

247
    def __init__(self, silhouette=False, loss=False, *args, **kwargs):
248
        self.loss = loss
249
        self.silhouette = silhouette
250
251
        super(Latent_space_control, self).__init__(*args, **kwargs)

252
253
254
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})
255
        config.update({"silhouette": self.silhouette})
256

257
258
259
260
261
262
263
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

264
        # Adds Silhouette score controlling overlap between clusters
265
266
267
268
269
270
        if self.silhouette:
            hard_labels = tf.math.argmax(z_cat, axis=1)
            silhouette = tf.numpy_function(
                silhouette_score, [z, hard_labels], tf.float32
            )
            self.add_metric(silhouette, aggregation="mean", name="silhouette")
271

272
273
            if self.loss:
                self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
274

275
        return z
276
277
278
279
280
281
282


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

283
    def __init__(self, weight=1.0, *args, **kwargs):
284
285
286
287
288
289
290
291
292
293
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

        entropy = K.sum(
294
            tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=0
295
296
297
        )

        # Adds metric that monitors dead neurons in the latent space
298
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
299

300
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
301
302

        return z