Commit c25f8d92 authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed epochs default for model training

parent c87ee026
......@@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE:
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 3,
"dense_layers_per_branch": 1,
"dropout_rate": 0.05,
"learning_rate": 1e-3,
"units_conv": 64,
......@@ -399,7 +399,7 @@ class SEQ_2_SEQ_GMVAE:
use_bias=True,
)
Model_E4 = [
seq_E = [
Dense(
self.DENSE_2,
activation=self.dense_activation,
......@@ -409,13 +409,16 @@ class SEQ_2_SEQ_GMVAE:
)
for _ in range(self.dense_layers_per_branch)
]
Model_E4 = []
for l in seq_E:
Model_E4.append(l)
Model_E4.append(BatchNormalization())
# Decoder layers
Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
Model_D1 = [
seq_D = [
Dense(
self.DENSE_2,
activation=self.dense_activation,
......@@ -424,6 +427,11 @@ class SEQ_2_SEQ_GMVAE:
)
for _ in range(self.dense_layers_per_branch)
]
Model_D1 = []
for l in seq_D:
Model_D1.append(l)
Model_D1.append(BatchNormalization())
Model_D2 = Dense(
self.DENSE_1,
activation=self.dense_activation,
......@@ -511,7 +519,6 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
......@@ -540,7 +547,6 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
......@@ -565,7 +571,6 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Sequential(Model_E4)(encoder)
# encoder = BatchNormalization()(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(
......@@ -626,7 +631,9 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
),
reinterpreted_batch_ndims=1,
)
......@@ -674,22 +681,28 @@ class SEQ_2_SEQ_GMVAE:
# Define and instantiate generator
g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g)
generator = Model_B1(generator)
generator = Model_D2(generator)
generator = Model_B2(generator)
generator = Model_B1(generator)
generator = Model_D3(generator)
generator = Model_D4(generator)
generator = Model_B3(generator)
generator = Model_B2(generator)
generator = Model_D5(generator)
generator = Model_B4(generator)
generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
generator
generator = Model_B3(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
x_decoded_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
)
x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
x_decoded = tf.keras.layers.concatenate(
[x_decoded_mean, x_decoded_var], axis=-1
)
x_decoded_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_reconstruction",
)(generator)
)(x_decoded)
# define individual branches as models
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
......@@ -720,14 +733,25 @@ class SEQ_2_SEQ_GMVAE:
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
predictor = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
predictor
x_predicted_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(predictor)
x_predicted_var = tf.keras.activations.softplus(
Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
predictor
)
)
x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
x_predicted_var
)
x_decoded = tf.keras.layers.concatenate(
[x_predicted_mean, x_predicted_var], axis=-1
)
x_predicted_mean = tfpl.IndependentNormal(
event_shape=input_shape[2:],
convert_to_tensor_fn=tfp.distributions.Distribution.mean,
name="vae_prediction",
)(predictor)
)(x_decoded)
model_outs.append(x_predicted_mean)
model_losses.append(log_loss)
......
......@@ -633,16 +633,16 @@ def tune_search(
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:, 0]]
yvals += [y_val[-Xvals.shape[0]:, 0]]
ys += [y_train[-Xs.shape[0] :, 0]]
yvals += [y_val[-Xvals.shape[0] :, 0]]
# Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:]
y_val = y_val[:, 1:]
if rule_based_prediction > 0.0:
ys += [y_train[-Xs.shape[0]:]]
yvals += [y_val[-Xvals.shape[0]:]]
ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]]
tuner.search(
Xs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment