Skip to content

Conversation

@davecwright3
Copy link
Collaborator

  • Adds checkpointing to NumPyro sampler
  • Allows users to resume sampling from a saved sampler state

also allows users to resume sampling
@davecwright3 davecwright3 requested a review from vallis September 5, 2025 21:20
@davecwright3
Copy link
Collaborator Author

davecwright3 commented Sep 5, 2025

I did try to try to reformat the NumPyro sampler state such that it could be stored as JSON, but there are fundamental incompatibilities between the sampler state namedtuple and JSON such as dictionaries with tuples as keys.

Here is my attempt at the bi-directional conversion if anyone is interested in fixing the issues mentioned above:

from numpyro.infer.hmc import HMCState
from numpy.infer.hmc_util import HMCAdaptState

def _namedtuple_to_dict(input_ntuple):
    """Recursively converts a nested namedtuple into a dictionary that is JSON serializable."""
    if hasattr(input_ntuple, "_asdict"):
        return {key: _namedtuple_to_dict(value) for key, value in input_ntuple._asdict().items()}
    elif isinstance(input_ntuple, dict):
        return {key: _namedtuple_to_dict(value) for key, value in input_ntuple.items()}
    elif isinstance(input_ntuple, list):
        return [_namedtuple_to_dict(item) for item in input_ntuple]
    elif isinstance(input_ntuple, tuple):
        return tuple((_namedtuple_to_dict(item) for item in input_ntuple))
    elif hasattr(input_ntuple, "__array__"):
        try:
            return input_ntuple.tolist()
        except AttributeError:
            return jax.random.key_data(input_ntuple).tolist()
    else:
        return input_ntuple


def _dict_to_namedtuple_preprocess(input_dict):
    if isinstance(input_dict, list) or isinstance(input_dict, int) or isinstance(input_dict, float):
        try:
            array = jnp.array(input_dict)
        # for prng keys
        except OverflowError:
            array = jnp.array(input_dict, jnp.uint32)
        return array
    if isinstance(input_dict, tuple):
        return tuple(jnp.array(item) for item in input_dict)
    if input_dict is None:
        return None

    return {key: _dict_to_namedtuple_preprocess(value) for key, value in input_dict.items()}


def _dict_to_namedtuple(input_dict):
    """Recursively converts a possibly nested dictionary into a namedtuple."""
    processed_dict = _dict_to_namedtuple_preprocess(input_dict)
    # Check if the dictionary's keys match the fields of the NumPyro namedtuples.
    numpyro_hmc_namedtuple_fields = frozenset(HMCState._fields)
    keys = frozenset(processed_dict.keys())
    if keys == numpyro_hmc_namedtuple_fields:
        processed_dict["adapt_state"] = HMCAdaptState(**processed_dict["adapt_state"])
        return HMCState(**processed_dict)
    msg = (
        "The dictionary keys do not match the expected NumPyro keys."
        f"The given keys were {keys}."
        f"The expected keys were {numpyro_hmc_namedtuple_fields}."
        f"The difference is {numpyro_hmc_namedtuple_fields - keys}"
    )
    raise KeyError(msg)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant