Commit f5130f60 authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed bug in transition matrix obtention

parent 26a21aa0
......@@ -8,7 +8,7 @@ import numpy as np
import pandas as pd
import pickle
import pims
import re
import regex as re
import scipy
import seaborn as sns
from copy import deepcopy
......@@ -761,14 +761,16 @@ def GMM_Model_Selection(
def cluster_transition_matrix(
cluster_sequence, autocorrelation=True, return_graph=False
cluster_sequence, nclusts, autocorrelation=True, return_graph=False
):
"""
Computes the transition matrix between clusters and the autocorrelation in the sequence.
"""
# Stores all possible transitions between clusters
clusters = set(cluster_sequence)
clusters = [str(i) for i in range(nclusts)]
cluster_sequence = cluster_sequence.astype(str)
trans = {t: 0 for t in product(clusters, clusters)}
k = len(clusters)
......@@ -776,14 +778,14 @@ def cluster_transition_matrix(
transtr = "".join(list(cluster_sequence))
# Assigns to each transition the number of times it occurs in the sequence
for t in trans:
for t in trans.keys():
trans[t] = len(re.findall("".join(t), transtr, overlapped=True))
# Normalizes the counts to add up to 1 for each departing cluster
trans_normed = np.zeros([k, k])
for t in trans:
trans_normed = np.zeros([k, k]) + 1e-5
for t in trans.keys():
trans_normed[int(t[0]), int(t[1])] = np.round(
trans[t] / sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()), 3
trans[t] / (sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()) + 1e-5), 3
)
# If specified, returns the transition matrix as an nx.Graph object
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment