Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
78d8baef
Commit
78d8baef
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Prototyped KNN_purity callback
parent
a13aac38
Changes
1
Show whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
78d8baef
...
@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
...
@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self
.
iteration
+=
1
self
.
iteration
+=
1
K
.
set_value
(
self
.
model
.
optimizer
.
lr
,
rate
)
K
.
set_value
(
self
.
model
.
optimizer
.
lr
,
rate
)
def
on_batch_end
(
self
,
epoch
,
logs
=
None
):
logs
[
"learning_rate"
]
=
self
.
last_rate
"""Add current learning rate as a metric, to check whether scheduling is working properly"""
return
self
.
last_rate
class
knn_cluster_purity
(
tf
.
keras
.
callbacks
.
Callback
):
class
knn_cluster_purity
(
tf
.
keras
.
callbacks
.
Callback
):
...
@@ -208,12 +205,16 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
...
@@ -208,12 +205,16 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
def
on_epoch_end
(
self
,
batch
:
int
,
logs
):
def
on_epoch_end
(
self
,
batch
:
int
,
logs
):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
if
self
.
validation_data
is
not
None
:
# Get encoer and grouper from full model
# Get encoer and grouper from full model
cluster_means
=
[
cluster_means
=
[
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_means"
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_means"
][
0
]
][
0
]
cluster_assignment
=
[
cluster_assignment
=
[
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_assignment"
layer
for
layer
in
self
.
model
.
layers
if
layer
.
name
==
"cluster_assignment"
][
0
]
][
0
]
encoder
=
tf
.
keras
.
models
.
Model
(
encoder
=
tf
.
keras
.
models
.
Model
(
...
@@ -223,6 +224,8 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
...
@@ -223,6 +224,8 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
self
.
model
.
layers
[
0
].
input
,
cluster_assignment
.
output
self
.
model
.
layers
[
0
].
input
,
cluster_assignment
.
output
)
)
print
(
self
.
validation_data
)
# Use encoder and grouper to predict on validation data
# Use encoder and grouper to predict on validation data
encoding
=
encoder
.
predict
(
self
.
validation_data
)
encoding
=
encoder
.
predict
(
self
.
validation_data
)
groups
=
grouper
.
predict
(
self
.
validation_data
)
groups
=
grouper
.
predict
(
self
.
validation_data
)
...
@@ -252,7 +255,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
...
@@ -252,7 +255,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
*
np
.
max
(
groups
[
sample
])
*
np
.
max
(
groups
[
sample
])
)
)
return
purity_vector
.
mean
()
logs
[
"knn_cluster_purity"
]
=
purity_vector
.
mean
()
class
uncorrelated_features_constraint
(
Constraint
):
class
uncorrelated_features_constraint
(
Constraint
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment