From 4316e20995d40a9a96ef3db4aa6f6dfce112914c Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 25 Nov 2020 16:24:47 +0100 Subject: [PATCH] Updated default GMVAE models --- deepof/models.py | 7 +++++++ deepof/train_model.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/deepof/models.py b/deepof/models.py index 613a15c3..85206c48 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -249,6 +249,7 @@ class SEQ_2_SEQ_GMVAE: self.dense_activation = dense_activation self.LSTM_units_1 = self.hparams["units_lstm"] self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2) + self.LSTM_unroll = True self.DENSE_1 = int(self.hparams["units_lstm"] / 2) self.DENSE_2 = self.hparams["units_dense2"] self.DROPOUT_RATE = self.hparams["dropout_rate"] @@ -344,6 +345,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=0), use_bias=True, ) @@ -354,6 +356,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=False, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=0), use_bias=True, ) @@ -397,6 +400,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) @@ -407,6 +411,7 @@ class SEQ_2_SEQ_GMVAE: activation="sigmoid", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) @@ -425,6 +430,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) @@ -435,6 +441,7 @@ class SEQ_2_SEQ_GMVAE: activation="tanh", recurrent_activation="sigmoid", return_sequences=True, + unroll=self.lstm_unroll, # kernel_constraint=UnitNorm(axis=1), use_bias=True, ) diff --git a/deepof/train_model.py b/deepof/train_model.py index 9b16d7ba..5703034a 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -12,8 +12,8 @@ usage: python -m examples.model_training -h from deepof.data import * from deepof.models import * from deepof.utils import * +from deepof.train_utils import * from tensorflow import keras -from train_utils import * parser = argparse.ArgumentParser( description="Autoencoder training for DeepOF animal pose recognition" -- GitLab