Commit f851531e authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent 9ce11008
......@@ -288,22 +288,52 @@ input_dict_train = {
input_dict_val = {
"coords": coords2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11,
window_step=1,
scale=True,
random_state=42,
filter="gaussian",
sigma=110,
),
"dists": distances2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11,
window_step=1,
scale=True,
random_state=42,
filter="gaussian",
sigma=110,
),
"angles": angles2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11,
window_step=1,
scale=True,
random_state=42,
filter="gaussian",
sigma=110,
),
"coords+dist": coords_distances2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11,
window_step=1,
scale=True,
random_state=42,
filter="gaussian",
sigma=110,
),
"coords+angle": coords_angles2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11,
window_step=1,
scale=True,
random_state=42,
filter="gaussian",
sigma=110,
),
"coords+dist+angle": coords_dist_angles2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11,
window_step=1,
scale=True,
random_state=42,
filter="gaussian",
sigma=110,
),
}
......
......@@ -297,19 +297,6 @@ class SEQ_2_SEQ_GMVAE:
activation=None,
)(encoder)
# Define and control custom loss functions
kl_warmup_callback = False
if "ELBO" in self.loss:
kl_beta = K.variable(1.0, name="kl_beta")
kl_beta._trainable = False
if self.kl_warmup:
kl_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value(
kl_beta, K.min([epoch / self.kl_warmup, 1])
)
)
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z = tfpl.DistributionLambda(
lambda gauss: tfd.mixture.Mixture(
......@@ -328,7 +315,19 @@ class SEQ_2_SEQ_GMVAE:
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
)([z_cat, z_gauss])
# Define and control custom loss functions
kl_warmup_callback = False
if "ELBO" in self.loss:
kl_beta = K.variable(1.0, name="kl_beta")
kl_beta._trainable = False
if self.kl_warmup:
kl_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value(
kl_beta, K.min([epoch / self.kl_warmup, 1])
)
)
z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
mmd_warmup_callback = False
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment