Skip to content

Commit d5712be

Browse files
committed
Adapt axlearn job for transformer engine
1 parent c2f6bcc commit d5712be

File tree

1 file changed

+13
-35
lines changed

1 file changed

+13
-35
lines changed
Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
apiVersion: batch/v1
22
kind: Job
33
metadata:
4-
name: PLACEHOLDER
4+
name: {{ JOB_NAME }}
55
labels:
66
kueue.x-k8s.io/queue-name: p5-queue
77
spec:
@@ -13,54 +13,32 @@ spec:
1313
spec:
1414
restartPolicy: Never
1515
containers:
16-
- name: axlearn-fuji-model
17-
image: PLACEHOLDER
16+
- name: transformer-engine
17+
image: {{ IMAGE_URI }}
1818
command:
1919
- bash
2020
- -xo
2121
- pipefail
2222
- -c
2323
- |
24-
BASEDIR="/opt/axlearn"
25-
CONFIG="fuji-3B-v3-flash-single-host"
26-
HLO_DUMP=0
27-
POSTFIX=""
24+
pip install pytest-reportlog pytest-xdist
25+
# Start MPS daemon
26+
nvidia-cuda-mps-control -d
27+
# TE's default is slightly different, without the hyphen
28+
export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE}
29+
# 1 GPU per worker, 6 workers per GPU
30+
pytest-xdist.sh 1 6 pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh
2831

29-
AR_THRESHOLD=1073741824
30-
AG_THRESHOLD=8589934592
31-
RS_THRESHOLD=8589934592
32-
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
33-
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
34-
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
35-
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
36-
--xla_gpu_enable_pipelined_all_gather=true
37-
--xla_gpu_enable_pipelined_reduce_scatter=true
38-
--xla_gpu_enable_pipelined_all_reduce=true
39-
--xla_gpu_enable_while_loop_double_buffering=true
40-
--xla_disable_hlo_passes=rematerialization}
41-
42-
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
43-
export TF_GPU_ALLOCATOR=cuda_malloc_async
44-
45-
LOG_DIR=${BASEDIR}/logs
46-
TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
47-
mkdir -p ${TRAINER_DIR}
48-
49-
50-
python3 -m axlearn.common.launch_trainer_main \
51-
--module=text.gpt.c4_trainer \
52-
--config=${CONFIG} \
53-
--trainer_dir=${TRAINER_DIR} \
54-
--data_dir=gs://axlearn-public/tensorflow_datasets \
55-
--jax_backend=gpu
5632
resources:
5733
limits:
5834
nvidia.com/gpu: 8
35+
requests:
36+
nvidia.com/gpu: 1
5937
volumeMounts:
6038
- name: output
6139
mountPath: /opt/output
6240
imagePullSecrets:
63-
- name: PLACEHOLDER
41+
- name: {{ IMAGE_PULL_SECRET }}
6442
volumes:
6543
- name: output
6644
emptyDir: {}

0 commit comments

Comments
 (0)