# standard library
from typing import Union, Optional, TYPE_CHECKING
# third party imports
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
from loguru import logger
# local imports
from probeye.subroutines import len_or_one
from probeye.subroutines import add_index_to_tex_prm_name
# imports only needed for type hints
if TYPE_CHECKING: # pragma: no cover
from probeye.definition.inverse_problem import InverseProblem
[docs]def create_pair_plot(
inference_data: az.data.inference_data.InferenceData,
problem: "InverseProblem",
plot_with: str = "arviz",
plot_priors: bool = True,
focus_on_posterior: bool = True,
kind: str = "kde",
figsize: Optional[tuple] = None,
inches_per_row: Union[int, float] = 2.0,
inches_per_col: Union[int, float] = 2.0,
textsize: Union[int, float] = 10,
title_size: Union[int, float] = 14,
title: Optional[str] = None,
true_values: Optional[dict] = None,
show_legends: bool = True,
show: bool = True,
**kwargs,
) -> np.ndarray:
"""
Creates a pair-plot for the given inference data.
Parameters
----------
inference_data
Contains the results of the sampling procedure.
problem
The inverse problem the inference data refers to.
plot_with
Defines the python package the plot will be generated with. Options are:
{'arviz', 'seaborn', 'matplotlib'}.
plot_priors
If True, the prior-distributions are included in the marginal subplots.
Otherwise the priors are not shown.
focus_on_posterior
If True, the marginal plots will focus on the posteriors, i.e., the range of the
horizontal axis will adapt to the posterior. This might result in just seeing a
fraction of the prior distribution (if they are included). If False, the
marginal plots will focus on the priors, which will have a broader x-range. If
plot_priors=False, this argument has no effect on the generated plot.
kind
Type of plot to display ('scatter', 'kde' and/or 'hexbin').
figsize
Defines the size of the generated plot in inches. If None is chosen, the figsize
will be derived automatically by using inches_per_row and inches_per_col.
inches_per_row
If figsize is None, this will specify the inches per row in the subplot-grid.
This argument has no effect if figsize is specified.
inches_per_col
If figsize is None, this will specify the inches per column in the subplot-grid.
This argument has no effect if figsize is specified.
textsize
Defines the font size in the default unit.
title_size
Defines the font size of the figures title if 'title' is given.
title
The title of the figure.
true_values
Used for plotting 'true' parameter values. Keys are the parameter names and
values are the values that are supposed to be shown in the marginal plots.
show_legends
If True, legends are shown in the marginal plots. Otherwise no legends are
included in the plot.
show
When True, the show-method is called after creating the plot. Otherwise, the
show-method is not called. The latter is useful, when the plot should be further
processed.
kwargs
Additional keyword arguments passed to arviz' pairplot function.
Returns
-------
axs
The array of subplots of the created plot.
"""
# a pairplot can only be generate when there are at least two parameter or parameter
# components (the latter refers to vector-valued parameters)
if problem.n_latent_prms_dim == 1:
logger.warning(
"The combined dimension of all latent parameters is one. Hence, no "
"pairplot can be generated in this setup."
)
return np.array([])
if plot_with == "arviz":
# set default value for kde_kwargs if not given in kwargs; note that this
# default value is mutable, so it should not be given as a default argument in
# create_pair_plot
if "kde_kwargs" not in kwargs:
kwargs["kde_kwargs"] = {
"contourf_kwargs": {"alpha": 0},
"contour_kwargs": {"colors": None},
}
if "backend_kwargs" not in kwargs:
if problem.n_latent_prms_dim == 2:
kwargs["backend_kwargs"] = {"constrained_layout": True}
histograms_on_diagonal = False
if "marginal_kwargs" in kwargs:
if "kind" in kwargs["marginal_kwargs"]:
if kwargs["marginal_kwargs"]["kind"] == "hist":
histograms_on_diagonal = True
# process true_values if specified
if true_values is not None:
reference_values_unsorted = dict()
index_list = []
for prm_name, value in true_values.items():
dim = problem.parameters[prm_name].dim
tex = prm_name # prevents tex being None
if problem.parameters[prm_name].tex is not None:
tex = problem.parameters[prm_name].tex
if dim > 1:
# all the channels in the inference data are 1D
idx_start = problem.parameters[prm_name].index
index_list.append(idx_start)
for i in range(dim):
key = add_index_to_tex_prm_name(tex, i + 1)
reference_values_unsorted[key] = value[i]
index_list.append(idx_start + (i + 1))
else:
key = tex
reference_values_unsorted[key] = value
index_list.append(problem.parameters[prm_name].index)
key_list_unsorted = [*reference_values_unsorted.keys()]
key_list = [key for _, key in sorted(zip(index_list, key_list_unsorted))]
reference_values = dict()
for key in key_list:
reference_values[key] = reference_values_unsorted[key]
kwargs["reference_values"] = reference_values
if "reference_values_kwargs" not in kwargs:
kwargs["reference_values_kwargs"] = {"marker": "o", "color": "red"}
# call the main plotting routine from arviz
axs = az.plot_pair(
inference_data,
marginals=True,
kind=kind,
textsize=textsize,
show=False,
**kwargs,
)
# adds a reference value in each marginal plot; for some reason this is not done
# by arviz.pair_plot when passing 'reference_values'
if "reference_values" in kwargs:
reference_values_kwargs = None
if "reference_values_kwargs" in kwargs:
reference_values_kwargs = kwargs["reference_values_kwargs"]
ref_value_list = [*kwargs["reference_values"].values()]
if problem.n_latent_prms_dim > 2:
# in this case, the relevant axis is always the horizontal one
for i, prm_value in enumerate(ref_value_list):
axs[i, i].scatter(
prm_value,
0,
label="true value",
zorder=10,
edgecolor="black",
**reference_values_kwargs,
)
else:
# in this case, the plot on the bottom right is rotated
axs[0, 0].scatter(
ref_value_list[0],
0,
label="true value",
zorder=10,
**reference_values_kwargs,
edgecolor="black",
)
axs[1, 1].scatter(
0,
ref_value_list[1],
label="true value",
zorder=10,
**reference_values_kwargs,
edgecolor="black",
)
if plot_priors:
# add the prior-pdfs to the marginal subplots
prm_names = problem.get_theta_names(tex=False, components=False)
i = 0 # not included in for-header due to possible dim-jumps
for prm_name in prm_names:
# for multivariate priors, no priors are plotted
if problem.parameters[prm_name].dim > 1:
i += problem.parameters[prm_name].dim
continue
x = None
if focus_on_posterior:
if (problem.n_latent_prms_dim == 2) and (i == 1):
# the plot on the bottom right is rotated
x_min, x_max = axs[i, i].get_ylim()
else:
x_min, x_max = axs[i, i].get_xlim()
x = np.linspace(x_min, x_max, 200)
# the following code adds labels to the prior and posterior plot if they
# are represented as lines
if axs[i, i].lines:
posterior_handle = [axs[i, i].lines[0]]
posterior_label = ["posterior"]
else:
# this is for the case, when the posterior is not shown as a line,
# but for example as a histogram etc.
posterior_handle, posterior_label = [], []
rotate = True if problem.n_latent_prms_dim == 2 and i == 1 else False
problem.parameters[prm_name].prior.plot(
axs[i, i],
problem.parameters,
x=x,
rotate=rotate,
label="prior",
)
# don't use the histogram bin ticks when the prior is also plotted
if histograms_on_diagonal and not focus_on_posterior:
if rotate:
y_min, y_max = axs[i, i].get_ylim()
tick_list = np.linspace(y_min, y_max, 9).tolist()
axs[i, i].set_yticks(tick_list)
else:
x_min, x_max = axs[i, i].get_xlim()
tick_list = np.linspace(x_min, x_max, 9).tolist()
axs[i, i].set_xticks(tick_list)
# create the legends if requested
if show_legends:
prior_handle, prior_label = axs[i, i].get_legend_handles_labels()
axs[i, i].legend(
posterior_handle + prior_handle,
posterior_label + prior_label,
loc="best",
)
i += 1
# here, the axis of the non-marginal plots are adjusted to the new ranges
if (not focus_on_posterior) and (problem.n_latent_prms_dim > 2):
n = problem.n_latent_prms_dim
for i in range(n):
# the reference is the plot on the diagonal
x_min, x_max = axs[i, i].get_xlim()
# loop over axes in the column below
for j in range(i + 1, n):
axs[j, i].set_xlim((x_min, x_max))
# loop over axes in the row to the left
for j in range(0, i):
axs[i, j].set_ylim((x_min, x_max))
else:
# the following code adds legends to the marginal plots for the case where
# no priors are supposed to be plotted
if show_legends:
prm_names = problem.get_theta_names(tex=False, components=True)
for i, prm_name in enumerate(prm_names):
existing_handles, existing_labels = axs[
i, i
].get_legend_handles_labels()
if axs[i, i].lines:
posterior_handle = [axs[i, i].lines[0]]
posterior_label = ["posterior"]
else:
# this is for the case, when the posterior is not shown as a
# line, but for example as a histogram etc.
posterior_handle, posterior_label = [], []
axs[i, i].legend(
posterior_handle + existing_handles,
posterior_label + existing_labels,
loc="best",
)
# synchronize the axes, which is only necessary if there are at least 3
# latent parameters; in this case of only 2 latent parameters (note that
# only one latent parameter is not allowed for a pair plot), a slightly
# different plot is created where the marginal plot on the right is rotated
n = problem.n_latent_prms_dim
if n > 2:
for i in range(n):
# the reference is the plot on the diagonal
x_min, x_max = axs[i, i].get_xlim()
# loop over axes in the column below
for j in range(i + 1, n):
axs[j, i].set_xlim((x_min, x_max))
# loop over axes in the row to the left
for j in range(0, i):
axs[i, j].set_ylim((x_min, x_max))
# set the figure size; this is done either automatically if the user did not
# specify the figsize argument, or it simply sets the requested figsize
fig = axs.ravel()[0].figure
if figsize is None:
if problem.n_latent_prms_dim > 2:
n_rows, n_cols = axs.shape
fig.set_size_inches(n_cols * inches_per_col, n_rows * inches_per_row)
else:
fig.set_size_inches(6.0, 5.0)
else:
fig.set_size_inches(figsize[0], figsize[1])
# add a title to the plot (if requested) and apply a tight layout
if title:
fig.suptitle(title, fontsize=title_size)
# the following command reduces the otherwise wide margins; when only two
# parameter (components) are given, the tight_layout()-call only results in a
# warning without having an effect - hence, the if-clause
if problem.n_latent_prms_dim > 2:
fig.tight_layout()
# by default, the y-axis of the first and last marginal plot have ticks, tick-
# labels and axis-labels that are not meaningful to show on the y-axis; hence,
# we remove them here; since the default plot looks different for only two
# latent parameters, there is a check before
if problem.n_latent_prms_dim > 2:
for i in [0, -1]:
axs[i, i].yaxis.set_ticks_position("none")
axs[i, i].yaxis.set_ticklabels([])
axs[i, i].yaxis.set_visible(False)
for i in range(problem.n_latent_prms_dim - 1):
xlim = axs[-1, i].get_xlim()
axs[i, i].set_xticks(ticks=axs[-1, i].get_xticks())
axs[i, i].set_xlim(xlim)
ylim = axs[-1, 0].get_ylim()
axs[-1, -1].set_xticks(ticks=axs[-1, 0].get_yticks())
axs[-1, -1].set_xlim(ylim)
# when histograms are used to plot the marginals, the tick labels are often
# rather close together, so that they overlap; here, they are rotated to
# alleviate this overlap
if histograms_on_diagonal:
for i in range(problem.n_latent_prms_dim):
axs[-1, i].tick_params(axis="x", labelrotation=45)
# show the plot if requested
if show:
plt.show() # pragma: no cover
# Note: the returned axs-object can be saved to a file via:
# fig = axs.ravel()[0].figure
# fig.savefig(filename, ...)
return axs
elif plot_with == "seaborn":
raise NotImplementedError(
"The plot-creation with seaborn has not been implemented yet."
)
elif plot_with == "matplotlib":
raise NotImplementedError(
"The plot-creation with matplotlib has not been implemented yet."
)
else:
raise RuntimeError(
f"Invalid 'plot_with' argument: '{plot_with}'. Available options are "
f"currently 'arviz', 'seaborn', 'matplotlib'"
)
[docs]def create_posterior_plot(
inference_data: az.data.inference_data.InferenceData,
problem: "InverseProblem",
plot_with: str = "arviz",
kind: str = "hist",
figsize: Optional[tuple] = None,
inches_per_row: Union[int, float] = 3.0,
inches_per_col: Union[int, float] = 2.5,
textsize: Union[int, float] = 10,
title_size: Union[int, float] = 14,
title: Optional[str] = None,
hdi_prob: float = 0.95,
true_values: Optional[dict] = None,
show: bool = True,
**kwargs,
) -> np.ndarray:
"""
Creates a posterior-plot for the given inference data.
Parameters
----------
inference_data
Contains the results of the sampling procedure.
problem
The inverse problem the inference data refers to.
plot_with
Defines the python package the plot will be generated with. Options are:
{'arviz', 'seaborn', 'matplotlib'}.
kind
Type of plot to display ('kde' or 'hist').
figsize
Defines the size of the generated plot in inches. If None is chosen, the figsize
will be derived automatically by using inches_per_row and inches_per_col.
inches_per_row
If figsize is None, this will specify the inches per row in the subplot-grid.
This argument has no effect if figsize is specified.
inches_per_col
If figsize is None, this will specify the inches per column in the subplot-grid.
This argument has no effect if figsize is specified.
textsize
Defines the font size in the default unit.
title_size
Defines the font size of the figures title if 'title' is given.
title
The title of the figure.
hdi_prob
Defines the highest density interval. Must be a number between 0 and 1.
true_values
Used for plotting 'true' parameter values. Keys are the parameter names and
values are the values that are supposed to be shown in the marginal plots.
show
When True, the show-method is called after creating the plot. Otherwise, the
show-method is not called. The latter is useful, when the plot should be further
processed.
kwargs
Additional keyword arguments passed to arviz' plot_posterior function.
Returns
-------
axs
The array of subplots of the created plot.
"""
if plot_with == "arviz":
# process true_values if specified
if true_values is not None:
var_names_raw = problem.get_theta_names(tex=False)
ref_val = []
for var_name in var_names_raw:
if len_or_one(true_values[var_name]) == 1:
ref_val.append(true_values[var_name])
else:
for true_value in true_values[var_name]:
ref_val.append(true_value)
kwargs["ref_val"] = ref_val
# call the main plotting routine from arviz and return the axes object
axs = az.plot_posterior(
inference_data,
kind=kind,
textsize=textsize,
hdi_prob=hdi_prob,
show=False,
**kwargs,
)
# set the figure size; this is done either automatically if the user did not
# specify the figsize argument, or it simply sets the requested figsize
if isinstance(axs, np.ndarray):
fig = axs.ravel()[0].figure
if len(axs.shape) == 1:
n_rows, n_cols = 1, axs.size
else:
n_rows, n_cols = axs.shape # pragma: no cover
else:
fig = axs.figure
n_rows, n_cols = 1, 1
if figsize is None:
fig.set_size_inches(n_cols * inches_per_col, n_rows * inches_per_row)
else:
fig.set_size_inches(figsize[0], figsize[1])
# add a title to the plot (if requested) and apply a tight layout
if title:
fig.suptitle(title, fontsize=title_size)
fig.tight_layout()
# show the plot if requested
if show:
plt.show() # pragma: no cover
# Note: the returned axs-object can be saved to a file via:
# fig = axs.ravel()[0].figure
# fig.savefig(filename, ...)
return axs
elif plot_with == "seaborn":
raise NotImplementedError(
"The plot-creation with seaborn has not been implemented yet."
)
elif plot_with == "matplotlib":
raise NotImplementedError(
"The plot-creation with matplotlib has not been implemented yet."
)
else:
raise RuntimeError(
f"Invalid 'plot_with' argument: '{plot_with}'. Available options are "
f"currently 'arviz', 'seaborn', 'matplotlib'"
)
[docs]def create_trace_plot(
inference_data: az.data.inference_data.InferenceData,
problem: "InverseProblem", # for consistent interface
plot_with: str = "arviz",
kind: str = "trace",
figsize: Optional[tuple] = None,
inches_per_row: Union[int, float] = 2.0,
inches_per_col: Union[int, float] = 3.0,
textsize: Union[int, float] = 10,
title_size: Union[int, float] = 14,
title: Optional[str] = None,
show: bool = True,
**kwargs,
) -> np.ndarray:
"""
Creates a trace-plot for the given inference data.
Parameters
----------
inference_data
Contains the results of the sampling procedure.
problem
The inverse problem the inference data refers to.
plot_with
Defines the python package the plot will be generated with. Options are:
{'arviz', 'seaborn', 'matplotlib'}.
kind
Allows to choose between plotting sampled values per iteration ("trace") and
rank plots ("rank_bar", "rank_vlines").
figsize
Defines the size of the generated plot in inches. If None is chosen, the figsize
will be derived automatically by using inches_per_row and inches_per_col.
inches_per_row
If figsize is None, this will specify the inches per row in the subplot-grid.
This argument has no effect if figsize is specified.
inches_per_col
If figsize is None, this will specify the inches per column in the subplot-grid.
This argument has no effect if figsize is specified.
textsize
Defines the font size in the default unit.
title_size
Defines the font size of the figures title if 'title' is given.
title
The title of the figure.
show
When True, the show-method is called after creating the plot. Otherwise, the
show-method is not called. The latter is useful, when the plot should be further
processed.
kwargs
Additional keyword arguments passed to arviz' plot_trace function.
Returns
-------
axs
The array of subplots of the created plot.
"""
if plot_with == "arviz":
# set default value for plot_kwargs if not given in kwargs; note that this
# default value is mutable, so it should not be given as a default argument in
# create_trace_plot
if "plot_kwargs" not in kwargs:
kwargs["plot_kwargs"] = {"textsize": textsize}
# call the main plotting routine from arviz and return the axes object
axs = az.plot_trace(inference_data, kind=kind, show=False, **kwargs)
# set the figure size; this is done either automatically if the user did not
# specify the figsize argument, or it simply sets the requested figsize
fig = axs.ravel()[0].figure
if figsize is None:
n_rows, n_cols = axs.shape
fig.set_size_inches(n_cols * inches_per_col, n_rows * inches_per_row)
else:
fig.set_size_inches(figsize[0], figsize[1])
# add a title to the plot (if requested) and apply a tight layout
if title:
fig.suptitle(title, fontsize=title_size)
fig.tight_layout(h_pad=1.75)
# show the plot if requested
if show:
plt.show() # pragma: no cover
# Note: the returned axs-object can be saved to a file via:
# fig = axs.ravel()[0].figure
# fig.savefig(filename, ...)
return axs
elif plot_with == "seaborn":
raise NotImplementedError(
"The plot-creation with seaborn has not been implemented yet."
)
elif plot_with == "matplotlib":
raise NotImplementedError(
"The plot-creation with matplotlib has not been implemented yet."
)
else:
raise RuntimeError(
f"Invalid 'plot_with' argument: '{plot_with}'. Available options are "
f"currently 'arviz', 'seaborn', 'matplotlib'"
)