Skip to content
Snippets Groups Projects

Update maps to be based on the linear models

Merged Thomas Purcell requested to merge rework_2d_map into pre_gpu_changes
2 files
+ 164
38
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -20,38 +20,62 @@ plot_2d_map: Plot a 2D map of a model (2 selected features)
import pandas as pd
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sissopp import ModelClassifier, Unit
from sissopp.postprocess.load_models import load_model
from sissopp.postprocess.plot.utils import setup_plot_ax
from sissopp.postprocess.plot.utils import (
setup_plot_ax,
latexify,
remove_E_notation,
exp_to_e_notaiton,
)
from sissopp.py_interface.import_dataframe import strip_units
import seaborn as sns
def plot_2d_map(
model,
df,
feats,
sample_df=None,
feat_bounds=None,
index=None,
default_data=None,
task_ind=0,
data_filename=None,
filename=None,
fig_settings=None,
n_points=1001,
levels=32,
cmap="PuBu",
cmap="YlOrBr",
vmin=None,
vmax=None,
contour_lc=None,
fig=None,
ax=None,
colorbar=True,
x_label=None,
y_label=None,
cbar_label=None,
samp_style_label=None,
samp_sizes_label=None,
hue_label=None,
samp_edgecolors=None,
samp_markers=None,
samp_pallete=None,
samp_size=None,
):
"""Plot a 2D map of a model (2 selected features)
Args:
model (Model): Model to plot the map of
df (pd.DataFrame): Dataframe from data.csv file
feats (list of str): List of 2 features to plot
index (str): row index to take default values for features not being plotted (from df, default is avg)
data_filename (str): Filename to store the data in
feats (list of int): List of feature indexes to plot against
sample_df (pd.DataFrame): Dataframe used to add sample points
indexes (list of str): Row indexes of samples to plot on the map
default_data (list of int): Default data point for all features not listed in feats
task_ind (int): The task ind to pull the coefficients from
data_filename (str): Filename to store the data that generated the features in
filename (str): Filename for the figure
fig_settings (dict): Settings used to augment the plot
n_points (int): Number of points to plot
@@ -63,6 +87,16 @@ def plot_2d_map(
fig (matplotlib.pyplot.Figure): The matplotlib Figure object
ax (matplotlib.pyplot.Axis): The matplotlib axis for the plot
colorbar (bool): If True add a colorbar
x_label (str): The label for the x-axis
y_label (str): The label for the y-axis
cbar_label (str): The label for the colorbar-axis
samp_style_label (str): Column label for the styles column for seaborn
samp_sizes_label (str): Column label for the size column for seaborn
hue_label (str): Column label for the colo column for seaborn
samp_edgecolors (list of str): edgecolors for each sample
samp_markers (dict): Dictionary for describing which marker to use
samp_pallete (dict): Dictionary for describing what color each sample should be
samp_size (dict): Dictionary for describing what size each sample should be
Returns:
tuple: A tuple containing:
- fig (matplotlib.pyplot.Figure): The pyplot Figure for the plot
@@ -72,59 +106,86 @@ def plot_2d_map(
if len(feats) != 2:
raise ValueError("feats must be of length 2")
df = strip_units(df)
if index and index not in df.index:
raise ValueError("Requested index not in the passed DataFrame")
if any([feat not in df.columns for feat in feats]):
raise ValueError("One of the requested features is not in the df column list")
if isinstance(model, str):
model = load_model(model)
if isinstance(model, ModelClassifier):
raise ValueError("2D maps are designed for regression type plots")
if default_data is None:
default_data = [feat.value.mean() for feat in model.feats]
if any([(feat >= len(model.feats)) or (feat < 0) for feat in feats]):
raise ValueError(
f"Requested feature outside of possible range of 0 to {len(model.feats)-1}."
)
if fig is None:
fig_config, fig, ax = setup_plot_ax(fig_settings)
fig_config, fig, ax = setup_plot_ax(fig_settings, True, False)
fig.subplots_adjust(right=0.85)
elif ax is None:
ax = fig.get_axes()[0]
ax.set_xlabel(feats[0].replace("_", " "))
ax.set_ylabel(feats[1].replace("_", " "))
if x_label is None:
x_label = remove_E_notation(model.feats[feats[0]].latex_expr)
x_label = exp_to_e_notaiton(x_label)
x_label = f"${x_label[7:-8]}$"
if model.feats[feats[0]].unit != Unit():
x_label += (
f" ({latexify(str(model.feats[feats[0]].unit).replace('*', ''))})"
)
if y_label is None:
y_label = remove_E_notation(model.feats[feats[1]].latex_expr)
y_label = exp_to_e_notaiton(y_label)
y_label = f"${y_label[7:-8]}$"
if model.feats[feats[1]].unit != Unit():
y_label += (
f" ({latexify(str(model.feats[feats[1]].unit).replace('*', ''))})"
)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
if not feat_bounds:
x = np.linspace(df[feats[0]].values.min(), df[feats[0]].values.max(), n_points)
y = np.linspace(df[feats[1]].values.min(), df[feats[1]].values.max(), n_points)
x = np.linspace(
model.feats[feats[0]].value.min(),
model.feats[feats[0]].value.max(),
n_points,
)
y = np.linspace(
model.feats[feats[1]].value.min(),
model.feats[feats[1]].value.max(),
n_points,
)
else:
x = np.linspace(feat_bounds[0][0], feat_bounds[0][-1], n_points)
y = np.linspace(feat_bounds[1][0], feat_bounds[1][-1], n_points)
x = np.linspace(feat_bounds[0][0], feat_bounds[0][1], n_points)
y = np.linspace(feat_bounds[1][0], feat_bounds[1][1], n_points)
xx, yy = np.meshgrid(x, y)
xx = xx.flatten()
yy = yy.flatten()
data_dict = {}
for feat in np.unique(
[x_in_expr for feat in model.feats for x_in_expr in feat.x_in_expr_list]
):
if index is None:
data_dict[feat] = np.ones(n_points ** 2) * df[feat].values.mean()
else:
data_dict[feat] = np.ones(n_points ** 2) * df.loc[index, feat]
feat_data = [np.ones(len(xx)) * default_data[ff] for ff in range(len(model.feats))]
feat_data[feats[0]] = xx
feat_data[feats[1]] = yy
zz = (
np.zeros(len(xx))
if model.fix_intercept
else np.ones(len(xx)) * model.coefs[task_ind][-1]
)
data_dict[feats[0]] = xx
data_dict[feats[1]] = yy
zz = model.eval_many(data_dict)
for dd, dat in enumerate(feat_data):
zz += dat * model.coefs[task_ind][dd]
if data_filename:
np.savetxt(data_filename, np.column_stack((xx, yy, zz)))
if vmin and vmax:
levels = np.linspace(vmin, vmax, levels)
cnt = ax.contourf(
x,
y,
@@ -151,7 +212,40 @@ def plot_2d_map(
if colorbar:
cax = divider.append_axes("right", size="5%", pad=0.15)
cbar = fig.colorbar(cnt, cax=cax, orientation="vertical")
cbar.ax.tick_params(axis="y", direction="in", right=True)
cbar.ax.tick_params(axis="y", direction="in", left=True, right=True)
if cbar_label is None:
if model.prop_unit != Unit():
cbar_label = f"{latexify(model.prop_label)} ({latexify(str(model.prop_unit).replace('*', ''))})"
else:
cbar_label = f"{latexify(model.prop_label)}"
cbar.ax.set_ylabel(cbar_label)
if sample_df is not None:
data_df = pd.DataFrame(
index=model.sample_ids_train,
data=np.column_stack(
(model.feats[feats[0]].value, model.feats[feats[1]].value)
),
columns=["x", "y"],
)
sample_df[[x_label, y_label]] = data_df.loc[sample_df.index, :]
sns.scatterplot(
x=x_label,
y=y_label,
hue=hue_label,
style=samp_style_label,
size=samp_sizes_label,
data=sample_df,
palette=samp_pallete,
markers=samp_markers,
sizes=samp_size,
legend=False,
ax=ax,
zorder=10000,
edgecolor=samp_edgecolors,
alpha=1.0,
)
if filename:
fig.savefig(filename)
Loading