Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions astrbot/builtin_stars/web_searcher/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class SearchResult:
title: str
url: str
snippet: str
favicon: str | None = None

def __str__(self) -> str:
return f"{self.title} - {self.url}\n{self.snippet}"
Expand Down
40 changes: 24 additions & 16 deletions astrbot/builtin_stars/web_searcher/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
import json
import random
import uuid

import aiohttp
from bs4 import BeautifulSoup
from readability import Document

from astrbot.api import AstrBotConfig, llm_tool, logger, star
from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
from astrbot.api.provider import ProviderRequest
from astrbot.core.provider.func_tool_manager import FunctionToolManager
Expand Down Expand Up @@ -151,6 +153,7 @@ async def _web_search_tavily(
title=item.get("title"),
url=item.get("url"),
snippet=item.get("content"),
favicon=item.get("favicon"),
)
results.append(result)
return results
Expand Down Expand Up @@ -272,7 +275,7 @@ async def search_from_tavily(
self,
event: AstrMessageEvent,
query: str,
max_results: int = 5,
max_results: int = 7,
search_depth: str = "basic",
topic: str = "general",
days: int = 3,
Expand All @@ -285,7 +288,7 @@ async def search_from_tavily(

Args:
query(string): Required. Search query.
max_results(number): Optional. The maximum number of results to return. Default is 5. Range is 5-20.
max_results(number): Optional. The maximum number of results to return. Default is 7. Range is 5-20.
search_depth(string): Optional. The depth of the search, must be one of 'basic', 'advanced'. Default is "basic".
topic(string): Optional. The topic of the search, must be one of 'general', 'news'. Default is "general".
days(number): Optional. The number of days back from the current date to include in the search results. Please note that this feature is only available when using the 'news' search topic.
Expand All @@ -296,15 +299,12 @@ async def search_from_tavily(
"""
logger.info(f"web_searcher - search_from_tavily: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
websearch_link = cfg["provider_settings"].get("web_search_link", False)
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []):
raise ValueError("Error: Tavily API key is not configured in AstrBot.")

# build payload
payload = {
"query": query,
"max_results": max_results,
}
payload = {"query": query, "max_results": max_results, "include_favicon": True}
if search_depth not in ["basic", "advanced"]:
search_depth = "basic"
payload["search_depth"] = search_depth
Expand All @@ -328,14 +328,22 @@ async def search_from_tavily(
return "Error: Tavily web searcher does not return any results."

ret_ls = []
for result in results:
ret_ls.append(f"\nTitle: {result.title}")
ret_ls.append(f"URL: {result.url}")
ret_ls.append(f"Content: {result.snippet}")
ret = "\n".join(ret_ls)

if websearch_link:
ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
ref_uuid = str(uuid.uuid4())[:4]
for idx, result in enumerate(results, 1):
index = f"{ref_uuid}.{idx}"
ret_ls.append(
{
"title": f"{result.title}",
"url": f"{result.url}",
"snippet": f"{result.snippet}",
# TODO: do not need ref for non-webchat platform adapter
"index": index,
}
)
if result.favicon:
sp.temorary_cache["_ws_favicon"][result.url] = result.favicon
# ret = "\n".join(ret_ls)
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
return ret

@llm_tool("tavily_extract_web_page")
Expand Down
24 changes: 24 additions & 0 deletions astrbot/core/astr_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mcp.types import CallToolResult

from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
Expand Down Expand Up @@ -34,6 +35,29 @@ async def on_tool_end(
):
run_context.context.event.clear_result()

# special handle web_search_tavily
if (
tool.name == "web_search_tavily"
and len(run_context.messages) > 0
and tool_result
and len(tool_result.content)
):
# inject system prompt
first_part = run_context.messages[0]
if (
isinstance(first_part, Message)
and first_part.role == "system"
and first_part.content
and isinstance(first_part.content, str)
):
# we assume system part is str
first_part.content += (
"Always cite web search results you rely on. "
"Index is a unique identifier for each search result. "
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
"after the sentence that uses the information. Do not invent citations."
)


class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
pass
Expand Down
14 changes: 14 additions & 0 deletions astrbot/core/utils/shared_preferences.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import os
import threading
from collections import defaultdict
from typing import Any, TypeVar, overload

from apscheduler.schedulers.background import BackgroundScheduler

from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Preference

Expand All @@ -20,11 +23,22 @@ def __init__(self, db_helper: BaseDatabase, json_storage_path=None):
)
self.path = json_storage_path
self.db_helper = db_helper
self.temorary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
"""automatically clear per 24 hours. Might be helpful in some cases XD"""

self._sync_loop = asyncio.new_event_loop()
t = threading.Thread(target=self._sync_loop.run_forever, daemon=True)
t.start()

self._scheduler = BackgroundScheduler()
self._scheduler.add_job(
self._clear_temporary_cache, "interval", hours=24, id="clear_sp_temp_cache"
)
self._scheduler.start()

def _clear_temporary_cache(self):
self.temorary_cache.clear()

async def get_async(
self,
scope: str,
Expand Down
80 changes: 79 additions & 1 deletion astrbot/dashboard/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import json
import mimetypes
import os
import re
import uuid
from contextlib import asynccontextmanager
from typing import cast

from quart import Response as QuartResponse
from quart import g, make_response, request, send_file

from astrbot.core import logger
from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
Expand Down Expand Up @@ -225,13 +226,72 @@ async def _create_attachment_from_file(
"filename": os.path.basename(file_path),
}

def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
) -> dict:
"""从消息中提取 web_search_tavily 的引用

Args:
accumulated_text: 累积的文本内容
accumulated_parts: 累积的消息部分列表

Returns:
包含 used 列表的字典,记录被引用的搜索结果
"""
# 从 accumulated_parts 中找到所有 web_search_tavily 的工具调用结果
web_search_results = {}
tool_call_parts = [
p
for p in accumulated_parts
if p.get("type") == "tool_call" and p.get("tool_calls")
]

for part in tool_call_parts:
for tool_call in part["tool_calls"]:
if tool_call.get("name") != "web_search_tavily" or not tool_call.get(
"result"
):
continue
try:
result_data = json.loads(tool_call["result"])
for item in result_data.get("results", []):
if idx := item.get("index"):
web_search_results[idx] = {
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
except (json.JSONDecodeError, KeyError):
pass

if not web_search_results:
return {}

# 从文本中提取所有 <ref>xxx</ref> 标签并去重
ref_indices = {
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
}

# 构建被引用的结果列表
used_refs = []
for ref_index in ref_indices:
if ref_index not in web_search_results:
continue
payload = {"index": ref_index, **web_search_results[ref_index]}
if favicon := sp.temorary_cache.get("_ws_favicon", {}).get(payload["url"]):
payload["favicon"] = favicon
used_refs.append(payload)

return {"used": used_refs} if used_refs else {}

async def _save_bot_message(
self,
webchat_conv_id: str,
text: str,
media_parts: list,
reasoning: str,
agent_stats: dict,
refs: dict,
):
"""保存 bot 消息到历史记录,返回保存的记录"""
bot_message_parts = []
Expand All @@ -244,6 +304,8 @@ async def _save_bot_message(
new_his["reasoning"] = reasoning
if agent_stats:
new_his["agent_stats"] = agent_stats
if refs:
new_his["refs"] = refs

record = await self.platform_history_mgr.insert(
platform_id="webchat",
Expand Down Expand Up @@ -305,6 +367,7 @@ async def stream():
accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
refs = {}
try:
async with track_conversation(self.running_convs, webchat_conv_id):
while True:
Expand Down Expand Up @@ -426,12 +489,26 @@ async def stream():
or chain_type == "tool_call_result"
):
continue

# 提取 web_search_tavily 引用
try:
refs = self._extract_web_search_refs(
accumulated_text,
accumulated_parts,
)
except Exception as e:
logger.exception(
f"Failed to extract web search refs: {e}",
exc_info=True,
)

saved_record = await self._save_bot_message(
webchat_conv_id,
accumulated_text,
accumulated_parts,
accumulated_reasoning,
agent_stats,
refs,
)
# 发送保存的消息信息给前端
if saved_record and not client_disconnected:
Expand All @@ -451,6 +528,7 @@ async def stream():
accumulated_reasoning = ""
# tool_calls = {}
agent_stats = {}
refs = {}
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)

Expand Down
19 changes: 19 additions & 0 deletions dashboard/src/components/chat/Chat.vue
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
@openImagePreview="openImagePreview"
@replyMessage="handleReplyMessage"
@replyWithText="handleReplyWithText"
@openRefs="handleOpenRefs"
ref="messageList" />
<div class="message-list-fade" :class="{ 'fade-dark': isDark }"></div>
</div>
Expand Down Expand Up @@ -146,6 +147,8 @@
/>
</div>

<!-- Refs Sidebar -->
<RefsSidebar v-model="refsSidebarOpen" :refs="refsSidebarRefs" />
</div>
</v-card-text>
</v-card>
Expand Down Expand Up @@ -198,6 +201,7 @@ import ChatInput from '@/components/chat/ChatInput.vue';
import ProjectDialog from '@/components/chat/ProjectDialog.vue';
import ProjectView from '@/components/chat/ProjectView.vue';
import WelcomeView from '@/components/chat/WelcomeView.vue';
import RefsSidebar from '@/components/chat/message_list_comps/RefsSidebar.vue';
import type { ProjectFormData } from '@/components/chat/ProjectDialog.vue';
import { useSessions } from '@/composables/useSessions';
import { useMessages } from '@/composables/useMessages';
Expand Down Expand Up @@ -406,6 +410,21 @@ function handleReplyWithText(replyData: any) {
};
}

// Refs Sidebar 状态
const refsSidebarOpen = ref(false);
const refsSidebarRefs = ref<any>(null);

function handleOpenRefs(refs: any) {
// 如果sidebar已打开且点击的是同一个refs,则关闭
if (refsSidebarOpen.value && refsSidebarRefs.value === refs) {
refsSidebarOpen.value = false;
} else {
// 否则打开sidebar并更新refs
refsSidebarRefs.value = refs;
refsSidebarOpen.value = true;
}
}

async function handleSelectConversation(sessionIds: string[]) {
if (!sessionIds[0]) return;

Expand Down
1 change: 0 additions & 1 deletion dashboard/src/components/chat/ConversationSidebar.vue
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ function handleDeleteConversation(session: Session) {
display: flex;
flex-direction: column;
padding: 0;
border-right: 1px solid rgba(0, 0, 0, 0.04);
height: 100%;
max-height: 100%;
position: relative;
Expand Down
Loading