Commit b5d3c717 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added maximum entropy regulariser to the gaussian mixture weight layer

parent 9e646727
......@@ -232,7 +232,7 @@ DLC_social_1_coords = DLC_social_1.run(verbose=True)
DLC_social_2_coords = DLC_social_2.run(verbose=True)
# Coordinates for training data
coords1 = DLC_social_1_coords.get_coords(center="B_Center")
coords1 = DLC_social_1_coords.get_coords(center="B_Center", polar=True)
distances1 = DLC_social_1_coords.get_distances()
angles1 = DLC_social_1_coords.get_angles()
coords_distances1 = merge_tables(coords1, distances1)
......@@ -241,7 +241,7 @@ dists_angles1 = merge_tables(distances1, angles1)
coords_dist_angles1 = merge_tables(coords1, distances1, angles1)
# Coordinates for validation data
coords2 = DLC_social_2_coords.get_coords(center="B_Center")
coords2 = DLC_social_2_coords.get_coords(center="B_Center", polar=True)
distances2 = DLC_social_2_coords.get_distances()
angles2 = DLC_social_2_coords.get_angles()
coords_distances2 = merge_tables(coords2, distances2)
......
......@@ -268,3 +268,30 @@ class Latent_space_control(Layer):
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
return z
class Entropy_regulariser(Layer):
"""
Identity layer that adds cluster weight entropy to the loss function
"""
def __init__(self, weight=False, *args, **kwargs):
self.weight = weight
super(Entropy_regulariser, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"weight": self.weight})
def call(self, z, **kwargs):
entropy = K.sum(
tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=0
)
# Adds metric that monitors dead neurons in the latent space
self.add_metric(-entropy, aggregation="mean", name="weight_entropy")
self.add_loss(-K.mean(entropy), inputs=[z])
return z
......@@ -173,6 +173,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components=1,
predictor=True,
overlap_loss=False,
entropy_reg_weight=1,
):
self.input_shape = input_shape
self.batch_size = batch_size
......@@ -191,6 +192,7 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components = number_of_components
self.predictor = predictor
self.overlap_loss = overlap_loss
self.entropy_reg_weight = entropy_reg_weight
if self.prior == "standard_normal":
......@@ -302,6 +304,7 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder)
z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
z_cat = Entropy_regulariser(self.entropy_reg_weight)(z_cat)
z_gauss = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment