model_utils.py 16.3 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
13
from typing import Any, Tuple

14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

87
88
89
90
91
92
93
94
95
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
96
    kernel = tf.exp(
97
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
98
    )
lucas_miranda's avatar
lucas_miranda committed
99
    return kernel
100
101


102
@tf.function
103
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
104
105
    """

106
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
107

108
109
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
110

111
112
113
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
114

115
    """
116
117
118
119

    x = tensors[0]
    y = tensors[1]

120
121
122
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
123
    mmd = (
124
125
126
127
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
128
    return mmd
129
130


131
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
132
133
134
135
136
137
138
139
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

140
141
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
147
    ):
lucas_miranda's avatar
lucas_miranda committed
148
        super().__init__()
149
150
151
152
153
154
155
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
156
        self.history = {}
157

lucas_miranda's avatar
lucas_miranda committed
158
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
159
160
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
161
162
163
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
164
165
166

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
188

lucas_miranda's avatar
lucas_miranda committed
189
    def on_batch_end(self, epoch, logs=None):
190
        """Add current learning rate as a metric, to check whether scheduling is working properly"""
lucas_miranda's avatar
lucas_miranda committed
191
192
193
194
195
196

        self.add_metric(
            self.last_rate,
            aggregation="mean",
            name="learning_rate",
        )
197
198
199
200
201
202
203
204
205


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

lucas_miranda's avatar
lucas_miranda committed
206
    def __init__(self, k=5, samples=1000):
207
208
        super().__init__()
        self.k = k
lucas_miranda's avatar
lucas_miranda committed
209
        self.samples = samples
210
211
212
213
214
215

    # noinspection PyMethodOverriding,PyTypeChecker
    def on_epoch_end(self, batch: int, logs):
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

        # Get encoer and grouper from full model
lucas_miranda's avatar
lucas_miranda committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        cluster_means = [
            layer for layer in self.model.layers if layer.name == "cluster_means"
        ][0]
        cluster_assignment = [
            layer for layer in self.model.layers if layer.name == "cluster_assignment"
        ][0]

        encoder = tf.keras.models.Model(
            self.model.layers[0].input, cluster_means.output
        )
        grouper = tf.keras.models.Model(
            self.model.layers[0].input, cluster_assignment.output
        )

        trial_idxs = np.random.choice(
            range(self.validation_data.shape[0]), self.samples
        )
        trial_data = self.validation_data[trial_idxs]

        # Use encoder and grouper to predict on validation data
        encoding = encoder.predict(trial_data)
        groups = grouper.predict(trial_data)
238
239
240
241

        #


lucas_miranda's avatar
lucas_miranda committed
242
243
244
class uncorrelated_features_constraint(Constraint):
    """

245
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
246
247
248
249
    Useful, among others, for auto encoder bottleneck layers

    """

250
251
252
253
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

254
    def get_config(self):  # pragma: no cover
255
        """Updates Constraint metadata"""
256
257

        config = super().get_config().copy()
258
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
259
260
261
        return config

    def get_covariance(self, x):
262
263
        """Computes the covariance of the elements of the passed layer"""

264
265
266
        x_centered_list = []

        for i in range(self.encoding_dim):
267
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
268
269

        x_centered = tf.stack(x_centered_list)
270
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
271
272
273
274
275
276
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
277
    # noinspection PyUnusedLocal
278
    def uncorrelated_feature(self, x):
279
280
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

281
        if self.encoding_dim <= 1:  # pragma: no cover
282
283
            return 0.0
        else:
284
285
            output = K.sum(
                K.square(
286
                    self.covariance
287
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
288
289
290
291
292
293
294
295
296
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


297
298
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
299
300
301
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

302
    def call(self, inputs, **kwargs):
303
        """Overrides the call method of the subclassed function"""
304
305
306
307
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
308
309
310
311
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

312
313
314
315
316
317
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

318
    def get_config(self):  # pragma: no cover
319
320
        """Updates Constraint metadata"""

321
322
323
324
325
326
327
328
329
330
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

331
    # noinspection PyAttributeOutsideInit
332
    def build(self, batch_input_shape):
333
334
        """Updates Layer's build method"""

335
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
336
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
337
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
338
            initializer="zeros",
339
340
341
342
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
343
344
        """Updates Layer's call method"""

345
346
347
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

348
    def compute_output_shape(self, input_shape):  # pragma: no cover
349
350
        """Outputs the transposed shape"""

351
352
353
        return input_shape[0], self.output_dim


354
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
355
    """
356
357
    Identity transform layer that adds KL Divergence
    to the final model loss.
358
359
    """

360
361
362
363
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

364
    def get_config(self):  # pragma: no cover
365
366
        """Updates Constraint metadata"""

367
        config = super().get_config().copy()
368
        config.update({"is_placeholder": self.is_placeholder})
369
370
371
        return config

    def call(self, distribution_a):
372
373
        """Updates Layer's call method"""

374
375
376
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
377
378
379
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
380
        )
381
        # noinspection PyProtectedMember
382
383
384
385
386
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


387
class MMDiscrepancyLayer(Layer):
388
    """
389
    Identity transform layer that adds MM Discrepancy
390
391
392
    to the final model loss.
    """

393
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
394
        self.is_placeholder = True
395
        self.batch_size = batch_size
396
        self.beta = beta
397
        self.prior = prior
398
399
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

400
    def get_config(self):  # pragma: no cover
401
402
        """Updates Constraint metadata"""

403
        config = super().get_config().copy()
404
        config.update({"batch_size": self.batch_size})
405
        config.update({"beta": self.beta})
406
        config.update({"prior": self.prior})
407
408
        return config

409
    def call(self, z, **kwargs):
410
411
        """Updates Layer's call method"""

412
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
413
        # noinspection PyTypeChecker
414
        mmd_batch = self.beta * compute_mmd((true_samples, z))
415
        self.add_loss(K.mean(mmd_batch), inputs=z)
416
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
417
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
418
419

        return z
420
421


422
class Cluster_overlap(Layer):
423
424
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
425
    using the average inter-cluster MMD as a metric
426
427
    """

428
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
429
430
431
432
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
433
        super(Cluster_overlap, self).__init__(*args, **kwargs)
434

435
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
436
437
        """Updates Constraint metadata"""

438
439
440
441
442
443
444
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
445
446
447
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
448
449
450
451

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
452
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
453

454
455
456
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
457
458
459

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
460
        # MMD-based overlap #
461
        intercomponent_mmd = K.mean(
462
463
            tf.convert_to_tensor(
                [
464
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
465
466
467
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
468
            )
469
        )
470

471
        self.add_metric(
472
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
473
        )
474

475
476
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
477
478
479
480

        return target


481
class Dead_neuron_control(Layer):
482
483
484
485
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
486

487
488
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
489

lucas_miranda's avatar
lucas_miranda committed
490
491
492
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
493
494
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
495
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
496
497
        )

lucas_miranda's avatar
lucas_miranda committed
498
        return target
499
500
501
502
503
504
505


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
506
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
507
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
508
        self.axis = axis
509
510
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

511
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
512
513
        """Updates Constraint metadata"""

514
515
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
516
        config.update({"axis": self.axis})
517
518

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
519
520
        """Updates Layer's call method"""

521
522
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
523
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
524
525

        # Adds metric that monitors dead neurons in the latent space
526
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
527

528
529
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
530
531

        return z