Usage

Installation

To install dive, install it with pip. Please note that the PyPI package name is dive-EPR:

pip install dive-EPR

To use dive, import it as usual:

import dive

You should also import important modules such as:

import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
import deerlab as dl
from scipy.io import loadmat

Reading Data

If you already have experimental traces that you wish to analyze, load them in as usual. Two example data files, 3992_good.dat and 3992_bad.dat, are also included with dive in the /data folder.

For example,

# example data in /data
loaded_data = np.genfromtxt(f"data/3992_good.dat", skip_header=1, delimiter=',')
t = loaded_data[:,0]
Vexp = loaded_data[:,1]

If you would instead like to create synthetic data, follow this tutorial from DeerLab or use the built in functions at Reading & Creating Data.

Modeling

Next, you should create for PyMC to run numerical sampling with dive.model. Detailed information can be found at Modeling.

# a shorter rmax is recommended for better and faster sampling
r = np.linspace(1.5, 6.5, 51)
# non-parametric (Tikhonov regularization):
model_tikh = dive.model(t, Vexp, method="regularization", r=r)
# parametric (Gaussian mixture model):
model_gauss = dive.model(t, Vexp, method="gaussian", n_gauss=2, r=r)

Sampling

You can now sample the model with PyMC through dive.sample. See Sampling for more information.

# non-parametric
trace_tikh = dive.sample(model_tikh, draws=2000, tune=1000, chains=4, random_seed=101)
# parametric
trace_gauss = dive.sample(model_gauss, draws=2000, tune=1000, chains=4, random_seed=101)

The output of dive.sample is an ArviZ InferenceData object.

Saving

The traces can be saved with dive.save_trace and dive.load_trace. More information can be found at Saving.

dive.save_trace(trace_tikh, "data/example_trace_tikh")
dive.save_trace(trace_gauss, "data/example_trace_gauss")
trace_tikh, model_tikh = dive.load_trace("data/example_trace_tikh.nc")
trace_gauss, model_gauss = dive.load_trace("data/example_trace_gauss.nc")

Validating

Before moving on to data analysis, it is important to make sure that the chains are stable and converged. Below are a few functions that are helpful in assessing chain convergence and quality:

dive.print_summary(trace): This prints a data table containing information about the marginalized parameter distributions, including their mean, 94% confidence interval, skewness, etc. For validating our results, however, we are most interested in the rightmost column titled r_hat.

r_hat is the ratio of interchain variance to intrachain variance. If it is close to 1, then the variances are similar, meaning that the chains have converged to similar regions.

r_hat values below 1.05 suggest good convergence, and it is encouraged that the r_hat for every variable in dive.print_summary should be below this value before you continue to analysis. (If you are just testing things out, it is okay, though not ideal, to have higher values of r_hat.)

Documentation can be found in Plotting.

dive.print_summary(trace_tikh)
dive.print_summary(trace_gauss)
                    mean    sd          hdi_3%      hdi_97% mcse_mean       mcse_sd ess_bulk        ess_tail        r_hat
$λ$                     0.533       0.007   0.521   0.546   0.000       0.000       1040.0      1656.0          1.01
$V_0$                   1.007       0.010   0.990   1.026   0.000       0.000       1819.0      3126.0          1.00
$σ$                     0.021       0.001   0.019   0.023   0.000       0.000       4447.0      3908.0          1.00
$\mathrm{lg}(α)$    -0.462  0.074   -0.597  -0.322  0.002       0.001       2225.0      3735.0          1.00
$B_\mathrm{end}$    0.897   0.015   0.872   0.929   0.000       0.000       1390.0      970.0           1.01
                    mean    sd      hdi_3%  hdi_97%         mcse_mean       mcse_sd ess_bulk        ess_tail        r_hat
$λ$                     0.502       0.005   0.495   0.512   0.001       0.001       16.0        70.0            1.19
$V_0$                   0.977       0.007   0.963   0.991   0.001       0.001       33.0        191.0           1.14
$σ$                     0.021       0.001   0.019   0.022   0.000       0.000       12.0        21.0            1.39
$r_{0,1}$           3.337   0.571   2.156   3.849   0.185       0.142       7.0         123.0           1.56
$r_{0,2}$           3.989   0.076   3.911   4.123   0.037       0.028       5.0         16.0            2.03
$w_1$                   1.157       0.908   0.348   2.998   0.265       0.192       16.0        128.0           1.32
$w_2$                   0.494       0.108   0.282   0.603   0.052       0.040       6.0         14.0            1.77
$a_1$                   0.264       0.259   0.001   0.708   0.126       0.097       6.0         51.0            1.91
$a_2$                   0.736       0.259   0.292   0.999   0.126       0.096       6.0         49.0            1.91
$B_\mathrm{end}$    0.850   0.009   0.839   0.866   0.002       0.002       14.0        134.0           1.21

az.plot_trace(trace): This helpful function from the arviz library makes two plots for each parameter. On the left, it plots the marginalized posterior of the parameter for each chain (if you set combined to False). This is very helpful in visualizing convergence. If you see one (or more) chains with a significantly different posterior plot, it is like unconverged. On the right, it plots the value of the parameter chronologically for each chain. If you notice that it gets ‘stuck’ (showing the same value for many draws in a row), it may be sampling poorly.

See the documentation for az.plot_trace.

# non-parametric trace is converged:
az.plot_trace(trace_tikh, var_names=["lamb","Bend","V0","sigma","lg_alpha"], combined=False)
# parametric trace is not converged
az.plot_trace(trace_gauss, var_names=["lamb","Bend","V0","sigma","r0","w","a"], combined=False)
_images/arviz-plot-trace.png

It can be seen that the non-parametric trace is well-converged, while the parametric trace is not. This is probably because the sampler does not know where to put the second gaussian in the parametric trace, as shown by the large uncertainties in its mean (r0), width (w), and amplitude (a) on the arviz plot. Changing the model to be a 1-gaussian model would likely help with convergence.

Question: My chains aren’t converged! What should I do?

Answer: Try the following steps:

1. Increase the number of draws. Most often, you just need to run your trace for longer to attain convergence.

2. Increase the number of tunes. Proper tuning is important to good sampling, and increasing the number of tuning steps may help, though anything above 2000 tunes per chain probably will not be too helpful.

3. Adjust your model. For example, shortening the maximum value of the r-axis will lead to faster sampling and convergence. You may also want to consider changing the number of gaussians (if using a parametric model) and changing the resolution of the r-axis.

4. Prune chain(s). If you notice that one (or more) chains are particularly unconverged, but the rest are decent, then you could remove the faulty chain(s). You can do this through dive.prune_chains(trace), which drops each chain one by one and pruning the one that reduces r_hat the most, repeating until converged. You can also remove chains with trace.sel(chain=desired_chain_numbers). Note that pruning chains will not be very effective if all of your chains are very different from each other.

5. Set alpha to a fixed value. This can be done by passing an alpha parameter in dive.model(). Removing a parameter may make for better sampling. However, fixing a parameter will also limit the validity of the uncertainty analysis, so do this with caution.

Accessing Results

The output of dive.sample is an ArviZ InferenceData object that consists of three xarray objects: posterior, sample_stats, and observed_data.

We are primarily interested in trace.posterior, which is an xarray object containing all of our parameters as data variables, the chains and draws as coordinates, and some supplementary information as attributes.

print(trace_tikh.posterior)
<xarray.Dataset> Size: 7MB
Dimensions:            (chain: 4, draw: 2000, P_dim_0: 51, P_Dirichlet_dim_0: 51)
Coordinates:
* chain              (chain) int64 32B 0 1 2 3
* draw               (draw) int64 16kB 0 1 2 3 4 ... 1995 1996 1997 1998 1999
* P_Dirichlet_dim_0  (P_Dirichlet_dim_0) int64 408B 0 1 2 3 4 ... 47 48 49 50
* P_dim_0            (P_dim_0) float64 408B 1.5 1.6 1.7 1.8 ... 6.3 6.4 6.5
Data variables:
    Bend               (chain, draw) float64 64kB 0.8891 0.8917 ... 0.8683
    P                  (chain, draw, P_dim_0) float64 3MB ...
    P_Dirichlet        (chain, draw, P_Dirichlet_dim_0) float64 3MB ...
    V0                 (chain, draw) float64 64kB 1.003 0.9967 ... 1.003 1.01
    delta              (chain, draw) float64 64kB ...
    k                  (chain, draw) float64 64kB ...
    lamb               (chain, draw) float64 64kB 0.5271 0.5239 ... 0.5251
    lg_alpha           (chain, draw) float64 64kB -0.4515 -0.4756 ... -0.6419
    lg_delta           (chain, draw) float64 64kB ...
    sigma              (chain, draw) float64 64kB 0.01939 0.02025 ... 0.02026
    tau                (chain, draw) float64 64kB ...
Attributes: (12/14)
    created_at:                 2024-08-06T20:39:37.362330+00:00
    arviz_version:              0.18.0
    inference_library:          pymc
    inference_library_version:  5.10.4
    sampling_time:              519.4337060451508
    tuning_steps:               1000
    ...                         ...
    include_background:         1
    include_mod_depth:          1
    include_amplitude:          1
    delta_prior:                [1.e+00 1.e-06]
    tau_prior:                  [1.e+00 1.e-04]
    random_seed:                101

The array stores the values for each parameter in the order they were generated, so if we wanted to find the value of lamb for the first draw of the first chain, we would print:

print(trace_tikh.posterior.lamb[0][0])
<xarray.DataArray 'lamb' ()> Size: 8B
array(0.527104)
Coordinates:
    chain    int64 8B 0
    draw     int64 8B 0

Oftentimes, however, this is not a very convenient way to access the posterior distributions. Instead, we can use the function az.extract, which allows us to extract the parameters we want and stack them to combine the chain and draw dimensions.

import arviz as az
lamb_array = az.extract(trace_tikh, var_names=["lamb"]) # extracts all the value of lambda for all samples and stacks them
print(lamb_array.values) #.values converts the xarray into a numpy array
[0.52710417 0.52392079 0.53517045 ... 0.53273982 0.53126788 0.52509982]

With an xarray/numpy array of your parameters, you can perform your data analysis as usual. ArviZ may also have useful functions for data analysis, including az.hdi, which calculates the highest density interval (the equivalent of a confidence interval) for some parameter.

lamb_hdi = az.hdi(trace_tikh, hdi_prob=0.95, var_names=["lamb"]) # calculates the 95% highest density interval for lambda
print(lamb_hdi)
<xarray.Dataset> Size: 64B
Dimensions:  (hdi: 2)
Coordinates:
* hdi      (hdi) <U6 48B 'lower' 'higher'
Data variables:
    lamb     (hdi) float64 16B 0.5202 0.5457

More information about InferenceData objects can be found here.

Plotting

dive provides several useful plots for assessing fit and investigating marginalized posterior distributions. Please see Plotting for full documentation. A few funcitons are outlined below:

dive.plot_V plots an ensemble of modelled signals to the true signal, along with residuals and a corresponding ensemble of background fits.

dive.plot_V(trace_tikh)
dive.plot_V(trace_gauss,show_avg=True,hdi=0.95)
_images/plotV-1.png _images/plotV-2.png

dive.plot_P plots an ensemble of distance distributions to give a visualization of the uncertainty of P.

dive.plot_P(trace_tikh)
dive.plot_P(trace_gauss,show_avg=True,hdi=0.95,alpha=0.5)
_images/plotP-1.png _images/plotP-2.png

For the marginal posteriors of the other parameters, you can call dive.plot_marginals for 1D marginalized distributions and dive.plot_correlations for 2D marginalized distributions.

dive.plot_marginals(trace_tikh, var_names=["lamb","Bend","V0","sigma","lg_alpha"], ground_truth={"lamb":0.5,"Bend":np.exp(-0.05*3.2),"V0":1,"sigma":0.02})
dive.plot_marginals(trace_gauss, var_names=["lamb","Bend","V0","sigma","r0","w","a"]) # spiky/uneven plots due to poor convergence
_images/marginals-1.png _images/marginals-2.png
dive.plot_correlations(trace_tikh)
dive.plot_correlations(trace_gauss,marginals=False)
_images/correlations-1.png _images/correlations-2.png