Skip to content

Commit e8b7dde

Browse files
authored
✨ 提供更多配置项 (#9)
1 parent bdc7571 commit e8b7dde

File tree

6 files changed

+309
-10
lines changed

6 files changed

+309
-10
lines changed

README.md

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,52 @@ plugins = ["nonebot_plugin_deepseek"]
9494
> `enable_models`[`CustomModel`](https://github.com/KomoriDev/nonebot-plugin-deepseek/blob/ee9f0b0f0568eedb3eb87423e6c1bf271787ab76/nonebot_plugin_deepseek/config.py#L34) 结构的字典,若无接入本地模型的需要则无需修改
9595
> 若要接入本地模型,请参见:👉 [文档](./tutorial.md)
9696
97-
| 配置项 | 必填 | 默认值 |
98-
| :---------------------------: | :--: | :---------------------------: |
99-
| deepseek__api_key |||
100-
| deepseek__enable_models || [{ "name": "deepseek-chat" }, { "name": "deepseek-reasoner" }] |
101-
| deepseek__prompt |||
102-
| deepseek__md_to_pic || False |
103-
| deepseek__enable_send_thinking || False |
97+
| 配置项 |必填| 默认值 | 说明 |
98+
|:---------------------------: |:--:| :---------------------------: | :-----------: |
99+
| deepseek__api_key ||| API Key |
100+
| deepseek__enable_models ||[{ "name": "deepseek-chat" }, { "name": "deepseek-reasoner" }]|启用的模型 [配置说明](#enable_models-配置说明)|
101+
| deepseek__prompt ||| 模型预设 |
102+
| deepseek__md_to_pic || False | 是否启用 Markdown 转图片 |
103+
|deepseek__enable_send_thinking|| False | 是否发送思维链 |
104+
105+
### `enable_models` 配置说明
106+
107+
`enable_models`[`CustomModel`](https://github.com/KomoriDev/nonebot-plugin-deepseek/blob/ee9f0b0f0568eedb3eb87423e6c1bf271787ab76/nonebot_plugin_deepseek/config.py#L34) 结构的字典,用于控制不同模型的配置,包含的字段有
108+
109+
> [!TIP]
110+
> 以下字段均在 [DeepSeek API 文档](https://api-docs.deepseek.com/zh-cn/) 有更详细的介绍
111+
> `deepseek-reasoner` 模型不支持 `logprobs``top_logprobs` 参数
112+
113+
- `name`: 模型名称(必填)
114+
- `base_url`: 接口地址(默认为:<https://api.deepseek.com>)(自建模型必填)
115+
- `max_tokens`: 模型生成补全的最大 token 数
116+
- `deepseek-chat`: 介于 1 到 8192 间的整数,默认使用 4096
117+
- `deepseek-reasoner`: 默认为 4K,最大为 8K
118+
- `frequency_penalty`: 用于阻止模型在生成的文本中过于频繁地重复相同的单词或短语
119+
- `presence_penalty`: 用于鼓励模型在生成的文本中包含各种标记
120+
- `stop`: 遇到这些词时停止生成token
121+
- `temperature`: 采样温度,不建议和 `top_p` 一起修改
122+
- `top_p`: 采样温度的替代方案。不建议和 `temperature` 一起修改
123+
- `logprobs`: 是否返回所输出 token 的对数概率
124+
- `top_logprobs`: 指定在每个 token 位置返回最有可能的 tokens
125+
126+
配置示例:
127+
128+
```bash
129+
deepseek__enable_models='
130+
[
131+
{
132+
"name": "deepseek-chat",
133+
"max_tokens": 2048,
134+
"top_p": 0.5
135+
},
136+
{
137+
"name": "deepseek-reasoner",
138+
"max_tokens": 8000
139+
}
140+
]
141+
'
142+
```
104143

105144
## 🎉 使用
106145

nonebot_plugin_deepseek/_types.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
@author: openai
3+
@website: https://github.com/openai/openai-python/blob/main/src/openai/_types.py
4+
"""
5+
6+
from typing_extensions import override
7+
from typing import Union, Literal, TypeVar
8+
9+
_T = TypeVar("_T")
10+
11+
12+
class NotGiven:
13+
"""
14+
A sentinel singleton class used to distinguish omitted keyword arguments
15+
from those passed in with the value None (which may have different behavior).
16+
17+
For example:
18+
19+
```py
20+
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
21+
22+
23+
get(timeout=1) # 1s timeout
24+
get(timeout=None) # No timeout
25+
get() # Default timeout behavior, which may not be statically known at the method definition.
26+
```
27+
"""
28+
29+
def __bool__(self) -> Literal[False]:
30+
return False
31+
32+
@override
33+
def __repr__(self) -> str:
34+
return "NOT_GIVEN"
35+
36+
37+
NotGivenOr = Union[_T, NotGiven]
38+
NOT_GIVEN = NotGiven()

nonebot_plugin_deepseek/apis/request.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ class API:
1717
@classmethod
1818
async def chat(cls, message: list[dict[str, str]], model: str = "deepseek-chat") -> ChatCompletions:
1919
"""普通对话"""
20+
model_config = config.get_model_config(model)
2021
json = {
2122
"messages": [{"content": config.prompt, "role": "user"}] + message if config.prompt else message,
2223
"model": model,
24+
**model_config.to_dict(),
2325
}
24-
logger.debug(f"使用模型 {model}")
26+
logger.debug(f"使用模型 {model},配置:{json}")
2527
# if model == "deepseek-chat":
2628
# json.update({"tools": registry.to_json()})
2729
async with httpx.AsyncClient() as client:
2830
response = await client.post(
29-
f"{config.get_model_url(model)}/chat/completions",
31+
f"{model_config.base_url}/chat/completions",
3032
headers={**cls._headers, "Content-Type": "application/json"},
3133
json=json,
3234
timeout=50,

nonebot_plugin_deepseek/compat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Literal, overload
2+
3+
from nonebot.compat import PYDANTIC_V2
4+
5+
__all__ = ("model_validator",)
6+
7+
8+
if PYDANTIC_V2:
9+
from pydantic import model_validator as model_validator
10+
else:
11+
from pydantic import root_validator
12+
13+
@overload
14+
def model_validator(*, mode: Literal["before"]): ...
15+
16+
@overload
17+
def model_validator(*, mode: Literal["after"]): ...
18+
19+
def model_validator(*, mode: Literal["before", "after"]):
20+
return root_validator(pre=mode == "before", allow_reuse=True) # type: ignore

nonebot_plugin_deepseek/config.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import json
22
from pathlib import Path
3+
from typing import Any, Union, Optional
34

4-
from pydantic import Field, BaseModel
5+
from nonebot.compat import PYDANTIC_V2
56
import nonebot_plugin_localstore as store
67
from nonebot import logger, get_plugin_config
8+
from pydantic import Field, BaseModel, ConfigDict
9+
10+
from .compat import model_validator
11+
from ._types import NOT_GIVEN, NotGivenOr
712

813

914
class ModelConfig:
@@ -39,6 +44,86 @@ class CustomModel(BaseModel):
3944
"""Model Name"""
4045
base_url: str = "https://api.deepseek.com"
4146
"""Custom base URL for this model (optional)"""
47+
max_tokens: int = Field(default=4090, gt=1, lt=8192)
48+
"""
49+
限制一次请求中模型生成 completion 的最大 token 数
50+
- `deepseek-chat`: Integer between 1 and 8192. Default is 4090.
51+
- `deepseek-reasoner`: Default is 4K, maximum is 8K.
52+
"""
53+
frequency_penalty: Union[int, float] = Field(default=0, ge=-2, le=2)
54+
"""
55+
Discourage the model from repeating the same words or phrases too frequently within the generated text
56+
"""
57+
presence_penalty: Union[int, float] = Field(default=0, ge=-2, le=2)
58+
"""Encourage the model to include a diverse range of tokens in the generated text"""
59+
stop: Optional[Union[str, list[str]]] = Field(default=None)
60+
"""
61+
Stop generating tokens when encounter these words.
62+
Note that the list contains a maximum of 16 string.
63+
"""
64+
temperature: Union[int, float] = Field(default=1, ge=0, le=2)
65+
"""Sampling temperature. It is not recommended to used it with top_p"""
66+
top_p: Union[int, float] = Field(default=1, ge=0, le=1)
67+
"""Alternatives to sampling temperature. It is not recommended to used it with temperature"""
68+
logprobs: NotGivenOr[Union[bool, None]] = Field(default=NOT_GIVEN)
69+
"""Whether to return the log probability of the output token."""
70+
top_logprobs: NotGivenOr[int] = Field(default=NOT_GIVEN, le=20)
71+
"""Specifies that the most likely token be returned at each token position."""
72+
73+
if PYDANTIC_V2:
74+
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
75+
else:
76+
77+
class Config:
78+
extra = "allow"
79+
arbitrary_types_allowed = True
80+
81+
@model_validator(mode="before")
82+
@classmethod
83+
def check_max_token(cls, data: Any) -> Any:
84+
if isinstance(data, dict):
85+
name = data.get("name")
86+
87+
if "max_tokens" not in data:
88+
if name == "deepseek-reasoner":
89+
data["max_tokens"] = 4000
90+
else:
91+
data["max_tokens"] = 4090
92+
93+
stop = data.get("stop")
94+
if isinstance(stop, list) and len(stop) >= 16:
95+
raise ValueError("字段 `stop` 最多允许设置 16 个字符")
96+
97+
if name == "deepseek-chat":
98+
temperature = data.get("temperature")
99+
top_p = data.get("top_p")
100+
if temperature and top_p:
101+
logger.warning("不建议同时修改 `temperature` 和 `top_p` 字段")
102+
103+
top_logprobs = data.get("top_logprobs")
104+
logprobs = data.get("logprobs")
105+
if top_logprobs and logprobs is False:
106+
raise ValueError("指定 `top_logprobs` 参数时,`logprobs` 必须为 True")
107+
108+
elif name == "deepseek-reasoner":
109+
max_tokens = data.get("max_tokens")
110+
if max_tokens and max_tokens > 8000:
111+
logger.warning(f"模型 {name} `max_tokens` 字段最大为 8000")
112+
113+
unsupported_params = ["temperature", "top_p", "presence_penalty", "frequency_penalty"]
114+
params_present = [param for param in unsupported_params if param in data]
115+
if params_present:
116+
logger.warning(f"模型 {name} 不支持设置 {', '.join(params_present)}")
117+
118+
logprobs = data.get("logprobs")
119+
top_logprobs = data.get("top_logprobs")
120+
if logprobs or top_logprobs:
121+
raise ValueError(f"模型 {name} 不支持设置 logprobs、top_logprobs")
122+
123+
return data
124+
125+
def to_dict(self):
126+
return self.model_dump(exclude_unset=True, exclude_none=True, exclude={"name", "base_url"})
42127

43128

44129
class ScopedConfig(BaseModel):
@@ -66,6 +151,13 @@ def get_model_url(self, model_name: str) -> str:
66151
return model.base_url
67152
raise ValueError(f"Model {model_name} not enabled")
68153

154+
def get_model_config(self, model_name: str) -> CustomModel:
155+
"""Get model config"""
156+
for model in self.enable_models:
157+
if model.name == model_name:
158+
return model
159+
raise ValueError(f"Model {model_name} not enabled")
160+
69161

70162
class Config(BaseModel):
71163
deepseek: ScopedConfig = Field(default_factory=ScopedConfig)

tests/test_custom_model.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
5+
def test_custom_model():
6+
from nonebot_plugin_deepseek.config import CustomModel
7+
8+
# 测试基础字段验证和默认值
9+
def test_default_values():
10+
model = CustomModel(name="deepseek-chat")
11+
assert model.max_tokens == 4090
12+
assert model.base_url == "https://api.deepseek.com"
13+
assert model.temperature == 1
14+
15+
def test_reasoner_default_max_tokens():
16+
model = CustomModel(name="deepseek-reasoner")
17+
assert model.max_tokens == 4000
18+
19+
def test_invalid_max_tokens_range():
20+
with pytest.raises(ValidationError):
21+
CustomModel(name="test", max_tokens=0) # 必须 >1
22+
with pytest.raises(ValidationError):
23+
CustomModel(name="test", max_tokens=8192) # 必须 <8192
24+
25+
def test_field_ranges():
26+
with pytest.raises(ValidationError):
27+
CustomModel(name="test", frequency_penalty=3) # 允许范围 [-2, 2]
28+
with pytest.raises(ValidationError):
29+
CustomModel(name="test", top_p=2) # 允许范围 [0, 1]
30+
31+
# 测试 stop 字段验证
32+
def test_valid_stop_values():
33+
# 字符串类型
34+
model = CustomModel(name="test", stop="stop_word")
35+
assert model.stop == "stop_word"
36+
37+
# 列表类型(<=16个元素)
38+
model = CustomModel(name="test", stop=["stop1", "stop2"])
39+
assert model.stop == ["stop1", "stop2"]
40+
41+
def test_stop_list_too_long():
42+
with pytest.raises(ValueError, match="最多允许设置 16 个字符"):
43+
CustomModel(name="test", stop=[f"word{i}" for i in range(17)])
44+
45+
# 测试模型特定逻辑
46+
def test_deepseek_chat_temperature_warning(caplog):
47+
CustomModel(name="deepseek-chat", temperature=0.5, top_p=0.5)
48+
assert "不建议同时修改" in caplog.text
49+
50+
def test_deepseek_reasoner_constraints():
51+
# 不支持 logprobs
52+
with pytest.raises(ValueError, match="不支持设置 logprobs"):
53+
CustomModel(name="deepseek-reasoner", logprobs=True)
54+
55+
# 设置无效字段时抛出警告
56+
with pytest.warns(UserWarning) as record:
57+
CustomModel(name="deepseek-reasoner", temperature=0.5, presence_penalty=1)
58+
assert any("不支持设置" in str(warn.message) for warn in record.list)
59+
60+
def test_top_logprobs_requires_logprobs():
61+
# 同时启用 logprobs 和 top_logprobs
62+
CustomModel(name="deepseek-chat", logprobs=True, top_logprobs=5)
63+
64+
# 仅设置 top_logprobs 不设置 logprobs
65+
with pytest.raises(ValueError, match="logprobs 必须为 True"):
66+
CustomModel(name="deepseek-chat", top_logprobs=5)
67+
68+
# 显式关闭 logprobs 但设置 top_logprobs
69+
with pytest.raises(ValueError, match="logprobs 必须为 True"):
70+
CustomModel(name="deepseek-chat", logprobs=False, top_logprobs=5)
71+
72+
def test_logprobs_combinations(caplog):
73+
# 测试合法组合
74+
model = CustomModel(name="deepseek-chat", logprobs=True)
75+
assert model.logprobs is True
76+
assert model.top_logprobs is None
77+
78+
# 测试带 top_logprobs 的合法组合
79+
model = CustomModel(name="deepseek-chat", logprobs=True, top_logprobs=10)
80+
assert model.top_logprobs == 10
81+
82+
# 测试非法组合的异常消息
83+
with pytest.raises(ValueError, match="logprobs 必须为 True") as excinfo:
84+
CustomModel(name="deepseek-chat", top_logprobs=5)
85+
assert "logprobs 必须为 True" in str(excinfo.value)
86+
87+
def test_reasoner_max_tokens_warning(caplog):
88+
CustomModel(name="deepseek-reasoner", max_tokens=8001)
89+
assert "最大为 8000" in caplog.text
90+
91+
# 测试额外字段和配置
92+
def test_extra_fields_allowed():
93+
model = CustomModel(name="test", extra_field="value") # type: ignore
94+
assert hasattr(model, "extra_field")
95+
96+
# 测试验证器边界条件
97+
def test_temperature_top_p_combinations():
98+
# 合法组合
99+
CustomModel(name="test", temperature=0) # 允许最小值
100+
CustomModel(name="test", top_p=0) # 允许最小值
101+
CustomModel(name="test", temperature=2, top_p=1) # 允许最大值
102+
103+
def test_presence_penalty_boundary():
104+
# 边界值测试
105+
CustomModel(name="test", presence_penalty=-2) # 最小值
106+
CustomModel(name="test", presence_penalty=2) # 最大值
107+
with pytest.raises(ValidationError):
108+
CustomModel(name="test", presence_penalty=-3)

0 commit comments

Comments
 (0)