Compare commits

...

1 Commits

3 changed files with 73 additions and 32 deletions

View File

@@ -3,7 +3,6 @@ from typing import Any
from mcp.types import CallToolResult
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import Message
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
@@ -70,37 +69,6 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
tool_result,
)
# special handle web_search_tavily
platform_name = run_context.context.event.get_platform_name()
if (
platform_name == "webchat"
and tool.name
in [
"web_search_baidu",
"web_search_tavily",
"web_search_bocha",
"web_search_brave",
]
and len(run_context.messages) > 0
and tool_result
and len(tool_result.content)
):
# inject system prompt
first_part = run_context.messages[0]
if (
isinstance(first_part, Message)
and first_part.role == "system"
and first_part.content
and isinstance(first_part.content, str)
):
# we assume system part is str
first_part.content += (
"Always cite web search results you rely on. "
"Index is a unique identifier for each search result. "
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
"after the sentence that uses the information. Do not invent citations."
)
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
pass

View File

@@ -115,6 +115,20 @@ from astrbot.core.utils.quoted_message_parser import (
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
LLM_ERROR_MESSAGE_EXTRA_KEY = "_llm_error_message"
WEB_SEARCH_CITATION_TOOL_NAMES = frozenset(
{
"web_search_baidu",
"web_search_tavily",
"web_search_bocha",
"web_search_brave",
}
)
WEB_SEARCH_CITATION_PROMPT = (
"Always cite web search results you rely on. "
"Index is a unique identifier for each search result. "
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
"after the sentence that uses the information. Do not invent citations."
)
@dataclass(slots=True)
@@ -1149,6 +1163,23 @@ async def _apply_web_search_tools(
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool))
def _apply_web_search_citation_prompt(
event: AstrMessageEvent,
req: ProviderRequest,
) -> None:
if event.get_platform_name() != "webchat" or not req.func_tool:
return
if not any(req.func_tool.get_tool(name) for name in WEB_SEARCH_CITATION_TOOL_NAMES):
return
system_prompt = req.system_prompt or ""
if WEB_SEARCH_CITATION_PROMPT in system_prompt:
return
req.system_prompt = f"{system_prompt}\n{WEB_SEARCH_CITATION_PROMPT}\n"
def _get_compress_provider(
config: MainAgentBuildConfig,
plugin_context: Context,
@@ -1520,6 +1551,8 @@ async def build_main_agent(
if action_type == "live":
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
_apply_web_search_citation_prompt(event, req)
reset_coro = agent_runner.reset(
provider=provider,
request=req,

View File

@@ -476,6 +476,46 @@ class TestBuiltinToolInjection:
assert req.func_tool.get_tool("web_search_firecrawl") is search_tool
assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool
def test_apply_web_search_citation_prompt_for_webchat(self, mock_event):
module = ama
req = ProviderRequest(system_prompt="base")
search_tool = MagicMock(spec=FunctionTool)
search_tool.name = "web_search_tavily"
req.func_tool = ToolSet()
req.func_tool.add_tool(search_tool)
mock_event.get_platform_name.return_value = "webchat"
module._apply_web_search_citation_prompt(mock_event, req)
assert module.WEB_SEARCH_CITATION_PROMPT in req.system_prompt
def test_apply_web_search_citation_prompt_is_idempotent(self, mock_event):
module = ama
req = ProviderRequest(system_prompt="")
search_tool = MagicMock(spec=FunctionTool)
search_tool.name = "web_search_tavily"
req.func_tool = ToolSet()
req.func_tool.add_tool(search_tool)
mock_event.get_platform_name.return_value = "webchat"
module._apply_web_search_citation_prompt(mock_event, req)
module._apply_web_search_citation_prompt(mock_event, req)
assert req.system_prompt.count(module.WEB_SEARCH_CITATION_PROMPT) == 1
def test_apply_web_search_citation_prompt_requires_webchat(self, mock_event):
module = ama
req = ProviderRequest(system_prompt="")
search_tool = MagicMock(spec=FunctionTool)
search_tool.name = "web_search_tavily"
req.func_tool = ToolSet()
req.func_tool.add_tool(search_tool)
mock_event.get_platform_name.return_value = "test_platform"
module._apply_web_search_citation_prompt(mock_event, req)
assert module.WEB_SEARCH_CITATION_PROMPT not in req.system_prompt
def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context):
"""Test cron tool injection through the builtin tool manager."""
module = ama