Compare commits

...

2 Commits

Author SHA1 Message Date
Soulter
299b7ad56b refactor: filter persona tools in one pass 2026-06-18 23:10:18 +08:00
Yufeng He
d23011262e fix(core): enforce persona tool boundaries 2026-06-15 19:33:19 +08:00
4 changed files with 110 additions and 7 deletions

View File

@@ -266,8 +266,13 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
# "all tools", including runtime computer-use tools.
if tools is None:
toolset = ToolSet()
for registered_tool in llm_tools.func_list:
if isinstance(registered_tool, HandoffTool):
handoff_names = {
tool.name
for tool in tool_mgr.func_list
if isinstance(tool, HandoffTool)
}
for registered_tool in tool_mgr.get_full_tool_set():
if registered_tool.name in handoff_names:
continue
if registered_tool.active:
toolset.add_tool(registered_tool)

View File

@@ -456,10 +456,10 @@ async def _ensure_persona_and_skills(
cfg: dict,
plugin_context: Context,
event: AstrMessageEvent,
) -> None:
) -> set[str] | None:
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
if not req.conversation:
return
return None
(
persona_id,
@@ -514,11 +514,13 @@ async def _ensure_persona_and_skills(
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
persona_allowed_tools = None
persona_toolset = tmgr.get_full_tool_set()
for tool in list(persona_toolset):
if not tool.active:
persona_toolset.remove_tool(tool.name)
else:
persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]}
persona_toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
@@ -599,6 +601,7 @@ async def _ensure_persona_and_skills(
)
except Exception:
pass
return persona_allowed_tools
async def _request_img_caption(
@@ -931,12 +934,13 @@ async def _decorate_llm_request(
plugin_context: Context,
config: MainAgentBuildConfig,
provider: Provider | None = None,
) -> None:
) -> set[str] | None:
cfg = config.provider_settings or plugin_context.get_config(
umo=event.unified_msg_origin
).get("provider_settings", {})
_apply_prompt_prefix(req, cfg)
persona_allowed_tools = None
main_provider_supports_image = provider is not None and _provider_supports_modality(
provider, "image"
@@ -945,7 +949,9 @@ async def _decorate_llm_request(
quote_images_already_captioned = False
if req.conversation:
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
persona_allowed_tools = await _ensure_persona_and_skills(
req, cfg, plugin_context, event
)
if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
await _ensure_img_caption(
@@ -974,6 +980,7 @@ async def _decorate_llm_request(
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
return persona_allowed_tools
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
@@ -1502,7 +1509,9 @@ async def build_main_agent(
else:
return None
await _decorate_llm_request(event, req, plugin_context, config, provider=provider)
persona_allowed_tools = await _decorate_llm_request(
event, req, plugin_context, config, provider=provider
)
await _apply_kb(event, req, plugin_context, config)
@@ -1538,6 +1547,11 @@ async def build_main_agent(
)
)
if persona_allowed_tools is not None and req.func_tool:
req.func_tool.tools = [
tool for tool in req.func_tool.tools if tool.name in persona_allowed_tools
]
fallback_providers = _get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
)

View File

@@ -3,9 +3,16 @@ from types import SimpleNamespace
import mcp
import pytest
from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.message.components import Image
from astrbot.core.provider.func_tool_manager import (
FunctionToolManager,
_PermissionGuardedTool,
)
class _DummyEvent:
@@ -29,6 +36,32 @@ def _build_run_context(message_components: list[object] | None = None):
return ContextWrapper(context=ctx)
def test_build_handoff_toolset_keeps_permission_guards_for_default_tools():
mgr = FunctionToolManager()
plugin_tool = FunctionTool(
name="admin_only_mcp",
description="admin tool",
parameters={"type": "object", "properties": {}},
)
handoff = HandoffTool(Agent(name="child"))
mgr.func_list = [plugin_tool, handoff]
event = _DummyEvent()
context = SimpleNamespace(
get_config=lambda **_kwargs: {
"provider_settings": {"computer_use_runtime": "none"}
},
get_llm_tool_manager=lambda: mgr,
)
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
toolset = FunctionToolExecutor._build_handoff_toolset(run_context, tools=None)
assert toolset is not None
assert isinstance(toolset.get_tool("admin_only_mcp"), _PermissionGuardedTool)
assert toolset.get_tool("transfer_to_child") is None
@pytest.mark.asyncio
async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image(
monkeypatch: pytest.MonkeyPatch,

View File

@@ -817,6 +817,57 @@ class TestEnsurePersonaAndSkills:
assert req.func_tool is not None
@pytest.mark.asyncio
async def test_persona_empty_tools_filters_late_builtin_tools(
self, mock_event, mock_context, mock_provider
):
module = ama
persona = {"name": "locked", "prompt": "No tools.", "tools": []}
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("locked", persona, None, False)
)
mock_context.get_config.return_value = {
"provider_settings": {
"web_search": True,
"websearch_provider": "baidu_ai_search",
}
}
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
provider_settings={
"web_search": True,
"websearch_provider": "baidu_ai_search",
},
computer_use_runtime="none",
)
req = ProviderRequest(prompt="hello")
req.conversation = MagicMock(persona_id="locked", history="[]")
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=config,
provider=mock_provider,
req=req,
apply_reset=False,
)
assert result is not None
try:
assert result.provider_request.func_tool is None or (
result.provider_request.func_tool.empty()
)
finally:
if result.reset_coro:
result.reset_coro.close()
@pytest.mark.asyncio
async def test_subagent_dedupe_uses_default_persona_tools(
self, mock_event, mock_context