model_utils.py 17.7 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
14
15
16
17

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
18
from functools import partial
19
from tensorflow.keras import backend as K
20
21
22
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer

23
tfd = tfp.distributions
24
tfpl = tfp.layers
25

lucas_miranda's avatar
lucas_miranda committed
26

27
# Helper functions and classes
28
29
30
31
@tf.function
def compute_shannon_entropy(tensor):
    """Computes Shannon entropy for a given tensor"""
    tensor = tf.cast(tensor, tf.dtypes.int32)
32
33
34
    bins = (
        tf.math.bincount(tensor, dtype=tf.dtypes.float32)
        / tf.cast(tf.shape(tensor), tf.dtypes.float32)[0]
35
36
37
38
39
40
41
    )
    return -tf.reduce_sum(bins * tf.math.log(bins + 1e-5))


@tf.function
def get_k_nearest_neighbors(tensor, k, index):
    """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
42
    query = tf.gather(tensor, index, batch_dims=0)
43
44
45
    distances = tf.norm(tensor - query, axis=1)
    max_distance = tf.sort(distances)[k]
    neighbourhood_mask = distances < max_distance
46
47
48
49
    return tf.squeeze(tf.where(neighbourhood_mask))


@tf.function
50
def get_neighbourhood_entropy(index, tensor, clusters, k):
51
    neighborhood = get_k_nearest_neighbors(tensor, k, index)
52
    cluster_z = tf.gather(clusters, neighborhood, batch_dims=0)
53
54
    neigh_entropy = compute_shannon_entropy(cluster_z)
    return neigh_entropy
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
76
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

118
119
120
121
122
123
124
125
126
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
127
    kernel = tf.exp(
128
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
129
    )
lucas_miranda's avatar
lucas_miranda committed
130
    return kernel
131
132


133
@tf.function
134
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
135
136
    """

137
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
138

139
140
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
141

142
143
144
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
145

146
    """
147
148
149
150

    x = tensors[0]
    y = tensors[1]

151
152
153
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
154
    mmd = (
155
156
157
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
158
    )
lucas_miranda's avatar
lucas_miranda committed
159
    return mmd
160
161


162
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
163
164
165
166
167
168
169
170
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

171
    def __init__(
172
173
174
175
176
177
178
        self,
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
        log_dir: str = ".",
179
    ):
lucas_miranda's avatar
lucas_miranda committed
180
        super().__init__()
181
182
183
184
185
186
187
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
188
        self.history = {}
lucas_miranda's avatar
lucas_miranda committed
189
        self.log_dir = log_dir
190

lucas_miranda's avatar
lucas_miranda committed
191
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
192
193
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
194
195
196
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
197
198
199

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
221

lucas_miranda's avatar
lucas_miranda committed
222
223
224
225
226
227
228
    def on_epoch_end(self, epoch, logs=None):
        """Logs the learning rate to tensorboard"""

        writer = tf.summary.create_file_writer(self.log_dir)

        with writer.as_default():
            tf.summary.scalar(
lucas_miranda's avatar
lucas_miranda committed
229
230
231
                "learning_rate",
                data=self.model.optimizer.lr,
                step=epoch,
lucas_miranda's avatar
lucas_miranda committed
232
            )
233
234


lucas_miranda's avatar
lucas_miranda committed
235
236
237
class uncorrelated_features_constraint(Constraint):
    """

238
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
239
240
241
242
    Useful, among others, for auto encoder bottleneck layers

    """

243
244
245
246
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

247
    def get_config(self):  # pragma: no cover
248
        """Updates Constraint metadata"""
249
250

        config = super().get_config().copy()
251
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
252
253
254
        return config

    def get_covariance(self, x):
255
256
        """Computes the covariance of the elements of the passed layer"""

257
258
259
        x_centered_list = []

        for i in range(self.encoding_dim):
260
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
261
262

        x_centered = tf.stack(x_centered_list)
263
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
264
265
266
267
268
269
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
270
    # noinspection PyUnusedLocal
271
    def uncorrelated_feature(self, x):
272
273
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

274
        if self.encoding_dim <= 1:  # pragma: no cover
275
276
            return 0.0
        else:
277
278
            output = K.sum(
                K.square(
279
                    self.covariance
280
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
281
282
283
284
285
286
287
288
289
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


290
291
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
292
293
294
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

295
    def call(self, inputs, **kwargs):
296
        """Overrides the call method of the subclassed function"""
297
298
299
300
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
301
302
303
304
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

305
306
307
308
309
310
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

311
    def get_config(self):  # pragma: no cover
312
313
        """Updates Constraint metadata"""

314
315
316
317
318
319
320
321
322
323
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

324
    # noinspection PyAttributeOutsideInit
325
    def build(self, batch_input_shape):
326
327
        """Updates Layer's build method"""

328
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
329
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
330
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
331
            initializer="zeros",
332
333
334
335
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
336
337
        """Updates Layer's call method"""

338
339
340
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

341
    def compute_output_shape(self, input_shape):  # pragma: no cover
342
343
        """Outputs the transposed shape"""

344
345
346
        return input_shape[0], self.output_dim


347
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
348
    """
349
350
    Identity transform layer that adds KL Divergence
    to the final model loss.
351
352
    """

353
    def __init__(self, iters, warm_up_iters, annealing_mode, *args, **kwargs):
354
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)
355
356
357
        self.is_placeholder = True
        self._iters = iters
        self._warm_up_iters = warm_up_iters
358
        self._annealing_mode = annealing_mode
359

360
    def get_config(self):  # pragma: no cover
361
362
        """Updates Constraint metadata"""

363
        config = super().get_config().copy()
364
        config.update({"is_placeholder": self.is_placeholder})
365
366
        config.update({"_iters": self._iters})
        config.update({"_warm_up_iters": self._warm_up_iters})
367
        config.update({"_annealing_mode": self._annealing_mode})
368
369
370
        return config

    def call(self, distribution_a):
371
372
        """Updates Layer's call method"""

373
374
        # Define and update KL weight for warmup
        if self._warm_up_iters > 0:
375
            if self._annealing_mode in ["linear", "sigmoid"]:
376
377
378
                kl_weight = tf.cast(
                    K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
                )
379
                if self._annealing_mode == "sigmoid":
380
381
382
                    kl_weight = tf.math.sigmoid(
                        (2 * kl_weight - 1) / (kl_weight - kl_weight ** 2)
                    )
383
            else:
384
385
386
                raise NotImplementedError(
                    "annealing_mode must be one of 'linear' and 'sigmoid'"
                )
387
388
389
390
391
        else:
            kl_weight = tf.cast(1.0, tf.float32)

        kl_batch = kl_weight * self._regularizer(distribution_a)

392
393
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
394
395
396
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
397
        )
398
        # noinspection PyProtectedMember
399
        self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
400
401
402
403

        return distribution_a


404
class MMDiscrepancyLayer(Layer):
405
    """
406
    Identity transform layer that adds MM Discrepancy
407
408
409
    to the final model loss.
    """

410
    def __init__(
411
        self, batch_size, prior, iters, warm_up_iters, annealing_mode, *args, **kwargs
412
    ):
413
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
414
        self.is_placeholder = True
415
        self.batch_size = batch_size
416
        self.prior = prior
417
418
        self._iters = iters
        self._warm_up_iters = warm_up_iters
419
        self._annealing_mode = annealing_mode
420

421
    def get_config(self):  # pragma: no cover
422
423
        """Updates Constraint metadata"""

424
        config = super().get_config().copy()
425
        config.update({"batch_size": self.batch_size})
426
427
        config.update({"_iters": self._iters})
        config.update({"_warmup_iters": self._warm_up_iters})
428
        config.update({"prior": self.prior})
429
        config.update({"_annealing_mode": self._annealing_mode})
430
431
        return config

432
    def call(self, z, **kwargs):
433
434
        """Updates Layer's call method"""

435
        true_samples = self.prior.sample(self.batch_size)
436

437
438
        # Define and update MMD weight for warmup
        if self._warm_up_iters > 0:
439
440
441
442
443
            if self._annealing_mode in ["linear", "sigmoid"]:
                mmd_weight = tf.cast(
                    K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
                )
                if self._annealing_mode == "sigmoid":
444
445
446
                    mmd_weight = tf.math.sigmoid(
                        (2 * mmd_weight - 1) / (mmd_weight - mmd_weight ** 2)
                    )
447
448
449
450
            else:
                raise NotImplementedError(
                    "annealing_mode must be one of 'linear' and 'sigmoid'"
                )
451
452
453
        else:
            mmd_weight = tf.cast(1.0, tf.float32)

454
        mmd_batch = mmd_weight * compute_mmd((true_samples, z))
455

456
        self.add_loss(K.mean(mmd_batch), inputs=z)
457
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
458
        self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
459
460

        return z
461
462


463
class ClusterOverlap(Layer):
464
465
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
466
467
    using the the entropy of the nearest neighbourhood. If self.loss_weight > 0, it adds a regularization
    penalty to the loss function
468
469
    """

470
471
    def __init__(
        self,
472
        batch_size: int,
473
        encoding_dim: int,
474
        k: int = 25,
475
        loss_weight: float = 0.0,
476
477
478
        *args,
        **kwargs
    ):
479
        self.batch_size = batch_size
480
481
482
        self.enc = encoding_dim
        self.k = k
        self.loss_weight = loss_weight
483
484
        self.min_confidence = 0.25
        super(ClusterOverlap, self).__init__(*args, **kwargs)
485

486
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
487
488
        """Updates Constraint metadata"""

489
        config = super().get_config().copy()
490
        config.update({"batch_size": self.batch_size})
491
492
493
        config.update({"enc": self.enc})
        config.update({"k": self.k})
        config.update({"loss_weight": self.loss_weight})
494
        config.update({"min_confidence": self.min_confidence})
495
496
497
        config.update({"samples": self.samples})
        return config

498
    def call(self, inputs, training=None, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
499
        """Updates Layer's call method"""
500

501
502
        encodings, categorical = inputs[0], inputs[1]

503
        if training:
504

505
506
            hard_groups = tf.math.argmax(categorical, axis=1)
            max_groups = tf.reduce_max(categorical, axis=1)
507

508
509
510
511
512
513
            get_local_neighbourhood_entropy = partial(
                get_neighbourhood_entropy,
                tensor=encodings,
                clusters=hard_groups,
                k=self.k,
            )
514

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            purity_vector = tf.map_fn(
                get_local_neighbourhood_entropy,
                tf.constant(list(range(self.batch_size))),
                dtype=tf.dtypes.float32,
            )

            ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
            neighbourhood_entropy = purity_vector * max_groups

            number_of_clusters = tf.cast(
                tf.shape(
                    tf.unique(
                        tf.reshape(
                            tf.gather(
                                tf.cast(hard_groups, tf.dtypes.float32),
                                tf.where(max_groups >= self.min_confidence),
                                batch_dims=0,
                            ),
                            [-1],
534
                        ),
535
                    )[0],
536
                )[0],
537
538
                tf.dtypes.float32,
            )
539

540
541
542
543
            self.add_metric(
                number_of_clusters,
                name="number_of_populated_clusters",
            )
544

545
546
547
548
549
            self.add_metric(
                max_groups,
                aggregation="mean",
                name="average_confidence_in_selected_cluster",
            )
550

551
552
553
            self.add_metric(
                neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
            )
554

555
556
557
558
559
            if self.loss_weight:
                # minimize local entropy
                self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
                # maximize number of clusters
                # self.add_loss(-self.loss_weight * tf.reduce_mean(number_of_clusters))
560

561
        return encodings