Commit 0b6883c1 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 021a38fb
Pipeline #98266 failed with stages
in 17 minutes and 51 seconds
......@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
number_of_components=k,
overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class,
predictor=self.predictor,
next_sequence_prediction=self.predictor,
).build(self.input_shape)[-3]
return gmvaep
......@@ -254,8 +254,10 @@ class SEQ_2_SEQ_GMVAE:
neuron_control: bool = False,
number_of_components: int = 1,
overlap_loss: float = 0.0,
next_sequence_prediction: float = 0.0,
phenotype_prediction: float = 0.0,
predictor: float = 0.0,
rule_based_prediction: float = 0.0,
rule_based_features: int = 6,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
):
......@@ -283,8 +285,10 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components = number_of_components
self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
self.overlap_loss = overlap_loss
self.next_sequence_prediction = next_sequence_prediction
self.phenotype_prediction = phenotype_prediction
self.predictor = predictor
self.rule_based_prediction = rule_based_prediction
self.rule_based_features = rule_based_features
self.prior = "standard_normal"
self.reg_cat_clusters = reg_cat_clusters
self.reg_cluster_variance = reg_cluster_variance
......@@ -334,7 +338,7 @@ class SEQ_2_SEQ_GMVAE:
"""Sets the default parameters for the model. Overwritable with a dictionary"""
defaults = {
"bidirectional_merge": "ave",
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 3,
......@@ -484,13 +488,20 @@ class SEQ_2_SEQ_GMVAE:
merge_mode=self.bidirectional_merge,
)
# Phenotype classification layers
# Phenotype classification layer
Model_PC1 = Dense(
self.number_of_components,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)
# Rule based trait classification layer
Model_RC1 = Dense(
self.number_of_components,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)
return (
Model_E0,
Model_E1,
......@@ -510,6 +521,7 @@ class SEQ_2_SEQ_GMVAE:
Model_P2,
Model_P3,
Model_PC1,
Model_RC1,
)
def build(self, input_shape: Tuple):
......@@ -538,6 +550,7 @@ class SEQ_2_SEQ_GMVAE:
Model_P2,
Model_P3,
Model_PC1,
Model_RC1,
) = self.get_layers(input_shape)
# Define and instantiate encoder
......@@ -668,7 +681,7 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_D4(generator)
generator = Model_B3(generator)
generator = Model_D5(generator)
# generator = Model_B4(generator)
generator = Model_B4(generator)
generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
generator
)
......@@ -692,7 +705,7 @@ class SEQ_2_SEQ_GMVAE:
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
if self.predictor > 0:
if self.next_sequence_prediction > 0:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2,
......@@ -706,7 +719,7 @@ class SEQ_2_SEQ_GMVAE:
predictor = Model_P2(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
# predictor = BatchNormalization()(predictor)
predictor = BatchNormalization()(predictor)
predictor = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
predictor
)
......@@ -719,7 +732,7 @@ class SEQ_2_SEQ_GMVAE:
model_outs.append(x_predicted_mean)
model_losses.append(log_loss)
model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.predictor)
loss_weights.append(self.next_sequence_prediction)
if self.phenotype_prediction > 0:
pheno_pred = Model_PC1(z)
......@@ -735,6 +748,22 @@ class SEQ_2_SEQ_GMVAE:
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction)
if self.rule_based_prediction > 0:
rule_pred = Model_RC1(z)
rule_pred = Dense(
tfpl.IndependentBernoulli.params_size(self.rule_based_features)
)(rule_pred)
rule_pred = tfpl.IndependentBernoulli(
event_shape=self.rule_based_features,
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="rule_based_prediction",
)(rule_pred)
model_outs.append(rule_pred)
model_losses.append(log_loss)
model_metrics["rule_based_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.rule_based_prediction)
# define grouper and end-to-end autoencoder model
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
gmvaep = Model(
......
......@@ -51,7 +51,7 @@ def test_SEQ_2_SEQ_GMVAE_build(
mmd_warmup_epochs=mmd_warmup_epochs,
montecarlo_kl=montecarlo_kl,
number_of_components=number_of_components,
predictor=True,
next_sequence_prediction=True,
phenotype_prediction=True,
overlap_loss=True,
).build(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment