Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
3234508f
Commit
3234508f
authored
Nov 30, 2020
by
lucas_miranda
Browse files
Reformatted files using last version of Black; fixed seaborn update issues
parent
9db84390
Pipeline
#88235
passed with stage
in 24 minutes and 22 seconds
Changes
18
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
3234508f
...
...
@@ -138,7 +138,7 @@ class project:
@
property
def
angles
(
self
):
"""Bool. Toggles angle computation. True by default. If turned off,
enhances performance for big datasets"""
enhances performance for big datasets"""
return
self
.
_angles
@
property
...
...
@@ -274,7 +274,7 @@ class project:
def
get_distances
(
self
,
tab_dict
:
dict
,
verbose
:
bool
=
False
)
->
dict
:
"""Computes the distances between all selected body parts over time.
If ego is provided, it only returns distances to a specified bodypart"""
If ego is provided, it only returns distances to a specified bodypart"""
if
verbose
:
print
(
"Computing distances..."
)
...
...
@@ -290,7 +290,11 @@ class project:
scales
=
self
.
scales
[:,
2
:]
distance_dict
=
{
key
:
deepof
.
utils
.
bpart_distance
(
tab
,
scales
[
i
,
1
],
scales
[
i
,
0
],)
key
:
deepof
.
utils
.
bpart_distance
(
tab
,
scales
[
i
,
1
],
scales
[
i
,
0
],
)
for
i
,
(
key
,
tab
)
in
enumerate
(
tab_dict
.
items
())
}
...
...
@@ -782,7 +786,7 @@ class table_dict(dict):
def
filter_videos
(
self
,
keys
:
list
)
->
Table_dict
:
"""Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
for selecting data coming from videos of a specified condition."""
for selecting data coming from videos of a specified condition."""
assert
np
.
all
([
k
in
self
.
keys
()
for
k
in
keys
]),
"Invalid keys selected"
...
...
@@ -825,7 +829,9 @@ class table_dict(dict):
return
heatmaps
def
get_training_set
(
self
,
test_videos
:
int
=
0
,
encode_labels
:
bool
=
True
,
self
,
test_videos
:
int
=
0
,
encode_labels
:
bool
=
True
,
)
->
Tuple
[
np
.
ndarray
,
list
,
Union
[
np
.
ndarray
,
list
],
list
]:
"""Generates training and test sets as numpy.array objects for model training"""
...
...
deepof/hypermodels.py
View file @
3234508f
...
...
@@ -30,18 +30,40 @@ class SEQ_2_SEQ_AE(HyperModel):
"""Retrieve hyperparameters to tune"""
conv_filters
=
hp
.
Int
(
"units_conv"
,
min_value
=
32
,
max_value
=
256
,
step
=
32
,
default
=
256
,
"units_conv"
,
min_value
=
32
,
max_value
=
256
,
step
=
32
,
default
=
256
,
)
lstm_units_1
=
hp
.
Int
(
"units_lstm"
,
min_value
=
128
,
max_value
=
512
,
step
=
32
,
default
=
256
,
"units_lstm"
,
min_value
=
128
,
max_value
=
512
,
step
=
32
,
default
=
256
,
)
dense_2
=
hp
.
Int
(
"units_dense2"
,
min_value
=
32
,
max_value
=
256
,
step
=
32
,
default
=
64
,
"units_dense2"
,
min_value
=
32
,
max_value
=
256
,
step
=
32
,
default
=
64
,
)
dropout_rate
=
hp
.
Float
(
"dropout_rate"
,
min_value
=
0.0
,
max_value
=
0.5
,
default
=
0.25
,
step
=
0.05
,
"dropout_rate"
,
min_value
=
0.0
,
max_value
=
0.5
,
default
=
0.25
,
step
=
0.05
,
)
encoding
=
hp
.
Int
(
"encoding"
,
min_value
=
16
,
max_value
=
64
,
step
=
8
,
default
=
24
,
)
encoding
=
hp
.
Int
(
"encoding"
,
min_value
=
16
,
max_value
=
64
,
step
=
8
,
default
=
24
,)
return
conv_filters
,
lstm_units_1
,
dense_2
,
dropout_rate
,
encoding
...
...
deepof/model_utils.py
View file @
3234508f
...
...
@@ -141,16 +141,16 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
def
compute_mmd
(
tensors
:
Tuple
[
Any
,
Any
])
->
tf
.
Tensor
:
"""
Computes the MMD between the two specified vectors using a gaussian kernel.
Computes the MMD between the two specified vectors using a gaussian kernel.
Parameters:
- tensors (tuple): tuple containing two tf.Tensor objects
Parameters:
- tensors (tuple): tuple containing two tf.Tensor objects
Returns
- mmd (tf.Tensor): returns the maximum mean discrepancy for each
training instance
Returns
- mmd (tf.Tensor): returns the maximum mean discrepancy for each
training instance
"""
"""
x
=
tensors
[
0
]
y
=
tensors
[
1
]
...
...
@@ -339,8 +339,8 @@ class DenseTranspose(Layer):
class
KLDivergenceLayer
(
tfpl
.
KLDivergenceAddLoss
):
"""
Identity transform layer that adds KL Divergence
to the final model loss.
Identity transform layer that adds KL Divergence
to the final model loss.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -360,7 +360,9 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
kl_batch
=
self
.
_regularizer
(
distribution_a
)
self
.
add_loss
(
kl_batch
,
inputs
=
[
distribution_a
])
self
.
add_metric
(
kl_batch
,
aggregation
=
"mean"
,
name
=
"kl_divergence"
,
kl_batch
,
aggregation
=
"mean"
,
name
=
"kl_divergence"
,
)
# noinspection PyProtectedMember
self
.
add_metric
(
self
.
_regularizer
.
_weight
,
aggregation
=
"mean"
,
name
=
"kl_rate"
)
...
...
deepof/models.py
View file @
3234508f
...
...
@@ -33,7 +33,9 @@ class SEQ_2_SEQ_AE:
""" Simple sequence to sequence autoencoder implemented with tf.keras """
def
__init__
(
self
,
architecture_hparams
:
Dict
=
{},
huber_delta
:
float
=
1.0
,
self
,
architecture_hparams
:
Dict
=
{},
huber_delta
:
float
=
1.0
,
):
self
.
hparams
=
self
.
get_hparams
(
architecture_hparams
)
self
.
CONV_filters
=
self
.
hparams
[
"units_conv"
]
...
...
@@ -118,13 +120,19 @@ class SEQ_2_SEQ_AE:
# Decoder layers
Model_D0
=
deepof
.
model_utils
.
DenseTranspose
(
Model_E5
,
activation
=
"elu"
,
output_dim
=
self
.
ENCODING
,
Model_E5
,
activation
=
"elu"
,
output_dim
=
self
.
ENCODING
,
)
Model_D1
=
deepof
.
model_utils
.
DenseTranspose
(
Model_E4
,
activation
=
"elu"
,
output_dim
=
self
.
DENSE_2
,
Model_E4
,
activation
=
"elu"
,
output_dim
=
self
.
DENSE_2
,
)
Model_D2
=
deepof
.
model_utils
.
DenseTranspose
(
Model_E3
,
activation
=
"elu"
,
output_dim
=
self
.
DENSE_1
,
Model_E3
,
activation
=
"elu"
,
output_dim
=
self
.
DENSE_1
,
)
Model_D3
=
RepeatVector
(
input_shape
[
1
])
Model_D4
=
Bidirectional
(
...
...
@@ -161,7 +169,10 @@ class SEQ_2_SEQ_AE:
Model_D5
,
)
def
build
(
self
,
input_shape
:
tuple
,)
->
Tuple
[
Any
,
Any
,
Any
]:
def
build
(
self
,
input_shape
:
tuple
,
)
->
Tuple
[
Any
,
Any
,
Any
]:
"""Builds the tf.keras model"""
(
...
...
@@ -213,7 +224,10 @@ class SEQ_2_SEQ_AE:
model
.
compile
(
loss
=
Huber
(
delta
=
self
.
delta
),
optimizer
=
Nadam
(
lr
=
self
.
learn_rate
,
clipvalue
=
0.5
,),
optimizer
=
Nadam
(
lr
=
self
.
learn_rate
,
clipvalue
=
0.5
,
),
metrics
=
[
"mae"
],
)
...
...
@@ -298,7 +312,10 @@ class SEQ_2_SEQ_GMVAE:
),
components
=
[
tfd
.
Independent
(
tfd
.
Normal
(
loc
=
init_means
[
k
],
scale
=
1
,),
tfd
.
Normal
(
loc
=
init_means
[
k
],
scale
=
1
,
),
reinterpreted_batch_ndims
=
1
,
)
for
k
in
range
(
self
.
number_of_components
)
...
...
@@ -537,7 +554,10 @@ class SEQ_2_SEQ_GMVAE:
encoder
=
BatchNormalization
()(
encoder
)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat
=
Dense
(
self
.
number_of_components
,
activation
=
"softmax"
,)(
encoder
)
z_cat
=
Dense
(
self
.
number_of_components
,
activation
=
"softmax"
,
)(
encoder
)
z_cat
=
deepof
.
model_utils
.
Entropy_regulariser
(
self
.
entropy_reg_weight
)(
z_cat
)
z_gauss
=
Dense
(
deepof
.
model_utils
.
tfpl
.
IndependentNormal
.
params_size
(
...
...
@@ -553,12 +573,16 @@ class SEQ_2_SEQ_GMVAE:
if
self
.
overlap_loss
:
z_gauss
=
deepof
.
model_utils
.
Gaussian_mixture_overlap
(
self
.
ENCODING
,
self
.
number_of_components
,
loss
=
self
.
overlap_loss
,
self
.
ENCODING
,
self
.
number_of_components
,
loss
=
self
.
overlap_loss
,
)(
z_gauss
)
z
=
deepof
.
model_utils
.
tfpl
.
DistributionLambda
(
lambda
gauss
:
tfd
.
mixture
.
Mixture
(
cat
=
tfd
.
categorical
.
Categorical
(
probs
=
gauss
[
0
],),
cat
=
tfd
.
categorical
.
Categorical
(
probs
=
gauss
[
0
],
),
components
=
[
tfd
.
Independent
(
tfd
.
Normal
(
...
...
@@ -663,7 +687,11 @@ class SEQ_2_SEQ_GMVAE:
grouper
=
Model
(
x
,
z_cat
,
name
=
"Deep_Gaussian_Mixture_clustering"
)
# noinspection PyUnboundLocalVariable
gmvaep
=
Model
(
inputs
=
x
,
outputs
=
model_outs
,
name
=
"SEQ_2_SEQ_GMVAE"
,)
gmvaep
=
Model
(
inputs
=
x
,
outputs
=
model_outs
,
name
=
"SEQ_2_SEQ_GMVAE"
,
)
# Build generator as a separate entity
g
=
Input
(
shape
=
self
.
ENCODING
)
...
...
@@ -682,7 +710,10 @@ class SEQ_2_SEQ_GMVAE:
if
self
.
compile
:
gmvaep
.
compile
(
loss
=
model_losses
,
optimizer
=
Nadam
(
lr
=
self
.
learn_rate
,
clipvalue
=
self
.
clipvalue
,),
optimizer
=
Nadam
(
lr
=
self
.
learn_rate
,
clipvalue
=
self
.
clipvalue
,
),
metrics
=
model_metrics
,
loss_weights
=
loss_weights
,
)
...
...
deepof/pose_utils.py
View file @
3234508f
...
...
@@ -33,18 +33,18 @@ def close_single_contact(
)
->
np
.
array
:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
Parameters:
- pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
to two-animal experiments.
- left (string): First member of the potential contact
- right (string): Second member of the potential contact
- tol (float): maximum distance for which a contact is reported
- arena_abs (int): length in mm of the diameter of the real arena
- arena_rel (int): length in pixels of the diameter of the arena in the video
Parameters:
- pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
to two-animal experiments.
- left (string): First member of the potential contact
- right (string): Second member of the potential contact
- tol (float): maximum distance for which a contact is reported
- arena_abs (int): length in mm of the diameter of the real arena
- arena_rel (int): length in pixels of the diameter of the arena in the video
Returns:
- contact_array (np.array): True if the distance between the two specified points
is less than tol, False otherwise"""
Returns:
- contact_array (np.array): True if the distance between the two specified points
is less than tol, False otherwise"""
close_contact
=
(
np
.
linalg
.
norm
(
pos_dframe
[
left
]
-
pos_dframe
[
right
],
axis
=
1
)
*
arena_abs
...
...
@@ -66,21 +66,21 @@ def close_double_contact(
)
->
np
.
array
:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
Parameters:
- pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
to two-animal experiments.
- left1 (string): First contact point of animal 1
- left2 (string): Second contact point of animal 1
- right1 (string): First contact point of animal 2
- right2 (string): Second contact point of animal 2
- tol (float): maximum distance for which a contact is reported
- arena_abs (int): length in mm of the diameter of the real arena
- arena_rel (int): length in pixels of the diameter of the arena in the video
- rev (bool): reverses the default behaviour (nose2tail contact for both mice)
Parameters:
- pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
to two-animal experiments.
- left1 (string): First contact point of animal 1
- left2 (string): Second contact point of animal 1
- right1 (string): First contact point of animal 2
- right2 (string): Second contact point of animal 2
- tol (float): maximum distance for which a contact is reported
- arena_abs (int): length in mm of the diameter of the real arena
- arena_rel (int): length in pixels of the diameter of the arena in the video
- rev (bool): reverses the default behaviour (nose2tail contact for both mice)
Returns:
- double_contact (np.array): True if the distance between the two specified points
is less than tol, False otherwise"""
Returns:
- double_contact (np.array): True if the distance between the two specified points
is less than tol, False otherwise"""
if
rev
:
double_contact
=
(
...
...
@@ -117,19 +117,19 @@ def climb_wall(
)
->
np
.
array
:
"""Returns True if the specified mouse is climbing the wall
Parameters:
- arena_type (str): arena type; must be one of ['circular']
- arena (np.array): contains arena location and shape details
- pos_dict (table_dict): position over time for all videos in a project
- tol (float): minimum tolerance to report a hit
- nose (str): indicates the name of the body part representing the nose of
the selected animal
- arena_dims (int): indicates radius of the real arena in mm
- centered_data (bool): indicates whether the input data is centered
Parameters:
- arena_type (str): arena type; must be one of ['circular']
- arena (np.array): contains arena location and shape details
- pos_dict (table_dict): position over time for all videos in a project
- tol (float): minimum tolerance to report a hit
- nose (str): indicates the name of the body part representing the nose of
the selected animal
- arena_dims (int): indicates radius of the real arena in mm
- centered_data (bool): indicates whether the input data is centered
Returns:
- climbing (np.array): boolean array. True if selected animal
is climbing the walls of the arena"""
Returns:
- climbing (np.array): boolean array. True if selected animal
is climbing the walls of the arena"""
nose
=
pos_dict
[
nose
]
...
...
@@ -166,7 +166,7 @@ def huddle(
Returns:
hudd (np.array): True if the animal is huddling, False otherwise
"""
"""
if
animal_id
!=
""
:
animal_id
+=
"_"
...
...
@@ -254,7 +254,8 @@ def following_path(
)
follow
=
np
.
all
(
np
.
array
([(
dist_df
.
min
(
axis
=
1
)
<
tol
),
right_orient1
,
right_orient2
]),
axis
=
0
,
np
.
array
([(
dist_df
.
min
(
axis
=
1
)
<
tol
),
right_orient1
,
right_orient2
]),
axis
=
0
,
)
return
follow
...
...
@@ -270,28 +271,27 @@ def single_behaviour_analysis(
ylim
:
float
=
None
,
)
->
list
:
"""Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
with the actual tags, outputs a box plot and a series of significance tests amongst the groups
with the actual tags, outputs a box plot and a series of significance tests amongst the groups
Parameters:
- behaviour_name (str): name of the behavioural trait to analize
- treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
- behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
- plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
- stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
- save (str): Saves the produced figure to the specified file
- ylim (float): y-limit for the boxplot. Ignored if plot == False
Parameters:
- behaviour_name (str): name of the behavioural trait to analize
- treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
- behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
- plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
- stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
- save (str): Saves the produced figure to the specified file
- ylim (float): y-limit for the boxplot. Ignored if plot == False
Returns:
- beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
- stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
Returns:
- beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
- stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
beh_dict
=
{
condition
:
[]
for
condition
in
treatment_dict
.
keys
()}
for
condition
in
beh_dict
.
keys
():
for
ind
in
treatment_dict
[
condition
]:
beh_dict
[
condition
].
append
(
np
.
sum
(
behavioural_dict
[
ind
][
behaviour_name
])
/
len
(
behavioural_dict
[
ind
][
behaviour_name
])
beh_dict
[
condition
]
+=
np
.
sum
(
behavioural_dict
[
ind
][
behaviour_name
])
/
len
(
behavioural_dict
[
ind
][
behaviour_name
]
)
return_list
=
[
beh_dict
]
...
...
@@ -301,7 +301,10 @@ def single_behaviour_analysis(
fig
,
ax
=
plt
.
subplots
(
dpi
=
plot
)
sns
.
boxplot
(
list
(
beh_dict
.
keys
()),
list
(
beh_dict
.
values
()),
orient
=
"vertical"
,
ax
=
ax
x
=
list
(
beh_dict
.
keys
()),
y
=
list
(
beh_dict
.
values
()),
orient
=
"vertical"
,
ax
=
ax
,
)
ax
.
set_title
(
"{} across groups"
.
format
(
behaviour_name
))
...
...
@@ -343,16 +346,16 @@ def max_behaviour(
)
->
np
.
array
:
"""Returns the most frequent behaviour in a window of window_size frames
Parameters:
- behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
of tagged behaviours per frame in the video
- window_size (int): size of the window to use when computing
the maximum behaviour per time slot
- stepped (bool): sliding windows don't overlap if True. False by default
Parameters:
- behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
of tagged behaviours per frame in the video
- window_size (int): size of the window to use when computing
the maximum behaviour per time slot
- stepped (bool): sliding windows don't overlap if True. False by default
Returns:
- max_array (np.array): string array with the most common behaviour per instance
of the sliding window"""
Returns:
- max_array (np.array): string array with the most common behaviour per instance
of the sliding window"""
speeds
=
[
col
for
col
in
behaviour_dframe
.
columns
if
"speed"
in
col
.
lower
()]
...
...
@@ -369,12 +372,12 @@ def max_behaviour(
def
get_hparameters
(
hparams
:
dict
=
{})
->
dict
:
"""Returns the most frequent behaviour in a window of window_size frames
Parameters:
- hparams (dict): dictionary containing hyperparameters to overwrite
Parameters:
- hparams (dict): dictionary containing hyperparameters to overwrite
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
defaults
=
{
"speed_pause"
:
3
,
...
...
@@ -398,14 +401,14 @@ def get_hparameters(hparams: dict = {}) -> dict:
def
frame_corners
(
w
,
h
,
corners
:
dict
=
{}):
"""Returns a dictionary with the corner positions of the video frame
Parameters:
- w (int): width of the frame in pixels
- h (int): height of the frame in pixels
- corners (dict): dictionary containing corners to overwrite
Parameters:
- w (int): width of the frame in pixels
- h (int): height of the frame in pixels
- corners (dict): dictionary containing corners to overwrite
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
defaults
=
{
"downleft"
:
(
int
(
w
*
0.3
/
10
),
int
(
h
/
1.05
)),
...
...
@@ -614,11 +617,13 @@ def tag_rulebased_frames(
write_on_frame
(
"Nose-Tail"
,
corners
[
"downright"
])
if
tag_dict
[
"sidebyside"
][
fnum
]:
write_on_frame
(
"Side-side"
,
conditional_pos
(),
"Side-side"
,
conditional_pos
(),
)
if
tag_dict
[
"sidereside"
][
fnum
]:
write_on_frame
(
"Side-Rside"
,
conditional_pos
(),
"Side-Rside"
,
conditional_pos
(),
)
for
_id
,
down_pos
,
up_pos
in
zipped_pos
:
if
(
...
...
@@ -626,7 +631,9 @@ def tag_rulebased_frames(
and
not
tag_dict
[
_id
+
"_climbing"
][
fnum
]
):
write_on_frame
(
"*f"
,
(
int
(
w
*
0.3
/
10
),
int
(
h
/
10
)),
conditional_col
(),
"*f"
,
(
int
(
w
*
0.3
/
10
),
int
(
h
/
10
)),
conditional_col
(),
)
for
_id
,
down_pos
,
up_pos
in
zipped_pos
:
...
...
deepof/train_model.py
View file @
3234508f
...
...
@@ -310,7 +310,12 @@ if not tune:
tf
.
keras
.
backend
.
clear_session
()
run_ID
,
tensorboard_callback
,
onecycle
,
cp_callback
=
get_callbacks
(
X_train
,
batch_size
,
True
,
variational
,
predictor
,
loss
,
X_train
,
batch_size
,
True
,
variational
,
predictor
,
loss
,
)
if
not
variational
:
...
...
@@ -393,7 +398,10 @@ if not tune:
epochs
=
250
,
batch_size
=
batch_size
,
verbose
=
1
,
validation_data
=
(
Xvals
,
yvals
,),
validation_data
=
(
Xvals
,
yvals
,
),
callbacks
=
callbacks_
,
)
...
...
deepof/train_utils.py
View file @
3234508f
...
...
@@ -69,10 +69,10 @@ def get_callbacks(
loss
:
str
,
)
->
List
[
Union
[
Any
]]:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
- tensorboard_callback: for real-time visualization;
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
- run_ID: run name, with coarse parameter details;
- tensorboard_callback: for real-time visualization;
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
run_ID
=
"{}{}{}_{}"
.
format
(
(
"GMVAE"
if
variational
else
"AE"
),
...
...
@@ -83,11 +83,14 @@ def get_callbacks(
log_dir
=
os
.
path
.
abspath
(
"logs/fit/{}"
.
format
(
run_ID
))
tensorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
log_dir
,
histogram_freq
=
1
,
profile_batch
=
2
,
log_dir
=
log_dir
,
histogram_freq
=
1
,
profile_batch
=
2
,
)
onecycle
=
deepof
.
model_utils
.
one_cycle_scheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
)
callbacks
=
[
run_ID
,
tensorboard_callback
,
onecycle
]
...
...
@@ -124,30 +127,30 @@ def tune_search(
)
->
Union
[
bool
,
Tuple
[
Any
,
Any
]]:
"""Define the search space using keras-tuner and bayesian optimization
Parameters:
- train (np.array): dataset to train the model on
- test (np.array): dataset to validate the model on
- hypertun_trials (int): number of Bayesian optimization iterations to run
- hpt_type (str): specify one of Bayesian Optimization (bayopt) and Hyperband (hyperband)
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder)
or S2SGMVAE (Gaussian Mixture Variational autoencoder).
- k (int) number of components of the Gaussian Mixture
- loss (str): one of [ELBO, MMD, ELBO+MMD]
- overlap_loss (float): assigns as weight to an extra loss term which
penalizes overlap between GM components
- pheno_class (float): adds an extra regularizing neural network to the model,
which tries to predict the phenotype of the animal from which the sequence comes
- predictor (float): adds an extra regularizing neural network to the model,
which tries to predict the next frame from the current one
- project_name (str): ID of the current run
- callbacks (list): list of callbacks for the training loop
- n_epochs (int): optional. Number of epochs to train each run for
- n_replicas (int): optional. Number of replicas per parameter set. Higher values
will yield more robust results, but will affect performance severely
Returns:
- best_hparams (dict): dictionary with the best retrieved hyperparameters
- best_run (tf.keras.Model): trained instance of the best model found
Parameters:
- train (np.array): dataset to train the model on
- test (np.array): dataset to validate the model on
- hypertun_trials (int): number of Bayesian optimization iterations to run
- hpt_type (str): specify one of Bayesian Optimization (bayopt) and Hyperband (hyperband)
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder)
or S2SGMVAE (Gaussian Mixture Variational autoencoder).
- k (int) number of components of the Gaussian Mixture
- loss (str): one of [ELBO, MMD, ELBO+MMD]
- overlap_loss (float): assigns as weight to an extra loss term which
penalizes overlap between GM components
- pheno_class (float): adds an extra regularizing neural network to the model,
which tries to predict the phenotype of the animal from which the sequence comes
- predictor (float): adds an extra regularizing neural network to the model,
which tries to predict the next frame from the current one
- project_name (str): ID of the current run