Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
3b845054
Commit
3b845054
authored
Nov 18, 2020
by
lucas_miranda
Browse files
Changed default hyperparameter values
parent
e240953e
Pipeline
#87073
failed with stage
in 13 minutes and 55 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/train_model.py
View file @
3b845054
...
...
@@ -308,8 +308,8 @@ if not tune:
# To avoid stability issues
tf
.
keras
.
backend
.
clear_session
()
run_ID
,
tensorboard_callback
,
cp_callback
,
onecycle
=
get_callbacks
(
X_train
,
batch_size
,
variational
,
predictor
,
loss
,
run_ID
,
tensorboard_callback
,
onecycle
,
cp_callback
=
get_callbacks
(
X_train
,
batch_size
,
True
,
variational
,
predictor
,
loss
,
)
if
not
variational
:
...
...
@@ -406,8 +406,8 @@ else:
hyp
=
"S2SGMVAE"
if
variational
else
"S2SAE"
run_ID
,
tensorboard_callback
,
cp_callback
,
onecycle
=
get_callbacks
(
X_train
,
batch_size
,
variational
,
predictor
,
loss
run_ID
,
tensorboard_callback
,
onecycle
=
get_callbacks
(
X_train
,
batch_size
,
False
,
variational
,
predictor
,
loss
)
best_hyperparameters
,
best_model
=
tune_search
(
...
...
deepof/train_utils.py
View file @
3b845054
...
...
@@ -61,7 +61,7 @@ def load_treatments(train_path):
def
get_callbacks
(
X_train
:
np
.
array
,
batch_size
:
int
,
variational
:
bool
,
predictor
:
float
,
loss
:
str
,
X_train
:
np
.
array
,
batch_size
:
int
,
cp
:
bool
,
variational
:
bool
,
predictor
:
float
,
loss
:
str
,
)
->
Tuple
:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
...
...
@@ -81,19 +81,23 @@ def get_callbacks(
log_dir
=
log_dir
,
histogram_freq
=
1
,
profile_batch
=
2
,
)
cp_callback
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
"./logs/checkpoints/"
+
run_ID
+
"/cp-{epoch:04d}.ckpt"
,
verbose
=
1
,
save_best_only
=
False
,
save_weights_only
=
True
,
save_freq
=
"epoch"
,
)
onecycle
=
deepof
.
model_utils
.
one_cycle_scheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
)
return
run_ID
,
tensorboard_callback
,
cp_callback
,
onecycle
callbacks
=
[
run_ID
,
tensorboard_callback
,
onecycle
]
if
cp
:
cp_callback
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
"./logs/checkpoints/"
+
run_ID
+
"/cp-{epoch:04d}.ckpt"
,
verbose
=
1
,
save_best_only
=
False
,
save_weights_only
=
True
,
save_freq
=
"epoch"
,
)
callbacks
.
append
(
cp_callback
)
return
callbacks
def
tune_search
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment