model_utils.py 17.9 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
from sklearn.neighbors import NearestNeighbors
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
18
import numpy as np
19
import tensorflow as tf
20
import tensorflow_probability as tfp
21

22
tfd = tfp.distributions
23
tfpl = tfp.layers
24

lucas_miranda's avatar
lucas_miranda committed
25

26
# Helper functions and classes
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

88
89
90
91
92
93
94
95
96
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
97
    kernel = tf.exp(
98
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
99
    )
lucas_miranda's avatar
lucas_miranda committed
100
    return kernel
101
102


103
@tf.function
104
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
105
106
    """

107
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
108

109
110
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
111

112
113
114
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
115

116
    """
117
118
119
120

    x = tensors[0]
    y = tensors[1]

121
122
123
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
124
    mmd = (
125
126
127
128
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
129
    return mmd
130
131


132
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
133
134
135
136
137
138
139
140
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

141
142
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
143
144
145
146
147
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
148
        log_dir: str = ".",
149
    ):
lucas_miranda's avatar
lucas_miranda committed
150
        super().__init__()
151
152
153
154
155
156
157
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
158
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
159
        self.log_dir = log_dir
160

lucas_miranda's avatar
lucas_miranda committed
161
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
162
163
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
164
165
166
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
167
168
169

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
191

lucas_miranda's avatar
lucas_miranda committed
192
193
194
195
196
197
198
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
199
200
201
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
202
            )
203
204
205
206
207
208
209
210
211


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

212
213
214
    def __init__(
        self, variational=True, validation_data=None, k=100, samples=10000, log_dir="."
    ):
215
        super().__init__()
216
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
217
        self.validation_data = validation_data
218
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
219
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
220
        self.log_dir = log_dir
221
222

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
223
    def on_epoch_end(self, epoch, logs=None):
224
225
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

226
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
227
228
229
230
231
232
233
234
235
236
237
238
239

            # Get encoer and grouper from full model
            cluster_means = [
                layer for layer in self.model.layers if layer.name == "cluster_means"
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
                self.model.layers[0].input, cluster_means.output
lucas_miranda's avatar
lucas_miranda committed
240
            )
lucas_miranda's avatar
lucas_miranda committed
241
242
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
243
244
            )

lucas_miranda's avatar
lucas_miranda committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)

            # Multiply encodings by groups, to get a weighted version of the matrix
            encoding = (
                encoding
                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
            )
            hard_groups = groups.argmax(axis=1)

            # Fit KNN model
            knn = NearestNeighbors().fit(encoding)

            # Iterate over samples and compute purity over k neighbours
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
            for i, sample in enumerate(random_idxs):
                indexes = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                )
                purity_vector[i] = (
                    np.sum(hard_groups[indexes] == hard_groups[sample])
                    / self.k
                    * np.max(groups[sample])
                )

lucas_miranda's avatar
lucas_miranda committed
274
275
276
277
278
279
280
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
                    "knn_cluster_purity",
                    data=purity_vector.mean(),
                    step=epoch,
                )
281
282


lucas_miranda's avatar
lucas_miranda committed
283
284
285
class uncorrelated_features_constraint(Constraint):
    """

286
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
287
288
289
290
    Useful, among others, for auto encoder bottleneck layers

    """

291
292
293
294
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

295
    def get_config(self):  # pragma: no cover
296
        """Updates Constraint metadata"""
297
298

        config = super().get_config().copy()
299
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
300
301
302
        return config

    def get_covariance(self, x):
303
304
        """Computes the covariance of the elements of the passed layer"""

305
306
307
        x_centered_list = []

        for i in range(self.encoding_dim):
308
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
309
310

        x_centered = tf.stack(x_centered_list)
311
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
312
313
314
315
316
317
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
318
    # noinspection PyUnusedLocal
319
    def uncorrelated_feature(self, x):
320
321
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

322
        if self.encoding_dim <= 1:  # pragma: no cover
323
324
            return 0.0
        else:
325
326
            output = K.sum(
                K.square(
327
                    self.covariance
328
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
329
330
331
332
333
334
335
336
337
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


338
339
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
340
341
342
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

343
    def call(self, inputs, **kwargs):
344
        """Overrides the call method of the subclassed function"""
345
346
347
348
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
349
350
351
352
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

353
354
355
356
357
358
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

359
    def get_config(self):  # pragma: no cover
360
361
        """Updates Constraint metadata"""

362
363
364
365
366
367
368
369
370
371
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

372
    # noinspection PyAttributeOutsideInit
373
    def build(self, batch_input_shape):
374
375
        """Updates Layer's build method"""

376
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
377
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
378
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
379
            initializer="zeros",
380
381
382
383
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
384
385
        """Updates Layer's call method"""

386
387
388
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

389
    def compute_output_shape(self, input_shape):  # pragma: no cover
390
391
        """Outputs the transposed shape"""

392
393
394
        return input_shape[0], self.output_dim


395
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
396
    """
397
398
    Identity transform layer that adds KL Divergence
    to the final model loss.
399
400
    """

401
402
403
404
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

405
    def get_config(self):  # pragma: no cover
406
407
        """Updates Constraint metadata"""

408
        config = super().get_config().copy()
409
        config.update({"is_placeholder": self.is_placeholder})
410
411
412
        return config

    def call(self, distribution_a):
413
414
        """Updates Layer's call method"""

415
416
417
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
418
419
420
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
421
        )
422
        # noinspection PyProtectedMember
423
424
425
426
427
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


428
class MMDiscrepancyLayer(Layer):
429
    """
430
    Identity transform layer that adds MM Discrepancy
431
432
433
    to the final model loss.
    """

434
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
435
        self.is_placeholder = True
436
        self.batch_size = batch_size
437
        self.beta = beta
438
        self.prior = prior
439
440
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

441
    def get_config(self):  # pragma: no cover
442
443
        """Updates Constraint metadata"""

444
        config = super().get_config().copy()
445
        config.update({"batch_size": self.batch_size})
446
        config.update({"beta": self.beta})
447
        config.update({"prior": self.prior})
448
449
        return config

450
    def call(self, z, **kwargs):
451
452
        """Updates Layer's call method"""

453
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
454
        # noinspection PyTypeChecker
455
        mmd_batch = self.beta * compute_mmd((true_samples, z))
456
        self.add_loss(K.mean(mmd_batch), inputs=z)
457
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
458
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
459
460

        return z
461
462


463
class Cluster_overlap(Layer):
464
465
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
466
    using the average inter-cluster MMD as a metric
467
468
    """

469
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
470
471
472
473
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
474
        super(Cluster_overlap, self).__init__(*args, **kwargs)
475

476
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
477
478
        """Updates Constraint metadata"""

479
480
481
482
483
484
485
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
486
487
488
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
489
490
491
492

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
493
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
494

495
496
497
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
498
499
500

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
501
        # MMD-based overlap #
502
        intercomponent_mmd = K.mean(
503
504
            tf.convert_to_tensor(
                [
505
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
506
507
508
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
509
            )
510
        )
511

512
        self.add_metric(
513
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
514
        )
515

516
517
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
518
519
520
521

        return target


522
class Dead_neuron_control(Layer):
523
524
525
526
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
527

528
529
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
530

lucas_miranda's avatar
lucas_miranda committed
531
532
533
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
534
535
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
536
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
537
538
        )

lucas_miranda's avatar
lucas_miranda committed
539
        return target
540
541
542
543
544
545
546


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
547
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
548
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
549
        self.axis = axis
550
551
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

552
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
553
554
        """Updates Constraint metadata"""

555
556
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
557
        config.update({"axis": self.axis})
558
559

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
560
561
        """Updates Layer's call method"""

562
563
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
564
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
565
566

        # Adds metric that monitors dead neurons in the latent space
567
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
568

569
570
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
571
572

        return z