Commit 3243a42f authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated default GMVAE models

parent 00cff2d9
Pipeline #87793 passed with stage
in 19 minutes and 6 seconds
...@@ -122,8 +122,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -122,8 +122,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
default=0.0, default=0.0,
sampling="linear", sampling="linear",
) )
encoding = 16 #hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, ) encoding = (
k = 5 # hp.Int( 16 # hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, )
)
k = 5 # hp.Int(
# "n_components", # "n_components",
# min_value=self.number_of_components - 5, # min_value=self.number_of_components - 5,
# max_value=self.number_of_components + 5, # max_value=self.number_of_components + 5,
......
...@@ -433,7 +433,7 @@ class Gaussian_mixture_overlap(Layer): ...@@ -433,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
dists = [] dists = []
for k in range(self.n_components): for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],) locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k]) scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append( dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]) tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
......
...@@ -133,7 +133,7 @@ class SEQ_2_SEQ_AE: ...@@ -133,7 +133,7 @@ class SEQ_2_SEQ_AE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
) )
) )
Model_D5 = Bidirectional( Model_D5 = Bidirectional(
...@@ -142,7 +142,7 @@ class SEQ_2_SEQ_AE: ...@@ -142,7 +142,7 @@ class SEQ_2_SEQ_AE:
activation="sigmoid", activation="sigmoid",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
) )
) )
...@@ -231,6 +231,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -231,6 +231,7 @@ class SEQ_2_SEQ_GMVAE:
architecture_hparams: dict = {}, architecture_hparams: dict = {},
batch_size: int = 256, batch_size: int = 256,
compile_model: bool = True, compile_model: bool = True,
dense_activation: str = "elu",
entropy_reg_weight: float = 0.0, entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0, huber_delta: float = 1.0,
initialiser_iters: int = int(1e4), initialiser_iters: int = int(1e4),
...@@ -245,6 +246,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -245,6 +246,7 @@ class SEQ_2_SEQ_GMVAE:
self.hparams = self.get_hparams(architecture_hparams) self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size self.batch_size = batch_size
self.CONV_filters = self.hparams["units_conv"] self.CONV_filters = self.hparams["units_conv"]
self.dense_activation = dense_activation
self.LSTM_units_1 = self.hparams["units_lstm"] self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2) self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.DENSE_1 = int(self.hparams["units_lstm"] / 2) self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
...@@ -332,7 +334,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -332,7 +334,7 @@ class SEQ_2_SEQ_GMVAE:
kernel_size=5, kernel_size=5,
strides=1, strides=1,
padding="causal", padding="causal",
activation="elu", activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
...@@ -342,7 +344,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -342,7 +344,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=0), # kernel_constraint=UnitNorm(axis=0),
use_bias=True, use_bias=True,
) )
) )
...@@ -352,21 +354,21 @@ class SEQ_2_SEQ_GMVAE: ...@@ -352,21 +354,21 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=False, return_sequences=False,
kernel_constraint=UnitNorm(axis=0), # kernel_constraint=UnitNorm(axis=0),
use_bias=True, use_bias=True,
) )
) )
Model_E3 = Dense( Model_E3 = Dense(
self.DENSE_1, self.DENSE_1,
activation="elu", activation=self.dense_activation,
kernel_constraint=UnitNorm(axis=0), # kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
Model_E4 = Dense( Model_E4 = Dense(
self.DENSE_2, self.DENSE_2,
activation="elu", activation=self.dense_activation,
kernel_constraint=UnitNorm(axis=0), # kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
...@@ -378,13 +380,13 @@ class SEQ_2_SEQ_GMVAE: ...@@ -378,13 +380,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B4 = BatchNormalization() Model_B4 = BatchNormalization()
Model_D1 = Dense( Model_D1 = Dense(
self.DENSE_2, self.DENSE_2,
activation="elu", activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
Model_D2 = Dense( Model_D2 = Dense(
self.DENSE_1, self.DENSE_1,
activation="elu", activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
...@@ -395,7 +397,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -395,7 +397,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
) )
...@@ -405,7 +407,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -405,7 +407,7 @@ class SEQ_2_SEQ_GMVAE:
activation="sigmoid", activation="sigmoid",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
) )
...@@ -413,7 +415,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -413,7 +415,7 @@ class SEQ_2_SEQ_GMVAE:
# Predictor layers # Predictor layers
Model_P1 = Dense( Model_P1 = Dense(
self.DENSE_1, self.DENSE_1,
activation="elu", activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
...@@ -423,7 +425,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -423,7 +425,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
) )
...@@ -433,14 +435,16 @@ class SEQ_2_SEQ_GMVAE: ...@@ -433,14 +435,16 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
) )
# Phenotype classification layers # Phenotype classification layers
Model_PC1 = Dense( Model_PC1 = Dense(
self.number_of_components, activation="elu", kernel_initializer=he_uniform() self.number_of_components,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
) )
return ( return (
...@@ -596,7 +600,9 @@ class SEQ_2_SEQ_GMVAE: ...@@ -596,7 +600,9 @@ class SEQ_2_SEQ_GMVAE:
if self.predictor > 0: if self.predictor > 0:
# Define and instantiate predictor # Define and instantiate predictor
predictor = Dense( predictor = Dense(
self.DENSE_2, activation="elu", kernel_initializer=he_uniform() self.DENSE_2,
activation=self.dense_activation,
kernel_initializer=he_uniform(),
)(z) )(z)
predictor = BatchNormalization()(predictor) predictor = BatchNormalization()(predictor)
predictor = Model_P1(predictor) predictor = Model_P1(predictor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment