Commit 9219111b authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 56de58c2
...@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat ...@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat
from itertools import combinations from itertools import combinations
from typing import Any, Tuple from typing import Any, Tuple
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
...@@ -203,7 +203,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -203,7 +203,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
""" """
def __init__(self, k=5, samples=1000): def __init__(self, k=5, samples=10000):
super().__init__() super().__init__()
self.k = k self.k = k
self.samples = samples self.samples = samples
...@@ -227,16 +227,40 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -227,16 +227,40 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
self.model.layers[0].input, cluster_assignment.output self.model.layers[0].input, cluster_assignment.output
) )
trial_idxs = np.random.choice( # Use encoder and grouper to predict on validation data
range(self.validation_data.shape[0]), self.samples encoding = encoder.predict(self.validation_data)
groups = grouper.predict(self.validation_data)
# Multiply encodings by groups, to get a weighted version of the matrix
encoding = (
encoding
* tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
) )
trial_data = self.validation_data[trial_idxs] hard_groups = groups.argmax(axis=1)
# Use encoder and grouper to predict on validation data # Fit KNN model
encoding = encoder.predict(trial_data) knn = NearestNeighbors().fit(encoding)
groups = grouper.predict(trial_data)
# # Iterate over samples and compute purity over k neighbours
random_idxs = np.random.choice(
range(encoding.shape[0]), self.samples, replace=False
)
purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs):
indexes = knn.kneighbors(
encoding[sample][np.newaxis, :], self.k, return_distance=False
)
purity_vector[i] = (
np.sum(hard_groups[indexes] == hard_groups[sample])
/ self.k
* np.max(groups[sample])
)
self.add_metric(
self.purity_vector,
aggregation="mean",
name="knn_cluster_purity",
)
class uncorrelated_features_constraint(Constraint): class uncorrelated_features_constraint(Constraint):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment