Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
78d8baef
Commit
78d8baef
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Prototyped KNN_purity callback
parent
a13aac38
Changes
1
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
78d8baef
...
...
@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self
.
iteration
+=
1
K
.
set_value
(
self
.
model
.
optimizer
.
lr
,
rate
)
def
on_batch_end
(
self
,
epoch
,
logs
=
None
):
"""Add current learning rate as a metric, to check whether scheduling is working properly"""
return
self
.
last_rate
logs
[
"learning_rate"
]
=
self
.
last_rate
class
knn_cluster_purity
(
tf
.
keras
.
callbacks
.
Callback
):
...
...
@@ -208,51 +205,57 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
def
on_epoch_end
(
self
,
batch
:
int
,
logs
):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
# Get encoer and grouper from full model
cluster_means
=
[
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_means"
][
0
]
cluster_assignment
=
[
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_assignment"
][
0
]
if
self
.
validation_data
is
not
None
:
encoder
=
tf
.
keras
.
models
.
Model
(
self
.
model
.
layers
[
0
].
input
,
cluster_means
.
output
)
grouper
=
tf
.
keras
.
models
.
Model
(
self
.
model
.
layers
[
0
].
input
,
cluster_assignment
.
output
)
# Get encoer and grouper from full model
cluster_means
=
[
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_means"
][
0
]
cluster_assignment
=
[
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_assignment"
][
0
]
# Use encoder and grouper to predict on validation data
encoding
=
encoder
.
predict
(
self
.
validation_data
)
groups
=
grouper
.
predict
(
self
.
validation_data
)
encoder
=
tf
.
keras
.
models
.
Model
(
self
.
model
.
layers
[
0
].
input
,
cluster_means
.
output
)
grouper
=
tf
.
keras
.
models
.
Model
(
self
.
model
.
layers
[
0
].
input
,
cluster_assignment
.
output
)
# Multiply encodings by groups, to get a weighted version of the matrix
encoding
=
(
encoding
*
tf
.
tile
(
groups
,
[
1
,
encoding
.
shape
[
1
]
//
groups
.
shape
[
1
]]).
numpy
()
)
hard_groups
=
groups
.
argmax
(
axis
=
1
)
print
(
self
.
validation_data
)
# Fit KNN model
knn
=
NearestNeighbors
().
fit
(
encoding
)
# Use encoder and grouper to predict on validation data
encoding
=
encoder
.
predict
(
self
.
validation_data
)
groups
=
grouper
.
predict
(
self
.
validation_data
)
# Iterate over samples and compute purity over k neighbours
random_idxs
=
np
.
random
.
choice
(
range
(
encoding
.
shape
[
0
]),
self
.
samples
,
replace
=
False
)
purity_vector
=
np
.
zeros
(
self
.
samples
)
for
i
,
sample
in
enumerate
(
random_idxs
):
indexes
=
knn
.
kneighbors
(
encoding
[
sample
][
np
.
newaxis
,
:],
self
.
k
,
return_distance
=
False
# Multiply encodings by groups, to get a weighted version of the matrix
encoding
=
(
encoding
*
tf
.
tile
(
groups
,
[
1
,
encoding
.
shape
[
1
]
//
groups
.
shape
[
1
]]).
numpy
()
)
purity_vector
[
i
]
=
(
np
.
sum
(
hard_groups
[
indexes
]
==
hard_groups
[
sample
])
/
self
.
k
*
np
.
max
(
groups
[
sample
])
hard_groups
=
groups
.
argmax
(
axis
=
1
)
# Fit KNN model
knn
=
NearestNeighbors
().
fit
(
encoding
)
# Iterate over samples and compute purity over k neighbours
random_idxs
=
np
.
random
.
choice
(
range
(
encoding
.
shape
[
0
]),
self
.
samples
,
replace
=
False
)
purity_vector
=
np
.
zeros
(
self
.
samples
)
for
i
,
sample
in
enumerate
(
random_idxs
):
indexes
=
knn
.
kneighbors
(
encoding
[
sample
][
np
.
newaxis
,
:],
self
.
k
,
return_distance
=
False
)
purity_vector
[
i
]
=
(
np
.
sum
(
hard_groups
[
indexes
]
==
hard_groups
[
sample
])
/
self
.
k
*
np
.
max
(
groups
[
sample
])
)
return
purity_vector
.
mean
()
logs
[
"knn_cluster_purity"
]
=
purity_vector
.
mean
()
class
uncorrelated_features_constraint
(
Constraint
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment