Commit 2d21ca43 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added latent space regularization options to GMVAE

parent 331b7014
Pipeline #93201 passed with stage
in 54 minutes and 28 seconds
......@@ -258,6 +258,8 @@ class SEQ_2_SEQ_GMVAE:
overlap_loss: float = False,
phenotype_prediction: float = 0.0,
predictor: float = 0.0,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
):
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
......@@ -288,6 +290,8 @@ class SEQ_2_SEQ_GMVAE:
self.phenotype_prediction = phenotype_prediction
self.predictor = predictor
self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance
assert (
"ELBO" in self.loss or "MMD" in self.loss
......@@ -564,7 +568,11 @@ class SEQ_2_SEQ_GMVAE:
z_cat = Dense(
self.number_of_components,
activation="softmax",
kernel_regularizer=tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01),
kernel_regularizer=(
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
if self.reg_cat_clusters
else None
),
)(encoder)
if self.entropy_reg_weight > 0:
......@@ -572,14 +580,26 @@ class SEQ_2_SEQ_GMVAE:
z_cat
)
z_gauss = Dense(
z_gauss_mean = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
),
)
// 2,
activation=None,
)(encoder)
z_gauss_var = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
)
// 2,
activation=None,
)(
encoder
) # REMOVE BIAS FROM HERE! WHAT's THE POINT?
activity_regularizer=(
tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
),
)(encoder)
z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
......
......@@ -219,6 +219,42 @@ def huddle(
return hudd
def dig(
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
likelihood_dframe: pd.DataFrame,
tol_nose_speed: float,
tol_speed: float,
tol_likelihood: float,
animal_id: str = "",
):
pass
def sniff(
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
likelihood_dframe: pd.DataFrame,
tol_nose_speed: float,
tol_speed: float,
tol_likelihood: float,
animal_id: str = "",
):
pass
def look_around(
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
likelihood_dframe: pd.DataFrame,
tol_nose_speed: float,
tol_speed: float,
tol_likelihood: float,
animal_id: str = "",
):
pass
def following_path(
distance_dframe: pd.DataFrame,
position_dframe: pd.DataFrame,
......
......@@ -48,9 +48,18 @@ def plot_heatmap(
for i, bpart in enumerate(bodyparts):
heatmap = dframe[bpart]
if len(bodyparts) > 1:
sns.kdeplot(data=heatmap.x, data2=heatmap.y, cmap=None, shade=True, alpha=1, ax=ax[i])
sns.kdeplot(
data=heatmap.x,
data2=heatmap.y,
cmap=None,
shade=True,
alpha=1,
ax=ax[i],
)
else:
sns.kdeplot(data=heatmap.x, data2=heatmap.y, cmap=None, shade=True, alpha=1, ax=ax)
sns.kdeplot(
data=heatmap.x, data2=heatmap.y, cmap=None, shade=True, alpha=1, ax=ax
)
ax = np.array([ax])
[x.set_xlim(xlim) for x in ax]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment