Commit 2b01ccdd authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py

parent e46d76a7
...@@ -220,6 +220,9 @@ class Gaussian_mixture_overlap(Layer): ...@@ -220,6 +220,9 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd" intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
) )
if self.loss:
self.add_loss(-intercomponent_mmd, inputs=[target])
elif self.metric == "wasserstein": elif self.metric == "wasserstein":
pass pass
...@@ -232,9 +235,14 @@ class Latent_space_control(Layer): ...@@ -232,9 +235,14 @@ class Latent_space_control(Layer):
to the metrics compiled by the model to the metrics compiled by the model
""" """
def __init__(self, *args, **kwargs): def __init__(self, loss=False, *args, **kwargs):
self.loss = loss
super(Latent_space_control, self).__init__(*args, **kwargs) super(Latent_space_control, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"loss": self.loss})
def call(self, z, z_gauss, z_cat, **kwargs): def call(self, z, z_gauss, z_cat, **kwargs):
# Adds metric that monitors dead neurons in the latent space # Adds metric that monitors dead neurons in the latent space
...@@ -245,7 +253,9 @@ class Latent_space_control(Layer): ...@@ -245,7 +253,9 @@ class Latent_space_control(Layer):
# Adds Silhouette score controling overlap between clusters # Adds Silhouette score controling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1) hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32) silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
self.add_metric(silhouette, aggregation="mean", name="silhouette") self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
return z return z
...@@ -168,6 +168,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -168,6 +168,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components=1, number_of_components=1,
predictor=True, predictor=True,
overlap_metric="mmd", overlap_metric="mmd",
overlap_loss=False,
): ):
self.input_shape = input_shape self.input_shape = input_shape
self.CONV_filters = units_conv self.CONV_filters = units_conv
...@@ -185,6 +186,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -185,6 +186,7 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components = number_of_components self.number_of_components = number_of_components
self.predictor = predictor self.predictor = predictor
self.overlap_metric = overlap_metric self.overlap_metric = overlap_metric
self.overlap_loss = overlap_loss
if self.prior == "standard_normal": if self.prior == "standard_normal":
self.prior = tfd.mixture.Mixture( self.prior = tfd.mixture.Mixture(
...@@ -301,7 +303,10 @@ class SEQ_2_SEQ_GMVAE: ...@@ -301,7 +303,10 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Gaussian_mixture_overlap( z_gauss = Gaussian_mixture_overlap(
self.ENCODING, self.number_of_components, metric=self.overlap_metric self.ENCODING,
self.number_of_components,
metric=self.overlap_metric,
loss=self.overlap_loss,
)(z_gauss) )(z_gauss)
z = tfpl.DistributionLambda( z = tfpl.DistributionLambda(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment