Skip to content

Commit 9ff388c

Browse files
authored
Add a torchax container (#1730)
1 parent 721eb20 commit 9ff388c

File tree

4 files changed

+72
-1
lines changed

4 files changed

+72
-1
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# syntax=docker/dockerfile:1-labs
2+
3+
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
4+
ARG URLREF_TORCHAX=https://github.com/google/torchax.git#main
5+
ARG SRC_PATH_TORCHAX=/opt/torchax
6+
7+
###############################################################################
8+
## Download source and add auxiliary scripts
9+
###############################################################################
10+
11+
FROM ${BASE_IMAGE} as mealkit
12+
ARG URLREF_TORCHAX
13+
ARG SRC_PATH_TORCHAX
14+
15+
# Specify installation targets
16+
RUN <<"EOF" bash -ex
17+
git-clone.sh ${URLREF_TORCHAX} ${SRC_PATH_TORCHAX}
18+
echo "-e file://${SRC_PATH_TORCHAX}" >> /opt/pip-tools.d/requirements-torchax.in
19+
echo "torch" >> /opt/pip-tools.d/requirements-torchax.in
20+
echo "torchvision" >> /opt/pip-tools.d/requirements-torchax.in
21+
# tensorflow is only needed to suppress an absl warning when importing torchax
22+
# that tensorflow.io.gfile will not support GCS paths such as gs://...
23+
# comment it out if you want to save ~300 MB in image size
24+
echo "tensorflow" >> /opt/pip-tools.d/requirements-torchax.in
25+
EOF
26+
27+
###############################################################################
28+
## Install accumulated packages from the base image and the previous stage
29+
###############################################################################
30+
31+
FROM mealkit as final
32+
33+
RUN <<"EOF" bash -ex -o pipefail
34+
PIP_INDEX_URL=https://download.pytorch.org/whl/cpu \
35+
PIP_EXTRA_INDEX_URL=https://pypi.org/simple \
36+
pip-finalize.sh
37+
EOF

.github/container/manifest.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,8 @@ tunix:
112112
tracking_ref: main
113113
latest_verified_commit: d799a45d48027e27b6a08aaf7cb15e6a4f495c01
114114
mode: git-clone
115+
torchax:
116+
url: https://github.com/google/torchax.git
117+
tracking_ref: main
118+
latest_verified_commit: f41e3de8526f9d4e8410bfb84660faaaf0b3ba4a
119+
mode: git-clone

.github/workflows/_ci.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,34 @@ jobs:
148148
EXTRA_BUILD_ARGS: |
149149
URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }}
150150
151+
build-torchax:
152+
needs: build-jax
153+
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"]
154+
outputs:
155+
DOCKER_TAG_MEALKIT: ${{ steps.build-torchax.outputs.DOCKER_TAG_MEALKIT }}
156+
DOCKER_TAG_FINAL: ${{ steps.build-torchax.outputs.DOCKER_TAG_FINAL }}
157+
steps:
158+
- name: Checkout repository
159+
uses: actions/checkout@v4
160+
- name: Build TorchAX container
161+
id: build-torchax
162+
uses: ./.github/actions/build-container
163+
with:
164+
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
165+
ARTIFACT_NAME: artifact-torchax-build
166+
BADGE_FILENAME: badge-torchax-build
167+
BUILD_DATE: ${{ inputs.BUILD_DATE }}
168+
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
169+
CONTAINER_NAME: torchax
170+
DOCKERFILE: .github/container/Dockerfile.torchax
171+
RUNNER_SIZE: small
172+
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
173+
ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }}
174+
github-token: ${{ secrets.GITHUB_TOKEN }}
175+
bazel-remote-cache-url: ${{ vars.BAZEL_REMOTE_CACHE_URL }}
176+
EXTRA_BUILD_ARGS: |
177+
URLREF_TORCHAX=${{ fromJson(inputs.SOURCE_URLREFS).TORCHAX }}
178+
151179
build-upstream-t5x:
152180
needs: build-jax
153181
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", "small"]

.github/workflows/ci.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ on:
3838
type: string
3939
description: |
4040
A comma-separated PACKAGE=URL#REF list to override sources used by build.
41-
PACKAGE∊{JAX,XLA,Flax,transformer-engine,airio,axlearn,equinox,T5X,maxtext} (case-insensitive)
41+
PACKAGE∊{JAX,XLA,Flax,transformer-engine,airio,axlearn,equinox,T5X,maxtext,TorchAX} (case-insensitive)
4242
default: ''
4343
required: false
4444
MODE:
@@ -361,6 +361,7 @@ jobs:
361361
upstream-t5x
362362
t5x
363363
axlearn
364+
torchax
364365
)
365366
declare -a STAGES=(
366367
mealkit

0 commit comments

Comments
 (0)