Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
3393dcd1
Commit
3393dcd1
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Prototyped KNN_purity callback
parent
9219111b
Changes
4
Hide whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
3393dcd1
...
...
@@ -870,6 +870,8 @@ class coordinates:
variational
:
bool
=
True
,
reg_cat_clusters
:
bool
=
False
,
reg_cluster_variance
:
bool
=
False
,
knn_neighbors
:
int
=
100
,
knn_samples
:
int
=
10000
,
)
->
Tuple
:
"""
Annotates coordinates using an unsupervised autoencoder.
...
...
@@ -930,6 +932,8 @@ class coordinates:
variational
=
variational
,
reg_cat_clusters
=
reg_cat_clusters
,
reg_cluster_variance
=
reg_cluster_variance
,
knn_neighbors
=
knn_neighbors
,
knn_samples
=
knn_samples
,
)
# returns a list of trained tensorflow models
...
...
deepof/train_model.py
View file @
3393dcd1
...
...
@@ -108,6 +108,20 @@ parser.add_argument(
default
=
10
,
type
=
int
,
)
parser
.
add_argument
(
"--knn-neighbors"
,
"-knn"
,
help
=
"Neighbors to take into account to compute KNN cluster purity"
,
default
=
100
,
type
=
int
,
)
parser
.
add_argument
(
"--knn-samples"
,
"-knns"
,
help
=
"Samples to use to compute KNN cluster purity"
,
default
=
10000
,
type
=
int
,
)
parser
.
add_argument
(
"--latent-reg"
,
"-lreg"
,
...
...
@@ -226,6 +240,8 @@ hparams = args.hyperparameters if args.hyperparameters is not None else {}
input_type
=
args
.
input_type
k
=
args
.
components
kl_wu
=
args
.
kl_warmup
knn_neighbors
=
args
.
knn_neighbors
knn_samples
=
args
.
knn_samples
latent_reg
=
args
.
latent_reg
loss
=
args
.
loss
mmd_wu
=
args
.
mmd_warmup
...
...
@@ -367,6 +383,8 @@ if not tune:
variational
=
variational
,
reg_cat_clusters
=
(
"categorical"
in
latent_reg
),
reg_cluster_variance
=
(
"variance"
in
latent_reg
),
knn_neighbors
=
knn_neighbors
,
knn_samples
=
knn_samples
,
)
else
:
...
...
@@ -374,11 +392,13 @@ else:
hyp
=
"S2SGMVAE"
if
variational
else
"S2SAE"
run_ID
,
tensorboard_callback
,
onecycle
=
get_callbacks
(
run_ID
,
tensorboard_callback
,
knn
,
onecycle
=
get_callbacks
(
X_train
=
X_train
,
batch_size
=
batch_size
,
cp
=
False
,
variational
=
variational
,
knn_samples
=
knn_samples
,
knn_neighbors
=
knn_neighbors
,
phenotype_class
=
pheno_class
,
predictor
=
predictor
,
loss
=
loss
,
...
...
@@ -403,6 +423,7 @@ else:
callbacks
=
[
tensorboard_callback
,
onecycle
,
knn
,
CustomStopper
(
monitor
=
"val_loss"
,
patience
=
5
,
...
...
deepof/train_utils.py
View file @
3393dcd1
...
...
@@ -74,6 +74,8 @@ def get_callbacks(
cp
:
bool
=
False
,
reg_cat_clusters
:
bool
=
False
,
reg_cluster_variance
:
bool
=
False
,
knn_samples
:
int
=
10000
,
knn_neighbors
:
int
=
100
,
logparam
:
dict
=
None
,
outpath
:
str
=
"."
,
)
->
List
[
Union
[
Any
]]:
...
...
@@ -109,12 +111,17 @@ def get_callbacks(
profile_batch
=
2
,
)
knn
=
deepof
.
model_utils
.
knn_cluster_purity
(
k
=
knn_neighbors
,
samples
=
knn_samples
,
)
onecycle
=
deepof
.
model_utils
.
one_cycle_scheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
)
callbacks
=
[
run_ID
,
tensorboard_callback
,
onecycle
]
callbacks
=
[
run_ID
,
tensorboard_callback
,
knn
,
onecycle
]
if
cp
:
cp_callback
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
...
...
@@ -252,6 +259,8 @@ def autoencoder_fitting(
variational
:
bool
,
reg_cat_clusters
:
bool
,
reg_cluster_variance
:
bool
,
knn_neighbors
:
int
,
knn_samples
:
int
,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
...
...
@@ -279,6 +288,8 @@ def autoencoder_fitting(
phenotype_class
=
phenotype_class
,
predictor
=
predictor
,
loss
=
loss
,
knn_neighbors
=
knn_neighbors
,
knn_samples
=
knn_samples
,
reg_cat_clusters
=
reg_cat_clusters
,
reg_cluster_variance
=
reg_cluster_variance
,
logparam
=
logparam
,
...
...
tests/test_train_utils.py
View file @
3393dcd1
...
...
@@ -56,7 +56,7 @@ def test_get_callbacks(
pheno_class
,
loss
,
):
runID
,
tbc
,
cycle1c
,
cpc
=
deepof
.
train_utils
.
get_callbacks
(
runID
,
tbc
,
knn
,
cycle1c
,
cpc
=
deepof
.
train_utils
.
get_callbacks
(
X_train
,
batch_size
,
variational
,
...
...
@@ -71,6 +71,7 @@ def test_get_callbacks(
assert
type
(
runID
)
==
str
assert
type
(
tbc
)
==
tf
.
keras
.
callbacks
.
TensorBoard
assert
type
(
cpc
)
==
tf
.
keras
.
callbacks
.
ModelCheckpoint
assert
type
(
knn
)
==
deepof
.
model_utils
.
knn_cluster_purity
assert
type
(
cycle1c
)
==
deepof
.
model_utils
.
one_cycle_scheduler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment