model_utils.py 12.5 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
12
from tensorflow.keras import backend as K
13
14
15
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
16
import tensorflow_probability as tfp
17

18
tfd = tfp.distributions
19
tfpl = tfp.layers
20

lucas_miranda's avatar
lucas_miranda committed
21

22
# Helper functions
23
@tf.function
lucas_miranda's avatar
lucas_miranda committed
24
25
26
def far_away_uniform_initialiser(
    shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
) -> tf.Tensor:
27
28
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
lucas_miranda's avatar
lucas_miranda committed
29
30
31
32
33
34
35
36
37
38
39
40
41

        Parameters:
            - shape (tuple): shape of the object to generate.
            - minval (int): Minimum value of the uniform distribution from which to sample
            - maxval (int): Maximum value of the uniform distribution from which to sample
            - iters (int): the algorithm generates values at random and keeps those runs that
            are the farthest apart. Increasing this parameter will lead to more accurate,
            results while making the function run slowlier.

        Returns:
            - init (tf.Tensor): tensor of the specified shape in which the column vectors
             are as far as possible

42
    """
43
44
45
46
47
48
49
50

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
51
52
53
54
55

        if dist > init_dist:
            init_dist = dist
            init = temp

56
57
58
        i += 1

    return init
59
60


lucas_miranda's avatar
lucas_miranda committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

76
77
78
79
80
81
82
83
84
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
85
    kernel = tf.exp(
86
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
87
    )
lucas_miranda's avatar
lucas_miranda committed
88
    return kernel
89
90


91
@tf.function
lucas_miranda's avatar
lucas_miranda committed
92
93
94
95
96
97
98
99
100
101
102
103
104
def compute_mmd(tensors: tuple) -> tf.Tensor:
    """

        Computes the MMD between the two specified vectors using a gaussian kernel.

            Parameters:
                - tensors (tuple): tuple containing two tf.Tensor objects

            Returns
                - mmd (tf.Tensor): returns the maximum mean discrepancy for each
                training instance

        """
105
106
107
108

    x = tensors[0]
    y = tensors[1]

109
110
111
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
112
    mmd = (
113
114
115
116
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
117
    return mmd
118
119


120
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
121
122
123
124
125
126
127
128
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

129
130
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
131
132
133
134
135
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
136
    ):
lucas_miranda's avatar
lucas_miranda committed
137
        super().__init__()
138
139
140
141
142
143
144
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
145
        self.history = {}
146

lucas_miranda's avatar
lucas_miranda committed
147
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
148
149
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
150
151
152
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
153
154
155

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
177
178


lucas_miranda's avatar
lucas_miranda committed
179
180
181
class uncorrelated_features_constraint(Constraint):
    """

182
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
183
184
185
186
    Useful, among others, for auto encoder bottleneck layers

    """

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
203
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
204
205

        x_centered = tf.stack(x_centered_list)
206
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
207
208
209
210
211
212
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
213
    def uncorrelated_feature(self, x):
214
215
216
        if self.encoding_dim <= 1:
            return 0.0
        else:
217
218
            output = K.sum(
                K.square(
219
                    self.covariance
220
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
221
222
223
224
225
226
227
228
229
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


230
231
232
233
234
235
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


236
237
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
238
239
240
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

241
242
243
    def get_config(self):
        config = super().get_config().copy()
        config.update(
244
            {"is_placeholder": self.is_placeholder,}
245
246
247
        )
        return config

248
249
250
251
252
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
253
        )
254
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
255

256
        return distribution_a
257
258


259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


291
class MMDiscrepancyLayer(Layer):
292
    """
293
    Identity transform layer that adds MM Discrepancy
294
295
296
    to the final model loss.
    """

297
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
298
        self.is_placeholder = True
299
        self.batch_size = batch_size
300
        self.beta = beta
301
        self.prior = prior
302
303
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

304
305
    def get_config(self):
        config = super().get_config().copy()
306
        config.update({"batch_size": self.batch_size})
307
        config.update({"beta": self.beta})
308
        config.update({"prior": self.prior})
309
310
        return config

311
    def call(self, z, **kwargs):
312
        true_samples = self.prior.sample(self.batch_size)
313
        mmd_batch = self.beta * compute_mmd([true_samples, z])
314
        self.add_loss(K.mean(mmd_batch), inputs=z)
315
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
316
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
317
318

        return z
319
320


321
322
323
324
325
326
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

327
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

349
350
351
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
352
353
354

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

355
        ### MMD-based overlap ###
356
        intercomponent_mmd = K.mean(
357
358
            tf.convert_to_tensor(
                [
359
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
360
361
362
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
363
            )
364
        )
365

366
        self.add_metric(
367
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
368
        )
369

370
371
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
372
373
374
375

        return target


376
class Dead_neuron_control(Layer):
377
378
379
380
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
381

382
383
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
384

385
386
387
388
389
390
391
392
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

        return z
393
394
395
396
397
398
399


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

400
    def __init__(self, weight=1.0, *args, **kwargs):
401
402
403
404
405
406
407
408
409
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

410
411
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
412
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
413
414

        # Adds metric that monitors dead neurons in the latent space
415
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
416

417
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
418
419

        return z