diff --git a/deepof/train_model.py b/deepof/train_model.py index 4975c33ee8e5a9020ee1aa70d338359f6fd6c99a..c1ff5acad38760d00e325008397798e9c61c70b0 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -490,7 +490,7 @@ if not tune: history = gmvaep.fit( x=Xs, y=ys, - epochs=35, + epochs=1, batch_size=batch_size, verbose=1, validation_data=( @@ -510,25 +510,24 @@ if not tune: ) ) - if logparam is not None: - # Logparams to tensorboard - def run(run_dir, hpms): - with tf.summary.create_file_writer(run_dir).as_default(): - hp.hparams(hpms) # record the values used in this trial - val_mae = tf.reduce_mean( - tf.keras.metrics.mean_absolute_error( - X_val, gmvaep.predict(X_val) - ) + # Logparams to tensorboard + def run(run_dir, hpms): + with tf.summary.create_file_writer(run_dir).as_default(): + hp.hparams(hpms) # record the values used in this trial + val_mae = tf.reduce_mean( + tf.keras.metrics.mean_absolute_error( + X_val, gmvaep.predict(X_val) ) - val_mse = tf.reduce_mean( - tf.keras.metrics.mean_squared_error( - X_val, gmvaep.predict(X_val) - ) + ) + val_mse = tf.reduce_mean( + tf.keras.metrics.mean_squared_error( + X_val, gmvaep.predict(X_val) ) - tf.summary.scalar("val_mae", val_mae, step=1) - tf.summary.scalar("val_mse", val_mse, step=1) + ) + tf.summary.scalar("val_mae", val_mae, step=1) + tf.summary.scalar("val_mse", val_mse, step=1) - run(os.path.join(output_path, "hparams", run_ID), logparam) + run(os.path.join(output_path, "hparams", run_ID), logparam) # To avoid stability issues tf.keras.backend.clear_session() diff --git a/deepof_experiments.smk b/deepof_experiments.smk index 80c5a744de3b55ca2246d8e878aa34dfe27f7651..a6aca1538ff76ab6cd9bf77cc93a6174d575aa28 100644 --- a/deepof_experiments.smk +++ b/deepof_experiments.smk @@ -73,8 +73,8 @@ rule explore_encoding_dimension_and_loss_function: "--predictor 0 " "--variational True " "--loss {wildcards.loss} " - "--kl-warmup 2 " - "--mmd-warmup 2 " + "--kl-warmup 20 " + "--mmd-warmup 20 " "--montecarlo-kl 10 " "--encoding-size {wildcards.encs} " "--batch-size 256 "