Commit 7d89b57d authored by lucas_miranda's avatar lucas_miranda
Browse files

Added negative log likelihood of a Bernoulli distribution as reconstruction loss

parent b9d0905e
...@@ -22,7 +22,7 @@ tfd = tfp.distributions ...@@ -22,7 +22,7 @@ tfd = tfp.distributions
tfpl = tfp.layers tfpl = tfp.layers
# Helper functions # Helper functions and classes
class exponential_learning_rate(tf.keras.callbacks.Callback): class exponential_learning_rate(tf.keras.callbacks.Callback):
"""Simple class that allows to grow learning rate exponentially during training""" """Simple class that allows to grow learning rate exponentially during training"""
......
...@@ -517,6 +517,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -517,6 +517,7 @@ class SEQ_2_SEQ_GMVAE:
def build(self, input_shape: Tuple): def build(self, input_shape: Tuple):
"""Builds the tf.keras model""" """Builds the tf.keras model"""
print(input_shape)
# Instanciate prior # Instanciate prior
self.get_prior() self.get_prior()
...@@ -623,6 +624,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -623,6 +624,7 @@ class SEQ_2_SEQ_GMVAE:
z = deepof.model_utils.KLDivergenceLayer( z = deepof.model_utils.KLDivergenceLayer(
self.prior, self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl), test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
weight=kl_beta, weight=kl_beta,
)(z) )(z)
...@@ -652,13 +654,23 @@ class SEQ_2_SEQ_GMVAE: ...@@ -652,13 +654,23 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_B3(generator) generator = Model_B3(generator)
generator = Model_D5(generator) generator = Model_D5(generator)
generator = Model_B4(generator) generator = Model_B4(generator)
x_decoded_mean = TimeDistributed( generator = TimeDistributed(Dense(input_shape[2]))(
Dense(input_shape[2]), name="vaep_reconstruction" generator
)
x_decoded_mean = tfpl.IndependentBernoulli(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_reconstruction",
)(generator) )(generator)
def log_loss(x_true, p_x_q_given_z):
"""Computes the negative log likelihood of the data given
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [x_decoded_mean] model_outs = [x_decoded_mean]
model_losses = [Huber(delta=self.delta, reduction="mean")] model_losses = [log_loss]
model_metrics = {"vaep_reconstruction": ["mae", "mse"]} model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0] loss_weights = [1.0]
if self.predictor > 0: if self.predictor > 0:
...@@ -677,12 +689,14 @@ class SEQ_2_SEQ_GMVAE: ...@@ -677,12 +689,14 @@ class SEQ_2_SEQ_GMVAE:
predictor = Model_P3(predictor) predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor) predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed( x_predicted_mean = TimeDistributed(
Dense(input_shape[2]), name="vaep_prediction" Dense(input_shape[2]), name="vae_prediction"
)(predictor) )(predictor)
model_outs.append(x_predicted_mean) model_outs.append(x_predicted_mean)
model_losses.append(Huber(delta=self.delta, reduction="mean")) model_losses.append(
model_metrics["vaep_prediction"] = ["mae", "mse"] Huber(delta=self.delta, reduction="sum_over_batch_size")
)
model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.predictor) loss_weights.append(self.predictor)
if self.phenotype_prediction > 0: if self.phenotype_prediction > 0:
...@@ -692,7 +706,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -692,7 +706,7 @@ class SEQ_2_SEQ_GMVAE:
) )
model_outs.append(pheno_pred) model_outs.append(pheno_pred)
model_losses.append(BinaryCrossentropy(reduction="mean")) model_losses.append(BinaryCrossentropy(reduction="sum_over_batch_size"))
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"] model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction) loss_weights.append(self.phenotype_prediction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment