model_utils.py 14.7 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
lucas_miranda's avatar
lucas_miranda committed
12
13
from typing import Any, Tuple

14
from tensorflow.keras import backend as K
15
16
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
17
import matplotlib.pyplot as plt
18
import tensorflow as tf
19
import tensorflow_probability as tfp
20

21
tfd = tfp.distributions
22
tfpl = tfp.layers
23

lucas_miranda's avatar
lucas_miranda committed
24

25
# Helper functions and classes
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


lucas_miranda's avatar
lucas_miranda committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

87
88
89
90
91
92
93
94
95
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
96
    kernel = tf.exp(
97
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
98
    )
lucas_miranda's avatar
lucas_miranda committed
99
    return kernel
100
101


102
@tf.function
103
def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
lucas_miranda's avatar
lucas_miranda committed
104
105
    """

106
    Computes the MMD between the two specified vectors using a gaussian kernel.
lucas_miranda's avatar
lucas_miranda committed
107

108
109
        Parameters:
            - tensors (tuple): tuple containing two tf.Tensor objects
lucas_miranda's avatar
lucas_miranda committed
110

111
112
113
        Returns
            - mmd (tf.Tensor): returns the maximum mean discrepancy for each
            training instance
lucas_miranda's avatar
lucas_miranda committed
114

115
    """
116
117
118
119

    x = tensors[0]
    y = tensors[1]

120
121
122
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
123
    mmd = (
124
125
126
127
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
128
    return mmd
129
130


131
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
132
133
134
135
136
137
138
139
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

140
141
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
147
    ):
lucas_miranda's avatar
lucas_miranda committed
148
        super().__init__()
149
150
151
152
153
154
155
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
156
        self.history = {}
157

lucas_miranda's avatar
lucas_miranda committed
158
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
159
160
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
161
162
163
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
164
165
166

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
188
189


lucas_miranda's avatar
lucas_miranda committed
190
191
192
class uncorrelated_features_constraint(Constraint):
    """

193
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
194
195
196
197
    Useful, among others, for auto encoder bottleneck layers

    """

198
199
200
201
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

202
    def get_config(self):  # pragma: no cover
203
        """Updates Constraint metadata"""
204
205

        config = super().get_config().copy()
206
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
207
208
209
        return config

    def get_covariance(self, x):
210
211
        """Computes the covariance of the elements of the passed layer"""

212
213
214
        x_centered_list = []

        for i in range(self.encoding_dim):
215
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
216
217

        x_centered = tf.stack(x_centered_list)
218
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
219
220
221
222
223
224
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
225
    # noinspection PyUnusedLocal
226
    def uncorrelated_feature(self, x):
227
228
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

229
        if self.encoding_dim <= 1:  # pragma: no cover
230
231
            return 0.0
        else:
232
233
            output = K.sum(
                K.square(
234
                    self.covariance
235
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
236
237
238
239
240
241
242
243
244
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


245
246
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
247
248
249
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

250
    def call(self, inputs, **kwargs):
251
        """Overrides the call method of the subclassed function"""
252
253
254
255
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
256
257
258
259
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

260
261
262
263
264
265
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

266
    def get_config(self):  # pragma: no cover
267
268
        """Updates Constraint metadata"""

269
270
271
272
273
274
275
276
277
278
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

279
    # noinspection PyAttributeOutsideInit
280
    def build(self, batch_input_shape):
281
282
        """Updates Layer's build method"""

283
        self.biases = self.add_weight(
lucas_miranda's avatar
lucas_miranda committed
284
            name="bias",
lucas_miranda's avatar
lucas_miranda committed
285
            shape=self.dense.get_input_at(-1).get_shape().as_list()[1:],
lucas_miranda's avatar
lucas_miranda committed
286
            initializer="zeros",
287
288
289
290
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
291
292
        """Updates Layer's call method"""

293
294
295
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

296
    def compute_output_shape(self, input_shape):  # pragma: no cover
297
298
        """Outputs the transposed shape"""

299
300
301
        return input_shape[0], self.output_dim


302
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
303
    """
304
305
    Identity transform layer that adds KL Divergence
    to the final model loss.
306
307
    """

308
309
310
311
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

312
    def get_config(self):  # pragma: no cover
313
314
        """Updates Constraint metadata"""

315
        config = super().get_config().copy()
316
        config.update({"is_placeholder": self.is_placeholder})
317
318
319
        return config

    def call(self, distribution_a):
320
321
        """Updates Layer's call method"""

322
323
324
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
325
326
327
            kl_batch,
            aggregation="mean",
            name="kl_divergence",
328
        )
329
        # noinspection PyProtectedMember
330
331
332
333
334
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


335
class MMDiscrepancyLayer(Layer):
336
    """
337
    Identity transform layer that adds MM Discrepancy
338
339
340
    to the final model loss.
    """

341
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
342
        self.is_placeholder = True
343
        self.batch_size = batch_size
344
        self.beta = beta
345
        self.prior = prior
346
347
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

348
    def get_config(self):  # pragma: no cover
349
350
        """Updates Constraint metadata"""

351
        config = super().get_config().copy()
352
        config.update({"batch_size": self.batch_size})
353
        config.update({"beta": self.beta})
354
        config.update({"prior": self.prior})
355
356
        return config

357
    def call(self, z, **kwargs):
358
359
        """Updates Layer's call method"""

360
        true_samples = self.prior.sample(self.batch_size)
lucas_miranda's avatar
lucas_miranda committed
361
        # noinspection PyTypeChecker
362
        mmd_batch = self.beta * compute_mmd((true_samples, z))
363
        self.add_loss(K.mean(mmd_batch), inputs=z)
364
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
365
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
366
367

        return z
368
369


370
class Cluster_overlap(Layer):
371
372
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
373
    using the average inter-cluster MMD as a metric
374
375
    """

376
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
377
378
379
380
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
381
        super(Cluster_overlap, self).__init__(*args, **kwargs)
382

383
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
384
385
        """Updates Constraint metadata"""

386
387
388
389
390
391
392
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
393
394
395
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
396
397
398
399

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
400
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
401

402
403
404
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
405
406
407

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
408
        # MMD-based overlap #
409
        intercomponent_mmd = K.mean(
410
411
            tf.convert_to_tensor(
                [
412
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
413
414
415
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
416
            )
417
        )
418

419
        self.add_metric(
420
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
421
        )
422

423
424
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
425
426
427
428

        return target


429
class Dead_neuron_control(Layer):
430
431
432
433
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
434

435
436
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
437

lucas_miranda's avatar
lucas_miranda committed
438
439
440
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
441
442
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
443
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
444
445
        )

lucas_miranda's avatar
lucas_miranda committed
446
        return target
447
448
449
450
451
452
453


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
454
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
455
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
456
        self.axis = axis
457
458
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

459
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
460
461
        """Updates Constraint metadata"""

462
463
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
464
        config.update({"axis": self.axis})
465
466

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
467
468
        """Updates Layer's call method"""

469
470
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
471
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
472
473

        # Adds metric that monitors dead neurons in the latent space
474
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
475

476
477
        if self.weight > 0:
            self.add_loss(self.weight * K.sum(entropy), inputs=[z])
478
479

        return z