Commit 547af774 authored by lucas_miranda's avatar lucas_miranda
Browse files

Reimplemented KL warmup using optimizer iterators; getting rid of the clumsy callback

parent fdfa1da2
......@@ -871,7 +871,7 @@ class coordinates:
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
entropy_samples: int = 10000,
entropy_min_n:int = 5,
entropy_min_n: int = 5,
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder.
......
......@@ -433,9 +433,12 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
to the final model loss.
"""
def __init__(self, *args, **kwargs):
self.is_placeholder = True
def __init__(self, iters, warm_up_iters, *args, **kwargs):
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True
self._iters = iters
self._warm_up_iters = warm_up_iters
self._regularizer._weight = K.min(self._iters / self._warm_up_iters, 1)
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
......
......@@ -280,6 +280,7 @@ class SEQ_2_SEQ_GMVAE:
self.mmd_warmup = mmd_warmup_epochs
self.neuron_control = neuron_control
self.number_of_components = number_of_components
self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
self.overlap_loss = overlap_loss
self.phenotype_prediction = phenotype_prediction
self.predictor = predictor
......@@ -625,21 +626,13 @@ class SEQ_2_SEQ_GMVAE:
kl_warmup_callback = False
if "ELBO" in self.loss:
kl_beta = deepof.model_utils.K.variable(1.0, name="kl_beta")
kl_beta._trainable = False
if self.kl_warmup:
kl_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
kl_beta, deepof.model_utils.K.min([epoch / self.kl_warmup, 1])
)
)
# noinspection PyCallingNonCallable
z = deepof.model_utils.KLDivergenceLayer(
self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
weight=kl_beta,
iters=self.optimizer.iterations,
warm_up_iters=self.kl_warmup,
)(z)
mmd_warmup_callback = False
......@@ -659,7 +652,7 @@ class SEQ_2_SEQ_GMVAE:
)(z)
# Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda x: x, name="latent_distribution")(z)
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
# Define and instantiate generator
g = Input(shape=self.ENCODING)
......@@ -749,10 +742,7 @@ class SEQ_2_SEQ_GMVAE:
if self.compile:
gmvaep.compile(
loss=model_losses,
optimizer=Nadam(
lr=self.learn_rate,
clipvalue=self.clipvalue,
),
optimizer=self.optimizer,
metrics=model_metrics,
loss_weights=loss_weights,
)
......
......@@ -110,8 +110,9 @@ rule latent_regularization_experiments:
"--batch-size 256 "
"--window-size 24 "
"--window-step 12 "
# "--exclude-bodyparts Tail_base,Tail_1,Tail_2,Tail_tip "
"--output-path {outpath}latent_regularization_experiments"
# "--exclude-bodyparts Tail_base,Tail_1,Tail_2,Tail_tip "
rule explore_phenotype_classification:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment