model_utils.py 5.17 KB
Newer Older
1
2
3
4
5
6
# @author lucasmiranda42

from keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
7
import tensorflow_probability as tfp
8

9
tfd = tfp.distributions
10
11

# Helper functions
12
def sampling(args, epsilon_std=1.0, number_of_components=1, categorical=None):
13
    z_mean, z_log_sigma = args
14
15
16
17
18
19
20
21

    if number_of_components == 1:
        epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std)
        return z_mean + K.exp(z_log_sigma) * epsilon

    else:
        # Implement mixture of gaussians encoding and sampling
        pass
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


def compute_kernel(x, y):
    x_size = K.shape(x)[0]
    y_size = K.shape(y)[0]
    dim = K.shape(x)[1]
    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))
    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))
    return K.exp(
        -tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32)
    )


def compute_mmd(x, y):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )


# Custom layers for efficiency/losses
class DenseTranspose(Layer):
    def __init__(self, dense, output_dim, activation=None, **kwargs):
        self.dense = dense
        self.output_dim = output_dim
        self.activation = tf.keras.activations.get(activation)
        super().__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "dense": self.dense,
                "output_dim": self.output_dim,
                "activation": self.activation,
            }
        )
        return config

    def build(self, batch_input_shape):
        self.biases = self.add_weight(
            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
        )
        super().build(batch_input_shape)

    def call(self, inputs, **kwargs):
        z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
        return self.activation(z + self.biases)

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.output_dim


class UncorrelatedFeaturesConstraint(Constraint):
    def __init__(self, encoding_dim, weightage=1.0):
        self.encoding_dim = encoding_dim
        self.weightage = weightage

    def get_config(self):

        config = super().get_config().copy()
        config.update(
            {"encoding_dim": self.encoding_dim, "weightage": self.weightage,}
        )
        return config

    def get_covariance(self, x):
        x_centered_list = []

        for i in range(self.encoding_dim):
            x_centered_list.append(x[:, i] - K.mean(x[:, i]))

        x_centered = tf.stack(x_centered_list)
        covariance = K.dot(x_centered, K.transpose(x_centered)) / tf.cast(
            x_centered.get_shape()[0], tf.float32
        )

        return covariance

    # Constraint penalty
    def uncorrelated_feature(self, x):
        if self.encoding_dim <= 1:
            return 0.0
        else:
            output = K.sum(
                K.square(
                    self.covariance
                    - tf.math.multiply(self.covariance, K.eye(self.encoding_dim))
                )
            )
            return output

    def __call__(self, x):
        self.covariance = self.get_covariance(x)
        return self.weightage * self.uncorrelated_feature(x)


class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

129
    def __init__(self, beta=1.0, *args, **kwargs):
130
        self.is_placeholder = True
131
        self.beta = beta
132
133
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

134
135
136
137
138
    def get_config(self):
        config = super().get_config().copy()
        config.update({"beta": self.beta})
        return config

139
140
    def call(self, inputs, **kwargs):
        mu, log_var = inputs
141
142
143
144
145
        KL_batch = (
            -0.5
            * self.beta
            * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
        )
146

147
148
        self.add_loss(K.mean(KL_batch), inputs=inputs)
        self.add_metric(KL_batch, aggregation="mean", name="kl_divergence")
149
        self.add_metric(self.beta, aggregation="mean", name="kl_rate")
150
151
152
153
154
155
156
157
158

        return inputs


class MMDiscrepancyLayer(Layer):
    """ Identity transform layer that adds MM discrepancy
    to the final model loss.
    """

159
    def __init__(self, beta=1.0, *args, **kwargs):
160
        self.is_placeholder = True
161
        self.beta = beta
162
163
        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)

164
165
166
167
168
    def get_config(self):
        config = super().get_config().copy()
        config.update({"beta": self.beta})
        return config

169
    def call(self, z, **kwargs):
170
        true_samples = K.random_normal(K.shape(z))
171
        mmd_batch = self.beta * compute_mmd(true_samples, z)
172

173
        self.add_loss(K.mean(mmd_batch), inputs=z)
174
        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
175
        self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
176
177

        return z