-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Hi @Joshuaalbert as requested in #210 here is a simple test case that gets OOM killed on my computer. This happens after humming along at around 12g virtual and 8g resident while sampling (25% memory usage). After it hits the termination condition of 5 million samples, I can see the memory spike in top
and I find a total-vm of about 36g reported in /var/log/syslog by the OOM killer.
import os
os.environ['JAX_ENABLE_X64'] = 'true'
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={os.cpu_count()}'
import jax
from jax import numpy as jnp
from jaxns import Prior, Model, NestedSampler, summary
import tensorflow_probability.substrates.jax.distributions as tfpd
def prior_model():
shape = (4, 4)
mat = yield Prior(tfpd.Uniform(low=-jnp.ones(shape), high=jnp.ones(shape)))
frob = yield Prior(jnp.square(mat).sum(), name='loss')
return frob
model = Model(prior_model=prior_model, log_likelihood=(lambda t: -1e9 * t))
sampler = NestedSampler(model=model, num_live_points=2**19, max_samples=5e6, verbose=True)
reason, state = jax.jit(sampler)(jax.random.key(0))
results = sampler.to_results(termination_reason=reason, state=state)
summary(results)
I've seen the same behavior with real world models, where the memory usage reported by top
was at or under 10% until a normal termination condition was reached ("small remaining evidence").
Metadata
Metadata
Assignees
Labels
No labels