test_train_utils.py 6.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# @author lucasmiranda42
# encoding: utf-8
# module deepof

"""

Testing module for deepof.train_utils

"""

11
12
13
14
import os

import numpy as np
import tensorflow as tf
lucas_miranda's avatar
lucas_miranda committed
15
from hypothesis import HealthCheck
16
from hypothesis import given
lucas_miranda's avatar
lucas_miranda committed
17
from hypothesis import settings
18
19
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
20

21
import deepof.data
22
import deepof.model_utils
23
24
25
26
import deepof.train_utils


def test_load_treatments():
27
    assert deepof.train_utils.load_treatments("tests") is None
28
29
30
31
32
33
    assert isinstance(
        deepof.train_utils.load_treatments(
            os.path.join("tests", "test_examples", "test_single_topview", "Others")
        ),
        dict,
    )
34
35


36
37
@given(
    X_train=arrays(
38
        shape=st.tuples(st.integers(min_value=1, max_value=1000), st.just(24)),
39
        dtype=float,
40
41
42
43
        elements=st.floats(
            min_value=0.0,
            max_value=1,
        ),
44
45
46
    ),
    batch_size=st.integers(min_value=128, max_value=512),
    loss=st.one_of(st.just("test_A"), st.just("test_B")),
47
48
49
    next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
    phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
    rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
50
    overlap_loss=st.floats(min_value=0.0, max_value=1.0),
51
52
)
def test_get_callbacks(
53
54
    X_train,
    batch_size,
55
56
57
    next_sequence_prediction,
    phenotype_prediction,
    rule_based_prediction,
58
    overlap_loss,
59
    loss,
60
):
61
    callbacks = deepof.train_utils.get_callbacks(
62
63
        X_train=X_train,
        batch_size=batch_size,
64
        phenotype_prediction=phenotype_prediction,
65
        next_sequence_prediction=next_sequence_prediction,
66
        rule_based_prediction=rule_based_prediction,
67
        overlap_loss=overlap_loss,
68
69
        loss=loss,
        X_val=X_train,
70
        input_type=False,
71
72
73
74
        cp=True,
        reg_cat_clusters=False,
        reg_cluster_variance=False,
        logparam={"encoding": 2, "k": 15},
75
    )
76
77
78
79
80
81
82
    assert np.any([isinstance(i, str) for i in callbacks])
    assert np.any(
        [isinstance(i, tf.keras.callbacks.ModelCheckpoint) for i in callbacks]
    )
    assert np.any(
        [isinstance(i, deepof.model_utils.one_cycle_scheduler) for i in callbacks]
    )
83
84


85
@settings(max_examples=15, deadline=None, suppress_health_check=[HealthCheck.too_slow])
86
@given(
lucas_miranda's avatar
lucas_miranda committed
87
    loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")),
88
89
90
    next_sequence_prediction=st.one_of(st.just(0.0), st.just(1.0)),
    phenotype_prediction=st.one_of(st.just(0.0), st.just(1.0)),
    rule_based_prediction=st.one_of(st.just(0.0), st.just(1.0)),
91
92
93
)
def test_autoencoder_fitting(
    loss,
94
95
96
    next_sequence_prediction,
    phenotype_prediction,
    rule_based_prediction,
97
):
lucas_miranda's avatar
lucas_miranda committed
98
    X_train = np.random.uniform(-1, 1, [20, 5, 6])
99
    y_train = np.round(np.random.uniform(0, 1, [20, 1]))
100

101
    if rule_based_prediction:
102
103
104
        y_train = np.concatenate(
            [y_train, np.round(np.random.uniform(0, 1, [20, 6]), 1)], axis=1
        )
lucas_miranda's avatar
lucas_miranda committed
105

106
    if next_sequence_prediction:
lucas_miranda's avatar
lucas_miranda committed
107
108
        y_train = y_train[1:]

lucas_miranda's avatar
lucas_miranda committed
109
    preprocessed_data = (X_train, y_train, X_train, y_train)
110
111
112
113
114
115
116
117

    prun = deepof.data.project(
        path=os.path.join(".", "tests", "test_examples", "test_single_topview"),
        arena="circular",
        arena_dims=tuple([380]),
        video_format=".mp4",
    ).run()

lucas_miranda's avatar
lucas_miranda committed
118
    prun.deep_unsupervised_embedding(
119
        preprocessed_data,
120
        batch_size=10,
lucas_miranda's avatar
lucas_miranda committed
121
        encoding_size=2,
122
        epochs=1,
lucas_miranda's avatar
lucas_miranda committed
123
        kl_warmup=1,
lucas_miranda's avatar
lucas_miranda committed
124
        log_history=True,
125
        log_hparams=True,
lucas_miranda's avatar
lucas_miranda committed
126
        mmd_warmup=1,
lucas_miranda's avatar
lucas_miranda committed
127
        n_components=2,
128
        loss=loss,
129
130
131
        next_sequence_prediction=next_sequence_prediction,
        phenotype_prediction=phenotype_prediction,
        rule_based_prediction=rule_based_prediction,
132
        entropy_samples=10,
133
        entropy_knn=5,
134
135
136
    )


137
@settings(max_examples=5, deadline=None)
138
@given(
139
    X_train=arrays(
140
141
        dtype=float,
        shape=st.tuples(
142
143
            st.integers(min_value=128, max_value=512),
            st.integers(min_value=128, max_value=512),
144
145
            st.integers(min_value=2, max_value=10),
        ),
146
147
148
149
        elements=st.floats(
            min_value=0.0,
            max_value=1,
        ),
150
    ),
151
    y_train=st.data(),
152
    batch_size=st.just(128),
153
    encoding_size=st.integers(min_value=1, max_value=16),
154
    hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")),
155
156
157
    k=st.integers(min_value=1, max_value=10),
    loss=st.one_of(st.just("ELBO"), st.just("MMD")),
    overlap_loss=st.floats(min_value=0.0, max_value=1.0),
158
159
160
    next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
    phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
    rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
161
162
)
def test_tune_search(
163
    X_train,
164
    y_train,
165
    batch_size,
166
    encoding_size,
167
168
169
    hpt_type,
    k,
    loss,
170
171
172
    next_sequence_prediction,
    phenotype_prediction,
    rule_based_prediction,
173
    overlap_loss,
174
175
176
):
    callbacks = list(
        deepof.train_utils.get_callbacks(
177
178
            X_train=X_train,
            batch_size=batch_size,
179
180
181
            phenotype_prediction=np.round(phenotype_prediction, 2),
            next_sequence_prediction=np.round(next_sequence_prediction, 2),
            rule_based_prediction=np.round(rule_based_prediction, 2),
182
183
            loss=loss,
            X_val=X_train,
184
            input_type=False,
185
186
187
            cp=False,
            reg_cat_clusters=True,
            reg_cluster_variance=True,
188
            overlap_loss=overlap_loss,
189
            entropy_samples=10,
190
            entropy_knn=5,
191
            logparam={"encoding": 2, "k": 15},
192
        )
193
    )[1:]
194

195
196
197
198
199
200
201
    y_train = y_train.draw(
        arrays(
            dtype=np.float32,
            elements=st.floats(min_value=0.0, max_value=1.0, width=32),
            shape=(X_train.shape[1], 1),
        )
    )
202

203
    deepof.train_utils.tune_search(
204
        data=[X_train, y_train, X_train, y_train],
205
        encoding_size=encoding_size,
206
207
        hpt_type=hpt_type,
        hypertun_trials=1,
208
        k=k,
209
        kl_warmup_epochs=0,
210
        loss=loss,
211
        mmd_warmup_epochs=0,
212
        overlap_loss=overlap_loss,
213
214
215
        next_sequence_prediction=np.round(next_sequence_prediction, 2),
        phenotype_prediction=np.round(phenotype_prediction, 2),
        rule_based_prediction=np.round(rule_based_prediction, 2),
216
217
        project_name="test_run",
        callbacks=callbacks,
218
        n_epochs=1,
219
    )