Compare commits

...

3 Commits

Author SHA1 Message Date
Soulter
e2745ab8e6 feat: enhance ToolLoopAgentRunner to support image and audio modalities; add context sanitization logic 2026-04-22 11:36:50 +08:00
Soulter
d9def46bae feat: update FileReadTool description to include support for docx and epub files; change base64 decoding to utf-8 2026-04-22 10:38:26 +08:00
Soulter
14ebde9348 feat: update FileReadTool description to mention image and PDF support
Add explicit mention of image (OCR) and PDF (text extraction) support
to the FileReadTool description for better discoverability.
2026-04-13 16:29:20 +08:00
7 changed files with 361 additions and 286 deletions

View File

@@ -7,7 +7,7 @@ import typing as T
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from pathlib import Path
from mcp.types import (
@@ -42,6 +42,10 @@ from astrbot.core.provider.entities import (
ProviderRequest,
ToolCallsResult,
)
from astrbot.core.provider.modalities import (
log_context_sanitize_stats,
sanitize_contexts_by_modalities,
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
@@ -300,8 +304,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if isinstance(msg, dict) and msg.get("_no_save"):
m._no_save = True
messages.append(m)
if request.prompt is not None:
m = await request.assemble_context()
if (
request.prompt is not None
or request.image_urls
or request.audio_urls
or request.extra_user_content_parts
):
m = await self._assemble_request_context_for_provider(request)
messages.append(Message.model_validate(m))
if request.system_prompt:
messages.insert(
@@ -318,6 +327,42 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
return f"`{self.read_tool.name}`"
return "the available file-read tool"
async def _assemble_request_context_for_provider(
self,
request: ProviderRequest,
) -> dict[str, T.Any]:
modalities = self.provider.provider_config.get("modalities", None)
if not isinstance(modalities, list):
return await request.assemble_context()
supports_image = "image" in modalities
supports_audio = "audio" in modalities
if supports_image and supports_audio:
return await request.assemble_context()
adjusted_request = replace(
request,
image_urls=request.image_urls if supports_image else [],
audio_urls=request.audio_urls if supports_audio else [],
)
context = await adjusted_request.assemble_context()
content = context.get("content")
if isinstance(content, str):
content_blocks: list[dict[str, T.Any]] = [{"type": "text", "text": content}]
elif isinstance(content, list):
content_blocks = content
else:
content_blocks = []
if not supports_image:
for _ in request.image_urls:
content_blocks.append({"type": "text", "text": "[Image]"})
if not supports_audio:
for _ in request.audio_urls:
content_blocks.append({"type": "text", "text": "[Audio]"})
return {"role": "user", "content": content_blocks}
async def _write_tool_result_overflow_file(
self,
*,
@@ -415,8 +460,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"contexts": self._sanitize_contexts_for_provider(self.run_context.messages),
"func_tool": self._func_tool_for_provider(),
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
"abort_signal": self._abort_signal,
@@ -532,6 +577,35 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
completion_text="All available chat models are unavailable.",
)
def _sanitize_contexts_for_provider(
self,
contexts: list[Message] | list[dict[str, T.Any]],
) -> list[Message] | list[dict[str, T.Any]]:
if not self._should_fix_modalities_for_provider():
return contexts
sanitized_contexts, stats = sanitize_contexts_by_modalities(
contexts,
self.provider.provider_config.get("modalities", None),
)
log_context_sanitize_stats(stats)
return sanitized_contexts
def _should_fix_modalities_for_provider(self) -> bool:
modalities = self.provider.provider_config.get("modalities", None)
return isinstance(modalities, list)
def _func_tool_for_provider(self) -> ToolSet | None:
if not self.req.func_tool:
return None
modalities = self.provider.provider_config.get("modalities", None)
if isinstance(modalities, list) and "tool_use" not in modalities:
logger.debug(
"Provider %s does not support tool_use, clearing tools for request.",
self.provider,
)
return None
return self.req.func_tool
def _simple_print_message_role(self, tag: str = ""):
roles = []
for message in self.run_context.messages:
@@ -1194,7 +1268,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if param_subset.tools and tool_names:
contexts = self._build_tool_requery_context(tool_names)
requery_resp = await self.provider.text_chat(
contexts=contexts,
contexts=self._sanitize_contexts_for_provider(contexts),
func_tool=param_subset,
model=self.req.model,
session_id=self.req.session_id,
@@ -1220,7 +1294,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
extra_instruction=self.SKILLS_LIKE_REQUERY_REPAIR_INSTRUCTION,
)
repair_resp = await self.provider.text_chat(
contexts=repair_contexts,
contexts=self._sanitize_contexts_for_provider(repair_contexts),
func_tool=param_subset,
model=self.req.model,
session_id=self.req.session_id,

View File

@@ -823,136 +823,6 @@ async def _decorate_llm_request(
_apply_workspace_extra_prompt(event, req)
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(
"Provider %s does not support image, using placeholder.", provider
)
image_count = len(req.image_urls)
placeholder = " ".join(["[Image]"] * image_count)
if req.prompt:
req.prompt = f"{placeholder} {req.prompt}"
else:
req.prompt = placeholder
req.image_urls = []
if req.audio_urls:
provider_cfg = provider.provider_config.get("modalities", ["audio"])
if "audio" not in provider_cfg:
logger.debug(
"Provider %s does not support audio, using placeholder.", provider
)
audio_count = len(req.audio_urls)
placeholder = " ".join(["[Audio]"] * audio_count)
if req.prompt:
req.prompt = f"{placeholder} {req.prompt}"
else:
req.prompt = placeholder
req.audio_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
if "tool_use" not in provider_cfg:
logger.debug(
"Provider %s does not support tool_use, clearing tools.", provider
)
req.func_tool = None
def _sanitize_context_by_modalities(
config: MainAgentBuildConfig,
provider: Provider,
req: ProviderRequest,
) -> None:
if not config.sanitize_context_by_modalities:
return
if not isinstance(req.contexts, list) or not req.contexts:
return
modalities = provider.provider_config.get("modalities", None)
if not modalities or not isinstance(modalities, list):
return
supports_image = bool("image" in modalities)
supports_audio = bool("audio" in modalities)
supports_tool_use = bool("tool_use" in modalities)
if supports_image and supports_audio and supports_tool_use:
return
sanitized_contexts: list[dict] = []
removed_image_blocks = 0
removed_audio_blocks = 0
removed_tool_messages = 0
removed_tool_calls = 0
for msg in req.contexts:
if not isinstance(msg, dict):
continue
role = msg.get("role")
if not role:
continue
new_msg = msg
if not supports_tool_use:
if role == "tool":
removed_tool_messages += 1
continue
if role == "assistant" and "tool_calls" in new_msg:
if "tool_calls" in new_msg:
removed_tool_calls += 1
new_msg.pop("tool_calls", None)
new_msg.pop("tool_call_id", None)
if not supports_image or not supports_audio:
content = new_msg.get("content")
if isinstance(content, list):
filtered_parts: list = []
removed_any_multimodal = False
for part in content:
if isinstance(part, dict):
part_type = str(part.get("type", "")).lower()
if not supports_image and part_type in {"image_url", "image"}:
removed_any_multimodal = True
removed_image_blocks += 1
continue
if not supports_audio and part_type in {
"audio_url",
"input_audio",
}:
removed_any_multimodal = True
removed_audio_blocks += 1
continue
filtered_parts.append(part)
if removed_any_multimodal:
new_msg["content"] = filtered_parts
if role == "assistant":
content = new_msg.get("content")
has_tool_calls = bool(new_msg.get("tool_calls"))
if not has_tool_calls:
if not content:
continue
if isinstance(content, str) and not content.strip():
continue
sanitized_contexts.append(new_msg)
if (
removed_image_blocks
or removed_audio_blocks
or removed_tool_messages
or removed_tool_calls
):
logger.debug(
"sanitize_context_by_modalities applied: "
"removed_image_blocks=%s, removed_audio_blocks=%s, "
"removed_tool_messages=%s, removed_tool_calls=%s",
removed_image_blocks,
removed_audio_blocks,
removed_tool_messages,
removed_tool_calls,
)
req.contexts = sanitized_contexts
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
"""根据事件中的插件设置,过滤请求中的工具列表。
@@ -1393,10 +1263,8 @@ async def build_main_agent(
if not req.session_id:
req.session_id = event.unified_msg_origin
_modalities_fix(provider, req)
_plugin_tool_fix(event, req)
await _apply_web_search_tools(event, req, plugin_context)
_sanitize_context_by_modalities(config, provider, req)
if config.llm_safety_mode:
_apply_llm_safety_mode(config, req)

View File

@@ -91,7 +91,7 @@ print(
json.dumps(
{{
"size_bytes": path.stat().st_size,
"sample_b64": base64.b64encode(sample).decode("ascii"),
"sample_b64": base64.b64encode(sample).decode("utf-8"),
}}
)
)
@@ -140,7 +140,7 @@ print(
json.dumps(
{{
"size_bytes": len(data),
"base64": base64.b64encode(data).decode("ascii"),
"base64": base64.b64encode(data).decode("utf-8"),
}}
)
)
@@ -278,7 +278,7 @@ async def _probe_local_file(path: str) -> dict[str, str | int]:
sample = file_obj.read(_FILE_SNIFF_BYTES)
return {
"size_bytes": file_path.stat().st_size,
"sample_b64": base64.b64encode(sample).decode("ascii"),
"sample_b64": base64.b64encode(sample).decode("utf-8"),
}
return await to_thread(_run)
@@ -289,7 +289,7 @@ async def _read_local_image_base64(path: str) -> dict[str, str | int]:
data = Path(path).read_bytes()
return {
"size_bytes": len(data),
"base64": base64.b64encode(data).decode("ascii"),
"base64": base64.b64encode(data).decode("utf-8"),
}
return await to_thread(_run)
@@ -319,7 +319,7 @@ async def _compress_image_bytes_to_base64(data: bytes) -> dict[str, str | int]:
return {
"size_bytes": len(compressed_bytes),
"base64": base64.b64encode(compressed_bytes).decode("ascii"),
"base64": base64.b64encode(compressed_bytes).decode("utf-8"),
"mime_type": "image/jpeg",
}
@@ -659,14 +659,14 @@ async def read_file_tool_result(
return "Error reading file: image payload is empty."
raw_bytes = base64.b64decode(raw_base64_data)
compressed_payload = await _compress_image_bytes_to_base64(raw_bytes)
base64_data = str(compressed_payload.get("base64", "") or "")
if not base64_data:
compressed_base64_data = str(compressed_payload.get("base64", "") or "")
if not compressed_base64_data:
return "Error reading file: compressed image payload is empty."
return mcp.types.CallToolResult(
content=[
mcp.types.ImageContent(
type="image",
data=base64_data,
data=compressed_base64_data,
mimeType=str(
compressed_payload.get("mime_type", "") or "image/jpeg"
),

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import copy
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from astrbot import logger
from astrbot.core.agent.message import Message
@dataclass(slots=True)
class ContextSanitizeStats:
fixed_image_blocks: int = 0
fixed_audio_blocks: int = 0
fixed_tool_messages: int = 0
removed_tool_calls: int = 0
@property
def changed(self) -> bool:
return bool(
self.fixed_image_blocks
or self.fixed_audio_blocks
or self.fixed_tool_messages
or self.removed_tool_calls
)
def _message_to_dict(message: dict[str, Any] | Message) -> dict[str, Any] | None:
if isinstance(message, Message):
return dict(message.model_dump())
if isinstance(message, dict):
return dict(copy.deepcopy(message))
return None
def sanitize_contexts_by_modalities(
contexts: Sequence[dict[str, Any] | Message],
modalities: list[str] | None,
) -> tuple[list[dict[str, Any]], ContextSanitizeStats]:
if not contexts:
return [], ContextSanitizeStats()
if not modalities or not isinstance(modalities, list):
copied_contexts = []
for msg in contexts:
copied_msg = _message_to_dict(msg)
if copied_msg:
copied_contexts.append(copied_msg)
return copied_contexts, ContextSanitizeStats()
supports_image = "image" in modalities
supports_audio = "audio" in modalities
supports_tool_use = "tool_use" in modalities
if supports_image and supports_audio and supports_tool_use:
copied_contexts = []
for msg in contexts:
copied_msg = _message_to_dict(msg)
if copied_msg:
copied_contexts.append(copied_msg)
return copied_contexts, ContextSanitizeStats()
sanitized_contexts: list[dict[str, Any]] = []
stats = ContextSanitizeStats()
for raw_msg in contexts:
msg = _message_to_dict(raw_msg)
if not msg:
continue
role = msg.get("role")
if not role:
continue
if not supports_tool_use:
if role == "tool":
stats.fixed_tool_messages += 1
fixed_msg: dict[str, Any] = {
"role": "user",
"content": _tool_result_placeholder(msg.get("content")),
}
msg = fixed_msg
if role == "assistant" and "tool_calls" in msg:
stats.removed_tool_calls += 1
msg.pop("tool_calls", None)
msg.pop("tool_call_id", None)
if not supports_image or not supports_audio:
content = msg.get("content")
if isinstance(content, list):
filtered_parts: list[Any] = []
removed_any_multimodal = False
for part in content:
if isinstance(part, dict):
part_type = str(part.get("type", "")).lower()
if not supports_image and part_type in {"image_url", "image"}:
removed_any_multimodal = True
stats.fixed_image_blocks += 1
filtered_parts.append({"type": "text", "text": "[Image]"})
continue
if not supports_audio and part_type in {
"audio_url",
"input_audio",
}:
removed_any_multimodal = True
stats.fixed_audio_blocks += 1
filtered_parts.append({"type": "text", "text": "[Audio]"})
continue
filtered_parts.append(part)
if removed_any_multimodal:
msg["content"] = filtered_parts
if role == "assistant":
content = msg.get("content")
has_tool_calls = bool(msg.get("tool_calls"))
if not has_tool_calls:
if not content:
continue
if isinstance(content, str) and not content.strip():
continue
sanitized_contexts.append(msg)
return sanitized_contexts, stats
def _tool_result_placeholder(content: Any) -> str:
if isinstance(content, str):
content_text = content.strip()
elif isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, dict):
part_type = str(part.get("type", "")).lower()
if part_type == "text":
text_parts.append(str(part.get("text", "")))
elif part_type in {"image_url", "image"}:
text_parts.append("[Image]")
elif part_type in {"audio_url", "input_audio"}:
text_parts.append("[Audio]")
content_text = "\n".join(part for part in text_parts if part).strip()
else:
content_text = ""
if not content_text:
return "[Tool result]"
return f"[Tool result]\n{content_text}"
def log_context_sanitize_stats(stats: ContextSanitizeStats) -> None:
if not stats.changed:
return
logger.debug(
"context modality fix applied: "
"fixed_image_blocks=%s, fixed_audio_blocks=%s, "
"fixed_tool_messages=%s, removed_tool_calls=%s",
stats.fixed_image_blocks,
stats.fixed_audio_blocks,
stats.fixed_tool_messages,
stats.removed_tool_calls,
)

View File

@@ -171,7 +171,7 @@ def _decode_escaped_text(value: str) -> str:
@dataclass
class FileReadTool(FunctionTool):
name: str = "astrbot_file_read_tool"
description: str = "read file content."
description: str = "read file content. Supports text, image, and PDF (text extraction), docx and epub files."
parameters: dict = field(
default_factory=lambda: {
"type": "object",

View File

@@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")
from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import ImageURLPart, Message, TextPart
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.tool import FunctionTool, ToolSet
@@ -156,6 +157,25 @@ class MockErrProvider(MockProvider):
)
class CapturingProvider(MockProvider):
def __init__(self, modalities: list[str]):
super().__init__()
self.provider_config["modalities"] = modalities
self.received_contexts = []
self.received_func_tools = []
self.should_call_tools = False
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
self.received_contexts.append(kwargs.get("contexts"))
self.received_func_tools.append(kwargs.get("func_tool"))
return LLMResponse(
role="assistant",
completion_text="final",
usage=TokenUsage(input_other=10, output=5),
)
class MockEmptyOutputThenSuccessProvider(MockProvider):
def __init__(self, failures_before_success: int = 1):
super().__init__()
@@ -615,6 +635,99 @@ async def test_tool_result_includes_all_calltoolresult_content(
]
@pytest.mark.asyncio
async def test_runner_replaces_runtime_image_context_before_provider_call(
runner, provider_request, mock_hooks
):
provider = CapturingProvider(modalities=["tool_use"])
await runner.reset(
provider=provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=MockToolExecutor,
agent_hooks=mock_hooks,
streaming=False,
)
runner.run_context.messages.append(
Message(
role="user",
content=[
TextPart(text="Review this image"),
ImageURLPart(
image_url=ImageURLPart.ImageURL(
url="data:image/png;base64,dGVzdA=="
)
),
],
)
)
async for _ in runner.step_until_done(1):
pass
assert provider.received_contexts
sent_context = provider.received_contexts[0]
assert sent_context[-1]["content"] == [
{"type": "text", "text": "Review this image"},
{"type": "text", "text": "[Image]"},
]
assert len(runner.run_context.messages[-2].content) == 2
@pytest.mark.asyncio
async def test_runner_builds_placeholder_for_unsupported_request_image(
runner, mock_hooks, tool_set
):
provider = CapturingProvider(modalities=["tool_use"])
request = ProviderRequest(
prompt="Describe it",
image_urls=["/path/that/should/not/be/read.jpg"],
func_tool=tool_set,
contexts=[],
)
await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=MockToolExecutor,
agent_hooks=mock_hooks,
streaming=False,
)
async for _ in runner.step_until_done(1):
pass
sent_context = provider.received_contexts[0]
assert sent_context[-1]["content"] == [
{"type": "text", "text": "Describe it"},
{"type": "text", "text": "[Image]"},
]
@pytest.mark.asyncio
async def test_runner_clears_tools_for_provider_without_tool_use(
runner, provider_request, mock_hooks, mock_tool_executor
):
provider = CapturingProvider(modalities=["text"])
await runner.reset(
provider=provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=False,
)
async for _ in runner.step_until_done(1):
pass
assert provider.received_func_tools == [None]
@pytest.mark.asyncio
async def test_same_tool_consecutive_results_include_escalating_guidance(
runner, mock_tool_executor, mock_hooks

View File

@@ -713,144 +713,6 @@ class TestDecorateLlmRequest:
assert req.prompt == "Hello"
class TestModalitiesFix:
"""Tests for _modalities_fix function."""
def test_modalities_fix_image_not_supported(self, mock_provider):
"""Test modality fix when image is not supported."""
module = ama
mock_provider.provider_config = {"modalities": ["text"]}
req = ProviderRequest(prompt="Hello", image_urls=["/path/to/image.jpg"])
module._modalities_fix(mock_provider, req)
assert "[Image]" in req.prompt
assert req.image_urls == []
def test_modalities_fix_tool_not_supported(self, mock_provider):
"""Test modality fix when tool is not supported."""
module = ama
mock_provider.provider_config = {"modalities": ["text", "image"]}
req = ProviderRequest(prompt="Hello")
req.func_tool = ToolSet()
req.func_tool.add_tool(
FunctionTool(
name="dummy_tool",
description="dummy",
parameters={"type": "object", "properties": {}},
)
)
module._modalities_fix(mock_provider, req)
assert req.func_tool is None
def test_modalities_fix_all_supported(self, mock_provider):
"""Test modality fix when all features are supported."""
module = ama
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
tool_set = ToolSet()
tool_set.add_tool(
FunctionTool(
name="dummy_tool",
description="dummy",
parameters={"type": "object", "properties": {}},
)
)
req = ProviderRequest(
prompt="Hello",
image_urls=["/path/to/image.jpg"],
func_tool=tool_set,
)
module._modalities_fix(mock_provider, req)
assert req.prompt == "Hello"
assert len(req.image_urls) == 1
assert req.func_tool is not None
class TestSanitizeContextByModalities:
"""Tests for _sanitize_context_by_modalities function."""
def test_sanitize_no_op(self, mock_provider):
"""Test sanitize when disabled or modalities support everything."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=False
)
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
req = ProviderRequest(contexts=[{"role": "user", "content": "Hello"}])
module._sanitize_context_by_modalities(config, mock_provider, req)
assert len(req.contexts) == 1
def test_sanitize_removes_tool_messages(self, mock_provider):
"""Test sanitize removes tool messages when tool_use not supported."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=True
)
mock_provider.provider_config = {"modalities": ["image"]}
req = ProviderRequest(
contexts=[
{"role": "user", "content": "Hello"},
{"role": "tool", "content": "Tool result"},
]
)
module._sanitize_context_by_modalities(config, mock_provider, req)
assert len(req.contexts) == 1
assert req.contexts[0]["role"] == "user"
def test_sanitize_removes_tool_calls(self, mock_provider):
"""Test sanitize removes tool_calls from assistant messages."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=True
)
mock_provider.provider_config = {"modalities": ["image"]}
req = ProviderRequest(
contexts=[
{
"role": "assistant",
"content": "Response",
"tool_calls": [{"name": "tool"}],
}
]
)
module._sanitize_context_by_modalities(config, mock_provider, req)
assert "tool_calls" not in req.contexts[0]
def test_sanitize_removes_image_blocks(self, mock_provider):
"""Test sanitize removes image blocks when image not supported."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=True
)
mock_provider.provider_config = {"modalities": ["tool_use"]}
req = ProviderRequest(
contexts=[
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image_url", "url": "image.jpg"},
],
}
]
)
module._sanitize_context_by_modalities(config, mock_provider, req)
assert len(req.contexts[0]["content"]) == 1
assert req.contexts[0]["content"][0]["type"] == "text"
class TestPluginToolFix:
"""Tests for _plugin_tool_fix function."""