Commit dc539f02 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 4cf9fb3f
Pipeline #102382 canceled with stages
in 10 minutes and 53 seconds
......@@ -232,95 +232,6 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
)
class neighbor_latent_entropy(tf.keras.callbacks.Callback):
"""
Latent space entropy callback. Computes the entropy of cluster assignment across k nearest neighbors of a subset
of samples in the latent space.
"""
def __init__(
self,
encoding_dim: int,
validation_data: np.ndarray = None,
k: int = 100,
samples: int = 10000,
log_dir: str = ".",
):
super().__init__()
self.enc = encoding_dim
self.validation_data = validation_data
self.k = k
self.samples = samples
self.log_dir = log_dir
# noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, epoch, logs=None):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
if self.validation_data is not None:
# Get encoer and grouper from full model
latent_distribution = [
layer
for layer in self.model.layers
if layer.name == "latent_distribution"
][0]
cluster_assignment = [
layer
for layer in self.model.layers
if layer.name == "cluster_assignment"
][0]
encoder = tf.keras.models.Model(
self.model.layers[0].input, latent_distribution.output
)
grouper = tf.keras.models.Model(
self.model.layers[0].input, cluster_assignment.output
)
# Use encoder and grouper to predict on validation data
encoding = encoder.predict(self.validation_data)
groups = grouper.predict(self.validation_data)
hard_groups = groups.argmax(axis=1)
max_groups = groups.max(axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, encoding.shape[0]])
random_idxs = np.random.choice(
range(encoding.shape[0]), self.samples, replace=False
)
@tf.function
def get_local_neighbourhood_entropy(index):
return get_neighbourhood_entropy(
index, tensor=encoding, clusters=hard_groups, k=self.k
)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
)
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
tf.summary.scalar(
"number_of_populated_clusters",
data=len(set(hard_groups[max_groups >= 0.25])),
step=epoch,
)
tf.summary.scalar(
"average_neighborhood_cluster_entropy",
data=np.average(purity_vector, weights=max_groups[random_idxs]),
step=epoch,
)
tf.summary.scalar(
"average_confidence_in_selected_cluster",
data=np.average(max_groups),
step=epoch,
)
class uncorrelated_features_constraint(Constraint):
"""
......
......@@ -128,21 +128,13 @@ def get_callbacks(
profile_batch=2,
)
entropy = deepof.model_utils.neighbor_latent_entropy(
encoding_dim=logparam["encoding"],
k=entropy_knn,
samples=entropy_samples,
validation_data=X_val,
log_dir=os.path.join(outpath, "metrics", run_ID),
)
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250,
max_rate=0.005,
log_dir=os.path.join(outpath, "metrics", run_ID),
)
callbacks = [run_ID, tensorboard_callback, entropy, onecycle]
callbacks = [run_ID, tensorboard_callback, onecycle]
if cp:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment