Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
9f0a9f31
Commit
9f0a9f31
authored
Mar 12, 2021
by
lucas_miranda
Browse files
Prototyped KNN_purity callback
parent
31a78518
Pipeline
#95526
canceled with stages
in 25 minutes and 22 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
9f0a9f31
...
...
@@ -195,7 +195,9 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
with
writer
.
as_default
():
tf
.
summary
.
scalar
(
"learning_rate"
,
data
=
self
.
model
.
optimizer
.
lr
,
step
=
epoch
,
"learning_rate"
,
data
=
self
.
model
.
optimizer
.
lr
,
step
=
epoch
,
)
...
...
@@ -206,11 +208,12 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def
__init__
(
self
,
validation_data
=
None
,
k
=
100
,
samples
=
10000
):
def
__init__
(
self
,
validation_data
=
None
,
k
=
100
,
samples
=
10000
,
log_dir
=
"."
):
super
().
__init__
()
self
.
validation_data
=
validation_data
self
.
k
=
k
self
.
samples
=
samples
self
.
log_dir
=
log_dir
# noinspection PyMethodOverriding,PyTypeChecker
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
...
...
@@ -266,9 +269,13 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
*
np
.
max
(
groups
[
sample
])
)
tf
.
summary
.
scalar
(
"knn_cluster_purity"
,
data
=
purity_vector
.
mean
(),
step
=
epoch
,
)
writer
=
tf
.
summary
.
create_file_writer
(
self
.
log_dir
)
with
writer
.
as_default
():
tf
.
summary
.
scalar
(
"knn_cluster_purity"
,
data
=
purity_vector
.
mean
(),
step
=
epoch
,
)
class
uncorrelated_features_constraint
(
Constraint
):
...
...
deepof/train_utils.py
View file @
9f0a9f31
...
...
@@ -119,7 +119,7 @@ def get_callbacks(
onecycle
=
deepof
.
model_utils
.
one_cycle_scheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
log_dir
=
os
.
path
.
join
(
outpath
,
"metrics"
)
log_dir
=
os
.
path
.
join
(
outpath
,
"metrics"
)
,
)
callbacks
=
[
run_ID
,
tensorboard_callback
,
knn
,
onecycle
]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment