Commit dd890b89 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added experimental phenotype classification branch to GMVAEP

parent 58553d5c
Pipeline #86478 passed with stage
in 25 minutes and 11 seconds
......@@ -504,6 +504,7 @@ class Entropy_regulariser(Layer):
# Adds metric that monitors dead neurons in the latent space
self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
self.add_loss(self.weight * K.sum(entropy), inputs=[z])
if self.weight > 0:
self.add_loss(self.weight * K.sum(entropy), inputs=[z])
return z
......@@ -250,16 +250,17 @@ class SEQ_2_SEQ_GMVAE:
self,
architecture_hparams: dict = {},
batch_size: int = 256,
loss: str = "ELBO+MMD",
compile: bool = True,
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
initialiser_iters: int = int(1e4),
kl_warmup_epochs: int = 0,
loss: str = "ELBO+MMD",
mmd_warmup_epochs: int = 0,
number_of_components: int = 1,
predictor: float = 0.0,
overlap_loss: bool = False,
entropy_reg_weight: float = 0.0,
initialiser_iters: int = int(1e4),
huber_delta: float = 1.0,
phenotype_prediction: float = 0.0,
predictor: float = 0.0,
):
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
......@@ -271,17 +272,18 @@ class SEQ_2_SEQ_GMVAE:
self.DROPOUT_RATE = self.hparams["dropout_rate"]
self.ENCODING = self.hparams["encoding"]
self.learn_rate = self.hparams["learning_rate"]
self.loss = loss
self.prior = "standard_normal"
self.compile = compile
self.delta = huber_delta
self.entropy_reg_weight = entropy_reg_weight
self.initialiser_iters = initialiser_iters
self.kl_warmup = kl_warmup_epochs
self.loss = loss
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
self.predictor = predictor
self.overlap_loss = overlap_loss
self.entropy_reg_weight = entropy_reg_weight
self.initialiser_iters = initialiser_iters
self.delta = huber_delta
self.phenotype_prediction = phenotype_prediction
self.predictor = predictor
self.prior = "standard_normal"
assert (
"ELBO" in self.loss or "MMD" in self.loss
......@@ -457,8 +459,9 @@ class SEQ_2_SEQ_GMVAE:
)
# Phenotype classification layers
Model_PC1 = Dense(self.number_of_components, activation="elu", kernel_initializer=he_uniform())
Model_PC2 = Dense(1, activation="sigmoid", kernel_initializer=he_uniform())
Model_PC1 = Dense(
self.number_of_components, activation="elu", kernel_initializer=he_uniform()
)
return (
Model_E0,
......@@ -479,7 +482,6 @@ class SEQ_2_SEQ_GMVAE:
Model_P2,
Model_P3,
Model_PC1,
Model_PC2
)
def build(self, input_shape: Tuple):
......@@ -508,7 +510,6 @@ class SEQ_2_SEQ_GMVAE:
Model_P2,
Model_P3,
Model_PC1,
Model_PC2,
) = self.get_layers(input_shape)
# Define and instantiate encoder
......@@ -554,7 +555,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]),
scale=softplus(gauss[1][..., self.ENCODING:, k]),
),
reinterpreted_batch_ndims=1,
)
......@@ -610,7 +611,7 @@ class SEQ_2_SEQ_GMVAE:
model_outs = [x_decoded_mean]
model_losses = [Huber(delta=self.delta, reduction="sum")]
loss_weights = [1.]
loss_weights = [1.0]
if self.predictor > 0:
# Define and instantiate predictor
......@@ -635,7 +636,7 @@ class SEQ_2_SEQ_GMVAE:
if self.phenotype_prediction > 0:
pheno_pred = Model_PC1(z)
pheno_pred = Model_PC2(pheno_pred)
pheno_pred = Dense(1, activation="sigmoid", name="phenotype_prediction")(pheno_pred)
model_outs.append(pheno_pred)
model_losses.append(BinaryCrossentropy())
......@@ -646,11 +647,7 @@ class SEQ_2_SEQ_GMVAE:
grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
# noinspection PyUnboundLocalVariable
gmvaep = Model(
inputs=x,
outputs=model_outs,
name="SEQ_2_SEQ_GMVAE",
)
gmvaep = Model(inputs=x, outputs=model_outs, name="SEQ_2_SEQ_GMVAE",)
# Build generator as a separate entity
g = Input(shape=self.ENCODING)
......@@ -666,12 +663,13 @@ class SEQ_2_SEQ_GMVAE:
_x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
gmvaep.compile(
loss=model_losses,
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"],
loss_weights=loss_weights,
)
if self.compile:
gmvaep.compile(
loss=model_losses,
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"],
loss_weights=loss_weights,
)
gmvaep.build(input_shape)
......@@ -691,7 +689,8 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - Check KL weight in the overal loss function! Are we scaling the loss components correctly?
# - Check batch and event shapes of all distributions involved. Incorrect shapes (batch >1) could bring problems with the KL.
# - Check batch and event shapes of all distributions involved. Incorrect shapes (batch >1) could bring
# problems with the KL.
# - Check merge mode in LSTM layers. Maybe we can drastically reduce model size!
# - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
# - Investigate posterior collapse (L1 as kernel/activity regulariser does not work)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment