Commit b9d0905e authored by lucas_miranda's avatar lucas_miranda
Browse files

Scaled loss functions with batch size

parent 3fd15b6c
......@@ -657,7 +657,7 @@ class SEQ_2_SEQ_GMVAE:
)(generator)
model_outs = [x_decoded_mean]
model_losses = [Huber(delta=self.delta, reduction="sum")]
model_losses = [Huber(delta=self.delta, reduction="mean")]
model_metrics = {"vaep_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
......@@ -681,7 +681,7 @@ class SEQ_2_SEQ_GMVAE:
)(predictor)
model_outs.append(x_predicted_mean)
model_losses.append(Huber(delta=self.delta, reduction="sum"))
model_losses.append(Huber(delta=self.delta, reduction="mean"))
model_metrics["vaep_prediction"] = ["mae", "mse"]
loss_weights.append(self.predictor)
......@@ -692,7 +692,7 @@ class SEQ_2_SEQ_GMVAE:
)
model_outs.append(pheno_pred)
model_losses.append(BinaryCrossentropy())
model_losses.append(BinaryCrossentropy(reduction="mean"))
model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
loss_weights.append(self.phenotype_prediction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment