Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
61bb9c12
Commit
61bb9c12
authored
Feb 12, 2021
by
lucas_miranda
Browse files
Added latent regularization control to deepof.data.coordinates.deep_unsupervised_embedding()
parent
c7bb409c
Pipeline
#93312
failed with stage
in 45 minutes and 11 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/train_model.py
View file @
61bb9c12
...
...
@@ -287,7 +287,7 @@ project_coords = project(
animal_ids
=
tuple
([
animal_id
]),
arena
=
"circular"
,
arena_dims
=
tuple
([
arena_dims
]),
enable_iterative_imputation
=
Tru
e
,
enable_iterative_imputation
=
Fals
e
,
exclude_bodyparts
=
exclude_bodyparts
,
exp_conditions
=
treatment_dict
,
path
=
train_path
,
...
...
@@ -359,6 +359,7 @@ if not tune:
trained_models
=
project_coords
.
deep_unsupervised_embedding
(
(
X_train
,
y_train
,
X_val
,
y_val
),
epochs
=
1
,
batch_size
=
batch_size
,
encoding_size
=
encoding_size
,
hparams
=
hparams
,
...
...
deepof/train_utils.py
View file @
61bb9c12
...
...
@@ -94,6 +94,8 @@ def get_callbacks(
phenotype_class
:
float
,
predictor
:
float
,
loss
:
str
,
reg_cat_clusters
:
bool
,
reg_cluster_variance
:
bool
,
logparam
:
dict
=
None
,
outpath
:
str
=
"."
,
)
->
List
[
Union
[
Any
]]:
...
...
@@ -103,6 +105,14 @@ def get_callbacks(
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
latreg
=
"none"
if
reg_cat_clusters
and
not
reg_cluster_variance
:
latreg
=
"categorical"
elif
reg_cluster_variance
and
not
reg_cat_clusters
:
latreg
=
"variance"
elif
reg_cat_clusters
and
reg_cluster_variance
:
latreg
=
"categorical+variance"
run_ID
=
"{}{}{}{}{}{}_{}"
.
format
(
(
"GMVAE"
if
variational
else
"AE"
),
(
"_Pred={}"
.
format
(
predictor
)
if
predictor
>
0
and
variational
else
""
),
...
...
@@ -110,6 +120,7 @@ def get_callbacks(
(
"_loss={}"
.
format
(
loss
)
if
variational
else
""
),
(
"_encoding={}"
.
format
(
logparam
[
"encoding"
])
if
logparam
is
not
None
else
""
),
(
"_k={}"
.
format
(
logparam
[
"k"
])
if
logparam
is
not
None
else
""
),
(
"_latreg={}"
.
format
(
latreg
)),
(
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S"
)),
)
...
...
@@ -251,11 +262,11 @@ def autoencoder_fitting(
log_history
:
bool
,
log_hparams
:
bool
,
loss
:
str
,
mmd_warmup
,
montecarlo_kl
,
n_components
,
output_path
,
phenotype_class
,
mmd_warmup
:
int
,
montecarlo_kl
:
int
,
n_components
:
int
,
output_path
:
str
,
phenotype_class
:
float
,
predictor
:
float
,
pretrained
:
str
,
save_checkpoints
:
bool
,
...
...
@@ -290,6 +301,8 @@ def autoencoder_fitting(
phenotype_class
=
phenotype_class
,
predictor
=
predictor
,
loss
=
loss
,
reg_cat_clusters
=
reg_cluster_variance
,
reg_cluster_variance
=
reg_cluster_variance
,
logparam
=
logparam
,
outpath
=
output_path
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment