Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
Lucas Miranda
deepOF
Commits
4ad26649
Commit
4ad26649
authored
Apr 13, 2021
by
lucas_miranda
Browse files
Added extra branch to main autoencoder for rule_based prediction
parent
1c00ad9d
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/hypermodels.py
View file @
4ad26649
...
...
@@ -193,7 +193,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
next_sequence_prediction
=
self
.
next_sequence_prediction
,
phenotype_prediction
=
self
.
phenotype_prediction
,
rule_based_prediction
=
self
.
rule_based_prediction
,
rule_based_features
=
self
.
rule_based_features
rule_based_features
=
self
.
rule_based_features
,
).
build
(
self
.
input_shape
)[
-
3
]
return
gmvaep
deepof/train_utils.py
View file @
4ad26649
...
...
@@ -518,7 +518,7 @@ def tune_search(
if
hypermodel
==
"S2SAE"
:
# pragma: no cover
assert
(
next_sequence_prediction
==
0.0
and
phenotype_prediction
==
0.0
next_sequence_prediction
==
0.0
and
phenotype_prediction
==
0.0
),
"Prediction branches are only available for variational models. See documentation for more details"
batch_size
=
1
hypermodel
=
deepof
.
hypermodels
.
SEQ_2_SEQ_AE
(
input_shape
=
X_train
.
shape
)
...
...
tests/test_train_utils.py
View file @
4ad26649
...
...
@@ -82,20 +82,24 @@ def test_get_callbacks(
@
settings
(
max_examples
=
10
,
deadline
=
None
,
suppress_health_check
=
[
HealthCheck
.
too_slow
])
@
given
(
loss
=
st
.
one_of
(
st
.
just
(
"ELBO"
),
st
.
just
(
"MMD"
),
st
.
just
(
"ELBO+MMD"
)),
pheno_class
=
st
.
one_of
(
st
.
just
(
1.0
),
st
.
just
(
0.0
)),
predictor
=
st
.
one_of
(
st
.
just
(
1.0
),
st
.
just
(
0.0
)),
next_sequence_prediction
=
st
.
one_of
(
st
.
just
(
1.0
),
st
.
just
(
0.0
)),
phenotype_prediction
=
st
.
one_of
(
st
.
just
(
1.0
),
st
.
just
(
0.0
)),
rule_based_prediction
=
st
.
one_of
(
st
.
just
(
1.0
),
st
.
just
(
0.0
)),
variational
=
st
.
one_of
(
st
.
just
(
True
),
st
.
just
(
False
)),
)
def
test_autoencoder_fitting
(
loss
,
pheno_class
,
predictor
,
next_sequence_prediction
,
phenotype_prediction
,
rule_based_prediction
,
variational
,
):
X_train
=
np
.
random
.
uniform
(
-
1
,
1
,
[
20
,
5
,
6
])
y_train
=
np
.
round
(
np
.
random
.
uniform
(
0
,
1
,
20
))
y_train
=
np
.
round
(
np
.
random
.
uniform
(
0
,
1
,
[
20
,
1
]))
if
rule_based_prediction
:
y_train
=
np
.
concatenate
([
y_train
,
np
.
random
.
uniform
(
0
,
1
,
[
20
,
6
])],
axis
=
1
)
if
predict
or
:
if
next_sequence_
predict
ion
:
y_train
=
y_train
[
1
:]
preprocessed_data
=
(
X_train
,
y_train
,
X_train
,
y_train
)
...
...
@@ -118,8 +122,9 @@ def test_autoencoder_fitting(
mmd_warmup
=
1
,
n_components
=
2
,
loss
=
loss
,
phenotype_prediction
=
pheno_class
,
next_sequence_prediction
=
predictor
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
variational
=
variational
,
entropy_samples
=
10
,
entropy_knn
=
5
,
...
...
@@ -147,8 +152,9 @@ def test_autoencoder_fitting(
k
=
st
.
integers
(
min_value
=
1
,
max_value
=
10
),
loss
=
st
.
one_of
(
st
.
just
(
"ELBO"
),
st
.
just
(
"MMD"
)),
overlap_loss
=
st
.
floats
(
min_value
=
0.0
,
max_value
=
1.0
),
pheno_class
=
st
.
floats
(
min_value
=
0.0
,
max_value
=
1.0
),
predictor
=
st
.
floats
(
min_value
=
0.0
,
max_value
=
1.0
),
next_sequence_prediction
=
st
.
floats
(
min_value
=
0.0
,
max_value
=
1.0
),
phenotype_prediction
=
st
.
floats
(
min_value
=
0.0
,
max_value
=
1.0
),
rule_based_prediction
=
st
.
floats
(
min_value
=
0.0
,
max_value
=
1.0
),
)
def
test_tune_search
(
X_train
,
...
...
@@ -159,8 +165,9 @@ def test_tune_search(
k
,
loss
,
overlap_loss
,
pheno_class
,
predictor
,
next_sequence_prediction
,
phenotype_prediction
,
rule_based_prediction
,
):
callbacks
=
list
(
deepof
.
train_utils
.
get_callbacks
(
...
...
@@ -193,8 +200,9 @@ def test_tune_search(
loss
=
loss
,
mmd_warmup_epochs
=
0
,
overlap_loss
=
overlap_loss
,
phenotype_prediction
=
pheno_class
,
next_sequence_prediction
=
predictor
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
project_name
=
"test_run"
,
callbacks
=
callbacks
,
n_epochs
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment