model_utils.py 14 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
12
from tensorflow.keras import backend as K
13
14
15
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
16
import tensorflow_probability as tfp
17

18
tfd = tfp.distributions
19
tfpl = tfp.layers
20

lucas_miranda's avatar
lucas_miranda committed
21

22
# Helper functions
23
@tf.function
lucas_miranda's avatar
lucas_miranda committed
24
25
26
def far_away_uniform_initialiser(
    shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
) -> tf.Tensor:
27
28
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
lucas_miranda's avatar
lucas_miranda committed
29
30
31
32
33
34
35
36
37
38
39
40
41

        Parameters:
            - shape (tuple): shape of the object to generate.
            - minval (int): Minimum value of the uniform distribution from which to sample
            - maxval (int): Maximum value of the uniform distribution from which to sample
            - iters (int): the algorithm generates values at random and keeps those runs that
            are the farthest apart. Increasing this parameter will lead to more accurate,
            results while making the function run slowlier.

        Returns:
            - init (tf.Tensor): tensor of the specified shape in which the column vectors
             are as far as possible

42
    """
43
44
45
46
47
48
49
50

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
51
52
53
54
55

        if dist > init_dist:
            init_dist = dist
            init = temp

56
57
58
        i += 1

    return init
59
60


lucas_miranda's avatar
lucas_miranda committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

76
77
78
79
80
81
82
83
84
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
85
    kernel = tf.exp(
86
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
87
    )
lucas_miranda's avatar
lucas_miranda committed
88
    return kernel
89
90


91
@tf.function
lucas_miranda's avatar
lucas_miranda committed
92
93
94
95
96
97
98
99
100
101
102
103
104
def compute_mmd(tensors: tuple) -> tf.Tensor:
    """

        Computes the MMD between the two specified vectors using a gaussian kernel.

            Parameters:
                - tensors (tuple): tuple containing two tf.Tensor objects

            Returns
                - mmd (tf.Tensor): returns the maximum mean discrepancy for each
                training instance

        """
105
106
107
108

    x = tensors[0]
    y = tensors[1]

109
110
111
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
112
    mmd = (
113
114
115
116
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
117
    return mmd
118
119


120
# Custom auxiliary classes
lucas_miranda's avatar
lucas_miranda committed
121
122
123
124
125
126
127
128
class one_cycle_scheduler(tf.keras.callbacks.Callback):
    """

    One cycle learning rate scheduler.
    Based on https://arxiv.org/pdf/1506.01186.pdf

    """

129
130
    def __init__(
        self,
lucas_miranda's avatar
lucas_miranda committed
131
132
133
134
135
        iterations: int,
        max_rate: float,
        start_rate: float = None,
        last_iterations: int = None,
        last_rate: float = None,
136
    ):
lucas_miranda's avatar
lucas_miranda committed
137
        super().__init__()
138
139
140
141
142
143
144
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0
145
        self.history = {}
146

lucas_miranda's avatar
lucas_miranda committed
147
    def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
148
149
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

lucas_miranda's avatar
lucas_miranda committed
150
151
152
    # noinspection PyMethodOverriding,PyTypeChecker
    def on_batch_begin(self, batch: int, logs):
        """ Defines computations to perform for each batch """
153
154
155

        self.history.setdefault("lr", []).append(K.get_value(self.model.optimizer.lr))

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
177
178


lucas_miranda's avatar
lucas_miranda committed
179
180
181
class uncorrelated_features_constraint(Constraint):
    """

182
    tf.keras.constraints.Constraint subclass that forces a layer to have uncorrelated features.
lucas_miranda's avatar
lucas_miranda committed
183
184
185
186
    Useful, among others, for auto encoder bottleneck layers

    """

187
188
189
190
191
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):
192
        """Updates Constraint metadata"""
193
194

        config = super().get_config().copy()
195
        config.update({"encoding_dim": self.encoding_dim, "weightage": self.weightage})
196
197
198
        return config

    def get_covariance(self, x):
199
200
        """Computes the covariance of the elements of the passed layer"""

201
202
203
        x_centered_list = []

        for i in range(self.encoding_dim):
204
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
205
206

        x_centered = tf.stack(x_centered_list)
207
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
208
209
210
211
212
213
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
214
    # noinspection PyUnusedLocal
215
    def uncorrelated_feature(self, x):
216
217
        """Adds a penalty on feature correlation, forcing more independent sets of weights"""

218
219
220
        if self.encoding_dim <= 1:
            return 0.0
        else:
221
222
            output = K.sum(
                K.square(
223
                    self.covariance
224
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
225
226
227
228
229
230
231
232
233
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


234
235
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
236
237
238
    """Equivalent to tf.keras.layers.Dropout, but with training mode enabled at prediction time.
    Useful for Montecarlo predictions"""

239
    def call(self, inputs, **kwargs):
240
        """Overrides the call method of the subclassed function"""
241
242
243
244
        return super().call(inputs, training=True)


class DenseTranspose(Layer):
245
246
247
248
    """Mirrors a tf.keras.layers.Dense instance with transposed weights.
    Useful for decoder layers in autoencoders, to force structure and
    decrease the effective number of parameters to train"""

249
250
251
252
253
254
255
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
256
257
        """Updates Constraint metadata"""

258
259
260
261
262
263
264
265
266
267
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

268
    # noinspection PyAttributeOutsideInit
269
    def build(self, batch_input_shape):
270
271
        """Updates Layer's build method"""

272
273
274
275
276
277
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
278
279
        """Updates Layer's call method"""

280
281
282
283
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
284
285
        """Outputs the transposed shape"""

286
287
288
        return input_shape[0], self.output_dim


289
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
290
291
292
293
294
    """
        Identity transform layer that adds KL Divergence
        to the final model loss.
    """

295
296
297
298
299
    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def get_config(self):
300
301
        """Updates Constraint metadata"""

302
        config = super().get_config().copy()
303
        config.update({"is_placeholder": self.is_placeholder})
304
305
306
        return config

    def call(self, distribution_a):
307
308
        """Updates Layer's call method"""

309
310
311
312
313
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
        )
314
        # noinspection PyProtectedMember
315
316
317
318
319
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")

        return distribution_a


320
class MMDiscrepancyLayer(Layer):
321
    """
322
    Identity transform layer that adds MM Discrepancy
323
324
325
    to the final model loss.
    """

326
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
327
        self.is_placeholder = True
328
        self.batch_size = batch_size
329
        self.beta = beta
330
        self.prior = prior
331
332
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

333
    def get_config(self):
334
335
        """Updates Constraint metadata"""

336
        config = super().get_config().copy()
337
        config.update({"batch_size": self.batch_size})
338
        config.update({"beta": self.beta})
339
        config.update({"prior": self.prior})
340
341
        return config

342
    def call(self, z, **kwargs):
343
344
        """Updates Layer's call method"""

345
        true_samples = self.prior.sample(self.batch_size)
346
        mmd_batch = self.beta * compute_mmd((true_samples, z))
347
        self.add_loss(K.mean(mmd_batch), inputs=z)
348
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
349
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
350
351

        return z
352
353


lucas_miranda's avatar
lucas_miranda committed
354
class Gaussian_mixture_overlap(Layer):  # pragma: no cover
355
356
357
358
359
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

360
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
361
362
363
364
365
366
367
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
lucas_miranda's avatar
lucas_miranda committed
368
369
        """Updates Constraint metadata"""

370
371
372
373
374
375
376
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

lucas_miranda's avatar
lucas_miranda committed
377
378
379
    @tf.function
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
380
381
382
383

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
lucas_miranda's avatar
lucas_miranda committed
384
            scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
385

386
387
388
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
389
390
391

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

lucas_miranda's avatar
lucas_miranda committed
392
        # MMD-based overlap #
393
        intercomponent_mmd = K.mean(
394
395
            tf.convert_to_tensor(
                [
396
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
397
398
399
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
400
            )
401
        )
402

403
        self.add_metric(
404
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
405
        )
406

407
408
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
409
410
411
412

        return target


413
class Dead_neuron_control(Layer):
414
415
416
417
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
418

419
420
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
421

lucas_miranda's avatar
lucas_miranda committed
422
423
424
    # noinspection PyMethodOverriding
    def call(self, target, **kwargs):
        """Updates Layer's call method"""
425
426
        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
lucas_miranda's avatar
lucas_miranda committed
427
            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
428
429
        )

lucas_miranda's avatar
lucas_miranda committed
430
        return target
431
432
433
434
435
436
437


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

lucas_miranda's avatar
lucas_miranda committed
438
    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
439
        self.weight = weight
lucas_miranda's avatar
lucas_miranda committed
440
        self.axis = axis
441
442
443
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
lucas_miranda's avatar
lucas_miranda committed
444
445
        """Updates Constraint metadata"""

446
447
        config = super().get_config().copy()
        config.update({"weight": self.weight})
lucas_miranda's avatar
lucas_miranda committed
448
        config.update({"axis": self.axis})
449
450

    def call(self, z, **kwargs):
lucas_miranda's avatar
lucas_miranda committed
451
452
        """Updates Layer's call method"""

453
454
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
lucas_miranda's avatar
lucas_miranda committed
455
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
456
457

        # Adds metric that monitors dead neurons in the latent space
458
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
459

460
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
461
462

        return z