Commit bc2989c1 authored by lucas_miranda's avatar lucas_miranda
Browse files

GMVAEP now returns prior and posterior distributions

parent 78db4095
......@@ -624,7 +624,7 @@ class SEQ_2_SEQ_GMVAE:
name="encoding_distribution",
)([z_cat, z_gauss])
encode_to_distribution = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
# Define and control custom loss functions
if "ELBO" in self.loss:
......@@ -679,7 +679,7 @@ class SEQ_2_SEQ_GMVAE:
)(generator)
# define individual branches as models
encode_to_vector = Model(x, z, name="SEQ_2_SEQ_VEncoder")
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
generator = Model(g, x_decoded_mean, name="vae_reconstruction")
def log_loss(x_true, p_x_q_given_z):
......@@ -687,7 +687,7 @@ class SEQ_2_SEQ_GMVAE:
the output distribution"""
return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
model_outs = [generator(encode_to_vector.outputs)]
model_outs = [generator(encoder.outputs)]
model_losses = [log_loss]
model_metrics = {"vae_reconstruction": ["mae", "mse"]}
loss_weights = [1.0]
......@@ -736,9 +736,9 @@ class SEQ_2_SEQ_GMVAE:
loss_weights.append(self.phenotype_prediction)
# define grouper and end-to-end autoencoder model
grouper = Model(encode_to_vector.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
gmvaep = Model(
inputs=encode_to_vector.inputs,
inputs=encoder.inputs,
outputs=model_outs,
name="SEQ_2_SEQ_GMVAE",
)
......@@ -754,11 +754,12 @@ class SEQ_2_SEQ_GMVAE:
gmvaep.build(input_shape)
return (
encode_to_vector,
encode_to_distribution,
encoder,
generator,
grouper,
gmvaep,
self.prior,
posterior,
)
@prior.setter
......
......@@ -606,142 +606,23 @@
### 6. Sample from latent space (to be incorporated into deepof.evaluate)
 
%% Cell type:code id: tags:
 
``` python
# Retrieve latent distribution parameters and sample from posterior
def get_median_params(component, categories, cluster, param):
# means = [np.median(component.mean().numpy(), axis=0) for component in mix_components]
# stddevs = [np.median(component.stddev().numpy(), axis=0) for component in mix_components]
if param == "mean":
component = component.mean().numpy()
elif param == "stddev":
component = component.stddev().numpy()
cluster_select = np.argmax(categories, axis=1) == cluster
if np.sum(cluster_select) == 0:
return None
component = component[cluster_select]
return np.median(component, axis=0)
# Get prior distribution
gmvaep.prior
```
 
%% Cell type:code id: tags:
``` python
def retrieve_latent_parameters(
distribution, reduce=False, plot=False, categories=None, filt=0, save=True
):
mix_components = distribution.components
# The main problem is here! We need to select only those training instances in which a given cluster was selected.
# Then compute the median for those only
means = [
get_median_params(component, categories, i, "mean")
for i, component in enumerate(mix_components)
]
stddevs = [
get_median_params(component, categories, i, "stddev")
for i, component in enumerate(mix_components)
]
means = [i for i in means if i is not None]
stddevs = [i for i in stddevs if i is not None]
filts = np.max(categories, axis=0) > filt
means = [i for i, j in zip(means, filts) if j]
stddevs = [i for i, j in zip(stddevs, filts) if j]
return means, stddevs
def sample_from_posterior(
decoder, parameters, component, enable_variance=False, samples=1
):
means, stddevs = parameters
sample = np.random.normal(
size=[samples, len(means[component])],
loc=means[component],
scale=(stddevs[component] if enable_variance else 0),
)
reconstruction = decoder(sample).mean()
return reconstruction
```
%% Cell type:code id: tags:
``` python
samples = np.random.choice(range(data_prep.shape[0]), 10000)
latent_distribution = encode_to_distribution(data_prep[samples])
```
%% Cell type:code id: tags:
``` python
means, stddevs = retrieve_latent_parameters(
latent_distribution,
categories=groupings[samples],
reduce=False,
plot=False,
filt=0.,
save=False,
)
```
%% Cell type:code id: tags:
``` python
# Plot sampled data in the latent space
@interact(dim_red=["PCA", "LDA", "umap", "tSNE"], samples=(5000, 15000))
def plot_static_latent_space(dim_red, samples):
if dim_red == "umap":
reducer = umap.UMAP(n_components=2)
elif dim_red == "LDA":
reducer = LinearDiscriminantAnalysis(n_components=2)
elif dim_red == "PCA":
reducer = PCA(n_components=2)
else:
reducer = TSNE(n_components=2)
categories = latent_distribution.cat.sample(samples).numpy().flatten()
mixture_sample = np.squeeze(
np.concatenate(
[latent_distribution.components[i].sample(1) for i in categories]
)
)
print(mixture_sample.shape)
if dim_red != "LDA":
enc = reducer.fit_transform(mixture_sample)
else:
enc = reducer.fit_transform(
mixture_sample,
np.repeat(range(len(latent_distribution.components)), categories),
)
plt.figure(figsize=(12, 8))
sns.scatterplot(enc[:, 0], enc[:, 1], hue=categories, palette="muted")
plt.title(
"Mean representation of latent space - K={}/{} - L={}".format(
len(means),
len(latent_distribution.components),
len(latent_distribution.components),
)
)
plt.xlabel("{} 1".format(dim_red))
plt.ylabel("{} 2".format(dim_red))
plt.suptitle("Static view of trained latent space")
plt.show()
```
%%%% Output: error
 
%%%% Output: display_data
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-153-0d3135aace65> in <module>
1 # Get prior distribution
----> 2 gmvaep.prior
 
AttributeError: 'Functional' object has no attribute 'prior'
 
%% Cell type:code id: tags:
 
``` python
 
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment