Skip to content

Commit 493c78b

Browse files
Update pytest for processed logprobs.
Signed-off-by: Wangshanshan <[email protected]>
1 parent c252b76 commit 493c78b

File tree

2 files changed

+82
-320
lines changed

2 files changed

+82
-320
lines changed

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -916,25 +916,31 @@ def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs,
916916

917917

918918
@skip_ray
919-
def test_llm_logprobs_modes():
919+
@pytest.mark.parametrize("temperature", [0.0, 0.8])
920+
@pytest.mark.parametrize("top_k", [None, 50])
921+
# temperature: 0.0 is greedy sampling
922+
# top_k: None means all logits
923+
def test_llm_logprobs_modes_basic(temperature, top_k):
920924
"""
921-
Test that processed_logprobs mode works correctly in PyTorch backend.
925+
Test processed_logprobs mode works correctly in PyTorch backend.
922926
Validates that:
923927
- processed_logprobs returns non-positive values (log probabilities)
924-
- all values are valid logprobs
925928
"""
926929
llm = LLM(
927930
llama_model_path,
928-
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
931+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7),
929932
)
930933

931934
prompts = ["The future of AI is"]
932935
sampling_params = SamplingParams(
933936
max_tokens=5,
934937
logprobs=3,
935-
temperature=0.8,
936-
top_k=50,
938+
temperature=temperature,
939+
top_k=top_k,
937940
logprobs_mode="processed_logprobs",
941+
seed=42,
942+
return_context_logits=True,
943+
return_generation_logits=True,
938944
)
939945

940946
outputs = list(llm.generate(prompts, sampling_params))
@@ -953,20 +959,22 @@ def test_llm_logprobs_modes():
953959
for logprob_obj in token_logprobs.values():
954960
all_values.append(logprob_obj.logprob)
955961

956-
# Validate that processed_logprobs returns non-positive values
962+
# Validate that processed_logprobs returns non-positive values (log probabilities)
957963
for val in all_values:
958964
assert val <= 0.0, f"processed_logprobs should have non-positive values, got {val}"
959965

966+
del llm
967+
960968

961969
@skip_ray
962970
@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
963-
def test_llm_processed_logprobs_with_temperature(temperature: float):
971+
def test_llm_processed_logprobs_with_temperature(temperature):
964972
"""
965973
Test that processed_logprobs correctly applies temperature scaling.
966974
"""
967975
llm = LLM(
968976
llama_model_path,
969-
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
977+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7),
970978
)
971979

972980
prompt = ["The capital of France is"]
@@ -992,34 +1000,69 @@ def test_llm_processed_logprobs_with_temperature(temperature: float):
9921000
f"processed_logprobs should have non-positive values, got {logprob_obj.logprob}"
9931001
)
9941002

1003+
del llm
1004+
1005+
1006+
@skip_ray
1007+
def test_llm_processed_logprobs_with_greedy_sampling():
1008+
llm = LLM(
1009+
llama_model_path,
1010+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7),
1011+
)
1012+
1013+
prompt = ["Once upon a time"]
1014+
1015+
sampling_params = SamplingParams(
1016+
max_tokens=10,
1017+
logprobs=3,
1018+
temperature=0.0, # Greedy sampling
1019+
logprobs_mode="processed_logprobs",
1020+
)
1021+
1022+
outputs = llm.generate(prompt, sampling_params=sampling_params)
1023+
1024+
assert len(outputs) == 1
1025+
assert len(outputs[0].outputs[0].logprobs) > 0, (
1026+
"processed_logprobs should return logprobs even with greedy sampling")
1027+
1028+
# Check value ranges - all should be non-positive (log probabilities)
1029+
logprob_vals = [
1030+
logprob_obj.logprob for token_logprobs in outputs[0].outputs[0].logprobs
1031+
for logprob_obj in token_logprobs.values()
1032+
]
1033+
1034+
assert all(
1035+
v <= 0.0 for v in
1036+
logprob_vals), "processed_logprobs should have non-positive values"
1037+
1038+
del llm
1039+
9951040

9961041
@skip_ray
9971042
def test_llm_logprobs_mode_backward_compatibility():
9981043
"""
999-
Test that default behavior uses processed_logprobs.
1044+
Test that default behavior without specifying logprobs_mode.
10001045
"""
10011046
llm = LLM(
10021047
llama_model_path,
1003-
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
1048+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7),
10041049
)
10051050

1006-
prompt = ["Hello world"]
1051+
prompt = ["once upon a time"]
10071052

10081053
# Explicit processed_logprobs
10091054
explicit_params = SamplingParams(
1010-
max_tokens=3,
1055+
max_tokens=10,
10111056
logprobs=2,
1012-
temperature=0.8,
10131057
logprobs_mode="processed_logprobs",
10141058
seed=123,
10151059
)
10161060
explicit_outputs = list(llm.generate(prompt, explicit_params))
10171061

10181062
# Default (should be processed_logprobs)
10191063
default_params = SamplingParams(
1020-
max_tokens=3,
1064+
max_tokens=10,
10211065
logprobs=2,
1022-
temperature=0.8,
10231066
seed=123,
10241067
)
10251068
default_outputs = list(llm.generate(prompt, default_params))
@@ -1028,50 +1071,51 @@ def test_llm_logprobs_mode_backward_compatibility():
10281071
explicit_tokens = explicit_outputs[0].outputs[0].token_ids
10291072
default_tokens = default_outputs[0].outputs[0].token_ids
10301073

1031-
assert explicit_tokens == default_tokens, \
1032-
"Default should match explicit processed_logprobs"
1074+
assert explicit_tokens == default_tokens, (
1075+
"Default should match explicit processed_logprobs")
1076+
1077+
del llm
10331078

10341079

10351080
@skip_ray
1036-
def test_llm_processed_logprobs_with_top_k_top_p():
1081+
@pytest.mark.parametrize("top_p", [0.5, 1.0])
1082+
def test_llm_processed_logprobs_with_top_p(top_p):
10371083
"""
10381084
Test that processed_logprobs correctly applies top-k and top-p filtering.
10391085
This verifies the fix for processed_logprobs implementation.
10401086
"""
10411087
llm = LLM(
10421088
llama_model_path,
1043-
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
1089+
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7),
10441090
)
10451091

10461092
prompt = ["The future of technology"]
10471093

10481094
# Test with top_k and top_p to ensure processed_logprobs applies filtering
10491095
params = SamplingParams(
1050-
max_tokens=2,
1051-
logprobs=15, # Request more logprobs than top_k to see filtering
1096+
max_tokens=5,
1097+
logprobs=3,
10521098
temperature=1.0,
1053-
top_k=5, # Only keep top 5 tokens
1054-
top_p=0.9, # Restrict to top 90% probability mass
1099+
top_p=top_p,
10551100
logprobs_mode="processed_logprobs",
1101+
seed=42,
1102+
return_context_logits=True,
1103+
return_generation_logits=True,
10561104
)
10571105

10581106
outputs = list(llm.generate(prompt, params))
10591107
assert len(outputs) == 1
10601108

1061-
# Check that logprobs were returned
1062-
logprobs_list = outputs[0].outputs[0].logprobs
1063-
assert logprobs_list is not None
1064-
assert len(logprobs_list) > 0
1065-
1066-
# Check first token logprobs
1067-
first_token_logprobs = logprobs_list[0]
1068-
logprob_values = [obj.logprob for obj in first_token_logprobs.values()]
1069-
1070-
# Should have some -inf values (masked by top-k/top-p)
1071-
assert any(val == float("-inf") for val in logprob_values), (
1072-
"processed_logprobs should have -inf values for tokens masked by top-k/top-p"
1073-
)
1074-
1109+
# Check that some logprobs are -inf (masked by top-p) across all generated tokens
1110+
# Note: With top_p, not every token position will have -inf values in the top-k logprobs
1111+
# We need to check across all tokens.
1112+
all_logprobs = outputs[0].outputs[0].logprobs
1113+
for token_idx, token_logprobs in enumerate(all_logprobs):
1114+
logprob_values = [obj.logprob for obj in token_logprobs.values()]
1115+
if token_idx == 0:
1116+
print(f"First token processed_logprobs values: {logprob_values}")
1117+
if any(val == float("-inf") for val in logprob_values):
1118+
break
10751119
# All non-inf values should be non-positive (log probabilities)
10761120
non_inf_values = [v for v in logprob_values if v != float("-inf")]
10771121
if non_inf_values:

0 commit comments

Comments
 (0)