diff --git a/astrbot/builtin_stars/web_searcher/engines/__init__.py b/astrbot/builtin_stars/web_searcher/engines/__init__.py index 699438602..82def138f 100644 --- a/astrbot/builtin_stars/web_searcher/engines/__init__.py +++ b/astrbot/builtin_stars/web_searcher/engines/__init__.py @@ -32,6 +32,7 @@ class SearchResult: title: str url: str snippet: str + favicon: str | None = None def __str__(self) -> str: return f"{self.title} - {self.url}\n{self.snippet}" diff --git a/astrbot/builtin_stars/web_searcher/main.py b/astrbot/builtin_stars/web_searcher/main.py index 4745cd0c0..e8388c816 100644 --- a/astrbot/builtin_stars/web_searcher/main.py +++ b/astrbot/builtin_stars/web_searcher/main.py @@ -1,11 +1,13 @@ import asyncio +import json import random +import uuid import aiohttp from bs4 import BeautifulSoup from readability import Document -from astrbot.api import AstrBotConfig, llm_tool, logger, star +from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter from astrbot.api.provider import ProviderRequest from astrbot.core.provider.func_tool_manager import FunctionToolManager @@ -151,6 +153,7 @@ async def _web_search_tavily( title=item.get("title"), url=item.get("url"), snippet=item.get("content"), + favicon=item.get("favicon"), ) results.append(result) return results @@ -272,7 +275,7 @@ async def search_from_tavily( self, event: AstrMessageEvent, query: str, - max_results: int = 5, + max_results: int = 7, search_depth: str = "basic", topic: str = "general", days: int = 3, @@ -285,7 +288,7 @@ async def search_from_tavily( Args: query(string): Required. Search query. - max_results(number): Optional. The maximum number of results to return. Default is 5. Range is 5-20. + max_results(number): Optional. The maximum number of results to return. Default is 7. Range is 5-20. search_depth(string): Optional. The depth of the search, must be one of 'basic', 'advanced'. Default is "basic". topic(string): Optional. The topic of the search, must be one of 'general', 'news'. Default is "general". days(number): Optional. The number of days back from the current date to include in the search results. Please note that this feature is only available when using the 'news' search topic. @@ -296,15 +299,12 @@ async def search_from_tavily( """ logger.info(f"web_searcher - search_from_tavily: {query}") cfg = self.context.get_config(umo=event.unified_msg_origin) - websearch_link = cfg["provider_settings"].get("web_search_link", False) + # websearch_link = cfg["provider_settings"].get("web_search_link", False) if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []): raise ValueError("Error: Tavily API key is not configured in AstrBot.") # build payload - payload = { - "query": query, - "max_results": max_results, - } + payload = {"query": query, "max_results": max_results, "include_favicon": True} if search_depth not in ["basic", "advanced"]: search_depth = "basic" payload["search_depth"] = search_depth @@ -328,14 +328,22 @@ async def search_from_tavily( return "Error: Tavily web searcher does not return any results." ret_ls = [] - for result in results: - ret_ls.append(f"\nTitle: {result.title}") - ret_ls.append(f"URL: {result.url}") - ret_ls.append(f"Content: {result.snippet}") - ret = "\n".join(ret_ls) - - if websearch_link: - ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。" + ref_uuid = str(uuid.uuid4())[:4] + for idx, result in enumerate(results, 1): + index = f"{ref_uuid}.{idx}" + ret_ls.append( + { + "title": f"{result.title}", + "url": f"{result.url}", + "snippet": f"{result.snippet}", + # TODO: do not need ref for non-webchat platform adapter + "index": index, + } + ) + if result.favicon: + sp.temorary_cache["_ws_favicon"][result.url] = result.favicon + # ret = "\n".join(ret_ls) + ret = json.dumps({"results": ret_ls}, ensure_ascii=False) return ret @llm_tool("tavily_extract_web_page") diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index 9d85de0cc..4aa1533b6 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -3,6 +3,7 @@ from mcp.types import CallToolResult from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.message import Message from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_context import AstrAgentContext @@ -34,6 +35,29 @@ async def on_tool_end( ): run_context.context.event.clear_result() + # special handle web_search_tavily + if ( + tool.name == "web_search_tavily" + and len(run_context.messages) > 0 + and tool_result + and len(tool_result.content) + ): + # inject system prompt + first_part = run_context.messages[0] + if ( + isinstance(first_part, Message) + and first_part.role == "system" + and first_part.content + and isinstance(first_part.content, str) + ): + # we assume system part is str + first_part.content += ( + "Always cite web search results you rely on. " + "Index is a unique identifier for each search result. " + "Use the exact citation format index (e.g. abcd.3) " + "after the sentence that uses the information. Do not invent citations." + ) + class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]): pass diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index ccd394ee4..765045513 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,8 +1,11 @@ import asyncio import os import threading +from collections import defaultdict from typing import Any, TypeVar, overload +from apscheduler.schedulers.background import BackgroundScheduler + from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Preference @@ -20,11 +23,22 @@ def __init__(self, db_helper: BaseDatabase, json_storage_path=None): ) self.path = json_storage_path self.db_helper = db_helper + self.temorary_cache: dict[str, dict[str, Any]] = defaultdict(dict) + """automatically clear per 24 hours. Might be helpful in some cases XD""" self._sync_loop = asyncio.new_event_loop() t = threading.Thread(target=self._sync_loop.run_forever, daemon=True) t.start() + self._scheduler = BackgroundScheduler() + self._scheduler.add_job( + self._clear_temporary_cache, "interval", hours=24, id="clear_sp_temp_cache" + ) + self._scheduler.start() + + def _clear_temporary_cache(self): + self.temorary_cache.clear() + async def get_async( self, scope: str, diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index de12daab9..92ff4c3fe 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -2,6 +2,7 @@ import json import mimetypes import os +import re import uuid from contextlib import asynccontextmanager from typing import cast @@ -9,7 +10,7 @@ from quart import Response as QuartResponse from quart import g, make_response, request, send_file -from astrbot.core import logger +from astrbot.core import logger, sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr @@ -225,6 +226,64 @@ async def _create_attachment_from_file( "filename": os.path.basename(file_path), } + def _extract_web_search_refs( + self, accumulated_text: str, accumulated_parts: list + ) -> dict: + """从消息中提取 web_search_tavily 的引用 + + Args: + accumulated_text: 累积的文本内容 + accumulated_parts: 累积的消息部分列表 + + Returns: + 包含 used 列表的字典,记录被引用的搜索结果 + """ + # 从 accumulated_parts 中找到所有 web_search_tavily 的工具调用结果 + web_search_results = {} + tool_call_parts = [ + p + for p in accumulated_parts + if p.get("type") == "tool_call" and p.get("tool_calls") + ] + + for part in tool_call_parts: + for tool_call in part["tool_calls"]: + if tool_call.get("name") != "web_search_tavily" or not tool_call.get( + "result" + ): + continue + try: + result_data = json.loads(tool_call["result"]) + for item in result_data.get("results", []): + if idx := item.get("index"): + web_search_results[idx] = { + "url": item.get("url"), + "title": item.get("title"), + "snippet": item.get("snippet"), + } + except (json.JSONDecodeError, KeyError): + pass + + if not web_search_results: + return {} + + # 从文本中提取所有 xxx 标签并去重 + ref_indices = { + m.strip() for m in re.findall(r"(.*?)", accumulated_text) + } + + # 构建被引用的结果列表 + used_refs = [] + for ref_index in ref_indices: + if ref_index not in web_search_results: + continue + payload = {"index": ref_index, **web_search_results[ref_index]} + if favicon := sp.temorary_cache.get("_ws_favicon", {}).get(payload["url"]): + payload["favicon"] = favicon + used_refs.append(payload) + + return {"used": used_refs} if used_refs else {} + async def _save_bot_message( self, webchat_conv_id: str, @@ -232,6 +291,7 @@ async def _save_bot_message( media_parts: list, reasoning: str, agent_stats: dict, + refs: dict, ): """保存 bot 消息到历史记录,返回保存的记录""" bot_message_parts = [] @@ -244,6 +304,8 @@ async def _save_bot_message( new_his["reasoning"] = reasoning if agent_stats: new_his["agent_stats"] = agent_stats + if refs: + new_his["refs"] = refs record = await self.platform_history_mgr.insert( platform_id="webchat", @@ -305,6 +367,7 @@ async def stream(): accumulated_reasoning = "" tool_calls = {} agent_stats = {} + refs = {} try: async with track_conversation(self.running_convs, webchat_conv_id): while True: @@ -426,12 +489,26 @@ async def stream(): or chain_type == "tool_call_result" ): continue + + # 提取 web_search_tavily 引用 + try: + refs = self._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"Failed to extract web search refs: {e}", + exc_info=True, + ) + saved_record = await self._save_bot_message( webchat_conv_id, accumulated_text, accumulated_parts, accumulated_reasoning, agent_stats, + refs, ) # 发送保存的消息信息给前端 if saved_record and not client_disconnected: @@ -451,6 +528,7 @@ async def stream(): accumulated_reasoning = "" # tool_calls = {} agent_stats = {} + refs = {} except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 4141f232e..a2c85b946 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -55,6 +55,7 @@ @openImagePreview="openImagePreview" @replyMessage="handleReplyMessage" @replyWithText="handleReplyWithText" + @openRefs="handleOpenRefs" ref="messageList" />
@@ -146,6 +147,8 @@ /> + +