model_utils.py 11.9 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
12
from tensorflow.keras import backend as K
13
14
15
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
16
import tensorflow_probability as tfp
17

18
tfd = tfp.distributions
19
tfpl = tfp.layers
20

lucas_miranda's avatar
lucas_miranda committed
21

22
# Helper functions
23
@tf.function
lucas_miranda's avatar
lucas_miranda committed
24
25
26
def far_away_uniform_initialiser(
    shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
) -> tf.Tensor:
27
28
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
lucas_miranda's avatar
lucas_miranda committed
29
30
31
32
33
34
35
36
37
38
39
40
41

        Parameters:
            - shape (tuple): shape of the object to generate.
            - minval (int): Minimum value of the uniform distribution from which to sample
            - maxval (int): Maximum value of the uniform distribution from which to sample
            - iters (int): the algorithm generates values at random and keeps those runs that
            are the farthest apart. Increasing this parameter will lead to more accurate,
            results while making the function run slowlier.

        Returns:
            - init (tf.Tensor): tensor of the specified shape in which the column vectors
             are as far as possible

42
    """
43
44
45
46
47
48
49
50

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
51
52
53
54
55

        if dist > init_dist:
            init_dist = dist
            init = temp

56
57
58
        i += 1

    return init
59
60


lucas_miranda's avatar
lucas_miranda committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
    """

    Computes the MMD between the two specified vectors using a gaussian kernel.

        Parameters:
            - x (tf.Tensor): left tensor
            - y (tf.Tensor): right tensor

        Returns
            - kernel (tf.Tensor): returns the result of applying the kernel, for
            each training instance

    """

76
77
78
79
80
81
82
83
84
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
lucas_miranda's avatar
lucas_miranda committed
85
    kernel = tf.exp(
86
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
87
    )
lucas_miranda's avatar
lucas_miranda committed
88
    return kernel
89
90


91
@tf.function
lucas_miranda's avatar
lucas_miranda committed
92
93
94
95
96
97
98
99
100
101
102
103
104
def compute_mmd(tensors: tuple) -> tf.Tensor:
    """

        Computes the MMD between the two specified vectors using a gaussian kernel.

            Parameters:
                - tensors (tuple): tuple containing two tf.Tensor objects

            Returns
                - mmd (tf.Tensor): returns the maximum mean discrepancy for each
                training instance

        """
105
106
107
108

    x = tensors[0]
    y = tensors[1]

109
110
111
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
lucas_miranda's avatar
lucas_miranda committed
112
    mmd = (
113
114
115
116
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )
lucas_miranda's avatar
lucas_miranda committed
117
    return mmd
118
119


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
182
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
183
184

        x_centered = tf.stack(x_centered_list)
185
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
186
187
188
189
190
191
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
192
    def uncorrelated_feature(self, x):
193
194
195
        if self.encoding_dim <= 1:
            return 0.0
        else:
196
197
            output = K.sum(
                K.square(
198
                    self.covariance
199
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
200
201
202
203
204
205
206
207
208
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


209
210
211
212
213
214
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


215
216
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
217
218
219
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

220
221
222
    def get_config(self):
        config = super().get_config().copy()
        config.update(
223
            {"is_placeholder": self.is_placeholder,}
224
225
226
        )
        return config

227
228
229
230
231
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
232
        )
233
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
234

235
        return distribution_a
236
237


238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


270
class MMDiscrepancyLayer(Layer):
271
    """
272
    Identity transform layer that adds MM Discrepancy
273
274
275
    to the final model loss.
    """

276
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
277
        self.is_placeholder = True
278
        self.batch_size = batch_size
279
        self.beta = beta
280
        self.prior = prior
281
282
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

283
284
    def get_config(self):
        config = super().get_config().copy()
285
        config.update({"batch_size": self.batch_size})
286
        config.update({"beta": self.beta})
287
        config.update({"prior": self.prior})
288
289
        return config

290
    def call(self, z, **kwargs):
291
        true_samples = self.prior.sample(self.batch_size)
292
        mmd_batch = self.beta * compute_mmd([true_samples, z])
293
        self.add_loss(K.mean(mmd_batch), inputs=z)
294
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
295
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
296
297

        return z
298
299


300
301
302
303
304
305
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

306
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

328
329
330
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
331
332
333

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

334
        ### MMD-based overlap ###
335
        intercomponent_mmd = K.mean(
336
337
            tf.convert_to_tensor(
                [
338
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
339
340
341
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
342
            )
343
        )
344

345
        self.add_metric(
346
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
347
        )
348

349
350
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
351
352
353
354

        return target


355
class Dead_neuron_control(Layer):
356
357
358
359
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
360

361
362
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
363

364
365
366
367
368
369
370
371
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

        return z
372
373
374
375
376
377
378


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

379
    def __init__(self, weight=1.0, *args, **kwargs):
380
381
382
383
384
385
386
387
388
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

389
390
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
391
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
392
393

        # Adds metric that monitors dead neurons in the latent space
394
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
395

396
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
397
398

        return z