Commit 0f9d9ac7 authored by lucas_miranda's avatar lucas_miranda
Browse files

encoder, generator and full vae now instantiable from models.py

parent eff66fa2
......@@ -60,7 +60,9 @@ assert hyp in [
"S2SVAE-MMD",
], "Invalid hypermodel. Type python hyperparameter_tuning.py -h for help."
log_dir = os.path.abspath("logs/fit/{}_{}".format(hyp, datetime.now().strftime("%Y%m%d-%H%M%S")))
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(hyp, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
with open(
......@@ -221,9 +223,11 @@ best_hyperparameters, best_model = tune_search(
hyp=hyp,
)
#Saves a compiled, untrained version of the best model
# Saves a compiled, untrained version of the best model
best_model.save("{}-based_{}_BAYESIAN_OPT.h5".format(input_type, hyp), save_format="tf")
#Saves the best hyperparameters
with open("{}-based_{}_BAYESIAN_OPT_params.pickle".format(input_type, hyp), "wb") as handle:
pickle.dump(best_hyperparameters, handle, protocol=pickle.HIGHEST_PROTOCOL)
\ No newline at end of file
# Saves the best hyperparameters
with open(
"{}-based_{}_BAYESIAN_OPT_params.pickle".format(input_type, hyp), "wb"
) as handle:
pickle.dump(best_hyperparameters, handle, protocol=pickle.HIGHEST_PROTOCOL)
This diff is collapsed.
This diff is collapsed.
......@@ -244,8 +244,16 @@ class SEQ_2_SEQ_VAE:
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
vae = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
#g = Input(shape=self.ENCODING)
#generator = Model(g, x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
# Build generator as a separate entity
g = Input(shape=self.ENCODING)
_generator = Model_D0(g)
_generator = Model_D1(_generator)
_generator = Model_D2(_generator)
_generator = Model_D3(_generator)
_generator = Model_D4(_generator)
_generator = Model_D5(_generator)
_x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
def huber_loss(x, x_decoded_mean):
huber_loss = Huber(reduction="sum", delta=100.0)
......@@ -258,7 +266,7 @@ class SEQ_2_SEQ_VAE:
experimental_run_tf_function=False,
)
return encoder, vae
return encoder, generator, vae
class SEQ_2_SEQ_MVAE:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment