Commit abc4fd8f authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent d4374203
...@@ -893,6 +893,7 @@ class coordinates: ...@@ -893,6 +893,7 @@ class coordinates:
mmd_warmup: int = 0, mmd_warmup: int = 0,
montecarlo_kl: int = 10, montecarlo_kl: int = 10,
n_components: int = 25, n_components: int = 25,
overlap_loss: float = 0,
output_path: str = ".", output_path: str = ".",
next_sequence_prediction: float = 0, next_sequence_prediction: float = 0,
phenotype_prediction: float = 0, phenotype_prediction: float = 0,
...@@ -958,6 +959,7 @@ class coordinates: ...@@ -958,6 +959,7 @@ class coordinates:
mmd_warmup=mmd_warmup, mmd_warmup=mmd_warmup,
montecarlo_kl=montecarlo_kl, montecarlo_kl=montecarlo_kl,
n_components=n_components, n_components=n_components,
overlap_loss=overlap_loss,
output_path=output_path, output_path=output_path,
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction, phenotype_prediction=phenotype_prediction,
......
...@@ -295,10 +295,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback): ...@@ -295,10 +295,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
@tf.function @tf.function
def get_local_neighbourhood_entropy(index): def get_local_neighbourhood_entropy(index):
return get_neighbourhood_entropy( return get_neighbourhood_entropy(
index, tensor=encodings, clusters=hard_groups, k=self.k index, tensor=encoding, clusters=hard_groups, k=self.k
) )
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs) purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32
)
writer = tf.summary.create_file_writer(self.log_dir) writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default(): with writer.as_default():
...@@ -558,7 +560,7 @@ class ClusterOverlap(Layer): ...@@ -558,7 +560,7 @@ class ClusterOverlap(Layer):
self, self,
encoding_dim: int, encoding_dim: int,
k: int = 100, k: int = 100,
loss_weight: float = False, loss_weight: float = 0.0,
samples: int = 512, samples: int = 512,
*args, *args,
**kwargs **kwargs
...@@ -581,19 +583,19 @@ class ClusterOverlap(Layer): ...@@ -581,19 +583,19 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples}) config.update({"samples": self.samples})
return config return config
@tf.function #@tf.function
def call(self, encodings, categorical, **kwargs): def call(self, inputs, **kwargs):
"""Updates Layer's call method""" """Updates Layer's call method"""
encodings, categorical = inputs[0], inputs[1]
hard_groups = tf.math.argmax(categorical, axis=1) hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1) max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood # Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, encodings.shape[0]]) self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]])
random_idxs = range(encoding.shape[0]) random_idxs = range(encodings.shape[0])
random_idxs = tf.random.categorical( random_idxs = np.random.choice(random_idxs, self.samples)
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples
)
@tf.function @tf.function
def get_local_neighbourhood_entropy(index): def get_local_neighbourhood_entropy(index):
...@@ -601,7 +603,9 @@ class ClusterOverlap(Layer): ...@@ -601,7 +603,9 @@ class ClusterOverlap(Layer):
index, tensor=encodings, clusters=hard_groups, k=self.k index, tensor=encodings, clusters=hard_groups, k=self.k
) )
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs) purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups[random_idxs] neighbourhood_entropy = purity_vector * max_groups[random_idxs]
...@@ -623,6 +627,8 @@ class ClusterOverlap(Layer): ...@@ -623,6 +627,8 @@ class ClusterOverlap(Layer):
) )
if self.loss_weight: if self.loss_weight:
self.add_loss(neighbourhood_entropy, inputs=[target, categorical]) self.add_loss(
self.loss_weight * neighbourhood_entropy, inputs=[target, categorical]
)
return encodings return encodings
...@@ -176,9 +176,9 @@ parser.add_argument( ...@@ -176,9 +176,9 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--overlap-loss", "--overlap-loss",
"-ol", "-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function", help="If > 0, adds a regularization term controlling for local cluster assignment entropy in the latent space",
type=deepof.utils.str2bool, type=float,
default=False, default=0,
) )
parser.add_argument( parser.add_argument(
"--next-sequence-prediction", "--next-sequence-prediction",
...@@ -263,7 +263,7 @@ mmd_annealing_mode = args.mmd_annealing_mode ...@@ -263,7 +263,7 @@ mmd_annealing_mode = args.mmd_annealing_mode
mmd_wu = args.mmd_warmup mmd_wu = args.mmd_warmup
mc_kl = args.montecarlo_kl mc_kl = args.montecarlo_kl
output_path = os.path.join(args.output_path) output_path = os.path.join(args.output_path)
overlap_loss = args.overlap_loss overlap_loss = float(args.overlap_loss)
next_sequence_prediction = float(args.next_sequence_prediction) next_sequence_prediction = float(args.next_sequence_prediction)
phenotype_prediction = float(args.phenotype_prediction) phenotype_prediction = float(args.phenotype_prediction)
rule_based_prediction = float(args.rule_based_prediction) rule_based_prediction = float(args.rule_based_prediction)
...@@ -397,6 +397,7 @@ if not tune: ...@@ -397,6 +397,7 @@ if not tune:
montecarlo_kl=mc_kl, montecarlo_kl=mc_kl,
n_components=k, n_components=k,
output_path=output_path, output_path=output_path,
overlap_loss=overlap_loss,
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction, rule_based_prediction=rule_based_prediction,
......
...@@ -291,6 +291,7 @@ def autoencoder_fitting( ...@@ -291,6 +291,7 @@ def autoencoder_fitting(
montecarlo_kl: int, montecarlo_kl: int,
n_components: int, n_components: int,
output_path: str, output_path: str,
overlap_loss: float,
next_sequence_prediction: float, next_sequence_prediction: float,
phenotype_prediction: float, phenotype_prediction: float,
rule_based_prediction: float, rule_based_prediction: float,
...@@ -394,7 +395,7 @@ def autoencoder_fitting( ...@@ -394,7 +395,7 @@ def autoencoder_fitting(
mmd_warmup_epochs=mmd_warmup, mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl, montecarlo_kl=montecarlo_kl,
number_of_components=n_components, number_of_components=n_components,
overlap_loss=False, overlap_loss=overlap_loss,
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction, rule_based_prediction=rule_based_prediction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment