Commit 952893f8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for preprocess.py

parent b08f926c
......@@ -9,7 +9,7 @@ test:
- pip install -r ./deepof/requirements.txt
- pip install -e deepof/
- coverage run --source deepof -m pytest
- coverage report -m --omit deepof/setup.py
- coverage report -m --include deepof/utils.py,deepof/preprocess.py,deepof/model_utils.py,deepof/visuals.py
- coverage xml -o deepof_cov.xml
artifacts:
reports:
......
......@@ -202,6 +202,8 @@ def align_trajectories(data: np.array, mode: str = "all") -> np.array:
Returns:
- aligned_trajs (2D np.array): aligned positions over time"""
print(data.shape, mode)
angles = np.zeros(data.shape[0])
data = deepcopy(data)
dshape = data.shape
......
......@@ -236,7 +236,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
......@@ -245,7 +245,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
......@@ -254,7 +254,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
......@@ -263,7 +263,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
......@@ -272,7 +272,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
......@@ -281,7 +281,7 @@ input_dict = {
window_size=13,
window_step=10,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
align="center",
),
......@@ -289,7 +289,7 @@ input_dict = {
window_size=13,
window_step=1,
scale="standard",
filter=None,
conv_filter=None,
sigma=55,
shuffle=True,
align="center",
......
......@@ -147,18 +147,20 @@ def test_get_table_dicts(nodes, ego, sampler):
).run(verbose=False)
algn = sampler.draw(st.one_of(st.just(False), st.just("Nose")))
polar = sampler.draw(st.booleans())
speed = sampler.draw(st.integers(min_value=0, max_value=5))
coords = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
polar=sampler.draw(st.booleans()),
polar=polar,
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
align=algn
align=algn,
)
speeds = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
polar=sampler.draw(st.booleans()),
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
speed=sampler.draw(st.integers(min_value=0, max_value=5)),
speed=speed,
)
distances = prun.get_distances(
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
......@@ -197,6 +199,13 @@ def test_get_table_dicts(nodes, ego, sampler):
assert len(tset) == 2
assert type(tset[0]) == np.ndarray
if table._type == "coords" and algn == "Nose" and polar is False and speed == 0:
align = sampler.draw(
st.one_of(st.just(False), st.just("all"), st.just("center"))
)
else:
align = False
table.preprocess(
window_size=11,
window_step=1,
......@@ -207,9 +216,5 @@ def test_get_table_dicts(nodes, ego, sampler):
sigma=sampler.draw(st.floats(min_value=0.5, max_value=5.0)),
shift=sampler.draw(st.floats(min_value=-1.0, max_value=1.0)),
shuffle=sampler.draw(st.booleans()),
align=(
sampler.draw(st.one_of(st.just(False), st.just("all"), st.just("center")))
if (table._type == "coords" and algn == "Nose")
else False
),
align=align,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment