model_utils.py 17 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
from typing import Any, Tuple
lucas_miranda's avatar
lucas_miranda committed
13
from sklearn.neighbors import NearestNeighbors
14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

87
88
89
90
91
92
93
94
95
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
96
    kernel = tf.exp(
97
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
98
    )
lucas_miranda's avatar
lucas_miranda committed
99
    return kernel
100
101


102
@tf.function
103
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
104
105
    """

106
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
107

108
109
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
110

111
112
113
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
114

115
    """
116
117
118
119

    x = tensors[0]
    y = tensors[1]

120
121
122
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
123
    mmd = (
124
125
126
127
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
128
    return mmd
129
130


131
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
132
133
134
135
136
137
138
139
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

140
141
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
147
    ):
lucas_miranda's avatar
lucas_miranda committed
148
        super().__init__()
149
150
151
152
153
154
155
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
156
        self.history = {}
157

lucas_miranda's avatar
lucas_miranda committed
158
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
159
160
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
161
162
163
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
164
165
166

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
188

lucas_miranda's avatar
lucas_miranda committed
189
    def on_batch_end(self, epoch, logs=None):
190
        """Add current learning rate as a metric, to check whether scheduling is working properly"""
lucas_miranda's avatar
lucas_miranda committed
191

lucas_miranda's avatar
lucas_miranda committed
192
        return self.last_rate
193
194
195
196
197
198
199
200
201


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

lucas_miranda's avatar
lucas_miranda committed
202
    def __init__(self, k=5, samples=10000):
203
204
        super().__init__()
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
205
        self.samples = samples
206
207
208
209
210
211

    # noinspection PyMethodOverriding,PyTypeChecker
    def on_epoch_end(self, batch: int, logs):
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

        # Get encoer and grouper from full model
lucas_miranda's avatar
lucas_miranda committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        cluster_means = [
            layer for layer in self.model.layers if layer.name == "cluster_means"
        ][0]
        cluster_assignment = [
            layer for layer in self.model.layers if layer.name == "cluster_assignment"
        ][0]

        encoder = tf.keras.models.Model(
            self.model.layers[0].input, cluster_means.output
        )
        grouper = tf.keras.models.Model(
            self.model.layers[0].input, cluster_assignment.output
        )

lucas_miranda's avatar
lucas_miranda committed
226
227
228
229
230
231
232
233
        # Use encoder and grouper to predict on validation data
        encoding = encoder.predict(self.validation_data)
        groups = grouper.predict(self.validation_data)

        # Multiply encodings by groups, to get a weighted version of the matrix
        encoding = (
            encoding
            * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
lucas_miranda's avatar
lucas_miranda committed
234
        )
lucas_miranda's avatar
lucas_miranda committed
235
        hard_groups = groups.argmax(axis=1)
lucas_miranda's avatar
lucas_miranda committed
236

lucas_miranda's avatar
lucas_miranda committed
237
238
        # Fit KNN model
        knn = NearestNeighbors().fit(encoding)
239

lucas_miranda's avatar
lucas_miranda committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        # Iterate over samples and compute purity over k neighbours
        random_idxs = np.random.choice(
            range(encoding.shape[0]), self.samples, replace=False
        )
        purity_vector = np.zeros(self.samples)
        for i, sample in enumerate(random_idxs):
            indexes = knn.kneighbors(
                encoding[sample][np.newaxis, :], self.k, return_distance=False
            )
            purity_vector[i] = (
                np.sum(hard_groups[indexes] == hard_groups[sample])
                / self.k
                * np.max(groups[sample])
            )

lucas_miranda's avatar
lucas_miranda committed
255
        return purity_vector.mean()
256
257


lucas_miranda's avatar
lucas_miranda committed
258
259
260
class uncorrelated_features_constraint(Constraint):
    """

261
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
262
263
264
265
    Useful, among others, for auto encoder bottleneck layers

    """

266
267
268
269
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

270
    def get_config(self):  # pragma: no cover
271
        """Updates Constraint metadata"""
272
273

        config = super().get_config().copy()
274
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
275
276
277
        return config

    def get_covariance(self, x):
278
279
        """Computes the covariance of the elements of the passed layer"""

280
281
282
        x_centered_list = []

        for i in range(self.encoding_dim):
283
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
284
285

        x_centered = tf.stack(x_centered_list)
286
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
287
288
289
290
291
292
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
293
    # noinspection PyUnusedLocal
294
    def uncorrelated_feature(self, x):
295
296
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

297
        if self.encoding_dim <= 1:  # pragma: no cover
298
299
            return 0.0
        else:
300
301
            output = K.sum(
                K.square(
302
                    self.covariance
303
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
304
305
306
307
308
309
310
311
312
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


313
314
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
315
316
317
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

318
    def call(self, inputs, **kwargs):
319
        """Overrides the call method of the subclassed function"""
320
321
322
323
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
324
325
326
327
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

328
329
330
331
332
333
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

334
    def get_config(self):  # pragma: no cover
335
336
        """Updates Constraint metadata"""

337
338
339
340
341
342
343
344
345
346
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

347
    # noinspection PyAttributeOutsideInit
348
    def build(self, batch_input_shape):
349
350
        """Updates Layer's build method"""

351
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
352
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
353
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
354
            initializer="zeros",
355
356
357
358
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
359
360
        """Updates Layer's call method"""

361
362
363
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

364
    def compute_output_shape(self, input_shape):  # pragma: no cover
365
366
        """Outputs the transposed shape"""

367
368
369
        return input_shape[0], self.output_dim


370
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
371
    """
372
373
    Identity transform layer that adds KL Divergence
    to the final model loss.
374
375
    """

376
377
378
379
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

380
    def get_config(self):  # pragma: no cover
381
382
        """Updates Constraint metadata"""

383
        config = super().get_config().copy()
384
        config.update({"is_placeholder": self.is_placeholder})
385
386
387
        return config

    def call(self, distribution_a):
388
389
        """Updates Layer's call method"""

390
391
392
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
393
394
395
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
396
        )
397
        # noinspection PyProtectedMember
398
399
400
401
402
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


403
class MMDiscrepancyLayer(Layer):
404
    """
405
    Identity transform layer that adds MM Discrepancy
406
407
408
    to the final model loss.
    """

409
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
410
        self.is_placeholder = True
411
        self.batch_size = batch_size
412
        self.beta = beta
413
        self.prior = prior
414
415
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

416
    def get_config(self):  # pragma: no cover
417
418
        """Updates Constraint metadata"""

419
        config = super().get_config().copy()
420
        config.update({"batch_size": self.batch_size})
421
        config.update({"beta": self.beta})
422
        config.update({"prior": self.prior})
423
424
        return config

425
    def call(self, z, **kwargs):
426
427
        """Updates Layer's call method"""

428
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
429
        # noinspection PyTypeChecker
430
        mmd_batch = self.beta * compute_mmd((true_samples, z))
431
        self.add_loss(K.mean(mmd_batch), inputs=z)
432
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
433
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
434
435

        return z
436
437


438
class Cluster_overlap(Layer):
439
440
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
441
    using the average inter-cluster MMD as a metric
442
443
    """

444
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
445
446
447
448
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
449
        super(Cluster_overlap, self).__init__(*args, **kwargs)
450

451
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
452
453
        """Updates Constraint metadata"""

454
455
456
457
458
459
460
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
461
462
463
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
464
465
466
467

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
468
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
469

470
471
472
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
473
474
475

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
476
        # MMD-based overlap #
477
        intercomponent_mmd = K.mean(
478
479
            tf.convert_to_tensor(
                [
480
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
481
482
483
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
484
            )
485
        )
486

487
        self.add_metric(
488
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
489
        )
490

491
492
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
493
494
495
496

        return target


497
class Dead_neuron_control(Layer):
498
499
500
501
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
502

503
504
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
505

lucas_miranda's avatar
lucas_miranda committed
506
507
508
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
509
510
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
511
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
512
513
        )

lucas_miranda's avatar
lucas_miranda committed
514
        return target
515
516
517
518
519
520
521


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
522
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
523
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
524
        self.axis = axis
525
526
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

527
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
528
529
        """Updates Constraint metadata"""

530
531
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
532
        config.update({"axis": self.axis})
533
534

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
535
536
        """Updates Layer's call method"""

537
538
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
539
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
540
541

        # Adds metric that monitors dead neurons in the latent space
542
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
543

544
545
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
546
547

        return z