Commit b08f926c authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for preprocess.py

parent 62b5debb
......@@ -556,9 +556,9 @@ class table_dict(dict):
scale="standard",
test_videos=0,
verbose=False,
filter=None,
sigma=None,
shift=0,
conv_filter=None,
sigma=1.0,
shift=0.0,
shuffle=False,
align=False,
):
......@@ -586,7 +586,7 @@ class table_dict(dict):
if scale == "standard":
assert np.allclose(np.mean(X_train), 0)
assert np.allclose(np.std(X_train, ddof=1), 1)
assert np.allclose(np.std(X_train), 1)
if test_videos:
X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
......@@ -604,7 +604,7 @@ class table_dict(dict):
if align == "center":
X_train = align_trajectories(X_train, align)
if filter == "gaussian":
if conv_filter == "gaussian":
r = range(-int(window_size / 2), int(window_size / 2) + 1)
r = [i - shift for i in r]
g = np.array(
......@@ -628,7 +628,7 @@ class table_dict(dict):
if align == "center":
X_test = align_trajectories(X_test, align)
if filter == "gaussian":
if conv_filter == "gaussian":
X_test = X_test * g.reshape(1, window_size, 1)
if shuffle:
......
......@@ -170,43 +170,43 @@ coords_dist_angles2 = merge_tables(coords2, distances2, angles2)
input_dict_train = {
"coords": coords1.preprocess(
window_size=11, window_step=10, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=10, scale=True, random_state=42, conv_filter="gauss"
),
"dists": distances1.preprocess(
window_size=11, window_step=10, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=10, scale=True, random_state=42, conv_filter="gauss"
),
"angles": angles1.preprocess(
window_size=11, window_step=10, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=10, scale=True, random_state=42, conv_filter="gauss"
),
"coords+dist": coords_distances1.preprocess(
window_size=11, window_step=10, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=10, scale=True, random_state=42, conv_filter="gauss"
),
"coords+angle": coords_angles1.preprocess(
window_size=11, window_step=10, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=10, scale=True, random_state=42, conv_filter="gauss"
),
"coords+dist+angle": coords_dist_angles1.preprocess(
window_size=11, window_step=10, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=10, scale=True, random_state=42, conv_filter="gauss"
),
}
input_dict_val = {
"coords": coords2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=1, scale=True, random_state=42, conv_filter="gauss"
),
"dists": distances2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=1, scale=True, random_state=42, conv_filter="gauss"
),
"angles": angles2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=1, scale=True, random_state=42, conv_filter="gauss"
),
"coords+dist": coords_distances2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=1, scale=True, random_state=42, conv_filter="gauss"
),
"coords+angle": coords_angles2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=1, scale=True, random_state=42, conv_filter="gauss"
),
"coords+dist+angle": coords_dist_angles2.preprocess(
window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
window_size=11, window_step=1, scale=True, random_state=42, conv_filter="gauss"
),
}
......
......@@ -202,7 +202,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......@@ -213,7 +213,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......@@ -224,7 +224,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......@@ -235,7 +235,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......@@ -246,7 +246,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......@@ -257,7 +257,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......@@ -268,7 +268,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=False,
align="center",
......
......@@ -144,13 +144,15 @@ def test_get_table_dicts(nodes, ego, sampler):
table_format=".h5",
distances=nodes,
ego=ego,
).run(verbose=True)
).run(verbose=False)
algn = sampler.draw(st.one_of(st.just(False), st.just("Nose")))
coords = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
polar=sampler.draw(st.booleans()),
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
align=sampler.draw(st.one_of(st.just(False), st.just("Nose"))),
align=algn
)
speeds = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
......@@ -182,7 +184,32 @@ def test_get_table_dicts(nodes, ego, sampler):
# deepof.table_dict testing
table = sampler.draw(
st.one_of(st.just(coords), st.just(speeds), st.just(distances), st.just(angles))
st.one_of(
st.just(coords), st.just(speeds), st.just(distances), st.just(angles)
),
st.just(deepof.preprocess.merge_tables(coords, speeds, distances, angles)),
)
#table.filter()
\ No newline at end of file
assert table.filter(["test"]) == table
tset = table.get_training_set(
test_videos=sampler.draw(st.integers(min_value=0, max_value=len(table) - 1))
)
assert len(tset) == 2
assert type(tset[0]) == np.ndarray
table.preprocess(
window_size=11,
window_step=1,
scale=sampler.draw(st.one_of(st.just("standard"), st.just("minmax"))),
test_videos=sampler.draw(st.integers(min_value=0, max_value=len(table) - 1)),
verbose=True,
conv_filter=sampler.draw(st.one_of(st.just(None), st.just("gaussian"))),
sigma=sampler.draw(st.floats(min_value=0.5, max_value=5.0)),
shift=sampler.draw(st.floats(min_value=-1.0, max_value=1.0)),
shuffle=sampler.draw(st.booleans()),
align=(
sampler.draw(st.one_of(st.just(False), st.just("all"), st.just("center")))
if (table._type == "coords" and algn == "Nose")
else False
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment