mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-02 02:30:16 +08:00
Compare commits
19 Commits
codex/rest
...
chore/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1fe0ed1f23 | ||
|
|
1745e9c4fb | ||
|
|
3acda6f77a | ||
|
|
cff148860a | ||
|
|
013ecacee9 | ||
|
|
7bf1d19332 | ||
|
|
5f049f2bb5 | ||
|
|
add5db6748 | ||
|
|
5ca2483a43 | ||
|
|
adc01e0c9d | ||
|
|
efc93a37b1 | ||
|
|
56a099bf90 | ||
|
|
006aedbd24 | ||
|
|
86ac40d944 | ||
|
|
20fed8ab62 | ||
|
|
a539deec91 | ||
|
|
11282c769f | ||
|
|
8e7d995fec | ||
|
|
fcf1b08455 |
@@ -16,6 +16,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
gnupg \
|
||||
git \
|
||||
ripgrep \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
|
||||
&& apt-get install -y --no-install-recommends nodejs \
|
||||
&& apt-get clean \
|
||||
|
||||
@@ -4,9 +4,11 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
@@ -25,7 +27,7 @@ from tenacity import (
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_image_cache import tool_image_cache
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.message.components import Json
|
||||
@@ -45,7 +47,7 @@ from astrbot.core.provider.provider import Provider
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..context.token_counter import EstimateTokenCounter, TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -97,6 +99,8 @@ ToolExecutorResultT = T.TypeVar("ToolExecutorResultT")
|
||||
|
||||
|
||||
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500
|
||||
TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000
|
||||
EMPTY_OUTPUT_RETRY_ATTEMPTS = 3
|
||||
EMPTY_OUTPUT_RETRY_WAIT_MIN_S = 1
|
||||
EMPTY_OUTPUT_RETRY_WAIT_MAX_S = 4
|
||||
@@ -151,6 +155,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
"Otherwise, change strategy, adjust arguments, or explain the limitation "
|
||||
"to the user."
|
||||
)
|
||||
TOOL_RESULT_OVERFLOW_NOTICE_TEMPLATE = (
|
||||
"Truncated tool output preview shown above. "
|
||||
"The tool output was too large to include directly and was written to "
|
||||
"`{overflow_path}`. Use {read_tool_hint} with a narrower window to inspect it."
|
||||
)
|
||||
|
||||
def _get_persona_custom_error_message(self) -> str | None:
|
||||
"""Read persona-level custom error message from event extras when available."""
|
||||
@@ -206,6 +215,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
tool_schema_mode: str | None = "full",
|
||||
fallback_providers: list[Provider] | None = None,
|
||||
tool_result_overflow_dir: str | None = None,
|
||||
read_tool: FunctionTool | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
@@ -217,6 +228,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
self.tool_result_overflow_dir = tool_result_overflow_dir
|
||||
self.read_tool = read_tool
|
||||
self._tool_result_token_counter = EstimateTokenCounter()
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
@@ -298,6 +312,103 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
|
||||
def _read_tool_hint(self) -> str:
|
||||
if self.read_tool is not None:
|
||||
return f"`{self.read_tool.name}`"
|
||||
return "the available file-read tool"
|
||||
|
||||
async def _write_tool_result_overflow_file(
|
||||
self,
|
||||
*,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> str:
|
||||
if self.tool_result_overflow_dir is None:
|
||||
raise ValueError("tool_result_overflow_dir is not configured")
|
||||
|
||||
overflow_dir = Path(self.tool_result_overflow_dir).resolve(strict=False)
|
||||
safe_tool_call_id = (
|
||||
"".join(
|
||||
ch if ch.isalnum() or ch in {"-", "_", "."} else "_"
|
||||
for ch in tool_call_id
|
||||
).strip("._")
|
||||
or "tool_call"
|
||||
)
|
||||
file_name = f"{safe_tool_call_id}_{uuid.uuid4().hex[:8]}.txt"
|
||||
overflow_path = overflow_dir / file_name
|
||||
|
||||
def _run() -> str:
|
||||
overflow_dir.mkdir(parents=True, exist_ok=True)
|
||||
overflow_path.write_text(content, encoding="utf-8")
|
||||
return str(overflow_path)
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
async def _materialize_large_tool_result(
|
||||
self,
|
||||
*,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> str:
|
||||
if self.tool_result_overflow_dir is None or self.read_tool is None:
|
||||
return content
|
||||
|
||||
estimated_tokens = self._tool_result_token_counter.count_tokens(
|
||||
[Message(role="tool", content=content, tool_call_id=tool_call_id)]
|
||||
)
|
||||
if estimated_tokens <= self.TOOL_RESULT_MAX_ESTIMATED_TOKENS:
|
||||
return content
|
||||
|
||||
preview = self._truncate_tool_result_preview(content, tool_call_id=tool_call_id)
|
||||
try:
|
||||
overflow_path = await self._write_tool_result_overflow_file(
|
||||
tool_call_id=tool_call_id,
|
||||
content=content,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to spill oversized tool result for %s: %s",
|
||||
tool_call_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
error_notice = (
|
||||
"Tool output exceeded the inline result limit "
|
||||
f"({estimated_tokens} estimated tokens > "
|
||||
f"{self.TOOL_RESULT_MAX_ESTIMATED_TOKENS}) and could not be written "
|
||||
f"to `{self.tool_result_overflow_dir}`: {exc}"
|
||||
)
|
||||
if not preview:
|
||||
return error_notice
|
||||
return f"{preview}\n\n{error_notice}"
|
||||
|
||||
notice = self.TOOL_RESULT_OVERFLOW_NOTICE_TEMPLATE.format(
|
||||
overflow_path=overflow_path,
|
||||
read_tool_hint=self._read_tool_hint(),
|
||||
)
|
||||
if not preview:
|
||||
return notice
|
||||
return f"{preview}\n\n{notice}"
|
||||
|
||||
def _truncate_tool_result_preview(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
tool_call_id: str,
|
||||
) -> str:
|
||||
preview = content
|
||||
while preview:
|
||||
estimated_tokens = self._tool_result_token_counter.count_tokens(
|
||||
[Message(role="tool", content=preview, tool_call_id=tool_call_id)]
|
||||
)
|
||||
if estimated_tokens <= self.TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS:
|
||||
return preview
|
||||
next_len = len(preview) // 2
|
||||
if next_len <= 0:
|
||||
break
|
||||
preview = preview[:next_len]
|
||||
return preview
|
||||
|
||||
async def _iter_llm_responses(
|
||||
self, *, include_model: bool = True
|
||||
) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
@@ -933,9 +1044,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
"The tool has returned a data type that is not supported."
|
||||
)
|
||||
if result_parts:
|
||||
inline_result = "\n\n".join(result_parts)
|
||||
inline_result = await self._materialize_large_tool_result(
|
||||
tool_call_id=func_tool_id,
|
||||
content=inline_result,
|
||||
)
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
"\n\n".join(result_parts)
|
||||
inline_result
|
||||
+ self._build_repeated_tool_call_guidance(
|
||||
func_tool_name, tool_call_streak
|
||||
),
|
||||
|
||||
@@ -19,12 +19,6 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
from astrbot.core.cron.events import CronMessageEvent
|
||||
from astrbot.core.message.components import Image
|
||||
@@ -36,6 +30,17 @@ from astrbot.core.message.message_event_result import (
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.tools.computer_tools import (
|
||||
ExecuteShellTool,
|
||||
FileDownloadTool,
|
||||
FileEditTool,
|
||||
FileReadTool,
|
||||
FileUploadTool,
|
||||
FileWriteTool,
|
||||
GrepTool,
|
||||
LocalPythonTool,
|
||||
PythonTool,
|
||||
)
|
||||
from astrbot.core.tools.message_tools import SendMessageToUserTool
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.history_saver import persist_agent_history
|
||||
@@ -177,18 +182,44 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
|
||||
def _get_runtime_computer_tools(
|
||||
cls,
|
||||
runtime: str,
|
||||
tool_mgr,
|
||||
) -> dict[str, FunctionTool]:
|
||||
if runtime == "sandbox":
|
||||
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
|
||||
python_tool = tool_mgr.get_builtin_tool(PythonTool)
|
||||
upload_tool = tool_mgr.get_builtin_tool(FileUploadTool)
|
||||
download_tool = tool_mgr.get_builtin_tool(FileDownloadTool)
|
||||
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
|
||||
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
|
||||
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
|
||||
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
|
||||
return {
|
||||
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
|
||||
PYTHON_TOOL.name: PYTHON_TOOL,
|
||||
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
|
||||
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
|
||||
shell_tool.name: shell_tool,
|
||||
python_tool.name: python_tool,
|
||||
upload_tool.name: upload_tool,
|
||||
download_tool.name: download_tool,
|
||||
read_tool.name: read_tool,
|
||||
write_tool.name: write_tool,
|
||||
edit_tool.name: edit_tool,
|
||||
grep_tool.name: grep_tool,
|
||||
}
|
||||
if runtime == "local":
|
||||
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
|
||||
python_tool = tool_mgr.get_builtin_tool(LocalPythonTool)
|
||||
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
|
||||
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
|
||||
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
|
||||
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
|
||||
return {
|
||||
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
|
||||
shell_tool.name: shell_tool,
|
||||
python_tool.name: python_tool,
|
||||
read_tool.name: read_tool,
|
||||
write_tool.name: write_tool,
|
||||
edit_tool.name: edit_tool,
|
||||
grep_tool.name: grep_tool,
|
||||
}
|
||||
return {}
|
||||
|
||||
@@ -203,7 +234,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
cfg = ctx.get_config(umo=event.unified_msg_origin)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
runtime = str(provider_settings.get("computer_use_runtime", "local"))
|
||||
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
|
||||
tool_mgr = (
|
||||
ctx.get_llm_tool_manager()
|
||||
if hasattr(ctx, "get_llm_tool_manager")
|
||||
else llm_tools
|
||||
)
|
||||
runtime_computer_tools = cls._get_runtime_computer_tools(
|
||||
runtime,
|
||||
tool_mgr,
|
||||
)
|
||||
|
||||
# Keep persona semantics aligned with the main agent: tools=None means
|
||||
# "all tools", including runtime computer-use tools.
|
||||
|
||||
@@ -9,6 +9,7 @@ import platform
|
||||
import zoneinfo
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
@@ -20,30 +21,10 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from astrbot.core.astr_agent_run_util import AgentRunner
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
ANNOTATE_EXECUTION_TOOL,
|
||||
BROWSER_BATCH_EXEC_TOOL,
|
||||
BROWSER_EXEC_TOOL,
|
||||
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
|
||||
CREATE_SKILL_CANDIDATE_TOOL,
|
||||
CREATE_SKILL_PAYLOAD_TOOL,
|
||||
EVALUATE_SKILL_CANDIDATE_TOOL,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
GET_EXECUTION_HISTORY_TOOL,
|
||||
GET_SKILL_PAYLOAD_TOOL,
|
||||
LIST_SKILL_CANDIDATES_TOOL,
|
||||
LIST_SKILL_RELEASES_TOOL,
|
||||
LIVE_MODE_SYSTEM_PROMPT,
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
PROMOTE_SKILL_CANDIDATE_TOOL,
|
||||
PYTHON_TOOL,
|
||||
ROLLBACK_SKILL_RELEASE_TOOL,
|
||||
RUN_BROWSER_SKILL_TOOL,
|
||||
SANDBOX_MODE_PROMPT,
|
||||
SYNC_SKILL_RELEASE_TOOL,
|
||||
TOOL_CALL_PROMPT,
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
|
||||
)
|
||||
@@ -56,9 +37,36 @@ from astrbot.core.persona_error_reply import (
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.tools.computer_tools import (
|
||||
AnnotateExecutionTool,
|
||||
BrowserBatchExecTool,
|
||||
BrowserExecTool,
|
||||
CreateSkillCandidateTool,
|
||||
CreateSkillPayloadTool,
|
||||
EvaluateSkillCandidateTool,
|
||||
ExecuteShellTool,
|
||||
FileDownloadTool,
|
||||
FileEditTool,
|
||||
FileReadTool,
|
||||
FileUploadTool,
|
||||
FileWriteTool,
|
||||
GetExecutionHistoryTool,
|
||||
GetSkillPayloadTool,
|
||||
GrepTool,
|
||||
ListSkillCandidatesTool,
|
||||
ListSkillReleasesTool,
|
||||
LocalPythonTool,
|
||||
PromoteSkillCandidateTool,
|
||||
PythonTool,
|
||||
RollbackSkillReleaseTool,
|
||||
RunBrowserSkillTool,
|
||||
SyncSkillReleaseTool,
|
||||
normalize_umo_for_workspace,
|
||||
)
|
||||
from astrbot.core.tools.cron_tools import FutureTaskTool
|
||||
from astrbot.core.tools.knowledge_base_tools import (
|
||||
KnowledgeBaseQueryTool,
|
||||
@@ -73,6 +81,10 @@ from astrbot.core.tools.web_search_tools import (
|
||||
TavilyWebSearchTool,
|
||||
normalize_legacy_web_search_config,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_system_tmp_path,
|
||||
get_astrbot_workspaces_path,
|
||||
)
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.media_utils import (
|
||||
@@ -290,11 +302,54 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None:
|
||||
req.prompt = f"{prefix}{req.prompt}"
|
||||
|
||||
|
||||
def _apply_local_env_tools(req: ProviderRequest) -> None:
|
||||
def _get_workspace_path_for_umo(umo: str) -> Path:
|
||||
normalized_umo = normalize_umo_for_workspace(umo)
|
||||
return Path(get_astrbot_workspaces_path()) / normalized_umo
|
||||
|
||||
|
||||
def _apply_workspace_extra_prompt(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / (
|
||||
"EXTRA_PROMPT.md"
|
||||
)
|
||||
if not extra_prompt_path.is_file():
|
||||
return
|
||||
|
||||
try:
|
||||
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to read workspace extra prompt for umo=%s from %s: %s",
|
||||
event.unified_msg_origin,
|
||||
extra_prompt_path,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
if not extra_prompt:
|
||||
return
|
||||
|
||||
req.system_prompt = (
|
||||
f"{req.system_prompt or ''}\n"
|
||||
"[Workspace Extra Prompt]\n"
|
||||
"The following instructions are loaded from the current workspace "
|
||||
"`EXTRA_PROMPT.md` file.\n"
|
||||
f"{extra_prompt}\n"
|
||||
)
|
||||
|
||||
|
||||
def _apply_local_env_tools(req: ProviderRequest, plugin_context: Context) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(LOCAL_PYTHON_TOOL)
|
||||
tool_mgr = plugin_context.get_llm_tool_manager()
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(LocalPythonTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool))
|
||||
req.system_prompt = f"{req.system_prompt or ''}\n{_build_local_mode_prompt()}\n"
|
||||
|
||||
|
||||
@@ -765,6 +820,7 @@ async def _decorate_llm_request(
|
||||
if tz is None:
|
||||
tz = plugin_context.get_config().get("timezone")
|
||||
_append_system_reminders(event, req, cfg, tz)
|
||||
_apply_workspace_extra_prompt(event, req)
|
||||
|
||||
|
||||
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
|
||||
@@ -981,7 +1037,9 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) -
|
||||
|
||||
|
||||
def _apply_sandbox_tools(
|
||||
config: MainAgentBuildConfig, req: ProviderRequest, session_id: str
|
||||
config: MainAgentBuildConfig,
|
||||
req: ProviderRequest,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
@@ -997,10 +1055,15 @@ def _apply_sandbox_tools(
|
||||
os.environ["SHIPYARD_ENDPOINT"] = ep
|
||||
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
|
||||
|
||||
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(PYTHON_TOOL)
|
||||
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
|
||||
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
|
||||
tool_mgr = llm_tools
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(PythonTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileUploadTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileDownloadTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool))
|
||||
if booter == "shipyard_neo":
|
||||
# Neo-specific path rule: filesystem tools operate relative to sandbox
|
||||
# workspace root. Do not prepend "/workspace".
|
||||
@@ -1036,22 +1099,22 @@ def _apply_sandbox_tools(
|
||||
# Browser tools: only register if profile supports browser
|
||||
# (or if capabilities are unknown because sandbox hasn't booted yet)
|
||||
if sandbox_capabilities is None or "browser" in sandbox_capabilities:
|
||||
req.func_tool.add_tool(BROWSER_EXEC_TOOL)
|
||||
req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL)
|
||||
req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL)
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserExecTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserBatchExecTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(RunBrowserSkillTool))
|
||||
|
||||
# Neo-specific tools (always available for shipyard_neo)
|
||||
req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL)
|
||||
req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL)
|
||||
req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL)
|
||||
req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL)
|
||||
req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL)
|
||||
req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL)
|
||||
req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL)
|
||||
req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL)
|
||||
req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL)
|
||||
req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL)
|
||||
req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL)
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetExecutionHistoryTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(AnnotateExecutionTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillPayloadTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetSkillPayloadTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillCandidateTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillCandidatesTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(EvaluateSkillCandidateTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(PromoteSkillCandidateTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillReleasesTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(RollbackSkillReleaseTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(SyncSkillReleaseTool))
|
||||
|
||||
req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n"
|
||||
|
||||
@@ -1341,7 +1404,7 @@ async def build_main_agent(
|
||||
if config.computer_use_runtime == "sandbox":
|
||||
_apply_sandbox_tools(config, req, req.session_id)
|
||||
elif config.computer_use_runtime == "local":
|
||||
_apply_local_env_tools(req)
|
||||
_apply_local_env_tools(req, plugin_context)
|
||||
|
||||
agent_runner = AgentRunner()
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
@@ -1377,6 +1440,15 @@ async def build_main_agent(
|
||||
if config.tool_schema_mode == "full"
|
||||
else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE
|
||||
)
|
||||
|
||||
if config.computer_use_runtime == "local":
|
||||
tool_prompt += (
|
||||
f"\nCurrent workspace you can use: "
|
||||
f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n"
|
||||
"Unless the user explicitly specifies a different directory, "
|
||||
"perform all file-related operations in this workspace.\n"
|
||||
)
|
||||
|
||||
req.system_prompt += f"\n{tool_prompt}\n"
|
||||
|
||||
action_type = event.get_extra("action_type")
|
||||
@@ -1402,6 +1474,14 @@ async def build_main_agent(
|
||||
fallback_providers=_get_fallback_chat_providers(
|
||||
provider, plugin_context, config.provider_settings
|
||||
),
|
||||
tool_result_overflow_dir=(
|
||||
get_astrbot_system_tmp_path()
|
||||
if req.func_tool and req.func_tool.get_tool("astrbot_file_read_tool")
|
||||
else None
|
||||
),
|
||||
read_tool=(
|
||||
req.func_tool.get_tool("astrbot_file_read_tool") if req.func_tool else None
|
||||
),
|
||||
)
|
||||
|
||||
if apply_reset:
|
||||
|
||||
@@ -1,27 +1,5 @@
|
||||
import base64
|
||||
|
||||
from astrbot.core.computer.tools import (
|
||||
AnnotateExecutionTool,
|
||||
BrowserBatchExecTool,
|
||||
BrowserExecTool,
|
||||
CreateSkillCandidateTool,
|
||||
CreateSkillPayloadTool,
|
||||
EvaluateSkillCandidateTool,
|
||||
ExecuteShellTool,
|
||||
FileDownloadTool,
|
||||
FileUploadTool,
|
||||
GetExecutionHistoryTool,
|
||||
GetSkillPayloadTool,
|
||||
ListSkillCandidatesTool,
|
||||
ListSkillReleasesTool,
|
||||
LocalPythonTool,
|
||||
PromoteSkillCandidateTool,
|
||||
PythonTool,
|
||||
RollbackSkillReleaseTool,
|
||||
RunBrowserSkillTool,
|
||||
SyncSkillReleaseTool,
|
||||
)
|
||||
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT = """You are running in Safe Mode.
|
||||
|
||||
Rules:
|
||||
@@ -130,28 +108,6 @@ BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT = (
|
||||
"{background_task_result}"
|
||||
)
|
||||
|
||||
|
||||
EXECUTE_SHELL_TOOL = ExecuteShellTool()
|
||||
LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True)
|
||||
PYTHON_TOOL = PythonTool()
|
||||
LOCAL_PYTHON_TOOL = LocalPythonTool()
|
||||
FILE_UPLOAD_TOOL = FileUploadTool()
|
||||
FILE_DOWNLOAD_TOOL = FileDownloadTool()
|
||||
BROWSER_EXEC_TOOL = BrowserExecTool()
|
||||
BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool()
|
||||
RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool()
|
||||
GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool()
|
||||
ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool()
|
||||
CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool()
|
||||
GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool()
|
||||
CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool()
|
||||
LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool()
|
||||
EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool()
|
||||
PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool()
|
||||
LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool()
|
||||
ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool()
|
||||
SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool()
|
||||
|
||||
# we prevent astrbot from connecting to known malicious hosts
|
||||
# these hosts are base64 encoded
|
||||
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import boxlite
|
||||
from shipyard.filesystem import FileSystemComponent as ShipyardFileSystemComponent
|
||||
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
|
||||
from shipyard.python import PythonComponent as ShipyardPythonComponent
|
||||
from shipyard.shell import ShellComponent as ShipyardShellComponent
|
||||
|
||||
@@ -12,6 +12,7 @@ from astrbot.api import logger
|
||||
|
||||
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
|
||||
from .base import ComputerBooter
|
||||
from .shipyard import ShipyardFileSystemWrapper
|
||||
|
||||
|
||||
class MockShipyardSandboxClient:
|
||||
@@ -150,11 +151,6 @@ class BoxliteBooter(ComputerBooter):
|
||||
self.mocked = MockShipyardSandboxClient(
|
||||
sb_url=f"http://127.0.0.1:{random_port}"
|
||||
)
|
||||
self._fs = ShipyardFileSystemComponent(
|
||||
client=self.mocked, # type: ignore
|
||||
ship_id=self.box.id,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._python = ShipyardPythonComponent(
|
||||
client=self.mocked, # type: ignore
|
||||
ship_id=self.box.id,
|
||||
@@ -165,6 +161,14 @@ class BoxliteBooter(ComputerBooter):
|
||||
ship_id=self.box.id,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._ship_fs = ShipyardFileSystemComponent(
|
||||
client=self.mocked, # type: ignore
|
||||
ship_id=self.box.id,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._fs = ShipyardFileSystemWrapper(
|
||||
_shipyard_fs=self._ship_fs, _shipyard_shell=self._shell
|
||||
)
|
||||
|
||||
await self.mocked.wait_healthy(self.box.id, session_id)
|
||||
|
||||
|
||||
@@ -9,15 +9,18 @@ import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from python_ripgrep import search
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_root,
|
||||
get_astrbot_temp_path,
|
||||
from astrbot.core.computer.file_read_utils import (
|
||||
detect_text_encoding,
|
||||
read_local_text_range_sync,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_root
|
||||
|
||||
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
|
||||
from .base import ComputerBooter
|
||||
from .shipyard_search_file_util import _truncate_long_lines
|
||||
|
||||
_BLOCKED_COMMAND_PATTERNS = [
|
||||
" rm -rf ",
|
||||
@@ -41,18 +44,6 @@ def _is_safe_command(command: str) -> bool:
|
||||
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)
|
||||
|
||||
|
||||
def _ensure_safe_path(path: str) -> str:
|
||||
abs_path = os.path.abspath(path)
|
||||
allowed_roots = [
|
||||
os.path.abspath(get_astrbot_root()),
|
||||
os.path.abspath(get_astrbot_data_path()),
|
||||
os.path.abspath(get_astrbot_temp_path()),
|
||||
]
|
||||
if not any(abs_path.startswith(root) for root in allowed_roots):
|
||||
raise PermissionError("Path is outside the allowed computer roots.")
|
||||
return abs_path
|
||||
|
||||
|
||||
def _decode_bytes_with_fallback(
|
||||
output: bytes | None,
|
||||
*,
|
||||
@@ -110,7 +101,7 @@ class LocalShellComponent(ShellComponent):
|
||||
run_env = os.environ.copy()
|
||||
if env:
|
||||
run_env.update({str(k): str(v) for k, v in env.items()})
|
||||
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
|
||||
working_dir = os.path.abspath(cwd) if cwd else get_astrbot_root()
|
||||
if background:
|
||||
# `command` is intentionally executed through the current shell so
|
||||
# local computer-use behavior matches existing tool semantics.
|
||||
@@ -186,7 +177,7 @@ class LocalFileSystemComponent(FileSystemComponent):
|
||||
self, path: str, content: str = "", mode: int = 0o644
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
abs_path = _ensure_safe_path(path)
|
||||
abs_path = os.path.abspath(path)
|
||||
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
|
||||
with open(abs_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
@@ -195,16 +186,85 @@ class LocalFileSystemComponent(FileSystemComponent):
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
|
||||
async def read_file(
|
||||
self,
|
||||
path: str,
|
||||
encoding: str = "utf-8",
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
abs_path = _ensure_safe_path(path)
|
||||
with open(abs_path, "rb") as f:
|
||||
raw_content = f.read()
|
||||
content = _decode_bytes_with_fallback(
|
||||
raw_content,
|
||||
preferred_encoding=encoding,
|
||||
abs_path = os.path.abspath(path)
|
||||
detected_encoding = encoding
|
||||
if encoding == "utf-8":
|
||||
with open(abs_path, "rb") as f:
|
||||
raw_sample = f.read(8192)
|
||||
detected_encoding = detect_text_encoding(raw_sample) or encoding
|
||||
return {
|
||||
"success": True,
|
||||
"content": read_local_text_range_sync(
|
||||
abs_path,
|
||||
encoding=detected_encoding,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
),
|
||||
}
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
async def search_files(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
after_context: int | None = None,
|
||||
before_context: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
results = search(
|
||||
patterns=[pattern],
|
||||
paths=[path] if path else None,
|
||||
globs=[glob] if glob else None,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
line_number=True,
|
||||
)
|
||||
return {"success": True, "content": content}
|
||||
return {"success": True, "content": _truncate_long_lines("".join(results))}
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
async def edit_file(
|
||||
self,
|
||||
path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
abs_path = os.path.abspath(path)
|
||||
with open(abs_path, encoding=encoding) as f:
|
||||
content = f.read()
|
||||
occurrences = content.count(old_string)
|
||||
if occurrences == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "old string not found in file",
|
||||
"replacements": 0,
|
||||
}
|
||||
if replace_all:
|
||||
updated = content.replace(old_string, new_string)
|
||||
replacements = occurrences
|
||||
else:
|
||||
updated = content.replace(old_string, new_string, 1)
|
||||
replacements = 1
|
||||
with open(abs_path, "w", encoding=encoding) as f:
|
||||
f.write(updated)
|
||||
return {
|
||||
"success": True,
|
||||
"path": abs_path,
|
||||
"replacements": replacements,
|
||||
}
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
@@ -212,7 +272,7 @@ class LocalFileSystemComponent(FileSystemComponent):
|
||||
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
abs_path = _ensure_safe_path(path)
|
||||
abs_path = os.path.abspath(path)
|
||||
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
|
||||
with open(abs_path, mode, encoding=encoding) as f:
|
||||
f.write(content)
|
||||
@@ -222,7 +282,7 @@ class LocalFileSystemComponent(FileSystemComponent):
|
||||
|
||||
async def delete_file(self, path: str) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
abs_path = _ensure_safe_path(path)
|
||||
abs_path = os.path.abspath(path)
|
||||
if os.path.isdir(abs_path):
|
||||
shutil.rmtree(abs_path)
|
||||
else:
|
||||
@@ -235,7 +295,7 @@ class LocalFileSystemComponent(FileSystemComponent):
|
||||
self, path: str = ".", show_hidden: bool = False
|
||||
) -> dict[str, Any]:
|
||||
def _run() -> dict[str, Any]:
|
||||
abs_path = _ensure_safe_path(path)
|
||||
abs_path = os.path.abspath(path)
|
||||
entries = os.listdir(abs_path)
|
||||
if not show_hidden:
|
||||
entries = [e for e in entries if not e.startswith(".")]
|
||||
|
||||
@@ -1,9 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
|
||||
from shipyard import ShipyardClient, Spec
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
|
||||
from .base import ComputerBooter
|
||||
from .shipyard_search_file_util import search_files_via_shell
|
||||
|
||||
|
||||
class ShipyardFileSystemWrapper:
|
||||
def __init__(
|
||||
self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent
|
||||
):
|
||||
self._fs = _shipyard_fs
|
||||
self._shell = _shipyard_shell
|
||||
|
||||
async def create_file(
|
||||
self, path: str, content: str = "", mode: int = 420
|
||||
) -> dict[str, Any]:
|
||||
return await self._fs.create_file(path=path, content=content, mode=mode)
|
||||
|
||||
async def read_file(
|
||||
self,
|
||||
path: str,
|
||||
encoding: str = "utf-8",
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await self._fs.read_file(
|
||||
path=path, encoding=encoding, offset=offset, limit=limit
|
||||
)
|
||||
|
||||
async def write_file(
|
||||
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
|
||||
) -> dict[str, Any]:
|
||||
return await self._fs.write_file(
|
||||
path=path, content=content, mode=mode, encoding=encoding
|
||||
)
|
||||
|
||||
async def list_dir(
|
||||
self, path: str = ".", show_hidden: bool = False
|
||||
) -> dict[str, Any]:
|
||||
return await self._fs.list_dir(path=path, show_hidden=show_hidden)
|
||||
|
||||
async def delete_file(self, path: str) -> dict[str, Any]:
|
||||
return await self._fs.delete_file(path=path)
|
||||
|
||||
async def search_files(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
after_context: int | None = None,
|
||||
before_context: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await search_files_via_shell(
|
||||
self._shell,
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
glob=glob,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
)
|
||||
|
||||
async def edit_file(
|
||||
self,
|
||||
path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
) -> dict[str, Any]:
|
||||
return await self._fs.edit_file(
|
||||
path=path,
|
||||
old_string=old_string,
|
||||
new_string=new_string,
|
||||
replace_all=replace_all,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
class ShipyardBooter(ComputerBooter):
|
||||
@@ -29,13 +107,14 @@ class ShipyardBooter(ComputerBooter):
|
||||
)
|
||||
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
|
||||
self._ship = ship
|
||||
self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._ship.shell)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info("[Computer] Shipyard booter shutdown.")
|
||||
|
||||
@property
|
||||
def fs(self) -> FileSystemComponent:
|
||||
return self._ship.fs
|
||||
return self._fs
|
||||
|
||||
@property
|
||||
def python(self) -> PythonComponent:
|
||||
|
||||
@@ -13,6 +13,15 @@ from ..olayer import (
|
||||
ShellComponent,
|
||||
)
|
||||
from .base import ComputerBooter
|
||||
from .shipyard_search_file_util import search_files_via_shell
|
||||
|
||||
try:
|
||||
from shipyard_neo import BayClient
|
||||
from shipyard_neo.sandbox import Sandbox
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it."
|
||||
)
|
||||
|
||||
|
||||
def _maybe_model_dump(value: Any) -> dict[str, Any]:
|
||||
@@ -25,8 +34,20 @@ def _maybe_model_dump(value: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def _slice_content_by_lines(
|
||||
content: str,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
start = 0 if offset is None else offset
|
||||
selected = lines[start:] if limit is None else lines[start : start + limit]
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
class NeoPythonComponent(PythonComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def exec(
|
||||
@@ -67,7 +88,7 @@ class NeoPythonComponent(PythonComponent):
|
||||
|
||||
|
||||
class NeoShellComponent(ShellComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def exec(
|
||||
@@ -136,8 +157,9 @@ class NeoShellComponent(ShellComponent):
|
||||
|
||||
|
||||
class NeoFileSystemComponent(FileSystemComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None:
|
||||
self._sandbox = sandbox
|
||||
self._shell = shell
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
@@ -149,10 +171,71 @@ class NeoFileSystemComponent(FileSystemComponent):
|
||||
await self._sandbox.filesystem.write_file(path, content)
|
||||
return {"success": True, "path": path}
|
||||
|
||||
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
|
||||
async def read_file(
|
||||
self,
|
||||
path: str,
|
||||
encoding: str = "utf-8",
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_ = encoding
|
||||
content = await self._sandbox.filesystem.read_file(path)
|
||||
return {"success": True, "path": path, "content": content}
|
||||
return {
|
||||
"success": True,
|
||||
"path": path,
|
||||
"content": _slice_content_by_lines(
|
||||
content,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
),
|
||||
}
|
||||
|
||||
async def search_files(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
after_context: int | None = None,
|
||||
before_context: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await search_files_via_shell(
|
||||
self._shell,
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
glob=glob,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
)
|
||||
|
||||
async def edit_file(
|
||||
self,
|
||||
path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
) -> dict[str, Any]:
|
||||
_ = encoding
|
||||
content = await self._sandbox.filesystem.read_file(path)
|
||||
occurrences = content.count(old_string)
|
||||
if occurrences == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "old string not found in file",
|
||||
"replacements": 0,
|
||||
}
|
||||
if replace_all:
|
||||
updated = content.replace(old_string, new_string)
|
||||
replacements = occurrences
|
||||
else:
|
||||
updated = content.replace(old_string, new_string, 1)
|
||||
replacements = 1
|
||||
await self._sandbox.filesystem.write_file(path, updated)
|
||||
return {
|
||||
"success": True,
|
||||
"path": path,
|
||||
"replacements": replacements,
|
||||
}
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
@@ -186,7 +269,7 @@ class NeoFileSystemComponent(FileSystemComponent):
|
||||
|
||||
|
||||
class NeoBrowserComponent(BrowserComponent):
|
||||
def __init__(self, sandbox: Any) -> None:
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
async def exec(
|
||||
@@ -271,8 +354,8 @@ class ShipyardNeoBooter(ComputerBooter):
|
||||
self._access_token = access_token
|
||||
self._profile = profile
|
||||
self._ttl = ttl
|
||||
self._client: Any = None
|
||||
self._sandbox: Any = None
|
||||
self._client: BayClient | None = None
|
||||
self._sandbox: Sandbox | None = None
|
||||
self._bay_manager: Any = None # BayContainerManager when auto-started
|
||||
self._fs: FileSystemComponent | None = None
|
||||
self._python: PythonComponent | None = None
|
||||
@@ -336,8 +419,6 @@ class ShipyardNeoBooter(ComputerBooter):
|
||||
"or ensure Bay's credentials.json is accessible for auto-discovery."
|
||||
)
|
||||
|
||||
from shipyard_neo import BayClient
|
||||
|
||||
self._client = BayClient(
|
||||
endpoint_url=self._endpoint_url,
|
||||
access_token=self._access_token,
|
||||
@@ -352,9 +433,9 @@ class ShipyardNeoBooter(ComputerBooter):
|
||||
ttl=self._ttl,
|
||||
)
|
||||
|
||||
self._fs = NeoFileSystemComponent(self._sandbox)
|
||||
self._python = NeoPythonComponent(self._sandbox)
|
||||
self._shell = NeoShellComponent(self._sandbox)
|
||||
self._fs = NeoFileSystemComponent(self._sandbox, self._shell)
|
||||
self._python = NeoPythonComponent(self._sandbox)
|
||||
|
||||
caps = self.capabilities or ()
|
||||
self._browser = (
|
||||
|
||||
148
astrbot/core/computer/booters/shipyard_search_file_util.py
Normal file
148
astrbot/core/computer/booters/shipyard_search_file_util.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from ..olayer import ShellComponent
|
||||
|
||||
_MAX_SEARCH_LINE_COLUMNS = 1000
|
||||
|
||||
|
||||
def _truncate_long_lines(text: str) -> str:
|
||||
output_lines: list[str] = []
|
||||
for line in text.splitlines(keepends=True):
|
||||
line_ending = ""
|
||||
line_body = line
|
||||
if line.endswith("\r\n"):
|
||||
line_body = line[:-2]
|
||||
line_ending = "\r\n"
|
||||
elif line.endswith("\n") or line.endswith("\r"):
|
||||
line_body = line[:-1]
|
||||
line_ending = line[-1]
|
||||
|
||||
if len(line_body) > _MAX_SEARCH_LINE_COLUMNS:
|
||||
line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS]
|
||||
|
||||
output_lines.append(f"{line_body}{line_ending}")
|
||||
return "".join(output_lines)
|
||||
|
||||
|
||||
def _build_rg_command(
|
||||
*,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None,
|
||||
after_context: int | None,
|
||||
before_context: int | None,
|
||||
) -> list[str]:
|
||||
command = [
|
||||
"rg",
|
||||
"--color=never",
|
||||
"-n",
|
||||
"--max-columns",
|
||||
str(_MAX_SEARCH_LINE_COLUMNS),
|
||||
"-e",
|
||||
pattern,
|
||||
]
|
||||
if glob:
|
||||
command.extend(["-g", glob])
|
||||
if after_context is not None:
|
||||
command.extend(["-A", str(after_context)])
|
||||
if before_context is not None:
|
||||
command.extend(["-B", str(before_context)])
|
||||
command.extend(["--", path])
|
||||
return command
|
||||
|
||||
|
||||
def _build_grep_command(
|
||||
*,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None,
|
||||
after_context: int | None,
|
||||
before_context: int | None,
|
||||
) -> list[str]:
|
||||
command = ["grep", "-R", "-H", "-n", "-e", pattern]
|
||||
if glob:
|
||||
command.append(f"--include={glob}")
|
||||
if after_context is not None:
|
||||
command.extend(["-A", str(after_context)])
|
||||
if before_context is not None:
|
||||
command.extend(["-B", str(before_context)])
|
||||
command.extend(["--", path])
|
||||
return command
|
||||
|
||||
|
||||
def _quote_command(command: list[str]) -> str:
|
||||
return " ".join(shlex.quote(part) for part in command)
|
||||
|
||||
|
||||
def build_search_command(
|
||||
*,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None,
|
||||
after_context: int | None,
|
||||
before_context: int | None,
|
||||
) -> str:
|
||||
rg_command = _quote_command(
|
||||
_build_rg_command(
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
glob=glob,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
)
|
||||
)
|
||||
grep_command = _quote_command(
|
||||
_build_grep_command(
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
glob=glob,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
)
|
||||
)
|
||||
return (
|
||||
"if command -v rg >/dev/null 2>&1; then "
|
||||
f"{rg_command}; "
|
||||
"elif command -v grep >/dev/null 2>&1; then "
|
||||
f"{grep_command}; "
|
||||
"else "
|
||||
"echo 'Neither rg nor grep is available in the sandbox.' >&2; "
|
||||
"exit 127; "
|
||||
"fi"
|
||||
)
|
||||
|
||||
|
||||
async def search_files_via_shell(
|
||||
shell: ShellComponent,
|
||||
*,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
after_context: int | None = None,
|
||||
before_context: int | None = None,
|
||||
timeout: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
command = build_search_command(
|
||||
pattern=pattern,
|
||||
path=path or ".",
|
||||
glob=glob,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
)
|
||||
result = await shell.exec(command, timeout=timeout)
|
||||
stdout = _truncate_long_lines(str(result.get("stdout", "") or ""))
|
||||
stderr = str(result.get("stderr", "") or "")
|
||||
exit_code = result.get("exit_code")
|
||||
if exit_code in (0, None):
|
||||
return {"success": True, "content": stdout}
|
||||
if exit_code == 1:
|
||||
return {"success": True, "content": ""}
|
||||
return {
|
||||
"success": False,
|
||||
"content": "",
|
||||
"error": stderr or f"command exited with code {exit_code}",
|
||||
"exit_code": exit_code,
|
||||
}
|
||||
707
astrbot/core/computer/file_read_utils.py
Normal file
707
astrbot/core/computer/file_read_utils.py
Normal file
@@ -0,0 +1,707 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from asyncio import to_thread
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import mcp
|
||||
|
||||
from astrbot.core.agent.context.token_counter import EstimateTokenCounter
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.media_utils import (
|
||||
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
|
||||
IMAGE_COMPRESS_DEFAULT_OPTIMIZE,
|
||||
IMAGE_COMPRESS_DEFAULT_QUALITY,
|
||||
_compress_image_sync,
|
||||
)
|
||||
|
||||
from .booters.base import ComputerBooter
|
||||
|
||||
_MAX_FILE_READ_BYTES = 128 * 1024
|
||||
_MAX_FILE_READ_TOKENS = 25_000
|
||||
_MAX_TEXT_FILE_FULL_READ_BYTES = 256 * 1024
|
||||
_FILE_SNIFF_BYTES = 512
|
||||
_TOKEN_COUNTER = EstimateTokenCounter()
|
||||
_TEXT_ENCODINGS = (
|
||||
"utf-8-sig",
|
||||
"utf-8",
|
||||
"gb18030",
|
||||
"utf-16",
|
||||
"utf-16-le",
|
||||
"utf-16-be",
|
||||
"utf-32",
|
||||
"utf-32-le",
|
||||
"utf-32-be",
|
||||
)
|
||||
_UTF_BOMS = (
|
||||
b"\xef\xbb\xbf",
|
||||
b"\xff\xfe",
|
||||
b"\xfe\xff",
|
||||
b"\xff\xfe\x00\x00",
|
||||
b"\x00\x00\xfe\xff",
|
||||
)
|
||||
_ZIP_MAGIC_PREFIXES = (
|
||||
b"PK\x03\x04",
|
||||
b"PK\x05\x06",
|
||||
b"PK\x07\x08",
|
||||
)
|
||||
_BINARY_MAGIC_PREFIXES = (
|
||||
b"%PDF-",
|
||||
b"\x1f\x8b",
|
||||
b"7z\xbc\xaf\x27\x1c",
|
||||
b"Rar!\x1a\x07",
|
||||
b"\x7fELF",
|
||||
b"MZ",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileProbe:
|
||||
kind: Literal["text", "image", "binary"]
|
||||
encoding: str | None
|
||||
mime_type: str | None
|
||||
size_bytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParsedDocument:
|
||||
kind: Literal["docx", "pdf"]
|
||||
file_bytes: bytes
|
||||
text: str
|
||||
|
||||
|
||||
def _build_probe_script(path: str) -> str:
|
||||
return f"""
|
||||
import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
path = Path({path!r})
|
||||
with path.open("rb") as file_obj:
|
||||
sample = file_obj.read({_FILE_SNIFF_BYTES})
|
||||
print(
|
||||
json.dumps(
|
||||
{{
|
||||
"size_bytes": path.stat().st_size,
|
||||
"sample_b64": base64.b64encode(sample).decode("ascii"),
|
||||
}}
|
||||
)
|
||||
)
|
||||
""".strip()
|
||||
|
||||
|
||||
def _build_text_read_script(
|
||||
path: str,
|
||||
*,
|
||||
encoding: str,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
) -> str:
|
||||
start_expr = "0" if offset is None else str(offset)
|
||||
limit_expr = "None" if limit is None else str(limit)
|
||||
return f"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
path = Path({path!r})
|
||||
start = {start_expr}
|
||||
limit = {limit_expr}
|
||||
end = None if limit is None else start + limit
|
||||
lines = []
|
||||
with path.open("r", encoding={encoding!r}, newline="") as file_obj:
|
||||
for index, line in enumerate(file_obj):
|
||||
if index < start:
|
||||
continue
|
||||
if end is not None and index >= end:
|
||||
break
|
||||
lines.append(line)
|
||||
content = "".join(lines)
|
||||
print(json.dumps({{"content": content}}, ensure_ascii=False))
|
||||
""".strip()
|
||||
|
||||
|
||||
def _build_image_read_script(path: str) -> str:
|
||||
return f"""
|
||||
import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
path = Path({path!r})
|
||||
data = path.read_bytes()
|
||||
print(
|
||||
json.dumps(
|
||||
{{
|
||||
"size_bytes": len(data),
|
||||
"base64": base64.b64encode(data).decode("ascii"),
|
||||
}}
|
||||
)
|
||||
)
|
||||
""".strip()
|
||||
|
||||
|
||||
def _looks_like_text(decoded: str) -> bool:
|
||||
if not decoded:
|
||||
return True
|
||||
|
||||
disallowed = 0
|
||||
printable = 0
|
||||
for char in decoded:
|
||||
if char in "\n\r\t\f\b":
|
||||
printable += 1
|
||||
continue
|
||||
if char.isprintable():
|
||||
printable += 1
|
||||
code = ord(char)
|
||||
if (0 <= code < 32) or (127 <= code < 160):
|
||||
disallowed += 1
|
||||
|
||||
total = max(len(decoded), 1)
|
||||
return disallowed / total <= 0.02 and printable / total >= 0.85
|
||||
|
||||
|
||||
def detect_text_encoding(sample: bytes) -> str | None:
|
||||
if not sample:
|
||||
return "utf-8"
|
||||
|
||||
if b"\x00" in sample and not sample.startswith(_UTF_BOMS):
|
||||
odd_bytes = sample[1::2]
|
||||
even_bytes = sample[0::2]
|
||||
odd_zero_ratio = odd_bytes.count(0) / max(len(odd_bytes), 1)
|
||||
even_zero_ratio = even_bytes.count(0) / max(len(even_bytes), 1)
|
||||
if odd_zero_ratio < 0.8 and even_zero_ratio < 0.8:
|
||||
return None
|
||||
|
||||
for encoding in _TEXT_ENCODINGS:
|
||||
try:
|
||||
decoded = sample.decode(encoding)
|
||||
except UnicodeDecodeError as exc:
|
||||
# Probe samples can end in the middle of a multibyte sequence.
|
||||
# When the decode failure only happens at the sample tail, trim a few
|
||||
# bytes and retry so UTF-8 text is not misclassified as binary.
|
||||
if exc.start >= len(sample) - 4:
|
||||
decoded = ""
|
||||
for trim_bytes in range(1, min(4, len(sample)) + 1):
|
||||
try:
|
||||
decoded = sample[:-trim_bytes].decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
if not decoded:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
if _looks_like_text(decoded):
|
||||
return encoding
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def read_local_text_range_sync(
|
||||
path: str,
|
||||
*,
|
||||
encoding: str,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
) -> str:
|
||||
lines: list[str] = []
|
||||
start = 0 if offset is None else offset
|
||||
end = None if limit is None else start + limit
|
||||
with open(path, encoding=encoding, newline="") as file_obj:
|
||||
for index, line in enumerate(file_obj):
|
||||
if index < start:
|
||||
continue
|
||||
if end is not None and index >= end:
|
||||
break
|
||||
lines.append(line)
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
async def read_local_text_range(
|
||||
path: str,
|
||||
*,
|
||||
encoding: str,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
) -> str:
|
||||
return await to_thread(
|
||||
read_local_text_range_sync,
|
||||
path,
|
||||
encoding=encoding,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
async def _exec_python_json(
|
||||
booter: ComputerBooter,
|
||||
script: str,
|
||||
*,
|
||||
action: str,
|
||||
) -> dict:
|
||||
result = await booter.python.exec(script)
|
||||
data = result.get("data") if isinstance(result.get("data"), dict) else {}
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(f"{action} failed: invalid result format")
|
||||
output = data.get("output") if isinstance(data.get("output"), dict) else {}
|
||||
if not isinstance(output, dict):
|
||||
raise RuntimeError(f"{action} failed: invalid output format")
|
||||
error_text = str(data.get("error", "") or result.get("error", "") or "").strip()
|
||||
if error_text:
|
||||
raise RuntimeError(f"{action} failed: {error_text}")
|
||||
|
||||
text = str(output.get("text", "") or "").strip()
|
||||
if not text:
|
||||
raise RuntimeError(f"{action} failed: empty output")
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"{action} failed: invalid JSON output") from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"{action} failed: invalid JSON payload")
|
||||
return payload
|
||||
|
||||
|
||||
async def _probe_local_file(path: str) -> dict[str, str | int]:
|
||||
def _run() -> dict[str, str | int]:
|
||||
file_path = Path(path)
|
||||
with file_path.open("rb") as file_obj:
|
||||
sample = file_obj.read(_FILE_SNIFF_BYTES)
|
||||
return {
|
||||
"size_bytes": file_path.stat().st_size,
|
||||
"sample_b64": base64.b64encode(sample).decode("ascii"),
|
||||
}
|
||||
|
||||
return await to_thread(_run)
|
||||
|
||||
|
||||
async def _read_local_image_base64(path: str) -> dict[str, str | int]:
|
||||
def _run() -> dict[str, str | int]:
|
||||
data = Path(path).read_bytes()
|
||||
return {
|
||||
"size_bytes": len(data),
|
||||
"base64": base64.b64encode(data).decode("ascii"),
|
||||
}
|
||||
|
||||
return await to_thread(_run)
|
||||
|
||||
|
||||
async def _read_local_file_bytes(path: str) -> bytes:
|
||||
return await to_thread(Path(path).read_bytes)
|
||||
|
||||
|
||||
async def _compress_image_bytes_to_base64(data: bytes) -> dict[str, str | int]:
|
||||
def _run() -> dict[str, str | int]:
|
||||
temp_dir = Path(get_astrbot_temp_path())
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
compressed_path = Path(
|
||||
_compress_image_sync(
|
||||
data,
|
||||
temp_dir,
|
||||
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
|
||||
IMAGE_COMPRESS_DEFAULT_QUALITY,
|
||||
IMAGE_COMPRESS_DEFAULT_OPTIMIZE,
|
||||
)
|
||||
)
|
||||
try:
|
||||
compressed_bytes = compressed_path.read_bytes()
|
||||
finally:
|
||||
compressed_path.unlink(missing_ok=True)
|
||||
|
||||
return {
|
||||
"size_bytes": len(compressed_bytes),
|
||||
"base64": base64.b64encode(compressed_bytes).decode("ascii"),
|
||||
"mime_type": "image/jpeg",
|
||||
}
|
||||
|
||||
return await to_thread(_run)
|
||||
|
||||
|
||||
def _detect_image_mime(sample: bytes) -> str | None:
|
||||
if sample.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
if sample.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
if sample.startswith((b"GIF87a", b"GIF89a")):
|
||||
return "image/gif"
|
||||
if sample.startswith(b"BM"):
|
||||
return "image/bmp"
|
||||
if sample.startswith((b"II*\x00", b"MM\x00*")):
|
||||
return "image/tiff"
|
||||
if sample.startswith(b"\x00\x00\x01\x00"):
|
||||
return "image/x-icon"
|
||||
if len(sample) >= 12 and sample[:4] == b"RIFF" and sample[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
if len(sample) >= 12 and sample[4:12] in (b"ftypavif", b"ftypavis"):
|
||||
return "image/avif"
|
||||
return None
|
||||
|
||||
|
||||
def _looks_like_known_binary(sample: bytes) -> bool:
|
||||
return any(sample.startswith(prefix) for prefix in _BINARY_MAGIC_PREFIXES)
|
||||
|
||||
|
||||
def _looks_like_pdf(path: str, sample: bytes) -> bool:
|
||||
return Path(path).suffix.lower() == ".pdf" or sample.startswith(b"%PDF-")
|
||||
|
||||
|
||||
def _looks_like_zip_container(sample: bytes) -> bool:
|
||||
return any(sample.startswith(prefix) for prefix in _ZIP_MAGIC_PREFIXES)
|
||||
|
||||
|
||||
def _is_docx_bytes(file_bytes: bytes) -> bool:
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_bytes)) as archive:
|
||||
names = set(archive.namelist())
|
||||
except (OSError, zipfile.BadZipFile):
|
||||
return False
|
||||
|
||||
if "[Content_Types].xml" not in names:
|
||||
return False
|
||||
|
||||
return any(name.startswith("word/") for name in names)
|
||||
|
||||
|
||||
async def _parse_local_docx_text(file_bytes: bytes, file_name: str) -> str:
|
||||
from astrbot.core.knowledge_base.parsers.markitdown_parser import (
|
||||
MarkitdownParser,
|
||||
)
|
||||
|
||||
result = await MarkitdownParser().parse(file_bytes, file_name)
|
||||
return result.text
|
||||
|
||||
|
||||
async def _parse_local_pdf_text(file_bytes: bytes, file_name: str) -> str:
|
||||
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
|
||||
|
||||
result = await PDFParser().parse(file_bytes, file_name)
|
||||
return result.text
|
||||
|
||||
|
||||
async def _parse_local_supported_document(
|
||||
path: str,
|
||||
sample: bytes,
|
||||
) -> ParsedDocument | None:
|
||||
file_name = Path(path).name
|
||||
if _looks_like_pdf(path, sample):
|
||||
file_bytes = await _read_local_file_bytes(path)
|
||||
text = await _parse_local_pdf_text(file_bytes, file_name)
|
||||
return ParsedDocument(kind="pdf", file_bytes=file_bytes, text=text)
|
||||
|
||||
if Path(path).suffix.lower() == ".docx" or _looks_like_zip_container(sample):
|
||||
file_bytes = await _read_local_file_bytes(path)
|
||||
if not _is_docx_bytes(file_bytes):
|
||||
return None
|
||||
text = await _parse_local_docx_text(file_bytes, file_name)
|
||||
return ParsedDocument(kind="docx", file_bytes=file_bytes, text=text)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _probe_file(sample: bytes, *, size_bytes: int) -> FileProbe:
|
||||
if image_mime := _detect_image_mime(sample):
|
||||
return FileProbe(
|
||||
kind="image",
|
||||
encoding=None,
|
||||
mime_type=image_mime,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
if _looks_like_known_binary(sample):
|
||||
return FileProbe(
|
||||
kind="binary",
|
||||
encoding=None,
|
||||
mime_type=None,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
if encoding := detect_text_encoding(sample):
|
||||
return FileProbe(
|
||||
kind="text",
|
||||
encoding=encoding,
|
||||
mime_type="text/plain",
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
return FileProbe(
|
||||
kind="binary",
|
||||
encoding=None,
|
||||
mime_type=None,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
|
||||
def _validate_text_output(content: str) -> str | None:
|
||||
content_bytes = len(content.encode("utf-8"))
|
||||
if content_bytes > _MAX_FILE_READ_BYTES:
|
||||
return (
|
||||
"Error reading file: "
|
||||
f"output exceeds {_MAX_FILE_READ_BYTES} bytes "
|
||||
f"({content_bytes} bytes). Use `offset`, `limit` to narrow the read window."
|
||||
)
|
||||
|
||||
content_tokens = _TOKEN_COUNTER.count_tokens(
|
||||
[Message(role="user", content=content)]
|
||||
)
|
||||
if content_tokens > _MAX_FILE_READ_TOKENS:
|
||||
return (
|
||||
"Error reading file: "
|
||||
f"output exceeds {_MAX_FILE_READ_TOKENS} tokens "
|
||||
f"({content_tokens} tokens). Use `offset`, `limit` to narrow the read window."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _text_exceeds_read_thresholds(content: str) -> bool:
|
||||
return _validate_text_output(content) is not None
|
||||
|
||||
|
||||
def _validate_full_text_read_request(probe: FileProbe) -> str | None:
|
||||
if probe.size_bytes > _MAX_TEXT_FILE_FULL_READ_BYTES:
|
||||
return (
|
||||
"Error reading file: "
|
||||
f"text file exceeds {_MAX_TEXT_FILE_FULL_READ_BYTES} bytes "
|
||||
f"({probe.size_bytes} bytes). Use `offset` and `limit` to narrow the read window."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _slice_text_by_lines(
|
||||
content: str,
|
||||
*,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
) -> str:
|
||||
if offset is None and limit is None:
|
||||
return content
|
||||
|
||||
lines = content.splitlines(keepends=True)
|
||||
start = 0 if offset is None else offset
|
||||
end = None if limit is None else start + limit
|
||||
return "".join(lines[start:end])
|
||||
|
||||
|
||||
async def _store_converted_text_for_workspace(
|
||||
*,
|
||||
workspace_dir: str,
|
||||
original_path: str,
|
||||
original_bytes: bytes,
|
||||
content: str,
|
||||
) -> str:
|
||||
def _run() -> str:
|
||||
original_name = Path(original_path).name
|
||||
digest_suffix = hashlib.md5(original_bytes).hexdigest()[-6:]
|
||||
target_dir = (
|
||||
Path(workspace_dir) / "converted_files" / f"{original_name}_{digest_suffix}"
|
||||
)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_path = target_dir / "text.txt"
|
||||
target_path.write_text(content, encoding="utf-8")
|
||||
return str(target_path)
|
||||
|
||||
return await to_thread(_run)
|
||||
|
||||
|
||||
def _build_converted_text_notice(
|
||||
converted_text_path: str,
|
||||
*,
|
||||
selection_returned: bool,
|
||||
selection_too_large: bool = False,
|
||||
) -> str:
|
||||
if selection_too_large:
|
||||
return (
|
||||
"Converted text was saved to "
|
||||
f"`{converted_text_path}`. The requested output is still too large to "
|
||||
"return directly. Read or grep that file with a narrower window."
|
||||
)
|
||||
|
||||
if selection_returned:
|
||||
return (
|
||||
"Full converted text is also available at "
|
||||
f"`{converted_text_path}`. Read or grep that file with a narrow "
|
||||
"window for additional reads."
|
||||
)
|
||||
|
||||
return (
|
||||
"Converted text was saved to "
|
||||
f"`{converted_text_path}` because the parsed document is too large to "
|
||||
"return directly. Read or grep that file with a narrow window."
|
||||
)
|
||||
|
||||
|
||||
async def _read_local_supported_document_result(
|
||||
*,
|
||||
path: str,
|
||||
parsed_document: ParsedDocument,
|
||||
workspace_dir: str | None,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
) -> ToolExecResult:
|
||||
content = parsed_document.text
|
||||
if not content:
|
||||
return "No content found at the requested line offset."
|
||||
|
||||
if not _text_exceeds_read_thresholds(content):
|
||||
selected_content = _slice_text_by_lines(content, offset=offset, limit=limit)
|
||||
if not selected_content:
|
||||
return "No content found at the requested line offset."
|
||||
if validation_error := _validate_text_output(selected_content):
|
||||
return validation_error
|
||||
return selected_content
|
||||
|
||||
if not workspace_dir:
|
||||
return (
|
||||
"Error reading file: parsed document exceeds the read output limit and "
|
||||
"no workspace is available for storing converted text."
|
||||
)
|
||||
|
||||
converted_text_path = await _store_converted_text_for_workspace(
|
||||
workspace_dir=workspace_dir,
|
||||
original_path=path,
|
||||
original_bytes=parsed_document.file_bytes,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if offset is None and limit is None:
|
||||
return _build_converted_text_notice(
|
||||
converted_text_path,
|
||||
selection_returned=False,
|
||||
)
|
||||
|
||||
selected_content = _slice_text_by_lines(content, offset=offset, limit=limit)
|
||||
if not selected_content:
|
||||
return (
|
||||
"No content found at the requested line offset. "
|
||||
+ _build_converted_text_notice(
|
||||
converted_text_path,
|
||||
selection_returned=False,
|
||||
)
|
||||
)
|
||||
|
||||
notice = _build_converted_text_notice(
|
||||
converted_text_path,
|
||||
selection_returned=True,
|
||||
)
|
||||
combined_output = f"{selected_content}\n\n[{notice}]"
|
||||
if _validate_text_output(combined_output):
|
||||
if _validate_text_output(selected_content):
|
||||
return _build_converted_text_notice(
|
||||
converted_text_path,
|
||||
selection_returned=False,
|
||||
selection_too_large=True,
|
||||
)
|
||||
return selected_content
|
||||
|
||||
return combined_output
|
||||
|
||||
|
||||
async def read_file_tool_result(
|
||||
booter: ComputerBooter,
|
||||
*,
|
||||
local_mode: bool,
|
||||
path: str,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
workspace_dir: str | None = None,
|
||||
) -> ToolExecResult:
|
||||
if local_mode:
|
||||
probe_payload = await _probe_local_file(path)
|
||||
else:
|
||||
probe_payload = await _exec_python_json(
|
||||
booter,
|
||||
_build_probe_script(path),
|
||||
action="file probe",
|
||||
)
|
||||
sample_b64 = str(probe_payload.get("sample_b64", "") or "")
|
||||
sample = base64.b64decode(sample_b64) if sample_b64 else b""
|
||||
size_bytes = int(probe_payload.get("size_bytes", 0) or 0)
|
||||
probe = _probe_file(sample, size_bytes=size_bytes)
|
||||
|
||||
if local_mode:
|
||||
try:
|
||||
parsed_document = await _parse_local_supported_document(path, sample)
|
||||
except Exception as exc:
|
||||
return f"Error reading file: failed to parse document: {exc}"
|
||||
|
||||
if parsed_document is not None:
|
||||
return await _read_local_supported_document_result(
|
||||
path=path,
|
||||
parsed_document=parsed_document,
|
||||
workspace_dir=workspace_dir,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if probe.kind == "binary":
|
||||
return "Error reading file: binary files are not supported by this tool."
|
||||
|
||||
if probe.kind == "image":
|
||||
if local_mode:
|
||||
image_payload = await _read_local_image_base64(path)
|
||||
else:
|
||||
image_payload = await _exec_python_json(
|
||||
booter,
|
||||
_build_image_read_script(path),
|
||||
action="image read",
|
||||
)
|
||||
raw_base64_data = str(image_payload.get("base64", "") or "")
|
||||
if not raw_base64_data:
|
||||
return "Error reading file: image payload is empty."
|
||||
raw_bytes = base64.b64decode(raw_base64_data)
|
||||
compressed_payload = await _compress_image_bytes_to_base64(raw_bytes)
|
||||
base64_data = str(compressed_payload.get("base64", "") or "")
|
||||
if not base64_data:
|
||||
return "Error reading file: compressed image payload is empty."
|
||||
return mcp.types.CallToolResult(
|
||||
content=[
|
||||
mcp.types.ImageContent(
|
||||
type="image",
|
||||
data=base64_data,
|
||||
mimeType=str(
|
||||
compressed_payload.get("mime_type", "") or "image/jpeg"
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if offset is None and limit is None:
|
||||
if validation_error := _validate_full_text_read_request(probe):
|
||||
return validation_error
|
||||
|
||||
if local_mode:
|
||||
content = await read_local_text_range(
|
||||
path,
|
||||
encoding=probe.encoding or "utf-8",
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
else:
|
||||
text_payload = await _exec_python_json(
|
||||
booter,
|
||||
_build_text_read_script(
|
||||
path,
|
||||
encoding=probe.encoding or "utf-8",
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
),
|
||||
action="text read",
|
||||
)
|
||||
content = str(text_payload.get("content", "") or "")
|
||||
|
||||
if not content:
|
||||
return "No content found at the requested line offset."
|
||||
|
||||
if validation_error := _validate_text_output(content):
|
||||
return validation_error
|
||||
|
||||
return content
|
||||
@@ -12,8 +12,36 @@ class FileSystemComponent(Protocol):
|
||||
"""Create a file with the specified content"""
|
||||
...
|
||||
|
||||
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
|
||||
"""Read file content"""
|
||||
async def read_file(
|
||||
self,
|
||||
path: str,
|
||||
encoding: str = "utf-8",
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Read file content by line window"""
|
||||
...
|
||||
|
||||
async def search_files(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
after_context: int | None = None,
|
||||
before_context: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Search file contents"""
|
||||
...
|
||||
|
||||
async def edit_file(
|
||||
self,
|
||||
path: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
replace_all: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
) -> dict[str, Any]:
|
||||
"""Edit file content by string replacement"""
|
||||
...
|
||||
|
||||
async def write_file(
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.api import FunctionTool, logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import File
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
from ..computer_client import get_booter
|
||||
from .permissions import check_admin_permission
|
||||
|
||||
# @dataclass
|
||||
# class CreateFileTool(FunctionTool):
|
||||
# name: str = "astrbot_create_file"
|
||||
# description: str = "Create a new file in the sandbox."
|
||||
# parameters: dict = field(
|
||||
# default_factory=lambda: {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "path": {
|
||||
# "path": "string",
|
||||
# "description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
|
||||
# },
|
||||
# "content": {
|
||||
# "type": "string",
|
||||
# "description": "The content to write into the file.",
|
||||
# },
|
||||
# },
|
||||
# "required": ["path", "content"],
|
||||
# }
|
||||
# )
|
||||
|
||||
# async def call(
|
||||
# self, context: ContextWrapper[AstrAgentContext], path: str, content: str
|
||||
# ) -> ToolExecResult:
|
||||
# sb = await get_booter(
|
||||
# context.context.context,
|
||||
# context.context.event.unified_msg_origin,
|
||||
# )
|
||||
# try:
|
||||
# result = await sb.fs.create_file(path, content)
|
||||
# return json.dumps(result)
|
||||
# except Exception as e:
|
||||
# return f"Error creating file: {str(e)}"
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class ReadFileTool(FunctionTool):
|
||||
# name: str = "astrbot_read_file"
|
||||
# description: str = "Read the content of a file in the sandbox."
|
||||
# parameters: dict = field(
|
||||
# default_factory=lambda: {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "path": {
|
||||
# "type": "string",
|
||||
# "description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
|
||||
# },
|
||||
# },
|
||||
# "required": ["path"],
|
||||
# }
|
||||
# )
|
||||
|
||||
# async def call(self, context: ContextWrapper[AstrAgentContext], path: str):
|
||||
# sb = await get_booter(
|
||||
# context.context.context,
|
||||
# context.context.event.unified_msg_origin,
|
||||
# )
|
||||
# try:
|
||||
# result = await sb.fs.read_file(path)
|
||||
# return result
|
||||
# except Exception as e:
|
||||
# return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileUploadTool(FunctionTool):
|
||||
name: str = "astrbot_upload_file"
|
||||
description: str = (
|
||||
"Transfer a file FROM the host machine INTO the sandbox so that sandbox "
|
||||
"code can access it. Use this when the user sends/attaches a file and you "
|
||||
"need to process it inside the sandbox. The local_path must point to an "
|
||||
"existing file on the host filesystem."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"local_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file on the host filesystem that will be copied into the sandbox.",
|
||||
},
|
||||
# "remote_path": {
|
||||
# "type": "string",
|
||||
# "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.",
|
||||
# },
|
||||
},
|
||||
"required": ["local_path"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
local_path: str,
|
||||
) -> str | None:
|
||||
if permission_error := check_admin_permission(context, "File upload/download"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not os.path.exists(local_path):
|
||||
return f"Error: File does not exist: {local_path}"
|
||||
|
||||
if not os.path.isfile(local_path):
|
||||
return f"Error: Path is not a file: {local_path}"
|
||||
|
||||
# Use basename if sandbox_filename is not provided
|
||||
remote_path = os.path.basename(local_path)
|
||||
|
||||
# Upload file to sandbox
|
||||
result = await sb.upload_file(local_path, remote_path)
|
||||
logger.debug(f"Upload result: {result}")
|
||||
success = result.get("success", False)
|
||||
|
||||
if not success:
|
||||
return f"Error uploading file: {result.get('message', 'Unknown error')}"
|
||||
|
||||
file_path = result.get("file_path", "")
|
||||
logger.info(f"File {local_path} uploaded to sandbox at {file_path}")
|
||||
|
||||
return f"File uploaded successfully to {file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file {local_path}: {e}")
|
||||
return f"Error uploading file: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileDownloadTool(FunctionTool):
|
||||
name: str = "astrbot_download_file"
|
||||
description: str = (
|
||||
"Transfer a file FROM the sandbox OUT to the host and optionally send it "
|
||||
"to the user. Use this ONLY when the user asks to retrieve/export a file "
|
||||
"that was created or modified inside the sandbox."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"remote_path": {
|
||||
"type": "string",
|
||||
"description": "Path of the file inside the sandbox to copy out to the host.",
|
||||
},
|
||||
"also_send_to_user": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to also send the downloaded file to the user via message. Defaults to true.",
|
||||
},
|
||||
},
|
||||
"required": ["remote_path"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
remote_path: str,
|
||||
also_send_to_user: bool = True,
|
||||
) -> ToolExecResult:
|
||||
if permission_error := check_admin_permission(context, "File upload/download"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
try:
|
||||
name = os.path.basename(remote_path)
|
||||
|
||||
local_path = os.path.join(
|
||||
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
|
||||
)
|
||||
|
||||
# Download file from sandbox
|
||||
await sb.download_file(remote_path, local_path)
|
||||
logger.info(f"File {remote_path} downloaded from sandbox to {local_path}")
|
||||
|
||||
if also_send_to_user:
|
||||
try:
|
||||
name = os.path.basename(local_path)
|
||||
await context.context.event.send(
|
||||
MessageChain(chain=[File(name=name, file=local_path)])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending file message: {e}")
|
||||
|
||||
# remove
|
||||
# try:
|
||||
# os.remove(local_path)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error removing temp file {local_path}: {e}")
|
||||
|
||||
return f"File downloaded successfully to {local_path} and sent to user."
|
||||
|
||||
return f"File downloaded successfully to {local_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading file {remote_path}: {e}")
|
||||
return f"Error downloading file: {str(e)}"
|
||||
@@ -36,6 +36,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
)
|
||||
from astrbot.core.subagent_orchestrator import SubAgentOrchestrator
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_system_tmp_path
|
||||
|
||||
from ..exceptions import ProviderNotFoundError
|
||||
from .filter.command import CommandFilter
|
||||
@@ -232,6 +233,13 @@ class Context:
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["stream", "agent_hooks", "agent_context"]
|
||||
}
|
||||
if request.func_tool and request.func_tool.get_tool("astrbot_file_read_tool"):
|
||||
other_kwargs.setdefault(
|
||||
"tool_result_overflow_dir", get_astrbot_system_tmp_path()
|
||||
)
|
||||
other_kwargs.setdefault(
|
||||
"read_tool", request.func_tool.get_tool("astrbot_file_read_tool")
|
||||
)
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
|
||||
55
astrbot/core/tools/computer_tools/__init__.py
Normal file
55
astrbot/core/tools/computer_tools/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from .fs import (
|
||||
FileDownloadTool,
|
||||
FileEditTool,
|
||||
FileReadTool,
|
||||
FileUploadTool,
|
||||
FileWriteTool,
|
||||
GrepTool,
|
||||
)
|
||||
from .python import LocalPythonTool, PythonTool
|
||||
from .shell import ExecuteShellTool
|
||||
from .shipyard_neo import (
|
||||
AnnotateExecutionTool,
|
||||
BrowserBatchExecTool,
|
||||
BrowserExecTool,
|
||||
CreateSkillCandidateTool,
|
||||
CreateSkillPayloadTool,
|
||||
EvaluateSkillCandidateTool,
|
||||
GetExecutionHistoryTool,
|
||||
GetSkillPayloadTool,
|
||||
ListSkillCandidatesTool,
|
||||
ListSkillReleasesTool,
|
||||
PromoteSkillCandidateTool,
|
||||
RollbackSkillReleaseTool,
|
||||
RunBrowserSkillTool,
|
||||
SyncSkillReleaseTool,
|
||||
)
|
||||
from .util import check_admin_permission, normalize_umo_for_workspace
|
||||
|
||||
__all__ = [
|
||||
"AnnotateExecutionTool",
|
||||
"BrowserBatchExecTool",
|
||||
"BrowserExecTool",
|
||||
"CreateSkillCandidateTool",
|
||||
"CreateSkillPayloadTool",
|
||||
"EvaluateSkillCandidateTool",
|
||||
"ExecuteShellTool",
|
||||
"FileDownloadTool",
|
||||
"FileEditTool",
|
||||
"FileReadTool",
|
||||
"FileUploadTool",
|
||||
"FileWriteTool",
|
||||
"GetExecutionHistoryTool",
|
||||
"GetSkillPayloadTool",
|
||||
"GrepTool",
|
||||
"ListSkillCandidatesTool",
|
||||
"ListSkillReleasesTool",
|
||||
"LocalPythonTool",
|
||||
"PromoteSkillCandidateTool",
|
||||
"PythonTool",
|
||||
"RollbackSkillReleaseTool",
|
||||
"RunBrowserSkillTool",
|
||||
"SyncSkillReleaseTool",
|
||||
"normalize_umo_for_workspace",
|
||||
"check_admin_permission",
|
||||
]
|
||||
747
astrbot/core/tools/computer_tools/fs.py
Normal file
747
astrbot/core/tools/computer_tools/fs.py
Normal file
@@ -0,0 +1,747 @@
|
||||
"""Filesystem tool audit.
|
||||
|
||||
Tool exposure from the main agent:
|
||||
- Local runtime exposes `astrbot_read_file_tool`, `astrbot_file_write_tool`,
|
||||
`astrbot_file_edit_tool`, and `astrbot_grep_tool`.
|
||||
- Sandbox runtime exposes `astrbot_upload_file`, `astrbot_download_file`,
|
||||
`astrbot_read_file_tool`, `astrbot_file_write_tool`,
|
||||
`astrbot_file_edit_tool`, and `astrbot_grep_tool`.
|
||||
|
||||
Behavior when `provider_settings.computer_use_require_admin=True`:
|
||||
- Admin + local: read/write/edit/grep are not path-restricted by this module;
|
||||
access depends on the local runtime implementation and host OS permissions.
|
||||
Upload and download tools are defined here, but `LocalBooter` does not
|
||||
implement them and the main agent does not expose them in local mode.
|
||||
- Member + local: read/write/edit/grep are restricted to `data/skills`,
|
||||
`data/workspaces/{normalized_umo}`, and `/tmp/.astrbot`. Upload/download are
|
||||
denied by `check_admin_permission` if invoked.
|
||||
- Admin + sandbox: read/write/edit/grep are not path-restricted by this
|
||||
module;
|
||||
sandbox filesystem boundaries are enforced by the sandbox runtime. Upload and
|
||||
download are allowed.
|
||||
- Member + sandbox: read/write/edit/grep are also not path-restricted by this
|
||||
module. Upload/download are denied by `check_admin_permission` if invoked.
|
||||
|
||||
When `computer_use_require_admin=False`, member behavior in this module matches
|
||||
admin behavior.
|
||||
|
||||
Local path resolution rule:
|
||||
- In local runtime, relative paths are resolved under
|
||||
`data/workspaces/{normalized_umo}`.
|
||||
- In sandbox runtime, relative paths are passed through unchanged.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.api import FunctionTool, logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.computer.file_read_utils import read_file_tool_result
|
||||
from astrbot.core.message.components import File
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_skills_path,
|
||||
get_astrbot_system_tmp_path,
|
||||
get_astrbot_temp_path,
|
||||
)
|
||||
|
||||
from ..registry import builtin_tool
|
||||
from . import util as computer_util
|
||||
from .util import (
|
||||
check_admin_permission,
|
||||
is_local_runtime,
|
||||
normalize_umo_for_workspace,
|
||||
)
|
||||
|
||||
_COMPUTER_RUNTIME_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": ("local", "sandbox"),
|
||||
}
|
||||
_SANDBOX_RUNTIME_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
}
|
||||
|
||||
|
||||
def _restricted_env_path_labels(umo: str) -> list[str]:
|
||||
"""Labels for the allowed directories in a local(not sandbox) and restricted(not admin) environment"""
|
||||
normalized_umo = normalize_umo_for_workspace(umo)
|
||||
return [
|
||||
"data/skills",
|
||||
f"data/workspaces/{normalized_umo}",
|
||||
get_astrbot_system_tmp_path(),
|
||||
]
|
||||
|
||||
|
||||
def get_astrbot_workspaces_path() -> str:
|
||||
"""Compatibility wrapper for tests and older module-level monkeypatches."""
|
||||
return computer_util.get_astrbot_workspaces_path()
|
||||
|
||||
|
||||
def _workspace_root(umo: str) -> Path:
|
||||
"""Workspace root that follows both util-level and fs-level getter monkeypatches."""
|
||||
normalized_umo = normalize_umo_for_workspace(umo)
|
||||
return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False)
|
||||
|
||||
|
||||
def _read_allowed_roots(umo: str) -> tuple[Path, ...]:
|
||||
"""Non-admin users can only read files within these directories (and their subdirectories)"""
|
||||
return (
|
||||
Path(get_astrbot_skills_path()).resolve(strict=False),
|
||||
_workspace_root(umo),
|
||||
Path(get_astrbot_system_tmp_path()).resolve(strict=False),
|
||||
)
|
||||
|
||||
|
||||
def _is_restricted_env(context: ContextWrapper[AstrAgentContext]) -> bool:
|
||||
if not is_local_runtime(context):
|
||||
return False
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
require_admin = provider_settings.get("computer_use_require_admin", True)
|
||||
return require_admin and context.context.event.role != "admin"
|
||||
|
||||
|
||||
def _resolve_tool_path(path: str, *, local_env: bool, umo: str) -> str:
|
||||
normalized_path = path.strip()
|
||||
if not normalized_path:
|
||||
return normalized_path
|
||||
candidate = Path(normalized_path).expanduser()
|
||||
if candidate.is_absolute():
|
||||
return str(candidate.resolve(strict=False))
|
||||
if local_env:
|
||||
return str((_workspace_root(umo) / candidate).resolve(strict=False))
|
||||
return normalized_path
|
||||
|
||||
|
||||
def _resolve_user_path(path: str, *, local_env: bool, umo: str) -> Path:
|
||||
candidate = Path(path).expanduser()
|
||||
if candidate.is_absolute():
|
||||
return candidate.resolve(strict=False)
|
||||
if local_env:
|
||||
return (_workspace_root(umo) / candidate).resolve(strict=False)
|
||||
return (Path.cwd() / candidate).resolve(strict=False)
|
||||
|
||||
|
||||
def _is_path_within_allowed_roots(path: str, umo: str) -> bool:
|
||||
resolved = _resolve_user_path(path, local_env=True, umo=umo)
|
||||
return any(
|
||||
resolved == allowed_root or resolved.is_relative_to(allowed_root)
|
||||
for allowed_root in _read_allowed_roots(umo)
|
||||
)
|
||||
|
||||
|
||||
def _normalize_rw_path(
|
||||
path: str,
|
||||
*,
|
||||
restricted: bool,
|
||||
local_env: bool,
|
||||
umo: str,
|
||||
) -> str:
|
||||
normalized_path = _resolve_tool_path(path, local_env=local_env, umo=umo)
|
||||
if not normalized_path:
|
||||
raise ValueError("`path` must be a non-empty string.")
|
||||
if restricted and not _is_path_within_allowed_roots(normalized_path, umo):
|
||||
allowed = ", ".join(_restricted_env_path_labels(umo))
|
||||
raise PermissionError(
|
||||
"Read access is restricted for this user. "
|
||||
f"Allowed directories: {allowed}. Blocked path: {normalized_path}."
|
||||
)
|
||||
return normalized_path
|
||||
|
||||
|
||||
def _decode_escaped_text(value: str) -> str:
|
||||
"""Decode common escaped control sequences used in tool arguments."""
|
||||
return (
|
||||
value.replace("\\r\\n", "\n")
|
||||
.replace("\\n", "\n")
|
||||
.replace("\\r", "\r")
|
||||
.replace("\\t", "\t")
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class FileReadTool(FunctionTool):
|
||||
name: str = "astrbot_file_read_tool"
|
||||
description: str = "read file content."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path of the file to read. If relative, will be in workspace root.",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Optional line offset to start reading from. 0-based index.",
|
||||
"minimum": 0,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Optional maximum number of lines to read.",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
)
|
||||
|
||||
def _validate_read_window(
|
||||
self,
|
||||
offset: int | None,
|
||||
limit: int | None,
|
||||
) -> tuple[int | None, int | None]:
|
||||
if offset is not None and offset < 0:
|
||||
raise ValueError("`offset` must be greater than or equal to 0.")
|
||||
if limit is not None and limit < 1:
|
||||
raise ValueError("`limit` must be greater than or equal to 1.")
|
||||
return offset, limit
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
path: str,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ToolExecResult:
|
||||
local_env = is_local_runtime(context)
|
||||
restricted = _is_restricted_env(context)
|
||||
try:
|
||||
normalized_path = (
|
||||
_normalize_rw_path(
|
||||
path,
|
||||
restricted=restricted,
|
||||
local_env=local_env,
|
||||
umo=context.context.event.unified_msg_origin,
|
||||
)
|
||||
if local_env
|
||||
else path.strip()
|
||||
)
|
||||
if not normalized_path:
|
||||
raise ValueError("`path` must be a non-empty string.")
|
||||
offset, limit = self._validate_read_window(offset, limit)
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
return await read_file_tool_result(
|
||||
sb,
|
||||
local_mode=local_env,
|
||||
path=normalized_path,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
workspace_dir=(
|
||||
str(_workspace_root(context.context.event.unified_msg_origin))
|
||||
if local_env
|
||||
else None
|
||||
),
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except Exception as exc:
|
||||
logger.error(f"Error reading file: {exc}")
|
||||
return f"Error reading file: {exc}"
|
||||
|
||||
|
||||
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class FileWriteTool(FunctionTool):
|
||||
name: str = "astrbot_file_write_tool"
|
||||
description: str = "Write UTF-8 text content to a file."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path of the file to write. If relative, will be in workspace root.",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file",
|
||||
},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
path: str,
|
||||
content: str,
|
||||
) -> ToolExecResult:
|
||||
local_env = is_local_runtime(context)
|
||||
restricted = _is_restricted_env(context)
|
||||
try:
|
||||
normalized_path = (
|
||||
_normalize_rw_path(
|
||||
path,
|
||||
restricted=restricted,
|
||||
local_env=local_env,
|
||||
umo=context.context.event.unified_msg_origin,
|
||||
)
|
||||
if local_env
|
||||
else path.strip()
|
||||
)
|
||||
if not normalized_path:
|
||||
raise ValueError("`path` must be a non-empty string.")
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
result = await sb.fs.write_file(
|
||||
path=normalized_path,
|
||||
content=content,
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
)
|
||||
if not result.get("success", False):
|
||||
error_detail = str(result.get("error", "") or "").strip()
|
||||
return (
|
||||
"Error writing file: "
|
||||
f"{error_detail or 'unknown filesystem write error'}"
|
||||
)
|
||||
return f"File written successfully: {normalized_path}"
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except Exception as exc:
|
||||
logger.error(f"Error writing file: {exc}")
|
||||
return f"Error writing file: {exc}"
|
||||
|
||||
|
||||
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class FileEditTool(FunctionTool):
|
||||
name: str = "astrbot_file_edit_tool"
|
||||
description: str = "Editing files."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path of the file to edit. If relative, will be in workspace root.",
|
||||
},
|
||||
"old": {
|
||||
"type": "string",
|
||||
"description": "The exact old text to replace.",
|
||||
},
|
||||
"new": {
|
||||
"type": "string",
|
||||
"description": "The replacement text.",
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to replace all matches. Defaults to false.",
|
||||
},
|
||||
},
|
||||
"required": ["path", "old", "new"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
path: str,
|
||||
old: str,
|
||||
new: str,
|
||||
replace_all: bool = False,
|
||||
) -> ToolExecResult:
|
||||
umo = str(context.context.event.unified_msg_origin)
|
||||
local_env = is_local_runtime(context)
|
||||
restricted = _is_restricted_env(context)
|
||||
try:
|
||||
normalized_path = (
|
||||
_normalize_rw_path(
|
||||
path,
|
||||
restricted=restricted,
|
||||
local_env=local_env,
|
||||
umo=umo,
|
||||
)
|
||||
if local_env
|
||||
else path.strip()
|
||||
)
|
||||
if not normalized_path:
|
||||
raise ValueError("`path` must be a non-empty string.")
|
||||
normalized_old = _decode_escaped_text(old)
|
||||
normalized_new = _decode_escaped_text(new)
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
result = await sb.fs.edit_file(
|
||||
path=normalized_path,
|
||||
old_string=normalized_old,
|
||||
new_string=normalized_new,
|
||||
replace_all=replace_all,
|
||||
encoding="utf-8",
|
||||
)
|
||||
if not result.get("success", False):
|
||||
error_detail = str(result.get("error", "") or "").strip()
|
||||
return (
|
||||
"Error editing file: "
|
||||
f"{error_detail or 'unknown filesystem edit error'}"
|
||||
)
|
||||
replacements = int(result.get("replacements", 0) or 0)
|
||||
mode_text = "all matches" if replace_all else "first match"
|
||||
return (
|
||||
f"Edited {normalized_path}. "
|
||||
f"Replaced {replacements} occurrence(s) using {mode_text} mode."
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except Exception as exc:
|
||||
logger.error(f"Error editing file: {exc}")
|
||||
return f"Error editing file: {exc}"
|
||||
|
||||
|
||||
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class GrepTool(FunctionTool):
|
||||
name: str = "astrbot_grep_tool"
|
||||
description: str = "Search and read file contents using ripgrep."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The expression pattern to search for in file contents.",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory to search in (rg PATH). If relative, will be in workspace root.",
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Optional glob filter such as `*.py`, `*.{ts,tsx}`.",
|
||||
},
|
||||
"-A": {
|
||||
"type": "integer",
|
||||
"description": "Number of trailing context lines to include after each match.",
|
||||
"minimum": 0,
|
||||
},
|
||||
"-B": {
|
||||
"type": "integer",
|
||||
"description": "Number of leading context lines to include before each match.",
|
||||
"minimum": 0,
|
||||
},
|
||||
"-C": {
|
||||
"type": "integer",
|
||||
"description": "Number of leading and trailing context lines to include around each match.",
|
||||
"minimum": 0,
|
||||
},
|
||||
"result_limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of result groups returned by the tool. Defaults to 100.",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_context_options(
|
||||
self,
|
||||
after_context: int | None,
|
||||
before_context: int | None,
|
||||
context: int | None,
|
||||
) -> tuple[int | None, int | None]:
|
||||
if context is not None and context < 0:
|
||||
raise ValueError("`-C` must be greater than or equal to 0.")
|
||||
if after_context is not None and after_context < 0:
|
||||
raise ValueError("`-A` must be greater than or equal to 0.")
|
||||
if before_context is not None and before_context < 0:
|
||||
raise ValueError("`-B` must be greater than or equal to 0.")
|
||||
|
||||
resolved_after = context if after_context is None else after_context
|
||||
resolved_before = context if before_context is None else before_context
|
||||
return resolved_after, resolved_before
|
||||
|
||||
def _split_output_groups(self, output: str, *, has_context: bool) -> list[str]:
|
||||
if not output.strip():
|
||||
return []
|
||||
|
||||
if not has_context:
|
||||
return [f"{line}\n" for line in output.splitlines() if line.strip()]
|
||||
|
||||
groups: list[str] = []
|
||||
current: list[str] = []
|
||||
|
||||
for line in output.splitlines(keepends=True):
|
||||
if line.strip() == "--":
|
||||
if current:
|
||||
groups.append("".join(current))
|
||||
current = []
|
||||
continue
|
||||
if not line.strip():
|
||||
continue
|
||||
current.append(line)
|
||||
|
||||
if current:
|
||||
groups.append("".join(current))
|
||||
return groups
|
||||
|
||||
def _apply_result_limit(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
result_limit: int,
|
||||
has_context: bool,
|
||||
) -> str:
|
||||
if result_limit < 1:
|
||||
raise ValueError("`result_limit` must be greater than or equal to 1.")
|
||||
|
||||
groups = self._split_output_groups(output, has_context=has_context)
|
||||
if len(groups) <= result_limit:
|
||||
return output if output.strip() else "No matches found."
|
||||
|
||||
limited_output = "".join(groups[:result_limit]).rstrip()
|
||||
return f"{limited_output}\n\n[Truncated to first {result_limit} result groups.]"
|
||||
|
||||
def _normalize_search_paths(
|
||||
self,
|
||||
path: str | None,
|
||||
*,
|
||||
restricted: bool,
|
||||
local_env: bool,
|
||||
umo: str,
|
||||
) -> list[str]:
|
||||
normalized = (
|
||||
[_resolve_tool_path(path, local_env=local_env, umo=umo)] if path else []
|
||||
)
|
||||
if not normalized:
|
||||
if restricted:
|
||||
return [str(root) for root in _read_allowed_roots(umo)]
|
||||
if local_env:
|
||||
return [str(_workspace_root(umo))]
|
||||
return ["."]
|
||||
|
||||
if restricted:
|
||||
disallowed = [
|
||||
path
|
||||
for path in normalized
|
||||
if not _is_path_within_allowed_roots(path, umo)
|
||||
]
|
||||
if disallowed:
|
||||
allowed = ", ".join(_restricted_env_path_labels(umo))
|
||||
blocked = ", ".join(disallowed)
|
||||
raise PermissionError(
|
||||
"Read access is restricted for this user. "
|
||||
f"Allowed directories: {allowed}. Blocked paths: {blocked}."
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
pattern: str,
|
||||
path: str | None = None,
|
||||
glob: str | None = None,
|
||||
result_limit: int = 100,
|
||||
**kwargs,
|
||||
) -> ToolExecResult:
|
||||
normalized_pattern = pattern.strip()
|
||||
if not normalized_pattern:
|
||||
return "Error: `pattern` must be a non-empty string."
|
||||
|
||||
local_env = is_local_runtime(context)
|
||||
restricted = _is_restricted_env(context)
|
||||
try:
|
||||
search_paths = (
|
||||
self._normalize_search_paths(
|
||||
path,
|
||||
restricted=restricted,
|
||||
local_env=local_env,
|
||||
umo=context.context.event.unified_msg_origin,
|
||||
)
|
||||
if local_env
|
||||
else ([path.strip()] if path and path.strip() else ["."])
|
||||
)
|
||||
after_context, before_context = self._resolve_context_options(
|
||||
kwargs.get("-A"),
|
||||
kwargs.get("-B"),
|
||||
kwargs.get("-C"),
|
||||
)
|
||||
has_context = (after_context or 0) > 0 or (before_context or 0) > 0
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
contents: list[str] = []
|
||||
for search_path in search_paths:
|
||||
result = await sb.fs.search_files(
|
||||
pattern=normalized_pattern,
|
||||
path=search_path,
|
||||
glob=glob,
|
||||
after_context=after_context,
|
||||
before_context=before_context,
|
||||
)
|
||||
if not result.get("success", False):
|
||||
error_detail = str(result.get("error", "") or "").strip()
|
||||
logger.error("GrepTool search failed: %s", error_detail)
|
||||
return (
|
||||
"Error searching files: "
|
||||
f"{error_detail or 'unknown filesystem search error'}"
|
||||
)
|
||||
content = str(result.get("content", "") or "")
|
||||
if content:
|
||||
contents.append(content)
|
||||
|
||||
return self._apply_result_limit(
|
||||
"".join(contents),
|
||||
result_limit=result_limit,
|
||||
has_context=has_context,
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except Exception as exc:
|
||||
logger.error(f"Error searching files: {exc}")
|
||||
return f"Error searching files: {exc}"
|
||||
|
||||
|
||||
@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class FileUploadTool(FunctionTool):
|
||||
name: str = "astrbot_upload_file"
|
||||
description: str = (
|
||||
"Transfer a file FROM the host machine INTO the sandbox so that sandbox "
|
||||
"code can access it. Use this when the user sends/attaches a file and you "
|
||||
"need to process it inside the sandbox. The local_path must point to an "
|
||||
"existing file on the host filesystem."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"local_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file on the host filesystem that will be copied into the sandbox.",
|
||||
},
|
||||
# "remote_path": {
|
||||
# "type": "string",
|
||||
# "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.",
|
||||
# },
|
||||
},
|
||||
"required": ["local_path"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
local_path: str,
|
||||
) -> str | None:
|
||||
if permission_error := check_admin_permission(context, "File upload/download"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not os.path.exists(local_path):
|
||||
return f"Error: File does not exist: {local_path}"
|
||||
|
||||
if not os.path.isfile(local_path):
|
||||
return f"Error: Path is not a file: {local_path}"
|
||||
|
||||
# Use basename if sandbox_filename is not provided
|
||||
remote_path = os.path.basename(local_path)
|
||||
|
||||
# Upload file to sandbox
|
||||
result = await sb.upload_file(local_path, remote_path)
|
||||
logger.debug(f"Upload result: {result}")
|
||||
success = result.get("success", False)
|
||||
|
||||
if not success:
|
||||
return f"Error uploading file: {result.get('message', 'Unknown error')}"
|
||||
|
||||
file_path = result.get("file_path", "")
|
||||
logger.info(f"File {local_path} uploaded to sandbox at {file_path}")
|
||||
|
||||
return f"File uploaded successfully to {file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file {local_path}: {e}")
|
||||
return f"Error uploading file: {str(e)}"
|
||||
|
||||
|
||||
@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class FileDownloadTool(FunctionTool):
|
||||
name: str = "astrbot_download_file"
|
||||
description: str = (
|
||||
"Transfer a file FROM the sandbox OUT to the host and optionally send it "
|
||||
"to the user. Use this ONLY when the user asks to retrieve/export a file "
|
||||
"that was created or modified inside the sandbox."
|
||||
)
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"remote_path": {
|
||||
"type": "string",
|
||||
"description": "Path of the file inside the sandbox to copy out to the host.",
|
||||
},
|
||||
"also_send_to_user": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to also send the downloaded file to the user via message. Defaults to true.",
|
||||
},
|
||||
},
|
||||
"required": ["remote_path"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
remote_path: str,
|
||||
also_send_to_user: bool = True,
|
||||
) -> ToolExecResult:
|
||||
if permission_error := check_admin_permission(context, "File upload/download"):
|
||||
return permission_error
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
try:
|
||||
name = os.path.basename(remote_path)
|
||||
|
||||
local_path = os.path.join(
|
||||
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
|
||||
)
|
||||
|
||||
# Download file from sandbox
|
||||
await sb.download_file(remote_path, local_path)
|
||||
logger.info(f"File {remote_path} downloaded from sandbox to {local_path}")
|
||||
|
||||
if also_send_to_user:
|
||||
try:
|
||||
name = os.path.basename(local_path)
|
||||
await context.context.event.send(
|
||||
MessageChain(chain=[File(name=name, file=local_path)])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending file message: {e}")
|
||||
|
||||
# remove
|
||||
# try:
|
||||
# os.remove(local_path)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error removing temp file {local_path}: {e}")
|
||||
|
||||
return f"File downloaded successfully to {local_path} and sent to user."
|
||||
|
||||
return f"File downloaded successfully to {local_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading file {remote_path}: {e}")
|
||||
return f"Error downloading file: {str(e)}"
|
||||
@@ -8,10 +8,18 @@ from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
|
||||
from astrbot.core.computer.computer_client import get_booter, get_local_booter
|
||||
from astrbot.core.computer.tools.permissions import check_admin_permission
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from ..registry import builtin_tool
|
||||
from .util import check_admin_permission
|
||||
|
||||
_OS_NAME = platform.system()
|
||||
_SANDBOX_PYTHON_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
}
|
||||
_LOCAL_PYTHON_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": "local",
|
||||
}
|
||||
|
||||
param_schema = {
|
||||
"type": "object",
|
||||
@@ -61,6 +69,7 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult
|
||||
return resp
|
||||
|
||||
|
||||
@builtin_tool(config=_SANDBOX_PYTHON_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class PythonTool(FunctionTool):
|
||||
name: str = "astrbot_execute_ipython"
|
||||
@@ -83,6 +92,7 @@ class PythonTool(FunctionTool):
|
||||
return f"Error executing code: {str(e)}"
|
||||
|
||||
|
||||
@builtin_tool(config=_LOCAL_PYTHON_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class LocalPythonTool(FunctionTool):
|
||||
name: str = "astrbot_execute_python"
|
||||
@@ -5,11 +5,17 @@ from astrbot.api import FunctionTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
|
||||
from ..computer_client import get_booter, get_local_booter
|
||||
from .permissions import check_admin_permission
|
||||
from ..registry import builtin_tool
|
||||
from .util import check_admin_permission, is_local_runtime, workspace_root
|
||||
|
||||
_COMPUTER_RUNTIME_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": ("local", "sandbox"),
|
||||
}
|
||||
|
||||
|
||||
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class ExecuteShellTool(FunctionTool):
|
||||
name: str = "astrbot_execute_shell"
|
||||
@@ -38,8 +44,6 @@ class ExecuteShellTool(FunctionTool):
|
||||
}
|
||||
)
|
||||
|
||||
is_local: bool = False
|
||||
|
||||
async def call(
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
@@ -50,15 +54,25 @@ class ExecuteShellTool(FunctionTool):
|
||||
if permission_error := check_admin_permission(context, "Shell execution"):
|
||||
return permission_error
|
||||
|
||||
if self.is_local:
|
||||
sb = get_local_booter()
|
||||
else:
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
)
|
||||
try:
|
||||
result = await sb.shell.exec(command, background=background, env=env)
|
||||
cwd: str | None = None
|
||||
if is_local_runtime(context):
|
||||
current_workspace_root = workspace_root(
|
||||
context.context.event.unified_msg_origin
|
||||
)
|
||||
current_workspace_root.mkdir(parents=True, exist_ok=True)
|
||||
cwd = str(current_workspace_root)
|
||||
|
||||
result = await sb.shell.exec(
|
||||
command,
|
||||
cwd=cwd,
|
||||
background=background,
|
||||
env=env,
|
||||
)
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
@@ -1,5 +1,4 @@
|
||||
from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool
|
||||
from .fs import FileDownloadTool, FileUploadTool
|
||||
from .neo_skills import (
|
||||
AnnotateExecutionTool,
|
||||
CreateSkillCandidateTool,
|
||||
@@ -13,27 +12,20 @@ from .neo_skills import (
|
||||
RollbackSkillReleaseTool,
|
||||
SyncSkillReleaseTool,
|
||||
)
|
||||
from .python import LocalPythonTool, PythonTool
|
||||
from .shell import ExecuteShellTool
|
||||
|
||||
__all__ = [
|
||||
"BrowserExecTool",
|
||||
"BrowserBatchExecTool",
|
||||
"RunBrowserSkillTool",
|
||||
"GetExecutionHistoryTool",
|
||||
"AnnotateExecutionTool",
|
||||
"CreateSkillPayloadTool",
|
||||
"GetSkillPayloadTool",
|
||||
"BrowserBatchExecTool",
|
||||
"BrowserExecTool",
|
||||
"CreateSkillCandidateTool",
|
||||
"ListSkillCandidatesTool",
|
||||
"CreateSkillPayloadTool",
|
||||
"EvaluateSkillCandidateTool",
|
||||
"PromoteSkillCandidateTool",
|
||||
"GetExecutionHistoryTool",
|
||||
"GetSkillPayloadTool",
|
||||
"ListSkillCandidatesTool",
|
||||
"ListSkillReleasesTool",
|
||||
"PromoteSkillCandidateTool",
|
||||
"RollbackSkillReleaseTool",
|
||||
"RunBrowserSkillTool",
|
||||
"SyncSkillReleaseTool",
|
||||
"FileUploadTool",
|
||||
"PythonTool",
|
||||
"LocalPythonTool",
|
||||
"ExecuteShellTool",
|
||||
"FileDownloadTool",
|
||||
]
|
||||
@@ -6,9 +6,14 @@ from astrbot.api import FunctionTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.tools.computer_tools.util import check_admin_permission
|
||||
from astrbot.core.tools.registry import builtin_tool
|
||||
|
||||
from ..computer_client import get_booter
|
||||
from .permissions import check_admin_permission
|
||||
_SHIPYARD_NEO_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard_neo",
|
||||
}
|
||||
|
||||
|
||||
def _to_json(data: Any) -> str:
|
||||
@@ -29,6 +34,7 @@ async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> A
|
||||
return browser
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class BrowserExecTool(FunctionTool):
|
||||
name: str = "astrbot_execute_browser"
|
||||
@@ -86,6 +92,7 @@ class BrowserExecTool(FunctionTool):
|
||||
return f"Error executing browser command: {str(e)}"
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class BrowserBatchExecTool(FunctionTool):
|
||||
name: str = "astrbot_execute_browser_batch"
|
||||
@@ -150,6 +157,7 @@ class BrowserBatchExecTool(FunctionTool):
|
||||
return f"Error executing browser batch command: {str(e)}"
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class RunBrowserSkillTool(FunctionTool):
|
||||
name: str = "astrbot_run_browser_skill"
|
||||
@@ -7,10 +7,15 @@ from astrbot.api import FunctionTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
|
||||
from astrbot.core.tools.computer_tools.util import check_admin_permission
|
||||
from astrbot.core.tools.registry import builtin_tool
|
||||
|
||||
from ..computer_client import get_booter
|
||||
from .permissions import check_admin_permission
|
||||
_SHIPYARD_NEO_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard_neo",
|
||||
}
|
||||
|
||||
|
||||
def _to_jsonable(model_like: Any) -> Any:
|
||||
@@ -64,6 +69,7 @@ class NeoSkillToolBase(FunctionTool):
|
||||
return f"{self.error_prefix} {error_action}: {str(e)}"
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class GetExecutionHistoryTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_get_execution_history"
|
||||
@@ -110,6 +116,7 @@ class GetExecutionHistoryTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class AnnotateExecutionTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_annotate_execution"
|
||||
@@ -147,6 +154,7 @@ class AnnotateExecutionTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class CreateSkillPayloadTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_create_skill_payload"
|
||||
@@ -194,6 +202,7 @@ class CreateSkillPayloadTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class GetSkillPayloadTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_get_skill_payload"
|
||||
@@ -220,6 +229,7 @@ class GetSkillPayloadTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class CreateSkillCandidateTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_create_skill_candidate"
|
||||
@@ -273,6 +283,7 @@ class CreateSkillCandidateTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class ListSkillCandidatesTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_list_skill_candidates"
|
||||
@@ -310,6 +321,7 @@ class ListSkillCandidatesTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class EvaluateSkillCandidateTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_evaluate_skill_candidate"
|
||||
@@ -350,6 +362,7 @@ class EvaluateSkillCandidateTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class PromoteSkillCandidateTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_promote_skill_candidate"
|
||||
@@ -420,6 +433,7 @@ class PromoteSkillCandidateTool(NeoSkillToolBase):
|
||||
return f"Error promoting skill candidate: {str(e)}"
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class ListSkillReleasesTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_list_skill_releases"
|
||||
@@ -460,6 +474,7 @@ class ListSkillReleasesTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class RollbackSkillReleaseTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_rollback_skill_release"
|
||||
@@ -486,6 +501,7 @@ class RollbackSkillReleaseTool(NeoSkillToolBase):
|
||||
)
|
||||
|
||||
|
||||
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class SyncSkillReleaseTool(NeoSkillToolBase):
|
||||
name: str = "astrbot_sync_skill_release"
|
||||
@@ -1,5 +1,29 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path
|
||||
|
||||
|
||||
def normalize_umo_for_workspace(umo: str) -> str:
|
||||
normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", umo.strip())
|
||||
return normalized or "unknown"
|
||||
|
||||
|
||||
def workspace_root(umo: str) -> Path:
|
||||
"""Root directory for relative paths in local runtime"""
|
||||
normalized_umo = normalize_umo_for_workspace(umo)
|
||||
return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False)
|
||||
|
||||
|
||||
def is_local_runtime(context: ContextWrapper[AstrAgentContext]) -> bool:
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
runtime = str(provider_settings.get("computer_use_runtime", "local"))
|
||||
return runtime == "local"
|
||||
|
||||
|
||||
def check_admin_permission(
|
||||
@@ -9,6 +9,10 @@ from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.tools.registry import builtin_tool
|
||||
|
||||
_CRON_TOOL_CONFIG = {
|
||||
"provider_settings.proactive_capability.add_cron_tools": True,
|
||||
}
|
||||
|
||||
|
||||
def _extract_job_session(job: Any) -> str | None:
|
||||
payload = getattr(job, "payload", None)
|
||||
@@ -24,7 +28,7 @@ def _parse_run_at(run_at: Any) -> datetime | None:
|
||||
return datetime.fromisoformat(str(run_at))
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_CRON_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class FutureTaskTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "future_task"
|
||||
|
||||
@@ -9,6 +9,10 @@ from astrbot.core.knowledge_base.kb_helper import KBHelper
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.tools.registry import builtin_tool
|
||||
|
||||
_KNOWLEDGE_BASE_TOOL_CONFIG = {
|
||||
"kb_agentic_mode": True,
|
||||
}
|
||||
|
||||
|
||||
def check_all_kb(kb_list: list[KBHelper | None]) -> bool:
|
||||
"""检查是否所有的知识库都为空"""
|
||||
@@ -83,7 +87,7 @@ async def retrieve_knowledge_base(
|
||||
return None
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_KNOWLEDGE_BASE_TOOL_CONFIG)
|
||||
@dataclass
|
||||
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "astr_kb_search"
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from importlib import import_module
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
|
||||
TFunctionTool = TypeVar("TFunctionTool", bound=type[FunctionTool])
|
||||
|
||||
_BUILTIN_TOOL_MODULES = (
|
||||
"astrbot.core.tools.computer_tools",
|
||||
"astrbot.core.tools.cron_tools",
|
||||
"astrbot.core.tools.knowledge_base_tools",
|
||||
"astrbot.core.tools.message_tools",
|
||||
@@ -17,6 +20,182 @@ _BUILTIN_TOOL_MODULES = (
|
||||
_builtin_tool_classes_by_name: dict[str, type[FunctionTool]] = {}
|
||||
_builtin_tool_names_by_class: dict[type[FunctionTool], str] = {}
|
||||
_builtin_tools_loaded = False
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BuiltinToolConfigCondition:
|
||||
key: str
|
||||
operator: str
|
||||
expected: Any = None
|
||||
message: str | None = None
|
||||
|
||||
def evaluate(self, config: dict[str, Any]) -> dict[str, Any]:
|
||||
actual = _get_config_value(config, self.key)
|
||||
|
||||
if self.operator == "equals":
|
||||
matched = actual == self.expected
|
||||
elif self.operator == "in":
|
||||
expected_values = tuple(self.expected or ())
|
||||
matched = actual in expected_values
|
||||
elif self.operator == "truthy":
|
||||
matched = bool(actual)
|
||||
elif self.operator == "custom":
|
||||
matched = bool(self.expected)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported builtin tool config operator: {self.operator}"
|
||||
)
|
||||
|
||||
return {
|
||||
"key": self.key,
|
||||
"operator": self.operator,
|
||||
"expected": _json_safe(self.expected),
|
||||
"actual": _json_safe(None if actual is _MISSING else actual),
|
||||
"matched": matched,
|
||||
"message": self.message,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BuiltinToolConfigRule:
|
||||
conditions: tuple[BuiltinToolConfigCondition, ...] = ()
|
||||
evaluator: Callable[[dict[str, Any]], list[dict[str, Any]]] | None = None
|
||||
|
||||
def evaluate(self, config: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
if self.evaluator is not None:
|
||||
return self.evaluator(config)
|
||||
return [condition.evaluate(config) for condition in self.conditions]
|
||||
|
||||
|
||||
def _get_config_value(config: dict[str, Any], key_path: str) -> Any:
|
||||
current: Any = config
|
||||
for segment in key_path.split("."):
|
||||
if not isinstance(current, dict) or segment not in current:
|
||||
return _MISSING
|
||||
current = current[segment]
|
||||
return current
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, tuple):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, list):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _json_safe(val) for key, val in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def _equals(key: str, expected: Any) -> BuiltinToolConfigCondition:
|
||||
return BuiltinToolConfigCondition(key=key, operator="equals", expected=expected)
|
||||
|
||||
|
||||
def _in(key: str, expected: tuple[Any, ...]) -> BuiltinToolConfigCondition:
|
||||
return BuiltinToolConfigCondition(key=key, operator="in", expected=expected)
|
||||
|
||||
|
||||
def _custom_condition(key: str, *, matched: bool, message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"key": key,
|
||||
"operator": "custom",
|
||||
"expected": None,
|
||||
"actual": None,
|
||||
"matched": matched,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def _build_rule_from_config_map(
|
||||
config_map: dict[str, Any],
|
||||
) -> BuiltinToolConfigRule:
|
||||
conditions: list[BuiltinToolConfigCondition] = []
|
||||
for key, expected in config_map.items():
|
||||
if isinstance(expected, tuple):
|
||||
conditions.append(_in(key, expected))
|
||||
else:
|
||||
conditions.append(_equals(key, expected))
|
||||
return BuiltinToolConfigRule(conditions=tuple(conditions))
|
||||
|
||||
|
||||
def _evaluate_send_message_tool(config: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
platform_configs = config.get("platform", [])
|
||||
if not isinstance(platform_configs, list):
|
||||
return [
|
||||
_custom_condition(
|
||||
"platform",
|
||||
matched=False,
|
||||
message="No enabled platform in this config supports proactive messaging.",
|
||||
)
|
||||
]
|
||||
|
||||
for platform_cfg in platform_configs:
|
||||
if not isinstance(platform_cfg, dict):
|
||||
continue
|
||||
if platform_cfg.get("enable", False) is False:
|
||||
continue
|
||||
|
||||
platform_type = str(platform_cfg.get("type", "")).strip()
|
||||
platform_id = str(platform_cfg.get("id", "")).strip() or platform_type
|
||||
if not platform_type:
|
||||
continue
|
||||
|
||||
if platform_type in {"wecom", "weixin_official_account"}:
|
||||
continue
|
||||
|
||||
if platform_type == "wecom_ai_bot":
|
||||
webhook = str(platform_cfg.get("msg_push_webhook_url", "")).strip()
|
||||
if not webhook:
|
||||
continue
|
||||
return [
|
||||
_custom_condition(
|
||||
"platform[].type",
|
||||
matched=True,
|
||||
message=(
|
||||
f"Enabled platform `{platform_id}` uses `wecom_ai_bot`, which supports proactive messaging "
|
||||
"when `platform[].msg_push_webhook_url` is configured."
|
||||
),
|
||||
),
|
||||
BuiltinToolConfigCondition(
|
||||
key="platform[].msg_push_webhook_url",
|
||||
operator="truthy",
|
||||
).evaluate({"platform[]": {"msg_push_webhook_url": webhook}}),
|
||||
]
|
||||
|
||||
return [
|
||||
_custom_condition(
|
||||
"platform[].type",
|
||||
matched=True,
|
||||
message=(
|
||||
f"Enabled platform `{platform_id}` (`{platform_type}`) supports proactive messaging."
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
return [
|
||||
_custom_condition(
|
||||
"platform",
|
||||
matched=False,
|
||||
message="No enabled platform in this config supports proactive messaging.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
_BUILTIN_TOOL_CONFIG_RULES: dict[str, BuiltinToolConfigRule] = {}
|
||||
|
||||
|
||||
def _register_builtin_tool_config_rule(
|
||||
tool_names: tuple[str, ...],
|
||||
rule: BuiltinToolConfigRule,
|
||||
) -> None:
|
||||
for tool_name in tool_names:
|
||||
_BUILTIN_TOOL_CONFIG_RULES[tool_name] = rule
|
||||
|
||||
|
||||
_register_builtin_tool_config_rule(
|
||||
("send_message_to_user",),
|
||||
BuiltinToolConfigRule(evaluator=_evaluate_send_message_tool),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str:
|
||||
@@ -34,18 +213,29 @@ def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def builtin_tool(tool_cls: TFunctionTool) -> TFunctionTool:
|
||||
tool_name = _resolve_builtin_tool_name(tool_cls)
|
||||
existing = _builtin_tool_classes_by_name.get(tool_name)
|
||||
if existing is not None and existing is not tool_cls:
|
||||
raise ValueError(
|
||||
f"Builtin tool name conflict detected: {tool_name} is already registered by "
|
||||
f"{existing.__module__}.{existing.__name__}.",
|
||||
)
|
||||
def builtin_tool(
|
||||
tool_cls: TFunctionTool | None = None,
|
||||
*,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> TFunctionTool | Callable[[TFunctionTool], TFunctionTool]:
|
||||
def _register(cls: TFunctionTool) -> TFunctionTool:
|
||||
tool_name = _resolve_builtin_tool_name(cls)
|
||||
existing = _builtin_tool_classes_by_name.get(tool_name)
|
||||
if existing is not None and existing is not cls:
|
||||
raise ValueError(
|
||||
f"Builtin tool name conflict detected: {tool_name} is already registered by "
|
||||
f"{existing.__module__}.{existing.__name__}.",
|
||||
)
|
||||
|
||||
_builtin_tool_classes_by_name[tool_name] = tool_cls
|
||||
_builtin_tool_names_by_class[tool_cls] = tool_name
|
||||
return tool_cls
|
||||
_builtin_tool_classes_by_name[tool_name] = cls
|
||||
_builtin_tool_names_by_class[cls] = tool_name
|
||||
if config is not None:
|
||||
_BUILTIN_TOOL_CONFIG_RULES[tool_name] = _build_rule_from_config_map(config)
|
||||
return cls
|
||||
|
||||
if tool_cls is None:
|
||||
return _register
|
||||
return _register(tool_cls)
|
||||
|
||||
|
||||
def ensure_builtin_tools_loaded() -> None:
|
||||
@@ -74,9 +264,64 @@ def iter_builtin_tool_classes() -> tuple[type[FunctionTool], ...]:
|
||||
return tuple(_builtin_tool_classes_by_name.values())
|
||||
|
||||
|
||||
def get_builtin_tool_config_rule(name: str) -> BuiltinToolConfigRule | None:
|
||||
ensure_builtin_tools_loaded()
|
||||
return _BUILTIN_TOOL_CONFIG_RULES.get(name)
|
||||
|
||||
|
||||
def get_builtin_tool_config_statuses(
|
||||
tool_name: str,
|
||||
config_entries: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
rule = get_builtin_tool_config_rule(tool_name)
|
||||
if rule is None:
|
||||
return []
|
||||
|
||||
statuses: list[dict[str, Any]] = []
|
||||
for entry in config_entries:
|
||||
config = entry.get("config")
|
||||
if not isinstance(config, dict):
|
||||
continue
|
||||
|
||||
conditions = rule.evaluate(config)
|
||||
enabled = bool(conditions) and all(
|
||||
bool(condition.get("matched")) for condition in conditions
|
||||
)
|
||||
statuses.append(
|
||||
{
|
||||
"conf_id": entry.get("conf_id"),
|
||||
"conf_name": entry.get("conf_name"),
|
||||
"enabled": enabled,
|
||||
"matched_conditions": [
|
||||
condition for condition in conditions if condition.get("matched")
|
||||
],
|
||||
"failed_conditions": [
|
||||
condition
|
||||
for condition in conditions
|
||||
if not condition.get("matched")
|
||||
],
|
||||
}
|
||||
)
|
||||
return statuses
|
||||
|
||||
|
||||
def get_builtin_tool_config_tags(
|
||||
tool_name: str,
|
||||
config_entries: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
status
|
||||
for status in get_builtin_tool_config_statuses(tool_name, config_entries)
|
||||
if status["enabled"]
|
||||
]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"builtin_tool",
|
||||
"ensure_builtin_tools_loaded",
|
||||
"get_builtin_tool_config_rule",
|
||||
"get_builtin_tool_config_statuses",
|
||||
"get_builtin_tool_config_tags",
|
||||
"get_builtin_tool_class",
|
||||
"get_builtin_tool_name",
|
||||
"iter_builtin_tool_classes",
|
||||
|
||||
@@ -20,6 +20,22 @@ WEB_SEARCH_TOOL_NAMES = [
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
]
|
||||
_TAVILY_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "tavily",
|
||||
}
|
||||
_BOCHA_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "bocha",
|
||||
}
|
||||
_BRAVE_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "brave",
|
||||
}
|
||||
_BAIDU_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "baidu_ai_search",
|
||||
}
|
||||
|
||||
|
||||
@std_dataclass
|
||||
@@ -276,7 +292,7 @@ async def _baidu_search(
|
||||
]
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class TavilyWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "web_search_tavily"
|
||||
@@ -359,7 +375,7 @@ class TavilyWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
return _search_result_payload(results)
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class TavilyExtractWebPageTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "tavily_extract_web_page"
|
||||
@@ -406,7 +422,7 @@ class TavilyExtractWebPageTool(FunctionTool[AstrAgentContext]):
|
||||
return ret or "Error: Tavily web searcher does not return any results."
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_BOCHA_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class BochaWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "web_search_bocha"
|
||||
@@ -470,7 +486,7 @@ class BochaWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
return _search_result_payload(results)
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_BRAVE_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class BraveWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "web_search_brave"
|
||||
@@ -528,7 +544,7 @@ class BraveWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
return _search_result_payload(results)
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class BaiduWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "web_search_baidu"
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
"""Astrbot统一路径获取
|
||||
"""Centralized AstrBot path helpers.
|
||||
|
||||
项目路径:固定为源码所在路径
|
||||
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
插件数据目录路径:固定为数据目录下的 plugin_data 目录
|
||||
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||
Skills 目录路径:固定为数据目录下的 skills 目录
|
||||
第三方依赖目录路径:固定为数据目录下的 site-packages 目录
|
||||
Project path:
|
||||
- Fixed to the source tree location.
|
||||
|
||||
Root path:
|
||||
- Defaults to the current working directory.
|
||||
- Can be overridden with the ``ASTRBOT_ROOT`` environment variable.
|
||||
|
||||
Data subdirectories:
|
||||
- Most runtime data lives under ``<root>/data``.
|
||||
- A few tool-runtime files intentionally live under the system temporary
|
||||
directory as ``.astrbot``.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
|
||||
|
||||
|
||||
def get_astrbot_path() -> str:
|
||||
"""获取Astrbot项目路径"""
|
||||
"""Return the AstrBot project source path."""
|
||||
return os.path.realpath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"),
|
||||
)
|
||||
|
||||
|
||||
def get_astrbot_root() -> str:
|
||||
"""获取Astrbot根目录路径"""
|
||||
"""Return the AstrBot root directory."""
|
||||
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||
return os.path.realpath(path)
|
||||
if is_packaged_desktop_runtime():
|
||||
@@ -35,55 +36,65 @@ def get_astrbot_root() -> str:
|
||||
|
||||
|
||||
def get_astrbot_data_path() -> str:
|
||||
"""获取Astrbot数据目录路径"""
|
||||
"""Return the AstrBot data directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
|
||||
|
||||
|
||||
def get_astrbot_config_path() -> str:
|
||||
"""获取Astrbot配置文件路径"""
|
||||
"""Return the AstrBot config directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
"""Return the AstrBot plugin directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_data_path() -> str:
|
||||
"""获取Astrbot插件数据目录路径"""
|
||||
"""Return the AstrBot plugin data directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
|
||||
|
||||
|
||||
def get_astrbot_t2i_templates_path() -> str:
|
||||
"""获取Astrbot T2I 模板目录路径"""
|
||||
"""Return the AstrBot T2I templates directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
|
||||
|
||||
|
||||
def get_astrbot_webchat_path() -> str:
|
||||
"""获取Astrbot WebChat 数据目录路径"""
|
||||
"""Return the AstrBot WebChat data directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
|
||||
|
||||
|
||||
def get_astrbot_temp_path() -> str:
|
||||
"""获取Astrbot临时文件目录路径"""
|
||||
"""Return the AstrBot temporary data directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
|
||||
|
||||
|
||||
def get_astrbot_skills_path() -> str:
|
||||
"""获取Astrbot Skills 目录路径"""
|
||||
"""Return the AstrBot skills directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "skills"))
|
||||
|
||||
|
||||
def get_astrbot_workspaces_path() -> str:
|
||||
"""Return the AstrBot workspaces directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "workspaces"))
|
||||
|
||||
|
||||
def get_astrbot_system_tmp_path() -> str:
|
||||
"""Return the shared system temporary directory used by local tools."""
|
||||
return os.path.realpath(os.path.join(tempfile.gettempdir(), ".astrbot"))
|
||||
|
||||
|
||||
def get_astrbot_site_packages_path() -> str:
|
||||
"""获取Astrbot第三方依赖目录路径"""
|
||||
"""Return the AstrBot third-party site-packages directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "site-packages"))
|
||||
|
||||
|
||||
def get_astrbot_knowledge_base_path() -> str:
|
||||
"""获取Astrbot知识库根目录路径"""
|
||||
"""Return the AstrBot knowledge base root path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||
|
||||
|
||||
def get_astrbot_backups_path() -> str:
|
||||
"""获取Astrbot备份目录路径"""
|
||||
"""Return the AstrBot backups directory path."""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
|
||||
|
||||
@@ -6,6 +6,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star import star_map
|
||||
from astrbot.core.tools.registry import get_builtin_tool_config_statuses
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -434,13 +435,36 @@ class ToolsRoute(Route):
|
||||
if tool.name not in existing_names:
|
||||
tools.append(tool)
|
||||
|
||||
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
|
||||
conf_name_map = {conf["id"]: conf["name"] for conf in conf_list}
|
||||
config_entries = []
|
||||
for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items():
|
||||
config_entries.append(
|
||||
{
|
||||
"conf_id": conf_id,
|
||||
"conf_name": conf_name_map.get(conf_id, conf_id),
|
||||
"config": conf,
|
||||
}
|
||||
)
|
||||
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
readonly = False
|
||||
builtin_config_statuses = []
|
||||
builtin_config_tags = []
|
||||
if self.tool_mgr.is_builtin_tool(tool.name):
|
||||
origin = "builtin"
|
||||
origin_name = "AstrBot Core"
|
||||
readonly = True
|
||||
builtin_config_statuses = get_builtin_tool_config_statuses(
|
||||
tool.name,
|
||||
config_entries,
|
||||
)
|
||||
builtin_config_tags = [
|
||||
status
|
||||
for status in builtin_config_statuses
|
||||
if status["enabled"]
|
||||
]
|
||||
elif isinstance(tool, MCPTool):
|
||||
origin = "mcp"
|
||||
origin_name = tool.mcp_server_name
|
||||
@@ -462,6 +486,8 @@ class ToolsRoute(Route):
|
||||
"origin": origin,
|
||||
"origin_name": origin_name,
|
||||
"readonly": readonly,
|
||||
"builtin_config_statuses": builtin_config_statuses,
|
||||
"builtin_config_tags": builtin_config_tags,
|
||||
}
|
||||
tools_dict.append(tool_info)
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import type { ToolItem } from '../types';
|
||||
import type { BuiltinToolConfigTag, ToolConfigCondition, ToolItem } from '../types';
|
||||
|
||||
const { tm: tmTool } = useModuleI18n('features/tooluse');
|
||||
|
||||
@@ -15,7 +15,7 @@ const emit = defineEmits<{
|
||||
}>();
|
||||
|
||||
const toolHeaders = computed(() => [
|
||||
{ title: tmTool('functionTools.title'), key: 'name', minWidth: '240px' },
|
||||
{ title: tmTool('functionTools.title'), key: 'name', minWidth: '320px' },
|
||||
{ title: tmTool('functionTools.description'), key: 'description' },
|
||||
{ title: tmTool('functionTools.table.origin'), key: 'origin', sortable: false, width: '120px' },
|
||||
{ title: tmTool('functionTools.table.originName'), key: 'origin_name', sortable: false, width: '160px' },
|
||||
@@ -23,6 +23,52 @@ const toolHeaders = computed(() => [
|
||||
]);
|
||||
|
||||
const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.properties || {});
|
||||
|
||||
const formatConfigValue = (value: unknown) => {
|
||||
if (Array.isArray(value)) {
|
||||
return value.map(item => String(item)).join(', ');
|
||||
}
|
||||
if (typeof value === 'boolean') {
|
||||
return value ? 'true' : 'false';
|
||||
}
|
||||
if (value === null || value === undefined || value === '') {
|
||||
return '-';
|
||||
}
|
||||
return String(value);
|
||||
};
|
||||
|
||||
const formatCondition = (condition: ToolConfigCondition) => {
|
||||
if (condition.message) {
|
||||
return condition.message;
|
||||
}
|
||||
|
||||
switch (condition.operator) {
|
||||
case 'truthy':
|
||||
return tmTool('functionTools.configTags.conditions.truthy', {
|
||||
key: condition.key
|
||||
});
|
||||
case 'equals':
|
||||
return tmTool('functionTools.configTags.conditions.equals', {
|
||||
key: condition.key,
|
||||
expected: formatConfigValue(condition.expected)
|
||||
});
|
||||
case 'in':
|
||||
return tmTool('functionTools.configTags.conditions.in', {
|
||||
key: condition.key,
|
||||
expected: formatConfigValue(condition.expected)
|
||||
});
|
||||
default:
|
||||
return tmTool('functionTools.configTags.conditions.fallback', {
|
||||
key: condition.key,
|
||||
actual: formatConfigValue(condition.actual)
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const enabledConfigTags = (tool: ToolItem): BuiltinToolConfigTag[] => {
|
||||
if (tool.origin !== 'builtin') return [];
|
||||
return (tool.builtin_config_tags || []).filter(tag => tag.enabled);
|
||||
};
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -38,7 +84,39 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro
|
||||
>
|
||||
<template #item.name="{ item }">
|
||||
<div class="py-2">
|
||||
<div class="tool-name text-body-2 font-weight-medium">{{ item.name }}</div>
|
||||
<div class="d-flex flex-wrap align-center ga-1">
|
||||
<div class="tool-name text-body-2 font-weight-medium">{{ item.name }}</div>
|
||||
<v-tooltip
|
||||
v-for="tag in enabledConfigTags(item)"
|
||||
:key="`${item.name}-${tag.conf_id}`"
|
||||
location="top"
|
||||
>
|
||||
<template #activator="{ props: tooltipProps }">
|
||||
<v-chip
|
||||
v-bind="tooltipProps"
|
||||
size="x-small"
|
||||
variant="tonal"
|
||||
color="secondary"
|
||||
class="text-caption font-weight-medium"
|
||||
>
|
||||
{{ tag.conf_name }}
|
||||
</v-chip>
|
||||
</template>
|
||||
|
||||
<div class="tool-config-tooltip">
|
||||
<div class="text-body-2 font-weight-medium mb-2">
|
||||
{{ tmTool('functionTools.configTags.tooltipTitle', { config: tag.conf_name }) }}
|
||||
</div>
|
||||
<div
|
||||
v-for="(condition, index) in tag.matched_conditions"
|
||||
:key="`${tag.conf_id}-${index}-${condition.key}`"
|
||||
class="text-body-2 text-medium-emphasis mb-1"
|
||||
>
|
||||
{{ formatCondition(condition) }}
|
||||
</div>
|
||||
</div>
|
||||
</v-tooltip>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -135,4 +213,16 @@ const parameterEntries = (tool: ToolItem) => Object.entries(tool.parameters?.pro
|
||||
font-size: 0.9rem;
|
||||
line-height: 1.35;
|
||||
}
|
||||
|
||||
.tool-config-tooltip {
|
||||
max-width: 360px;
|
||||
padding: 4px 0;
|
||||
color: rgba(255, 255, 255, 0.92);
|
||||
}
|
||||
|
||||
.tool-config-tooltip :deep(.text-body-2),
|
||||
.tool-config-tooltip :deep(.text-medium-emphasis),
|
||||
.tool-config-tooltip :deep(.font-weight-medium) {
|
||||
color: inherit !important;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -89,6 +89,23 @@ export interface ToolParameter {
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface ToolConfigCondition {
|
||||
key: string;
|
||||
operator: 'truthy' | 'equals' | 'in' | 'custom' | string;
|
||||
expected?: unknown;
|
||||
actual?: unknown;
|
||||
matched: boolean;
|
||||
message?: string | null;
|
||||
}
|
||||
|
||||
export interface BuiltinToolConfigTag {
|
||||
conf_id: string;
|
||||
conf_name: string;
|
||||
enabled: boolean;
|
||||
matched_conditions: ToolConfigCondition[];
|
||||
failed_conditions: ToolConfigCondition[];
|
||||
}
|
||||
|
||||
/** MCP/函数工具对象 */
|
||||
export interface ToolItem {
|
||||
name: string;
|
||||
@@ -100,4 +117,6 @@ export interface ToolItem {
|
||||
};
|
||||
origin?: string;
|
||||
origin_name?: string;
|
||||
builtin_config_statuses?: BuiltinToolConfigTag[];
|
||||
builtin_config_tags?: BuiltinToolConfigTag[];
|
||||
}
|
||||
|
||||
@@ -90,31 +90,52 @@
|
||||
<div v-if="filteredTools.length > 0" class="tools-selection">
|
||||
<v-virtual-scroll :items="filteredTools" height="300" item-height="72">
|
||||
<template v-slot:default="{ item }">
|
||||
<v-list-item :key="item.name" density="comfortable"
|
||||
@click="toggleTool(item.name)">
|
||||
<template v-slot:prepend>
|
||||
<v-checkbox-btn :model-value="isToolSelected(item.name)"
|
||||
@click.stop="toggleTool(item.name)" />
|
||||
<v-tooltip
|
||||
:disabled="!isBuiltinTool(item)"
|
||||
location="top"
|
||||
>
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
<div v-bind="tooltipProps">
|
||||
<v-list-item
|
||||
:key="item.name"
|
||||
density="comfortable"
|
||||
:disabled="isBuiltinTool(item)"
|
||||
@click="toggleTool(item.name)"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-checkbox-btn
|
||||
v-if="!isBuiltinTool(item)"
|
||||
:model-value="isToolSelected(item.name)"
|
||||
@click.stop="toggleTool(item.name)"
|
||||
/>
|
||||
<div
|
||||
v-else
|
||||
class="builtin-tool-checkbox-placeholder"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<v-list-item-title>
|
||||
{{ item.name }}
|
||||
|
||||
<v-chip v-if="item.origin" size="x-small" color="info" class="mr-2"
|
||||
variant="tonal">
|
||||
{{ item.origin }}
|
||||
</v-chip>
|
||||
<v-chip v-if="item.origin_name" size="x-small" color="info"
|
||||
variant="outlined">
|
||||
{{ item.origin_name }}
|
||||
</v-chip>
|
||||
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle v-if="item.description">
|
||||
{{ truncateText(item.description, 100) }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<v-list-item-title>
|
||||
{{ item.name }}
|
||||
|
||||
<v-chip v-if="item.origin" size="x-small" color="info" class="mr-2"
|
||||
variant="tonal">
|
||||
{{ item.origin }}
|
||||
</v-chip>
|
||||
<v-chip v-if="item.origin_name" size="x-small" color="info"
|
||||
variant="outlined">
|
||||
{{ item.origin_name }}
|
||||
</v-chip>
|
||||
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle v-if="item.description">
|
||||
{{ truncateText(item.description, 100) }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
<span>{{ tm('form.builtinToolDisabledHint') }}</span>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
</v-virtual-scroll>
|
||||
</div>
|
||||
@@ -155,11 +176,26 @@
|
||||
</h4>
|
||||
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
|
||||
<v-chip v-for="toolName in personaForm.tools" :key="toolName" size="small"
|
||||
color="primary" variant="tonal" closable
|
||||
@click:close="removeTool(toolName)">
|
||||
{{ toolName }}
|
||||
</v-chip>
|
||||
<v-tooltip
|
||||
v-for="toolName in personaForm.tools"
|
||||
:key="toolName"
|
||||
:disabled="!isBuiltinToolName(toolName)"
|
||||
location="top"
|
||||
>
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
<v-chip
|
||||
v-bind="tooltipProps"
|
||||
size="small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
:closable="!isBuiltinToolName(toolName)"
|
||||
@click:close="removeTool(toolName)"
|
||||
>
|
||||
{{ toolName }}
|
||||
</v-chip>
|
||||
</template>
|
||||
<span>{{ tm('form.builtinToolDisabledHint') }}</span>
|
||||
</v-tooltip>
|
||||
</div>
|
||||
<div v-else class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.noToolsSelected') }}
|
||||
@@ -712,6 +748,9 @@ export default {
|
||||
},
|
||||
|
||||
toggleTool(toolName) {
|
||||
if (this.isBuiltinToolName(toolName)) {
|
||||
return;
|
||||
}
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 如果是全选状态,点击某个工具表示要取消选择该工具
|
||||
@@ -735,6 +774,9 @@ export default {
|
||||
},
|
||||
|
||||
removeTool(toolName) {
|
||||
if (this.isBuiltinToolName(toolName)) {
|
||||
return;
|
||||
}
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 创建一个包含所有工具的数组,然后移除指定工具
|
||||
@@ -784,6 +826,14 @@ export default {
|
||||
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
|
||||
},
|
||||
|
||||
isBuiltinTool(tool) {
|
||||
return tool?.origin === 'builtin' || tool?.readonly === true;
|
||||
},
|
||||
|
||||
isBuiltinToolName(toolName) {
|
||||
return this.availableTools.some(tool => tool.name === toolName && this.isBuiltinTool(tool));
|
||||
},
|
||||
|
||||
getDialogRules(index) {
|
||||
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
return [
|
||||
@@ -859,6 +909,12 @@ export default {
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.builtin-tool-checkbox-placeholder {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
flex: 0 0 40px;
|
||||
}
|
||||
|
||||
.skills-selection {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
|
||||
@@ -117,7 +117,9 @@ const defaultPersonaData = {
|
||||
|
||||
const normalizedTools = computed(() => (Array.isArray(personaData.value?.tools) ? personaData.value.tools : []))
|
||||
const normalizedSkills = computed(() => (Array.isArray(personaData.value?.skills) ? personaData.value.skills : []))
|
||||
const allToolsCount = computed(() => Object.keys(toolMetaMap.value).length)
|
||||
const allToolsCount = computed(() =>
|
||||
Object.values(toolMetaMap.value).filter((tool) => tool.origin !== 'builtin').length
|
||||
)
|
||||
const allSkillsCount = computed(() => availableSkills.value.length)
|
||||
const resolvedTools = computed(() =>
|
||||
normalizedTools.value.map((toolName) => {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
"mcpServersQuickSelect": "MCP Servers Quick Select",
|
||||
"searchTools": "Search Tools",
|
||||
"selectedTools": "Selected Tools",
|
||||
"builtinToolDisabledHint": "Builtin tools cannot be enabled or disabled here yet. Please enable or disable the corresponding config items in the config file.",
|
||||
"noToolsAvailable": "No tools available",
|
||||
"noToolsFound": "No matching tools found",
|
||||
"loadingTools": "Loading tools...",
|
||||
|
||||
@@ -47,6 +47,15 @@
|
||||
"originName": "Origin Name",
|
||||
"readonly": "Read-only",
|
||||
"actions": "Actions"
|
||||
},
|
||||
"configTags": {
|
||||
"tooltipTitle": "This tool is enabled in config file {config} because:",
|
||||
"conditions": {
|
||||
"truthy": "{key} is enabled",
|
||||
"equals": "{key} = {expected}",
|
||||
"in": "{key} matched {expected}",
|
||||
"fallback": "{key} is currently {actual}"
|
||||
}
|
||||
}
|
||||
},
|
||||
"marketplace": {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
"mcpServersQuickSelect": "Быстрый выбор MCP серверов",
|
||||
"searchTools": "Поиск инструментов",
|
||||
"selectedTools": "Выбранные инструменты",
|
||||
"builtinToolDisabledHint": "Встроенные инструменты пока нельзя включать или выключать здесь. Измените соответствующие параметры в файле конфигурации.",
|
||||
"noToolsAvailable": "Нет доступных инструментов",
|
||||
"noToolsFound": "Инструменты не найдены",
|
||||
"loadingTools": "Загрузка инструментов...",
|
||||
@@ -143,4 +144,4 @@
|
||||
"success": "Объект перемещен",
|
||||
"error": "Ошибка перемещения"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,15 @@
|
||||
"originName": "Имя источника",
|
||||
"readonly": "Только чтение",
|
||||
"actions": "Действия"
|
||||
},
|
||||
"configTags": {
|
||||
"tooltipTitle": "Этот инструмент включен в файле конфигурации {config}, потому что:",
|
||||
"conditions": {
|
||||
"truthy": "{key} включен",
|
||||
"equals": "{key} = {expected}",
|
||||
"in": "{key} соответствует {expected}",
|
||||
"fallback": "Текущее значение {key}: {actual}"
|
||||
}
|
||||
}
|
||||
},
|
||||
"marketplace": {
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
"mcpServersQuickSelect": "MCP 服务器快速选择",
|
||||
"searchTools": "搜索工具",
|
||||
"selectedTools": "已选择的工具",
|
||||
"builtinToolDisabledHint": "暂不支持在这里启用和停用内置工具,请在配置文件中启用和停用工具对应的配置项。",
|
||||
"noToolsAvailable": "暂无可用工具",
|
||||
"noToolsFound": "未找到匹配的工具",
|
||||
"loadingTools": "正在加载工具...",
|
||||
|
||||
@@ -47,6 +47,15 @@
|
||||
"originName": "来源名称",
|
||||
"readonly": "只读",
|
||||
"actions": "操作"
|
||||
},
|
||||
"configTags": {
|
||||
"tooltipTitle": "该工具在配置文件 {config} 中启用,因为:",
|
||||
"conditions": {
|
||||
"truthy": "启用了 {key}",
|
||||
"equals": "{key} = {expected}",
|
||||
"in": "{key} 命中了 {expected}",
|
||||
"fallback": "{key} 当前值为 {actual}"
|
||||
}
|
||||
}
|
||||
},
|
||||
"marketplace": {
|
||||
|
||||
@@ -15,13 +15,11 @@ dependencies = [
|
||||
"aiosqlite>=0.21.0",
|
||||
"anthropic>=0.51.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"certifi>=2025.4.26",
|
||||
"chardet~=5.1.0",
|
||||
"loguru>=0.7.2",
|
||||
"cryptography>=44.0.3",
|
||||
"dashscope>=1.23.2",
|
||||
"defusedxml>=0.7.1",
|
||||
"deprecated>=1.2.18",
|
||||
"dingtalk-stream>=0.22.1",
|
||||
"docstring-parser>=0.16",
|
||||
@@ -30,7 +28,6 @@ dependencies = [
|
||||
"google-genai>=1.56.0",
|
||||
"httpx[socks]>=0.28.1",
|
||||
"lark-oapi>=1.4.15",
|
||||
"lxml-html-clean>=0.4.2",
|
||||
"mcp>=1.8.0",
|
||||
"openai>=1.78.0",
|
||||
"ormsgpack>=1.9.1",
|
||||
@@ -45,7 +42,6 @@ dependencies = [
|
||||
"python-telegram-bot>=22.6",
|
||||
"qq-botpy>=1.2.1",
|
||||
"quart>=0.20.0",
|
||||
"readability-lxml>=0.8.4.1",
|
||||
"silk-python>=0.2.6",
|
||||
"slack-sdk>=3.35.0",
|
||||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
@@ -68,6 +64,7 @@ dependencies = [
|
||||
"python-socks>=2.8.0",
|
||||
"pysocks>=1.7.1",
|
||||
"packaging>=24.2",
|
||||
"python-ripgrep==0.0.9",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -4,13 +4,11 @@ aiohttp>=3.11.18
|
||||
aiosqlite>=0.21.0
|
||||
anthropic>=0.51.0
|
||||
apscheduler>=3.11.0
|
||||
beautifulsoup4>=4.13.4
|
||||
certifi>=2025.4.26
|
||||
chardet~=5.1.0
|
||||
loguru>=0.7.2
|
||||
cryptography>=44.0.3
|
||||
dashscope>=1.23.2
|
||||
defusedxml>=0.7.1
|
||||
deprecated>=1.2.18
|
||||
dingtalk-stream>=0.22.1
|
||||
docstring-parser>=0.16
|
||||
@@ -19,7 +17,6 @@ filelock>=3.18.0
|
||||
google-genai>=1.56.0
|
||||
httpx[socks]>=0.28.1
|
||||
lark-oapi>=1.4.15
|
||||
lxml-html-clean>=0.4.2
|
||||
mcp>=1.8.0
|
||||
openai>=1.78.0
|
||||
ormsgpack>=1.9.1
|
||||
@@ -34,7 +31,6 @@ python-telegram-bot>=22.6
|
||||
qq-botpy>=1.2.1
|
||||
python-socks>=2.8.0
|
||||
quart>=0.20.0
|
||||
readability-lxml>=0.8.4.1
|
||||
silk-python>=0.2.6
|
||||
slack-sdk>=3.35.0
|
||||
sqlalchemy[asyncio]>=2.0.41
|
||||
@@ -56,4 +52,5 @@ tenacity>=9.1.2
|
||||
shipyard-python-sdk>=0.2.4
|
||||
shipyard-neo-sdk>=0.2.0
|
||||
packaging>=24.2
|
||||
qrcode>=8.2
|
||||
qrcode>=8.2
|
||||
python-ripgrep==0.0.9
|
||||
283
tests/test_computer_fs_tools.py
Normal file
283
tests/test_computer_fs_tools.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import zipfile
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from mcp.types import CallToolResult, ImageContent
|
||||
from PIL import Image
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.computer import file_read_utils
|
||||
from astrbot.core.computer.booters.local import LocalBooter
|
||||
from astrbot.core.tools.computer_tools import fs as fs_tools
|
||||
from astrbot.core.tools.computer_tools import util as computer_util
|
||||
|
||||
|
||||
def _make_context(
|
||||
*,
|
||||
require_admin: bool = True,
|
||||
role: str = "admin",
|
||||
runtime: str = "local",
|
||||
umo: str = "qq:friend:user-1",
|
||||
) -> ContextWrapper:
|
||||
config_holder = SimpleNamespace(
|
||||
get_config=lambda umo=None: {
|
||||
"provider_settings": {
|
||||
"computer_use_require_admin": require_admin,
|
||||
"computer_use_runtime": runtime,
|
||||
}
|
||||
}
|
||||
)
|
||||
event = SimpleNamespace(
|
||||
role=role,
|
||||
unified_msg_origin=umo,
|
||||
get_sender_id=lambda: "user-1",
|
||||
)
|
||||
astr_ctx = SimpleNamespace(context=config_holder, event=event)
|
||||
return ContextWrapper(context=astr_ctx)
|
||||
|
||||
|
||||
def _setup_local_fs_tools(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
*,
|
||||
umo: str = "qq:friend:user-1",
|
||||
) -> Any:
|
||||
workspaces_root = tmp_path / "workspaces"
|
||||
skills_root = tmp_path / "skills"
|
||||
temp_root = tmp_path / "temp"
|
||||
workspaces_root.mkdir()
|
||||
skills_root.mkdir()
|
||||
temp_root.mkdir()
|
||||
|
||||
monkeypatch.setattr(
|
||||
computer_util,
|
||||
"get_astrbot_workspaces_path",
|
||||
lambda: str(workspaces_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
fs_tools,
|
||||
"get_astrbot_skills_path",
|
||||
lambda: str(skills_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
fs_tools,
|
||||
"get_astrbot_temp_path",
|
||||
lambda: str(temp_root),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
file_read_utils,
|
||||
"get_astrbot_temp_path",
|
||||
lambda: str(temp_root),
|
||||
)
|
||||
|
||||
booter = LocalBooter()
|
||||
|
||||
async def _fake_get_booter(_ctx, _umo):
|
||||
return booter
|
||||
|
||||
monkeypatch.setattr(fs_tools, "get_booter", _fake_get_booter)
|
||||
|
||||
normalized_umo = computer_util.normalize_umo_for_workspace(umo)
|
||||
workspace = workspaces_root / normalized_umo
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
def _make_large_text() -> str:
|
||||
return "".join(f"line-{index:05d}-{'x' * 48}\n" for index in range(6000))
|
||||
|
||||
|
||||
def test_detect_text_encoding_allows_utf8_probe_cut_mid_character():
|
||||
sample = '{"results": ["中文内容"]}'.encode()[:-1]
|
||||
|
||||
assert file_read_utils.detect_text_encoding(sample) in {"utf-8", "utf-8-sig"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_rejects_large_full_text_read_before_local_stream_read(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
large_file = workspace / "large.txt"
|
||||
large_file.write_text(_make_large_text(), encoding="utf-8")
|
||||
|
||||
async def _unexpected_read(*args, **kwargs):
|
||||
raise AssertionError("full file read should be rejected before streaming")
|
||||
|
||||
monkeypatch.setattr(file_read_utils, "read_local_text_range", _unexpected_read)
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="large.txt",
|
||||
)
|
||||
|
||||
assert "text file exceeds 262144 bytes" in result
|
||||
assert "Use `offset` and `limit`" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_allows_partial_read_for_large_text_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
large_file = workspace / "large.txt"
|
||||
lines = [f"line-{index:05d}\n" for index in range(50000)]
|
||||
large_file.write_text("".join(lines), encoding="utf-8")
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="large.txt",
|
||||
offset=1000,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert result == "".join(lines[1000:1003])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_returns_image_call_tool_result_for_images(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
image_path = workspace / "sample.png"
|
||||
Image.new("RGB", (32, 16), color=(255, 0, 0)).save(image_path, format="PNG")
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="sample.png",
|
||||
)
|
||||
|
||||
assert isinstance(result, CallToolResult)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], ImageContent)
|
||||
assert result.content[0].mimeType == "image/jpeg"
|
||||
assert base64.b64decode(result.content[0].data).startswith(b"\xff\xd8\xff")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_treats_svg_as_text(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
svg_path = workspace / "shape.svg"
|
||||
svg_text = (
|
||||
"<svg xmlns='http://www.w3.org/2000/svg'><rect width='10' height='10'/></svg>"
|
||||
)
|
||||
svg_path.write_text(svg_text, encoding="utf-8")
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="shape.svg",
|
||||
)
|
||||
|
||||
assert result == svg_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_reads_pdf_via_parser(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
pdf_path = workspace / "doc.pdf"
|
||||
pdf_path.write_bytes(b"%PDF-1.7\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<<>>\nendobj\n")
|
||||
|
||||
async def _fake_parse_pdf(_file_bytes: bytes, _file_name: str) -> str:
|
||||
return "page-1\npage-2\n"
|
||||
|
||||
monkeypatch.setattr(file_read_utils, "_parse_local_pdf_text", _fake_parse_pdf)
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="doc.pdf",
|
||||
)
|
||||
|
||||
assert result == "page-1\npage-2\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_reads_docx_via_parser_and_magic(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
docx_path = workspace / "report.bin"
|
||||
buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(buffer, mode="w") as archive:
|
||||
archive.writestr("[Content_Types].xml", "<Types/>")
|
||||
archive.writestr("word/document.xml", "<w:document/>")
|
||||
docx_path.write_bytes(buffer.getvalue())
|
||||
|
||||
async def _fake_parse_docx(_file_bytes: bytes, _file_name: str) -> str:
|
||||
return "doc-line-1\ndoc-line-2\n"
|
||||
|
||||
monkeypatch.setattr(file_read_utils, "_parse_local_docx_text", _fake_parse_docx)
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="report.bin",
|
||||
)
|
||||
|
||||
assert result == "doc-line-1\ndoc-line-2\n"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_read_tool_stores_long_converted_document_in_workspace(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
pdf_path = workspace / "manual.pdf"
|
||||
pdf_path.write_bytes(b"%PDF-1.7\nfake\n")
|
||||
long_text = _make_large_text()
|
||||
|
||||
async def _fake_parse_pdf(_file_bytes: bytes, _file_name: str) -> str:
|
||||
return long_text
|
||||
|
||||
monkeypatch.setattr(file_read_utils, "_parse_local_pdf_text", _fake_parse_pdf)
|
||||
|
||||
result = await fs_tools.FileReadTool().call(
|
||||
_make_context(),
|
||||
path="manual.pdf",
|
||||
)
|
||||
|
||||
converted_root = workspace / "converted_files"
|
||||
converted_files = list(converted_root.glob("manual.pdf_*/text.txt"))
|
||||
assert len(converted_files) == 1
|
||||
assert converted_files[0].read_text(encoding="utf-8") == long_text
|
||||
assert str(converted_files[0]) in result
|
||||
assert "Read or grep that file with a narrow window." in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grep_tool_applies_result_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
workspace = _setup_local_fs_tools(monkeypatch, tmp_path)
|
||||
text_path = workspace / "grep.txt"
|
||||
text_path.write_text(
|
||||
"match-1\nmatch-2\nmatch-3\nmatch-4\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = await fs_tools.GrepTool().call(
|
||||
_make_context(),
|
||||
pattern="match",
|
||||
path="grep.txt",
|
||||
result_limit=2,
|
||||
)
|
||||
|
||||
assert "match-1" in result
|
||||
assert "match-2" in result
|
||||
assert "match-3" not in result
|
||||
assert "[Truncated to first 2 result groups.]" in result
|
||||
@@ -4,8 +4,10 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.computer.tools.browser import BrowserExecTool
|
||||
from astrbot.core.computer.tools.neo_skills import GetExecutionHistoryTool
|
||||
from astrbot.core.tools.computer_tools.shipyard_neo.browser import BrowserExecTool
|
||||
from astrbot.core.tools.computer_tools.shipyard_neo.neo_skills import (
|
||||
GetExecutionHistoryTool,
|
||||
)
|
||||
|
||||
|
||||
class _FakeBrowser:
|
||||
@@ -49,7 +51,7 @@ async def test_browser_tool_allows_non_admin_when_admin_requirement_disabled(
|
||||
return SimpleNamespace(browser=_FakeBrowser())
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.tools.browser.get_booter",
|
||||
"astrbot.core.tools.computer_tools.shipyard_neo.browser.get_booter",
|
||||
_fake_get_booter,
|
||||
)
|
||||
|
||||
@@ -72,7 +74,7 @@ async def test_neo_skill_tool_allows_non_admin_when_admin_requirement_disabled(
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.tools.neo_skills.get_booter",
|
||||
"astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.get_booter",
|
||||
_fake_get_booter,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ from astrbot.core.computer.booters.local import LocalFileSystemComponent
|
||||
|
||||
def _allow_tmp_root(monkeypatch, tmp_path: Path) -> None:
|
||||
monkeypatch.setattr(local_booter, "get_astrbot_root", lambda: str(tmp_path))
|
||||
monkeypatch.setattr(local_booter, "get_astrbot_data_path", lambda: str(tmp_path))
|
||||
monkeypatch.setattr(local_booter, "get_astrbot_temp_path", lambda: str(tmp_path))
|
||||
|
||||
|
||||
def test_local_file_system_component_prefers_utf8_before_windows_locale(
|
||||
@@ -27,7 +25,7 @@ def test_local_file_system_component_prefers_utf8_before_windows_locale(
|
||||
|
||||
skill_path = tmp_path / "skills" / "demo.txt"
|
||||
skill_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
skill_path.write_bytes("技能内容".encode("utf-8"))
|
||||
skill_path.write_bytes("技能内容".encode())
|
||||
|
||||
result = asyncio.run(LocalFileSystemComponent().read_file(str(skill_path)))
|
||||
|
||||
|
||||
@@ -4,7 +4,9 @@ import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.computer.tools.neo_skills import PromoteSkillCandidateTool
|
||||
from astrbot.core.tools.computer_tools.shipyard_neo.neo_skills import (
|
||||
PromoteSkillCandidateTool,
|
||||
)
|
||||
|
||||
|
||||
class _FakeSkills:
|
||||
@@ -46,11 +48,11 @@ def test_promote_stable_sync_failure_auto_rolls_back(monkeypatch):
|
||||
raise ValueError("sync failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.tools.neo_skills.get_booter",
|
||||
"astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.get_booter",
|
||||
_fake_get_booter,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.computer.tools.neo_skills.NeoSkillSyncManager.sync_release",
|
||||
"astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.NeoSkillSyncManager.sync_release",
|
||||
_fake_sync_release,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -97,6 +98,26 @@ class MockToolExecutor:
|
||||
return generator()
|
||||
|
||||
|
||||
class LargeTextToolExecutor:
|
||||
"""模拟返回超长文本的工具执行器"""
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str) -> "LargeTextToolExecutor":
|
||||
return cls(text)
|
||||
|
||||
def execute(self, tool, run_context, **tool_args):
|
||||
async def generator():
|
||||
from mcp.types import CallToolResult, TextContent
|
||||
|
||||
result = CallToolResult(content=[TextContent(type="text", text=self.text)])
|
||||
yield result
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
class MockMixedContentToolExecutor:
|
||||
"""模拟返回图片 + 文本的工具执行器"""
|
||||
|
||||
@@ -193,6 +214,32 @@ class MockToolCallProvider(MockProvider):
|
||||
)
|
||||
|
||||
|
||||
class SingleToolThenFinalProvider(MockProvider):
|
||||
def __init__(self, tool_name: str, tool_args: dict[str, str] | None = None):
|
||||
super().__init__()
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args or {}
|
||||
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
func_tool = kwargs.get("func_tool")
|
||||
if func_tool is None or self.call_count > 1:
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="最终回复",
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="",
|
||||
tools_call_name=[self.tool_name],
|
||||
tools_call_args=[self.tool_args],
|
||||
tools_call_ids=["call_large_result"],
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
|
||||
class SequentialToolProvider(MockProvider):
|
||||
def __init__(self, tool_sequence: list[str]):
|
||||
super().__init__()
|
||||
@@ -334,6 +381,10 @@ def runner():
|
||||
return ToolLoopAgentRunner()
|
||||
|
||||
|
||||
def _make_large_tool_result_text() -> str:
|
||||
return "x" * 100000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_step_limit_functionality(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
@@ -1124,18 +1175,116 @@ async def test_follow_up_accepted_when_active_and_not_stopping(
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# Runner is active (not done) and stop is not requested
|
||||
assert not runner.done()
|
||||
assert runner._is_stop_requested() is False
|
||||
|
||||
ticket = runner.follow_up(message_text="valid follow-up message")
|
||||
|
||||
assert ticket is not None, (
|
||||
"Follow-up should be accepted when runner is active and not stopping"
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_tool_result_is_spilled_to_file_and_replaced_with_read_notice(
|
||||
tmp_path,
|
||||
):
|
||||
tool = FunctionTool(
|
||||
name="test_tool",
|
||||
description="测试工具",
|
||||
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
handler=AsyncMock(),
|
||||
)
|
||||
assert ticket.text == "valid follow-up message"
|
||||
assert ticket.consumed is False
|
||||
assert ticket in runner._pending_follow_ups
|
||||
read_tool = FunctionTool(
|
||||
name="astrbot_file_read_tool",
|
||||
description="read file",
|
||||
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
|
||||
handler=AsyncMock(),
|
||||
)
|
||||
tool_set = ToolSet(tools=[tool, read_tool])
|
||||
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
|
||||
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
|
||||
runner = ToolLoopAgentRunner()
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=cast(
|
||||
Any,
|
||||
LargeTextToolExecutor.from_text(_make_large_tool_result_text()),
|
||||
),
|
||||
agent_hooks=MockHooks(),
|
||||
streaming=False,
|
||||
tool_result_overflow_dir=str(tmp_path),
|
||||
read_tool=read_tool,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in runner.step_until_done(3):
|
||||
responses.append(response)
|
||||
|
||||
tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message_content = str(tool_messages[0].content)
|
||||
assert "xxxxxxxxxx" in tool_message_content
|
||||
assert "Truncated tool output preview shown above." in tool_message_content
|
||||
assert "The tool output was too large to include directly" in tool_message_content
|
||||
assert "`astrbot_file_read_tool`" in tool_message_content
|
||||
assert "Use `astrbot_file_read_tool` to inspect it." in tool_message_content
|
||||
|
||||
overflow_files = list(Path(tmp_path).glob("call_large_result_*.txt"))
|
||||
assert len(overflow_files) == 1
|
||||
assert (
|
||||
overflow_files[0].read_text(encoding="utf-8") == _make_large_tool_result_text()
|
||||
)
|
||||
assert str(overflow_files[0]) in tool_message_content
|
||||
|
||||
llm_results = [resp for resp in responses if resp.type == "llm_result"]
|
||||
assert llm_results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_tool_result_keeps_preview_when_spill_fails(
|
||||
tmp_path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
tool = FunctionTool(
|
||||
name="test_tool",
|
||||
description="测试工具",
|
||||
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
handler=AsyncMock(),
|
||||
)
|
||||
read_tool = FunctionTool(
|
||||
name="astrbot_file_read_tool",
|
||||
description="read file",
|
||||
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
|
||||
handler=AsyncMock(),
|
||||
)
|
||||
tool_set = ToolSet(tools=[tool, read_tool])
|
||||
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
|
||||
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
|
||||
runner = ToolLoopAgentRunner()
|
||||
|
||||
async def _raise_spill_error(*, tool_call_id: str, content: str) -> str:
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr(runner, "_write_tool_result_overflow_file", _raise_spill_error)
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=cast(
|
||||
Any,
|
||||
LargeTextToolExecutor.from_text(_make_large_tool_result_text()),
|
||||
),
|
||||
agent_hooks=MockHooks(),
|
||||
streaming=False,
|
||||
tool_result_overflow_dir=str(tmp_path),
|
||||
read_tool=read_tool,
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(3):
|
||||
pass
|
||||
|
||||
tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message_content = str(tool_messages[0].content)
|
||||
assert "xxxxxxxxxx" in tool_message_content
|
||||
assert "Tool output exceeded the inline result limit" in tool_message_content
|
||||
assert "disk full" in tool_message_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -40,7 +40,9 @@ def mock_context():
|
||||
return_value=(None, None, None, False)
|
||||
)
|
||||
ctx.persona_manager.get_persona_v3_by_id = MagicMock(return_value=None)
|
||||
ctx.get_llm_tool_manager.return_value = MagicMock()
|
||||
tool_mgr = MagicMock()
|
||||
tool_mgr.get_builtin_tool.side_effect = lambda cls, **kwargs: cls(**kwargs)
|
||||
ctx.get_llm_tool_manager.return_value = tool_mgr
|
||||
ctx.subagent_orchestrator = None
|
||||
return ctx
|
||||
|
||||
@@ -1479,7 +1481,7 @@ class TestApplyLlmSafetyMode:
|
||||
class TestApplySandboxTools:
|
||||
"""Tests for _apply_sandbox_tools function."""
|
||||
|
||||
def test_apply_sandbox_tools_creates_toolset_if_none(self):
|
||||
def test_apply_sandbox_tools_creates_toolset_if_none(self, mock_context):
|
||||
"""Test that ToolSet is created when func_tool is None."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1494,7 +1496,7 @@ class TestApplySandboxTools:
|
||||
assert req.func_tool is not None
|
||||
assert isinstance(req.func_tool, ToolSet)
|
||||
|
||||
def test_apply_sandbox_tools_adds_required_tools(self):
|
||||
def test_apply_sandbox_tools_adds_required_tools(self, mock_context):
|
||||
"""Test that all required sandbox tools are added."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1512,7 +1514,7 @@ class TestApplySandboxTools:
|
||||
assert "astrbot_upload_file" in tool_names
|
||||
assert "astrbot_download_file" in tool_names
|
||||
|
||||
def test_apply_sandbox_tools_adds_sandbox_prompt(self):
|
||||
def test_apply_sandbox_tools_adds_sandbox_prompt(self, mock_context):
|
||||
"""Test that sandbox mode prompt is added to system_prompt."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1526,7 +1528,7 @@ class TestApplySandboxTools:
|
||||
|
||||
assert "sandboxed environment" in req.system_prompt
|
||||
|
||||
def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch):
|
||||
def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch, mock_context):
|
||||
"""Test sandbox tools with shipyard booter configuration."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1548,7 +1550,7 @@ class TestApplySandboxTools:
|
||||
assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com"
|
||||
assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token"
|
||||
|
||||
def test_apply_sandbox_tools_shipyard_missing_endpoint(self):
|
||||
def test_apply_sandbox_tools_shipyard_missing_endpoint(self, mock_context):
|
||||
"""Test that shipyard config is skipped when endpoint is missing."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1571,7 +1573,7 @@ class TestApplySandboxTools:
|
||||
in mock_logger.error.call_args[0][0]
|
||||
)
|
||||
|
||||
def test_apply_sandbox_tools_shipyard_missing_access_token(self):
|
||||
def test_apply_sandbox_tools_shipyard_missing_access_token(self, mock_context):
|
||||
"""Test that shipyard config is skipped when access token is missing."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1590,7 +1592,7 @@ class TestApplySandboxTools:
|
||||
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
def test_apply_sandbox_tools_preserves_existing_toolset(self):
|
||||
def test_apply_sandbox_tools_preserves_existing_toolset(self, mock_context):
|
||||
"""Test that existing tools are preserved when adding sandbox tools."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1609,7 +1611,7 @@ class TestApplySandboxTools:
|
||||
assert "existing_tool" in req.func_tool.names()
|
||||
assert "astrbot_execute_shell" in req.func_tool.names()
|
||||
|
||||
def test_apply_sandbox_tools_appends_to_existing_system_prompt(self):
|
||||
def test_apply_sandbox_tools_appends_to_existing_system_prompt(self, mock_context):
|
||||
"""Test that sandbox prompt is appended to existing system prompt."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
@@ -1624,7 +1626,7 @@ class TestApplySandboxTools:
|
||||
assert req.system_prompt.startswith("Base prompt")
|
||||
assert "sandboxed environment" in req.system_prompt
|
||||
|
||||
def test_apply_sandbox_tools_with_none_system_prompt(self):
|
||||
def test_apply_sandbox_tools_with_none_system_prompt(self, mock_context):
|
||||
"""Test that sandbox prompt is applied when system_prompt is None."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
|
||||
@@ -15,7 +15,6 @@ from astrbot.core.computer.booters.local import (
|
||||
LocalFileSystemComponent,
|
||||
LocalPythonComponent,
|
||||
LocalShellComponent,
|
||||
_ensure_safe_path,
|
||||
_is_safe_command,
|
||||
)
|
||||
|
||||
@@ -126,51 +125,6 @@ class TestSecurityRestrictions:
|
||||
for cmd in blocked_commands:
|
||||
assert _is_safe_command(cmd) is False, f"Command '{cmd}' should be blocked"
|
||||
|
||||
def test_ensure_safe_path_allowed(self, tmp_path):
|
||||
"""Test paths within allowed roots are accepted."""
|
||||
# Create a test directory structure
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test")
|
||||
|
||||
# Mock get_astrbot_root, get_astrbot_data_path, get_astrbot_temp_path
|
||||
with (
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
result = _ensure_safe_path(str(test_file))
|
||||
assert result == str(test_file)
|
||||
|
||||
def test_ensure_safe_path_blocked(self, tmp_path):
|
||||
"""Test paths outside allowed roots raise PermissionError."""
|
||||
with (
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
# Try to access a path outside the allowed roots
|
||||
with pytest.raises(PermissionError) as exc_info:
|
||||
_ensure_safe_path("/etc/passwd")
|
||||
assert "Path is outside the allowed computer roots" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestLocalShellComponent:
|
||||
"""Tests for LocalShellComponent."""
|
||||
@@ -212,14 +166,6 @@ class TestLocalShellComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
# Use python to read file to avoid Windows vs Unix command differences
|
||||
result = await shell.exec(
|
||||
@@ -294,14 +240,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
result = await fs.create_file(str(test_path), "test content")
|
||||
assert result["success"] is True
|
||||
@@ -320,14 +258,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
result = await fs.read_file(str(test_path))
|
||||
assert result["success"] is True
|
||||
@@ -344,14 +274,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
result = await fs.write_file(str(test_path), "new content")
|
||||
assert result["success"] is True
|
||||
@@ -369,14 +291,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
result = await fs.delete_file(str(test_path))
|
||||
assert result["success"] is True
|
||||
@@ -395,14 +309,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
result = await fs.delete_file(str(test_dir))
|
||||
assert result["success"] is True
|
||||
@@ -422,14 +328,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
# Without hidden files
|
||||
result = await fs.list_dir(str(tmp_path), show_hidden=False)
|
||||
@@ -452,14 +350,6 @@ class TestLocalFileSystemComponent:
|
||||
"astrbot.core.computer.booters.local.get_astrbot_root",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_data_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.computer.booters.local.get_astrbot_temp_path",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
):
|
||||
# Should raise FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError):
|
||||
@@ -633,7 +523,7 @@ class TestComputerClient:
|
||||
"shipyard_access_token": "test_token",
|
||||
"shipyard_ttl": 3600,
|
||||
"shipyard_max_sessions": 10,
|
||||
}
|
||||
},
|
||||
}
|
||||
}.get(key, default)
|
||||
mock_context.get_config = MagicMock(return_value=mock_config)
|
||||
@@ -681,7 +571,7 @@ class TestComputerClient:
|
||||
"computer_use_runtime": "sandbox",
|
||||
"sandbox": {
|
||||
"booter": "unknown_type",
|
||||
}
|
||||
},
|
||||
}
|
||||
}.get(key, default)
|
||||
mock_context.get_config = MagicMock(return_value=mock_config)
|
||||
@@ -707,7 +597,7 @@ class TestComputerClient:
|
||||
"booter": "shipyard",
|
||||
"shipyard_endpoint": "http://localhost:8080",
|
||||
"shipyard_access_token": "test_token",
|
||||
}
|
||||
},
|
||||
}
|
||||
}.get(key, default)
|
||||
mock_context.get_config = MagicMock(return_value=mock_config)
|
||||
@@ -752,7 +642,7 @@ class TestComputerClient:
|
||||
"booter": "shipyard",
|
||||
"shipyard_endpoint": "http://localhost:8080",
|
||||
"shipyard_access_token": "test_token",
|
||||
}
|
||||
},
|
||||
}
|
||||
}.get(key, default)
|
||||
mock_context.get_config = MagicMock(return_value=mock_config)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from astrbot.core.tools.computer_tools.shell import ExecuteShellTool
|
||||
from astrbot.core.tools.message_tools import SendMessageToUserTool
|
||||
|
||||
|
||||
@@ -12,3 +13,28 @@ def test_get_builtin_tool_by_class_returns_cached_instance():
|
||||
assert tool_by_class is tool_by_name
|
||||
assert manager.get_func("send_message_to_user") is tool_by_class
|
||||
assert tool_by_class.name == "send_message_to_user"
|
||||
|
||||
|
||||
def test_builtin_tool_ignores_inactivated_llm_tools():
|
||||
manager = FunctionToolManager()
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
["send_message_to_user"],
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
try:
|
||||
tool = manager.get_builtin_tool(SendMessageToUserTool)
|
||||
assert tool.active is True
|
||||
finally:
|
||||
sp.put("inactivated_llm_tools", [], scope="global", scope_id="global")
|
||||
|
||||
|
||||
def test_computer_tools_are_registered_as_builtin_tools():
|
||||
manager = FunctionToolManager()
|
||||
|
||||
tool = manager.get_builtin_tool(ExecuteShellTool)
|
||||
|
||||
assert tool.name == "astrbot_execute_shell"
|
||||
assert manager.is_builtin_tool("astrbot_execute_shell") is True
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import platform
|
||||
from astrbot.core.computer.tools.python import PythonTool, LocalPythonTool
|
||||
|
||||
from astrbot.core.tools.computer_tools.python import LocalPythonTool, PythonTool
|
||||
|
||||
|
||||
def test_python_tool_description_contains_os():
|
||||
"""测试 PythonTool 的描述中是否包含当前操作系统信息"""
|
||||
@@ -8,6 +10,7 @@ def test_python_tool_description_contains_os():
|
||||
assert current_os in tool.description
|
||||
assert "IPython" in tool.description
|
||||
|
||||
|
||||
def test_local_python_tool_description_contains_os():
|
||||
"""测试 LocalPythonTool 的描述中是否包含当前操作系统信息和兼容性提示"""
|
||||
tool = LocalPythonTool()
|
||||
|
||||
Reference in New Issue
Block a user