Compare commits

...

19 Commits

Author SHA1 Message Date
Soulter
1fe0ed1f23 chore: remove lxml dependencies from project 2026-04-10 17:49:32 +08:00
Soulter
1745e9c4fb feat: implement handling for large tool results with overflow file writing and read tool integration 2026-04-10 17:26:17 +08:00
Soulter
3acda6f77a feat: update converted text notice to suggest using grep for large files 2026-04-10 16:46:13 +08:00
Soulter
cff148860a feat: enhance file reading capabilities to support PDF and DOCX parsing, including workspace storage for long documents 2026-04-10 16:31:53 +08:00
Soulter
013ecacee9 Merge remote-tracking branch 'origin/master' into feat/fs-grep-read-edit 2026-04-10 15:47:18 +08:00
Soulter
7bf1d19332 perf: shell executed in workspace dir in local env 2026-04-10 15:38:40 +08:00
Soulter
5f049f2bb5 feat: add ripgrep installation to Dockerfile 2026-04-10 11:35:59 +08:00
Soulter
add5db6748 feat: add workspace extra prompt handling in message processing 2026-04-10 00:01:50 +08:00
Soulter
5ca2483a43 feat: add tooltip for disabled builtin tools and update localization strings 2026-04-09 23:56:28 +08:00
Soulter
adc01e0c9d feat: supports to display enabled builtin tools in configs 2026-04-09 23:49:22 +08:00
Soulter
efc93a37b1 refactor: remove unused plugin_context parameter from _apply_sandbox_tools 2026-04-08 15:34:18 +08:00
Soulter
56a099bf90 refactor: move computer tools to builtin tools registry 2026-04-08 15:29:00 +08:00
Soulter
006aedbd24 Merge remote-tracking branch 'origin/master' into feat/fs-grep-read-edit 2026-04-08 14:52:06 +08:00
Soulter
86ac40d944 feat: add file read utilities and integrate with filesystem tools 2026-04-08 00:34:38 +08:00
Soulter
20fed8ab62 feat: implement file read tool with support for text and image files, including validation for large files 2026-04-07 23:52:23 +08:00
Soulter
a539deec91 feat: remove redundant safe path tests from security restrictions 2026-04-07 21:44:34 +08:00
Soulter
11282c769f feat: enhance tool prompt formatting and add escaped text decoding for file editing 2026-04-07 21:32:11 +08:00
Soulter
8e7d995fec feat: add file write tool and enhance file read functionality 2026-04-07 20:45:40 +08:00
Soulter
fcf1b08455 feat: filesystem grep, read, edit file 2026-04-07 01:20:26 +08:00
49 changed files with 3444 additions and 632 deletions

View File

@@ -16,6 +16,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
gnupg \
git \
ripgrep \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& apt-get clean \

View File

@@ -4,9 +4,11 @@ import sys
import time
import traceback
import typing as T
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from mcp.types import (
BlobResourceContents,
@@ -25,7 +27,7 @@ from tenacity import (
from astrbot import logger
from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart
from astrbot.core.agent.tool import ToolSet
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_image_cache import tool_image_cache
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.message.components import Json
@@ -45,7 +47,7 @@ from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.token_counter import TokenCounter
from ..context.token_counter import EstimateTokenCounter, TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -97,6 +99,8 @@ ToolExecutorResultT = T.TypeVar("ToolExecutorResultT")
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500
TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000
EMPTY_OUTPUT_RETRY_ATTEMPTS = 3
EMPTY_OUTPUT_RETRY_WAIT_MIN_S = 1
EMPTY_OUTPUT_RETRY_WAIT_MAX_S = 4
@@ -151,6 +155,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
"Otherwise, change strategy, adjust arguments, or explain the limitation "
"to the user."
)
TOOL_RESULT_OVERFLOW_NOTICE_TEMPLATE = (
"Truncated tool output preview shown above. "
"The tool output was too large to include directly and was written to "
"`{overflow_path}`. Use {read_tool_hint} with a narrower window to inspect it."
)
def _get_persona_custom_error_message(self) -> str | None:
"""Read persona-level custom error message from event extras when available."""
@@ -206,6 +215,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
tool_result_overflow_dir: str | None = None,
read_tool: FunctionTool | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
@@ -217,6 +228,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
self.tool_result_overflow_dir = tool_result_overflow_dir
self.read_tool = read_tool
self._tool_result_token_counter = EstimateTokenCounter()
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
@@ -298,6 +312,103 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.stats = AgentStats()
self.stats.start_time = time.time()
def _read_tool_hint(self) -> str:
if self.read_tool is not None:
return f"`{self.read_tool.name}`"
return "the available file-read tool"
async def _write_tool_result_overflow_file(
self,
*,
tool_call_id: str,
content: str,
) -> str:
if self.tool_result_overflow_dir is None:
raise ValueError("tool_result_overflow_dir is not configured")
overflow_dir = Path(self.tool_result_overflow_dir).resolve(strict=False)
safe_tool_call_id = (
"".join(
ch if ch.isalnum() or ch in {"-", "_", "."} else "_"
for ch in tool_call_id
).strip("._")
or "tool_call"
)
file_name = f"{safe_tool_call_id}_{uuid.uuid4().hex[:8]}.txt"
overflow_path = overflow_dir / file_name
def _run() -> str:
overflow_dir.mkdir(parents=True, exist_ok=True)
overflow_path.write_text(content, encoding="utf-8")
return str(overflow_path)
return await asyncio.to_thread(_run)
async def _materialize_large_tool_result(
self,
*,
tool_call_id: str,
content: str,
) -> str:
if self.tool_result_overflow_dir is None or self.read_tool is None:
return content
estimated_tokens = self._tool_result_token_counter.count_tokens(
[Message(role="tool", content=content, tool_call_id=tool_call_id)]
)
if estimated_tokens <= self.TOOL_RESULT_MAX_ESTIMATED_TOKENS:
return content
preview = self._truncate_tool_result_preview(content, tool_call_id=tool_call_id)
try:
overflow_path = await self._write_tool_result_overflow_file(
tool_call_id=tool_call_id,
content=content,
)
except Exception as exc:
logger.warning(
"Failed to spill oversized tool result for %s: %s",
tool_call_id,
exc,
exc_info=True,
)
error_notice = (
"Tool output exceeded the inline result limit "
f"({estimated_tokens} estimated tokens > "
f"{self.TOOL_RESULT_MAX_ESTIMATED_TOKENS}) and could not be written "
f"to `{self.tool_result_overflow_dir}`: {exc}"
)
if not preview:
return error_notice
return f"{preview}\n\n{error_notice}"
notice = self.TOOL_RESULT_OVERFLOW_NOTICE_TEMPLATE.format(
overflow_path=overflow_path,
read_tool_hint=self._read_tool_hint(),
)
if not preview:
return notice
return f"{preview}\n\n{notice}"
def _truncate_tool_result_preview(
self,
content: str,
*,
tool_call_id: str,
) -> str:
preview = content
while preview:
estimated_tokens = self._tool_result_token_counter.count_tokens(
[Message(role="tool", content=preview, tool_call_id=tool_call_id)]
)
if estimated_tokens <= self.TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS:
return preview
next_len = len(preview) // 2
if next_len <= 0:
break
preview = preview[:next_len]
return preview
async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
@@ -933,9 +1044,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
"The tool has returned a data type that is not supported."
)
if result_parts:
inline_result = "\n\n".join(result_parts)
inline_result = await self._materialize_large_tool_result(
tool_call_id=func_tool_id,
content=inline_result,
)
_append_tool_call_result(
func_tool_id,
"\n\n".join(result_parts)
inline_result
+ self._build_repeated_tool_call_guidance(
func_tool_name, tool_call_streak
),

View File

@@ -19,12 +19,6 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.astr_main_agent_resources import (
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PYTHON_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
@@ -36,6 +30,17 @@ from astrbot.core.message.message_event_result import (
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.tools.computer_tools import (
ExecuteShellTool,
FileDownloadTool,
FileEditTool,
FileReadTool,
FileUploadTool,
FileWriteTool,
GrepTool,
LocalPythonTool,
PythonTool,
)
from astrbot.core.tools.message_tools import SendMessageToUserTool
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.history_saver import persist_agent_history
@@ -177,18 +182,44 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
return
@classmethod
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
def _get_runtime_computer_tools(
cls,
runtime: str,
tool_mgr,
) -> dict[str, FunctionTool]:
if runtime == "sandbox":
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
python_tool = tool_mgr.get_builtin_tool(PythonTool)
upload_tool = tool_mgr.get_builtin_tool(FileUploadTool)
download_tool = tool_mgr.get_builtin_tool(FileDownloadTool)
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
return {
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
PYTHON_TOOL.name: PYTHON_TOOL,
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
shell_tool.name: shell_tool,
python_tool.name: python_tool,
upload_tool.name: upload_tool,
download_tool.name: download_tool,
read_tool.name: read_tool,
write_tool.name: write_tool,
edit_tool.name: edit_tool,
grep_tool.name: grep_tool,
}
if runtime == "local":
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
python_tool = tool_mgr.get_builtin_tool(LocalPythonTool)
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
return {
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
shell_tool.name: shell_tool,
python_tool.name: python_tool,
read_tool.name: read_tool,
write_tool.name: write_tool,
edit_tool.name: edit_tool,
grep_tool.name: grep_tool,
}
return {}
@@ -203,7 +234,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
cfg = ctx.get_config(umo=event.unified_msg_origin)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
tool_mgr = (
ctx.get_llm_tool_manager()
if hasattr(ctx, "get_llm_tool_manager")
else llm_tools
)
runtime_computer_tools = cls._get_runtime_computer_tools(
runtime,
tool_mgr,
)
# Keep persona semantics aligned with the main agent: tools=None means
# "all tools", including runtime computer-use tools.

View File

@@ -9,6 +9,7 @@ import platform
import zoneinfo
from collections.abc import Coroutine
from dataclasses import dataclass, field
from pathlib import Path
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
@@ -20,30 +21,10 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.astr_agent_run_util import AgentRunner
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astr_main_agent_resources import (
ANNOTATE_EXECUTION_TOOL,
BROWSER_BATCH_EXEC_TOOL,
BROWSER_EXEC_TOOL,
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
CREATE_SKILL_CANDIDATE_TOOL,
CREATE_SKILL_PAYLOAD_TOOL,
EVALUATE_SKILL_CANDIDATE_TOOL,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
GET_EXECUTION_HISTORY_TOOL,
GET_SKILL_PAYLOAD_TOOL,
LIST_SKILL_CANDIDATES_TOOL,
LIST_SKILL_RELEASES_TOOL,
LIVE_MODE_SYSTEM_PROMPT,
LLM_SAFETY_MODE_SYSTEM_PROMPT,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PROMOTE_SKILL_CANDIDATE_TOOL,
PYTHON_TOOL,
ROLLBACK_SKILL_RELEASE_TOOL,
RUN_BROWSER_SKILL_TOOL,
SANDBOX_MODE_PROMPT,
SYNC_SKILL_RELEASE_TOOL,
TOOL_CALL_PROMPT,
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
)
@@ -56,9 +37,36 @@ from astrbot.core.persona_error_reply import (
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import star_map
from astrbot.core.tools.computer_tools import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileEditTool,
FileReadTool,
FileUploadTool,
FileWriteTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
GrepTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
LocalPythonTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
normalize_umo_for_workspace,
)
from astrbot.core.tools.cron_tools import FutureTaskTool
from astrbot.core.tools.knowledge_base_tools import (
KnowledgeBaseQueryTool,
@@ -73,6 +81,10 @@ from astrbot.core.tools.web_search_tools import (
TavilyWebSearchTool,
normalize_legacy_web_search_config,
)
from astrbot.core.utils.astrbot_path import (
get_astrbot_system_tmp_path,
get_astrbot_workspaces_path,
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.media_utils import (
@@ -290,11 +302,54 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None:
req.prompt = f"{prefix}{req.prompt}"
def _apply_local_env_tools(req: ProviderRequest) -> None:
def _get_workspace_path_for_umo(umo: str) -> Path:
normalized_umo = normalize_umo_for_workspace(umo)
return Path(get_astrbot_workspaces_path()) / normalized_umo
def _apply_workspace_extra_prompt(
event: AstrMessageEvent,
req: ProviderRequest,
) -> None:
extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / (
"EXTRA_PROMPT.md"
)
if not extra_prompt_path.is_file():
return
try:
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to read workspace extra prompt for umo=%s from %s: %s",
event.unified_msg_origin,
extra_prompt_path,
exc,
)
return
if not extra_prompt:
return
req.system_prompt = (
f"{req.system_prompt or ''}\n"
"[Workspace Extra Prompt]\n"
"The following instructions are loaded from the current workspace "
"`EXTRA_PROMPT.md` file.\n"
f"{extra_prompt}\n"
)
def _apply_local_env_tools(req: ProviderRequest, plugin_context: Context) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(LOCAL_PYTHON_TOOL)
tool_mgr = plugin_context.get_llm_tool_manager()
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(LocalPythonTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool))
req.system_prompt = f"{req.system_prompt or ''}\n{_build_local_mode_prompt()}\n"
@@ -765,6 +820,7 @@ async def _decorate_llm_request(
if tz is None:
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
@@ -981,7 +1037,9 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) -
def _apply_sandbox_tools(
config: MainAgentBuildConfig, req: ProviderRequest, session_id: str
config: MainAgentBuildConfig,
req: ProviderRequest,
session_id: str,
) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
@@ -997,10 +1055,15 @@ def _apply_sandbox_tools(
os.environ["SHIPYARD_ENDPOINT"] = ep
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(PYTHON_TOOL)
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
tool_mgr = llm_tools
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(PythonTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileUploadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileDownloadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool))
if booter == "shipyard_neo":
# Neo-specific path rule: filesystem tools operate relative to sandbox
# workspace root. Do not prepend "/workspace".
@@ -1036,22 +1099,22 @@ def _apply_sandbox_tools(
# Browser tools: only register if profile supports browser
# (or if capabilities are unknown because sandbox hasn't booted yet)
if sandbox_capabilities is None or "browser" in sandbox_capabilities:
req.func_tool.add_tool(BROWSER_EXEC_TOOL)
req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL)
req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL)
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserExecTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserBatchExecTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(RunBrowserSkillTool))
# Neo-specific tools (always available for shipyard_neo)
req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL)
req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL)
req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL)
req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL)
req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL)
req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL)
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetExecutionHistoryTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(AnnotateExecutionTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillPayloadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetSkillPayloadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillCandidateTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillCandidatesTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(EvaluateSkillCandidateTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(PromoteSkillCandidateTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillReleasesTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(RollbackSkillReleaseTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(SyncSkillReleaseTool))
req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n"
@@ -1341,7 +1404,7 @@ async def build_main_agent(
if config.computer_use_runtime == "sandbox":
_apply_sandbox_tools(config, req, req.session_id)
elif config.computer_use_runtime == "local":
_apply_local_env_tools(req)
_apply_local_env_tools(req, plugin_context)
agent_runner = AgentRunner()
astr_agent_ctx = AstrAgentContext(
@@ -1377,6 +1440,15 @@ async def build_main_agent(
if config.tool_schema_mode == "full"
else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE
)
if config.computer_use_runtime == "local":
tool_prompt += (
f"\nCurrent workspace you can use: "
f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n"
"Unless the user explicitly specifies a different directory, "
"perform all file-related operations in this workspace.\n"
)
req.system_prompt += f"\n{tool_prompt}\n"
action_type = event.get_extra("action_type")
@@ -1402,6 +1474,14 @@ async def build_main_agent(
fallback_providers=_get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
),
tool_result_overflow_dir=(
get_astrbot_system_tmp_path()
if req.func_tool and req.func_tool.get_tool("astrbot_file_read_tool")
else None
),
read_tool=(
req.func_tool.get_tool("astrbot_file_read_tool") if req.func_tool else None
),
)
if apply_reset:

View File

@@ -1,27 +1,5 @@
import base64
from astrbot.core.computer.tools import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
LocalPythonTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
)
LLM_SAFETY_MODE_SYSTEM_PROMPT = """You are running in Safe Mode.
Rules:
@@ -130,28 +108,6 @@ BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT = (
"{background_task_result}"
)
EXECUTE_SHELL_TOOL = ExecuteShellTool()
LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True)
PYTHON_TOOL = PythonTool()
LOCAL_PYTHON_TOOL = LocalPythonTool()
FILE_UPLOAD_TOOL = FileUploadTool()
FILE_DOWNLOAD_TOOL = FileDownloadTool()
BROWSER_EXEC_TOOL = BrowserExecTool()
BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool()
RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool()
GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool()
ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool()
CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool()
GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool()
CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool()
LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool()
EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool()
PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool()
LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool()
ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool()
SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool()
# we prevent astrbot from connecting to known malicious hosts
# these hosts are base64 encoded
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}

View File

@@ -4,7 +4,7 @@ from typing import Any
import aiohttp
import boxlite
from shipyard.filesystem import FileSystemComponent as ShipyardFileSystemComponent
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
from shipyard.python import PythonComponent as ShipyardPythonComponent
from shipyard.shell import ShellComponent as ShipyardShellComponent
@@ -12,6 +12,7 @@ from astrbot.api import logger
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard import ShipyardFileSystemWrapper
class MockShipyardSandboxClient:
@@ -150,11 +151,6 @@ class BoxliteBooter(ComputerBooter):
self.mocked = MockShipyardSandboxClient(
sb_url=f"http://127.0.0.1:{random_port}"
)
self._fs = ShipyardFileSystemComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
self._python = ShipyardPythonComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
@@ -165,6 +161,14 @@ class BoxliteBooter(ComputerBooter):
ship_id=self.box.id,
session_id=session_id,
)
self._ship_fs = ShipyardFileSystemComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
self._fs = ShipyardFileSystemWrapper(
_shipyard_fs=self._ship_fs, _shipyard_shell=self._shell
)
await self.mocked.wait_healthy(self.box.id, session_id)

View File

@@ -9,15 +9,18 @@ import sys
from dataclasses import dataclass
from typing import Any
from python_ripgrep import search
from astrbot.api import logger
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_root,
get_astrbot_temp_path,
from astrbot.core.computer.file_read_utils import (
detect_text_encoding,
read_local_text_range_sync,
)
from astrbot.core.utils.astrbot_path import get_astrbot_root
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard_search_file_util import _truncate_long_lines
_BLOCKED_COMMAND_PATTERNS = [
" rm -rf ",
@@ -41,18 +44,6 @@ def _is_safe_command(command: str) -> bool:
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)
def _ensure_safe_path(path: str) -> str:
abs_path = os.path.abspath(path)
allowed_roots = [
os.path.abspath(get_astrbot_root()),
os.path.abspath(get_astrbot_data_path()),
os.path.abspath(get_astrbot_temp_path()),
]
if not any(abs_path.startswith(root) for root in allowed_roots):
raise PermissionError("Path is outside the allowed computer roots.")
return abs_path
def _decode_bytes_with_fallback(
output: bytes | None,
*,
@@ -110,7 +101,7 @@ class LocalShellComponent(ShellComponent):
run_env = os.environ.copy()
if env:
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
working_dir = os.path.abspath(cwd) if cwd else get_astrbot_root()
if background:
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
@@ -186,7 +177,7 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str, content: str = "", mode: int = 0o644
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
@@ -195,16 +186,85 @@ class LocalFileSystemComponent(FileSystemComponent):
return await asyncio.to_thread(_run)
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
with open(abs_path, "rb") as f:
raw_content = f.read()
content = _decode_bytes_with_fallback(
raw_content,
preferred_encoding=encoding,
abs_path = os.path.abspath(path)
detected_encoding = encoding
if encoding == "utf-8":
with open(abs_path, "rb") as f:
raw_sample = f.read(8192)
detected_encoding = detect_text_encoding(raw_sample) or encoding
return {
"success": True,
"content": read_local_text_range_sync(
abs_path,
encoding=detected_encoding,
offset=offset,
limit=limit,
),
}
return await asyncio.to_thread(_run)
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
results = search(
patterns=[pattern],
paths=[path] if path else None,
globs=[glob] if glob else None,
after_context=after_context,
before_context=before_context,
line_number=True,
)
return {"success": True, "content": content}
return {"success": True, "content": _truncate_long_lines("".join(results))}
return await asyncio.to_thread(_run)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
with open(abs_path, encoding=encoding) as f:
content = f.read()
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"error": "old string not found in file",
"replacements": 0,
}
if replace_all:
updated = content.replace(old_string, new_string)
replacements = occurrences
else:
updated = content.replace(old_string, new_string, 1)
replacements = 1
with open(abs_path, "w", encoding=encoding) as f:
f.write(updated)
return {
"success": True,
"path": abs_path,
"replacements": replacements,
}
return await asyncio.to_thread(_run)
@@ -212,7 +272,7 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, mode, encoding=encoding) as f:
f.write(content)
@@ -222,7 +282,7 @@ class LocalFileSystemComponent(FileSystemComponent):
async def delete_file(self, path: str) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
if os.path.isdir(abs_path):
shutil.rmtree(abs_path)
else:
@@ -235,7 +295,7 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
entries = os.listdir(abs_path)
if not show_hidden:
entries = [e for e in entries if not e.startswith(".")]

View File

@@ -1,9 +1,87 @@
from __future__ import annotations
from typing import Any
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
from shipyard import ShipyardClient, Spec
from astrbot.api import logger
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard_search_file_util import search_files_via_shell
class ShipyardFileSystemWrapper:
def __init__(
self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent
):
self._fs = _shipyard_fs
self._shell = _shipyard_shell
async def create_file(
self, path: str, content: str = "", mode: int = 420
) -> dict[str, Any]:
return await self._fs.create_file(path=path, content=content, mode=mode)
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
return await self._fs.read_file(
path=path, encoding=encoding, offset=offset, limit=limit
)
async def write_file(
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
return await self._fs.write_file(
path=path, content=content, mode=mode, encoding=encoding
)
async def list_dir(
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
return await self._fs.list_dir(path=path, show_hidden=show_hidden)
async def delete_file(self, path: str) -> dict[str, Any]:
return await self._fs.delete_file(path=path)
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
return await search_files_via_shell(
self._shell,
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
return await self._fs.edit_file(
path=path,
old_string=old_string,
new_string=new_string,
replace_all=replace_all,
encoding=encoding,
)
class ShipyardBooter(ComputerBooter):
@@ -29,13 +107,14 @@ class ShipyardBooter(ComputerBooter):
)
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
self._ship = ship
self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._ship.shell)
async def shutdown(self) -> None:
logger.info("[Computer] Shipyard booter shutdown.")
@property
def fs(self) -> FileSystemComponent:
return self._ship.fs
return self._fs
@property
def python(self) -> PythonComponent:

View File

@@ -13,6 +13,15 @@ from ..olayer import (
ShellComponent,
)
from .base import ComputerBooter
from .shipyard_search_file_util import search_files_via_shell
try:
from shipyard_neo import BayClient
from shipyard_neo.sandbox import Sandbox
except ImportError:
logger.warning(
"shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it."
)
def _maybe_model_dump(value: Any) -> dict[str, Any]:
@@ -25,8 +34,20 @@ def _maybe_model_dump(value: Any) -> dict[str, Any]:
return {}
def _slice_content_by_lines(
content: str,
*,
offset: int | None = None,
limit: int | None = None,
) -> str:
lines = content.splitlines(keepends=True)
start = 0 if offset is None else offset
selected = lines[start:] if limit is None else lines[start : start + limit]
return "".join(selected)
class NeoPythonComponent(PythonComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
async def exec(
@@ -67,7 +88,7 @@ class NeoPythonComponent(PythonComponent):
class NeoShellComponent(ShellComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
async def exec(
@@ -136,8 +157,9 @@ class NeoShellComponent(ShellComponent):
class NeoFileSystemComponent(FileSystemComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None:
self._sandbox = sandbox
self._shell = shell
async def create_file(
self,
@@ -149,10 +171,71 @@ class NeoFileSystemComponent(FileSystemComponent):
await self._sandbox.filesystem.write_file(path, content)
return {"success": True, "path": path}
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
return {"success": True, "path": path, "content": content}
return {
"success": True,
"path": path,
"content": _slice_content_by_lines(
content,
offset=offset,
limit=limit,
),
}
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
return await search_files_via_shell(
self._shell,
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"error": "old string not found in file",
"replacements": 0,
}
if replace_all:
updated = content.replace(old_string, new_string)
replacements = occurrences
else:
updated = content.replace(old_string, new_string, 1)
replacements = 1
await self._sandbox.filesystem.write_file(path, updated)
return {
"success": True,
"path": path,
"replacements": replacements,
}
async def write_file(
self,
@@ -186,7 +269,7 @@ class NeoFileSystemComponent(FileSystemComponent):
class NeoBrowserComponent(BrowserComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
async def exec(
@@ -271,8 +354,8 @@ class ShipyardNeoBooter(ComputerBooter):
self._access_token = access_token
self._profile = profile
self._ttl = ttl
self._client: Any = None
self._sandbox: Any = None
self._client: BayClient | None = None
self._sandbox: Sandbox | None = None
self._bay_manager: Any = None # BayContainerManager when auto-started
self._fs: FileSystemComponent | None = None
self._python: PythonComponent | None = None
@@ -336,8 +419,6 @@ class ShipyardNeoBooter(ComputerBooter):
"or ensure Bay's credentials.json is accessible for auto-discovery."
)
from shipyard_neo import BayClient
self._client = BayClient(
endpoint_url=self._endpoint_url,
access_token=self._access_token,
@@ -352,9 +433,9 @@ class ShipyardNeoBooter(ComputerBooter):
ttl=self._ttl,
)
self._fs = NeoFileSystemComponent(self._sandbox)
self._python = NeoPythonComponent(self._sandbox)
self._shell = NeoShellComponent(self._sandbox)
self._fs = NeoFileSystemComponent(self._sandbox, self._shell)
self._python = NeoPythonComponent(self._sandbox)
caps = self.capabilities or ()
self._browser = (

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
import shlex
from typing import Any
from ..olayer import ShellComponent
_MAX_SEARCH_LINE_COLUMNS = 1000
def _truncate_long_lines(text: str) -> str:
output_lines: list[str] = []
for line in text.splitlines(keepends=True):
line_ending = ""
line_body = line
if line.endswith("\r\n"):
line_body = line[:-2]
line_ending = "\r\n"
elif line.endswith("\n") or line.endswith("\r"):
line_body = line[:-1]
line_ending = line[-1]
if len(line_body) > _MAX_SEARCH_LINE_COLUMNS:
line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS]
output_lines.append(f"{line_body}{line_ending}")
return "".join(output_lines)
def _build_rg_command(
*,
pattern: str,
path: str,
glob: str | None,
after_context: int | None,
before_context: int | None,
) -> list[str]:
command = [
"rg",
"--color=never",
"-n",
"--max-columns",
str(_MAX_SEARCH_LINE_COLUMNS),
"-e",
pattern,
]
if glob:
command.extend(["-g", glob])
if after_context is not None:
command.extend(["-A", str(after_context)])
if before_context is not None:
command.extend(["-B", str(before_context)])
command.extend(["--", path])
return command
def _build_grep_command(
*,
pattern: str,
path: str,
glob: str | None,
after_context: int | None,
before_context: int | None,
) -> list[str]:
command = ["grep", "-R", "-H", "-n", "-e", pattern]
if glob:
command.append(f"--include={glob}")
if after_context is not None:
command.extend(["-A", str(after_context)])
if before_context is not None:
command.extend(["-B", str(before_context)])
command.extend(["--", path])
return command
def _quote_command(command: list[str]) -> str:
return " ".join(shlex.quote(part) for part in command)
def build_search_command(
*,
pattern: str,
path: str,
glob: str | None,
after_context: int | None,
before_context: int | None,
) -> str:
rg_command = _quote_command(
_build_rg_command(
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
)
grep_command = _quote_command(
_build_grep_command(
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
)
return (
"if command -v rg >/dev/null 2>&1; then "
f"{rg_command}; "
"elif command -v grep >/dev/null 2>&1; then "
f"{grep_command}; "
"else "
"echo 'Neither rg nor grep is available in the sandbox.' >&2; "
"exit 127; "
"fi"
)
async def search_files_via_shell(
shell: ShellComponent,
*,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
timeout: int = 30,
) -> dict[str, Any]:
command = build_search_command(
pattern=pattern,
path=path or ".",
glob=glob,
after_context=after_context,
before_context=before_context,
)
result = await shell.exec(command, timeout=timeout)
stdout = _truncate_long_lines(str(result.get("stdout", "") or ""))
stderr = str(result.get("stderr", "") or "")
exit_code = result.get("exit_code")
if exit_code in (0, None):
return {"success": True, "content": stdout}
if exit_code == 1:
return {"success": True, "content": ""}
return {
"success": False,
"content": "",
"error": stderr or f"command exited with code {exit_code}",
"exit_code": exit_code,
}

View File

@@ -0,0 +1,707 @@
from __future__ import annotations
import base64
import hashlib
import io
import json
import zipfile
from asyncio import to_thread
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import mcp
from astrbot.core.agent.context.token_counter import EstimateTokenCounter
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
IMAGE_COMPRESS_DEFAULT_OPTIMIZE,
IMAGE_COMPRESS_DEFAULT_QUALITY,
_compress_image_sync,
)
from .booters.base import ComputerBooter
_MAX_FILE_READ_BYTES = 128 * 1024
_MAX_FILE_READ_TOKENS = 25_000
_MAX_TEXT_FILE_FULL_READ_BYTES = 256 * 1024
_FILE_SNIFF_BYTES = 512
_TOKEN_COUNTER = EstimateTokenCounter()
_TEXT_ENCODINGS = (
"utf-8-sig",
"utf-8",
"gb18030",
"utf-16",
"utf-16-le",
"utf-16-be",
"utf-32",
"utf-32-le",
"utf-32-be",
)
_UTF_BOMS = (
b"\xef\xbb\xbf",
b"\xff\xfe",
b"\xfe\xff",
b"\xff\xfe\x00\x00",
b"\x00\x00\xfe\xff",
)
_ZIP_MAGIC_PREFIXES = (
b"PK\x03\x04",
b"PK\x05\x06",
b"PK\x07\x08",
)
_BINARY_MAGIC_PREFIXES = (
b"%PDF-",
b"\x1f\x8b",
b"7z\xbc\xaf\x27\x1c",
b"Rar!\x1a\x07",
b"\x7fELF",
b"MZ",
)
@dataclass(frozen=True)
class FileProbe:
kind: Literal["text", "image", "binary"]
encoding: str | None
mime_type: str | None
size_bytes: int
@dataclass(frozen=True)
class ParsedDocument:
kind: Literal["docx", "pdf"]
file_bytes: bytes
text: str
def _build_probe_script(path: str) -> str:
return f"""
import base64
import json
from pathlib import Path
path = Path({path!r})
with path.open("rb") as file_obj:
sample = file_obj.read({_FILE_SNIFF_BYTES})
print(
json.dumps(
{{
"size_bytes": path.stat().st_size,
"sample_b64": base64.b64encode(sample).decode("ascii"),
}}
)
)
""".strip()
def _build_text_read_script(
path: str,
*,
encoding: str,
offset: int | None,
limit: int | None,
) -> str:
start_expr = "0" if offset is None else str(offset)
limit_expr = "None" if limit is None else str(limit)
return f"""
import json
from pathlib import Path
path = Path({path!r})
start = {start_expr}
limit = {limit_expr}
end = None if limit is None else start + limit
lines = []
with path.open("r", encoding={encoding!r}, newline="") as file_obj:
for index, line in enumerate(file_obj):
if index < start:
continue
if end is not None and index >= end:
break
lines.append(line)
content = "".join(lines)
print(json.dumps({{"content": content}}, ensure_ascii=False))
""".strip()
def _build_image_read_script(path: str) -> str:
return f"""
import base64
import json
from pathlib import Path
path = Path({path!r})
data = path.read_bytes()
print(
json.dumps(
{{
"size_bytes": len(data),
"base64": base64.b64encode(data).decode("ascii"),
}}
)
)
""".strip()
def _looks_like_text(decoded: str) -> bool:
if not decoded:
return True
disallowed = 0
printable = 0
for char in decoded:
if char in "\n\r\t\f\b":
printable += 1
continue
if char.isprintable():
printable += 1
code = ord(char)
if (0 <= code < 32) or (127 <= code < 160):
disallowed += 1
total = max(len(decoded), 1)
return disallowed / total <= 0.02 and printable / total >= 0.85
def detect_text_encoding(sample: bytes) -> str | None:
if not sample:
return "utf-8"
if b"\x00" in sample and not sample.startswith(_UTF_BOMS):
odd_bytes = sample[1::2]
even_bytes = sample[0::2]
odd_zero_ratio = odd_bytes.count(0) / max(len(odd_bytes), 1)
even_zero_ratio = even_bytes.count(0) / max(len(even_bytes), 1)
if odd_zero_ratio < 0.8 and even_zero_ratio < 0.8:
return None
for encoding in _TEXT_ENCODINGS:
try:
decoded = sample.decode(encoding)
except UnicodeDecodeError as exc:
# Probe samples can end in the middle of a multibyte sequence.
# When the decode failure only happens at the sample tail, trim a few
# bytes and retry so UTF-8 text is not misclassified as binary.
if exc.start >= len(sample) - 4:
decoded = ""
for trim_bytes in range(1, min(4, len(sample)) + 1):
try:
decoded = sample[:-trim_bytes].decode(encoding)
break
except UnicodeDecodeError:
continue
if not decoded:
continue
else:
continue
if _looks_like_text(decoded):
return encoding
return None
def read_local_text_range_sync(
path: str,
*,
encoding: str,
offset: int | None,
limit: int | None,
) -> str:
lines: list[str] = []
start = 0 if offset is None else offset
end = None if limit is None else start + limit
with open(path, encoding=encoding, newline="") as file_obj:
for index, line in enumerate(file_obj):
if index < start:
continue
if end is not None and index >= end:
break
lines.append(line)
return "".join(lines)
async def read_local_text_range(
path: str,
*,
encoding: str,
offset: int | None,
limit: int | None,
) -> str:
return await to_thread(
read_local_text_range_sync,
path,
encoding=encoding,
offset=offset,
limit=limit,
)
async def _exec_python_json(
booter: ComputerBooter,
script: str,
*,
action: str,
) -> dict:
result = await booter.python.exec(script)
data = result.get("data") if isinstance(result.get("data"), dict) else {}
if not isinstance(data, dict):
raise RuntimeError(f"{action} failed: invalid result format")
output = data.get("output") if isinstance(data.get("output"), dict) else {}
if not isinstance(output, dict):
raise RuntimeError(f"{action} failed: invalid output format")
error_text = str(data.get("error", "") or result.get("error", "") or "").strip()
if error_text:
raise RuntimeError(f"{action} failed: {error_text}")
text = str(output.get("text", "") or "").strip()
if not text:
raise RuntimeError(f"{action} failed: empty output")
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise RuntimeError(f"{action} failed: invalid JSON output") from exc
if not isinstance(payload, dict):
raise RuntimeError(f"{action} failed: invalid JSON payload")
return payload
async def _probe_local_file(path: str) -> dict[str, str | int]:
def _run() -> dict[str, str | int]:
file_path = Path(path)
with file_path.open("rb") as file_obj:
sample = file_obj.read(_FILE_SNIFF_BYTES)
return {
"size_bytes": file_path.stat().st_size,
"sample_b64": base64.b64encode(sample).decode("ascii"),
}
return await to_thread(_run)
async def _read_local_image_base64(path: str) -> dict[str, str | int]:
def _run() -> dict[str, str | int]:
data = Path(path).read_bytes()
return {
"size_bytes": len(data),
"base64": base64.b64encode(data).decode("ascii"),
}
return await to_thread(_run)
async def _read_local_file_bytes(path: str) -> bytes:
return await to_thread(Path(path).read_bytes)
async def _compress_image_bytes_to_base64(data: bytes) -> dict[str, str | int]:
def _run() -> dict[str, str | int]:
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
compressed_path = Path(
_compress_image_sync(
data,
temp_dir,
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
IMAGE_COMPRESS_DEFAULT_QUALITY,
IMAGE_COMPRESS_DEFAULT_OPTIMIZE,
)
)
try:
compressed_bytes = compressed_path.read_bytes()
finally:
compressed_path.unlink(missing_ok=True)
return {
"size_bytes": len(compressed_bytes),
"base64": base64.b64encode(compressed_bytes).decode("ascii"),
"mime_type": "image/jpeg",
}
return await to_thread(_run)
def _detect_image_mime(sample: bytes) -> str | None:
if sample.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
if sample.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
if sample.startswith((b"GIF87a", b"GIF89a")):
return "image/gif"
if sample.startswith(b"BM"):
return "image/bmp"
if sample.startswith((b"II*\x00", b"MM\x00*")):
return "image/tiff"
if sample.startswith(b"\x00\x00\x01\x00"):
return "image/x-icon"
if len(sample) >= 12 and sample[:4] == b"RIFF" and sample[8:12] == b"WEBP":
return "image/webp"
if len(sample) >= 12 and sample[4:12] in (b"ftypavif", b"ftypavis"):
return "image/avif"
return None
def _looks_like_known_binary(sample: bytes) -> bool:
return any(sample.startswith(prefix) for prefix in _BINARY_MAGIC_PREFIXES)
def _looks_like_pdf(path: str, sample: bytes) -> bool:
return Path(path).suffix.lower() == ".pdf" or sample.startswith(b"%PDF-")
def _looks_like_zip_container(sample: bytes) -> bool:
return any(sample.startswith(prefix) for prefix in _ZIP_MAGIC_PREFIXES)
def _is_docx_bytes(file_bytes: bytes) -> bool:
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as archive:
names = set(archive.namelist())
except (OSError, zipfile.BadZipFile):
return False
if "[Content_Types].xml" not in names:
return False
return any(name.startswith("word/") for name in names)
async def _parse_local_docx_text(file_bytes: bytes, file_name: str) -> str:
from astrbot.core.knowledge_base.parsers.markitdown_parser import (
MarkitdownParser,
)
result = await MarkitdownParser().parse(file_bytes, file_name)
return result.text
async def _parse_local_pdf_text(file_bytes: bytes, file_name: str) -> str:
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
result = await PDFParser().parse(file_bytes, file_name)
return result.text
async def _parse_local_supported_document(
path: str,
sample: bytes,
) -> ParsedDocument | None:
file_name = Path(path).name
if _looks_like_pdf(path, sample):
file_bytes = await _read_local_file_bytes(path)
text = await _parse_local_pdf_text(file_bytes, file_name)
return ParsedDocument(kind="pdf", file_bytes=file_bytes, text=text)
if Path(path).suffix.lower() == ".docx" or _looks_like_zip_container(sample):
file_bytes = await _read_local_file_bytes(path)
if not _is_docx_bytes(file_bytes):
return None
text = await _parse_local_docx_text(file_bytes, file_name)
return ParsedDocument(kind="docx", file_bytes=file_bytes, text=text)
return None
def _probe_file(sample: bytes, *, size_bytes: int) -> FileProbe:
if image_mime := _detect_image_mime(sample):
return FileProbe(
kind="image",
encoding=None,
mime_type=image_mime,
size_bytes=size_bytes,
)
if _looks_like_known_binary(sample):
return FileProbe(
kind="binary",
encoding=None,
mime_type=None,
size_bytes=size_bytes,
)
if encoding := detect_text_encoding(sample):
return FileProbe(
kind="text",
encoding=encoding,
mime_type="text/plain",
size_bytes=size_bytes,
)
return FileProbe(
kind="binary",
encoding=None,
mime_type=None,
size_bytes=size_bytes,
)
def _validate_text_output(content: str) -> str | None:
content_bytes = len(content.encode("utf-8"))
if content_bytes > _MAX_FILE_READ_BYTES:
return (
"Error reading file: "
f"output exceeds {_MAX_FILE_READ_BYTES} bytes "
f"({content_bytes} bytes). Use `offset`, `limit` to narrow the read window."
)
content_tokens = _TOKEN_COUNTER.count_tokens(
[Message(role="user", content=content)]
)
if content_tokens > _MAX_FILE_READ_TOKENS:
return (
"Error reading file: "
f"output exceeds {_MAX_FILE_READ_TOKENS} tokens "
f"({content_tokens} tokens). Use `offset`, `limit` to narrow the read window."
)
return None
def _text_exceeds_read_thresholds(content: str) -> bool:
return _validate_text_output(content) is not None
def _validate_full_text_read_request(probe: FileProbe) -> str | None:
if probe.size_bytes > _MAX_TEXT_FILE_FULL_READ_BYTES:
return (
"Error reading file: "
f"text file exceeds {_MAX_TEXT_FILE_FULL_READ_BYTES} bytes "
f"({probe.size_bytes} bytes). Use `offset` and `limit` to narrow the read window."
)
return None
def _slice_text_by_lines(
content: str,
*,
offset: int | None,
limit: int | None,
) -> str:
if offset is None and limit is None:
return content
lines = content.splitlines(keepends=True)
start = 0 if offset is None else offset
end = None if limit is None else start + limit
return "".join(lines[start:end])
async def _store_converted_text_for_workspace(
*,
workspace_dir: str,
original_path: str,
original_bytes: bytes,
content: str,
) -> str:
def _run() -> str:
original_name = Path(original_path).name
digest_suffix = hashlib.md5(original_bytes).hexdigest()[-6:]
target_dir = (
Path(workspace_dir) / "converted_files" / f"{original_name}_{digest_suffix}"
)
target_dir.mkdir(parents=True, exist_ok=True)
target_path = target_dir / "text.txt"
target_path.write_text(content, encoding="utf-8")
return str(target_path)
return await to_thread(_run)
def _build_converted_text_notice(
converted_text_path: str,
*,
selection_returned: bool,
selection_too_large: bool = False,
) -> str:
if selection_too_large:
return (
"Converted text was saved to "
f"`{converted_text_path}`. The requested output is still too large to "
"return directly. Read or grep that file with a narrower window."
)
if selection_returned:
return (
"Full converted text is also available at "
f"`{converted_text_path}`. Read or grep that file with a narrow "
"window for additional reads."
)
return (
"Converted text was saved to "
f"`{converted_text_path}` because the parsed document is too large to "
"return directly. Read or grep that file with a narrow window."
)
async def _read_local_supported_document_result(
*,
path: str,
parsed_document: ParsedDocument,
workspace_dir: str | None,
offset: int | None,
limit: int | None,
) -> ToolExecResult:
content = parsed_document.text
if not content:
return "No content found at the requested line offset."
if not _text_exceeds_read_thresholds(content):
selected_content = _slice_text_by_lines(content, offset=offset, limit=limit)
if not selected_content:
return "No content found at the requested line offset."
if validation_error := _validate_text_output(selected_content):
return validation_error
return selected_content
if not workspace_dir:
return (
"Error reading file: parsed document exceeds the read output limit and "
"no workspace is available for storing converted text."
)
converted_text_path = await _store_converted_text_for_workspace(
workspace_dir=workspace_dir,
original_path=path,
original_bytes=parsed_document.file_bytes,
content=content,
)
if offset is None and limit is None:
return _build_converted_text_notice(
converted_text_path,
selection_returned=False,
)
selected_content = _slice_text_by_lines(content, offset=offset, limit=limit)
if not selected_content:
return (
"No content found at the requested line offset. "
+ _build_converted_text_notice(
converted_text_path,
selection_returned=False,
)
)
notice = _build_converted_text_notice(
converted_text_path,
selection_returned=True,
)
combined_output = f"{selected_content}\n\n[{notice}]"
if _validate_text_output(combined_output):
if _validate_text_output(selected_content):
return _build_converted_text_notice(
converted_text_path,
selection_returned=False,
selection_too_large=True,
)
return selected_content
return combined_output
async def read_file_tool_result(
booter: ComputerBooter,
*,
local_mode: bool,
path: str,
offset: int | None,
limit: int | None,
workspace_dir: str | None = None,
) -> ToolExecResult:
if local_mode:
probe_payload = await _probe_local_file(path)
else:
probe_payload = await _exec_python_json(
booter,
_build_probe_script(path),
action="file probe",
)
sample_b64 = str(probe_payload.get("sample_b64", "") or "")
sample = base64.b64decode(sample_b64) if sample_b64 else b""
size_bytes = int(probe_payload.get("size_bytes", 0) or 0)
probe = _probe_file(sample, size_bytes=size_bytes)
if local_mode:
try:
parsed_document = await _parse_local_supported_document(path, sample)
except Exception as exc:
return f"Error reading file: failed to parse document: {exc}"
if parsed_document is not None:
return await _read_local_supported_document_result(
path=path,
parsed_document=parsed_document,
workspace_dir=workspace_dir,
offset=offset,
limit=limit,
)
if probe.kind == "binary":
return "Error reading file: binary files are not supported by this tool."
if probe.kind == "image":
if local_mode:
image_payload = await _read_local_image_base64(path)
else:
image_payload = await _exec_python_json(
booter,
_build_image_read_script(path),
action="image read",
)
raw_base64_data = str(image_payload.get("base64", "") or "")
if not raw_base64_data:
return "Error reading file: image payload is empty."
raw_bytes = base64.b64decode(raw_base64_data)
compressed_payload = await _compress_image_bytes_to_base64(raw_bytes)
base64_data = str(compressed_payload.get("base64", "") or "")
if not base64_data:
return "Error reading file: compressed image payload is empty."
return mcp.types.CallToolResult(
content=[
mcp.types.ImageContent(
type="image",
data=base64_data,
mimeType=str(
compressed_payload.get("mime_type", "") or "image/jpeg"
),
)
]
)
if offset is None and limit is None:
if validation_error := _validate_full_text_read_request(probe):
return validation_error
if local_mode:
content = await read_local_text_range(
path,
encoding=probe.encoding or "utf-8",
offset=offset,
limit=limit,
)
else:
text_payload = await _exec_python_json(
booter,
_build_text_read_script(
path,
encoding=probe.encoding or "utf-8",
offset=offset,
limit=limit,
),
action="text read",
)
content = str(text_payload.get("content", "") or "")
if not content:
return "No content found at the requested line offset."
if validation_error := _validate_text_output(content):
return validation_error
return content

View File

@@ -12,8 +12,36 @@ class FileSystemComponent(Protocol):
"""Create a file with the specified content"""
...
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
"""Read file content"""
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Read file content by line window"""
...
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
"""Search file contents"""
...
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
"""Edit file content by string replacement"""
...
async def write_file(

View File

@@ -1,213 +0,0 @@
import os
import uuid
from dataclasses import dataclass, field
from astrbot.api import FunctionTool, logger
from astrbot.api.event import MessageChain
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..computer_client import get_booter
from .permissions import check_admin_permission
# @dataclass
# class CreateFileTool(FunctionTool):
# name: str = "astrbot_create_file"
# description: str = "Create a new file in the sandbox."
# parameters: dict = field(
# default_factory=lambda: {
# "type": "object",
# "properties": {
# "path": {
# "path": "string",
# "description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
# },
# "content": {
# "type": "string",
# "description": "The content to write into the file.",
# },
# },
# "required": ["path", "content"],
# }
# )
# async def call(
# self, context: ContextWrapper[AstrAgentContext], path: str, content: str
# ) -> ToolExecResult:
# sb = await get_booter(
# context.context.context,
# context.context.event.unified_msg_origin,
# )
# try:
# result = await sb.fs.create_file(path, content)
# return json.dumps(result)
# except Exception as e:
# return f"Error creating file: {str(e)}"
# @dataclass
# class ReadFileTool(FunctionTool):
# name: str = "astrbot_read_file"
# description: str = "Read the content of a file in the sandbox."
# parameters: dict = field(
# default_factory=lambda: {
# "type": "object",
# "properties": {
# "path": {
# "type": "string",
# "description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
# },
# },
# "required": ["path"],
# }
# )
# async def call(self, context: ContextWrapper[AstrAgentContext], path: str):
# sb = await get_booter(
# context.context.context,
# context.context.event.unified_msg_origin,
# )
# try:
# result = await sb.fs.read_file(path)
# return result
# except Exception as e:
# return f"Error reading file: {str(e)}"
@dataclass
class FileUploadTool(FunctionTool):
name: str = "astrbot_upload_file"
description: str = (
"Transfer a file FROM the host machine INTO the sandbox so that sandbox "
"code can access it. Use this when the user sends/attaches a file and you "
"need to process it inside the sandbox. The local_path must point to an "
"existing file on the host filesystem."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"local_path": {
"type": "string",
"description": "Absolute path to the file on the host filesystem that will be copied into the sandbox.",
},
# "remote_path": {
# "type": "string",
# "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.",
# },
},
"required": ["local_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
local_path: str,
) -> str | None:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
# Check if file exists
if not os.path.exists(local_path):
return f"Error: File does not exist: {local_path}"
if not os.path.isfile(local_path):
return f"Error: Path is not a file: {local_path}"
# Use basename if sandbox_filename is not provided
remote_path = os.path.basename(local_path)
# Upload file to sandbox
result = await sb.upload_file(local_path, remote_path)
logger.debug(f"Upload result: {result}")
success = result.get("success", False)
if not success:
return f"Error uploading file: {result.get('message', 'Unknown error')}"
file_path = result.get("file_path", "")
logger.info(f"File {local_path} uploaded to sandbox at {file_path}")
return f"File uploaded successfully to {file_path}"
except Exception as e:
logger.error(f"Error uploading file {local_path}: {e}")
return f"Error uploading file: {str(e)}"
@dataclass
class FileDownloadTool(FunctionTool):
name: str = "astrbot_download_file"
description: str = (
"Transfer a file FROM the sandbox OUT to the host and optionally send it "
"to the user. Use this ONLY when the user asks to retrieve/export a file "
"that was created or modified inside the sandbox."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"remote_path": {
"type": "string",
"description": "Path of the file inside the sandbox to copy out to the host.",
},
"also_send_to_user": {
"type": "boolean",
"description": "Whether to also send the downloaded file to the user via message. Defaults to true.",
},
},
"required": ["remote_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
remote_path: str,
also_send_to_user: bool = True,
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
name = os.path.basename(remote_path)
local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
# Download file from sandbox
await sb.download_file(remote_path, local_path)
logger.info(f"File {remote_path} downloaded from sandbox to {local_path}")
if also_send_to_user:
try:
name = os.path.basename(local_path)
await context.context.event.send(
MessageChain(chain=[File(name=name, file=local_path)])
)
except Exception as e:
logger.error(f"Error sending file message: {e}")
# remove
# try:
# os.remove(local_path)
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path} and sent to user."
return f"File downloaded successfully to {local_path}"
except Exception as e:
logger.error(f"Error downloading file {remote_path}: {e}")
return f"Error downloading file: {str(e)}"

View File

@@ -36,6 +36,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
)
from astrbot.core.subagent_orchestrator import SubAgentOrchestrator
from astrbot.core.utils.astrbot_path import get_astrbot_system_tmp_path
from ..exceptions import ProviderNotFoundError
from .filter.command import CommandFilter
@@ -232,6 +233,13 @@ class Context:
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
if request.func_tool and request.func_tool.get_tool("astrbot_file_read_tool"):
other_kwargs.setdefault(
"tool_result_overflow_dir", get_astrbot_system_tmp_path()
)
other_kwargs.setdefault(
"read_tool", request.func_tool.get_tool("astrbot_file_read_tool")
)
await agent_runner.reset(
provider=prov,

View File

@@ -0,0 +1,55 @@
from .fs import (
FileDownloadTool,
FileEditTool,
FileReadTool,
FileUploadTool,
FileWriteTool,
GrepTool,
)
from .python import LocalPythonTool, PythonTool
from .shell import ExecuteShellTool
from .shipyard_neo import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
PromoteSkillCandidateTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
)
from .util import check_admin_permission, normalize_umo_for_workspace
__all__ = [
"AnnotateExecutionTool",
"BrowserBatchExecTool",
"BrowserExecTool",
"CreateSkillCandidateTool",
"CreateSkillPayloadTool",
"EvaluateSkillCandidateTool",
"ExecuteShellTool",
"FileDownloadTool",
"FileEditTool",
"FileReadTool",
"FileUploadTool",
"FileWriteTool",
"GetExecutionHistoryTool",
"GetSkillPayloadTool",
"GrepTool",
"ListSkillCandidatesTool",
"ListSkillReleasesTool",
"LocalPythonTool",
"PromoteSkillCandidateTool",
"PythonTool",
"RollbackSkillReleaseTool",
"RunBrowserSkillTool",
"SyncSkillReleaseTool",
"normalize_umo_for_workspace",
"check_admin_permission",
]

View File

@@ -0,0 +1,747 @@
"""Filesystem tool audit.
Tool exposure from the main agent:
- Local runtime exposes `astrbot_read_file_tool`, `astrbot_file_write_tool`,
`astrbot_file_edit_tool`, and `astrbot_grep_tool`.
- Sandbox runtime exposes `astrbot_upload_file`, `astrbot_download_file`,
`astrbot_read_file_tool`, `astrbot_file_write_tool`,
`astrbot_file_edit_tool`, and `astrbot_grep_tool`.
Behavior when `provider_settings.computer_use_require_admin=True`:
- Admin + local: read/write/edit/grep are not path-restricted by this module;
access depends on the local runtime implementation and host OS permissions.
Upload and download tools are defined here, but `LocalBooter` does not
implement them and the main agent does not expose them in local mode.
- Member + local: read/write/edit/grep are restricted to `data/skills`,
`data/workspaces/{normalized_umo}`, and `/tmp/.astrbot`. Upload/download are
denied by `check_admin_permission` if invoked.
- Admin + sandbox: read/write/edit/grep are not path-restricted by this
module;
sandbox filesystem boundaries are enforced by the sandbox runtime. Upload and
download are allowed.
- Member + sandbox: read/write/edit/grep are also not path-restricted by this
module. Upload/download are denied by `check_admin_permission` if invoked.
When `computer_use_require_admin=False`, member behavior in this module matches
admin behavior.
Local path resolution rule:
- In local runtime, relative paths are resolved under
`data/workspaces/{normalized_umo}`.
- In sandbox runtime, relative paths are passed through unchanged.
"""
import os
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from astrbot.api import FunctionTool, logger
from astrbot.api.event import MessageChain
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.computer.file_read_utils import read_file_tool_result
from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import (
get_astrbot_skills_path,
get_astrbot_system_tmp_path,
get_astrbot_temp_path,
)
from ..registry import builtin_tool
from . import util as computer_util
from .util import (
check_admin_permission,
is_local_runtime,
normalize_umo_for_workspace,
)
_COMPUTER_RUNTIME_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": ("local", "sandbox"),
}
_SANDBOX_RUNTIME_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
}
def _restricted_env_path_labels(umo: str) -> list[str]:
"""Labels for the allowed directories in a local(not sandbox) and restricted(not admin) environment"""
normalized_umo = normalize_umo_for_workspace(umo)
return [
"data/skills",
f"data/workspaces/{normalized_umo}",
get_astrbot_system_tmp_path(),
]
def get_astrbot_workspaces_path() -> str:
"""Compatibility wrapper for tests and older module-level monkeypatches."""
return computer_util.get_astrbot_workspaces_path()
def _workspace_root(umo: str) -> Path:
"""Workspace root that follows both util-level and fs-level getter monkeypatches."""
normalized_umo = normalize_umo_for_workspace(umo)
return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False)
def _read_allowed_roots(umo: str) -> tuple[Path, ...]:
"""Non-admin users can only read files within these directories (and their subdirectories)"""
return (
Path(get_astrbot_skills_path()).resolve(strict=False),
_workspace_root(umo),
Path(get_astrbot_system_tmp_path()).resolve(strict=False),
)
def _is_restricted_env(context: ContextWrapper[AstrAgentContext]) -> bool:
if not is_local_runtime(context):
return False
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
require_admin = provider_settings.get("computer_use_require_admin", True)
return require_admin and context.context.event.role != "admin"
def _resolve_tool_path(path: str, *, local_env: bool, umo: str) -> str:
normalized_path = path.strip()
if not normalized_path:
return normalized_path
candidate = Path(normalized_path).expanduser()
if candidate.is_absolute():
return str(candidate.resolve(strict=False))
if local_env:
return str((_workspace_root(umo) / candidate).resolve(strict=False))
return normalized_path
def _resolve_user_path(path: str, *, local_env: bool, umo: str) -> Path:
candidate = Path(path).expanduser()
if candidate.is_absolute():
return candidate.resolve(strict=False)
if local_env:
return (_workspace_root(umo) / candidate).resolve(strict=False)
return (Path.cwd() / candidate).resolve(strict=False)
def _is_path_within_allowed_roots(path: str, umo: str) -> bool:
resolved = _resolve_user_path(path, local_env=True, umo=umo)
return any(
resolved == allowed_root or resolved.is_relative_to(allowed_root)
for allowed_root in _read_allowed_roots(umo)
)
def _normalize_rw_path(
path: str,
*,
restricted: bool,
local_env: bool,
umo: str,
) -> str:
normalized_path = _resolve_tool_path(path, local_env=local_env, umo=umo)
if not normalized_path:
raise ValueError("`path` must be a non-empty string.")
if restricted and not _is_path_within_allowed_roots(normalized_path, umo):
allowed = ", ".join(_restricted_env_path_labels(umo))
raise PermissionError(
"Read access is restricted for this user. "
f"Allowed directories: {allowed}. Blocked path: {normalized_path}."
)
return normalized_path
def _decode_escaped_text(value: str) -> str:
"""Decode common escaped control sequences used in tool arguments."""
return (
value.replace("\\r\\n", "\n")
.replace("\\n", "\n")
.replace("\\r", "\r")
.replace("\\t", "\t")
)
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
@dataclass
class FileReadTool(FunctionTool):
name: str = "astrbot_file_read_tool"
description: str = "read file content."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path of the file to read. If relative, will be in workspace root.",
},
"offset": {
"type": "integer",
"description": "Optional line offset to start reading from. 0-based index.",
"minimum": 0,
},
"limit": {
"type": "integer",
"description": "Optional maximum number of lines to read.",
"minimum": 1,
},
},
"required": ["path"],
}
)
def _validate_read_window(
self,
offset: int | None,
limit: int | None,
) -> tuple[int | None, int | None]:
if offset is not None and offset < 0:
raise ValueError("`offset` must be greater than or equal to 0.")
if limit is not None and limit < 1:
raise ValueError("`limit` must be greater than or equal to 1.")
return offset, limit
async def call(
self,
context: ContextWrapper[AstrAgentContext],
path: str,
offset: int | None = None,
limit: int | None = None,
) -> ToolExecResult:
local_env = is_local_runtime(context)
restricted = _is_restricted_env(context)
try:
normalized_path = (
_normalize_rw_path(
path,
restricted=restricted,
local_env=local_env,
umo=context.context.event.unified_msg_origin,
)
if local_env
else path.strip()
)
if not normalized_path:
raise ValueError("`path` must be a non-empty string.")
offset, limit = self._validate_read_window(offset, limit)
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
return await read_file_tool_result(
sb,
local_mode=local_env,
path=normalized_path,
offset=offset,
limit=limit,
workspace_dir=(
str(_workspace_root(context.context.event.unified_msg_origin))
if local_env
else None
),
)
except PermissionError as exc:
return f"Error: {exc}"
except Exception as exc:
logger.error(f"Error reading file: {exc}")
return f"Error reading file: {exc}"
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
@dataclass
class FileWriteTool(FunctionTool):
name: str = "astrbot_file_write_tool"
description: str = "Write UTF-8 text content to a file."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path of the file to write. If relative, will be in workspace root.",
},
"content": {
"type": "string",
"description": "The content to write to the file",
},
},
"required": ["path", "content"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
path: str,
content: str,
) -> ToolExecResult:
local_env = is_local_runtime(context)
restricted = _is_restricted_env(context)
try:
normalized_path = (
_normalize_rw_path(
path,
restricted=restricted,
local_env=local_env,
umo=context.context.event.unified_msg_origin,
)
if local_env
else path.strip()
)
if not normalized_path:
raise ValueError("`path` must be a non-empty string.")
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
result = await sb.fs.write_file(
path=normalized_path,
content=content,
mode="w",
encoding="utf-8",
)
if not result.get("success", False):
error_detail = str(result.get("error", "") or "").strip()
return (
"Error writing file: "
f"{error_detail or 'unknown filesystem write error'}"
)
return f"File written successfully: {normalized_path}"
except PermissionError as exc:
return f"Error: {exc}"
except Exception as exc:
logger.error(f"Error writing file: {exc}")
return f"Error writing file: {exc}"
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
@dataclass
class FileEditTool(FunctionTool):
name: str = "astrbot_file_edit_tool"
description: str = "Editing files."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path of the file to edit. If relative, will be in workspace root.",
},
"old": {
"type": "string",
"description": "The exact old text to replace.",
},
"new": {
"type": "string",
"description": "The replacement text.",
},
"replace_all": {
"type": "boolean",
"description": "Whether to replace all matches. Defaults to false.",
},
},
"required": ["path", "old", "new"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
path: str,
old: str,
new: str,
replace_all: bool = False,
) -> ToolExecResult:
umo = str(context.context.event.unified_msg_origin)
local_env = is_local_runtime(context)
restricted = _is_restricted_env(context)
try:
normalized_path = (
_normalize_rw_path(
path,
restricted=restricted,
local_env=local_env,
umo=umo,
)
if local_env
else path.strip()
)
if not normalized_path:
raise ValueError("`path` must be a non-empty string.")
normalized_old = _decode_escaped_text(old)
normalized_new = _decode_escaped_text(new)
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
result = await sb.fs.edit_file(
path=normalized_path,
old_string=normalized_old,
new_string=normalized_new,
replace_all=replace_all,
encoding="utf-8",
)
if not result.get("success", False):
error_detail = str(result.get("error", "") or "").strip()
return (
"Error editing file: "
f"{error_detail or 'unknown filesystem edit error'}"
)
replacements = int(result.get("replacements", 0) or 0)
mode_text = "all matches" if replace_all else "first match"
return (
f"Edited {normalized_path}. "
f"Replaced {replacements} occurrence(s) using {mode_text} mode."
)
except PermissionError as exc:
return f"Error: {exc}"
except Exception as exc:
logger.error(f"Error editing file: {exc}")
return f"Error editing file: {exc}"
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
@dataclass
class GrepTool(FunctionTool):
name: str = "astrbot_grep_tool"
description: str = "Search and read file contents using ripgrep."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "The expression pattern to search for in file contents.",
},
"path": {
"type": "string",
"description": "File or directory to search in (rg PATH). If relative, will be in workspace root.",
},
"glob": {
"type": "string",
"description": "Optional glob filter such as `*.py`, `*.{ts,tsx}`.",
},
"-A": {
"type": "integer",
"description": "Number of trailing context lines to include after each match.",
"minimum": 0,
},
"-B": {
"type": "integer",
"description": "Number of leading context lines to include before each match.",
"minimum": 0,
},
"-C": {
"type": "integer",
"description": "Number of leading and trailing context lines to include around each match.",
"minimum": 0,
},
"result_limit": {
"type": "integer",
"description": "Maximum number of result groups returned by the tool. Defaults to 100.",
"minimum": 1,
},
},
"required": ["pattern"],
}
)
def _resolve_context_options(
self,
after_context: int | None,
before_context: int | None,
context: int | None,
) -> tuple[int | None, int | None]:
if context is not None and context < 0:
raise ValueError("`-C` must be greater than or equal to 0.")
if after_context is not None and after_context < 0:
raise ValueError("`-A` must be greater than or equal to 0.")
if before_context is not None and before_context < 0:
raise ValueError("`-B` must be greater than or equal to 0.")
resolved_after = context if after_context is None else after_context
resolved_before = context if before_context is None else before_context
return resolved_after, resolved_before
def _split_output_groups(self, output: str, *, has_context: bool) -> list[str]:
if not output.strip():
return []
if not has_context:
return [f"{line}\n" for line in output.splitlines() if line.strip()]
groups: list[str] = []
current: list[str] = []
for line in output.splitlines(keepends=True):
if line.strip() == "--":
if current:
groups.append("".join(current))
current = []
continue
if not line.strip():
continue
current.append(line)
if current:
groups.append("".join(current))
return groups
def _apply_result_limit(
self,
output: str,
*,
result_limit: int,
has_context: bool,
) -> str:
if result_limit < 1:
raise ValueError("`result_limit` must be greater than or equal to 1.")
groups = self._split_output_groups(output, has_context=has_context)
if len(groups) <= result_limit:
return output if output.strip() else "No matches found."
limited_output = "".join(groups[:result_limit]).rstrip()
return f"{limited_output}\n\n[Truncated to first {result_limit} result groups.]"
def _normalize_search_paths(
self,
path: str | None,
*,
restricted: bool,
local_env: bool,
umo: str,
) -> list[str]:
normalized = (
[_resolve_tool_path(path, local_env=local_env, umo=umo)] if path else []
)
if not normalized:
if restricted:
return [str(root) for root in _read_allowed_roots(umo)]
if local_env:
return [str(_workspace_root(umo))]
return ["."]
if restricted:
disallowed = [
path
for path in normalized
if not _is_path_within_allowed_roots(path, umo)
]
if disallowed:
allowed = ", ".join(_restricted_env_path_labels(umo))
blocked = ", ".join(disallowed)
raise PermissionError(
"Read access is restricted for this user. "
f"Allowed directories: {allowed}. Blocked paths: {blocked}."
)
return normalized
async def call(
self,
context: ContextWrapper[AstrAgentContext],
pattern: str,
path: str | None = None,
glob: str | None = None,
result_limit: int = 100,
**kwargs,
) -> ToolExecResult:
normalized_pattern = pattern.strip()
if not normalized_pattern:
return "Error: `pattern` must be a non-empty string."
local_env = is_local_runtime(context)
restricted = _is_restricted_env(context)
try:
search_paths = (
self._normalize_search_paths(
path,
restricted=restricted,
local_env=local_env,
umo=context.context.event.unified_msg_origin,
)
if local_env
else ([path.strip()] if path and path.strip() else ["."])
)
after_context, before_context = self._resolve_context_options(
kwargs.get("-A"),
kwargs.get("-B"),
kwargs.get("-C"),
)
has_context = (after_context or 0) > 0 or (before_context or 0) > 0
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
contents: list[str] = []
for search_path in search_paths:
result = await sb.fs.search_files(
pattern=normalized_pattern,
path=search_path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
if not result.get("success", False):
error_detail = str(result.get("error", "") or "").strip()
logger.error("GrepTool search failed: %s", error_detail)
return (
"Error searching files: "
f"{error_detail or 'unknown filesystem search error'}"
)
content = str(result.get("content", "") or "")
if content:
contents.append(content)
return self._apply_result_limit(
"".join(contents),
result_limit=result_limit,
has_context=has_context,
)
except PermissionError as exc:
return f"Error: {exc}"
except Exception as exc:
logger.error(f"Error searching files: {exc}")
return f"Error searching files: {exc}"
@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG)
@dataclass
class FileUploadTool(FunctionTool):
name: str = "astrbot_upload_file"
description: str = (
"Transfer a file FROM the host machine INTO the sandbox so that sandbox "
"code can access it. Use this when the user sends/attaches a file and you "
"need to process it inside the sandbox. The local_path must point to an "
"existing file on the host filesystem."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"local_path": {
"type": "string",
"description": "Absolute path to the file on the host filesystem that will be copied into the sandbox.",
},
# "remote_path": {
# "type": "string",
# "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.",
# },
},
"required": ["local_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
local_path: str,
) -> str | None:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
# Check if file exists
if not os.path.exists(local_path):
return f"Error: File does not exist: {local_path}"
if not os.path.isfile(local_path):
return f"Error: Path is not a file: {local_path}"
# Use basename if sandbox_filename is not provided
remote_path = os.path.basename(local_path)
# Upload file to sandbox
result = await sb.upload_file(local_path, remote_path)
logger.debug(f"Upload result: {result}")
success = result.get("success", False)
if not success:
return f"Error uploading file: {result.get('message', 'Unknown error')}"
file_path = result.get("file_path", "")
logger.info(f"File {local_path} uploaded to sandbox at {file_path}")
return f"File uploaded successfully to {file_path}"
except Exception as e:
logger.error(f"Error uploading file {local_path}: {e}")
return f"Error uploading file: {str(e)}"
@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG)
@dataclass
class FileDownloadTool(FunctionTool):
name: str = "astrbot_download_file"
description: str = (
"Transfer a file FROM the sandbox OUT to the host and optionally send it "
"to the user. Use this ONLY when the user asks to retrieve/export a file "
"that was created or modified inside the sandbox."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"remote_path": {
"type": "string",
"description": "Path of the file inside the sandbox to copy out to the host.",
},
"also_send_to_user": {
"type": "boolean",
"description": "Whether to also send the downloaded file to the user via message. Defaults to true.",
},
},
"required": ["remote_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
remote_path: str,
also_send_to_user: bool = True,
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
name = os.path.basename(remote_path)
local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
# Download file from sandbox
await sb.download_file(remote_path, local_path)
logger.info(f"File {remote_path} downloaded from sandbox to {local_path}")
if also_send_to_user:
try:
name = os.path.basename(local_path)
await context.context.event.send(
MessageChain(chain=[File(name=name, file=local_path)])
)
except Exception as e:
logger.error(f"Error sending file message: {e}")
# remove
# try:
# os.remove(local_path)
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path} and sent to user."
return f"File downloaded successfully to {local_path}"
except Exception as e:
logger.error(f"Error downloading file {remote_path}: {e}")
return f"Error downloading file: {str(e)}"

View File

@@ -8,10 +8,18 @@ from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
from ..registry import builtin_tool
from .util import check_admin_permission
_OS_NAME = platform.system()
_SANDBOX_PYTHON_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
}
_LOCAL_PYTHON_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "local",
}
param_schema = {
"type": "object",
@@ -61,6 +69,7 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult
return resp
@builtin_tool(config=_SANDBOX_PYTHON_TOOL_CONFIG)
@dataclass
class PythonTool(FunctionTool):
name: str = "astrbot_execute_ipython"
@@ -83,6 +92,7 @@ class PythonTool(FunctionTool):
return f"Error executing code: {str(e)}"
@builtin_tool(config=_LOCAL_PYTHON_TOOL_CONFIG)
@dataclass
class LocalPythonTool(FunctionTool):
name: str = "astrbot_execute_python"

View File

@@ -5,11 +5,17 @@ from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from ..computer_client import get_booter, get_local_booter
from .permissions import check_admin_permission
from ..registry import builtin_tool
from .util import check_admin_permission, is_local_runtime, workspace_root
_COMPUTER_RUNTIME_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": ("local", "sandbox"),
}
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
@dataclass
class ExecuteShellTool(FunctionTool):
name: str = "astrbot_execute_shell"
@@ -38,8 +44,6 @@ class ExecuteShellTool(FunctionTool):
}
)
is_local: bool = False
async def call(
self,
context: ContextWrapper[AstrAgentContext],
@@ -50,15 +54,25 @@ class ExecuteShellTool(FunctionTool):
if permission_error := check_admin_permission(context, "Shell execution"):
return permission_error
if self.is_local:
sb = get_local_booter()
else:
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
result = await sb.shell.exec(command, background=background, env=env)
cwd: str | None = None
if is_local_runtime(context):
current_workspace_root = workspace_root(
context.context.event.unified_msg_origin
)
current_workspace_root.mkdir(parents=True, exist_ok=True)
cwd = str(current_workspace_root)
result = await sb.shell.exec(
command,
cwd=cwd,
background=background,
env=env,
)
return json.dumps(result)
except Exception as e:
return f"Error executing command: {str(e)}"

View File

@@ -1,5 +1,4 @@
from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool
from .fs import FileDownloadTool, FileUploadTool
from .neo_skills import (
AnnotateExecutionTool,
CreateSkillCandidateTool,
@@ -13,27 +12,20 @@ from .neo_skills import (
RollbackSkillReleaseTool,
SyncSkillReleaseTool,
)
from .python import LocalPythonTool, PythonTool
from .shell import ExecuteShellTool
__all__ = [
"BrowserExecTool",
"BrowserBatchExecTool",
"RunBrowserSkillTool",
"GetExecutionHistoryTool",
"AnnotateExecutionTool",
"CreateSkillPayloadTool",
"GetSkillPayloadTool",
"BrowserBatchExecTool",
"BrowserExecTool",
"CreateSkillCandidateTool",
"ListSkillCandidatesTool",
"CreateSkillPayloadTool",
"EvaluateSkillCandidateTool",
"PromoteSkillCandidateTool",
"GetExecutionHistoryTool",
"GetSkillPayloadTool",
"ListSkillCandidatesTool",
"ListSkillReleasesTool",
"PromoteSkillCandidateTool",
"RollbackSkillReleaseTool",
"RunBrowserSkillTool",
"SyncSkillReleaseTool",
"FileUploadTool",
"PythonTool",
"LocalPythonTool",
"ExecuteShellTool",
"FileDownloadTool",
]

View File

@@ -6,9 +6,14 @@ from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.tools.computer_tools.util import check_admin_permission
from astrbot.core.tools.registry import builtin_tool
from ..computer_client import get_booter
from .permissions import check_admin_permission
_SHIPYARD_NEO_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
}
def _to_json(data: Any) -> str:
@@ -29,6 +34,7 @@ async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> A
return browser
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class BrowserExecTool(FunctionTool):
name: str = "astrbot_execute_browser"
@@ -86,6 +92,7 @@ class BrowserExecTool(FunctionTool):
return f"Error executing browser command: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class BrowserBatchExecTool(FunctionTool):
name: str = "astrbot_execute_browser_batch"
@@ -150,6 +157,7 @@ class BrowserBatchExecTool(FunctionTool):
return f"Error executing browser batch command: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class RunBrowserSkillTool(FunctionTool):
name: str = "astrbot_run_browser_skill"

View File

@@ -7,10 +7,15 @@ from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from astrbot.core.tools.computer_tools.util import check_admin_permission
from astrbot.core.tools.registry import builtin_tool
from ..computer_client import get_booter
from .permissions import check_admin_permission
_SHIPYARD_NEO_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
}
def _to_jsonable(model_like: Any) -> Any:
@@ -64,6 +69,7 @@ class NeoSkillToolBase(FunctionTool):
return f"{self.error_prefix} {error_action}: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class GetExecutionHistoryTool(NeoSkillToolBase):
name: str = "astrbot_get_execution_history"
@@ -110,6 +116,7 @@ class GetExecutionHistoryTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class AnnotateExecutionTool(NeoSkillToolBase):
name: str = "astrbot_annotate_execution"
@@ -147,6 +154,7 @@ class AnnotateExecutionTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class CreateSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_payload"
@@ -194,6 +202,7 @@ class CreateSkillPayloadTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class GetSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_get_skill_payload"
@@ -220,6 +229,7 @@ class GetSkillPayloadTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class CreateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_candidate"
@@ -273,6 +283,7 @@ class CreateSkillCandidateTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class ListSkillCandidatesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_candidates"
@@ -310,6 +321,7 @@ class ListSkillCandidatesTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class EvaluateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_evaluate_skill_candidate"
@@ -350,6 +362,7 @@ class EvaluateSkillCandidateTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class PromoteSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_promote_skill_candidate"
@@ -420,6 +433,7 @@ class PromoteSkillCandidateTool(NeoSkillToolBase):
return f"Error promoting skill candidate: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class ListSkillReleasesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_releases"
@@ -460,6 +474,7 @@ class ListSkillReleasesTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class RollbackSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_rollback_skill_release"
@@ -486,6 +501,7 @@ class RollbackSkillReleaseTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class SyncSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_sync_skill_release"

View File

@@ -1,5 +1,29 @@
import re
from pathlib import Path
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path
def normalize_umo_for_workspace(umo: str) -> str:
normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", umo.strip())
return normalized or "unknown"
def workspace_root(umo: str) -> Path:
"""Root directory for relative paths in local runtime"""
normalized_umo = normalize_umo_for_workspace(umo)
return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False)
def is_local_runtime(context: ContextWrapper[AstrAgentContext]) -> bool:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
return runtime == "local"
def check_admin_permission(

View File

@@ -9,6 +9,10 @@ from astrbot.core.agent.tool import FunctionTool, ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.tools.registry import builtin_tool
_CRON_TOOL_CONFIG = {
"provider_settings.proactive_capability.add_cron_tools": True,
}
def _extract_job_session(job: Any) -> str | None:
payload = getattr(job, "payload", None)
@@ -24,7 +28,7 @@ def _parse_run_at(run_at: Any) -> datetime | None:
return datetime.fromisoformat(str(run_at))
@builtin_tool
@builtin_tool(config=_CRON_TOOL_CONFIG)
@dataclass
class FutureTaskTool(FunctionTool[AstrAgentContext]):
name: str = "future_task"

View File

@@ -9,6 +9,10 @@ from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.star.context import Context
from astrbot.core.tools.registry import builtin_tool
_KNOWLEDGE_BASE_TOOL_CONFIG = {
"kb_agentic_mode": True,
}
def check_all_kb(kb_list: list[KBHelper | None]) -> bool:
"""检查是否所有的知识库都为空"""
@@ -83,7 +87,7 @@ async def retrieve_knowledge_base(
return None
@builtin_tool
@builtin_tool(config=_KNOWLEDGE_BASE_TOOL_CONFIG)
@dataclass
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
name: str = "astr_kb_search"

View File

@@ -1,13 +1,16 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from importlib import import_module
from typing import TypeVar
from typing import Any, TypeVar
from astrbot.core.agent.tool import FunctionTool
TFunctionTool = TypeVar("TFunctionTool", bound=type[FunctionTool])
_BUILTIN_TOOL_MODULES = (
"astrbot.core.tools.computer_tools",
"astrbot.core.tools.cron_tools",
"astrbot.core.tools.knowledge_base_tools",
"astrbot.core.tools.message_tools",
@@ -17,6 +20,182 @@ _BUILTIN_TOOL_MODULES = (
_builtin_tool_classes_by_name: dict[str, type[FunctionTool]] = {}
_builtin_tool_names_by_class: dict[type[FunctionTool], str] = {}
_builtin_tools_loaded = False
_MISSING = object()
@dataclass(frozen=True)
class BuiltinToolConfigCondition:
key: str
operator: str
expected: Any = None
message: str | None = None
def evaluate(self, config: dict[str, Any]) -> dict[str, Any]:
actual = _get_config_value(config, self.key)
if self.operator == "equals":
matched = actual == self.expected
elif self.operator == "in":
expected_values = tuple(self.expected or ())
matched = actual in expected_values
elif self.operator == "truthy":
matched = bool(actual)
elif self.operator == "custom":
matched = bool(self.expected)
else:
raise ValueError(
f"Unsupported builtin tool config operator: {self.operator}"
)
return {
"key": self.key,
"operator": self.operator,
"expected": _json_safe(self.expected),
"actual": _json_safe(None if actual is _MISSING else actual),
"matched": matched,
"message": self.message,
}
@dataclass(frozen=True)
class BuiltinToolConfigRule:
conditions: tuple[BuiltinToolConfigCondition, ...] = ()
evaluator: Callable[[dict[str, Any]], list[dict[str, Any]]] | None = None
def evaluate(self, config: dict[str, Any]) -> list[dict[str, Any]]:
if self.evaluator is not None:
return self.evaluator(config)
return [condition.evaluate(config) for condition in self.conditions]
def _get_config_value(config: dict[str, Any], key_path: str) -> Any:
current: Any = config
for segment in key_path.split("."):
if not isinstance(current, dict) or segment not in current:
return _MISSING
current = current[segment]
return current
def _json_safe(value: Any) -> Any:
if isinstance(value, tuple):
return [_json_safe(item) for item in value]
if isinstance(value, list):
return [_json_safe(item) for item in value]
if isinstance(value, dict):
return {key: _json_safe(val) for key, val in value.items()}
return value
def _equals(key: str, expected: Any) -> BuiltinToolConfigCondition:
return BuiltinToolConfigCondition(key=key, operator="equals", expected=expected)
def _in(key: str, expected: tuple[Any, ...]) -> BuiltinToolConfigCondition:
return BuiltinToolConfigCondition(key=key, operator="in", expected=expected)
def _custom_condition(key: str, *, matched: bool, message: str) -> dict[str, Any]:
return {
"key": key,
"operator": "custom",
"expected": None,
"actual": None,
"matched": matched,
"message": message,
}
def _build_rule_from_config_map(
config_map: dict[str, Any],
) -> BuiltinToolConfigRule:
conditions: list[BuiltinToolConfigCondition] = []
for key, expected in config_map.items():
if isinstance(expected, tuple):
conditions.append(_in(key, expected))
else:
conditions.append(_equals(key, expected))
return BuiltinToolConfigRule(conditions=tuple(conditions))
def _evaluate_send_message_tool(config: dict[str, Any]) -> list[dict[str, Any]]:
platform_configs = config.get("platform", [])
if not isinstance(platform_configs, list):
return [
_custom_condition(
"platform",
matched=False,
message="No enabled platform in this config supports proactive messaging.",
)
]
for platform_cfg in platform_configs:
if not isinstance(platform_cfg, dict):
continue
if platform_cfg.get("enable", False) is False:
continue
platform_type = str(platform_cfg.get("type", "")).strip()
platform_id = str(platform_cfg.get("id", "")).strip() or platform_type
if not platform_type:
continue
if platform_type in {"wecom", "weixin_official_account"}:
continue
if platform_type == "wecom_ai_bot":
webhook = str(platform_cfg.get("msg_push_webhook_url", "")).strip()
if not webhook:
continue
return [
_custom_condition(
"platform[].type",
matched=True,
message=(
f"Enabled platform `{platform_id}` uses `wecom_ai_bot`, which supports proactive messaging "
"when `platform[].msg_push_webhook_url` is configured."
),
),
BuiltinToolConfigCondition(
key="platform[].msg_push_webhook_url",
operator="truthy",
).evaluate({"platform[]": {"msg_push_webhook_url": webhook}}),
]
return [
_custom_condition(
"platform[].type",
matched=True,
message=(
f"Enabled platform `{platform_id}` (`{platform_type}`) supports proactive messaging."
),
)
]
return [
_custom_condition(
"platform",
matched=False,
message="No enabled platform in this config supports proactive messaging.",
)
]
_BUILTIN_TOOL_CONFIG_RULES: dict[str, BuiltinToolConfigRule] = {}
def _register_builtin_tool_config_rule(
tool_names: tuple[str, ...],
rule: BuiltinToolConfigRule,
) -> None:
for tool_name in tool_names:
_BUILTIN_TOOL_CONFIG_RULES[tool_name] = rule
_register_builtin_tool_config_rule(
("send_message_to_user",),
BuiltinToolConfigRule(evaluator=_evaluate_send_message_tool),
)
def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str:
@@ -34,18 +213,29 @@ def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str:
)
def builtin_tool(tool_cls: TFunctionTool) -> TFunctionTool:
tool_name = _resolve_builtin_tool_name(tool_cls)
existing = _builtin_tool_classes_by_name.get(tool_name)
if existing is not None and existing is not tool_cls:
raise ValueError(
f"Builtin tool name conflict detected: {tool_name} is already registered by "
f"{existing.__module__}.{existing.__name__}.",
)
def builtin_tool(
tool_cls: TFunctionTool | None = None,
*,
config: dict[str, Any] | None = None,
) -> TFunctionTool | Callable[[TFunctionTool], TFunctionTool]:
def _register(cls: TFunctionTool) -> TFunctionTool:
tool_name = _resolve_builtin_tool_name(cls)
existing = _builtin_tool_classes_by_name.get(tool_name)
if existing is not None and existing is not cls:
raise ValueError(
f"Builtin tool name conflict detected: {tool_name} is already registered by "
f"{existing.__module__}.{existing.__name__}.",
)
_builtin_tool_classes_by_name[tool_name] = tool_cls
_builtin_tool_names_by_class[tool_cls] = tool_name
return tool_cls
_builtin_tool_classes_by_name[tool_name] = cls
_builtin_tool_names_by_class[cls] = tool_name
if config is not None:
_BUILTIN_TOOL_CONFIG_RULES[tool_name] = _build_rule_from_config_map(config)
return cls
if tool_cls is None:
return _register
return _register(tool_cls)
def ensure_builtin_tools_loaded() -> None:
@@ -74,9 +264,64 @@ def iter_builtin_tool_classes() -> tuple[type[FunctionTool], ...]:
return tuple(_builtin_tool_classes_by_name.values())
def get_builtin_tool_config_rule(name: str) -> BuiltinToolConfigRule | None:
ensure_builtin_tools_loaded()
return _BUILTIN_TOOL_CONFIG_RULES.get(name)
def get_builtin_tool_config_statuses(
tool_name: str,
config_entries: list[dict[str, Any]],
) -> list[dict[str, Any]]:
rule = get_builtin_tool_config_rule(tool_name)
if rule is None:
return []
statuses: list[dict[str, Any]] = []
for entry in config_entries:
config = entry.get("config")
if not isinstance(config, dict):
continue
conditions = rule.evaluate(config)
enabled = bool(conditions) and all(
bool(condition.get("matched")) for condition in conditions
)
statuses.append(
{
"conf_id": entry.get("conf_id"),
"conf_name": entry.get("conf_name"),
"enabled": enabled,
"matched_conditions": [
condition for condition in conditions if condition.get("matched")
],
"failed_conditions": [
condition
for condition in conditions
if not condition.get("matched")
],
}
)
return statuses
def get_builtin_tool_config_tags(
tool_name: str,
config_entries: list[dict[str, Any]],
) -> list[dict[str, Any]]:
return [
status
for status in get_builtin_tool_config_statuses(tool_name, config_entries)
if status["enabled"]
]
__all__ = [
"builtin_tool",
"ensure_builtin_tools_loaded",
"get_builtin_tool_config_rule",
"get_builtin_tool_config_statuses",
"get_builtin_tool_config_tags",
"get_builtin_tool_class",
"get_builtin_tool_name",
"iter_builtin_tool_classes",

View File

@@ -20,6 +20,22 @@ WEB_SEARCH_TOOL_NAMES = [
"web_search_bocha",
"web_search_brave",
]
_TAVILY_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "tavily",
}
_BOCHA_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "bocha",
}
_BRAVE_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "brave",
}
_BAIDU_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "baidu_ai_search",
}
@std_dataclass
@@ -276,7 +292,7 @@ async def _baidu_search(
]
@builtin_tool
@builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class TavilyWebSearchTool(FunctionTool[AstrAgentContext]):
name: str = "web_search_tavily"
@@ -359,7 +375,7 @@ class TavilyWebSearchTool(FunctionTool[AstrAgentContext]):
return _search_result_payload(results)
@builtin_tool
@builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class TavilyExtractWebPageTool(FunctionTool[AstrAgentContext]):
name: str = "tavily_extract_web_page"
@@ -406,7 +422,7 @@ class TavilyExtractWebPageTool(FunctionTool[AstrAgentContext]):
return ret or "Error: Tavily web searcher does not return any results."
@builtin_tool
@builtin_tool(config=_BOCHA_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class BochaWebSearchTool(FunctionTool[AstrAgentContext]):
name: str = "web_search_bocha"
@@ -470,7 +486,7 @@ class BochaWebSearchTool(FunctionTool[AstrAgentContext]):
return _search_result_payload(results)
@builtin_tool
@builtin_tool(config=_BRAVE_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class BraveWebSearchTool(FunctionTool[AstrAgentContext]):
name: str = "web_search_brave"
@@ -528,7 +544,7 @@ class BraveWebSearchTool(FunctionTool[AstrAgentContext]):
return _search_result_payload(results)
@builtin_tool
@builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class BaiduWebSearchTool(FunctionTool[AstrAgentContext]):
name: str = "web_search_baidu"

View File

@@ -1,32 +1,33 @@
"""Astrbot统一路径获取
"""Centralized AstrBot path helpers.
项目路径:固定为源码所在路径
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
数据目录路径:固定为根目录下的 data 目录
配置文件路径:固定为数据目录下的 config 目录
插件目录路径:固定为数据目录下的 plugins 目录
插件数据目录路径:固定为数据目录下的 plugin_data 目录
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
临时文件目录路径:固定为数据目录下的 temp 目录
Skills 目录路径:固定为数据目录下的 skills 目录
第三方依赖目录路径:固定为数据目录下的 site-packages 目录
Project path:
- Fixed to the source tree location.
Root path:
- Defaults to the current working directory.
- Can be overridden with the ``ASTRBOT_ROOT`` environment variable.
Data subdirectories:
- Most runtime data lives under ``<root>/data``.
- A few tool-runtime files intentionally live under the system temporary
directory as ``.astrbot``.
"""
import os
import tempfile
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
def get_astrbot_path() -> str:
"""获取Astrbot项目路径"""
"""Return the AstrBot project source path."""
return os.path.realpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"),
)
def get_astrbot_root() -> str:
"""获取Astrbot根目录路径"""
"""Return the AstrBot root directory."""
if path := os.environ.get("ASTRBOT_ROOT"):
return os.path.realpath(path)
if is_packaged_desktop_runtime():
@@ -35,55 +36,65 @@ def get_astrbot_root() -> str:
def get_astrbot_data_path() -> str:
"""获取Astrbot数据目录路径"""
"""Return the AstrBot data directory path."""
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
def get_astrbot_config_path() -> str:
"""获取Astrbot配置文件路径"""
"""Return the AstrBot config directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
def get_astrbot_plugin_path() -> str:
"""获取Astrbot插件目录路径"""
"""Return the AstrBot plugin directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
def get_astrbot_plugin_data_path() -> str:
"""获取Astrbot插件数据目录路径"""
"""Return the AstrBot plugin data directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
def get_astrbot_t2i_templates_path() -> str:
"""获取Astrbot T2I 模板目录路径"""
"""Return the AstrBot T2I templates directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
def get_astrbot_webchat_path() -> str:
"""获取Astrbot WebChat 数据目录路径"""
"""Return the AstrBot WebChat data directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
def get_astrbot_temp_path() -> str:
"""获取Astrbot临时文件目录路径"""
"""Return the AstrBot temporary data directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
def get_astrbot_skills_path() -> str:
"""获取Astrbot Skills 目录路径"""
"""Return the AstrBot skills directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "skills"))
def get_astrbot_workspaces_path() -> str:
"""Return the AstrBot workspaces directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "workspaces"))
def get_astrbot_system_tmp_path() -> str:
"""Return the shared system temporary directory used by local tools."""
return os.path.realpath(os.path.join(tempfile.gettempdir(), ".astrbot"))
def get_astrbot_site_packages_path() -> str:
"""获取Astrbot第三方依赖目录路径"""
"""Return the AstrBot third-party site-packages directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "site-packages"))
def get_astrbot_knowledge_base_path() -> str:
"""获取Astrbot知识库根目录路径"""
"""Return the AstrBot knowledge base root path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
def get_astrbot_backups_path() -> str:
"""获取Astrbot备份目录路径"""
"""Return the AstrBot backups directory path."""
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))

View File

@@ -6,6 +6,7 @@ from astrbot.core import logger
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star import star_map
from astrbot.core.tools.registry import get_builtin_tool_config_statuses
from .route import Response, Route, RouteContext
@@ -434,13 +435,36 @@ class ToolsRoute(Route):
if tool.name not in existing_names:
tools.append(tool)
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
conf_name_map = {conf["id"]: conf["name"] for conf in conf_list}
config_entries = []
for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items():
config_entries.append(
{
"conf_id": conf_id,
"conf_name": conf_name_map.get(conf_id, conf_id),
"config": conf,
}
)
tools_dict = []
for tool in tools:
readonly = False
builtin_config_statuses = []
builtin_config_tags = []
if self.tool_mgr.is_builtin_tool(tool.name):
origin = "builtin"
origin_name = "AstrBot Core"
readonly = True
builtin_config_statuses = get_builtin_tool_config_statuses(
tool.name,
config_entries,
)
builtin_config_tags = [
status
for status in builtin_config_statuses
if status["enabled"]
]
elif isinstance(tool, MCPTool):
origin = "mcp"
origin_name = tool.mcp_server_name
@@ -462,6 +486,8 @@ class ToolsRoute(Route):
"origin": origin,
"origin_name": origin_name,
"readonly": readonly,
"builtin_config_statuses": builtin_config_statuses,
"builtin_config_tags": builtin_config_tags,
}
tools_dict.append(tool_info)
return Response().ok(data=tools_dict).__dict__

View File

@@ -1,7 +1,7 @@
<script setup lang="ts">
import { computed } from 'vue';
import { useModuleI18n } from '@/i18n/composables';
import type { ToolItem } from '../types';
import type { BuiltinToolConfigTag, ToolConfigCondition, ToolItem } from '../types';
const { tm: tmTool } = useModuleI18n('features/tooluse');
@@ -15,7 +15,7 @@ const emit = defineEmits<{
}>();
const toolHeaders = computed(() => [
{ title: tmTool('functionTools.title'), key: 'name', minWidth: '240px' },
{ title: tmTool('functionTools.title'), key: 'name', minWidth: '320px' },
{ title: tmTool('functionTools.description'), key: 'description' },
{ title: tmTool('functionTools.table.origin'), key: 'origin', sortable: false, width: '120px' },
{ title: tmTool('functionTools.table.originName'), key: 'origin_name', sortable: false, width: '160px' },
@@ -23,6 +23,52 @@ const toolHeaders = computed(() => [
]);
const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.properties || {});
const formatConfigValue = (value: unknown) => {
if (Array.isArray(value)) {
return value.map(item => String(item)).join(', ');
}
if (typeof value === 'boolean') {
return value ? 'true' : 'false';
}
if (value === null || value === undefined || value === '') {
return '-';
}
return String(value);
};
const formatCondition = (condition: ToolConfigCondition) => {
if (condition.message) {
return condition.message;
}
switch (condition.operator) {
case 'truthy':
return tmTool('functionTools.configTags.conditions.truthy', {
key: condition.key
});
case 'equals':
return tmTool('functionTools.configTags.conditions.equals', {
key: condition.key,
expected: formatConfigValue(condition.expected)
});
case 'in':
return tmTool('functionTools.configTags.conditions.in', {
key: condition.key,
expected: formatConfigValue(condition.expected)
});
default:
return tmTool('functionTools.configTags.conditions.fallback', {
key: condition.key,
actual: formatConfigValue(condition.actual)
});
}
};
const enabledConfigTags = (tool: ToolItem): BuiltinToolConfigTag[] => {
if (tool.origin !== 'builtin') return [];
return (tool.builtin_config_tags || []).filter(tag => tag.enabled);
};
</script>
<template>
@@ -38,7 +84,39 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro
>
<template #item.name="{ item }">
<div class="py-2">
<div class="tool-name text-body-2 font-weight-medium">{{ item.name }}</div>
<div class="d-flex flex-wrap align-center ga-1">
<div class="tool-name text-body-2 font-weight-medium">{{ item.name }}</div>
<v-tooltip
v-for="tag in enabledConfigTags(item)"
:key="`${item.name}-${tag.conf_id}`"
location="top"
>
<template #activator="{ props: tooltipProps }">
<v-chip
v-bind="tooltipProps"
size="x-small"
variant="tonal"
color="secondary"
class="text-caption font-weight-medium"
>
{{ tag.conf_name }}
</v-chip>
</template>
<div class="tool-config-tooltip">
<div class="text-body-2 font-weight-medium mb-2">
{{ tmTool('functionTools.configTags.tooltipTitle', { config: tag.conf_name }) }}
</div>
<div
v-for="(condition, index) in tag.matched_conditions"
:key="`${tag.conf_id}-${index}-${condition.key}`"
class="text-body-2 text-medium-emphasis mb-1"
>
{{ formatCondition(condition) }}
</div>
</div>
</v-tooltip>
</div>
</div>
</template>
@@ -135,4 +213,16 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro
font-size: 0.9rem;
line-height: 1.35;
}
.tool-config-tooltip {
max-width: 360px;
padding: 4px 0;
color: rgba(255, 255, 255, 0.92);
}
.tool-config-tooltip :deep(.text-body-2),
.tool-config-tooltip :deep(.text-medium-emphasis),
.tool-config-tooltip :deep(.font-weight-medium) {
color: inherit !important;
}
</style>

View File

@@ -89,6 +89,23 @@ export interface ToolParameter {
description?: string;
}
export interface ToolConfigCondition {
key: string;
operator: 'truthy' | 'equals' | 'in' | 'custom' | string;
expected?: unknown;
actual?: unknown;
matched: boolean;
message?: string | null;
}
export interface BuiltinToolConfigTag {
conf_id: string;
conf_name: string;
enabled: boolean;
matched_conditions: ToolConfigCondition[];
failed_conditions: ToolConfigCondition[];
}
/** MCP/函数工具对象 */
export interface ToolItem {
name: string;
@@ -100,4 +117,6 @@ export interface ToolItem {
};
origin?: string;
origin_name?: string;
builtin_config_statuses?: BuiltinToolConfigTag[];
builtin_config_tags?: BuiltinToolConfigTag[];
}

View File

@@ -90,31 +90,52 @@
<div v-if="filteredTools.length > 0" class="tools-selection">
<v-virtual-scroll :items="filteredTools" height="300" item-height="72">
<template v-slot:default="{ item }">
<v-list-item :key="item.name" density="comfortable"
@click="toggleTool(item.name)">
<template v-slot:prepend>
<v-checkbox-btn :model-value="isToolSelected(item.name)"
@click.stop="toggleTool(item.name)" />
<v-tooltip
:disabled="!isBuiltinTool(item)"
location="top"
>
<template v-slot:activator="{ props: tooltipProps }">
<div v-bind="tooltipProps">
<v-list-item
:key="item.name"
density="comfortable"
:disabled="isBuiltinTool(item)"
@click="toggleTool(item.name)"
>
<template v-slot:prepend>
<v-checkbox-btn
v-if="!isBuiltinTool(item)"
:model-value="isToolSelected(item.name)"
@click.stop="toggleTool(item.name)"
/>
<div
v-else
class="builtin-tool-checkbox-placeholder"
/>
</template>
<v-list-item-title>
{{ item.name }}
<v-chip v-if="item.origin" size="x-small" color="info" class="mr-2"
variant="tonal">
{{ item.origin }}
</v-chip>
<v-chip v-if="item.origin_name" size="x-small" color="info"
variant="outlined">
{{ item.origin_name }}
</v-chip>
</v-list-item-title>
<v-list-item-subtitle v-if="item.description">
{{ truncateText(item.description, 100) }}
</v-list-item-subtitle>
</v-list-item>
</div>
</template>
<v-list-item-title>
{{ item.name }}
<v-chip v-if="item.origin" size="x-small" color="info" class="mr-2"
variant="tonal">
{{ item.origin }}
</v-chip>
<v-chip v-if="item.origin_name" size="x-small" color="info"
variant="outlined">
{{ item.origin_name }}
</v-chip>
</v-list-item-title>
<v-list-item-subtitle v-if="item.description">
{{ truncateText(item.description, 100) }}
</v-list-item-subtitle>
</v-list-item>
<span>{{ tm('form.builtinToolDisabledHint') }}</span>
</v-tooltip>
</template>
</v-virtual-scroll>
</div>
@@ -155,11 +176,26 @@
</h4>
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
<v-chip v-for="toolName in personaForm.tools" :key="toolName" size="small"
color="primary" variant="tonal" closable
@click:close="removeTool(toolName)">
{{ toolName }}
</v-chip>
<v-tooltip
v-for="toolName in personaForm.tools"
:key="toolName"
:disabled="!isBuiltinToolName(toolName)"
location="top"
>
<template v-slot:activator="{ props: tooltipProps }">
<v-chip
v-bind="tooltipProps"
size="small"
color="primary"
variant="tonal"
:closable="!isBuiltinToolName(toolName)"
@click:close="removeTool(toolName)"
>
{{ toolName }}
</v-chip>
</template>
<span>{{ tm('form.builtinToolDisabledHint') }}</span>
</v-tooltip>
</div>
<div v-else class="text-body-2 text-medium-emphasis">
{{ tm('form.noToolsSelected') }}
@@ -712,6 +748,9 @@ export default {
},
toggleTool(toolName) {
if (this.isBuiltinToolName(toolName)) {
return;
}
// 如果当前是全选状态,需要先转换为具体的工具列表
if (this.personaForm.tools === null) {
// 如果是全选状态,点击某个工具表示要取消选择该工具
@@ -735,6 +774,9 @@ export default {
},
removeTool(toolName) {
if (this.isBuiltinToolName(toolName)) {
return;
}
// 如果当前是全选状态,需要先转换为具体的工具列表
if (this.personaForm.tools === null) {
// 创建一个包含所有工具的数组,然后移除指定工具
@@ -784,6 +826,14 @@ export default {
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
},
isBuiltinTool(tool) {
return tool?.origin === 'builtin' || tool?.readonly === true;
},
isBuiltinToolName(toolName) {
return this.availableTools.some(tool => tool.name === toolName && this.isBuiltinTool(tool));
},
getDialogRules(index) {
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
return [
@@ -859,6 +909,12 @@ export default {
overflow-y: auto;
}
.builtin-tool-checkbox-placeholder {
width: 40px;
height: 40px;
flex: 0 0 40px;
}
.skills-selection {
max-height: 300px;
overflow-y: auto;

View File

@@ -117,7 +117,9 @@ const defaultPersonaData = {
const normalizedTools = computed(() => (Array.isArray(personaData.value?.tools) ? personaData.value.tools : []))
const normalizedSkills = computed(() => (Array.isArray(personaData.value?.skills) ? personaData.value.skills : []))
const allToolsCount = computed(() => Object.keys(toolMetaMap.value).length)
const allToolsCount = computed(() =>
Object.values(toolMetaMap.value).filter((tool) => tool.origin !== 'builtin').length
)
const allSkillsCount = computed(() => availableSkills.value.length)
const resolvedTools = computed(() =>
normalizedTools.value.map((toolName) => {

View File

@@ -35,6 +35,7 @@
"mcpServersQuickSelect": "MCP Servers Quick Select",
"searchTools": "Search Tools",
"selectedTools": "Selected Tools",
"builtinToolDisabledHint": "Builtin tools cannot be enabled or disabled here yet. Please enable or disable the corresponding config items in the config file.",
"noToolsAvailable": "No tools available",
"noToolsFound": "No matching tools found",
"loadingTools": "Loading tools...",

View File

@@ -47,6 +47,15 @@
"originName": "Origin Name",
"readonly": "Read-only",
"actions": "Actions"
},
"configTags": {
"tooltipTitle": "This tool is enabled in config file {config} because:",
"conditions": {
"truthy": "{key} is enabled",
"equals": "{key} = {expected}",
"in": "{key} matched {expected}",
"fallback": "{key} is currently {actual}"
}
}
},
"marketplace": {

View File

@@ -35,6 +35,7 @@
"mcpServersQuickSelect": "Быстрый выбор MCP серверов",
"searchTools": "Поиск инструментов",
"selectedTools": "Выбранные инструменты",
"builtinToolDisabledHint": "Встроенные инструменты пока нельзя включать или выключать здесь. Измените соответствующие параметры в файле конфигурации.",
"noToolsAvailable": "Нет доступных инструментов",
"noToolsFound": "Инструменты не найдены",
"loadingTools": "Загрузка инструментов...",
@@ -143,4 +144,4 @@
"success": "Объект перемещен",
"error": "Ошибка перемещения"
}
}
}

View File

@@ -47,6 +47,15 @@
"originName": "Имя источника",
"readonly": "Только чтение",
"actions": "Действия"
},
"configTags": {
"tooltipTitle": "Этот инструмент включен в файле конфигурации {config}, потому что:",
"conditions": {
"truthy": "{key} включен",
"equals": "{key} = {expected}",
"in": "{key} соответствует {expected}",
"fallback": "Текущее значение {key}: {actual}"
}
}
},
"marketplace": {

View File

@@ -35,6 +35,7 @@
"mcpServersQuickSelect": "MCP 服务器快速选择",
"searchTools": "搜索工具",
"selectedTools": "已选择的工具",
"builtinToolDisabledHint": "暂不支持在这里启用和停用内置工具,请在配置文件中启用和停用工具对应的配置项。",
"noToolsAvailable": "暂无可用工具",
"noToolsFound": "未找到匹配的工具",
"loadingTools": "正在加载工具...",

View File

@@ -47,6 +47,15 @@
"originName": "来源名称",
"readonly": "只读",
"actions": "操作"
},
"configTags": {
"tooltipTitle": "该工具在配置文件 {config} 中启用,因为:",
"conditions": {
"truthy": "启用了 {key}",
"equals": "{key} = {expected}",
"in": "{key} 命中了 {expected}",
"fallback": "{key} 当前值为 {actual}"
}
}
},
"marketplace": {

View File

@@ -15,13 +15,11 @@ dependencies = [
"aiosqlite>=0.21.0",
"anthropic>=0.51.0",
"apscheduler>=3.11.0",
"beautifulsoup4>=4.13.4",
"certifi>=2025.4.26",
"chardet~=5.1.0",
"loguru>=0.7.2",
"cryptography>=44.0.3",
"dashscope>=1.23.2",
"defusedxml>=0.7.1",
"deprecated>=1.2.18",
"dingtalk-stream>=0.22.1",
"docstring-parser>=0.16",
@@ -30,7 +28,6 @@ dependencies = [
"google-genai>=1.56.0",
"httpx[socks]>=0.28.1",
"lark-oapi>=1.4.15",
"lxml-html-clean>=0.4.2",
"mcp>=1.8.0",
"openai>=1.78.0",
"ormsgpack>=1.9.1",
@@ -45,7 +42,6 @@ dependencies = [
"python-telegram-bot>=22.6",
"qq-botpy>=1.2.1",
"quart>=0.20.0",
"readability-lxml>=0.8.4.1",
"silk-python>=0.2.6",
"slack-sdk>=3.35.0",
"sqlalchemy[asyncio]>=2.0.41",
@@ -68,6 +64,7 @@ dependencies = [
"python-socks>=2.8.0",
"pysocks>=1.7.1",
"packaging>=24.2",
"python-ripgrep==0.0.9",
]
[dependency-groups]

View File

@@ -4,13 +4,11 @@ aiohttp>=3.11.18
aiosqlite>=0.21.0
anthropic>=0.51.0
apscheduler>=3.11.0
beautifulsoup4>=4.13.4
certifi>=2025.4.26
chardet~=5.1.0
loguru>=0.7.2
cryptography>=44.0.3
dashscope>=1.23.2
defusedxml>=0.7.1
deprecated>=1.2.18
dingtalk-stream>=0.22.1
docstring-parser>=0.16
@@ -19,7 +17,6 @@ filelock>=3.18.0
google-genai>=1.56.0
httpx[socks]>=0.28.1
lark-oapi>=1.4.15
lxml-html-clean>=0.4.2
mcp>=1.8.0
openai>=1.78.0
ormsgpack>=1.9.1
@@ -34,7 +31,6 @@ python-telegram-bot>=22.6
qq-botpy>=1.2.1
python-socks>=2.8.0
quart>=0.20.0
readability-lxml>=0.8.4.1
silk-python>=0.2.6
slack-sdk>=3.35.0
sqlalchemy[asyncio]>=2.0.41
@@ -56,4 +52,5 @@ tenacity>=9.1.2
shipyard-python-sdk>=0.2.4
shipyard-neo-sdk>=0.2.0
packaging>=24.2
qrcode>=8.2
qrcode>=8.2
python-ripgrep==0.0.9

View File

@@ -0,0 +1,283 @@
from __future__ import annotations
import base64
import io
import zipfile
from types import SimpleNamespace
from typing import Any
import pytest
from mcp.types import CallToolResult, ImageContent
from PIL import Image
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.computer import file_read_utils
from astrbot.core.computer.booters.local import LocalBooter
from astrbot.core.tools.computer_tools import fs as fs_tools
from astrbot.core.tools.computer_tools import util as computer_util
def _make_context(
*,
require_admin: bool = True,
role: str = "admin",
runtime: str = "local",
umo: str = "qq:friend:user-1",
) -> ContextWrapper:
config_holder = SimpleNamespace(
get_config=lambda umo=None: {
"provider_settings": {
"computer_use_require_admin": require_admin,
"computer_use_runtime": runtime,
}
}
)
event = SimpleNamespace(
role=role,
unified_msg_origin=umo,
get_sender_id=lambda: "user-1",
)
astr_ctx = SimpleNamespace(context=config_holder, event=event)
return ContextWrapper(context=astr_ctx)
def _setup_local_fs_tools(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
*,
umo: str = "qq:friend:user-1",
) -> Any:
workspaces_root = tmp_path / "workspaces"
skills_root = tmp_path / "skills"
temp_root = tmp_path / "temp"
workspaces_root.mkdir()
skills_root.mkdir()
temp_root.mkdir()
monkeypatch.setattr(
computer_util,
"get_astrbot_workspaces_path",
lambda: str(workspaces_root),
)
monkeypatch.setattr(
fs_tools,
"get_astrbot_skills_path",
lambda: str(skills_root),
)
monkeypatch.setattr(
fs_tools,
"get_astrbot_temp_path",
lambda: str(temp_root),
)
monkeypatch.setattr(
file_read_utils,
"get_astrbot_temp_path",
lambda: str(temp_root),
)
booter = LocalBooter()
async def _fake_get_booter(_ctx, _umo):
return booter
monkeypatch.setattr(fs_tools, "get_booter", _fake_get_booter)
normalized_umo = computer_util.normalize_umo_for_workspace(umo)
workspace = workspaces_root / normalized_umo
workspace.mkdir(parents=True, exist_ok=True)
return workspace
def _make_large_text() -> str:
return "".join(f"line-{index:05d}-{'x' * 48}\n" for index in range(6000))
def test_detect_text_encoding_allows_utf8_probe_cut_mid_character():
sample = '{"results": ["中文内容"]}'.encode()[:-1]
assert file_read_utils.detect_text_encoding(sample) in {"utf-8", "utf-8-sig"}
@pytest.mark.asyncio
async def test_file_read_tool_rejects_large_full_text_read_before_local_stream_read(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
large_file = workspace / "large.txt"
large_file.write_text(_make_large_text(), encoding="utf-8")
async def _unexpected_read(*args, **kwargs):
raise AssertionError("full file read should be rejected before streaming")
monkeypatch.setattr(file_read_utils, "read_local_text_range", _unexpected_read)
result = await fs_tools.FileReadTool().call(
_make_context(),
path="large.txt",
)
assert "text file exceeds 262144 bytes" in result
assert "Use `offset` and `limit`" in result
@pytest.mark.asyncio
async def test_file_read_tool_allows_partial_read_for_large_text_file(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
large_file = workspace / "large.txt"
lines = [f"line-{index:05d}\n" for index in range(50000)]
large_file.write_text("".join(lines), encoding="utf-8")
result = await fs_tools.FileReadTool().call(
_make_context(),
path="large.txt",
offset=1000,
limit=3,
)
assert result == "".join(lines[1000:1003])
@pytest.mark.asyncio
async def test_file_read_tool_returns_image_call_tool_result_for_images(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
image_path = workspace / "sample.png"
Image.new("RGB", (32, 16), color=(255, 0, 0)).save(image_path, format="PNG")
result = await fs_tools.FileReadTool().call(
_make_context(),
path="sample.png",
)
assert isinstance(result, CallToolResult)
assert len(result.content) == 1
assert isinstance(result.content[0], ImageContent)
assert result.content[0].mimeType == "image/jpeg"
assert base64.b64decode(result.content[0].data).startswith(b"\xff\xd8\xff")
@pytest.mark.asyncio
async def test_file_read_tool_treats_svg_as_text(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
svg_path = workspace / "shape.svg"
svg_text = (
"<svg xmlns='http://www.w3.org/2000/svg'><rect width='10' height='10'/></svg>"
)
svg_path.write_text(svg_text, encoding="utf-8")
result = await fs_tools.FileReadTool().call(
_make_context(),
path="shape.svg",
)
assert result == svg_text
@pytest.mark.asyncio
async def test_file_read_tool_reads_pdf_via_parser(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
pdf_path = workspace / "doc.pdf"
pdf_path.write_bytes(b"%PDF-1.7\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<<>>\nendobj\n")
async def _fake_parse_pdf(_file_bytes: bytes, _file_name: str) -> str:
return "page-1\npage-2\n"
monkeypatch.setattr(file_read_utils, "_parse_local_pdf_text", _fake_parse_pdf)
result = await fs_tools.FileReadTool().call(
_make_context(),
path="doc.pdf",
)
assert result == "page-1\npage-2\n"
@pytest.mark.asyncio
async def test_file_read_tool_reads_docx_via_parser_and_magic(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
docx_path = workspace / "report.bin"
buffer = io.BytesIO()
with zipfile.ZipFile(buffer, mode="w") as archive:
archive.writestr("[Content_Types].xml", "<Types/>")
archive.writestr("word/document.xml", "<w:document/>")
docx_path.write_bytes(buffer.getvalue())
async def _fake_parse_docx(_file_bytes: bytes, _file_name: str) -> str:
return "doc-line-1\ndoc-line-2\n"
monkeypatch.setattr(file_read_utils, "_parse_local_docx_text", _fake_parse_docx)
result = await fs_tools.FileReadTool().call(
_make_context(),
path="report.bin",
)
assert result == "doc-line-1\ndoc-line-2\n"
@pytest.mark.asyncio
async def test_file_read_tool_stores_long_converted_document_in_workspace(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
pdf_path = workspace / "manual.pdf"
pdf_path.write_bytes(b"%PDF-1.7\nfake\n")
long_text = _make_large_text()
async def _fake_parse_pdf(_file_bytes: bytes, _file_name: str) -> str:
return long_text
monkeypatch.setattr(file_read_utils, "_parse_local_pdf_text", _fake_parse_pdf)
result = await fs_tools.FileReadTool().call(
_make_context(),
path="manual.pdf",
)
converted_root = workspace / "converted_files"
converted_files = list(converted_root.glob("manual.pdf_*/text.txt"))
assert len(converted_files) == 1
assert converted_files[0].read_text(encoding="utf-8") == long_text
assert str(converted_files[0]) in result
assert "Read or grep that file with a narrow window." in result
@pytest.mark.asyncio
async def test_grep_tool_applies_result_limit(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
text_path = workspace / "grep.txt"
text_path.write_text(
"match-1\nmatch-2\nmatch-3\nmatch-4\n",
encoding="utf-8",
)
result = await fs_tools.GrepTool().call(
_make_context(),
pattern="match",
path="grep.txt",
result_limit=2,
)
assert "match-1" in result
assert "match-2" in result
assert "match-3" not in result
assert "[Truncated to first 2 result groups.]" in result

View File

@@ -4,8 +4,10 @@ from types import SimpleNamespace
import pytest
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.computer.tools.browser import BrowserExecTool
from astrbot.core.computer.tools.neo_skills import GetExecutionHistoryTool
from astrbot.core.tools.computer_tools.shipyard_neo.browser import BrowserExecTool
from astrbot.core.tools.computer_tools.shipyard_neo.neo_skills import (
GetExecutionHistoryTool,
)
class _FakeBrowser:
@@ -49,7 +51,7 @@ async def test_browser_tool_allows_non_admin_when_admin_requirement_disabled(
return SimpleNamespace(browser=_FakeBrowser())
monkeypatch.setattr(
"astrbot.core.computer.tools.browser.get_booter",
"astrbot.core.tools.computer_tools.shipyard_neo.browser.get_booter",
_fake_get_booter,
)
@@ -72,7 +74,7 @@ async def test_neo_skill_tool_allows_non_admin_when_admin_requirement_disabled(
)
monkeypatch.setattr(
"astrbot.core.computer.tools.neo_skills.get_booter",
"astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.get_booter",
_fake_get_booter,
)

View File

@@ -9,8 +9,6 @@ from astrbot.core.computer.booters.local import LocalFileSystemComponent
def _allow_tmp_root(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setattr(local_booter, "get_astrbot_root", lambda: str(tmp_path))
monkeypatch.setattr(local_booter, "get_astrbot_data_path", lambda: str(tmp_path))
monkeypatch.setattr(local_booter, "get_astrbot_temp_path", lambda: str(tmp_path))
def test_local_file_system_component_prefers_utf8_before_windows_locale(
@@ -27,7 +25,7 @@ def test_local_file_system_component_prefers_utf8_before_windows_locale(
skill_path = tmp_path / "skills" / "demo.txt"
skill_path.parent.mkdir(parents=True, exist_ok=True)
skill_path.write_bytes("技能内容".encode("utf-8"))
skill_path.write_bytes("技能内容".encode())
result = asyncio.run(LocalFileSystemComponent().read_file(str(skill_path)))

View File

@@ -4,7 +4,9 @@ import asyncio
from types import SimpleNamespace
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.computer.tools.neo_skills import PromoteSkillCandidateTool
from astrbot.core.tools.computer_tools.shipyard_neo.neo_skills import (
PromoteSkillCandidateTool,
)
class _FakeSkills:
@@ -46,11 +48,11 @@ def test_promote_stable_sync_failure_auto_rolls_back(monkeypatch):
raise ValueError("sync failed")
monkeypatch.setattr(
"astrbot.core.computer.tools.neo_skills.get_booter",
"astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.get_booter",
_fake_get_booter,
)
monkeypatch.setattr(
"astrbot.core.computer.tools.neo_skills.NeoSkillSyncManager.sync_release",
"astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.NeoSkillSyncManager.sync_release",
_fake_sync_release,
)

View File

@@ -1,6 +1,7 @@
import asyncio
import os
import sys
from pathlib import Path
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock
@@ -97,6 +98,26 @@ class MockToolExecutor:
return generator()
class LargeTextToolExecutor:
"""模拟返回超长文本的工具执行器"""
def __init__(self, text: str):
self.text = text
@classmethod
def from_text(cls, text: str) -> "LargeTextToolExecutor":
return cls(text)
def execute(self, tool, run_context, **tool_args):
async def generator():
from mcp.types import CallToolResult, TextContent
result = CallToolResult(content=[TextContent(type="text", text=self.text)])
yield result
return generator()
class MockMixedContentToolExecutor:
"""模拟返回图片 + 文本的工具执行器"""
@@ -193,6 +214,32 @@ class MockToolCallProvider(MockProvider):
)
class SingleToolThenFinalProvider(MockProvider):
def __init__(self, tool_name: str, tool_args: dict[str, str] | None = None):
super().__init__()
self.tool_name = tool_name
self.tool_args = tool_args or {}
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
func_tool = kwargs.get("func_tool")
if func_tool is None or self.call_count > 1:
return LLMResponse(
role="assistant",
completion_text="最终回复",
usage=TokenUsage(input_other=10, output=5),
)
return LLMResponse(
role="assistant",
completion_text="",
tools_call_name=[self.tool_name],
tools_call_args=[self.tool_args],
tools_call_ids=["call_large_result"],
usage=TokenUsage(input_other=10, output=5),
)
class SequentialToolProvider(MockProvider):
def __init__(self, tool_sequence: list[str]):
super().__init__()
@@ -334,6 +381,10 @@ def runner():
return ToolLoopAgentRunner()
def _make_large_tool_result_text() -> str:
return "x" * 100000
@pytest.mark.asyncio
async def test_max_step_limit_functionality(
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
@@ -1124,18 +1175,116 @@ async def test_follow_up_accepted_when_active_and_not_stopping(
streaming=False,
)
# Runner is active (not done) and stop is not requested
assert not runner.done()
assert runner._is_stop_requested() is False
ticket = runner.follow_up(message_text="valid follow-up message")
assert ticket is not None, (
"Follow-up should be accepted when runner is active and not stopping"
@pytest.mark.asyncio
async def test_large_tool_result_is_spilled_to_file_and_replaced_with_read_notice(
tmp_path,
):
tool = FunctionTool(
name="test_tool",
description="测试工具",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
assert ticket.text == "valid follow-up message"
assert ticket.consumed is False
assert ticket in runner._pending_follow_ups
read_tool = FunctionTool(
name="astrbot_file_read_tool",
description="read file",
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
handler=AsyncMock(),
)
tool_set = ToolSet(tools=[tool, read_tool])
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
runner = ToolLoopAgentRunner()
await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=cast(
Any,
LargeTextToolExecutor.from_text(_make_large_tool_result_text()),
),
agent_hooks=MockHooks(),
streaming=False,
tool_result_overflow_dir=str(tmp_path),
read_tool=read_tool,
)
responses = []
async for response in runner.step_until_done(3):
responses.append(response)
tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
assert len(tool_messages) == 1
tool_message_content = str(tool_messages[0].content)
assert "xxxxxxxxxx" in tool_message_content
assert "Truncated tool output preview shown above." in tool_message_content
assert "The tool output was too large to include directly" in tool_message_content
assert "`astrbot_file_read_tool`" in tool_message_content
assert "Use `astrbot_file_read_tool` to inspect it." in tool_message_content
overflow_files = list(Path(tmp_path).glob("call_large_result_*.txt"))
assert len(overflow_files) == 1
assert (
overflow_files[0].read_text(encoding="utf-8") == _make_large_tool_result_text()
)
assert str(overflow_files[0]) in tool_message_content
llm_results = [resp for resp in responses if resp.type == "llm_result"]
assert llm_results
@pytest.mark.asyncio
async def test_large_tool_result_keeps_preview_when_spill_fails(
tmp_path,
monkeypatch: pytest.MonkeyPatch,
):
tool = FunctionTool(
name="test_tool",
description="测试工具",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
read_tool = FunctionTool(
name="astrbot_file_read_tool",
description="read file",
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
handler=AsyncMock(),
)
tool_set = ToolSet(tools=[tool, read_tool])
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
runner = ToolLoopAgentRunner()
async def _raise_spill_error(*, tool_call_id: str, content: str) -> str:
raise OSError("disk full")
monkeypatch.setattr(runner, "_write_tool_result_overflow_file", _raise_spill_error)
await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=cast(
Any,
LargeTextToolExecutor.from_text(_make_large_tool_result_text()),
),
agent_hooks=MockHooks(),
streaming=False,
tool_result_overflow_dir=str(tmp_path),
read_tool=read_tool,
)
async for _ in runner.step_until_done(3):
pass
tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
assert len(tool_messages) == 1
tool_message_content = str(tool_messages[0].content)
assert "xxxxxxxxxx" in tool_message_content
assert "Tool output exceeded the inline result limit" in tool_message_content
assert "disk full" in tool_message_content
@pytest.mark.asyncio

View File

@@ -40,7 +40,9 @@ def mock_context():
return_value=(None, None, None, False)
)
ctx.persona_manager.get_persona_v3_by_id = MagicMock(return_value=None)
ctx.get_llm_tool_manager.return_value = MagicMock()
tool_mgr = MagicMock()
tool_mgr.get_builtin_tool.side_effect = lambda cls, **kwargs: cls(**kwargs)
ctx.get_llm_tool_manager.return_value = tool_mgr
ctx.subagent_orchestrator = None
return ctx
@@ -1479,7 +1481,7 @@ class TestApplyLlmSafetyMode:
class TestApplySandboxTools:
"""Tests for _apply_sandbox_tools function."""
def test_apply_sandbox_tools_creates_toolset_if_none(self):
def test_apply_sandbox_tools_creates_toolset_if_none(self, mock_context):
"""Test that ToolSet is created when func_tool is None."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1494,7 +1496,7 @@ class TestApplySandboxTools:
assert req.func_tool is not None
assert isinstance(req.func_tool, ToolSet)
def test_apply_sandbox_tools_adds_required_tools(self):
def test_apply_sandbox_tools_adds_required_tools(self, mock_context):
"""Test that all required sandbox tools are added."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1512,7 +1514,7 @@ class TestApplySandboxTools:
assert "astrbot_upload_file" in tool_names
assert "astrbot_download_file" in tool_names
def test_apply_sandbox_tools_adds_sandbox_prompt(self):
def test_apply_sandbox_tools_adds_sandbox_prompt(self, mock_context):
"""Test that sandbox mode prompt is added to system_prompt."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1526,7 +1528,7 @@ class TestApplySandboxTools:
assert "sandboxed environment" in req.system_prompt
def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch):
def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch, mock_context):
"""Test sandbox tools with shipyard booter configuration."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1548,7 +1550,7 @@ class TestApplySandboxTools:
assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com"
assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token"
def test_apply_sandbox_tools_shipyard_missing_endpoint(self):
def test_apply_sandbox_tools_shipyard_missing_endpoint(self, mock_context):
"""Test that shipyard config is skipped when endpoint is missing."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1571,7 +1573,7 @@ class TestApplySandboxTools:
in mock_logger.error.call_args[0][0]
)
def test_apply_sandbox_tools_shipyard_missing_access_token(self):
def test_apply_sandbox_tools_shipyard_missing_access_token(self, mock_context):
"""Test that shipyard config is skipped when access token is missing."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1590,7 +1592,7 @@ class TestApplySandboxTools:
mock_logger.error.assert_called_once()
def test_apply_sandbox_tools_preserves_existing_toolset(self):
def test_apply_sandbox_tools_preserves_existing_toolset(self, mock_context):
"""Test that existing tools are preserved when adding sandbox tools."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1609,7 +1611,7 @@ class TestApplySandboxTools:
assert "existing_tool" in req.func_tool.names()
assert "astrbot_execute_shell" in req.func_tool.names()
def test_apply_sandbox_tools_appends_to_existing_system_prompt(self):
def test_apply_sandbox_tools_appends_to_existing_system_prompt(self, mock_context):
"""Test that sandbox prompt is appended to existing system prompt."""
module = ama
config = module.MainAgentBuildConfig(
@@ -1624,7 +1626,7 @@ class TestApplySandboxTools:
assert req.system_prompt.startswith("Base prompt")
assert "sandboxed environment" in req.system_prompt
def test_apply_sandbox_tools_with_none_system_prompt(self):
def test_apply_sandbox_tools_with_none_system_prompt(self, mock_context):
"""Test that sandbox prompt is applied when system_prompt is None."""
module = ama
config = module.MainAgentBuildConfig(

View File

@@ -15,7 +15,6 @@ from astrbot.core.computer.booters.local import (
LocalFileSystemComponent,
LocalPythonComponent,
LocalShellComponent,
_ensure_safe_path,
_is_safe_command,
)
@@ -126,51 +125,6 @@ class TestSecurityRestrictions:
for cmd in blocked_commands:
assert _is_safe_command(cmd) is False, f"Command '{cmd}' should be blocked"
def test_ensure_safe_path_allowed(self, tmp_path):
"""Test paths within allowed roots are accepted."""
# Create a test directory structure
test_file = tmp_path / "test.txt"
test_file.write_text("test")
# Mock get_astrbot_root, get_astrbot_data_path, get_astrbot_temp_path
with (
patch(
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
result = _ensure_safe_path(str(test_file))
assert result == str(test_file)
def test_ensure_safe_path_blocked(self, tmp_path):
"""Test paths outside allowed roots raise PermissionError."""
with (
patch(
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
# Try to access a path outside the allowed roots
with pytest.raises(PermissionError) as exc_info:
_ensure_safe_path("/etc/passwd")
assert "Path is outside the allowed computer roots" in str(exc_info.value)
class TestLocalShellComponent:
"""Tests for LocalShellComponent."""
@@ -212,14 +166,6 @@ class TestLocalShellComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
# Use python to read file to avoid Windows vs Unix command differences
result = await shell.exec(
@@ -294,14 +240,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
result = await fs.create_file(str(test_path), "test content")
assert result["success"] is True
@@ -320,14 +258,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
result = await fs.read_file(str(test_path))
assert result["success"] is True
@@ -344,14 +274,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
result = await fs.write_file(str(test_path), "new content")
assert result["success"] is True
@@ -369,14 +291,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
result = await fs.delete_file(str(test_path))
assert result["success"] is True
@@ -395,14 +309,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
result = await fs.delete_file(str(test_dir))
assert result["success"] is True
@@ -422,14 +328,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
# Without hidden files
result = await fs.list_dir(str(tmp_path), show_hidden=False)
@@ -452,14 +350,6 @@ class TestLocalFileSystemComponent:
"astrbot.core.computer.booters.local.get_astrbot_root",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_data_path",
return_value=str(tmp_path),
),
patch(
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
return_value=str(tmp_path),
),
):
# Should raise FileNotFoundError
with pytest.raises(FileNotFoundError):
@@ -633,7 +523,7 @@ class TestComputerClient:
"shipyard_access_token": "test_token",
"shipyard_ttl": 3600,
"shipyard_max_sessions": 10,
}
},
}
}.get(key, default)
mock_context.get_config = MagicMock(return_value=mock_config)
@@ -681,7 +571,7 @@ class TestComputerClient:
"computer_use_runtime": "sandbox",
"sandbox": {
"booter": "unknown_type",
}
},
}
}.get(key, default)
mock_context.get_config = MagicMock(return_value=mock_config)
@@ -707,7 +597,7 @@ class TestComputerClient:
"booter": "shipyard",
"shipyard_endpoint": "http://localhost:8080",
"shipyard_access_token": "test_token",
}
},
}
}.get(key, default)
mock_context.get_config = MagicMock(return_value=mock_config)
@@ -752,7 +642,7 @@ class TestComputerClient:
"booter": "shipyard",
"shipyard_endpoint": "http://localhost:8080",
"shipyard_access_token": "test_token",
}
},
}
}.get(key, default)
mock_context.get_config = MagicMock(return_value=mock_config)

View File

@@ -1,5 +1,6 @@
from astrbot.core import sp
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.tools.computer_tools.shell import ExecuteShellTool
from astrbot.core.tools.message_tools import SendMessageToUserTool
@@ -12,3 +13,28 @@ def test_get_builtin_tool_by_class_returns_cached_instance():
assert tool_by_class is tool_by_name
assert manager.get_func("send_message_to_user") is tool_by_class
assert tool_by_class.name == "send_message_to_user"
def test_builtin_tool_ignores_inactivated_llm_tools():
manager = FunctionToolManager()
sp.put(
"inactivated_llm_tools",
["send_message_to_user"],
scope="global",
scope_id="global",
)
try:
tool = manager.get_builtin_tool(SendMessageToUserTool)
assert tool.active is True
finally:
sp.put("inactivated_llm_tools", [], scope="global", scope_id="global")
def test_computer_tools_are_registered_as_builtin_tools():
manager = FunctionToolManager()
tool = manager.get_builtin_tool(ExecuteShellTool)
assert tool.name == "astrbot_execute_shell"
assert manager.is_builtin_tool("astrbot_execute_shell") is True

View File

@@ -1,5 +1,7 @@
import platform
from astrbot.core.computer.tools.python import PythonTool, LocalPythonTool
from astrbot.core.tools.computer_tools.python import LocalPythonTool, PythonTool
def test_python_tool_description_contains_os():
"""测试 PythonTool 的描述中是否包含当前操作系统信息"""
@@ -8,6 +10,7 @@ def test_python_tool_description_contains_os():
assert current_os in tool.description
assert "IPython" in tool.description
def test_local_python_tool_description_contains_os():
"""测试 LocalPythonTool 的描述中是否包含当前操作系统信息和兼容性提示"""
tool = LocalPythonTool()