model_utils.py 10.9 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from keras import backend as K
5
from sklearn.metrics import silhouette_score
6
7
8
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
9
import tensorflow_probability as tfp
10

11
tfd = tfp.distributions
12
tfpl = tfp.layers
13
14

# Helper functions
15
16
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
17
18
19
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
20
21
22
23
24
25
26
27

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
28
29
30
31
32

        if dist > init_dist:
            init_dist = dist
            init = temp

33
34
35
        i += 1

    return init
36
37


38
39
40
41
42
43
44
45
46
47
48
def compute_kernel(x, y):
    x_size = K.shape(x)[0]
    y_size = K.shape(y)[0]
    dim = K.shape(x)[1]
    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
    return K.exp(
        -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
    )


49
@tf.function
50
51
52
53
54
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

55
56
57
58
59
60
61
62
63
64
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))

        x_centered = tf.stack(x_centered_list)
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
            output = K.sum(
                K.square(
                    self.covariance
                    - tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


154
155
156
157
158
159
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


160
161
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
162
163
164
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

165
166
167
168
169
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
170
        )
171
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
172

173
        return distribution_a
174
175


176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


208
class MMDiscrepancyLayer(Layer):
209
    """
210
    Identity transform layer that adds MM Discrepancy
211
212
213
    to the final model loss.
    """

214
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
215
        self.is_placeholder = True
216
        self.batch_size = batch_size
217
        self.beta = beta
218
        self.prior = prior
219
220
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

221
222
    def get_config(self):
        config = super().get_config().copy()
223
        config.update({"batch_size": self.batch_size})
224
        config.update({"beta": self.beta})
225
        config.update({"prior": self.prior})
226
227
        return config

228
    def call(self, z, **kwargs):
229
        true_samples = self.prior.sample(self.batch_size)
230
        mmd_batch = self.beta * compute_mmd([true_samples, z])
231
        self.add_loss(K.mean(mmd_batch), inputs=z)
232
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
233
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
234
235

        return z
236
237


238
239
240
241
242
243
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

244
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

266
267
268
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
269
270
271

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

272
273
274
275
        ### MMD-based overlap ###
        intercomponent_mmd = K.mean(
            tf.convert_to_tensor(
                [
276
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
277
278
279
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
280
            )
281
        )
282

283
        self.add_metric(
284
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
285
        )
286

287
288
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
289
290
291
292

        return target


293
class Latent_space_control(Layer):
294
295
296
297
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
298

299
    def __init__(self, silhouette=False, loss=False, *args, **kwargs):
300
        self.loss = loss
301
        self.silhouette = silhouette
302
303
        super(Latent_space_control, self).__init__(*args, **kwargs)

304
305
306
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})
307
        config.update({"silhouette": self.silhouette})
308

309
310
311
312
313
314
315
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

316
        # Adds Silhouette score controlling overlap between clusters
317
318
319
320
321
322
        if self.silhouette:
            hard_labels = tf.math.argmax(z_cat, axis=1)
            silhouette = tf.numpy_function(
                silhouette_score, [z, hard_labels], tf.float32
            )
            self.add_metric(silhouette, aggregation="mean", name="silhouette")
323

324
325
            if self.loss:
                self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
326

327
        return z
328
329
330
331
332
333
334


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

335
    def __init__(self, weight=1.0, *args, **kwargs):
336
337
338
339
340
341
342
343
344
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

345
346
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
347
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
348
349

        # Adds metric that monitors dead neurons in the latent space
350
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
351

352
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
353
354

        return z