Commit bcc49658 authored by lucas_miranda's avatar lucas_miranda
Browse files

Modified GMVAEP - GRUs instead of LSTMs, stricted clipping, less deep, l1...

Modified GMVAEP - GRUs instead of LSTMs, stricted clipping, less deep, l1 regularization in cluster means, uniform initializer of variances
parent 1ebd5f15
...@@ -15,7 +15,7 @@ import tensorflow_probability as tfp ...@@ -15,7 +15,7 @@ import tensorflow_probability as tfp
from tensorflow.keras import Input, Model, Sequential from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.activations import softplus from tensorflow.keras.activations import softplus
from tensorflow.keras.constraints import UnitNorm from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform from tensorflow.keras.initializers import he_uniform, random_uniform
from tensorflow.keras.layers import BatchNormalization, Bidirectional from tensorflow.keras.layers import BatchNormalization, Bidirectional
from tensorflow.keras.layers import Dense, Dropout, GRU from tensorflow.keras.layers import Dense, Dropout, GRU
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
...@@ -155,7 +155,7 @@ class GMVAE: ...@@ -155,7 +155,7 @@ class GMVAE:
Model_E0 = tf.keras.layers.Conv1D( Model_E0 = tf.keras.layers.Conv1D(
filters=self.CONV_filters, filters=self.CONV_filters,
kernel_size=5, kernel_size=5,
strides=2, # Increased strides to yield shorter sequences strides=2, # Increased strides to yield shorter sequences
padding="same", padding="same",
activation=self.dense_activation, activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
...@@ -398,6 +398,7 @@ class GMVAE: ...@@ -398,6 +398,7 @@ class GMVAE:
// 2, // 2,
name="cluster_means", name="cluster_means",
activation=None, activation=None,
activity_regularizer=(tf.keras.regularizers.l1(10e-5)),
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
)(encoder) )(encoder)
...@@ -411,6 +412,7 @@ class GMVAE: ...@@ -411,6 +412,7 @@ class GMVAE:
activity_regularizer=( activity_regularizer=(
tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
), ),
kernel_initializer=random_uniform(),
)(encoder) )(encoder)
z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1) z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
...@@ -638,4 +640,4 @@ class GMVAE: ...@@ -638,4 +640,4 @@ class GMVAE:
# - Think about using spectral normalization # - Think about using spectral normalization
# - REVISIT DROPOUT - CAN HELP WITH TRAINING STABILIZATION # - REVISIT DROPOUT - CAN HELP WITH TRAINING STABILIZATION
# - Decrease learning rate! # - Decrease learning rate!
# - Implement residual blocks! # - Implement residual blocks!
\ No newline at end of file
...@@ -52,11 +52,11 @@ def load_treatments(train_path): ...@@ -52,11 +52,11 @@ def load_treatments(train_path):
to be loaded as metadata in the coordinates class""" to be loaded as metadata in the coordinates class"""
try: try:
with open( with open(
os.path.join( os.path.join(
train_path, train_path,
[i for i in os.listdir(train_path) if i.endswith(".json")][0], [i for i in os.listdir(train_path) if i.endswith(".json")][0],
), ),
"r", "r",
) as handle: ) as handle:
treatment_dict = json.load(handle) treatment_dict = json.load(handle)
except IndexError: except IndexError:
...@@ -66,25 +66,25 @@ def load_treatments(train_path): ...@@ -66,25 +66,25 @@ def load_treatments(train_path):
def get_callbacks( def get_callbacks(
X_train: np.array, X_train: np.array,
batch_size: int, batch_size: int,
phenotype_prediction: float, phenotype_prediction: float,
next_sequence_prediction: float, next_sequence_prediction: float,
rule_based_prediction: float, rule_based_prediction: float,
overlap_loss: float, overlap_loss: float,
loss: str, loss: str,
loss_warmup: int = 0, loss_warmup: int = 0,
warmup_mode: str = "none", warmup_mode: str = "none",
X_val: np.array = None, X_val: np.array = None,
input_type: str = False, input_type: str = False,
cp: bool = False, cp: bool = False,
reg_cat_clusters: bool = False, reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False, reg_cluster_variance: bool = False,
entropy_samples: int = 15000, entropy_samples: int = 15000,
entropy_knn: int = 100, entropy_knn: int = 100,
logparam: dict = None, logparam: dict = None,
outpath: str = ".", outpath: str = ".",
run: int = False, run: int = False,
) -> List[Union[Any]]: ) -> List[Union[Any]]:
"""Generates callbacks for model training, including: """Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details; - run_ID: run name, with coarse parameter details;
...@@ -202,15 +202,15 @@ def log_hyperparameters(phenotype_class: float, rec: str): ...@@ -202,15 +202,15 @@ def log_hyperparameters(phenotype_class: float, rec: str):
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
def tensorboard_metric_logging( def tensorboard_metric_logging(
run_dir: str, run_dir: str,
hpms: Any, hpms: Any,
ae: Any, ae: Any,
X_val: np.ndarray, X_val: np.ndarray,
y_val: np.ndarray, y_val: np.ndarray,
next_sequence_prediction: float, next_sequence_prediction: float,
phenotype_prediction: float, phenotype_prediction: float,
rule_based_prediction: float, rule_based_prediction: float,
rec: str, rec: str,
): ):
"""Autoencoder metric logging in tensorboard""" """Autoencoder metric logging in tensorboard"""
...@@ -270,35 +270,35 @@ def tensorboard_metric_logging( ...@@ -270,35 +270,35 @@ def tensorboard_metric_logging(
def autoencoder_fitting( def autoencoder_fitting(
preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
batch_size: int, batch_size: int,
encoding_size: int, encoding_size: int,
epochs: int, epochs: int,
hparams: dict, hparams: dict,
kl_annealing_mode: str, kl_annealing_mode: str,
kl_warmup: int, kl_warmup: int,
log_history: bool, log_history: bool,
log_hparams: bool, log_hparams: bool,
loss: str, loss: str,
mmd_annealing_mode: str, mmd_annealing_mode: str,
mmd_warmup: int, mmd_warmup: int,
montecarlo_kl: int, montecarlo_kl: int,
n_components: int, n_components: int,
output_path: str, output_path: str,
overlap_loss: float, overlap_loss: float,
next_sequence_prediction: float, next_sequence_prediction: float,
phenotype_prediction: float, phenotype_prediction: float,
rule_based_prediction: float, rule_based_prediction: float,
pretrained: str, pretrained: str,
save_checkpoints: bool, save_checkpoints: bool,
save_weights: bool, save_weights: bool,
reg_cat_clusters: bool, reg_cat_clusters: bool,
reg_cluster_variance: bool, reg_cluster_variance: bool,
entropy_samples: int, entropy_samples: int,
entropy_knn: int, entropy_knn: int,
input_type: str, input_type: str,
run: int = 0, run: int = 0,
strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(), strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
): ):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding""" """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
...@@ -317,8 +317,8 @@ def autoencoder_fitting( ...@@ -317,8 +317,8 @@ def autoencoder_fitting(
# Generate validation dataset for callback usage # Generate validation dataset for callback usage
X_val_dataset = ( X_val_dataset = (
tf.data.Dataset.from_tensor_slices(X_val) tf.data.Dataset.from_tensor_slices(X_val)
.with_options(options) .with_options(options)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True) .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
) )
# Defines what to log on tensorboard (useful for trying out different models) # Defines what to log on tensorboard (useful for trying out different models)
...@@ -361,7 +361,7 @@ def autoencoder_fitting( ...@@ -361,7 +361,7 @@ def autoencoder_fitting(
logparams, metrics = log_hyperparameters(phenotype_prediction, rec) logparams, metrics = log_hyperparameters(phenotype_prediction, rec)
with tf.summary.create_file_writer( with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID) os.path.join(output_path, "hparams", run_ID)
).as_default(): ).as_default():
hp.hparams_config( hp.hparams_config(
hparams=logparams, hparams=logparams,
...@@ -422,28 +422,28 @@ def autoencoder_fitting( ...@@ -422,28 +422,28 @@ def autoencoder_fitting(
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]] Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_prediction > 0.0: if phenotype_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:, 0]] ys += [y_train[-Xs.shape[0] :, 0]]
yvals += [y_val[-Xvals.shape[0]:, 0]] yvals += [y_val[-Xvals.shape[0] :, 0]]
# Remove the used column (phenotype) from both y arrays # Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:] y_train = y_train[:, 1:]
y_val = y_val[:, 1:] y_val = y_val[:, 1:]
if rule_based_prediction > 0.0: if rule_based_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:]] ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0]:]] yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects # Convert data to tf.data.Dataset objects
train_dataset = ( train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys))) tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True) .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.shuffle(buffer_size=X_train.shape[0]) .shuffle(buffer_size=X_train.shape[0])
.with_options(options) .with_options(options)
) )
val_dataset = ( val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals))) tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True) .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.with_options(options) .with_options(options)
) )
ae.fit( ae.fit(
...@@ -484,23 +484,23 @@ def autoencoder_fitting( ...@@ -484,23 +484,23 @@ def autoencoder_fitting(
def tune_search( def tune_search(
data: List[np.array], data: List[np.array],
encoding_size: int, encoding_size: int,
hypertun_trials: int, hypertun_trials: int,
hpt_type: str, hpt_type: str,
k: int, k: int,
kl_warmup_epochs: int, kl_warmup_epochs: int,
loss: str, loss: str,
mmd_warmup_epochs: int, mmd_warmup_epochs: int,
overlap_loss: float, overlap_loss: float,
next_sequence_prediction: float, next_sequence_prediction: float,
phenotype_prediction: float, phenotype_prediction: float,
rule_based_prediction: float, rule_based_prediction: float,
project_name: str, project_name: str,
callbacks: List, callbacks: List,
n_epochs: int = 30, n_epochs: int = 30,
n_replicas: int = 1, n_replicas: int = 1,
outpath: str = ".", outpath: str = ".",
) -> Union[bool, Tuple[Any, Any]]: ) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization """Define the search space using keras-tuner and bayesian optimization
...@@ -592,16 +592,16 @@ def tune_search( ...@@ -592,16 +592,16 @@ def tune_search(
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]] Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_prediction > 0.0: if phenotype_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:, 0]] ys += [y_train[-Xs.shape[0] :, 0]]
yvals += [y_val[-Xvals.shape[0]:, 0]] yvals += [y_val[-Xvals.shape[0] :, 0]]
# Remove the used column (phenotype) from both y arrays # Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:] y_train = y_train[:, 1:]
y_val = y_val[:, 1:] y_val = y_val[:, 1:]
if rule_based_prediction > 0.0: if rule_based_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:]] ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0]:]] yvals += [y_val[-Xvals.shape[0] :]]
tuner.search( tuner.search(
Xs, Xs,
......
...@@ -2,9 +2,18 @@ ...@@ -2,9 +2,18 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 36,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [ "source": [
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload 2" "%autoreload 2"
...@@ -54,7 +63,7 @@ ...@@ -54,7 +63,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 45,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -67,6 +76,7 @@ ...@@ -67,6 +76,7 @@
"from collections import Counter\n", "from collections import Counter\n",
"from sklearn.preprocessing import StandardScaler\n", "from sklearn.preprocessing import StandardScaler\n",
"\n", "\n",
"from datetime import datetime\n",
"from sklearn.manifold import TSNE\n", "from sklearn.manifold import TSNE\n",
"from sklearn.decomposition import PCA\n", "from sklearn.decomposition import PCA\n",
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
...@@ -615,7 +625,7 @@ ...@@ -615,7 +625,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -630,7 +640,7 @@ ...@@ -630,7 +640,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 37,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -647,7 +657,7 @@ ...@@ -647,7 +657,7 @@
" compile_model=True,\n", " compile_model=True,\n",
" batch_size=batch_size,\n", " batch_size=batch_size,\n",
" encoding=encoding,\n", " encoding=encoding,\n",
" next_sequence_prediction=0.1,\n", " next_sequence_prediction=NextSeqPred,\n",
" phenotype_prediction=PhenoPred,\n", " phenotype_prediction=PhenoPred,\n",
" rule_based_prediction=RuleBasedPred,\n", " rule_based_prediction=RuleBasedPred,\n",
").build(\n", ").build(\n",
...@@ -658,7 +668,7 @@ ...@@ -658,7 +668,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 38,
"metadata": { "metadata": {
"scrolled": false "scrolled": false
}, },
...@@ -671,84 +681,51 @@ ...@@ -671,84 +681,51 @@
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n", "Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n", "==================================================================================================\n",
"input_15 (InputLayer) [(None, 22, 26)] 0 \n", "input_19 (InputLayer) [(None, 22, 26)] 0 \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"conv1d_24 (Conv1D) (None, 11, 64) 8384 input_15[0][0] \n", "conv1d_33 (Conv1D) (None, 11, 64) 8384 input_19[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"batch_normalization_94 (BatchNo (None, 11, 64) 256 conv1d_24[0][0] \n", "batch_normalization_118 (BatchN (None, 11, 64) 256 conv1d_33[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"bidirectional_48 (Bidirectional (None, 11, 256) 148992 batch_normalization_94[0][0] \n", "bidirectional_60 (Bidirectional (None, 11, 256) 197632 batch_normalization_118[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"batch_normalization_95 (BatchNo (None, 11, 256) 1024 bidirectional_48[0][0] \n", "batch_normalization_119 (BatchN (None, 11, 256) 1024 bidirectional_60[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"bidirectional_49 (Bidirectional (None, 128) 123648 batch_normalization_95[0][0] \n", "bidirectional_61 (Bidirectional (None, 128) 164352 batch_normalization_119[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"batch_normalization_96 (BatchNo (None, 128) 512 bidirectional_49[0][0] \n", "batch_normalization_120 (BatchN (None, 128) 512 bidirectional_61[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"dense_88 (Dense) (None, 64) 8256 batch_normalization_96[0][0] \n", "dense_109 (Dense) (None, 64) 8256 batch_normalization_120[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"batch_normalization_97 (BatchNo (None, 64) 256 dense_88[0][0] \n", "batch_normalization_121 (BatchN (None, 64) 256 dense_109[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"dropout_8 (Dropout) (None, 64) 0 batch_normalization_97[0][0] \n", "dropout_10 (Dropout) (None, 64) 0 batch_normalization_121[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"sequential_14 (Sequential) (None, 32) 2208 dropout_8[0][0] \n", "sequential_18 (Sequential) (None, 32) 2208 dropout_10[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"cluster_means (Dense) (None, 90) 2970 sequential_14[0][0] \n", "cluster_means (Dense) (None, 90) 2970 sequential_18[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"cluster_variances (Dense) (None, 90) 2970 sequential_14[0][0] \n", "cluster_variances (Dense) (None, 90) 2970 sequential_18[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"concatenate_14 (Concatenate) (None, 180) 0 cluster_means[0][0] \n", "concatenate_19 (Concatenate) (None, 180) 0 cluster_means[0][0] \n",
" cluster_variances[0][0] \n", " cluster_variances[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"cluster_assignment (Dense) (None, 15) 495 sequential_14[0][0] \n", "cluster_assignment (Dense) (None, 15) 495 sequential_18[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"reshape_8 (Reshape) (None, 12, 15) 0 concatenate_14[0][0] \n", "reshape_10 (Reshape) (None, 12, 15) 0 concatenate_19[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"encoding_distribution (Distribu multiple 0 cluster_assignment[0][0] \n", "encoding_distribution (Distribu multiple 0 cluster_assignment[0][0] \n",
" reshape_8[0][0] \n", " reshape_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"kl_divergence_layer_6 (KLDiverg multiple 181 encoding_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"latent_distribution (Lambda) multiple 0 kl_divergence_layer_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_97 (Dense) (None, 32) 224 latent_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_102 (BatchN (None, 32) 128 dense_97[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_92 (Dense) (None, 64) 2112 batch_normalization_102[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_103 (BatchN (None, 64) 256 dense_92[0][0] \n",
"__________________________________________________________________________________________________\n",
"repeat_vector_9 (RepeatVector) (None, 22, 64) 0 batch_normalization_103[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_52 (Bidirectional (None, 22, 256) 148992 repeat_vector_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_104 (BatchN (None, 22, 256) 1024 bidirectional_52[0][0] \n",
"__________________________________________________________________________________________________\n",
"bidirectional_53 (Bidirectional (None, 22, 256) 296448 batch_normalization_104[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_105 (BatchN (None, 22, 256) 1024 bidirectional_53[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv1d_26 (Conv1D) (None, 22, 64) 81984 batch_normalization_105[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"dense_99 (Dense) (None, 22, 26) 1690 conv1d_26[0][0] \n", "kl_divergence_layer_8 (KLDiverg multiple 181 encoding_distribution[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"tf.math.softplus_7 (TFOpLambda) (None, 22, 26) 0 dense_99[0][0] \n", "latent_distribution (Lambda) multiple 0 kl_divergence_layer_8[0][0] \n",
"__________________________________________________________________________________________________\n", "__________________________________________________________________________________________________\n",
"dense_98 (Dense) (None, 22, 26) 1690 conv1d_26[0][0] \n", "vae_reconstruction (Functional) multiple 419092 latent_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"lambda_7 (Lambda) (None, 22, 26) 0 tf.math.softplus_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_16 (Concatenate) (None, 22, 52) 0 dense_98[0][0] \n",
" lambda_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"vae_reconstruction (Functional) multiple 337940 latent_distribution[0][0] \n",
"__________________________________________________________________________________________________\n",
"vae_prediction (IndependentNorm multiple 0 concatenate_16[0][0] \n",
"==================================================================================================\n", "==================================================================================================\n",
"Total params: 1,173,664\n", "Total params: 808,588\n",
"Trainable params: 1,170,271\n", "Trainable params: 806,411\n",
"Non-trainable params: 3,393\n", "Non-trainable params: 2,177\n",
"__________________________________________________________________________________________________\n" "__________________________________________________________________________________________________\n"
] ]
} }
...@@ -763,7 +740,7 @@ ...@@ -763,7 +740,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 46,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -787,7 +764,7 @@ ...@@ -787,7 +764,7 @@
"# plot_model(encoder, \"encoder\")\n", "# plot_model(encoder, \"encoder\")\n",
"# plot_model(decoder, \"decoder\")\n", "# plot_model(decoder, \"decoder\")\n",