Commit 08963b26 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 77d9eba8
Pipeline #98419 failed with stages
in 36 seconds
......@@ -903,6 +903,7 @@ class coordinates:
reg_cluster_variance: bool = False,
entropy_samples: int = 10000,
entropy_knn: int = 100,
run: int = 0,
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder.
......@@ -966,6 +967,7 @@ class coordinates:
reg_cluster_variance=reg_cluster_variance,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
run=run,
)
# returns a list of trained tensorflow models
......
......@@ -231,6 +231,13 @@ parser.add_argument(
type=int,
default=5,
)
parser.add_argument(
"--run",
"-rid",
help="Sets the run ID of the experiment (for naming output files only). If 0 (default), uses a timestamp instead",
type=int,
default=0,
)
args = parser.parse_args()
......@@ -264,6 +271,7 @@ val_num = args.val_num
variational = bool(args.variational)
window_size = args.window_size
window_step = args.window_step
run = args.run
if not train_path:
raise ValueError("Set a valid data path for the training to run")
......@@ -395,6 +403,7 @@ if not tune:
reg_cluster_variance=("variance" in latent_reg),
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
run=run,
)
else:
......@@ -416,6 +425,7 @@ else:
loss=loss,
logparam=logparam,
outpath=output_path,
run=run,
)
best_hyperparameters, best_model = tune_search(
......
......@@ -99,28 +99,16 @@ def get_callbacks(
run_ID = "{}{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
(
"_NextSeqPred={}".format(next_sequence_prediction)
if next_sequence_prediction > 0 and variational
else ""
),
(
"_PhenoPred={}".format(phenotype_prediction)
if phenotype_prediction > 0
else ""
),
(
"_RuleBasedPred={}".format(rule_based_prediction)
if rule_based_prediction > 0
else ""
),
("_NextSeqPred={}".format(next_sequence_prediction) if variational else ""),
("_PhenoPred={}".format(phenotype_prediction) if variational else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""),
("_loss={}".format(loss) if variational else ""),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
("_k={}".format(logparam["k"]) if logparam is not None else ""),
("_latreg={}".format(latreg)),
("_entknn={}".format(entropy_knn)),
("_run={}".format(run) if run else ""),
(datetime.now().strftime("%Y%m%d-%H%M%S")),
("_{}".format(datetime.now().strftime("%Y%m%d-%H%M%S")) if not run else ""),
)
log_dir = os.path.abspath(os.path.join(outpath, "fit", run_ID))
......@@ -305,6 +293,7 @@ def autoencoder_fitting(
reg_cluster_variance: bool,
entropy_samples: int,
entropy_knn: int,
run: int = run,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -340,6 +329,7 @@ def autoencoder_fitting(
reg_cluster_variance=reg_cluster_variance,
logparam=logparam,
outpath=output_path,
run=run,
)
if not log_history:
cbacks = cbacks[1:]
......@@ -459,8 +449,8 @@ def autoencoder_fitting(
y_val = y_val[:, 1:]
if rule_based_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:]]
yvals += [y_val[-Xs.shape[0]:]]
ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xs.shape[0] :]]
ae.fit(
x=Xs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment