Commit f0902fce authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent b35d840c
Pipeline #99821 failed with stages
in 17 minutes and 25 seconds
......@@ -341,11 +341,11 @@ class SEQ_2_SEQ_GMVAE:
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 1,
"dense_layers_per_branch": 3,
"dropout_rate": 0.05,
"learning_rate": 1e-3,
"units_conv": 64,
"units_dense2": 64,
"units_dense2": 32,
"units_lstm": 128,
}
......@@ -399,23 +399,15 @@ class SEQ_2_SEQ_GMVAE:
use_bias=True,
)
# Nested list comprehension adding a series of dense and batch norm layers
Model_E4 = [
i
for s in [
[
Dense(
self.DENSE_2,
activation=self.dense_activation,
# kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
use_bias=True,
),
BatchNormalization(),
]
for _ in range(self.dense_layers_per_branch)
]
for i in s
Dense(
self.DENSE_2,
activation=self.dense_activation,
# kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
use_bias=True,
)
for _ in range(self.dense_layers_per_branch)
]
# Decoder layers
......
......@@ -103,6 +103,7 @@ def get_callbacks(
run_ID = "{}{}{}{}{}{}{}{}{}{}{}{}".format(
("GMVAE" if variational else "AE"),
("_input_type={}".format(input_type) if input_type else "coords"),
("_window_size={}", format(X_train.shape[1])),
("_NextSeqPred={}".format(next_sequence_prediction) if variational else ""),
("_PhenoPred={}".format(phenotype_prediction) if variational else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""),
......
......@@ -23,6 +23,7 @@ entropy_knn = [100]
next_sequence_pred_weights = [0.0, 0.15]
phenotype_pred_weights = [0.0]
rule_based_pred_weights = [0.0, 0.15]
window_lengths = range(11,56,11)
input_types = ["coords"]
run = list(range(1, 11))
......@@ -47,6 +48,7 @@ rule deepof_experiments:
expand(
outpath + "train_models/trained_weights/"
"GMVAE_input_type={input_type}_"
"window_size={window_size}_"
"NextSeqPred={nspredweight}_"
"PhenoPred={phenpredweight}_"
"RuleBasedPred={rulesweight}_"
......@@ -58,6 +60,7 @@ rule deepof_experiments:
"run={run}_"
"final_weights.h5",
input_type=input_types,
window_size=window_lengths,
loss=losses,
encs=encodings,
k=cluster_numbers,
......@@ -121,6 +124,7 @@ rule train_models:
output:
trained_models=outpath + "train_models/trained_weights/"
"GMVAE_input_type={input_type}_"
"window_size={window_size}_"
"NextSeqPred={nspredweight}_"
"PhenoPred={phenpredweight}_"
"RuleBasedPred={rulesweight}_"
......@@ -149,7 +153,7 @@ rule train_models:
"--encoding-size {wildcards.encs} "
"--entropy-knn {wildcards.entknn} "
"--batch-size 256 "
"--window-size 24 "
"--window-step 12 "
"--window-size {window_size} "
"--window-step 11 "
"--run {wildcards.run} "
"--output-path {outpath}train_models"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment