model_utils.py 15.6 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
13
from typing import Any, Tuple

14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

87
88
89
90
91
92
93
94
95
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
96
    kernel = tf.exp(
97
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
98
    )
lucas_miranda's avatar
lucas_miranda committed
99
    return kernel
100
101


102
@tf.function
103
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
104
105
    """

106
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
107

108
109
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
110

111
112
113
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
114

115
    """
116
117
118
119

    x = tensors[0]
    y = tensors[1]

120
121
122
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
123
    mmd = (
124
125
126
127
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
128
    return mmd
129
130


131
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
132
133
134
135
136
137
138
139
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

140
141
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
147
    ):
lucas_miranda's avatar
lucas_miranda committed
148
        super().__init__()
149
150
151
152
153
154
155
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
156
        self.history = {}
157

lucas_miranda's avatar
lucas_miranda committed
158
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
159
160
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
161
162
163
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
164
165
166

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def on_epoch_end(self, epoch, logs=None):
        """Add current learning rate as a metric, to check whether scheduling is working properly"""
        pass


class knn_cluster_purity(tf.keras.callbacks.Callback):
    """

    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space

    """

    def __init__(self, trial_data, k=5):
        super().__init__()
        self.trial_data = trial_data
        self.k = k

    # noinspection PyMethodOverriding,PyTypeChecker
    def on_epoch_end(self, batch: int, logs):
        """ Passes samples through the encoder and computes cluster purity on the latent embedding """

        # Get encoer and grouper from full model
lucas_miranda's avatar
lucas_miranda committed
211
212
        encoder = 0
        grouper = 0
213
214
215
216
217
218
219
220

        # Use encoder and grouper to predict on trial data
        encoding = encoder.predict(self.trial_data)
        groups = grouper.predict(self.trial_data)

        #


221

lucas_miranda's avatar
lucas_miranda committed
222
223
224
class uncorrelated_features_constraint(Constraint):
    """

225
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
226
227
228
229
    Useful, among others, for auto encoder bottleneck layers

    """

230
231
232
233
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

234
    def get_config(self):  # pragma: no cover
235
        """Updates Constraint metadata"""
236
237

        config = super().get_config().copy()
238
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
239
240
241
        return config

    def get_covariance(self, x):
242
243
        """Computes the covariance of the elements of the passed layer"""

244
245
246
        x_centered_list = []

        for i in range(self.encoding_dim):
247
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
248
249

        x_centered = tf.stack(x_centered_list)
250
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
251
252
253
254
255
256
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
257
    # noinspection PyUnusedLocal
258
    def uncorrelated_feature(self, x):
259
260
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

261
        if self.encoding_dim <= 1:  # pragma: no cover
262
263
            return 0.0
        else:
264
265
            output = K.sum(
                K.square(
266
                    self.covariance
267
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
268
269
270
271
272
273
274
275
276
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


277
278
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
279
280
281
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

282
    def call(self, inputs, **kwargs):
283
        """Overrides the call method of the subclassed function"""
284
285
286
287
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
288
289
290
291
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

292
293
294
295
296
297
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

298
    def get_config(self):  # pragma: no cover
299
300
        """Updates Constraint metadata"""

301
302
303
304
305
306
307
308
309
310
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

311
    # noinspection PyAttributeOutsideInit
312
    def build(self, batch_input_shape):
313
314
        """Updates Layer's build method"""

315
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
316
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
317
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
318
            initializer="zeros",
319
320
321
322
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
323
324
        """Updates Layer's call method"""

325
326
327
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

328
    def compute_output_shape(self, input_shape):  # pragma: no cover
329
330
        """Outputs the transposed shape"""

331
332
333
        return input_shape[0], self.output_dim


334
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
335
    """
336
337
    Identity transform layer that adds KL Divergence
    to the final model loss.
338
339
    """

340
341
342
343
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

344
    def get_config(self):  # pragma: no cover
345
346
        """Updates Constraint metadata"""

347
        config = super().get_config().copy()
348
        config.update({"is_placeholder": self.is_placeholder})
349
350
351
        return config

    def call(self, distribution_a):
352
353
        """Updates Layer's call method"""

354
355
356
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
357
358
359
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
360
        )
361
        # noinspection PyProtectedMember
362
363
364
365
366
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


367
class MMDiscrepancyLayer(Layer):
368
    """
369
    Identity transform layer that adds MM Discrepancy
370
371
372
    to the final model loss.
    """

373
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
374
        self.is_placeholder = True
375
        self.batch_size = batch_size
376
        self.beta = beta
377
        self.prior = prior
378
379
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

380
    def get_config(self):  # pragma: no cover
381
382
        """Updates Constraint metadata"""

383
        config = super().get_config().copy()
384
        config.update({"batch_size": self.batch_size})
385
        config.update({"beta": self.beta})
386
        config.update({"prior": self.prior})
387
388
        return config

389
    def call(self, z, **kwargs):
390
391
        """Updates Layer's call method"""

392
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
393
        # noinspection PyTypeChecker
394
        mmd_batch = self.beta * compute_mmd((true_samples, z))
395
        self.add_loss(K.mean(mmd_batch), inputs=z)
396
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
397
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
398
399

        return z
400
401


402
class Cluster_overlap(Layer):
403
404
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
405
    using the average inter-cluster MMD as a metric
406
407
    """

408
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
409
410
411
412
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
413
        super(Cluster_overlap, self).__init__(*args, **kwargs)
414

415
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
416
417
        """Updates Constraint metadata"""

418
419
420
421
422
423
424
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
425
426
427
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
428
429
430
431

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
432
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
433

434
435
436
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
437
438
439

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
440
        # MMD-based overlap #
441
        intercomponent_mmd = K.mean(
442
443
            tf.convert_to_tensor(
                [
444
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
445
446
447
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
448
            )
449
        )
450

451
        self.add_metric(
452
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
453
        )
454

455
456
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
457
458
459
460

        return target


461
class Dead_neuron_control(Layer):
462
463
464
465
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
466

467
468
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
469

lucas_miranda's avatar
lucas_miranda committed
470
471
472
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
473
474
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
475
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
476
477
        )

lucas_miranda's avatar
lucas_miranda committed
478
        return target
479
480
481
482
483
484
485


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
486
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
487
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
488
        self.axis = axis
489
490
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

491
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
492
493
        """Updates Constraint metadata"""

494
495
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
496
        config.update({"axis": self.axis})
497
498

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
499
500
        """Updates Layer's call method"""

501
502
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
503
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
504
505

        # Adds metric that monitors dead neurons in the latent space
506
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
507

508
509
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
510
511

        return z