Commit d79fb2d0 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented autoencoder fitting as part of data.py

parent ffd17471
......@@ -786,6 +786,8 @@ class coordinates:
encoding_size: int = 4,
hparams: dict = None,
kl_warmup: int = 0,
log_history: bool = True,
log_hparams: bool = False,
loss: str = "ELBO",
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
......@@ -794,7 +796,7 @@ class coordinates:
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
save_checkpoints: bool = True,
save_checkpoints: bool = False,
variational: bool = True,
) -> Tuple:
"""
......@@ -834,21 +836,23 @@ class coordinates:
"""
trained_models = deepof.train_utils.deep_unsupervised_embedding(
preprocessed_object,
batch_size,
encoding_size,
hparams,
kl_warmup,
loss,
mmd_warmup,
montecarlo_kl,
n_components,
output_path,
phenotype_class,
predictor,
pretrained,
save_checkpoints,
variational,
preprocessed_object=preprocessed_object,
batch_size=batch_size,
encoding_size=encoding_size,
hparams=hparams,
kl_warmup=kl_warmup,
log_history=log_history,
log_hparams=log_hparams,
loss=loss,
mmd_warmup=mmd_warmup,
montecarlo_kl=montecarlo_kl,
n_components=n_components,
output_path=output_path,
phenotype_class=phenotype_class,
predictor=predictor,
pretrained=pretrained,
save_checkpoints=save_checkpoints,
variational=variational,
)
# returns a list of trained tensorflow models
......
......@@ -8,6 +8,9 @@ Functions and general utilities for rule-based pose estimation. See documentatio
"""
from itertools import combinations
from scipy import stats
from typing import Any, List, NewType
import cv2
import deepof.utils
import matplotlib.pyplot as plt
......@@ -17,9 +20,6 @@ import pandas as pd
import regex as re
import seaborn as sns
import warnings
from itertools import combinations
from scipy import stats
from typing import Any, List, NewType
# Ignore warning with no downstream effect
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
......
......@@ -7,8 +7,8 @@
Simple utility functions used in deepof example scripts. These are not part of the main package
"""
from datetime import date, datetime
from datetime import date, datetime
from kerastuner import BayesianOptimization, Hyperband
from kerastuner import HyperParameters
from kerastuner_tensorboard_logger import TensorBoardLogger
......@@ -21,8 +21,13 @@ import os
import pickle
import tensorflow as tf
# Ignore warning with no downstream effect
tf.get_logger().setLevel("ERROR")
tf.autograph.set_verbosity(0)
class CustomStopper(tf.keras.callbacks.EarlyStopping):
""" Custom callback for """
""" Custom early stopping callback. Prevents the model from stopping before warmup is over """
def __init__(self, start_epoch, *args, **kwargs):
super(CustomStopper, self).__init__(*args, **kwargs)
......@@ -136,20 +141,22 @@ def get_callbacks(
def deep_unsupervised_embedding(
preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
batch_size: int = 256,
encoding_size: int = 4,
hparams: dict = None,
kl_warmup: int = 0,
loss: str = "ELBO",
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
output_path: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
save_checkpoints: bool = True,
variational: bool = True,
batch_size: int,
encoding_size: int,
hparams: dict,
kl_warmup: int,
log_history: bool,
log_hparams: bool,
loss: str,
mmd_warmup,
montecarlo_kl,
n_components,
output_path,
phenotype_class,
predictor: float,
pretrained: str,
save_checkpoints: bool,
variational: bool,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -160,6 +167,7 @@ def deep_unsupervised_embedding(
tf.keras.backend.clear_session()
# Defines what to log on tensorboard (useful for trying out different models)
logparam = {
"encoding": encoding_size,
"k": n_components,
......@@ -169,7 +177,7 @@ def deep_unsupervised_embedding(
logparam["pheno_weight"] = phenotype_class
# Load callbacks
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
run_ID, *cbacks = get_callbacks(
X_train=X_train,
batch_size=batch_size,
cp=save_checkpoints,
......@@ -180,61 +188,64 @@ def deep_unsupervised_embedding(
logparam=logparam,
outpath=output_path,
)
if not log_history:
cbacks = cbacks[1:]
# Logs hyperparameters to tensorboard
logparams = [
hp.HParam(
"encoding",
hp.Discrete([2, 4, 6, 8, 12, 16]),
display_name="encoding",
description="encoding size dimensionality",
),
hp.HParam(
"k",
hp.IntInterval(min_value=1, max_value=15),
display_name="k",
description="cluster_number",
),
hp.HParam(
"loss",
hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
display_name="loss function",
description="loss function",
),
]
rec = "reconstruction_" if phenotype_class else ""
metrics = [
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
if phenotype_class:
logparams.append(
if log_hparams:
logparams = [
hp.HParam(
"pheno_weight",
hp.RealInterval(min_value=0.0, max_value=1000.0),
display_name="pheno weight",
description="weight applied to phenotypic classifier from the latent space",
)
)
metrics += [
hp.Metric(
"phenotype_prediction_accuracy",
display_name="phenotype_prediction_accuracy",
"encoding",
hp.Discrete([2, 4, 6, 8, 12, 16]),
display_name="encoding",
description="encoding size dimensionality",
),
hp.Metric(
"phenotype_prediction_auc",
display_name="phenotype_prediction_auc",
hp.HParam(
"k",
hp.IntInterval(min_value=1, max_value=25),
display_name="k",
description="cluster_number",
),
hp.HParam(
"loss",
hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
display_name="loss function",
description="loss function",
),
]
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
).as_default():
hp.hparams_config(
hparams=logparams,
metrics=metrics,
)
rec = "reconstruction_" if phenotype_class else ""
metrics = [
hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
]
if phenotype_class:
logparams.append(
hp.HParam(
"pheno_weight",
hp.RealInterval(min_value=0.0, max_value=1000.0),
display_name="pheno weight",
description="weight applied to phenotypic classifier from the latent space",
)
)
metrics += [
hp.Metric(
"phenotype_prediction_accuracy",
display_name="phenotype_prediction_accuracy",
),
hp.Metric(
"phenotype_prediction_auc",
display_name="phenotype_prediction_auc",
),
]
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
).as_default():
hp.hparams_config(
hparams=logparams,
metrics=metrics,
)
# Build models
if not variational:
......@@ -285,10 +296,8 @@ def deep_unsupervised_embedding(
batch_size=batch_size,
verbose=1,
validation_data=(X_val, X_val),
callbacks=[
tensorboard_callback,
cp_callback,
onecycle,
callbacks=cbacks
+ [
CustomStopper(
monitor="val_loss",
patience=5,
......@@ -300,10 +309,7 @@ def deep_unsupervised_embedding(
else:
callbacks_ = [
tensorboard_callback,
cp_callback,
onecycle,
callbacks_ = cbacks + [
CustomStopper(
monitor="val_loss",
patience=5,
......@@ -343,57 +349,60 @@ def deep_unsupervised_embedding(
callbacks=callbacks_,
)
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
output = gmvaep.predict(X_val)
if phenotype_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar(
"val_prediction_mae".format(rec), pred_mae, step=1
)
tf.summary.scalar(
"val_prediction_mse".format(rec), pred_mse, step=1
)
if phenotype_class:
pheno_acc = tf.keras.metrics.binary_accuracy(
y_val, tf.squeeze(pheno)
if log_hparams:
# noinspection PyUnboundLocalVariable
def tensorboard_metric_logging(run_dir: str, hpms: Any):
output = gmvaep.predict(X_val)
if phenotype_class or predictor:
reconstruction = output[0]
prediction = output[1]
pheno = output[-1]
else:
reconstruction = output
with tf.summary.create_file_writer(run_dir).as_default():
hp.hparams(hpms) # record the values used in this trial
val_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, reconstruction)
)
pheno_auc = roc_auc_score(y_val, pheno)
tf.summary.scalar(
"phenotype_prediction_accuracy", pheno_acc, step=1
val_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, reconstruction)
)
tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
# Logparams to tensorboard
tensorboard_metric_logging(
os.path.join(output_path, "hparams", run_ID),
logparam,
)
tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
if predictor:
pred_mae = tf.reduce_mean(
tf.keras.metrics.mean_absolute_error(X_val, prediction)
)
pred_mse = tf.reduce_mean(
tf.keras.metrics.mean_squared_error(X_val, prediction)
)
tf.summary.scalar(
"val_prediction_mae".format(rec), pred_mae, step=1
)
tf.summary.scalar(
"val_prediction_mse".format(rec), pred_mse, step=1
)
if phenotype_class:
pheno_acc = tf.keras.metrics.binary_accuracy(
y_val, tf.squeeze(pheno)
)
pheno_auc = roc_auc_score(y_val, pheno)
tf.summary.scalar(
"phenotype_prediction_accuracy", pheno_acc, step=1
)
tf.summary.scalar(
"phenotype_prediction_auc", pheno_auc, step=1
)
# Logparams to tensorboard
tensorboard_metric_logging(
os.path.join(output_path, "hparams", run_ID),
logparam,
)
return return_list
......
......@@ -19,8 +19,6 @@ import regex as re
from copy import deepcopy
from itertools import combinations, product
from joblib import Parallel, delayed
# from skimage.transform import hough_ellipse
from sklearn import mixture
from tqdm import tqdm
from typing import Tuple, Any, List, Union, NewType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment