Commit 7aaa0341 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added encoding size as a CL parameter in train_model.py

parent 375a5b70
Pipeline #88512 passed with stage
in 19 minutes and 52 seconds
......@@ -586,7 +586,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING:, k]),
scale=softplus(gauss[1][..., self.ENCODING :, k]),
),
reinterpreted_batch_ndims=1,
)
......
......@@ -413,7 +413,9 @@ if not tune:
callbacks=callbacks_,
)
gmvaep.save_weights("{}_final_weights.h5".format(run_ID))
gmvaep.save_weights(
"GMVAE_loss={}_encoding={}_final_weights.h5".format(loss, encoding_size)
)
# To avoid stability issues
tf.keras.backend.clear_session()
......
......@@ -239,4 +239,4 @@ def tune_search(
print(tuner.results_summary())
return best_hparams, best_run
\ No newline at end of file
return best_hparams, best_run
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment