model_utils.py 18.1 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
13
from scipy.stats import entropy
14
from sklearn.neighbors import NearestNeighbors
15
from tensorflow.keras import backend as K
16
17
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
18
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
19
import numpy as np
20
import tensorflow as tf
21
import tensorflow_probability as tfp
22

23
tfd = tfp.distributions
24
tfpl = tfp.layers
25

lucas_miranda's avatar
lucas_miranda committed
26

27
# Helper functions and classes
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

89
90
91
92
93
94
95
96
97
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
98
    kernel = tf.exp(
99
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
100
    )
lucas_miranda's avatar
lucas_miranda committed
101
    return kernel
102
103


104
@tf.function
105
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
106
107
    """

108
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
109

110
111
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
112

113
114
115
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
116

117
    """
118
119
120
121

    x = tensors[0]
    y = tensors[1]

122
123
124
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
125
    mmd = (
126
127
128
129
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
130
    return mmd
131
132


133
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
134
135
136
137
138
139
140
141
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

142
143
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
144
145
146
147
148
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
149
        log_dir: str = ".",
150
    ):
lucas_miranda's avatar
lucas_miranda committed
151
        super().__init__()
152
153
154
155
156
157
158
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
159
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
160
        self.log_dir = log_dir
161

lucas_miranda's avatar
lucas_miranda committed
162
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
163
164
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
165
166
167
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
168
169
170

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
192

lucas_miranda's avatar
lucas_miranda committed
193
194
195
196
197
198
199
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
200
201
202
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
203
            )
204
205


206
class neighbor_cluster_purity(tf.keras.callbacks.Callback):
207
208
    """

209
    Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
210
211
212

    """

213
    def __init__(
214
        self,
215
216
217
        encoding_dim: int,
        variational: bool = True,
        validation_data: np.ndarray = None,
218
        k: int = 100,
219
220
        samples: int = 10000,
        log_dir: str = ".",
221
    ):
222
        super().__init__()
223
        self.enc = encoding_dim
224
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
225
        self.validation_data = validation_data
226
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
227
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
228
        self.log_dir = log_dir
229
230

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
231
    def on_epoch_end(self, epoch, logs=None):
232
233
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

234
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
235
236

            # Get encoer and grouper from full model
237
238
239
240
            latent_distribution = [
                layer
                for layer in self.model.layers
                if layer.name == "latent_distribution"
lucas_miranda's avatar
lucas_miranda committed
241
242
243
244
245
246
247
248
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
249
                self.model.layers[0].input, latent_distribution.output
lucas_miranda's avatar
lucas_miranda committed
250
            )
lucas_miranda's avatar
lucas_miranda committed
251
252
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
253
254
            )

lucas_miranda's avatar
lucas_miranda committed
255
256
257
258
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)
            hard_groups = groups.argmax(axis=1)
259
            max_groups = groups.max(axis=1)
lucas_miranda's avatar
lucas_miranda committed
260

261
            # compute pairwise distances on latent space
262
            knn = NearestNeighbors().fit(encoding)
lucas_miranda's avatar
lucas_miranda committed
263

264
            # Iterate over samples and compute purity across neighbourhood
265
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
266
267
268
269
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
270

lucas_miranda's avatar
lucas_miranda committed
271
            for i, sample in enumerate(random_idxs):
272

273
274
275
276
                neighborhood = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                ).flatten()

277
                z = hard_groups[neighborhood]
278

279
                # Compute Shannon entropy across samples
280
                neigh_entropy = entropy(np.bincount(z))
281

282
                # Add result to pre allocated array
283
                purity_vector[i] = neigh_entropy
lucas_miranda's avatar
lucas_miranda committed
284

lucas_miranda's avatar
lucas_miranda committed
285
286
287
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
288
289
                    "average_neighborhood_cluster_entropy",
                    data=np.average(purity_vector, weights=max_groups[random_idxs]),
290
291
292
293
294
295
296
                    step=epoch,
                )
                tf.summary.scalar(
                    "average_confidence_in_selected_cluster",
                    data=np.average(max_groups),
                    step=epoch,
                )
297
298


lucas_miranda's avatar
lucas_miranda committed
299
300
301
class uncorrelated_features_constraint(Constraint):
    """

302
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
303
304
305
306
    Useful, among others, for auto encoder bottleneck layers

    """

307
308
309
310
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

311
    def get_config(self):  # pragma: no cover
312
        """Updates Constraint metadata"""
313
314

        config = super().get_config().copy()
315
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
316
317
318
        return config

    def get_covariance(self, x):
319
320
        """Computes the covariance of the elements of the passed layer"""

321
322
323
        x_centered_list = []

        for i in range(self.encoding_dim):
324
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
325
326

        x_centered = tf.stack(x_centered_list)
327
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
328
329
330
331
332
333
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
334
    # noinspection PyUnusedLocal
335
    def uncorrelated_feature(self, x):
336
337
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

338
        if self.encoding_dim <= 1:  # pragma: no cover
339
340
            return 0.0
        else:
341
342
            output = K.sum(
                K.square(
343
                    self.covariance
344
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
345
346
347
348
349
350
351
352
353
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


354
355
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
356
357
358
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

359
    def call(self, inputs, **kwargs):
360
        """Overrides the call method of the subclassed function"""
361
362
363
364
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
365
366
367
368
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

369
370
371
372
373
374
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

375
    def get_config(self):  # pragma: no cover
376
377
        """Updates Constraint metadata"""

378
379
380
381
382
383
384
385
386
387
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

388
    # noinspection PyAttributeOutsideInit
389
    def build(self, batch_input_shape):
390
391
        """Updates Layer's build method"""

392
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
393
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
394
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
395
            initializer="zeros",
396
397
398
399
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
400
401
        """Updates Layer's call method"""

402
403
404
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

405
    def compute_output_shape(self, input_shape):  # pragma: no cover
406
407
        """Outputs the transposed shape"""

408
409
410
        return input_shape[0], self.output_dim


411
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
412
    """
413
414
    Identity transform layer that adds KL Divergence
    to the final model loss.
415
416
    """

417
    def __init__(self, iters, warm_up_iters, *args, **kwargs):
418
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)
419
420
421
        self.is_placeholder = True
        self._iters = iters
        self._warm_up_iters = warm_up_iters
422

423
    def get_config(self):  # pragma: no cover
424
425
        """Updates Constraint metadata"""

426
        config = super().get_config().copy()
427
        config.update({"is_placeholder": self.is_placeholder})
428
429
        config.update({"_iters": self._iters})
        config.update({"_warm_up_iters": self._warm_up_iters})
430
431
432
        return config

    def call(self, distribution_a):
433
434
        """Updates Layer's call method"""

435
436
437
438
439
440
441
442
443
444
        # Define and update KL weight for warmup
        if self._warm_up_iters > 0:
            kl_weight = tf.cast(
                K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
            )
        else:
            kl_weight = tf.cast(1.0, tf.float32)

        kl_batch = kl_weight * self._regularizer(distribution_a)

445
446
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
447
448
449
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
450
        )
451
        # noinspection PyProtectedMember
452
        self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
453
454
455
456

        return distribution_a


457
class MMDiscrepancyLayer(Layer):
458
    """
459
    Identity transform layer that adds MM Discrepancy
460
461
462
    to the final model loss.
    """

463
464
    def __init__(self, batch_size, prior, iters, warm_up_iters, *args, **kwargs):
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
465
        self.is_placeholder = True
466
        self.batch_size = batch_size
467
        self.prior = prior
468
469
        self._iters = iters
        self._warm_up_iters = warm_up_iters
470

471
    def get_config(self):  # pragma: no cover
472
473
        """Updates Constraint metadata"""

474
        config = super().get_config().copy()
475
        config.update({"batch_size": self.batch_size})
476
477
        config.update({"_iters": self._iters})
        config.update({"_warmup_iters": self._warm_up_iters})
478
        config.update({"prior": self.prior})
479
480
        return config

481
    def call(self, z, **kwargs):
482
483
        """Updates Layer's call method"""

484
        true_samples = self.prior.sample(self.batch_size)
485

486
487
488
489
490
491
492
493
        # Define and update MMD weight for warmup
        if self._warm_up_iters > 0:
            mmd_weight = tf.cast(
                K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
            )
        else:
            mmd_weight = tf.cast(1.0, tf.float32)

494
        mmd_batch = mmd_weight * compute_mmd((true_samples, z))
495

496
        self.add_loss(K.mean(mmd_batch), inputs=z)
497
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
498
        self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
499
500

        return z
501
502


503
class Cluster_overlap(Layer):
504
505
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
506
    using the average inter-cluster MMD as a metric
507
508
    """

509
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
510
511
512
513
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
514
        super(Cluster_overlap, self).__init__(*args, **kwargs)
515

516
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
517
518
        """Updates Constraint metadata"""

519
520
521
522
523
524
525
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
526
527
528
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
529
530
531
532

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
533
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
534

535
536
537
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
538
539
540

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
541
        # MMD-based overlap #
542
        intercomponent_mmd = K.mean(
543
544
            tf.convert_to_tensor(
                [
545
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
546
547
548
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
549
            )
550
        )
551

552
        self.add_metric(
553
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
554
        )
555

556
557
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
558
559
560
561

        return target


562
class Dead_neuron_control(Layer):
563
564
565
566
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
567

568
569
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
570

lucas_miranda's avatar
lucas_miranda committed
571
572
573
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
574
575
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
576
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
577
578
        )

lucas_miranda's avatar
lucas_miranda committed
579
        return target