train_model.py 13.6 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
"""

Model training command line tool for the deepof package.
usage: python -m examples.model_training -h

"""

12
13
import argparse
import os
14

15
16
17
import deepof.data
import deepof.train_utils
import deepof.utils
lucas_miranda's avatar
lucas_miranda committed
18
19
20
21
22

parser = argparse.ArgumentParser(
    description="Autoencoder training for DeepOF animal pose recognition"
)

23
24
25
26
27
28
29
parser.add_argument(
    "--animal-id",
    "-id",
    help="Id of the animal to use. Empty string by default",
    type=str,
    default="",
)
lucas_miranda's avatar
lucas_miranda committed
30
parser.add_argument(
31
32
33
    "--arena-dims",
    "-adim",
    help="diameter in mm of the utilised arena. Used for scaling purposes",
lucas_miranda's avatar
lucas_miranda committed
34
    type=int,
35
36
37
38
39
40
41
42
43
    default=380,
)
parser.add_argument(
    "--batch-size",
    "-bs",
    help="set training batch size. Defaults to 512",
    type=int,
    default=512,
)
lucas_miranda's avatar
lucas_miranda committed
44
45
46
parser.add_argument(
    "--components",
    "-k",
47
    help="set the number of components for the GMVAE(P) model. Defaults to 1",
lucas_miranda's avatar
lucas_miranda committed
48
49
50
    type=int,
    default=1,
)
51
52
53
54
55
56
57
parser.add_argument(
    "--encoding-size",
    "-es",
    help="set the number of dimensions of the latent space. 16 by default",
    type=int,
    default=16,
)
lucas_miranda's avatar
lucas_miranda committed
58
parser.add_argument(
59
60
61
    "--exclude-bodyparts",
    "-exc",
    help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas",
lucas_miranda's avatar
lucas_miranda committed
62
    type=str,
63
    default="",
lucas_miranda's avatar
lucas_miranda committed
64
65
)
parser.add_argument(
66
67
68
    "--gaussian-filter",
    "-gf",
    help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
69
    type=deepof.utils.str2bool,
70
    default=False,
lucas_miranda's avatar
lucas_miranda committed
71
)
72
parser.add_argument(
73
    "--hpt-trials",
74
75
76
77
78
    "-n",
    help="sets the number of hyperparameter tuning iterations to run. Default is 25",
    type=int,
    default=25,
)
79
80
81
parser.add_argument(
    "--hyperparameter-tuning",
    "-tune",
82
    help="Indicates whether hyperparameters should be tuned either using 'bayopt' of 'hyperband'. "
83
    "See documentation for details",
84
    type=str,
85
    default=False,
lucas_miranda's avatar
lucas_miranda committed
86
87
)
parser.add_argument(
88
89
90
    "--hyperparameters",
    "-hp",
    help="Path pointing to a pickled dictionary of network hyperparameters. "
91
    "Thought to be used with the output of hyperparameter tuning",
92
93
94
95
96
97
98
    type=str,
    default=None,
)
parser.add_argument(
    "--input-type",
    "-d",
    help="Select an input type for the autoencoder hypermodels. "
99
100
    "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
    "Defaults to coords.",
lucas_miranda's avatar
lucas_miranda committed
101
    type=str,
102
    default="dists",
lucas_miranda's avatar
lucas_miranda committed
103
)
104
105
106
107
108
109
110
parser.add_argument(
    "--kl-annealing-mode",
    "-klam",
    help="Weight annealing to use for ELBO loss. Can be one of 'linear' and 'sigmoid'",
    default="linear",
    type=str,
)
lucas_miranda's avatar
lucas_miranda committed
111
112
113
114
115
116
117
parser.add_argument(
    "--kl-warmup",
    "-klw",
    help="Number of epochs during which the KL weight increases linearly from zero to 1. Defaults to 10",
    default=10,
    type=int,
)
118
parser.add_argument(
119
    "--entropy-knn",
120
    "-entminn",
121
122
    help="number of nearest neighbors to take into account when computing latent space entropy",
    default=100,
123
124
    type=int,
)
lucas_miranda's avatar
lucas_miranda committed
125
parser.add_argument(
126
127
128
    "--entropy-samples",
    "-ents",
    help="Samples to use to compute cluster purity",
lucas_miranda's avatar
lucas_miranda committed
129
130
131
    default=10000,
    type=int,
)
132
133
134
135
parser.add_argument(
    "--latent-reg",
    "-lreg",
    help="Sets the strategy to regularize the latent mixture of Gaussians. "
136
137
    "It has to be one of none, categorical (an elastic net penalty is applied to the categorical distribution),"
    "variance (l2 penalty to the variance of the clusters) or categorical+variance. Defaults to none.",
138
139
140
    default="none",
    type=str,
)
141
142
143
144
parser.add_argument(
    "--loss",
    "-l",
    help="Sets the loss function for the variational model. "
145
    "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
146
147
148
    default="ELBO+MMD",
    type=str,
)
149
150
151
152
153
154
155
parser.add_argument(
    "--mmd-annealing-mode",
    "-mmdam",
    help="Weight annealing to use for MMD loss. Can be one of 'linear' and 'sigmoid'",
    default="linear",
    type=str,
)
lucas_miranda's avatar
lucas_miranda committed
156
157
158
159
160
161
162
parser.add_argument(
    "--mmd-warmup",
    "-mmdw",
    help="Number of epochs during which the MMD weight increases linearly from zero to 1. Defaults to 10",
    default=10,
    type=int,
)
163
164
165
166
167
168
169
parser.add_argument(
    "--montecarlo-kl",
    "-mckl",
    help="Number of samples to compute when adding KLDivergence to the loss function",
    default=10,
    type=int,
)
170
171
172
173
174
175
176
parser.add_argument(
    "--output-path",
    "-o",
    help="Sets the base directory where to output results. Default is the current directory",
    type=str,
    default=".",
)
lucas_miranda's avatar
lucas_miranda committed
177
178
179
parser.add_argument(
    "--overlap-loss",
    "-ol",
180
181
182
    help="If > 0, adds a regularization term controlling for local cluster assignment entropy in the latent space",
    type=float,
    default=0,
lucas_miranda's avatar
lucas_miranda committed
183
)
184
parser.add_argument(
185
186
187
188
189
190
191
192
193
194
    "--next-sequence-prediction",
    "-nspred",
    help="Activates the next sequence prediction branch of the variational Seq 2 Seq model with the specified weight. "
    "Defaults to 0.0 (inactive)",
    default=0.0,
    type=float,
)
parser.add_argument(
    "--phenotype-prediction",
    "-ppred",
195
196
197
198
    help="Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)",
    default=0.0,
    type=float,
)
lucas_miranda's avatar
lucas_miranda committed
199
parser.add_argument(
200
201
202
    "--rule-based-prediction",
    "-rbpred",
    help="Activates the rule-based trait prediction branch of the variational Seq 2 Seq model "
203
    "with the specified weight Defaults to 0.0 (inactive)",
lucas_miranda's avatar
lucas_miranda committed
204
    default=0.0,
205
206
207
208
209
210
    type=float,
)
parser.add_argument(
    "--smooth-alpha",
    "-sa",
    help="Sets the exponential smoothing factor to apply to the input data. "
211
    "Float between 0 and 1 (lower is more smooting)",
212
213
    type=float,
    default=0.99,
lucas_miranda's avatar
lucas_miranda committed
214
)
215
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
216
parser.add_argument(
217
218
219
220
221
222
    "--val-num",
    "-vn",
    help="set number of videos of the training" "set to use for validation",
    type=int,
    default=1,
)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
parser.add_argument(
    "--window-size",
    "-ws",
    help="Sets the sliding window size to be used when building both training and validation sets. Defaults to 15",
    type=int,
    default=15,
)
parser.add_argument(
    "--window-step",
    "-wt",
    help="Sets the sliding window step to be used when building both training and validation sets. Defaults to 5",
    type=int,
    default=5,
)
237
238
239
240
241
242
243
parser.add_argument(
    "--run",
    "-rid",
    help="Sets the run ID of the experiment (for naming output files only). If 0 (default), uses a timestamp instead",
    type=int,
    default=0,
)
lucas_miranda's avatar
lucas_miranda committed
244
245

args = parser.parse_args()
246

247
animal_id = args.animal_id
248
arena_dims = args.arena_dims
249
batch_size = args.batch_size
250
hypertun_trials = args.hpt_trials
251
encoding_size = args.encoding_size
252
exclude_bodyparts = [i for i in args.exclude_bodyparts.split(",") if i]
253
gaussian_filter = args.gaussian_filter
254
hparams = args.hyperparameters if args.hyperparameters is not None else {}
lucas_miranda's avatar
lucas_miranda committed
255
256
input_type = args.input_type
k = args.components
257
kl_annealing_mode = args.kl_annealing_mode
lucas_miranda's avatar
lucas_miranda committed
258
kl_wu = args.kl_warmup
259
entropy_knn = args.entropy_knn
260
entropy_samples = args.entropy_samples
261
latent_reg = args.latent_reg
262
loss = args.loss
263
mmd_annealing_mode = args.mmd_annealing_mode
lucas_miranda's avatar
lucas_miranda committed
264
mmd_wu = args.mmd_warmup
265
mc_kl = args.montecarlo_kl
266
output_path = os.path.join(args.output_path)
267
overlap_loss = float(args.overlap_loss)
268
269
270
next_sequence_prediction = float(args.next_sequence_prediction)
phenotype_prediction = float(args.phenotype_prediction)
rule_based_prediction = float(args.rule_based_prediction)
271
smooth_alpha = args.smooth_alpha
272
train_path = os.path.abspath(args.train_path)
273
tune = args.hyperparameter_tuning
274
275
276
val_num = args.val_num
window_size = args.window_size
window_step = args.window_step
277
run = args.run
lucas_miranda's avatar
lucas_miranda committed
278
279
280

if not train_path:
    raise ValueError("Set a valid data path for the training to run")
281

lucas_miranda's avatar
lucas_miranda committed
282
283
284
285
286
287
288
289
290
291
assert input_type in [
    "coords",
    "dists",
    "angles",
    "coords+dist",
    "coords+angle",
    "dists+angle",
    "coords+dist+angle",
], "Invalid input type. Type python model_training.py -h for help."

292
# Loads model hyperparameters and treatment conditions, if available
293
treatment_dict = deepof.train_utils.load_treatments(train_path)
lucas_miranda's avatar
lucas_miranda committed
294

295
# Logs hyperparameters  if specified on the --logparam CLI argument
296
297
298
299
300
logparam = {
    "encoding": encoding_size,
    "k": k,
    "loss": loss,
}
301
302
if next_sequence_prediction:
    logparam["next_sequence_prediction_weight"] = next_sequence_prediction
303
if phenotype_prediction:
304
    logparam["phenotype_prediction_weight"] = phenotype_prediction
305
if rule_based_prediction:
306
    logparam["rule_based_prediction_weight"] = rule_based_prediction
307

308
# noinspection PyTypeChecker
309
project_coords = deepof.data.project(
310
311
    animal_ids=tuple([animal_id]),
    arena="circular",
lucas_miranda's avatar
lucas_miranda committed
312
    arena_dims=tuple([arena_dims]),
313
    enable_iterative_imputation=True,
314
    exclude_bodyparts=exclude_bodyparts,
315
    exp_conditions=treatment_dict,
316
317
    path=train_path,
    smooth_alpha=smooth_alpha,
lucas_miranda's avatar
lucas_miranda committed
318
    table_format=".h5",
319
    video_format=".mp4",
320
)
321
322
323
324
325
326

if animal_id:
    project_coords.subset_condition = animal_id

project_coords = project_coords.run(verbose=True)
undercond = "" if animal_id == "" else "_"
lucas_miranda's avatar
lucas_miranda committed
327
328

# Coordinates for training data
329
330
coords = project_coords.get_coords(
    center=animal_id + undercond + "Center",
lucas_miranda's avatar
lucas_miranda committed
331
    align=animal_id + undercond + "Spine_1",
332
    align_inplace=True,
333
    propagate_labels=(phenotype_prediction > 0),
334
335
336
    propagate_annotations=(
        False if not rule_based_prediction else project_coords.rule_based_annotation()
    ),
337
)
lucas_miranda's avatar
lucas_miranda committed
338
339
distances = project_coords.get_distances()
angles = project_coords.get_angles()
340
341
342
343
coords_distances = deepof.data.merge_tables(coords, distances)
coords_angles = deepof.data.merge_tables(coords, angles)
dists_angles = deepof.data.merge_tables(distances, angles)
coords_dist_angles = deepof.data.merge_tables(coords, distances, angles)
lucas_miranda's avatar
lucas_miranda committed
344
345


346
347
348
349
350
351
def batch_preprocess(tab_dict):
    """Returns a preprocessed instance of the input table_dict object"""

    return tab_dict.preprocess(
        window_size=window_size,
        window_step=window_step,
352
        scale="standard",
353
        conv_filter=gaussian_filter,
lucas_miranda's avatar
lucas_miranda committed
354
        sigma=1,
lucas_miranda's avatar
lucas_miranda committed
355
        test_videos=val_num,
lucas_miranda's avatar
lucas_miranda committed
356
        shuffle=True,
357
358
359
360
361
    )


input_dict_train = {
    "coords": coords,
lucas_miranda's avatar
lucas_miranda committed
362
363
364
365
366
367
    "dists": distances,
    "angles": angles,
    "coords+dist": coords_distances,
    "coords+angle": coords_angles,
    "dists+angle": dists_angles,
    "coords+dist+angle": coords_dist_angles,
lucas_miranda's avatar
lucas_miranda committed
368
369
}

370
print("Preprocessing data...")
371
X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
372
# Get training and validation sets
373
374
375

print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape)
376
if any([phenotype_prediction, rule_based_prediction]):
377
378
    print("Training set label shape:", y_train.shape)
    print("Validation set label shape:", y_val.shape)
379

380
print("Done!")
lucas_miranda's avatar
lucas_miranda committed
381

382
383
384
# Proceed with training mode. Fit autoencoder with the same parameters,
# as many times as specified by runs
if not tune:
lucas_miranda's avatar
lucas_miranda committed
385

386
    trained_models = project_coords.deep_unsupervised_embedding(
387
388
389
        (X_train, y_train, X_val, y_val),
        batch_size=batch_size,
        encoding_size=encoding_size,
390
        hparams={},
391
        kl_annealing_mode=kl_annealing_mode,
392
393
394
395
        kl_warmup=kl_wu,
        log_history=True,
        log_hparams=True,
        loss=loss,
396
        mmd_annealing_mode=mmd_annealing_mode,
397
398
399
400
        mmd_warmup=mmd_wu,
        montecarlo_kl=mc_kl,
        n_components=k,
        output_path=output_path,
401
        overlap_loss=overlap_loss,
402
        next_sequence_prediction=next_sequence_prediction,
403
        phenotype_prediction=phenotype_prediction,
404
        rule_based_prediction=rule_based_prediction,
405
406
        save_checkpoints=False,
        save_weights=True,
407
408
        reg_cat_clusters=("categorical" in latent_reg),
        reg_cluster_variance=("variance" in latent_reg),
409
        entropy_samples=entropy_samples,
410
        entropy_knn=entropy_knn,
411
        input_type=input_type,
412
        run=run,
413
    )
414
415
416

else:
    # Runs hyperparameter tuning with the specified parameters and saves the results
417
    run_ID, tensorboard_callback, entropy, onecycle = deepof.train_utils.get_callbacks(
418
419
        X_train=X_train,
        batch_size=batch_size,
420
        phenotype_prediction=phenotype_prediction,
421
        next_sequence_prediction=next_sequence_prediction,
422
        rule_based_prediction=rule_base_prediction,
423
        loss=loss,
424
425
        loss_warmup=kl_wu,
        warmup_mode=kl_annealing_mode,
426
427
428
        input_type=input_type,
        cp=False,
        entropy_knn=entropy_knn,
429
        logparam=logparam,
430
        outpath=output_path,
431
        overlap_loss=overlap_loss,
432
        run=run,
433
    )
lucas_miranda's avatar
lucas_miranda committed
434

435
    best_hyperparameters, best_model = deepof.train_utils.tune_search(
436
        data=[X_train, y_train, X_val, y_val],
437
        encoding_size=encoding_size,
438
439
        hypertun_trials=hypertun_trials,
        hpt_type=tune,
440
        k=k,
441
        kl_warmup_epochs=kl_wu,
442
        loss=loss,
443
        mmd_warmup_epochs=mmd_wu,
444
        overlap_loss=overlap_loss,
445
        next_sequence_prediction=next_sequence_prediction,
446
447
        phenotype_prediction=phenotype_prediction,
        rule_based_prediction=rule_base_prediction,
448
        project_name="{}-based_GMVAE_{}".format(input_type, tune.capitalize()),
lucas_miranda's avatar
lucas_miranda committed
449
450
        callbacks=[
            tensorboard_callback,
lucas_miranda's avatar
lucas_miranda committed
451
            onecycle,
452
            entropy,
453
            deepof.train_utils.CustomStopper(
454
                monitor="val_loss",
455
                patience=5,
456
                restore_best_weights=True,
457
                start_epoch=max(kl_wu, mmd_wu),
lucas_miranda's avatar
lucas_miranda committed
458
459
            ),
        ],
460
        n_replicas=1,
461
        n_epochs=30,
462
        outpath=output_path,
463
    )
lucas_miranda's avatar
lucas_miranda committed
464

465
466
    # Saves the best hyperparameters
    with open(
467
468
        os.path.join(
            output_path,
469
            "{}-based_GMVAE_{}_params.pickle".format(input_type, tune.capitalize()),
470
471
        ),
        "wb",
472
473
474
475
    ) as handle:
        pickle.dump(
            best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL
        )