Commit 2b01ccdd authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py

parent e46d76a7
......@@ -220,6 +220,9 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
)
if self.loss:
self.add_loss(-intercomponent_mmd, inputs=[target])
elif self.metric == "wasserstein":
pass
......@@ -232,9 +235,14 @@ class Latent_space_control(Layer):
to the metrics compiled by the model
"""
def __init__(self, *args, **kwargs):
def __init__(self, loss=False, *args, **kwargs):
self.loss = loss
super(Latent_space_control, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"loss": self.loss})
def call(self, z, z_gauss, z_cat, **kwargs):
# Adds metric that monitors dead neurons in the latent space
......@@ -245,7 +253,9 @@ class Latent_space_control(Layer):
# Adds Silhouette score controling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
return z
......@@ -168,6 +168,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components=1,
predictor=True,
overlap_metric="mmd",
overlap_loss=False,
):
self.input_shape = input_shape
self.CONV_filters = units_conv
......@@ -185,6 +186,7 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components = number_of_components
self.predictor = predictor
self.overlap_metric = overlap_metric
self.overlap_loss = overlap_loss
if self.prior == "standard_normal":
self.prior = tfd.mixture.Mixture(
......@@ -301,7 +303,10 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Gaussian_mixture_overlap(
self.ENCODING, self.number_of_components, metric=self.overlap_metric
self.ENCODING,
self.number_of_components,
metric=self.overlap_metric,
loss=self.overlap_loss,
)(z_gauss)
z = tfpl.DistributionLambda(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment