Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
060349e8
Commit
060349e8
authored
Oct 12, 2020
by
lucas_miranda
Browse files
Added support for tensorboard HParams while tuning hyperparameters
parent
0361a1bb
Pipeline
#84057
passed with stage
in 29 minutes and 30 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/hypermodels.py
View file @
060349e8
...
...
@@ -24,10 +24,10 @@ class SEQ_2_SEQ_AE(HyperModel):
super
().
__init__
()
self
.
input_shape
=
input_shape
def
build
(
self
,
hp
):
"""Overrides Hypermodel's build method"""
@
staticmethod
def
get_hparams
(
hp
):
"""Retrieve hyperparameters to tune"""
# HYPERPARAMETERS TO TUNE
conv_filters
=
hp
.
Int
(
"units_conv"
,
min_value
=
32
,
max_value
=
256
,
step
=
32
,
default
=
256
)
...
...
@@ -42,6 +42,16 @@ class SEQ_2_SEQ_AE(HyperModel):
)
encoding
=
hp
.
Int
(
"encoding"
,
min_value
=
16
,
max_value
=
64
,
step
=
8
,
default
=
24
)
return
conv_filters
,
lstm_units_1
,
dense_2
,
dropout_rate
,
encoding
def
build
(
self
,
hp
):
"""Overrides Hypermodel's build method"""
# HYPERPARAMETERS TO TUNE
conv_filters
,
lstm_units_1
,
dense_2
,
dropout_rate
,
encoding
=
self
.
get_hparams
(
hp
)
# INSTANCIATED MODEL
model
=
deepof
.
models
.
SEQ_2_SEQ_AE
(
architecture_hparams
=
{
...
...
@@ -92,10 +102,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"ELBO"
in
self
.
loss
or
"MMD"
in
self
.
loss
),
"loss must be one of ELBO, MMD or ELBO+MMD (default)"
def
build
(
self
,
hp
):
"""Overrides Hypermodel's build method"""
@
staticmethod
def
get_hparams
(
hp
):
"""Retrieve hyperparameters to tune"""
# Hyperparameters to tune
conv_filters
=
hp
.
Int
(
"units_conv"
,
min_value
=
32
,
max_value
=
256
,
step
=
32
,
default
=
256
)
...
...
@@ -110,6 +120,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
)
encoding
=
hp
.
Int
(
"encoding"
,
min_value
=
16
,
max_value
=
64
,
step
=
8
,
default
=
24
)
return
conv_filters
,
lstm_units_1
,
dense_2
,
dropout_rate
,
encoding
def
build
(
self
,
hp
):
"""Overrides Hypermodel's build method"""
# Hyperparameters to tune
conv_filters
,
lstm_units_1
,
dense_2
,
dropout_rate
,
encoding
=
self
.
get_hparams
(
hp
)
gmvaep
,
kl_warmup_callback
,
mmd_warmup_callback
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
architecture_hparams
=
{
"units_conv"
:
conv_filters
,
...
...
deepof/requirements.txt
View file @
060349e8
...
...
@@ -2,6 +2,7 @@ dash~=1.11.0
hypothesis~=5.29.0
joblib~=0.16.0
keras-tuner~=1.0.1
kerastuner-tensorboard-logger~=0.2.3
matplotlib~=3.1.3
networkx~=2.4
numpy~=1.18.1
...
...
deepof/train_model.py
View file @
060349e8
...
...
@@ -13,7 +13,6 @@ from deepof.data import *
from
deepof.models
import
*
from
deepof.utils
import
*
from
train_utils
import
*
from
tensorboard.plugins.hparams
import
api
as
hp
from
tensorflow
import
keras
parser
=
argparse
.
ArgumentParser
(
...
...
deepof/train_utils.py
View file @
060349e8
...
...
@@ -10,6 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
from
datetime
import
datetime
from
kerastuner
import
BayesianOptimization
from
kerastuner_tensorboard_logger
import
TensorBoardLogger
from
typing
import
Tuple
,
Union
,
Any
,
List
import
deepof.hypermodels
import
deepof.model_utils
...
...
@@ -176,6 +177,7 @@ def tune_search(
seed
=
42
,
directory
=
"BayesianOptx"
,
project_name
=
project_name
,
logger
=
TensorBoardLogger
(
metrics
=
[
"val_mae"
],
logdir
=
"./logs/hparams"
),
)
print
(
tuner
.
search_space_summary
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment