Commit c74ddedc authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented Markov transition matrices and autocorrelation calculations in...

Implemented Markov transition matrices and autocorrelation calculations in utils.py; EarlyStopping now works on val_intercomponent_mmd instead of val_mae in model_training.py
parent 12de8cde
......@@ -95,14 +95,14 @@ parser.add_argument(
"-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
default=False,
type=str2bool
type=str2bool,
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512
default=512,
)
args = parser.parse_args()
......@@ -395,7 +395,7 @@ if not variational:
tensorboard_callback,
cp_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
"val_intercomponent_mmd", patience=5, restore_best_weights=True
),
],
)
......
......@@ -167,9 +167,7 @@ class Gaussian_mixture_overlap(Layer):
using a specified metric (MMD, Wasserstein, Fischer-Rao)
"""
def __init__(
self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs
):
def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
self.lat_dims = lat_dims
self.n_components = n_components
self.loss = loss
......@@ -224,13 +222,15 @@ class Latent_space_control(Layer):
to the metrics compiled by the model
"""
def __init__(self, loss=False, *args, **kwargs):
def __init__(self, silhouette=False, loss=False, *args, **kwargs):
self.loss = loss
self.silhouette = silhouette
super(Latent_space_control, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"loss": self.loss})
config.update({"silhouette": silhouette})
def call(self, z, z_gauss, z_cat, **kwargs):
......@@ -240,11 +240,14 @@ class Latent_space_control(Layer):
)
# Adds Silhouette score controlling overlap between clusters
# hard_labels = tf.math.argmax(z_cat, axis=1)
# silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
# self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.silhouette:
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(
silhouette_score, [z, hard_labels], tf.float32
)
self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
return z
......@@ -196,9 +196,7 @@ class SEQ_2_SEQ_GMVAE:
[
tfd.Independent(
tfd.Normal(
loc=tf.random.normal(
shape=[self.ENCODING], stddev=10
),
loc=tf.random.normal(shape=[self.ENCODING], stddev=10),
scale=1,
),
reinterpreted_batch_ndims=1,
......
......@@ -3,6 +3,7 @@
import cv2
import matplotlib.pyplot as plt
import multiprocessing
import networkx as nx
import numpy as np
import pandas as pd
import pickle
......@@ -10,7 +11,7 @@ import pims
import re
import scipy
import seaborn as sns
from itertools import cycle, combinations
from itertools import cycle, combinations, product
from joblib import Parallel, delayed
from numba import jit
from numpy.core.umath_tests import inner1d
......@@ -729,6 +730,45 @@ def GMM_Model_Selection(
return bic, m_bic, best_bic_gmm
##### RESULT ANALYSIS FUNCTIONS #####
def cluster_transition_matrix(
cluster_sequence, autocorrelation=True, return_graph=False
):
"""
Computes the transition matrix between clusters and the autocorrelation in the sequence.
"""
# Stores all possible transitions between clusters
clusters = set(cluster_sequence)
trans = {t: 0 for t in product(clusters, clusters)}
k = len(clusters)
# Stores the cluster sequence as a string
transtr = "".join(list(cluster_sequence))
# Assigns to each transition the number of times it occurs in the sequence
for t in trans:
trans[t] = len(re.findall("".join(t), transtr, overlapped=True))
# Normalizes the counts to add up to 1 for each departing cluster
trans_normed = np.zeros([k, k])
for t in trans:
trans_normed[int(t[0]), int(t[1])] = np.round(
trans[t] / sum({i: j for i, j in trans.items() if i[0] == t[0]}.values()), 3
)
# If specified, returns the transition matrix as an nx.Graph object
if return_graph:
trans_normed = nx.Graph(trans_normed)
if autocorrelation:
cluster_sequence = list(map(int, cluster_sequence))
return trans_normed, np.corrcoef(cluster_sequence[:-1], cluster_sequence[1:])
return trans_normed
##### PLOTTING FUNCTIONS #####
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment