@@ -51,6 +51,7 @@ usage() {
5151 echo " --clean-only Do not build, just cleanup"
5252 echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc."
5353 echo " --debug Build in debug mode"
54+ echo " -r, --release Build for release"
5455 echo " -h, --help Print usage."
5556 echo " --install Install the JAX wheels when build succeeds"
5657 echo " --no-install Do not install the JAX wheels when build succeeds"
@@ -76,10 +77,11 @@ CPU_ARCH="$(dpkg --print-architecture)"
7677CUDA_COMPUTE_CAPABILITIES=" local"
7778DEBUG=0
7879INSTALL=1
80+ IS_RELEASE=0
7981SRC_PATH_JAX=" /opt/jax"
8082SRC_PATH_XLA=" /opt/xla"
8183
82- args=$( getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- " $@ " )
84+ args=$( getopt -o h,r --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,release ,cpu-arch:,debug,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- " $@ " )
8385if [[ $? -ne 0 ]]; then
8486 exit 1
8587fi
@@ -123,6 +125,10 @@ while [ : ]; do
123125 DEBUG=1
124126 shift 1
125127 ;;
128+ -r | --release)
129+ IS_RELEASE=1
130+ shift 1
131+ ;;
126132 -h | --help)
127133 usage 1
128134 ;;
@@ -206,6 +212,7 @@ print_var INSTALL
206212print_var PYTHON_VERSION
207213print_var SRC_PATH_JAX
208214print_var SRC_PATH_XLA
215+ print_var IS_RELEASE
209216
210217echo " =================================================="
211218
@@ -248,6 +255,12 @@ done
248255bazel run --config=cuda_libraries_from_stubs --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=" ${PYTHON_VERSION} "
249256popd
250257
258+
259+ if [[ " ${IS_RELEASE} " == " 1" ]]; then
260+ jaxlib_version=$( pip show jaxlib | grep Version | tr ' :' ' \n' | tail -1)
261+ sed -i " s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.5.0',|" /opt/jax/setup.py
262+ fi
263+
251264# # Install the built packages
252265
253266if [[ " ${INSTALL} " == " 1" ]]; then
0 commit comments