Commit 8d3cb4c3 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent baabe478
......@@ -893,6 +893,7 @@ class coordinates:
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
overlap_loss: float = 0,
output_path: str = ".",
next_sequence_prediction: float = 0,
phenotype_prediction: float = 0,
......@@ -958,6 +959,7 @@ class coordinates:
mmd_warmup=mmd_warmup,
montecarlo_kl=montecarlo_kl,
n_components=n_components,
overlap_loss=overlap_loss,
output_path=output_path,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
......
......@@ -295,10 +295,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
@tf.function
def get_local_neighbourhood_entropy(index):
return get_neighbourhood_entropy(
index, tensor=encodings, clusters=hard_groups, k=self.k
index, tensor=encoding, clusters=hard_groups, k=self.k
)
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
)
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
......@@ -558,7 +560,7 @@ class ClusterOverlap(Layer):
self,
encoding_dim: int,
k: int = 100,
loss_weight: float = False,
loss_weight: float = 0.0,
samples: int = 512,
*args,
**kwargs
......@@ -581,19 +583,19 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples})
return config
@tf.function
def call(self, encodings, categorical, **kwargs):
#@tf.function
def call(self, inputs, **kwargs):
"""Updates Layer's call method"""
encodings, categorical = inputs[0], inputs[1]
hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, encodings.shape[0]])
random_idxs = range(encoding.shape[0])
random_idxs = tf.random.categorical(
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples
)
self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
random_idxs = range(encodings.shape[0])
random_idxs = np.random.choice(random_idxs, self.samples)
@tf.function
def get_local_neighbourhood_entropy(index):
......@@ -601,7 +603,9 @@ class ClusterOverlap(Layer):
index, tensor=encodings, clusters=hard_groups, k=self.k
)
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups[random_idxs]
......@@ -623,6 +627,8 @@ class ClusterOverlap(Layer):
)
if self.loss_weight:
self.add_loss(neighbourhood_entropy, inputs=[target, categorical])
self.add_loss(
self.loss_weight * neighbourhood_entropy, inputs=[target, categorical]
)
return encodings
......@@ -176,9 +176,9 @@ parser.add_argument(
parser.add_argument(
"--overlap-loss",
"-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
type=deepof.utils.str2bool,
default=False,
help="If > 0, adds a regularization term controlling for local cluster assignment entropy in the latent space",
type=float,
default=0,
)
parser.add_argument(
"--next-sequence-prediction",
......@@ -263,7 +263,7 @@ mmd_annealing_mode = args.mmd_annealing_mode
mmd_wu = args.mmd_warmup
mc_kl = args.montecarlo_kl
output_path = os.path.join(args.output_path)
overlap_loss = args.overlap_loss
overlap_loss = float(args.overlap_loss)
next_sequence_prediction = float(args.next_sequence_prediction)
phenotype_prediction = float(args.phenotype_prediction)
rule_based_prediction = float(args.rule_based_prediction)
......@@ -397,6 +397,7 @@ if not tune:
montecarlo_kl=mc_kl,
n_components=k,
output_path=output_path,
overlap_loss=overlap_loss,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
......
......@@ -291,6 +291,7 @@ def autoencoder_fitting(
montecarlo_kl: int,
n_components: int,
output_path: str,
overlap_loss: float,
next_sequence_prediction: float,
phenotype_prediction: float,
rule_based_prediction: float,
......@@ -394,7 +395,7 @@ def autoencoder_fitting(
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
number_of_components=n_components,
overlap_loss=False,
overlap_loss=overlap_loss,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment