Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
4362b142
Commit
4362b142
authored
Apr 14, 2021
by
lucas_miranda
Browse files
Added extra branch to main autoencoder for rule_based prediction
parent
870f15a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
deepof/train_model.py
View file @
4362b142
...
...
@@ -172,17 +172,25 @@ parser.add_argument(
default
=
False
,
)
parser
.
add_argument
(
"--phenotype-classifier"
,
"-pheno"
,
"--next-sequence-prediction"
,
"-nspred"
,
help
=
"Activates the next sequence prediction branch of the variational Seq 2 Seq model with the specified weight. "
"Defaults to 0.0 (inactive)"
,
default
=
0.0
,
type
=
float
,
)
parser
.
add_argument
(
"--phenotype-prediction"
,
"-ppred"
,
help
=
"Activates the phenotype classification branch with the specified weight. Defaults to 0.0 (inactive)"
,
default
=
0.0
,
type
=
float
,
)
parser
.
add_argument
(
"--predict
or
"
,
"-pred"
,
help
=
"Activates the prediction branch of the variational Seq 2 Seq model
with the specified weight.
"
"
Defaults to 0.0 (inactive)"
,
"--
rule-based-
predict
ion
"
,
"-
rb
pred"
,
help
=
"Activates the
rule-based trait
prediction branch of the variational Seq 2 Seq model "
"with the specified weight
Defaults to 0.0 (inactive)"
,
default
=
0.0
,
type
=
float
,
)
...
...
@@ -246,8 +254,9 @@ mc_kl = args.montecarlo_kl
neuron_control
=
args
.
neuron_control
output_path
=
os
.
path
.
join
(
args
.
output_path
)
overlap_loss
=
args
.
overlap_loss
pheno_class
=
float
(
args
.
phenotype_classifier
)
predictor
=
float
(
args
.
predictor
)
next_sequence_prediction
=
float
(
args
.
next_sequence_prediction
)
phenotype_prediction
=
float
(
args
.
phenotype_prediction
)
rule_based_prediction
=
float
(
args
.
rule_based_prediction
)
smooth_alpha
=
args
.
smooth_alpha
train_path
=
os
.
path
.
abspath
(
args
.
train_path
)
tune
=
args
.
hyperparameter_tuning
...
...
@@ -282,8 +291,8 @@ logparam = {
"k"
:
k
,
"loss"
:
loss
,
}
if
pheno
_class
:
logparam
[
"pheno_weight"
]
=
pheno
_class
if
pheno
type_prediction
:
logparam
[
"pheno_weight"
]
=
pheno
type_prediction
# noinspection PyTypeChecker
project_coords
=
project
(
...
...
@@ -310,7 +319,7 @@ coords = project_coords.get_coords(
center
=
animal_id
+
undercond
+
"Center"
,
align
=
animal_id
+
undercond
+
"Spine_1"
,
align_inplace
=
True
,
propagate_labels
=
(
pheno
_class
>
0
),
propagate_labels
=
(
pheno
type_prediction
>
0
),
)
distances
=
project_coords
.
get_distances
()
angles
=
project_coords
.
get_angles
()
...
...
@@ -350,7 +359,7 @@ X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
print
(
"Training set shape:"
,
X_train
.
shape
)
print
(
"Validation set shape:"
,
X_val
.
shape
)
if
pheno
_class
>
0
:
if
pheno
type_prediction
>
0
:
print
(
"Training set label shape:"
,
y_train
.
shape
)
print
(
"Validation set label shape:"
,
y_val
.
shape
)
...
...
@@ -373,8 +382,8 @@ if not tune:
montecarlo_kl
=
mc_kl
,
n_components
=
k
,
output_path
=
output_path
,
phenotype_prediction
=
pheno
_class
,
next_sequence_prediction
=
predict
or
,
phenotype_prediction
=
pheno
type_prediction
,
next_sequence_prediction
=
next_sequence_
predict
ion
,
save_checkpoints
=
False
,
save_weights
=
True
,
variational
=
variational
,
...
...
@@ -397,8 +406,8 @@ else:
variational
=
variational
,
entropy_samples
=
entropy_samples
,
entropy_knn
=
entropy_knn
,
phenotype_prediction
=
pheno
_class
,
next_sequence_prediction
=
predict
or
,
phenotype_prediction
=
pheno
type_prediction
,
next_sequence_prediction
=
next_sequence_
predict
ion
,
loss
=
loss
,
logparam
=
logparam
,
outpath
=
output_path
,
...
...
@@ -415,8 +424,8 @@ else:
loss
=
loss
,
mmd_warmup_epochs
=
mmd_wu
,
overlap_loss
=
overlap_loss
,
phenotype_prediction
=
pheno
_class
,
next_sequence_prediction
=
predict
or
,
phenotype_prediction
=
pheno
type_prediction
,
next_sequence_prediction
=
next_sequence_
predict
ion
,
project_name
=
"{}-based_{}_{}"
.
format
(
input_type
,
hyp
,
tune
.
capitalize
()),
callbacks
=
[
tensorboard_callback
,
...
...
deepof/train_utils.py
View file @
4362b142
...
...
@@ -256,7 +256,9 @@ def tensorboard_metric_logging(
if
phenotype_prediction
:
idx
=
next
(
idx_generator
)
pheno_acc
=
tf
.
keras
.
metrics
.
binary_accuracy
(
y_val
[
idx
],
tf
.
squeeze
(
outputs
[
idx
]))
pheno_acc
=
tf
.
keras
.
metrics
.
binary_accuracy
(
y_val
[
idx
],
tf
.
squeeze
(
outputs
[
idx
])
)
pheno_auc
=
tf
.
keras
.
metrics
.
AUC
()
pheno_auc
.
update_state
(
y_val
[
idx
],
outputs
[
idx
])
pheno_auc
=
pheno_auc
.
result
().
numpy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment