From 4316e20995d40a9a96ef3db4aa6f6dfce112914c Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 25 Nov 2020 16:24:47 +0100
Subject: [PATCH] Updated default GMVAE models

---
 deepof/models.py      | 7 +++++++
 deepof/train_model.py | 2 +-
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/deepof/models.py b/deepof/models.py
index 613a15c3..85206c48 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -249,6 +249,7 @@ class SEQ_2_SEQ_GMVAE:
         self.dense_activation = dense_activation
         self.LSTM_units_1 = self.hparams["units_lstm"]
         self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
+        self.LSTM_unroll = True
         self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
         self.DENSE_2 = self.hparams["units_dense2"]
         self.DROPOUT_RATE = self.hparams["dropout_rate"]
@@ -344,6 +345,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=0),
                 use_bias=True,
             )
@@ -354,6 +356,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=False,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=0),
                 use_bias=True,
             )
@@ -397,6 +400,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
@@ -407,6 +411,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="sigmoid",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
@@ -425,6 +430,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
@@ -435,6 +441,7 @@ class SEQ_2_SEQ_GMVAE:
                 activation="tanh",
                 recurrent_activation="sigmoid",
                 return_sequences=True,
+                unroll=self.lstm_unroll,
                 # kernel_constraint=UnitNorm(axis=1),
                 use_bias=True,
             )
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 9b16d7ba..5703034a 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -12,8 +12,8 @@ usage: python -m examples.model_training -h
 from deepof.data import *
 from deepof.models import *
 from deepof.utils import *
+from deepof.train_utils import *
 from tensorflow import keras
-from train_utils import *
 
 parser = argparse.ArgumentParser(
     description="Autoencoder training for DeepOF animal pose recognition"
-- 
GitLab