model_utils.py 18.2 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
13
from scipy.stats import entropy
14
from sklearn.neighbors import NearestNeighbors
15
from tensorflow.keras import backend as K
16
17
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
18
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
19
import numpy as np
20
import tensorflow as tf
21
import tensorflow_probability as tfp
22

23
tfd = tfp.distributions
24
tfpl = tfp.layers
25

lucas_miranda's avatar
lucas_miranda committed
26

27
# Helper functions and classes
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

89
90
91
92
93
94
95
96
97
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
98
    kernel = tf.exp(
99
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
100
    )
lucas_miranda's avatar
lucas_miranda committed
101
    return kernel
102
103


104
@tf.function
105
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
106
107
    """

108
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
109

110
111
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
112

113
114
115
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
116

117
    """
118
119
120
121

    x = tensors[0]
    y = tensors[1]

122
123
124
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
125
    mmd = (
126
127
128
129
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
130
    return mmd
131
132


133
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
134
135
136
137
138
139
140
141
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

142
143
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
144
145
146
147
148
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
149
        log_dir: str = ".",
150
    ):
lucas_miranda's avatar
lucas_miranda committed
151
        super().__init__()
152
153
154
155
156
157
158
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
159
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
160
        self.log_dir = log_dir
161

lucas_miranda's avatar
lucas_miranda committed
162
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
163
164
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
165
166
167
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
168
169
170

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
192

lucas_miranda's avatar
lucas_miranda committed
193
194
195
196
197
198
199
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
200
201
202
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
203
            )
204
205


206
class neighbor_latent_entropy(tf.keras.callbacks.Callback):
207
208
    """

209
210
    Latent space entropy callback. Computes the entropy of cluster assignment across k nearest neighbors of a subset
    of samples in the latent space.
211
212
213

    """

214
    def __init__(
215
        self,
216
217
218
        encoding_dim: int,
        variational: bool = True,
        validation_data: np.ndarray = None,
219
        k: int = 100,
220
221
        samples: int = 10000,
        log_dir: str = ".",
222
    ):
223
        super().__init__()
224
        self.enc = encoding_dim
225
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
226
        self.validation_data = validation_data
227
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
228
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
229
        self.log_dir = log_dir
230
231

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
232
    def on_epoch_end(self, epoch, logs=None):
233
234
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

235
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
236
237

            # Get encoer and grouper from full model
238
239
240
241
            latent_distribution = [
                layer
                for layer in self.model.layers
                if layer.name == "latent_distribution"
lucas_miranda's avatar
lucas_miranda committed
242
243
244
245
246
247
248
249
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
250
                self.model.layers[0].input, latent_distribution.output
lucas_miranda's avatar
lucas_miranda committed
251
            )
lucas_miranda's avatar
lucas_miranda committed
252
253
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
254
255
            )

lucas_miranda's avatar
lucas_miranda committed
256
257
258
259
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)
            hard_groups = groups.argmax(axis=1)
260
            max_groups = groups.max(axis=1)
lucas_miranda's avatar
lucas_miranda committed
261

262
            # compute pairwise distances on latent space
263
            knn = NearestNeighbors().fit(encoding)
lucas_miranda's avatar
lucas_miranda committed
264

265
            # Iterate over samples and compute purity across neighbourhood
266
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
267
268
269
270
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
271

lucas_miranda's avatar
lucas_miranda committed
272
            for i, sample in enumerate(random_idxs):
273

274
275
276
277
                neighborhood = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                ).flatten()

278
                z = hard_groups[neighborhood]
279

280
                # Compute Shannon entropy across samples
281
                neigh_entropy = entropy(np.bincount(z))
282

283
                # Add result to pre allocated array
284
                purity_vector[i] = neigh_entropy
lucas_miranda's avatar
lucas_miranda committed
285

lucas_miranda's avatar
lucas_miranda committed
286
287
288
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
289
                    "average_neighborhood_cluster_entropy",
290
                    data=np.average(
291
292
                        purity_vector,
                        weights=max_groups[random_idxs]),
293
294
295
296
297
298
299
                    step=epoch,
                )
                tf.summary.scalar(
                    "average_confidence_in_selected_cluster",
                    data=np.average(max_groups),
                    step=epoch,
                )
300
301


lucas_miranda's avatar
lucas_miranda committed
302
303
304
class uncorrelated_features_constraint(Constraint):
    """

305
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
306
307
308
309
    Useful, among others, for auto encoder bottleneck layers

    """

310
311
312
313
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

314
    def get_config(self):  # pragma: no cover
315
        """Updates Constraint metadata"""
316
317

        config = super().get_config().copy()
318
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
319
320
321
        return config

    def get_covariance(self, x):
322
323
        """Computes the covariance of the elements of the passed layer"""

324
325
326
        x_centered_list = []

        for i in range(self.encoding_dim):
327
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
328
329

        x_centered = tf.stack(x_centered_list)
330
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
331
332
333
334
335
336
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
337
    # noinspection PyUnusedLocal
338
    def uncorrelated_feature(self, x):
339
340
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

341
        if self.encoding_dim <= 1:  # pragma: no cover
342
343
            return 0.0
        else:
344
345
            output = K.sum(
                K.square(
346
                    self.covariance
347
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
348
349
350
351
352
353
354
355
356
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


357
358
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
359
360
361
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

362
    def call(self, inputs, **kwargs):
363
        """Overrides the call method of the subclassed function"""
364
365
366
367
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
368
369
370
371
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

372
373
374
375
376
377
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

378
    def get_config(self):  # pragma: no cover
379
380
        """Updates Constraint metadata"""

381
382
383
384
385
386
387
388
389
390
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

391
    # noinspection PyAttributeOutsideInit
392
    def build(self, batch_input_shape):
393
394
        """Updates Layer's build method"""

395
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
396
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
397
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
398
            initializer="zeros",
399
400
401
402
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
403
404
        """Updates Layer's call method"""

405
406
407
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

408
    def compute_output_shape(self, input_shape):  # pragma: no cover
409
410
        """Outputs the transposed shape"""

411
412
413
        return input_shape[0], self.output_dim


414
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
415
    """
416
417
    Identity transform layer that adds KL Divergence
    to the final model loss.
418
419
    """

420
    def __init__(self, iters, warm_up_iters, *args, **kwargs):
421
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)
422
423
424
        self.is_placeholder = True
        self._iters = iters
        self._warm_up_iters = warm_up_iters
425

426
    def get_config(self):  # pragma: no cover
427
428
        """Updates Constraint metadata"""

429
        config = super().get_config().copy()
430
        config.update({"is_placeholder": self.is_placeholder})
431
432
        config.update({"_iters": self._iters})
        config.update({"_warm_up_iters": self._warm_up_iters})
433
434
435
        return config

    def call(self, distribution_a):
436
437
        """Updates Layer's call method"""

438
439
440
441
442
443
444
445
446
447
        # Define and update KL weight for warmup
        if self._warm_up_iters > 0:
            kl_weight = tf.cast(
                K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
            )
        else:
            kl_weight = tf.cast(1.0, tf.float32)

        kl_batch = kl_weight * self._regularizer(distribution_a)

448
449
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
450
451
452
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
453
        )
454
        # noinspection PyProtectedMember
455
        self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
456
457
458
459

        return distribution_a


460
class MMDiscrepancyLayer(Layer):
461
    """
462
    Identity transform layer that adds MM Discrepancy
463
464
465
    to the final model loss.
    """

466
467
    def __init__(self, batch_size, prior, iters, warm_up_iters, *args, **kwargs):
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
468
        self.is_placeholder = True
469
        self.batch_size = batch_size
470
        self.prior = prior
471
472
        self._iters = iters
        self._warm_up_iters = warm_up_iters
473

474
    def get_config(self):  # pragma: no cover
475
476
        """Updates Constraint metadata"""

477
        config = super().get_config().copy()
478
        config.update({"batch_size": self.batch_size})
479
480
        config.update({"_iters": self._iters})
        config.update({"_warmup_iters": self._warm_up_iters})
481
        config.update({"prior": self.prior})
482
483
        return config

484
    def call(self, z, **kwargs):
485
486
        """Updates Layer's call method"""

487
        true_samples = self.prior.sample(self.batch_size)
488

489
490
491
492
493
494
495
496
        # Define and update MMD weight for warmup
        if self._warm_up_iters > 0:
            mmd_weight = tf.cast(
                K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
            )
        else:
            mmd_weight = tf.cast(1.0, tf.float32)

497
        mmd_batch = mmd_weight * compute_mmd((true_samples, z))
498

499
        self.add_loss(K.mean(mmd_batch), inputs=z)
500
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
501
        self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
502
503

        return z
504
505


506
class Cluster_overlap(Layer):
507
508
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
509
    using the average inter-cluster MMD as a metric
510
511
    """

512
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
513
514
515
516
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
517
        super(Cluster_overlap, self).__init__(*args, **kwargs)
518

519
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
520
521
        """Updates Constraint metadata"""

522
523
524
525
526
527
528
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
529
530
531
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
532
533
534
535

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
536
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
537

538
539
540
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
541
542
543

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
544
        # MMD-based overlap #
545
        intercomponent_mmd = K.mean(
546
547
            tf.convert_to_tensor(
                [
548
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
549
550
551
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
552
            )
553
        )
554

555
        self.add_metric(
556
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
557
        )
558

559
560
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
561
562
563
564

        return target


565
class Dead_neuron_control(Layer):
566
567
568
569
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
570

571
572
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
573

lucas_miranda's avatar
lucas_miranda committed
574
575
576
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
577
578
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
579
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
580
581
        )

lucas_miranda's avatar
lucas_miranda committed
582
        return target