model_utils.py 17.8 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
from sklearn.neighbors import NearestNeighbors
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

87
88
89
90
91
92
93
94
95
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
96
    kernel = tf.exp(
97
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
98
    )
lucas_miranda's avatar
lucas_miranda committed
99
    return kernel
100
101


102
@tf.function
103
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
104
105
    """

106
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
107

108
109
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
110

111
112
113
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
114

115
    """
116
117
118
119

    x = tensors[0]
    y = tensors[1]

120
121
122
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
123
    mmd = (
124
125
126
127
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
128
    return mmd
129
130


131
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
132
133
134
135
136
137
138
139
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

140
141
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
lucas_miranda's avatar
lucas_miranda committed
147
        log_dir: str = ".",
148
    ):
lucas_miranda's avatar
lucas_miranda committed
149
        super().__init__()
150
151
152
153
154
155
156
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
157
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
158
        self.log_dir = log_dir
159

lucas_miranda's avatar
lucas_miranda committed
160
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
161
162
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
163
164
165
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
166
167
168

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
190

lucas_miranda's avatar
lucas_miranda committed
191
192
193
194
195
196
197
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
198
199
200
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
201
            )
202
203
204
205
206
207
208
209
210


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

lucas_miranda's avatar
lucas_miranda committed
211
    def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."):
212
        super().__init__()
lucas_miranda's avatar
lucas_miranda committed
213
        self.validation_data = validation_data
214
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
215
        self.samples = samples
lucas_miranda's avatar
lucas_miranda committed
216
        self.log_dir = log_dir
217
218

    # noinspection PyMethodOverriding,PyTypeChecker
lucas_miranda's avatar
lucas_miranda committed
219
    def on_epoch_end(self, epoch, logs=None):
220
221
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

lucas_miranda's avatar
lucas_miranda committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        if self.validation_data is not None:

            # Get encoer and grouper from full model
            cluster_means = [
                layer for layer in self.model.layers if layer.name == "cluster_means"
            ][0]
            cluster_assignment = [
                layer
                for layer in self.model.layers
                if layer.name == "cluster_assignment"
            ][0]

            encoder = tf.keras.models.Model(
                self.model.layers[0].input, cluster_means.output
lucas_miranda's avatar
lucas_miranda committed
236
            )
lucas_miranda's avatar
lucas_miranda committed
237
238
            grouper = tf.keras.models.Model(
                self.model.layers[0].input, cluster_assignment.output
lucas_miranda's avatar
lucas_miranda committed
239
240
            )

lucas_miranda's avatar
lucas_miranda committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
            print(self.validation_data)

            # Use encoder and grouper to predict on validation data
            encoding = encoder.predict(self.validation_data)
            groups = grouper.predict(self.validation_data)

            # Multiply encodings by groups, to get a weighted version of the matrix
            encoding = (
                encoding
                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
            )
            hard_groups = groups.argmax(axis=1)

            # Fit KNN model
            knn = NearestNeighbors().fit(encoding)

            # Iterate over samples and compute purity over k neighbours
            random_idxs = np.random.choice(
                range(encoding.shape[0]), self.samples, replace=False
            )
            purity_vector = np.zeros(self.samples)
            for i, sample in enumerate(random_idxs):
                indexes = knn.kneighbors(
                    encoding[sample][np.newaxis, :], self.k, return_distance=False
                )
                purity_vector[i] = (
                    np.sum(hard_groups[indexes] == hard_groups[sample])
                    / self.k
                    * np.max(groups[sample])
                )

lucas_miranda's avatar
lucas_miranda committed
272
273
274
275
276
277
278
            writer = tf.summary.create_file_writer(self.log_dir)
            with writer.as_default():
                tf.summary.scalar(
                    "knn_cluster_purity",
                    data=purity_vector.mean(),
                    step=epoch,
                )
279
280


lucas_miranda's avatar
lucas_miranda committed
281
282
283
class uncorrelated_features_constraint(Constraint):
    """

284
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
285
286
287
288
    Useful, among others, for auto encoder bottleneck layers

    """

289
290
291
292
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

293
    def get_config(self):  # pragma: no cover
294
        """Updates Constraint metadata"""
295
296

        config = super().get_config().copy()
297
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
298
299
300
        return config

    def get_covariance(self, x):
301
302
        """Computes the covariance of the elements of the passed layer"""

303
304
305
        x_centered_list = []

        for i in range(self.encoding_dim):
306
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
307
308

        x_centered = tf.stack(x_centered_list)
309
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
310
311
312
313
314
315
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
316
    # noinspection PyUnusedLocal
317
    def uncorrelated_feature(self, x):
318
319
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

320
        if self.encoding_dim <= 1:  # pragma: no cover
321
322
            return 0.0
        else:
323
324
            output = K.sum(
                K.square(
325
                    self.covariance
326
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
327
328
329
330
331
332
333
334
335
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


336
337
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
338
339
340
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

341
    def call(self, inputs, **kwargs):
342
        """Overrides the call method of the subclassed function"""
343
344
345
346
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
347
348
349
350
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

351
352
353
354
355
356
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

357
    def get_config(self):  # pragma: no cover
358
359
        """Updates Constraint metadata"""

360
361
362
363
364
365
366
367
368
369
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

370
    # noinspection PyAttributeOutsideInit
371
    def build(self, batch_input_shape):
372
373
        """Updates Layer's build method"""

374
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
375
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
376
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
377
            initializer="zeros",
378
379
380
381
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
382
383
        """Updates Layer's call method"""

384
385
386
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

387
    def compute_output_shape(self, input_shape):  # pragma: no cover
388
389
        """Outputs the transposed shape"""

390
391
392
        return input_shape[0], self.output_dim


393
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
394
    """
395
396
    Identity transform layer that adds KL Divergence
    to the final model loss.
397
398
    """

399
400
401
402
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

403
    def get_config(self):  # pragma: no cover
404
405
        """Updates Constraint metadata"""

406
        config = super().get_config().copy()
407
        config.update({"is_placeholder": self.is_placeholder})
408
409
410
        return config

    def call(self, distribution_a):
411
412
        """Updates Layer's call method"""

413
414
415
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
416
417
418
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
419
        )
420
        # noinspection PyProtectedMember
421
422
423
424
425
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


426
class MMDiscrepancyLayer(Layer):
427
    """
428
    Identity transform layer that adds MM Discrepancy
429
430
431
    to the final model loss.
    """

432
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
433
        self.is_placeholder = True
434
        self.batch_size = batch_size
435
        self.beta = beta
436
        self.prior = prior
437
438
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

439
    def get_config(self):  # pragma: no cover
440
441
        """Updates Constraint metadata"""

442
        config = super().get_config().copy()
443
        config.update({"batch_size": self.batch_size})
444
        config.update({"beta": self.beta})
445
        config.update({"prior": self.prior})
446
447
        return config

448
    def call(self, z, **kwargs):
449
450
        """Updates Layer's call method"""

451
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
452
        # noinspection PyTypeChecker
453
        mmd_batch = self.beta * compute_mmd((true_samples, z))
454
        self.add_loss(K.mean(mmd_batch), inputs=z)
455
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
456
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
457
458

        return z
459
460


461
class Cluster_overlap(Layer):
462
463
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
464
    using the average inter-cluster MMD as a metric
465
466
    """

467
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
468
469
470
471
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
472
        super(Cluster_overlap, self).__init__(*args, **kwargs)
473

474
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
475
476
        """Updates Constraint metadata"""

477
478
479
480
481
482
483
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
484
485
486
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
487
488
489
490

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
491
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
492

493
494
495
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
496
497
498

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
499
        # MMD-based overlap #
500
        intercomponent_mmd = K.mean(
501
502
            tf.convert_to_tensor(
                [
503
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
504
505
506
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
507
            )
508
        )
509

510
        self.add_metric(
511
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
512
        )
513

514
515
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
516
517
518
519

        return target


520
class Dead_neuron_control(Layer):
521
522
523
524
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
525

526
527
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
528

lucas_miranda's avatar
lucas_miranda committed
529
530
531
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
532
533
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
534
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
535
536
        )

lucas_miranda's avatar
lucas_miranda committed
537
        return target
538
539
540
541
542
543
544


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
545
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
546
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
547
        self.axis = axis
548
549
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

550
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
551
552
        """Updates Constraint metadata"""

553
554
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
555
        config.update({"axis": self.axis})
556
557

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
558
559
        """Updates Layer's call method"""

560
561
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
562
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
563
564

        # Adds metric that monitors dead neurons in the latent space
565
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
566

567
568
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
569
570

        return z