model_utils.py 9.03 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from keras import backend as K
5
from sklearn.metrics import silhouette_score
6
7
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
8
import numpy as np
9
import tensorflow as tf
10
import tensorflow_probability as tfp
11

12
tfd = tfp.distributions
13
tfpl = tfp.layers
14
15

# Helper functions
16
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=10000000):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    """
    Returns a uniformly initialised matrix in which the columns are as far as possible
    """
    init_dist = 0
    for i in range(iters):
        temp = np.random.uniform(minval, maxval, shape)
        dist = np.abs(np.linalg.norm(np.diff(temp)))

        if dist > init_dist:
            init_dist = dist
            init = temp

    return init.astype(np.float32)


32
33
34
35
36
37
38
39
40
41
42
def compute_kernel(x, y):
    x_size = K.shape(x)[0]
    y_size = K.shape(y)[0]
    dim = K.shape(x)[1]
    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
    return K.exp(
        -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
    )


43
44
45
46
47
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


# Custom layers for efficiency/losses
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))

        x_centered = tf.stack(x_centered_list)
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
            output = K.sum(
                K.square(
                    self.covariance
                    - tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


135
136
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
137
138
139
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

140
141
142
143
144
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
145
        )
146
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
147

148
        return distribution_a
149
150
151


class MMDiscrepancyLayer(Layer):
152
    """
153
    Identity transform layer that adds MM Discrepancy
154
155
156
    to the final model loss.
    """

157
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
158
        self.is_placeholder = True
159
        self.batch_size = batch_size
160
        self.beta = beta
161
        self.prior = prior
162
163
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

164
165
    def get_config(self):
        config = super().get_config().copy()
166
        config.update({"batch_size": self.batch_size})
167
        config.update({"beta": self.beta})
168
        config.update({"prior": self.prior})
169
170
        return config

171
    def call(self, z, **kwargs):
172
        true_samples = self.prior.sample(self.batch_size)
173
        mmd_batch = self.beta * compute_mmd([true_samples, z])
174
        self.add_loss(K.mean(mmd_batch), inputs=z)
175
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
176
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
177
178

        return z
179
180


181
182
183
184
185
186
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

187
    def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

209
210
211
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
212
213
214

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

215
216
217
218
        ### MMD-based overlap ###
        intercomponent_mmd = K.mean(
            tf.convert_to_tensor(
                [
219
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
220
221
222
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
223
            )
224
        )
225

226
        self.add_metric(
227
            -intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
228
        )
229

230
231
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
232
233
234
235

        return target


236
class Latent_space_control(Layer):
237
238
239
240
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
241

242
    def __init__(self, silhouette=False, loss=False, *args, **kwargs):
243
        self.loss = loss
244
        self.silhouette = silhouette
245
246
        super(Latent_space_control, self).__init__(*args, **kwargs)

247
248
249
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})
250
        config.update({"silhouette": self.silhouette})
251

252
253
254
255
256
257
258
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

259
        # Adds Silhouette score controlling overlap between clusters
260
261
262
263
264
265
        if self.silhouette:
            hard_labels = tf.math.argmax(z_cat, axis=1)
            silhouette = tf.numpy_function(
                silhouette_score, [z, hard_labels], tf.float32
            )
            self.add_metric(silhouette, aggregation="mean", name="silhouette")
266

267
268
            if self.loss:
                self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
269

270
        return z
271
272
273
274
275
276
277


class Entropy_regulariser(Layer):
    """
    Identity layer that adds cluster weight entropy to the loss function
    """

278
    def __init__(self, weight=1., *args, **kwargs):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        self.weight = weight
        super(Entropy_regulariser, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"weight": self.weight})

    def call(self, z, **kwargs):

        entropy = K.sum(
            tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=0
        )

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(-entropy, aggregation="mean", name="weight_entropy")

295
        self.add_loss(self.weight * K.sum(entropy), inputs=[z])
296
297

        return z