11apiVersion : batch/v1
22kind : Job
33metadata :
4- name : PLACEHOLDER
4+ name : {{ JOB_NAME }}
55 labels :
66 kueue.x-k8s.io/queue-name : p5-queue
77spec :
@@ -13,54 +13,32 @@ spec:
1313 spec :
1414 restartPolicy : Never
1515 containers :
16- - name : axlearn-fuji-model
17- image : PLACEHOLDER
16+ - name : transformer-engine
17+ image : {{ IMAGE_URI }}
1818 command :
1919 - bash
2020 - -xo
2121 - pipefail
2222 - -c
2323 - |
24- BASEDIR="/opt/axlearn"
25- CONFIG="fuji-3B-v3-flash-single-host"
26- HLO_DUMP=0
27- POSTFIX=""
24+ pip install pytest-reportlog pytest-xdist
25+ # Start MPS daemon
26+ nvidia-cuda-mps-control -d
27+ # TE's default is slightly different, without the hyphen
28+ export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE}
29+ # 1 GPU per worker, 6 workers per GPU
30+ pytest-xdist.sh 1 6 pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh
2831
29- AR_THRESHOLD=1073741824
30- AG_THRESHOLD=8589934592
31- RS_THRESHOLD=8589934592
32- BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
33- --xla_gpu_all_reduce_combine_threshold_bytes=1073741824
34- --xla_gpu_all_gather_combine_threshold_bytes=1073741824
35- --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
36- --xla_gpu_enable_pipelined_all_gather=true
37- --xla_gpu_enable_pipelined_reduce_scatter=true
38- --xla_gpu_enable_pipelined_all_reduce=true
39- --xla_gpu_enable_while_loop_double_buffering=true
40- --xla_disable_hlo_passes=rematerialization}
41-
42- export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
43- export TF_GPU_ALLOCATOR=cuda_malloc_async
44-
45- LOG_DIR=${BASEDIR}/logs
46- TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
47- mkdir -p ${TRAINER_DIR}
48-
49-
50- python3 -m axlearn.common.launch_trainer_main \
51- --module=text.gpt.c4_trainer \
52- --config=${CONFIG} \
53- --trainer_dir=${TRAINER_DIR} \
54- --data_dir=gs://axlearn-public/tensorflow_datasets \
55- --jax_backend=gpu
5632 resources :
5733 limits :
5834 nvidia.com/gpu : 8
35+ requests :
36+ nvidia.com/gpu : 1
5937 volumeMounts :
6038 - name : output
6139 mountPath : /opt/output
6240 imagePullSecrets :
63- - name : PLACEHOLDER
41+ - name : {{ IMAGE_PULL_SECRET }}
6442 volumes :
6543 - name : output
6644 emptyDir : {}
0 commit comments