File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -59,10 +59,12 @@ RUN build-jax.sh \
5959## Transformer engine: check out source and build wheel
6060RUN <<"EOF" bash -ex -o pipefail
6161pip install ninja && rm -rf ~/.cache/pip
62- # TransformerEngine now needs JAX at build time
6362git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
6463pushd ${SRC_PATH_TRANSFORMER_ENGINE}
6564export NVTE_BUILD_THREADS_PER_JOB=8
65+ export NVTE_FRAMEWORK=jax
66+ # TransformerEngine needs FFI headers from XLA
67+ export XLA_HOME=${SRC_PATH_XLA}
6668python setup.py bdist_wheel && rm -rf build
6769ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
6870EOF
@@ -114,7 +116,6 @@ echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in
114116EOF
115117
116118# Copy TransformerEngine wheel from the builder stage
117- ENV NVTE_FRAMEWORK=jax
118119ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE}
119120COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
120121RUN <<"EOF" bash -ex
You can’t perform that action at this time.
0 commit comments