Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
9e6df3ca
Commit
9e6df3ca
authored
Apr 13, 2021
by
lucas_miranda
Browse files
Added extra branch to main autoencoder for rule_based prediction
parent
1e5ce83e
Changes
5
Hide whitespace changes
Inline
Side-by-side
deepof/hypermodels.py
View file @
9e6df3ca
...
...
@@ -102,8 +102,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
mmd_warmup_epochs
:
int
=
0
,
number_of_components
:
int
=
10
,
overlap_loss
:
float
=
False
,
phenotype_predictor
:
float
=
0.0
,
predictor
:
float
=
0.0
,
next_sequence_prediction
:
float
=
0.0
,
phenotype_prediction
:
float
=
0.0
,
rule_based_prediction
:
float
=
0.0
,
prior
:
str
=
"standard_normal"
,
):
super
().
__init__
()
...
...
@@ -115,8 +116,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self
.
mmd_warmup_epochs
=
mmd_warmup_epochs
self
.
number_of_components
=
number_of_components
self
.
overlap_loss
=
overlap_loss
self
.
pheno_class
=
phenotype_predictor
self
.
predictor
=
predictor
self
.
next_sequence_prediction
=
next_sequence_prediction
self
.
phenotype_prediction
=
phenotype_prediction
self
.
rule_based_prediction
=
rule_based_prediction
self
.
prior
=
prior
assert
(
...
...
@@ -186,8 +188,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
mmd_warmup_epochs
=
self
.
mmd_warmup_epochs
,
number_of_components
=
k
,
overlap_loss
=
self
.
overlap_loss
,
phenotype_prediction
=
self
.
pheno_class
,
next_sequence_prediction
=
self
.
predictor
,
next_sequence_prediction
=
self
.
next_sequence_prediction
,
phenotype_prediction
=
self
.
phenotype_prediction
,
rule_based_prediction
=
self
.
rule_based_prediction
,
).
build
(
self
.
input_shape
)[
-
3
]
return
gmvaep
deepof/train_model.py
View file @
9e6df3ca
...
...
@@ -415,8 +415,8 @@ else:
loss
=
loss
,
mmd_warmup_epochs
=
mmd_wu
,
overlap_loss
=
overlap_loss
,
phenotype_
class
=
pheno_class
,
predict
or
=
predictor
,
phenotype_
prediction
=
pheno_class
,
next_sequence_
predict
ion
=
predictor
,
project_name
=
"{}-based_{}_{}"
.
format
(
input_type
,
hyp
,
tune
.
capitalize
()),
callbacks
=
[
tensorboard_callback
,
...
...
deepof/train_utils.py
View file @
9e6df3ca
...
...
@@ -472,8 +472,9 @@ def tune_search(
loss
:
str
,
mmd_warmup_epochs
:
int
,
overlap_loss
:
float
,
phenotype_class
:
float
,
predictor
:
float
,
next_sequence_prediction
:
float
,
phenotype_prediction
:
float
,
rule_based_prediction
:
float
,
project_name
:
str
,
callbacks
:
List
,
n_epochs
:
int
=
30
,
...
...
@@ -517,7 +518,7 @@ def tune_search(
if
hypermodel
==
"S2SAE"
:
# pragma: no cover
assert
(
predict
or
==
0.0
and
phenotype_
class
==
0.0
next_sequence_
predict
ion
==
0.0
and
phenotype_
prediction
==
0.0
),
"Prediction branches are only available for variational models. See documentation for more details"
batch_size
=
1
hypermodel
=
deepof
.
hypermodels
.
SEQ_2_SEQ_AE
(
input_shape
=
X_train
.
shape
)
...
...
@@ -532,8 +533,9 @@ def tune_search(
mmd_warmup_epochs
=
mmd_warmup_epochs
,
number_of_components
=
k
,
overlap_loss
=
overlap_loss
,
phenotype_predictor
=
phenotype_class
,
predictor
=
predictor
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
)
else
:
...
...
@@ -574,11 +576,19 @@ def tune_search(
Xs
,
ys
=
[
X_train
],
[
X_train
]
Xvals
,
yvals
=
[
X_val
],
[
X_val
]
if
predict
or
>
0.0
:
if
next_sequence_
predict
ion
>
0.0
:
Xs
,
ys
=
X_train
[:
-
1
],
[
X_train
[:
-
1
],
X_train
[
1
:]]
Xvals
,
yvals
=
X_val
[:
-
1
],
[
X_val
[:
-
1
],
X_val
[
1
:]]
if
phenotype_class
>
0.0
:
if
phenotype_prediction
>
0.0
:
ys
+=
[
y_train
[:,
0
]]
yvals
+=
[
y_val
[:,
0
]]
# Remove the used column (phenotype) from both y arrays
y_train
=
y_train
[:,
1
:]
y_val
=
y_val
[:,
1
:]
if
rule_based_prediction
>
0.0
:
ys
+=
[
y_train
]
yvals
+=
[
y_val
]
...
...
tests/test_build_hypermodels.py
View file @
9e6df3ca
...
...
@@ -51,5 +51,5 @@ def test_SEQ_2_SEQ_GMVAE_hypermodel_build(
),
loss
=
loss
,
number_of_components
=
number_of_components
,
predict
or
=
True
,
next_sequence_
predict
ion
=
True
,
).
build
(
hp
=
HyperParameters
())
tests/test_train_utils.py
View file @
9e6df3ca
...
...
@@ -193,8 +193,8 @@ def test_tune_search(
loss
=
loss
,
mmd_warmup_epochs
=
0
,
overlap_loss
=
overlap_loss
,
phenotype_
class
=
pheno_class
,
predict
or
=
predictor
,
phenotype_
prediction
=
pheno_class
,
next_sequence_
predict
ion
=
predictor
,
project_name
=
"test_run"
,
callbacks
=
callbacks
,
n_epochs
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment