diff --git a/deepof/models.py b/deepof/models.py index ef71032152c4d0bfec8dd9649358b46185cfc0e6..6ad1b509eccab3540fa92574423a460b538c7f1f 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -133,9 +133,9 @@ class GMVAE: defaults = { "bidirectional_merge": "concat", - "clipvalue": 1.0, + "clipvalue": 0.75, "dense_activation": "relu", - "dense_layers_per_branch": 3, + "dense_layers_per_branch": 1, "dropout_rate": 0.1, "learning_rate": 1e-4, "units_conv": 64, diff --git a/deepof_experiments.smk b/deepof_experiments.smk index 3d33234df686f41a1fc08d43b1ead9ee6f8fcfa1..e7c701c3e0905c219ea0f4afaecdd839740b56d1 100644 --- a/deepof_experiments.smk +++ b/deepof_experiments.smk @@ -19,11 +19,11 @@ warmup_epochs = [15] warmup_mode = ["sigmoid"] losses = ["ELBO"] # , "MMD", "ELBO+MMD"] overlap_loss = [0.1, 0.2, 0.5, 0.75, 1.] -encodings = [32] # [2, 4, 6, 8, 10, 12, 14, 16] +encodings = [16] # [2, 4, 6, 8, 10, 12, 14, 16] cluster_numbers = [15] # [1, 5, 10, 15, 20, 25] latent_reg = ["variance"] # ["none", "categorical", "variance", "categorical+variance"] entropy_knn = [10] -next_sequence_pred_weights = [0.15] +next_sequence_pred_weights = [0.0] phenotype_pred_weights = [0.0] rule_based_pred_weights = [0.0] window_lengths = [22] # range(11,56,11) diff --git a/supplementary_notebooks/deepof_model_evaluation.ipynb b/supplementary_notebooks/deepof_model_evaluation.ipynb index b0cbd17fea2e6e14234689b7d3f6525b41b731aa..84679e4e9bfa8f8eaa726285eb0a391eb72270ff 100644 --- a/supplementary_notebooks/deepof_model_evaluation.ipynb +++ b/supplementary_notebooks/deepof_model_evaluation.ipynb @@ -630,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -647,7 +647,7 @@ " compile_model=True,\n", " batch_size=batch_size,\n", " encoding=encoding,\n", - " next_sequence_prediction=NextSeqPred,\n", + " next_sequence_prediction=0.1,\n", " phenotype_prediction=PhenoPred,\n", " rule_based_prediction=RuleBasedPred,\n", ").build(\n", @@ -658,11 +658,101 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": { - "scrolled": true + "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"SEQ_2_SEQ_GMVAE\"\n", + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_15 (InputLayer) [(None, 22, 26)] 0 \n", + "__________________________________________________________________________________________________\n", + "conv1d_24 (Conv1D) (None, 11, 64) 8384 input_15[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_94 (BatchNo (None, 11, 64) 256 conv1d_24[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_48 (Bidirectional (None, 11, 256) 148992 batch_normalization_94[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_95 (BatchNo (None, 11, 256) 1024 bidirectional_48[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_49 (Bidirectional (None, 128) 123648 batch_normalization_95[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_96 (BatchNo (None, 128) 512 bidirectional_49[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_88 (Dense) (None, 64) 8256 batch_normalization_96[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_97 (BatchNo (None, 64) 256 dense_88[0][0] \n", + "__________________________________________________________________________________________________\n", + "dropout_8 (Dropout) (None, 64) 0 batch_normalization_97[0][0] \n", + "__________________________________________________________________________________________________\n", + "sequential_14 (Sequential) (None, 32) 2208 dropout_8[0][0] \n", + "__________________________________________________________________________________________________\n", + "cluster_means (Dense) (None, 90) 2970 sequential_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "cluster_variances (Dense) (None, 90) 2970 sequential_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_14 (Concatenate) (None, 180) 0 cluster_means[0][0] \n", + " cluster_variances[0][0] \n", + "__________________________________________________________________________________________________\n", + "cluster_assignment (Dense) (None, 15) 495 sequential_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "reshape_8 (Reshape) (None, 12, 15) 0 concatenate_14[0][0] \n", + "__________________________________________________________________________________________________\n", + "encoding_distribution (Distribu multiple 0 cluster_assignment[0][0] \n", + " reshape_8[0][0] \n", + "__________________________________________________________________________________________________\n", + "kl_divergence_layer_6 (KLDiverg multiple 181 encoding_distribution[0][0] \n", + "__________________________________________________________________________________________________\n", + "latent_distribution (Lambda) multiple 0 kl_divergence_layer_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_97 (Dense) (None, 32) 224 latent_distribution[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_102 (BatchN (None, 32) 128 dense_97[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_92 (Dense) (None, 64) 2112 batch_normalization_102[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_103 (BatchN (None, 64) 256 dense_92[0][0] \n", + "__________________________________________________________________________________________________\n", + "repeat_vector_9 (RepeatVector) (None, 22, 64) 0 batch_normalization_103[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_52 (Bidirectional (None, 22, 256) 148992 repeat_vector_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_104 (BatchN (None, 22, 256) 1024 bidirectional_52[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_53 (Bidirectional (None, 22, 256) 296448 batch_normalization_104[0][0] \n", + "__________________________________________________________________________________________________\n", + "batch_normalization_105 (BatchN (None, 22, 256) 1024 bidirectional_53[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv1d_26 (Conv1D) (None, 22, 64) 81984 batch_normalization_105[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_99 (Dense) (None, 22, 26) 1690 conv1d_26[0][0] \n", + "__________________________________________________________________________________________________\n", + "tf.math.softplus_7 (TFOpLambda) (None, 22, 26) 0 dense_99[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_98 (Dense) (None, 22, 26) 1690 conv1d_26[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_7 (Lambda) (None, 22, 26) 0 tf.math.softplus_7[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_16 (Concatenate) (None, 22, 52) 0 dense_98[0][0] \n", + " lambda_7[0][0] \n", + "__________________________________________________________________________________________________\n", + "vae_reconstruction (Functional) multiple 337940 latent_distribution[0][0] \n", + "__________________________________________________________________________________________________\n", + "vae_prediction (IndependentNorm multiple 0 concatenate_16[0][0] \n", + "==================================================================================================\n", + "Total params: 1,173,664\n", + "Trainable params: 1,170,271\n", + "Non-trainable params: 3,393\n", + "__________________________________________________________________________________________________\n" + ] + } + ], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n",