Skip to content

Commit d474281

Browse files
authored
Fix TE build broken by #1409 (#1415)
This will still not build until openxla/xla#25678 lands to partially revert openxla/xla#25402.
1 parent e20714c commit d474281

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

.github/container/Dockerfile.jax

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ RUN build-jax.sh \
5959
## Transformer engine: check out source and build wheel
6060
RUN <<"EOF" bash -ex -o pipefail
6161
pip install ninja && rm -rf ~/.cache/pip
62-
# TransformerEngine now needs JAX at build time
6362
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
6463
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
6564
export NVTE_BUILD_THREADS_PER_JOB=8
65+
export NVTE_FRAMEWORK=jax
66+
# TransformerEngine needs FFI headers from XLA
67+
export XLA_HOME=${SRC_PATH_XLA}
6668
python setup.py bdist_wheel && rm -rf build
6769
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
6870
EOF
@@ -114,7 +116,6 @@ echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in
114116
EOF
115117

116118
# Copy TransformerEngine wheel from the builder stage
117-
ENV NVTE_FRAMEWORK=jax
118119
ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE}
119120
COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
120121
RUN <<"EOF" bash -ex

0 commit comments

Comments
 (0)