@@ -489,12 +489,41 @@ jobs:
489489 # test-equinox.log
490490 # secrets: inherit
491491
492- test-te-multigpu :
492+ te-unittests :
493+ secrets : inherit
493494 needs : build-jax
494- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
495- uses : ./.github/workflows/_test_te .yaml
495+ if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
496+ uses : ./.github/workflows/_test_unit .yaml
496497 with :
497- TE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
498+ TEST_NAME : te
499+ EXECUTE : |
500+ docker run -i --gpus all --shm-size=1g -v $PWD:/log \
501+ ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
502+ bash <<"EOF" |& tee test-te.log
503+ pip install pytest-reportlog pytest-xdist
504+ # Start MPS daemon
505+ nvidia-cuda-mps-control -d
506+ # TE's default is slightly different, without the hyphen
507+ export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE}
508+ # 1 GPU per worker, 6 workers per GPU
509+ pytest-xdist.sh 1 6 pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh
510+ EOF
511+
512+ STATISTICS_SCRIPT : |
513+ summary_line=$(tail -n1 test-te.log)
514+ errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
515+ passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l)
516+ failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l)
517+ total_tests=$((failed_tests + passed_tests))
518+ echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
519+ echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
520+ echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
521+ echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
522+
523+ TIMEOUT_MINUTES : 120
524+ ARTIFACTS : |
525+ test-te.log
526+ pytest-report.jsonl
498527 secrets : inherit
499528
500529# test-upstream-t5x:
0 commit comments