Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
31a78518
Commit
31a78518
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Prototyped KNN_purity callback
parent
78d8baef
Changes
2
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
31a78518
...
...
@@ -144,6 +144,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
start_rate
:
float
=
None
,
last_iterations
:
int
=
None
,
last_rate
:
float
=
None
,
log_dir
:
str
=
"."
,
):
super
().
__init__
()
self
.
iterations
=
iterations
...
...
@@ -154,6 +155,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self
.
last_rate
=
last_rate
or
self
.
start_rate
/
1000
self
.
iteration
=
0
self
.
history
=
{}
self
.
log_dir
=
log_dir
def
_interpolate
(
self
,
iter1
:
int
,
iter2
:
int
,
rate1
:
float
,
rate2
:
float
)
->
float
:
return
(
rate2
-
rate1
)
*
(
self
.
iteration
-
iter1
)
/
(
iter2
-
iter1
)
+
rate1
...
...
@@ -186,7 +188,15 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self
.
iteration
+=
1
K
.
set_value
(
self
.
model
.
optimizer
.
lr
,
rate
)
logs
[
"learning_rate"
]
=
self
.
last_rate
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
"""Logs the learning rate to tensorboard"""
writer
=
tf
.
summary
.
create_file_writer
(
self
.
log_dir
)
with
writer
.
as_default
():
tf
.
summary
.
scalar
(
"learning_rate"
,
data
=
self
.
model
.
optimizer
.
lr
,
step
=
epoch
,
)
class
knn_cluster_purity
(
tf
.
keras
.
callbacks
.
Callback
):
...
...
@@ -196,13 +206,14 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def
__init__
(
self
,
k
=
5
,
samples
=
10000
):
def
__init__
(
self
,
validation_data
=
None
,
k
=
100
,
samples
=
10000
):
super
().
__init__
()
self
.
validation_data
=
validation_data
self
.
k
=
k
self
.
samples
=
samples
# noinspection PyMethodOverriding,PyTypeChecker
def
on_epoch_end
(
self
,
batch
:
int
,
logs
):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
if
self
.
validation_data
is
not
None
:
...
...
@@ -255,7 +266,9 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
*
np
.
max
(
groups
[
sample
])
)
logs
[
"knn_cluster_purity"
]
=
purity_vector
.
mean
()
tf
.
summary
.
scalar
(
"knn_cluster_purity"
,
data
=
purity_vector
.
mean
(),
step
=
epoch
,
)
class
uncorrelated_features_constraint
(
Constraint
):
...
...
deepof/train_utils.py
View file @
31a78518
...
...
@@ -119,6 +119,7 @@ def get_callbacks(
onecycle
=
deepof
.
model_utils
.
one_cycle_scheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
log_dir
=
os
.
path
.
join
(
outpath
,
"metrics"
)
)
callbacks
=
[
run_ID
,
tensorboard_callback
,
knn
,
onecycle
]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment