Commit 9219111b authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 56de58c2
......@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat
from itertools import combinations
from typing import Any, Tuple
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
......@@ -203,7 +203,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def __init__(self, k=5, samples=1000):
def __init__(self, k=5, samples=10000):
super().__init__()
self.k = k
self.samples = samples
......@@ -227,16 +227,40 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
self.model.layers[0].input, cluster_assignment.output
)
trial_idxs = np.random.choice(
range(self.validation_data.shape[0]), self.samples
# Use encoder and grouper to predict on validation data
encoding = encoder.predict(self.validation_data)
groups = grouper.predict(self.validation_data)
# Multiply encodings by groups, to get a weighted version of the matrix
encoding = (
encoding
* tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
)
trial_data = self.validation_data[trial_idxs]
hard_groups = groups.argmax(axis=1)
# Use encoder and grouper to predict on validation data
encoding = encoder.predict(trial_data)
groups = grouper.predict(trial_data)
# Fit KNN model
knn = NearestNeighbors().fit(encoding)
#
# Iterate over samples and compute purity over k neighbours
random_idxs = np.random.choice(
range(encoding.shape[0]), self.samples, replace=False
)
purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs):
indexes = knn.kneighbors(
encoding[sample][np.newaxis, :], self.k, return_distance=False
)
purity_vector[i] = (
np.sum(hard_groups[indexes] == hard_groups[sample])
/ self.k
* np.max(groups[sample])
)
self.add_metric(
self.purity_vector,
aggregation="mean",
name="knn_cluster_purity",
)
class uncorrelated_features_constraint(Constraint):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment