@@ -177,7 +177,10 @@ clean() {
177177 popd
178178}
179179
180- CUDA_MAJOR_VERSION=" ${CUDA_VERSION: 0: 2} "
180+ # Derive CUDA_MAJOR_VERSION from CUDA_VERSION if it isn't set
181+ if [ -z ${CUDA_MAJOR_VERSION+x} ]; then
182+ CUDA_MAJOR_VERSION=" ${CUDA_VERSION: 0: 2} "
183+ fi
181184PYTHON_VERSION=$( python -c ' import sys; print("{}.{}".format(*sys.version_info[:2]))' )
182185if [[ " $CUDA_COMPUTE_CAPABILITIES " == " all" ]]; then
183186 CUDA_COMPUTE_CAPABILITIES=$( supported_compute_capabilities ${CPU_ARCH} )
@@ -263,17 +266,21 @@ time python "${SRC_PATH_JAX}/build/build.py" build \
263266
264267# Make sure that JAX depends on the local jaxlib installation
265268# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
269+ old_hash=($( md5sum build/requirements.in) )
266270for component in jaxlib " jax-cuda${CUDA_MAJOR_VERSION} -pjrt" " jax-cuda${CUDA_MAJOR_VERSION} -plugin" ; do
267- # Note that this also drops the [with-cuda] extra from the committed
268- # version, so nvidia-*-cu12 wheels disappear from the lock file
269271 sed -i " s|^${component} .*$|${component} @ file://${BUILD_PATH_JAXLIB} /${component// -/ _} |" build/requirements.in
270272done
273+ new_hash=($( md5sum build/requirements.in) )
271274# Bazel args to avoid cache invalidation
272275BAZEL_ARGS=(
273276 --config=cuda_libraries_from_stubs
274277 --repo_env=HERMETIC_PYTHON_VERSION=" ${PYTHON_VERSION} "
275278)
276- bazel run " ${BAZEL_ARGS[@]} " --verbose_failures=true //build:requirements.update
279+ # //build:requirements.update can be quite slow; only run it if we actually
280+ # modified requirements.in just above.
281+ if [[ " ${old_hash} " != " ${new_hash} " ]]; then
282+ bazel run " ${BAZEL_ARGS[@]} " --verbose_failures=true //build:requirements.update
283+ fi
277284if (( "${# EXTRA_TARGETS[@]} " > 0 )) ; then
278285 bazel build " ${BAZEL_ARGS[@]} " --verbose_failures=true " ${EXTRA_TARGETS[@]} "
279286 if [[ -n " ${EXTRA_TARGET_DEST} " ]]; then
0 commit comments