diff --git a/deepof/utils.py b/deepof/utils.py index 617445b1fbb9a25668e921cb213e4cf335e03b74..04f2f06ffa2d682245d3e69d2b0341d90510a527 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -1021,7 +1021,7 @@ def gmm_compute(x: np.array, n_components: int, cv_type: str) -> list: def gmm_model_selection( - x: np.array, + x: pd.DataFrame, n_components_range: range, part_size: int, n_runs: int = 100, @@ -1032,7 +1032,7 @@ def gmm_model_selection( a vector with the median BICs and an object with the overall best model Parameters: - - x (numpy.array): data matrix to train the models + - x (pandas.DataFrame): data matrix to train the models - n_components_range (range): generator with numbers of components to evaluate - n_runs (int): number of bootstraps for each model - part_size (int): size of bootstrap samples for each model @@ -1083,10 +1083,22 @@ def gmm_model_selection( def cluster_transition_matrix( - cluster_sequence, nclusts, autocorrelation=True, return_graph=False -): - """ - Computes the transition matrix between clusters and the autocorrelation in the sequence. + cluster_sequence: np.array, + nclusts: int, + autocorrelation: bool = True, + return_graph: bool = False, +) -> Tuple[Union[nx.Graph, Any], np.ndarray]: + """Computes the transition matrix between clusters and the autocorrelation in the sequence. + + Parameters: + - cluster_sequence (numpy.array): + - nclusts (int): + - autocorrelation (bool): + - return_graph (bool): + + Returns: + - trans_normed (numpy.array / networkx.Graph: + - autocorr (numpy.array): """ # Stores all possible transitions between clusters @@ -1118,7 +1130,8 @@ def cluster_transition_matrix( if autocorrelation: cluster_sequence = list(map(int, cluster_sequence)) - return trans_normed, np.corrcoef(cluster_sequence[:-1], cluster_sequence[1:]) + autocorr = np.corrcoef(cluster_sequence[:-1], cluster_sequence[1:]) + return trans_normed, autocorr return trans_normed diff --git a/tests/test_utils.py b/tests/test_utils.py index bad1b8bf91b84416c4eb962be8cba09f9e754dd5..34739cf872454e9ea10f4b4b6818bd9d88c17988 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -696,3 +696,37 @@ def test_gmm_model_selection(x, sampler): ) == 3 ) + + +@settings(deadline=None) +@given(sampler=st.data(), autocorrelation=st.booleans(), return_graph=st.booleans()) +def test_cluster_transition_matrix(sampler, autocorrelation, return_graph): + + nclusts = sampler.draw(st.integers(min_value=1, max_value=10)) + cluster_sequence = sampler.draw( + arrays( + dtype=int, + shape=st.tuples(st.integers(min_value=10, max_value=1000)), + elements=st.integers(min_value=1, max_value=nclusts), + ).filter(lambda x: len(set(x)) != 1) + ) + + trans = cluster_transition_matrix( + cluster_sequence, nclusts, autocorrelation, return_graph + ) + + if autocorrelation: + assert len(trans) == 2 + + if return_graph: + assert type(trans[0]) == nx.Graph + else: + assert type(trans[0]) == np.ndarray + + assert type(trans[1]) == np.ndarray + + else: + if return_graph: + assert type(trans) == nx.Graph + else: + assert type(trans) == np.ndarray