Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
9a805f04
Commit
9a805f04
authored
May 18, 2021
by
lucas_miranda
Browse files
Replaced for loop with vectorised mapping on ClusterOverlap regularization layer
parent
5b68125c
Changes
2
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
9a805f04
...
...
@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@
tf
.
function
def
get_k_nearest_neighbors
(
tensor
,
k
,
index
):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query
=
tensor
[
index
]
query
=
tf
.
gather
(
tensor
,
index
)
distances
=
tf
.
norm
(
tensor
-
query
,
axis
=
1
)
max_distance
=
tf
.
sort
(
distances
)[
k
]
neighbourhood_mask
=
distances
<
max_distance
...
...
@@ -558,13 +558,15 @@ class ClusterOverlap(Layer):
def
__init__
(
self
,
batch_size
:
int
,
encoding_dim
:
int
,
k
:
int
=
100
,
loss_weight
:
float
=
0.0
,
samples
:
int
=
5
12
,
samples
:
int
=
5
0
,
*
args
,
**
kwargs
):
self
.
batch_size
=
batch_size
self
.
enc
=
encoding_dim
self
.
k
=
k
self
.
loss_weight
=
loss_weight
...
...
@@ -576,6 +578,7 @@ class ClusterOverlap(Layer):
"""Updates Constraint metadata"""
config
=
super
().
get_config
().
copy
()
config
.
update
({
"batch_size"
:
self
.
batch_size
})
config
.
update
({
"enc"
:
self
.
enc
})
config
.
update
({
"k"
:
self
.
k
})
config
.
update
({
"loss_weight"
:
self
.
loss_weight
})
...
...
@@ -583,7 +586,6 @@ class ClusterOverlap(Layer):
config
.
update
({
"samples"
:
self
.
samples
})
return
config
@
tf
.
function
def
call
(
self
,
inputs
,
**
kwargs
):
"""Updates Layer's call method"""
...
...
@@ -593,46 +595,42 @@ class ClusterOverlap(Layer):
max_groups
=
tf
.
reduce_max
(
categorical
,
axis
=
1
)
# Iterate over samples and compute purity across neighbourhood
self
.
samples
=
tf
.
reduce_min
([
self
.
samples
,
tf
.
shape
(
encodings
)[
0
]])
random_idxs
=
tf
.
range
(
tf
.
shape
(
encodings
)[
0
])
random_idxs
=
tf
.
random
.
categorical
(
tf
.
expand_dims
(
random_idxs
/
tf
.
reduce_sum
(
random_idxs
),
0
),
self
.
samples
,
dtype
=
tf
.
dtypes
.
int32
,
)
self
.
samples
=
np
.
min
([
self
.
samples
,
self
.
batch_size
])
# convert to numpy
random_idxs
=
range
(
self
.
batch_size
)
# convert to batch size
random_idxs
=
np
.
random
.
choice
(
random_idxs
,
self
.
samples
)
@
tf
.
function
def
get_local_neighbourhood_entropy
(
index
):
return
get_neighbourhood_entropy
(
index
,
tensor
=
encodings
,
clusters
=
hard_groups
,
k
=
self
.
k
)
get_local_neighbourhood_entropy
=
partial
(
get_neighbourhood_entropy
,
tensor
=
encodings
,
clusters
=
hard_groups
,
k
=
self
.
k
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
random_idxs
,
dtype
=
tf
.
dtypes
.
float32
get_local_neighbourhood_entropy
,
tf
.
constant
(
random_idxs
),
dtype
=
tf
.
dtypes
.
float32
,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy
=
purity_vector
*
max_groups
[
random_idxs
]
self
.
add_metric
(
len
(
set
(
hard_groups
[
max_groups
>=
self
.
min_confidence
])),
aggregation
=
"mean"
,
name
=
"number_of_populated_clusters"
,
)
self
.
add_metric
(
max_groups
,
aggregation
=
"mean"
,
name
=
"average_confidence_in_selected_cluster"
,
)
neighbourhood_entropy
=
purity_vector
#
* max_groups[random_idxs]
#
self.add_metric(
#
len(set(hard_groups[max_groups >= self.min_confidence])),
#
aggregation="mean",
#
name="number_of_populated_clusters",
#
)
#
#
self.add_metric(
#
max_groups,
#
aggregation="mean",
#
name="average_confidence_in_selected_cluster",
#
)
self
.
add_metric
(
neighbourhood_entropy
,
aggregation
=
"mean"
,
name
=
"neighbourhood_entropy"
)
if
self
.
loss_weight
:
self
.
add_loss
(
self
.
loss_weight
*
neighbourhood_entropy
,
inputs
=
[
target
,
categorical
]
)
#
if self.loss_weight:
#
self.add_loss(
#
self.loss_weight * neighbourhood_entropy, inputs=
inputs
#
)
return
encodings
deepof/models.py
View file @
9a805f04
...
...
@@ -475,6 +475,7 @@ class GMVAE:
if
self
.
overlap_loss
:
z
=
deepof
.
model_utils
.
ClusterOverlap
(
self
.
batch_size
,
self
.
ENCODING
,
self
.
number_of_components
,
loss_weight
=
self
.
overlap_loss
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment