model_utils.py 10.4 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from tensorflow.keras import backend as K
5
6
7
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
8
import tensorflow_probability as tfp
9

10
tfd = tfp.distributions
11
tfpl = tfp.layers
12
13

# Helper functions
14
15
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
16
17
18
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
19
20
21
22
23
24
25
26

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
27
28
29
30
31

        if dist > init_dist:
            init_dist = dist
            init = temp

32
33
34
        i += 1

    return init
35
36


37
def compute_kernel(x, y):
38
39
40
41
42
43
44
45
46
47
48
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
49
50
51
    )


52
@tf.function
53
54
55
56
57
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

58
59
60
61
62
63
64
65
66
67
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
130
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
131
132

        x_centered = tf.stack(x_centered_list)
133
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
134
135
136
137
138
139
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
140
    def uncorrelated_feature(self, x):
141
142
143
        if self.encoding_dim <= 1:
            return 0.0
        else:
144
145
            output = K.sum(
                K.square(
146
                    self.covariance
147
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
148
149
150
151
152
153
154
155
156
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


157
158
159
160
161
162
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


163
164
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
165
166
167
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

168
169
170
    def get_config(self):
        config = super().get_config().copy()
        config.update(
171
            {"is_placeholder": self.is_placeholder,}
172
173
174
        )
        return config

175
176
177
178
179
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
180
        )
181
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
182

183
        return distribution_a
184
185


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


218
class MMDiscrepancyLayer(Layer):
219
    """
220
    Identity transform layer that adds MM Discrepancy
221
222
223
    to the final model loss.
    """

224
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
225
        self.is_placeholder = True
226
        self.batch_size = batch_size
227
        self.beta = beta
228
        self.prior = prior
229
230
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

231
232
    def get_config(self):
        config = super().get_config().copy()
233
        config.update({"batch_size": self.batch_size})
234
        config.update({"beta": self.beta})
235
        config.update({"prior": self.prior})
236
237
        return config

238
    def call(self, z, **kwargs):
239
        true_samples = self.prior.sample(self.batch_size)
240
        mmd_batch = self.beta * compute_mmd([true_samples, z])
241
        self.add_loss(K.mean(mmd_batch), inputs=z)
242
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
243
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
244
245

        return z
246
247


248
249
250
251
252
253
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

254
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

276
277
278
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
279
280
281

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

282
        ### MMD-based overlap ###
283
        intercomponent_mmd = K.mean(
284
285
            tf.convert_to_tensor(
                [
286
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
287
288
289
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
290
            )
291
        )
292

293
        self.add_metric(
294
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
295
        )
296

297
298
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
299
300
301
302

        return target


303
class Dead_neuron_control(Layer):
304
305
306
307
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
308

309
310
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
311

312
313
314
315
316
317
318
319
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

        return z
320
321
322
323
324
325
326


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

327
    def __init__(self, weight=1.0, *args, **kwargs):
328
329
330
331
332
333
334
335
336
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

337
338
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
339
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
340
341

        # Adds metric that monitors dead neurons in the latent space
342
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
343

344
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
345
346

        return z