train_utils.py 20.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# @author lucasmiranda42
# encoding: utf-8
# module deepof

"""

Simple utility functions used in deepof example scripts. These are not part of the main package

"""

lucas_miranda's avatar
lucas_miranda committed
11
12
import json
import os
13
from datetime import date, datetime
lucas_miranda's avatar
lucas_miranda committed
14
15
16
17
from typing import Tuple, Union, Any, List

import numpy as np
import tensorflow as tf
18
from kerastuner import BayesianOptimization, Hyperband
19
from kerastuner_tensorboard_logger import TensorBoardLogger
lucas_miranda's avatar
lucas_miranda committed
20
from sklearn.metrics import roc_auc_score
21
from tensorboard.plugins.hparams import api as hp
lucas_miranda's avatar
lucas_miranda committed
22

23
24
25
import deepof.hypermodels
import deepof.model_utils

26
27
28
29
30
# Ignore warning with no downstream effect
tf.get_logger().setLevel("ERROR")
tf.autograph.set_verbosity(0)


31
class CustomStopper(tf.keras.callbacks.EarlyStopping):
32
    """ Custom early stopping callback. Prevents the model from stopping before warmup is over """
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

    def __init__(self, start_epoch, *args, **kwargs):
        super(CustomStopper, self).__init__(*args, **kwargs)
        self.start_epoch = start_epoch

    def get_config(self):  # pragma: no cover
        """Updates callback metadata"""

        config = super().get_config().copy()
        config.update({"start_epoch": self.start_epoch})
        return config

    def on_epoch_end(self, epoch, logs=None):
        if epoch > self.start_epoch:
            super().on_epoch_end(epoch, logs)


50
51
52
53
54
def load_treatments(train_path):
    """Loads a dictionary containing the treatments per individual,
    to be loaded as metadata in the coordinates class"""
    try:
        with open(
55
56
57
58
59
            os.path.join(
                train_path,
                [i for i in os.listdir(train_path) if i.endswith(".json")][0],
            ),
            "r",
60
        ) as handle:
61
            treatment_dict = json.load(handle)
62
63
64
65
66
67
68
    except IndexError:
        treatment_dict = None

    return treatment_dict


def get_callbacks(
69
70
71
    X_train: np.array,
    batch_size: int,
    variational: bool,
72
73
74
    phenotype_prediction: float,
    next_sequence_prediction: float,
    rule_based_prediction: float,
75
76
77
78
79
80
81
82
83
    loss: str,
    X_val: np.array = None,
    cp: bool = False,
    reg_cat_clusters: bool = False,
    reg_cluster_variance: bool = False,
    entropy_samples: int = 15000,
    entropy_knn: int = 100,
    logparam: dict = None,
    outpath: str = ".",
84
    run: int = False,
85
) -> List[Union[Any]]:
86
    """Generates callbacks for model training, including:
87
88
89
90
    - run_ID: run name, with coarse parameter details;
    - tensorboard_callback: for real-time visualization;
    - cp_callback: for checkpoint saving,
    - onecycle: for learning rate scheduling"""
91

92
93
94
95
96
97
98
99
    latreg = "none"
    if reg_cat_clusters and not reg_cluster_variance:
        latreg = "categorical"
    elif reg_cluster_variance and not reg_cat_clusters:
        latreg = "variance"
    elif reg_cat_clusters and reg_cluster_variance:
        latreg = "categorical+variance"

100
    run_ID = "{}{}{}{}{}{}{}_{}".format(
101
        ("GMVAE" if variational else "AE"),
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        (
            "_NextSeqPred={}".format(next_sequence_prediction)
            if next_sequence_prediction > 0 and variational
            else ""
        ),
        (
            "_PhenoPred={}".format(phenotype_prediction)
            if phenotype_prediction > 0
            else ""
        ),
        (
            "_RuleBasedPred={}".format(rule_based_prediction)
            if rule_based_prediction > 0
            else ""
        ),
117
        ("_loss={}".format(loss) if variational else ""),
118
119
        ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
        ("_k={}".format(logparam["k"]) if logparam is not None else ""),
120
        ("_latreg={}".format(latreg)),
121
122
        ("_entknn={}".format(entropy_knn)),
        ("_run={}".format(run) if run else ""),
123
        (datetime.now().strftime("%Y%m%d-%H%M%S")),
124
125
    )

126
    log_dir = os.path.abspath(os.path.join(outpath, "fit", run_ID))
127
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
128
129
130
        log_dir=log_dir,
        histogram_freq=1,
        profile_batch=2,
131
132
    )

133
    entropy = deepof.model_utils.neighbor_latent_entropy(
134
        encoding_dim=logparam["encoding"],
135
        k=entropy_knn,
136
        samples=entropy_samples,
lucas_miranda's avatar
lucas_miranda committed
137
        validation_data=X_val,
138
        log_dir=os.path.join(outpath, "metrics", run_ID),
139
        variational=variational,
lucas_miranda's avatar
lucas_miranda committed
140
141
    )

142
    onecycle = deepof.model_utils.one_cycle_scheduler(
143
144
        X_train.shape[0] // batch_size * 250,
        max_rate=0.005,
145
        log_dir=os.path.join(outpath, "metrics", run_ID),
146
147
    )

148
    callbacks = [run_ID, tensorboard_callback, entropy, onecycle]
149
150
151

    if cp:
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
152
            os.path.join(outpath, "checkpoints", run_ID + "/cp-{epoch:04d}.ckpt"),
153
154
155
156
157
158
159
160
            verbose=1,
            save_best_only=False,
            save_weights_only=True,
            save_freq="epoch",
        )
        callbacks.append(cp_callback)

    return callbacks
161
162


lucas_miranda's avatar
lucas_miranda committed
163
def log_hyperparameters(phenotype_class: float, rec: str):
lucas_miranda's avatar
lucas_miranda committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    """Blueprint for hyperparameter and metric logging in tensorboard during hyperparameter tuning"""

    logparams = [
        hp.HParam(
            "encoding",
            hp.Discrete([2, 4, 6, 8, 12, 16]),
            display_name="encoding",
            description="encoding size dimensionality",
        ),
        hp.HParam(
            "k",
            hp.IntInterval(min_value=1, max_value=25),
            display_name="k",
            description="cluster_number",
        ),
        hp.HParam(
            "loss",
            hp.Discrete(["ELBO", "MMD", "ELBO+MMD"]),
            display_name="loss function",
            description="loss function",
        ),
    ]

    metrics = [
        hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
        hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
    ]
    if phenotype_class:
        logparams.append(
            hp.HParam(
                "pheno_weight",
                hp.RealInterval(min_value=0.0, max_value=1000.0),
                display_name="pheno weight",
                description="weight applied to phenotypic classifier from the latent space",
            )
        )
        metrics += [
            hp.Metric(
                "phenotype_prediction_accuracy",
                display_name="phenotype_prediction_accuracy",
            ),
            hp.Metric(
                "phenotype_prediction_auc",
                display_name="phenotype_prediction_auc",
            ),
        ]

    return logparams, metrics


# noinspection PyUnboundLocalVariable
lucas_miranda's avatar
lucas_miranda committed
215
def tensorboard_metric_logging(
216
217
218
219
220
    run_dir: str,
    hpms: Any,
    ae: Any,
    X_val: np.ndarray,
    y_val: np.ndarray,
221
222
223
    next_sequence_prediction: float,
    phenotype_prediction: float,
    rule_based_prediction: float,
224
    rec: str,
lucas_miranda's avatar
lucas_miranda committed
225
):
lucas_miranda's avatar
lucas_miranda committed
226
227
    """Autoencoder metric logging in tensorboard"""

228
229
    outputs = ae.predict(X_val)
    idx_generator = (idx for idx in range(len(outputs)))
lucas_miranda's avatar
lucas_miranda committed
230
231
232

    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hpms)  # record the values used in this trial
233
234
        idx = next(idx_generator)

lucas_miranda's avatar
lucas_miranda committed
235
        val_mae = tf.reduce_mean(
236
            tf.keras.metrics.mean_absolute_error(y_val[idx], outputs[idx])
lucas_miranda's avatar
lucas_miranda committed
237
238
        )
        val_mse = tf.reduce_mean(
239
            tf.keras.metrics.mean_squared_error(y_val[idx], outputs[idx])
lucas_miranda's avatar
lucas_miranda committed
240
241
242
243
        )
        tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
        tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)

244
        if next_sequence_prediction:
245
            idx = next(idx_generator)
lucas_miranda's avatar
lucas_miranda committed
246
            pred_mae = tf.reduce_mean(
247
                tf.keras.metrics.mean_absolute_error(y_val[idx], outputs[idx])
lucas_miranda's avatar
lucas_miranda committed
248
249
            )
            pred_mse = tf.reduce_mean(
250
                tf.keras.metrics.mean_squared_error(y_val[idx], outputs[idx])
251
252
253
254
255
256
            )
            tf.summary.scalar(
                "val_next_sequence_prediction_mae".format(rec), pred_mae, step=1
            )
            tf.summary.scalar(
                "val_next_sequence_prediction_mse".format(rec), pred_mse, step=1
lucas_miranda's avatar
lucas_miranda committed
257
258
            )

259
        if phenotype_prediction:
260
            idx = next(idx_generator)
261
262
263
            pheno_acc = tf.keras.metrics.binary_accuracy(
                y_val[idx], tf.squeeze(outputs[idx])
            )
264
            pheno_auc = tf.keras.metrics.AUC()
265
            pheno_auc.update_state(y_val[idx], outputs[idx])
266
            pheno_auc = pheno_auc.result().numpy()
lucas_miranda's avatar
lucas_miranda committed
267
268
269
270

            tf.summary.scalar("phenotype_prediction_accuracy", pheno_acc, step=1)
            tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)

271
        if rule_based_prediction:
272
            idx = next(idx_generator)
273
            rules_mae = tf.reduce_mean(
274
                tf.keras.metrics.mean_absolute_error(y_val[idx], outputs[idx])
275
276
            )
            rules_mse = tf.reduce_mean(
277
                tf.keras.metrics.mean_squared_error(y_val[idx], outputs[idx])
278
279
280
281
            )
            tf.summary.scalar("val_prediction_mae".format(rec), rules_mae, step=1)
            tf.summary.scalar("val_prediction_mse".format(rec), rules_mse, step=1)

lucas_miranda's avatar
lucas_miranda committed
282

283
def autoencoder_fitting(
284
285
286
287
288
289
290
291
292
293
294
295
296
    preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    batch_size: int,
    encoding_size: int,
    epochs: int,
    hparams: dict,
    kl_warmup: int,
    log_history: bool,
    log_hparams: bool,
    loss: str,
    mmd_warmup: int,
    montecarlo_kl: int,
    n_components: int,
    output_path: str,
297
298
299
    next_sequence_prediction: float,
    phenotype_prediction: float,
    rule_based_prediction: float,
300
301
302
303
304
305
306
307
    pretrained: str,
    save_checkpoints: bool,
    save_weights: bool,
    variational: bool,
    reg_cat_clusters: bool,
    reg_cluster_variance: bool,
    entropy_samples: int,
    entropy_knn: int,
308
):
309
310
    """Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""

311
    # Load data
312
313
314
315
316
    X_train, y_train, X_val, y_val = preprocessed_object

    # To avoid stability issues
    tf.keras.backend.clear_session()

317
    # Defines what to log on tensorboard (useful for trying out different models)
318
319
    logparam = {
        "encoding": encoding_size,
320
        "k": n_components,
321
322
        "loss": loss,
    }
323
324
    if phenotype_prediction:
        logparam["pheno_weight"] = phenotype_prediction
325

326
    # Load callbacks
327
    run_ID, *cbacks = get_callbacks(
328
        X_train=X_train,
lucas_miranda's avatar
lucas_miranda committed
329
        X_val=(X_val if X_val.shape != (0,) else None),
330
331
332
        batch_size=batch_size,
        cp=save_checkpoints,
        variational=variational,
333
334
335
        next_sequence_prediction=next_sequence_prediction,
        phenotype_prediction=phenotype_prediction,
        rule_based_prediction=rule_based_prediction,
336
        loss=loss,
337
        entropy_samples=entropy_samples,
338
        entropy_knn=entropy_knn,
339
        reg_cat_clusters=reg_cat_clusters,
340
        reg_cluster_variance=reg_cluster_variance,
341
342
343
        logparam=logparam,
        outpath=output_path,
    )
344
345
    if not log_history:
        cbacks = cbacks[1:]
346

347
    # Logs hyperparameters to tensorboard
348
    rec = "reconstruction_" if phenotype_prediction else ""
349
    if log_hparams:
350
        logparams, metrics = log_hyperparameters(phenotype_prediction, rec)
351
352

        with tf.summary.create_file_writer(
353
            os.path.join(output_path, "hparams", run_ID)
354
355
356
357
358
        ).as_default():
            hp.hparams_config(
                hparams=logparams,
                metrics=metrics,
            )
359

360
361
362
363
364
365
366
367
    # Gets the number of rule-based features
    try:
        rule_based_features = (
            y_train.shape[1] if not phenotype_prediction else y_train.shape[1] - 1
        )
    except IndexError:
        rule_based_features = 0

368
    # Build models
369
370
371
372
373
374
375
    if not variational:
        encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
            ({} if hparams is None else hparams)
        ).build(X_train.shape)
        return_list = (encoder, decoder, ae)

    else:
376
377
378
379
380
381
382
383
        (
            encoder,
            generator,
            grouper,
            ae,
            prior,
            posterior,
        ) = deepof.models.SEQ_2_SEQ_GMVAE(
384
385
386
387
388
389
390
391
392
393
394
            architecture_hparams=({} if hparams is None else hparams),
            batch_size=batch_size,
            compile_model=True,
            encoding=encoding_size,
            kl_warmup_epochs=kl_warmup,
            loss=loss,
            mmd_warmup_epochs=mmd_warmup,
            montecarlo_kl=montecarlo_kl,
            neuron_control=False,
            number_of_components=n_components,
            overlap_loss=False,
395
396
397
            next_sequence_prediction=next_sequence_prediction,
            phenotype_prediction=phenotype_prediction,
            rule_based_prediction=rule_based_prediction,
398
            rule_based_features=rule_based_features,
399
400
            reg_cat_clusters=reg_cat_clusters,
            reg_cluster_variance=reg_cluster_variance,
401
402
403
        ).build(
            X_train.shape
        )
404
405
406
        return_list = (encoder, generator, grouper, ae)

    if pretrained:
407
        # If pretrained models are specified, load weights and return
408
409
410
411
412
413
414
415
416
        ae.load_weights(pretrained)
        return return_list

    else:
        if not variational:

            ae.fit(
                x=X_train,
                y=X_train,
417
                epochs=epochs,
418
419
420
                batch_size=batch_size,
                verbose=1,
                validation_data=(X_val, X_val),
421
                callbacks=cbacks
422
423
424
425
426
427
428
429
                + [
                    CustomStopper(
                        monitor="val_loss",
                        patience=5,
                        restore_best_weights=True,
                        start_epoch=max(kl_warmup, mmd_warmup),
                    ),
                ],
430
431
            )

432
433
434
            if save_weights:
                ae.save_weights("{}_final_weights.h5".format(run_ID))

435
436
        else:

437
            callbacks_ = cbacks + [
438
439
440
441
                CustomStopper(
                    monitor="val_loss",
                    patience=5,
                    restore_best_weights=True,
442
                    start_epoch=max(kl_warmup, mmd_warmup),
443
444
445
                ),
            ]

446
447
448
            Xs, ys = [X_train], [X_train]
            Xvals, yvals = [X_val], [X_val]

449
            if next_sequence_prediction > 0.0:
450
451
452
                Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
                Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]

453
            if phenotype_prediction > 0.0:
454
455
                ys += [y_train[-Xs.shape[0] :, 0]]
                yvals += [y_val[-Xs.shape[0] :, 0]]
456
457
458
459
460
461

                # Remove the used column (phenotype) from both y arrays
                y_train = y_train[:, 1:]
                y_val = y_val[:, 1:]

            if rule_based_prediction > 0.0:
462
463
                ys += [y_train[-Xs.shape[0]:]]
                yvals += [y_val[-Xs.shape[0]:]]
464

465
            ae.fit(
466
467
                x=Xs,
                y=ys,
468
                epochs=epochs,
469
470
471
472
473
474
475
476
477
                batch_size=batch_size,
                verbose=1,
                validation_data=(
                    Xvals,
                    yvals,
                ),
                callbacks=callbacks_,
            )

478
            if not os.path.exists(os.path.join(output_path, "trained_weights")):
479
480
                os.makedirs("trained_weights")

481
            if save_weights:
482
483
                ae.save_weights(
                    os.path.join(
484
485
486
                        "{}".format(output_path),
                        "trained_weights",
                        "{}_final_weights.h5".format(run_ID),
487
488
                    )
                )
489

490
491
492
            if log_hparams:
                # Logparams to tensorboard
                tensorboard_metric_logging(
493
494
495
496
497
498
499
500
501
                    run_dir=os.path.join(output_path, "hparams", run_ID),
                    hpms=logparam,
                    ae=ae,
                    X_val=Xvals,
                    y_val=yvals,
                    next_sequence_prediction=next_sequence_prediction,
                    phenotype_prediction=phenotype_prediction,
                    rule_based_prediction=rule_based_prediction,
                    rec=rec,
502
                )
503

504
505
506
    return return_list


507
def tune_search(
508
509
510
511
512
513
514
515
516
517
    data: List[np.array],
    encoding_size: int,
    hypertun_trials: int,
    hpt_type: str,
    hypermodel: str,
    k: int,
    kl_warmup_epochs: int,
    loss: str,
    mmd_warmup_epochs: int,
    overlap_loss: float,
518
519
520
    next_sequence_prediction: float,
    phenotype_prediction: float,
    rule_based_prediction: float,
521
522
523
524
525
    project_name: str,
    callbacks: List,
    n_epochs: int = 30,
    n_replicas: int = 1,
    outpath: str = ".",
526
) -> Union[bool, Tuple[Any, Any]]:
527
528
    """Define the search space using keras-tuner and bayesian optimization

529
530
531
532
533
534
535
536
537
538
539
    Parameters:
        - train (np.array): dataset to train the model on
        - test (np.array): dataset to validate the model on
        - hypertun_trials (int): number of Bayesian optimization iterations to run
        - hpt_type (str): specify one of Bayesian Optimization (bayopt) and Hyperband (hyperband)
        - hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder)
        or S2SGMVAE (Gaussian Mixture Variational autoencoder).
        - k (int) number of components of the Gaussian Mixture
        - loss (str): one of [ELBO, MMD, ELBO+MMD]
        - overlap_loss (float): assigns as weight to an extra loss term which
        penalizes overlap between GM components
540
        - phenotype_class (float): adds an extra regularizing neural network to the model,
541
542
543
544
545
546
547
548
549
550
551
552
        which tries to predict the phenotype of the animal from which the sequence comes
        - predictor (float): adds an extra regularizing neural network to the model,
        which tries to predict the next frame from the current one
        - project_name (str): ID of the current run
        - callbacks (list): list of callbacks for the training loop
        - n_epochs (int): optional. Number of epochs to train each run for
        - n_replicas (int): optional. Number of replicas per parameter set. Higher values
         will yield more robust results, but will affect performance severely

    Returns:
        - best_hparams (dict): dictionary with the best retrieved hyperparameters
        - best_run (tf.keras.Model): trained instance of the best model found
553
554
555

    """

556
557
    X_train, y_train, X_val, y_val = data

558
559
560
561
    assert hpt_type in ["bayopt", "hyperband"], (
        "Invalid hyperparameter tuning framework. " "Select one of bayopt and hyperband"
    )

lucas_miranda's avatar
lucas_miranda committed
562
    if hypermodel == "S2SAE":  # pragma: no cover
563
        assert (
564
            next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
565
        ), "Prediction branches are only available for variational models. See documentation for more details"
566
        batch_size = 1
567
        hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
568
569

    elif hypermodel == "S2SGMVAE":
570
        batch_size = 64
571
        hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
572
            input_shape=X_train.shape,
573
            encoding=encoding_size,
574
            kl_warmup_epochs=kl_warmup_epochs,
575
            loss=loss,
576
            mmd_warmup_epochs=mmd_warmup_epochs,
577
            number_of_components=k,
578
            overlap_loss=overlap_loss,
579
580
581
            next_sequence_prediction=next_sequence_prediction,
            phenotype_prediction=phenotype_prediction,
            rule_based_prediction=rule_based_prediction,
582
583
584
            rule_based_features=(
                y_train.shape[1] if not phenotype_prediction else y_train.shape[1] - 1
            ),
585
        )
lucas_miranda's avatar
lucas_miranda committed
586

587
588
589
    else:
        return False

590
591
592
    hpt_params = {
        "hypermodel": hypermodel,
        "executions_per_trial": n_replicas,
593
594
595
        "logger": TensorBoardLogger(
            metrics=["val_mae"], logdir=os.path.join(outpath, "logged_hparams")
        ),
596
597
598
599
600
601
602
        "objective": "val_mae",
        "project_name": project_name,
        "tune_new_entries": True,
    }

    if hpt_type == "hyperband":
        tuner = Hyperband(
603
604
605
            directory=os.path.join(
                outpath, "HyperBandx_{}_{}".format(loss, str(date.today()))
            ),
606
607
            max_epochs=35,
            hyperband_iterations=hypertun_trials,
608
            factor=3,
609
610
611
612
            **hpt_params
        )
    else:
        tuner = BayesianOptimization(
613
614
615
            directory=os.path.join(
                outpath, "BayOpt_{}_{}".format(loss, str(date.today()))
            ),
616
617
618
            max_trials=hypertun_trials,
            **hpt_params
        )
619
620
621

    print(tuner.search_space_summary())

622
623
624
    Xs, ys = [X_train], [X_train]
    Xvals, yvals = [X_val], [X_val]

625
    if next_sequence_prediction > 0.0:
626
627
628
        Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
        Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]

629
630
631
632
633
634
635
636
637
    if phenotype_prediction > 0.0:
        ys += [y_train[:, 0]]
        yvals += [y_val[:, 0]]

        # Remove the used column (phenotype) from both y arrays
        y_train = y_train[:, 1:]
        y_val = y_val[:, 1:]

    if rule_based_prediction > 0.0:
638
639
640
        ys += [y_train]
        yvals += [y_val]

641
    tuner.search(
642
643
        Xs,
        ys,
644
        epochs=n_epochs,
645
        validation_data=(Xvals, yvals),
646
        verbose=1,
647
        batch_size=batch_size,
lucas_miranda's avatar
lucas_miranda committed
648
        callbacks=callbacks,
649
650
651
652
653
    )

    best_hparams = tuner.get_best_hyperparameters(num_trials=1)[0]
    best_run = tuner.hypermodel.build(best_hparams)

lucas_miranda's avatar
lucas_miranda committed
654
655
    print(tuner.results_summary())

656
    return best_hparams, best_run