Skip to content

Commit

Permalink
refactor and improve plots
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Sep 27, 2023
1 parent 112933a commit 2324b33
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 55 deletions.
130 changes: 93 additions & 37 deletions kulprit/plots/plots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from arviz.plots.plot_utils import _scale_fig_size
from arviz.plots import plot_density, plot_forest
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -38,41 +39,35 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
xticks_labels[0] = labels[0]
xticks_labels[2::2] = labels[1:]

fig, ax1 = plt.subplots(1, figsize=figsize)
fig, axes = plt.subplots(1, figsize=figsize)

# double axes
ax2 = ax1.twinx()

ax1.errorbar(
axes.errorbar(
y=cmp_df["elpd_loo"][1:],
x=xticks_pos[::2],
yerr=cmp_df.se[1:],
label="Submodel ELPD",
label="Submodels",
color=plot_kwargs.get("color_eldp", "k"),
fmt=plot_kwargs.get("marker_eldp", "o"),
mfc=plot_kwargs.get("marker_fc_elpd", "white"),
mew=linewidth,
lw=linewidth,
markersize=4,
)
ax2.errorbar(
y=cmp_df["elpd_diff"].iloc[1:],
x=xticks_pos[1::2],
yerr=cmp_df.dse[1:],
label="ELPD difference\n(to reference model)",
color=plot_kwargs.get("color_dse", "grey"),
fmt=plot_kwargs.get("marker_dse", "^"),
mew=linewidth,
elinewidth=linewidth,
markersize=4,
)

ax1.axhline(
axes.axhline(
cmp_df["elpd_loo"].iloc[0],
ls=plot_kwargs.get("ls_reference", "--"),
color=plot_kwargs.get("color_ls_reference", "grey"),
lw=linewidth,
label="Reference model ELPD",
label="Reference model",
)

axes.fill_between(
[-2, 1],
cmp_df["elpd_loo"].iloc[0] + cmp_df["se"].iloc[0],
cmp_df["elpd_loo"].iloc[0] - cmp_df["se"].iloc[0],
alpha=0.1,
color=plot_kwargs.get("color_ls_reference", "grey"),
)

if legend:
Expand All @@ -84,7 +79,7 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
)

if title:
ax1.set_title(
axes.set_title(
"Model comparison",
fontsize=ax_labelsize * 0.6,
)
Expand All @@ -93,23 +88,84 @@ def plot_compare(cmp_df, legend=True, title=True, figsize=None, plot_kwargs=None
xticks_pos, xticks_labels = xticks_pos[::2], xticks_labels[::2]

# set axes
ax1.set_xticks(xticks_pos)
ax1.set_ylabel("ELPD", fontsize=ax_labelsize * 0.6)
ax1.set_xlabel("Submodel size", fontsize=ax_labelsize * 0.6)
ax1.set_xticklabels(xticks_labels)
ax1.set_xlim(-1 + step, 0 - step)
ax1.tick_params(labelsize=xt_labelsize * 0.6)
ax2.set_ylabel("ELPD difference", fontsize=ax_labelsize * 0.6, color="grey")
ax2.set_ylim(ax2.get_ylim()[::-1])
ax2.tick_params(axis="y", colors="grey")
align_yaxis(ax1, cmp_df["elpd_loo"].iloc[0], ax2, 0)

return ax1


def align_yaxis(ax1, v_1, ax2, v_2):
"""adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
_, y_1 = ax1.transData.transform((0, v_1))
axes.set_xticks(xticks_pos)
axes.set_ylabel("ELPD", fontsize=ax_labelsize * 0.6)
axes.set_xlabel("Submodel size", fontsize=ax_labelsize * 0.6)
axes.set_xticklabels(xticks_labels)
axes.set_xlim(-1 + step, 0 - step)
axes.tick_params(labelsize=xt_labelsize * 0.6)

return axes


def plot_densities(
model,
path,
idata,
var_names=None,
submodels=None,
include_reference=True,
labels="formula",
kind="density",
figsize=None,
plot_kwargs=None,
):
"""Compare the projected posterior densities of the submodels"""

if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs.setdefault("figsize", figsize)

if kind not in ["density", "forest"]:
raise ValueError("kind must be one of 'density' or 'forest'")

# set default variable names to the reference model terms
if not var_names:
var_names = list(set(model.response_component.terms.keys()) - set([model.response_name]))

if include_reference:
data = [idata]
l_labels = ["Reference"]
var_names.append(f"~{model.response_name}_mean")

if submodels is None:
submodels = path.values()
else:
submodels = [path[key] for key in submodels]

if labels == "formula":
l_labels.extend([submodel.model.formula for submodel in submodels])
else:
l_labels.extend([submodel.size for submodel in submodels])

data.extend([submodel.idata for submodel in submodels])

if kind == "density":
plot_kwargs.setdefault("outline", False)
plot_kwargs.setdefault("shade", 0.4)

axes = plot_density(
data=data,
var_names=var_names,
data_labels=l_labels,
**plot_kwargs,
)

elif kind == "forest":
plot_kwargs.setdefault("combined", True)
axes = plot_forest(
data=data,
model_names=l_labels,
var_names=var_names,
**plot_kwargs,
)

return axes


def align_yaxis(axes, v_1, ax2, v_2):
"""adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in axes"""
_, y_1 = axes.transData.transform((0, v_1))
_, y_2 = ax2.transData.transform((0, v_2))
inv = ax2.transData.inverted()
_, d_y = inv.transform((0, 0)) - inv.transform((0, y_1 - y_2))
Expand Down
56 changes: 40 additions & 16 deletions kulprit/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas as pd

from kulprit.data.submodel import SubModel
from kulprit.plots.plots import plot_compare
from kulprit.plots.plots import plot_compare, plot_densities
from kulprit.projection.projector import Projector
from kulprit.search.searcher import Searcher

Expand Down Expand Up @@ -234,26 +234,50 @@ def plot_compare(
def plot_densities(
self,
var_names: Optional[List[str]] = None,
outline: Optional[bool] = False,
shade: Optional[float] = 0.4,
submodels: Optional[List[int]] = None,
include_reference: bool = True,
labels: Literal["formula", "size"] = "formula",
kind: Literal["density", "forest"] = "density",
figsize: Optional[Tuple[int, int]] = None,
plot_kwargs: Optional[dict] = None,
) -> matplotlib.axes.Axes:
"""Compare the projected posterior densities of the submodels"""
"""Compare the projected posterior densities of the submodels
# set default variable names to the reference model terms
if not var_names:
var_names = list(
set(self.model.response_component.terms.keys()) - set([self.model.response_name])
)
Parameters:
-----------
var_names : list of str, optional
List of variables to plot.
submodels : list of int, optional
List of submodels to plot, 0 is intercept-only model and the largest valid integer is
the total number of variables in reference model. If None, all submodels are plotted.
include_reference : bool
Whether to include the reference model in the plot. Defaults to True.
labels : str
If "formula", the labels are the formulas of the submodels. If "size", the number
of covariates in the submodels.
figsize : tuple
Figure size. If None it will be defined automatically.
plot_kwargs : dict
Dictionary passed to ArviZ's ``plot_density`` function (if kind density) or to
``plot_forest`` (if kind forest).
axes = az.plot_density(
data=[submodel.idata for submodel in self.path.values()],
group="posterior",
Returns:
--------
axes : matplotlib_axes or bokeh_figure
"""
return plot_densities(
self.model,
self.path,
self.idata,
var_names=var_names,
outline=outline,
shade=shade,
data_labels=[submodel.model.formula for submodel in self.path.values()],
submodels=submodels,
include_reference=include_reference,
labels=labels,
kind=kind,
figsize=figsize,
plot_kwargs=plot_kwargs,
)
return axes


def check_model_idata_compatability(model, idata):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,17 @@ def test_plot_comparison(self, ref_model):
ref_model_copy.search()
ref_model_copy.plot_compare(plot=True, figsize=(10, 5))

def test_plot_densities(self, ref_model):
@pytest.mark.parametrize(
"kwargs",
[
{},
{"kind": "forest"},
{"kind": "forest", "plot_kwargs": {"combined": False}},
{"submodels": [0, 1], "labels": "size"},
{"figsize": (4, 4)},
],
)
def test_plot_densities(self, ref_model, kwargs):
ref_model_copy = copy.copy(ref_model)
ref_model_copy.search()
ref_model_copy.plot_densities()
ref_model_copy.plot_densities(**kwargs)

0 comments on commit 2324b33

Please sign in to comment.