Skip to content

Commit dfdf80f

Browse files
authored
Merge branch 'main' into aybchan/xpk-private
2 parents 6ca4190 + a77491d commit dfdf80f

File tree

3 files changed

+17
-96
lines changed

3 files changed

+17
-96
lines changed

.github/container/Dockerfile.maxtext

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@ EOF
1818

1919
RUN <<"EOF" bash -ex -o pipefail
2020
echo "-e file://${SRC_PATH_MAXTEXT}" >> /opt/pip-tools.d/requirements-maxtext.in
21-
echo "-r ${SRC_PATH_MAXTEXT}/src/install_maxtext_extra_deps/extra_deps_from_github.txt" >> /opt/pip-tools.d/requirements-maxtext.in
21+
echo "-r ${SRC_PATH_MAXTEXT}/base_requirements/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
2222
EOF
2323

2424
# add version constraints to avoid eternal dependency resolution
2525
RUN <<"EOF" bash -ex -o pipefail
2626
for pattern in \
27-
"s|tensorflow>=2.19.1|tensorflow==2.18.1|g" \
28-
"s|tensorboard>=2.19.0|tensorboard>=2.18,<2.19|g" \
29-
"s|tensorflow-text>=2.19.0|tensorflow-text==2.18.1|g" \
30-
"/tunix/d" \
27+
"s|^tensorflow$|tensorflow==2.18.1|g" \
28+
"s|^tensorflow-text$|tensorflow-text==2.18.1|g" \
29+
"s|^jax!=.*|jax|g" \
30+
"s|^jaxlib!=.*|jaxlib|g" \
3131
; do
3232
# tensorflow-cpu,tensorboard,tensorflow-text>=2.19.0 is incompatible with tensorflow==2.18.1
33-
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/pyproject.toml
33+
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/base_requirements/requirements.txt
3434
done
3535
EOF
3636

.github/container/build-jax.sh

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,10 @@ clean() {
177177
popd
178178
}
179179

180-
CUDA_MAJOR_VERSION="${CUDA_VERSION:0:2}"
180+
# Derive CUDA_MAJOR_VERSION from CUDA_VERSION if it isn't set
181+
if [ -z ${CUDA_MAJOR_VERSION+x} ]; then
182+
CUDA_MAJOR_VERSION="${CUDA_VERSION:0:2}"
183+
fi
181184
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
182185
if [[ "$CUDA_COMPUTE_CAPABILITIES" == "all" ]]; then
183186
CUDA_COMPUTE_CAPABILITIES=$(supported_compute_capabilities ${CPU_ARCH})
@@ -263,17 +266,21 @@ time python "${SRC_PATH_JAX}/build/build.py" build \
263266

264267
# Make sure that JAX depends on the local jaxlib installation
265268
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
269+
old_hash=($(md5sum build/requirements.in))
266270
for component in jaxlib "jax-cuda${CUDA_MAJOR_VERSION}-pjrt" "jax-cuda${CUDA_MAJOR_VERSION}-plugin"; do
267-
# Note that this also drops the [with-cuda] extra from the committed
268-
# version, so nvidia-*-cu12 wheels disappear from the lock file
269271
sed -i "s|^${component}.*$|${component} @ file://${BUILD_PATH_JAXLIB}/${component//-/_}|" build/requirements.in
270272
done
273+
new_hash=($(md5sum build/requirements.in))
271274
# Bazel args to avoid cache invalidation
272275
BAZEL_ARGS=(
273276
--config=cuda_libraries_from_stubs
274277
--repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
275278
)
276-
bazel run "${BAZEL_ARGS[@]}" --verbose_failures=true //build:requirements.update
279+
# //build:requirements.update can be quite slow; only run it if we actually
280+
# modified requirements.in just above.
281+
if [[ "${old_hash}" != "${new_hash}" ]]; then
282+
bazel run "${BAZEL_ARGS[@]}" --verbose_failures=true //build:requirements.update
283+
fi
277284
if (( "${#EXTRA_TARGETS[@]}" > 0 )); then
278285
bazel build "${BAZEL_ARGS[@]}" --verbose_failures=true "${EXTRA_TARGETS[@]}"
279286
if [[ -n "${EXTRA_TARGET_DEST}" ]]; then

.github/container/cutlass_dsl_jax/tests/test_streams.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

0 commit comments

Comments
 (0)