Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
dd62bd2e
Commit
dd62bd2e
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Implemented KNN_purity callback
parent
1cd9d1b6
Pipeline
#95528
canceled with stages
in 7 minutes and 35 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
dd62bd2e
...
...
@@ -209,8 +209,11 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def
__init__
(
self
,
validation_data
=
None
,
k
=
100
,
samples
=
10000
,
log_dir
=
"."
):
def
__init__
(
self
,
variational
=
True
,
validation_data
=
None
,
k
=
100
,
samples
=
10000
,
log_dir
=
"."
):
super
().
__init__
()
self
.
variational
=
variational
self
.
validation_data
=
validation_data
self
.
k
=
k
self
.
samples
=
samples
...
...
@@ -220,7 +223,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
if
self
.
validation_data
is
not
None
:
if
self
.
validation_data
is
not
None
and
self
.
variational
:
# Get encoer and grouper from full model
cluster_means
=
[
...
...
deepof/models.py
View file @
dd62bd2e
...
...
@@ -253,7 +253,7 @@ class SEQ_2_SEQ_GMVAE:
montecarlo_kl
:
int
=
1
,
neuron_control
:
bool
=
False
,
number_of_components
:
int
=
1
,
overlap_loss
:
float
=
-
1.
,
overlap_loss
:
float
=
-
1.
0
,
phenotype_prediction
:
float
=
0.0
,
predictor
:
float
=
0.0
,
reg_cat_clusters
:
bool
=
False
,
...
...
@@ -606,7 +606,7 @@ class SEQ_2_SEQ_GMVAE:
z_gauss
=
deepof
.
model_utils
.
Cluster_overlap
(
self
.
ENCODING
,
self
.
number_of_components
,
loss
=
tf
.
maximum
(
0.
,
self
.
overlap_loss
).
numpy
(),
loss
=
tf
.
maximum
(
0.
0
,
self
.
overlap_loss
).
numpy
(),
)(
z_gauss
)
z
=
tfpl
.
DistributionLambda
(
...
...
deepof/train_utils.py
View file @
dd62bd2e
...
...
@@ -117,6 +117,7 @@ def get_callbacks(
samples
=
knn_samples
,
validation_data
=
X_val
,
log_dir
=
os
.
path
.
join
(
outpath
,
"metrics"
),
variational
=
variational
,
)
onecycle
=
deepof
.
model_utils
.
one_cycle_scheduler
(
...
...
tests/test_train_utils.py
View file @
dd62bd2e
...
...
@@ -117,6 +117,8 @@ def test_autoencoder_fitting(
phenotype_class
=
pheno_class
,
predictor
=
predictor
,
variational
=
variational
,
knn_neighbors
=
10
,
knn_samples
=
10
,
)
...
...
@@ -168,6 +170,8 @@ def test_tune_search(
True
,
True
,
None
,
knn_neighbors
=
10
,
knn_samples
=
10
,
)
)[
1
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment