model_utils.py 17.8 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
from sklearn.neighbors import NearestNeighbors
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
18
import numpy as np
19
import tensorflow as tf
20
import tensorflow_probability as tfp
21

22
tfd = tfp.distributions
23
tfpl = tfp.layers
24

lucas_miranda's avatar
lucas_miranda committed
25

26
# Helper functions and classes
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

88
89
90
91
92
93
94
95
96
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
97
    kernel = tf.exp(
98
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
99
    )
lucas_miranda's avatar
lucas_miranda committed
100
    return kernel
101
102


103
@tf.function
104
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
105
106
    """

107
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
108

109
110
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
111

112
113
114
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
115

116
    """
117
118
119
120

    x = tensors[0]
    y = tensors[1]

121
122
123
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
124
    mmd = (
125
126
127
128
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
129
    return mmd
130
131


132
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
133
134
135
136
137
138
139
140
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

141
142
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
143
144
145
146
147
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
148
        log_dir: str = ".",
149
    ):
lucas_miranda's avatar
lucas_miranda committed
150
        super().__init__()
151
152
153
154
155
156
157
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
158
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
159
        self.log_dir = log_dir
160

lucas_miranda's avatar
lucas_miranda committed
161
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
162
163
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
164
165
166
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
167
168
169

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
191

lucas_miranda's avatar
lucas_miranda committed
192
193
194
195
196
197
198
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
199
200
201
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
202
            )
203
204
205
206
207
208
209
210
211


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

lucas_miranda's avatar
lucas_miranda committed
212
    def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."):
213
        super().__init__()
lucas_miranda's avatar
lucas_miranda committed
214
        self.validation_data = validation_data
215
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
216
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
217
        self.log_dir = log_dir
218
219

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
220
    def on_epoch_end(self, epoch, logs=None):
221
222
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

lucas_miranda's avatar
lucas_miranda committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        if self.validation_data is not None:

            # Get encoer and grouper from full model
            cluster_means = [
                layer for layer in self.model.layers if layer.name == "cluster_means"
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
                self.model.layers[0].input, cluster_means.output
lucas_miranda's avatar
lucas_miranda committed
237
            )
lucas_miranda's avatar
lucas_miranda committed
238
239
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
240
241
            )

lucas_miranda's avatar
lucas_miranda committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)

            # Multiply encodings by groups, to get a weighted version of the matrix
            encoding = (
                encoding
                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
            )
            hard_groups = groups.argmax(axis=1)

            # Fit KNN model
            knn = NearestNeighbors().fit(encoding)

            # Iterate over samples and compute purity over k neighbours
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
            for i, sample in enumerate(random_idxs):
                indexes = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                )
                purity_vector[i] = (
                    np.sum(hard_groups[indexes] == hard_groups[sample])
                    / self.k
                    * np.max(groups[sample])
                )

lucas_miranda's avatar
lucas_miranda committed
271
272
273
274
275
276
277
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
                    "knn_cluster_purity",
                    data=purity_vector.mean(),
                    step=epoch,
                )
278
279


lucas_miranda's avatar
lucas_miranda committed
280
281
282
class uncorrelated_features_constraint(Constraint):
    """

283
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
284
285
286
287
    Useful, among others, for auto encoder bottleneck layers

    """

288
289
290
291
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

292
    def get_config(self):  # pragma: no cover
293
        """Updates Constraint metadata"""
294
295

        config = super().get_config().copy()
296
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
297
298
299
        return config

    def get_covariance(self, x):
300
301
        """Computes the covariance of the elements of the passed layer"""

302
303
304
        x_centered_list = []

        for i in range(self.encoding_dim):
305
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
306
307

        x_centered = tf.stack(x_centered_list)
308
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
309
310
311
312
313
314
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
315
    # noinspection PyUnusedLocal
316
    def uncorrelated_feature(self, x):
317
318
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

319
        if self.encoding_dim <= 1:  # pragma: no cover
320
321
            return 0.0
        else:
322
323
            output = K.sum(
                K.square(
324
                    self.covariance
325
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
326
327
328
329
330
331
332
333
334
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


335
336
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
337
338
339
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

340
    def call(self, inputs, **kwargs):
341
        """Overrides the call method of the subclassed function"""
342
343
344
345
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
346
347
348
349
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

350
351
352
353
354
355
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

356
    def get_config(self):  # pragma: no cover
357
358
        """Updates Constraint metadata"""

359
360
361
362
363
364
365
366
367
368
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

369
    # noinspection PyAttributeOutsideInit
370
    def build(self, batch_input_shape):
371
372
        """Updates Layer's build method"""

373
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
374
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
375
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
376
            initializer="zeros",
377
378
379
380
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
381
382
        """Updates Layer's call method"""

383
384
385
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

386
    def compute_output_shape(self, input_shape):  # pragma: no cover
387
388
        """Outputs the transposed shape"""

389
390
391
        return input_shape[0], self.output_dim


392
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
393
    """
394
395
    Identity transform layer that adds KL Divergence
    to the final model loss.
396
397
    """

398
399
400
401
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

402
    def get_config(self):  # pragma: no cover
403
404
        """Updates Constraint metadata"""

405
        config = super().get_config().copy()
406
        config.update({"is_placeholder": self.is_placeholder})
407
408
409
        return config

    def call(self, distribution_a):
410
411
        """Updates Layer's call method"""

412
413
414
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
415
416
417
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
418
        )
419
        # noinspection PyProtectedMember
420
421
422
423
424
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


425
class MMDiscrepancyLayer(Layer):
426
    """
427
    Identity transform layer that adds MM Discrepancy
428
429
430
    to the final model loss.
    """

431
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
432
        self.is_placeholder = True
433
        self.batch_size = batch_size
434
        self.beta = beta
435
        self.prior = prior
436
437
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

438
    def get_config(self):  # pragma: no cover
439
440
        """Updates Constraint metadata"""

441
        config = super().get_config().copy()
442
        config.update({"batch_size": self.batch_size})
443
        config.update({"beta": self.beta})
444
        config.update({"prior": self.prior})
445
446
        return config

447
    def call(self, z, **kwargs):
448
449
        """Updates Layer's call method"""

450
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
451
        # noinspection PyTypeChecker
452
        mmd_batch = self.beta * compute_mmd((true_samples, z))
453
        self.add_loss(K.mean(mmd_batch), inputs=z)
454
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
455
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
456
457

        return z
458
459


460
class Cluster_overlap(Layer):
461
462
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
463
    using the average inter-cluster MMD as a metric
464
465
    """

466
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
467
468
469
470
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
471
        super(Cluster_overlap, self).__init__(*args, **kwargs)
472

473
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
474
475
        """Updates Constraint metadata"""

476
477
478
479
480
481
482
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
483
484
485
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
486
487
488
489

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
490
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
491

492
493
494
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
495
496
497

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
498
        # MMD-based overlap #
499
        intercomponent_mmd = K.mean(
500
501
            tf.convert_to_tensor(
                [
502
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
503
504
505
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
506
            )
507
        )
508

509
        self.add_metric(
510
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
511
        )
512

513
514
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
515
516
517
518

        return target


519
class Dead_neuron_control(Layer):
520
521
522
523
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
524

525
526
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
527

lucas_miranda's avatar
lucas_miranda committed
528
529
530
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
531
532
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
533
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
534
535
        )

lucas_miranda's avatar
lucas_miranda committed
536
        return target
537
538
539
540
541
542
543


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
544
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
545
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
546
        self.axis = axis
547
548
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

549
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
550
551
        """Updates Constraint metadata"""

552
553
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
554
        config.update({"axis": self.axis})
555
556

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
557
558
        """Updates Layer's call method"""

559
560
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
561
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
562
563

        # Adds metric that monitors dead neurons in the latent space
564
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
565

566
567
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
568
569

        return z