From 952893f8357cee1d71ea670ec7df7d7446498f84 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 16 Sep 2020 18:51:33 +0200 Subject: [PATCH] Added tests for preprocess.py --- .gitlab-ci.yml | 2 +- deepof/utils.py | 2 ++ .../train_viz_data_generator.py | 14 ++++++------- tests/test_preprocess.py | 21 ++++++++++++------- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ed7c3c14..f6488211 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,7 +9,7 @@ test: - pip install -r ./deepof/requirements.txt - pip install -e deepof/ - coverage run --source deepof -m pytest - - coverage report -m --omit deepof/setup.py + - coverage report -m --include deepof/utils.py,deepof/preprocess.py,deepof/model_utils.py,deepof/visuals.py - coverage xml -o deepof_cov.xml artifacts: reports: diff --git a/deepof/utils.py b/deepof/utils.py index 3f7ec333..26186ac9 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -202,6 +202,8 @@ def align_trajectories(data: np.array, mode: str = "all") -> np.array: Returns: - aligned_trajs (2D np.array): aligned positions over time""" + print(data.shape, mode) + angles = np.zeros(data.shape[0]) data = deepcopy(data) dshape = data.shape diff --git a/examples/visualizations/train_viz_data_generator.py b/examples/visualizations/train_viz_data_generator.py index 05b29698..9cf3f3c2 100644 --- a/examples/visualizations/train_viz_data_generator.py +++ b/examples/visualizations/train_viz_data_generator.py @@ -236,7 +236,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -245,7 +245,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -254,7 +254,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -263,7 +263,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -272,7 +272,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -281,7 +281,7 @@ input_dict = { window_size=13, window_step=10, scale="standard", - filter=None, + conv_filter=None, sigma=55, align="center", ), @@ -289,7 +289,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 5d790a96..388ccb0b 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -147,18 +147,20 @@ def test_get_table_dicts(nodes, ego, sampler): ).run(verbose=False) algn = sampler.draw(st.one_of(st.just(False), st.just("Nose"))) + polar = sampler.draw(st.booleans()) + speed = sampler.draw(st.integers(min_value=0, max_value=5)) coords = prun.get_coords( center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))), - polar=sampler.draw(st.booleans()), + polar=polar, length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), - align=algn + align=algn, ) speeds = prun.get_coords( center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))), polar=sampler.draw(st.booleans()), length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), - speed=sampler.draw(st.integers(min_value=0, max_value=5)), + speed=speed, ) distances = prun.get_distances( length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), @@ -197,6 +199,13 @@ def test_get_table_dicts(nodes, ego, sampler): assert len(tset) == 2 assert type(tset[0]) == np.ndarray + if table._type == "coords" and algn == "Nose" and polar is False and speed == 0: + align = sampler.draw( + st.one_of(st.just(False), st.just("all"), st.just("center")) + ) + else: + align = False + table.preprocess( window_size=11, window_step=1, @@ -207,9 +216,5 @@ def test_get_table_dicts(nodes, ego, sampler): sigma=sampler.draw(st.floats(min_value=0.5, max_value=5.0)), shift=sampler.draw(st.floats(min_value=-1.0, max_value=1.0)), shuffle=sampler.draw(st.booleans()), - align=( - sampler.draw(st.one_of(st.just(False), st.just("all"), st.just("center"))) - if (table._type == "coords" and algn == "Nose") - else False - ), + align=align, ) -- GitLab