Commit 03c57ea9 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented outlier interpolation

parent ba5967e6
Pipeline #93040 passed with stage
in 25 minutes and 48 seconds
......@@ -836,7 +836,7 @@ class coordinates:
"""
trained_models = deepof.train_utils.deep_unsupervised_embedding(
trained_models = deepof.train_utils.autoencoder_fitting(
preprocessed_object=preprocessed_object,
batch_size=batch_size,
encoding_size=encoding_size,
......
......@@ -139,7 +139,7 @@ def get_callbacks(
return callbacks
def deep_unsupervised_embedding(
def autoencoder_fitting(
preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
batch_size: int,
encoding_size: int,
......@@ -308,6 +308,9 @@ def deep_unsupervised_embedding(
],
)
if save_weights:
ae.save_weights("{}_final_weights.h5".format(run_ID))
else:
callbacks_ = cbacks + [
......@@ -340,7 +343,7 @@ def deep_unsupervised_embedding(
ae.fit(
x=Xs,
y=ys,
epochs=35,
epochs=2,
batch_size=batch_size,
verbose=1,
validation_data=(
......@@ -350,6 +353,9 @@ def deep_unsupervised_embedding(
callbacks=callbacks_,
)
if save_weights:
ae.save_weights("{}_final_weights.h5".format(run_ID))
if log_hparams:
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
......
This diff is collapsed.
......@@ -142,7 +142,7 @@ def test_tune_search(
loss=loss,
mmd_warmup_epochs=0,
overlap_loss=overlap_loss,
pheno_class=pheno_class,
phenotype_class=pheno_class,
predictor=predictor,
project_name="test_run",
callbacks=callbacks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment