diff --git a/main.ipynb b/main.ipynb
index b892655237958e6616d46b4a05b3a696be20ac63..13824c0eda1cbc873579b015ecc93feb615ed3ba 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -303,7 +303,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "NAME = 'Baseline_VAEP_short'\n",
+    "NAME = 'Baseline_VAEP_short_partially_untied'\n",
     "log_dir = os.path.abspath(\n",
     "    \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
     ")\n",
@@ -346,6 +346,15 @@
     "encoder, generator, vaep = SEQ_2_SEQ_VAEP(pttest.shape).build()"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#vaep.summary()"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
diff --git a/source/models.py b/source/models.py
index e0de2ffa8fab7ed9d7ec1a4dc018f6f6ada50a86..0d1542d5f5bb33a271503716ad7329d617abdf1b 100644
--- a/source/models.py
+++ b/source/models.py
@@ -129,7 +129,6 @@ class SEQ_2_SEQ_AE:
         decoder.add(Model_D2)
         encoder.add(BatchNormalization())
         decoder.add(Model_D3)
-        decoder.add(BatchNormalization())
         decoder.add(Model_D4)
         encoder.add(BatchNormalization())
         decoder.add(Model_D5)
@@ -222,7 +221,6 @@ class SEQ_2_SEQ_VAE:
         Model_B3 = BatchNormalization()
         Model_B4 = BatchNormalization()
         Model_B5 = BatchNormalization()
-        Model_B6 = BatchNormalization()
         Model_D0 = DenseTranspose(
             Model_E5, activation="relu", output_dim=self.ENCODING,
         )
@@ -280,11 +278,10 @@ class SEQ_2_SEQ_VAE:
         generator = Model_D2(generator)
         generator = Model_B3(generator)
         generator = Model_D3(generator)
-        generator = Model_B4(generator)
         generator = Model_D4(generator)
-        generator = Model_B5(generator)
+        generator = Model_B4(generator)
         generator = Model_D5(generator)
-        generator = Model_B6(generator)
+        generator = Model_B5(generator)
         x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
 
         # end-to-end autoencoder
@@ -300,11 +297,10 @@ class SEQ_2_SEQ_VAE:
         _generator = Model_D2(_generator)
         _generator = Model_B3(_generator)
         _generator = Model_D3(_generator)
-        _generator = Model_B4(_generator)
         _generator = Model_D4(_generator)
-        _generator = Model_B5(_generator)
+        _generator = Model_B4(_generator)
         _generator = Model_D5(_generator)
-        _generator = Model_B6(_generator)
+        _generator = Model_B5(_generator)
         _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
         generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
 
@@ -398,7 +394,6 @@ class SEQ_2_SEQ_VAEP:
         Model_B3 = BatchNormalization()
         Model_B4 = BatchNormalization()
         Model_B5 = BatchNormalization()
-        Model_B6 = BatchNormalization()
         Model_D0 = DenseTranspose(
             Model_E5, activation="relu", output_dim=self.ENCODING,
         )
@@ -456,22 +451,26 @@ class SEQ_2_SEQ_VAEP:
         generator = Model_D2(generator)
         generator = Model_B3(generator)
         generator = Model_D3(generator)
-        generator = Model_B4(generator)
         generator = Model_D4(generator)
-        generator = Model_B5(generator)
+        generator = Model_B4(generator)
         generator = Model_D5(generator)
-        generator = Model_B6(generator)
+        generator = Model_B5(generator)
         x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
 
         # Define and instanciate predictor
-        predictor = Model_D0(z)
-        predictor = Model_B1(predictor)
-        predictor = Model_D1(predictor)
-        predictor = Model_B2(predictor)
-        predictor = Model_D2(predictor)
-        predictor = Model_B3(predictor)
-        predictor = Model_D3(predictor)
-        predictor = Model_B4(predictor)
+        predictor = Dense(
+            self.ENCODING, activation="relu", kernel_initializer=he_uniform()
+        )(z)
+        predictor = BatchNormalization()(predictor)
+        predictor = Dense(
+            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
+        )(predictor)
+        predictor = BatchNormalization()(predictor)
+        predictor = Dense(
+            self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
+        )(predictor)
+        predictor = BatchNormalization()(predictor)
+        predictor = RepeatVector(self.input_shape[1])(predictor)
         predictor = Bidirectional(
             LSTM(
                 self.LSTM_units_1,
@@ -507,11 +506,10 @@ class SEQ_2_SEQ_VAEP:
         _generator = Model_D2(_generator)
         _generator = Model_B3(_generator)
         _generator = Model_D3(_generator)
-        _generator = Model_B4(_generator)
         _generator = Model_D4(_generator)
-        _generator = Model_B5(_generator)
+        _generator = Model_B4(_generator)
         _generator = Model_D5(_generator)
-        _generator = Model_B6(_generator)
+        _generator = Model_B5(_generator)
         _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
         generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")