Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/RunTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ jobs:
with:
device_type: gpu
device_name: a100-40gb-4
build_mode: pinned
base_image: gcr.io/tpu-prod-env-multipod/maxtext_gpu_jax_pinned:latest
build_mode: stable_stack
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest

tpu_unit_tests:
needs: tpu_image
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,6 @@ jobs:
- name: build image with stable stack nightly jax
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_stable_stack_nightly_jax MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
- name: build image with jax stable stack release candidate image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_stable_stack_candidate_gpu MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_stable_stack_candidate_gpu BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ jobs:
${{ inputs.container_resource_option }} \
gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ inputs.device_type }} \
bash -c "
python3 -m pip install -e . &&
python3 -m pip install -e . --no-dependencies &&
python3 -m pytest --pyargs MaxText.tests -m '${{ inputs.pytest_marker }}' --durations=0
"
2 changes: 1 addition & 1 deletion maxtext_jax_stable_stack.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pk
python3 -m pip install --no-cache-dir --upgrade jax[tpu]; fi

# Install Maxtext requirements with Jax Stable Stack
RUN apt-get update && apt-get install --yes google-cloud-cli && apt-get install --yes dnsutils
RUN apt-get update && apt-get install --yes && apt-get install --yes dnsutils

# Install requirements file generated with pipreqs for JSS 0.5.2.
# Othewise use general requirements_with_jax_stable_stack.txt
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ mlperf-logging@git+https://github.com/mlperf/logging.git
google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
jsonlines
pathwaysutils@git+https://github.com/AI-Hypercomputer/pathways-utils.git
omegaconf
omegaconf
9 changes: 2 additions & 7 deletions requirements_with_jax_stable_stack.txt
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
# Requirements for Building the MaxText Docker Image
# These requirements are additional to the dependencies present in the JAX SS base image.
absl-py
aqtp==0.8.2
datasets
grain[parquet]>=0.2.6
ml-goodput-measurement==0.0.10
orbax-checkpoint>=0.10.3
pylint
pytest
pyink
pre-commit
protobuf==3.20.3
pytype
pillow>=11.1.0
sentencepiece==0.1.97
tensorflow-text>=2.13.0
tensorflow-datasets
tensorboardx>=2.6.2.2
tiktoken
transformers
mlperf-logging@git+https://github.com/mlperf/logging.git
google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
jsonlines
pathwaysutils@git+https://github.com/AI-Hypercomputer/pathways-utils.git
google-cloud-monitoring
google-api-core
google-api-python-client
omegaconf
jaxtyping
jaxtyping
Loading