Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ RUN <<"EOF" bash -ex
for component in $(ls ${BUILD_PATH_JAXLIB}); do
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
done
echo "-e file://${BUILD_PATH_JAXLIB}/jax[k8s]" >> /opt/pip-tools.d/requirements-jax.in
echo "-e file://${SRC_PATH_JAX}[k8s]" >> /opt/pip-tools.d/requirements-jax.in
EOF

## Flax
Expand Down
8 changes: 4 additions & 4 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ time python "${SRC_PATH_JAX}/build/build.py" build \
--editable \
--use_clang \
--use_new_wheel_build_rule \
--wheels=jax,jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
--bazel_options=--linkopt=-fuse-ld=lld \
--local_xla_path=$SRC_PATH_XLA \
Expand All @@ -320,10 +320,10 @@ popd

# Make sure that JAX depends on the local jaxlib installation
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
line="jax @ file://${BUILD_PATH_JAXLIB}/jax"
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
local_jax_whl="jax @ file://${SRC_PATH_JAX}"
if ! grep -xF "${local_jax_whl}" "${SRC_PATH_JAX}/build/requirements.in"; then
pushd "${SRC_PATH_JAX}"
echo "${line}" >> build/requirements.in
echo "${local_jax_whl}" >> build/requirements.in
echo "jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_cuda${TF_CUDA_MAJOR_VERSION}_pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_cuda${TF_CUDA_MAJOR_VERSION}_plugin" >> build/requirements.in
Expand Down
4 changes: 2 additions & 2 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ENABLE_X64=-1

query_tests() {
cd ${SRC_PATH_JAX}
python build/build.py build --use_new_wheel_build_rule --wheels=jax,jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
python build/build.py build --use_new_wheel_build_rule --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
bazel query tests/... 2>&1 | grep -F '//tests:'
exit
}
Expand Down Expand Up @@ -196,5 +196,5 @@ pip install matplotlib
## Run tests

cd ${SRC_PATH_JAX}
python build/build.py build --use_new_wheel_build_rule --wheels=jax,jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
python build/build.py build --use_new_wheel_build_rule --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}
Loading