Plotting
- dive.plotting.draw_posterior_samples(trace: InferenceData, num_samples: int = 100, rng: int | None = None) tuple[ndarray, ndarray]
Draws random samples from a trace and calculates Vs and Bs.
- Parameters:
trace (az.InferenceData) – The trace to be read.
num_samples (int, default=100) – The number of random samples to draw from the trace.
rng (int, optional) – The seed for the random number generator for drawing random samples.
- Returns:
Vs, Bs – The fitted signals and backgrounds from the random samples.
- Return type:
tuple of np.ndarray
See also
- dive.plotting.pairplot_chain(trace: InferenceData, var1: str, var2: str, ax: Axes | None = None, plot_inits: bool = False, gauss_id: int = 1, colors: list[tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float]] | tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = ['r', 'g', 'b', 'y', 'm', 'c'], alpha: float = 0.1, **kwargs) Axes
Plots a scatter plot of two parameters for each chain.
Each chain will be plotted in a different color. Optionally, the initial point of each chain may be plotted.
- Parameters:
trace (az.InferenceData) – The trace to be read.
var1 (str) – The variables to be plotted.
var2 (str) – The variables to be plotted.
ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.
plot_inits (bool, default=False) – Whether or not to plot the initial points of each chain.
gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1.
colors (ColorType or list of ColorType, default=["r","g","b","y","m","c"]) – The color(s) to plot the chains. If a str is provided, all chains will be plotted that color. If a list of colors is provided, the chains will follow those colors. If the number of colors is less than the number of chains, the colors will be cycled.
alpha (float, default=0.1) – The transparency of the plotted points.
**kwargs (dict, optional) – Keyword arguments to be passed to plt.plot.
- Returns:
ax
- Return type:
plt.Axes
See also
plt.plot
- dive.plotting.pairplot_condition(trace: InferenceData, var1: str, var2: str, ax: Axes | None = None, gauss_id: int = 1, criterion: str | None = None, threshold: float | None = None, color_greater: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'dodgerblue', color_lesser: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'hotpink', alpha_greater: float = 0.2, alpha_lesser: float = 0.2, **kwargs) Axes
Plots two parameters in two groups based on a criterion.
The criteria can be found in the keys of trace.sample_stats, including tree depth and step size.
- Parameters:
trace (az.InferenceData) – The trace to be read.
var1 (str) – The variables to be plotted.
var2 (str) – The variables to be plotted.
ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.
gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1.
color_greater (ColorType, default="dodgerblue") – The color to plot the samples greater than the condition.
color_lesser (ColorType, default="hotpink") – The color to plot the samples less than the condition.
alpha (float, default=0.2) – The transparency of the samples greater than the condition.
divergence_alpha (float, default=0.4) – The transparency of the samples less than the condition.
**kwargs (dict, optional) – Keyword arguments to pass to plt.plot.
- Returns:
ax
- Return type:
plt.Axes
See also
plt.plot
- dive.plotting.pairplot_divergence(trace: InferenceData, var1: str, var2: str, ax: Axes | None = None, gauss_id: int = 1, color: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'C2', divergence_color: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'C3', alpha: float = 0.2, divergence_alpha: float = 0.4, **kwargs) Axes
Plots a scatter plot of two parameters, highlighting divergences.
Divergences are plotted in a different color. Divergences occur when the potential-energy landscape is too steep to effectively sample, causing the sampler to remain in place, potentially indicating poor sampling.
- Parameters:
trace (az.InferenceData) – The trace to be read.
var1 (str) – The variables to be plotted.
var2 (str) – The variables to be plotted.
ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.
gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1.
color (ColorType, default="C2") – The color to plot the non-divergent points.
divergence_color (ColorType, default="C3") – The color to plot the divergent points.
alpha (float, default=0.2) – The transparency of the non-divergent points.
divergence_alpha (float, default=0.4) – The transparency of the divergent points.
**kwargs (dict, optional) – Keyword arguments to pass to plt.plot.
- Returns:
ax
- Return type:
plt.Axes
See also
plt.plot
- dive.plotting.plot_P(trace: InferenceData, ax: Axes | None = None, num_samples: int = 100, show_avg: bool = False, hdi: float | None = None, rng: int | None = None, Pref: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, rref: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, alpha: float = 0.2, color: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = '#4A5899', **kwargs) Axes
Plots an ensemble of distance distributions.
A set of random distance distributions is drawn from the posterior of P. Averages and highest density intervals may also be optionally plotted. A ground-truth distance distribution will be plotted in black if provided through rref and Pref.
- Parameters:
trace (az.InferenceData) – The trace to be read.
ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.
num_samples (int, default=100) – The number of random samples to draw from the trace.
show_avg (bool, default=False) – Whether or not to plot the average distance distribution.
hdi (float, optional) – If a float is provided, the corresponding highest density interval will be plotted instead of the ensembles.
rng (int, optional) – The seed for the random number generator for drawing random samples.
Pref (ArrayLike, optional) – The ground-truth distance distribution.
rref (ArrayLike, optional) – The distnace axis of the ground-truth distance distribution.
alpha (float, default=0.2) – The transparency of the ensemble plots.
color (ColorType, default="#4A5899") – The color of the ensemble plots.
**kwargs (dict, optional) – Keyword arguments to be passed to plt.plot or plt.fill_between.
- Returns:
ax
- Return type:
plt.Axes
See also
az.extract
,summary
,plt.plot
,plt.fill_between
- dive.plotting.plot_V(trace: InferenceData, ax: Axes | None = None, num_samples: int = 100, show_avg: bool = False, hdi: float | None = None, rng: int | None = None, residuals_offset: float = 0, V_kwargs: dict = {}, B_kwargs: dict = {}, res_kwargs={}, **kwargs) Axes
Plots an ensemble of fitted signals and backgrounds with residuals.
A set of random samples from the full trace is selected for V and B. Averages and highest density intervals may also be optionally plotted.
- Parameters:
trace (az.InferenceData) – The trace to be read.
ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.
num_samples (int, default=100) – The number of random samples to draw from the trace.
show_avg (bool, default=False) – Whether or not to plot the average signal and background.
hdi (float, optional) – If a float is provided, the corresponding highest density interval will be plotted instead of the ensembles.
rng (int, optional) – The seed for the random number generator for drawing random samples.
residuals_offset (float, default=0) – The amount to raise the residual plot by (for a more compact plot).
V_kwargs (dict, default={}) – Keyword arguments to be passed to plt.plot or plt.fill_between for V.
B_kwargs (dict, default={}) – Keyword arguments to be passed to plt.plot or plt.fill_between for B.
res_kwargs (dict, default={}) – Keyword arguments to be passed to plt.plot or plt.fill_between for the residuals.
**kwargs (dict, optional) – Keyword arguments to be passed to plt.plot or plt.fill_between for all plots.
- Returns:
ax
- Return type:
plt.Axes
See also
draw_posterior_samples
,summary
,plt.plot
,plt.fill_between
- dive.plotting.plot_correlations(trace: InferenceData, axs: ndarray[Axes] | None = None, var_names: list[str] | None = None, marginals: bool = True, **kwargs) Axes
Plots 2D marginalized posteriors of selected variables.
Illustrates pairwise correlation plots between model parameters.
- Parameters:
trace (az.InferenceData) – The trace to be read.
axs (np.ndarray of plt.Axes) – 2D numpy array of the MatPlotLib axes to be plotted on. If not provided, axes will be automatically created. Needs to be the correct size.
var_names (list of str) – The variables to be plotted. If not provided, a list of relevant important variables will be automatically selected.
marginals (bool, default=True) – Whether or not to also include the 1D marginalized posteriors for each variable.
**kwargs (dict, optional) – Keyword arguments to be passed to az.plot_pair.
- Returns:
axs
- Return type:
plt.Axes
See also
az.plot_pair
,summary
- dive.plotting.plot_hist(trace: InferenceData, var: str, ax: Axes | None = None, combine: bool = False, gauss_id: int = 1, **kwargs) Axes
Plots a histogram of a parameter’s values.
- Parameters:
trace (az.InferenceData) – The trace to read.
var (str) – The varaible to plot.
ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.
combine (bool, default=False) – For the parameters associated with a gaussian fit (r0, w, a), whether or not to combine all gaussians. E.g., if set to True, the r0, w, and a values for all gaussians will be combined.
gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1. Only if combine is False.
**kwargs (dict, optional) – Keyword arguments to pass to plt.hist.
- Returns:
ax
- Return type:
plt.Axes
See also
plt.hist
- dive.plotting.plot_marginals(trace: InferenceData, axs: ndarray[Axes] | None = None, var_names: list[str] | None = None, ground_truth: dict[str, float] | None = None, point_estimate: str | None = None, hdi_prob: float = 'hide', **kwargs) Axes
Plot 1D marginalized posteriors of selected variables.
- Parameters:
trace (az.InferenceData) – The trace to be read.
axs (np.ndarray of plt.Axes, optional) – 1D numpy array of the MatPlotLib axes to be plotted on. If not provided, axes will be automatically created. Needs to be the correct size.
var_names (list of str, optional) – The variables to be plotted. If not provided, a list of relevant important variables will be automatically selected.
ground_truth (dict of str, float, optional) – A dictionary of ground-truth variable values, which will be plotted as vertical gray lines. Keys should be the variable names and arguments should be their values.
point_estimate (str, optional) – If “mean”, “median”, or “mode” are passed, that value will be plotted. See az.plot_posterior.
hdi_prob (float, default="hide") – The highest density interval to plot. If set to the default “hide”, none will be plotted.
**kwargs (dict, optional) – Keyword arguments to be passed to az.plot_posterior.
- Returns:
axs
- Return type:
plt.Axes
See also
az.plot_posterior
,summary
- dive.plotting.print_summary(trace: InferenceData, var_names: list[str] | None = None, **kwargs)
Prints a table with summary statistics of important parameters.
The table includes their means, standard deviations, effective sample sizes, Monte Carlo standard errors, and R-hat diagnostic values.
- Parameters:
trace (az.InferenceData) – The trace to be read.
var_names (list of str, optional) – The variables to be summarized. If not provided, a list of relevant important variables will be automatically selected.
**kwargs (dict, optional) – Keyword arguments to be passed to az.summary.
See also
az.summary
,summary
- dive.plotting.summary(trace: InferenceData, var_names: list[str] | None = None, num_samples: int = 100, rng: int | None = None)
Summary function to plot several plots.
Calls print_summary, plot_marginals, plot_correlations, plot_V, and plot_P.
- Parameters:
trace (az.InferenceData) – The trace to be read.
var_names (list of str) – The variables to be plotted. If not provided, a list of relevant important variables will be automatically selected.
num_samples (int, default=100) – The number of random samples to draw from the trace.
rng (int, optional) – The seed for the random number generator for drawing random samples.
See also
print_summary
,plot_marginals
,plot_correlations
,plot_V
,plot_P