Commit 9914123a authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated preprocessing alignment; now all frames all treated equally regardless...

Updated preprocessing alignment; now all frames all treated equally regardless their position in the sliding window
parent 696eab02
This diff is collapsed.
......@@ -138,7 +138,7 @@ class UncorrelatedFeaturesConstraint(Constraint):
return covariance
# Constraint penalty
def uncorrelated_feature(self, x):
def uncorrelated_feature(self):
if self.encoding_dim <= 1:
return 0.0
else:
......@@ -301,21 +301,14 @@ class Gaussian_mixture_overlap(Layer):
return target
class Latent_space_control(Layer):
class Dead_neuron_control(Layer):
"""
Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
def __init__(self, silhouette=False, loss=False, *args, **kwargs):
self.loss = loss
self.silhouette = silhouette
super(Latent_space_control, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"loss": self.loss})
config.update({"silhouette": self.silhouette})
def __init__(self, *args, **kwargs):
super(Dead_neuron_control, self).__init__(*args, **kwargs)
def call(self, z, z_gauss, z_cat, **kwargs):
......@@ -324,17 +317,6 @@ class Latent_space_control(Layer):
tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
)
# Adds Silhouette score controlling overlap between clusters
if self.silhouette:
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(
silhouette_score, [z, hard_labels], tf.float32
)
self.add_metric(silhouette, aggregation="mean", name="silhouette")
if self.loss:
self.add_loss(-K.mean(silhouette), inputs=[z, hard_labels])
return z
......
......@@ -193,7 +193,6 @@ class SEQ_2_SEQ_GMVAE:
self.entropy_reg_weight = entropy_reg_weight
if self.prior == "standard_normal":
init_means = far_away_uniform_initialiser(
shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5
)
......@@ -315,10 +314,9 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder)
encoding_shuffle = MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(
self.number_of_components,
activation="softmax",
)(encoding_shuffle)
z_cat = Dense(self.number_of_components, activation="softmax",)(
encoding_shuffle
)
z_cat = Entropy_regulariser(self.entropy_reg_weight)(z_cat)
z_gauss = Dense(
tfpl.IndependentNormal.params_size(
......@@ -382,7 +380,7 @@ class SEQ_2_SEQ_GMVAE:
)(z)
# Identity layer controlling clustering and latent space statistics
z = Latent_space_control(loss=self.overlap_loss)(z, z_gauss, z_cat)
z = Dead_neuron_control(loss=self.overlap_loss)(z, z_gauss, z_cat)
# Define and instantiate generator
generator = Model_D1(z)
......
......@@ -592,10 +592,13 @@ class table_dict(dict):
if verbose:
print("Done!")
if align == "all":
X_train = align_trajectories(X_train, align)
X_train = rolling_window(X_train, window_size, window_step)
if align:
X_train = align_trajectories(X_train)
if align == "center":
X_train = align_trajectories(X_train, align)
if filter == "gaussian":
r = range(-int(window_size / 2), int(window_size / 2) + 1)
......@@ -612,10 +615,14 @@ class table_dict(dict):
X_train = X_train * g.reshape(1, window_size, 1)
if test_videos:
if align == "all":
X_test = align_trajectories(X_test, align)
X_test = rolling_window(X_test, window_size, window_step)
if align:
X_test = align_trajectories(X_test)
if align == "center":
X_test = align_trajectories(X_test, align)
if filter == "gaussian":
X_test = X_test * g.reshape(1, window_size, 1)
......
......@@ -4,6 +4,7 @@ import cv2
import matplotlib.pyplot as plt
import multiprocessing
import networkx as nx
from numba import njit
import numpy as np
import pandas as pd
import pickle
......@@ -97,20 +98,33 @@ def rotate(p, angles, origin=np.array([0, 0])):
return np.squeeze((R @ (p.T - o.T) + o.T).T)
def align_trajectories(data):
def align_trajectories(data, mode="all"):
"""
mode: all aligns all frames in the data
mode: center aligns only the central frame
"""
data = deepcopy(data)
dshape = data.shape
center_time = (data.shape[1] - 1) // 2
angles = np.arctan2(data[:, center_time, 0], data[:, center_time, 1])
if mode == "center":
center_time = (data.shape[1] - 1) // 2
angles = np.arctan2(data[:, center_time, 0], data[:, center_time, 1])
elif mode == "all":
data = data.reshape(-1, dshape[-1])
angles = np.arctan2(data[:, 0], data[:, 1])
aligned_trajs = np.zeros(data.shape)
for frame in range(data.shape[0]):
aligned_trajs[frame] = rotate(
data[frame].reshape([data.shape[1] * data.shape[2] // 2, 2]), angles[frame],
data[frame].reshape([-1, 2]), angles[frame],
).reshape(data.shape[1:])
if mode == "all":
aligned_trajs = aligned_trajs.reshape(dshape)
return aligned_trajs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment