Skip to content

Commit c6cce39

Browse files
authored
[TRTLLM-9053][feat] Support accuracy test and install from wheel (#9038)
Signed-off-by: Zero Zeng <[email protected]>
1 parent 84483a2 commit c6cce39

File tree

6 files changed

+162
-19
lines changed

6 files changed

+162
-19
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/bin/bash
2+
set -euo pipefail
3+
4+
# Parse arguments
5+
full_logdir=${1}
6+
accuracy_model=${2}
7+
accuracy_tasks=${3}
8+
model_path=${4}
9+
model_args_extra=${5}
10+
11+
echo "Starting accuracy evaluation..."
12+
echo "Log directory: ${full_logdir}"
13+
14+
# Parse hostname and port from server_config.yaml
15+
config_file="${full_logdir}/server_config.yaml"
16+
17+
# Wait for server_config.yaml to be created
18+
max_wait=1800
19+
wait_count=0
20+
while [ ! -f "${config_file}" ] && [ ${wait_count} -lt ${max_wait} ]; do
21+
echo "Waiting for server_config.yaml to be created..."
22+
sleep 1
23+
wait_count=$((wait_count + 1))
24+
done
25+
26+
if [ ${wait_count} -ge ${max_wait} ]; then
27+
echo "Error: server_config.yaml not found after ${max_wait} seconds"
28+
exit 1
29+
fi
30+
31+
# grep the host and port from the config file
32+
hostname=$(grep -i "hostname:" ${config_file} | awk '{print $2}')
33+
port=$(grep -i "port:" ${config_file} | awk '{print $2}')
34+
35+
if [ -z "$hostname" ] || [ -z "$port" ]; then
36+
echo "Error: Failed to extract hostname or port from config file"
37+
exit 1
38+
fi
39+
40+
echo "Hostname: ${hostname}, Port: ${port}"
41+
base_url="http://${hostname}:${port}/v1/completions"
42+
echo "Using base_url: ${base_url}"
43+
44+
# check server is health by curl every 10 seconds timeout 1800 seconds
45+
timeout=1800
46+
start_time=$(date +%s)
47+
while ! curl -s -o /dev/null -w "%{http_code}" http://${hostname}:${port}/health; do
48+
current_time=$(date +%s)
49+
elapsed=$((current_time - start_time))
50+
if [ $elapsed -ge $timeout ]; then
51+
echo "Error: Server is not healthy after ${timeout} seconds"
52+
exit 1
53+
fi
54+
if [ $((elapsed % 30)) -eq 0 ]; then
55+
echo "Waiting for server to be healthy... (${elapsed}s elapsed)"
56+
fi
57+
sleep 10
58+
done
59+
60+
# Install lm_eval and run evaluation
61+
echo "Installing lm_eval[api] and running evaluation..."
62+
pip install lm_eval[api]==0.4.8
63+
64+
echo "Running lm_eval with tasks: ${accuracy_tasks}..."
65+
lm_eval --model ${accuracy_model} \
66+
--tasks ${accuracy_tasks} \
67+
--model_args model=${model_path},base_url=${base_url},${model_args_extra} \
68+
--trust_remote_code
69+
70+
echo "Accuracy evaluation completed successfully"

examples/disaggregated/slurm/benchmark/config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,21 @@ environment:
3434
model_path: "<model_path>"
3535
trtllm_repo: "<trtllm_repo>"
3636
build_wheel: false # Don't build the wheel when launching multiple jobs
37+
trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead
3738
dataset_file: "<dataset_file>"
3839
work_dir: "<full_path_to_work_dir>"
3940

4041
# Profiling Configuration
4142
profiling:
4243
nsys_on: false # Set to true to enable profiling
4344

45+
# Accuracy Configuration
46+
accuracy:
47+
enable_accuracy_test: false # Set to true to enable accuracy evaluation
48+
model: "local-completions" # Model type for lm_eval
49+
tasks: "gsm8k" # Evaluation tasks (comma-separated)
50+
model_args_extra: "num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096" # Extra model arguments for lm_eval
51+
4452
worker_config:
4553
gen:
4654
tensor_parallel_size: 8

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,16 @@ full_logdir=${24}
3636
container_mount=${25}
3737
container_image=${26}
3838
build_wheel=${27}
39+
trtllm_wheel_path=${28}
3940

4041
# Profiling
41-
nsys_on=${28}
42+
nsys_on=${29}
43+
44+
# Accuracy evaluation
45+
enable_accuracy_test=${30}
46+
accuracy_model=${31}
47+
accuracy_tasks=${32}
48+
model_args_extra=${33}
4249

4350
# Print all parsed arguments
4451
echo "Parsed arguments:"
@@ -74,12 +81,18 @@ echo " container_image: ${container_image}"
7481
echo " model_path: ${model_path}"
7582
echo " trtllm_repo: ${trtllm_repo}"
7683
echo " build_wheel: ${build_wheel}"
84+
echo " trtllm_wheel_path: ${trtllm_wheel_path}"
7785
echo " work_dir: ${work_dir}"
7886
echo " nsys_on: ${nsys_on}"
87+
echo
88+
echo "Accuracy Configuration:"
89+
echo " enable_accuracy_test: ${enable_accuracy_test}"
90+
echo " accuracy_model: ${accuracy_model}"
91+
echo " accuracy_tasks: ${accuracy_tasks}"
92+
echo " model_args_extra: ${model_args_extra}"
7993

8094
container_name="disaggr-test"
8195

82-
# Log directory is now passed directly
8396
echo "Log directory: ${full_logdir}"
8497

8598
# Function to cleanup on failure
@@ -102,8 +115,20 @@ if ! srun -l --container-image=${container_image} \
102115
cleanup_on_failure "Failed to start container. Check ${full_logdir}/container_launch.log"
103116
fi
104117

105-
# Build TensorRT-LLM if needed
106-
if [ -d "${trtllm_repo}" ]; then
118+
# Install TensorRT-LLM
119+
if [ -n "${trtllm_wheel_path}" ]; then
120+
# Install from pre-built wheel if path is provided
121+
echo "Installing TensorRT-LLM from wheel: ${trtllm_wheel_path}..."
122+
if ! srun --container-name=${container_name} \
123+
--container-mounts=${container_mount} --no-container-mount-home \
124+
--mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \
125+
bash -c "pip install ${trtllm_wheel_path}" \
126+
&> ${full_logdir}/install.log; then
127+
cleanup_on_failure "TensorRT-LLM wheel installation failed. Check ${full_logdir}/install.log for details"
128+
fi
129+
echo "TensorRT-LLM wheel installation completed successfully"
130+
elif [ -d "${trtllm_repo}" ]; then
131+
# Build and install from repository if no wheel path provided
107132
echo "Installing TensorRT-LLM from ${trtllm_repo}..."
108133
TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown")
109134
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
@@ -226,6 +251,22 @@ else
226251
fi
227252
fi
228253
echo "Benchmark completed successfully"
254+
255+
# Run accuracy evaluation if enabled
256+
if [ "${enable_accuracy_test}" = "true" ]; then
257+
echo "Starting accuracy evaluation..."
258+
if ! srun -l --container-name=${container_name} \
259+
--container-mounts=${container_mount} \
260+
--mpi=pmix --overlap -N 1 -n 1 \
261+
bash ${work_dir}/accuracy_eval.sh \
262+
"${full_logdir}" "${accuracy_model}" "${accuracy_tasks}" "${model_path}" \
263+
"${model_args_extra}" \
264+
&> ${full_logdir}/accuracy_eval.log; then
265+
cleanup_on_failure "Accuracy evaluation failed. Check ${full_logdir}/accuracy_eval.log for details"
266+
fi
267+
echo "Accuracy evaluation completed successfully"
268+
fi
269+
229270
echo "Total runtime: $SECONDS seconds"
230271

231272
# try to kill the server and workers

examples/disaggregated/slurm/benchmark/run_benchmark.sh

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,3 @@ job_id=${SLURM_JOB_ID}
127127
if [ -n "${job_id}" ]; then
128128
echo "${SLURM_JOB_NODELIST}" > ${log_path}/job_${job_id}.txt
129129
fi
130-
131-
echo "Benchmark done, gracefully shutting down server and workers..."
132-
kill -9 $(ps aux | grep '[s]tart_server.sh' | awk '{print $2}') >/dev/null 2>&1 || true
133-
kill -9 $(ps aux | grep '[s]tart_worker.sh' | awk '{print $2}') >/dev/null 2>&1 || true
134-
kill -9 $(ps aux | grep '[t]rtllm-serve' | awk '{print $2}') >/dev/null 2>&1 || true
135-
sleep 20 # Give processes some time to clean up
136-
137-
# Check if there are any remaining processes
138-
if pgrep -f "trtllm-serve"; then
139-
echo "Warning: Some processes may still be running"
140-
else
141-
echo "All processes successfully terminated"
142-
fi

examples/disaggregated/slurm/benchmark/submit.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import glob
55
import os
6+
import shutil
67
import subprocess
78
import sys
89

@@ -50,6 +51,24 @@ def submit_job(config):
5051
hw_config = config['hardware']
5152
env_config = config['environment']
5253

54+
# Set default accuracy configuration for backward compatibility
55+
if 'accuracy' not in config:
56+
config['accuracy'] = {
57+
'enable_accuracy_test':
58+
False,
59+
'model':
60+
'local-completions',
61+
'tasks':
62+
'gsm8k',
63+
'model_args_extra':
64+
'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
65+
}
66+
67+
# Set default environment configuration for backward compatibility
68+
env_config.setdefault('trtllm_repo', '')
69+
env_config.setdefault('build_wheel', False)
70+
env_config.setdefault('trtllm_wheel_path', '')
71+
5372
# Get number of servers from config
5473
ctx_num = hw_config['num_ctx_servers']
5574
gen_num = hw_config['num_gen_servers']
@@ -94,7 +113,10 @@ def submit_job(config):
94113

95114
# Create full log directory path
96115
log_dir = os.path.join(log_base, dir_suffix)
97-
os.makedirs(log_dir, exist_ok=True)
116+
# Remove existing directory if it exists
117+
if os.path.exists(log_dir):
118+
shutil.rmtree(log_dir)
119+
os.makedirs(log_dir)
98120

99121
# Setup config file paths and save worker configs
100122
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
@@ -150,9 +172,16 @@ def submit_job(config):
150172
env_config['container_mount'],
151173
env_config['container_image'],
152174
str(env_config['build_wheel']).lower(),
175+
env_config['trtllm_wheel_path'],
153176

154177
# Profiling
155-
str(config['profiling']['nsys_on']).lower()
178+
str(config['profiling']['nsys_on']).lower(),
179+
180+
# Accuracy evaluation
181+
str(config['accuracy']['enable_accuracy_test']).lower(),
182+
config['accuracy']['model'],
183+
config['accuracy']['tasks'],
184+
config['accuracy']['model_args_extra']
156185
]
157186

158187
# Submit the job

examples/wide_ep/slurm_scripts/config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,21 @@ environment:
3434
model_path: "<model_path>"
3535
trtllm_repo: "<trtllm_repo>"
3636
build_wheel: false # Don't build the wheel when launching multiple jobs
37+
trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead
3738
dataset_file: "<dataset_file>"
3839
work_dir: "<full_path_to_work_dir>"
3940

4041
# Profiling Configuration
4142
profiling:
4243
nsys_on: false # Set to true to enable profiling
4344

45+
# Accuracy Configuration
46+
accuracy:
47+
enable_accuracy_test: false # Set to true to enable accuracy evaluation
48+
model: "local-completions" # Model type for lm_eval
49+
tasks: "gsm8k" # Evaluation tasks (comma-separated)
50+
model_args_extra: "num_concurrent=512,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=256,max_length=512" # Extra model arguments for lm_eval
51+
4452
# Worker Configuration
4553
worker_config:
4654
gen:

0 commit comments

Comments
 (0)