Skip to content

Commit 8bbcfa3

Browse files
committed
Merge branch '25.07-devel' into 25.07-devel-add-ngc-release-testing
2 parents cd29eab + 612babc commit 8bbcfa3

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

.github/container/build-jax.sh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ usage() {
5151
echo " --clean-only Do not build, just cleanup"
5252
echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc."
5353
echo " --debug Build in debug mode"
54+
echo " -r, --release Build for release"
5455
echo " -h, --help Print usage."
5556
echo " --install Install the JAX wheels when build succeeds"
5657
echo " --no-install Do not install the JAX wheels when build succeeds"
@@ -76,10 +77,11 @@ CPU_ARCH="$(dpkg --print-architecture)"
7677
CUDA_COMPUTE_CAPABILITIES="local"
7778
DEBUG=0
7879
INSTALL=1
80+
IS_RELEASE=0
7981
SRC_PATH_JAX="/opt/jax"
8082
SRC_PATH_XLA="/opt/xla"
8183

82-
args=$(getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
84+
args=$(getopt -o h,r --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,release,cpu-arch:,debug,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
8385
if [[ $? -ne 0 ]]; then
8486
exit 1
8587
fi
@@ -123,6 +125,10 @@ while [ : ]; do
123125
DEBUG=1
124126
shift 1
125127
;;
128+
-r | --release)
129+
IS_RELEASE=1
130+
shift 1
131+
;;
126132
-h | --help)
127133
usage 1
128134
;;
@@ -206,6 +212,7 @@ print_var INSTALL
206212
print_var PYTHON_VERSION
207213
print_var SRC_PATH_JAX
208214
print_var SRC_PATH_XLA
215+
print_var IS_RELEASE
209216

210217
echo "=================================================="
211218

@@ -248,6 +255,12 @@ done
248255
bazel run --config=cuda_libraries_from_stubs --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
249256
popd
250257

258+
259+
if [[ "${IS_RELEASE}" == "1" ]]; then
260+
jaxlib_version=$(pip show jaxlib | grep Version | tr ':' '\n' | tail -1)
261+
sed -i "s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.5.0',|" /opt/jax/setup.py
262+
fi
263+
251264
## Install the built packages
252265

253266
if [[ "${INSTALL}" == "1" ]]; then

0 commit comments

Comments
 (0)