Commit 4362b142 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 870f15a2
......@@ -172,17 +172,25 @@ parser.add_argument(
default=False,
)
parser.add_argument(
"--phenotype-classifier",
"-pheno",
"--next-sequence-prediction",
"-nspred",
help="Activates the next sequence prediction branch of the variational Seq 2 Seq model with the specified weight. "
"Defaults to 0.0 (inactive)",
default=0.0,
type=float,
)
parser.add_argument(
"--phenotype-prediction",
"-ppred",
help="Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)",
default=0.0,
type=float,
)
parser.add_argument(
"--predictor",
"-pred",
help="Activates the prediction branch of the variational Seq 2 Seq model with the specified weight. "
"Defaults to 0.0 (inactive)",
"--rule-based-prediction",
"-rbpred",
help="Activates the rule-based trait prediction branch of the variational Seq 2 Seq model "
"with the specified weight Defaults to 0.0 (inactive)",
default=0.0,
type=float,
)
......@@ -246,8 +254,9 @@ mc_kl = args.montecarlo_kl
neuron_control = args.neuron_control
output_path = os.path.join(args.output_path)
overlap_loss = args.overlap_loss
pheno_class = float(args.phenotype_classifier)
predictor = float(args.predictor)
next_sequence_prediction = float(args.next_sequence_prediction)
phenotype_prediction = float(args.phenotype_prediction)
rule_based_prediction = float(args.rule_based_prediction)
smooth_alpha = args.smooth_alpha
train_path = os.path.abspath(args.train_path)
tune = args.hyperparameter_tuning
......@@ -282,8 +291,8 @@ logparam = {
"k": k,
"loss": loss,
}
if pheno_class:
logparam["pheno_weight"] = pheno_class
if phenotype_prediction:
logparam["pheno_weight"] = phenotype_prediction
# noinspection PyTypeChecker
project_coords = project(
......@@ -310,7 +319,7 @@ coords = project_coords.get_coords(
center=animal_id + undercond + "Center",
align=animal_id + undercond + "Spine_1",
align_inplace=True,
propagate_labels=(pheno_class > 0),
propagate_labels=(phenotype_prediction > 0),
)
distances = project_coords.get_distances()
angles = project_coords.get_angles()
......@@ -350,7 +359,7 @@ X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape)
if pheno_class > 0:
if phenotype_prediction > 0:
print("Training set label shape:", y_train.shape)
print("Validation set label shape:", y_val.shape)
......@@ -373,8 +382,8 @@ if not tune:
montecarlo_kl=mc_kl,
n_components=k,
output_path=output_path,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
save_checkpoints=False,
save_weights=True,
variational=variational,
......@@ -397,8 +406,8 @@ else:
variational=variational,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
loss=loss,
logparam=logparam,
outpath=output_path,
......@@ -415,8 +424,8 @@ else:
loss=loss,
mmd_warmup_epochs=mmd_wu,
overlap_loss=overlap_loss,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
callbacks=[
tensorboard_callback,
......
......@@ -256,7 +256,9 @@ def tensorboard_metric_logging(
if phenotype_prediction:
idx = next(idx_generator)
pheno_acc = tf.keras.metrics.binary_accuracy(y_val[idx], tf.squeeze(outputs[idx]))
pheno_acc = tf.keras.metrics.binary_accuracy(
y_val[idx], tf.squeeze(outputs[idx])
)
pheno_auc = tf.keras.metrics.AUC()
pheno_auc.update_state(y_val[idx], outputs[idx])
pheno_auc = pheno_auc.result().numpy()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment