Skip to content

Commit 589cea1

Browse files
authored
Set NVTE_CUDA_ARCHS from CUDA_ARCH_LIST (#1427)
Extends #1406 to the TransformerEngine build.
1 parent 4abaae5 commit 589cea1

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

.github/container/Dockerfile.jax

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ pip install ninja && rm -rf ~/.cache/pip
6262
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
6363
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
6464
export NVTE_BUILD_THREADS_PER_JOB=8
65+
if [[ -n "${CUDA_ARCH_LIST//}" ]]; then
66+
# "1.2 3.4 5.6" -> "12-real;34-real;56", i.e. SASS plus PTX for the last one
67+
NVTE_CUDA_ARCHS="${CUDA_ARCH_LIST// /-real;}"
68+
export NVTE_CUDA_ARCHS="${NVTE_CUDA_ARCHS//./}"
69+
fi
6570
export NVTE_FRAMEWORK=jax
6671
# TransformerEngine needs FFI headers from XLA
6772
export XLA_HOME=${SRC_PATH_XLA}

0 commit comments

Comments
 (0)