model_utils.py 18 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
from sklearn.neighbors import NearestNeighbors
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
18
import numpy as np
19
import tensorflow as tf
20
import tensorflow_probability as tfp
21

22
tfd = tfp.distributions
23
tfpl = tfp.layers
24

lucas_miranda's avatar
lucas_miranda committed
25

26
# Helper functions and classes
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

88
89
90
91
92
93
94
95
96
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
97
    kernel = tf.exp(
98
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
99
    )
lucas_miranda's avatar
lucas_miranda committed
100
    return kernel
101
102


103
@tf.function
104
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
105
106
    """

107
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
108

109
110
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
111

112
113
114
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
115

116
    """
117
118
119
120

    x = tensors[0]
    y = tensors[1]

121
122
123
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
124
    mmd = (
125
126
127
128
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
129
    return mmd
130
131


132
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
133
134
135
136
137
138
139
140
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

141
142
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
143
144
145
146
147
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
148
        log_dir: str = ".",
149
    ):
lucas_miranda's avatar
lucas_miranda committed
150
        super().__init__()
151
152
153
154
155
156
157
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
158
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
159
        self.log_dir = log_dir
160

lucas_miranda's avatar
lucas_miranda committed
161
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
162
163
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
164
165
166
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
167
168
169

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
191

lucas_miranda's avatar
lucas_miranda committed
192
193
194
195
196
197
198
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
199
200
201
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
202
            )
203
204
205
206
207
208
209
210
211


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

212
213
214
    def __init__(
        self, variational=True, validation_data=None, k=100, samples=10000, log_dir="."
    ):
215
        super().__init__()
216
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
217
        self.validation_data = validation_data
218
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
219
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
220
        self.log_dir = log_dir
221
222

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
223
    def on_epoch_end(self, epoch, logs=None):
224
225
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

226
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
227
228
229
230
231
232
233
234
235
236
237
238
239

            # Get encoer and grouper from full model
            cluster_means = [
                layer for layer in self.model.layers if layer.name == "cluster_means"
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
                self.model.layers[0].input, cluster_means.output
lucas_miranda's avatar
lucas_miranda committed
240
            )
lucas_miranda's avatar
lucas_miranda committed
241
242
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
243
244
            )

lucas_miranda's avatar
lucas_miranda committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)

            # Multiply encodings by groups, to get a weighted version of the matrix
            encoding = (
                encoding
                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
            )
            hard_groups = groups.argmax(axis=1)

            # Fit KNN model
            knn = NearestNeighbors().fit(encoding)

            # Iterate over samples and compute purity over k neighbours
260
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
            for i, sample in enumerate(random_idxs):
                indexes = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                )
                purity_vector[i] = (
                    np.sum(hard_groups[indexes] == hard_groups[sample])
                    / self.k
                    * np.max(groups[sample])
                )

lucas_miranda's avatar
lucas_miranda committed
275
276
277
278
279
280
281
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
                    "knn_cluster_purity",
                    data=purity_vector.mean(),
                    step=epoch,
                )
282
283


lucas_miranda's avatar
lucas_miranda committed
284
285
286
class uncorrelated_features_constraint(Constraint):
    """

287
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
288
289
290
291
    Useful, among others, for auto encoder bottleneck layers

    """

292
293
294
295
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

296
    def get_config(self):  # pragma: no cover
297
        """Updates Constraint metadata"""
298
299

        config = super().get_config().copy()
300
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
301
302
303
        return config

    def get_covariance(self, x):
304
305
        """Computes the covariance of the elements of the passed layer"""

306
307
308
        x_centered_list = []

        for i in range(self.encoding_dim):
309
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
310
311

        x_centered = tf.stack(x_centered_list)
312
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
313
314
315
316
317
318
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
319
    # noinspection PyUnusedLocal
320
    def uncorrelated_feature(self, x):
321
322
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

323
        if self.encoding_dim <= 1:  # pragma: no cover
324
325
            return 0.0
        else:
326
327
            output = K.sum(
                K.square(
328
                    self.covariance
329
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
330
331
332
333
334
335
336
337
338
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


339
340
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
341
342
343
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

344
    def call(self, inputs, **kwargs):
345
        """Overrides the call method of the subclassed function"""
346
347
348
349
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
350
351
352
353
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

354
355
356
357
358
359
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

360
    def get_config(self):  # pragma: no cover
361
362
        """Updates Constraint metadata"""

363
364
365
366
367
368
369
370
371
372
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

373
    # noinspection PyAttributeOutsideInit
374
    def build(self, batch_input_shape):
375
376
        """Updates Layer's build method"""

377
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
378
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
379
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
380
            initializer="zeros",
381
382
383
384
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
385
386
        """Updates Layer's call method"""

387
388
389
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

390
    def compute_output_shape(self, input_shape):  # pragma: no cover
391
392
        """Outputs the transposed shape"""

393
394
395
        return input_shape[0], self.output_dim


396
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
397
    """
398
399
    Identity transform layer that adds KL Divergence
    to the final model loss.
400
401
    """

402
403
404
405
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

406
    def get_config(self):  # pragma: no cover
407
408
        """Updates Constraint metadata"""

409
        config = super().get_config().copy()
410
        config.update({"is_placeholder": self.is_placeholder})
411
412
413
        return config

    def call(self, distribution_a):
414
415
        """Updates Layer's call method"""

416
417
418
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
419
420
421
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
422
        )
423
        # noinspection PyProtectedMember
424
425
426
427
428
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


429
class MMDiscrepancyLayer(Layer):
430
    """
431
    Identity transform layer that adds MM Discrepancy
432
433
434
    to the final model loss.
    """

435
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
436
        self.is_placeholder = True
437
        self.batch_size = batch_size
438
        self.beta = beta
439
        self.prior = prior
440
441
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

442
    def get_config(self):  # pragma: no cover
443
444
        """Updates Constraint metadata"""

445
        config = super().get_config().copy()
446
        config.update({"batch_size": self.batch_size})
447
        config.update({"beta": self.beta})
448
        config.update({"prior": self.prior})
449
450
        return config

451
    def call(self, z, **kwargs):
452
453
        """Updates Layer's call method"""

454
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
455
        # noinspection PyTypeChecker
456
        mmd_batch = self.beta * compute_mmd((true_samples, z))
457
        self.add_loss(K.mean(mmd_batch), inputs=z)
458
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
459
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
460
461

        return z
462
463


464
class Cluster_overlap(Layer):
465
466
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
467
    using the average inter-cluster MMD as a metric
468
469
    """

470
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
471
472
473
474
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
475
        super(Cluster_overlap, self).__init__(*args, **kwargs)
476

477
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
478
479
        """Updates Constraint metadata"""

480
481
482
483
484
485
486
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
487
488
489
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
490
491
492
493

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
494
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
495

496
497
498
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
499
500
501

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
502
        # MMD-based overlap #
503
        intercomponent_mmd = K.mean(
504
505
            tf.convert_to_tensor(
                [
506
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
507
508
509
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
510
            )
511
        )
512

513
        self.add_metric(
514
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
515
        )
516

517
518
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
519
520
521
522

        return target


523
class Dead_neuron_control(Layer):
524
525
526
527
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
528

529
530
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
531

lucas_miranda's avatar
lucas_miranda committed
532
533
534
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
535
536
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
537
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
538
539
        )

lucas_miranda's avatar
lucas_miranda committed
540
        return target
541
542
543
544
545
546
547


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
548
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
549
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
550
        self.axis = axis
551
552
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

553
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
554
555
        """Updates Constraint metadata"""

556
557
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
558
        config.update({"axis": self.axis})
559
560

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
561
562
        """Updates Layer's call method"""

563
564
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
565
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
566
567

        # Adds metric that monitors dead neurons in the latent space
568
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
569

570
571
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
572
573

        return z