train_model.py 12.9 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
"""

Model training command line tool for the deepof package.
usage: python -m examples.model_training -h

"""

from deepof.data import *
13
from deepof.models import *
14
from deepof.utils import *
lucas_miranda's avatar
lucas_miranda committed
15
from deepof.train_utils import *
lucas_miranda's avatar
lucas_miranda committed
16
17
18
19
20
21
from tensorflow import keras

parser = argparse.ArgumentParser(
    description="Autoencoder training for DeepOF animal pose recognition"
)

22
23
24
25
26
27
28
parser.add_argument(
    "--animal-id",
    "-id",
    help="Id of the animal to use. Empty string by default",
    type=str,
    default="",
)
lucas_miranda's avatar
lucas_miranda committed
29
parser.add_argument(
30
31
32
    "--arena-dims",
    "-adim",
    help="diameter in mm of the utilised arena. Used for scaling purposes",
lucas_miranda's avatar
lucas_miranda committed
33
    type=int,
34
35
36
37
38
39
40
41
42
    default=380,
)
parser.add_argument(
    "--batch-size",
    "-bs",
    help="set training batch size. Defaults to 512",
    type=int,
    default=512,
)
lucas_miranda's avatar
lucas_miranda committed
43
44
45
parser.add_argument(
    "--components",
    "-k",
46
    help="set the number of components for the GMVAE(P) model. Defaults to 1",
lucas_miranda's avatar
lucas_miranda committed
47
48
49
50
    type=int,
    default=1,
)
parser.add_argument(
51
52
53
    "--exclude-bodyparts",
    "-exc",
    help="Excludes the indicated bodyparts from all analyses. It should consist of several values separated by commas",
lucas_miranda's avatar
lucas_miranda committed
54
    type=str,
55
    default="",
lucas_miranda's avatar
lucas_miranda committed
56
57
)
parser.add_argument(
58
59
60
61
62
    "--gaussian-filter",
    "-gf",
    help="Convolves each training instance with a Gaussian filter before feeding it to the autoencoder model",
    type=str2bool,
    default=False,
lucas_miranda's avatar
lucas_miranda committed
63
)
64
parser.add_argument(
65
    "--hpt-trials",
66
67
68
69
70
    "-n",
    help="sets the number of hyperparameter tuning iterations to run. Default is 25",
    type=int,
    default=25,
)
71
72
73
parser.add_argument(
    "--hyperparameter-tuning",
    "-tune",
74
75
    help="Indicates whether hyperparameters should be tuned either using 'bayopt' of 'hyperband'. "
    "See documentation for details",
76
    type=str,
77
    default=False,
lucas_miranda's avatar
lucas_miranda committed
78
79
)
parser.add_argument(
80
81
82
83
84
85
86
87
88
89
90
91
92
    "--hyperparameters",
    "-hp",
    help="Path pointing to a pickled dictionary of network hyperparameters. "
    "Thought to be used with the output of hyperparameter tuning",
    type=str,
    default=None,
)
parser.add_argument(
    "--input-type",
    "-d",
    help="Select an input type for the autoencoder hypermodels. "
    "It must be one of coords, dists, angles, coords+dist, coords+angle, dists+angle or coords+dist+angle."
    "Defaults to coords.",
lucas_miranda's avatar
lucas_miranda committed
93
    type=str,
94
    default="dists",
lucas_miranda's avatar
lucas_miranda committed
95
96
97
98
99
100
101
102
)
parser.add_argument(
    "--kl-warmup",
    "-klw",
    help="Number of epochs during which the KL weight increases linearly from zero to 1. Defaults to 10",
    default=10,
    type=int,
)
103
104
105
106
107
108
109
110
parser.add_argument(
    "--loss",
    "-l",
    help="Sets the loss function for the variational model. "
    "It has to be one of ELBO+MMD, ELBO or MMD. Defaults to ELBO+MMD",
    default="ELBO+MMD",
    type=str,
)
lucas_miranda's avatar
lucas_miranda committed
111
112
113
114
115
116
117
118
119
120
121
parser.add_argument(
    "--mmd-warmup",
    "-mmdw",
    help="Number of epochs during which the MMD weight increases linearly from zero to 1. Defaults to 10",
    default=10,
    type=int,
)
parser.add_argument(
    "--overlap-loss",
    "-ol",
    help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
122
    type=str2bool,
123
    default=False,
lucas_miranda's avatar
lucas_miranda committed
124
)
125
126
127
128
129
130
131
parser.add_argument(
    "--phenotype-classifier",
    "-pheno",
    help="Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)",
    default=0.0,
    type=float,
)
lucas_miranda's avatar
lucas_miranda committed
132
parser.add_argument(
133
134
    "--predictor",
    "-pred",
135
136
    help="Activates the prediction branch of the variational Seq 2 Seq model with the specified weight. "
    "Defaults to 0.0 (inactive)",
lucas_miranda's avatar
lucas_miranda committed
137
    default=0.0,
138
139
140
141
142
143
144
145
146
    type=float,
)
parser.add_argument(
    "--smooth-alpha",
    "-sa",
    help="Sets the exponential smoothing factor to apply to the input data. "
    "Float between 0 and 1 (lower is more smooting)",
    type=float,
    default=0.99,
lucas_miranda's avatar
lucas_miranda committed
147
148
149
150
)
parser.add_argument(
    "--stability-check",
    "-s",
151
152
    help="Sets the number of times that the model is trained and initialised. "
    "If greater than 1 (the default), saves the cluster assignments to a dataframe on disk",
lucas_miranda's avatar
lucas_miranda committed
153
154
155
    type=int,
    default=1,
)
156
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
157
parser.add_argument(
158
159
160
161
162
163
164
165
166
167
168
    "--val-num",
    "-vn",
    help="set number of videos of the training" "set to use for validation",
    type=int,
    default=1,
)
parser.add_argument(
    "--variational",
    "-v",
    help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
    default=True,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    type=str2bool,
)
parser.add_argument(
    "--window-size",
    "-ws",
    help="Sets the sliding window size to be used when building both training and validation sets. Defaults to 15",
    type=int,
    default=15,
)
parser.add_argument(
    "--window-step",
    "-wt",
    help="Sets the sliding window step to be used when building both training and validation sets. Defaults to 5",
    type=int,
    default=5,
)
lucas_miranda's avatar
lucas_miranda committed
185
186

args = parser.parse_args()
187

188
animal_id = args.animal_id
189
arena_dims = args.arena_dims
190
batch_size = args.batch_size
191
hypertun_trials = args.hpt_trials
192
exclude_bodyparts = list(args.exclude_bodyparts.split(","))
193
194
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
lucas_miranda's avatar
lucas_miranda committed
195
196
197
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
198
loss = args.loss
lucas_miranda's avatar
lucas_miranda committed
199
200
mmd_wu = args.mmd_warmup
overlap_loss = args.overlap_loss
201
pheno_class = float(args.phenotype_classifier)
202
predictor = float(args.predictor)
lucas_miranda's avatar
lucas_miranda committed
203
runs = args.stability_check
204
smooth_alpha = args.smooth_alpha
205
train_path = os.path.abspath(args.train_path)
206
tune = args.hyperparameter_tuning
207
208
209
210
val_num = args.val_num
variational = bool(args.variational)
window_size = args.window_size
window_step = args.window_step
lucas_miranda's avatar
lucas_miranda committed
211
212
213

if not train_path:
    raise ValueError("Set a valid data path for the training to run")
214
if not val_num:
lucas_miranda's avatar
lucas_miranda committed
215
216
217
    raise ValueError(
        "Set a valid data path / validation number for the validation to run"
    )
218

lucas_miranda's avatar
lucas_miranda committed
219
220
221
222
223
224
225
226
227
228
assert input_type in [
    "coords",
    "dists",
    "angles",
    "coords+dist",
    "coords+angle",
    "dists+angle",
    "coords+dist+angle",
], "Invalid input type. Type python model_training.py -h for help."

229
# Loads model hyperparameters and treatment conditions, if available
230
hparams = load_hparams(hparams)
231
treatment_dict = load_treatments(train_path)
lucas_miranda's avatar
lucas_miranda committed
232

233
# noinspection PyTypeChecker
234
project_coords = project(
235
236
    animal_ids=tuple([animal_id]),
    arena="circular",
lucas_miranda's avatar
lucas_miranda committed
237
    arena_dims=tuple([arena_dims]),
238
239
    exclude_bodyparts=exclude_bodyparts,
    exp_conditions=treatment_dict,
240
241
    path=train_path,
    smooth_alpha=smooth_alpha,
lucas_miranda's avatar
lucas_miranda committed
242
    table_format=".h5",
243
    video_format=".mp4",
244
)
245
246
247
248
249
250

if animal_id:
    project_coords.subset_condition = animal_id

project_coords = project_coords.run(verbose=True)
undercond = "" if animal_id == "" else "_"
lucas_miranda's avatar
lucas_miranda committed
251
252

# Coordinates for training data
253
254
coords = project_coords.get_coords(
    center=animal_id + undercond + "Center",
lucas_miranda's avatar
lucas_miranda committed
255
    align=animal_id + undercond + "Spine_1",
256
257
    align_inplace=True,
)
lucas_miranda's avatar
lucas_miranda committed
258
259
260
261
262
263
distances = project_coords.get_distances()
angles = project_coords.get_angles()
coords_distances = merge_tables(coords, distances)
coords_angles = merge_tables(coords, angles)
dists_angles = merge_tables(distances, angles)
coords_dist_angles = merge_tables(coords, distances, angles)
lucas_miranda's avatar
lucas_miranda committed
264
265


266
267
268
269
270
271
def batch_preprocess(tab_dict):
    """Returns a preprocessed instance of the input table_dict object"""

    return tab_dict.preprocess(
        window_size=window_size,
        window_step=window_step,
lucas_miranda's avatar
lucas_miranda committed
272
        scale="standard",
273
        conv_filter=gaussian_filter,
lucas_miranda's avatar
lucas_miranda committed
274
        sigma=1,
lucas_miranda's avatar
lucas_miranda committed
275
        test_videos=val_num,
lucas_miranda's avatar
lucas_miranda committed
276
        shuffle=True,
277
278
279
280
281
    )


input_dict_train = {
    "coords": coords,
lucas_miranda's avatar
lucas_miranda committed
282
283
284
285
286
287
    "dists": distances,
    "angles": angles,
    "coords+dist": coords_distances,
    "coords+angle": coords_angles,
    "dists+angle": dists_angles,
    "coords+dist+angle": coords_dist_angles,
lucas_miranda's avatar
lucas_miranda committed
288
289
}

290
print("Preprocessing data...")
291
X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
292
# Get training and validation sets
293
294
295

print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape)
296
297
298
if pheno_class > 0:
    print("Training set label shape:", y_train.shape)
    print("Validation set label shape:", y_val.shape)
299

300
print("Done!")
lucas_miranda's avatar
lucas_miranda committed
301

302
303
304
# Proceed with training mode. Fit autoencoder with the same parameters,
# as many times as specified by runs
if not tune:
lucas_miranda's avatar
lucas_miranda committed
305

306
307
    # Training loop
    for run in range(runs):
lucas_miranda's avatar
lucas_miranda committed
308

309
310
        # To avoid stability issues
        tf.keras.backend.clear_session()
lucas_miranda's avatar
lucas_miranda committed
311

312
        run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
313
314
315
316
317
318
            X_train,
            batch_size,
            True,
            variational,
            predictor,
            loss,
lucas_miranda's avatar
lucas_miranda committed
319
320
        )

321
322
323
324
325
326
327
        if not variational:
            encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
            print(ae.summary())

            ae.save_weights("./logs/checkpoints/cp-{epoch:04d}.ckpt".format(epoch=0))
            # Fit the specified model to the data
            history = ae.fit(
lucas_miranda's avatar
lucas_miranda committed
328
329
                x=X_train,
                y=X_train,
330
                epochs=25,
lucas_miranda's avatar
lucas_miranda committed
331
332
                batch_size=batch_size,
                verbose=1,
333
334
335
336
337
338
                validation_data=(X_val, X_val),
                callbacks=[
                    tensorboard_callback,
                    cp_callback,
                    onecycle,
                    tf.keras.callbacks.EarlyStopping(
lucas_miranda's avatar
lucas_miranda committed
339
                        "val_mae", patience=15, restore_best_weights=True
340
341
                    ),
                ],
lucas_miranda's avatar
lucas_miranda committed
342
            )
343
344
345

            ae.save_weights("{}_final_weights.h5".format(run_ID))

lucas_miranda's avatar
lucas_miranda committed
346
        else:
347
348
349
350
351
352
353
354
            (
                encoder,
                generator,
                grouper,
                gmvaep,
                kl_warmup_callback,
                mmd_warmup_callback,
            ) = SEQ_2_SEQ_GMVAE(
355
356
                architecture_hparams=hparams,
                batch_size=batch_size,
357
                compile_model=True,
358
                kl_warmup_epochs=kl_wu,
359
                loss=loss,
360
                mmd_warmup_epochs=mmd_wu,
361
                number_of_components=k,
362
                overlap_loss=overlap_loss,
363
                phenotype_prediction=pheno_class,
364
                predictor=predictor,
365
366
            ).build(
                X_train.shape
lucas_miranda's avatar
lucas_miranda committed
367
            )
368
369
370
371
372
373
374
            print(gmvaep.summary())

            callbacks_ = [
                tensorboard_callback,
                cp_callback,
                onecycle,
                tf.keras.callbacks.EarlyStopping(
lucas_miranda's avatar
lucas_miranda committed
375
                    "val_mae", patience=15, restore_best_weights=True
376
377
378
379
380
381
382
383
                ),
            ]

            if "ELBO" in loss and kl_wu > 0:
                callbacks_.append(kl_warmup_callback)
            if "MMD" in loss and mmd_wu > 0:
                callbacks_.append(mmd_warmup_callback)

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
            Xs, ys = [X_train], [X_train]
            Xvals, yvals = [X_val], [X_val]

            if predictor > 0.0:
                Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
                Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]

            if pheno_class > 0.0:
                ys += [y_train]
                yvals += [y_val]

            history = gmvaep.fit(
                x=Xs,
                y=ys,
                epochs=250,
                batch_size=batch_size,
                verbose=1,
401
402
403
404
                validation_data=(
                    Xvals,
                    yvals,
                ),
405
406
                callbacks=callbacks_,
            )
407
408
409
410
411
412
413
414
415

            gmvaep.save_weights("{}_final_weights.h5".format(run_ID))

        # To avoid stability issues
        tf.keras.backend.clear_session()

else:
    # Runs hyperparameter tuning with the specified parameters and saves the results

lucas_miranda's avatar
lucas_miranda committed
416
417
    hyp = "S2SGMVAE" if variational else "S2SAE"

418
419
    run_ID, tensorboard_callback, onecycle = get_callbacks(
        X_train, batch_size, False, variational, predictor, loss
420
    )
lucas_miranda's avatar
lucas_miranda committed
421

422
    best_hyperparameters, best_model = tune_search(
423
        data=[X_train, y_train, X_val, y_val],
424
425
        hypertun_trials=hypertun_trials,
        hpt_type=tune,
426
427
        hypermodel=hyp,
        k=k,
428
        kl_warmup_epochs=kl_wu,
429
        loss=loss,
430
        mmd_warmup_epochs=mmd_wu,
431
        overlap_loss=overlap_loss,
432
        pheno_class=pheno_class,
433
        predictor=predictor,
434
        project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
lucas_miranda's avatar
lucas_miranda committed
435
436
437
438
439
440
441
        callbacks=[
            tensorboard_callback,
            onecycle,
            tf.keras.callbacks.EarlyStopping(
                "val_mae", patience=15, restore_best_weights=True
            ),
        ],
442
        n_replicas=1,
443
        n_epochs=30,
444
    )
lucas_miranda's avatar
lucas_miranda committed
445

446
    # Saves a compiled, untrained version of the best model
447
    best_model.build(X_train.shape)
448
    best_model.save(
449
        "{}-based_{}_{}.h5".format(input_type, hyp, tune.capitalize()), save_format="tf"
450
451
452
453
    )

    # Saves the best hyperparameters
    with open(
454
        "{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()), "wb"
455
456
457
458
    ) as handle:
        pickle.dump(
            best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL
        )
lucas_miranda's avatar
lucas_miranda committed
459
460
461

# TODO:
#    - Investigate how goussian filters affect reproducibility (in a systematic way)