Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
b5d3c717
Commit
b5d3c717
authored
Jul 20, 2020
by
lucas_miranda
Browse files
Added maximum entropy regulariser to the gaussian mixture weight layer
parent
9e646727
Changes
3
Hide whitespace changes
Inline
Side-by-side
model_training.py
View file @
b5d3c717
...
...
@@ -232,7 +232,7 @@ DLC_social_1_coords = DLC_social_1.run(verbose=True)
DLC_social_2_coords
=
DLC_social_2
.
run
(
verbose
=
True
)
# Coordinates for training data
coords1
=
DLC_social_1_coords
.
get_coords
(
center
=
"B_Center"
)
coords1
=
DLC_social_1_coords
.
get_coords
(
center
=
"B_Center"
,
polar
=
True
)
distances1
=
DLC_social_1_coords
.
get_distances
()
angles1
=
DLC_social_1_coords
.
get_angles
()
coords_distances1
=
merge_tables
(
coords1
,
distances1
)
...
...
@@ -241,7 +241,7 @@ dists_angles1 = merge_tables(distances1, angles1)
coords_dist_angles1
=
merge_tables
(
coords1
,
distances1
,
angles1
)
# Coordinates for validation data
coords2
=
DLC_social_2_coords
.
get_coords
(
center
=
"B_Center"
)
coords2
=
DLC_social_2_coords
.
get_coords
(
center
=
"B_Center"
,
polar
=
True
)
distances2
=
DLC_social_2_coords
.
get_distances
()
angles2
=
DLC_social_2_coords
.
get_angles
()
coords_distances2
=
merge_tables
(
coords2
,
distances2
)
...
...
source/model_utils.py
View file @
b5d3c717
...
...
@@ -268,3 +268,30 @@ class Latent_space_control(Layer):
self
.
add_loss
(
-
K
.
mean
(
silhouette
),
inputs
=
[
z
,
hard_labels
])
return
z
class
Entropy_regulariser
(
Layer
):
"""
Identity layer that adds cluster weight entropy to the loss function
"""
def
__init__
(
self
,
weight
=
False
,
*
args
,
**
kwargs
):
self
.
weight
=
weight
super
(
Entropy_regulariser
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
config
=
super
().
get_config
().
copy
()
config
.
update
({
"weight"
:
self
.
weight
})
def
call
(
self
,
z
,
**
kwargs
):
entropy
=
K
.
sum
(
tf
.
multiply
(
z
,
tf
.
where
(
~
tf
.
math
.
is_inf
(
K
.
log
(
z
)),
K
.
log
(
z
),
0
)),
axis
=
0
)
# Adds metric that monitors dead neurons in the latent space
self
.
add_metric
(
-
entropy
,
aggregation
=
"mean"
,
name
=
"weight_entropy"
)
self
.
add_loss
(
-
K
.
mean
(
entropy
),
inputs
=
[
z
])
return
z
source/models.py
View file @
b5d3c717
...
...
@@ -173,6 +173,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components
=
1
,
predictor
=
True
,
overlap_loss
=
False
,
entropy_reg_weight
=
1
,
):
self
.
input_shape
=
input_shape
self
.
batch_size
=
batch_size
...
...
@@ -191,6 +192,7 @@ class SEQ_2_SEQ_GMVAE:
self
.
number_of_components
=
number_of_components
self
.
predictor
=
predictor
self
.
overlap_loss
=
overlap_loss
self
.
entropy_reg_weight
=
entropy_reg_weight
if
self
.
prior
==
"standard_normal"
:
...
...
@@ -302,6 +304,7 @@ class SEQ_2_SEQ_GMVAE:
encoder
=
BatchNormalization
()(
encoder
)
z_cat
=
Dense
(
self
.
number_of_components
,
activation
=
"softmax"
,)(
encoder
)
z_cat
=
Entropy_regulariser
(
self
.
entropy_reg_weight
)(
z_cat
)
z_gauss
=
Dense
(
tfpl
.
IndependentNormal
.
params_size
(
self
.
ENCODING
*
self
.
number_of_components
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment