diff --git a/deepof/models.py b/deepof/models.py index 613a15c30cca50b159acf065c1c45dee5401978e..85206c489a51d7b1411345a14b43b3d9926c2834 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -249,6 +249,7 @@ class SEQ_2_SEQ_GMVAE: self.dense_activation = dense_activation self.LSTM_units_1 = self.hparams["units_lstm"] self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2) + self.LSTM_unroll = True self.DENSE_1 = int(self.hparams["units_lstm"] / 2) self.DENSE_2 = self.hparams["units_dense2"] self.DROPOUT_RATE = self.hparams["dropout_rate"] @@ -344,6 +345,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=0), use_bias=True, ) @@ -354,6 +356,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=False, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=0), use_bias=True, ) @@ -397,6 +400,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) @@ -407,6 +411,7 @@ class SEQ_2_SEQ_GMVAE: activation="sigmoid", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) @@ -425,6 +430,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) @@ -435,6 +441,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) diff --git a/deepof/train_model.py b/deepof/train_model.py index 9b16d7ba6e6391907964fb2757aed06247a54e2c..5703034a4244a0d684f1b0d23cb981a4833eb4b5 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -12,8 +12,8 @@ usage: python -m examples.model_training -h from deepof.data import * from deepof.models import * from deepof.utils import * +from deepof.train_utils import * from tensorflow import keras -from train_utils import * parser = argparse.ArgumentParser( description="Autoencoder training for DeepOF animal pose recognition"