Commit 0c3e9955 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent a462771c
......@@ -892,8 +892,9 @@ class coordinates:
montecarlo_kl: int = 10,
n_components: int = 25,
output_path: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
next_sequence_prediction: float = 0,
phenotype_prediction: float = 0,
rule_based_prediction: float = 0,
pretrained: str = False,
save_checkpoints: bool = False,
save_weights: bool = True,
......@@ -954,8 +955,9 @@ class coordinates:
montecarlo_kl=montecarlo_kl,
n_components=n_components,
output_path=output_path,
phenotype_class=phenotype_class,
predictor=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
pretrained=pretrained,
save_checkpoints=save_checkpoints,
save_weights=save_weights,
......
......@@ -373,8 +373,8 @@ if not tune:
montecarlo_kl=mc_kl,
n_components=k,
output_path=output_path,
phenotype_class=pheno_class,
predictor=predictor,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
save_checkpoints=False,
save_weights=True,
variational=variational,
......
......@@ -259,8 +259,9 @@ def autoencoder_fitting(
montecarlo_kl: int,
n_components: int,
output_path: str,
phenotype_class: float,
predictor: float,
next_sequence_prediction: float,
phenotype_prediction: float,
rule_based_prediction: float,
pretrained: str,
save_checkpoints: bool,
save_weights: bool,
......@@ -284,8 +285,8 @@ def autoencoder_fitting(
"k": n_components,
"loss": loss,
}
if phenotype_class:
logparam["pheno_weight"] = phenotype_class
if phenotype_prediction:
logparam["pheno_weight"] = phenotype_prediction
# Load callbacks
run_ID, *cbacks = get_callbacks(
......@@ -294,8 +295,8 @@ def autoencoder_fitting(
batch_size=batch_size,
cp=save_checkpoints,
variational=variational,
phenotype_class=phenotype_class,
predictor=predictor,
phenotype_class=phenotype_prediction,
predictor=next_sequence_prediction,
loss=loss,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
......@@ -308,9 +309,9 @@ def autoencoder_fitting(
cbacks = cbacks[1:]
# Logs hyperparameters to tensorboard
rec = "reconstruction_" if phenotype_class else ""
rec = "reconstruction_" if phenotype_prediction else ""
if log_hparams:
logparams, metrics = log_hyperparameters(phenotype_class, rec)
logparams, metrics = log_hyperparameters(phenotype_prediction, rec)
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
......@@ -347,8 +348,12 @@ def autoencoder_fitting(
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
phenotype_prediction=phenotype_class,
predictor=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
rule_based_features=(
y_train.shape[1] if not phenotype_prediction else y_train.shape[1] - 1
),
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
).build(
......@@ -399,11 +404,19 @@ def autoencoder_fitting(
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
if predictor > 0.0:
if next_sequence_prediction > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_class > 0.0:
if phenotype_prediction > 0.0:
ys += [y_train[:, 0]]
yvals += [y_val[:, 0]]
# Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:]
y_val = y_val[:, 1:]
if rule_based_prediction > 0.0:
ys += [y_train]
yvals += [y_val]
......@@ -440,8 +453,8 @@ def autoencoder_fitting(
ae,
Xvals,
yvals[-1],
phenotype_class,
predictor,
phenotype_prediction,
next_sequence_prediction,
rec,
)
......
......@@ -118,8 +118,8 @@ def test_autoencoder_fitting(
mmd_warmup=1,
n_components=2,
loss=loss,
phenotype_class=pheno_class,
predictor=predictor,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
variational=variational,
entropy_samples=10,
entropy_knn=5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment