Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
d207cd1d
Commit
d207cd1d
authored
Nov 16, 2020
by
lucas_miranda
Browse files
Updated train_model.py to be compatible with phenotype classification
parent
81cd01a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
deepof/train_model.py
View file @
d207cd1d
...
...
@@ -411,8 +411,7 @@ else:
)
best_hyperparameters
,
best_model
=
tune_search
(
X_train
,
X_val
,
data
=
[
X_train
,
y_train
,
X_val
,
y_val
],
bayopt_trials
=
bayopt_trials
,
hypermodel
=
hyp
,
k
=
k
,
...
...
deepof/train_utils.py
View file @
d207cd1d
...
...
@@ -97,8 +97,7 @@ def get_callbacks(
def
tune_search
(
train
:
np
.
array
,
test
:
np
.
array
,
data
:
List
[
np
.
array
],
bayopt_trials
:
int
,
hypermodel
:
str
,
k
:
int
,
...
...
@@ -139,12 +138,14 @@ def tune_search(
"""
X_train
,
y_train
,
X_val
,
y_val
=
data
if
hypermodel
==
"S2SAE"
:
# pragma: no cover
hypermodel
=
deepof
.
hypermodels
.
SEQ_2_SEQ_AE
(
input_shape
=
train
.
shape
)
hypermodel
=
deepof
.
hypermodels
.
SEQ_2_SEQ_AE
(
input_shape
=
X_
train
.
shape
)
elif
hypermodel
==
"S2SGMVAE"
:
hypermodel
=
deepof
.
hypermodels
.
SEQ_2_SEQ_GMVAE
(
input_shape
=
train
.
shape
,
input_shape
=
X_
train
.
shape
,
loss
=
loss
,
number_of_components
=
k
,
overlap_loss
=
overlap_loss
,
...
...
@@ -168,13 +169,22 @@ def tune_search(
print
(
tuner
.
search_space_summary
())
Xs
,
ys
=
[
X_train
],
[
X_train
]
Xvals
,
yvals
=
[
X_val
],
[
X_val
]
if
predictor
>
0.0
:
Xs
,
ys
=
X_train
[:
-
1
],
[
X_train
[:
-
1
],
X_train
[
1
:]]
Xvals
,
yvals
=
X_val
[:
-
1
],
[
X_val
[:
-
1
],
X_val
[
1
:]]
if
pheno_class
>
0.0
:
ys
+=
[
y_train
]
yvals
+=
[
y_val
]
tuner
.
search
(
train
if
predictor
==
0
else
[
train
[:
-
1
]]
,
train
if
predictor
==
0
else
[
train
[:
-
1
],
train
[
1
:]]
,
Xs
,
ys
,
epochs
=
n_epochs
,
validation_data
=
(
(
test
,
test
)
if
predictor
==
0
else
(
test
[:
-
1
],
[
test
[:
-
1
],
test
[
1
:]])
),
validation_data
=
(
Xvals
,
yvals
),
verbose
=
1
,
batch_size
=
256
,
callbacks
=
callbacks
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment