Commit 4362b142 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 870f15a2
...@@ -172,17 +172,25 @@ parser.add_argument( ...@@ -172,17 +172,25 @@ parser.add_argument(
default=False, default=False,
) )
parser.add_argument( parser.add_argument(
"--phenotype-classifier", "--next-sequence-prediction",
"-pheno", "-nspred",
help="Activates the next sequence prediction branch of the variational Seq 2 Seq model with the specified weight. "
"Defaults to 0.0 (inactive)",
default=0.0,
type=float,
)
parser.add_argument(
"--phenotype-prediction",
"-ppred",
help="Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)", help="Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)",
default=0.0, default=0.0,
type=float, type=float,
) )
parser.add_argument( parser.add_argument(
"--predictor", "--rule-based-prediction",
"-pred", "-rbpred",
help="Activates the prediction branch of the variational Seq 2 Seq model with the specified weight. " help="Activates the rule-based trait prediction branch of the variational Seq 2 Seq model "
"Defaults to 0.0 (inactive)", "with the specified weight Defaults to 0.0 (inactive)",
default=0.0, default=0.0,
type=float, type=float,
) )
...@@ -246,8 +254,9 @@ mc_kl = args.montecarlo_kl ...@@ -246,8 +254,9 @@ mc_kl = args.montecarlo_kl
neuron_control = args.neuron_control neuron_control = args.neuron_control
output_path = os.path.join(args.output_path) output_path = os.path.join(args.output_path)
overlap_loss = args.overlap_loss overlap_loss = args.overlap_loss
pheno_class = float(args.phenotype_classifier) next_sequence_prediction = float(args.next_sequence_prediction)
predictor = float(args.predictor) phenotype_prediction = float(args.phenotype_prediction)
rule_based_prediction = float(args.rule_based_prediction)
smooth_alpha = args.smooth_alpha smooth_alpha = args.smooth_alpha
train_path = os.path.abspath(args.train_path) train_path = os.path.abspath(args.train_path)
tune = args.hyperparameter_tuning tune = args.hyperparameter_tuning
...@@ -282,8 +291,8 @@ logparam = { ...@@ -282,8 +291,8 @@ logparam = {
"k": k, "k": k,
"loss": loss, "loss": loss,
} }
if pheno_class: if phenotype_prediction:
logparam["pheno_weight"] = pheno_class logparam["pheno_weight"] = phenotype_prediction
# noinspection PyTypeChecker # noinspection PyTypeChecker
project_coords = project( project_coords = project(
...@@ -310,7 +319,7 @@ coords = project_coords.get_coords( ...@@ -310,7 +319,7 @@ coords = project_coords.get_coords(
center=animal_id + undercond + "Center", center=animal_id + undercond + "Center",
align=animal_id + undercond + "Spine_1", align=animal_id + undercond + "Spine_1",
align_inplace=True, align_inplace=True,
propagate_labels=(pheno_class > 0), propagate_labels=(phenotype_prediction > 0),
) )
distances = project_coords.get_distances() distances = project_coords.get_distances()
angles = project_coords.get_angles() angles = project_coords.get_angles()
...@@ -350,7 +359,7 @@ X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type]) ...@@ -350,7 +359,7 @@ X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
print("Training set shape:", X_train.shape) print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape) print("Validation set shape:", X_val.shape)
if pheno_class > 0: if phenotype_prediction > 0:
print("Training set label shape:", y_train.shape) print("Training set label shape:", y_train.shape)
print("Validation set label shape:", y_val.shape) print("Validation set label shape:", y_val.shape)
...@@ -373,8 +382,8 @@ if not tune: ...@@ -373,8 +382,8 @@ if not tune:
montecarlo_kl=mc_kl, montecarlo_kl=mc_kl,
n_components=k, n_components=k,
output_path=output_path, output_path=output_path,
phenotype_prediction=pheno_class, phenotype_prediction=phenotype_prediction,
next_sequence_prediction=predictor, next_sequence_prediction=next_sequence_prediction,
save_checkpoints=False, save_checkpoints=False,
save_weights=True, save_weights=True,
variational=variational, variational=variational,
...@@ -397,8 +406,8 @@ else: ...@@ -397,8 +406,8 @@ else:
variational=variational, variational=variational,
entropy_samples=entropy_samples, entropy_samples=entropy_samples,
entropy_knn=entropy_knn, entropy_knn=entropy_knn,
phenotype_prediction=pheno_class, phenotype_prediction=phenotype_prediction,
next_sequence_prediction=predictor, next_sequence_prediction=next_sequence_prediction,
loss=loss, loss=loss,
logparam=logparam, logparam=logparam,
outpath=output_path, outpath=output_path,
...@@ -415,8 +424,8 @@ else: ...@@ -415,8 +424,8 @@ else:
loss=loss, loss=loss,
mmd_warmup_epochs=mmd_wu, mmd_warmup_epochs=mmd_wu,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
phenotype_prediction=pheno_class, phenotype_prediction=phenotype_prediction,
next_sequence_prediction=predictor, next_sequence_prediction=next_sequence_prediction,
project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()), project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
callbacks=[ callbacks=[
tensorboard_callback, tensorboard_callback,
......
...@@ -256,7 +256,9 @@ def tensorboard_metric_logging( ...@@ -256,7 +256,9 @@ def tensorboard_metric_logging(
if phenotype_prediction: if phenotype_prediction:
idx = next(idx_generator) idx = next(idx_generator)
pheno_acc = tf.keras.metrics.binary_accuracy(y_val[idx], tf.squeeze(outputs[idx])) pheno_acc = tf.keras.metrics.binary_accuracy(
y_val[idx], tf.squeeze(outputs[idx])
)
pheno_auc = tf.keras.metrics.AUC() pheno_auc = tf.keras.metrics.AUC()
pheno_auc.update_state(y_val[idx], outputs[idx]) pheno_auc.update_state(y_val[idx], outputs[idx])
pheno_auc = pheno_auc.result().numpy() pheno_auc = pheno_auc.result().numpy()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment