model_utils.py 19.2 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
14
15
16
17

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
18
from scipy.stats import entropy
19
from sklearn.neighbors import NearestNeighbors
20
from tensorflow.keras import backend as K
21
22
23
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer

24
tfd = tfp.distributions
25
tfpl = tfp.layers
26

lucas_miranda's avatar
lucas_miranda committed
27

28
# Helper functions and classes
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
48
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

90
91
92
93
94
95
96
97
98
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
99
    kernel = tf.exp(
100
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
101
    )
lucas_miranda's avatar
lucas_miranda committed
102
    return kernel
103
104


105
@tf.function
106
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
107
108
    """

109
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
110

111
112
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
113

114
115
116
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
117

118
    """
119
120
121
122

    x = tensors[0]
    y = tensors[1]

123
124
125
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
126
    mmd = (
127
128
129
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
130
    )
lucas_miranda's avatar
lucas_miranda committed
131
    return mmd
132
133


134
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
135
136
137
138
139
140
141
142
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

143
    def __init__(
144
145
146
147
148
149
150
        self,
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
        log_dir: str = ".",
151
    ):
lucas_miranda's avatar
lucas_miranda committed
152
        super().__init__()
153
154
155
156
157
158
159
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
160
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
161
        self.log_dir = log_dir
162

lucas_miranda's avatar
lucas_miranda committed
163
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
164
165
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
166
167
168
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
169
170
171

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
193

lucas_miranda's avatar
lucas_miranda committed
194
195
196
197
198
199
200
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
201
202
203
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
204
            )
205
206


207
class neighbor_latent_entropy(tf.keras.callbacks.Callback):
208
209
    """

210
211
    Latent space entropy callback. Computes the entropy of cluster assignment across k nearest neighbors of a subset
    of samples in the latent space.
212
213
214

    """

215
    def __init__(
216
217
218
219
220
221
222
        self,
        encoding_dim: int,
        variational: bool = True,
        validation_data: np.ndarray = None,
        k: int = 100,
        samples: int = 10000,
        log_dir: str = ".",
223
    ):
224
        super().__init__()
225
        self.enc = encoding_dim
226
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
227
        self.validation_data = validation_data
228
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
229
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
230
        self.log_dir = log_dir
231
232

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
233
    def on_epoch_end(self, epoch, logs=None):
234
235
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

236
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
237
238

            # Get encoer and grouper from full model
239
240
241
242
            latent_distribution = [
                layer
                for layer in self.model.layers
                if layer.name == "latent_distribution"
lucas_miranda's avatar
lucas_miranda committed
243
244
245
246
247
248
249
250
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
251
                self.model.layers[0].input, latent_distribution.output
lucas_miranda's avatar
lucas_miranda committed
252
            )
lucas_miranda's avatar
lucas_miranda committed
253
254
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
255
256
            )

lucas_miranda's avatar
lucas_miranda committed
257
258
259
260
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)
            hard_groups = groups.argmax(axis=1)
261
            max_groups = groups.max(axis=1)
lucas_miranda's avatar
lucas_miranda committed
262

263
            # compute pairwise distances on latent space
264
            knn = NearestNeighbors().fit(encoding)
lucas_miranda's avatar
lucas_miranda committed
265

266
            # Iterate over samples and compute purity across neighbourhood
267
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
268
269
270
271
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
272

lucas_miranda's avatar
lucas_miranda committed
273
            for i, sample in enumerate(random_idxs):
274
275
276
277
                neighborhood = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                ).flatten()

278
                z = hard_groups[neighborhood]
279

280
                # Compute Shannon entropy across samples
281
                neigh_entropy = entropy(np.bincount(z))
282

283
                # Add result to pre allocated array
284
                purity_vector[i] = neigh_entropy
lucas_miranda's avatar
lucas_miranda committed
285

lucas_miranda's avatar
lucas_miranda committed
286
287
288
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
289
                    "average_neighborhood_cluster_entropy",
290
                    data=np.average(purity_vector, weights=max_groups[random_idxs]),
291
292
293
294
295
296
297
                    step=epoch,
                )
                tf.summary.scalar(
                    "average_confidence_in_selected_cluster",
                    data=np.average(max_groups),
                    step=epoch,
                )
298
299


lucas_miranda's avatar
lucas_miranda committed
300
301
302
class uncorrelated_features_constraint(Constraint):
    """

303
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
304
305
306
307
    Useful, among others, for auto encoder bottleneck layers

    """

308
309
310
311
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

312
    def get_config(self):  # pragma: no cover
313
        """Updates Constraint metadata"""
314
315

        config = super().get_config().copy()
316
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
317
318
319
        return config

    def get_covariance(self, x):
320
321
        """Computes the covariance of the elements of the passed layer"""

322
323
324
        x_centered_list = []

        for i in range(self.encoding_dim):
325
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
326
327

        x_centered = tf.stack(x_centered_list)
328
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
329
330
331
332
333
334
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
335
    # noinspection PyUnusedLocal
336
    def uncorrelated_feature(self, x):
337
338
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

339
        if self.encoding_dim <= 1:  # pragma: no cover
340
341
            return 0.0
        else:
342
343
            output = K.sum(
                K.square(
344
                    self.covariance
345
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
346
347
348
349
350
351
352
353
354
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


355
356
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
357
358
359
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

360
    def call(self, inputs, **kwargs):
361
        """Overrides the call method of the subclassed function"""
362
363
364
365
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
366
367
368
369
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

370
371
372
373
374
375
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

376
    def get_config(self):  # pragma: no cover
377
378
        """Updates Constraint metadata"""

379
380
381
382
383
384
385
386
387
388
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

389
    # noinspection PyAttributeOutsideInit
390
    def build(self, batch_input_shape):
391
392
        """Updates Layer's build method"""

393
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
394
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
395
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
396
            initializer="zeros",
397
398
399
400
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
401
402
        """Updates Layer's call method"""

403
404
405
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

406
    def compute_output_shape(self, input_shape):  # pragma: no cover
407
408
        """Outputs the transposed shape"""

409
410
411
        return input_shape[0], self.output_dim


412
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
413
    """
414
415
    Identity transform layer that adds KL Divergence
    to the final model loss.
416
417
    """

418
    def __init__(self, iters, warm_up_iters, annealing_mode="sigmoid", *args, **kwargs):
419
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)
420
421
422
        self.is_placeholder = True
        self._iters = iters
        self._warm_up_iters = warm_up_iters
423
        self._annealing_mode = annealing_mode
424

425
    def get_config(self):  # pragma: no cover
426
427
        """Updates Constraint metadata"""

428
        config = super().get_config().copy()
429
        config.update({"is_placeholder": self.is_placeholder})
430
431
        config.update({"_iters": self._iters})
        config.update({"_warm_up_iters": self._warm_up_iters})
432
        config.update({"_annealing_mode": self._annealing_mode})
433
434
435
        return config

    def call(self, distribution_a):
436
437
        """Updates Layer's call method"""

438
439
        # Define and update KL weight for warmup
        if self._warm_up_iters > 0:
440
            if self._annealing_mode in ["linear", "sigmoid"]:
441
442
443
                kl_weight = tf.cast(
                    K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
                )
444
445
                if self._annealing_mode == "sigmoid":
                    kl_weight = 1.0 / (1.0 + tf.exp(-kl_weight))
446
            else:
447
448
449
                raise NotImplementedError(
                    "annealing_mode must be one of 'linear' and 'sigmoid'"
                )
450
451
452
453
454
        else:
            kl_weight = tf.cast(1.0, tf.float32)

        kl_batch = kl_weight * self._regularizer(distribution_a)

455
456
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
457
458
459
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
460
        )
461
        # noinspection PyProtectedMember
462
        self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
463
464
465
466

        return distribution_a


467
class MMDiscrepancyLayer(Layer):
468
    """
469
    Identity transform layer that adds MM Discrepancy
470
471
472
    to the final model loss.
    """

473
474
475
476
477
478
479
480
481
482
    def __init__(
        self,
        batch_size,
        prior,
        iters,
        warm_up_iters,
        annealing_mode="sigmoid",
        *args,
        **kwargs
    ):
483
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
484
        self.is_placeholder = True
485
        self.batch_size = batch_size
486
        self.prior = prior
487
488
        self._iters = iters
        self._warm_up_iters = warm_up_iters
489
        self._annealing_mode = annealing_mode
490

491
    def get_config(self):  # pragma: no cover
492
493
        """Updates Constraint metadata"""

494
        config = super().get_config().copy()
495
        config.update({"batch_size": self.batch_size})
496
497
        config.update({"_iters": self._iters})
        config.update({"_warmup_iters": self._warm_up_iters})
498
        config.update({"prior": self.prior})
499
        config.update({"_annealing_mode": self._annealing_mode})
500
501
        return config

502
    def call(self, z, **kwargs):
503
504
        """Updates Layer's call method"""

505
        true_samples = self.prior.sample(self.batch_size)
506

507
508
        # Define and update MMD weight for warmup
        if self._warm_up_iters > 0:
509
510
511
512
513
514
515
516
517
518
            if self._annealing_mode in ["linear", "sigmoid"]:
                mmd_weight = tf.cast(
                    K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
                )
                if self._annealing_mode == "sigmoid":
                    mmd_weight = 1.0 / (1.0 + tf.exp(-mmd_weight))
            else:
                raise NotImplementedError(
                    "annealing_mode must be one of 'linear' and 'sigmoid'"
                )
519
520
521
        else:
            mmd_weight = tf.cast(1.0, tf.float32)

522
        mmd_batch = mmd_weight * compute_mmd((true_samples, z))
523

524
        self.add_loss(K.mean(mmd_batch), inputs=z)
525
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
526
        self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
527
528

        return z
529
530


531
class Cluster_overlap(Layer):
532
533
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
534
    using the average inter-cluster MMD as a metric
535
536
    """

537
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
538
539
540
541
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
542
        super(Cluster_overlap, self).__init__(*args, **kwargs)
543

544
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
545
546
        """Updates Constraint metadata"""

547
548
549
550
551
552
553
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
554
555
556
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
557
558
559
560

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
561
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
562

563
564
565
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
566
567
568

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
569
        # MMD-based overlap #
570
        intercomponent_mmd = K.mean(
571
572
            tf.convert_to_tensor(
                [
573
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
574
575
576
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
577
            )
578
        )
579

580
        self.add_metric(
581
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
582
        )
583

584
585
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
586
587
588
589

        return target


590
class Dead_neuron_control(Layer):
591
592
593
594
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
595

596
597
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
598

lucas_miranda's avatar
lucas_miranda committed
599
600
601
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
602
603
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
604
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
605
606
        )

lucas_miranda's avatar
lucas_miranda committed
607
        return target