model_utils.py 17.9 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
13
from sklearn.metrics import pairwise_distances
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
18
import numpy as np
19
import tensorflow as tf
20
import tensorflow_probability as tfp
21

22
tfd = tfp.distributions
23
tfpl = tfp.layers
24

lucas_miranda's avatar
lucas_miranda committed
25

26
# Helper functions and classes
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

88
89
90
91
92
93
94
95
96
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
97
    kernel = tf.exp(
98
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
99
    )
lucas_miranda's avatar
lucas_miranda committed
100
    return kernel
101
102


103
@tf.function
104
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
105
106
    """

107
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
108

109
110
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
111

112
113
114
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
115

116
    """
117
118
119
120

    x = tensors[0]
    y = tensors[1]

121
122
123
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
124
    mmd = (
125
126
127
128
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
129
    return mmd
130
131


132
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
133
134
135
136
137
138
139
140
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

141
142
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
143
144
145
146
147
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
148
        log_dir: str = ".",
149
    ):
lucas_miranda's avatar
lucas_miranda committed
150
        super().__init__()
151
152
153
154
155
156
157
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
158
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
159
        self.log_dir = log_dir
160

lucas_miranda's avatar
lucas_miranda committed
161
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
162
163
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
164
165
166
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
167
168
169

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
191

lucas_miranda's avatar
lucas_miranda committed
192
193
194
195
196
197
198
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
199
200
201
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
202
            )
203
204
205
206
207


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

208
    Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
209
210
211

    """

212
    def __init__(
213
        self, variational=True, validation_data=None, r=100, samples=10000, log_dir="."
214
    ):
215
        super().__init__()
216
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
217
        self.validation_data = validation_data
218
        self.r = r
lucas_miranda's avatar
lucas_miranda committed
219
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
220
        self.log_dir = log_dir
221
222

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
223
    def on_epoch_end(self, epoch, logs=None):
224
225
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

226
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
227
228
229
230
231
232
233
234
235
236
237
238
239

            # Get encoer and grouper from full model
            cluster_means = [
                layer for layer in self.model.layers if layer.name == "cluster_means"
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
                self.model.layers[0].input, cluster_means.output
lucas_miranda's avatar
lucas_miranda committed
240
            )
lucas_miranda's avatar
lucas_miranda committed
241
242
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
243
244
            )

lucas_miranda's avatar
lucas_miranda committed
245
246
247
248
249
250
251
252
253
254
255
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)

            # Multiply encodings by groups, to get a weighted version of the matrix
            encoding = (
                encoding
                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
            )
            hard_groups = groups.argmax(axis=1)

256
257
            # compute pairwise distances on latent space
            pdist = pairwise_distances(encoding)
lucas_miranda's avatar
lucas_miranda committed
258

259
            # Iterate over samples and compute purity across neighbourhood
260
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
261
262
263
264
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
265

lucas_miranda's avatar
lucas_miranda committed
266
            for i, sample in enumerate(random_idxs):
267
268
269

                neighborhood = pdist[sample] < self.r

lucas_miranda's avatar
lucas_miranda committed
270
                purity_vector[i] = (
271
                    np.sum(hard_groups[neighborhood] == hard_groups[sample])
lucas_miranda's avatar
lucas_miranda committed
272
273
274
275
                    / self.k
                    * np.max(groups[sample])
                )

lucas_miranda's avatar
lucas_miranda committed
276
277
278
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
279
                    "neighborhood_cluster_purity",
lucas_miranda's avatar
lucas_miranda committed
280
281
282
                    data=purity_vector.mean(),
                    step=epoch,
                )
283
284


lucas_miranda's avatar
lucas_miranda committed
285
286
287
class uncorrelated_features_constraint(Constraint):
    """

288
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
289
290
291
292
    Useful, among others, for auto encoder bottleneck layers

    """

293
294
295
296
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

297
    def get_config(self):  # pragma: no cover
298
        """Updates Constraint metadata"""
299
300

        config = super().get_config().copy()
301
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
302
303
304
        return config

    def get_covariance(self, x):
305
306
        """Computes the covariance of the elements of the passed layer"""

307
308
309
        x_centered_list = []

        for i in range(self.encoding_dim):
310
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
311
312

        x_centered = tf.stack(x_centered_list)
313
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
314
315
316
317
318
319
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
320
    # noinspection PyUnusedLocal
321
    def uncorrelated_feature(self, x):
322
323
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

324
        if self.encoding_dim <= 1:  # pragma: no cover
325
326
            return 0.0
        else:
327
328
            output = K.sum(
                K.square(
329
                    self.covariance
330
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
331
332
333
334
335
336
337
338
339
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


340
341
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
342
343
344
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

345
    def call(self, inputs, **kwargs):
346
        """Overrides the call method of the subclassed function"""
347
348
349
350
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
351
352
353
354
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

355
356
357
358
359
360
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

361
    def get_config(self):  # pragma: no cover
362
363
        """Updates Constraint metadata"""

364
365
366
367
368
369
370
371
372
373
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

374
    # noinspection PyAttributeOutsideInit
375
    def build(self, batch_input_shape):
376
377
        """Updates Layer's build method"""

378
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
379
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
380
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
381
            initializer="zeros",
382
383
384
385
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
386
387
        """Updates Layer's call method"""

388
389
390
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

391
    def compute_output_shape(self, input_shape):  # pragma: no cover
392
393
        """Outputs the transposed shape"""

394
395
396
        return input_shape[0], self.output_dim


397
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
398
    """
399
400
    Identity transform layer that adds KL Divergence
    to the final model loss.
401
402
    """

403
404
405
406
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

407
    def get_config(self):  # pragma: no cover
408
409
        """Updates Constraint metadata"""

410
        config = super().get_config().copy()
411
        config.update({"is_placeholder": self.is_placeholder})
412
413
414
        return config

    def call(self, distribution_a):
415
416
        """Updates Layer's call method"""

417
418
419
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
420
421
422
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
423
        )
424
        # noinspection PyProtectedMember
425
426
427
428
429
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


430
class MMDiscrepancyLayer(Layer):
431
    """
432
    Identity transform layer that adds MM Discrepancy
433
434
435
    to the final model loss.
    """

436
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
437
        self.is_placeholder = True
438
        self.batch_size = batch_size
439
        self.beta = beta
440
        self.prior = prior
441
442
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

443
    def get_config(self):  # pragma: no cover
444
445
        """Updates Constraint metadata"""

446
        config = super().get_config().copy()
447
        config.update({"batch_size": self.batch_size})
448
        config.update({"beta": self.beta})
449
        config.update({"prior": self.prior})
450
451
        return config

452
    def call(self, z, **kwargs):
453
454
        """Updates Layer's call method"""

455
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
456
        # noinspection PyTypeChecker
457
        mmd_batch = self.beta * compute_mmd((true_samples, z))
458
        self.add_loss(K.mean(mmd_batch), inputs=z)
459
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
460
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
461
462

        return z
463
464


465
class Cluster_overlap(Layer):
466
467
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
468
    using the average inter-cluster MMD as a metric
469
470
    """

471
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
472
473
474
475
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
476
        super(Cluster_overlap, self).__init__(*args, **kwargs)
477

478
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
479
480
        """Updates Constraint metadata"""

481
482
483
484
485
486
487
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
488
489
490
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
491
492
493
494

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
495
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
496

497
498
499
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
500
501
502

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
503
        # MMD-based overlap #
504
        intercomponent_mmd = K.mean(
505
506
            tf.convert_to_tensor(
                [
507
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
508
509
510
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
511
            )
512
        )
513

514
        self.add_metric(
515
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
516
        )
517

518
519
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
520
521
522
523

        return target


524
class Dead_neuron_control(Layer):
525
526
527
528
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
529

530
531
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
532

lucas_miranda's avatar
lucas_miranda committed
533
534
535
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
536
537
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
538
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
539
540
        )

lucas_miranda's avatar
lucas_miranda committed
541
        return target
542
543
544
545
546
547
548


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
549
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
550
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
551
        self.axis = axis
552
553
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

554
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
555
556
        """Updates Constraint metadata"""

557
558
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
559
        config.update({"axis": self.axis})
560
561

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
562
563
        """Updates Layer's call method"""

564
565
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
566
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
567
568

        # Adds metric that monitors dead neurons in the latent space
569
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
570

571
572
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
573
574

        return z