model_utils.py 17.2 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
from sklearn.neighbors import NearestNeighbors
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

87
88
89
90
91
92
93
94
95
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
96
    kernel = tf.exp(
97
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
98
    )
lucas_miranda's avatar
lucas_miranda committed
99
    return kernel
100
101


102
@tf.function
103
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
104
105
    """

106
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
107

108
109
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
110

111
112
113
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
114

115
    """
116
117
118
119

    x = tensors[0]
    y = tensors[1]

120
121
122
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
123
    mmd = (
124
125
126
127
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
128
    return mmd
129
130


131
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
132
133
134
135
136
137
138
139
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

140
141
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
147
    ):
lucas_miranda's avatar
lucas_miranda committed
148
        super().__init__()
149
150
151
152
153
154
155
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
156
        self.history = {}
157

lucas_miranda's avatar
lucas_miranda committed
158
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
159
160
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
161
162
163
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
164
165
166

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
188

lucas_miranda's avatar
lucas_miranda committed
189
        logs["learning_rate"] = self.last_rate
190
191
192
193
194
195
196
197
198


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

lucas_miranda's avatar
lucas_miranda committed
199
    def __init__(self, k=5, samples=10000):
200
201
        super().__init__()
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
202
        self.samples = samples
203
204
205
206
207

    # noinspection PyMethodOverriding,PyTypeChecker
    def on_epoch_end(self, batch: int, logs):
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

lucas_miranda's avatar
lucas_miranda committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        if self.validation_data is not None:

            # Get encoer and grouper from full model
            cluster_means = [
                layer for layer in self.model.layers if layer.name == "cluster_means"
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
                self.model.layers[0].input, cluster_means.output
lucas_miranda's avatar
lucas_miranda committed
222
            )
lucas_miranda's avatar
lucas_miranda committed
223
224
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
225
226
            )

lucas_miranda's avatar
lucas_miranda committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            print(self.validation_data)

            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)

            # Multiply encodings by groups, to get a weighted version of the matrix
            encoding = (
                encoding
                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
            )
            hard_groups = groups.argmax(axis=1)

            # Fit KNN model
            knn = NearestNeighbors().fit(encoding)

            # Iterate over samples and compute purity over k neighbours
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
            for i, sample in enumerate(random_idxs):
                indexes = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                )
                purity_vector[i] = (
                    np.sum(hard_groups[indexes] == hard_groups[sample])
                    / self.k
                    * np.max(groups[sample])
                )

            logs["knn_cluster_purity"] = purity_vector.mean()
259
260


lucas_miranda's avatar
lucas_miranda committed
261
262
263
class uncorrelated_features_constraint(Constraint):
    """

264
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
265
266
267
268
    Useful, among others, for auto encoder bottleneck layers

    """

269
270
271
272
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

273
    def get_config(self):  # pragma: no cover
274
        """Updates Constraint metadata"""
275
276

        config = super().get_config().copy()
277
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
278
279
280
        return config

    def get_covariance(self, x):
281
282
        """Computes the covariance of the elements of the passed layer"""

283
284
285
        x_centered_list = []

        for i in range(self.encoding_dim):
286
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
287
288

        x_centered = tf.stack(x_centered_list)
289
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
290
291
292
293
294
295
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
296
    # noinspection PyUnusedLocal
297
    def uncorrelated_feature(self, x):
298
299
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

300
        if self.encoding_dim <= 1:  # pragma: no cover
301
302
            return 0.0
        else:
303
304
            output = K.sum(
                K.square(
305
                    self.covariance
306
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
307
308
309
310
311
312
313
314
315
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


316
317
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
318
319
320
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

321
    def call(self, inputs, **kwargs):
322
        """Overrides the call method of the subclassed function"""
323
324
325
326
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
327
328
329
330
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

331
332
333
334
335
336
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

337
    def get_config(self):  # pragma: no cover
338
339
        """Updates Constraint metadata"""

340
341
342
343
344
345
346
347
348
349
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

350
    # noinspection PyAttributeOutsideInit
351
    def build(self, batch_input_shape):
352
353
        """Updates Layer's build method"""

354
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
355
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
356
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
357
            initializer="zeros",
358
359
360
361
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
362
363
        """Updates Layer's call method"""

364
365
366
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

367
    def compute_output_shape(self, input_shape):  # pragma: no cover
368
369
        """Outputs the transposed shape"""

370
371
372
        return input_shape[0], self.output_dim


373
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
374
    """
375
376
    Identity transform layer that adds KL Divergence
    to the final model loss.
377
378
    """

379
380
381
382
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

383
    def get_config(self):  # pragma: no cover
384
385
        """Updates Constraint metadata"""

386
        config = super().get_config().copy()
387
        config.update({"is_placeholder": self.is_placeholder})
388
389
390
        return config

    def call(self, distribution_a):
391
392
        """Updates Layer's call method"""

393
394
395
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
396
397
398
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
399
        )
400
        # noinspection PyProtectedMember
401
402
403
404
405
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


406
class MMDiscrepancyLayer(Layer):
407
    """
408
    Identity transform layer that adds MM Discrepancy
409
410
411
    to the final model loss.
    """

412
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
413
        self.is_placeholder = True
414
        self.batch_size = batch_size
415
        self.beta = beta
416
        self.prior = prior
417
418
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

419
    def get_config(self):  # pragma: no cover
420
421
        """Updates Constraint metadata"""

422
        config = super().get_config().copy()
423
        config.update({"batch_size": self.batch_size})
424
        config.update({"beta": self.beta})
425
        config.update({"prior": self.prior})
426
427
        return config

428
    def call(self, z, **kwargs):
429
430
        """Updates Layer's call method"""

431
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
432
        # noinspection PyTypeChecker
433
        mmd_batch = self.beta * compute_mmd((true_samples, z))
434
        self.add_loss(K.mean(mmd_batch), inputs=z)
435
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
436
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
437
438

        return z
439
440


441
class Cluster_overlap(Layer):
442
443
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
444
    using the average inter-cluster MMD as a metric
445
446
    """

447
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
448
449
450
451
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
452
        super(Cluster_overlap, self).__init__(*args, **kwargs)
453

454
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
455
456
        """Updates Constraint metadata"""

457
458
459
460
461
462
463
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
464
465
466
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
467
468
469
470

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
471
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
472

473
474
475
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
476
477
478

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
479
        # MMD-based overlap #
480
        intercomponent_mmd = K.mean(
481
482
            tf.convert_to_tensor(
                [
483
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
484
485
486
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
487
            )
488
        )
489

490
        self.add_metric(
491
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
492
        )
493

494
495
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
496
497
498
499

        return target


500
class Dead_neuron_control(Layer):
501
502
503
504
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
505

506
507
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
508

lucas_miranda's avatar
lucas_miranda committed
509
510
511
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
512
513
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
514
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
515
516
        )

lucas_miranda's avatar
lucas_miranda committed
517
        return target
518
519
520
521
522
523
524


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
525
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
526
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
527
        self.axis = axis
528
529
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

530
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
531
532
        """Updates Constraint metadata"""

533
534
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
535
        config.update({"axis": self.axis})
536
537

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
538
539
        """Updates Layer's call method"""

540
541
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
542
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
543
544

        # Adds metric that monitors dead neurons in the latent space
545
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
546

547
548
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
549
550

        return z