Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
abc4fd8f
Commit
abc4fd8f
authored
May 18, 2021
by
lucas_miranda
Browse files
Replaced for loop with vectorised mapping on ClusterOverlap regularization layer
parent
d4374203
Changes
4
Hide whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
abc4fd8f
...
@@ -893,6 +893,7 @@ class coordinates:
...
@@ -893,6 +893,7 @@ class coordinates:
mmd_warmup
:
int
=
0
,
mmd_warmup
:
int
=
0
,
montecarlo_kl
:
int
=
10
,
montecarlo_kl
:
int
=
10
,
n_components
:
int
=
25
,
n_components
:
int
=
25
,
overlap_loss
:
float
=
0
,
output_path
:
str
=
"."
,
output_path
:
str
=
"."
,
next_sequence_prediction
:
float
=
0
,
next_sequence_prediction
:
float
=
0
,
phenotype_prediction
:
float
=
0
,
phenotype_prediction
:
float
=
0
,
...
@@ -958,6 +959,7 @@ class coordinates:
...
@@ -958,6 +959,7 @@ class coordinates:
mmd_warmup
=
mmd_warmup
,
mmd_warmup
=
mmd_warmup
,
montecarlo_kl
=
montecarlo_kl
,
montecarlo_kl
=
montecarlo_kl
,
n_components
=
n_components
,
n_components
=
n_components
,
overlap_loss
=
overlap_loss
,
output_path
=
output_path
,
output_path
=
output_path
,
next_sequence_prediction
=
next_sequence_prediction
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
phenotype_prediction
=
phenotype_prediction
,
...
...
deepof/model_utils.py
View file @
abc4fd8f
...
@@ -295,10 +295,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
...
@@ -295,10 +295,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
@
tf
.
function
@
tf
.
function
def
get_local_neighbourhood_entropy
(
index
):
def
get_local_neighbourhood_entropy
(
index
):
return
get_neighbourhood_entropy
(
return
get_neighbourhood_entropy
(
index
,
tensor
=
encoding
s
,
clusters
=
hard_groups
,
k
=
self
.
k
index
,
tensor
=
encoding
,
clusters
=
hard_groups
,
k
=
self
.
k
)
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
random_idxs
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
random_idxs
,
dtype
=
tf
.
dtypes
.
float32
)
writer
=
tf
.
summary
.
create_file_writer
(
self
.
log_dir
)
writer
=
tf
.
summary
.
create_file_writer
(
self
.
log_dir
)
with
writer
.
as_default
():
with
writer
.
as_default
():
...
@@ -558,7 +560,7 @@ class ClusterOverlap(Layer):
...
@@ -558,7 +560,7 @@ class ClusterOverlap(Layer):
self
,
self
,
encoding_dim
:
int
,
encoding_dim
:
int
,
k
:
int
=
100
,
k
:
int
=
100
,
loss_weight
:
float
=
False
,
loss_weight
:
float
=
0.0
,
samples
:
int
=
512
,
samples
:
int
=
512
,
*
args
,
*
args
,
**
kwargs
**
kwargs
...
@@ -581,19 +583,19 @@ class ClusterOverlap(Layer):
...
@@ -581,19 +583,19 @@ class ClusterOverlap(Layer):
config
.
update
({
"samples"
:
self
.
samples
})
config
.
update
({
"samples"
:
self
.
samples
})
return
config
return
config
@
tf
.
function
#
@tf.function
def
call
(
self
,
encodings
,
categorical
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
"""Updates Layer's call method"""
"""Updates Layer's call method"""
encodings
,
categorical
=
inputs
[
0
],
inputs
[
1
]
hard_groups
=
tf
.
math
.
argmax
(
categorical
,
axis
=
1
)
hard_groups
=
tf
.
math
.
argmax
(
categorical
,
axis
=
1
)
max_groups
=
tf
.
reduce_max
(
categorical
,
axis
=
1
)
max_groups
=
tf
.
reduce_max
(
categorical
,
axis
=
1
)
# Iterate over samples and compute purity across neighbourhood
# Iterate over samples and compute purity across neighbourhood
self
.
samples
=
tf
.
reduce_min
([
self
.
samples
,
encodings
.
shape
[
0
]])
self
.
samples
=
tf
.
reduce_min
([
self
.
samples
,
tf
.
shape
(
encodings
)[
0
]])
random_idxs
=
range
(
encoding
.
shape
[
0
])
random_idxs
=
range
(
encodings
.
shape
[
0
])
random_idxs
=
tf
.
random
.
categorical
(
random_idxs
=
np
.
random
.
choice
(
random_idxs
,
self
.
samples
)
tf
.
expand_dims
(
random_idxs
/
tf
.
reduce_sum
(
random_idxs
),
0
),
self
.
samples
)
@
tf
.
function
@
tf
.
function
def
get_local_neighbourhood_entropy
(
index
):
def
get_local_neighbourhood_entropy
(
index
):
...
@@ -601,7 +603,9 @@ class ClusterOverlap(Layer):
...
@@ -601,7 +603,9 @@ class ClusterOverlap(Layer):
index
,
tensor
=
encodings
,
clusters
=
hard_groups
,
k
=
self
.
k
index
,
tensor
=
encodings
,
clusters
=
hard_groups
,
k
=
self
.
k
)
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
random_idxs
)
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
tf
.
constant
(
random_idxs
),
dtype
=
tf
.
dtypes
.
float32
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy
=
purity_vector
*
max_groups
[
random_idxs
]
neighbourhood_entropy
=
purity_vector
*
max_groups
[
random_idxs
]
...
@@ -623,6 +627,8 @@ class ClusterOverlap(Layer):
...
@@ -623,6 +627,8 @@ class ClusterOverlap(Layer):
)
)
if
self
.
loss_weight
:
if
self
.
loss_weight
:
self
.
add_loss
(
neighbourhood_entropy
,
inputs
=
[
target
,
categorical
])
self
.
add_loss
(
self
.
loss_weight
*
neighbourhood_entropy
,
inputs
=
[
target
,
categorical
]
)
return
encodings
return
encodings
deepof/train_model.py
View file @
abc4fd8f
...
@@ -176,9 +176,9 @@ parser.add_argument(
...
@@ -176,9 +176,9 @@ parser.add_argument(
parser
.
add_argument
(
parser
.
add_argument
(
"--overlap-loss"
,
"--overlap-loss"
,
"-ol"
,
"-ol"
,
help
=
"If
True
, adds
the negative MMD between all components of the latent Gaussian mixture to the loss function
"
,
help
=
"If
> 0
, adds
a regularization term controlling for local cluster assignment entropy in the latent space
"
,
type
=
deepof
.
utils
.
str2bool
,
type
=
float
,
default
=
False
,
default
=
0
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--next-sequence-prediction"
,
"--next-sequence-prediction"
,
...
@@ -263,7 +263,7 @@ mmd_annealing_mode = args.mmd_annealing_mode
...
@@ -263,7 +263,7 @@ mmd_annealing_mode = args.mmd_annealing_mode
mmd_wu
=
args
.
mmd_warmup
mmd_wu
=
args
.
mmd_warmup
mc_kl
=
args
.
montecarlo_kl
mc_kl
=
args
.
montecarlo_kl
output_path
=
os
.
path
.
join
(
args
.
output_path
)
output_path
=
os
.
path
.
join
(
args
.
output_path
)
overlap_loss
=
args
.
overlap_loss
overlap_loss
=
float
(
args
.
overlap_loss
)
next_sequence_prediction
=
float
(
args
.
next_sequence_prediction
)
next_sequence_prediction
=
float
(
args
.
next_sequence_prediction
)
phenotype_prediction
=
float
(
args
.
phenotype_prediction
)
phenotype_prediction
=
float
(
args
.
phenotype_prediction
)
rule_based_prediction
=
float
(
args
.
rule_based_prediction
)
rule_based_prediction
=
float
(
args
.
rule_based_prediction
)
...
@@ -397,6 +397,7 @@ if not tune:
...
@@ -397,6 +397,7 @@ if not tune:
montecarlo_kl
=
mc_kl
,
montecarlo_kl
=
mc_kl
,
n_components
=
k
,
n_components
=
k
,
output_path
=
output_path
,
output_path
=
output_path
,
overlap_loss
=
overlap_loss
,
next_sequence_prediction
=
next_sequence_prediction
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
rule_based_prediction
=
rule_based_prediction
,
...
...
deepof/train_utils.py
View file @
abc4fd8f
...
@@ -291,6 +291,7 @@ def autoencoder_fitting(
...
@@ -291,6 +291,7 @@ def autoencoder_fitting(
montecarlo_kl
:
int
,
montecarlo_kl
:
int
,
n_components
:
int
,
n_components
:
int
,
output_path
:
str
,
output_path
:
str
,
overlap_loss
:
float
,
next_sequence_prediction
:
float
,
next_sequence_prediction
:
float
,
phenotype_prediction
:
float
,
phenotype_prediction
:
float
,
rule_based_prediction
:
float
,
rule_based_prediction
:
float
,
...
@@ -394,7 +395,7 @@ def autoencoder_fitting(
...
@@ -394,7 +395,7 @@ def autoencoder_fitting(
mmd_warmup_epochs
=
mmd_warmup
,
mmd_warmup_epochs
=
mmd_warmup
,
montecarlo_kl
=
montecarlo_kl
,
montecarlo_kl
=
montecarlo_kl
,
number_of_components
=
n_components
,
number_of_components
=
n_components
,
overlap_loss
=
False
,
overlap_loss
=
overlap_loss
,
next_sequence_prediction
=
next_sequence_prediction
,
next_sequence_prediction
=
next_sequence_prediction
,
phenotype_prediction
=
phenotype_prediction
,
phenotype_prediction
=
phenotype_prediction
,
rule_based_prediction
=
rule_based_prediction
,
rule_based_prediction
=
rule_based_prediction
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment