mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 10:10:15 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dcb714ec8c |
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
@@ -70,37 +69,6 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
tool_result,
|
||||
)
|
||||
|
||||
# special handle web_search_tavily
|
||||
platform_name = run_context.context.event.get_platform_name()
|
||||
if (
|
||||
platform_name == "webchat"
|
||||
and tool.name
|
||||
in [
|
||||
"web_search_baidu",
|
||||
"web_search_tavily",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
]
|
||||
and len(run_context.messages) > 0
|
||||
and tool_result
|
||||
and len(tool_result.content)
|
||||
):
|
||||
# inject system prompt
|
||||
first_part = run_context.messages[0]
|
||||
if (
|
||||
isinstance(first_part, Message)
|
||||
and first_part.role == "system"
|
||||
and first_part.content
|
||||
and isinstance(first_part.content, str)
|
||||
):
|
||||
# we assume system part is str
|
||||
first_part.content += (
|
||||
"Always cite web search results you rely on. "
|
||||
"Index is a unique identifier for each search result. "
|
||||
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
|
||||
"after the sentence that uses the information. Do not invent citations."
|
||||
)
|
||||
|
||||
|
||||
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
pass
|
||||
|
||||
@@ -115,6 +115,20 @@ from astrbot.core.utils.quoted_message_parser import (
|
||||
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
||||
|
||||
LLM_ERROR_MESSAGE_EXTRA_KEY = "_llm_error_message"
|
||||
WEB_SEARCH_CITATION_TOOL_NAMES = frozenset(
|
||||
{
|
||||
"web_search_baidu",
|
||||
"web_search_tavily",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
}
|
||||
)
|
||||
WEB_SEARCH_CITATION_PROMPT = (
|
||||
"Always cite web search results you rely on. "
|
||||
"Index is a unique identifier for each search result. "
|
||||
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
|
||||
"after the sentence that uses the information. Do not invent citations."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -1149,6 +1163,23 @@ async def _apply_web_search_tools(
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool))
|
||||
|
||||
|
||||
def _apply_web_search_citation_prompt(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
if event.get_platform_name() != "webchat" or not req.func_tool:
|
||||
return
|
||||
|
||||
if not any(req.func_tool.get_tool(name) for name in WEB_SEARCH_CITATION_TOOL_NAMES):
|
||||
return
|
||||
|
||||
system_prompt = req.system_prompt or ""
|
||||
if WEB_SEARCH_CITATION_PROMPT in system_prompt:
|
||||
return
|
||||
|
||||
req.system_prompt = f"{system_prompt}\n{WEB_SEARCH_CITATION_PROMPT}\n"
|
||||
|
||||
|
||||
def _get_compress_provider(
|
||||
config: MainAgentBuildConfig,
|
||||
plugin_context: Context,
|
||||
@@ -1520,6 +1551,8 @@ async def build_main_agent(
|
||||
if action_type == "live":
|
||||
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
|
||||
|
||||
_apply_web_search_citation_prompt(event, req)
|
||||
|
||||
reset_coro = agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
|
||||
@@ -476,6 +476,46 @@ class TestBuiltinToolInjection:
|
||||
assert req.func_tool.get_tool("web_search_firecrawl") is search_tool
|
||||
assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool
|
||||
|
||||
def test_apply_web_search_citation_prompt_for_webchat(self, mock_event):
|
||||
module = ama
|
||||
req = ProviderRequest(system_prompt="base")
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_tavily"
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(search_tool)
|
||||
mock_event.get_platform_name.return_value = "webchat"
|
||||
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
|
||||
assert module.WEB_SEARCH_CITATION_PROMPT in req.system_prompt
|
||||
|
||||
def test_apply_web_search_citation_prompt_is_idempotent(self, mock_event):
|
||||
module = ama
|
||||
req = ProviderRequest(system_prompt="")
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_tavily"
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(search_tool)
|
||||
mock_event.get_platform_name.return_value = "webchat"
|
||||
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
|
||||
assert req.system_prompt.count(module.WEB_SEARCH_CITATION_PROMPT) == 1
|
||||
|
||||
def test_apply_web_search_citation_prompt_requires_webchat(self, mock_event):
|
||||
module = ama
|
||||
req = ProviderRequest(system_prompt="")
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_tavily"
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(search_tool)
|
||||
mock_event.get_platform_name.return_value = "test_platform"
|
||||
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
|
||||
assert module.WEB_SEARCH_CITATION_PROMPT not in req.system_prompt
|
||||
|
||||
def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context):
|
||||
"""Test cron tool injection through the builtin tool manager."""
|
||||
module = ama
|
||||
|
||||
Reference in New Issue
Block a user