model_utils.py 11.1 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from tensorflow.keras import backend as K
5
from sklearn.metrics import silhouette_score
6
7
8
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
9
import tensorflow_probability as tfp
10

11
tfd = tfp.distributions
12
tfpl = tfp.layers
13
14

# Helper functions
15
16
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
17
18
19
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
20
21
22
23
24
25
26
27

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
28
29
30
31
32

        if dist > init_dist:
            init_dist = dist
            init = temp

33
34
35
        i += 1

    return init
36
37


38
def compute_kernel(x, y):
39
40
41
42
43
44
45
46
47
48
49
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
50
51
52
    )


53
@tf.function
54
55
56
57
58
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

59
60
61
62
63
64
65
66
67
68
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
131
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
132
133

        x_centered = tf.stack(x_centered_list)
134
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
135
136
137
138
139
140
141
142
143
144
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
145
146
            output = K.sum(
                K.square(
147
                    self.covariance
148
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
149
150
151
152
153
154
155
156
157
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


158
159
160
161
162
163
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


164
165
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
166
167
168
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

169
170
171
    def get_config(self):
        config = super().get_config().copy()
        config.update(
172
            {"is_placeholder": self.is_placeholder,}
173
174
175
        )
        return config

176
177
178
179
180
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
181
        )
182
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
183

184
        return distribution_a
185
186


187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


219
class MMDiscrepancyLayer(Layer):
220
    """
221
    Identity transform layer that adds MM Discrepancy
222
223
224
    to the final model loss.
    """

225
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
226
        self.is_placeholder = True
227
        self.batch_size = batch_size
228
        self.beta = beta
229
        self.prior = prior
230
231
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

232
233
    def get_config(self):
        config = super().get_config().copy()
234
        config.update({"batch_size": self.batch_size})
235
        config.update({"beta": self.beta})
236
        config.update({"prior": self.prior})
237
238
        return config

239
    def call(self, z, **kwargs):
240
        true_samples = self.prior.sample(self.batch_size)
241
        mmd_batch = self.beta * compute_mmd([true_samples, z])
242
        self.add_loss(K.mean(mmd_batch), inputs=z)
243
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
244
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
245
246

        return z
247
248


249
250
251
252
253
254
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

255
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

277
278
279
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
280
281
282

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

283
        ### MMD-based overlap ###
284
        intercomponent_mmd = K.mean(
285
286
            tf.convert_to_tensor(
                [
287
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
288
289
290
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
291
            )
292
        )
293

294
        self.add_metric(
295
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
296
        )
297

298
299
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
300
301
302
303

        return target


304
class Latent_space_control(Layer):
305
306
307
308
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
309

310
    def __init__(self, silhouette=False, loss=False, *args, **kwargs):
311
        self.loss = loss
312
        self.silhouette = silhouette
313
314
        super(Latent_space_control, self).__init__(*args, **kwargs)

315
316
317
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})
318
        config.update({"silhouette": self.silhouette})
319

320
321
322
323
324
325
326
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

327
        # Adds Silhouette score controlling overlap between clusters
328
329
330
331
332
333
        if self.silhouette:
            hard_labels = tf.math.argmax(z_cat, axis=1)
            silhouette = tf.numpy_function(
                silhouette_score, [z, hard_labels], tf.float32
            )
            self.add_metric(silhouette, aggregation="mean", name="silhouette")
334

335
            if self.loss:
336
                self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
337

338
        return z
339
340
341
342
343
344
345


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

346
    def __init__(self, weight=1.0, *args, **kwargs):
347
348
349
350
351
352
353
354
355
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

356
357
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
358
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
359
360

        # Adds metric that monitors dead neurons in the latent space
361
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
362

363
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
364
365

        return z