model_utils.py 10.6 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from keras import backend as K
5
from sklearn.metrics import silhouette_score
6
7
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
8
import numpy as np
9
import tensorflow as tf
10
import tensorflow_probability as tfp
11

12
tfd = tfp.distributions
13
tfpl = tfp.layers
14
15

# Helper functions
16
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=10000000):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
    init_dist = 0
    for i in range(iters):
        temp = np.random.uniform(minval, maxval, shape)
        dist = np.abs(np.linalg.norm(np.diff(temp)))

        if dist > init_dist:
            init_dist = dist
            init = temp

    return init.astype(np.float32)


32
33
34
35
36
37
38
39
40
41
42
def compute_kernel(x, y):
    x_size = K.shape(x)[0]
    y_size = K.shape(y)[0]
    dim = K.shape(x)[1]
    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
    return K.exp(
        -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
    )


43
44
45
46
47
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

48
49
50
51
52
53
54
55
56
57
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))

        x_centered = tf.stack(x_centered_list)
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
            output = K.sum(
                K.square(
                    self.covariance
                    - tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


147
148
149
150
151
152
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


153
154
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
155
156
157
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

158
159
160
161
162
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
163
        )
164
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
165

166
        return distribution_a
167
168


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


201
class MMDiscrepancyLayer(Layer):
202
    """
203
    Identity transform layer that adds MM Discrepancy
204
205
206
    to the final model loss.
    """

207
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
208
        self.is_placeholder = True
209
        self.batch_size = batch_size
210
        self.beta = beta
211
        self.prior = prior
212
213
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

214
215
    def get_config(self):
        config = super().get_config().copy()
216
        config.update({"batch_size": self.batch_size})
217
        config.update({"beta": self.beta})
218
        config.update({"prior": self.prior})
219
220
        return config

221
    def call(self, z, **kwargs):
222
        true_samples = self.prior.sample(self.batch_size)
223
        mmd_batch = self.beta * compute_mmd([true_samples, z])
224
        self.add_loss(K.mean(mmd_batch), inputs=z)
225
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
226
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
227
228

        return z
229
230


231
232
233
234
235
236
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

237
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

259
260
261
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
262
263
264

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

265
266
267
268
        ### MMD-based overlap ###
        intercomponent_mmd = K.mean(
            tf.convert_to_tensor(
                [
269
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
270
271
272
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
273
            )
274
        )
275

276
        self.add_metric(
277
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
278
        )
279

280
281
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
282
283
284
285

        return target


286
class Latent_space_control(Layer):
287
288
289
290
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
291

292
    def __init__(self, silhouette=False, loss=False, *args, **kwargs):
293
        self.loss = loss
294
        self.silhouette = silhouette
295
296
        super(Latent_space_control, self).__init__(*args, **kwargs)

297
298
299
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})
300
        config.update({"silhouette": self.silhouette})
301

302
303
304
305
306
307
308
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

309
        # Adds Silhouette score controlling overlap between clusters
310
311
312
313
314
315
        if self.silhouette:
            hard_labels = tf.math.argmax(z_cat, axis=1)
            silhouette = tf.numpy_function(
                silhouette_score, [z, hard_labels], tf.float32
            )
            self.add_metric(silhouette, aggregation="mean", name="silhouette")
316

317
318
            if self.loss:
                self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
319

320
        return z
321
322
323
324
325
326
327


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

328
    def __init__(self, weight=1.0, *args, **kwargs):
329
330
331
332
333
334
335
336
337
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

338
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
339
340

        # Adds metric that monitors dead neurons in the latent space
341
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
342

343
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
344
345

        return z