diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index a4809fa4a4c6894c079b9dfc67969502a3dc2787..ba413ef03d67faa9347d6926b08f56de7b390a07 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -107,11 +107,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): # Architectural hyperparameters clipvalue = hp.Float( - "clipvalue", - min_value=0.0, - max_value=1.0, - default=0.5, - sampling="Linear" + "clipvalue", min_value=0.0, max_value=1.0, default=0.5, sampling="Linear" ) conv_filters = hp.Int( "units_conv", min_value=128, max_value=160, step=16, default=128, @@ -120,6 +116,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): "units_dense2", min_value=120, max_value=180, step=10, default=150, ) dense_activation = hp.Choice("dense_activation", values=["elu", "relu"]) + dense_layers_per_branch = hp.Int("dense_layers_per_branch", min_value=1, max_value=3, default=1) dropout_rate = hp.Float( "dropout_rate", min_value=0.0, @@ -138,6 +135,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): conv_filters, dense_2, dense_activation, + dense_layers_per_branch, dropout_rate, encoding, k, @@ -153,6 +151,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): conv_filters, dense_2, dense_activation, + dense_layers_per_branch, dropout_rate, encoding, k, @@ -164,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): "clipvalue": clipvalue, "dense_activation": dense_activation, "dropout_rate": dropout_rate, + "dense_layers_per_branch": dense_layers_per_branch, "encoding": encoding, "units_conv": conv_filters, "units_dense_2": dense_2, diff --git a/deepof/models.py b/deepof/models.py index 5aae169bb940bb5ef24e5a5d9c34b8ea7dbee4ca..d00cc8243b1fae9f2042c84194d1f3addd81cd6e 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -253,6 +253,7 @@ class SEQ_2_SEQ_GMVAE: self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2) self.clipvalue = self.hparams["clipvalue"] self.dense_activation = self.hparams["dense_activation"] + self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"] self.learn_rate = self.hparams["learning_rate"] self.lstm_unroll = True self.compile = compile_model @@ -313,14 +314,15 @@ class SEQ_2_SEQ_GMVAE: """Sets the default parameters for the model. Overwritable with a dictionary""" defaults = { - "units_conv": 256, - "units_lstm": 256, - "units_dense2": 64, + "clipvalue": 0.5, + "dense_activation": "elu", + "dense_layers_per_branch": 1, "dropout_rate": 0.15, "encoding": 16, "learning_rate": 1e-3, - "clipvalue": 0.5, - "dense_activation": "elu", + "units_conv": 256, + "units_dense2": 64, + "units_lstm": 256, } for k, v in params.items(): @@ -370,25 +372,32 @@ class SEQ_2_SEQ_GMVAE: kernel_initializer=he_uniform(), use_bias=True, ) - Model_E4 = Dense( - self.DENSE_2, - activation=self.dense_activation, - # kernel_constraint=UnitNorm(axis=0), - kernel_initializer=he_uniform(), - use_bias=True, - ) + + Model_E4 = [ + Dense( + self.DENSE_2, + activation=self.dense_activation, + # kernel_constraint=UnitNorm(axis=0), + kernel_initializer=he_uniform(), + use_bias=True, + ) + for _ in self.dense_layers_per_branch + ] # Decoder layers Model_B1 = BatchNormalization() Model_B2 = BatchNormalization() Model_B3 = BatchNormalization() Model_B4 = BatchNormalization() - Model_D1 = Dense( - self.DENSE_2, - activation=self.dense_activation, - kernel_initializer=he_uniform(), - use_bias=True, - ) + Model_D1 = [ + Dense( + self.DENSE_2, + activation=self.dense_activation, + kernel_initializer=he_uniform(), + use_bias=True, + ) + for _ in self.dense_layers_per_branch + ] Model_D2 = Dense( self.DENSE_1, activation=self.dense_activation, @@ -516,7 +525,7 @@ class SEQ_2_SEQ_GMVAE: encoder = Model_E3(encoder) encoder = BatchNormalization()(encoder) encoder = Dropout(self.DROPOUT_RATE)(encoder) - encoder = Model_E4(encoder) + encoder = Sequential(Model_E4)(encoder) encoder = BatchNormalization()(encoder) # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) @@ -546,7 +555,7 @@ class SEQ_2_SEQ_GMVAE: tfd.Independent( tfd.Normal( loc=gauss[1][..., : self.ENCODING, k], - scale=softplus(gauss[1][..., self.ENCODING :, k]), + scale=softplus(gauss[1][..., self.ENCODING:, k]), ), reinterpreted_batch_ndims=1, ) @@ -588,7 +597,7 @@ class SEQ_2_SEQ_GMVAE: )(z) # Define and instantiate generator - generator = Model_D1(z) + generator = Sequential(Model_D1)(z) generator = Model_B1(generator) generator = Model_D2(generator) generator = Model_B2(generator) @@ -650,7 +659,7 @@ class SEQ_2_SEQ_GMVAE: # Build generator as a separate entity g = Input(shape=self.ENCODING) - _generator = Model_D1(g) + _generator = Sequential(Model_D1)(g) _generator = Model_B1(_generator) _generator = Model_D2(_generator) _generator = Model_B2(_generator)