model_utils.py 11.2 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
12
from tensorflow.keras import backend as K
13
14
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
lucas_miranda's avatar
lucas_miranda committed
15
import networkx as nx
16
import tensorflow as tf
17
import tensorflow_probability as tfp
18

19
tfd = tfp.distributions
20
tfpl = tfp.layers
21
22

# Helper functions
23
@tf.function
lucas_miranda's avatar
lucas_miranda committed
24
25
26
def far_away_uniform_initialiser(
    shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
) -> tf.Tensor:
27
28
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
lucas_miranda's avatar
lucas_miranda committed
29
30
31
32
33
34
35
36
37
38
39
40
41

        Parameters:
            - shape (tuple): shape of the object to generate.
            - minval (int): Minimum value of the uniform distribution from which to sample
            - maxval (int): Maximum value of the uniform distribution from which to sample
            - iters (int): the algorithm generates values at random and keeps those runs that
            are the farthest apart. Increasing this parameter will lead to more accurate,
            results while making the function run slowlier.

        Returns:
            - init (tf.Tensor): tensor of the specified shape in which the column vectors
             are as far as possible

42
    """
43
44
45
46
47
48
49
50

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
51
52
53
54
55

        if dist > init_dist:
            init_dist = dist
            init = temp

56
57
58
        i += 1

    return init
59
60


61
def compute_kernel(x, y):
62
63
64
65
66
67
68
69
70
71
72
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
73
74
75
    )


76
@tf.function
77
78
79
80
81
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

82
83
84
85
86
87
88
89
90
91
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
154
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
155
156

        x_centered = tf.stack(x_centered_list)
157
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
158
159
160
161
162
163
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
164
    def uncorrelated_feature(self, x):
165
166
167
        if self.encoding_dim <= 1:
            return 0.0
        else:
168
169
            output = K.sum(
                K.square(
170
                    self.covariance
171
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
172
173
174
175
176
177
178
179
180
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


181
182
183
184
185
186
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


187
188
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
189
190
191
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

192
193
194
    def get_config(self):
        config = super().get_config().copy()
        config.update(
195
            {"is_placeholder": self.is_placeholder,}
196
197
198
        )
        return config

199
200
201
202
203
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
204
        )
205
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
206

207
        return distribution_a
208
209


210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


242
class MMDiscrepancyLayer(Layer):
243
    """
244
    Identity transform layer that adds MM Discrepancy
245
246
247
    to the final model loss.
    """

248
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
249
        self.is_placeholder = True
250
        self.batch_size = batch_size
251
        self.beta = beta
252
        self.prior = prior
253
254
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

255
256
    def get_config(self):
        config = super().get_config().copy()
257
        config.update({"batch_size": self.batch_size})
258
        config.update({"beta": self.beta})
259
        config.update({"prior": self.prior})
260
261
        return config

262
    def call(self, z, **kwargs):
263
        true_samples = self.prior.sample(self.batch_size)
264
        mmd_batch = self.beta * compute_mmd([true_samples, z])
265
        self.add_loss(K.mean(mmd_batch), inputs=z)
266
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
267
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
268
269

        return z
270
271


272
273
274
275
276
277
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

278
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

300
301
302
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
303
304
305

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

306
        ### MMD-based overlap ###
307
        intercomponent_mmd = K.mean(
308
309
            tf.convert_to_tensor(
                [
310
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
311
312
313
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
314
            )
315
        )
316

317
        self.add_metric(
318
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
319
        )
320

321
322
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
323
324
325
326

        return target


327
class Dead_neuron_control(Layer):
328
329
330
331
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
332

333
334
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
335

336
337
338
339
340
341
342
343
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

        return z
344
345
346
347
348
349
350


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

351
    def __init__(self, weight=1.0, *args, **kwargs):
352
353
354
355
356
357
358
359
360
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

361
362
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
363
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
364
365

        # Adds metric that monitors dead neurons in the latent space
366
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
367

368
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
369
370

        return z