model_utils.py 11.5 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from tensorflow.keras import backend as K
5
6
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
lucas_miranda's avatar
lucas_miranda committed
7
import networkx as nx
8
import tensorflow as tf
9
import tensorflow_probability as tfp
10

11
tfd = tfp.distributions
12
tfpl = tfp.layers
13

lucas_miranda's avatar
lucas_miranda committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

# Connectivity for DLC models
def connect_mouse_topview(animal_id=None) -> nx.Graph:
    """Creates a nx.Graph object with the connectivity of the bodyparts in the
    DLC topview model for a single mouse. Used later for angle computing, among others

        Parameters:
            - animal_id (str): if more than one animal is tagged,
            specify the animal identyfier as a string

        Returns:
            - connectivity (nx.Graph)"""

    connectivity = {
        "Nose": ["Left_ear", "Right_ear", "Spine_1"],
        "Left_ear": ["Right_ear", "Spine_1"],
        "Right_ear": ["Spine_1"],
        "Spine_1": ["Center", "Left_fhip", "Right_fhip"],
        "Center": ["Left_fhip", "Right_fhip", "Spine_2", "Left_bhip", "Right_bhip"],
        "Spine_2": ["Left_bhip", "Right_bhip", "Tail_base"],
        "Tail_base": ["Tail_1", "Left_bhip", "Right_bhip"],
        "Tail_1": ["Tail_2"],
        "Tail_2": ["Tail_tip"],
    }

    connectivity = nx.Graph(connectivity)

    if animal_id:
        mapping = {
            node: "{}_{}".format(animal_id, node) for node in connectivity.nodes()
        }
        nx.relabel_nodes(connectivity, mapping, copy=False)

    return connectivity


50
# Helper functions
51
52
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
53
54
55
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
56
57
58
59
60
61
62
63

    init = tf.random.uniform(shape, minval, maxval)
    init_dist = tf.abs(tf.norm(tf.math.subtract(init[1:], init[:1])))
    i = 0

    while tf.less(i, iters):
        temp = tf.random.uniform(shape, minval, maxval)
        dist = tf.abs(tf.norm(tf.math.subtract(temp[1:], temp[:1])))
64
65
66
67
68

        if dist > init_dist:
            init_dist = dist
            init = temp

69
70
71
        i += 1

    return init
72
73


74
def compute_kernel(x, y):
75
76
77
78
79
80
81
82
83
84
85
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
86
87
88
    )


89
@tf.function
90
91
92
93
94
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

95
96
97
98
99
100
101
102
103
104
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(
        self,
        iterations,
        max_rate,
        start_rate=None,
        last_iterations=None,
        last_rate=None,
    ):
        self.iterations = iterations
        self.max_rate = max_rate
        self.start_rate = start_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.start_rate / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, rate1, rate2):
        return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            rate = self._interpolate(
                0, self.half_iteration, self.start_rate, self.max_rate
            )
        elif self.iteration < 2 * self.half_iteration:
            rate = self._interpolate(
                self.half_iteration,
                2 * self.half_iteration,
                self.max_rate,
                self.start_rate,
            )
        else:
            rate = self._interpolate(
                2 * self.half_iteration,
                self.iterations,
                self.start_rate,
                self.last_rate,
            )
            rate = max(rate, self.last_rate)
        self.iteration += 1
        K.set_value(self.model.optimizer.lr, rate)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
167
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))
168
169

        x_centered = tf.stack(x_centered_list)
170
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
171
172
173
174
175
176
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
177
    def uncorrelated_feature(self, x):
178
179
180
        if self.encoding_dim <= 1:
            return 0.0
        else:
181
182
            output = K.sum(
                K.square(
183
                    self.covariance
184
                    - tf.math.multiply(self.covariance, tf.eye(self.encoding_dim))
185
186
187
188
189
190
191
192
193
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


194
195
196
197
198
199
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, **kwargs):
        return super().call(inputs, training=True)


200
201
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
202
203
204
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

205
206
207
    def get_config(self):
        config = super().get_config().copy()
        config.update(
208
            {"is_placeholder": self.is_placeholder,}
209
210
211
        )
        return config

212
213
214
215
216
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
217
        )
218
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
219

220
        return distribution_a
221
222


223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


255
class MMDiscrepancyLayer(Layer):
256
    """
257
    Identity transform layer that adds MM Discrepancy
258
259
260
    to the final model loss.
    """

261
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
262
        self.is_placeholder = True
263
        self.batch_size = batch_size
264
        self.beta = beta
265
        self.prior = prior
266
267
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

268
269
    def get_config(self):
        config = super().get_config().copy()
270
        config.update({"batch_size": self.batch_size})
271
        config.update({"beta": self.beta})
272
        config.update({"prior": self.prior})
273
274
        return config

275
    def call(self, z, **kwargs):
276
        true_samples = self.prior.sample(self.batch_size)
277
        mmd_batch = self.beta * compute_mmd([true_samples, z])
278
        self.add_loss(K.mean(mmd_batch), inputs=z)
279
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
280
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
281
282

        return z
283
284


285
286
287
288
289
290
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

291
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

313
314
315
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
316
317
318

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

319
        ### MMD-based overlap ###
320
        intercomponent_mmd = K.mean(
321
322
            tf.convert_to_tensor(
                [
323
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
324
325
326
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
327
            )
328
        )
329

330
        self.add_metric(
331
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
332
        )
333

334
335
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
336
337
338
339

        return target


340
class Dead_neuron_control(Layer):
341
342
343
344
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
345

346
347
    def __init__(self, *args, **kwargs):
        super(Dead_neuron_control, self).__init__(*args, **kwargs)
348

349
350
351
352
353
354
355
356
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

        return z
357
358
359
360
361
362
363


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

364
    def __init__(self, weight=1.0, *args, **kwargs):
365
366
367
368
369
370
371
372
373
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

374
375
        # axis=1 increases the entropy of a cluster across instances
        # axis=0 increases the entropy of the assignment for a given instance
376
        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
377
378

        # Adds metric that monitors dead neurons in the latent space
379
        self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
380

381
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
382
383

        return z