Commit 3243a42f authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated default GMVAE models

parent 00cff2d9
Pipeline #87793 passed with stage
in 19 minutes and 6 seconds
......@@ -122,8 +122,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
default=0.0,
sampling="linear",
)
encoding = 16 #hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, )
k = 5 # hp.Int(
encoding = (
16 # hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, )
)
k = 5 # hp.Int(
# "n_components",
# min_value=self.number_of_components - 5,
# max_value=self.number_of_components + 5,
......
......@@ -433,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
dists = []
for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
......
......@@ -133,7 +133,7 @@ class SEQ_2_SEQ_AE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
)
)
Model_D5 = Bidirectional(
......@@ -142,7 +142,7 @@ class SEQ_2_SEQ_AE:
activation="sigmoid",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
)
)
......@@ -231,6 +231,7 @@ class SEQ_2_SEQ_GMVAE:
architecture_hparams: dict = {},
batch_size: int = 256,
compile_model: bool = True,
dense_activation: str = "elu",
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
initialiser_iters: int = int(1e4),
......@@ -245,6 +246,7 @@ class SEQ_2_SEQ_GMVAE:
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
self.CONV_filters = self.hparams["units_conv"]
self.dense_activation = dense_activation
self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
......@@ -332,7 +334,7 @@ class SEQ_2_SEQ_GMVAE:
kernel_size=5,
strides=1,
padding="causal",
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -342,7 +344,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=0),
# kernel_constraint=UnitNorm(axis=0),
use_bias=True,
)
)
......@@ -352,21 +354,21 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=False,
kernel_constraint=UnitNorm(axis=0),
# kernel_constraint=UnitNorm(axis=0),
use_bias=True,
)
)
Model_E3 = Dense(
self.DENSE_1,
activation="elu",
kernel_constraint=UnitNorm(axis=0),
activation=self.dense_activation,
# kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
use_bias=True,
)
Model_E4 = Dense(
self.DENSE_2,
activation="elu",
kernel_constraint=UnitNorm(axis=0),
activation=self.dense_activation,
# kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -378,13 +380,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B4 = BatchNormalization()
Model_D1 = Dense(
self.DENSE_2,
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
Model_D2 = Dense(
self.DENSE_1,
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -395,7 +397,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
......@@ -405,7 +407,7 @@ class SEQ_2_SEQ_GMVAE:
activation="sigmoid",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
......@@ -413,7 +415,7 @@ class SEQ_2_SEQ_GMVAE:
# Predictor layers
Model_P1 = Dense(
self.DENSE_1,
activation="elu",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
......@@ -423,7 +425,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
......@@ -433,14 +435,16 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh",
recurrent_activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
# kernel_constraint=UnitNorm(axis=1),
use_bias=True,
)
)
# Phenotype classification layers
Model_PC1 = Dense(
self.number_of_components, activation="elu", kernel_initializer=he_uniform()
self.number_of_components,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)
return (
......@@ -596,7 +600,9 @@ class SEQ_2_SEQ_GMVAE:
if self.predictor > 0:
# Define and instantiate predictor
predictor = Dense(
self.DENSE_2, activation="elu", kernel_initializer=he_uniform()
self.DENSE_2,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)(z)
predictor = BatchNormalization()(predictor)
predictor = Model_P1(predictor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment