Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
545d089a
Commit
545d089a
authored
Mar 29, 2021
by
lucas_miranda
Browse files
Changed all type() checks for isinstance() to take inheritance into account
parent
48ef8ea8
Changes
10
Show whitespace changes
Inline
Side-by-side
deepof/data.py
View file @
545d089a
...
...
@@ -40,8 +40,8 @@ import warnings
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"2"
# DEFINE CUSTOM ANNOTATED TYPES #
Coordinates
=
deepof
.
utils
.
New
Typ
e
(
"Coordinates"
,
deepof
.
utils
.
Any
)
Table_dict
=
deepof
.
utils
.
New
Typ
e
(
"Table_dict"
,
deepof
.
utils
.
Any
)
Coordinates
=
deepof
.
utils
.
New
isinstanc
e
(
"Coordinates"
,
deepof
.
utils
.
Any
)
Table_dict
=
deepof
.
utils
.
New
isinstanc
e
(
"Table_dict"
,
deepof
.
utils
.
Any
)
# CLASSES FOR PREPROCESSING AND DATA WRANGLING
...
...
@@ -549,7 +549,7 @@ class coordinates:
-
self
.
_scales
[
i
][
1
]
/
2
)
elif
typ
e
(
center
)
==
str
and
center
!=
"arena"
:
elif
isinstanc
e
(
center
,
str
)
and
center
!=
"arena"
:
for
i
,
(
key
,
value
)
in
enumerate
(
tabs
.
items
()):
...
...
@@ -583,7 +583,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
as
typ
e
(
"timedelta64[s]"
)
).
as
isinstanc
e
(
"timedelta64[s]"
)
if
align
:
assert
(
...
...
@@ -667,7 +667,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
as
typ
e
(
"timedelta64[s]"
)
).
as
isinstanc
e
(
"timedelta64[s]"
)
if
propagate_labels
:
for
key
,
tab
in
tabs
.
items
():
...
...
@@ -732,7 +732,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
as
typ
e
(
"timedelta64[s]"
)
).
as
isinstanc
e
(
"timedelta64[s]"
)
if
propagate_labels
:
for
key
,
tab
in
tabs
.
items
():
...
...
@@ -833,7 +833,7 @@ class coordinates:
)
pbar
.
update
(
1
)
if
typ
e
(
video_output
)
==
list
:
if
isinstanc
e
(
video_output
,
list
)
:
vid_idxs
=
video_output
elif
video_output
==
"all"
:
vid_idxs
=
list
(
self
.
_tables
.
keys
())
...
...
deepof/pose_utils.py
View file @
545d089a
...
...
@@ -25,7 +25,7 @@ import warnings
warnings
.
filterwarnings
(
"ignore"
,
message
=
"All-NaN slice encountered"
)
# Create custom string type
Coordinates
=
New
Typ
e
(
"Coordinates"
,
Any
)
Coordinates
=
New
isinstanc
e
(
"Coordinates"
,
Any
)
def
close_single_contact
(
...
...
@@ -53,12 +53,12 @@ def close_single_contact(
close_contact
=
None
if
typ
e
(
right
)
==
str
:
if
isinstanc
e
(
right
,
str
)
:
close_contact
=
(
np
.
linalg
.
norm
(
pos_dframe
[
left
]
-
pos_dframe
[
right
],
axis
=
1
)
*
arena_abs
)
/
arena_rel
<
tol
elif
typ
e
(
right
)
==
list
:
elif
isinstanc
e
(
right
,
list
)
:
close_contact
=
np
.
any
(
[
(
np
.
linalg
.
norm
(
pos_dframe
[
left
]
-
pos_dframe
[
r
],
axis
=
1
)
*
arena_abs
)
...
...
@@ -528,7 +528,7 @@ def max_behaviour(
speeds
=
[
col
for
col
in
behaviour_dframe
.
columns
if
"speed"
in
col
.
lower
()]
behaviour_dframe
=
behaviour_dframe
.
drop
(
speeds
,
axis
=
1
).
as
typ
e
(
"float"
)
behaviour_dframe
=
behaviour_dframe
.
drop
(
speeds
,
axis
=
1
).
as
isinstanc
e
(
"float"
)
win_array
=
behaviour_dframe
.
rolling
(
window_size
,
center
=
True
).
sum
()
if
stepped
:
win_array
=
win_array
[::
window_size
]
...
...
@@ -678,8 +678,8 @@ def rule_based_tagging(
return
deepof
.
utils
.
smooth_boolean_array
(
close_single_contact
(
coords
,
(
left
if
typ
e
(
left
)
!=
list
else
right
),
(
right
if
typ
e
(
left
)
!=
list
else
left
),
(
left
if
not
isinstanc
e
(
left
,
list
)
else
right
),
(
right
if
not
isinstanc
e
(
left
,
list
)
else
left
),
params
[
"close_contact_tol"
],
arena_abs
,
arena
[
1
][
1
],
...
...
deepof/utils.py
View file @
545d089a
...
...
@@ -26,7 +26,7 @@ from typing import Tuple, Any, List, Union, NewType
# DEFINE CUSTOM ANNOTATED TYPES #
Coordinates
=
New
Typ
e
(
"Coordinates"
,
Any
)
Coordinates
=
New
isinstanc
e
(
"Coordinates"
,
Any
)
# CONNECTIVITY FOR DLC MODELS
...
...
@@ -750,7 +750,7 @@ def cluster_transition_matrix(
# Stores all possible transitions between clusters
clusters
=
[
str
(
i
)
for
i
in
range
(
nclusts
)]
cluster_sequence
=
cluster_sequence
.
as
typ
e
(
str
)
cluster_sequence
=
cluster_sequence
.
as
isinstanc
e
(
str
)
trans
=
{
t
:
0
for
t
in
product
(
clusters
,
clusters
)}
k
=
len
(
clusters
)
...
...
supplementary_notebooks/main.ipynb
View file @
545d089a
%% Cell type:code id: tags:
```
python
%
load_ext
autoreload
%
autoreload
2
```
%% Cell type:code id: tags:
```
python
import
os
os
.
chdir
(
os
.
path
.
dirname
(
"../"
))
```
%% Cell type:code id: tags:
```
python
import
cv2
import
deepof.data
import
deepof.models
import
matplotlib.pyplot
as
plt
from
mpl_toolkits.mplot3d
import
Axes3D
import
numpy
as
np
import
pandas
as
pd
import
re
import
seaborn
as
sns
from
sklearn.preprocessing
import
StandardScaler
,
MinMaxScaler
import
tensorflow
as
tf
import
tqdm.notebook
as
tqdm
from
ipywidgets
import
interact
```
%% Cell type:code id: tags:
```
python
from
sklearn.manifold
import
TSNE
from
sklearn.decomposition
import
PCA
from
sklearn.discriminant_analysis
import
LinearDiscriminantAnalysis
```
%% Cell type:code id: tags:
```
python
import
umap
```
%% Cell type:markdown id: tags:
# Retrieve phenotypes
%% Cell type:code id: tags:
```
python
flatten
=
lambda
t
:
[
item
for
sublist
in
t
for
item
in
sublist
]
```
%% Cell type:code id: tags:
```
python
# Load first batch
dset11
=
pd
.
ExcelFile
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/1.Openfield_data-part1/JB05.1-OF-SI-part1.xlsx"
)
dset12
=
pd
.
ExcelFile
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_1/DLC_single_CDR1_1/2.Openfielddata-part2/AnimalID's-JB05.1-part2.xlsx"
)
dset11
=
pd
.
read_excel
(
dset11
,
"Tabelle2"
)
dset12
=
pd
.
read_excel
(
dset12
,
"Tabelle2"
)
dset11
.
Test
=
dset11
.
Test
.
apply
(
lambda
x
:
"Test {}_s11"
.
format
(
x
))
dset12
.
Test
=
dset12
.
Test
.
apply
(
lambda
x
:
"Test {}_s12"
.
format
(
x
))
dset1
=
{
"CSDS"
:
list
(
dset11
.
loc
[
dset11
.
Treatment
.
isin
([
"CTR+CSDS"
,
"NatCre+CSDS"
]),
"Test"
])
+
list
(
dset12
.
loc
[
dset12
.
Treatment
.
isin
([
"CTR+CSDS"
,
"NatCre+CSDS"
]),
"Test"
]),
"NS"
:
list
(
dset11
.
loc
[
dset11
.
Treatment
.
isin
([
"CTR+nonstressed"
,
"NatCre+nonstressed"
]),
"Test"
])
+
list
(
dset12
.
loc
[
dset12
.
Treatment
.
isin
([
"CTR+nonstressed"
,
"NatCre+nonstressed"
]),
"Test"
]),}
dset1inv
=
{}
for
i
in
flatten
(
list
(
dset1
.
values
())):
if
i
in
dset1
[
"CSDS"
]:
dset1inv
[
i
]
=
"CSDS"
else
:
dset1inv
[
i
]
=
"NS"
assert
len
(
dset1inv
)
==
dset11
.
shape
[
0
]
+
dset12
.
shape
[
0
],
"You missed some labels!"
```
%% Cell type:code id: tags:
```
python
# Load second batch
dset21
=
pd
.
read_excel
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part1/2_Single/stressproject22.04.2020genotypes-openfieldday1.xlsx"
)
dset22
=
pd
.
read_excel
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_2/Part2/2_Single/OpenFieldvideos-part2.xlsx"
)
dset21
.
Test
=
dset21
.
Test
.
apply
(
lambda
x
:
"Test {}_s21"
.
format
(
x
))
dset22
.
Test
=
dset22
.
Test
.
apply
(
lambda
x
:
"Test {}_s22"
.
format
(
x
))
dset2
=
{
"CSDS"
:
list
(
dset21
.
loc
[
dset21
.
Treatment
==
"Stress"
,
"Test"
])
+
list
(
dset22
.
loc
[
dset22
.
Treatment
==
"Stressed"
,
"Test"
]),
"NS"
:
list
(
dset21
.
loc
[
dset21
.
Treatment
==
"Nonstressed"
,
"Test"
])
+
list
(
dset22
.
loc
[
dset22
.
Treatment
==
"Nonstressed"
,
"Test"
])}
dset2inv
=
{}
for
i
in
flatten
(
list
(
dset2
.
values
())):
if
i
in
dset2
[
"CSDS"
]:
dset2inv
[
i
]
=
"CSDS"
else
:
dset2inv
[
i
]
=
"NS"
assert
len
(
dset2inv
)
==
dset21
.
shape
[
0
]
+
dset22
.
shape
[
0
],
"You missed some labels!"
```
%% Cell type:code id: tags:
```
python
# Load third batch
dset31
=
pd
.
read_excel
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/1.Day2OF-SIpart1/JB05 2Female-ELS-OF-SIpart1.xlsx"
,
sheet_name
=
1
)
dset32
=
pd
.
read_excel
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_3/2.Day3OF-SIpart2/JB05 2FEMALE-ELS-OF-SIpart2.xlsx"
,
sheet_name
=
1
)
dset31
.
Test
=
dset31
.
Test
.
apply
(
lambda
x
:
"Test {}_s31"
.
format
(
x
))
dset32
.
Test
=
dset32
.
Test
.
apply
(
lambda
x
:
"Test {}_s32"
.
format
(
x
))
dset3
=
{
"CSDS"
:[],
"NS"
:
list
(
dset31
.
loc
[:,
"Test"
])
+
list
(
dset32
.
loc
[:,
"Test"
])}
dset3inv
=
{}
for
i
in
flatten
(
list
(
dset3
.
values
())):
if
i
in
dset3
[
"CSDS"
]:
dset3inv
[
i
]
=
"CSDS"
else
:
dset3inv
[
i
]
=
"NS"
assert
len
(
dset3inv
)
==
dset31
.
shape
[
0
]
+
dset32
.
shape
[
0
],
"You missed some labels!"
```
%% Cell type:code id: tags:
```
python
# Load fourth batch
dset41
=
os
.
listdir
(
"../../Desktop/deepof-data/tagged_videos/Individual_datasets/DLC_batch_4/JB05.4-OpenFieldvideos/"
)
# Remove empty video!
dset41
=
[
vid
for
vid
in
dset41
if
"52"
not
in
vid
]
dset4
=
{
"CSDS"
:[],
"NS"
:
[
i
[:
-
4
]
+
"_s41"
for
i
in
dset41
]}
dset4inv
=
{}
for
i
in
flatten
(
list
(
dset4
.
values
())):
if
i
in
dset4
[
"CSDS"
]:
dset4inv
[
i
]
=
"CSDS"
else
:
dset4inv
[
i
]
=
"NS"
assert
len
(
dset4inv
)
==
len
(
dset41
),
"You missed some labels!"
```
%% Cell type:code id: tags:
```
python
# Merge phenotype dicts and serialise!
aggregated_dset
=
{
**
dset1inv
,
**
dset2inv
,
**
dset3inv
,
**
dset4inv
}
```
%% Cell type:code id: tags:
```
python
from
collections
import
Counter
print
(
Counter
(
aggregated_dset
.
values
()))
print
(
115
+
52
)
```
%%%% Output: stream
Counter({'NS': 115, 'CSDS': 52})
167
%% Cell type:code id: tags:
```
python
# Save aggregated dataset to disk
import
pickle
with
open
(
"../../Desktop/deepof-data/deepof_single_topview/deepof_exp_conditions.pkl"
,
"wb"
)
as
handle
:
pickle
.
dump
(
aggregated_dset
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
```
%% Cell type:markdown id: tags:
# Define and run project
%% Cell type:code id: tags:
```
python
%%
time
deepof_main
=
deepof
.
data
.
project
(
path
=
os
.
path
.
join
(
".."
,
".."
,
"Desktop"
,
"deepoftesttemp"
),
smooth_alpha
=
0.99
,
arena_dims
=
[
380
],
exclude_bodyparts
=
[
"Tail_1"
,
"Tail_2"
,
"Tail_tip"
,
"Tail_base"
],
exp_conditions
=
aggregated_dset
)
```
%%%% Output: stream
CPU times: user 111 ms, sys: 14 ms, total: 125 ms
Wall time: 123 ms
%% Cell type:code id: tags:
```
python
%%
time
deepof_main
=
deepof_main
.
run
(
verbose
=
True
)
print
(
deepof_main
)
```
%%%% Output: stream
Loading trajectories...
Smoothing trajectories...
Interpolating outliers...
Iterative imputation of ocluded bodyparts...
Computing distances...
Computing angles...
Done!
Coordinates of 2 videos across 2 conditions
CPU times: user 4.8 s, sys: 806 ms, total: 5.61 s
Wall time: 4.32 s
%% Cell type:code id: tags:
```
python
all_quality
=
pd
.
concat
([
tab
for
tab
in
deepof_main
.
get_quality
().
values
()])
```
%% Cell type:code id: tags:
```
python
all_quality
.
boxplot
(
rot
=
45
)
plt
.
ylim
(
0.99985
,
1.00001
)
plt
.
show
()
```
%%%% Output: display_data

%% Cell type:code id: tags:
```
python
@
interact
(
quality_top
=
(
0.
,
1.
,
0.01
))
def
low_quality_tags
(
quality_top
):
pd
.
DataFrame
(
pd
.
melt
(
all_quality
).
groupby
(
"bodyparts"
).
value
.
apply
(
lambda
y
:
sum
(
y
<
quality_top
)
/
len
(
y
)
*
100
)
).
sort_values
(
by
=
"value"
,
ascending
=
False
).
plot
.
bar
(
rot
=
45
)
plt
.
xlabel
(
"body part"
)
plt
.
ylabel
(
"Tags with quality under {} (%)"
.
format
(
quality_top
*
100
))
plt
.
tight_layout
()
plt
.
legend
([])
plt
.
show
()
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
```
python
%%
time
deepof_coords
=
deepof_main
.
get_coords
(
center
=
"Center"
,
polar
=
False
,
speed
=
0
,
align
=
"Spine_1"
,
align_inplace
=
True
,
propagate_labels
=
False
)
#deepof_dists = deepof_main.get_distances(propagate_labels=False)
#deepof_angles = deepof_main.get_angles(propagate_labels=False)
```
%%%% Output: stream
CPU times: user 624 ms, sys: 27 ms, total: 651 ms
Wall time: 662 ms
%% Cell type:markdown id: tags:
# Visualization
%% Cell type:code id: tags:
```
python
%%
time
tf
.
keras
.
backend
.
clear_session
()
print
(
"Preprocessing training set..."
)
deepof_train
=
deepof_coords
.
preprocess
(
window_size
=
24
,
window_step
=
24
,
conv_filter
=
None
,
scale
=
"standard"
,
shuffle
=
False
,
test_videos
=
0
,
)[
0
]
# print("Loading pre-trained model...")
# encoder, decoder, grouper, gmvaep, = deepof.models.SEQ_2_SEQ_GMVAE(
# loss="ELBO",
# number_of_components=20,
# compile_model=True,
# kl_warmup_epochs=20,
# montecarlo_kl=10,
# encoding=6,
# mmd_warmup_epochs=20,
# predictor=0,
# phenotype_prediction=0,
# ).build(deepof_train.shape)[:4]
```
%%%% Output: stream
Preprocessing training set...
CPU times: user 18.1 ms, sys: 13 ms, total: 31.1 ms
Wall time: 37.4 ms
%% Cell type:code id: tags:
```
python
weights
=
[
"./latreg_trained_weights/"
+
i
for
i
in
os
.
listdir
(
"./latreg_trained_weights/"
)
if
"encoding=8"
in
i
]
weights
```
%%%% Output: execute_result
['./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=none_20210212-021944_final_weights.h5',
'./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=categorical_20210212-031749_final_weights.h5',
'./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=categorical+variance_20210212-022008_final_weights.h5',
'./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=variance_20210212-023839_final_weights.h5']
%% Cell type:code id: tags:
```
python
trained_network
=
weights
[
2
]
print
(
trained_network
)
l
=
int
(
re
.
findall
(
"encoding=(\d+)_"
,
trained_network
)[
0
])
k
=
int
(
re
.
findall
(
"k=(\d+)_"
,
trained_network
)[
0
])
pheno
=
0
encoder
,
decoder
,
grouper
,
gmvaep
,
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
loss
=
"ELBO"
,
number_of_components
=
k
,
compile_model
=
True
,
kl_warmup_epochs
=
20
,
montecarlo_kl
=
10
,
encoding
=
l
,
mmd_warmup_epochs
=
20
,
predictor
=
0
,
phenotype_prediction
=
pheno
,
reg_cat_clusters
=
(
"categorical"
in
trained_network
),
reg_cluster_variance
=
(
"variance"
in
trained_network
),
).
build
(
deepof_train
.
shape
)[:
4
]
gmvaep
.
load_weights
(
trained_network
)
```
%%%% Output: stream
./latreg_trained_weights/GMVAE_loss=ELBO_encoding=8_k=25_latreg=categorical+variance_20210212-022008_final_weights.h5
%% Cell type:code id: tags:
```
python
# Get data to pass through the models
trained_distribution
=
encoder
(
deepof_train
)
categories
=
tf
.
keras
.
models
.
Model
(
encoder
.
input
,
encoder
.
layers
[
15
].
output
)(
deepof_train
).
numpy
()
# Fit a scaler to unscale the reconstructions later on
video_key
=
np
.
random
.
choice
(
list
(
deepof_coords
.
keys
()),
1
)[
0
]
scaler
=
StandardScaler
()
scaler
.
fit
(
np
.
array
(
pd
.
concat
(
list
(
deepof_coords
.
values
()))))
```
%%%% Output: execute_result
StandardScaler()
%% Cell type:code id: tags:
```
python
# Retrieve latent distribution parameters and sample from posterior
def
get_median_params
(
component
,
categories
,
cluster
,
param
):
# means = [np.median(component.mean().numpy(), axis=0) for component in mix_components]
# stddevs = [np.median(component.stddev().numpy(), axis=0) for component in mix_components]
if
param
==
"mean"
:
component
=
component
.
mean
().
numpy
()
elif
param
==
"stddev"
:
component
=
component
.
stddev
().
numpy
()
cluster_select
=
np
.
argmax
(
categories
,
axis
=
1
)
==
cluster
if
np
.
sum
(
cluster_select
)
==
0
:
return
None
component
=
component
[
cluster_select
]
return
np
.
median
(
component
,
axis
=
0
)
```
%% Cell type:code id: tags:
```
python
def
retrieve_latent_parameters
(
distribution
,
reduce
=
False
,
plot
=
False
,
categories
=
None
,
filt
=
0
,
save
=
True
):
mix_components
=
distribution
.
components
# The main problem is here! We need to select only those training instances in which a given cluster was selected.
# Then compute the median for those only
means
=
[
get_median_params
(
component
,
categories
,
i
,
"mean"
)
for
i
,
component
in
enumerate
(
mix_components
)]
stddevs
=
[
get_median_params
(
component
,
categories
,
i
,
"stddev"
)
for
i
,
component
in
enumerate
(
mix_components
)]
means
=
[
i
for
i
in
means
if
i
is
not
None
]
stddevs
=
[
i
for
i
in
stddevs
if
i
is
not
None
]
if
filter
:
filts
=
np
.
max
(
categories
,
axis
=
0
)
>
filt
means
=
[
i
for
i
,
j
in
zip
(
means
,
filts
)
if
j
]
stddevs
=
[
i
for
i
,
j
in
zip
(
stddevs
,
filts
)
if
j
]
if
reduce
:
data
=
[
np
.
random
.
normal
(
size
=
[
1000
,
len
(
means
[
0
])],
loc
=
meanvec
,
scale
=
stddevvec
)[:,
np
.
newaxis
]
for
meanvec
,
stddevvec
in
zip
(
means
,
stddevs
)]
data
=
np
.
concatenate
(
data
,
axis
=
1
).
reshape
([
1000
*
len
(
means
),
len
(
means
[
0
])])
reducer
=
PCA
(
n_components
=
3
)
data
=
reducer
.
fit_transform
(
data
)
data
=
data
.
reshape
([
1000
,
len
(
means
),
3
])
if
plot
==
2
:
for
i
in
range
(
len
(
means
)):
plt
.
scatter
(
data
[:,
i
,
0
],
data
[:,
i
,
1
],
label
=
i
)
plt
.
title
(
"Mean representation of latent space - K={}/{} - L={} - filt={}"
.
format
(
len
(
means
),
len
(
mix_components
),
len
(
means
[
0
]),
filt
))
plt
.
xlabel
(
"PCA 1"
)
plt
.
ylabel
(
"PCA 2"
)
#plt.legend()
if
save
:
plt
.
savefig
(
"Mean representation of latent space - K={}.{} - L={} - filt={}.png"
.
format
(
len
(
means
),
len
(
mix_components
),
len
(
means
[
0
]),
filt
).
replace
(
" "
,
"_"
))
plt
.
show
()
elif
plot
==
3
:
fig
=
plt
.
figure
()
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
for
i
in
range
(
len
(
means
)):
ax
.
scatter
(
data
[:,
i
,
0
],
data
[:,
i
,
1
],
data
[:,
i
,
2
],
label
=
i
)
plt
.
title
(
"Mean representation of latent space - K={}/{} - L={} - filt={}"
.
format
(
len
(
means
),
len
(
mix_components
),
len
(
means
[
0
]),
filt
))
ax
.
set_xlabel
(
"PCA 1"
)
ax
.
set_ylabel
(
"PCA 2"
)
ax
.
set_zlabel
(
"PCA 3"
)
#plt.legend()
if
save
:
plt
.
savefig
(
"Mean representation of latent space - K={}.{} - L={} - filt={}.png"
.
format
(
len
(
means
),
len
(
mix_components
),
len
(
means
[
0
]),
filt
).
replace
(
" "
,
"_"
))
plt
.
show
()
elif
plot
>
3
:
raise
ValueError
(
"Can't plot in more than 3 dimensions!"
)
return
means
,
stddevs
def
sample_from_posterior
(
decoder
,
parameters
,
component
,
enable_variance
=
False
,
video_output
=
False
,
samples
=
1
):
means
,
stddevs
=
parameters
sample
=
np
.
random
.
normal
(
size
=
[
samples
,
len
(
means
[
component
])],
loc
=
means
[
component
],
scale
=
(
stddevs
[
component
]
if
enable_variance
else
0
))