Skip to content
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8e99e73
fix axlearn tests, and remove not necessary ones
Steboss May 9, 2025
f84b130
fix axlearn test
Steboss May 9, 2025
a7b5e79
fix summary write up and fix exclude patterns
Steboss May 12, 2025
e785509
fix axlearn tests
Steboss May 12, 2025
f82bda4
add a workflow dispatch for running selective jobs + try to run axlae…
Steboss May 13, 2025
fc34220
fix error
Steboss May 13, 2025
7294422
Merge branch 'main' into sbosisio/fix_axlearn_tests
Steboss May 13, 2025
7f357ff
wrong variable
Steboss May 13, 2025
b6e5121
Merge branch 'sbosisio/fix_axlearn_tests' of github.com:NVIDIA/JAX-To…
Steboss May 13, 2025
b4c5831
Fix output directory
Steboss May 13, 2025
fb1666c
fix the copy from s3
Steboss May 14, 2025
cb3dadc
fix the aws cp command
Steboss May 14, 2025
d105d81
Fake test to run ci
Steboss May 14, 2025
9fc1724
try to revert this action in order to detect failures and successes
Steboss May 14, 2025
833ba61
revert changes to k8s checker
Steboss May 14, 2025
f13deb8
add the xla flag
Steboss May 14, 2025
e6e6e52
Merge branch 'main' into sbosisio/fix_axlearn_tests
Steboss May 15, 2025
589e06a
try with new branch for runnings tests
Steboss May 15, 2025
94579a7
Merge branch 'sbosisio/fix_axlearn_tests' of github.com:NVIDIA/JAX-To…
Steboss May 15, 2025
0febad3
back to the origins
Steboss May 15, 2025
d852bc9
fix error
Steboss May 15, 2025
3e3ed35
exclude the host_array_test and try to run everything on gpus
Steboss May 15, 2025
d871ad2
remove unnecessary tests
Steboss May 15, 2025
b971e9e
exclude unnecessary tests
Steboss May 16, 2025
2489910
test with cuda as platform
Steboss May 16, 2025
33b400f
fix tests
Steboss May 16, 2025
e6e204c
Fix tests for GPUs and devices
Steboss May 16, 2025
440fc2d
try to check what gpus capabilities we see
Steboss May 19, 2025
6bdd345
Merge branch 'main' into sbosisio/fix_axlearn_tests
Steboss May 19, 2025
8a46fce
Update .github/eks-workflow-files/axlearn/axlearn-job.yml
Steboss May 19, 2025
acce96a
run the for 8 devices test
Steboss May 19, 2025
ca165c9
Merge branch 'sbosisio/fix_axlearn_tests' of github.com:NVIDIA/JAX-To…
Steboss May 19, 2025
404c799
fix script for jobs
Steboss May 19, 2025
8b7c570
fix error in test variable
Steboss May 19, 2025
8443391
remove unnecessary cuda
Steboss May 19, 2025
d93ca0e
Merge branch 'main' into sbosisio/fix_axlearn_tests
Steboss May 20, 2025
dff5456
reset CI to standard
Steboss May 20, 2025
b53408f
Merge branch 'sbosisio/fix_axlearn_tests' of github.com:NVIDIA/JAX-To…
Steboss May 20, 2025
c6f8342
test on tests
Steboss May 20, 2025
1388e6b
Merge branch 'main' into sbosisio/fix_axlearn_tests
Steboss May 20, 2025
be90585
fix test and run 8_devices
Steboss May 20, 2025
3171727
Merge branch 'sbosisio/fix_axlearn_tests' of github.com:NVIDIA/JAX-To…
Steboss May 20, 2025
f815fa3
install missing packages
Steboss May 20, 2025
dd22aaf
reset ci axlearn
Steboss May 20, 2025
25d1c77
fix @olupton comments
Steboss May 21, 2025
8a20be6
fix @olupton comments
Steboss May 21, 2025
66e374d
reset ci
Steboss May 21, 2025
7ff0d6f
Fix whitespace
Steboss May 21, 2025
5398b7c
Merge branch 'main' into sbosisio/fix_axlearn_tests
Steboss May 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/submit-delete-k8s-job/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ runs:
echo "Submit K8s job"
kubectl apply -f "${{ inputs.job-config-file }}"
kubectl get event | grep ${{ inputs.job-name }}

# Wait for job to be created
kubectl wait --for=create job/${{ inputs.job-name }} --timeout=$TIMEOUT_JOB_CREATION

Expand Down
2 changes: 2 additions & 0 deletions .github/container/Dockerfile.axlearn
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ seqio==0.0.18
protobuf==3.20.3
pytest>=7.4.3
tensorflow==2.18.1
pytest-xdist
pytest-reportlog
REQUIREMENTS
EOF

Expand Down
188 changes: 130 additions & 58 deletions .github/container/test-axlearn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

set -uo pipefail

# HELPER FUNCTIONS
usage() {
# Function to handle all the inputs
echo "Run tests in axlearn with specified options."
echo ""
echo "Usage: $0 [OPTIONS]"
Expand All @@ -18,12 +20,30 @@ usage() {
exit 1
}

# Default values
DIR='axlearn/axlearn/common'
run_tests() {
# Function to run tests for AXLearn
local env_spec=$1
local marker=$2
local suffix=$3
shift 3
local -a test_files=("$@")

local junit="log_${suffix}.xml"
local log="log_${suffix}.log"

cmd="${env_spec:+${env_spec} }pytest -m \"${marker}\" ${test_files[@]}\
--capture=tee-sys -v \
--junit-xml=${LOG_DIRECTORY}/${junit} | tee ${LOG_DIRECTORY}/${log}"
echo "Running command ${cmd}"
eval "${cmd}"
}

# DEFAULT VALUES
DIR='/opt/axlearn/axlearn/common'
TEST_FILES=()
OUTPUT_DIRECTORY=''

# Parse args manually
# INPUT PARSING
while [[ $# -gt 0 ]]; do
key="$1"
case $key in
Expand Down Expand Up @@ -66,19 +86,15 @@ while [[ $# -gt 0 ]]; do
;;
esac
done


cd "$DIR"
if [ -z "$OUTPUT_DIRECTORY" ]; then
timestamp=$(date +%Y%m%d_%H%M%S)
OUTPUT_DIRECTORY="test_runs/${timestamp}"
OUTPUT_DIRECTORY="output/${timestamp}"
fi
LOG_DIRECTORY="${OUTPUT_DIRECTORY}/logs"

mkdir -p "${LOG_DIRECTORY}"

# Print out config for sanity check
echo "Configuration:"
echo " Directory: $DIR"
if [ "${#TEST_FILES[@]}" -gt 0 ]; then
echo " Test Files:"
for f in "${TEST_FILES[@]}"; do
Expand All @@ -87,16 +103,18 @@ if [ "${#TEST_FILES[@]}" -gt 0 ]; then
else
echo " Test Files Pattern: '*_test.py' (default)"
fi
echo " Output Directory: $OUTPUT_DIRECTORY"

cd "$DIR" || exit 1

echo "Running tests..."

# DEPENDENCIES
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install timm transformers scikit-learn


pip install timm transformers scikit-learn grain evaluate prefixed wandb
echo "Downloading input data..."
mkdir -p /opt/axlearn/axlearn/data/tokenizers/sentencepiece
mkdir -p /opt/axlearn/axlearn/data/tokenizers/bpe
curl https://huggingface.co/t5-base/resolve/main/spiece.model -o /opt/axlearn/axlearn/data/tokenizers/sentencepiece/t5-base
curl https://huggingface.co/FacebookAI/roberta-base/raw/main/merges.txt -o /opt/axlearn/axlearn/data/tokenizers/bpe/roberta-base-merges.txt
curl https://huggingface.co/FacebookAI/roberta-base/raw/main/vocab.json -o /opt/axlearn/axlearn/data/tokenizers/bpe/roberta-base-vocab.json

# RETRIEVE TEST FILES
if [ "${#TEST_FILES[@]}" -eq 0 ]; then
TEST_FILES=("*_test.py")
fi
Expand All @@ -117,53 +135,107 @@ if [ "${#expanded_test_files[@]}" -eq 0 ]; then
exit 1
fi

# in case we have the exclusion list file
EXCLUDE_LIST_FILE="$DIR/exclusion_list.txt"
EXCLUDE_PATTERNS=()

if [ -f "$EXCLUDE_LIST_FILE" ]; then
echo "Reading exclusion list from '$EXCLUDE_LIST_FILE'"
mapfile -t EXCLUDE_PATTERNS < "$EXCLUDE_LIST_FILE"
else
echo "Exclusion list file not found at '$EXCLUDE_LIST_FILE'"
fi

EXCLUDE_PATTERNS=("array_serialization_test.py"
"t5_test.py" # tensorflow bug
"loss_test.py"
"input_t5_test.py"
"layers_test.py" # tensorflow bug
"checkpointer_orbax_test.py"
"checkpointer_orbax_emergency_test.py"
"checkpointer_test.py"
"input_glue_test.py"
"deberta_test.py"
"orbax_checkpointer"
"loss_test.py" # optax bug
"quantizer_test.py"
"test_utils_test.py"
"update_transformation_test.py"
"env_test.py"
"causal_lm_test.py"
"gradient_accumulation_test.py"
"file_system_test.py"
"compiler_options_test.py" # tpu only
"metrics_correlation_test.py" # manual only
"metrics_glue_test.py"
"ssm_test.py" # test on ssm
"summary_test.py" # wandb test
"param_converter_test.py"
"attention_test.py" # assertion errors to fix
# run these as part of the for_8_devices:
"gda_test.py"
"input_base_test.py"
"input_dispatch_test.py"
"trainer_test.py"
"utils_test.py"
)
final_test_files=()

for test_file in "${expanded_test_files[@]}"; do
exclude=false
for pattern in "${EXCLUDE_PATTERNS[@]}"; do
for test_file in "${expanded_test_files[@]}"; do
exclude=false
for pattern in "${EXCLUDE_PATTERNS[@]}"; do
if [[ "$(basename "$test_file")" == "$(basename "$pattern")" ]]; then
exclude=true
break
fi
done
if [ "$exclude" = false ]; then
exclude=true
break
fi
done
if [ "$exclude" = false ]; then
final_test_files+=("$test_file")
fi
fi
done

# Initialize counters for test
failures=0
passed=0
SUMMARY_FILE="${OUTPUT_DIRECTORY}/summary.txt"


for test_file in "${final_test_files[@]}"; do
echo "Running: ${test_file}"
log_file_name=$(echo "${test_file%.py}" | sed 's/\//__/g').log
log_file="${LOG_DIRECTORY}/${log_file_name}"
# run the tests and save them as *.log
pytest "${test_file}" --capture=tee-sys | tee "${log_file}"
exit_code=${PIPESTATUS[0]}
echo $exit_code
# write number of tests passed and failed
if [ $exit_code -eq 0 ]; then
echo "${test_file}: PASSED" >> "${SUMMARY_FILE}"
((passed++))

# RUN TESTS
TEST_8_DEVICES_FILES=("gda_test.py"
"input_base_test.py"
"input_dispatch_test.py"
"trainer_test.py"
"utils_test.py"
)
TEST_8_DEVICES_WITH_PATHS=()
for file in "${TEST_8_DEVICES_FILES[@]}"; do
found_files=$(find . -name "$file" -type f 2>/dev/null)
if [[ -n "$found_files" ]]; then
while IFS= read -r found_file; do
TEST_8_DEVICES_WITH_PATHS+=("$found_file")
done <<< "$found_files"
else
echo "${test_file}: FAILED (Exit code: $exit_code)" >> "${SUMMARY_FILE}"
((failures++))
echo "Warning: Test file $file not found in current directory structure"
fi
echo ""
done

run_tests "" "for_8_devices" "8_dev" "${TEST_8_DEVICES_WITH_PATHS[@]}"
# All the other tests
runs=(
"|not (gs_login or tpu or high_cpu or fp64 or for_8_devices)|base"
"JAX_ENABLE_X64=1|fp64|fp64"
)
for spec in "${runs[@]}"; do
IFS='|' read -r env_spec marker suffix <<< "${spec}"
echo "Running tests with ${env_spec}, ${marker}, ${suffix}"
run_tests "${env_spec}" "${marker}" "${suffix}" "${final_test_files[@]}"
echo "Test run"
done

# SUMMARY STATUS
passed=0
failed=0
skipped=0
for log in ${LOG_DIRECTORY}/log_*.log; do
count_pass=$(grep -Eo '[0-9]+ passed' "${log}" | awk '{print $1}' || true)
count_fail=$(grep -Eo '[0-9]+ failed' "${log}" | awk '{print $1}' || true)
count_skipped=$(grep -Eo '[0-9]+ skipped' "${log}" | awk '{print $1}' || true)
# in case of None
count_pass=${count_pass:-0}
count_fail=${count_fail:-0}
count_skipped=${count_skipped:-0}
# count all the tests
(( passed += count_pass ))
(( failed += count_fail ))
(( skipped += count_skipped ))
done

echo "Total number of passed tests ${passed}"
echo "Total number of failed tests ${failed}"
echo "Total number of skipped tests ${skipped}"
# add those to summary.txt and we're using it for extracting values
echo "PASSED: ${passed} FAILED: ${failed} SKIPPED: ${skipped}" >> ${LOG_DIRECTORY}/summary.txt
8 changes: 2 additions & 6 deletions .github/eks-workflow-files/axlearn/axlearn-job.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ spec:

LOG_DIR="/output/${RUN_ID}"
mkdir -p ${LOG_DIR}

# Start MPS daemon
nvidia-cuda-mps-control -d
# Run tests
pytest-xdist.sh 1 6 ${LOG_DIR}/axlearn-unittests.jsonl test-axlearn.sh --directory "." --output ${LOG_DIR} --test-files "/opt/axlearn/axlearn/common/*_test.py" | tee -a ${LOG_DIR}/pytest_stdout.log

# test on JAX, make sure 8 devices are visible
pytest-xdist.sh 8 8 ${LOG_DIR}/axlearn-unittests.jsonl test-axlearn.sh --directory "." --output ${LOG_DIR} --test-files "/opt/axlearn/axlearn/common/*_test.py"
env:
- name: RUN_ID
value: PLACEHOLDER
Expand Down
Loading