model_utils.py 11 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from sklearn.metrics import silhouette_score
5
from tensorflow.keras import backend as K
6
7
8
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
9
import tensorflow_probability as tfp
10

11
tfd = tfp.distributions
12
tfpl = tfp.layers
13
14

# Helper functions
15
16
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
17
18
19
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
20
21
22
23
24
25
26
27

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
28
29
30
31
32

        if dist > init_dist:
            init_dist = dist
            init = temp

33
34
35
        i += 1

    return init
36
37


38
def compute_kernel(x, y):
39
40
41
42
43
44
45
46
47
48
49
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
50
51
52
    )


53
@tf.function
54
55
56
57
58
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

59
60
61
62
63
64
65
66
67
68
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
131
            x_centered_list.append(x[:, i] - tf.reduce_mean(x[:, i]))
132
133

        x_centered = tf.stack(x_centered_list)
134
        covariance = tf.tensordot(x_centered, tf.transpose(x_centered)) / tf.cast(
135
136
137
138
139
140
141
142
143
144
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
145
146
            output = tf.reduce_sum.sum(
                tf.square(
147
                    self.covariance
148
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
149
150
151
152
153
154
155
156
157
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


158
159
160
161
162
163
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


164
165
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
166
167
168
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

169
170
171
172
173
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
174
        )
175
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
176

177
        return distribution_a
178
179


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


212
class MMDiscrepancyLayer(Layer):
213
    """
214
    Identity transform layer that adds MM Discrepancy
215
216
217
    to the final model loss.
    """

218
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
219
        self.is_placeholder = True
220
        self.batch_size = batch_size
221
        self.beta = beta
222
        self.prior = prior
223
224
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

225
226
    def get_config(self):
        config = super().get_config().copy()
227
        config.update({"batch_size": self.batch_size})
228
        config.update({"beta": self.beta})
229
        config.update({"prior": self.prior})
230
231
        return config

232
    def call(self, z, **kwargs):
233
        true_samples = self.prior.sample(self.batch_size)
234
        mmd_batch = self.beta * compute_mmd([true_samples, z])
235
        self.add_loss(tf.reduce_mean(mmd_batch), inputs=z)
236
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
237
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
238
239

        return z
240
241


242
243
244
245
246
247
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

248
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

270
271
272
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
273
274
275

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

276
        ### MMD-based overlap ###
277
        intercomponent_mmd = tf.reduce_mean(
278
279
            tf.convert_to_tensor(
                [
280
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
281
282
283
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
284
            )
285
        )
286

287
        self.add_metric(
288
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
289
        )
290

291
292
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
293
294
295
296

        return target


297
class Latent_space_control(Layer):
298
299
300
301
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
302

303
    def __init__(self, silhouette=False, loss=False, *args, **kwargs):
304
        self.loss = loss
305
        self.silhouette = silhouette
306
307
        super(Latent_space_control, self).__init__(*args, **kwargs)

308
309
310
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})
311
        config.update({"silhouette": self.silhouette})
312

313
314
315
316
317
318
319
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

320
        # Adds Silhouette score controlling overlap between clusters
321
322
323
324
325
326
        if self.silhouette:
            hard_labels = tf.math.argmax(z_cat, axis=1)
            silhouette = tf.numpy_function(
                silhouette_score, [z, hard_labels], tf.float32
            )
            self.add_metric(silhouette, aggregation="mean", name="silhouette")
327

328
            if self.loss:
329
                self.add_loss(-tf.reduce_mean(silhouette), inputs=[z, hard_labels])
330

331
        return z
332
333
334
335
336
337
338


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

339
    def __init__(self, weight=1.0, *args, **kwargs):
340
341
342
343
344
345
346
347
348
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

349
350
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
351
        entropy = tf.reduce_sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
352
353

        # Adds metric that monitors dead neurons in the latent space
354
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
355

356
        self.add_loss(self.weight * tf.reduce_sum(entropy), inputs=[z])
357
358

        return z