Commit 32ebe84e authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent b9a11499
......@@ -43,7 +43,7 @@ class GMVAE:
loss: str = "ELBO",
mmd_annealing_mode: str = "sigmoid",
mmd_warmup_epochs: int = 20,
montecarlo_kl: int = 1,
montecarlo_kl: int = 10,
number_of_components: int = 1,
overlap_loss: float = 0.0,
next_sequence_prediction: float = 0.0,
......@@ -473,12 +473,13 @@ class GMVAE:
# Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
if self.number_of_components > 1:
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
# Define and instantiate generator
g = Input(shape=self.ENCODING)
......
......@@ -71,6 +71,7 @@ def get_callbacks(
phenotype_prediction: float,
next_sequence_prediction: float,
rule_based_prediction: float,
overlap_loss: float,
loss: str,
loss_warmup: int = 0,
warmup_mode: str = "none",
......@@ -101,14 +102,15 @@ def get_callbacks(
elif reg_cat_clusters and reg_cluster_variance:
latreg = "categorical+variance"
run_ID = "{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}".format(
("deepof_GMVAE"),
run_ID = "{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}".format(
"deepof_GMVAE",
("_input_type={}".format(input_type) if input_type else "coords"),
("_window_size={}".format(X_train.shape[1])),
("_NextSeqPred={}".format(next_sequence_prediction)),
("_PhenoPred={}".format(phenotype_prediction)),
("_RuleBasedPred={}".format(rule_based_prediction)),
("_NSPred={}".format(next_sequence_prediction)),
("_PPred={}".format(phenotype_prediction)),
("_RBPred={}".format(rule_based_prediction)),
("_loss={}".format(loss)),
("_overlap_loss={}".format(overlap_loss)),
("_loss_warmup={}".format(loss_warmup)),
("_warmup_mode={}".format(warmup_mode)),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
......@@ -345,6 +347,7 @@ def autoencoder_fitting(
rule_based_prediction=rule_based_prediction,
loss=loss,
loss_warmup=kl_warmup,
overlap_loss=overlap_loss,
warmup_mode=kl_annealing_mode,
input_type=input_type,
X_val=(X_val_dataset if X_val.shape != (0,) else None),
......
......@@ -266,15 +266,19 @@ def test_find_learning_rate():
def test_neighbor_latent_entropy():
X = np.random.normal(0, 1, [1500, 25, 6])
test_model = deepof.models.GMVAE()
train_dataset = tf.data.Dataset.from_tensor_slices((X, X))
train_dataset = train_dataset.batch(256, drop_remainder=True)
test_model = deepof.models.GMVAE(number_of_components=10)
gmvaep = test_model.build(X.shape)[3]
gmvaep.fit(
X,
X,
train_dataset,
epochs=1,
callbacks=deepof.model_utils.neighbor_latent_entropy(
k=10,
encoding_dim=6,
samples=100,
validation_data=X,
),
)
......@@ -45,6 +45,7 @@ def test_load_treatments():
next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
overlap_loss=st.floats(min_value=0.0, max_value=1.0),
)
def test_get_callbacks(
X_train,
......@@ -52,6 +53,7 @@ def test_get_callbacks(
next_sequence_prediction,
phenotype_prediction,
rule_based_prediction,
overlap_loss,
loss,
):
callbacks = deepof.train_utils.get_callbacks(
......@@ -60,6 +62,7 @@ def test_get_callbacks(
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_based_prediction,
overlap_loss=overlap_loss,
loss=loss,
X_val=X_train,
input_type=False,
......@@ -148,7 +151,6 @@ def test_autoencoder_fitting(
batch_size=st.integers(min_value=128, max_value=512),
encoding_size=st.integers(min_value=1, max_value=16),
hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")),
hypermodel=st.just("S2SGMVAE"),
k=st.integers(min_value=1, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")),
overlap_loss=st.floats(min_value=0.0, max_value=1.0),
......@@ -161,13 +163,12 @@ def test_tune_search(
batch_size,
encoding_size,
hpt_type,
hypermodel,
k,
loss,
overlap_loss,
next_sequence_prediction,
phenotype_prediction,
rule_based_prediction,
overlap_loss,
):
callbacks = list(
deepof.train_utils.get_callbacks(
......@@ -182,6 +183,7 @@ def test_tune_search(
cp=False,
reg_cat_clusters=True,
reg_cluster_variance=True,
overlap_loss=overlap_loss,
entropy_samples=10,
entropy_knn=5,
logparam={"encoding": 2, "k": 15},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment