Commit 486fe222 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added encoding size as a CL parameter in train_model.py

parent 3b8b588f
Pipeline #88510 failed with stage
in 19 minutes and 56 seconds
......@@ -11,7 +11,6 @@ keras hypermodels for hyperparameter tuning of deep autoencoders
from kerastuner import HyperModel
import deepof.models
import deepof.model_utils
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
......@@ -95,6 +94,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
def __init__(
self,
input_shape: tuple,
encoding: int,
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
kl_warmup_epochs: int = 0,
......@@ -109,6 +109,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
):
super().__init__()
self.input_shape = input_shape
self.encoding = encoding
self.entropy_reg_weight = entropy_reg_weight
self.huber_delta = huber_delta
self.kl_warmup_epochs = kl_warmup_epochs
......@@ -136,7 +137,6 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
dense_activation = "relu"
dense_layers_per_branch = 1
dropout_rate = 1e-3
encoding = 16
k = self.number_of_components
lstm_units_1 = 300
......@@ -148,7 +148,6 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
dense_activation,
dense_layers_per_branch,
dropout_rate,
encoding,
k,
lstm_units_1,
)
......@@ -165,7 +164,6 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
dense_activation,
dense_layers_per_branch,
dropout_rate,
encoding,
k,
lstm_units_1,
) = self.get_hparams(hp)
......@@ -177,11 +175,11 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"dense_activation": dense_activation,
"dense_layers_per_branch": dense_layers_per_branch,
"dropout_rate": dropout_rate,
"encoding": encoding,
"units_conv": conv_filters,
"units_dense_2": dense_2,
"units_lstm": lstm_units_1,
},
encoding=self.encoding,
entropy_reg_weight=self.entropy_reg_weight,
huber_delta=self.huber_delta,
kl_warmup_epochs=self.kl_warmup_epochs,
......
......@@ -9,7 +9,6 @@ deep autoencoder models for unsupervised pose detection
"""
from typing import Any, Dict, Tuple
from tensorflow.keras import backend as K
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.activations import softplus
from tensorflow.keras.callbacks import LambdaCallback
......@@ -245,6 +244,7 @@ class SEQ_2_SEQ_GMVAE:
architecture_hparams: dict = {},
batch_size: int = 256,
compile_model: bool = True,
encoding: int = 16,
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
initialiser_iters: int = int(1),
......@@ -263,7 +263,7 @@ class SEQ_2_SEQ_GMVAE:
self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
self.DENSE_2 = self.hparams["units_dense2"]
self.DROPOUT_RATE = self.hparams["dropout_rate"]
self.ENCODING = self.hparams["encoding"]
self.ENCODING = encoding
self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.clipvalue = self.hparams["clipvalue"]
......@@ -337,7 +337,6 @@ class SEQ_2_SEQ_GMVAE:
"dense_activation": "relu",
"dense_layers_per_branch": 1,
"dropout_rate": 1e-3,
"encoding": 16,
"learning_rate": 1e-3,
"units_conv": 160,
"units_dense2": 120,
......@@ -587,7 +586,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]),
scale=softplus(gauss[1][..., self.ENCODING:, k]),
),
reinterpreted_batch_ndims=1,
)
......@@ -610,6 +609,7 @@ class SEQ_2_SEQ_GMVAE:
)
)
# noinspection PyCallingNonCallable
z = deepof.model_utils.KLDivergenceLayer(self.prior, weight=kl_beta)(z)
mmd_warmup_callback = False
......
......@@ -46,6 +46,13 @@ parser.add_argument(
type=int,
default=1,
)
parser.add_argument(
"--encoding-size",
"-es",
help="set the number of dimensions of the latent space. 16 by default",
type=int,
default=16,
)
parser.add_argument(
"--exclude-bodyparts",
"-exc",
......@@ -188,6 +195,7 @@ animal_id = args.animal_id
arena_dims = args.arena_dims
batch_size = args.batch_size
hypertun_trials = args.hpt_trials
encoding_size = args.encoding_size
exclude_bodyparts = list(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
......@@ -354,6 +362,7 @@ if not tune:
architecture_hparams=hparams,
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_warmup_epochs=kl_wu,
loss=loss,
mmd_warmup_epochs=mmd_wu,
......@@ -420,6 +429,7 @@ else:
best_hyperparameters, best_model = tune_search(
data=[X_train, y_train, X_val, y_val],
encoding_size=encoding_size,
hypertun_trials=hypertun_trials,
hpt_type=tune,
hypermodel=hyp,
......
......@@ -110,6 +110,7 @@ def get_callbacks(
def tune_search(
data: List[np.array],
encoding_size: int,
hypertun_trials: int,
hpt_type: str,
hypermodel: str,
......@@ -169,6 +170,7 @@ def tune_search(
elif hypermodel == "S2SGMVAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
input_shape=X_train.shape,
encoding=encoding_size,
kl_warmup_epochs=kl_warmup_epochs,
loss=loss,
mmd_warmup_epochs=mmd_warmup_epochs,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment