Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
cae7ee68
Commit
cae7ee68
authored
Sep 18, 2020
by
lucas_miranda
Browse files
Added tests for model_utils.py
parent
27b68380
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
cae7ee68
...
...
@@ -351,7 +351,7 @@ class MMDiscrepancyLayer(Layer):
return
z
class
Gaussian_mixture_overlap
(
Layer
):
class
Gaussian_mixture_overlap
(
Layer
):
# pragma: no cover
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (MMD, Wasserstein, Fischer-Rao)
...
...
@@ -365,6 +365,8 @@ class Gaussian_mixture_overlap(Layer):
super
(
Gaussian_mixture_overlap
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
"""Updates Constraint metadata"""
config
=
super
().
get_config
().
copy
()
config
.
update
({
"lat_dims"
:
self
.
lat_dims
})
config
.
update
({
"n_components"
:
self
.
n_components
})
...
...
@@ -372,12 +374,14 @@ class Gaussian_mixture_overlap(Layer):
config
.
update
({
"samples"
:
self
.
samples
})
return
config
def
call
(
self
,
target
,
loss
=
False
):
@
tf
.
function
def
call
(
self
,
target
,
**
kwargs
):
"""Updates Layer's call method"""
dists
=
[]
for
k
in
range
(
self
.
n_components
):
locs
=
(
target
[...,
:
self
.
lat_dims
,
k
],)
scales
=
tf
.
keras
.
activations
.
softplus
(
target
[...,
self
.
lat_dims
:,
k
])
scales
=
tf
.
keras
.
activations
.
softplus
(
target
[...,
self
.
lat_dims
:,
k
])
dists
.
append
(
tfd
.
BatchReshape
(
tfd
.
MultivariateNormalDiag
(
locs
,
scales
),
[
-
1
])
...
...
@@ -385,7 +389,7 @@ class Gaussian_mixture_overlap(Layer):
dists
=
[
tf
.
transpose
(
gauss
.
sample
(
self
.
samples
),
[
1
,
0
,
2
])
for
gauss
in
dists
]
#
##
MMD-based overlap #
##
# MMD-based overlap #
intercomponent_mmd
=
K
.
mean
(
tf
.
convert_to_tensor
(
[
...
...
@@ -415,13 +419,15 @@ class Dead_neuron_control(Layer):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Dead_neuron_control
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
call
(
self
,
z
,
z_gauss
,
z_cat
,
**
kwargs
):
# noinspection PyMethodOverriding
def
call
(
self
,
target
,
**
kwargs
):
"""Updates Layer's call method"""
# Adds metric that monitors dead neurons in the latent space
self
.
add_metric
(
tf
.
math
.
zero_fraction
(
z_gauss
),
aggregation
=
"mean"
,
name
=
"dead_neurons"
tf
.
math
.
zero_fraction
(
target
),
aggregation
=
"mean"
,
name
=
"dead_neurons"
)
return
z
return
target
class
Entropy_regulariser
(
Layer
):
...
...
@@ -429,18 +435,24 @@ class Entropy_regulariser(Layer):
Identity layer that adds cluster weight entropy to the loss function
"""
def
__init__
(
self
,
weight
=
1.0
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
weight
=
1.0
,
axis
=
1
,
*
args
,
**
kwargs
):
self
.
weight
=
weight
self
.
axis
=
axis
super
(
Entropy_regulariser
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
"""Updates Constraint metadata"""
config
=
super
().
get_config
().
copy
()
config
.
update
({
"weight"
:
self
.
weight
})
config
.
update
({
"axis"
:
self
.
axis
})
def
call
(
self
,
z
,
**
kwargs
):
"""Updates Layer's call method"""
# axis=1 increases the entropy of a cluster across instances
# axis=0 increases the entropy of the assignment for a given instance
entropy
=
K
.
sum
(
tf
.
multiply
(
z
+
1e-5
,
tf
.
math
.
log
(
z
)
+
1e-5
),
axis
=
1
)
entropy
=
K
.
sum
(
tf
.
multiply
(
z
+
1e-5
,
tf
.
math
.
log
(
z
)
+
1e-5
),
axis
=
self
.
axis
)
# Adds metric that monitors dead neurons in the latent space
self
.
add_metric
(
entropy
,
aggregation
=
"mean"
,
name
=
"-weight_entropy"
)
...
...
deepof/models.py
View file @
cae7ee68
...
...
@@ -335,6 +335,9 @@ class SEQ_2_SEQ_GMVAE:
z_gauss
=
Reshape
([
2
*
self
.
ENCODING
,
self
.
number_of_components
])(
z_gauss
)
# Identity layer controlling for dead neurons in the Gaussian Mixture posterior
z_gauss
=
Dead_neuron_control
()(
z_gauss
)
if
self
.
overlap_loss
:
z_gauss
=
Gaussian_mixture_overlap
(
self
.
ENCODING
,
self
.
number_of_components
,
loss
=
self
.
overlap_loss
,
...
...
@@ -387,9 +390,6 @@ class SEQ_2_SEQ_GMVAE:
batch_size
=
self
.
batch_size
,
prior
=
self
.
prior
,
beta
=
mmd_beta
)(
z
)
# Identity layer controlling clustering and latent space statistics
z
=
Dead_neuron_control
()(
z
,
z_gauss
,
z_cat
)
# Define and instantiate generator
generator
=
Model_D1
(
z
)
generator
=
Model_B1
(
generator
)
...
...
tests/test_model_utils.py
View file @
cae7ee68
...
...
@@ -188,7 +188,7 @@ def test_MMDiscrepancyLayer():
y
=
np
.
random
.
randint
(
0
,
2
,
[
1500
,
1
])
prior
=
tfd
.
Independent
(
tfd
.
Normal
(
loc
=
tf
.
zeros
(
10
),
scale
=
1
,
),
reinterpreted_batch_ndims
=
1
,
tfd
.
Normal
(
loc
=
tf
.
zeros
(
10
),
scale
=
1
,),
reinterpreted_batch_ndims
=
1
,
)
dense_1
=
tf
.
keras
.
layers
.
Dense
(
10
)
...
...
@@ -197,9 +197,10 @@ def test_MMDiscrepancyLayer():
d
=
dense_1
(
i
)
x
=
tfpl
.
DistributionLambda
(
lambda
dense
:
tfd
.
Independent
(
tfd
.
Normal
(
loc
=
dense
,
scale
=
1
,
),
reinterpreted_batch_ndims
=
1
,
tfd
.
Normal
(
loc
=
dense
,
scale
=
1
,),
reinterpreted_batch_ndims
=
1
,
)
)(
d
)
x
=
deepof
.
model_utils
.
MMDiscrepancyLayer
(
100
,
prior
,
beta
=
tf
.
keras
.
backend
.
variable
(
1.0
,
name
=
"kl_beta"
)
)(
x
)
...
...
@@ -213,21 +214,21 @@ def test_MMDiscrepancyLayer():
assert
type
(
fit
)
==
tf
.
python
.
keras
.
callbacks
.
History
#
#
# @settings(deadline=None
)
# @given()
# def
test_
gaussian_mixture_overlap
()
:
#
pass
#
#
# @settings(deadline=None)
# @given
()
# def test_dead_neuron_control():
# pass
#
#
# @settings(deadline=None)
# @given()
def
test_dead_neuron_control
():
X
=
np
.
random
.
uniform
(
0
,
10
,
[
1500
,
5
])
y
=
np
.
random
.
randint
(
0
,
2
,
[
1500
,
1
]
)
test_
model
=
tf
.
keras
.
Sequential
()
test_model
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
test_model
.
add
(
deepof
.
model_utils
.
Dead_neuron_control
())
test_model
.
compile
(
loss
=
tf
.
keras
.
losses
.
binary_crossentropy
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
,
)
fit
=
test_model
.
fit
(
X
,
y
,
epochs
=
10
,
batch_size
=
100
)
assert
type
(
fit
)
==
tf
.
python
.
keras
.
callbacks
.
History
# def test_entropy_regulariser():
# pass
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment