Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
acc06bfb
Commit
acc06bfb
authored
May 18, 2021
by
lucas_miranda
Browse files
Replaced for loop with vectorised mapping on ClusterOverlap regularization layer
parent
76a818fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
acc06bfb
...
...
@@ -15,6 +15,7 @@ import matplotlib.pyplot as plt
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
from
functools
import
partial
from
tensorflow.keras
import
backend
as
K
from
tensorflow.keras.constraints
import
Constraint
from
tensorflow.keras.layers
import
Layer
...
...
@@ -46,7 +47,7 @@ def get_k_nearest_neighbors(tensor, k, index):
@
tf
.
function
def
get_neighbourhood_entropy
(
tensor
,
clusters
,
k
,
index
):
def
get_neighbourhood_entropy
(
index
,
tensor
,
clusters
,
k
):
neighborhood
=
get_k_nearest_neighbors
(
tensor
,
k
,
index
)
cluster_z
=
tf
.
gather
(
clusters
,
neighborhood
)
neigh_entropy
=
compute_shannon_entropy
(
cluster_z
)
...
...
@@ -291,13 +292,14 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
range
(
encoding
.
shape
[
0
]),
self
.
samples
,
replace
=
False
)
# Add result to pre allocated array
purity_vector
=
np
.
zeros
(
self
.
samples
)
for
i
,
sample
in
enumerate
(
random_idxs
):
purity_vector
[
i
]
=
get_neighbourhood_entropy
(
encodings
,
hard_groups
,
self
.
k
,
sample
)
get_local_neighbourhood_entropy
=
partial
(
get_neighbourhood_entropy
,
tensor
=
encodings
,
clusters
=
hard_groups
,
k
=
self
.
k
,
dtype
=
tf
.
dtypes
.
float32
,
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
random_idxs
)
writer
=
tf
.
summary
.
create_file_writer
(
self
.
log_dir
)
with
writer
.
as_default
():
...
...
@@ -594,13 +596,16 @@ class ClusterOverlap(Layer):
tf
.
expand_dims
(
random_idxs
/
tf
.
reduce_sum
(
random_idxs
),
0
),
self
.
samples
)
purity_vector
=
tf
.
map_fn
(
get_neighbourhood_entropy
,
random_idxs
)
for
i
,
sample
in
enumerate
(
random_idxs
):
purity_vector
[
i
]
=
get_neighbourhood_entropy
(
encodings
,
hard_groups
,
self
.
k
,
sample
)
get_local_neighbourhood_entropy
=
partial
(
get_neighbourhood_entropy
,
tensor
=
encodings
,
clusters
=
hard_groups
,
k
=
self
.
k
,
dtype
=
tf
.
dtypes
.
float32
,
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
random_idxs
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy
=
purity_vector
*
max_groups
[
random_idxs
]
self
.
add_metric
(
...
...
deepof/models.py
View file @
acc06bfb
...
...
@@ -426,9 +426,7 @@ class GMVAE:
tfd
.
Independent
(
tfd
.
Normal
(
loc
=
gauss
[
1
][...,
:
self
.
ENCODING
,
k
],
scale
=
1e-3
+
softplus
(
gauss
[
1
][...,
self
.
ENCODING
:,
k
])
+
1e-5
,
scale
=
1e-3
+
softplus
(
gauss
[
1
][...,
self
.
ENCODING
:,
k
]),
),
reinterpreted_batch_ndims
=
1
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment