Commit a3d522a2 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented Montecarlo Dropout custom layer

parent 7d1ad0b5
......@@ -55,42 +55,49 @@ def compute_mmd(tensors):
)
# Custom layers for efficiency/losses
class MCDropout(tf.keras.layers.Dropout):
def call(self, inputs, **kwargs):
return super().call(inputs, training=True)
class DenseTranspose(Layer):
def __init__(self, dense, output_dim, activation=None, **kwargs):
self.dense = dense
self.output_dim = output_dim
self.activation = tf.keras.activations.get(activation)
super().__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
"dense": self.dense,
"output_dim": self.output_dim,
"activation": self.activation,
}
)
return config
def build(self, batch_input_shape):
self.biases = self.add_weight(
name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
)
super().build(batch_input_shape)
def call(self, inputs, **kwargs):
z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
return self.activation(z + self.biases)
def compute_output_shape(self, input_shape):
return input_shape[0], self.output_dim
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
def __init__(
self,
iterations,
max_rate,
start_rate=None,
last_iterations=None,
last_rate=None,
):
self.iterations = iterations
self.max_rate = max_rate
self.start_rate = start_rate or max_rate / 10
self.last_iterations = last_iterations or iterations // 10 + 1
self.half_iteration = (iterations - self.last_iterations) // 2
self.last_rate = last_rate or self.start_rate / 1000
self.iteration = 0
def _interpolate(self, iter1, iter2, rate1, rate2):
return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1
def on_batch_begin(self, batch, logs):
if self.iteration < self.half_iteration:
rate = self._interpolate(
0, self.half_iteration, self.start_rate, self.max_rate
)
elif self.iteration < 2 * self.half_iteration:
rate = self._interpolate(
self.half_iteration,
2 * self.half_iteration,
self.max_rate,
self.start_rate,
)
else:
rate = self._interpolate(
2 * self.half_iteration,
self.iterations,
self.start_rate,
self.last_rate,
)
rate = max(rate, self.last_rate)
self.iteration += 1
K.set_value(self.model.optimizer.lr, rate)
class UncorrelatedFeaturesConstraint(Constraint):
......@@ -137,6 +144,12 @@ class UncorrelatedFeaturesConstraint(Constraint):
return self.weightage * self.uncorrelated_feature(x)
# Custom Layers
class MCDropout(tf.keras.layers.Dropout):
def call(self, inputs, **kwargs):
return super().call(inputs, training=True)
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def __init__(self, *args, **kwargs):
self.is_placeholder = True
......@@ -153,6 +166,38 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
return distribution_a
class DenseTranspose(Layer):
def __init__(self, dense, output_dim, activation=None, **kwargs):
self.dense = dense
self.output_dim = output_dim
self.activation = tf.keras.activations.get(activation)
super().__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
"dense": self.dense,
"output_dim": self.output_dim,
"activation": self.activation,
}
)
return config
def build(self, batch_input_shape):
self.biases = self.add_weight(
name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
)
super().build(batch_input_shape)
def call(self, inputs, **kwargs):
z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)
return self.activation(z + self.biases)
def compute_output_shape(self, input_shape):
return input_shape[0], self.output_dim
class MMDiscrepancyLayer(Layer):
"""
Identity transform layer that adds MM Discrepancy
......
......@@ -89,9 +89,7 @@ class SEQ_2_SEQ_AE:
)
# Decoder layers
Model_D0 = DenseTranspose(
Model_E5, activation="elu", output_dim=self.ENCODING,
)
Model_D0 = DenseTranspose(Model_E5, activation="elu", output_dim=self.ENCODING,)
Model_D1 = DenseTranspose(Model_E4, activation="elu", output_dim=self.DENSE_2,)
Model_D2 = DenseTranspose(Model_E3, activation="elu", output_dim=self.DENSE_1,)
Model_D3 = RepeatVector(self.input_shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment