Commit 952893f8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for preprocess.py

parent b08f926c
...@@ -9,7 +9,7 @@ test: ...@@ -9,7 +9,7 @@ test:
- pip install -r ./deepof/requirements.txt - pip install -r ./deepof/requirements.txt
- pip install -e deepof/ - pip install -e deepof/
- coverage run --source deepof -m pytest - coverage run --source deepof -m pytest
- coverage report -m --omit deepof/setup.py - coverage report -m --include deepof/utils.py,deepof/preprocess.py,deepof/model_utils.py,deepof/visuals.py
- coverage xml -o deepof_cov.xml - coverage xml -o deepof_cov.xml
artifacts: artifacts:
reports: reports:
......
...@@ -202,6 +202,8 @@ def align_trajectories(data: np.array, mode: str = "all") -> np.array: ...@@ -202,6 +202,8 @@ def align_trajectories(data: np.array, mode: str = "all") -> np.array:
Returns: Returns:
- aligned_trajs (2D np.array): aligned positions over time""" - aligned_trajs (2D np.array): aligned positions over time"""
print(data.shape, mode)
angles = np.zeros(data.shape[0]) angles = np.zeros(data.shape[0])
data = deepcopy(data) data = deepcopy(data)
dshape = data.shape dshape = data.shape
......
...@@ -236,7 +236,7 @@ input_dict = { ...@@ -236,7 +236,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=1, window_step=1,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align="center", align="center",
...@@ -245,7 +245,7 @@ input_dict = { ...@@ -245,7 +245,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=1, window_step=1,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align="center", align="center",
...@@ -254,7 +254,7 @@ input_dict = { ...@@ -254,7 +254,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=1, window_step=1,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align="center", align="center",
...@@ -263,7 +263,7 @@ input_dict = { ...@@ -263,7 +263,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=1, window_step=1,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align="center", align="center",
...@@ -272,7 +272,7 @@ input_dict = { ...@@ -272,7 +272,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=1, window_step=1,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align="center", align="center",
...@@ -281,7 +281,7 @@ input_dict = { ...@@ -281,7 +281,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=10, window_step=10,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
align="center", align="center",
), ),
...@@ -289,7 +289,7 @@ input_dict = { ...@@ -289,7 +289,7 @@ input_dict = {
window_size=13, window_size=13,
window_step=1, window_step=1,
scale="standard", scale="standard",
filter=None, conv_filter=None,
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align="center", align="center",
......
...@@ -147,18 +147,20 @@ def test_get_table_dicts(nodes, ego, sampler): ...@@ -147,18 +147,20 @@ def test_get_table_dicts(nodes, ego, sampler):
).run(verbose=False) ).run(verbose=False)
algn = sampler.draw(st.one_of(st.just(False), st.just("Nose"))) algn = sampler.draw(st.one_of(st.just(False), st.just("Nose")))
polar = sampler.draw(st.booleans())
speed = sampler.draw(st.integers(min_value=0, max_value=5))
coords = prun.get_coords( coords = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))), center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
polar=sampler.draw(st.booleans()), polar=polar,
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
align=algn align=algn,
) )
speeds = prun.get_coords( speeds = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))), center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
polar=sampler.draw(st.booleans()), polar=sampler.draw(st.booleans()),
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
speed=sampler.draw(st.integers(min_value=0, max_value=5)), speed=speed,
) )
distances = prun.get_distances( distances = prun.get_distances(
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))), length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
...@@ -197,6 +199,13 @@ def test_get_table_dicts(nodes, ego, sampler): ...@@ -197,6 +199,13 @@ def test_get_table_dicts(nodes, ego, sampler):
assert len(tset) == 2 assert len(tset) == 2
assert type(tset[0]) == np.ndarray assert type(tset[0]) == np.ndarray
if table._type == "coords" and algn == "Nose" and polar is False and speed == 0:
align = sampler.draw(
st.one_of(st.just(False), st.just("all"), st.just("center"))
)
else:
align = False
table.preprocess( table.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=1,
...@@ -207,9 +216,5 @@ def test_get_table_dicts(nodes, ego, sampler): ...@@ -207,9 +216,5 @@ def test_get_table_dicts(nodes, ego, sampler):
sigma=sampler.draw(st.floats(min_value=0.5, max_value=5.0)), sigma=sampler.draw(st.floats(min_value=0.5, max_value=5.0)),
shift=sampler.draw(st.floats(min_value=-1.0, max_value=1.0)), shift=sampler.draw(st.floats(min_value=-1.0, max_value=1.0)),
shuffle=sampler.draw(st.booleans()), shuffle=sampler.draw(st.booleans()),
align=( align=align,
sampler.draw(st.one_of(st.just(False), st.just("all"), st.just("center")))
if (table._type == "coords" and algn == "Nose")
else False
),
) )
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment