Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
206a376d
Commit
206a376d
authored
Jul 31, 2020
by
lucas_miranda
Browse files
Minimise entropy to see if overal confidence increases in a reproducible way
parent
d96c8c4c
Changes
2
Hide whitespace changes
Inline
Side-by-side
source/model_utils.py
View file @
206a376d
...
@@ -169,9 +169,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
...
@@ -169,9 +169,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def
get_config
(
self
):
def
get_config
(
self
):
config
=
super
().
get_config
().
copy
()
config
=
super
().
get_config
().
copy
()
config
.
update
(
config
.
update
(
{
{
"is_placeholder"
:
self
.
is_placeholder
,}
"is_placeholder"
:
self
.
is_placeholder
,
}
)
)
return
config
return
config
...
@@ -357,7 +355,7 @@ class Entropy_regulariser(Layer):
...
@@ -357,7 +355,7 @@ class Entropy_regulariser(Layer):
# axis=1 increases the entropy of a cluster across instances
# axis=1 increases the entropy of a cluster across instances
# axis=0 increases the entropy of the assignment for a given instance
# axis=0 increases the entropy of the assignment for a given instance
entropy
=
-
K
.
sum
(
tf
.
multiply
(
z
+
1e-5
,
tf
.
math
.
log
(
z
)
+
1e-5
),
axis
=
1
)
entropy
=
-
K
.
sum
(
tf
.
multiply
(
z
+
1e-5
,
tf
.
math
.
log
(
z
)
+
1e-5
),
axis
=
1
)
# Adds metric that monitors dead neurons in the latent space
# Adds metric that monitors dead neurons in the latent space
self
.
add_metric
(
entropy
,
aggregation
=
"mean"
,
name
=
"-weight_entropy"
)
self
.
add_metric
(
entropy
,
aggregation
=
"mean"
,
name
=
"-weight_entropy"
)
...
...
source/models.py
View file @
206a376d
...
@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE:
...
@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components
=
1
,
number_of_components
=
1
,
predictor
=
True
,
predictor
=
True
,
overlap_loss
=
False
,
overlap_loss
=
False
,
entropy_reg_weight
=
0.25
,
entropy_reg_weight
=
1.0
,
):
):
self
.
input_shape
=
input_shape
self
.
input_shape
=
input_shape
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -315,7 +315,9 @@ class SEQ_2_SEQ_GMVAE:
...
@@ -315,7 +315,9 @@ class SEQ_2_SEQ_GMVAE:
encoder
=
BatchNormalization
()(
encoder
)
encoder
=
BatchNormalization
()(
encoder
)
encoding_shuffle
=
MCDropout
(
self
.
DROPOUT_RATE
)(
encoder
)
encoding_shuffle
=
MCDropout
(
self
.
DROPOUT_RATE
)(
encoder
)
z_cat
=
Dense
(
self
.
number_of_components
,
activation
=
"softmax"
,)(
encoding_shuffle
)
z_cat
=
Dense
(
self
.
number_of_components
,
activation
=
"softmax"
,)(
encoding_shuffle
)
z_cat
=
Entropy_regulariser
(
self
.
entropy_reg_weight
)(
z_cat
)
z_cat
=
Entropy_regulariser
(
self
.
entropy_reg_weight
)(
z_cat
)
z_gauss
=
Dense
(
z_gauss
=
Dense
(
tfpl
.
IndependentNormal
.
params_size
(
tfpl
.
IndependentNormal
.
params_size
(
...
@@ -468,7 +470,7 @@ class SEQ_2_SEQ_GMVAE:
...
@@ -468,7 +470,7 @@ class SEQ_2_SEQ_GMVAE:
gmvaep
.
compile
(
gmvaep
.
compile
(
loss
=
huber_loss
,
loss
=
huber_loss
,
optimizer
=
Nadam
(
lr
=
self
.
learn_rate
),
optimizer
=
Nadam
(
lr
=
self
.
learn_rate
,
clipvalue
=
0.5
,
),
metrics
=
[
"mae"
],
metrics
=
[
"mae"
],
loss_weights
=
([
1
,
self
.
predictor
]
if
self
.
predictor
>
0
else
[
1
]),
loss_weights
=
([
1
,
self
.
predictor
]
if
self
.
predictor
>
0
else
[
1
]),
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment