mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 01:10:21 +08:00
fix: keep system tools with persona tool lists
This commit is contained in:
@@ -457,10 +457,10 @@ async def _ensure_persona_and_skills(
|
||||
cfg: dict,
|
||||
plugin_context: Context,
|
||||
event: AstrMessageEvent,
|
||||
) -> set[str] | None:
|
||||
) -> None:
|
||||
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
|
||||
if not req.conversation:
|
||||
return None
|
||||
return
|
||||
|
||||
(
|
||||
persona_id,
|
||||
@@ -527,13 +527,11 @@ async def _ensure_persona_and_skills(
|
||||
|
||||
# inject toolset in the persona
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
persona_allowed_tools = None
|
||||
persona_toolset = tmgr.get_full_tool_set()
|
||||
for tool in list(persona_toolset):
|
||||
if not tool.active:
|
||||
persona_toolset.remove_tool(tool.name)
|
||||
else:
|
||||
persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]}
|
||||
persona_toolset = ToolSet()
|
||||
if persona["tools"]:
|
||||
for tool_name in persona["tools"]:
|
||||
@@ -614,7 +612,6 @@ async def _ensure_persona_and_skills(
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return persona_allowed_tools
|
||||
|
||||
|
||||
async def _request_img_caption(
|
||||
@@ -947,13 +944,12 @@ async def _decorate_llm_request(
|
||||
plugin_context: Context,
|
||||
config: MainAgentBuildConfig,
|
||||
provider: Provider | None = None,
|
||||
) -> set[str] | None:
|
||||
) -> None:
|
||||
cfg = config.provider_settings or plugin_context.get_config(
|
||||
umo=event.unified_msg_origin
|
||||
).get("provider_settings", {})
|
||||
|
||||
_apply_prompt_prefix(req, cfg)
|
||||
persona_allowed_tools = None
|
||||
|
||||
main_provider_supports_image = provider is not None and _provider_supports_modality(
|
||||
provider, "image"
|
||||
@@ -962,9 +958,7 @@ async def _decorate_llm_request(
|
||||
quote_images_already_captioned = False
|
||||
|
||||
if req.conversation:
|
||||
persona_allowed_tools = await _ensure_persona_and_skills(
|
||||
req, cfg, plugin_context, event
|
||||
)
|
||||
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
|
||||
|
||||
if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
|
||||
await _ensure_img_caption(
|
||||
@@ -993,7 +987,6 @@ async def _decorate_llm_request(
|
||||
tz = plugin_context.get_config().get("timezone")
|
||||
_append_system_reminders(event, req, cfg, tz)
|
||||
_apply_workspace_extra_prompt(event, req)
|
||||
return persona_allowed_tools
|
||||
|
||||
|
||||
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
@@ -1522,9 +1515,7 @@ async def build_main_agent(
|
||||
else:
|
||||
return None
|
||||
|
||||
persona_allowed_tools = await _decorate_llm_request(
|
||||
event, req, plugin_context, config, provider=provider
|
||||
)
|
||||
await _decorate_llm_request(event, req, plugin_context, config, provider=provider)
|
||||
|
||||
await _apply_kb(event, req, plugin_context, config)
|
||||
|
||||
@@ -1560,11 +1551,6 @@ async def build_main_agent(
|
||||
)
|
||||
)
|
||||
|
||||
if persona_allowed_tools is not None and req.func_tool:
|
||||
req.func_tool.tools = [
|
||||
tool for tool in req.func_tool.tools if tool.name in persona_allowed_tools
|
||||
]
|
||||
|
||||
fallback_providers = _get_fallback_chat_providers(
|
||||
provider, plugin_context, config.provider_settings
|
||||
)
|
||||
|
||||
@@ -1009,7 +1009,7 @@ class TestEnsurePersonaAndSkills:
|
||||
assert req.func_tool is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persona_empty_tools_filters_late_builtin_tools(
|
||||
async def test_persona_empty_tools_keeps_late_builtin_tools(
|
||||
self, mock_event, mock_context, mock_provider
|
||||
):
|
||||
module = ama
|
||||
@@ -1017,6 +1017,7 @@ class TestEnsurePersonaAndSkills:
|
||||
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
||||
return_value=("locked", persona, None, False)
|
||||
)
|
||||
mock_event.platform_meta.support_proactive_message = False
|
||||
mock_context.get_config.return_value = {
|
||||
"provider_settings": {
|
||||
"web_search": True,
|
||||
@@ -1030,6 +1031,7 @@ class TestEnsurePersonaAndSkills:
|
||||
"websearch_provider": "baidu_ai_search",
|
||||
},
|
||||
computer_use_runtime="none",
|
||||
add_cron_tools=False,
|
||||
)
|
||||
req = ProviderRequest(prompt="hello")
|
||||
req.conversation = MagicMock(persona_id="locked", history="[]")
|
||||
@@ -1052,9 +1054,52 @@ class TestEnsurePersonaAndSkills:
|
||||
)
|
||||
assert result is not None
|
||||
try:
|
||||
assert result.provider_request.func_tool is None or (
|
||||
result.provider_request.func_tool.empty()
|
||||
assert result.provider_request.func_tool is not None
|
||||
assert result.provider_request.func_tool.names() == ["web_search_baidu"]
|
||||
finally:
|
||||
if result.reset_coro:
|
||||
result.reset_coro.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persona_empty_tools_keeps_local_runtime_builtin_tools(
|
||||
self, mock_event, mock_context, mock_provider
|
||||
):
|
||||
module = ama
|
||||
persona = {"name": "locked", "prompt": "No tools.", "tools": []}
|
||||
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
||||
return_value=("locked", persona, None, False)
|
||||
)
|
||||
mock_event.platform_meta.support_proactive_message = False
|
||||
config = module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60,
|
||||
computer_use_runtime="local",
|
||||
add_cron_tools=False,
|
||||
)
|
||||
req = ProviderRequest(prompt="hello")
|
||||
req.conversation = MagicMock(persona_id="locked", history="[]")
|
||||
|
||||
with (
|
||||
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
||||
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
||||
):
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.reset = AsyncMock()
|
||||
mock_runner_cls.return_value = mock_runner
|
||||
|
||||
result = await module.build_main_agent(
|
||||
event=mock_event,
|
||||
plugin_context=mock_context,
|
||||
config=config,
|
||||
provider=mock_provider,
|
||||
req=req,
|
||||
apply_reset=False,
|
||||
)
|
||||
assert result is not None
|
||||
try:
|
||||
assert result.provider_request.func_tool is not None
|
||||
tool_names = result.provider_request.func_tool.names()
|
||||
assert "astrbot_execute_shell" in tool_names
|
||||
assert "astrbot_execute_python" in tool_names
|
||||
finally:
|
||||
if result.reset_coro:
|
||||
result.reset_coro.close()
|
||||
|
||||
Reference in New Issue
Block a user