Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
707ab97e
Commit
707ab97e
authored
Mar 30, 2021
by
lucas_miranda
Browse files
Ran PyCharm code cleaner
Former-commit-id:
bb286fec
parent
4c71282e
Changes
9
Hide whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
707ab97e
...
...
@@ -14,29 +14,30 @@ Contains methods for generating training and test sets ready for model training.
"""
import
os
import
warnings
from
collections
import
defaultdict
from
joblib
import
delayed
,
Parallel
,
parallel_backend
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
multiprocessing
import
cpu_count
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
pandas
as
pd
import
tensorflow
as
tf
from
joblib
import
delayed
,
Parallel
,
parallel_backend
from
pkg_resources
import
resource_filename
from
sklearn
import
random_projection
from
sklearn.decomposition
import
KernelPCA
from
sklearn.experimental
import
enable_iterative_imputer
from
sklearn.impute
import
IterativeImputer
from
sklearn.manifold
import
TSNE
from
sklearn.preprocessing
import
MinMaxScaler
,
StandardScaler
,
LabelEncoder
from
tqdm
import
tqdm
import
deepof.models
import
deepof.pose_utils
import
deepof.train_utils
import
deepof.utils
import
deepof.visuals
import
deepof.train_utils
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
os
import
pandas
as
pd
import
tensorflow
as
tf
import
warnings
# Remove excessive logging from tensorflow
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"2"
...
...
@@ -58,23 +59,23 @@ class project:
"""
def
__init__
(
self
,
animal_ids
:
List
=
tuple
([
""
]),
arena
:
str
=
"circular"
,
arena_detection
:
str
=
"cnn"
,
arena_dims
:
tuple
=
(
1
,),
enable_iterative_imputation
:
bool
=
None
,
exclude_bodyparts
:
List
=
tuple
([
""
]),
exp_conditions
:
dict
=
None
,
interpolate_outliers
:
bool
=
True
,
interpolation_limit
:
int
=
5
,
interpolation_std
:
int
=
5
,
likelihood_tol
:
float
=
0.25
,
model
:
str
=
"mouse_topview"
,
path
:
str
=
deepof
.
utils
.
os
.
path
.
join
(
"."
),
smooth_alpha
:
float
=
0.99
,
table_format
:
str
=
"autodetect"
,
video_format
:
str
=
".mp4"
,
self
,
animal_ids
:
List
=
tuple
([
""
]),
arena
:
str
=
"circular"
,
arena_detection
:
str
=
"cnn"
,
arena_dims
:
tuple
=
(
1
,),
enable_iterative_imputation
:
bool
=
None
,
exclude_bodyparts
:
List
=
tuple
([
""
]),
exp_conditions
:
dict
=
None
,
interpolate_outliers
:
bool
=
True
,
interpolation_limit
:
int
=
5
,
interpolation_std
:
int
=
5
,
likelihood_tol
:
float
=
0.25
,
model
:
str
=
"mouse_topview"
,
path
:
str
=
deepof
.
utils
.
os
.
path
.
join
(
"."
),
smooth_alpha
:
float
=
0.99
,
table_format
:
str
=
"autodetect"
,
video_format
:
str
=
".mp4"
,
):
# Set working paths
...
...
@@ -115,7 +116,6 @@ class project:
self
.
arena_dims
=
arena_dims
self
.
ellipse_detection
=
None
if
arena
==
"circular"
and
arena_detection
==
"cnn"
:
self
.
ellipse_detection
=
tf
.
keras
.
models
.
load_model
(
[
os
.
path
.
join
(
self
.
trained_path
,
i
)
...
...
@@ -286,8 +286,8 @@ class project:
).
T
.
index
.
remove_unused_levels
()
tab
=
value
.
loc
[
:,
[
i
for
i
in
value
.
columns
.
levels
[
0
]
if
i
not
in
lablist
]
]
:,
[
i
for
i
in
value
.
columns
.
levels
[
0
]
if
i
not
in
lablist
]
]
tab
.
columns
=
tabcols
...
...
@@ -361,14 +361,14 @@ class project:
for
key
in
distance_dict
.
keys
():
distance_dict
[
key
]
=
distance_dict
[
key
].
loc
[
:,
[
np
.
all
([
i
in
nodes
for
i
in
j
])
for
j
in
distance_dict
[
key
].
columns
]
]
:,
[
np
.
all
([
i
in
nodes
for
i
in
j
])
for
j
in
distance_dict
[
key
].
columns
]
]
if
self
.
ego
:
for
key
,
val
in
distance_dict
.
items
():
distance_dict
[
key
]
=
val
.
loc
[
:,
[
dist
for
dist
in
val
.
columns
if
self
.
ego
in
dist
]
]
:,
[
dist
for
dist
in
val
.
columns
if
self
.
ego
in
dist
]
]
return
distance_dict
...
...
@@ -471,20 +471,20 @@ class coordinates:
"""
def
__init__
(
self
,
arena
:
str
,
arena_detection
:
str
,
arena_dims
:
np
.
array
,
path
:
str
,
quality
:
dict
,
scales
:
np
.
array
,
tables
:
dict
,
videos
:
list
,
angles
:
dict
=
None
,
animal_ids
:
List
=
tuple
([
""
]),
distances
:
dict
=
None
,
exp_conditions
:
dict
=
None
,
ellipse_detection
:
tf
.
keras
.
models
.
Model
=
None
,
self
,
arena
:
str
,
arena_detection
:
str
,
arena_dims
:
np
.
array
,
path
:
str
,
quality
:
dict
,
scales
:
np
.
array
,
tables
:
dict
,
videos
:
list
,
angles
:
dict
=
None
,
animal_ids
:
List
=
tuple
([
""
]),
distances
:
dict
=
None
,
exp_conditions
:
dict
=
None
,
ellipse_detection
:
tf
.
keras
.
models
.
Model
=
None
,
):
self
.
_animal_ids
=
animal_ids
self
.
_arena
=
arena
...
...
@@ -509,15 +509,15 @@ class coordinates:
return
"deepof analysis of {} videos"
.
format
(
len
(
self
.
_videos
))
def
get_coords
(
self
,
center
:
str
=
"arena"
,
polar
:
bool
=
False
,
speed
:
int
=
0
,
length
:
str
=
None
,
align
:
bool
=
False
,
align_inplace
:
bool
=
False
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
self
,
center
:
str
=
"arena"
,
polar
:
bool
=
False
,
speed
:
int
=
0
,
length
:
str
=
None
,
align
:
bool
=
False
,
align_inplace
:
bool
=
False
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
)
->
Table_dict
:
"""
Returns a table_dict object with the coordinates of each animal as values.
...
...
@@ -557,23 +557,23 @@ class coordinates:
try
:
value
.
loc
[:,
(
slice
(
"coords"
),
[
"x"
])]
=
(
value
.
loc
[:,
(
slice
(
"coords"
),
[
"x"
])]
-
self
.
_scales
[
i
][
0
]
/
2
value
.
loc
[:,
(
slice
(
"coords"
),
[
"x"
])]
-
self
.
_scales
[
i
][
0
]
/
2
)
value
.
loc
[:,
(
slice
(
"coords"
),
[
"y"
])]
=
(
value
.
loc
[:,
(
slice
(
"coords"
),
[
"y"
])]
-
self
.
_scales
[
i
][
1
]
/
2
value
.
loc
[:,
(
slice
(
"coords"
),
[
"y"
])]
-
self
.
_scales
[
i
][
1
]
/
2
)
except
KeyError
:
value
.
loc
[:,
(
slice
(
"coords"
),
[
"rho"
])]
=
(
value
.
loc
[:,
(
slice
(
"coords"
),
[
"rho"
])]
-
self
.
_scales
[
i
][
0
]
/
2
value
.
loc
[:,
(
slice
(
"coords"
),
[
"rho"
])]
-
self
.
_scales
[
i
][
0
]
/
2
)
value
.
loc
[:,
(
slice
(
"coords"
),
[
"phi"
])]
=
(
value
.
loc
[:,
(
slice
(
"coords"
),
[
"phi"
])]
-
self
.
_scales
[
i
][
1
]
/
2
value
.
loc
[:,
(
slice
(
"coords"
),
[
"phi"
])]
-
self
.
_scales
[
i
][
1
]
/
2
)
elif
isinstance
(
center
,
str
)
and
center
!=
"arena"
:
...
...
@@ -582,24 +582,24 @@ class coordinates:
try
:
value
.
loc
[:,
(
slice
(
"coords"
),
[
"x"
])]
=
value
.
loc
[
:,
(
slice
(
"coords"
),
[
"x"
])
].
subtract
(
value
[
center
][
"x"
],
axis
=
0
)
:,
(
slice
(
"coords"
),
[
"x"
])
].
subtract
(
value
[
center
][
"x"
],
axis
=
0
)
value
.
loc
[:,
(
slice
(
"coords"
),
[
"y"
])]
=
value
.
loc
[
:,
(
slice
(
"coords"
),
[
"y"
])
].
subtract
(
value
[
center
][
"y"
],
axis
=
0
)
:,
(
slice
(
"coords"
),
[
"y"
])
].
subtract
(
value
[
center
][
"y"
],
axis
=
0
)
except
KeyError
:
value
.
loc
[:,
(
slice
(
"coords"
),
[
"rho"
])]
=
value
.
loc
[
:,
(
slice
(
"coords"
),
[
"rho"
])
].
subtract
(
value
[
center
][
"rho"
],
axis
=
0
)
:,
(
slice
(
"coords"
),
[
"rho"
])
].
subtract
(
value
[
center
][
"rho"
],
axis
=
0
)
value
.
loc
[:,
(
slice
(
"coords"
),
[
"phi"
])]
=
value
.
loc
[
:,
(
slice
(
"coords"
),
[
"phi"
])
].
subtract
(
value
[
center
][
"phi"
],
axis
=
0
)
:,
(
slice
(
"coords"
),
[
"phi"
])
].
subtract
(
value
[
center
][
"phi"
],
axis
=
0
)
tabs
[
key
]
=
value
.
loc
[
:,
[
tab
for
tab
in
value
.
columns
if
tab
[
0
]
!=
center
]
]
:,
[
tab
for
tab
in
value
.
columns
if
tab
[
0
]
!=
center
]
]
if
speed
:
for
key
,
tab
in
tabs
.
items
():
...
...
@@ -614,16 +614,16 @@ class coordinates:
if
align
:
assert
(
align
in
list
(
tabs
.
values
())[
0
].
columns
.
levels
[
0
]
align
in
list
(
tabs
.
values
())[
0
].
columns
.
levels
[
0
]
),
"align must be set to the name of a bodypart"
for
key
,
tab
in
tabs
.
items
():
# Bring forward the column to align
columns
=
[
i
for
i
in
tab
.
columns
if
align
not
in
i
]
columns
=
[
(
align
,
(
"phi"
if
polar
else
"x"
)),
(
align
,
(
"rho"
if
polar
else
"y"
)),
]
+
columns
(
align
,
(
"phi"
if
polar
else
"x"
)),
(
align
,
(
"rho"
if
polar
else
"y"
)),
]
+
columns
tab
=
tab
[
columns
]
tabs
[
key
]
=
tab
...
...
@@ -658,11 +658,11 @@ class coordinates:
)
def
get_distances
(
self
,
speed
:
int
=
0
,
length
:
str
=
None
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
self
,
speed
:
int
=
0
,
length
:
str
=
None
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
)
->
Table_dict
:
"""
Returns a table_dict object with the distances between body parts animal as values.
...
...
@@ -719,12 +719,12 @@ class coordinates:
)
# pragma: no cover
def
get_angles
(
self
,
degrees
:
bool
=
False
,
speed
:
int
=
0
,
length
:
str
=
None
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
self
,
degrees
:
bool
=
False
,
speed
:
int
=
0
,
length
:
str
=
None
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
)
->
Table_dict
:
"""
Returns a table_dict object with the angles between body parts animal as values.
...
...
@@ -810,11 +810,11 @@ class coordinates:
# noinspection PyDefaultArgument
def
rule_based_annotation
(
self
,
params
:
Dict
=
{},
video_output
:
bool
=
False
,
frame_limit
:
int
=
np
.
inf
,
debug
:
bool
=
False
,
self
,
params
:
Dict
=
{},
video_output
:
bool
=
False
,
frame_limit
:
int
=
np
.
inf
,
debug
:
bool
=
False
,
)
->
Table_dict
:
"""Annotates coordinates using a simple rule-based pipeline"""
...
...
@@ -882,29 +882,29 @@ class coordinates:
@
staticmethod
def
deep_unsupervised_embedding
(
preprocessed_object
:
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
batch_size
:
int
=
256
,
encoding_size
:
int
=
4
,
epochs
:
int
=
35
,
hparams
:
dict
=
None
,
kl_warmup
:
int
=
0
,
log_history
:
bool
=
True
,
log_hparams
:
bool
=
False
,
loss
:
str
=
"ELBO"
,
mmd_warmup
:
int
=
0
,
montecarlo_kl
:
int
=
10
,
n_components
:
int
=
25
,
output_path
:
str
=
"."
,
phenotype_class
:
float
=
0
,
predictor
:
float
=
0
,
pretrained
:
str
=
False
,
save_checkpoints
:
bool
=
False
,
save_weights
:
bool
=
True
,
variational
:
bool
=
True
,
reg_cat_clusters
:
bool
=
False
,
reg_cluster_variance
:
bool
=
False
,
entropy_samples
:
int
=
10000
,
entropy_knn
:
int
=
100
,
preprocessed_object
:
Tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
batch_size
:
int
=
256
,
encoding_size
:
int
=
4
,
epochs
:
int
=
35
,
hparams
:
dict
=
None
,
kl_warmup
:
int
=
0
,
log_history
:
bool
=
True
,
log_hparams
:
bool
=
False
,
loss
:
str
=
"ELBO"
,
mmd_warmup
:
int
=
0
,
montecarlo_kl
:
int
=
10
,
n_components
:
int
=
25
,
output_path
:
str
=
"."
,
phenotype_class
:
float
=
0
,
predictor
:
float
=
0
,
pretrained
:
str
=
False
,
save_checkpoints
:
bool
=
False
,
save_weights
:
bool
=
True
,
variational
:
bool
=
True
,
reg_cat_clusters
:
bool
=
False
,
reg_cluster_variance
:
bool
=
False
,
entropy_samples
:
int
=
10000
,
entropy_knn
:
int
=
100
,
)
->
Tuple
:
"""
Annotates coordinates using an unsupervised autoencoder.
...
...
@@ -982,15 +982,15 @@ class table_dict(dict):
"""
def
__init__
(
self
,
tabs
:
Dict
,
typ
:
str
,
arena
:
str
=
None
,
arena_dims
:
np
.
array
=
None
,
center
:
str
=
None
,
polar
:
bool
=
None
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
self
,
tabs
:
Dict
,
typ
:
str
,
arena
:
str
=
None
,
arena_dims
:
np
.
array
=
None
,
center
:
str
=
None
,
polar
:
bool
=
None
,
propagate_labels
:
bool
=
False
,
propagate_annotations
:
Dict
=
False
,
):
super
().
__init__
(
tabs
)
self
.
_type
=
typ
...
...
@@ -1013,13 +1013,13 @@ class table_dict(dict):
# noinspection PyTypeChecker
def
plot_heatmaps
(
self
,
bodyparts
:
list
,
xlim
:
float
=
None
,
ylim
:
float
=
None
,
save
:
bool
=
False
,
i
:
int
=
0
,
dpi
:
int
=
100
,
self
,
bodyparts
:
list
,
xlim
:
float
=
None
,
ylim
:
float
=
None
,
save
:
bool
=
False
,
i
:
int
=
0
,
dpi
:
int
=
100
,
)
->
plt
.
figure
:
"""Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
...
...
@@ -1045,9 +1045,9 @@ class table_dict(dict):
return
heatmaps
def
get_training_set
(
self
,
test_videos
:
int
=
0
,
encode_labels
:
bool
=
True
,
self
,
test_videos
:
int
=
0
,
encode_labels
:
bool
=
True
,
)
->
Tuple
[
np
.
ndarray
,
list
,
Union
[
np
.
ndarray
,
list
],
list
]:
"""Generates training and test sets as numpy.array objects for model training"""
...
...
@@ -1106,17 +1106,17 @@ class table_dict(dict):
# noinspection PyTypeChecker,PyGlobalUndefined
def
preprocess
(
self
,
window_size
:
int
=
1
,
window_step
:
int
=
1
,
scale
:
str
=
"standard"
,
test_videos
:
int
=
0
,
verbose
:
bool
=
False
,
conv_filter
:
bool
=
None
,
sigma
:
float
=
1.0
,
shift
:
float
=
0.0
,
shuffle
:
bool
=
False
,
align
:
str
=
False
,
self
,
window_size
:
int
=
1
,
window_step
:
int
=
1
,
scale
:
str
=
"standard"
,
test_videos
:
int
=
0
,
verbose
:
bool
=
False
,
conv_filter
:
bool
=
None
,
sigma
:
float
=
1.0
,
shift
:
float
=
0.0
,
shuffle
:
bool
=
False
,
align
:
str
=
False
,
)
->
np
.
ndarray
:
"""
...
...
@@ -1243,7 +1243,7 @@ class table_dict(dict):
return
X_train
,
y_train
,
np
.
array
(
X_test
),
np
.
array
(
y_test
)
def
random_projection
(
self
,
n_components
:
int
=
None
,
sample
:
int
=
1000
self
,
n_components
:
int
=
None
,
sample
:
int
=
1000
)
->
deepof
.
utils
.
Tuple
[
deepof
.
utils
.
Any
,
deepof
.
utils
.
Any
]:
"""Returns a training set generated from the 2D original data (time x features) and a random projection
to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
...
...
@@ -1265,7 +1265,7 @@ class table_dict(dict):
return
X
,
rproj
def
pca
(
self
,
n_components
:
int
=
None
,
sample
:
int
=
1000
,
kernel
:
str
=
"linear"
self
,
n_components
:
int
=
None
,
sample
:
int
=
1000
,
kernel
:
str
=
"linear"
)
->
deepof
.
utils
.
Tuple
[
deepof
.
utils
.
Any
,
deepof
.
utils
.
Any
]:
"""Returns a training set generated from the 2D original data (time x features) and a PCA projection
to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
...
...
@@ -1287,7 +1287,7 @@ class table_dict(dict):
return
X
,
pca
def
tsne
(
self
,
n_components
:
int
=
None
,
sample
:
int
=
1000
,
perplexity
:
int
=
30
self
,
n_components
:
int
=
None
,
sample
:
int
=
1000
,
perplexity
:
int
=
30
)
->
deepof
.
utils
.
Tuple
[
deepof
.
utils
.
Any
,
deepof
.
utils
.
Any
]:
"""Returns a training set generated from the 2D original data (time x features) and a PCA projection
to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
...
...
@@ -1331,7 +1331,6 @@ def merge_tables(*args):
return
merged_tables
# TODO:
# - Generate ragged training array using a metric (acceleration, maybe?)
# - Use something like Dynamic Time Warping to put all instances in the same length
...
...
deepof/hypermodels.py
View file @
707ab97e
...
...
@@ -8,10 +8,11 @@ keras hypermodels for hyperparameter tuning of deep autoencoders
"""
import
tensorflow_probability
as
tfp
from
kerastuner
import
HyperModel
import
deepof.models
import
deepof.model_utils
import
tensorflow_probability
as
tfp
import
deepof.models
tfd
=
tfp
.
distributions
tfpl
=
tfp
.
layers
...
...
@@ -92,18 +93,18 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"""Hyperparameter tuning pipeline for deepof.models.SEQ_2_SEQ_GMVAE"""
def
__init__
(
self
,
input_shape
:
tuple
,
encoding
:
int
,
kl_warmup_epochs
:
int
=
0
,
learn_rate
:
float
=
1e-3
,
loss
:
str
=
"ELBO+MMD"
,
mmd_warmup_epochs
:
int
=
0
,
number_of_components
:
int
=
10
,
overlap_loss
:
float
=
False
,
phenotype_predictor
:
float
=
0.0
,
predictor
:
float
=
0.0
,
prior
:
str
=
"standard_normal"
,
self
,
input_shape
:
tuple
,
encoding
:
int
,
kl_warmup_epochs
:
int
=
0
,
learn_rate
:
float
=
1e-3
,
loss
:
str
=
"ELBO+MMD"
,
mmd_warmup_epochs
:
int
=
0
,
number_of_components
:
int
=
10
,
overlap_loss
:
float
=
False
,
phenotype_predictor
:
float
=
0.0
,
predictor
:
float
=
0.0
,
prior
:
str
=
"standard_normal"
,
):
super
().
__init__
()
self
.
input_shape
=
input_shape
...
...
@@ -119,7 +120,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self
.
prior
=
prior
assert
(
"ELBO"
in
self
.
loss
or
"MMD"
in
self
.
loss
"ELBO"
in
self
.
loss
or
"MMD"
in
self
.
loss
),
"loss must be one of ELBO, MMD or ELBO+MMD (default)"
def
get_hparams
(
self
,
hp
):
...
...
@@ -191,7 +192,6 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
return
gmvaep
# TODO:
# - We can add as many parameters as we want to the hypermodel!
# with this implementation, predictor, warmup, loss and even number of components can be tuned using BayOpt
...
...
deepof/model_utils.py
View file @
707ab97e
...
...
@@ -10,15 +10,16 @@ Functions and general utilities for the deepof tensorflow models. See documentat
from
itertools
import
combinations
from
typing
import
Any
,
Tuple
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
from
scipy.stats
import
entropy
from
sklearn.neighbors
import
NearestNeighbors
from
tensorflow.keras
import
backend
as
K
from
tensorflow.keras.constraints
import
Constraint
from
tensorflow.keras.layers
import
Layer
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
tfd
=
tfp
.
distributions
tfpl
=
tfp
.
layers
...
...
@@ -44,7 +45,7 @@ class exponential_learning_rate(tf.keras.callbacks.Callback):
def
find_learning_rate
(
model
,
X
,
y
,
epochs
=
1
,
batch_size
=
32
,
min_rate
=
10
**
-
5
,
max_rate
=
10
model
,
X
,
y
,
epochs
=
1
,
batch_size
=
32
,
min_rate
=
10
**
-
5
,
max_rate
=
10
):
"""Trains the provided model for an epoch with an exponentially increasing learning rate"""
...
...
@@ -123,9 +124,9 @@ def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
y_kernel
=
compute_kernel
(
y
,
y
)
xy_kernel
=
compute_kernel
(
x
,
y
)
mmd
=
(
tf
.
reduce_mean
(
x_kernel
)
+
tf
.
reduce_mean
(
y_kernel
)
-
2
*
tf
.
reduce_mean
(
xy_kernel
)
tf
.
reduce_mean
(
x_kernel
)
+
tf
.
reduce_mean
(
y_kernel
)
-
2
*
tf
.
reduce_mean
(
xy_kernel
)
)
return
mmd
...
...
@@ -140,13 +141,13 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
"""
def
__init__
(
self
,
iterations
:
int
,
max_rate
:
float
,
start_rate
:
float
=
None
,
last_iterations
:
int
=
None
,
last_rate
:
float
=
None
,
log_dir
:
str
=
"."
,
self
,
iterations
:
int
,
max_rate
:
float
,
start_rate
:
float
=
None
,
last_iterations
:
int
=
None
,
last_rate
:
float
=
None
,
log_dir
:
str
=
"."
,
):
super
().
__init__
()
self
.
iterations
=
iterations
...
...
@@ -212,13 +213,13 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
"""
def
__init__
(
self
,
encoding_dim
:
int
,