model_utils.py 16 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
13
from typing import Any, Tuple

14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


72
@tf.function
lucas_miranda's avatar
lucas_miranda committed
73
74
75
def far_away_uniform_initialiser(
    shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
) -> tf.Tensor:
76
77
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
lucas_miranda's avatar
lucas_miranda committed
78
79
80
81
82
83
84
85
86
87
88
89
90

        Parameters:
            - shape (tuple): shape of the object to generate.
            - minval (int): Minimum value of the uniform distribution from which to sample
            - maxval (int): Maximum value of the uniform distribution from which to sample
            - iters (int): the algorithm generates values at random and keeps those runs that
            are the farthest apart. Increasing this parameter will lead to more accurate,
            results while making the function run slowlier.

        Returns:
            - init (tf.Tensor): tensor of the specified shape in which the column vectors
             are as far as possible

91
    """
92
93
94
95
96
97
98
99

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
100
101
102
103
104

        if dist > init_dist:
            init_dist = dist
            init = temp

105
106
107
        i += 1

    return init
108
109


lucas_miranda's avatar
lucas_miranda committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

125
126
127
128
129
130
131
132
133
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
134
    kernel = tf.exp(
135
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
136
    )
lucas_miranda's avatar
lucas_miranda committed
137
    return kernel
138
139


140
@tf.function
141
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
142
143
    """

144
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
145

146
147
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
148

149
150
151
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
152

153
    """
154
155
156
157

    x = tensors[0]
    y = tensors[1]

158
159
160
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
161
    mmd = (
162
163
164
165
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
166
    return mmd
167
168


169
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
170
171
172
173
174
175
176
177
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

178
179
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
180
181
182
183
184
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
185
    ):
lucas_miranda's avatar
lucas_miranda committed
186
        super().__init__()
187
188
189
190
191
192
193
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
194
        self.history = {}
195

lucas_miranda's avatar
lucas_miranda committed
196
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
197
198
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
199
200
201
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
202
203
204

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
226
227


lucas_miranda's avatar
lucas_miranda committed
228
229
230
class uncorrelated_features_constraint(Constraint):
    """

231
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
232
233
234
235
    Useful, among others, for auto encoder bottleneck layers

    """

236
237
238
239
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

240
    def get_config(self):  # pragma: no cover
241
        """Updates Constraint metadata"""
242
243

        config = super().get_config().copy()
244
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
245
246
247
        return config

    def get_covariance(self, x):
248
249
        """Computes the covariance of the elements of the passed layer"""

250
251
252
        x_centered_list = []

        for i in range(self.encoding_dim):
253
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
254
255

        x_centered = tf.stack(x_centered_list)
256
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
257
258
259
260
261
262
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
263
    # noinspection PyUnusedLocal
264
    def uncorrelated_feature(self, x):
265
266
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

267
        if self.encoding_dim <= 1:  # pragma: no cover
268
269
            return 0.0
        else:
270
271
            output = K.sum(
                K.square(
272
                    self.covariance
273
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
274
275
276
277
278
279
280
281
282
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


283
284
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
285
286
287
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

288
    def call(self, inputs, **kwargs):
289
        """Overrides the call method of the subclassed function"""
290
291
292
293
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
294
295
296
297
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

298
299
300
301
302
303
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

304
    def get_config(self):  # pragma: no cover
305
306
        """Updates Constraint metadata"""

307
308
309
310
311
312
313
314
315
316
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

317
    # noinspection PyAttributeOutsideInit
318
    def build(self, batch_input_shape):
319
320
        """Updates Layer's build method"""

321
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
322
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
323
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
324
            initializer="zeros",
325
326
327
328
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
329
330
        """Updates Layer's call method"""

331
332
333
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

334
    def compute_output_shape(self, input_shape):  # pragma: no cover
335
336
        """Outputs the transposed shape"""

337
338
339
        return input_shape[0], self.output_dim


340
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
341
    """
342
343
    Identity transform layer that adds KL Divergence
    to the final model loss.
344
345
    """

346
347
348
349
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

350
    def get_config(self):  # pragma: no cover
351
352
        """Updates Constraint metadata"""

353
        config = super().get_config().copy()
354
        config.update({"is_placeholder": self.is_placeholder})
355
356
357
        return config

    def call(self, distribution_a):
358
359
        """Updates Layer's call method"""

360
361
362
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
363
364
365
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
366
        )
367
        # noinspection PyProtectedMember
368
369
370
371
372
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


373
class MMDiscrepancyLayer(Layer):
374
    """
375
    Identity transform layer that adds MM Discrepancy
376
377
378
    to the final model loss.
    """

379
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
380
        self.is_placeholder = True
381
        self.batch_size = batch_size
382
        self.beta = beta
383
        self.prior = prior
384
385
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

386
    def get_config(self):  # pragma: no cover
387
388
        """Updates Constraint metadata"""

389
        config = super().get_config().copy()
390
        config.update({"batch_size": self.batch_size})
391
        config.update({"beta": self.beta})
392
        config.update({"prior": self.prior})
393
394
        return config

395
    def call(self, z, **kwargs):
396
397
        """Updates Layer's call method"""

398
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
399
        # noinspection PyTypeChecker
400
        mmd_batch = self.beta * compute_mmd((true_samples, z))
401
        self.add_loss(K.mean(mmd_batch), inputs=z)
402
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
403
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
404
405

        return z
406
407


408
class Gaussian_mixture_overlap(Layer):
409
410
411
412
413
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

414
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
415
416
417
418
419
420
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

421
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
422
423
        """Updates Constraint metadata"""

424
425
426
427
428
429
430
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
431
432
433
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
434
435
436
437

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
438
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
439

440
441
442
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
443
444
445

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
446
        # MMD-based overlap #
447
        intercomponent_mmd = K.mean(
448
449
            tf.convert_to_tensor(
                [
450
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
451
452
453
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
454
            )
455
        )
456

457
        self.add_metric(
458
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
459
        )
460

461
462
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
463
464
465
466

        return target


467
class Dead_neuron_control(Layer):
468
469
470
471
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
472

473
474
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
475

lucas_miranda's avatar
lucas_miranda committed
476
477
478
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
479
480
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
481
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
482
483
        )

lucas_miranda's avatar
lucas_miranda committed
484
        return target
485
486
487
488
489
490
491


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
492
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
493
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
494
        self.axis = axis
495
496
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

497
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
498
499
        """Updates Constraint metadata"""

500
501
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
502
        config.update({"axis": self.axis})
503
504

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
505
506
        """Updates Layer's call method"""

507
508
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
509
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
510
511

        # Adds metric that monitors dead neurons in the latent space
512
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
513

514
515
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
516
517

        return z