Skip to content
Snippets Groups Projects
Commit 3243a42f authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Updated default GMVAE models

parent 00cff2d9
Branches
Tags
No related merge requests found
Pipeline #87793 passed
......@@ -122,8 +122,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
default=0.0,
sampling="linear",
)
encoding = 16 #hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, )
k = 5 # hp.Int(
encoding = (
16 # hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, )
)
k = 5 # hp.Int(
# "n_components",
# min_value=self.number_of_components - 5,
# max_value=self.number_of_components + 5,
......
......@@ -433,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
dists = []
for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
......
......@@ -133,7 +133,7 @@ class SEQ_2_SEQ_AE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
)
)
Model_D5 = Bidirectional(
......@@ -142,7 +142,7 @@ class SEQ_2_SEQ_AE:
activation="sigmoid",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
)
)
......@@ -231,6 +231,7 @@ class SEQ_2_SEQ_GMVAE:
architecture_hparams: dict = {},
batch_size: int = 256,
compile_model: bool = True,
dense_activation: str = "elu",
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
initialiser_iters: int = int(1e4),
......@@ -245,6 +246,7 @@ class SEQ_2_SEQ_GMVAE:
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
self.CONV_filters = self.hparams["units_conv"]
self.dense_activation = dense_activation
self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
......@@ -332,7 +334,7 @@ class SEQ_2_SEQ_GMVAE:
kernel_size=5,
strides=1,
padding="causal",
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -342,7 +344,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=0),
# kernel_constraint=UnitNorm(axis=0),
use_bias=True,
)
)
......@@ -352,21 +354,21 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=False,
kernel_constraint=UnitNorm(axis=0),
# kernel_constraint=UnitNorm(axis=0),
use_bias=True,
)
)
Model_E3 = Dense(
self.DENSE_1,
activation="elu",
kernel_constraint=UnitNorm(axis=0),
activation=self.dense_activation,
# kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
use_bias=True,
)
Model_E4 = Dense(
self.DENSE_2,
activation="elu",
kernel_constraint=UnitNorm(axis=0),
activation=self.dense_activation,
# kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -378,13 +380,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B4 = BatchNormalization()
Model_D1 = Dense(
self.DENSE_2,
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
Model_D2 = Dense(
self.DENSE_1,
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -395,7 +397,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
......@@ -405,7 +407,7 @@ class SEQ_2_SEQ_GMVAE:
activation="sigmoid",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
......@@ -413,7 +415,7 @@ class SEQ_2_SEQ_GMVAE:
# Predictor layers
Model_P1 = Dense(
self.DENSE_1,
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -423,7 +425,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
......@@ -433,14 +435,16 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
# Phenotype classification layers
Model_PC1 = Dense(
self.number_of_components, activation="elu", kernel_initializer=he_uniform()
self.number_of_components,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)
return (
......@@ -596,7 +600,9 @@ class SEQ_2_SEQ_GMVAE:
if self.predictor > 0:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2, activation="elu", kernel_initializer=he_uniform()
self.DENSE_2,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)(z)
predictor = BatchNormalization()(predictor)
predictor = Model_P1(predictor)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment