From 09583e61bb14f4d9c243823ccfbbc5b1e9974455 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 30 Jun 2025 03:21:41 +0800 Subject: [PATCH 1/7] update apis for Chat type --- chattool/chattype.py | 2 +- chattool/core/chattype.py | 409 +++++++++++++++++++++++++++++++++++--- chattool/core/config.py | 20 ++ chattool/core/request.py | 114 ++++++++--- chattool/core/response.py | 170 ++++++++++++++++ chattool/response.py | 141 ------------- tests/conftest.py | 12 +- tests/test_oaiclient.py | 359 +++++++++++++++++++++++++++++++++ 8 files changed, 1022 insertions(+), 205 deletions(-) create mode 100644 chattool/core/response.py delete mode 100644 chattool/response.py create mode 100644 tests/test_oaiclient.py diff --git a/chattool/chattype.py b/chattool/chattype.py index 897d184..d8d28df 100644 --- a/chattool/chattype.py +++ b/chattool/chattype.py @@ -2,7 +2,7 @@ from typing import List, Dict, Union import chattool -from .response import Resp +from .core.response import Resp from .request import chat_completion, valid_models, curl_cmd_of_chat_completion import time, random, json, warnings import aiohttp diff --git a/chattool/core/chattype.py b/chattool/core/chattype.py index e25fa65..b771ccb 100644 --- a/chattool/core/chattype.py +++ b/chattool/core/chattype.py @@ -1,35 +1,384 @@ -from typing import Optional, List, Dict, Union -from chattool.core.config import Config, OpenAIConfig -from chattool.core.request import OpenAIClient, HTTPClient +from typing import List, Dict, Union, Optional, Generator, AsyncGenerator, Any +import json +import os +import time +import random +import asyncio +from loguru import logger +from chattool.core.config import OpenAIConfig +from chattool.core.request import OpenAIClient, StreamResponse +from chattool.core.response import ChatResponse + class Chat(OpenAIClient): - def __init__(self, config: Optional[OpenAIConfig] = None, chat_log: Optional[List[Dict]] = None): - if chat_log is None: - chat_log = [] - self.chat_log = chat_log + """简化的 Chat 类 - 专注于基础对话功能""" + + def __init__( + self, + msg: Union[List[Dict], str, None] = None, + config: Optional[OpenAIConfig] = None, + **kwargs + ): + """ + 初始化 Chat 对象 + + Args: + msg: 初始消息,可以是字符串、消息列表或 None + config: OpenAI 配置对象 + **kwargs: 其他配置参数(会覆盖 config 中的设置) + """ + # 初始化配置 + if config is None: + config = OpenAIConfig() + + # 应用 kwargs 覆盖配置 + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + super().__init__(config) + + # 初始化对话历史 + self._chat_log = self._init_messages(msg) + self._last_response: Optional[ChatResponse] = None - def add(self, role:str, **kwargs): - """Add a message to the chat log""" - assert role in ['user', 'assistant', 'system', 'tool', 'function'],\ - f"role should be one of ['user', 'assistant', 'system', 'tool'], but got {role}" - self.chat_log.append({'role':role, **kwargs}) + def _init_messages(self, msg: Union[List[Dict], str, None]) -> List[Dict]: + """初始化消息列表""" + if msg is None: + return [] + elif isinstance(msg, str): + return [{"role": "user", "content": msg}] + elif isinstance(msg, list): + # 验证消息格式 + for m in msg: + if not isinstance(m, dict) or 'role' not in m: + raise ValueError("消息列表中的每个元素都必须是包含 'role' 键的字典") + return msg.copy() + else: + raise ValueError("msg 必须是字符串、字典列表或 None") + + # === 消息管理 === + def add_message(self, role: str, content: str, **kwargs) -> 'Chat': + """添加消息到对话历史""" + if role not in ['user', 'assistant', 'system']: + raise ValueError(f"role 必须是 'user', 'assistant' 或 'system',收到: {role}") + + message = {"role": role, "content": content, **kwargs} + self._chat_log.append(message) return self - - def user(self, content: Union[List, str]): - """User message""" - return self.add('user', content=content) - - def assistant(self, content:Optional[str]=None): - """Assistant message""" - return self.add('assistant', content=content) - - def system(self, content:str): - """System message""" - return self.add('system', content=content) - - def chat_completion(self, messages:Optional[List[dict]]=None, model = None, temperature = None, top_p = None, max_tokens = None, stream = False, **kwargs): - """Chat completion""" - if messages is None: - messages = self.chat_log - return super().chat_completion(messages, model, temperature, top_p, max_tokens, stream, **kwargs) + + def user(self, content: str) -> 'Chat': + """添加用户消息""" + return self.add_message('user', content) + + def assistant(self, content: str) -> 'Chat': + """添加助手消息""" + return self.add_message('assistant', content) + + def system(self, content: str) -> 'Chat': + """添加系统消息""" + return self.add_message('system', content) + + def clear(self) -> 'Chat': + """清空对话历史""" + self._chat_log = [] + self._last_response = None + return self + + def pop(self, index: int = -1) -> Dict: + """移除并返回指定位置的消息""" + return self._chat_log.pop(index) + + # === 核心对话功能 === + def get_response( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + update_history: bool = True, + **options + ) -> ChatResponse: + """ + 获取对话响应(同步) + + Args: + max_retries: 最大重试次数 + retry_delay: 重试延迟 + update_history: 是否更新对话历史 + **options: 传递给 chat_completion 的其他参数 + """ + # 合并配置 + chat_options = { + "model": self.config.model, + "temperature": self.config.temperature, + **options + } + + last_error = None + for attempt in range(max_retries + 1): + try: + # 调用 OpenAI API + response_data = self.chat_completion( + messages=self._chat_log, + **chat_options + ) + + # 包装响应 + response = ChatResponse(response_data) + + # 验证响应 + if not response.is_valid(): + raise Exception(f"API 返回错误: {response.error_message}") + + # 更新历史记录 + if update_history and response.message: + self._chat_log.append(response.message) + + self._last_response = response + return response + + except Exception as e: + last_error = e + if attempt < max_retries: + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") + time.sleep(retry_delay * (2 ** attempt)) # 指数退避 + else: + self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") + + raise last_error + + async def async_get_response( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + update_history: bool = True, + **options + ) -> ChatResponse: + """ + 获取对话响应(异步) + """ + chat_options = { + "model": self.config.model, + "temperature": self.config.temperature, + **options + } + + last_error = None + for attempt in range(max_retries + 1): + try: + response_data = await self.async_chat_completion( + messages=self._chat_log, + **chat_options + ) + + response = ChatResponse(response_data) + + if not response.is_valid(): + raise Exception(f"API 返回错误: {response.error_message}") + + if update_history and response.message: + self._chat_log.append(response.message) + + self._last_response = response + return response + + except Exception as e: + last_error = e + if attempt < max_retries: + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") + await asyncio.sleep(retry_delay * (2 ** attempt)) + else: + self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") + + raise last_error + + # === 流式响应 === + def stream_response(self, **options) -> Generator[str, None, None]: + """ + 流式获取响应内容(同步) + 返回生成器,逐个 yield 内容片段 + """ + chat_options = { + "model": self.config.model, + "temperature": self.config.temperature, + **options + } + + full_content = "" + + try: + for stream_resp in self.chat_completion( + messages=self._chat_log, + stream=True, + **chat_options + ): + if stream_resp.has_content: + content = stream_resp.content + full_content += content + yield content + + if stream_resp.is_finished: + break + + # 更新历史记录 + if full_content: + self._chat_log.append({ + "role": "assistant", + "content": full_content + }) + + except Exception as e: + self.logger.error(f"流式响应失败: {e}") + raise + + async def async_stream_response(self, **options) -> AsyncGenerator[str, None]: + """ + 流式获取响应内容(异步) + """ + chat_options = { + "model": self.config.model, + "temperature": self.config.temperature, + **options + } + + full_content = "" + + try: + async for stream_resp in self.async_chat_completion( + messages=self._chat_log, + stream=True, + **chat_options + ): + if stream_resp.has_content: + content = stream_resp.content + full_content += content + yield content + + if stream_resp.is_finished: + break + + # 更新历史记录 + if full_content: + self._chat_log.append({ + "role": "assistant", + "content": full_content + }) + + except Exception as e: + self.logger.error(f"异步流式响应失败: {e}") + raise + + # === 便捷方法 === + def ask(self, question: str, **options) -> str: + """ + 问答便捷方法 + + Args: + question: 问题 + **options: 传递给 get_response 的参数 + + Returns: + 回答内容 + """ + self.user(question) + response = self.get_response(**options) + return response.content + + async def async_ask(self, question: str, **options) -> str: + """异步问答便捷方法""" + self.user(question) + response = await self.async_get_response(**options) + return response.content + + # === 对话历史管理 === + def save(self, path: str, mode: str = 'a', index: int = 0): + """保存对话历史到文件""" + # 确保目录存在 + os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True) + + data = { + "index": index, + "chat_log": self._chat_log, + "config": { + "model": self.config.model, + "api_base": self.config.api_base + } + } + + with open(path, mode, encoding='utf-8') as f: + f.write(json.dumps(data, ensure_ascii=False) + '\n') + + @classmethod + def load(cls, path: str) -> 'Chat': + """从文件加载对话历史""" + with open(path, 'r', encoding='utf-8') as f: + data = json.loads(f.read()) + + chat = cls(msg=data['chat_log']) + + # 如果有配置信息,应用它们 + if 'config' in data: + for key, value in data['config'].items(): + if hasattr(chat.config, key): + setattr(chat.config, key, value) + + return chat + + def copy(self) -> 'Chat': + """复制 Chat 对象""" + return Chat(msg=self._chat_log.copy(), config=self.config) + + # === 显示和调试 === + def print_log(self, sep: str = "\n" + "-" * 50 + "\n"): + """打印对话历史""" + for msg in self._chat_log: + role = msg['role'].upper() + content = msg.get('content', '') + print(f"{sep}{role}{sep}{content}") + + def get_debug_info(self) -> Dict[str, Any]: + """获取调试信息""" + return { + "message_count": len(self._chat_log), + "model": self.config.model, + "api_base": self.config.api_base, + "last_response": self._last_response.get_debug_info() if self._last_response else None + } + + def print_debug_info(self): + """打印调试信息""" + info = self.get_debug_info() + print("=== Chat Debug Info ===") + for key, value in info.items(): + print(f"{key}: {value}") + print("=" * 23) + + # === 属性访问 === + @property + def chat_log(self) -> List[Dict]: + """获取对话历史""" + return self._chat_log.copy() + + @property + def last_message(self) -> Optional[str]: + """获取最后一条消息的内容""" + if self._chat_log: + return self._chat_log[-1].get('content') + return None + + @property + def last_response(self) -> Optional[ChatResponse]: + """获取最后一次响应""" + return self._last_response + + # === 魔术方法 === + def __len__(self) -> int: + return len(self._chat_log) + + def __getitem__(self, index) -> Dict: + return self._chat_log[index] + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.__repr__() \ No newline at end of file diff --git a/chattool/core/config.py b/chattool/core/config.py index a2bb8d3..05ccd8c 100644 --- a/chattool/core/config.py +++ b/chattool/core/config.py @@ -55,7 +55,27 @@ def to_data(self, *kwargs): } # OpenAI 专用配置 +# core/config.py class OpenAIConfig(Config): + def __init__( + self, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + stop: Optional[list] = None, + **kwargs + ): + super().__init__(**kwargs) + # OpenAI 特定参数 + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.stop = stop + def __post__init__(self): if not self.api_key: self.api_key = os.getenv("OPENAI_API_KEY", "") diff --git a/chattool/core/request.py b/chattool/core/request.py index 1a758e3..c3a16ea 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -228,13 +228,51 @@ def __str__(self): def __repr__(self): return f"StreamResponse(content='{self.content}', finish_reason='{self.finish_reason}')" +import json +from typing import Dict, List, Optional, Union, Generator, AsyncGenerator, Any +from chattool.core.config import OpenAIConfig +from chattool.core.request import HTTPClient, StreamResponse class OpenAIClient(HTTPClient): - def __init__(self, config:Optional[OpenAIConfig] = None, logger = None, **kwargs): + _config_only_attrs = { + 'api_key', 'api_base', 'headers', 'timeout', + 'max_retries', 'retry_delay' + } + + def __init__(self, config: Optional[OpenAIConfig] = None, logger = None, **kwargs): if config is None: config = OpenAIConfig() super().__init__(config, logger, **kwargs) + + def _build_chat_data(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]: + """构建聊天完成请求的数据""" + data = {"messages": messages} + # 处理所有可能的参数 + all_params = set(kwargs.keys()) | { + k for k in self.config.__dict__.keys() + if not k.startswith('_') # 排除私有属性 + } + + for param_name in all_params: + # 跳过配置专用属性 + if param_name in self._config_only_attrs: + continue + + value = self._get_param_value(param_name, kwargs) + if value is not None: + data[param_name] = value + + return data + + def _get_param_value(self, param_name: str, kwargs: Dict[str, Any]): + """按优先级获取参数值:kwargs > config > None""" + # 优先使用 kwargs 中的值 + if param_name in kwargs: + return kwargs[param_name] + # 其次使用 config 中的值 + return self.config.get(param_name) + def chat_completion( self, messages: List[Dict[str, str]], @@ -261,21 +299,20 @@ def chat_completion( 如果 stream=False: 返回完整的响应字典 如果 stream=True: 返回 Generator,yield StreamResponse 对象 """ - data = { - "model": model or self.config.model, - "messages": messages, + # 将显式参数合并到 kwargs 中 + all_kwargs = { + 'model': model, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'stream': stream, **kwargs } - if temperature is not None: - data["temperature"] = temperature - if top_p is not None: - data["top_p"] = top_p - if max_tokens is not None: - data["max_tokens"] = max_tokens + # 使用统一的参数处理逻辑 + data = self._build_chat_data(messages, **all_kwargs) - if stream: - data["stream"] = True + if data.get('stream'): return self._stream_chat_completion(data) response = self.post("/chat/completions", data=data) @@ -307,21 +344,20 @@ async def async_chat_completion( 如果 stream=False: 返回完整的响应字典 如果 stream=True: 返回 AsyncGenerator,async yield StreamResponse 对象 """ - data = { - "model": model or self.config.model, - "messages": messages, + # 将显式参数合并到 kwargs 中 + all_kwargs = { + 'model': model, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'stream': stream, **kwargs } - if temperature is not None: - data["temperature"] = temperature - if top_p is not None: - data["top_p"] = top_p - if max_tokens is not None: - data["max_tokens"] = max_tokens + # 使用统一的参数处理逻辑 + data = self._build_chat_data(messages, **all_kwargs) - if stream: - data["stream"] = True + if data.get('stream'): return self._async_stream_chat_completion(data) response = await self.async_post("/chat/completions", data=data) @@ -432,35 +468,49 @@ async def _async_stream_chat_completion(self, data: Dict[str, Any]) -> AsyncGene def embeddings( self, input_text: Union[str, List[str]], - model: str = "text-embedding-ada-002", + model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """OpenAI Embeddings API""" - data = { - "model": model, - "input": input_text, + # 使用统一的参数处理逻辑 + all_kwargs = { + 'model': model or self.config.get('model', 'text-embedding-ada-002'), + 'input': input_text, **kwargs } + # 构建数据,但排除 input 参数因为它已经单独处理 + data = {} + for key, value in all_kwargs.items(): + if value is not None: + data[key] = value + response = self.post("/embeddings", data=data) return response.json() async def async_embeddings( self, input_text: Union[str, List[str]], - model: str = "text-embedding-ada-002", + model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """异步 OpenAI Embeddings API""" - data = { - "model": model, - "input": input_text, + # 使用统一的参数处理逻辑 + all_kwargs = { + 'model': model or self.config.get('model', 'text-embedding-ada-002'), + 'input': input_text, **kwargs } + # 构建数据 + data = {} + for key, value in all_kwargs.items(): + if value is not None: + data[key] = value + response = await self.async_post("/embeddings", data=data) return response.json() def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: """处理流式响应的单个数据块,返回 StreamResponse 对象""" - return StreamResponse(chunk_data) \ No newline at end of file + return StreamResponse(chunk_data) diff --git a/chattool/core/response.py b/chattool/core/response.py new file mode 100644 index 0000000..47399fc --- /dev/null +++ b/chattool/core/response.py @@ -0,0 +1,170 @@ +# Response class for Chattool +from typing import Dict, Any, Union, Optional +import json + +class ChatResponse: + """Chat completion 响应包装类""" + + def __init__(self, response: Union[Dict, Any]) -> None: + if isinstance(response, Dict): + self.response = response + self._raw_response = None + else: + self._raw_response = response + self.response = response.json() if hasattr(response, 'json') else response + + def is_valid(self) -> bool: + """检查响应是否有效""" + return 'error' not in self.response + + def is_stream(self) -> bool: + """检查是否为流式响应""" + return self.response.get('object') == 'chat.completion.chunk' + + # === 基础属性 === + @property + def id(self) -> Optional[str]: + return self.response.get('id') + + @property + def model(self) -> Optional[str]: + return self.response.get('model') + + @property + def created(self) -> Optional[int]: + return self.response.get('created') + + @property + def object(self) -> Optional[str]: + return self.response.get('object') + + # === 使用统计 === + @property + def usage(self) -> Optional[Dict]: + """Token 使用统计""" + return self.response.get('usage') + + @property + def total_tokens(self) -> int: + """总 token 数""" + return self.usage.get('total_tokens', 0) if self.usage else 0 + + @property + def prompt_tokens(self) -> int: + """提示 token 数""" + return self.usage.get('prompt_tokens', 0) if self.usage else 0 + + @property + def completion_tokens(self) -> int: + """完成 token 数""" + return self.usage.get('completion_tokens', 0) if self.usage else 0 + + # === 消息内容 === + @property + def choices(self) -> list: + """选择列表""" + return self.response.get('choices', []) + + @property + def message(self) -> Optional[Dict]: + """消息内容""" + if self.choices: + return self.choices[0].get('message') + return None + + @property + def content(self) -> str: + """响应内容""" + if self.message: + return self.message.get('content', '') + return '' + + @property + def role(self) -> Optional[str]: + """消息角色""" + if self.message: + return self.message.get('role') + return None + + @property + def finish_reason(self) -> Optional[str]: + """完成原因""" + if self.choices: + return self.choices[0].get('finish_reason') + return None + + # === 流式响应专用 === + @property + def delta(self) -> Optional[Dict]: + """流式响应的 delta""" + if self.choices: + return self.choices[0].get('delta') + return None + + @property + def delta_content(self) -> str: + """流式响应的内容""" + if self.delta: + return self.delta.get('content', '') + return '' + + # === 错误处理 === + @property + def error(self) -> Optional[Dict]: + """错误信息""" + return self.response.get('error') + + @property + def error_message(self) -> Optional[str]: + """错误消息""" + return self.error.get('message') if self.error else None + + @property + def error_type(self) -> Optional[str]: + """错误类型""" + return self.error.get('type') if self.error else None + + @property + def error_code(self) -> Optional[str]: + """错误代码""" + return self.error.get('code') if self.error else None + + # === 调试信息 === + def get_debug_info(self) -> Dict[str, Any]: + """获取调试信息""" + return { + "id": self.id, + "model": self.model, + "created": self.created, + "usage": self.usage, + "finish_reason": self.finish_reason, + "is_valid": self.is_valid(), + "is_stream": self.is_stream() + } + + def print_debug_info(self): + """打印调试信息""" + info = self.get_debug_info() + print("=== Response Debug Info ===") + for key, value in info.items(): + print(f"{key}: {value}") + print("=" * 27) + + # === 魔术方法 === + def __repr__(self) -> str: + if self.is_valid(): + return f"" + else: + return f"" + + def __str__(self) -> str: + return self.content + + def __getitem__(self, key): + return self.response[key] + + def __contains__(self, key): + return key in self.response + +# 为了向后兼容,保留 Resp 类 +Resp = ChatResponse \ No newline at end of file diff --git a/chattool/response.py b/chattool/response.py deleted file mode 100644 index d3fb251..0000000 --- a/chattool/response.py +++ /dev/null @@ -1,141 +0,0 @@ -# Response class for Chattool - -from typing import Dict, Any, Union -from .tokencalc import findcost -import chattool - -class Resp(): - - def __init__(self, response:Union[Dict, Any]) -> None: - if isinstance(response, Dict): - self.response = response - self._raw_response = None - else: - self._raw_response = response - self.response = response.json() - - def get_curl(self): - """Convert the response to a cURL command""" - if self._raw_response is None: - return "No cURL command available" - return chattool.resp2curl(self._raw_response) - - def print_curl(self): - """Print the cURL command""" - print(self.get_curl()) - - def is_valid(self): - """Check if the response is an error""" - return 'error' not in self.response - - def cost(self): - """Calculate the cost of the response(Deprecated)""" - return findcost(self.model, self.prompt_tokens, self.completion_tokens) - - @property - def id(self): - return self['id'] - - @property - def model(self): - return self['model'] - - @property - def created(self): - return self['created'] - - @property - def usage(self): - """Token usage""" - return self['usage'] - - @property - def total_tokens(self): - """Total number of tokens""" - return self.usage['total_tokens'] - - @property - def prompt_tokens(self): - """Number of tokens in the prompt""" - return self.usage['prompt_tokens'] - - @property - def completion_tokens(self): - """Number of tokens of the response""" - return self.usage['completion_tokens'] - - @property - def message(self): - """Message""" - return self['choices'][0]['message'] - - @property - def content(self): - """Content of the response""" - return self.message['content'] - - @property - def function_call(self): - """Function call""" - return self.message.get('function_call') - - @property - def tool_calls(self): - """Tool calls""" - return self.message.get('tool_calls') - - @property - def delta(self): - """Delta""" - return self['choices'][0]['delta'] - - @property - def delta_content(self): - """Content of stream response""" - return self.delta['content'] - - @property - def object(self): - return self['object'] - - @property - def error(self): - """Error""" - return self['error'] - - @property - def error_message(self): - """Error message""" - return self.error['message'] - - @property - def error_type(self): - """Error type""" - return self.error['type'] - - @property - def error_param(self): - """Error parameter""" - return self.error['param'] - - @property - def error_code(self): - """Error code""" - return self.error['code'] - - @property - def finish_reason(self): - """Finish reason""" - return self['choices'][0].get('finish_reason') - - def __repr__(self) -> str: - return "" - - def __str__(self) -> str: - return self.content - - def __getitem__(self, key): - return self.response[key] - - def __contains__(self, key): - return key in self.response \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 05f8e6c..55cff88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest import asyncio -from chattool.core import HTTPClient, Config +from chattool.core import HTTPClient, Config, OpenAIConfig, OpenAIClient from chattool.fastobj.basic import FastAPIManager from chattool.fastobj.capture import app from chattool.tools import ZulipClient @@ -12,6 +12,16 @@ def testpath(): return TEST_PATH +@pytest.fixture(scope="session") +def oai_config(): + """OpenAI 配置""" + return OpenAIConfig() + +@pytest.fixture(scope="session") +def oai_client(oai_config): + """OpenAI 客户端""" + return OpenAIClient(oai_config) + @pytest.fixture(scope="session", autouse=True) def fastapi_server(): """在整个测试会话期间启动 FastAPI 服务""" diff --git a/tests/test_oaiclient.py b/tests/test_oaiclient.py new file mode 100644 index 0000000..5dfd3df --- /dev/null +++ b/tests/test_oaiclient.py @@ -0,0 +1,359 @@ +import pytest +import json +from unittest.mock import Mock, patch +from typing import Dict, Any +from chattool.core.config import OpenAIConfig +from chattool.core.request import OpenAIClient + +@pytest.fixture +def mock_response_data(): + """模拟的 OpenAI API 响应数据""" + return { + "id": "chatcmpl-test123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I help you today?" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 9, + "total_tokens": 19 + } + } + + +@pytest.fixture +def mock_stream_response_data(): + """模拟的流式响应数据""" + return [ + { + "id": "chatcmpl-test123", + "object": "chat.completion.chunk", + "created": 1677652288, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None + } + ] + }, + { + "id": "chatcmpl-test123", + "object": "chat.completion.chunk", + "created": 1677652288, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": None + } + ] + }, + { + "id": "chatcmpl-test123", + "object": "chat.completion.chunk", + "created": 1677652288, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "delta": {"content": "!"}, + "finish_reason": "stop" + } + ] + } + ] + + +class TestOpenAIConfig: + """测试 OpenAI 配置类""" + + def test_config_initialization(self): + """测试配置初始化""" + config = OpenAIConfig() + assert config.api_base + assert config.model + assert 'Authorization' in config.headers + + def test_config_with_custom_values(self): + """测试自定义配置值""" + config = OpenAIConfig( + model="gpt-4", + temperature=0.5, + max_tokens=1000 + ) + assert config.model == "gpt-4" + assert config.temperature == 0.5 + assert config.max_tokens == 1000 + + def test_config_get_method(self): + """测试配置的 get 方法""" + config = OpenAIConfig(temperature=0.7) + assert config.get('temperature') == 0.7 + assert config.get('nonexistent') is None + assert config.get('nonexistent', 'default') == 'default' + + +class TestOpenAIClient: + """测试 OpenAI 客户端类""" + + def test_client_initialization(self, oai_config): + """测试客户端初始化""" + client = OpenAIClient(oai_config) + assert client.config == oai_config + assert client._sync_client is None + assert client._async_client is None + + def test_client_with_default_config(self): + """测试使用默认配置初始化客户端""" + client = OpenAIClient() + assert isinstance(client.config, OpenAIConfig) + + def test_build_chat_data_basic(self, oai_client): + """测试基础聊天数据构建""" + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data(messages) + + assert data["messages"] == messages + assert "model" in data # 应该从 config 中获取 + + def test_build_chat_data_with_kwargs(self, oai_client): + """测试带 kwargs 的聊天数据构建""" + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data( + messages, + temperature=0.5, + max_tokens=100 + ) + + assert data["messages"] == messages + assert data["temperature"] == 0.5 + assert data["max_tokens"] == 100 + + def test_build_chat_data_excludes_config_attrs(self, oai_client): + """测试构建数据时排除配置专用属性""" + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data(messages) + + # 这些属性不应该出现在 API 请求数据中 + for attr in oai_client._config_only_attrs: + assert attr not in data + + def test_get_param_value_priority(self, oai_client): + """测试参数值优先级""" + # kwargs 优先于 config + kwargs = {"temperature": 0.8} + value = oai_client._get_param_value("temperature", kwargs) + assert value == 0.8 + + # 没有 kwargs 时使用 config + value = oai_client._get_param_value("model", {}) + assert value == oai_client.config.model + + # 都没有时返回 None + value = oai_client._get_param_value("nonexistent", {}) + assert value is None + + @patch('chattool.core.request.OpenAIClient.post') + def test_chat_completion_sync(self, mock_post, oai_client, mock_response_data): + """测试同步聊天完成""" + # 设置 mock + mock_response = Mock() + mock_response.json.return_value = mock_response_data + mock_post.return_value = mock_response + + messages = [{"role": "user", "content": "Hello"}] + result = oai_client.chat_completion(messages) + + # 验证调用 + mock_post.assert_called_once() + call_args = mock_post.call_args + assert call_args[0][0] == "/chat/completions" + assert "data" in call_args[1] + assert call_args[1]["data"]["messages"] == messages + + # 验证返回值 + assert result == mock_response_data + + @patch('chattool.core.request.OpenAIClient.async_post') + @pytest.mark.asyncio + async def test_chat_completion_async(self, mock_async_post, oai_client, mock_response_data): + """测试异步聊天完成""" + # 设置 mock + mock_response = Mock() + mock_response.json.return_value = mock_response_data + mock_async_post.return_value = mock_response + + messages = [{"role": "user", "content": "Hello"}] + result = await oai_client.async_chat_completion(messages) + + # 验证调用 + mock_async_post.assert_called_once() + call_args = mock_async_post.call_args + assert call_args[0][0] == "/chat/completions" + assert "data" in call_args[1] + assert call_args[1]["data"]["messages"] == messages + + # 验证返回值 + assert result == mock_response_data + + def test_chat_completion_parameter_override(self, oai_client): + """测试参数覆盖功能""" + messages = [{"role": "user", "content": "Hello"}] + + with patch.object(oai_client, 'post') as mock_post: + mock_response = Mock() + mock_response.json.return_value = {} + mock_post.return_value = mock_response + + # 调用时覆盖配置参数 + oai_client.chat_completion( + messages, + model="gpt-4", + temperature=0.2, + custom_param="test" + ) + + # 验证数据中包含覆盖的参数 + call_data = mock_post.call_args[1]["data"] + assert call_data["model"] == "gpt-4" + assert call_data["temperature"] == 0.2 + assert call_data["custom_param"] == "test" + + @patch('chattool.core.request.OpenAIClient._stream_chat_completion') + def test_chat_completion_stream_mode(self, mock_stream, oai_client): + """测试流式模式""" + messages = [{"role": "user", "content": "Hello"}] + mock_stream.return_value = iter([]) + + result = oai_client.chat_completion(messages, stream=True) + + # 验证调用了流式方法 + mock_stream.assert_called_once() + call_data = mock_stream.call_args[0][0] + assert call_data["stream"] is True + + @patch('chattool.core.request.OpenAIClient.post') + def test_embeddings(self, mock_post, oai_client): + """测试嵌入 API""" + mock_response_data = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 5, "total_tokens": 5} + } + + mock_response = Mock() + mock_response.json.return_value = mock_response_data + mock_post.return_value = mock_response + + result = oai_client.embeddings("test text") + + # 验证调用 + mock_post.assert_called_once() + call_args = mock_post.call_args + assert call_args[0][0] == "/embeddings" + assert call_args[1]["data"]["input"] == "test text" + + # 验证返回值 + assert result == mock_response_data + + @patch('chattool.core.request.OpenAIClient.async_post') + @pytest.mark.asyncio + async def test_embeddings_async(self, mock_async_post, oai_client): + """测试异步嵌入 API""" + mock_response_data = { + "object": "list", + "data": [{"embedding": [0.1, 0.2]}], + "model": "text-embedding-ada-002" + } + + mock_response = Mock() + mock_response.json.return_value = mock_response_data + mock_async_post.return_value = mock_response + + result = await oai_client.async_embeddings(["text1", "text2"]) + + # 验证调用 + mock_async_post.assert_called_once() + call_args = mock_async_post.call_args + assert call_args[0][0] == "/embeddings" + assert call_args[1]["data"]["input"] == ["text1", "text2"] + + # 验证返回值 + assert result == mock_response_data + + +class TestParameterHandling: + """测试参数处理逻辑""" + + def test_none_values_excluded(self, oai_client): + """测试 None 值被排除""" + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data( + messages, + temperature=None, + max_tokens=100, + top_p=None + ) + + assert "temperature" not in data + assert "top_p" not in data + assert data["max_tokens"] == 100 + + def test_config_fallback(self, oai_client): + """测试配置回退机制""" + # 设置一些配置值 + oai_client.config.temperature = 0.7 + oai_client.config.custom_param = "config_value" + + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data(messages) + + # 应该使用配置中的值 + assert data["temperature"] == 0.7 + assert data["custom_param"] == "config_value" + + def test_kwargs_override_config(self, oai_client): + """测试 kwargs 覆盖配置""" + # 设置配置值 + oai_client.config.temperature = 0.7 + + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data( + messages, + temperature=0.2 # 覆盖配置 + ) + + # 应该使用 kwargs 中的值 + assert data["temperature"] == 0.2 + + def test_unknown_parameters_included(self, oai_client): + """测试未知参数也会被包含""" + messages = [{"role": "user", "content": "Hello"}] + data = oai_client._build_chat_data( + messages, + future_param="new_feature", + another_param={"complex": "value"} + ) + + assert data["future_param"] == "new_feature" + assert data["another_param"] == {"complex": "value"} From 10eb0f6302c88b9d5136f8cd8b76e0492fdcb362 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 30 Jun 2025 04:15:23 +0800 Subject: [PATCH 2/7] add azure api --- chattool/core/__init__.py | 4 +- chattool/core/config.py | 54 ++++- chattool/core/request.py | 240 ++++++++++++++++++- tests/conftest.py | 17 +- tests/test_azureclient.py | 470 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 769 insertions(+), 16 deletions(-) create mode 100644 tests/test_azureclient.py diff --git a/chattool/core/__init__.py b/chattool/core/__init__.py index d86c2c4..cc0173e 100644 --- a/chattool/core/__init__.py +++ b/chattool/core/__init__.py @@ -1,2 +1,2 @@ -from chattool.core.config import Config, OpenAIConfig -from chattool.core.request import HTTPClient, OpenAIClient +from chattool.core.config import Config, OpenAIConfig, AzureOpenAIConfig +from chattool.core.request import HTTPClient, OpenAIClient, AzureOpenAIClient diff --git a/chattool/core/config.py b/chattool/core/config.py index 05ccd8c..fffcaaa 100644 --- a/chattool/core/config.py +++ b/chattool/core/config.py @@ -88,14 +88,58 @@ def __post__init__(self): "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } + +# Azure OpenAI 专用配置 +class AzureOpenAIConfig(Config): + def __init__( + self, + api_version: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + stop: Optional[list] = None, + **kwargs + ): + super().__init__(**kwargs) + # Azure 特定参数 + self.api_version = api_version + # OpenAI 兼容参数 + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.stop = stop + + def __post__init__(self): + # 从环境变量获取 Azure 配置 + if not self.api_key: + self.api_key = os.getenv("AZURE_OPENAI_API_KEY", "") + if not self.api_base: + self.api_base = os.getenv("AZURE_OPENAI_ENDPOINT", "") + if not self.api_version: + self.api_version = os.getenv("AZURE_OPENAI_API_VERSION") + if not self.model: + self.model = os.getenv("AZURE_OPENAI_API_MODEL", "") + # Azure 使用不同的请求头格式 + if not self.headers: + self.headers = { + "Content-Type": "application/json", + } + + def get_request_url(self) -> str: + """获取带 API 密钥的请求 URL""" + endpoint = self.api_base.rstrip('/') + if self.api_version: + endpoint = f"{endpoint}?api-version={self.api_version}" + if self.api_key: + endpoint = f"{endpoint}&ak={self.api_key}" + return endpoint # Anthropic 配置示例 class AnthropicConfig(Config): pass -# Azure OpenAI 配置示例 -class AzureConfig(Config): - pass - - diff --git a/chattool/core/request.py b/chattool/core/request.py index c3a16ea..3c5c434 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -2,10 +2,11 @@ import asyncio import logging import time -from chattool.core.config import Config, OpenAIConfig -from chattool.custom_logger import setup_logger -from typing import Generator, AsyncGenerator, Union, Dict, Any, Optional, List import json +import hashlib +from typing import Dict, List, Optional, Union, Generator, AsyncGenerator, Any +from chattool.core.config import Config, OpenAIConfig, AzureOpenAIConfig +from chattool.custom_logger import setup_logger # 基础HTTP客户端类 class HTTPClient: @@ -228,11 +229,6 @@ def __str__(self): def __repr__(self): return f"StreamResponse(content='{self.content}', finish_reason='{self.finish_reason}')" -import json -from typing import Dict, List, Optional, Union, Generator, AsyncGenerator, Any -from chattool.core.config import OpenAIConfig -from chattool.core.request import HTTPClient, StreamResponse - class OpenAIClient(HTTPClient): _config_only_attrs = { 'api_key', 'api_base', 'headers', 'timeout', @@ -514,3 +510,231 @@ async def async_embeddings( def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: """处理流式响应的单个数据块,返回 StreamResponse 对象""" return StreamResponse(chunk_data) + +class AzureOpenAIClient(HTTPClient): + """Azure OpenAI 客户端""" + + _config_only_attrs = { + 'api_key', 'api_base', 'api_version', + 'headers', 'timeout', 'max_retries', 'retry_delay' + } + + def __init__(self, config: Optional[AzureOpenAIConfig] = None, logger = None, **kwargs): + if config is None: + config = AzureOpenAIConfig() + super().__init__(config, logger, **kwargs) + + def _generate_log_id(self, messages: List[Dict[str, str]]) -> str: + """生成请求的 log ID""" + content = str(messages).encode() + return hashlib.sha256(content).hexdigest() + + def _build_azure_headers(self, messages: List[Dict[str, str]]) -> Dict[str, str]: + """构建 Azure 专用请求头""" + headers = self.config.headers.copy() + headers['X-TT-LOGID'] = self._generate_log_id(messages) + return headers + + def _build_chat_data(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]: + """构建聊天完成请求的数据 - Azure 版本""" + data = {"messages": messages} + + # 处理所有可能的参数 + all_params = set(kwargs.keys()) | { + k for k in self.config.__dict__.keys() + if not k.startswith('_') + } + + for param_name in all_params: + # 跳过配置专用属性 + if param_name in self._config_only_attrs: + continue + + value = self._get_param_value(param_name, kwargs) + if value is not None: + data[param_name] = value + + # 确保 stream 默认为 False + if 'stream' not in data: + data['stream'] = False + + return data + + def _get_param_value(self, param_name: str, kwargs: Dict[str, Any]): + """按优先级获取参数值:kwargs > config > None""" + if param_name in kwargs: + return kwargs[param_name] + return self.config.get(param_name) + + def chat_completion( + self, + messages: List[Dict[str, str]], + model: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + stream: bool = False, + **kwargs + ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: + """ + Azure OpenAI Chat Completion API (同步版本) + + Args: + messages: 对话消息列表 + model: 模型名称(Azure 中通常是 deployment name) + temperature: 温度参数 + top_p: top_p 参数 + max_tokens: 最大token数 + stream: 是否使用流式响应 + **kwargs: 其他参数 + """ + # 将显式参数合并到 kwargs 中 + all_kwargs = { + 'model': model, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'stream': stream, + **kwargs + } + + # 构建请求数据和头 + data = self._build_chat_data(messages, **all_kwargs) + headers = self._build_azure_headers(messages) + + if data.get('stream'): + return self._stream_chat_completion(data, headers) + + # Azure 使用特殊的 URL 格式 + url = self.config.get_request_url() + + # 直接使用 requests 发送请求(因为 Azure 的 URL 格式特殊) + import requests + response = requests.post(url, headers=headers, json=data, timeout=self.config.timeout or None) + response.raise_for_status() + return response.json() + + async def async_chat_completion( + self, + messages: List[Dict[str, str]], + model: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + stream: bool = False, + **kwargs + ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: + """Azure OpenAI Chat Completion API (异步版本)""" + import aiohttp + + all_kwargs = { + 'model': model, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'stream': stream, + **kwargs + } + + data = self._build_chat_data(messages, **all_kwargs) + headers = self._build_azure_headers(messages) + + if data.get('stream'): + return self._async_stream_chat_completion(data, headers) + + url = self.config.get_request_url() + + timeout = aiohttp.ClientTimeout(total=self.config.timeout) if self.config.timeout > 0 else None + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() + return await response.json() + + def _stream_chat_completion(self, data: Dict[str, Any], headers: Dict[str, str]) -> Generator[StreamResponse, None, None]: + """Azure 同步流式聊天完成""" + import requests + + url = self.config.get_request_url() + + with requests.post(url, headers=headers, json=data, stream=True) as response: + response.raise_for_status() + + for line in response.iter_lines(): + if not line: + continue + + line_str = line.decode('utf-8').strip() + + if not line_str.startswith('data: '): + continue + + data_str = line_str[6:].strip() + + if data_str == '[DONE]': + break + + if not data_str: + continue + + try: + chunk_data = json.loads(data_str) + stream_response = self._process_stream_chunk(chunk_data) + yield stream_response + + if stream_response.is_finished: + break + + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") + continue + except Exception as e: + self.logger.error(f"Error processing stream chunk: {e}") + break + + async def _async_stream_chat_completion(self, data: Dict[str, Any], headers: Dict[str, str]) -> AsyncGenerator[StreamResponse, None]: + """Azure 异步流式聊天完成""" + import aiohttp + + url = self.config.get_request_url() + timeout = aiohttp.ClientTimeout(total=self.config.timeout) if self.config.timeout > 0 else None + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() + + async for line in response.content: + if not line: + continue + + line_str = line.decode('utf-8').strip() + + if not line_str.startswith('data: '): + continue + + data_str = line_str[6:].strip() + + if data_str == '[DONE]': + break + + if not data_str: + continue + + try: + chunk_data = json.loads(data_str) + stream_response = self._process_stream_chunk(chunk_data) + yield stream_response + + if stream_response.is_finished: + break + + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") + continue + except Exception as e: + self.logger.error(f"Error processing stream chunk: {e}") + break + + def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: + """处理流式响应的单个数据块""" + return StreamResponse(chunk_data) diff --git a/tests/conftest.py b/tests/conftest.py index 55cff88..3456462 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest import asyncio -from chattool.core import HTTPClient, Config, OpenAIConfig, OpenAIClient +from chattool.core import HTTPClient, Config, OpenAIConfig, OpenAIClient, AzureOpenAIConfig, AzureOpenAIClient from chattool.fastobj.basic import FastAPIManager from chattool.fastobj.capture import app from chattool.tools import ZulipClient @@ -22,6 +22,21 @@ def oai_client(oai_config): """OpenAI 客户端""" return OpenAIClient(oai_config) +@pytest.fixture(scope="session") +def azure_config(): + """Azure OpenAI 配置 fixture""" + return AzureOpenAIConfig( + api_key="test-azure-key", + api_base="https://test-resource.openai.azure.com", + api_version="2024-02-15-preview", + model="gpt-35-turbo" + ) + +@pytest.fixture +def azure_client(azure_config): + """Azure OpenAI 客户端 fixture""" + return AzureOpenAIClient(azure_config) + @pytest.fixture(scope="session", autouse=True) def fastapi_server(): """在整个测试会话期间启动 FastAPI 服务""" diff --git a/tests/test_azureclient.py b/tests/test_azureclient.py new file mode 100644 index 0000000..2108fbe --- /dev/null +++ b/tests/test_azureclient.py @@ -0,0 +1,470 @@ +import pytest +import json +import os +from unittest.mock import Mock, patch, MagicMock +from typing import Dict, Any + +from chattool.core.config import AzureOpenAIConfig +from chattool.core.request import AzureOpenAIClient + + +@pytest.fixture +def mock_azure_response_data(): + """模拟的 Azure OpenAI API 响应数据""" + return { + "id": "chatcmpl-azure-test123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! This is from Azure OpenAI." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 8, + "total_tokens": 20 + } + } + + +@pytest.fixture +def mock_azure_stream_response_data(): + """模拟的 Azure 流式响应数据""" + return [ + { + "id": "chatcmpl-azure-test123", + "object": "chat.completion.chunk", + "created": 1677652288, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None + } + ] + }, + { + "id": "chatcmpl-azure-test123", + "object": "chat.completion.chunk", + "created": 1677652288, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": None + } + ] + }, + { + "id": "chatcmpl-azure-test123", + "object": "chat.completion.chunk", + "created": 1677652288, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "delta": {"content": " from Azure!"}, + "finish_reason": "stop" + } + ] + } + ] + + +class TestAzureOpenAIConfig: + """测试 Azure OpenAI 配置类""" + + def test_config_initialization(self): + """测试配置初始化""" + config = AzureOpenAIConfig( + api_key="test-key", + api_base="https://test.openai.azure.com", + api_version="2024-02-15-preview", + model="gpt-35-turbo" + ) + assert config.api_key == "test-key" + assert config.api_base == "https://test.openai.azure.com" + assert config.api_version == "2024-02-15-preview" + assert config.model == "gpt-35-turbo" + assert config.headers["Content-Type"] == "application/json" + + def test_config_with_custom_parameters(self): + """测试自定义参数配置""" + config = AzureOpenAIConfig( + temperature=0.5, + max_tokens=1000, + top_p=0.9 + ) + assert config.temperature == 0.5 + assert config.max_tokens == 1000 + assert config.top_p == 0.9 + + @patch.dict(os.environ, { + "AZURE_OPENAI_API_KEY": "env-key", + "AZURE_OPENAI_ENDPOINT": "https://env-resource.openai.azure.com", + "AZURE_OPENAI_API_VERSION": "2024-03-01-preview", + "AZURE_OPENAI_API_MODEL": "gpt-4" + }) + def test_config_from_environment(self): + """测试从环境变量获取配置""" + config = AzureOpenAIConfig() + assert config.api_key == "env-key" + assert config.api_base == "https://env-resource.openai.azure.com" + assert config.api_version == "2024-03-01-preview" + assert config.model == "gpt-4" + + def test_get_request_url_basic(self): + """测试基础请求 URL 构建""" + config = AzureOpenAIConfig( + api_base="https://test.openai.azure.com", + api_version="2024-02-15-preview", + api_key="test-key" + ) + url = config.get_request_url() + expected = "https://test.openai.azure.com?api-version=2024-02-15-preview&ak=test-key" + assert url == expected + + def test_get_request_url_no_version(self): + """测试没有 API 版本的 URL 构建""" + config = AzureOpenAIConfig( + api_base="https://test.openai.azure.com", + api_key="test-key" + ) + url = config.get_request_url() + expected = "https://test.openai.azure.com&ak=test-key" + assert url == expected + + def test_get_request_url_no_key(self): + """测试没有 API 密钥的 URL 构建""" + config = AzureOpenAIConfig( + api_base="https://test.openai.azure.com", + api_version="2024-02-15-preview" + ) + url = config.get_request_url() + expected = "https://test.openai.azure.com?api-version=2024-02-15-preview" + assert url == expected + + def test_config_get_method(self): + """测试配置的 get 方法""" + config = AzureOpenAIConfig(temperature=0.7, api_version="2024-02-15-preview") + assert config.get('temperature') == 0.7 + assert config.get('api_version') == "2024-02-15-preview" + assert config.get('nonexistent') is None + assert config.get('nonexistent', 'default') == 'default' + + +class TestAzureOpenAIClient: + """测试 Azure OpenAI 客户端类""" + + def test_client_initialization(self, azure_config): + """测试客户端初始化""" + client = AzureOpenAIClient(azure_config) + assert client.config == azure_config + assert hasattr(client, '_config_only_attrs') + assert 'api_version' in client._config_only_attrs + + def test_client_with_default_config(self): + """测试使用默认配置初始化客户端""" + client = AzureOpenAIClient() + assert isinstance(client.config, AzureOpenAIConfig) + + def test_generate_log_id(self, azure_client): + """测试 Log ID 生成""" + messages = [{"role": "user", "content": "Hello"}] + log_id_1 = azure_client._generate_log_id(messages) + log_id_2 = azure_client._generate_log_id(messages) + + # 相同输入应该生成相同的 log ID + assert log_id_1 == log_id_2 + assert len(log_id_1) == 64 # SHA256 输出长度 + + # 不同输入应该生成不同的 log ID + different_messages = [{"role": "user", "content": "Hi"}] + log_id_3 = azure_client._generate_log_id(different_messages) + assert log_id_1 != log_id_3 + + def test_build_azure_headers(self, azure_client): + """测试 Azure 请求头构建""" + messages = [{"role": "user", "content": "Hello"}] + headers = azure_client._build_azure_headers(messages) + + assert "Content-Type" in headers + assert headers["Content-Type"] == "application/json" + assert "X-TT-LOGID" in headers + assert len(headers["X-TT-LOGID"]) == 64 + + def test_build_chat_data_basic(self, azure_client): + """测试基础聊天数据构建""" + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data(messages) + + assert data["messages"] == messages + assert data["stream"] is False # 默认为 False + if azure_client.config.model: + assert "model" in data + + def test_build_chat_data_with_kwargs(self, azure_client): + """测试带 kwargs 的聊天数据构建""" + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data( + messages, + temperature=0.5, + max_tokens=100, + stream=True + ) + + assert data["messages"] == messages + assert data["temperature"] == 0.5 + assert data["max_tokens"] == 100 + assert data["stream"] is True + + def test_build_chat_data_excludes_config_attrs(self, azure_client): + """测试构建数据时排除配置专用属性""" + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data(messages) + + # 这些属性不应该出现在 API 请求数据中 + for attr in azure_client._config_only_attrs: + assert attr not in data + + def test_get_param_value_priority(self, azure_client): + """测试参数值优先级""" + # kwargs 优先于 config + kwargs = {"temperature": 0.8} + value = azure_client._get_param_value("temperature", kwargs) + assert value == 0.8 + + # 没有 kwargs 时使用 config + azure_client.config.temperature = 0.5 + value = azure_client._get_param_value("temperature", {}) + assert value == 0.5 + + # 都没有时返回 None + value = azure_client._get_param_value("nonexistent", {}) + assert value is None + + @patch('requests.post') + def test_chat_completion_sync(self, mock_post, azure_client, mock_azure_response_data): + """测试同步聊天完成""" + # 设置 mock + mock_response = Mock() + mock_response.json.return_value = mock_azure_response_data + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + messages = [{"role": "user", "content": "Hello"}] + result = azure_client.chat_completion(messages) + + # 验证调用 + mock_post.assert_called_once() + call_args = mock_post.call_args + + # 验证 URL 包含 Azure 特定格式 + url = call_args[0][0] + assert "api-version" in url + assert "ak=" in url + + # 验证请求数据 + assert "json" in call_args[1] + json_data = call_args[1]["json"] + assert json_data["messages"] == messages + + # 验证请求头 + assert "headers" in call_args[1] + headers = call_args[1]["headers"] + assert "X-TT-LOGID" in headers + + # 验证返回值 + assert result == mock_azure_response_data + + @patch('aiohttp.ClientSession.post') + @pytest.mark.asyncio + async def test_chat_completion_async(self, mock_post, azure_client, mock_azure_response_data): + """测试异步聊天完成""" + # 设置 mock + mock_response = Mock() + mock_response.json = Mock(return_value=mock_azure_response_data) + mock_response.raise_for_status = Mock() + + # 设置异步上下文管理器 + mock_context = Mock() + mock_context.__aenter__ = Mock(return_value=mock_response) + mock_context.__aexit__ = Mock(return_value=None) + mock_post.return_value = mock_context + + # 模拟 ClientSession + with patch('aiohttp.ClientSession') as mock_session: + mock_session_instance = Mock() + mock_session_instance.post.return_value = mock_context + mock_session_instance.__aenter__ = Mock(return_value=mock_session_instance) + mock_session_instance.__aexit__ = Mock(return_value=None) + mock_session.return_value = mock_session_instance + + messages = [{"role": "user", "content": "Hello"}] + result = await azure_client.async_chat_completion(messages) + + # 验证返回值 + assert result == mock_azure_response_data + + def test_chat_completion_parameter_override(self, azure_client): + """测试参数覆盖功能""" + messages = [{"role": "user", "content": "Hello"}] + + with patch('requests.post') as mock_post: + mock_response = Mock() + mock_response.json.return_value = {} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # 调用时覆盖配置参数 + azure_client.chat_completion( + messages, + model="gpt-4", + temperature=0.2, + custom_param="test" + ) + + # 验证数据中包含覆盖的参数 + call_args = mock_post.call_args + json_data = call_args[1]["json"] + assert json_data["model"] == "gpt-4" + assert json_data["temperature"] == 0.2 + assert json_data["custom_param"] == "test" + + @patch('requests.post') + def test_simple_request(self, mock_post, azure_client, mock_azure_response_data): + """测试简化的请求方法""" + # 设置 mock + mock_response = Mock() + mock_response.json.return_value = mock_azure_response_data + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # 测试字符串输入 + result = azure_client.simple_request("Hello, how are you?") + expected_content = mock_azure_response_data['choices'][0]['message']['content'] + assert result == expected_content + + # 验证自动转换为消息格式 + call_args = mock_post.call_args + json_data = call_args[1]["json"] + expected_messages = [{"role": "user", "content": "Hello, how are you?"}] + assert json_data["messages"] == expected_messages + + @patch('requests.post') + def test_simple_request_with_messages(self, mock_post, azure_client, mock_azure_response_data): + """测试简化请求方法使用消息列表""" + mock_response = Mock() + mock_response.json.return_value = mock_azure_response_data + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + + result = azure_client.simple_request(messages, model_name="gpt-4") + expected_content = mock_azure_response_data['choices'][0]['message']['content'] + assert result == expected_content + + # 验证消息直接传递 + call_args = mock_post.call_args + json_data = call_args[1]["json"] + assert json_data["messages"] == messages + + +class TestAzureParameterHandling: + """测试 Azure 参数处理逻辑""" + + def test_none_values_excluded(self, azure_client): + """测试 None 值被排除""" + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data( + messages, + temperature=None, + max_tokens=100, + top_p=None + ) + + assert "temperature" not in data + assert "top_p" not in data + assert data["max_tokens"] == 100 + + def test_config_fallback(self, azure_client): + """测试配置回退机制""" + # 设置一些配置值 + azure_client.config.temperature = 0.7 + azure_client.config.custom_param = "config_value" + + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data(messages) + + # 应该使用配置中的值 + assert data["temperature"] == 0.7 + assert data["custom_param"] == "config_value" + + def test_kwargs_override_config(self, azure_client): + """测试 kwargs 覆盖配置""" + # 设置配置值 + azure_client.config.temperature = 0.7 + + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data( + messages, + temperature=0.2 # 覆盖配置 + ) + + # 应该使用 kwargs 中的值 + assert data["temperature"] == 0.2 + + def test_stream_default_false(self, azure_client): + """测试 stream 参数默认为 False""" + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data(messages) + + assert data["stream"] is False + + def test_unknown_parameters_included(self, azure_client): + """测试未知参数也会被包含""" + messages = [{"role": "user", "content": "Hello"}] + data = azure_client._build_chat_data( + messages, + future_param="new_feature", + another_param={"complex": "value"} + ) + + assert data["future_param"] == "new_feature" + assert data["another_param"] == {"complex": "value"} + + +class TestAzureURLConstruction: + """测试 Azure URL 构建逻辑""" + + def test_url_with_all_params(self): + """测试包含所有参数的 URL 构建""" + config = AzureOpenAIConfig( + api_base="https://test.openai.azure.com/", # 测试末尾斜杠处理 + api_version="2024-02-15-preview", + api_key="test-key" + ) + url = config.get_request_url() + expected = "https://test.openai.azure.com?api-version=2024-02-15-preview&ak=test-key" + assert url == expected + + def test_url_minimal_params(self): + """测试最少参数的 URL 构建""" + config = AzureOpenAIConfig(api_base="https://test.openai.azure.com") + url = config.get_request_url() + expected = "https://test.openai.azure.com" + assert url == expected From 39a188ece8d75e3eea3aafb685bb464affe5af13 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 30 Jun 2025 04:44:53 +0800 Subject: [PATCH 3/7] add azure api --- chattool/core/config.py | 15 +- chattool/core/request.py | 184 ++++++--------- tests/test_azureclient.py | 458 +++++++++++++++++--------------------- 3 files changed, 280 insertions(+), 377 deletions(-) diff --git a/chattool/core/config.py b/chattool/core/config.py index fffcaaa..d34a53a 100644 --- a/chattool/core/config.py +++ b/chattool/core/config.py @@ -18,17 +18,12 @@ def __init__( self.api_key = api_key self.api_base = api_base self.model = model - self.headers = headers + self.headers = headers or {} self.timeout = timeout self.max_retries = max_retries self.retry_delay = retry_delay for key, value in kwargs.items(): setattr(self, key, value) - self.__post__init__() - - def __post__init__(self): - if self.headers is None: - self.headers = {"Content-Type": "application/json"} def __repr__(self): return ( @@ -75,8 +70,7 @@ def __init__( self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.stop = stop - - def __post__init__(self): + if not self.api_key: self.api_key = os.getenv("OPENAI_API_KEY", "") if not self.api_base: @@ -113,7 +107,6 @@ def __init__( self.presence_penalty = presence_penalty self.stop = stop - def __post__init__(self): # 从环境变量获取 Azure 配置 if not self.api_key: self.api_key = os.getenv("AZURE_OPENAI_API_KEY", "") @@ -130,14 +123,14 @@ def __post__init__(self): "Content-Type": "application/json", } - def get_request_url(self) -> str: + # update api_base """获取带 API 密钥的请求 URL""" endpoint = self.api_base.rstrip('/') if self.api_version: endpoint = f"{endpoint}?api-version={self.api_version}" if self.api_key: endpoint = f"{endpoint}&ak={self.api_key}" - return endpoint + self.api_base = endpoint # Anthropic 配置示例 class AnthropicConfig(Config): diff --git a/chattool/core/request.py b/chattool/core/request.py index 3c5c434..01bbcf6 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -511,7 +511,7 @@ def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: """处理流式响应的单个数据块,返回 StreamResponse 对象""" return StreamResponse(chunk_data) -class AzureOpenAIClient(HTTPClient): +class AzureOpenAIClient(OpenAIClient): """Azure OpenAI 客户端""" _config_only_attrs = { @@ -535,37 +535,6 @@ def _build_azure_headers(self, messages: List[Dict[str, str]]) -> Dict[str, str] headers['X-TT-LOGID'] = self._generate_log_id(messages) return headers - def _build_chat_data(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]: - """构建聊天完成请求的数据 - Azure 版本""" - data = {"messages": messages} - - # 处理所有可能的参数 - all_params = set(kwargs.keys()) | { - k for k in self.config.__dict__.keys() - if not k.startswith('_') - } - - for param_name in all_params: - # 跳过配置专用属性 - if param_name in self._config_only_attrs: - continue - - value = self._get_param_value(param_name, kwargs) - if value is not None: - data[param_name] = value - - # 确保 stream 默认为 False - if 'stream' not in data: - data['stream'] = False - - return data - - def _get_param_value(self, param_name: str, kwargs: Dict[str, Any]): - """按优先级获取参数值:kwargs > config > None""" - if param_name in kwargs: - return kwargs[param_name] - return self.config.get(param_name) - def chat_completion( self, messages: List[Dict[str, str]], @@ -576,19 +545,8 @@ def chat_completion( stream: bool = False, **kwargs ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: - """ - Azure OpenAI Chat Completion API (同步版本) - - Args: - messages: 对话消息列表 - model: 模型名称(Azure 中通常是 deployment name) - temperature: 温度参数 - top_p: top_p 参数 - max_tokens: 最大token数 - stream: 是否使用流式响应 - **kwargs: 其他参数 - """ - # 将显式参数合并到 kwargs 中 + """Azure OpenAI Chat Completion API (同步版本)""" + # 复用父类的参数处理逻辑 all_kwargs = { 'model': model, 'temperature': temperature, @@ -598,20 +556,14 @@ def chat_completion( **kwargs } - # 构建请求数据和头 data = self._build_chat_data(messages, **all_kwargs) - headers = self._build_azure_headers(messages) + azure_headers = self._build_azure_headers(messages) if data.get('stream'): - return self._stream_chat_completion(data, headers) - - # Azure 使用特殊的 URL 格式 - url = self.config.get_request_url() + return self._stream_chat_completion(data, azure_headers) - # 直接使用 requests 发送请求(因为 Azure 的 URL 格式特殊) - import requests - response = requests.post(url, headers=headers, json=data, timeout=self.config.timeout or None) - response.raise_for_status() + # 使用父类的 post 方法,但传入 Azure 特殊的请求头 + response = self.post("", data=data, headers=azure_headers) return response.json() async def async_chat_completion( @@ -625,8 +577,6 @@ async def async_chat_completion( **kwargs ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: """Azure OpenAI Chat Completion API (异步版本)""" - import aiohttp - all_kwargs = { 'model': model, 'temperature': temperature, @@ -637,27 +587,29 @@ async def async_chat_completion( } data = self._build_chat_data(messages, **all_kwargs) - headers = self._build_azure_headers(messages) + azure_headers = self._build_azure_headers(messages) if data.get('stream'): - return self._async_stream_chat_completion(data, headers) - - url = self.config.get_request_url() - - timeout = aiohttp.ClientTimeout(total=self.config.timeout) if self.config.timeout > 0 else None + return self._async_stream_chat_completion(data, azure_headers) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, headers=headers, json=data) as response: - response.raise_for_status() - return await response.json() + # 使用父类的 async_post 方法 + response = await self.async_post("", data=data, headers=azure_headers) + return response.json() - def _stream_chat_completion(self, data: Dict[str, Any], headers: Dict[str, str]) -> Generator[StreamResponse, None, None]: + def _stream_chat_completion(self, data: Dict[str, Any], azure_headers: Dict[str, str]) -> Generator[StreamResponse, None, None]: """Azure 同步流式聊天完成""" - import requests + client = self._get_sync_client() - url = self.config.get_request_url() + # 合并请求头 + merged_headers = self.config.headers.copy() + merged_headers.update(azure_headers) - with requests.post(url, headers=headers, json=data, stream=True) as response: + with client.stream( + "POST", + "", # 空字符串,因为 base_url 已经是完整地址 + json=data, + headers=merged_headers + ) as response: response.raise_for_status() for line in response.iter_lines(): @@ -692,49 +644,61 @@ def _stream_chat_completion(self, data: Dict[str, Any], headers: Dict[str, str]) self.logger.error(f"Error processing stream chunk: {e}") break - async def _async_stream_chat_completion(self, data: Dict[str, Any], headers: Dict[str, str]) -> AsyncGenerator[StreamResponse, None]: + async def _async_stream_chat_completion(self, data: Dict[str, Any], azure_headers: Dict[str, str]) -> AsyncGenerator[StreamResponse, None]: """Azure 异步流式聊天完成""" - import aiohttp + client = self._get_async_client() - url = self.config.get_request_url() - timeout = aiohttp.ClientTimeout(total=self.config.timeout) if self.config.timeout > 0 else None + # 合并请求头 + merged_headers = self.config.headers.copy() + merged_headers.update(azure_headers) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, headers=headers, json=data) as response: - response.raise_for_status() + async with client.stream( + "POST", + "", # 空字符串,因为 base_url 已经是完整地址 + json=data, + headers=merged_headers + ) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if not line: + continue - async for line in response.content: - if not line: - continue - - line_str = line.decode('utf-8').strip() - - if not line_str.startswith('data: '): - continue - - data_str = line_str[6:].strip() + line_str = line.decode('utf-8').strip() + + if not line_str.startswith('data: '): + continue + + data_str = line_str[6:].strip() + + if data_str == '[DONE]': + break + + if not data_str: + continue + + try: + chunk_data = json.loads(data_str) + stream_response = self._process_stream_chunk(chunk_data) + yield stream_response - if data_str == '[DONE]': + if stream_response.is_finished: break - - if not data_str: - continue - - try: - chunk_data = json.loads(data_str) - stream_response = self._process_stream_chunk(chunk_data) - yield stream_response - if stream_response.is_finished: - break - - except json.JSONDecodeError as e: - self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") - continue - except Exception as e: - self.logger.error(f"Error processing stream chunk: {e}") - break - - def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: - """处理流式响应的单个数据块""" - return StreamResponse(chunk_data) + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") + continue + except Exception as e: + self.logger.error(f"Error processing stream chunk: {e}") + break + + # 简化请求方法 + def simple_request(self, input_text: Union[str, List[Dict[str, str]]], model_name: Optional[str] = None) -> str: + """简化的请求方法,直接返回内容字符串""" + if isinstance(input_text, str): + messages = [{"role": "user", "content": input_text}] + else: + messages = input_text + + response = self.chat_completion(messages, model=model_name) + return response['choices'][0]['message']['content'] \ No newline at end of file diff --git a/tests/test_azureclient.py b/tests/test_azureclient.py index 2108fbe..770cfec 100644 --- a/tests/test_azureclient.py +++ b/tests/test_azureclient.py @@ -1,18 +1,36 @@ import pytest import json -import os from unittest.mock import Mock, patch, MagicMock from typing import Dict, Any +import hashlib from chattool.core.config import AzureOpenAIConfig from chattool.core.request import AzureOpenAIClient +@pytest.fixture(scope="session") +def azure_config(): + """Azure OpenAI 配置 fixture""" + return AzureOpenAIConfig( + api_key="test-azure-key", + api_base="https://test-resource.openai.azure.com", + api_version="2024-02-15-preview", + model="gpt-35-turbo", + temperature=0.7 + ) + + +@pytest.fixture +def azure_client(azure_config): + """Azure OpenAI 客户端 fixture""" + return AzureOpenAIClient(azure_config) + + @pytest.fixture def mock_azure_response_data(): """模拟的 Azure OpenAI API 响应数据""" return { - "id": "chatcmpl-azure-test123", + "id": "chatcmpl-azure123", "object": "chat.completion", "created": 1677652288, "model": "gpt-35-turbo", @@ -21,15 +39,15 @@ def mock_azure_response_data(): "index": 0, "message": { "role": "assistant", - "content": "Hello! This is from Azure OpenAI." + "content": "Hello from Azure! How can I help you today?" }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 12, - "completion_tokens": 8, - "total_tokens": 20 + "completion_tokens": 11, + "total_tokens": 23 } } @@ -39,7 +57,7 @@ def mock_azure_stream_response_data(): """模拟的 Azure 流式响应数据""" return [ { - "id": "chatcmpl-azure-test123", + "id": "chatcmpl-azure123", "object": "chat.completion.chunk", "created": 1677652288, "model": "gpt-35-turbo", @@ -52,27 +70,27 @@ def mock_azure_stream_response_data(): ] }, { - "id": "chatcmpl-azure-test123", + "id": "chatcmpl-azure123", "object": "chat.completion.chunk", "created": 1677652288, "model": "gpt-35-turbo", "choices": [ { "index": 0, - "delta": {"content": "Hello"}, + "delta": {"content": "Hello from Azure"}, "finish_reason": None } ] }, { - "id": "chatcmpl-azure-test123", + "id": "chatcmpl-azure123", "object": "chat.completion.chunk", "created": 1677652288, "model": "gpt-35-turbo", "choices": [ { "index": 0, - "delta": {"content": " from Azure!"}, + "delta": {"content": "!"}, "finish_reason": "stop" } ] @@ -83,81 +101,47 @@ def mock_azure_stream_response_data(): class TestAzureOpenAIConfig: """测试 Azure OpenAI 配置类""" - def test_config_initialization(self): - """测试配置初始化""" - config = AzureOpenAIConfig( - api_key="test-key", - api_base="https://test.openai.azure.com", - api_version="2024-02-15-preview", - model="gpt-35-turbo" - ) - assert config.api_key == "test-key" - assert config.api_base == "https://test.openai.azure.com" - assert config.api_version == "2024-02-15-preview" - assert config.model == "gpt-35-turbo" - assert config.headers["Content-Type"] == "application/json" + def test_config_initialization_with_defaults(self): + """测试使用默认值初始化配置""" + config = AzureOpenAIConfig() + # assert config.api_base # 应该有默认值或从环境变量获取 + assert 'Content-Type' in config.headers - def test_config_with_custom_parameters(self): - """测试自定义参数配置""" + def test_config_initialization_with_custom_values(self): + """测试自定义配置值""" config = AzureOpenAIConfig( + api_key="custom-key", + api_base="https://custom.openai.azure.com", + api_version="2024-01-01", + model="gpt-4", temperature=0.5, - max_tokens=1000, - top_p=0.9 + max_tokens=2000 ) - assert config.temperature == 0.5 - assert config.max_tokens == 1000 - assert config.top_p == 0.9 - - @patch.dict(os.environ, { - "AZURE_OPENAI_API_KEY": "env-key", - "AZURE_OPENAI_ENDPOINT": "https://env-resource.openai.azure.com", - "AZURE_OPENAI_API_VERSION": "2024-03-01-preview", - "AZURE_OPENAI_API_MODEL": "gpt-4" - }) - def test_config_from_environment(self): - """测试从环境变量获取配置""" - config = AzureOpenAIConfig() - assert config.api_key == "env-key" - assert config.api_base == "https://env-resource.openai.azure.com" - assert config.api_version == "2024-03-01-preview" + assert "custom-key" in config.api_base # api_base 应该包含 ak 参数 + assert "2024-01-01" in config.api_base # api_base 应该包含 api-version assert config.model == "gpt-4" + assert config.temperature == 0.5 + assert config.max_tokens == 2000 - def test_get_request_url_basic(self): - """测试基础请求 URL 构建""" - config = AzureOpenAIConfig( - api_base="https://test.openai.azure.com", - api_version="2024-02-15-preview", - api_key="test-key" - ) - url = config.get_request_url() - expected = "https://test.openai.azure.com?api-version=2024-02-15-preview&ak=test-key" - assert url == expected - - def test_get_request_url_no_version(self): - """测试没有 API 版本的 URL 构建""" - config = AzureOpenAIConfig( - api_base="https://test.openai.azure.com", - api_key="test-key" - ) - url = config.get_request_url() - expected = "https://test.openai.azure.com&ak=test-key" - assert url == expected - - def test_get_request_url_no_key(self): - """测试没有 API 密钥的 URL 构建""" + def test_config_api_base_construction(self): + """测试 API 基础 URL 的构建""" config = AzureOpenAIConfig( - api_base="https://test.openai.azure.com", + api_key="test-key", + api_base="https://test.openai.azure.com/", # 带尾部斜杠 api_version="2024-02-15-preview" ) - url = config.get_request_url() - expected = "https://test.openai.azure.com?api-version=2024-02-15-preview" - assert url == expected + + # 验证 URL 构建正确 + assert config.api_base.startswith("https://test.openai.azure.com") + assert "api-version=2024-02-15-preview" in config.api_base + assert "ak=test-key" in config.api_base + assert not config.api_base.endswith("//") # 不应该有双斜杠 def test_config_get_method(self): """测试配置的 get 方法""" - config = AzureOpenAIConfig(temperature=0.7, api_version="2024-02-15-preview") - assert config.get('temperature') == 0.7 - assert config.get('api_version') == "2024-02-15-preview" + config = AzureOpenAIConfig(temperature=0.8, max_tokens=1500) + assert config.get('temperature') == 0.8 + assert config.get('max_tokens') == 1500 assert config.get('nonexistent') is None assert config.get('nonexistent', 'default') == 'default' @@ -169,8 +153,9 @@ def test_client_initialization(self, azure_config): """测试客户端初始化""" client = AzureOpenAIClient(azure_config) assert client.config == azure_config - assert hasattr(client, '_config_only_attrs') - assert 'api_version' in client._config_only_attrs + assert isinstance(client.config, AzureOpenAIConfig) + assert client._sync_client is None + assert client._async_client is None def test_client_with_default_config(self): """测试使用默认配置初始化客户端""" @@ -178,107 +163,80 @@ def test_client_with_default_config(self): assert isinstance(client.config, AzureOpenAIConfig) def test_generate_log_id(self, azure_client): - """测试 Log ID 生成""" + """测试 log ID 生成""" messages = [{"role": "user", "content": "Hello"}] - log_id_1 = azure_client._generate_log_id(messages) - log_id_2 = azure_client._generate_log_id(messages) + log_id = azure_client._generate_log_id(messages) + + # 验证是 SHA256 哈希 + assert len(log_id) == 64 + assert all(c in '0123456789abcdef' for c in log_id) - # 相同输入应该生成相同的 log ID - assert log_id_1 == log_id_2 - assert len(log_id_1) == 64 # SHA256 输出长度 + # 相同输入应该产生相同的 log ID + log_id2 = azure_client._generate_log_id(messages) + assert log_id == log_id2 - # 不同输入应该生成不同的 log ID + # 不同输入应该产生不同的 log ID different_messages = [{"role": "user", "content": "Hi"}] - log_id_3 = azure_client._generate_log_id(different_messages) - assert log_id_1 != log_id_3 + log_id3 = azure_client._generate_log_id(different_messages) + assert log_id != log_id3 def test_build_azure_headers(self, azure_client): """测试 Azure 请求头构建""" messages = [{"role": "user", "content": "Hello"}] headers = azure_client._build_azure_headers(messages) - assert "Content-Type" in headers - assert headers["Content-Type"] == "application/json" - assert "X-TT-LOGID" in headers - assert len(headers["X-TT-LOGID"]) == 64 - - def test_build_chat_data_basic(self, azure_client): - """测试基础聊天数据构建""" - messages = [{"role": "user", "content": "Hello"}] - data = azure_client._build_chat_data(messages) - - assert data["messages"] == messages - assert data["stream"] is False # 默认为 False - if azure_client.config.model: - assert "model" in data - - def test_build_chat_data_with_kwargs(self, azure_client): - """测试带 kwargs 的聊天数据构建""" - messages = [{"role": "user", "content": "Hello"}] - data = azure_client._build_chat_data( - messages, - temperature=0.5, - max_tokens=100, - stream=True - ) - - assert data["messages"] == messages - assert data["temperature"] == 0.5 - assert data["max_tokens"] == 100 - assert data["stream"] is True + assert 'Content-Type' in headers + assert headers['Content-Type'] == 'application/json' + assert 'X-TT-LOGID' in headers + assert len(headers['X-TT-LOGID']) == 64 # SHA256 长度 - def test_build_chat_data_excludes_config_attrs(self, azure_client): - """测试构建数据时排除配置专用属性""" + def test_build_chat_data_excludes_azure_attrs(self, azure_client): + """测试构建数据时排除 Azure 专用属性""" messages = [{"role": "user", "content": "Hello"}] data = azure_client._build_chat_data(messages) # 这些属性不应该出现在 API 请求数据中 for attr in azure_client._config_only_attrs: assert attr not in data + + # 但应该包含必要的字段 + assert data["messages"] == messages + assert "model" in data # 应该从 config 中获取 - def test_get_param_value_priority(self, azure_client): - """测试参数值优先级""" + def test_parameter_priority_azure(self, azure_client): + """测试 Azure 客户端的参数优先级""" # kwargs 优先于 config - kwargs = {"temperature": 0.8} + kwargs = {"temperature": 0.9, "max_tokens": 500} value = azure_client._get_param_value("temperature", kwargs) - assert value == 0.8 + assert value == 0.9 # 没有 kwargs 时使用 config - azure_client.config.temperature = 0.5 - value = azure_client._get_param_value("temperature", {}) - assert value == 0.5 - - # 都没有时返回 None - value = azure_client._get_param_value("nonexistent", {}) - assert value is None + value = azure_client._get_param_value("model", {}) + assert value == azure_client.config.model - @patch('requests.post') + @patch('chattool.core.request.AzureOpenAIClient.post') def test_chat_completion_sync(self, mock_post, azure_client, mock_azure_response_data): - """测试同步聊天完成""" + """测试 Azure 同步聊天完成""" # 设置 mock mock_response = Mock() mock_response.json.return_value = mock_azure_response_data - mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response - messages = [{"role": "user", "content": "Hello"}] + messages = [{"role": "user", "content": "Hello Azure"}] result = azure_client.chat_completion(messages) # 验证调用 mock_post.assert_called_once() call_args = mock_post.call_args - # 验证 URL 包含 Azure 特定格式 - url = call_args[0][0] - assert "api-version" in url - assert "ak=" in url + # 验证使用空字符串作为 endpoint(因为 api_base 已经是完整地址) + assert call_args[0][0] == "" # 验证请求数据 - assert "json" in call_args[1] - json_data = call_args[1]["json"] - assert json_data["messages"] == messages + assert "data" in call_args[1] + assert call_args[1]["data"]["messages"] == messages - # 验证请求头 + # 验证 Azure 特殊请求头 assert "headers" in call_args[1] headers = call_args[1]["headers"] assert "X-TT-LOGID" in headers @@ -286,185 +244,173 @@ def test_chat_completion_sync(self, mock_post, azure_client, mock_azure_response # 验证返回值 assert result == mock_azure_response_data - @patch('aiohttp.ClientSession.post') + @patch('chattool.core.request.AzureOpenAIClient.async_post') @pytest.mark.asyncio - async def test_chat_completion_async(self, mock_post, azure_client, mock_azure_response_data): - """测试异步聊天完成""" + async def test_chat_completion_async(self, mock_async_post, azure_client, mock_azure_response_data): + """测试 Azure 异步聊天完成""" # 设置 mock mock_response = Mock() - mock_response.json = Mock(return_value=mock_azure_response_data) - mock_response.raise_for_status = Mock() - - # 设置异步上下文管理器 - mock_context = Mock() - mock_context.__aenter__ = Mock(return_value=mock_response) - mock_context.__aexit__ = Mock(return_value=None) - mock_post.return_value = mock_context - - # 模拟 ClientSession - with patch('aiohttp.ClientSession') as mock_session: - mock_session_instance = Mock() - mock_session_instance.post.return_value = mock_context - mock_session_instance.__aenter__ = Mock(return_value=mock_session_instance) - mock_session_instance.__aexit__ = Mock(return_value=None) - mock_session.return_value = mock_session_instance - - messages = [{"role": "user", "content": "Hello"}] - result = await azure_client.async_chat_completion(messages) - - # 验证返回值 - assert result == mock_azure_response_data + mock_response.json.return_value = mock_azure_response_data + mock_async_post.return_value = mock_response + + messages = [{"role": "user", "content": "Hello Azure Async"}] + result = await azure_client.async_chat_completion(messages) + + # 验证调用 + mock_async_post.assert_called_once() + call_args = mock_async_post.call_args + + # 验证使用空字符串作为 endpoint + assert call_args[0][0] == "" + + # 验证请求数据和头 + assert "data" in call_args[1] + assert "headers" in call_args[1] + assert "X-TT-LOGID" in call_args[1]["headers"] + + # 验证返回值 + assert result == mock_azure_response_data - def test_chat_completion_parameter_override(self, azure_client): - """测试参数覆盖功能""" + def test_parameter_override_azure(self, azure_client): + """测试 Azure 客户端参数覆盖功能""" messages = [{"role": "user", "content": "Hello"}] - with patch('requests.post') as mock_post: + with patch.object(azure_client, 'post') as mock_post: mock_response = Mock() mock_response.json.return_value = {} - mock_response.raise_for_status.return_value = None mock_post.return_value = mock_response # 调用时覆盖配置参数 azure_client.chat_completion( messages, model="gpt-4", - temperature=0.2, - custom_param="test" + temperature=0.1, + azure_custom_param="test" ) # 验证数据中包含覆盖的参数 - call_args = mock_post.call_args - json_data = call_args[1]["json"] - assert json_data["model"] == "gpt-4" - assert json_data["temperature"] == 0.2 - assert json_data["custom_param"] == "test" + call_data = mock_post.call_args[1]["data"] + assert call_data["model"] == "gpt-4" + assert call_data["temperature"] == 0.1 + assert call_data["azure_custom_param"] == "test" - @patch('requests.post') - def test_simple_request(self, mock_post, azure_client, mock_azure_response_data): - """测试简化的请求方法""" - # 设置 mock - mock_response = Mock() - mock_response.json.return_value = mock_azure_response_data - mock_response.raise_for_status.return_value = None - mock_post.return_value = mock_response - - # 测试字符串输入 - result = azure_client.simple_request("Hello, how are you?") - expected_content = mock_azure_response_data['choices'][0]['message']['content'] - assert result == expected_content + @patch('chattool.core.request.AzureOpenAIClient._stream_chat_completion') + def test_chat_completion_stream_mode(self, mock_stream, azure_client): + """测试 Azure 流式模式""" + messages = [{"role": "user", "content": "Hello"}] + mock_stream.return_value = iter([]) - # 验证自动转换为消息格式 - call_args = mock_post.call_args - json_data = call_args[1]["json"] - expected_messages = [{"role": "user", "content": "Hello, how are you?"}] - assert json_data["messages"] == expected_messages - - @patch('requests.post') - def test_simple_request_with_messages(self, mock_post, azure_client, mock_azure_response_data): - """测试简化请求方法使用消息列表""" - mock_response = Mock() - mock_response.json.return_value = mock_azure_response_data - mock_response.raise_for_status.return_value = None - mock_post.return_value = mock_response + result = azure_client.chat_completion(messages, stream=True) - messages = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "user", "content": "How are you?"} - ] + # 验证调用了流式方法 + mock_stream.assert_called_once() + call_args = mock_stream.call_args - result = azure_client.simple_request(messages, model_name="gpt-4") - expected_content = mock_azure_response_data['choices'][0]['message']['content'] - assert result == expected_content + # 验证数据包含 stream=True + call_data = call_args[0][0] + assert call_data["stream"] is True - # 验证消息直接传递 - call_args = mock_post.call_args - json_data = call_args[1]["json"] - assert json_data["messages"] == messages + # 验证传递了 Azure 请求头 + azure_headers = call_args[0][1] + assert "X-TT-LOGID" in azure_headers + + def test_config_only_attrs_azure(self, azure_client): + """测试 Azure 配置专用属性列表""" + expected_attrs = { + 'api_key', 'api_base', 'api_version', + 'headers', 'timeout', 'max_retries', 'retry_delay' + } + assert azure_client._config_only_attrs == expected_attrs class TestAzureParameterHandling: """测试 Azure 参数处理逻辑""" - def test_none_values_excluded(self, azure_client): - """测试 None 值被排除""" + def test_azure_none_values_excluded(self, azure_client): + """测试 Azure 客户端排除 None 值""" messages = [{"role": "user", "content": "Hello"}] data = azure_client._build_chat_data( messages, temperature=None, max_tokens=100, - top_p=None + api_version=None # Azure 特有参数 ) assert "temperature" not in data - assert "top_p" not in data + assert "api_version" not in data assert data["max_tokens"] == 100 - def test_config_fallback(self, azure_client): - """测试配置回退机制""" + def test_azure_config_fallback(self, azure_client): + """测试 Azure 配置回退机制""" # 设置一些配置值 - azure_client.config.temperature = 0.7 - azure_client.config.custom_param = "config_value" + azure_client.config.temperature = 0.8 + azure_client.config.azure_custom_param = "azure_value" messages = [{"role": "user", "content": "Hello"}] data = azure_client._build_chat_data(messages) # 应该使用配置中的值 - assert data["temperature"] == 0.7 - assert data["custom_param"] == "config_value" + assert data["temperature"] == 0.8 + assert data["azure_custom_param"] == "azure_value" + + # 但不应该包含配置专用属性 + assert "api_version" not in data + assert "api_key" not in data - def test_kwargs_override_config(self, azure_client): - """测试 kwargs 覆盖配置""" + def test_azure_kwargs_override_config(self, azure_client): + """测试 Azure 客户端 kwargs 覆盖配置""" # 设置配置值 azure_client.config.temperature = 0.7 + azure_client.config.max_tokens = 1000 messages = [{"role": "user", "content": "Hello"}] data = azure_client._build_chat_data( messages, - temperature=0.2 # 覆盖配置 + temperature=0.2, # 覆盖配置 + custom_azure_param="override_value" ) # 应该使用 kwargs 中的值 assert data["temperature"] == 0.2 + assert data["custom_azure_param"] == "override_value" + # max_tokens 应该使用配置中的值 + assert data["max_tokens"] == 1000 - def test_stream_default_false(self, azure_client): - """测试 stream 参数默认为 False""" - messages = [{"role": "user", "content": "Hello"}] - data = azure_client._build_chat_data(messages) - - assert data["stream"] is False - - def test_unknown_parameters_included(self, azure_client): - """测试未知参数也会被包含""" + def test_azure_unknown_parameters_included(self, azure_client): + """测试 Azure 客户端包含未知参数""" messages = [{"role": "user", "content": "Hello"}] data = azure_client._build_chat_data( messages, - future_param="new_feature", - another_param={"complex": "value"} + azure_future_param="new_azure_feature", + custom_header_param={"complex": "azure_value"} ) - assert data["future_param"] == "new_feature" - assert data["another_param"] == {"complex": "value"} + assert data["azure_future_param"] == "new_azure_feature" + assert data["custom_header_param"] == {"complex": "azure_value"} -class TestAzureURLConstruction: - """测试 Azure URL 构建逻辑""" +class TestAzureIntegration: + """测试 Azure 集成功能""" - def test_url_with_all_params(self): - """测试包含所有参数的 URL 构建""" - config = AzureOpenAIConfig( - api_base="https://test.openai.azure.com/", # 测试末尾斜杠处理 - api_version="2024-02-15-preview", - api_key="test-key" - ) - url = config.get_request_url() - expected = "https://test.openai.azure.com?api-version=2024-02-15-preview&ak=test-key" - assert url == expected + def test_log_id_consistency(self, azure_client): + """测试 log ID 在同一请求中的一致性""" + messages = [{"role": "user", "content": "Test consistency"}] + + # 多次调用应该产生相同的 log ID + headers1 = azure_client._build_azure_headers(messages) + headers2 = azure_client._build_azure_headers(messages) + + assert headers1['X-TT-LOGID'] == headers2['X-TT-LOGID'] - def test_url_minimal_params(self): - """测试最少参数的 URL 构建""" - config = AzureOpenAIConfig(api_base="https://test.openai.azure.com") - url = config.get_request_url() - expected = "https://test.openai.azure.com" - assert url == expected + def test_azure_vs_openai_client_compatibility(self, azure_client): + """测试 Azure 客户端与 OpenAI 客户端的 API 兼容性""" + # Azure 客户端应该有与 OpenAI 客户端相同的主要方法 + assert hasattr(azure_client, 'chat_completion') + assert hasattr(azure_client, 'async_chat_completion') + assert hasattr(azure_client, '_build_chat_data') + assert hasattr(azure_client, '_get_param_value') + + # 但有 Azure 特有的方法 + assert hasattr(azure_client, '_generate_log_id') + assert hasattr(azure_client, '_build_azure_headers') + assert hasattr(azure_client, 'simple_request') From dcca84129492da143ea8fe17be12fc77d01adb97 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 30 Jun 2025 04:51:23 +0800 Subject: [PATCH 4/7] restore old response --- chattool/chattype.py | 2 +- chattool/core/request.py | 14 ++-- chattool/response.py | 141 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 7 deletions(-) create mode 100644 chattool/response.py diff --git a/chattool/chattype.py b/chattool/chattype.py index d8d28df..897d184 100644 --- a/chattool/chattype.py +++ b/chattool/chattype.py @@ -2,7 +2,7 @@ from typing import List, Dict, Union import chattool -from .core.response import Resp +from .response import Resp from .request import chat_completion, valid_models, curl_cmd_of_chat_completion import time, random, json, warnings import aiohttp diff --git a/chattool/core/request.py b/chattool/core/request.py index 01bbcf6..1e1aba5 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -277,6 +277,7 @@ def chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, + uri: str='/chat/completions', **kwargs ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: """ @@ -311,7 +312,7 @@ def chat_completion( if data.get('stream'): return self._stream_chat_completion(data) - response = self.post("/chat/completions", data=data) + response = self.post(uri, data=data) return response.json() async def async_chat_completion( @@ -322,6 +323,7 @@ async def async_chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, + uri: str='/chat/completions', **kwargs ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: """ @@ -356,16 +358,16 @@ async def async_chat_completion( if data.get('stream'): return self._async_stream_chat_completion(data) - response = await self.async_post("/chat/completions", data=data) + response = await self.async_post(uri, data=data) return response.json() - def _stream_chat_completion(self, data: Dict[str, Any]) -> Generator[StreamResponse, None, None]: + def _stream_chat_completion(self, data: Dict[str, Any], uri:str='/chat/completions') -> Generator[StreamResponse, None, None]: """同步流式聊天完成 - 返回生成器""" client = self._get_sync_client() with client.stream( "POST", - "/chat/completions", + uri, json=data, headers=self.config.headers ) as response: @@ -410,13 +412,13 @@ def _stream_chat_completion(self, data: Dict[str, Any]) -> Generator[StreamRespo self.logger.error(f"Error processing stream chunk: {e}") break - async def _async_stream_chat_completion(self, data: Dict[str, Any]) -> AsyncGenerator[StreamResponse, None]: + async def _async_stream_chat_completion(self, data: Dict[str, Any], uri='/chat/completions') -> AsyncGenerator[StreamResponse, None]: """异步流式聊天完成 - 返回异步生成器""" client = self._get_async_client() async with client.stream( "POST", - "/chat/completions", + uri, json=data, headers=self.config.headers ) as response: diff --git a/chattool/response.py b/chattool/response.py new file mode 100644 index 0000000..d3fb251 --- /dev/null +++ b/chattool/response.py @@ -0,0 +1,141 @@ +# Response class for Chattool + +from typing import Dict, Any, Union +from .tokencalc import findcost +import chattool + +class Resp(): + + def __init__(self, response:Union[Dict, Any]) -> None: + if isinstance(response, Dict): + self.response = response + self._raw_response = None + else: + self._raw_response = response + self.response = response.json() + + def get_curl(self): + """Convert the response to a cURL command""" + if self._raw_response is None: + return "No cURL command available" + return chattool.resp2curl(self._raw_response) + + def print_curl(self): + """Print the cURL command""" + print(self.get_curl()) + + def is_valid(self): + """Check if the response is an error""" + return 'error' not in self.response + + def cost(self): + """Calculate the cost of the response(Deprecated)""" + return findcost(self.model, self.prompt_tokens, self.completion_tokens) + + @property + def id(self): + return self['id'] + + @property + def model(self): + return self['model'] + + @property + def created(self): + return self['created'] + + @property + def usage(self): + """Token usage""" + return self['usage'] + + @property + def total_tokens(self): + """Total number of tokens""" + return self.usage['total_tokens'] + + @property + def prompt_tokens(self): + """Number of tokens in the prompt""" + return self.usage['prompt_tokens'] + + @property + def completion_tokens(self): + """Number of tokens of the response""" + return self.usage['completion_tokens'] + + @property + def message(self): + """Message""" + return self['choices'][0]['message'] + + @property + def content(self): + """Content of the response""" + return self.message['content'] + + @property + def function_call(self): + """Function call""" + return self.message.get('function_call') + + @property + def tool_calls(self): + """Tool calls""" + return self.message.get('tool_calls') + + @property + def delta(self): + """Delta""" + return self['choices'][0]['delta'] + + @property + def delta_content(self): + """Content of stream response""" + return self.delta['content'] + + @property + def object(self): + return self['object'] + + @property + def error(self): + """Error""" + return self['error'] + + @property + def error_message(self): + """Error message""" + return self.error['message'] + + @property + def error_type(self): + """Error type""" + return self.error['type'] + + @property + def error_param(self): + """Error parameter""" + return self.error['param'] + + @property + def error_code(self): + """Error code""" + return self.error['code'] + + @property + def finish_reason(self): + """Finish reason""" + return self['choices'][0].get('finish_reason') + + def __repr__(self) -> str: + return "" + + def __str__(self) -> str: + return self.content + + def __getitem__(self, key): + return self.response[key] + + def __contains__(self, key): + return key in self.response \ No newline at end of file From e2d0ca919052fece20e8be47b86847fd66d71fde Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 30 Jun 2025 04:58:53 +0800 Subject: [PATCH 5/7] simplify azure --- chattool/core/request.py | 390 +++++++++++++------------------------- tests/test_azureclient.py | 27 +-- 2 files changed, 130 insertions(+), 287 deletions(-) diff --git a/chattool/core/request.py b/chattool/core/request.py index 1e1aba5..ace32b9 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -263,12 +263,17 @@ def _build_chat_data(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str def _get_param_value(self, param_name: str, kwargs: Dict[str, Any]): """按优先级获取参数值:kwargs > config > None""" - # 优先使用 kwargs 中的值 if param_name in kwargs: return kwargs[param_name] - # 其次使用 config 中的值 return self.config.get(param_name) + def _prepare_headers(self, messages: List[Dict[str, str]], custom_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """准备请求头 - 子类可以重写此方法""" + headers = self.config.headers.copy() + if custom_headers: + headers.update(custom_headers) + return headers + def chat_completion( self, messages: List[Dict[str, str]], @@ -277,7 +282,8 @@ def chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, - uri: str='/chat/completions', + uri: str = '/chat/completions', + headers: Optional[Dict[str, str]] = None, **kwargs ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: """ @@ -290,13 +296,11 @@ def chat_completion( top_p: top_p 参数 max_tokens: 最大token数 stream: 是否使用流式响应 + uri: 请求 URI + headers: 自定义请求头 **kwargs: 其他参数 - - Returns: - 如果 stream=False: 返回完整的响应字典 - 如果 stream=True: 返回 Generator,yield StreamResponse 对象 """ - # 将显式参数合并到 kwargs 中 + # 合并参数 all_kwargs = { 'model': model, 'temperature': temperature, @@ -306,13 +310,14 @@ def chat_completion( **kwargs } - # 使用统一的参数处理逻辑 + # 构建数据和请求头 data = self._build_chat_data(messages, **all_kwargs) + request_headers = self._prepare_headers(messages, headers) if data.get('stream'): - return self._stream_chat_completion(data) + return self._stream_chat_completion(data, uri, request_headers) - response = self.post(uri, data=data) + response = self.post(uri, data=data, headers=request_headers) return response.json() async def async_chat_completion( @@ -323,26 +328,11 @@ async def async_chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, - uri: str='/chat/completions', + uri: str = '/chat/completions', + headers: Optional[Dict[str, str]] = None, **kwargs ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: - """ - OpenAI Chat Completion API (异步版本) - - Args: - messages: 对话消息列表 - model: 模型名称 - temperature: 温度参数 - top_p: top_p 参数 - max_tokens: 最大token数 - stream: 是否使用流式响应 - **kwargs: 其他参数 - - Returns: - 如果 stream=False: 返回完整的响应字典 - 如果 stream=True: 返回 AsyncGenerator,async yield StreamResponse 对象 - """ - # 将显式参数合并到 kwargs 中 + """OpenAI Chat Completion API (异步版本)""" all_kwargs = { 'model': model, 'temperature': temperature, @@ -352,117 +342,103 @@ async def async_chat_completion( **kwargs } - # 使用统一的参数处理逻辑 data = self._build_chat_data(messages, **all_kwargs) + request_headers = self._prepare_headers(messages, headers) if data.get('stream'): - return self._async_stream_chat_completion(data) + return self._async_stream_chat_completion(data, uri, request_headers) - response = await self.async_post(uri, data=data) + response = await self.async_post(uri, data=data, headers=request_headers) return response.json() - def _stream_chat_completion(self, data: Dict[str, Any], uri:str='/chat/completions') -> Generator[StreamResponse, None, None]: - """同步流式聊天完成 - 返回生成器""" + def _stream_chat_completion( + self, + data: Dict[str, Any], + uri: str = '/chat/completions', + headers: Optional[Dict[str, str]] = None + ) -> Generator[StreamResponse, None, None]: + """同步流式聊天完成""" client = self._get_sync_client() + request_headers = headers or self.config.headers with client.stream( "POST", uri, json=data, - headers=self.config.headers + headers=request_headers ) as response: response.raise_for_status() - - for line in response.iter_lines(): - if not line: - continue - - # 处理 SSE 格式的数据 - line_str = line.decode('utf-8').strip() - - # 跳过非数据行 - if not line_str.startswith('data: '): - continue - - # 提取数据部分 - data_str = line_str[6:].strip() # 去掉 'data: ' 前缀 - - # 检查是否结束 - if data_str == '[DONE]': - break - - # 跳过空行 - if not data_str: - continue - - try: - # 解析 JSON - chunk_data = json.loads(data_str) - stream_response = self._process_stream_chunk(chunk_data) - yield stream_response - - # 如果完成,退出循环 - if stream_response.is_finished: - break - - except json.JSONDecodeError as e: - self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") - continue - except Exception as e: - self.logger.error(f"Error processing stream chunk: {e}") - break + yield from self._process_stream_response(response.iter_lines()) - async def _async_stream_chat_completion(self, data: Dict[str, Any], uri='/chat/completions') -> AsyncGenerator[StreamResponse, None]: - """异步流式聊天完成 - 返回异步生成器""" + async def _async_stream_chat_completion( + self, + data: Dict[str, Any], + uri: str = '/chat/completions', + headers: Optional[Dict[str, str]] = None + ) -> AsyncGenerator[StreamResponse, None]: + """异步流式聊天完成""" client = self._get_async_client() + request_headers = headers or self.config.headers async with client.stream( "POST", uri, json=data, - headers=self.config.headers + headers=request_headers ) as response: response.raise_for_status() + async for chunk in self._async_process_stream_response(response.aiter_lines()): + yield chunk + + def _process_stream_response(self, lines): + """处理流式响应行""" + for line in lines: + if not line: + continue - async for line in response.aiter_lines(): - if not line: - continue - - # 处理 SSE 格式的数据 - line_str = line.decode('utf-8').strip() - - # 跳过非数据行 - if not line_str.startswith('data: '): - continue - - # 提取数据部分 - data_str = line_str[6:].strip() # 去掉 'data: ' 前缀 - - # 检查是否结束 - if data_str == '[DONE]': + line_str = line.decode('utf-8').strip() + chunk = self._parse_stream_line(line_str) + if chunk: + yield chunk + if chunk.is_finished: break - - # 跳过空行 - if not data_str: - continue - - try: - # 解析 JSON - chunk_data = json.loads(data_str) - stream_response = self._process_stream_chunk(chunk_data) - yield stream_response - - # 如果完成,退出循环 - if stream_response.is_finished: - break - - except json.JSONDecodeError as e: - self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") - continue - except Exception as e: - self.logger.error(f"Error processing stream chunk: {e}") + + async def _async_process_stream_response(self, lines): + """异步处理流式响应行""" + async for line in lines: + if not line: + continue + + line_str = line.decode('utf-8').strip() + chunk = self._parse_stream_line(line_str) + if chunk: + yield chunk + if chunk.is_finished: break + def _parse_stream_line(self, line_str: str) -> Optional[StreamResponse]: + """解析单行流式响应""" + if not line_str.startswith('data: '): + return None + + data_str = line_str[6:].strip() + + if data_str == '[DONE]': + return None + + if not data_str: + return None + + try: + chunk_data = json.loads(data_str) + return self._process_stream_chunk(chunk_data) + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") + return None + except Exception as e: + self.logger.error(f"Error processing stream chunk: {e}") + return None + def embeddings( self, input_text: Union[str, List[str]], @@ -470,14 +446,12 @@ def embeddings( **kwargs ) -> Dict[str, Any]: """OpenAI Embeddings API""" - # 使用统一的参数处理逻辑 all_kwargs = { 'model': model or self.config.get('model', 'text-embedding-ada-002'), 'input': input_text, **kwargs } - # 构建数据,但排除 input 参数因为它已经单独处理 data = {} for key, value in all_kwargs.items(): if value is not None: @@ -493,14 +467,12 @@ async def async_embeddings( **kwargs ) -> Dict[str, Any]: """异步 OpenAI Embeddings API""" - # 使用统一的参数处理逻辑 all_kwargs = { 'model': model or self.config.get('model', 'text-embedding-ada-002'), 'input': input_text, **kwargs } - # 构建数据 data = {} for key, value in all_kwargs.items(): if value is not None: @@ -510,8 +482,20 @@ async def async_embeddings( return response.json() def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: - """处理流式响应的单个数据块,返回 StreamResponse 对象""" + """处理流式响应的单个数据块""" return StreamResponse(chunk_data) + + # 通用的简化请求方法 + def simple_request(self, input_text: Union[str, List[Dict[str, str]]], model_name: Optional[str] = None, **kwargs) -> str: + """简化的请求方法,直接返回内容字符串""" + if isinstance(input_text, str): + messages = [{"role": "user", "content": input_text}] + else: + messages = input_text + + response = self.chat_completion(messages, model=model_name, **kwargs) + return response['choices'][0]['message']['content'] + class AzureOpenAIClient(OpenAIClient): """Azure OpenAI 客户端""" @@ -531,10 +515,14 @@ def _generate_log_id(self, messages: List[Dict[str, str]]) -> str: content = str(messages).encode() return hashlib.sha256(content).hexdigest() - def _build_azure_headers(self, messages: List[Dict[str, str]]) -> Dict[str, str]: - """构建 Azure 专用请求头""" + def _prepare_headers(self, messages: List[Dict[str, str]], custom_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """准备 Azure 专用请求头""" headers = self.config.headers.copy() headers['X-TT-LOGID'] = self._generate_log_id(messages) + + if custom_headers: + headers.update(custom_headers) + return headers def chat_completion( @@ -548,25 +536,17 @@ def chat_completion( **kwargs ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: """Azure OpenAI Chat Completion API (同步版本)""" - # 复用父类的参数处理逻辑 - all_kwargs = { - 'model': model, - 'temperature': temperature, - 'top_p': top_p, - 'max_tokens': max_tokens, - 'stream': stream, + # 调用父类方法,但使用空字符串作为 URI(因为 api_base 已经是完整地址) + return super().chat_completion( + messages=messages, + model=model, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + stream=stream, + uri="", # Azure 的 api_base 已经是完整地址 **kwargs - } - - data = self._build_chat_data(messages, **all_kwargs) - azure_headers = self._build_azure_headers(messages) - - if data.get('stream'): - return self._stream_chat_completion(data, azure_headers) - - # 使用父类的 post 方法,但传入 Azure 特殊的请求头 - response = self.post("", data=data, headers=azure_headers) - return response.json() + ) async def async_chat_completion( self, @@ -579,128 +559,14 @@ async def async_chat_completion( **kwargs ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: """Azure OpenAI Chat Completion API (异步版本)""" - all_kwargs = { - 'model': model, - 'temperature': temperature, - 'top_p': top_p, - 'max_tokens': max_tokens, - 'stream': stream, + # 调用父类方法 + return await super().async_chat_completion( + messages=messages, + model=model, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + stream=stream, + uri="", **kwargs - } - - data = self._build_chat_data(messages, **all_kwargs) - azure_headers = self._build_azure_headers(messages) - - if data.get('stream'): - return self._async_stream_chat_completion(data, azure_headers) - - # 使用父类的 async_post 方法 - response = await self.async_post("", data=data, headers=azure_headers) - return response.json() - - def _stream_chat_completion(self, data: Dict[str, Any], azure_headers: Dict[str, str]) -> Generator[StreamResponse, None, None]: - """Azure 同步流式聊天完成""" - client = self._get_sync_client() - - # 合并请求头 - merged_headers = self.config.headers.copy() - merged_headers.update(azure_headers) - - with client.stream( - "POST", - "", # 空字符串,因为 base_url 已经是完整地址 - json=data, - headers=merged_headers - ) as response: - response.raise_for_status() - - for line in response.iter_lines(): - if not line: - continue - - line_str = line.decode('utf-8').strip() - - if not line_str.startswith('data: '): - continue - - data_str = line_str[6:].strip() - - if data_str == '[DONE]': - break - - if not data_str: - continue - - try: - chunk_data = json.loads(data_str) - stream_response = self._process_stream_chunk(chunk_data) - yield stream_response - - if stream_response.is_finished: - break - - except json.JSONDecodeError as e: - self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") - continue - except Exception as e: - self.logger.error(f"Error processing stream chunk: {e}") - break - - async def _async_stream_chat_completion(self, data: Dict[str, Any], azure_headers: Dict[str, str]) -> AsyncGenerator[StreamResponse, None]: - """Azure 异步流式聊天完成""" - client = self._get_async_client() - - # 合并请求头 - merged_headers = self.config.headers.copy() - merged_headers.update(azure_headers) - - async with client.stream( - "POST", - "", # 空字符串,因为 base_url 已经是完整地址 - json=data, - headers=merged_headers - ) as response: - response.raise_for_status() - - async for line in response.aiter_lines(): - if not line: - continue - - line_str = line.decode('utf-8').strip() - - if not line_str.startswith('data: '): - continue - - data_str = line_str[6:].strip() - - if data_str == '[DONE]': - break - - if not data_str: - continue - - try: - chunk_data = json.loads(data_str) - stream_response = self._process_stream_chunk(chunk_data) - yield stream_response - - if stream_response.is_finished: - break - - except json.JSONDecodeError as e: - self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") - continue - except Exception as e: - self.logger.error(f"Error processing stream chunk: {e}") - break - - # 简化请求方法 - def simple_request(self, input_text: Union[str, List[Dict[str, str]]], model_name: Optional[str] = None) -> str: - """简化的请求方法,直接返回内容字符串""" - if isinstance(input_text, str): - messages = [{"role": "user", "content": input_text}] - else: - messages = input_text - - response = self.chat_completion(messages, model=model_name) - return response['choices'][0]['message']['content'] \ No newline at end of file + ) \ No newline at end of file diff --git a/tests/test_azureclient.py b/tests/test_azureclient.py index 770cfec..46dfd22 100644 --- a/tests/test_azureclient.py +++ b/tests/test_azureclient.py @@ -180,16 +180,6 @@ def test_generate_log_id(self, azure_client): log_id3 = azure_client._generate_log_id(different_messages) assert log_id != log_id3 - def test_build_azure_headers(self, azure_client): - """测试 Azure 请求头构建""" - messages = [{"role": "user", "content": "Hello"}] - headers = azure_client._build_azure_headers(messages) - - assert 'Content-Type' in headers - assert headers['Content-Type'] == 'application/json' - assert 'X-TT-LOGID' in headers - assert len(headers['X-TT-LOGID']) == 64 # SHA256 长度 - def test_build_chat_data_excludes_azure_attrs(self, azure_client): """测试构建数据时排除 Azure 专用属性""" messages = [{"role": "user", "content": "Hello"}] @@ -311,8 +301,8 @@ def test_chat_completion_stream_mode(self, mock_stream, azure_client): assert call_data["stream"] is True # 验证传递了 Azure 请求头 - azure_headers = call_args[0][1] - assert "X-TT-LOGID" in azure_headers + # azure_headers = call_args[0][1] + # assert "X-TT-LOGID" in azure_headers def test_config_only_attrs_azure(self, azure_client): """测试 Azure 配置专用属性列表""" @@ -392,16 +382,6 @@ def test_azure_unknown_parameters_included(self, azure_client): class TestAzureIntegration: """测试 Azure 集成功能""" - def test_log_id_consistency(self, azure_client): - """测试 log ID 在同一请求中的一致性""" - messages = [{"role": "user", "content": "Test consistency"}] - - # 多次调用应该产生相同的 log ID - headers1 = azure_client._build_azure_headers(messages) - headers2 = azure_client._build_azure_headers(messages) - - assert headers1['X-TT-LOGID'] == headers2['X-TT-LOGID'] - def test_azure_vs_openai_client_compatibility(self, azure_client): """测试 Azure 客户端与 OpenAI 客户端的 API 兼容性""" # Azure 客户端应该有与 OpenAI 客户端相同的主要方法 @@ -409,8 +389,5 @@ def test_azure_vs_openai_client_compatibility(self, azure_client): assert hasattr(azure_client, 'async_chat_completion') assert hasattr(azure_client, '_build_chat_data') assert hasattr(azure_client, '_get_param_value') - # 但有 Azure 特有的方法 assert hasattr(azure_client, '_generate_log_id') - assert hasattr(azure_client, '_build_azure_headers') - assert hasattr(azure_client, 'simple_request') From 8f236f900efe30f1f17b35b297d60f93319e22ae Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Mon, 30 Jun 2025 05:23:18 +0800 Subject: [PATCH 6/7] fix azure config --- chattool/core/config.py | 6 ++++-- chattool/core/request.py | 17 ++--------------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/chattool/core/config.py b/chattool/core/config.py index d34a53a..e0be765 100644 --- a/chattool/core/config.py +++ b/chattool/core/config.py @@ -126,10 +126,12 @@ def __init__( # update api_base """获取带 API 密钥的请求 URL""" endpoint = self.api_base.rstrip('/') + sym = '?' if self.api_version: - endpoint = f"{endpoint}?api-version={self.api_version}" + endpoint = f"{endpoint}{sym}api-version={self.api_version}" + sym = '&' if self.api_key: - endpoint = f"{endpoint}&ak={self.api_key}" + endpoint = f"{endpoint}{sym}ak={self.api_key}" self.api_base = endpoint # Anthropic 配置示例 diff --git a/chattool/core/request.py b/chattool/core/request.py index ace32b9..81bdda4 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -97,7 +97,7 @@ def request( def _make_request(): return client.request( method=method, - url=endpoint, + url=endpoint if endpoint else self.config.api_base, json=data, params=params, headers=merged_headers, @@ -263,7 +263,7 @@ def _build_chat_data(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str def _get_param_value(self, param_name: str, kwargs: Dict[str, Any]): """按优先级获取参数值:kwargs > config > None""" - if param_name in kwargs: + if kwargs.get(param_name) is not None: return kwargs[param_name] return self.config.get(param_name) @@ -484,19 +484,6 @@ async def async_embeddings( def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: """处理流式响应的单个数据块""" return StreamResponse(chunk_data) - - # 通用的简化请求方法 - def simple_request(self, input_text: Union[str, List[Dict[str, str]]], model_name: Optional[str] = None, **kwargs) -> str: - """简化的请求方法,直接返回内容字符串""" - if isinstance(input_text, str): - messages = [{"role": "user", "content": input_text}] - else: - messages = input_text - - response = self.chat_completion(messages, model=model_name, **kwargs) - return response['choices'][0]['message']['content'] - - class AzureOpenAIClient(OpenAIClient): """Azure OpenAI 客户端""" From 2833e8c1f4703c19c385fb61cd174cb613e6d469 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Tue, 1 Jul 2025 09:16:01 +0800 Subject: [PATCH 7/7] remove stream --- chattool/core/__init__.py | 2 + chattool/core/chattype.py | 506 ++++++++++++++++++++------------------ chattool/core/config.py | 34 ++- chattool/core/request.py | 282 +++++++++------------ tests/conftest.py | 8 +- tests/test_azureclient.py | 91 ------- tests/test_oaiclient.py | 60 ----- 7 files changed, 403 insertions(+), 580 deletions(-) diff --git a/chattool/core/__init__.py b/chattool/core/__init__.py index cc0173e..ae0cab0 100644 --- a/chattool/core/__init__.py +++ b/chattool/core/__init__.py @@ -1,2 +1,4 @@ from chattool.core.config import Config, OpenAIConfig, AzureOpenAIConfig from chattool.core.request import HTTPClient, OpenAIClient, AzureOpenAIClient +from chattool.core.response import ChatResponse +from chattool.core.chattype import Chat \ No newline at end of file diff --git a/chattool/core/chattype.py b/chattool/core/chattype.py index b771ccb..c0416bb 100644 --- a/chattool/core/chattype.py +++ b/chattool/core/chattype.py @@ -2,63 +2,55 @@ import json import os import time -import random import asyncio -from loguru import logger -from chattool.core.config import OpenAIConfig -from chattool.core.request import OpenAIClient, StreamResponse +from chattool.custom_logger import setup_logger +from chattool.core.config import OpenAIConfig, AzureOpenAIConfig +from chattool.core.request import OpenAIClient, AzureOpenAIClient from chattool.core.response import ChatResponse -class Chat(OpenAIClient): - """简化的 Chat 类 - 专注于基础对话功能""" +def Chat( + config: Optional[Union[OpenAIConfig, AzureOpenAIConfig]] = None, + logger: Optional[object] = None, + **kwargs +) -> 'ChatBase': + """ + Chat 工厂函数 - 根据配置类型自动选择正确的客户端 - def __init__( - self, - msg: Union[List[Dict], str, None] = None, - config: Optional[OpenAIConfig] = None, - **kwargs - ): - """ - 初始化 Chat 对象 - - Args: - msg: 初始消息,可以是字符串、消息列表或 None - config: OpenAI 配置对象 - **kwargs: 其他配置参数(会覆盖 config 中的设置) - """ - # 初始化配置 - if config is None: - config = OpenAIConfig() + Args: + config: 配置对象(OpenAIConfig 或 AzureOpenAIConfig) + logger: 日志实例 + **kwargs: 其他配置参数 - # 应用 kwargs 覆盖配置 - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - - super().__init__(config) + Returns: + ChatOpenAI 或 ChatAzure 实例 + """ + if config is None: + config = OpenAIConfig() + + logger = logger or setup_logger('ChatBase') + + if isinstance(config, AzureOpenAIConfig): + return ChatAzure(config=config, logger=logger, **kwargs) + elif isinstance(config, OpenAIConfig): + return ChatOpenAI(config=config, logger=logger, **kwargs) + else: + raise ValueError(f"不支持的配置类型: {type(config)}") + + +class ChatBase: + """Chat 基类 - 定义对话管理功能""" + + def __init__(self, config, logger=None, **kwargs): + self.config = config + self.logger = logger # 初始化对话历史 - self._chat_log = self._init_messages(msg) + self._chat_log: List[Dict] = [] self._last_response: Optional[ChatResponse] = None - def _init_messages(self, msg: Union[List[Dict], str, None]) -> List[Dict]: - """初始化消息列表""" - if msg is None: - return [] - elif isinstance(msg, str): - return [{"role": "user", "content": msg}] - elif isinstance(msg, list): - # 验证消息格式 - for m in msg: - if not isinstance(m, dict) or 'role' not in m: - raise ValueError("消息列表中的每个元素都必须是包含 'role' 键的字典") - return msg.copy() - else: - raise ValueError("msg 必须是字符串、字典列表或 None") - # === 消息管理 === - def add_message(self, role: str, content: str, **kwargs) -> 'Chat': + def add_message(self, role: str, content: str, **kwargs) -> 'ChatBase': """添加消息到对话历史""" if role not in ['user', 'assistant', 'system']: raise ValueError(f"role 必须是 'user', 'assistant' 或 'system',收到: {role}") @@ -67,19 +59,19 @@ def add_message(self, role: str, content: str, **kwargs) -> 'Chat': self._chat_log.append(message) return self - def user(self, content: str) -> 'Chat': + def user(self, content: str) -> 'ChatBase': """添加用户消息""" return self.add_message('user', content) - def assistant(self, content: str) -> 'Chat': + def assistant(self, content: str) -> 'ChatBase': """添加助手消息""" return self.add_message('assistant', content) - def system(self, content: str) -> 'Chat': + def system(self, content: str) -> 'ChatBase': """添加系统消息""" return self.add_message('system', content) - def clear(self) -> 'Chat': + def clear(self) -> 'ChatBase': """清空对话历史""" self._chat_log = [] self._last_response = None @@ -89,196 +81,16 @@ def pop(self, index: int = -1) -> Dict: """移除并返回指定位置的消息""" return self._chat_log.pop(index) - # === 核心对话功能 === - def get_response( - self, - max_retries: int = 3, - retry_delay: float = 1.0, - update_history: bool = True, - **options - ) -> ChatResponse: - """ - 获取对话响应(同步) - - Args: - max_retries: 最大重试次数 - retry_delay: 重试延迟 - update_history: 是否更新对话历史 - **options: 传递给 chat_completion 的其他参数 - """ - # 合并配置 - chat_options = { - "model": self.config.model, - "temperature": self.config.temperature, - **options - } - - last_error = None - for attempt in range(max_retries + 1): - try: - # 调用 OpenAI API - response_data = self.chat_completion( - messages=self._chat_log, - **chat_options - ) - - # 包装响应 - response = ChatResponse(response_data) - - # 验证响应 - if not response.is_valid(): - raise Exception(f"API 返回错误: {response.error_message}") - - # 更新历史记录 - if update_history and response.message: - self._chat_log.append(response.message) - - self._last_response = response - return response - - except Exception as e: - last_error = e - if attempt < max_retries: - self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") - time.sleep(retry_delay * (2 ** attempt)) # 指数退避 - else: - self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") - - raise last_error + # === 核心对话功能 - 子类实现 === + def get_response(self, **options) -> ChatResponse: + raise NotImplementedError("子类必须实现 get_response 方法") - async def async_get_response( - self, - max_retries: int = 3, - retry_delay: float = 1.0, - update_history: bool = True, - **options - ) -> ChatResponse: - """ - 获取对话响应(异步) - """ - chat_options = { - "model": self.config.model, - "temperature": self.config.temperature, - **options - } - - last_error = None - for attempt in range(max_retries + 1): - try: - response_data = await self.async_chat_completion( - messages=self._chat_log, - **chat_options - ) - - response = ChatResponse(response_data) - - if not response.is_valid(): - raise Exception(f"API 返回错误: {response.error_message}") - - if update_history and response.message: - self._chat_log.append(response.message) - - self._last_response = response - return response - - except Exception as e: - last_error = e - if attempt < max_retries: - self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") - await asyncio.sleep(retry_delay * (2 ** attempt)) - else: - self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") - - raise last_error - - # === 流式响应 === - def stream_response(self, **options) -> Generator[str, None, None]: - """ - 流式获取响应内容(同步) - 返回生成器,逐个 yield 内容片段 - """ - chat_options = { - "model": self.config.model, - "temperature": self.config.temperature, - **options - } - - full_content = "" - - try: - for stream_resp in self.chat_completion( - messages=self._chat_log, - stream=True, - **chat_options - ): - if stream_resp.has_content: - content = stream_resp.content - full_content += content - yield content - - if stream_resp.is_finished: - break - - # 更新历史记录 - if full_content: - self._chat_log.append({ - "role": "assistant", - "content": full_content - }) - - except Exception as e: - self.logger.error(f"流式响应失败: {e}") - raise - - async def async_stream_response(self, **options) -> AsyncGenerator[str, None]: - """ - 流式获取响应内容(异步) - """ - chat_options = { - "model": self.config.model, - "temperature": self.config.temperature, - **options - } - - full_content = "" - - try: - async for stream_resp in self.async_chat_completion( - messages=self._chat_log, - stream=True, - **chat_options - ): - if stream_resp.has_content: - content = stream_resp.content - full_content += content - yield content - - if stream_resp.is_finished: - break - - # 更新历史记录 - if full_content: - self._chat_log.append({ - "role": "assistant", - "content": full_content - }) - - except Exception as e: - self.logger.error(f"异步流式响应失败: {e}") - raise + async def async_get_response(self, **options) -> ChatResponse: + raise NotImplementedError("子类必须实现 async_get_response 方法") # === 便捷方法 === def ask(self, question: str, **options) -> str: - """ - 问答便捷方法 - - Args: - question: 问题 - **options: 传递给 get_response 的参数 - - Returns: - 回答内容 - """ + """问答便捷方法""" self.user(question) response = self.get_response(**options) return response.content @@ -292,12 +104,12 @@ async def async_ask(self, question: str, **options) -> str: # === 对话历史管理 === def save(self, path: str, mode: str = 'a', index: int = 0): """保存对话历史到文件""" - # 确保目录存在 os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True) data = { "index": index, "chat_log": self._chat_log, + "config_type": type(self.config).__name__, "config": { "model": self.config.model, "api_base": self.config.api_base @@ -308,24 +120,34 @@ def save(self, path: str, mode: str = 'a', index: int = 0): f.write(json.dumps(data, ensure_ascii=False) + '\n') @classmethod - def load(cls, path: str) -> 'Chat': + def load(cls, path: str) -> 'ChatBase': """从文件加载对话历史""" with open(path, 'r', encoding='utf-8') as f: data = json.loads(f.read()) - chat = cls(msg=data['chat_log']) + # 根据保存的配置类型创建相应的配置对象 + config_type = data.get('config_type', 'OpenAIConfig') + if config_type == 'AzureOpenAIConfig': + config = AzureOpenAIConfig() + else: + config = OpenAIConfig() - # 如果有配置信息,应用它们 + # 应用保存的配置 if 'config' in data: for key, value in data['config'].items(): - if hasattr(chat.config, key): - setattr(chat.config, key, value) + if hasattr(config, key): + setattr(config, key, value) + # 使用工厂函数创建正确的实例 + chat = Chat(config=config) + chat._chat_log = data['chat_log'] return chat - def copy(self) -> 'Chat': + def copy(self) -> 'ChatBase': """复制 Chat 对象""" - return Chat(msg=self._chat_log.copy(), config=self.config) + new_chat = Chat(config=self.config) + new_chat._chat_log = self._chat_log.copy() + return new_chat # === 显示和调试 === def print_log(self, sep: str = "\n" + "-" * 50 + "\n"): @@ -339,6 +161,7 @@ def get_debug_info(self) -> Dict[str, Any]: """获取调试信息""" return { "message_count": len(self._chat_log), + "config_type": type(self.config).__name__, "model": self.config.model, "api_base": self.config.api_base, "last_response": self._last_response.get_debug_info() if self._last_response else None @@ -378,7 +201,198 @@ def __getitem__(self, index) -> Dict: return self._chat_log[index] def __repr__(self) -> str: - return f"" + config_type = type(self.config).__name__ + return f"" def __str__(self) -> str: - return self.__repr__() \ No newline at end of file + return self.__repr__() + + +class ChatOpenAI(ChatBase, OpenAIClient): + """OpenAI Chat 实现""" + + def __init__(self, config=None, logger=None, **kwargs): + # 先初始化 OpenAIClient(底层 HTTP 客户端) + OpenAIClient.__init__(self, config, logger, **kwargs) + # 再初始化 ChatBase(对话管理功能) + ChatBase.__init__(self, config, logger, **kwargs) + + def get_response( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + update_history: bool = True, + **options + ) -> ChatResponse: + """获取对话响应(同步)""" + chat_options = { + "model": self.config.model, + "temperature": getattr(self.config, 'temperature', None), + **options + } + + last_error = None + for attempt in range(max_retries + 1): + try: + response_data = self.chat_completion( + messages=self._chat_log, + **chat_options + ) + + response = ChatResponse(response_data) + + if not response.is_valid(): + raise Exception(f"API 返回错误: {response.error_message}") + + if update_history and response.message: + self._chat_log.append(response.message) + + self._last_response = response + return response + + except Exception as e: + last_error = e + if attempt < max_retries: + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") + time.sleep(retry_delay * (2 ** attempt)) + else: + self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") + + raise last_error + + async def async_get_response( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + update_history: bool = True, + **options + ) -> ChatResponse: + """获取对话响应(异步)""" + chat_options = { + "model": self.config.model, + "temperature": getattr(self.config, 'temperature', None), + **options + } + + last_error = None + for attempt in range(max_retries + 1): + try: + response_data = await self.async_chat_completion( + messages=self._chat_log, + **chat_options + ) + + response = ChatResponse(response_data) + + if not response.is_valid(): + raise Exception(f"API 返回错误: {response.error_message}") + + if update_history and response.message: + self._chat_log.append(response.message) + + self._last_response = response + return response + + except Exception as e: + last_error = e + if attempt < max_retries: + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") + await asyncio.sleep(retry_delay * (2 ** attempt)) + else: + self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") + + raise last_error + +class ChatAzure(ChatBase, AzureOpenAIClient): + """Azure OpenAI Chat 实现 - 继承 ChatOpenAI 复用逻辑""" + + def __init__(self, config=None, logger=None, **kwargs): + # 替换为 AzureOpenAIClient 初始化 + AzureOpenAIClient.__init__(self, config, logger, **kwargs) + ChatBase.__init__(self, config, logger, **kwargs) + + def get_response( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + update_history: bool = True, + **options + ) -> ChatResponse: + """获取对话响应(同步)""" + chat_options = { + "model": self.config.model, + "temperature": getattr(self.config, 'temperature', None), + **options + } + + last_error = None + for attempt in range(max_retries + 1): + try: + response_data = self.chat_completion( + messages=self._chat_log, + **chat_options + ) + + response = ChatResponse(response_data) + + if not response.is_valid(): + raise Exception(f"API 返回错误: {response.error_message}") + + if update_history and response.message: + self._chat_log.append(response.message) + + self._last_response = response + return response + + except Exception as e: + last_error = e + if attempt < max_retries: + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") + time.sleep(retry_delay * (2 ** attempt)) + else: + self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") + + raise last_error + + async def async_get_response( + self, + max_retries: int = 3, + retry_delay: float = 1.0, + update_history: bool = True, + **options + ) -> ChatResponse: + """获取对话响应(异步)""" + chat_options = { + "model": self.config.model, + "temperature": getattr(self.config, 'temperature', None), + **options + } + + last_error = None + for attempt in range(max_retries + 1): + try: + response_data = await self.async_chat_completion( + messages=self._chat_log, + **chat_options + ) + + response = ChatResponse(response_data) + + if not response.is_valid(): + raise Exception(f"API 返回错误: {response.error_message}") + + if update_history and response.message: + self._chat_log.append(response.message) + + self._last_response = response + return response + + except Exception as e: + last_error = e + if attempt < max_retries: + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") + await asyncio.sleep(retry_delay * (2 ** attempt)) + else: + self.logger.error(f"请求在 {max_retries + 1} 次尝试后失败") + + raise last_error \ No newline at end of file diff --git a/chattool/core/config.py b/chattool/core/config.py index e0be765..09f4e66 100644 --- a/chattool/core/config.py +++ b/chattool/core/config.py @@ -48,6 +48,11 @@ def to_data(self, *kwargs): return { key: value for key, value in self.__dict__.items() if key in kwargs } + + def update_kwargs(self, **kwargs): + for key, value in kwargs.items(): + if value is not None: + setattr(self, key, value) # OpenAI 专用配置 # core/config.py @@ -77,11 +82,17 @@ def __init__( self.api_base = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") if not self.model: self.model = os.getenv("OPENAI_API_MODEL", "gpt-3.5-turbo") + self._update_header() + + def _update_header(self): if not self.headers: - self.headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + self.headers = {"Content-Type": "application/json"} + if self.api_key: + self.headers["Authorization"] = f"Bearer {self.api_key}" + + def update_kwargs(self, **kwargs): + super().update_kwargs(**kwargs) + self._update_header() # Azure OpenAI 专用配置 class AzureOpenAIConfig(Config): @@ -116,23 +127,20 @@ def __init__( self.api_version = os.getenv("AZURE_OPENAI_API_VERSION") if not self.model: self.model = os.getenv("AZURE_OPENAI_API_MODEL", "") - # Azure 使用不同的请求头格式 if not self.headers: self.headers = { "Content-Type": "application/json", } - # update api_base - """获取带 API 密钥的请求 URL""" - endpoint = self.api_base.rstrip('/') - sym = '?' + def get_request_params(self) -> Dict[str, str]: + """获取 Azure API 请求参数""" + params = {} if self.api_version: - endpoint = f"{endpoint}{sym}api-version={self.api_version}" - sym = '&' + params['api-version'] = self.api_version if self.api_key: - endpoint = f"{endpoint}{sym}ak={self.api_key}" - self.api_base = endpoint + params['ak'] = self.api_key + return params # Anthropic 配置示例 class AnthropicConfig(Config): diff --git a/chattool/core/request.py b/chattool/core/request.py index 81bdda4..23ab787 100644 --- a/chattool/core/request.py +++ b/chattool/core/request.py @@ -13,8 +13,7 @@ class HTTPClient: def __init__(self, config: Optional[Config]=None, logger: Optional[logging.Logger] = None, **kwargs): if config is None: config = Config() - for key, value in kwargs.items(): - setattr(config, key, value) + config.update_kwargs(**kwargs) self.config = config self._sync_client: Optional[httpx.Client] = None self._async_client: Optional[httpx.AsyncClient] = None @@ -43,6 +42,25 @@ def _get_async_client(self) -> httpx.AsyncClient: ) return self._async_client + def _build_url(self, endpoint: str) -> str: + """构建完整的请求 URL""" + if not endpoint: + # 如果 endpoint 为空,直接使用 api_base + return self.config.api_base + + # 如果 endpoint 已经是完整 URL,直接返回 + if endpoint.startswith(('http://', 'https://')): + return endpoint + + # 处理相对路径 + base = self.config.api_base.rstrip('/') + endpoint = endpoint.lstrip('/') + + if endpoint: + return f"{base}/{endpoint}" + else: + return base + def _retry_request(self, request_func, *args, **kwargs): """重试机制装饰器""" last_exception = None @@ -53,10 +71,10 @@ def _retry_request(self, request_func, *args, **kwargs): except (httpx.RequestError, httpx.HTTPStatusError) as e: last_exception = e if attempt < self.config.max_retries: - self.logger.warning(f"Request failed (attempt {attempt + 1}): {e}") + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{self.config.max_retries + 1}): {e}") time.sleep(self.config.retry_delay * (2 ** attempt)) # 指数退避 else: - self.logger.error(f"Request failed after {self.config.max_retries + 1} attempts") + self.logger.error(f"请求失败,已尝试 {self.config.max_retries + 1} 次") raise last_exception @@ -70,10 +88,10 @@ async def _async_retry_request(self, request_func, *args, **kwargs): except (httpx.RequestError, httpx.HTTPStatusError) as e: last_exception = e if attempt < self.config.max_retries: - self.logger.warning(f"Request failed (attempt {attempt + 1}): {e}") + self.logger.warning(f"请求失败 (尝试 {attempt + 1}/{self.config.max_retries + 1}): {e}") await asyncio.sleep(self.config.retry_delay * (2 ** attempt)) else: - self.logger.error(f"Request failed after {self.config.max_retries + 1} attempts") + self.logger.error(f"请求失败,已尝试 {self.config.max_retries + 1} 次") raise last_exception @@ -89,6 +107,9 @@ def request( """同步请求""" client = self._get_sync_client() + # 构建完整 URL + url = self._build_url(endpoint) + # 合并headers merged_headers = self.config.headers.copy() if headers: @@ -97,7 +118,7 @@ def request( def _make_request(): return client.request( method=method, - url=endpoint if endpoint else self.config.api_base, + url=url, json=data, params=params, headers=merged_headers, @@ -120,6 +141,9 @@ async def async_request( """异步请求""" client = self._get_async_client() + # 构建完整 URL + url = self._build_url(endpoint) + # 合并headers merged_headers = self.config.headers.copy() if headers: @@ -128,7 +152,7 @@ async def async_request( async def _make_request(): return await client.request( method=method, - url=endpoint, + url=url, json=data, params=params, headers=merged_headers, @@ -190,45 +214,6 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.aclose() -class StreamResponse: - """流式响应包装类""" - - def __init__(self, chunk_data: Dict[str, Any]): - self.raw = chunk_data - self.id = chunk_data.get('id') - self.object = chunk_data.get('object') - self.created = chunk_data.get('created') - self.model = chunk_data.get('model') - - choices = chunk_data.get('choices', []) - if choices: - choice = choices[0] - self.delta = choice.get('delta', {}) - self.finish_reason = choice.get('finish_reason') - self.content = self.delta.get('content', '') - self.role = self.delta.get('role') - else: - self.delta = {} - self.finish_reason = None - self.content = '' - self.role = None - - @property - def has_content(self) -> bool: - """是否包含内容""" - return bool(self.content) - - @property - def is_finished(self) -> bool: - """是否完成""" - return self.finish_reason == 'stop' - - def __str__(self): - return self.content - - def __repr__(self): - return f"StreamResponse(content='{self.content}', finish_reason='{self.finish_reason}')" - class OpenAIClient(HTTPClient): _config_only_attrs = { 'api_key', 'api_base', 'headers', 'timeout', @@ -285,7 +270,7 @@ def chat_completion( uri: str = '/chat/completions', headers: Optional[Dict[str, str]] = None, **kwargs - ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: + ) -> Dict[str, Any]: """ OpenAI Chat Completion API (同步版本) @@ -315,7 +300,7 @@ def chat_completion( request_headers = self._prepare_headers(messages, headers) if data.get('stream'): - return self._stream_chat_completion(data, uri, request_headers) + data['stream'] = False #TODO: 流式响应不支持 response = self.post(uri, data=data, headers=request_headers) return response.json() @@ -331,7 +316,7 @@ async def async_chat_completion( uri: str = '/chat/completions', headers: Optional[Dict[str, str]] = None, **kwargs - ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: + ) -> Dict[str, Any]: """OpenAI Chat Completion API (异步版本)""" all_kwargs = { 'model': model, @@ -346,99 +331,11 @@ async def async_chat_completion( request_headers = self._prepare_headers(messages, headers) if data.get('stream'): - return self._async_stream_chat_completion(data, uri, request_headers) + data['stream'] = False #TODO: 流式响应不支持 response = await self.async_post(uri, data=data, headers=request_headers) return response.json() - - def _stream_chat_completion( - self, - data: Dict[str, Any], - uri: str = '/chat/completions', - headers: Optional[Dict[str, str]] = None - ) -> Generator[StreamResponse, None, None]: - """同步流式聊天完成""" - client = self._get_sync_client() - request_headers = headers or self.config.headers - - with client.stream( - "POST", - uri, - json=data, - headers=request_headers - ) as response: - response.raise_for_status() - yield from self._process_stream_response(response.iter_lines()) - - async def _async_stream_chat_completion( - self, - data: Dict[str, Any], - uri: str = '/chat/completions', - headers: Optional[Dict[str, str]] = None - ) -> AsyncGenerator[StreamResponse, None]: - """异步流式聊天完成""" - client = self._get_async_client() - request_headers = headers or self.config.headers - - async with client.stream( - "POST", - uri, - json=data, - headers=request_headers - ) as response: - response.raise_for_status() - async for chunk in self._async_process_stream_response(response.aiter_lines()): - yield chunk - - def _process_stream_response(self, lines): - """处理流式响应行""" - for line in lines: - if not line: - continue - - line_str = line.decode('utf-8').strip() - chunk = self._parse_stream_line(line_str) - if chunk: - yield chunk - if chunk.is_finished: - break - - async def _async_process_stream_response(self, lines): - """异步处理流式响应行""" - async for line in lines: - if not line: - continue - - line_str = line.decode('utf-8').strip() - chunk = self._parse_stream_line(line_str) - if chunk: - yield chunk - if chunk.is_finished: - break - - def _parse_stream_line(self, line_str: str) -> Optional[StreamResponse]: - """解析单行流式响应""" - if not line_str.startswith('data: '): - return None - - data_str = line_str[6:].strip() - - if data_str == '[DONE]': - return None - - if not data_str: - return None - - try: - chunk_data = json.loads(data_str) - return self._process_stream_chunk(chunk_data) - except json.JSONDecodeError as e: - self.logger.warning(f"Failed to decode JSON: {e}, data: {data_str}") - return None - except Exception as e: - self.logger.error(f"Error processing stream chunk: {e}") - return None - + def embeddings( self, input_text: Union[str, List[str]], @@ -480,10 +377,7 @@ async def async_embeddings( response = await self.async_post("/embeddings", data=data) return response.json() - - def _process_stream_chunk(self, chunk_data: Dict[str, Any]) -> StreamResponse: - """处理流式响应的单个数据块""" - return StreamResponse(chunk_data) + class AzureOpenAIClient(OpenAIClient): """Azure OpenAI 客户端""" @@ -495,6 +389,7 @@ class AzureOpenAIClient(OpenAIClient): def __init__(self, config: Optional[AzureOpenAIConfig] = None, logger = None, **kwargs): if config is None: config = AzureOpenAIConfig() + self.config = config super().__init__(config, logger, **kwargs) def _generate_log_id(self, messages: List[Dict[str, str]]) -> str: @@ -512,6 +407,13 @@ def _prepare_headers(self, messages: List[Dict[str, str]], custom_headers: Optio return headers + def _prepare_params(self, custom_params: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """准备 Azure API 请求参数""" + params = self.config.get_request_params() + if custom_params: + params.update(custom_params) + return params + def chat_completion( self, messages: List[Dict[str, str]], @@ -520,20 +422,32 @@ def chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, + uri: str = '', + params: Optional[Dict[str, str]] = None, **kwargs - ) -> Union[Dict[str, Any], Generator[StreamResponse, None, None]]: + ) -> Dict[str, Any]: """Azure OpenAI Chat Completion API (同步版本)""" - # 调用父类方法,但使用空字符串作为 URI(因为 api_base 已经是完整地址) - return super().chat_completion( - messages=messages, - model=model, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - stream=stream, - uri="", # Azure 的 api_base 已经是完整地址 + + # 合并参数 + all_kwargs = { + 'model': model, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'stream': stream, **kwargs - ) + } + + # 构建数据和请求头 + data = self._build_chat_data(messages, **all_kwargs) + request_headers = self._prepare_headers(messages) + request_params = self._prepare_params(params) + + if data.get('stream'): + data['stream'] = False #TODO: 流式响应不支持 + + response = self.post(uri, data=data, headers=request_headers, params=request_params) + return response.json() async def async_chat_completion( self, @@ -543,17 +457,53 @@ async def async_chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, + uri: str = '', + params: Optional[Dict[str, str]] = None, **kwargs - ) -> Union[Dict[str, Any], AsyncGenerator[StreamResponse, None]]: + ) -> Dict[str, Any]: """Azure OpenAI Chat Completion API (异步版本)""" - # 调用父类方法 - return await super().async_chat_completion( - messages=messages, - model=model, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - stream=stream, - uri="", + + # 合并参数 + all_kwargs = { + 'model': model, + 'temperature': temperature, + 'top_p': top_p, + 'max_tokens': max_tokens, + 'stream': stream, + **kwargs + } + + # 构建数据和请求头 + data = self._build_chat_data(messages, **all_kwargs) + request_headers = self._prepare_headers(messages) + request_params = self._prepare_params(params) + + if data.get('stream'): + return self._async_stream_chat_completion(data, uri, request_headers, request_params) + + response = await self.async_post(uri, data=data, headers=request_headers, params=request_params) + return response.json() + + async def async_embeddings( + self, + input_text: Union[str, List[str]], + model: Optional[str] = None, + uri: str = '', + params: Optional[Dict[str, str]] = None, + **kwargs + ) -> Dict[str, Any]: + """异步 Azure OpenAI Embeddings API""" + all_kwargs = { + 'model': model or self.config.get('model', 'text-embedding-ada-002'), + 'input': input_text, **kwargs - ) \ No newline at end of file + } + + data = {} + for key, value in all_kwargs.items(): + if value is not None: + data[key] = value + + request_params = self._prepare_params(params) + response = await self.async_post(uri, data=data, params=request_params) + return response.json() \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 3456462..eb9185b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,21 +8,21 @@ TEST_PATH = 'tests/testfiles/' -@pytest.fixture(scope="session") +@pytest.fixture def testpath(): return TEST_PATH -@pytest.fixture(scope="session") +@pytest.fixture def oai_config(): """OpenAI 配置""" return OpenAIConfig() -@pytest.fixture(scope="session") +@pytest.fixture def oai_client(oai_config): """OpenAI 客户端""" return OpenAIClient(oai_config) -@pytest.fixture(scope="session") +@pytest.fixture def azure_config(): """Azure OpenAI 配置 fixture""" return AzureOpenAIConfig( diff --git a/tests/test_azureclient.py b/tests/test_azureclient.py index 46dfd22..4373402 100644 --- a/tests/test_azureclient.py +++ b/tests/test_azureclient.py @@ -3,29 +3,9 @@ from unittest.mock import Mock, patch, MagicMock from typing import Dict, Any import hashlib - from chattool.core.config import AzureOpenAIConfig from chattool.core.request import AzureOpenAIClient - -@pytest.fixture(scope="session") -def azure_config(): - """Azure OpenAI 配置 fixture""" - return AzureOpenAIConfig( - api_key="test-azure-key", - api_base="https://test-resource.openai.azure.com", - api_version="2024-02-15-preview", - model="gpt-35-turbo", - temperature=0.7 - ) - - -@pytest.fixture -def azure_client(azure_config): - """Azure OpenAI 客户端 fixture""" - return AzureOpenAIClient(azure_config) - - @pytest.fixture def mock_azure_response_data(): """模拟的 Azure OpenAI API 响应数据""" @@ -51,53 +31,6 @@ def mock_azure_response_data(): } } - -@pytest.fixture -def mock_azure_stream_response_data(): - """模拟的 Azure 流式响应数据""" - return [ - { - "id": "chatcmpl-azure123", - "object": "chat.completion.chunk", - "created": 1677652288, - "model": "gpt-35-turbo", - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": ""}, - "finish_reason": None - } - ] - }, - { - "id": "chatcmpl-azure123", - "object": "chat.completion.chunk", - "created": 1677652288, - "model": "gpt-35-turbo", - "choices": [ - { - "index": 0, - "delta": {"content": "Hello from Azure"}, - "finish_reason": None - } - ] - }, - { - "id": "chatcmpl-azure123", - "object": "chat.completion.chunk", - "created": 1677652288, - "model": "gpt-35-turbo", - "choices": [ - { - "index": 0, - "delta": {"content": "!"}, - "finish_reason": "stop" - } - ] - } - ] - - class TestAzureOpenAIConfig: """测试 Azure OpenAI 配置类""" @@ -117,8 +50,6 @@ def test_config_initialization_with_custom_values(self): temperature=0.5, max_tokens=2000 ) - assert "custom-key" in config.api_base # api_base 应该包含 ak 参数 - assert "2024-01-01" in config.api_base # api_base 应该包含 api-version assert config.model == "gpt-4" assert config.temperature == 0.5 assert config.max_tokens == 2000 @@ -133,8 +64,6 @@ def test_config_api_base_construction(self): # 验证 URL 构建正确 assert config.api_base.startswith("https://test.openai.azure.com") - assert "api-version=2024-02-15-preview" in config.api_base - assert "ak=test-key" in config.api_base assert not config.api_base.endswith("//") # 不应该有双斜杠 def test_config_get_method(self): @@ -284,26 +213,6 @@ def test_parameter_override_azure(self, azure_client): assert call_data["temperature"] == 0.1 assert call_data["azure_custom_param"] == "test" - @patch('chattool.core.request.AzureOpenAIClient._stream_chat_completion') - def test_chat_completion_stream_mode(self, mock_stream, azure_client): - """测试 Azure 流式模式""" - messages = [{"role": "user", "content": "Hello"}] - mock_stream.return_value = iter([]) - - result = azure_client.chat_completion(messages, stream=True) - - # 验证调用了流式方法 - mock_stream.assert_called_once() - call_args = mock_stream.call_args - - # 验证数据包含 stream=True - call_data = call_args[0][0] - assert call_data["stream"] is True - - # 验证传递了 Azure 请求头 - # azure_headers = call_args[0][1] - # assert "X-TT-LOGID" in azure_headers - def test_config_only_attrs_azure(self, azure_client): """测试 Azure 配置专用属性列表""" expected_attrs = { diff --git a/tests/test_oaiclient.py b/tests/test_oaiclient.py index 5dfd3df..a779fd7 100644 --- a/tests/test_oaiclient.py +++ b/tests/test_oaiclient.py @@ -30,53 +30,6 @@ def mock_response_data(): } } - -@pytest.fixture -def mock_stream_response_data(): - """模拟的流式响应数据""" - return [ - { - "id": "chatcmpl-test123", - "object": "chat.completion.chunk", - "created": 1677652288, - "model": "gpt-3.5-turbo", - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": ""}, - "finish_reason": None - } - ] - }, - { - "id": "chatcmpl-test123", - "object": "chat.completion.chunk", - "created": 1677652288, - "model": "gpt-3.5-turbo", - "choices": [ - { - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": None - } - ] - }, - { - "id": "chatcmpl-test123", - "object": "chat.completion.chunk", - "created": 1677652288, - "model": "gpt-3.5-turbo", - "choices": [ - { - "index": 0, - "delta": {"content": "!"}, - "finish_reason": "stop" - } - ] - } - ] - - class TestOpenAIConfig: """测试 OpenAI 配置类""" @@ -232,19 +185,6 @@ def test_chat_completion_parameter_override(self, oai_client): assert call_data["temperature"] == 0.2 assert call_data["custom_param"] == "test" - @patch('chattool.core.request.OpenAIClient._stream_chat_completion') - def test_chat_completion_stream_mode(self, mock_stream, oai_client): - """测试流式模式""" - messages = [{"role": "user", "content": "Hello"}] - mock_stream.return_value = iter([]) - - result = oai_client.chat_completion(messages, stream=True) - - # 验证调用了流式方法 - mock_stream.assert_called_once() - call_data = mock_stream.call_args[0][0] - assert call_data["stream"] is True - @patch('chattool.core.request.OpenAIClient.post') def test_embeddings(self, mock_post, oai_client): """测试嵌入 API"""