Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,12 @@ class CompletionRequest(BaseModel):
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
max_tokens: Optional[int] = None
# remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
)
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
Expand Down Expand Up @@ -568,6 +573,11 @@ def to_dict_for_infer(self, request_id=None, prompt=None):
if value is not None:
req_dict[key] = value

# max_completion_tokens takes priority over deprecated max_tokens
req_dict["max_tokens"] = (
self.max_completion_tokens if self.max_completion_tokens is not None else self.max_tokens
)

if request_id is not None:
req_dict["request_id"] = request_id
if prompt is not None:
Expand Down Expand Up @@ -745,7 +755,9 @@ def to_dict_for_infer(self, request_id=None):
"""
req_dict = {}

req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["max_tokens"] = (
self.max_completion_tokens if self.max_completion_tokens is not None else self.max_tokens
)
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
req_dict["prompt_logprobs"] = self.prompt_logprobs
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
Expand Down
189 changes: 189 additions & 0 deletions tests/entrypoints/test_completion_max_completion_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for CompletionRequest max_completion_tokens support (issue #2697).

This file tests two things:

1. **Logic correctness** (TestCompletionRequest* classes): Uses a minimal
local CompletionRequest to verify the max_completion_tokens field and
to_dict_for_infer priority logic. This mirrors the production code
exactly and tests the same pydantic + dict logic.

2. **Source code verification** (TestSourceCodeVerification classes):
Reads the actual protocol.py source to confirm the production code
contains the expected field definition and priority logic. This catches
regressions if someone modifies the real class.
"""

import os
from typing import List, Optional, Union

from pydantic import BaseModel, Field

# ---------------------------------------------------------------------------
# Minimal local CompletionRequest — mirrors production protocol.py
# ---------------------------------------------------------------------------


class CompletionRequest(BaseModel):

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 测试使用了本地定义的 CompletionRequest 而非导入生产代码的类

当前测试通过在测试文件内重新定义一个最小化的 CompletionRequest 类来验证逻辑,这意味着即使生产代码存在 bug,测试也可能通过。而 TestSourceCodeVerification 使用字符串匹配来验证源码,较为脆弱且无法验证运行时行为。

建议直接导入生产类进行测试:

from fastdeploy.entrypoints.openai.protocol import CompletionRequest

这样可以直接验证生产代码的行为,且当生产代码发生变更时测试能及时感知到。同时可以移除 TestSourceCodeVerification 类,因为直接导入生产类后字符串匹配验证就不再必要了。

"""Minimal reproduction of CompletionRequest for testing max_completion_tokens.

This mirrors the production class in fastdeploy/entrypoints/openai/protocol.py.
The to_dict_for_infer logic is identical to production.
"""

model: Optional[str] = "default"
prompt: Union[List[int], List[List[int]], str, List[str]] = "test"
max_tokens: Optional[int] = Field(
default=None,
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
)
max_completion_tokens: Optional[int] = None
suffix: Optional[dict] = None

def to_dict_for_infer(self, request_id=None, prompt=None):
req_dict = {}
req_dict["metrics"] = {}

if self.suffix is not None:
for key, value in self.suffix.items():
req_dict[key] = value
for key, value in self.dict().items():
if value is not None:
req_dict[key] = value

# max_completion_tokens takes priority over deprecated max_tokens
req_dict["max_tokens"] = (
self.max_completion_tokens if self.max_completion_tokens is not None else self.max_tokens
)

if request_id is not None:
req_dict["request_id"] = request_id
return req_dict


# ---------------------------------------------------------------------------
# Tests: Field behavior
# ---------------------------------------------------------------------------


class TestCompletionRequestMaxCompletionTokens:
"""Tests for CompletionRequest max_completion_tokens field."""

def test_field_exists(self):
"""max_completion_tokens field should be accepted."""
req = CompletionRequest(prompt="hello", max_completion_tokens=100)
assert req.max_completion_tokens == 100

def test_field_defaults_to_none(self):
"""max_completion_tokens should default to None."""
req = CompletionRequest(prompt="hello")
assert req.max_completion_tokens is None

def test_max_tokens_field_exists(self):
"""max_tokens field should still be accepted (backward compat)."""
req = CompletionRequest(prompt="hello", max_tokens=50)
assert req.max_tokens == 50

def test_both_fields_accepted(self):
"""Both max_tokens and max_completion_tokens should be accepted."""
req = CompletionRequest(prompt="hello", max_tokens=50, max_completion_tokens=100)
assert req.max_tokens == 50
assert req.max_completion_tokens == 100


# ---------------------------------------------------------------------------
# Tests: to_dict_for_infer logic
# ---------------------------------------------------------------------------


class TestCompletionRequestToDictForInfer:
"""Tests for CompletionRequest.to_dict_for_infer max_tokens logic."""

def test_only_max_completion_tokens(self):
"""When only max_completion_tokens is set, it should map to max_tokens."""
req = CompletionRequest(prompt="hello", max_completion_tokens=200)
result = req.to_dict_for_infer()
assert result["max_tokens"] == 200

def test_only_max_tokens(self):
"""When only max_tokens is set, it should be used."""
req = CompletionRequest(prompt="hello", max_tokens=100)
result = req.to_dict_for_infer()
assert result["max_tokens"] == 100

def test_max_completion_tokens_takes_priority(self):
"""max_completion_tokens should override max_tokens."""
req = CompletionRequest(prompt="hello", max_tokens=100, max_completion_tokens=200)
result = req.to_dict_for_infer()
assert result["max_tokens"] == 200

def test_neither_set(self):
"""When neither is set, max_tokens should be None."""
req = CompletionRequest(prompt="hello")
result = req.to_dict_for_infer()
assert result["max_tokens"] is None

def test_max_completion_tokens_zero_is_valid(self):
"""max_completion_tokens=0 should not fall back to max_tokens."""
req = CompletionRequest(prompt="hello", max_tokens=50, max_completion_tokens=0)
result = req.to_dict_for_infer()
assert result["max_tokens"] == 0

def test_includes_request_id(self):
"""to_dict_for_infer should include request_id when provided."""
req = CompletionRequest(prompt="hello", max_completion_tokens=50)
result = req.to_dict_for_infer(request_id="req-123")
assert result["request_id"] == "req-123"
assert result["max_tokens"] == 50


# ---------------------------------------------------------------------------
# Tests: Source code verification (production code)
# ---------------------------------------------------------------------------


class TestSourceCodeVerification:
"""Verify the actual protocol.py source has the expected changes."""

@staticmethod
def _read_protocol_source():
protocol_file = os.path.join(
os.path.dirname(__file__), "..", "..", "fastdeploy", "entrypoints", "openai", "protocol.py"
)
with open(protocol_file) as f:
return f.read()

def test_protocol_has_max_completion_tokens(self):
"""CompletionRequest in protocol.py should define max_completion_tokens."""
source = self._read_protocol_source()

class_start = source.find("class CompletionRequest(")
assert class_start != -1, "CompletionRequest class not found"

next_class = source.find("\nclass ", class_start + 1)
class_body = source[class_start:next_class] if next_class != -1 else source[class_start:]

assert (
"max_completion_tokens: Optional[int] = None" in class_body
), "max_completion_tokens field not found in CompletionRequest"

def test_to_dict_for_infer_has_priority_logic(self):
"""to_dict_for_infer should have max_completion_tokens priority logic."""
source = self._read_protocol_source()

assert (
"self.max_completion_tokens" in source and "self.max_tokens" in source and "is not None" in source
), "max_completion_tokens priority logic not found in protocol.py"
Loading