diff --git a/deepof/models.py b/deepof/models.py index 3ce3e5b6a62a383cc010f8646050b5e6fd1cd3c8..29e73854e5aa93cccc49912ed48c58cb4388664a 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -624,7 +624,7 @@ class SEQ_2_SEQ_GMVAE: name="encoding_distribution", )([z_cat, z_gauss]) - encode_to_distribution = Model(x, z, name="SEQ_2_SEQ_trained_distribution") + posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution") # Define and control custom loss functions if "ELBO" in self.loss: @@ -679,7 +679,7 @@ class SEQ_2_SEQ_GMVAE: )(generator) # define individual branches as models - encode_to_vector = Model(x, z, name="SEQ_2_SEQ_VEncoder") + encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder") generator = Model(g, x_decoded_mean, name="vae_reconstruction") def log_loss(x_true, p_x_q_given_z): @@ -687,7 +687,7 @@ class SEQ_2_SEQ_GMVAE: the output distribution""" return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true)) - model_outs = [generator(encode_to_vector.outputs)] + model_outs = [generator(encoder.outputs)] model_losses = [log_loss] model_metrics = {"vae_reconstruction": ["mae", "mse"]} loss_weights = [1.0] @@ -736,9 +736,9 @@ class SEQ_2_SEQ_GMVAE: loss_weights.append(self.phenotype_prediction) # define grouper and end-to-end autoencoder model - grouper = Model(encode_to_vector.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering") + grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering") gmvaep = Model( - inputs=encode_to_vector.inputs, + inputs=encoder.inputs, outputs=model_outs, name="SEQ_2_SEQ_GMVAE", ) @@ -754,11 +754,12 @@ class SEQ_2_SEQ_GMVAE: gmvaep.build(input_shape) return ( - encode_to_vector, - encode_to_distribution, + encoder, generator, grouper, gmvaep, + self.prior, + posterior, ) @prior.setter diff --git a/supplementary_notebooks/deepof_model_evaluation.ipynb b/supplementary_notebooks/deepof_model_evaluation.ipynb index de62ca5557c95aa9440697eb9cb3304a7a28dea3..8e1bbe4075a4b6afa2a5190d43aab72a0b75de46 100644 --- a/supplementary_notebooks/deepof_model_evaluation.ipynb +++ b/supplementary_notebooks/deepof_model_evaluation.ipynb @@ -15219,165 +15219,24 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# Retrieve latent distribution parameters and sample from posterior\n", - "def get_median_params(component, categories, cluster, param):\n", - " # means = [np.median(component.mean().numpy(), axis=0) for component in mix_components]\n", - " # stddevs = [np.median(component.stddev().numpy(), axis=0) for component in mix_components]\n", - " if param == \"mean\":\n", - " component = component.mean().numpy()\n", - " elif param == \"stddev\":\n", - " component = component.stddev().numpy()\n", - "\n", - " cluster_select = np.argmax(categories, axis=1) == cluster\n", - " if np.sum(cluster_select) == 0:\n", - " return None\n", - " component = component[cluster_select]\n", - " return np.median(component, axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "def retrieve_latent_parameters(\n", - " distribution, reduce=False, plot=False, categories=None, filt=0, save=True\n", - "):\n", - " mix_components = distribution.components\n", - "\n", - " # The main problem is here! We need to select only those training instances in which a given cluster was selected.\n", - " # Then compute the median for those only\n", - "\n", - " means = [\n", - " get_median_params(component, categories, i, \"mean\")\n", - " for i, component in enumerate(mix_components)\n", - " ]\n", - " stddevs = [\n", - " get_median_params(component, categories, i, \"stddev\")\n", - " for i, component in enumerate(mix_components)\n", - " ]\n", - " means = [i for i in means if i is not None]\n", - " stddevs = [i for i in stddevs if i is not None]\n", - "\n", - " filts = np.max(categories, axis=0) > filt\n", - " means = [i for i, j in zip(means, filts) if j]\n", - " stddevs = [i for i, j in zip(stddevs, filts) if j]\n", - "\n", - " return means, stddevs\n", - "\n", - "\n", - "def sample_from_posterior(\n", - " decoder, parameters, component, enable_variance=False, samples=1\n", - "):\n", - " means, stddevs = parameters\n", - " sample = np.random.normal(\n", - " size=[samples, len(means[component])],\n", - " loc=means[component],\n", - " scale=(stddevs[component] if enable_variance else 0),\n", - " )\n", - " reconstruction = decoder(sample).mean()\n", - "\n", - " return reconstruction" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "metadata": {}, - "outputs": [], - "source": [ - "samples = np.random.choice(range(data_prep.shape[0]), 10000)\n", - "latent_distribution = encode_to_distribution(data_prep[samples])" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "metadata": {}, - "outputs": [], - "source": [ - "means, stddevs = retrieve_latent_parameters(\n", - " latent_distribution,\n", - " categories=groupings[samples],\n", - " reduce=False,\n", - " plot=False,\n", - " filt=0.,\n", - " save=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 104, + "execution_count": 153, "metadata": {}, "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fd7e515df218432b88e6fa7a6a75f9de", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "interactive(children=(Dropdown(description='dim_red', options=('PCA', 'LDA', 'umap', 'tSNE'), value='PCA'), In…" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "AttributeError", + "evalue": "'Functional' object has no attribute 'prior'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-153-0d3135aace65>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Get prior distribution\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mgmvaep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprior\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m: 'Functional' object has no attribute 'prior'" + ] } ], "source": [ - "# Plot sampled data in the latent space\n", - "\n", - "\n", - "@interact(dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"], samples=(5000, 15000))\n", - "def plot_static_latent_space(dim_red, samples):\n", - " if dim_red == \"umap\":\n", - " reducer = umap.UMAP(n_components=2)\n", - " elif dim_red == \"LDA\":\n", - " reducer = LinearDiscriminantAnalysis(n_components=2)\n", - " elif dim_red == \"PCA\":\n", - " reducer = PCA(n_components=2)\n", - " else:\n", - " reducer = TSNE(n_components=2)\n", - "\n", - " categories = latent_distribution.cat.sample(samples).numpy().flatten()\n", - " mixture_sample = np.squeeze(\n", - " np.concatenate(\n", - " [latent_distribution.components[i].sample(1) for i in categories]\n", - " )\n", - " )\n", - "\n", - " print(mixture_sample.shape)\n", - "\n", - " if dim_red != \"LDA\":\n", - " enc = reducer.fit_transform(mixture_sample)\n", - " else:\n", - " enc = reducer.fit_transform(\n", - " mixture_sample,\n", - " np.repeat(range(len(latent_distribution.components)), categories),\n", - " )\n", - "\n", - " plt.figure(figsize=(12, 8))\n", - "\n", - " sns.scatterplot(enc[:, 0], enc[:, 1], hue=categories, palette=\"muted\")\n", - " plt.title(\n", - " \"Mean representation of latent space - K={}/{} - L={}\".format(\n", - " len(means),\n", - " len(latent_distribution.components),\n", - " len(latent_distribution.components),\n", - " )\n", - " )\n", - "\n", - " plt.xlabel(\"{} 1\".format(dim_red))\n", - " plt.ylabel(\"{} 2\".format(dim_red))\n", - " plt.suptitle(\"Static view of trained latent space\")\n", - " plt.show()" + "# Get prior distribution\n", + "gmvaep.prior" ] }, {