Skip to content
Snippets Groups Projects
Commit c25f8d92 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Changed epochs default for model training

parent c87ee026
Branches
Tags
No related merge requests found
......@@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE:
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 3,
"dense_layers_per_branch": 1,
"dropout_rate": 0.05,
"learning_rate": 1e-3,
"units_conv": 64,
......@@ -399,7 +399,7 @@ class SEQ_2_SEQ_GMVAE:
use_bias=True,
)
Model_E4 = [
seq_E = [
Dense(
self.DENSE_2,
activation=self.dense_activation,
......@@ -409,13 +409,16 @@ class SEQ_2_SEQ_GMVAE:
)
for _ in range(self.dense_layers_per_branch)
]
Model_E4 = []
for l in seq_E:
Model_E4.append(l)
Model_E4.append(BatchNormalization())
# Decoder layers
Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
Model_D1 = [
seq_D = [
Dense(
self.DENSE_2,
activation=self.dense_activation,
......@@ -424,6 +427,11 @@ class SEQ_2_SEQ_GMVAE:
)
for _ in range(self.dense_layers_per_branch)
]
Model_D1 = []
for l in seq_D:
Model_D1.append(l)
Model_D1.append(BatchNormalization())
Model_D2 = Dense(
self.DENSE_1,
activation=self.dense_activation,
......@@ -511,7 +519,6 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
......@@ -540,7 +547,6 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
......@@ -565,7 +571,6 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Sequential(Model_E4)(encoder)
# encoder = BatchNormalization()(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(
......@@ -626,7 +631,9 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
),
reinterpreted_batch_ndims=1,
)
......@@ -674,22 +681,28 @@ class SEQ_2_SEQ_GMVAE:
# Define and instantiate generator
g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g)
generator = Model_B1(generator)
generator = Model_D2(generator)
generator = Model_B2(generator)
generator = Model_B1(generator)
generator = Model_D3(generator)
generator = Model_D4(generator)
generator = Model_B3(generator)
generator = Model_B2(generator)
generator = Model_D5(generator)
generator = Model_B4(generator)
generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
generator
generator = Model_B3(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
x_decoded_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
)
x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
x_decoded = tf.keras.layers.concatenate(
[x_decoded_mean, x_decoded_var], axis=-1
)
x_decoded_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_reconstruction",
)(generator)
)(x_decoded)
# define individual branches as models
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
......@@ -720,14 +733,25 @@ class SEQ_2_SEQ_GMVAE:
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
predictor = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
predictor
x_predicted_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(predictor)
x_predicted_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
predictor
)
)
x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
x_predicted_var
)
x_decoded = tf.keras.layers.concatenate(
[x_predicted_mean, x_predicted_var], axis=-1
)
x_predicted_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_prediction",
)(predictor)
)(x_decoded)
model_outs.append(x_predicted_mean)
model_losses.append(log_loss)
......
......@@ -633,16 +633,16 @@ def tune_search(
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:, 0]]
yvals += [y_val[-Xvals.shape[0]:, 0]]
ys += [y_train[-Xs.shape[0] :, 0]]
yvals += [y_val[-Xvals.shape[0] :, 0]]
# Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:]
y_val = y_val[:, 1:]
if rule_based_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:]]
yvals += [y_val[-Xvals.shape[0]:]]
ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]]
tuner.search(
Xs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment