Commit 0b6883c1 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 021a38fb
Pipeline #98266 failed with stages
in 17 minutes and 51 seconds
...@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
number_of_components=k, number_of_components=k,
overlap_loss=self.overlap_loss, overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class, phenotype_prediction=self.pheno_class,
predictor=self.predictor, next_sequence_prediction=self.predictor,
).build(self.input_shape)[-3] ).build(self.input_shape)[-3]
return gmvaep return gmvaep
...@@ -254,8 +254,10 @@ class SEQ_2_SEQ_GMVAE: ...@@ -254,8 +254,10 @@ class SEQ_2_SEQ_GMVAE:
neuron_control: bool = False, neuron_control: bool = False,
number_of_components: int = 1, number_of_components: int = 1,
overlap_loss: float = 0.0, overlap_loss: float = 0.0,
next_sequence_prediction: float = 0.0,
phenotype_prediction: float = 0.0, phenotype_prediction: float = 0.0,
predictor: float = 0.0, rule_based_prediction: float = 0.0,
rule_based_features: int = 6,
reg_cat_clusters: bool = False, reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False, reg_cluster_variance: bool = False,
): ):
...@@ -283,8 +285,10 @@ class SEQ_2_SEQ_GMVAE: ...@@ -283,8 +285,10 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components = number_of_components self.number_of_components = number_of_components
self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue) self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
self.overlap_loss = overlap_loss self.overlap_loss = overlap_loss
self.next_sequence_prediction = next_sequence_prediction
self.phenotype_prediction = phenotype_prediction self.phenotype_prediction = phenotype_prediction
self.predictor = predictor self.rule_based_prediction = rule_based_prediction
self.rule_based_features = rule_based_features
self.prior = "standard_normal" self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance self.reg_cluster_variance = reg_cluster_variance
...@@ -334,7 +338,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -334,7 +338,7 @@ class SEQ_2_SEQ_GMVAE:
"""Sets the default parameters for the model. Overwritable with a dictionary""" """Sets the default parameters for the model. Overwritable with a dictionary"""
defaults = { defaults = {
"bidirectional_merge": "ave", "bidirectional_merge": "concat",
"clipvalue": 1.0, "clipvalue": 1.0,
"dense_activation": "relu", "dense_activation": "relu",
"dense_layers_per_branch": 3, "dense_layers_per_branch": 3,
...@@ -484,13 +488,20 @@ class SEQ_2_SEQ_GMVAE: ...@@ -484,13 +488,20 @@ class SEQ_2_SEQ_GMVAE:
merge_mode=self.bidirectional_merge, merge_mode=self.bidirectional_merge,
) )
# Phenotype classification layers # Phenotype classification layer
Model_PC1 = Dense( Model_PC1 = Dense(
self.number_of_components, self.number_of_components,
activation=self.dense_activation, activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
) )
# Rule based trait classification layer
Model_RC1 = Dense(
self.number_of_components,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)
return ( return (
Model_E0, Model_E0,
Model_E1, Model_E1,
...@@ -510,6 +521,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -510,6 +521,7 @@ class SEQ_2_SEQ_GMVAE:
Model_P2, Model_P2,
Model_P3, Model_P3,
Model_PC1, Model_PC1,
Model_RC1,
) )
def build(self, input_shape: Tuple): def build(self, input_shape: Tuple):
...@@ -538,6 +550,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -538,6 +550,7 @@ class SEQ_2_SEQ_GMVAE:
Model_P2, Model_P2,
Model_P3, Model_P3,
Model_PC1, Model_PC1,
Model_RC1,
) = self.get_layers(input_shape) ) = self.get_layers(input_shape)
# Define and instantiate encoder # Define and instantiate encoder
...@@ -668,7 +681,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -668,7 +681,7 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_D4(generator) generator = Model_D4(generator)
generator = Model_B3(generator) generator = Model_B3(generator)
generator = Model_D5(generator) generator = Model_D5(generator)
# generator = Model_B4(generator) generator = Model_B4(generator)
generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))( generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
generator generator
) )
...@@ -692,7 +705,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -692,7 +705,7 @@ class SEQ_2_SEQ_GMVAE:
model_metrics = {"vae_reconstruction": ["mae", "mse"]} model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0] loss_weights = [1.0]
if self.predictor > 0: if self.next_sequence_prediction > 0:
# Define and instantiate predictor # Define and instantiate predictor
predictor = Dense( predictor = Dense(
self.DENSE_2, self.DENSE_2,
...@@ -706,7 +719,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -706,7 +719,7 @@ class SEQ_2_SEQ_GMVAE:
predictor = Model_P2(predictor) predictor = Model_P2(predictor)
predictor = BatchNormalization()(predictor) predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor) predictor = Model_P3(predictor)
# predictor = BatchNormalization()(predictor) predictor = BatchNormalization()(predictor)
predictor = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))( predictor = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
predictor predictor
) )
...@@ -719,7 +732,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -719,7 +732,7 @@ class SEQ_2_SEQ_GMVAE:
model_outs.append(x_predicted_mean) model_outs.append(x_predicted_mean)
model_losses.append(log_loss) model_losses.append(log_loss)
model_metrics["vae_prediction"] = ["mae", "mse"] model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.predictor) loss_weights.append(self.next_sequence_prediction)
if self.phenotype_prediction > 0: if self.phenotype_prediction > 0:
pheno_pred = Model_PC1(z) pheno_pred = Model_PC1(z)
...@@ -735,6 +748,22 @@ class SEQ_2_SEQ_GMVAE: ...@@ -735,6 +748,22 @@ class SEQ_2_SEQ_GMVAE:
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"] model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction) loss_weights.append(self.phenotype_prediction)
if self.rule_based_prediction > 0:
rule_pred = Model_RC1(z)
rule_pred = Dense(
tfpl.IndependentBernoulli.params_size(self.rule_based_features)
)(rule_pred)
rule_pred = tfpl.IndependentBernoulli(
event_shape=self.rule_based_features,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="rule_based_prediction",
)(rule_pred)
model_outs.append(rule_pred)
model_losses.append(log_loss)
model_metrics["rule_based_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.rule_based_prediction)
# define grouper and end-to-end autoencoder model # define grouper and end-to-end autoencoder model
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering") grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
gmvaep = Model( gmvaep = Model(
......
...@@ -51,7 +51,7 @@ def test_SEQ_2_SEQ_GMVAE_build( ...@@ -51,7 +51,7 @@ def test_SEQ_2_SEQ_GMVAE_build(
mmd_warmup_epochs=mmd_warmup_epochs, mmd_warmup_epochs=mmd_warmup_epochs,
montecarlo_kl=montecarlo_kl, montecarlo_kl=montecarlo_kl,
number_of_components=number_of_components, number_of_components=number_of_components,
predictor=True, next_sequence_prediction=True,
phenotype_prediction=True, phenotype_prediction=True,
overlap_loss=True, overlap_loss=True,
).build( ).build(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment