-
Notifications
You must be signed in to change notification settings - Fork 18
Description
I have an issue with the number of threads spawned when running training with kauldron with multiprocessing enabled.
On the example of mnist_autoencoder, I get ~1500 threads with spike towards ~2000 during eval.
In the real training scenario, I get ~2500 threads with spike towards ~4000 during eval.
This poses me a problem, as with a few training jobs running in parallel, I quickly reach the user limit for threads.
From the telemetry I see from my colleagues, equivalent torch code with torch dataloader barely reaches 200 threads per job, without specifically tuning anything.
Is the number of threads high by design choice, or is this an issue in running grain+orbax+jax+kauldron.
Below are details to reproduce the problem:
dockerfile:
ARG BASE_IMAGE=nvcr.io/nvidia/driver:570.148.08-ubuntu22.04
FROM ${BASE_IMAGE}
RUN echo "deb http://archive.ubuntu.com/ubuntu jammy main universe restricted multiverse" > /etc/apt/sources.list && \
echo "deb http://archive.ubuntu.com/ubuntu jammy-updates main universe restricted multiverse" >> /etc/apt/sources.list && \
echo "deb http://archive.ubuntu.com/ubuntu jammy-backports main universe restricted multiverse" >> /etc/apt/sources.list && \
echo "deb http://security.ubuntu.com/ubuntu jammy-security main universe restricted multiverse" >> /etc/apt/sources.list
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
cmake \
curl \
git \
wget \
gnupg \
libjpeg-dev \
software-properties-common \
libpng-dev && \
rm -rf /var/lib/apt/lists/*
RUN /usr/sbin/update-ccache-symlinks
RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache
ENV PATH /opt/conda/bin:$PATH
#
ARG PYTHON_VERSION=3.12
# Automatically set by buildx
ARG TARGETPLATFORM
# translating Docker's TARGETPLATFORM into miniconda arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MINICONDA_ARCH=aarch64 ;; \
*) MINICONDA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh"
RUN chmod +x ~/miniconda.sh && \
bash ~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
/opt/conda/bin/conda clean -ya
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ca-certificates \
libjpeg-dev \
libpng-dev \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip
RUN pip install "jax[cuda12]" kauldron==1.2.2running as
python -m kauldron.main --cfg=/workdir/mnist_autoencoder.py --cfg.workdir=/workdir/model/
with num_workers=0
ps -eo pid,comm,nlwp --sort=-nlwp
PID COMMAND NLWP
1 python3 157
92 sh 1
98 sh 1
99 bash 1
319 ps 1
with num_workers=16
ps -eo pid,comm,nlwp --sort=-nlwp
PID COMMAND NLWP
1 python3 212
1737 python3 82
1738 python3 82
1739 python3 82
1741 python3 82
1742 python3 82
1743 python3 82
1745 python3 82
1747 python3 82
1748 python3 82
1749 python3 82
1752 python3 82
1753 python3 82
1754 python3 82
1755 python3 82
1756 python3 82
1757 python3 82
197 python3 9
3107 sh 1
3113 sh 1
3114 bash 1
3128 ps 1