model_utils.py 7.64 KB
Newer Older
1
2
# @author lucasmiranda42

3
from itertools import combinations
4
from keras import backend as K
5
from sklearn.metrics import silhouette_score
6
7
8
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
9
import tensorflow_probability as tfp
10

11
tfd = tfp.distributions
12
tfpl = tfp.layers
13
14
15
16
17
18
19
20
21
22
23
24
25

# Helper functions
def compute_kernel(x, y):
    x_size = K.shape(x)[0]
    y_size = K.shape(y)[0]
    dim = K.shape(x)[1]
    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
    return K.exp(
        -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
    )


26
27
28
29
30
def compute_mmd(tensors):

    x = tensors[0]
    y = tensors[1]

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


# Custom layers for efficiency/losses
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))

        x_centered = tf.stack(x_centered_list)
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
            output = K.sum(
                K.square(
                    self.covariance
                    - tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


118
119
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
    def __init__(self, *args, **kwargs):
120
121
122
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

123
124
125
126
127
    def call(self, distribution_a):
        kl_batch = self._regularizer(distribution_a)
        self.add_loss(kl_batch, inputs=[distribution_a])
        self.add_metric(
            kl_batch, aggregation="mean", name="kl_divergence",
128
        )
129
        self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
130

131
        return distribution_a
132
133
134


class MMDiscrepancyLayer(Layer):
135
    """
136
    Identity transform layer that adds MM Discrepancy
137
138
139
    to the final model loss.
    """

140
    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
141
        self.is_placeholder = True
142
        self.batch_size = batch_size
143
        self.beta = beta
144
        self.prior = prior
145
146
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

147
148
    def get_config(self):
        config = super().get_config().copy()
149
        config.update({"batch_size": self.batch_size})
150
        config.update({"beta": self.beta})
151
        config.update({"prior": self.prior})
152
153
        return config

154
    def call(self, z, **kwargs):
155
        true_samples = self.prior.sample(self.batch_size)
156
        mmd_batch = self.beta * compute_mmd([true_samples, z])
157
        self.add_loss(K.mean(mmd_batch), inputs=z)
158
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
159
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
160
161

        return z
162
163


164
165
166
167
168
169
170
class Gaussian_mixture_overlap(Layer):
    """
    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
    using a specified metric (MMD, Wasserstein, Fischer-Rao)
    """

    def __init__(
171
        self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    ):
        self.lat_dims = lat_dims
        self.n_components = n_components
        self.loss = loss
        self.samples = samples
        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({"lat_dims": self.lat_dims})
        config.update({"n_components": self.n_components})
        config.update({"loss": self.loss})
        config.update({"samples": self.samples})
        return config

    def call(self, target, loss=False):

        dists = []
        for k in range(self.n_components):
            locs = (target[..., : self.lat_dims, k],)
            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])

194
195
196
            dists.append(
                tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
            )
197
198
199

        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]

200
201
202
203
        ### MMD-based overlap ###
        intercomponent_mmd = K.mean(
            tf.convert_to_tensor(
                [
204
                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
205
206
207
                    for c in combinations(range(len(dists)), 2)
                ],
                dtype=tf.float32,
208
            )
209
        )
210

211
212
213
        self.add_metric(
            intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
        )
214

215
216
        if self.loss:
            self.add_loss(-intercomponent_mmd, inputs=[target])
217
218
219
220

        return target


221
class Latent_space_control(Layer):
222
223
224
225
    """
    Identity layer that adds latent space and clustering stats
    to the metrics compiled by the model
    """
226

227
228
    def __init__(self, loss=False, *args, **kwargs):
        self.loss = loss
229
230
        super(Latent_space_control, self).__init__(*args, **kwargs)

231
232
233
234
    def get_config(self):
        config = super().get_config().copy()
        config.update({"loss": self.loss})

235
236
237
238
239
240
241
    def call(self, z, z_gauss, z_cat, **kwargs):

        # Adds metric that monitors dead neurons in the latent space
        self.add_metric(
            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
        )

242
        # Adds Silhouette score controlling overlap between clusters
243
244
245
246
        hard_labels = tf.math.argmax(z_cat, axis=1)
        silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
        self.add_metric(silhouette, aggregation="mean", name="silhouette")

247
248
249
        if self.loss:
            self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])

250
        return z