diff --git a/pymc/step_methods/__init__.py b/pymc/step_methods/__init__.py index 733eed5ed6..194536e470 100644 --- a/pymc/step_methods/__init__.py +++ b/pymc/step_methods/__init__.py @@ -15,7 +15,7 @@ """Step methods.""" from pymc.step_methods.compound import BlockedStep, CompoundStep -from pymc.step_methods.hmc import NUTS, HamiltonianMC +from pymc.step_methods.hmc import NUTS, WALNUTS, HamiltonianMC from pymc.step_methods.metropolis import ( BinaryGibbsMetropolis, BinaryMetropolis, @@ -35,6 +35,7 @@ # Other step methods can be added by appending to this list STEP_METHODS: list[type[BlockedStep]] = [ NUTS, + WALNUTS, HamiltonianMC, Metropolis, BinaryMetropolis, diff --git a/pymc/step_methods/hmc/__init__.py b/pymc/step_methods/hmc/__init__.py index e51cef7784..432aa4ea5b 100644 --- a/pymc/step_methods/hmc/__init__.py +++ b/pymc/step_methods/hmc/__init__.py @@ -16,3 +16,4 @@ from pymc.step_methods.hmc.hmc import HamiltonianMC from pymc.step_methods.hmc.nuts import NUTS +from pymc.step_methods.hmc.walnuts import WALNUTS diff --git a/pymc/step_methods/hmc/adaptive_integrators.py b/pymc/step_methods/hmc/adaptive_integrators.py new file mode 100644 index 0000000000..5a638d99f5 --- /dev/null +++ b/pymc/step_methods/hmc/adaptive_integrators.py @@ -0,0 +1,625 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adaptive integrators for WALNUTS sampler. + +Based on adaptiveIntegrators.py from WALNUTSpy by Tore Selland Kleppe. +""" + +import sys + +import numpy as np + +from pymc.step_methods.hmc.walnuts_constants import __logZero + + +class integratorReturn: + """Common return object for integrators.""" + + def __init__(self, q1, v1, lp1, grad1, nEvalF, nEvalB, If, Ib, c, lwt, igrConst): + self.q = q1 + self.v = v1 + self.lp = lp1 + self.grad = grad1 + self.nEvalF = nEvalF + self.nEvalB = nEvalB + self.If = If + self.Ib = Ib + self.c = c + self.lwt = lwt + self.igrConst = igrConst + + def __str__(self): + return ( + "dim: " + + str(len(self.q)) + + " If: " + + str(self.If) + + " Ib: " + + str(self.Ib) + + " c: " + + str(self.c) + ) + + +class integratorAuxPar: + """Class for passing assorted tuning parameters to the integrators.""" + + def __init__( + self, + minC=0, + maxC=10, + R2Pprob0=2.0 / 3.0, + maxFPiter=30, + FPtol=1.0e-8, + FPNewton=False, + rescaledGradThresh=5.0, + ): + self.minC = minC + self.maxC = maxC + self.R2Pprob0 = R2Pprob0 + self.maxFPiter = maxFPiter + self.FPtol = FPtol + self.FPNewton = FPNewton + self.rescaledGradThresh = rescaledGradThresh + + +def fixedLeapFrog(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Implement basic leapfrog integration step.""" + vh = xi * v + 0.5 * h * g + qq = q + h * vh + fnew, gnew = lpFun(qq) + vv = vh + 0.5 * h * gnew + + H1 = -fnew + 0.5 * sum(vv * vv) + + return integratorReturn( + qq, + xi * vv, + fnew, + gnew, + 1, + 0, + 0, + 0, + 0, + 0.0, + h * (max(1.0e-10, abs(Ham0 - H1)) ** (-1.0 / 3.0)), + ) + + +def adaptLeapFrogD(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Adaptive Leap Frog with Deterministic choice of precision parameter.""" + nEvalF = 0 + If = auxPar.maxC + for c in range(auxPar.minC, auxPar.maxC + 1): + nstep = 2**c + hh = h / nstep + qq = q + vv = xi * v + gg = g + Hams = np.zeros(nstep + 1) + Hams[0] = Ham0 + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalF += 1 + vv = vh + 0.5 * hh * gg + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + maxErr = abs(Hams[0] - Hams[-1]) + + if all(np.isfinite(Hams)) and maxErr < delta: + If = c + break + + qOut = qq + vOut = vv + fOut = fnew + gOut = gg + + igrConst = hh * (np.max(np.abs(np.diff(Hams))) ** (-1.0 / 3.0)) + + Ib = If + nEvalB = 0 + + if If > auxPar.minC: + H0b = Hams[-1] + for c in range(auxPar.minC, If): + nstep = 2**c + hh = h / nstep + qq = qOut + vv = -vOut + gg = gOut + Hams = np.zeros(nstep + 1) + Hams[0] = H0b + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalB += 1 + vv = vh + 0.5 * hh * gg + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + maxErr = abs(Hams[0] - Hams[-1]) + if all(np.isfinite(Hams)) and maxErr < delta: + Ib = c + break + + return integratorReturn( + qOut, xi * vOut, fOut, gOut, nEvalF, nEvalB, If, Ib, If, (If != Ib) * __logZero, igrConst + ) + + +def adaptLeapFrogFlowD(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Adaptive Leap Frog with Deterministic choice of precision parameter and flow-based error criterion.""" + nEvalF = 0 + If = auxPar.maxC + for c in range(0, auxPar.maxC + 1): + nstep = 2**c + hh = h / nstep + qq = q + vv = xi * v + gg = g + Hams = np.zeros(nstep + 1) + Errs = np.zeros(nstep) + Hams[0] = Ham0 + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qold = qq + gold = gg + vold = vv + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalF += 1 + vv = vh + 0.5 * hh * gg + + qMid = 0.5 * (qq + qold) + (hh / 8.0) * (vold - vv) + fMid, gMid = lpFun(qMid) + nEvalF += 1 + + qf = qold + hh * vold + hh * hh * ((1.0 / 6.0) * gold + (1.0 / 3.0) * gMid) + err = np.max(np.abs(qf - qq)) + + vf = vold + (hh / 6.0) * (gold + gg + 4.0 * gMid) + err = max(err, np.max(np.abs(vf - vv))) + + qb = qq - hh * vv + hh * hh * ((1.0 / 6.0) * gg + (1.0 / 3.0) * gMid) + err = max(err, np.max(np.abs(qb - qold))) + + vb = -(-(vv + (hh / 6.0) * (gold + gg + 4.0 * gMid))) + + err = max(err, np.max(np.abs(vb - vold))) + Errs[i - 1] = err + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + maxErr = np.max(Errs) + + if all(np.isfinite(Hams)) and maxErr < delta: + If = c + break + + qOut = qq + vOut = vv + fOut = fnew + gOut = gg + + igrConst = hh * (np.max(np.abs(np.diff(Hams))) ** (-1.0 / 3.0)) + + Ib = If + nEvalB = 0 + + if If > 0: + H0b = Hams[-1] + for c in range(0, If): + nstep = 2**c + hh = h / nstep + qq = qOut + vv = -vOut + gg = gOut + Hams = np.zeros(nstep + 1) + Hams[0] = H0b + Errs = np.zeros(nstep) + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qold = qq + gold = gg + vold = vv + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalB += 1 + vv = vh + 0.5 * hh * gg + + qMid = 0.5 * (qq + qold) + (hh / 8.0) * (vold - vv) + fMid, gMid = lpFun(qMid) + nEvalB += 1 + qf = qold + hh * vold + hh * hh * ((1.0 / 6.0) * gold + (1.0 / 3.0) * gMid) + err = np.max(np.abs(qf - qq)) + vf = vold + (hh / 6.0) * (gold + gg + 4.0 * gMid) + err = max(err, np.max(np.abs(vf - vv))) + qb = qq - hh * vv + hh * hh * ((1.0 / 6.0) * gg + (1.0 / 3.0) * gMid) + err = max(err, np.max(np.abs(qb - qold))) + vb = -(-(vv + (hh / 6.0) * (gold + gg + 4.0 * gMid))) + + err = max(err, np.max(np.abs(vb - vold))) + Errs[i - 1] = err + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + maxErr = np.max(Errs) + if all(np.isfinite(Hams)) and maxErr < delta: + Ib = c + break + + return integratorReturn( + qOut, xi * vOut, fOut, gOut, nEvalF, nEvalB, If, Ib, If, (If != Ib) * __logZero, igrConst + ) + + +def adaptLeapFrogR2P(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Adaptive Leap Frog with Randomized-to-Probabilistic choice.""" + nEvalF = 0 + If = auxPar.maxC + for c in range(auxPar.minC, auxPar.maxC + 1): + nstep = 2**c + hh = h / nstep + qq = q + vv = xi * v + gg = g + Hams = np.zeros(nstep + 1) + Hams[0] = Ham0 + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalF += 1 + vv = vh + 0.5 * hh * gg + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + maxErr = abs(Hams[0] - Hams[-1]) + + if all(np.isfinite(Hams)) and maxErr < delta: + If = c + break + + if np.random.uniform() < auxPar.R2Pprob0: + # simulation occur at minimal accepted precision + qOut = qq + vOut = vv + fOut = fnew + gOut = gg + cSim = If + igrConst = hh * (np.max(np.abs(np.diff(Hams))) ** (-1.0 / 3.0)) + else: + # simulation occur at minimal + 1 + c = If + 1 + nstep = 2**c + hh = h / nstep + qq = q + vv = xi * v + gg = g + Hams = np.zeros(nstep + 1) + Hams[0] = Ham0 + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalF += 1 + vv = vh + 0.5 * hh * gg + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + qOut = qq + vOut = vv + fOut = fnew + gOut = gg + cSim = If + 1 + igrConst = hh * (np.max(np.abs(np.diff(Hams))) ** (-1.0 / 3.0)) + + # done forward simulation pass, now do backward simulations + nEvalB = 0 + + if cSim == If: + maxTry = If - 1 + Ib = If + lwtf = np.log(auxPar.R2Pprob0) + else: + maxTry = auxPar.maxC + Ib = auxPar.maxC + lwtf = np.log(1.0 - auxPar.R2Pprob0) + + if maxTry >= auxPar.minC: + H0b = Hams[-1] + for c in range(auxPar.minC, maxTry + 1): + nstep = 2**c + hh = h / nstep + qq = qOut + vv = -vOut + gg = gOut + Hams = np.zeros(nstep + 1) + Hams[0] = H0b + + for i in range(1, nstep + 1): + vh = vv + 0.5 * hh * gg + qq = qq + hh * vh + fnew, gg = lpFun(qq) + nEvalB += 1 + vv = vh + 0.5 * hh * gg + Hams[i] = -fnew + 0.5 * sum(vv * vv) + + maxErr = abs(Hams[0] - Hams[-1]) + if all(np.isfinite(Hams)) and maxErr < delta: + Ib = c + break + + # done backward simulation pass, now work out backward probability + lwtb = __logZero + if cSim == Ib: + lwtb = np.log(auxPar.R2Pprob0) + elif cSim == Ib + 1: + lwtb = np.log(1.0 - auxPar.R2Pprob0) + + return integratorReturn( + qOut, xi * vOut, fOut, gOut, nEvalF, nEvalB, If, Ib, cSim, lwtb - lwtf, igrConst + ) + + +def adaptImplicitMidpointD(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Adaptive Implicit Midpoint integrator.""" + nEvalF = 0 + If = auxPar.maxC + for c in range(0, auxPar.maxC + 1): + nstep = 2**c + hh = h / nstep + qq = q + vv = xi * v + gg = g + Hams = np.zeros(nstep + 1) + Hams[0] = Ham0 + numCompleted = 0 + for i in range(1, nstep + 1): + # initial guess based on leap frog + qt = qq + hh * (vv + 0.5 * hh * gg) + + # controls for fixed point iterations + converged = False + oldMaxErr = 1.0e100 + + for iter in range(1, auxPar.maxFPiter + 1): + mpq = 0.5 * (qt + qq) + if auxPar.FPNewton: + fmp, gmp, Hmp = lpFun(mpq, True) + HH = 0.25 * hh * hh * Hmp - np.identity(len(gmp)) + qtNew = qt - np.linalg.solve(HH, qq + hh * vv + (0.5 * hh * hh) * gmp - qt) + else: + fmp, gmp = lpFun(mpq) + qtNew = qq + hh * vv + (0.5 * hh * hh) * gmp + + nEvalF += 1 + maxErr = np.max(np.abs(qtNew - qt)) + qt = qtNew + if maxErr < auxPar.FPtol: + converged = True + break + + if maxErr > 1.1 * oldMaxErr: + break + oldMaxErr = maxErr + + if not converged: + break + + # step used and evaluation at mesh times + mpq = 0.5 * (qt + qq) + + fmp, gmp = lpFun(mpq) + nEvalF += 1 + qq = qq + hh * vv + (0.5 * hh * hh) * gmp + vv = vv + hh * gmp + + fnew, gg = lpFun(qq) + nEvalF += 1 + Hams[i] = -fnew + 0.5 * sum(vv * vv) + numCompleted += 1 + + maxHErr = abs(Hams[0] - Hams[-1]) + if all(np.isfinite(Hams)) and maxHErr < delta and numCompleted == nstep: + If = c + break + + if not converged: + import warnings + + warnings.warn("Numerical problems in adaptImplicitMidpoint, consider increasing maxC") + sys.exit() + + qOut = qq + vOut = vv + fOut = fnew + gOut = gg + + igrConst = hh * (np.max(np.abs(np.diff(Hams))) ** (-1.0 / 3.0)) + + Ib = auxPar.maxC + nEvalB = 0 + + Hb0 = -fOut + 0.5 * sum(vOut * vOut) + for c in range(0, auxPar.maxC + 1): + nstep = 2**c + hh = h / nstep + qq = qOut + vv = -vOut + gg = gOut + Hams = np.zeros(nstep + 1) + Hams[0] = Hb0 + numCompleted = 0 + for i in range(1, nstep + 1): + # initial guess based on leap frog + qt = qq + hh * (vv + 0.5 * hh * gg) + + # controls for fixed point iterations + converged = False + oldMaxErr = 1.0e100 + + for iter in range(1, auxPar.maxFPiter + 1): + mpq = 0.5 * (qt + qq) + if auxPar.FPNewton: + fmp, gmp, Hmp = lpFun(mpq, True) + HH = 0.25 * hh * hh * Hmp - np.identity(len(gmp)) + qtNew = qt - np.linalg.solve(HH, qq + hh * vv + (0.5 * hh * hh) * gmp - qt) + else: + fmp, gmp = lpFun(mpq) + qtNew = qq + hh * vv + (0.5 * hh * hh) * gmp + + nEvalB += 1 + maxErr = np.max(np.abs(qtNew - qt)) + qt = qtNew + + if maxErr < auxPar.FPtol: + converged = True + break + + if maxErr > 1.1 * oldMaxErr: + break + oldMaxErr = maxErr + + if not converged: + break + + # step used + mpq = 0.5 * (qt + qq) + fmp, gmp = lpFun(mpq) + nEvalB += 1 + qq = qq + hh * vv + (0.5 * hh * hh) * gmp + vv = vv + hh * gmp + + fnew, gg = lpFun(qq) + nEvalB += 1 + Hams[i] = -fnew + 0.5 * sum(vv * vv) + numCompleted += 1 + + maxHErr = abs(Hams[0] - Hams[-1]) + if all(np.isfinite(Hams)) and maxHErr < delta and numCompleted == nstep: + Ib = c + break + + return integratorReturn( + qOut, xi * vOut, fOut, gOut, nEvalF, nEvalB, If, Ib, If, (If != Ib) * __logZero, igrConst + ) + + +def rescaledLeapFrogTest(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Test rescaled leapfrog integrator.""" + d = len(q) + Sd = np.exp(np.random.normal(scale=0.5, size=d)) + + vv = xi * v + qb = q / Sd + gb = Sd * g + vh = vv + 0.5 * h * gb + qbn = qb + h * vh + fout, g = lpFun(qbn * Sd) + vout = vh + 0.5 * h * Sd * g + return integratorReturn(qbn * Sd, xi * vout, fout, g, 1, 0, 0, 0, 0, 0, 1.0) + + +def adaptRescaledLeapFrogD(q, v, g, Ham0, h, xi, lpFun, delta, auxPar): + """Adaptive Rescaled Leap Frog with Deterministic choice.""" + gradThresh = auxPar.rescaledGradThresh + d = len(q) + Sd = np.ones(d) + Sred = np.zeros(d, dtype=np.int64) + vv = xi * v + If = auxPar.maxC + nEvalF = 0 + nEvalB = 0 + for c in range(0, auxPar.maxC + 1): + qb = q / Sd + gb = Sd * g + vh = vv + 0.5 * h * gb + qbn = qb + h * vh + q1 = qbn * Sd + ff, gnew = lpFun(q1) + nEvalF += 1 + gb1 = Sd * gnew + v1 = vh + 0.5 * h * gb1 + gbmean = 0.5 * (np.abs(gb) + np.abs(gb1)) + Ham1 = -ff + 0.5 * sum(v1 * v1) + + grTooBig = gbmean > gradThresh + + if not np.isfinite(Ham1): + Sred += 1 + elif np.any(grTooBig): + Sred[grTooBig] += 1 + elif np.abs(Ham0 - Ham1) > delta: + Sred += 1 + else: + If = c + break + + Sd = 2.0 ** (-Sred) + + qOut = q1 + vOut = v1 + fOut = ff + gOut = gnew + SredForw = Sred + + Hb0 = Ham1 + + Sd = np.ones(d) + Sred = np.zeros(d, dtype=np.int64) + vv = -vOut + Ib = If + + if If > 0: + for c in range(0, auxPar.maxC + 1): + qb = qOut / Sd + gb = Sd * gOut + vh = vv + 0.5 * h * gb + qbn = qb + h * vh + q1 = qbn * Sd + ff, gnew = lpFun(q1) + nEvalB += 1 + gb1 = Sd * gnew + v1 = vh + 0.5 * h * gb1 + gbmean = 0.5 * (np.abs(gb) + np.abs(gb1)) + Ham1 = -ff + 0.5 * sum(v1 * v1) + + grTooBig = gbmean > gradThresh + + if not np.isfinite(Ham1): + Sred += 1 + elif np.any(grTooBig): + Sred[grTooBig] += 1 + elif np.abs(Hb0 - Ham1) > delta: + Sred += 1 + else: + Ib = c + break + + if np.all(SredForw == Sred): + Ib = c + 1 + break + Sd = 2.0 ** (-Sred) + + lpw = __logZero * (not np.all(Sred == SredForw)) + + return integratorReturn(qOut, xi * vOut, fOut, gOut, nEvalF, nEvalB, If, Ib, If, lpw, 1.0) diff --git a/pymc/step_methods/hmc/p2_quantile.py b/pymc/step_methods/hmc/p2_quantile.py new file mode 100644 index 0000000000..3870b08bc2 --- /dev/null +++ b/pymc/step_methods/hmc/p2_quantile.py @@ -0,0 +1,104 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Online estimation/approximation of quantiles using the P-squared algorithm. + +Based on P2quantile.py from WALNUTSpy by Tore Selland Kleppe. +Implementation of Jain and Chlamtac, Communications of the ACM, 28(10) 1985. +""" + +import numpy as np + + +class P2quantile: + def __init__(self, prob=0.5): + self.npush = 0 + self.p = prob + self.x = np.zeros(5) + self.q = np.zeros(5) + self.n = np.r_[1:6] + self.npp = np.array([1.0, 1.0 + 2.0 * prob, 1.0 + 4.0 * prob, 3.0 + 2.0 * prob, 5.0]) + + def quantile(self): + return self.q[2] + + def findInterval(self, xi): + if xi < self.q[0]: + return 0 + elif xi > self.q[4]: + return 5 + for i in range(0, 4): + if xi < self.q[i + 1]: + return i + 1 + + def pushCore(self, xi): + self.npush += 1 + if self.npush <= 5: + self.x[self.npush - 1] = xi + + if self.npush == 5: + self.x = np.sort(self.x) + self.q = self.x + elif self.npush > 5: + k = self.findInterval(xi) + + if k == 0: + self.q[0] = xi + k = 1 + elif k == 5: + self.q[4] = xi + k = 4 + + self.n[k:5] += 1 + + nn = self.npush + pp = self.p + self.npp = np.array( + [ + 1.0, + 0.5 * (nn - 1) * pp + 1.0, + (nn - 1) * pp + 1.0, + (nn - 1) * (1 + pp) / 2.0 + 1, + nn, + ] + ) + + for i in range(2, 5): + ni = self.n[i - 1] + nip = self.n[i] + nim = self.n[i - 2] + di = self.npp[i - 1] - ni + + if (di >= 1.0 and nip - ni > 1) or (di <= -1.0 and nim - ni < -1): + di = np.sign(di).astype(np.int64) + + qi = self.q[i - 1] + qip = qi + (di / (nip - nim)) * ( + (ni - nim + di) * (self.q[i] - qi) / (nip - ni) + + (nip - ni - di) * (qi - self.q[i - 2]) / (ni - nim) + ) + if self.q[i - 2] < qip and qip < self.q[i]: + self.q[i - 1] = qip + else: + self.q[i - 1] = qi + di * (self.q[i + di - 1] - qi) / ( + self.n[i + di - 1] - self.n[i - 1] + ) + self.n[i - 1] += di + + def pushVec(self, x): + for val in x: + self.pushCore(val) + + def push(self, x): + self.pushCore(x) diff --git a/pymc/step_methods/hmc/walnuts.py b/pymc/step_methods/hmc/walnuts.py new file mode 100644 index 0000000000..87654d089f --- /dev/null +++ b/pymc/step_methods/hmc/walnuts.py @@ -0,0 +1,845 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +WALNUTS (Within-orbit Adaptive Step-length No-U-Turn Sampler) implementation. + +Based on WALNUTSpy by Tore Selland Kleppe. +Reference: Bou-Rabee, N., Carpenter, B., Kleppe, T. S., & Liu, S. (2025). +The Within-Orbit Adaptive Leapfrog No-U-Turn Sampler. +arXiv preprint arXiv:2506.18746. +""" + +from __future__ import annotations + +import math +import sys + +import numpy as np + +from pymc.stats.convergence import SamplerWarning +from pymc.step_methods.compound import Competence +from pymc.step_methods.hmc.adaptive_integrators import fixedLeapFrog, integratorAuxPar +from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData +from pymc.step_methods.hmc.p2_quantile import P2quantile +from pymc.step_methods.hmc.walnuts_constants import __logZero as _logZero +from pymc.step_methods.hmc.walnuts_constants import __wtSumThresh as _wtSumThresh +from pymc.vartypes import continuous_types + +__all__ = ["WALNUTS"] + + +# Sequence of stop-checks during NUTS iteration (precomputed) +checks = None +k = 0 + + +def subTreePlan(nleaf): + """Create sequence of U-turn checks for tree building.""" + global checks + global k + checks = np.zeros((nleaf - 1, 2), dtype=int) + k = 0 + + def Uturn(a, b): + global k + global checks + checks[k, 0] = a + checks[k, 1] = b + k += 1 + + def subUturn(a, b): + if a != b: + m = math.floor((a + b) / 2) + subUturn(a, m) + subUturn(m + 1, b) + Uturn(a, b) + + subUturn(1, nleaf) + return checks + + +class stateStore: + """Utilities for storing states (avoiding storing the complete orbit).""" + + def __init__(self, d, n): + self.__stateStack = np.zeros((d, n)) + self.__stateStackId = np.zeros(n, dtype=int) + self.__stateStackUsed = np.full(n, False, dtype=bool) + + def statePush(self, id_, state): + full = True + for i in range(len(self.__stateStackId)): + if not self.__stateStackUsed[i]: + self.__stateStack[:, i] = state + self.__stateStackId[i] = id_ + self.__stateStackUsed[i] = True + full = False + break + if full: + sys.exit("stack full") + + def stateRead(self, id_): + inds = np.where(self.__stateStackId == id_)[0] + if len(inds > 0): + for i in range(len(inds)): + if self.__stateStackUsed[inds[i]]: + return self.__stateStack[:, inds[i]] + sys.exit("element not found for read in stack") + + def stateDeleteRange(self, from_, to_): + self.__stateStackUsed[ + np.logical_and(self.__stateStackId >= from_, self.__stateStackId <= to_) + ] = False + + def stateReset(self): + self.__stateStackUsed = np.full(len(self.__stateStackUsed), False, dtype=bool) + + def dump(self): + """Dump state stack for debugging.""" + return { + "values": self.__stateStack, + "ids": self.__stateStackId, + "used": self.__stateStackUsed, + } + + +def stopCondition(qm, vm, qp, vp): + """NUT stop condition.""" + tmp = qp - qm + return sum(vp * tmp) < 0.0 or sum(vm * tmp) < 0.0 + + +class WALNUTS(BaseHMC): + """Within-orbit Adaptive Step-length No-U-Turn Sampler. + + WALNUTS (Bou-Rabee et al., 2025) extends NUTS by adapting the integration step size within + each trajectory. This can improve numerical stability in models with varying curvature. + + Parameters + ---------- + vars : list, optional + Variables to sample. If None, all continuous variables in the model. + H0 : float, default=0.2 + Initial big step size / fixed big step size if adaptH=False. + stepSizeRandScale : float, default=0.2 + Step size randomization scale. + delta0 : float, default=0.05 + Initial integrator tolerance. + M : int, default=10 + Number of NUTS iterations (maximum tree depth). + integrator : function, default=fixedLeapFrog + Integrator function to use. + igrAux : integratorAuxPar, optional + Auxiliary parameters for integrators. + adaptH : bool, default=True + Whether to adapt the big step size H. + adaptHtarget : float, default=0.8 + Desired fraction of steps where crudest step size is accepted. + adaptDelta : bool, default=True + Whether to adapt the integrator tolerance delta. + adaptDeltaTarget : float, default=0.6 + Target for delta adaptation. + adaptDeltaQuantile : float, default=0.9 + Quantile for delta adaptation. + recordOrbitStats : bool, default=False + Whether to record orbit statistics. + **kwargs + Additional arguments passed to BaseHMC. + + References + ---------- + .. [1] Bou-Rabee, N., Carpenter, B., Kleppe, T. S., & Liu, S. (2025). + The Within-Orbit Adaptive Leapfrog No-U-Turn Sampler. + arXiv preprint arXiv:2506.18746. + https://arxiv.org/abs/2506.18746v1 + """ + + name = "walnuts" + + default_blocked = True + + stats_dtypes_shapes = { + "L_": (np.int64, []), + "NdoublSampled_": (np.int64, []), + "orbitLen_": (np.float64, []), + "orbitLenSam_": (np.float64, []), + "maxFint": (np.int64, []), + "maxBint": (np.int64, []), + "nevalF": (np.int64, []), + "nevalB": (np.int64, []), + "min_Ifs": (np.int64, []), + "max_Ifs": (np.int64, []), + "min_lwts": (np.float64, []), + "max_lwts": (np.float64, []), + "bothEndsPassive": (bool, []), + "oneEndPassive": (bool, []), + "mean_IfsNeqIbs": (np.float64, []), + "H": (np.float64, []), + "mean_IfsEq0": (np.float64, []), + "orbitEnergyError": (np.float64, []), + "delta": (np.float64, []), + "stopCode": (np.int64, []), + "NdoublComputed_": (np.int64, []), + "min_cs": (np.int64, []), + "max_cs": (np.int64, []), + "indexStat_": (np.float64, []), + # PyMC compatibility fields + "depth": (np.int64, []), + "step_size": (np.float64, []), + "tune": (bool, []), + "mean_tree_accept": (np.float64, []), + "step_size_bar": (np.float64, []), + "tree_size": (np.float64, []), + "diverging": (bool, []), + "energy_error": (np.float64, []), + "energy": (np.float64, []), + "max_energy_error": (np.float64, []), + "model_logp": (np.float64, []), + "process_time_diff": (np.float64, []), + "perf_counter_diff": (np.float64, []), + "perf_counter_start": (np.float64, []), + "largest_eigval": (np.float64, []), + "smallest_eigval": (np.float64, []), + "index_in_trajectory": (np.int64, []), + "reached_max_treedepth": (bool, []), + "warning": (SamplerWarning, None), + "n_steps_total": (np.int64, []), + "avg_steps_per_proposal": (np.float64, []), + } + + def __init__( + self, + vars=None, + H0=0.2, + stepSizeRandScale=0.2, + delta0=0.05, + M=10, + integrator=None, + igrAux=None, + adaptH=True, + adaptHtarget=0.8, + adaptDelta=True, + adaptDeltaTarget=0.6, + adaptDeltaQuantile=0.9, + recordOrbitStats=False, + max_error=None, + max_treedepth=None, + **kwargs, + ): + """Initialize WALNUTS sampler.""" + # WALNUTSpy parameters + self.H = H0 + self.stepSizeRandScale = stepSizeRandScale + self.delta = delta0 + # Allow max_treedepth to override M for PyMC compatibility + self.M = max_treedepth if max_treedepth is not None else M + self.walnuts_integrator = integrator if integrator is not None else fixedLeapFrog + self.igrAux = igrAux if igrAux is not None else integratorAuxPar() + self.adaptH = adaptH + self.adaptHtarget = adaptHtarget + self.adaptDelta = adaptDelta + self.adaptDeltaTarget = adaptDeltaTarget + self.adaptDeltaQuantile = adaptDeltaQuantile + self.recordOrbitStats = recordOrbitStats + + # For PyMC compatibility + self.max_treedepth = self.M + self.early_max_treedepth = min(8, self.M) + self.max_error = max_error if max_error is not None else delta0 + + # Adaptation setup + if self.adaptH: + if self.adaptHtarget < 0.0 or self.adaptHtarget > 1.0: + sys.exit("bad adaptHtarget") + self.igrConstQ = P2quantile(1.0 - self.adaptHtarget) + + if self.adaptDelta: + if self.adaptDeltaTarget < 0.0: + sys.exit("bad adaptDeltaTarget") + + # Make tables for all of the sub-uturn checks + self.plans = [] + for i in range(0, self.M): + self.plans.append(subTreePlan(2**i)) + + # Initialize parent class with remaining kwargs + super().__init__(vars, **kwargs) + + # Track iteration count for warmup + self.iterN = 0 + self.warmupIter = 1000 # Default warmup iterations + + # Energy error tracking for adaptation + if self.adaptDelta: + self.energyErrorInfFacs = np.zeros(self.warmupIter) + + def _hamiltonian_step(self, start, p0, step_size): + """Perform a single WALNUTS iteration.""" + # Use PyMC's step size if provided, otherwise use WALNUTS H + H = step_size if step_size is not None else self.H + + # Extract position and gradient from PyMC State + qc = start.q.data + grad0 = start.q_grad + d = len(qc) + + # Create lpFun wrapper for PyMC compatibility + def lpFun(q): + # Create a new State with updated position + new_q = start.q._replace(data=q) + new_state = self.integrator.compute_state(new_q, p0) + return new_state.model_logp, new_state.q_grad + + # Track iteration + self.iterN += 1 + warmup = self.tune and self.iterN <= self.warmupIter + + # Per iteration diagnostics info + nevalF = 0 + nevalB = 0 + Lold_ = 0 + L_ = 0 + orbitLen_ = 0.0 + orbitLenSam_ = 0.0 + NdoublSampled_ = 0 + NdoublComputed_ = 0 + indexStat_ = 0.0 + indexStatOld_ = 0.0 + timeLenF_ = 0.0 + timeLenB_ = 0.0 + + # Integration directions: 1=backward, 0=forward + B = np.floor(self.rng.uniform(low=0.0, high=2.0, size=self.M)).astype(int) + + # How many backward steps could there possibly be + nleft = sum(B * (2 ** np.arange(0, self.M))) + + # Allocate memory for intermediate states + states = stateStore(3 * d, 2 * (self.M + 1) + 1) + + # Memory for quantities stored for all states in orbit + Hs = np.zeros(2**self.M) + Ifs = np.zeros(2**self.M, dtype=int) + Ibs = np.zeros(2**self.M, dtype=int) + cs = np.zeros(2**self.M, dtype=int) + lwts = np.zeros(2**self.M) + + # I0 is the index of the zeroth state + I0 = nleft + Ifs[I0] = 0 + Ibs[I0] = 0 + cs[I0] = 0 + lwts[I0] = 0.0 + + # Endpoints of accepted and proposed trajectory + a = 0 + b = 0 + maxFint = 0 + maxBint = 0 + + # Full momentum refresh + v = self.rng.normal(size=d) + + # Endpoints of current orbit (p=forward, m=backward) + qp = qc + qm = qc + vp = v + vm = v + + # Current proposal + qProp = qc + qPropLast = qc + + # Evaluate at current state + f0, grad0 = lpFun(qc) + + # Gradients at either endpoint + gp = grad0 + gm = grad0 + + # Hamiltonian at initial point + Hs[I0] = -f0 + 0.5 * sum(v**2) + + # Index selection-related quantities + multinomialLscale = Hs[I0] + WoldSum = 1.0 + + lwtSumb = 0.0 + lwtSumf = 0.0 + + # Reject orbit if numerical problems occur + forcedReject = False + + # Stop if both multinomial bias at both ends are zero + bothEndsPassive = False + stopCode = 0 + + # NUT iteration loop + for i in range(self.M): + # Integration direction + xi = (-1) ** B[i] + # Proposed new endpoints + at = a + xi * (2**i) + bt = b + xi * (2**i) + + # More bookkeeping + expandFurther = True + qPropLast = qProp + Lold_ = L_ + indexStatOld_ = indexStat_ + + if i == 0: # Single first integration step required + HLoc = self.rng.uniform( + low=H * (1 - self.stepSizeRandScale), + high=H * (1 + self.stepSizeRandScale), + size=1, + )[0] + orbitLen_ += HLoc + + if xi == 1: # Forward integration + intOut = self.walnuts_integrator( + qp, vp, gp, Hs[I0], HLoc, xi, lpFun, self.delta, self.igrAux + ) + qp = intOut.q + vp = intOut.v + gp = intOut.grad + nevalF += intOut.nEvalF + nevalB += intOut.nEvalB + Hs[I0 + 1] = -intOut.lp + 0.5 * sum(vp * vp) + Ifs[I0 + 1] = intOut.If + Ibs[I0 + 1] = intOut.Ib + cs[I0 + 1] = intOut.c + lwts[I0 + 1] = intOut.lwt + if warmup and self.adaptH: + self.igrConstQ.push(np.log(intOut.igrConst)) + maxFint = 1 + timeLenF_ = HLoc + if not np.isfinite(Hs[I0 + 1]): + forcedReject = True + stopCode = 999 + break + + lwtSumf = lwts[I0 + 1] + Wnew = np.exp(-Hs[I0 + 1] + multinomialLscale + lwtSumf) + + qProp = qp + L_ = 1 + indexStat_ = timeLenF_ + + else: # Backward integration + intOut = self.walnuts_integrator( + qm, vm, gm, Hs[I0], HLoc, xi, lpFun, self.delta, self.igrAux + ) + qm = intOut.q + vm = intOut.v + gm = intOut.grad + nevalF += intOut.nEvalF + nevalB += intOut.nEvalB + Hs[I0 - 1] = -intOut.lp + 0.5 * sum(vm * vm) + Ifs[I0 - 1] = intOut.If + Ibs[I0 - 1] = intOut.Ib + cs[I0 - 1] = intOut.c + lwts[I0 - 1] = intOut.lwt + if warmup and self.adaptH: + self.igrConstQ.push(np.log(intOut.igrConst)) + maxBint = -1 + timeLenB_ = HLoc + if not np.isfinite(Hs[I0 - 1]): + forcedReject = True + stopCode = 999 + break + lwtSumb = lwts[I0 - 1] + Wnew = np.exp(-Hs[I0 - 1] + multinomialLscale + lwtSumb) + + qProp = qm + L_ = -1 + indexStat_ = -timeLenB_ + + WoldSum = 1.0 + WnewSum = Wnew + + else: # More than a single integration step, these require sub-u-turn checks + # Work out which sub-u-turn-checks we are doing + plan = 0 + if xi == 1: + plan = b + self.plans[i] + else: + plan = a - self.plans[i] + + WnewSum = 0.0 + + for j in range(len(plan)): # Loop over U-turn-checks + if abs(plan[j, 0] - plan[j, 1]) == 1: # New integration steps needed + HLoc1 = self.rng.uniform( + low=H * (1 - self.stepSizeRandScale), + high=H * (1 + self.stepSizeRandScale), + size=2, + ) + + if xi == -1: # Backward integration + i1 = plan[j, 0] + intOut = self.walnuts_integrator( + qm, + vm, + gm, + Hs[I0 + i1 + 1], + HLoc1[0], + xi, + lpFun, + self.delta, + self.igrAux, + ) + qm = intOut.q + vm = intOut.v + gm = intOut.grad + nevalF += intOut.nEvalF + nevalB += intOut.nEvalB + Hs[I0 + i1] = -intOut.lp + 0.5 * sum(vm * vm) + Ifs[I0 + i1] = intOut.If + Ibs[I0 + i1] = intOut.Ib + cs[I0 + i1] = intOut.c + lwts[I0 + i1] = intOut.lwt + if warmup and self.adaptH: + self.igrConstQ.push(np.log(intOut.igrConst)) + maxBint = i1 + timeLenB_ += HLoc1[0] + if not np.isfinite(Hs[I0 + i1]): + forcedReject = True + stopCode = 999 + break + + lwtSumb += lwts[I0 + i1] + + Wnew = np.exp(-Hs[I0 + i1] + multinomialLscale + lwtSumb) + WnewSum += Wnew + + # Online categorical sampling + if WnewSum > _wtSumThresh and self.rng.uniform() < Wnew / WnewSum: + qProp = qm + L_ = i1 + indexStat_ = -timeLenB_ + + states.statePush(i1, np.concatenate([qm, vm, gm])) + orbitLen_ += HLoc1[0] + + qtmp = qm + vtmp = vm + + # Second integration step + i2 = plan[j, 1] + intOut = self.walnuts_integrator( + qm, + vm, + gm, + Hs[I0 + i2 + 1], + HLoc1[1], + xi, + lpFun, + self.delta, + self.igrAux, + ) + qm = intOut.q + vm = intOut.v + gm = intOut.grad + nevalF += intOut.nEvalF + nevalB += intOut.nEvalB + Hs[I0 + i2] = -intOut.lp + 0.5 * sum(vm * vm) + Ifs[I0 + i2] = intOut.If + Ibs[I0 + i2] = intOut.Ib + cs[I0 + i2] = intOut.c + lwts[I0 + i2] = intOut.lwt + if warmup and self.adaptH: + self.igrConstQ.push(np.log(intOut.igrConst)) + maxBint = i2 + timeLenB_ += HLoc1[1] + if not np.isfinite(Hs[I0 + i2]): + forcedReject = True + break + + # Online categorical sampling + Wnew = np.exp(-Hs[I0 + i2] + multinomialLscale + lwtSumb) + WnewSum += Wnew + if WnewSum > _wtSumThresh and self.rng.uniform() < Wnew / WnewSum: + qProp = qm + L_ = i2 + indexStat_ = -timeLenB_ + + # Store state for future u-turn-checking + states.statePush(i2, np.concatenate([qm, vm, gm])) + orbitLen_ += HLoc1[1] + + # Uturn check + if stopCondition(qm, vm, qtmp, vtmp): + expandFurther = False + break + + else: # Forward integration + i1 = plan[j, 0] + intOut = self.walnuts_integrator( + qp, + vp, + gp, + Hs[I0 + i1 - 1], + HLoc1[0], + xi, + lpFun, + self.delta, + self.igrAux, + ) + qp = intOut.q + vp = intOut.v + gp = intOut.grad + nevalF += intOut.nEvalF + nevalB += intOut.nEvalB + Hs[I0 + i1] = -intOut.lp + 0.5 * sum(vp * vp) + Ifs[I0 + i1] = intOut.If + Ibs[I0 + i1] = intOut.Ib + cs[I0 + i1] = intOut.c + lwts[I0 + i1] = intOut.lwt + if warmup and self.adaptH: + self.igrConstQ.push(np.log(intOut.igrConst)) + maxFint = i1 + timeLenF_ += HLoc1[0] + if not np.isfinite(Hs[I0 + i1]): + forcedReject = True + stopCode = 999 + break + + lwtSumf += lwts[I0 + i1] + + # Online categorical sampling + Wnew = np.exp(-Hs[I0 + i1] + multinomialLscale + lwtSumf) + WnewSum += Wnew + if WnewSum > _wtSumThresh and self.rng.uniform() < Wnew / WnewSum: + qProp = qp + L_ = i1 + indexStat_ = timeLenF_ + + # Store state for future u-turn-checking + states.statePush(i1, np.concatenate([qp, vp, gp])) + orbitLen_ += HLoc1[0] + + qtmp = qp + vtmp = vp + + # Second integration step + i2 = plan[j, 1] + intOut = self.walnuts_integrator( + qp, + vp, + gp, + Hs[I0 + i2 - 1], + HLoc1[1], + xi, + lpFun, + self.delta, + self.igrAux, + ) + qp = intOut.q + vp = intOut.v + gp = intOut.grad + nevalF += intOut.nEvalF + nevalB += intOut.nEvalB + Hs[I0 + i2] = -intOut.lp + 0.5 * sum(vp * vp) + Ifs[I0 + i2] = intOut.If + Ibs[I0 + i2] = intOut.Ib + cs[I0 + i2] = intOut.c + lwts[I0 + i2] = intOut.lwt + if warmup and self.adaptH: + self.igrConstQ.push(np.log(intOut.igrConst)) + maxFint = i2 + timeLenF_ += HLoc1[1] + if not np.isfinite(Hs[I0 + i2]): + forcedReject = True + break + + # Multinomial/progressive sampling + lwtSumf += lwts[I0 + i2] + + Wnew = np.exp(-Hs[I0 + i2] + multinomialLscale + lwtSumf) + WnewSum += Wnew + if WnewSum > _wtSumThresh and self.rng.uniform() < Wnew / WnewSum: + qProp = qp + L_ = i2 + indexStat_ = timeLenF_ + + # Store state for future u-turn-checking + states.statePush(i2, np.concatenate([qp, vp, gp])) + orbitLen_ += HLoc1[1] + + if stopCondition(qtmp, vtmp, qp, vp): + expandFurther = False + break + # Done forward integration + else: # No new integration steps needed, only U-turn checks + # Delete states not needed further + im = min(plan[j, :]) + ip = max(plan[j, :]) + states.stateDeleteRange(im + 1, ip - 1) + statep = states.stateRead(ip) + statem = states.stateRead(im) + + if stopCondition( + statem[0:d], statem[d : (2 * d)], statep[0:d], statep[d : (2 * d)] + ): + expandFurther = False + break + # Done loop over j + + if forcedReject: + break + + indexStat_ = indexStat_ / (timeLenF_ + timeLenB_) + + if not expandFurther: + # Proposed subOrbit had a sub-U-turn + qProp = qPropLast + L_ = Lold_ + indexStat_ = indexStatOld_ + NdoublSampled_ = i + NdoublComputed_ = i + 1 + stopCode = 5 + break + else: + # The proposed sub-orbit was found to be free of u-turns + # Now check if proposed state should be from old or new sub-orbit + if not (self.rng.uniform() < WnewSum / WoldSum): + L_ = Lold_ + indexStat_ = indexStatOld_ + qProp = qPropLast + + # Proposed suborbit free of U-turns + # Final U-turn check + joinedCrit = stopCondition(qm, vm, qp, vp) + # Stop simulation if multinomial weights at either end are effectively zero + bothEndsPassive = lwtSumb < _logZero + 1.0 and lwtSumf < _logZero + 1.0 + if joinedCrit or bothEndsPassive: + if joinedCrit: + stopCode = 4 + else: + stopCode = -4 + + NdoublSampled_ = i + 1 + NdoublComputed_ = i + 1 + orbitLenSam_ = orbitLen_ + break + + # From now on, it is clear that a new doubling will be attempted + WoldSum += WnewSum + + orbitLenSam_ = orbitLen_ + NdoublSampled_ = i + 1 + NdoublComputed_ = i + 1 + a = min(a, at) + b = max(b, bt) + states.stateReset() + + # Done NUTS loop + + # Store samples and diagnostics info + if maxBint < 0 and maxFint > 0: + usedSteps = np.r_[maxBint:0, 1 : (maxFint + 1)] + elif maxBint < 0: + usedSteps = np.r_[maxBint:0] + else: + usedSteps = np.r_[1 : (maxFint + 1)] + + enUsedSteps = np.r_[0, usedSteps] + orbitEnergyError = np.max(Hs[I0 + enUsedSteps]) - np.min(Hs[I0 + enUsedSteps]) + + # Tuning parameter adaptation + if warmup: + # Tuning of local error threshold delta + if self.adaptDelta: + self.energyErrorInfFacs[self.iterN - 1] = orbitEnergyError / self.delta + if self.adaptDelta and self.iterN > 10: + self.delta = self.adaptDeltaTarget / np.quantile( + self.energyErrorInfFacs[0 : self.iterN], self.adaptDeltaQuantile + ) + + # Tuning of big step size H + if self.adaptH and self.igrConstQ.npush > 10: + self.H = ((self.delta) ** (1.0 / 3.0)) * np.exp(self.igrConstQ.quantile()) + + # Create final state for PyMC + qc = qProp + new_q = start.q._replace(data=qc) + final_state = self.integrator.compute_state(new_q, p0) + + # Prepare statistics + divergence_info = None + if forcedReject: + divergence_info = DivergenceInfo( + "Numerical problems in WALNUTS", + None, + final_state, + None, + ) + + stats = { + # WALNUTSpy statistics + "L_": L_, + "NdoublSampled_": NdoublSampled_, + "orbitLen_": orbitLen_, + "orbitLenSam_": orbitLenSam_, + "maxFint": maxFint, + "maxBint": maxBint, + "nevalF": nevalF, + "nevalB": nevalB, + "min_Ifs": np.min(Ifs[I0 + usedSteps]) if len(usedSteps) > 0 else 0, + "max_Ifs": np.max(Ifs[I0 + usedSteps]) if len(usedSteps) > 0 else 0, + "min_lwts": np.min(lwts[I0 + usedSteps]) if len(usedSteps) > 0 else 0.0, + "max_lwts": np.max(lwts[I0 + usedSteps]) if len(usedSteps) > 0 else 0.0, + "bothEndsPassive": bothEndsPassive, + "oneEndPassive": lwtSumb < _logZero + 1.0 or lwtSumf < _logZero + 1.0, + "mean_IfsNeqIbs": ( + np.mean(Ifs[I0 + usedSteps] != Ibs[I0 + usedSteps]) if len(usedSteps) > 0 else 0.0 + ), + "H": self.H, + "mean_IfsEq0": np.mean(Ifs[I0 + usedSteps] == 0) if len(usedSteps) > 0 else 0.0, + "orbitEnergyError": orbitEnergyError, + "delta": self.delta, + "stopCode": stopCode, + "NdoublComputed_": NdoublComputed_, + "min_cs": np.min(cs[I0 + usedSteps]) if len(usedSteps) > 0 else 0, + "max_cs": np.max(cs[I0 + usedSteps]) if len(usedSteps) > 0 else 0, + "indexStat_": indexStat_, + # PyMC compatibility statistics + "depth": NdoublSampled_, + "step_size": self.H, + "tune": self.tune, + "mean_tree_accept": np.exp(-orbitEnergyError), # Approximation + "step_size_bar": self.H, + "tree_size": nevalF + nevalB, + "diverging": forcedReject, + "energy_error": orbitEnergyError, + "energy": final_state.energy, + "max_energy_error": orbitEnergyError, + "model_logp": final_state.model_logp, + "index_in_trajectory": final_state.index_in_trajectory, + "reached_max_treedepth": NdoublSampled_ >= self.M, + "n_steps_total": nevalF + nevalB, + "avg_steps_per_proposal": (nevalF + nevalB) / max(1, NdoublSampled_), + "largest_eigval": np.nan, + "smallest_eigval": np.nan, + } + + return HMCStepData(final_state, 1, divergence_info, stats) + + @staticmethod + def competence(var, has_grad): + """Check if WALNUTS can sample this variable.""" + if var.dtype in continuous_types and has_grad: + return Competence.COMPATIBLE + return Competence.INCOMPATIBLE diff --git a/pymc/step_methods/hmc/walnuts_constants.py b/pymc/step_methods/hmc/walnuts_constants.py new file mode 100644 index 0000000000..748ff6d05e --- /dev/null +++ b/pymc/step_methods/hmc/walnuts_constants.py @@ -0,0 +1,24 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Constants for WALNUTS sampler. + +Based on constants.py from WALNUTSpy by Tore Selland Kleppe. +""" + +import numpy as np + +# Numerical constants +__logZero = -700.0 # Used for communicating that a probability weight is zero +__wtSumThresh = np.exp(__logZero + 1.0) diff --git a/tests/step_methods/hmc/test_walnuts.py b/tests/step_methods/hmc/test_walnuts.py new file mode 100644 index 0000000000..21d66c64ff --- /dev/null +++ b/tests/step_methods/hmc/test_walnuts.py @@ -0,0 +1,186 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import numpy as np +import numpy.testing as npt +import pytest + +import pymc as pm + +from pymc.exceptions import SamplingError +from pymc.step_methods.hmc import WALNUTS +from tests import sampler_fixtures as sf +from tests.helpers import RVsAssignmentStepsTester, StepMethodTester + + +class WalnutsFixture(sf.BaseSampler): + @classmethod + def make_step(cls): + args = {} + if hasattr(cls, "step_args"): + args.update(cls.step_args) + if "scaling" not in args: + _, step = pm.sampling.mcmc.init_nuts(n_init=10000, **args) + # Replace the NUTS step with WALNUTS but keep the same mass matrix + step = pm.WALNUTS(potential=step.potential, target_accept=step.target_accept, **args) + else: + step = pm.WALNUTS(**args) + return step + + def test_target_accept(self): + accept = self.trace[self.burn :]["mean_tree_accept"] + npt.assert_allclose(accept.mean(), self.step.target_accept, 1) + + +# Basic distribution tests - these are relevant for WALNUTS since it's a general HMC sampler +class TestWALNUTSUniform(WalnutsFixture, sf.UniformFixture): + n_samples = 5000 # Reduced for faster testing + tune = 500 + burn = 500 + chains = 2 + min_n_eff = 2000 + rtol = 0.1 + atol = 0.05 + step_args = {"random_seed": 202010} + + +class TestWALNUTSNormal(WalnutsFixture, sf.NormalFixture): + n_samples = 5000 # Reduced for faster testing + tune = 500 + burn = 0 + chains = 2 + min_n_eff = 4000 + rtol = 0.1 + atol = 0.05 + step_args = {"random_seed": 123456} + + +# WALNUTS-specific functionality tests +class TestWalnutsSpecific: + def test_walnuts_specific_stats(self): + """Test that WALNUTS produces its specific statistics.""" + with pm.Model(): + pm.Normal("x", mu=0, sigma=1) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + trace = pm.sample( + draws=10, tune=5, chains=1, return_inferencedata=False, step=pm.WALNUTS() + ) + + # Check WALNUTS-specific stats are present + walnuts_stats = ["n_steps_total", "avg_steps_per_proposal"] + for stat in walnuts_stats: + assert stat in trace.stat_names, f"WALNUTS-specific stat '{stat}' missing" + stats_values = trace.get_sampler_stats(stat) + assert stats_values.shape == (10,), f"Wrong shape for {stat}" + assert np.all(stats_values >= 0), f"{stat} should be non-negative" + + # Check that n_steps_total makes sense relative to tree_size + n_steps = trace.get_sampler_stats("n_steps_total") + tree_size = trace.get_sampler_stats("tree_size") + # n_steps_total should generally be >= tree_size (adaptive steps might use more steps) + assert np.all(n_steps >= tree_size), "n_steps_total should be >= tree_size" + + def test_walnuts_parameters(self): + """Test WALNUTS-specific parameters.""" + with pm.Model(): + pm.Normal("x", mu=0, sigma=1) + + # Test custom max_error parameter + step = pm.WALNUTS(max_error=0.5, max_treedepth=8) + assert step.max_error == 0.5 + assert step.max_treedepth == 8 + + # Test early_max_treedepth + assert hasattr(step, "early_max_treedepth") + + def test_bad_init_handling(self): + """Test that WALNUTS handles bad initialization properly.""" + with pm.Model(): + pm.HalfNormal("a", sigma=1, initval=-1, default_transform=None) + with pytest.raises(SamplingError) as error: + pm.sample(chains=1, random_seed=1, step=pm.WALNUTS()) + error.match("Bad initial energy") + + def test_competence_method(self): + """Test WALNUTS competence for different variable types.""" + from pymc.step_methods.compound import Competence + + # Mock continuous variable with gradient + class MockVar: + dtype = "float64" # continuous_types contains strings, not dtype objects + + var = MockVar() + assert WALNUTS.competence(var, has_grad=True) == Competence.COMPATIBLE + assert WALNUTS.competence(var, has_grad=False) == Competence.INCOMPATIBLE + + def test_required_attributes(self): + """Test that WALNUTS has all required attributes.""" + with pm.Model(): + pm.Normal("x", mu=0, sigma=1) + step = pm.WALNUTS() + + # Check required attributes + assert hasattr(step, "name") + assert step.name == "walnuts" + assert hasattr(step, "default_blocked") + assert step.default_blocked is True + assert hasattr(step, "stats_dtypes_shapes") + + # Check WALNUTS-specific stats are defined + required_stats = ["n_steps_total", "avg_steps_per_proposal"] + for stat in required_stats: + assert stat in step.stats_dtypes_shapes + + +# Test step method functionality +class TestStepWALNUTS(StepMethodTester): + @pytest.mark.parametrize( + "step_fn, draws", + [ + (lambda C, _: WALNUTS(scaling=C, is_cov=True, blocked=False), 1000), + (lambda C, _: WALNUTS(scaling=C, is_cov=True), 1000), + ], + ) + def test_step_continuous(self, step_fn, draws): + self.step_continuous(step_fn, draws) + + +class TestRVsAssignmentWALNUTS(RVsAssignmentStepsTester): + @pytest.mark.parametrize("step, step_kwargs", [(WALNUTS, {})]) + def test_continuous_steps(self, step, step_kwargs): + self.continuous_steps(step, step_kwargs) + + +def test_walnuts_step_legacy_value_grad_function(): + """Test WALNUTS with legacy value grad function (compatibility test).""" + with pm.Model() as m: + x = pm.Normal("x", shape=(2,)) + y = pm.Normal("y", x, shape=(3, 2)) + + legacy_value_grad_fn = m.logp_dlogp_function(ravel_inputs=False, mode="FAST_COMPILE") + legacy_value_grad_fn.set_extra_values({}) + walnuts = WALNUTS(model=m, logp_dlogp_func=legacy_value_grad_fn) + + # Confirm it is a function of multiple variables + logp, dlogp = walnuts._logp_dlogp_func([np.zeros((2,)), np.zeros((3, 2))]) + np.testing.assert_allclose(dlogp, np.zeros(8)) + + # Confirm we can perform a WALNUTS step + ip = m.initial_point() + new_ip, _ = walnuts.step(ip) + assert np.all(new_ip["x"] != ip["x"]) + assert np.all(new_ip["y"] != ip["y"])