Commit 7d89b57d authored by lucas_miranda's avatar lucas_miranda
Browse files

Added negative log likelihood of a Bernoulli distribution as reconstruction loss

parent b9d0905e
......@@ -22,7 +22,7 @@ tfd = tfp.distributions
tfpl = tfp.layers
# Helper functions
# Helper functions and classes
class exponential_learning_rate(tf.keras.callbacks.Callback):
"""Simple class that allows to grow learning rate exponentially during training"""
......
......@@ -517,6 +517,7 @@ class SEQ_2_SEQ_GMVAE:
def build(self, input_shape: Tuple):
"""Builds the tf.keras model"""
print(input_shape)
# Instanciate prior
self.get_prior()
......@@ -623,6 +624,7 @@ class SEQ_2_SEQ_GMVAE:
z = deepof.model_utils.KLDivergenceLayer(
self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
weight=kl_beta,
)(z)
......@@ -652,13 +654,23 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_B3(generator)
generator = Model_D5(generator)
generator = Model_B4(generator)
x_decoded_mean = TimeDistributed(
Dense(input_shape[2]), name="vaep_reconstruction"
generator = TimeDistributed(Dense(input_shape[2]))(
generator
)
x_decoded_mean = tfpl.IndependentBernoulli(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_reconstruction",
)(generator)
def log_loss(x_true, p_x_q_given_z):
"""Computes the negative log likelihood of the data given
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [x_decoded_mean]
model_losses = [Huber(delta=self.delta, reduction="mean")]
model_metrics = {"vaep_reconstruction": ["mae", "mse"]}
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
if self.predictor > 0:
......@@ -677,12 +689,14 @@ class SEQ_2_SEQ_GMVAE:
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed(
Dense(input_shape[2]), name="vaep_prediction"
Dense(input_shape[2]), name="vae_prediction"
)(predictor)
model_outs.append(x_predicted_mean)
model_losses.append(Huber(delta=self.delta, reduction="mean"))
model_metrics["vaep_prediction"] = ["mae", "mse"]
model_losses.append(
Huber(delta=self.delta, reduction="sum_over_batch_size")
)
model_metrics["vae_prediction"] = ["mae", "mse"]
loss_weights.append(self.predictor)
if self.phenotype_prediction > 0:
......@@ -692,7 +706,7 @@ class SEQ_2_SEQ_GMVAE:
)
model_outs.append(pheno_pred)
model_losses.append(BinaryCrossentropy(reduction="mean"))
model_losses.append(BinaryCrossentropy(reduction="sum_over_batch_size"))
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment