Plotting

dive.plotting.draw_posterior_samples(trace: InferenceData, num_samples: int = 100, rng: int | None = None) tuple[ndarray, ndarray]

Draws random samples from a trace and calculates Vs and Bs.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • num_samples (int, default=100) – The number of random samples to draw from the trace.

  • rng (int, optional) – The seed for the random number generator for drawing random samples.

Returns:

Vs, Bs – The fitted signals and backgrounds from the random samples.

Return type:

tuple of np.ndarray

See also

plot_V

dive.plotting.pairplot_chain(trace: InferenceData, var1: str, var2: str, ax: Axes | None = None, plot_inits: bool = False, gauss_id: int = 1, colors: list[tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float]] | tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = ['r', 'g', 'b', 'y', 'm', 'c'], alpha: float = 0.1, **kwargs) Axes

Plots a scatter plot of two parameters for each chain.

Each chain will be plotted in a different color. Optionally, the initial point of each chain may be plotted.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • var1 (str) – The variables to be plotted.

  • var2 (str) – The variables to be plotted.

  • ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.

  • plot_inits (bool, default=False) – Whether or not to plot the initial points of each chain.

  • gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1.

  • colors (ColorType or list of ColorType, default=["r","g","b","y","m","c"]) – The color(s) to plot the chains. If a str is provided, all chains will be plotted that color. If a list of colors is provided, the chains will follow those colors. If the number of colors is less than the number of chains, the colors will be cycled.

  • alpha (float, default=0.1) – The transparency of the plotted points.

  • **kwargs (dict, optional) – Keyword arguments to be passed to plt.plot.

Returns:

ax

Return type:

plt.Axes

See also

plt.plot

dive.plotting.pairplot_condition(trace: InferenceData, var1: str, var2: str, ax: Axes | None = None, gauss_id: int = 1, criterion: str | None = None, threshold: float | None = None, color_greater: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'dodgerblue', color_lesser: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'hotpink', alpha_greater: float = 0.2, alpha_lesser: float = 0.2, **kwargs) Axes

Plots two parameters in two groups based on a criterion.

The criteria can be found in the keys of trace.sample_stats, including tree depth and step size.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • var1 (str) – The variables to be plotted.

  • var2 (str) – The variables to be plotted.

  • ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.

  • gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1.

  • color_greater (ColorType, default="dodgerblue") – The color to plot the samples greater than the condition.

  • color_lesser (ColorType, default="hotpink") – The color to plot the samples less than the condition.

  • alpha (float, default=0.2) – The transparency of the samples greater than the condition.

  • divergence_alpha (float, default=0.4) – The transparency of the samples less than the condition.

  • **kwargs (dict, optional) – Keyword arguments to pass to plt.plot.

Returns:

ax

Return type:

plt.Axes

See also

plt.plot

dive.plotting.pairplot_divergence(trace: InferenceData, var1: str, var2: str, ax: Axes | None = None, gauss_id: int = 1, color: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'C2', divergence_color: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = 'C3', alpha: float = 0.2, divergence_alpha: float = 0.4, **kwargs) Axes

Plots a scatter plot of two parameters, highlighting divergences.

Divergences are plotted in a different color. Divergences occur when the potential-energy landscape is too steep to effectively sample, causing the sampler to remain in place, potentially indicating poor sampling.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • var1 (str) – The variables to be plotted.

  • var2 (str) – The variables to be plotted.

  • ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.

  • gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1.

  • color (ColorType, default="C2") – The color to plot the non-divergent points.

  • divergence_color (ColorType, default="C3") – The color to plot the divergent points.

  • alpha (float, default=0.2) – The transparency of the non-divergent points.

  • divergence_alpha (float, default=0.4) – The transparency of the divergent points.

  • **kwargs (dict, optional) – Keyword arguments to pass to plt.plot.

Returns:

ax

Return type:

plt.Axes

See also

plt.plot

dive.plotting.plot_P(trace: InferenceData, ax: Axes | None = None, num_samples: int = 100, show_avg: bool = False, hdi: float | None = None, rng: int | None = None, Pref: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, rref: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None, alpha: float = 0.2, color: tuple[float, float, float] | str | tuple[float, float, float, float] | tuple[tuple[float, float, float] | str, float] | tuple[tuple[float, float, float, float], float] = '#4A5899', **kwargs) Axes

Plots an ensemble of distance distributions.

A set of random distance distributions is drawn from the posterior of P. Averages and highest density intervals may also be optionally plotted. A ground-truth distance distribution will be plotted in black if provided through rref and Pref.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.

  • num_samples (int, default=100) – The number of random samples to draw from the trace.

  • show_avg (bool, default=False) – Whether or not to plot the average distance distribution.

  • hdi (float, optional) – If a float is provided, the corresponding highest density interval will be plotted instead of the ensembles.

  • rng (int, optional) – The seed for the random number generator for drawing random samples.

  • Pref (ArrayLike, optional) – The ground-truth distance distribution.

  • rref (ArrayLike, optional) – The distnace axis of the ground-truth distance distribution.

  • alpha (float, default=0.2) – The transparency of the ensemble plots.

  • color (ColorType, default="#4A5899") – The color of the ensemble plots.

  • **kwargs (dict, optional) – Keyword arguments to be passed to plt.plot or plt.fill_between.

Returns:

ax

Return type:

plt.Axes

See also

az.extract, summary, plt.plot, plt.fill_between

dive.plotting.plot_V(trace: InferenceData, ax: Axes | None = None, num_samples: int = 100, show_avg: bool = False, hdi: float | None = None, rng: int | None = None, residuals_offset: float = 0, V_kwargs: dict = {}, B_kwargs: dict = {}, res_kwargs={}, **kwargs) Axes

Plots an ensemble of fitted signals and backgrounds with residuals.

A set of random samples from the full trace is selected for V and B. Averages and highest density intervals may also be optionally plotted.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.

  • num_samples (int, default=100) – The number of random samples to draw from the trace.

  • show_avg (bool, default=False) – Whether or not to plot the average signal and background.

  • hdi (float, optional) – If a float is provided, the corresponding highest density interval will be plotted instead of the ensembles.

  • rng (int, optional) – The seed for the random number generator for drawing random samples.

  • residuals_offset (float, default=0) – The amount to raise the residual plot by (for a more compact plot).

  • V_kwargs (dict, default={}) – Keyword arguments to be passed to plt.plot or plt.fill_between for V.

  • B_kwargs (dict, default={}) – Keyword arguments to be passed to plt.plot or plt.fill_between for B.

  • res_kwargs (dict, default={}) – Keyword arguments to be passed to plt.plot or plt.fill_between for the residuals.

  • **kwargs (dict, optional) – Keyword arguments to be passed to plt.plot or plt.fill_between for all plots.

Returns:

ax

Return type:

plt.Axes

See also

draw_posterior_samples, summary, plt.plot, plt.fill_between

dive.plotting.plot_correlations(trace: InferenceData, axs: ndarray[Axes] | None = None, var_names: list[str] | None = None, marginals: bool = True, **kwargs) Axes

Plots 2D marginalized posteriors of selected variables.

Illustrates pairwise correlation plots between model parameters.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • axs (np.ndarray of plt.Axes) – 2D numpy array of the MatPlotLib axes to be plotted on. If not provided, axes will be automatically created. Needs to be the correct size.

  • var_names (list of str) – The variables to be plotted. If not provided, a list of relevant important variables will be automatically selected.

  • marginals (bool, default=True) – Whether or not to also include the 1D marginalized posteriors for each variable.

  • **kwargs (dict, optional) – Keyword arguments to be passed to az.plot_pair.

Returns:

axs

Return type:

plt.Axes

See also

az.plot_pair, summary

dive.plotting.plot_hist(trace: InferenceData, var: str, ax: Axes | None = None, combine: bool = False, gauss_id: int = 1, **kwargs) Axes

Plots a histogram of a parameter’s values.

Parameters:
  • trace (az.InferenceData) – The trace to read.

  • var (str) – The varaible to plot.

  • ax (plt.Axes, optional) – The MatPlotLib axes to plot on. If none provided, axes will be automatically created.

  • combine (bool, default=False) – For the parameters associated with a gaussian fit (r0, w, a), whether or not to combine all gaussians. E.g., if set to True, the r0, w, and a values for all gaussians will be combined.

  • gauss_id (int, default=1) – For the parameters associated with a gaussian fit (r0, w, a), which gaussian to plot. E.g., if var1 is set to w and gauss_id is set to 2, the widths of the second gaussian will be plotted. Counting starts at 1. Only if combine is False.

  • **kwargs (dict, optional) – Keyword arguments to pass to plt.hist.

Returns:

ax

Return type:

plt.Axes

See also

plt.hist

dive.plotting.plot_marginals(trace: InferenceData, axs: ndarray[Axes] | None = None, var_names: list[str] | None = None, ground_truth: dict[str, float] | None = None, point_estimate: str | None = None, hdi_prob: float = 'hide', **kwargs) Axes

Plot 1D marginalized posteriors of selected variables.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • axs (np.ndarray of plt.Axes, optional) – 1D numpy array of the MatPlotLib axes to be plotted on. If not provided, axes will be automatically created. Needs to be the correct size.

  • var_names (list of str, optional) – The variables to be plotted. If not provided, a list of relevant important variables will be automatically selected.

  • ground_truth (dict of str, float, optional) – A dictionary of ground-truth variable values, which will be plotted as vertical gray lines. Keys should be the variable names and arguments should be their values.

  • point_estimate (str, optional) – If “mean”, “median”, or “mode” are passed, that value will be plotted. See az.plot_posterior.

  • hdi_prob (float, default="hide") – The highest density interval to plot. If set to the default “hide”, none will be plotted.

  • **kwargs (dict, optional) – Keyword arguments to be passed to az.plot_posterior.

Returns:

axs

Return type:

plt.Axes

See also

az.plot_posterior, summary

dive.plotting.print_summary(trace: InferenceData, var_names: list[str] | None = None, **kwargs)

Prints a table with summary statistics of important parameters.

The table includes their means, standard deviations, effective sample sizes, Monte Carlo standard errors, and R-hat diagnostic values.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • var_names (list of str, optional) – The variables to be summarized. If not provided, a list of relevant important variables will be automatically selected.

  • **kwargs (dict, optional) – Keyword arguments to be passed to az.summary.

See also

az.summary, summary

dive.plotting.summary(trace: InferenceData, var_names: list[str] | None = None, num_samples: int = 100, rng: int | None = None)

Summary function to plot several plots.

Calls print_summary, plot_marginals, plot_correlations, plot_V, and plot_P.

Parameters:
  • trace (az.InferenceData) – The trace to be read.

  • var_names (list of str) – The variables to be plotted. If not provided, a list of relevant important variables will be automatically selected.

  • num_samples (int, default=100) – The number of random samples to draw from the trace.

  • rng (int, optional) – The seed for the random number generator for drawing random samples.