diff --git a/deepof/models.py b/deepof/models.py
index 613a15c30cca50b159acf065c1c45dee5401978e..85206c489a51d7b1411345a14b43b3d9926c2834 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -249,6 +249,7 @@ class SEQ_2_SEQ_GMVAE:
         self.dense_activation = dense_activation
         self.LSTM_units_1 = self.hparams["units_lstm"]
         self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
+        self.LSTM_unroll = True
         self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
         self.DENSE_2 = self.hparams["units_dense2"]
         self.DROPOUT_RATE = self.hparams["dropout_rate"]
@@ -344,6 +345,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=0),
                 use_bias=True,
             )
@@ -354,6 +356,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=False,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=0),
                 use_bias=True,
             )
@@ -397,6 +400,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
@@ -407,6 +411,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="sigmoid",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
@@ -425,6 +430,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
@@ -435,6 +441,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 9b16d7ba6e6391907964fb2757aed06247a54e2c..5703034a4244a0d684f1b0d23cb981a4833eb4b5 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -12,8 +12,8 @@ usage: python -m examples.model_training -h
 from deepof.data import *
 from deepof.models import *
 from deepof.utils import *
+from deepof.train_utils import *
 from tensorflow import keras
-from train_utils import *
 
 parser = argparse.ArgumentParser(
     description="Autoencoder training for DeepOF animal pose recognition"