model_utils.py 17.7 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
13
from scipy.stats import entropy
14
from sklearn.metrics import pairwise_distances
15
from tensorflow.keras import backend as K
16
17
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
18
import matplotlib.pyplot as plt
lucas_miranda's avatar
lucas_miranda committed
19
import numpy as np
20
import tensorflow as tf
21
import tensorflow_probability as tfp
22

23
tfd = tfp.distributions
24
tfpl = tfp.layers
25

lucas_miranda's avatar
lucas_miranda committed
26

27
# Helper functions and classes
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

89
90
91
92
93
94
95
96
97
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
98
    kernel = tf.exp(
99
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
100
    )
lucas_miranda's avatar
lucas_miranda committed
101
    return kernel
102
103


104
@tf.function
105
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
106
107
    """

108
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
109

110
111
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
112

113
114
115
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
116

117
    """
118
119
120
121

    x = tensors[0]
    y = tensors[1]

122
123
124
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
125
    mmd = (
126
127
128
129
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
130
    return mmd
131
132


133
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
134
135
136
137
138
139
140
141
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

142
143
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
144
145
146
147
148
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
149
        log_dir: str = ".",
150
    ):
lucas_miranda's avatar
lucas_miranda committed
151
        super().__init__()
152
153
154
155
156
157
158
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
159
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
160
        self.log_dir = log_dir
161

lucas_miranda's avatar
lucas_miranda committed
162
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
163
164
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
165
166
167
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
168
169
170

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
192

lucas_miranda's avatar
lucas_miranda committed
193
194
195
196
197
198
199
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
200
201
202
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
203
            )
204
205


206
class neighbor_cluster_purity(tf.keras.callbacks.Callback):
207
208
    """

209
    Cluster entropy callback. Computes assignment local entropy over a neighborhood of radius r in the latent space
210
211
212

    """

213
    def __init__(
214
215
216
217
218
219
        self,
        encoding_dim,
        variational=True,
        validation_data=None,
        samples=10000,
        log_dir=".",
220
    ):
221
        super().__init__()
222
223
224
225
226
227
        self.enc = encoding_dim
        self.r = (
            -0.14220132706202965 * np.log2(validation_data.shape[0])
            + 0.17189696892334544 * self.enc
            + 1.6940295848037952
        )  # Empirically derived from data. See examples/set_default_entropy_radius.ipynb for details
228
        self.variational = variational
lucas_miranda's avatar
lucas_miranda committed
229
        self.validation_data = validation_data
lucas_miranda's avatar
lucas_miranda committed
230
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
231
        self.log_dir = log_dir
232
233

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
234
    def on_epoch_end(self, epoch, logs=None):
235
236
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

237
        if self.validation_data is not None and self.variational:
lucas_miranda's avatar
lucas_miranda committed
238
239

            # Get encoer and grouper from full model
240
241
242
243
            latent_distribution = [
                layer
                for layer in self.model.layers
                if layer.name == "latent_distribution"
lucas_miranda's avatar
lucas_miranda committed
244
245
246
247
248
249
250
251
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
252
                self.model.layers[0].input, latent_distribution.output
lucas_miranda's avatar
lucas_miranda committed
253
            )
lucas_miranda's avatar
lucas_miranda committed
254
255
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
256
257
            )

lucas_miranda's avatar
lucas_miranda committed
258
259
260
261
            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)
            hard_groups = groups.argmax(axis=1)
262
            max_groups = groups.max(axis=1)
lucas_miranda's avatar
lucas_miranda committed
263

264
265
            # compute pairwise distances on latent space
            pdist = pairwise_distances(encoding)
lucas_miranda's avatar
lucas_miranda committed
266

267
            # Iterate over samples and compute purity across neighbourhood
268
            self.samples = np.min([self.samples, encoding.shape[0]])
lucas_miranda's avatar
lucas_miranda committed
269
270
271
272
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
273
            neighbor_number = np.zeros(self.samples)
274

lucas_miranda's avatar
lucas_miranda committed
275
            for i, sample in enumerate(random_idxs):
276
277

                neighborhood = pdist[sample] < self.r
278
                z = hard_groups[neighborhood]
279

280
281
282
283
                # Compute Shannon entropy across samples
                neigh_entropy = entropy(z)

                purity_vector[i] = neigh_entropy
284
285
286
287
                neighbor_number[i] = np.sum(neighborhood)

            # Compute weights multiplying neighbor number and target confidence
            purity_weights = neighbor_number * max_groups
lucas_miranda's avatar
lucas_miranda committed
288

lucas_miranda's avatar
lucas_miranda committed
289
290
291
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
292
                    "neighborhood_cluster_purity",
293
                    data=np.average(purity_vector, weights=purity_weights),
lucas_miranda's avatar
lucas_miranda committed
294
295
                    step=epoch,
                )
296
297
298
299
300
301
302
303
304
305
                tf.summary.scalar(
                    "average_neighbors_in_radius",
                    data=np.average(neighbor_number),
                    step=epoch,
                )
                tf.summary.scalar(
                    "average_confidence_in_selected_cluster",
                    data=np.average(max_groups),
                    step=epoch,
                )
306
307


lucas_miranda's avatar
lucas_miranda committed
308
309
310
class uncorrelated_features_constraint(Constraint):
    """

311
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
312
313
314
315
    Useful, among others, for auto encoder bottleneck layers

    """

316
317
318
319
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

320
    def get_config(self):  # pragma: no cover
321
        """Updates Constraint metadata"""
322
323

        config = super().get_config().copy()
324
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
325
326
327
        return config

    def get_covariance(self, x):
328
329
        """Computes the covariance of the elements of the passed layer"""

330
331
332
        x_centered_list = []

        for i in range(self.encoding_dim):
333
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
334
335

        x_centered = tf.stack(x_centered_list)
336
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
337
338
339
340
341
342
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
343
    # noinspection PyUnusedLocal
344
    def uncorrelated_feature(self, x):
345
346
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

347
        if self.encoding_dim <= 1:  # pragma: no cover
348
349
            return 0.0
        else:
350
351
            output = K.sum(
                K.square(
352
                    self.covariance
353
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
354
355
356
357
358
359
360
361
362
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


363
364
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
365
366
367
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

368
    def call(self, inputs, **kwargs):
369
        """Overrides the call method of the subclassed function"""
370
371
372
373
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
374
375
376
377
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

378
379
380
381
382
383
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

384
    def get_config(self):  # pragma: no cover
385
386
        """Updates Constraint metadata"""

387
388
389
390
391
392
393
394
395
396
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

397
    # noinspection PyAttributeOutsideInit
398
    def build(self, batch_input_shape):
399
400
        """Updates Layer's build method"""

401
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
402
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
403
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
404
            initializer="zeros",
405
406
407
408
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
409
410
        """Updates Layer's call method"""

411
412
413
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

414
    def compute_output_shape(self, input_shape):  # pragma: no cover
415
416
        """Outputs the transposed shape"""

417
418
419
        return input_shape[0], self.output_dim


420
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
421
    """
422
423
    Identity transform layer that adds KL Divergence
    to the final model loss.
424
425
    """

426
427
428
429
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

430
    def get_config(self):  # pragma: no cover
431
432
        """Updates Constraint metadata"""

433
        config = super().get_config().copy()
434
        config.update({"is_placeholder": self.is_placeholder})
435
436
437
        return config

    def call(self, distribution_a):
438
439
        """Updates Layer's call method"""

440
441
442
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
443
444
445
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
446
        )
447
        # noinspection PyProtectedMember
448
449
450
451
452
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


453
class MMDiscrepancyLayer(Layer):
454
    """
455
    Identity transform layer that adds MM Discrepancy
456
457
458
    to the final model loss.
    """

459
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
460
        self.is_placeholder = True
461
        self.batch_size = batch_size
462
        self.beta = beta
463
        self.prior = prior
464
465
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

466
    def get_config(self):  # pragma: no cover
467
468
        """Updates Constraint metadata"""

469
        config = super().get_config().copy()
470
        config.update({"batch_size": self.batch_size})
471
        config.update({"beta": self.beta})
472
        config.update({"prior": self.prior})
473
474
        return config

475
    def call(self, z, **kwargs):
476
477
        """Updates Layer's call method"""

478
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
479
        # noinspection PyTypeChecker
480
        mmd_batch = self.beta * compute_mmd((true_samples, z))
481
        self.add_loss(K.mean(mmd_batch), inputs=z)
482
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
483
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
484
485

        return z
486
487


488
class Cluster_overlap(Layer):
489
490
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
491
    using the average inter-cluster MMD as a metric
492
493
    """

494
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
495
496
497
498
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
499
        super(Cluster_overlap, self).__init__(*args, **kwargs)
500

501
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
502
503
        """Updates Constraint metadata"""

504
505
506
507
508
509
510
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
511
512
513
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
514
515
516
517

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
518
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
519

520
521
522
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
523
524
525

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
526
        # MMD-based overlap #
527
        intercomponent_mmd = K.mean(
528
529
            tf.convert_to_tensor(
                [
530
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
531
532
533
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
534
            )
535
        )
536

537
        self.add_metric(
538
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
539
        )
540

541
542
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
543
544
545
546

        return target


547
class Dead_neuron_control(Layer):
548
549
550
551
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
552

553
554
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
555

lucas_miranda's avatar
lucas_miranda committed
556
557
558
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
559
560
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
561
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
562
563
        )

lucas_miranda's avatar
lucas_miranda committed
564
        return target