model_utils.py 19.1 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
13
from scipy.stats import entropy
14
from sklearn.metrics import pairwise_distances
15
from tensorflow.keras import backend as K
16
17
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
18
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
19
import numpy as np
20
import tensorflow as tf
21
import tensorflow_probability as tfp
22

23
tfd = tfp.distributions
24
tfpl = tfp.layers
25

lucas_miranda's avatar
lucas_miranda committed
26

27
# Helper functions and classes
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

89
90
91
92
93
94
95
96
97
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
98
    kernel = tf.exp(
99
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
100
    )
lucas_miranda's avatar
lucas_miranda committed
101
    return kernel
102
103


104
@tf.function
105
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
106
107
    """

108
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
109

110
111
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
112

113
114
115
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
116

117
    """
118
119
120
121

    x = tensors[0]
    y = tensors[1]

122
123
124
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
125
    mmd = (
126
127
128
129
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
130
    return mmd
131
132


133
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
134
135
136
137
138
139
140
141
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

142
143
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
144
145
146
147
148
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
149
        log_dir: str = ".",
150
    ):
lucas_miranda's avatar
lucas_miranda committed
151
        super().__init__()
152
153
154
155
156
157
158
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
159
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
160
        self.log_dir = log_dir
161

lucas_miranda's avatar
lucas_miranda committed
162
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
163
164
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
165
166
167
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
168
169
170

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
192

lucas_miranda's avatar
lucas_miranda committed
193
194
195
196
197
198
199
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
200
201
202
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
203
            )
204
205


206
class neighbor_cluster_purity(tf.keras.callbacks.Callback):
207
208
    """

209
    Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
210
211
212

    """

213
    def __init__(
214
        self,
215
216
217
218
219
220
        encoding_dim: int,
        variational: bool = True,
        validation_data: np.ndarray = None,
        samples: int = 10000,
        log_dir: str = ".",
        min_n: int = 2,
221
    ):
222
        super().__init__()
223
        self.enc = encoding_dim
224
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
225
        self.validation_data = validation_data
lucas_miranda's avatar
lucas_miranda committed
226
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
227
        self.log_dir = log_dir
228
        self.min_n = min_n
229
230
231
232
233
234
        if self.validation_data is not None:
            self.r = (
                -0.14220132706202965 * np.log2(validation_data.shape[0])
                + 0.17189696892334544 * self.enc
                + 1.6940295848037952
            )  # Empirically derived from data. See examples/set_default_entropy_radius.ipynb for details
235
236

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
237
    def on_epoch_end(self, epoch, logs=None):
238
239
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

240
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
241
242

            # Get encoer and grouper from full model
243
244
245
246
            latent_distribution = [
                layer
                for layer in self.model.layers
                if layer.name == "latent_distribution"
lucas_miranda's avatar
lucas_miranda committed
247
248
249
250
251
252
253
254
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
255
                self.model.layers[0].input, latent_distribution.output
lucas_miranda's avatar
lucas_miranda committed
256
            )
lucas_miranda's avatar
lucas_miranda committed
257
258
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
259
260
            )

lucas_miranda's avatar
lucas_miranda committed
261
262
263
264
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)
            hard_groups = groups.argmax(axis=1)
265
            max_groups = groups.max(axis=1)
lucas_miranda's avatar
lucas_miranda committed
266

267
268
            # compute pairwise distances on latent space
            pdist = pairwise_distances(encoding)
lucas_miranda's avatar
lucas_miranda committed
269

270
            # Iterate over samples and compute purity across neighbourhood
271
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
272
273
274
275
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
276
            neighbor_number = np.zeros(self.samples)
277

lucas_miranda's avatar
lucas_miranda committed
278
            for i, sample in enumerate(random_idxs):
279
280

                neighborhood = pdist[sample] < self.r
281
                z = hard_groups[neighborhood]
282

283
                # Compute Shannon entropy across samples
284
                neigh_entropy = entropy(np.bincount(z))
285

286
                purity_vector[i] = neigh_entropy
287
288
                neighbor_number[i] = np.sum(neighborhood)

289
290
291
292
293
294
295
296
            # Compute a mask to keep only examples with a minimum of self.min_n neighbors
            mask = neighbor_number >= self.min_n

            # Filter all relevant vectors using the mask
            purity_vector = purity_vector[mask]
            neighbor_number = neighbor_number[mask]
            max_groups = max_groups[random_idxs][mask]

297
            # Compute weights multiplying neighbor number and target confidence
298
            purity_weights = neighbor_number * max_groups
lucas_miranda's avatar
lucas_miranda committed
299

lucas_miranda's avatar
lucas_miranda committed
300
301
302
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
303
                    "neighborhood_cluster_purity",
304
                    data=np.average(purity_vector, weights=purity_weights),
lucas_miranda's avatar
lucas_miranda committed
305
306
                    step=epoch,
                )
307
308
309
310
311
312
313
314
315
316
                tf.summary.scalar(
                    "average_neighbors_in_radius",
                    data=np.average(neighbor_number),
                    step=epoch,
                )
                tf.summary.scalar(
                    "average_confidence_in_selected_cluster",
                    data=np.average(max_groups),
                    step=epoch,
                )
317
318


lucas_miranda's avatar
lucas_miranda committed
319
320
321
class uncorrelated_features_constraint(Constraint):
    """

322
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
323
324
325
326
    Useful, among others, for auto encoder bottleneck layers

    """

327
328
329
330
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

331
    def get_config(self):  # pragma: no cover
332
        """Updates Constraint metadata"""
333
334

        config = super().get_config().copy()
335
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
336
337
338
        return config

    def get_covariance(self, x):
339
340
        """Computes the covariance of the elements of the passed layer"""

341
342
343
        x_centered_list = []

        for i in range(self.encoding_dim):
344
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
345
346

        x_centered = tf.stack(x_centered_list)
347
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
348
349
350
351
352
353
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
354
    # noinspection PyUnusedLocal
355
    def uncorrelated_feature(self, x):
356
357
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

358
        if self.encoding_dim <= 1:  # pragma: no cover
359
360
            return 0.0
        else:
361
362
            output = K.sum(
                K.square(
363
                    self.covariance
364
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
365
366
367
368
369
370
371
372
373
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


374
375
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
376
377
378
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

379
    def call(self, inputs, **kwargs):
380
        """Overrides the call method of the subclassed function"""
381
382
383
384
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
385
386
387
388
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

389
390
391
392
393
394
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

395
    def get_config(self):  # pragma: no cover
396
397
        """Updates Constraint metadata"""

398
399
400
401
402
403
404
405
406
407
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

408
    # noinspection PyAttributeOutsideInit
409
    def build(self, batch_input_shape):
410
411
        """Updates Layer's build method"""

412
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
413
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
414
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
415
            initializer="zeros",
416
417
418
419
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
420
421
        """Updates Layer's call method"""

422
423
424
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

425
    def compute_output_shape(self, input_shape):  # pragma: no cover
426
427
        """Outputs the transposed shape"""

428
429
430
        return input_shape[0], self.output_dim


431
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
432
    """
433
434
    Identity transform layer that adds KL Divergence
    to the final model loss.
435
436
    """

437
    def __init__(self, iters, warm_up_iters, *args, **kwargs):
438
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)
439
440
441
        self.is_placeholder = True
        self._iters = iters
        self._warm_up_iters = warm_up_iters
442

443
    def get_config(self):  # pragma: no cover
444
445
        """Updates Constraint metadata"""

446
        config = super().get_config().copy()
447
        config.update({"is_placeholder": self.is_placeholder})
448
449
        config.update({"_iters": self._iters})
        config.update({"_warm_up_iters": self._warm_up_iters})
450
451
452
        return config

    def call(self, distribution_a):
453
454
        """Updates Layer's call method"""

455
456
457
458
459
460
461
462
463
464
        # Define and update KL weight for warmup
        if self._warm_up_iters > 0:
            kl_weight = tf.cast(
                K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
            )
        else:
            kl_weight = tf.cast(1.0, tf.float32)

        kl_batch = kl_weight * self._regularizer(distribution_a)

465
466
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
467
468
469
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
470
        )
471
        # noinspection PyProtectedMember
472
        self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
473
474
475
476

        return distribution_a


477
class MMDiscrepancyLayer(Layer):
478
    """
479
    Identity transform layer that adds MM Discrepancy
480
481
482
    to the final model loss.
    """

483
484
    def __init__(self, batch_size, prior, iters, warm_up_iters, *args, **kwargs):
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
485
        self.is_placeholder = True
486
        self.batch_size = batch_size
487
        self.prior = prior
488
489
        self._iters = iters
        self._warm_up_iters = warm_up_iters
490

491
    def get_config(self):  # pragma: no cover
492
493
        """Updates Constraint metadata"""

494
        config = super().get_config().copy()
495
        config.update({"batch_size": self.batch_size})
496
497
        config.update({"_iters": self._iters})
        config.update({"_warmup_iters": self._warm_up_iters})
498
        config.update({"prior": self.prior})
499
500
        return config

501
    def call(self, z, **kwargs):
502
503
        """Updates Layer's call method"""

504
        true_samples = self.prior.sample(self.batch_size)
505

506
507
508
509
510
511
512
513
        # Define and update MMD weight for warmup
        if self._warm_up_iters > 0:
            mmd_weight = tf.cast(
                K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
            )
        else:
            mmd_weight = tf.cast(1.0, tf.float32)

514
        mmd_batch = mmd_weight * compute_mmd((true_samples, z))
515

516
        self.add_loss(K.mean(mmd_batch), inputs=z)
517
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
518
        self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
519
520

        return z
521
522


523
class Cluster_overlap(Layer):
524
525
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
526
    using the average inter-cluster MMD as a metric
527
528
    """

529
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
530
531
532
533
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
534
        super(Cluster_overlap, self).__init__(*args, **kwargs)
535

536
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
537
538
        """Updates Constraint metadata"""

539
540
541
542
543
544
545
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
546
547
548
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
549
550
551
552

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
553
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
554

555
556
557
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
558
559
560

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
561
        # MMD-based overlap #
562
        intercomponent_mmd = K.mean(
563
564
            tf.convert_to_tensor(
                [
565
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
566
567
568
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
569
            )
570
        )
571

572
        self.add_metric(
573
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
574
        )
575

576
577
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
578
579
580
581

        return target


582
class Dead_neuron_control(Layer):
583
584
585
586
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
587

588
589
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
590

lucas_miranda's avatar
lucas_miranda committed
591
592
593
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
594
595
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
596
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
597
598
        )

lucas_miranda's avatar
lucas_miranda committed
599
        return target