model_utils.py 10.5 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

Functions and general utilities for the deepof tensorflow models. See documentation for details

"""
10

11
from itertools import combinations
12
from tensorflow.keras import backend as K
13
14
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
lucas_miranda's avatar
lucas_miranda committed
15
import networkx as nx
16
import tensorflow as tf
17
import tensorflow_probability as tfp
18

19
tfd = tfp.distributions
20
tfpl = tfp.layers
21
22

# Helper functions
23
24
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
25
26
27
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
28
29
30
31
32
33
34
35

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
36
37
38
39
40

        if dist > init_dist:
            init_dist = dist
            init = temp

41
42
43
        i += 1

    return init
44
45


46
def compute_kernel(x, y):
47
48
49
50
51
52
53
54
55
56
57
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
58
59
60
    )


61
@tf.function
62
63
64
65
66
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

67
68
69
70
71
72
73
74
75
76
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
139
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
140
141

        x_centered = tf.stack(x_centered_list)
142
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
143
144
145
146
147
148
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
149
    def uncorrelated_feature(self, x):
150
151
152
        if self.encoding_dim <= 1:
            return 0.0
        else:
153
154
            output = K.sum(
                K.square(
155
                    self.covariance
156
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
157
158
159
160
161
162
163
164
165
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


166
167
168
169
170
171
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


172
173
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
174
175
176
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

177
178
179
    def get_config(self):
        config = super().get_config().copy()
        config.update(
180
            {"is_placeholder": self.is_placeholder,}
181
182
183
        )
        return config

184
185
186
187
188
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
189
        )
190
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
191

192
        return distribution_a
193
194


195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


227
class MMDiscrepancyLayer(Layer):
228
    """
229
    Identity transform layer that adds MM Discrepancy
230
231
232
    to the final model loss.
    """

233
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
234
        self.is_placeholder = True
235
        self.batch_size = batch_size
236
        self.beta = beta
237
        self.prior = prior
238
239
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

240
241
    def get_config(self):
        config = super().get_config().copy()
242
        config.update({"batch_size": self.batch_size})
243
        config.update({"beta": self.beta})
244
        config.update({"prior": self.prior})
245
246
        return config

247
    def call(self, z, **kwargs):
248
        true_samples = self.prior.sample(self.batch_size)
249
        mmd_batch = self.beta * compute_mmd([true_samples, z])
250
        self.add_loss(K.mean(mmd_batch), inputs=z)
251
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
252
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
253
254

        return z
255
256


257
258
259
260
261
262
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

263
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

285
286
287
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
288
289
290

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

291
        ### MMD-based overlap ###
292
        intercomponent_mmd = K.mean(
293
294
            tf.convert_to_tensor(
                [
295
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
296
297
298
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
299
            )
300
        )
301

302
        self.add_metric(
303
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
304
        )
305

306
307
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
308
309
310
311

        return target


312
class Dead_neuron_control(Layer):
313
314
315
316
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
317

318
319
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
320

321
322
323
324
325
326
327
328
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

        return z
329
330
331
332
333
334
335


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

336
    def __init__(self, weight=1.0, *args, **kwargs):
337
338
339
340
341
342
343
344
345
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

346
347
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
348
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
349
350

        # Adds metric that monitors dead neurons in the latent space
351
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
352

353
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
354
355

        return z