Commit 436e9625 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent 49c396c5
......@@ -240,9 +240,9 @@ class Latent_space_control(Layer):
)
# Adds Silhouette score controlling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_metric(silhouette, aggregation="mean", name="silhouette")
# hard_labels = tf.math.argmax(z_cat, axis=1)
# silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
# self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
......
......@@ -355,7 +355,6 @@ for i in range(len(checkpoints) + 1):
if variational
else np.zeros(samples)
)
print(clusts)
encs.append(reducer.fit_transform(predictions[i], clusts))
else:
encs.append(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment