model_utils.py 15.8 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
12
from tensorflow.keras import backend as K
13
14
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
15
import matplotlib.pyplot as plt
16
import tensorflow as tf
17
import tensorflow_probability as tfp
18

19
tfd = tfp.distributions
20
tfpl = tfp.layers
21

lucas_miranda's avatar
lucas_miranda committed
22

23
# Helper functions
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class exponential_learning_rate(tf.keras.callbacks.Callback):
    """Simple class that allows to grow learning rate exponentially during training"""

    def __init__(self, factor):
        super().__init__()
        self.factor = factor
        self.rates = []
        self.losses = []

    # noinspection PyMethodOverriding
    def on_batch_end(self, batch, logs):
        """This callback acts after processing each batch"""

        self.rates.append(K.get_value(self.model.optimizer.lr))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)


def find_learning_rate(
    model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
    """Trains the provided model for an epoch with an exponentially increasing learning rate"""

    init_weights = model.get_weights()
    iterations = len(X) // batch_size * epochs
    factor = K.exp(K.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.lr)
    K.set_value(model.optimizer.lr, min_rate)
    exp_lr = exponential_learning_rate(factor)
    model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
    K.set_value(model.optimizer.lr, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses


def plot_lr_vs_loss(rates, losses):  # pragma: no cover
    """Plots learing rate versus the loss function of the model"""

    plt.plot(rates, losses)
    plt.gca().set_xscale("log")
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")


70
@tf.function
lucas_miranda's avatar
lucas_miranda committed
71
72
73
def far_away_uniform_initialiser(
    shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
) -> tf.Tensor:
74
75
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
lucas_miranda's avatar
lucas_miranda committed
76
77
78
79
80
81
82
83
84
85
86
87
88

        Parameters:
            - shape (tuple): shape of the object to generate.
            - minval (int): Minimum value of the uniform distribution from which to sample
            - maxval (int): Maximum value of the uniform distribution from which to sample
            - iters (int): the algorithm generates values at random and keeps those runs that
            are the farthest apart. Increasing this parameter will lead to more accurate,
            results while making the function run slowlier.

        Returns:
            - init (tf.Tensor): tensor of the specified shape in which the column vectors
             are as far as possible

89
    """
90
91
92
93
94
95
96
97

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
98
99
100
101
102

        if dist > init_dist:
            init_dist = dist
            init = temp

103
104
105
        i += 1

    return init
106
107


lucas_miranda's avatar
lucas_miranda committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

123
124
125
126
127
128
129
130
131
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
132
    kernel = tf.exp(
133
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
134
    )
lucas_miranda's avatar
lucas_miranda committed
135
    return kernel
136
137


138
@tf.function
lucas_miranda's avatar
lucas_miranda committed
139
140
141
142
143
144
145
146
147
148
149
150
151
def compute_mmd(tensors: tuple) -> tf.Tensor:
    """

        Computes the MMD between the two specified vectors using a gaussian kernel.

            Parameters:
                - tensors (tuple): tuple containing two tf.Tensor objects

            Returns
                - mmd (tf.Tensor): returns the maximum mean discrepancy for each
                training instance

        """
152
153
154
155

    x = tensors[0]
    y = tensors[1]

156
157
158
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
159
    mmd = (
160
161
162
163
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
164
    return mmd
165
166


167
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
168
169
170
171
172
173
174
175
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

176
177
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
178
179
180
181
182
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
183
    ):
lucas_miranda's avatar
lucas_miranda committed
184
        super().__init__()
185
186
187
188
189
190
191
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
192
        self.history = {}
193

lucas_miranda's avatar
lucas_miranda committed
194
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
195
196
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
197
198
199
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
200
201
202

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
224
225


lucas_miranda's avatar
lucas_miranda committed
226
227
228
class uncorrelated_features_constraint(Constraint):
    """

229
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
230
231
232
233
    Useful, among others, for auto encoder bottleneck layers

    """

234
235
236
237
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

238
    def get_config(self):  # pragma: no cover
239
        """Updates Constraint metadata"""
240
241

        config = super().get_config().copy()
242
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
243
244
245
        return config

    def get_covariance(self, x):
246
247
        """Computes the covariance of the elements of the passed layer"""

248
249
250
        x_centered_list = []

        for i in range(self.encoding_dim):
251
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
252
253

        x_centered = tf.stack(x_centered_list)
254
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
255
256
257
258
259
260
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
261
    # noinspection PyUnusedLocal
262
    def uncorrelated_feature(self, x):
263
264
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

265
        if self.encoding_dim <= 1:  # pragma: no cover
266
267
            return 0.0
        else:
268
269
            output = K.sum(
                K.square(
270
                    self.covariance
271
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
272
273
274
275
276
277
278
279
280
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


281
282
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
283
284
285
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

286
    def call(self, inputs, **kwargs):
287
        """Overrides the call method of the subclassed function"""
288
289
290
291
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
292
293
294
295
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

296
297
298
299
300
301
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

302
    def get_config(self):  # pragma: no cover
303
304
        """Updates Constraint metadata"""

305
306
307
308
309
310
311
312
313
314
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

315
    # noinspection PyAttributeOutsideInit
316
    def build(self, batch_input_shape):
317
318
        """Updates Layer's build method"""

319
320
321
322
323
324
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
325
326
        """Updates Layer's call method"""

327
328
329
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

330
    def compute_output_shape(self, input_shape):  # pragma: no cover
331
332
        """Outputs the transposed shape"""

333
334
335
        return input_shape[0], self.output_dim


336
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
337
338
339
340
341
    """
        Identity transform layer that adds KL Divergence
        to the final model loss.
    """

342
343
344
345
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

346
    def get_config(self):  # pragma: no cover
347
348
        """Updates Constraint metadata"""

349
        config = super().get_config().copy()
350
        config.update({"is_placeholder": self.is_placeholder})
351
352
353
        return config

    def call(self, distribution_a):
354
355
        """Updates Layer's call method"""

356
357
358
359
360
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
        )
361
        # noinspection PyProtectedMember
362
363
364
365
366
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


367
class MMDiscrepancyLayer(Layer):
368
    """
369
    Identity transform layer that adds MM Discrepancy
370
371
372
    to the final model loss.
    """

373
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
374
        self.is_placeholder = True
375
        self.batch_size = batch_size
376
        self.beta = beta
377
        self.prior = prior
378
379
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

380
    def get_config(self):  # pragma: no cover
381
382
        """Updates Constraint metadata"""

383
        config = super().get_config().copy()
384
        config.update({"batch_size": self.batch_size})
385
        config.update({"beta": self.beta})
386
        config.update({"prior": self.prior})
387
388
        return config

389
    def call(self, z, **kwargs):
390
391
        """Updates Layer's call method"""

392
        true_samples = self.prior.sample(self.batch_size)
393
        mmd_batch = self.beta * compute_mmd((true_samples, z))
394
        self.add_loss(K.mean(mmd_batch), inputs=z)
395
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
396
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
397
398

        return z
399
400


401
class Gaussian_mixture_overlap(Layer):
402
403
404
405
406
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

407
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
408
409
410
411
412
413
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

414
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
415
416
        """Updates Constraint metadata"""

417
418
419
420
421
422
423
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
424
425
426
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
427
428
429
430

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
431
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
432

433
434
435
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
436
437
438

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
439
        # MMD-based overlap #
440
        intercomponent_mmd = K.mean(
441
442
            tf.convert_to_tensor(
                [
443
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
444
445
446
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
447
            )
448
        )
449

450
        self.add_metric(
451
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
452
        )
453

454
455
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
456
457
458
459

        return target


460
class Dead_neuron_control(Layer):
461
462
463
464
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
465

466
467
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
468

lucas_miranda's avatar
lucas_miranda committed
469
470
471
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
472
473
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
474
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
475
476
        )

lucas_miranda's avatar
lucas_miranda committed
477
        return target
478
479
480
481
482
483
484


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
485
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
486
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
487
        self.axis = axis
488
489
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

490
    def get_config(self):  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
491
492
        """Updates Constraint metadata"""

493
494
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
495
        config.update({"axis": self.axis})
496
497

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
498
499
        """Updates Layer's call method"""

500
501
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
502
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
503
504

        # Adds metric that monitors dead neurons in the latent space
505
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
506

507
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
508
509

        return z