Commit 1ebd5f15 authored by lucas_miranda's avatar lucas_miranda
Browse files

Increased default dimensionality of latent space

parent 69196db8
......@@ -133,9 +133,9 @@ class GMVAE:
defaults = {
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"clipvalue": 0.75,
"dense_activation": "relu",
"dense_layers_per_branch": 3,
"dense_layers_per_branch": 1,
"dropout_rate": 0.1,
"learning_rate": 1e-4,
"units_conv": 64,
......
......@@ -19,11 +19,11 @@ warmup_epochs = [15]
warmup_mode = ["sigmoid"]
losses = ["ELBO"] # , "MMD", "ELBO+MMD"]
overlap_loss = [0.1, 0.2, 0.5, 0.75, 1.]
encodings = [32] # [2, 4, 6, 8, 10, 12, 14, 16]
encodings = [16] # [2, 4, 6, 8, 10, 12, 14, 16]
cluster_numbers = [15] # [1, 5, 10, 15, 20, 25]
latent_reg = ["variance"] # ["none", "categorical", "variance", "categorical+variance"]
entropy_knn = [10]
next_sequence_pred_weights = [0.15]
next_sequence_pred_weights = [0.0]
phenotype_pred_weights = [0.0]
rule_based_pred_weights = [0.0]
window_lengths = [22] # range(11,56,11)
......
......@@ -630,7 +630,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
......@@ -647,7 +647,7 @@
" compile_model=True,\n",
" batch_size=batch_size,\n",
" encoding=encoding,\n",
" next_sequence_prediction=NextSeqPred,\n",
" next_sequence_prediction=0.1,\n",
" phenotype_prediction=PhenoPred,\n",
" rule_based_prediction=RuleBasedPred,\n",
").build(\n",
......@@ -658,11 +658,101 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"SEQ_2_SEQ_GMVAE\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_15 (InputLayer) [(None, 22, 26)] 0 \n",
"__________________________________________________________________________________________________\n",
"conv1d_24 (Conv1D) (None, 11, 64) 8384 input_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_94 (BatchNo (None, 11, 64) 256 conv1d_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_48 (Bidirectional (None, 11, 256) 148992 batch_normalization_94[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_95 (BatchNo (None, 11, 256) 1024 bidirectional_48[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_49 (Bidirectional (None, 128) 123648 batch_normalization_95[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_96 (BatchNo (None, 128) 512 bidirectional_49[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_88 (Dense) (None, 64) 8256 batch_normalization_96[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_97 (BatchNo (None, 64) 256 dense_88[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_8 (Dropout) (None, 64) 0 batch_normalization_97[0][0] \n",
"__________________________________________________________________________________________________\n",
"sequential_14 (Sequential) (None, 32) 2208 dropout_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"cluster_means (Dense) (None, 90) 2970 sequential_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"cluster_variances (Dense) (None, 90) 2970 sequential_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_14 (Concatenate) (None, 180) 0 cluster_means[0][0] \n",
" cluster_variances[0][0] \n",
"__________________________________________________________________________________________________\n",
"cluster_assignment (Dense) (None, 15) 495 sequential_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_8 (Reshape) (None, 12, 15) 0 concatenate_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"encoding_distribution (Distribu multiple 0 cluster_assignment[0][0] \n",
" reshape_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"kl_divergence_layer_6 (KLDiverg multiple 181 encoding_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"latent_distribution (Lambda) multiple 0 kl_divergence_layer_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_97 (Dense) (None, 32) 224 latent_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_102 (BatchN (None, 32) 128 dense_97[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_92 (Dense) (None, 64) 2112 batch_normalization_102[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_103 (BatchN (None, 64) 256 dense_92[0][0] \n",
"__________________________________________________________________________________________________\n",
"repeat_vector_9 (RepeatVector) (None, 22, 64) 0 batch_normalization_103[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_52 (Bidirectional (None, 22, 256) 148992 repeat_vector_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_104 (BatchN (None, 22, 256) 1024 bidirectional_52[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_53 (Bidirectional (None, 22, 256) 296448 batch_normalization_104[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_105 (BatchN (None, 22, 256) 1024 bidirectional_53[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv1d_26 (Conv1D) (None, 22, 64) 81984 batch_normalization_105[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_99 (Dense) (None, 22, 26) 1690 conv1d_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf.math.softplus_7 (TFOpLambda) (None, 22, 26) 0 dense_99[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_98 (Dense) (None, 22, 26) 1690 conv1d_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"lambda_7 (Lambda) (None, 22, 26) 0 tf.math.softplus_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_16 (Concatenate) (None, 22, 52) 0 dense_98[0][0] \n",
" lambda_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"vae_reconstruction (Functional) multiple 337940 latent_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"vae_prediction (IndependentNorm multiple 0 concatenate_16[0][0] \n",
"==================================================================================================\n",
"Total params: 1,173,664\n",
"Trainable params: 1,170,271\n",
"Non-trainable params: 3,393\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"# Uncomment to see model summaries\n",
"# encoder.summary()\n",
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment