Commit bc2989c1 authored by lucas_miranda's avatar lucas_miranda
Browse files

GMVAEP now returns prior and posterior distributions

parent 78db4095
......@@ -624,7 +624,7 @@ class SEQ_2_SEQ_GMVAE:
name="encoding_distribution",
)([z_cat, z_gauss])
encode_to_distribution = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
# Define and control custom loss functions
if "ELBO" in self.loss:
......@@ -679,7 +679,7 @@ class SEQ_2_SEQ_GMVAE:
)(generator)
# define individual branches as models
encode_to_vector = Model(x, z, name="SEQ_2_SEQ_VEncoder")
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
generator = Model(g, x_decoded_mean, name="vae_reconstruction")
def log_loss(x_true, p_x_q_given_z):
......@@ -687,7 +687,7 @@ class SEQ_2_SEQ_GMVAE:
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [generator(encode_to_vector.outputs)]
model_outs = [generator(encoder.outputs)]
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
......@@ -736,9 +736,9 @@ class SEQ_2_SEQ_GMVAE:
loss_weights.append(self.phenotype_prediction)
# define grouper and end-to-end autoencoder model
grouper = Model(encode_to_vector.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
gmvaep = Model(
inputs=encode_to_vector.inputs,
inputs=encoder.inputs,
outputs=model_outs,
name="SEQ_2_SEQ_GMVAE",
)
......@@ -754,11 +754,12 @@ class SEQ_2_SEQ_GMVAE:
gmvaep.build(input_shape)
return (
encode_to_vector,
encode_to_distribution,
encoder,
generator,
grouper,
gmvaep,
self.prior,
posterior,
)
@prior.setter
......
%% Cell type:code id: tags:
 
``` python
%load_ext autoreload
%autoreload 2
```
 
%% Cell type:code id: tags:
 
``` python
import warnings
 
warnings.filterwarnings("ignore")
```
 
%% Cell type:markdown id: tags:
 
# deepOF model evaluation
 
%% Cell type:markdown id: tags:
 
Given a dataset and a trained model, this notebook allows the user to
 
* Load and inspect the different models (encoder, decoder, grouper, gmvaep)
* Visualize reconstruction quality for a given model
* Visualize a static latent space
* Visualize trajectories on the latent space for a given video
* sample from the latent space distributions and generate video clips showcasing generated data
 
%% Cell type:code id: tags:
 
``` python
import os
 
os.chdir(os.path.dirname("../"))
```
 
%% Cell type:code id: tags:
 
``` python
import deepof.data
import deepof.utils
import numpy as np
import pandas as pd
import re
import tensorflow as tf
from collections import Counter
from sklearn.preprocessing import StandardScaler
 
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import umap
 
from ipywidgets import interactive, interact, HBox, Layout, VBox
from IPython import display
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import seaborn as sns
 
from ipywidgets import interact
```
 
%% Cell type:markdown id: tags:
 
### 1. Define and run project
 
%% Cell type:code id: tags:
 
``` python
path = os.path.join("..", "..", "Desktop", "deepoftesttemp")
trained_network = os.path.join("..", "..", "Desktop", "trained_weights")
exclude_bodyparts = tuple([""])
window_size = 24
```
 
%% Cell type:code id: tags:
 
``` python
%%time
proj = deepof.data.project(
path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],
)
```
 
%% Output
 
CPU times: user 298 ms, sys: 24 ms, total: 322 ms
Wall time: 276 ms
 
%% Cell type:code id: tags:
 
``` python
%%time
proj = proj.run(verbose=True)
print(proj)
```
 
%% Output
 
Loading trajectories...
Smoothing trajectories...
Interpolating outliers...
Iterative imputation of ocluded bodyparts...
Computing distances...
Computing angles...
Done!
deepof analysis of 2 videos
CPU times: user 50.3 s, sys: 429 ms, total: 50.8 s
Wall time: 7.09 s
 
%% Cell type:markdown id: tags:
 
### 2. Load pretrained deepof model
 
%% Cell type:code id: tags:
 
``` python
coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
data_prep = coords.preprocess(test_videos=0, window_step=1, window_size=window_size, shuffle=True)[
0
]
```
 
%% Cell type:code id: tags:
 
``` python
[i for i in os.listdir(trained_network) if i.endswith("h5")]
```
 
%% Output
 
['GMVAE_loss=MMD_encoding=10_k=25_latreg=categorical+variance_entknn=100_final_weights.h5',
'GMVAE_loss=ELBO_encoding=12_k=25_latreg=variance_entknn=50_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=10_k=25_latreg=variance_entknn=50_final_weights.h5',
'GMVAE_loss=ELBO_encoding=8_k=25_latreg=none_entknn=80_final_weights.h5',
'GMVAE_loss=ELBO_encoding=6_k=25_latreg=categorical+variance_entknn=20_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=2_k=25_latreg=none_entknn=80_final_weights.h5',
'GMVAE_loss=MMD_encoding=12_k=25_latreg=categorical+variance_entknn=20_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=16_k=25_latreg=categorical_entknn=80_final_weights.h5',
'GMVAE_loss=ELBO_encoding=8_k=25_latreg=variance_entknn=80_final_weights.h5',
'GMVAE_loss=MMD_encoding=6_k=25_latreg=none_entknn=50_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=16_k=25_latreg=categorical+variance_entknn=80_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=4_k=25_latreg=categorical_entknn=80_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=14_k=25_latreg=categorical+variance_entknn=80_final_weights.h5',
'GMVAE_loss=MMD_encoding=12_k=25_latreg=categorical_entknn=20_final_weights.h5',
'GMVAE_loss=ELBO+MMD_encoding=6_k=25_latreg=categorical+variance_entknn=100_final_weights.h5']
 
%% Cell type:code id: tags:
 
``` python
deepof_weights = [i for i in os.listdir(trained_network) if i.endswith("h5")][-1]
deepof_weights
```
 
%% Output
 
'GMVAE_loss=ELBO+MMD_encoding=6_k=25_latreg=categorical+variance_entknn=100_final_weights.h5'
 
%% Cell type:code id: tags:
 
``` python
# Set model parameters
encoding = int(re.findall("encoding=(\d+)_", deepof_weights)[0])
k = int(re.findall("k=(\d+)_", deepof_weights)[0])
loss = re.findall("loss=(.+?)_", deepof_weights)[0]
pheno = 0
predictor = 0
```
 
%% Cell type:code id: tags:
 
``` python
encode_to_vector, encode_to_distribution, decoder, grouper, gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
compile_model=True,
encoding=encoding,
predictor=predictor,
phenotype_prediction=pheno,
).build(data_prep.shape)
 
gmvaep.load_weights(os.path.join(trained_network, deepof_weights))
```
 
%% Cell type:code id: tags:
 
``` python
# Uncomment to see model summaries
# encoder.summary()
# decoder.summary()
# grouper.summary()
# gmvaep.summary()
```
 
%% Cell type:code id: tags:
 
``` python
# Uncomment to plot model structure
def plot_model(model, name):
tf.keras.utils.plot_model(
model,
to_file=os.path.join(
path,
"deepof_{}_{}.png".format(name, datetime.now().strftime("%Y%m%d-%H%M%S")),
),
show_shapes=True,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=True,
dpi=200,
)
 
 
# plot_model(encoder, "encoder")
# plot_model(decoder, "decoder")
# plot_model(grouper, "grouper")
# plot_model(gmvaep, "gmvaep")
```
 
%% Cell type:markdown id: tags:
 
### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)
 
%% Cell type:code id: tags:
 
``` python
# Auxiliary animation functions
 
 
def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
rec_plots = []
for edge in edges:
(temp_plot,) = ax.plot(
[float(instant_x[edge[0]]), float(instant_x[edge[1]])],
[float(instant_y[edge[0]]), float(instant_y[edge[1]])],
color="#006699",
linewidth=2.0,
)
(temp_rec_plot,) = ax.plot(
[float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],
[float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],
color="red",
linewidth=2.0,
)
plots.append(temp_plot)
rec_plots.append(temp_rec_plot)
return plots, rec_plots
 
 
def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):
"""Updates the graph plot to enable animation"""
 
for plot, edge in zip(plots, edges):
plot.set_data(
[float(x[edge[0]]), float(x[edge[1]])],
[float(y[edge[0]]), float(y[edge[1]])],
)
for plot, edge in zip(rec_plots, edges):
plot.set_data(
[float(rec_x[edge[0]]), float(rec_x[edge[1]])],
[float(rec_y[edge[0]]), float(rec_y[edge[1]])],
)
```
 
%% Cell type:code id: tags:
 
``` python
# Display a video with the original data superimposed with the reconstructions
 
coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
random_exp = np.random.choice(list(coords.keys()), 1)[0]
print(random_exp)
 
 
def animate_mice_across_time(random_exp):
 
# Define canvas
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
 
# Retrieve body graph
edges = deepof.utils.connect_mouse_topview()
 
for bpart in exclude_bodyparts:
if bpart:
edges.remove_node(bpart)
 
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
if ("Tail_base", limb) in edges.edges():
edges.remove_edge("Tail_base", limb)
 
edges = edges.edges()
 
# Compute observed and predicted data to plot
data = coords[random_exp]
coords_rec = coords.filter_videos([random_exp])
data_prep = coords_rec.preprocess(
test_videos=0, window_step=1, window_size=window_size, shuffle=False
)[0]
 
data_rec = gmvaep.predict(data_prep)
data_rec = pd.DataFrame(coords_rec._scaler.inverse_transform(data_rec[:, 6, :]))
data_rec.columns = data.columns
data = pd.DataFrame(coords_rec._scaler.inverse_transform(data_prep[:, 6, :]))
data.columns = data_rec.columns
 
# Add Central coordinate, lost during alignment
data["Center", "x"] = 0
data["Center", "y"] = 0
data_rec["Center", "x"] = 0
data_rec["Center", "y"] = 0
 
# Plot!
init_x = data.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
init_rec_x = data_rec.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_rec_y = data_rec.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
 
plots, rec_plots = plot_mouse_graph(
init_x, init_y, init_rec_x, init_rec_y, ax, edges
)
scatter = ax.scatter(
x=np.array(init_x), y=np.array(init_y), color="#006699", label="Original"
)
rec_scatter = ax.scatter(
x=np.array(init_rec_x),
y=np.array(init_rec_y),
color="red",
label="Reconstruction",
)
 
# Update data in main plot
def animation_frame(i):
# Update scatter plot
x = data.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
rec_x = data_rec.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
rec_y = data_rec.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
 
scatter.set_offsets(np.c_[np.array(x), np.array(y)])
rec_scatter.set_offsets(np.c_[np.array(rec_x), np.array(rec_y)])
update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges)
 
return scatter
 
animation = FuncAnimation(fig, func=animation_frame, frames=250, interval=75,)
 
ax.set_title("Original versus reconstructed data")
ax.set_ylim(-100, 60)
ax.set_xlim(-60, 60)
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.legend()
 
video = animation.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
 
 
animate_mice_across_time(random_exp)
```
 
%% Output
 
Test 1_s12
 
 
%% Cell type:markdown id: tags:
 
### 5. Evaluate latent space (to be incorporated into deepof.evaluate)
 
%% Cell type:code id: tags:
 
``` python
# Get encodings and groupings for the same random video as above
data_prep = coords.preprocess(
test_videos=0, window_step=1, window_size=window_size, shuffle=False
)[0]
 
encodings = encode_to_vector.predict(data_prep)
groupings = grouper.predict(data_prep)
hard_groups = np.argmax(groupings, axis=1)
```
 
%% Cell type:code id: tags:
 
``` python
@interact(minimum_confidence=(0.0, 1.0, 0.01))
def plot_cluster_population(minimum_confidence):
plt.figure(figsize=(12, 8))
 
groups = hard_groups[np.max(groupings, axis=1) > minimum_confidence].flatten()
groups = np.concatenate([groups, np.arange(25)])
sns.countplot(groups)
plt.xlabel("Cluster")
plt.title("Training instances per cluster")
plt.show()
```
 
%% Output
 
 
%% Cell type:markdown id: tags:
 
The slider in the figure above lets you set the minimum confidence the model may yield when assigning a training instance to a cluster in order to be visualized.
 
%% Cell type:code id: tags:
 
``` python
# Plot real data in the latent space
 
 
@interact(
samples=(1000, 10000, 500),
minimum_confidence=(0.0, 0.99, 0.01),
dim_red=["PCA", "LDA", "umap", "tSNE"],
)
def plot_static_latent_space(samples, minimum_confidence, dim_red):
if dim_red == "umap":
reducer = umap.UMAP(n_components=2)
elif dim_red == "LDA":
reducer = LinearDiscriminantAnalysis(n_components=2)
elif dim_red == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2)
 
encods = encodings[np.max(groupings, axis=1) > minimum_confidence]
groups = groupings[np.max(groupings, axis=1) > minimum_confidence]
hgroups = hard_groups[np.max(groupings, axis=1) > minimum_confidence].flatten()
 
samples = np.random.choice(range(encods.shape[0]), samples)
sample_enc = encods[samples, :]
sample_grp = groups[samples, :]
sample_hgr = hgroups[samples]
 
if dim_red != "LDA":
enc = reducer.fit_transform(sample_enc)
else:
enc = reducer.fit_transform(sample_enc, sample_hgr)
 
plt.figure(figsize=(12, 8))
 
sns.scatterplot(
x=enc[:, 0],
y=enc[:, 1],
hue=sample_hgr,
size=np.max(sample_grp, axis=1),
sizes=(1, 100),
palette="muted",
)
plt.xlabel("{} 1".format(dim_red))
plt.ylabel("{} 2".format(dim_red))
plt.suptitle("Static view of trained latent space")
plt.show()
```
 
%% Output
 
 
%% Cell type:code id: tags:
 
``` python
def plot_mouse_graph(instant_x, instant_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
for edge in edges:
(temp_plot,) = ax.plot(
[float(instant_x[edge[0]]), float(instant_x[edge[1]])],
[float(instant_y[edge[0]]), float(instant_y[edge[1]])],
color="#006699",
linewidth=2.0,
)
plots.append(temp_plot)
return plots
 
 
def update_mouse_graph(x, y, plots, edges):
"""Updates the graph plot to enable animation"""
 
for plot, edge in zip(plots, edges):
plot.set_data(
[float(x[edge[0]]), float(x[edge[1]])],
[float(y[edge[0]]), float(y[edge[1]])],
)
```
 
%% Cell type:code id: tags:
 
``` python
# Plot trajectory of a video in latent space
 
 
@interact(
samples=(1000, 10000, 500),
trajectory=(100, 500),
trace=False,
dim_red=["PCA", "LDA", "umap", "tSNE"],
)
def plot_dynamic_latent_pace(samples, trajectory, trace, dim_red):
if dim_red == "umap":
reducer = umap.UMAP(n_components=2)
elif dim_red == "LDA":
reducer = LinearDiscriminantAnalysis(n_components=2)
elif dim_red == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2)
 
if dim_red != "LDA":
enc = reducer.fit_transform(encodings)
else:
enc = reducer.fit_transform(encodings, hard_groups)
 
traj_enc = enc[:samples, :]
traj_grp = enc[:samples, :]
traj_hgr = enc[:samples]
 
samples = np.random.choice(range(enc.shape[0]), samples)
sample_enc = enc[samples, :]
sample_grp = enc[samples, :]
sample_hgr = hard_groups[samples]
 
# Define two figures arranged horizontally
fig, (ax, ax2) = plt.subplots(
1, 2, figsize=(12, 8), gridspec_kw={"width_ratios": [3, 1.5]}
)
 
# Plot the animated embedding trajectory on the left
sns.scatterplot(
x=sample_enc[:, 0],
y=sample_enc[:, 1],
hue=sample_hgr,
size=np.max(sample_grp, axis=1),
sizes=(1, 100),
palette="muted",
ax=ax,
)
 
traj_init = traj_enc[0, :]
scatter = ax.scatter(
x=[traj_init[0]], y=[traj_init[1]], s=100, color="red", edgecolor="black"
)
(lineplt,) = ax.plot([traj_init[0]], [traj_init[1]], color="red", linewidth=2.0)
tracking_line_x = []
tracking_line_y = []
 
# Plot the initial data (before feeding it to the encoder) on the right
edges = deepof.utils.connect_mouse_topview()
 
for bpart in exclude_bodyparts:
if bpart:
edges.remove_node(bpart)
 
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
if ("Tail_base", limb) in list(edges.edges()):
edges.remove_edge("Tail_base", limb)
 
edges = edges.edges()
 
inv_coords = coords._scaler.inverse_transform(data_prep)[:, window_size // 2, :]
data = pd.DataFrame(inv_coords, columns=coords[random_exp].columns)
 
data["Center", "x"] = 0
data["Center", "y"] = 0
 
init_x = data.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
 
plots = plot_mouse_graph(init_x, init_y, ax2, edges)
track = ax2.scatter(x=np.array(init_x), y=np.array(init_y), color="#006699",)
 
# Update data in both plots
def animation_frame(i):
# Update scatter plot
offset = traj_enc[i, :]
 
prev_t = scatter.get_offsets()[0]
 
if trace:
tracking_line_x.append([prev_t[0], offset[0]])
tracking_line_y.append([prev_t[1], offset[1]])
lineplt.set_xdata(tracking_line_x)
lineplt.set_ydata(tracking_line_y)
 
scatter.set_offsets(np.c_[np.array(offset[0]), np.array(offset[1])])
 
x = data.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
track.set_offsets(np.c_[np.array(x), np.array(y)])
update_mouse_graph(x, y, plots, edges)
 
return scatter
 
animation = FuncAnimation(
fig, func=animation_frame, frames=trajectory, interval=75,
)
 
ax.set_xlabel("{} 1".format(dim_red))
ax.set_ylabel("{} 2".format(dim_red))
 
ax2.set_xlabel("x")
ax2.set_xlabel("y")
ax2.set_ylim(-90, 60)
ax2.set_xlim(-60, 60)
 
plt.tight_layout()
 
video = animation.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
```
 
%% Output
 
 
%% Cell type:markdown id: tags:
 
### 6. Sample from latent space (to be incorporated into deepof.evaluate)
 
%% Cell type:code id: tags:
 
``` python
# Retrieve latent distribution parameters and sample from posterior
def get_median_params(component, categories, cluster, param):
# means = [np.median(component.mean().numpy(), axis=0) for component in mix_components]
# stddevs = [np.median(component.stddev().numpy(), axis=0) for component in mix_components]
if param == "mean":
component = component.mean().numpy()
elif param == "stddev":
component = component.stddev().numpy()
cluster_select = np.argmax(categories, axis=1) == cluster
if np.sum(cluster_select) == 0:
return None
component = component[cluster_select]
return np.median(component, axis=0)
```
%% Cell type:code id: tags:
``` python
def retrieve_latent_parameters(
distribution, reduce=False, plot=False, categories=None, filt=0, save=True
):
mix_components = distribution.components
# The main problem is here! We need to select only those training instances in which a given cluster was selected.
# Then compute the median for those only
means = [
get_median_params(component, categories, i, "mean")
for i, component in enumerate(mix_components)
]
stddevs = [
get_median_params(component, categories, i, "stddev")
for i, component in enumerate(mix_components)
]
means = [i for i in means if i is not None]
stddevs = [i for i in stddevs if i is not None]
filts = np.max(categories, axis=0) > filt
means = [i for i, j in zip(means, filts) if j]
stddevs = [i for i, j in zip(stddevs, filts) if j]
return means, stddevs
def sample_from_posterior(
decoder, parameters, component, enable_variance=False, samples=1
):
means, stddevs = parameters
sample = np.random.normal(
size=[samples, len(means[component])],
loc=means[component],
scale=(stddevs[component] if enable_variance else 0),
)
reconstruction = decoder(sample).mean()
return reconstruction
```
%% Cell type:code id: tags:
``` python
samples = np.random.choice(range(data_prep.shape[0]), 10000)
latent_distribution = encode_to_distribution(data_prep[samples])
```
%% Cell type:code id: tags:
``` python
means, stddevs = retrieve_latent_parameters(
latent_distribution,
categories=groupings[samples],
reduce=False,
plot=False,
filt=0.,
save=False,
)
```
%% Cell type:code id: tags:
``` python
# Plot sampled data in the latent space
@interact(dim_red=["PCA", "LDA", "umap", "tSNE"], samples=(5000, 15000))
def plot_static_latent_space(dim_red, samples):
if dim_red == "umap":
reducer = umap.UMAP(n_components=2)
elif dim_red == "LDA":
reducer = LinearDiscriminantAnalysis(n_components=2)
elif dim_red == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2)
categories = latent_distribution.cat.sample(samples).numpy().flatten()
mixture_sample = np.squeeze(
np.concatenate(
[latent_distribution.components[i].sample(1) for i in categories]
)
)
print(mixture_sample.shape)
if dim_red != "LDA":
enc = reducer.fit_transform(mixture_sample)
else:
enc = reducer.fit_transform(
mixture_sample,
np.repeat(range(len(latent_distribution.components)), categories),
)
plt.figure(figsize=(12, 8))
sns.scatterplot(enc[:, 0], enc[:, 1], hue=categories, palette="muted")
plt.title(
"Mean representation of latent space - K={}/{} - L={}".format(
len(means),
len(latent_distribution.components),
len(latent_distribution.components),
)
)
plt.xlabel("{} 1".format(dim_red))
plt.ylabel("{} 2".format(dim_red))
plt.suptitle("Static view of trained latent space")
plt.show()
# Get prior distribution
gmvaep.prior
```
 
%% Output
 
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-153-0d3135aace65> in <module>
1 # Get prior distribution
----> 2 gmvaep.prior
AttributeError: 'Functional' object has no attribute 'prior'
 
%% Cell type:code id: tags:
 
``` python
```
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment