Skip to content
Snippets Groups Projects
Commit 4316e209 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Updated default GMVAE models

parent 8e546668
No related branches found
No related tags found
No related merge requests found
Pipeline #87811 failed
...@@ -249,6 +249,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -249,6 +249,7 @@ class SEQ_2_SEQ_GMVAE:
self.dense_activation = dense_activation self.dense_activation = dense_activation
self.LSTM_units_1 = self.hparams["units_lstm"] self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2) self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.LSTM_unroll = True
self.DENSE_1 = int(self.hparams["units_lstm"] / 2) self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
self.DENSE_2 = self.hparams["units_dense2"] self.DENSE_2 = self.hparams["units_dense2"]
self.DROPOUT_RATE = self.hparams["dropout_rate"] self.DROPOUT_RATE = self.hparams["dropout_rate"]
...@@ -344,6 +345,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -344,6 +345,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
unroll=self.lstm_unroll,
# kernel_constraint=UnitNorm(axis=0), # kernel_constraint=UnitNorm(axis=0),
use_bias=True, use_bias=True,
) )
...@@ -354,6 +356,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -354,6 +356,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=False, return_sequences=False,
unroll=self.lstm_unroll,
# kernel_constraint=UnitNorm(axis=0), # kernel_constraint=UnitNorm(axis=0),
use_bias=True, use_bias=True,
) )
...@@ -397,6 +400,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -397,6 +400,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
unroll=self.lstm_unroll,
# kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
...@@ -407,6 +411,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -407,6 +411,7 @@ class SEQ_2_SEQ_GMVAE:
activation="sigmoid", activation="sigmoid",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
unroll=self.lstm_unroll,
# kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
...@@ -425,6 +430,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -425,6 +430,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
unroll=self.lstm_unroll,
# kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
...@@ -435,6 +441,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -435,6 +441,7 @@ class SEQ_2_SEQ_GMVAE:
activation="tanh", activation="tanh",
recurrent_activation="sigmoid", recurrent_activation="sigmoid",
return_sequences=True, return_sequences=True,
unroll=self.lstm_unroll,
# kernel_constraint=UnitNorm(axis=1), # kernel_constraint=UnitNorm(axis=1),
use_bias=True, use_bias=True,
) )
......
...@@ -12,8 +12,8 @@ usage: python -m examples.model_training -h ...@@ -12,8 +12,8 @@ usage: python -m examples.model_training -h
from deepof.data import * from deepof.data import *
from deepof.models import * from deepof.models import *
from deepof.utils import * from deepof.utils import *
from deepof.train_utils import *
from tensorflow import keras from tensorflow import keras
from train_utils import *
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition" description="Autoencoder training for DeepOF animal pose recognition"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment