model_utils.py 18.2 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
13
from scipy.stats import entropy
14
from sklearn.metrics import pairwise_distances
15
from tensorflow.keras import backend as K
16
17
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
18
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
19
import numpy as np
20
import tensorflow as tf
21
import tensorflow_probability as tfp
22

23
tfd = tfp.distributions
24
tfpl = tfp.layers
25

lucas_miranda's avatar
lucas_miranda committed
26

27
# Helper functions and classes
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

89
90
91
92
93
94
95
96
97
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
98
    kernel = tf.exp(
99
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
100
    )
lucas_miranda's avatar
lucas_miranda committed
101
    return kernel
102
103


104
@tf.function
105
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
106
107
    """

108
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
109

110
111
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
112

113
114
115
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
116

117
    """
118
119
120
121

    x = tensors[0]
    y = tensors[1]

122
123
124
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
125
    mmd = (
126
127
128
129
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
130
    return mmd
131
132


133
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
134
135
136
137
138
139
140
141
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

142
143
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
144
145
146
147
148
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
149
        log_dir: str = ".",
150
    ):
lucas_miranda's avatar
lucas_miranda committed
151
        super().__init__()
152
153
154
155
156
157
158
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
159
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
160
        self.log_dir = log_dir
161

lucas_miranda's avatar
lucas_miranda committed
162
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
163
164
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
165
166
167
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
168
169
170

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
192

lucas_miranda's avatar
lucas_miranda committed
193
194
195
196
197
198
199
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
200
201
202
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
203
            )
204
205


206
class neighbor_cluster_purity(tf.keras.callbacks.Callback):
207
208
    """

209
    Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
210
211
212

    """

213
    def __init__(
214
        self,
215
216
217
218
219
220
        encoding_dim: int,
        variational: bool = True,
        validation_data: np.ndarray = None,
        samples: int = 10000,
        log_dir: str = ".",
        min_n: int = 2,
221
    ):
222
        super().__init__()
223
224
225
226
227
228
        self.enc = encoding_dim
        self.r = (
            -0.14220132706202965 * np.log2(validation_data.shape[0])
            + 0.17189696892334544 * self.enc
            + 1.6940295848037952
        )  # Empirically derived from data. See examples/set_default_entropy_radius.ipynb for details
229
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
230
        self.validation_data = validation_data
lucas_miranda's avatar
lucas_miranda committed
231
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
232
        self.log_dir = log_dir
233
        self.min_n = min_n
234
235

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
236
    def on_epoch_end(self, epoch, logs=None):
237
238
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

239
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
240
241

            # Get encoer and grouper from full model
242
243
244
245
            latent_distribution = [
                layer
                for layer in self.model.layers
                if layer.name == "latent_distribution"
lucas_miranda's avatar
lucas_miranda committed
246
247
248
249
250
251
252
253
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
254
                self.model.layers[0].input, latent_distribution.output
lucas_miranda's avatar
lucas_miranda committed
255
            )
lucas_miranda's avatar
lucas_miranda committed
256
257
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
258
259
            )

lucas_miranda's avatar
lucas_miranda committed
260
261
262
263
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)
            hard_groups = groups.argmax(axis=1)
264
            max_groups = groups.max(axis=1)
lucas_miranda's avatar
lucas_miranda committed
265

266
267
            # compute pairwise distances on latent space
            pdist = pairwise_distances(encoding)
lucas_miranda's avatar
lucas_miranda committed
268

269
            # Iterate over samples and compute purity across neighbourhood
270
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
271
272
273
274
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
275
            neighbor_number = np.zeros(self.samples)
276

lucas_miranda's avatar
lucas_miranda committed
277
            for i, sample in enumerate(random_idxs):
278
279

                neighborhood = pdist[sample] < self.r
280
                z = hard_groups[neighborhood]
281

282
                # Compute Shannon entropy across samples
283
                neigh_entropy = entropy(np.bincount(z))
284

285
                purity_vector[i] = neigh_entropy
286
287
                neighbor_number[i] = np.sum(neighborhood)

288
289
290
291
292
293
294
295
            # Compute a mask to keep only examples with a minimum of self.min_n neighbors
            mask = neighbor_number >= self.min_n

            # Filter all relevant vectors using the mask
            purity_vector = purity_vector[mask]
            neighbor_number = neighbor_number[mask]
            max_groups = max_groups[random_idxs][mask]

296
            # Compute weights multiplying neighbor number and target confidence
297
            purity_weights = neighbor_number * max_groups
lucas_miranda's avatar
lucas_miranda committed
298

lucas_miranda's avatar
lucas_miranda committed
299
300
301
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
302
                    "neighborhood_cluster_purity",
303
                    data=np.average(purity_vector, weights=purity_weights),
lucas_miranda's avatar
lucas_miranda committed
304
305
                    step=epoch,
                )
306
307
308
309
310
311
312
313
314
315
                tf.summary.scalar(
                    "average_neighbors_in_radius",
                    data=np.average(neighbor_number),
                    step=epoch,
                )
                tf.summary.scalar(
                    "average_confidence_in_selected_cluster",
                    data=np.average(max_groups),
                    step=epoch,
                )
316
317


lucas_miranda's avatar
lucas_miranda committed
318
319
320
class uncorrelated_features_constraint(Constraint):
    """

321
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
322
323
324
325
    Useful, among others, for auto encoder bottleneck layers

    """

326
327
328
329
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

330
    def get_config(self):  # pragma: no cover
331
        """Updates Constraint metadata"""
332
333

        config = super().get_config().copy()
334
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
335
336
337
        return config

    def get_covariance(self, x):
338
339
        """Computes the covariance of the elements of the passed layer"""

340
341
342
        x_centered_list = []

        for i in range(self.encoding_dim):
343
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
344
345

        x_centered = tf.stack(x_centered_list)
346
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
347
348
349
350
351
352
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
353
    # noinspection PyUnusedLocal
354
    def uncorrelated_feature(self, x):
355
356
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

357
        if self.encoding_dim <= 1:  # pragma: no cover
358
359
            return 0.0
        else:
360
361
            output = K.sum(
                K.square(
362
                    self.covariance
363
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
364
365
366
367
368
369
370
371
372
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


373
374
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
375
376
377
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

378
    def call(self, inputs, **kwargs):
379
        """Overrides the call method of the subclassed function"""
380
381
382
383
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
384
385
386
387
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

388
389
390
391
392
393
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

394
    def get_config(self):  # pragma: no cover
395
396
        """Updates Constraint metadata"""

397
398
399
400
401
402
403
404
405
406
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

407
    # noinspection PyAttributeOutsideInit
408
    def build(self, batch_input_shape):
409
410
        """Updates Layer's build method"""

411
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
412
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
413
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
414
            initializer="zeros",
415
416
417
418
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
419
420
        """Updates Layer's call method"""

421
422
423
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

424
    def compute_output_shape(self, input_shape):  # pragma: no cover
425
426
        """Outputs the transposed shape"""

427
428
429
        return input_shape[0], self.output_dim


430
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
431
    """
432
433
    Identity transform layer that adds KL Divergence
    to the final model loss.
434
435
    """

436
437
438
439
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

440
    def get_config(self):  # pragma: no cover
441
442
        """Updates Constraint metadata"""

443
        config = super().get_config().copy()
444
        config.update({"is_placeholder": self.is_placeholder})
445
446
447
        return config

    def call(self, distribution_a):
448
449
        """Updates Layer's call method"""

450
451
452
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
453
454
455
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
456
        )
457
        # noinspection PyProtectedMember
458
459
460
461
462
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


463
class MMDiscrepancyLayer(Layer):
464
    """
465
    Identity transform layer that adds MM Discrepancy
466
467
468
    to the final model loss.
    """

469
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
470
        self.is_placeholder = True
471
        self.batch_size = batch_size
472
        self.beta = beta
473
        self.prior = prior
474
475
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

476
    def get_config(self):  # pragma: no cover
477
478
        """Updates Constraint metadata"""

479
        config = super().get_config().copy()
480
        config.update({"batch_size": self.batch_size})
481
        config.update({"beta": self.beta})
482
        config.update({"prior": self.prior})
483
484
        return config

485
    def call(self, z, **kwargs):
486
487
        """Updates Layer's call method"""

488
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
489
        # noinspection PyTypeChecker
490
        mmd_batch = self.beta * compute_mmd((true_samples, z))
491
        self.add_loss(K.mean(mmd_batch), inputs=z)
492
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
493
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
494
495

        return z
496
497


498
class Cluster_overlap(Layer):
499
500
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
501
    using the average inter-cluster MMD as a metric
502
503
    """

504
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
505
506
507
508
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
509
        super(Cluster_overlap, self).__init__(*args, **kwargs)
510

511
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
512
513
        """Updates Constraint metadata"""

514
515
516
517
518
519
520
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
521
522
523
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
524
525
526
527

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
528
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
529

530
531
532
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
533
534
535

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
536
        # MMD-based overlap #
537
        intercomponent_mmd = K.mean(
538
539
            tf.convert_to_tensor(
                [
540
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
541
542
543
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
544
            )
545
        )
546

547
        self.add_metric(
548
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
549
        )
550

551
552
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
553
554
555
556

        return target


557
class Dead_neuron_control(Layer):
558
559
560
561
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
562

563
564
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
565

lucas_miranda's avatar
lucas_miranda committed
566
567
568
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
569
570
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
571
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
572
573
        )

lucas_miranda's avatar
lucas_miranda committed
574
        return target