Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
2d21ca43
Commit
2d21ca43
authored
Feb 10, 2021
by
lucas_miranda
Browse files
Added latent space regularization options to GMVAE
parent
331b7014
Pipeline
#93201
passed with stage
in 54 minutes and 28 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/models.py
View file @
2d21ca43
...
...
@@ -258,6 +258,8 @@ class SEQ_2_SEQ_GMVAE:
overlap_loss
:
float
=
False
,
phenotype_prediction
:
float
=
0.0
,
predictor
:
float
=
0.0
,
reg_cat_clusters
:
bool
=
False
,
reg_cluster_variance
:
bool
=
False
,
):
self
.
hparams
=
self
.
get_hparams
(
architecture_hparams
)
self
.
batch_size
=
batch_size
...
...
@@ -288,6 +290,8 @@ class SEQ_2_SEQ_GMVAE:
self
.
phenotype_prediction
=
phenotype_prediction
self
.
predictor
=
predictor
self
.
prior
=
"standard_normal"
self
.
reg_cat_clusters
=
reg_cat_clusters
self
.
reg_cluster_variance
=
reg_cluster_variance
assert
(
"ELBO"
in
self
.
loss
or
"MMD"
in
self
.
loss
...
...
@@ -564,7 +568,11 @@ class SEQ_2_SEQ_GMVAE:
z_cat
=
Dense
(
self
.
number_of_components
,
activation
=
"softmax"
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l1_l2
(
l1
=
0.01
,
l2
=
0.01
),
kernel_regularizer
=
(
tf
.
keras
.
regularizers
.
l1_l2
(
l1
=
0.01
,
l2
=
0.01
)
if
self
.
reg_cat_clusters
else
None
),
)(
encoder
)
if
self
.
entropy_reg_weight
>
0
:
...
...
@@ -572,14 +580,26 @@ class SEQ_2_SEQ_GMVAE:
z_cat
)
z_gauss
=
Dense
(
z_gauss
_mean
=
Dense
(
tfpl
.
IndependentNormal
.
params_size
(
self
.
ENCODING
*
self
.
number_of_components
),
)
//
2
,
activation
=
None
,
)(
encoder
)
z_gauss_var
=
Dense
(
tfpl
.
IndependentNormal
.
params_size
(
self
.
ENCODING
*
self
.
number_of_components
)
//
2
,
activation
=
None
,
)(
encoder
)
# REMOVE BIAS FROM HERE! WHAT's THE POINT?
activity_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
0.01
)
if
self
.
reg_cluster_variance
else
None
),
)(
encoder
)
z_gauss
=
tf
.
keras
.
layers
.
concatenate
([
z_gauss_mean
,
z_gauss_var
],
axis
=
1
)
z_gauss
=
Reshape
([
2
*
self
.
ENCODING
,
self
.
number_of_components
])(
z_gauss
)
...
...
deepof/pose_utils.py
View file @
2d21ca43
...
...
@@ -219,6 +219,42 @@ def huddle(
return
hudd
def
dig
(
pos_dframe
:
pd
.
DataFrame
,
speed_dframe
:
pd
.
DataFrame
,
likelihood_dframe
:
pd
.
DataFrame
,
tol_nose_speed
:
float
,
tol_speed
:
float
,
tol_likelihood
:
float
,
animal_id
:
str
=
""
,
):
pass
def
sniff
(
pos_dframe
:
pd
.
DataFrame
,
speed_dframe
:
pd
.
DataFrame
,
likelihood_dframe
:
pd
.
DataFrame
,
tol_nose_speed
:
float
,
tol_speed
:
float
,
tol_likelihood
:
float
,
animal_id
:
str
=
""
,
):
pass
def
look_around
(
pos_dframe
:
pd
.
DataFrame
,
speed_dframe
:
pd
.
DataFrame
,
likelihood_dframe
:
pd
.
DataFrame
,
tol_nose_speed
:
float
,
tol_speed
:
float
,
tol_likelihood
:
float
,
animal_id
:
str
=
""
,
):
pass
def
following_path
(
distance_dframe
:
pd
.
DataFrame
,
position_dframe
:
pd
.
DataFrame
,
...
...
deepof/visuals.py
View file @
2d21ca43
...
...
@@ -48,9 +48,18 @@ def plot_heatmap(
for
i
,
bpart
in
enumerate
(
bodyparts
):
heatmap
=
dframe
[
bpart
]
if
len
(
bodyparts
)
>
1
:
sns
.
kdeplot
(
data
=
heatmap
.
x
,
data2
=
heatmap
.
y
,
cmap
=
None
,
shade
=
True
,
alpha
=
1
,
ax
=
ax
[
i
])
sns
.
kdeplot
(
data
=
heatmap
.
x
,
data2
=
heatmap
.
y
,
cmap
=
None
,
shade
=
True
,
alpha
=
1
,
ax
=
ax
[
i
],
)
else
:
sns
.
kdeplot
(
data
=
heatmap
.
x
,
data2
=
heatmap
.
y
,
cmap
=
None
,
shade
=
True
,
alpha
=
1
,
ax
=
ax
)
sns
.
kdeplot
(
data
=
heatmap
.
x
,
data2
=
heatmap
.
y
,
cmap
=
None
,
shade
=
True
,
alpha
=
1
,
ax
=
ax
)
ax
=
np
.
array
([
ax
])
[
x
.
set_xlim
(
xlim
)
for
x
in
ax
]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment