Skip to content
Snippets Groups Projects
Commit ad44f3fa authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Merged model_training.py and hyperparameter_tuning.py in one script. Code is...

Merged model_training.py and hyperparameter_tuning.py in one script. Code is now more concise and easy to use
parent d06f7d41
Branches
Tags
No related merge requests found
Pipeline #83665 passed
# @author lucasmiranda42
from datetime import datetime
from deepof.data import *
from deepof.hypermodels import *
from .example_utils import *
from kerastuner import BayesianOptimization
from tensorflow import keras
import argparse
import os, pickle
parser = argparse.ArgumentParser(
description="hyperparameter tuning for DeepOF autoencoder models"
)
parser.add_argument("--train_path", "-tp", help="set training set path", type=str)
parser.add_argument(
"--components",
"-k",
help="set the number of components for the MMVAE(P) model. Defaults to 1",
type=int,
default=1,
)
parser.add_argument(
"--input-type",
"-d",
help="Select an input type for the autoencoder hypermodels. \
It must be one of coords, dists, angles, coords+dist, coords+angle or coords+dist+angle",
type=str,
default="coords",
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
default=25,
type=int,
)
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, "
"S2SVAEP, S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
default="S2SVAE",
type=str,
)
args = parser.parse_args()
train_path = os.path.abspath(args.train_path)
val_path = os.path.abspath(args.val_path)
input_type = args.input_type
bayopt_trials = args.bayopt
hyp = args.hypermodel
k = args.components
if not train_path:
raise ValueError("Set a valid data path for the training to run")
if not val_path:
raise ValueError("Set a valid data path for the validation to run")
assert input_type in [
"coords",
"dists",
"angles",
"coords+dist",
"coords+angle",
"coords+dist+angle",
], "Invalid input type. Type python hyperparameter_tuning.py -h for help."
assert hyp in [
"S2SAE",
"S2SGMVAE",
], "Invalid hypermodel. Type python hyperparameter_tuning.py -h for help."
log_dir = os.path.abspath(
"logs/fit/{}_{}".format(hyp, datetime.now().strftime("%Y%m%d-%H%M%S"))
)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
treatment_dict = load_treatments(train_path)
project_coords = project(
path=train_path, # Path where to find the required files
smooth_alpha=0.85, # Alpha value for exponentially weighted smoothing
arena="circular", # Type of arena used in the experiments
arena_dims=tuple([380]), # Dimensions of the arena. Just one if it's circular
video_format=".mp4",
table_format=".h5",
exp_conditions=treatment_dict,
).run(verbose=True)
# Coordinates for training data
coords = project_coords.get_coords(center="Center", align="Spine_1", align_inplace=True)
distances = project_coords.get_distances()
angles = project_coords.get_angles()
coords_distances = merge_tables(coords, distances)
coords_angles = merge_tables(coords, angles)
dists_angles = merge_tables(distances, angles)
coords_dist_angles = merge_tables(coords, distances, angles)
def batch_preprocess(tab_dict):
"""Returns a preprocessed instance of the input table_dict object"""
return tab_dict.preprocess(
window_size=window_size,
window_step=window_step,
scale="standard",
conv_filter=gaussian_filter,
sigma=1,
test_videos=val_num,
shuffle=True,
)
input_dict_train = {
"coords": coords,
"dists": distances,
"angles": angles,
"coords+dist": coords_distances,
"coords+angle": coords_angles,
"dists+angle": dists_angles,
"coords+dist+angle": coords_dist_angles,
}
print("Preprocessing data...")
for key, value in input_dict_train.items():
input_dict_train[key] = batch_preprocess(value)
print("Done!")
def tune_search(train, test, project_name, hyp):
"""Define the search space using keras-tuner and bayesian optimization"""
if hyp == "S2SAE":
hypermodel = SEQ_2_SEQ_AE(input_shape=train.shape)
elif hyp == "S2SGMVAE":
hypermodel = SEQ_2_SEQ_GMVAE(
input_shape=train.shape,
loss="ELBO+MMD",
predictor=False,
number_of_components=k,
).build()
else:
return False
tuner = BayesianOptimization(
hypermodel,
max_trials=bayopt_trials,
executions_per_trial=1,
objective="val_mae",
seed=42,
directory="BayesianOptx",
project_name=project_name,
)
print(tuner.search_space_summary())
tuner.search(
train,
train,
epochs=30,
validation_data=(test, test),
verbose=1,
batch_size=256,
callbacks=[
tensorboard_callback,
tf.keras.callbacks.EarlyStopping("val_mae", patience=5),
],
)
print(tuner.results_summary())
best_hyperparameters = tuner.get_best_hyperparameters(num_trials=1)[0]
best_model = tuner.hypermodel.build(best_hyperparameters)
return best_hyperparameters, best_model
# Runs hyperparameter tuning with the specified parameters and saves the results
best_hyperparameters, best_model = tune_search(
input_dict_train[input_type],
input_dict_val[input_type],
"{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
hyp=hyp,
)
# Saves a compiled, untrained version of the best model
best_model.build(input_dict_train[input_type].shape)
best_model.save("{}-based_{}_BAYESIAN_OPT.h5".format(input_type, hyp), save_format="tf")
# Saves the best hyperparameters
with open(
"{}-based_{}_BAYESIAN_OPT_params.pickle".format(input_type, hyp), "wb"
) as handle:
pickle.dump(best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL)
......@@ -9,11 +9,10 @@ usage: python -m examples.model_training -h
"""
from datetime import datetime
from deepof.data import *
from deepof.models import *
from deepof.utils import *
from .example_utils import *
from .train_utils import *
from tensorflow import keras
parser = argparse.ArgumentParser(
......@@ -84,7 +83,7 @@ parser.add_argument(
"--hyperparameters",
"-hp",
help="Path pointing to a pickled dictionary of network hyperparameters. "
"Thought to be used with the output of hyperparameter_tuning.py",
"Thought to be used with the output of hyperparameter tuning",
)
parser.add_argument(
"--encoding-size",
......@@ -159,15 +158,39 @@ parser.add_argument(
type=int,
default=380,
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
type=str2bool,
default=False,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
default=25,
type=int,
)
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, "
"S2SVAEP, S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
default="S2SVAE",
type=str,
)
args = parser.parse_args()
arena_dims = args.arena_dims
batch_size = args.batch_size
bayopt_trials = args.bayopt
encoding = args.encoding_size
exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
hyp = args.hypermodel
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
......@@ -178,6 +201,7 @@ predictor = float(args.predictor)
runs = args.stability_check
smooth_alpha = args.smooth_alpha
train_path = os.path.abspath(args.train_path)
tune = args.hyperparameter_tuning
val_num = args.val_num
variational = bool(args.variational)
window_size = args.window_size
......@@ -204,6 +228,7 @@ assert input_type in [
hparams = load_hparams(hparams, encoding)
treatment_dict = load_treatments(train_path)
# noinspection PyTypeChecker
project_coords = project(
arena="circular", # Type of arena used in the experiments
arena_dims=tuple(
......@@ -262,126 +287,142 @@ X_train = input_dict_train[input_type][0]
X_val = input_dict_train[input_type][1]
print("Done!")
# Training loop
for run in range(runs):
# To avoid stability issues
tf.keras.backend.clear_session()
# Proceed with training mode. Fit autoencoder with the same parameters,
# as many times as specified by runs
if not tune:
run_ID = "{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("P" if predictor > 0 and variational else ""),
("_components={}".format(k) if variational else ""),
("_loss={}".format(loss) if variational else ""),
("_kl_warmup={}".format(kl_wu) if variational else ""),
("_mmd_warmup={}".format(mmd_wu) if variational else ""),
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
log_dir = os.path.abspath("logs/fit/{}".format(run_ID))
tensorboard_callback = keras.callbacks.TensorBoard(
log_dir=log_dir, histogram_freq=1, profile_batch=2,
)
cp_callback = (
tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
verbose=1,
save_best_only=False,
save_weights_only=True,
save_freq="epoch",
),
)
# Training loop
for run in range(runs):
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250, max_rate=0.005,
)
# To avoid stability issues
tf.keras.backend.clear_session()
if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
print(ae.summary())
ae.save_weights("./logs/checkpoints/cp-{epoch:04d}.ckpt".format(epoch=0))
# Fit the specified model to the data
history = ae.fit(
x=X_train,
y=X_train,
epochs=25,
batch_size=batch_size,
verbose=1,
validation_data=(X_val, X_val),
callbacks=[
tensorboard_callback,
cp_callback,
onecycle,
tf.keras.callbacks.EarlyStopping(
"val_loss", patience=10, restore_best_weights=True
),
],
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
)
ae.save_weights("{}_final_weights.h5".format(run_ID))
else:
(
encoder,
generator,
grouper,
gmvaep,
kl_warmup_callback,
mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
kl_warmup_epochs=kl_wu,
mmd_warmup_epochs=mmd_wu,
predictor=predictor,
overlap_loss=overlap_loss,
architecture_hparams=hparams,
).build(
X_train.shape
)
print(gmvaep.summary())
callbacks_ = [
tensorboard_callback,
cp_callback,
onecycle,
tf.keras.callbacks.EarlyStopping(
"val_loss", patience=10, restore_best_weights=True
),
]
if "ELBO" in loss and kl_wu > 0:
callbacks_.append(kl_warmup_callback)
if "MMD" in loss and mmd_wu > 0:
callbacks_.append(mmd_warmup_callback)
if predictor == 0:
history = gmvaep.fit(
if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
print(ae.summary())
ae.save_weights("./logs/checkpoints/cp-{epoch:04d}.ckpt".format(epoch=0))
# Fit the specified model to the data
history = ae.fit(
x=X_train,
y=X_train,
epochs=250,
epochs=25,
batch_size=batch_size,
verbose=1,
validation_data=(X_val, X_val,),
callbacks=callbacks_,
validation_data=(X_val, X_val),
callbacks=[
tensorboard_callback,
cp_callback,
onecycle,
tf.keras.callbacks.EarlyStopping(
"val_loss", patience=10, restore_best_weights=True
),
],
)
ae.save_weights("{}_final_weights.h5".format(run_ID))
else:
history = gmvaep.fit(
x=X_train[:-1],
y=[X_train[:-1], X_train[1:]],
epochs=250,
batch_size=batch_size,
verbose=1,
validation_data=(X_val[:-1], [X_val[:-1], X_val[1:]],),
callbacks=callbacks_,
(
encoder,
generator,
grouper,
gmvaep,
kl_warmup_callback,
mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
kl_warmup_epochs=kl_wu,
mmd_warmup_epochs=mmd_wu,
predictor=predictor,
overlap_loss=overlap_loss,
architecture_hparams=hparams,
).build(
X_train.shape
)
print(gmvaep.summary())
callbacks_ = [
tensorboard_callback,
cp_callback,
onecycle,
tf.keras.callbacks.EarlyStopping(
"val_loss", patience=10, restore_best_weights=True
),
]
if "ELBO" in loss and kl_wu > 0:
callbacks_.append(kl_warmup_callback)
if "MMD" in loss and mmd_wu > 0:
callbacks_.append(mmd_warmup_callback)
if predictor == 0:
history = gmvaep.fit(
x=X_train,
y=X_train,
epochs=250,
batch_size=batch_size,
verbose=1,
validation_data=(X_val, X_val,),
callbacks=callbacks_,
)
else:
history = gmvaep.fit(
x=X_train[:-1],
y=[X_train[:-1], X_train[1:]],
epochs=250,
batch_size=batch_size,
verbose=1,
validation_data=(X_val[:-1], [X_val[:-1], X_val[1:]],),
callbacks=callbacks_,
)
gmvaep.save_weights("{}_final_weights.h5".format(run_ID))
# To avoid stability issues
tf.keras.backend.clear_session()
else:
# Runs hyperparameter tuning with the specified parameters and saves the results
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
)
gmvaep.save_weights("{}_final_weights.h5".format(run_ID))
best_hyperparameters, best_model = tune_search(
X_train,
X_val,
bayopt_trials=bayopt_trials,
hypermodel=hyp,
k=k,
kl_wu=kl_wu,
loss=loss,
mmd_wu=mmd_wu,
overlap_loss=overlap_loss,
predictor=predictor,
project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
tensorboard_callback=tensorboard_callback,
)
# To avoid stability issues
tf.keras.backend.clear_session()
# Saves a compiled, untrained version of the best model
best_model.build(input_dict_train[input_type].shape)
best_model.save(
"{}-based_{}_BAYESIAN_OPT.h5".format(input_type, hyp), save_format="tf"
)
# Saves the best hyperparameters
with open(
"{}-based_{}_BAYESIAN_OPT_params.pickle".format(input_type, hyp), "wb"
) as handle:
pickle.dump(
best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL
)
# TODO:
# - Investigate how goussian filters affect reproducibility (in a systematic way)
......
......@@ -10,7 +10,7 @@ import argparse
import cv2
import os
from deepof.models import *
from .example_utils import *
from .train_utils import *
parser = argparse.ArgumentParser(
......
......@@ -7,9 +7,17 @@
Simple utility functions used in deepof example scripts. These are not part of the main package
"""
from datetime import datetime
from kerastuner import BayesianOptimization
from typing import Tuple, Union, Any
import deepof.hypermodels
import deepof.model_utils
import keras
import numpy as np
import os
import pickle
import tensorflow as tf
def load_hparams(hparams, encoding):
......@@ -50,6 +58,118 @@ def load_treatments(train_path):
return treatment_dict
def get_callbacks(
X_train: np.array,
batch_size: int,
variational: bool,
predictor: float,
k: int,
loss: str,
kl_wu: int,
mmd_wu: int,
) -> Tuple:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
- tensorboard_callback: for real-time visualization;
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
run_ID = "{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("P" if predictor > 0 and variational else ""),
("_components={}".format(k) if variational else ""),
("_loss={}".format(loss) if variational else ""),
("_kl_warmup={}".format(kl_wu) if variational else ""),
("_mmd_warmup={}".format(mmd_wu) if variational else ""),
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
log_dir = os.path.abspath("logs/fit/{}".format(run_ID))
tensorboard_callback = keras.callbacks.TensorBoard(
log_dir=log_dir, histogram_freq=1, profile_batch=2,
)
cp_callback = (
tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
verbose=1,
save_best_only=False,
save_weights_only=True,
save_freq="epoch",
),
)
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250, max_rate=0.005,
)
return run_ID, tensorboard_callback, cp_callback, onecycle
def tune_search(
train: np.array,
test: np.array,
bayopt_trials: int,
hypermodel: str,
k: int,
kl_wu: int,
loss: str,
mmd_wu: int,
overlap_loss: float,
predictor: float,
project_name: str,
tensorboard_callback: tf.keras.callbacks,
) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization"""
if hypermodel == "S2SAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
elif hypermodel == "S2SGMVAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
input_shape=train.shape,
loss=loss,
number_of_components=k,
kl_warmup_epochs=kl_wu,
mmd_warmup_epochs=mmd_wu,
predictor=predictor,
overlap_loss=overlap_loss,
)
else:
return False
tuner = BayesianOptimization(
hypermodel,
max_trials=bayopt_trials,
executions_per_trial=1,
objective="val_mae",
seed=42,
directory="BayesianOptx",
project_name=project_name,
)
print(tuner.search_space_summary())
tuner.search(
train,
train,
epochs=30,
validation_data=(test, test),
verbose=1,
batch_size=256,
callbacks=[
tensorboard_callback,
tf.keras.callbacks.EarlyStopping("val_mae", patience=5),
],
)
print(tuner.results_summary())
best_hparams = tuner.get_best_hyperparameters(num_trials=1)[0]
best_run = tuner.hypermodel.build(best_hparams)
return best_hparams, best_run
# TODO:
# - load_treatments should be part of the main data module. If available in the main directory,
# a table (preferrable in csv) should be loaded as metadata of the coordinates automatically.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment