Commit 78db4095 authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed back latent space variance activation

parent 3cbf030a
......@@ -13,7 +13,7 @@ from typing import Any, Dict, Tuple
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.activations import relu
from tensorflow.keras.activations import softplus
from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform, Orthogonal
from tensorflow.keras.layers import BatchNormalization, Bidirectional
......@@ -442,7 +442,7 @@ class SEQ_2_SEQ_GMVAE:
Model_D5 = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="sigmoid",
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
unroll=self.lstm_unroll,
......@@ -605,7 +605,7 @@ class SEQ_2_SEQ_GMVAE:
)(z_gauss)
z = tfpl.DistributionLambda(
lambda gauss: tfd.mixture.Mixture(
make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
cat=tfd.categorical.Categorical(
probs=gauss[0],
),
......@@ -613,7 +613,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=relu(gauss[1][..., self.ENCODING :, k]) + 1e-5,
scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
),
reinterpreted_batch_ndims=1,
)
......@@ -621,8 +621,11 @@ class SEQ_2_SEQ_GMVAE:
],
),
convert_to_tensor_fn="sample",
name="encoding_distribution",
)([z_cat, z_gauss])
encode_to_distribution = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
# Define and control custom loss functions
if "ELBO" in self.loss:
kl_warm_up_iters = tf.cast(
......@@ -676,7 +679,7 @@ class SEQ_2_SEQ_GMVAE:
)(generator)
# define individual branches as models
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
encode_to_vector = Model(x, z, name="SEQ_2_SEQ_VEncoder")
generator = Model(g, x_decoded_mean, name="vae_reconstruction")
def log_loss(x_true, p_x_q_given_z):
......@@ -684,7 +687,7 @@ class SEQ_2_SEQ_GMVAE:
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [generator(encoder.outputs)]
model_outs = [generator(encode_to_vector.outputs)]
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
......@@ -733,9 +736,9 @@ class SEQ_2_SEQ_GMVAE:
loss_weights.append(self.phenotype_prediction)
# define grouper and end-to-end autoencoder model
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
grouper = Model(encode_to_vector.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
gmvaep = Model(
inputs=encoder.inputs,
inputs=encode_to_vector.inputs,
outputs=model_outs,
name="SEQ_2_SEQ_GMVAE",
)
......@@ -751,7 +754,8 @@ class SEQ_2_SEQ_GMVAE:
gmvaep.build(input_shape)
return (
encoder,
encode_to_vector,
encode_to_distribution,
generator,
grouper,
gmvaep,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment