diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ed7c3c14e09984f82b3680b3a9dd326f17a2f05d..f64882113bf56aefaafb0e79848f2d6c965038b0 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,7 +9,7 @@ test: - pip install -r ./deepof/requirements.txt - pip install -e deepof/ - coverage run --source deepof -m pytest - - coverage report -m --omit deepof/setup.py + - coverage report -m --include deepof/utils.py,deepof/preprocess.py,deepof/model_utils.py,deepof/visuals.py - coverage xml -o deepof_cov.xml artifacts: reports: diff --git a/deepof/utils.py b/deepof/utils.py index 3f7ec33302965a5e6aaa84437cbebccb1b62d5cb..26186ac9ce41a1892dfc6217ef36cfd1dd51c9ae 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -202,6 +202,8 @@ def align_trajectories(data: np.array, mode: str = "all") -> np.array: Returns: - aligned_trajs (2D np.array): aligned positions over time""" + print(data.shape, mode) + angles = np.zeros(data.shape[0]) data = deepcopy(data) dshape = data.shape diff --git a/examples/visualizations/train_viz_data_generator.py b/examples/visualizations/train_viz_data_generator.py index 05b296986ade6555510f0fbace9c56e726c22101..9cf3f3c2f4347be2cea0b8de3df683562c13c72c 100644 --- a/examples/visualizations/train_viz_data_generator.py +++ b/examples/visualizations/train_viz_data_generator.py @@ -236,7 +236,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -245,7 +245,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -254,7 +254,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -263,7 +263,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -272,7 +272,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", @@ -281,7 +281,7 @@ input_dict = { window_size=13, window_step=10, scale="standard", - filter=None, + conv_filter=None, sigma=55, align="center", ), @@ -289,7 +289,7 @@ input_dict = { window_size=13, window_step=1, scale="standard", - filter=None, + conv_filter=None, sigma=55, shuffle=True, align="center", diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 5d790a9615b54187c7b42f0ff292ad217621d646..388ccb0bd107b65b28439bca56dc247c31647c33 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -147,18 +147,20 @@ def test_get_table_dicts(nodes, ego, sampler): ).run(verbose=False) algn = sampler.draw(st.one_of(st.just(False), st.just("Nose"))) + polar = sampler.draw(st.booleans()) + speed = sampler.draw(st.integers(min_value=0, max_value=5)) coords = prun.get_coords( center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))), - polar=sampler.draw(st.booleans()), + polar=polar, length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), - align=algn + align=algn, ) speeds = prun.get_coords( center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))), polar=sampler.draw(st.booleans()), length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), - speed=sampler.draw(st.integers(min_value=0, max_value=5)), + speed=speed, ) distances = prun.get_distances( length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), @@ -197,6 +199,13 @@ def test_get_table_dicts(nodes, ego, sampler): assert len(tset) == 2 assert type(tset[0]) == np.ndarray + if table._type == "coords" and algn == "Nose" and polar is False and speed == 0: + align = sampler.draw( + st.one_of(st.just(False), st.just("all"), st.just("center")) + ) + else: + align = False + table.preprocess( window_size=11, window_step=1, @@ -207,9 +216,5 @@ def test_get_table_dicts(nodes, ego, sampler): sigma=sampler.draw(st.floats(min_value=0.5, max_value=5.0)), shift=sampler.draw(st.floats(min_value=-1.0, max_value=1.0)), shuffle=sampler.draw(st.booleans()), - align=( - sampler.draw(st.one_of(st.just(False), st.just("all"), st.just("center"))) - if (table._type == "coords" and algn == "Nose") - else False - ), + align=align, )