mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-05 12:20:17 +08:00
Compare commits
2 Commits
codex/fix-
...
codex/pr-8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
299b7ad56b | ||
|
|
d23011262e |
@@ -266,8 +266,13 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
# "all tools", including runtime computer-use tools.
|
||||
if tools is None:
|
||||
toolset = ToolSet()
|
||||
for registered_tool in llm_tools.func_list:
|
||||
if isinstance(registered_tool, HandoffTool):
|
||||
handoff_names = {
|
||||
tool.name
|
||||
for tool in tool_mgr.func_list
|
||||
if isinstance(tool, HandoffTool)
|
||||
}
|
||||
for registered_tool in tool_mgr.get_full_tool_set():
|
||||
if registered_tool.name in handoff_names:
|
||||
continue
|
||||
if registered_tool.active:
|
||||
toolset.add_tool(registered_tool)
|
||||
|
||||
@@ -456,10 +456,10 @@ async def _ensure_persona_and_skills(
|
||||
cfg: dict,
|
||||
plugin_context: Context,
|
||||
event: AstrMessageEvent,
|
||||
) -> None:
|
||||
) -> set[str] | None:
|
||||
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
|
||||
if not req.conversation:
|
||||
return
|
||||
return None
|
||||
|
||||
(
|
||||
persona_id,
|
||||
@@ -514,11 +514,13 @@ async def _ensure_persona_and_skills(
|
||||
|
||||
# inject toolset in the persona
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
persona_allowed_tools = None
|
||||
persona_toolset = tmgr.get_full_tool_set()
|
||||
for tool in list(persona_toolset):
|
||||
if not tool.active:
|
||||
persona_toolset.remove_tool(tool.name)
|
||||
else:
|
||||
persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]}
|
||||
persona_toolset = ToolSet()
|
||||
if persona["tools"]:
|
||||
for tool_name in persona["tools"]:
|
||||
@@ -599,6 +601,7 @@ async def _ensure_persona_and_skills(
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return persona_allowed_tools
|
||||
|
||||
|
||||
async def _request_img_caption(
|
||||
@@ -931,12 +934,13 @@ async def _decorate_llm_request(
|
||||
plugin_context: Context,
|
||||
config: MainAgentBuildConfig,
|
||||
provider: Provider | None = None,
|
||||
) -> None:
|
||||
) -> set[str] | None:
|
||||
cfg = config.provider_settings or plugin_context.get_config(
|
||||
umo=event.unified_msg_origin
|
||||
).get("provider_settings", {})
|
||||
|
||||
_apply_prompt_prefix(req, cfg)
|
||||
persona_allowed_tools = None
|
||||
|
||||
main_provider_supports_image = provider is not None and _provider_supports_modality(
|
||||
provider, "image"
|
||||
@@ -945,7 +949,9 @@ async def _decorate_llm_request(
|
||||
quote_images_already_captioned = False
|
||||
|
||||
if req.conversation:
|
||||
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
|
||||
persona_allowed_tools = await _ensure_persona_and_skills(
|
||||
req, cfg, plugin_context, event
|
||||
)
|
||||
|
||||
if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
|
||||
await _ensure_img_caption(
|
||||
@@ -974,6 +980,7 @@ async def _decorate_llm_request(
|
||||
tz = plugin_context.get_config().get("timezone")
|
||||
_append_system_reminders(event, req, cfg, tz)
|
||||
_apply_workspace_extra_prompt(event, req)
|
||||
return persona_allowed_tools
|
||||
|
||||
|
||||
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
@@ -1502,7 +1509,9 @@ async def build_main_agent(
|
||||
else:
|
||||
return None
|
||||
|
||||
await _decorate_llm_request(event, req, plugin_context, config, provider=provider)
|
||||
persona_allowed_tools = await _decorate_llm_request(
|
||||
event, req, plugin_context, config, provider=provider
|
||||
)
|
||||
|
||||
await _apply_kb(event, req, plugin_context, config)
|
||||
|
||||
@@ -1538,6 +1547,11 @@ async def build_main_agent(
|
||||
)
|
||||
)
|
||||
|
||||
if persona_allowed_tools is not None and req.func_tool:
|
||||
req.func_tool.tools = [
|
||||
tool for tool in req.func_tool.tools if tool.name in persona_allowed_tools
|
||||
]
|
||||
|
||||
fallback_providers = _get_fallback_chat_providers(
|
||||
provider, plugin_context, config.provider_settings
|
||||
)
|
||||
|
||||
@@ -3,9 +3,16 @@ from types import SimpleNamespace
|
||||
import mcp
|
||||
import pytest
|
||||
|
||||
from astrbot.core.agent.agent import Agent
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.provider.func_tool_manager import (
|
||||
FunctionToolManager,
|
||||
_PermissionGuardedTool,
|
||||
)
|
||||
|
||||
|
||||
class _DummyEvent:
|
||||
@@ -29,6 +36,32 @@ def _build_run_context(message_components: list[object] | None = None):
|
||||
return ContextWrapper(context=ctx)
|
||||
|
||||
|
||||
def test_build_handoff_toolset_keeps_permission_guards_for_default_tools():
|
||||
mgr = FunctionToolManager()
|
||||
plugin_tool = FunctionTool(
|
||||
name="admin_only_mcp",
|
||||
description="admin tool",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
handoff = HandoffTool(Agent(name="child"))
|
||||
mgr.func_list = [plugin_tool, handoff]
|
||||
|
||||
event = _DummyEvent()
|
||||
context = SimpleNamespace(
|
||||
get_config=lambda **_kwargs: {
|
||||
"provider_settings": {"computer_use_runtime": "none"}
|
||||
},
|
||||
get_llm_tool_manager=lambda: mgr,
|
||||
)
|
||||
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
|
||||
|
||||
toolset = FunctionToolExecutor._build_handoff_toolset(run_context, tools=None)
|
||||
|
||||
assert toolset is not None
|
||||
assert isinstance(toolset.get_tool("admin_only_mcp"), _PermissionGuardedTool)
|
||||
assert toolset.get_tool("transfer_to_child") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -817,6 +817,57 @@ class TestEnsurePersonaAndSkills:
|
||||
|
||||
assert req.func_tool is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persona_empty_tools_filters_late_builtin_tools(
|
||||
self, mock_event, mock_context, mock_provider
|
||||
):
|
||||
module = ama
|
||||
persona = {"name": "locked", "prompt": "No tools.", "tools": []}
|
||||
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
|
||||
return_value=("locked", persona, None, False)
|
||||
)
|
||||
mock_context.get_config.return_value = {
|
||||
"provider_settings": {
|
||||
"web_search": True,
|
||||
"websearch_provider": "baidu_ai_search",
|
||||
}
|
||||
}
|
||||
config = module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60,
|
||||
provider_settings={
|
||||
"web_search": True,
|
||||
"websearch_provider": "baidu_ai_search",
|
||||
},
|
||||
computer_use_runtime="none",
|
||||
)
|
||||
req = ProviderRequest(prompt="hello")
|
||||
req.conversation = MagicMock(persona_id="locked", history="[]")
|
||||
|
||||
with (
|
||||
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
|
||||
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
|
||||
):
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.reset = AsyncMock()
|
||||
mock_runner_cls.return_value = mock_runner
|
||||
|
||||
result = await module.build_main_agent(
|
||||
event=mock_event,
|
||||
plugin_context=mock_context,
|
||||
config=config,
|
||||
provider=mock_provider,
|
||||
req=req,
|
||||
apply_reset=False,
|
||||
)
|
||||
assert result is not None
|
||||
try:
|
||||
assert result.provider_request.func_tool is None or (
|
||||
result.provider_request.func_tool.empty()
|
||||
)
|
||||
finally:
|
||||
if result.reset_coro:
|
||||
result.reset_coro.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_dedupe_uses_default_persona_tools(
|
||||
self, mock_event, mock_context
|
||||
|
||||
Reference in New Issue
Block a user