Commit d3bc34c8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored a bit the old visualization pipeline

parent 7ef86305
Pipeline #87190 passed with stage
in 19 minutes and 31 seconds
......@@ -46,36 +46,85 @@ parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition"
)
parser.add_argument("--data-path", "-vp", help="set validation set path", type=str)
parser.add_argument(
"--animal-id",
"-id",
help="Id of the animal to use. Empty string by default",
type=str,
default="",
)
parser.add_argument(
"--arena-dims",
"-adim",
help="diameter in mm of the utilised arena. Used for scaling purposes",
type=int,
default=380,
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
type=int,
default=25,
)
parser.add_argument(
"--components",
"-k",
help="set the number of components for the MMVAE(P) model. Defaults to 1",
help="set the number of components for the GMVAE(P) model. Defaults to 1",
type=int,
default=1,
)
parser.add_argument(
"--input-type",
"-d",
help="Select an input type for the autoencoder hypermodels. \
It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle. \
Defaults to coords.",
"--exclude-bodyparts",
"-exc",
help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas",
type=str,
default="coords",
default="",
)
parser.add_argument(
"--predictor",
"-pred",
help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
default=True,
"--gaussian-filter",
"-gf",
help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
type=str2bool,
default=False,
)
parser.add_argument(
"--variational",
"-v",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
default=True,
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
type=str2bool,
default=False,
)
parser.add_argument(
"--hyperparameters",
"-hp",
help="Path pointing to a pickled dictionary of network hyperparameters. "
"Thought to be used with the output of hyperparameter tuning",
type=str,
default=None,
)
parser.add_argument(
"--input-type",
"-d",
help="Select an input type for the autoencoder hypermodels. "
"It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
"Defaults to coords.",
type=str,
default="dists",
)
parser.add_argument(
"--kl-warmup",
"-klw",
help="Number of epochs during which the KL weight increases linearly from zero to 1. Defaults to 10",
default=10,
type=int,
)
parser.add_argument(
"--loss",
......@@ -86,17 +135,78 @@ parser.add_argument(
type=str,
)
parser.add_argument(
"--hyperparameters",
"-hp",
help="Path pointing to a pickled dictionary of network hyperparameters. "
"Thought to be used with the output of hyperparameter_tuning.py",
"--mmd-warmup",
"-mmdw",
help="Number of epochs during which the MMD weight increases linearly from zero to 1. Defaults to 10",
default=10,
type=int,
)
parser.add_argument(
"--overlap-loss",
"-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
type=str2bool,
default=False,
)
parser.add_argument(
"--phenotype-classifier",
"-pheno",
help="Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)",
default=0.0,
type=float,
)
parser.add_argument(
"--predictor",
"-pred",
help="Activates the prediction branch of the variational Seq 2 Seq model with the specified weight. "
"Defaults to 0.0 (inactive)",
default=0.0,
type=float,
)
parser.add_argument(
"--smooth-alpha",
"-sa",
help="Sets the exponential smoothing factor to apply to the input data. "
"Float between 0 and 1 (lower is more smooting)",
type=float,
default=0.99,
)
parser.add_argument(
"--stability-check",
"-s",
help="Sets the number of times that the model is trained and initialised. "
"If greater than 1 (the default), saves the cluster assignments to a dataframe on disk",
type=int,
default=1,
)
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
parser.add_argument(
"--val-num",
"-vn",
help="set number of videos of the training" "set to use for validation",
type=int,
default=1,
)
parser.add_argument(
"--variational",
"-v",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
default=True,
type=str2bool,
)
parser.add_argument(
"--encoding-size",
"-e",
help="Sets the dimensionality of the latent space. Defaults to 16.",
default=16,
"--window-size",
"-ws",
help="Sets the sliding window size to be used when building both training and validation sets. Defaults to 15",
type=int,
default=15,
)
parser.add_argument(
"--window-step",
"-wt",
help="Sets the sliding window step to be used when building both training and validation sets. Defaults to 5",
type=int,
default=5,
)
parser.add_argument(
"--checkpoint-path",
......@@ -120,20 +230,41 @@ parser.add_argument(
)
args = parser.parse_args()
data_path = os.path.abspath(args.data_path)
samples = args.samples
red = args.reducer
animal_id = args.animal_id
arena_dims = args.arena_dims
batch_size = args.batch_size
bayopt_trials = args.bayopt
checkpoints = args.Zcheckpoint_path
exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
input_type = args.input_type
k = args.components
predictor = args.predictor
variational = bool(args.variational)
kl_wu = args.kl_warmup
loss = args.loss
hparams = args.hyperparameters
encoding = args.encoding_size
checkpoints = args.checkpoint_path
samples = args.samples
red = args.reducer
mmd_wu = args.mmd_warmup
overlap_loss = args.overlap_loss
pheno_class = float(args.phenotype_classifier)
predictor = float(args.predictor)
runs = args.stability_check
smooth_alpha = args.smooth_alpha
train_path = os.path.abspath(args.train_path)
tune = args.hyperparameter_tuning
val_num = args.val_num
variational = bool(args.variational)
window_size = args.window_size
window_step = args.window_step
if not train_path:
raise ValueError("Set a valid data path for the training to run")
if not val_num:
raise ValueError(
"Set a valid data path / validation number for the validation to run"
)
if not data_path:
raise ValueError("Set a valid data path for the data to be loaded")
assert input_type in [
"coords",
"dists",
......@@ -142,123 +273,81 @@ assert input_type in [
"coords+angle",
"dists+angle",
"coords+dist+angle",
], "Invalid input type. Type python train_viz_app.py -h for help."
# Loads hyperparameters, most likely obtained from hyperparameter_tuning.py
hparams = load_hparams(hparams, encoding)
# with open(
# os.path.join(
# data_path, [i for i in os.listdir(data_path) if i.endswith(".pickle")][0]
# ),
# "rb",
# ) as handle:
# Treatment_dict = pickle.load(handle)
# Which angles to compute?
bp_dict = {
"B_Nose": ["B_Left_ear", "B_Right_ear"],
"B_Left_ear": ["B_Nose", "B_Right_ear", "B_Center", "B_Left_flank"],
"B_Right_ear": ["B_Nose", "B_Left_ear", "B_Center", "B_Right_flank"],
"B_Center": [
"B_Left_ear",
"B_Right_ear",
"B_Left_flank",
"B_Right_flank",
"B_Tail_base",
],
"B_Left_flank": ["B_Left_ear", "B_Center", "B_Tail_base"],
"B_Right_flank": ["B_Right_ear", "B_Center", "B_Tail_base"],
"B_Tail_base": ["B_Center", "B_Left_flank", "B_Right_flank"],
}
DLC_social = project(
path=data_path, # Path where to find the required files
smooth_alpha=0.50, # Alpha value for exponentially weighted smoothing
arena="circular", # Type of arena used in the experiments
arena_dims=tuple([380]), # Dimensions of the arena. Just one if it's circular
video_format=".mp4",
], "Invalid input type. Type python model_training.py -h for help."
# Loads model hyperparameters and treatment conditions, if available
hparams = load_hparams(hparams)
treatment_dict = load_treatments(train_path)
# noinspection PyTypeChecker
project_coords = project(
animal_ids=tuple([animal_id]),
arena="circular",
arena_dims=tuple([arena_dims]),
exclude_bodyparts=exclude_bodyparts,
exp_conditions=treatment_dict,
path=train_path,
smooth_alpha=smooth_alpha,
table_format=".h5",
# exp_conditions=Treatment_dict,
video_format=".mp4",
)
if animal_id:
project_coords.subset_condition = animal_id
DLC_social_coords = DLC_social.run(verbose=True)
project_coords = project_coords.run(verbose=True)
undercond = "" if animal_id == "" else "_"
# Coordinates for training data
coords1 = DLC_social_coords.get_coords(center="B_Center", align="B_Nose")
distances1 = DLC_social_coords.get_distances()
angles1 = DLC_social_coords.get_angles()
coords_distances1 = merge_tables(coords1, distances1)
coords_angles1 = merge_tables(coords1, angles1)
dists_angles1 = merge_tables(distances1, angles1)
coords_dist_angles1 = merge_tables(coords1, distances1, angles1)
input_dict = {
"coords": coords1.preprocess(
window_size=13,
window_step=1,
scale="standard",
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
),
"dists": distances1.preprocess(
window_size=13,
window_step=1,
scale="standard",
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
),
"angles": angles1.preprocess(
window_size=13,
window_step=1,
scale="standard",
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
),
"coords+dist": coords_distances1.preprocess(
window_size=13,
window_step=1,
scale="standard",
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
),
"coords+angle": coords_angles1.preprocess(
window_size=13,
window_step=1,
scale="standard",
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
),
"dists+angle": dists_angles1.preprocess(
window_size=13,
window_step=10,
scale="standard",
conv_filter=None,
sigma=55,
align="center",
),
"coords+dist+angle": coords_dist_angles1.preprocess(
window_size=13,
window_step=1,
coords = project_coords.get_coords(
center=animal_id + undercond + "Center",
align=animal_id + undercond + "Spine_1",
align_inplace=True,
)
distances = project_coords.get_distances()
angles = project_coords.get_angles()
coords_distances = merge_tables(coords, distances)
coords_angles = merge_tables(coords, angles)
dists_angles = merge_tables(distances, angles)
coords_dist_angles = merge_tables(coords, distances, angles)
def batch_preprocess(tab_dict):
"""Returns a preprocessed instance of the input table_dict object"""
return tab_dict.preprocess(
window_size=window_size,
window_step=window_step,
scale="standard",
conv_filter=None,
sigma=55,
conv_filter=gaussian_filter,
sigma=1,
test_videos=val_num,
shuffle=True,
align="center",
),
)
input_dict_train = {
"coords": coords,
"dists": distances,
"angles": angles,
"coords+dist": coords_distances,
"coords+angle": coords_angles,
"dists+angle": dists_angles,
"coords+dist+angle": coords_dist_angles,
}
print("Preprocessing data...")
X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
# Get training and validation sets
print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape)
if pheno_class > 0:
print("Training set label shape:", y_train.shape)
print("Validation set label shape:", y_val.shape)
print("Done!")
# Load checkpoints and build dataframe with predictions
path = checkpoints
......@@ -266,7 +355,7 @@ checkpoints = sorted(
list(
set(
[
path + re.findall("(.*\.ckpt).data", i)[0]
path + re.findall('(.*\.ckpt).data', i)[0]
for i in os.listdir(path)
if "ckpt.data" in i
]
......@@ -274,58 +363,42 @@ checkpoints = sorted(
)
)
pttest_idx = np.random.choice(list(range(input_dict[input_type].shape[0])), samples)
pttest = input_dict[input_type][pttest_idx]
pttest_idx = np.random.choice(list(range(X_train.shape[0])), samples)
pttest = X_val[pttest_idx]
# Instanciate all models
clusters = []
predictions = []
reconstructions = []
if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape, **hparams).build()
ae.build(pttest.shape)
(encoder, generator, grouper, gmvaep, _, _,) = SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
predictor=predictor,
kl_warmup_epochs=10,
mmd_warmup_epochs=10,
**hparams
).build(pttest.shape)
gmvaep.build(pttest.shape)
predictions.append(encoder.predict(pttest))
if predictor:
reconstructions.append(gmvaep.predict(pttest)[0])
else:
reconstructions.append(gmvaep.predict(pttest))
predictions.append(encoder.predict(pttest))
reconstructions.append(ae.predict(pttest))
print("Building predictions from pretrained models...")
else:
(encoder, generator, grouper, gmvaep, _, _,) = SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
predictor=predictor,
kl_warmup_epochs=10,
mmd_warmup_epochs=10,
**hparams
).build(pttest.shape)
gmvaep.build(pttest.shape)
for checkpoint in tqdm(checkpoints):
gmvaep.load_weights(checkpoint)
clusters.append(grouper.predict(pttest))
predictions.append(encoder.predict(pttest))
if predictor:
reconstructions.append(gmvaep.predict(pttest)[0])
else:
reconstructions.append(gmvaep.predict(pttest))
print("Building predictions from pretrained models...")
for checkpoint in tqdm(checkpoints):
if variational:
gmvaep.load_weights(checkpoint)
clusters.append(grouper.predict(pttest))
predictions.append(encoder.predict(pttest))
if predictor:
reconstructions.append(gmvaep.predict(pttest)[0])
else:
reconstructions.append(gmvaep.predict(pttest))
else:
ae.load_weights(checkpoint)
clusters.append(np.zeros(samples))
predictions.append(encoder.predict(pttest))
reconstructions.append(ae.predict(pttest))
print("Done!")
print("Reducing latent space to 2 dimensions for dataviz...")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment