Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
358b099b
Commit
358b099b
authored
Mar 11, 2021
by
lucas_miranda
Browse files
Added nose2body to rule_based_annotation()
parent
fb591fe6
Changes
5
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
358b099b
...
...
@@ -367,10 +367,10 @@ class MMDiscrepancyLayer(Layer):
return
z
class
Gaussian_mixture
_overlap
(
Layer
):
class
Cluster
_overlap
(
Layer
):
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (
MMD, Wasserstein, Fischer-Rao
)
using a specified metric (
KNN-purity, MMD
)
"""
def
__init__
(
self
,
lat_dims
,
n_components
,
loss
=
False
,
samples
=
10
,
*
args
,
**
kwargs
):
...
...
@@ -378,7 +378,7 @@ class Gaussian_mixture_overlap(Layer):
self
.
n_components
=
n_components
self
.
loss
=
loss
self
.
samples
=
samples
super
(
Gaussian_mixture
_overlap
,
self
).
__init__
(
*
args
,
**
kwargs
)
super
(
Cluster
_overlap
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
# pragma: no cover
"""Updates Constraint metadata"""
...
...
deepof/models.py
View file @
358b099b
...
...
@@ -334,15 +334,15 @@ class SEQ_2_SEQ_GMVAE:
"""Sets the default parameters for the model. Overwritable with a dictionary"""
defaults
=
{
"bidirectional_merge"
:
"
concat
"
,
"bidirectional_merge"
:
"
ave
"
,
"clipvalue"
:
1.0
,
"dense_activation"
:
"relu"
,
"dense_layers_per_branch"
:
5
,
"dropout_rate"
:
1e-3
,
"learning_rate"
:
1e-
3
,
"units_conv"
:
32
,
"dense_layers_per_branch"
:
1
,
"dropout_rate"
:
0.15
,
"learning_rate"
:
1e-
4
,
"units_conv"
:
64
,
"units_dense2"
:
32
,
"units_lstm"
:
300
,
"units_lstm"
:
128
,
}
for
k
,
v
in
params
.
items
():
...
...
@@ -599,7 +599,7 @@ class SEQ_2_SEQ_GMVAE:
z_gauss
=
deepof
.
model_utils
.
Dead_neuron_control
()(
z_gauss
)
if
self
.
overlap_loss
:
z_gauss
=
deepof
.
model_utils
.
Gaussian_mixture
_overlap
(
z_gauss
=
deepof
.
model_utils
.
Cluster
_overlap
(
self
.
ENCODING
,
self
.
number_of_components
,
loss
=
self
.
overlap_loss
,
...
...
deepof/train_model.py
View file @
358b099b
...
...
@@ -222,7 +222,7 @@ hypertun_trials = args.hpt_trials
encoding_size
=
args
.
encoding_size
exclude_bodyparts
=
[
i
for
i
in
args
.
exclude_bodyparts
.
split
(
","
)
if
i
]
gaussian_filter
=
args
.
gaussian_filter
hparams
=
args
.
hyperparameters
hparams
=
args
.
hyperparameters
if
args
.
hyperparameters
is
not
None
else
{}
input_type
=
args
.
input_type
k
=
args
.
components
kl_wu
=
args
.
kl_warmup
...
...
@@ -261,7 +261,6 @@ assert input_type in [
],
"Invalid input type. Type python model_training.py -h for help."
# Loads model hyperparameters and treatment conditions, if available
hparams
=
load_hparams
(
hparams
)
treatment_dict
=
load_treatments
(
train_path
)
# Logs hyperparameters if specified on the --logparam CLI argument
...
...
@@ -352,7 +351,7 @@ if not tune:
(
X_train
,
y_train
,
X_val
,
y_val
),
batch_size
=
batch_size
,
encoding_size
=
encoding_size
,
hparams
=
hparams
,
hparams
=
{}
,
kl_warmup
=
kl_wu
,
log_history
=
True
,
log_hparams
=
True
,
...
...
deepof/train_utils.py
View file @
358b099b
...
...
@@ -46,28 +46,6 @@ class CustomStopper(tf.keras.callbacks.EarlyStopping):
super
().
on_epoch_end
(
epoch
,
logs
)
def
load_hparams
(
hparams
):
"""Loads hyperparameters from a custom dictionary pickled on disc.
Thought to be used with the output of hyperparameter_tuning.py"""
if
hparams
is
not
None
:
with
open
(
hparams
,
"rb"
)
as
handle
:
hparams
=
pickle
.
load
(
handle
)
else
:
hparams
=
{
"bidirectional_merge"
:
"ave"
,
"clipvalue"
:
1.0
,
"dense_activation"
:
"relu"
,
"dense_layers_per_branch"
:
1
,
"dropout_rate"
:
1e-3
,
"learning_rate"
:
1e-3
,
"units_conv"
:
160
,
"units_dense2"
:
120
,
"units_lstm"
:
300
,
}
return
hparams
def
load_treatments
(
train_path
):
"""Loads a dictionary containing the treatments per individual,
to be loaded as metadata in the coordinates class"""
...
...
tests/test_train_utils.py
View file @
358b099b
...
...
@@ -21,24 +21,6 @@ import os
import
tensorflow
as
tf
def
test_load_hparams
():
assert
type
(
deepof
.
train_utils
.
load_hparams
(
None
))
==
dict
assert
(
type
(
deepof
.
train_utils
.
load_hparams
(
os
.
path
.
join
(
"tests"
,
"test_examples"
,
"test_single_topview"
,
"Others"
,
"test_hparams.pkl"
,
)
)
)
==
dict
)
def
test_load_treatments
():
assert
deepof
.
train_utils
.
load_treatments
(
"."
)
is
None
assert
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment