Commit 19df12a5 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented outlier interpolation

parent f4b22d1c
......@@ -137,7 +137,23 @@ def get_callbacks(
return callbacks
def deep_unsupervised_embedding():
def deep_unsupervised_embedding(
preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
batch_size: int = 256,
encoding_size: int = 4,
hparams: dict = None,
kl_warmup: int = 0,
loss: str = "ELBO",
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
outpath: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
save_checkpoints: bool = True,
variational: bool = True,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
# Load data
......@@ -146,19 +162,89 @@ def deep_unsupervised_embedding():
# To avoid stability issues
tf.keras.backend.clear_session()
# defines what to log on tensorboard (useful for trying out different models)
logparam = {
"encoding": encoding_size,
"k": k,
"loss": loss,
}
if phenotype_class:
logparam["pheno_weight"] = phenotype_class
# Load callbacks
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train=X_train,
batch_size=batch_size,
cp=save_checkpoints,
variational=variational,
phenotype_class=pheno_class,
phenotype_class=phenotype_class,
predictor=predictor,
loss=loss,
logparam=logparam,
outpath=output_path,
)
logparams = [
hp.HParam(
"encoding",
hp.Discrete([2, 4, 6, 8, 12, 16]),
display_name="encoding",
description="encoding size dimensionality",
),
hp.HParam(
"k",
hp.IntInterval(min_value=1, max_value=15),
display_name="k",
description="cluster_number",
),
hp.HParam(
"loss",
hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
display_name="loss function",
description="loss function",
),
hp.HParam(
"run",
hp.Discrete([0, 1, 2]),
display_name="trial run",
description="trial run",
),
]
rec = "reconstruction_" if phenotype_class else ""
metrics = [
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
logparam["run"] = run
if phenotype_class:
logparams.append(
hp.HParam(
"pheno_weight",
hp.RealInterval(min_value=0.0, max_value=1000.0),
display_name="pheno weight",
description="weight applied to phenotypic classifier from the latent space",
)
)
metrics += [
hp.Metric(
"phenotype_prediction_accuracy",
display_name="phenotype_prediction_accuracy",
),
hp.Metric(
"phenotype_prediction_auc",
display_name="phenotype_prediction_auc",
),
]
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
).as_default():
hp.hparams_config(
hparams=logparams,
metrics=metrics,
)
# Build models
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
......@@ -249,7 +335,7 @@ def deep_unsupervised_embedding():
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if pheno_class > 0.0:
if phenotype_class > 0.0:
ys += [y_train]
yvals += [y_val]
......@@ -280,7 +366,7 @@ def tune_search(
loss: str,
mmd_warmup_epochs: int,
overlap_loss: float,
pheno_class: float,
phenotype_class: float,
predictor: float,
project_name: str,
callbacks: List,
......@@ -300,7 +386,7 @@ def tune_search(
- loss (str): one of [ELBO, MMD, ELBO+MMD]
- overlap_loss (float): assigns as weight to an extra loss term which
penalizes overlap between GM components
- pheno_class (float): adds an extra regularizing neural network to the model,
- phenotype_class (float): adds an extra regularizing neural network to the model,
which tries to predict the phenotype of the animal from which the sequence comes
- predictor (float): adds an extra regularizing neural network to the model,
which tries to predict the next frame from the current one
......@@ -324,7 +410,7 @@ def tune_search(
if hypermodel == "S2SAE": # pragma: no cover
assert (
predictor == 0.0 and pheno_class == 0.0
predictor == 0.0 and phenotype_class == 0.0
), "Prediction branches are only available for variational models. See documentation for more details"
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
......@@ -337,7 +423,7 @@ def tune_search(
mmd_warmup_epochs=mmd_warmup_epochs,
number_of_components=k,
overlap_loss=overlap_loss,
phenotype_predictor=pheno_class,
phenotype_predictor=phenotype_class,
predictor=predictor,
)
......@@ -378,7 +464,7 @@ def tune_search(
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if pheno_class > 0.0:
if phenotype_class > 0.0:
ys += [y_train]
yvals += [y_val]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment