fix: keep system tools with persona tool lists

This commit is contained in:
Soulter
2026-06-20 00:57:43 +08:00
parent a7533aacda
commit 319b2570be
2 changed files with 53 additions and 22 deletions

View File

@@ -457,10 +457,10 @@ async def _ensure_persona_and_skills(
cfg: dict,
plugin_context: Context,
event: AstrMessageEvent,
) -> set[str] | None:
) -> None:
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
if not req.conversation:
return None
return
(
persona_id,
@@ -527,13 +527,11 @@ async def _ensure_persona_and_skills(
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
persona_allowed_tools = None
persona_toolset = tmgr.get_full_tool_set()
for tool in list(persona_toolset):
if not tool.active:
persona_toolset.remove_tool(tool.name)
else:
persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]}
persona_toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
@@ -614,7 +612,6 @@ async def _ensure_persona_and_skills(
)
except Exception:
pass
return persona_allowed_tools
async def _request_img_caption(
@@ -947,13 +944,12 @@ async def _decorate_llm_request(
plugin_context: Context,
config: MainAgentBuildConfig,
provider: Provider | None = None,
) -> set[str] | None:
) -> None:
cfg = config.provider_settings or plugin_context.get_config(
umo=event.unified_msg_origin
).get("provider_settings", {})
_apply_prompt_prefix(req, cfg)
persona_allowed_tools = None
main_provider_supports_image = provider is not None and _provider_supports_modality(
provider, "image"
@@ -962,9 +958,7 @@ async def _decorate_llm_request(
quote_images_already_captioned = False
if req.conversation:
persona_allowed_tools = await _ensure_persona_and_skills(
req, cfg, plugin_context, event
)
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
await _ensure_img_caption(
@@ -993,7 +987,6 @@ async def _decorate_llm_request(
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
return persona_allowed_tools
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
@@ -1522,9 +1515,7 @@ async def build_main_agent(
else:
return None
persona_allowed_tools = await _decorate_llm_request(
event, req, plugin_context, config, provider=provider
)
await _decorate_llm_request(event, req, plugin_context, config, provider=provider)
await _apply_kb(event, req, plugin_context, config)
@@ -1560,11 +1551,6 @@ async def build_main_agent(
)
)
if persona_allowed_tools is not None and req.func_tool:
req.func_tool.tools = [
tool for tool in req.func_tool.tools if tool.name in persona_allowed_tools
]
fallback_providers = _get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
)

View File

@@ -1009,7 +1009,7 @@ class TestEnsurePersonaAndSkills:
assert req.func_tool is not None
@pytest.mark.asyncio
async def test_persona_empty_tools_filters_late_builtin_tools(
async def test_persona_empty_tools_keeps_late_builtin_tools(
self, mock_event, mock_context, mock_provider
):
module = ama
@@ -1017,6 +1017,7 @@ class TestEnsurePersonaAndSkills:
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("locked", persona, None, False)
)
mock_event.platform_meta.support_proactive_message = False
mock_context.get_config.return_value = {
"provider_settings": {
"web_search": True,
@@ -1030,6 +1031,7 @@ class TestEnsurePersonaAndSkills:
"websearch_provider": "baidu_ai_search",
},
computer_use_runtime="none",
add_cron_tools=False,
)
req = ProviderRequest(prompt="hello")
req.conversation = MagicMock(persona_id="locked", history="[]")
@@ -1052,9 +1054,52 @@ class TestEnsurePersonaAndSkills:
)
assert result is not None
try:
assert result.provider_request.func_tool is None or (
result.provider_request.func_tool.empty()
assert result.provider_request.func_tool is not None
assert result.provider_request.func_tool.names() == ["web_search_baidu"]
finally:
if result.reset_coro:
result.reset_coro.close()
@pytest.mark.asyncio
async def test_persona_empty_tools_keeps_local_runtime_builtin_tools(
self, mock_event, mock_context, mock_provider
):
module = ama
persona = {"name": "locked", "prompt": "No tools.", "tools": []}
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("locked", persona, None, False)
)
mock_event.platform_meta.support_proactive_message = False
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="local",
add_cron_tools=False,
)
req = ProviderRequest(prompt="hello")
req.conversation = MagicMock(persona_id="locked", history="[]")
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=config,
provider=mock_provider,
req=req,
apply_reset=False,
)
assert result is not None
try:
assert result.provider_request.func_tool is not None
tool_names = result.provider_request.func_tool.names()
assert "astrbot_execute_shell" in tool_names
assert "astrbot_execute_python" in tool_names
finally:
if result.reset_coro:
result.reset_coro.close()