Commit 2158e4b7 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored rule_based_tagging and fixed type annotation problems

parent 1e83611f
......@@ -12,6 +12,7 @@ import warnings
import networkx as nx
from deepof.utils import *
from deepof.visuals import *
class project:
......
This diff is collapsed.
......@@ -11,7 +11,9 @@ from typing import List, Dict
# PLOTTING FUNCTIONS #
def plot_speed(behaviour_dict: dict, treatments: Dict[List]) -> plt.figure:
def plot_speed(
behaviour_dict: Dict[str, pd.DataFrame], treatments: Dict[str, List]
) -> plt.figure:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
......
# @author lucasmiranda42
from hypothesis import given
from hypothesis import HealthCheck
from hypothesis import settings
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
from hypothesis.extra.pandas import range_indexes, columns, data_frames
from scipy.spatial import distance
from deepof.utils import *
import deepof.preprocess
import pytest
......@@ -11,6 +11,7 @@ from deepof.utils import *
import deepof.preprocess
import pytest
# AUXILIARY FUNCTIONS #
......@@ -343,7 +344,7 @@ def test_close_single_contact(pos_dframe, tol):
[["bpart1", "bpart2"], ["X", "y"]], names=["bodyparts", "coords"],
)
pos_dframe.columns = idx
close_contact = close_single_contact(pos_dframe, "bpart1", "bpart2", tol)
close_contact = close_single_contact(pos_dframe, "bpart1", "bpart2", tol, 1, 1)
assert close_contact.dtype == bool
assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
......@@ -375,7 +376,7 @@ def test_close_double_contact(pos_dframe, tol, rev):
)
pos_dframe.columns = idx
close_contact = close_double_contact(
pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev
pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev, 1, 1
)
assert close_contact.dtype == bool
assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
......@@ -730,3 +731,9 @@ def test_cluster_transition_matrix(sampler, autocorrelation, return_graph):
assert type(trans) == nx.Graph
else:
assert type(trans) == np.ndarray
@settings(deadline=None)
@given()
def test_rule_based_tagging():
pass
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment