Skip to content

Spike in memory usage and OOM kill upon reaching termination condition #232

@tavin

Description

@tavin

Hi @Joshuaalbert as requested in #210 here is a simple test case that gets OOM killed on my computer. This happens after humming along at around 12g virtual and 8g resident while sampling (25% memory usage). After it hits the termination condition of 5 million samples, I can see the memory spike in top and I find a total-vm of about 36g reported in /var/log/syslog by the OOM killer.

import os

os.environ['JAX_ENABLE_X64'] = 'true'
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={os.cpu_count()}'

import jax
from jax import numpy as jnp
from jaxns import Prior, Model, NestedSampler, summary
import tensorflow_probability.substrates.jax.distributions as tfpd


def prior_model():
    shape = (4, 4)
    mat = yield Prior(tfpd.Uniform(low=-jnp.ones(shape), high=jnp.ones(shape)))
    frob = yield Prior(jnp.square(mat).sum(), name='loss')
    return frob


model = Model(prior_model=prior_model, log_likelihood=(lambda t: -1e9 * t))
sampler = NestedSampler(model=model, num_live_points=2**19, max_samples=5e6, verbose=True)
reason, state = jax.jit(sampler)(jax.random.key(0))
results = sampler.to_results(termination_reason=reason, state=state)
summary(results)

I've seen the same behavior with real world models, where the memory usage reported by top was at or under 10% until a normal termination condition was reached ("small remaining evidence").

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions