Estimate Bernoulli draws probabilility
We estimate a simple model of $n$ independent Bernoulli draws, with probability $α$. First, we load the packages we use.
First, we import DynamicHMC and related libraries,
using TransformVariables, LogDensityProblems, LogDensityProblemsAD, DynamicHMC,
TransformedLogDensities
then some packages that help code the log posterior,
using Parameters, Statistics, Random, Distributions, LinearAlgebra
then diagnostic tools,
using MCMCDiagnosticTools, DynamicHMC.Diagnostics
and use ForwardDiff for AD since the dimensions is small.
import ForwardDiff
Then define a structure to hold the data. For this model, the number of draws equal to 1
is a sufficient statistic.
"""
Toy problem using a Bernoulli distribution.
We model `n` independent draws from a ``Bernoulli(α)`` distribution.
"""
struct BernoulliProblem
"Total number of draws in the data."
n::Int
"Number of draws `==1` in the data"
s::Int
end
Main.BernoulliProblem
Then make the type callable with the parameters as a single argument. We use decomposition in the arguments, but it could be done inside the function, too.
function (problem::BernoulliProblem)(θ)
@unpack α = θ # extract the parameters
@unpack n, s = problem # extract the data
# log likelihood: the constant log(combinations(n, s)) term
# has been dropped since it is irrelevant for posterior sampling.
s * log(α) + (n-s) * log(1-α)
end
We should test this, also, this would be a good place to benchmark and optimize more complicated problems.
p = BernoulliProblem(20, 10)
p((α = 0.5, ))
-13.862943611198906
Recall that we need to
transform from $ℝ$ to the valid parameter domain
(0,1)
for more efficient sampling, andcalculate the derivatives for this transformed mapping.
The helper packages TransformVariables
and LogDensityProblems
take care of this. We use a flat prior.
t = as((α = as𝕀,))
P = TransformedLogDensity(t, p)
∇P = ADgradient(:ForwardDiff, P);
Finally, we sample from the posterior. The returned value the posterior matrix, diagnostic information, and the tuned sampler which would allow continuation of sampling.
results = [mcmc_with_warmup(Random.default_rng(), ∇P, 1000) for _ in 1:5]
5-element Vector{NamedTuple{(:posterior_matrix, :tree_statistics, :κ, :ϵ), Tuple{Matrix{Float64}, Vector{DynamicHMC.TreeStatisticsNUTS}, DynamicHMC.GaussianKineticEnergy{LinearAlgebra.Diagonal{Float64, Vector{Float64}}, LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Float64}}}:
(posterior_matrix = [0.44118681279422206 0.23551691444796968 … -0.02193842919207828 -0.023577547067679505], tree_statistics = [DynamicHMC.TreeStatisticsNUTS(-15.968355200554027, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x1b4fb8ae)), DynamicHMC.TreeStatisticsNUTS(-15.683314517323891, 1, turning at positions 0:1, 1.0, 1, DynamicHMC.Directions(0x56f4da5d)), DynamicHMC.TreeStatisticsNUTS(-16.15358129259003, 1, turning at positions 1:2, 0.8691271294782731, 3, DynamicHMC.Directions(0xd0376c5a)), DynamicHMC.TreeStatisticsNUTS(-20.17416491919264, 2, turning at positions -3:0, 0.5681208954860367, 3, DynamicHMC.Directions(0x5afa3898)), DynamicHMC.TreeStatisticsNUTS(-18.24135878935732, 1, turning at positions 2:3, 0.9890980617136939, 3, DynamicHMC.Directions(0x90b09037)), DynamicHMC.TreeStatisticsNUTS(-15.598020168492797, 1, turning at positions 0:1, 1.0, 1, DynamicHMC.Directions(0x71e3a93f)), DynamicHMC.TreeStatisticsNUTS(-15.458369237349014, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0x120f841a)), DynamicHMC.TreeStatisticsNUTS(-15.439835216181262, 2, turning at positions -1:2, 0.9836274241117814, 3, DynamicHMC.Directions(0xca38f27e)), DynamicHMC.TreeStatisticsNUTS(-17.51932462308967, 2, turning at positions -3:0, 0.6910804043681561, 3, DynamicHMC.Directions(0xca8578b4)), DynamicHMC.TreeStatisticsNUTS(-16.8929573284304, 1, turning at positions -1:-2, 0.7218779813693635, 3, DynamicHMC.Directions(0x3bcbd3a1)) … DynamicHMC.TreeStatisticsNUTS(-15.92685242614872, 1, turning at positions 0:1, 0.9933771346303615, 1, DynamicHMC.Directions(0xc5003c67)), DynamicHMC.TreeStatisticsNUTS(-17.754125823210423, 2, turning at positions -3:0, 0.8495914838452876, 3, DynamicHMC.Directions(0x13ff467c)), DynamicHMC.TreeStatisticsNUTS(-18.551090561443928, 2, turning at positions -1:2, 0.9154411555915963, 3, DynamicHMC.Directions(0x3f489562)), DynamicHMC.TreeStatisticsNUTS(-15.611879152044239, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x6a98bdbc)), DynamicHMC.TreeStatisticsNUTS(-15.484957138026836, 1, turning at positions -2:-3, 0.9999999999999999, 3, DynamicHMC.Directions(0xf9f599d8)), DynamicHMC.TreeStatisticsNUTS(-15.666410960712026, 1, turning at positions 1:2, 0.9158076548148976, 3, DynamicHMC.Directions(0x90c73bc6)), DynamicHMC.TreeStatisticsNUTS(-16.41336167775356, 1, turning at positions -1:0, 0.7691070059169374, 1, DynamicHMC.Directions(0x134ff2d4)), DynamicHMC.TreeStatisticsNUTS(-15.613946319589372, 1, turning at positions 2:3, 0.9987010904084155, 3, DynamicHMC.Directions(0x2853fe4f)), DynamicHMC.TreeStatisticsNUTS(-15.280179472692314, 2, turning at positions 0:3, 0.9945535054093521, 3, DynamicHMC.Directions(0x5222b31b)), DynamicHMC.TreeStatisticsNUTS(-15.846902914864014, 2, turning at positions 0:3, 0.90010134754908, 3, DynamicHMC.Directions(0x6321b40f))], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.46808504301534504], ϵ = 0.9412522654911162)
(posterior_matrix = [0.46538571409351054 -0.22176448268025659 … -0.5435933171194245 -0.7451541149362756], tree_statistics = [DynamicHMC.TreeStatisticsNUTS(-15.954474309232669, 1, turning at positions -1:-2, 0.8834157610435546, 3, DynamicHMC.Directions(0xc0c4ba75)), DynamicHMC.TreeStatisticsNUTS(-15.744950165630604, 2, turning at positions -6:-7, 0.9999832009673033, 7, DynamicHMC.Directions(0x9226e4f8)), DynamicHMC.TreeStatisticsNUTS(-15.37357522096998, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x0b569f10)), DynamicHMC.TreeStatisticsNUTS(-15.960399199531848, 2, turning at positions 0:3, 0.9323896275712852, 3, DynamicHMC.Directions(0x0e7d8c63)), DynamicHMC.TreeStatisticsNUTS(-16.226972883375726, 1, turning at positions -1:0, 0.9680233256530544, 1, DynamicHMC.Directions(0x5688ddd0)), DynamicHMC.TreeStatisticsNUTS(-17.288712235455577, 2, turning at positions -2:1, 0.9145207923325424, 3, DynamicHMC.Directions(0xd4114ce5)), DynamicHMC.TreeStatisticsNUTS(-16.40409978683969, 1, turning at positions -1:0, 0.9436482840163201, 1, DynamicHMC.Directions(0x501591dc)), DynamicHMC.TreeStatisticsNUTS(-16.270679474476495, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x3cc34762)), DynamicHMC.TreeStatisticsNUTS(-15.863235152640117, 1, turning at positions 2:3, 0.9929273760858622, 3, DynamicHMC.Directions(0x42301a67)), DynamicHMC.TreeStatisticsNUTS(-15.402572839860788, 2, turning at positions -3:0, 0.985550001561042, 3, DynamicHMC.Directions(0x24f83a08)) … DynamicHMC.TreeStatisticsNUTS(-16.869005548355098, 1, turning at positions 0:1, 0.8505791078079232, 1, DynamicHMC.Directions(0x1602d3f5)), DynamicHMC.TreeStatisticsNUTS(-16.8875295219362, 2, turning at positions 6:7, 0.9988114128858226, 7, DynamicHMC.Directions(0x69965297)), DynamicHMC.TreeStatisticsNUTS(-18.30012442568985, 2, turning at positions -1:2, 0.9266904301102111, 3, DynamicHMC.Directions(0x3e458676)), DynamicHMC.TreeStatisticsNUTS(-18.623123123884557, 1, turning at positions 2:3, 0.9483892239223937, 3, DynamicHMC.Directions(0x1fb75c6b)), DynamicHMC.TreeStatisticsNUTS(-17.795837376792097, 1, turning at positions 1:2, 0.6290800580297801, 3, DynamicHMC.Directions(0xcf610f66)), DynamicHMC.TreeStatisticsNUTS(-17.165755958020355, 1, turning at positions 2:3, 0.975339084429919, 3, DynamicHMC.Directions(0xfd70e87b)), DynamicHMC.TreeStatisticsNUTS(-16.3577476380149, 1, turning at positions 1:2, 0.7904169620855398, 3, DynamicHMC.Directions(0xfc2b50fa)), DynamicHMC.TreeStatisticsNUTS(-15.689837537025596, 1, turning at positions 1:2, 0.9360053120904744, 3, DynamicHMC.Directions(0x7fcfdd16)), DynamicHMC.TreeStatisticsNUTS(-17.11649885312921, 2, turning at positions 0:3, 0.8100429504794503, 3, DynamicHMC.Directions(0x6e3f271f)), DynamicHMC.TreeStatisticsNUTS(-16.784462009710133, 2, turning at positions -1:2, 0.9563815649554962, 3, DynamicHMC.Directions(0x3bca7a86))], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.3934408787742183], ϵ = 1.0255312560269043)
(posterior_matrix = [-0.21537802814382748 0.11486200754446987 … -0.4161992723045449 -0.0645791364781495], tree_statistics = [DynamicHMC.TreeStatisticsNUTS(-15.387606679393993, 1, turning at positions 0:1, 0.9863772364600606, 1, DynamicHMC.Directions(0xc378024d)), DynamicHMC.TreeStatisticsNUTS(-15.470743176779854, 1, turning at positions -2:-3, 0.9865263082119428, 3, DynamicHMC.Directions(0xdee79e50)), DynamicHMC.TreeStatisticsNUTS(-15.389966136734868, 1, turning at positions 1:2, 0.9862307437017024, 3, DynamicHMC.Directions(0x2f3ed622)), DynamicHMC.TreeStatisticsNUTS(-15.393430889112562, 1, turning at positions 0:1, 0.9778460820750433, 1, DynamicHMC.Directions(0x0add44f5)), DynamicHMC.TreeStatisticsNUTS(-15.45229650686006, 2, turning at positions -2:1, 0.9906453282765785, 3, DynamicHMC.Directions(0x8426de55)), DynamicHMC.TreeStatisticsNUTS(-15.445497465291082, 1, turning at positions -1:-2, 0.9635039621668398, 3, DynamicHMC.Directions(0x5da68075)), DynamicHMC.TreeStatisticsNUTS(-15.562917578106166, 2, turning at positions -1:2, 0.9853204349698991, 3, DynamicHMC.Directions(0x4b0ec386)), DynamicHMC.TreeStatisticsNUTS(-15.592063845663272, 1, turning at positions 0:1, 1.0, 1, DynamicHMC.Directions(0xf52e5e53)), DynamicHMC.TreeStatisticsNUTS(-15.668847060048408, 2, turning at positions -1:2, 0.9766605776620878, 3, DynamicHMC.Directions(0x0d6cdcb6)), DynamicHMC.TreeStatisticsNUTS(-15.741858938650488, 2, turning at positions -3:0, 0.9269092035890608, 3, DynamicHMC.Directions(0xcb7deb20)) … DynamicHMC.TreeStatisticsNUTS(-15.673896369711901, 1, turning at positions 2:3, 0.9912213070627892, 3, DynamicHMC.Directions(0xdd54178f)), DynamicHMC.TreeStatisticsNUTS(-15.313252483396601, 1, turning at positions 1:2, 0.9883727900192136, 3, DynamicHMC.Directions(0x7955a9ba)), DynamicHMC.TreeStatisticsNUTS(-15.77997920784738, 1, turning at positions 1:2, 0.9093339467793812, 3, DynamicHMC.Directions(0x4b43f422)), DynamicHMC.TreeStatisticsNUTS(-16.599346913875753, 2, turning at positions -2:1, 0.9395796943190025, 3, DynamicHMC.Directions(0xf4cb0059)), DynamicHMC.TreeStatisticsNUTS(-17.022356042094717, 1, turning at positions 2:3, 0.9429833100604094, 3, DynamicHMC.Directions(0x6ce45977)), DynamicHMC.TreeStatisticsNUTS(-15.368874346907097, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x33a8743c)), DynamicHMC.TreeStatisticsNUTS(-15.777352577533533, 1, turning at positions 1:2, 0.9235167148141709, 3, DynamicHMC.Directions(0x79f4318e)), DynamicHMC.TreeStatisticsNUTS(-15.54027933752425, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0xdcc691b4)), DynamicHMC.TreeStatisticsNUTS(-16.262809635969653, 1, turning at positions 1:2, 0.8813729437077962, 3, DynamicHMC.Directions(0x8d481272)), DynamicHMC.TreeStatisticsNUTS(-15.643332816118248, 1, turning at positions 2:3, 0.9999999999999999, 3, DynamicHMC.Directions(0x7025b20f))], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.4351431194774923], ϵ = 1.0146261223731448)
(posterior_matrix = [-0.18657649141505123 -0.03496241147764012 … -0.0870733649363256 0.05801258276415713], tree_statistics = [DynamicHMC.TreeStatisticsNUTS(-15.355708043446457, 2, turning at positions 0:3, 0.9870181962873444, 3, DynamicHMC.Directions(0xf3513c6f)), DynamicHMC.TreeStatisticsNUTS(-15.336203606409546, 1, turning at positions -2:-3, 0.9993075828408658, 3, DynamicHMC.Directions(0x5db2fa40)), DynamicHMC.TreeStatisticsNUTS(-15.668265486187641, 1, turning at positions -1:-2, 0.9265137270215189, 3, DynamicHMC.Directions(0x78a7b955)), DynamicHMC.TreeStatisticsNUTS(-16.398740801437683, 1, turning at positions -2:-3, 0.9067326875062228, 3, DynamicHMC.Directions(0x4e4728b0)), DynamicHMC.TreeStatisticsNUTS(-15.96259011559142, 1, turning at positions 0:1, 0.9271057185771879, 1, DynamicHMC.Directions(0x573f21e3)), DynamicHMC.TreeStatisticsNUTS(-16.31862435172378, 1, turning at positions 0:1, 0.9402121654595735, 1, DynamicHMC.Directions(0x4885a7ff)), DynamicHMC.TreeStatisticsNUTS(-16.282042678307747, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0x960216ce)), DynamicHMC.TreeStatisticsNUTS(-15.86143412041134, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0xc252811e)), DynamicHMC.TreeStatisticsNUTS(-15.282263498568549, 1, turning at positions 0:1, 0.9969000873100248, 1, DynamicHMC.Directions(0x2969426d)), DynamicHMC.TreeStatisticsNUTS(-15.480913086005081, 1, turning at positions 1:2, 0.9665520561401818, 3, DynamicHMC.Directions(0xb20bb0be)) … DynamicHMC.TreeStatisticsNUTS(-15.637587540041292, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0x27ee3a8e)), DynamicHMC.TreeStatisticsNUTS(-15.513704740849217, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0xe65ecfde)), DynamicHMC.TreeStatisticsNUTS(-15.579745845580678, 2, turning at positions 0:3, 0.960333294625491, 3, DynamicHMC.Directions(0x6a5fd8c7)), DynamicHMC.TreeStatisticsNUTS(-17.785523303593045, 1, turning at positions 1:2, 0.6275284802867417, 3, DynamicHMC.Directions(0x8c1b3b9a)), DynamicHMC.TreeStatisticsNUTS(-15.614267034329588, 2, turning at positions -1:2, 0.9915440685770157, 3, DynamicHMC.Directions(0x6bc81d66)), DynamicHMC.TreeStatisticsNUTS(-15.62658398989789, 1, turning at positions 1:2, 0.9411003293508426, 3, DynamicHMC.Directions(0xeaa2d27e)), DynamicHMC.TreeStatisticsNUTS(-15.50647693256249, 1, turning at positions -1:0, 0.9989542305895763, 1, DynamicHMC.Directions(0xe4064938)), DynamicHMC.TreeStatisticsNUTS(-15.5371227516098, 1, turning at positions 2:3, 0.9888031209860336, 3, DynamicHMC.Directions(0xde686e07)), DynamicHMC.TreeStatisticsNUTS(-15.270268770953736, 1, turning at positions 0:1, 0.9970878216877336, 1, DynamicHMC.Directions(0x8ac661cd)), DynamicHMC.TreeStatisticsNUTS(-15.36060431846609, 2, turning at positions 0:3, 0.9877038728883938, 3, DynamicHMC.Directions(0x9d112fef))], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.37585792746486696], ϵ = 1.0918130254022915)
(posterior_matrix = [0.31522771100623675 0.4658084029256925 … -0.213940713705729 0.7283698333295582], tree_statistics = [DynamicHMC.TreeStatisticsNUTS(-15.710438969480842, 2, turning at positions -3:0, 0.9359716948191452, 3, DynamicHMC.Directions(0xbbd47208)), DynamicHMC.TreeStatisticsNUTS(-15.864192845890916, 1, turning at positions -1:0, 0.9202976151307684, 1, DynamicHMC.Directions(0xc0b7dff0)), DynamicHMC.TreeStatisticsNUTS(-16.245617130104808, 1, turning at positions -1:0, 0.9244232392214982, 1, DynamicHMC.Directions(0x6a9d883e)), DynamicHMC.TreeStatisticsNUTS(-16.475221740862043, 2, turning at positions -2:1, 0.9559551098420646, 3, DynamicHMC.Directions(0x967d8c59)), DynamicHMC.TreeStatisticsNUTS(-15.488970535884523, 2, turning at positions -3:0, 0.974594830432193, 3, DynamicHMC.Directions(0x39f7db58)), DynamicHMC.TreeStatisticsNUTS(-15.311581626195416, 1, turning at positions 0:1, 1.0, 1, DynamicHMC.Directions(0xd774513b)), DynamicHMC.TreeStatisticsNUTS(-15.661905788098457, 2, turning at positions 0:3, 0.936348322687408, 3, DynamicHMC.Directions(0x70b42157)), DynamicHMC.TreeStatisticsNUTS(-18.609563610927356, 2, turning at positions -3:0, 0.5920623846439861, 3, DynamicHMC.Directions(0x2cf1de8c)), DynamicHMC.TreeStatisticsNUTS(-15.320828630047362, 1, turning at positions -1:0, 0.9854290620036505, 1, DynamicHMC.Directions(0x298c5afa)), DynamicHMC.TreeStatisticsNUTS(-15.650690761397938, 1, turning at positions -1:-2, 0.9489559638703281, 3, DynamicHMC.Directions(0x5c664a45)) … DynamicHMC.TreeStatisticsNUTS(-16.10378931155812, 1, turning at positions -1:0, 0.9567067654438173, 1, DynamicHMC.Directions(0x3580c29a)), DynamicHMC.TreeStatisticsNUTS(-17.008602979519214, 1, turning at positions 0:1, 0.7887531203771654, 1, DynamicHMC.Directions(0x84a09f29)), DynamicHMC.TreeStatisticsNUTS(-17.014348927147964, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x1bc46af8)), DynamicHMC.TreeStatisticsNUTS(-16.254761093859386, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0x49bd87fc)), DynamicHMC.TreeStatisticsNUTS(-16.486050190746735, 1, turning at positions 0:1, 0.8235668894159108, 1, DynamicHMC.Directions(0x51368bcb)), DynamicHMC.TreeStatisticsNUTS(-16.38941819494885, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0x8d658db6)), DynamicHMC.TreeStatisticsNUTS(-15.297081438996013, 1, turning at positions -1:0, 0.9901590936256242, 1, DynamicHMC.Directions(0x9b679630)), DynamicHMC.TreeStatisticsNUTS(-15.330704116873118, 1, turning at positions 0:1, 0.993180018889921, 1, DynamicHMC.Directions(0xf46a2927)), DynamicHMC.TreeStatisticsNUTS(-15.550426926612179, 1, turning at positions 1:2, 0.9692909389515624, 3, DynamicHMC.Directions(0xdda40dfa)), DynamicHMC.TreeStatisticsNUTS(-17.493833186060954, 1, turning at positions 1:2, 0.654906117337214, 3, DynamicHMC.Directions(0xe8f26cd2))], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.46103321971273487], ϵ = 0.9605292801919004)
To get the posterior for $α$, we need to use the columns of the posterior_matrix
and then transform
posterior = transform.(t, eachcol(pool_posterior_matrices(results)));
Extract the parameter.
posterior_α = first.(posterior);
check the mean
mean(posterior_α)
0.49883700700811895
check the effective sample size
ess, R̂ = ess_rhat(stack_posterior_matrices(results))
(ess = [1736.8088164641554], rhat = [1.0019327641915297])
NUTS-specific statistics of the first chain
summarize_tree_statistics(results[1].tree_statistics)
Hamiltonian Monte Carlo sample of length 1000
acceptance rate mean: 0.92, 5/25/50/75/95%: 0.67 0.9 0.97 1.0 1.0
termination: divergence => 0%, max_depth => 0%, turning => 100%
depth: 0 => 0%, 1 => 65%, 2 => 35%
This page was generated using Literate.jl.