Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
9219111b
Commit
9219111b
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Prototyped KNN_purity callback
parent
56de58c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
9219111b
...
...
@@ -10,7 +10,7 @@ Functions and general utilities for the deepof tensorflow models. See documentat
from
itertools
import
combinations
from
typing
import
Any
,
Tuple
from
sklearn.neighbors
import
NearestNeighbors
from
tensorflow.keras
import
backend
as
K
from
tensorflow.keras.constraints
import
Constraint
from
tensorflow.keras.layers
import
Layer
...
...
@@ -203,7 +203,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def
__init__
(
self
,
k
=
5
,
samples
=
1000
):
def
__init__
(
self
,
k
=
5
,
samples
=
1000
0
):
super
().
__init__
()
self
.
k
=
k
self
.
samples
=
samples
...
...
@@ -227,16 +227,40 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
self
.
model
.
layers
[
0
].
input
,
cluster_assignment
.
output
)
trial_idxs
=
np
.
random
.
choice
(
range
(
self
.
validation_data
.
shape
[
0
]),
self
.
samples
# Use encoder and grouper to predict on validation data
encoding
=
encoder
.
predict
(
self
.
validation_data
)
groups
=
grouper
.
predict
(
self
.
validation_data
)
# Multiply encodings by groups, to get a weighted version of the matrix
encoding
=
(
encoding
*
tf
.
tile
(
groups
,
[
1
,
encoding
.
shape
[
1
]
//
groups
.
shape
[
1
]]).
numpy
()
)
trial_data
=
self
.
validation_data
[
trial_idxs
]
hard_groups
=
groups
.
argmax
(
axis
=
1
)
# Use encoder and grouper to predict on validation data
encoding
=
encoder
.
predict
(
trial_data
)
groups
=
grouper
.
predict
(
trial_data
)
# Fit KNN model
knn
=
NearestNeighbors
().
fit
(
encoding
)
#
# Iterate over samples and compute purity over k neighbours
random_idxs
=
np
.
random
.
choice
(
range
(
encoding
.
shape
[
0
]),
self
.
samples
,
replace
=
False
)
purity_vector
=
np
.
zeros
(
self
.
samples
)
for
i
,
sample
in
enumerate
(
random_idxs
):
indexes
=
knn
.
kneighbors
(
encoding
[
sample
][
np
.
newaxis
,
:],
self
.
k
,
return_distance
=
False
)
purity_vector
[
i
]
=
(
np
.
sum
(
hard_groups
[
indexes
]
==
hard_groups
[
sample
])
/
self
.
k
*
np
.
max
(
groups
[
sample
])
)
self
.
add_metric
(
self
.
purity_vector
,
aggregation
=
"mean"
,
name
=
"knn_cluster_purity"
,
)
class
uncorrelated_features_constraint
(
Constraint
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment