mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 18:20:16 +08:00
Compare commits
3 Commits
fix/remove
...
draft/file
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2745ab8e6 | ||
|
|
d9def46bae | ||
|
|
14ebde9348 |
@@ -7,7 +7,7 @@ import typing as T
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
|
||||
from mcp.types import (
|
||||
@@ -42,6 +42,10 @@ from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.provider.modalities import (
|
||||
log_context_sanitize_stats,
|
||||
sanitize_contexts_by_modalities,
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
@@ -300,8 +304,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if isinstance(msg, dict) and msg.get("_no_save"):
|
||||
m._no_save = True
|
||||
messages.append(m)
|
||||
if request.prompt is not None:
|
||||
m = await request.assemble_context()
|
||||
if (
|
||||
request.prompt is not None
|
||||
or request.image_urls
|
||||
or request.audio_urls
|
||||
or request.extra_user_content_parts
|
||||
):
|
||||
m = await self._assemble_request_context_for_provider(request)
|
||||
messages.append(Message.model_validate(m))
|
||||
if request.system_prompt:
|
||||
messages.insert(
|
||||
@@ -318,6 +327,42 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
return f"`{self.read_tool.name}`"
|
||||
return "the available file-read tool"
|
||||
|
||||
async def _assemble_request_context_for_provider(
|
||||
self,
|
||||
request: ProviderRequest,
|
||||
) -> dict[str, T.Any]:
|
||||
modalities = self.provider.provider_config.get("modalities", None)
|
||||
if not isinstance(modalities, list):
|
||||
return await request.assemble_context()
|
||||
|
||||
supports_image = "image" in modalities
|
||||
supports_audio = "audio" in modalities
|
||||
if supports_image and supports_audio:
|
||||
return await request.assemble_context()
|
||||
|
||||
adjusted_request = replace(
|
||||
request,
|
||||
image_urls=request.image_urls if supports_image else [],
|
||||
audio_urls=request.audio_urls if supports_audio else [],
|
||||
)
|
||||
context = await adjusted_request.assemble_context()
|
||||
content = context.get("content")
|
||||
if isinstance(content, str):
|
||||
content_blocks: list[dict[str, T.Any]] = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
content_blocks = content
|
||||
else:
|
||||
content_blocks = []
|
||||
|
||||
if not supports_image:
|
||||
for _ in request.image_urls:
|
||||
content_blocks.append({"type": "text", "text": "[Image]"})
|
||||
if not supports_audio:
|
||||
for _ in request.audio_urls:
|
||||
content_blocks.append({"type": "text", "text": "[Audio]"})
|
||||
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def _write_tool_result_overflow_file(
|
||||
self,
|
||||
*,
|
||||
@@ -415,8 +460,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"contexts": self._sanitize_contexts_for_provider(self.run_context.messages),
|
||||
"func_tool": self._func_tool_for_provider(),
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
"abort_signal": self._abort_signal,
|
||||
@@ -532,6 +577,35 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
completion_text="All available chat models are unavailable.",
|
||||
)
|
||||
|
||||
def _sanitize_contexts_for_provider(
|
||||
self,
|
||||
contexts: list[Message] | list[dict[str, T.Any]],
|
||||
) -> list[Message] | list[dict[str, T.Any]]:
|
||||
if not self._should_fix_modalities_for_provider():
|
||||
return contexts
|
||||
sanitized_contexts, stats = sanitize_contexts_by_modalities(
|
||||
contexts,
|
||||
self.provider.provider_config.get("modalities", None),
|
||||
)
|
||||
log_context_sanitize_stats(stats)
|
||||
return sanitized_contexts
|
||||
|
||||
def _should_fix_modalities_for_provider(self) -> bool:
|
||||
modalities = self.provider.provider_config.get("modalities", None)
|
||||
return isinstance(modalities, list)
|
||||
|
||||
def _func_tool_for_provider(self) -> ToolSet | None:
|
||||
if not self.req.func_tool:
|
||||
return None
|
||||
modalities = self.provider.provider_config.get("modalities", None)
|
||||
if isinstance(modalities, list) and "tool_use" not in modalities:
|
||||
logger.debug(
|
||||
"Provider %s does not support tool_use, clearing tools for request.",
|
||||
self.provider,
|
||||
)
|
||||
return None
|
||||
return self.req.func_tool
|
||||
|
||||
def _simple_print_message_role(self, tag: str = ""):
|
||||
roles = []
|
||||
for message in self.run_context.messages:
|
||||
@@ -1194,7 +1268,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if param_subset.tools and tool_names:
|
||||
contexts = self._build_tool_requery_context(tool_names)
|
||||
requery_resp = await self.provider.text_chat(
|
||||
contexts=contexts,
|
||||
contexts=self._sanitize_contexts_for_provider(contexts),
|
||||
func_tool=param_subset,
|
||||
model=self.req.model,
|
||||
session_id=self.req.session_id,
|
||||
@@ -1220,7 +1294,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
extra_instruction=self.SKILLS_LIKE_REQUERY_REPAIR_INSTRUCTION,
|
||||
)
|
||||
repair_resp = await self.provider.text_chat(
|
||||
contexts=repair_contexts,
|
||||
contexts=self._sanitize_contexts_for_provider(repair_contexts),
|
||||
func_tool=param_subset,
|
||||
model=self.req.model,
|
||||
session_id=self.req.session_id,
|
||||
|
||||
@@ -823,136 +823,6 @@ async def _decorate_llm_request(
|
||||
_apply_workspace_extra_prompt(event, req)
|
||||
|
||||
|
||||
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(
|
||||
"Provider %s does not support image, using placeholder.", provider
|
||||
)
|
||||
image_count = len(req.image_urls)
|
||||
placeholder = " ".join(["[Image]"] * image_count)
|
||||
if req.prompt:
|
||||
req.prompt = f"{placeholder} {req.prompt}"
|
||||
else:
|
||||
req.prompt = placeholder
|
||||
req.image_urls = []
|
||||
if req.audio_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["audio"])
|
||||
if "audio" not in provider_cfg:
|
||||
logger.debug(
|
||||
"Provider %s does not support audio, using placeholder.", provider
|
||||
)
|
||||
audio_count = len(req.audio_urls)
|
||||
placeholder = " ".join(["[Audio]"] * audio_count)
|
||||
if req.prompt:
|
||||
req.prompt = f"{placeholder} {req.prompt}"
|
||||
else:
|
||||
req.prompt = placeholder
|
||||
req.audio_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(
|
||||
"Provider %s does not support tool_use, clearing tools.", provider
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
|
||||
def _sanitize_context_by_modalities(
|
||||
config: MainAgentBuildConfig,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
if not config.sanitize_context_by_modalities:
|
||||
return
|
||||
if not isinstance(req.contexts, list) or not req.contexts:
|
||||
return
|
||||
modalities = provider.provider_config.get("modalities", None)
|
||||
if not modalities or not isinstance(modalities, list):
|
||||
return
|
||||
supports_image = bool("image" in modalities)
|
||||
supports_audio = bool("audio" in modalities)
|
||||
supports_tool_use = bool("tool_use" in modalities)
|
||||
if supports_image and supports_audio and supports_tool_use:
|
||||
return
|
||||
|
||||
sanitized_contexts: list[dict] = []
|
||||
removed_image_blocks = 0
|
||||
removed_audio_blocks = 0
|
||||
removed_tool_messages = 0
|
||||
removed_tool_calls = 0
|
||||
|
||||
for msg in req.contexts:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
|
||||
new_msg = msg
|
||||
if not supports_tool_use:
|
||||
if role == "tool":
|
||||
removed_tool_messages += 1
|
||||
continue
|
||||
if role == "assistant" and "tool_calls" in new_msg:
|
||||
if "tool_calls" in new_msg:
|
||||
removed_tool_calls += 1
|
||||
new_msg.pop("tool_calls", None)
|
||||
new_msg.pop("tool_call_id", None)
|
||||
|
||||
if not supports_image or not supports_audio:
|
||||
content = new_msg.get("content")
|
||||
if isinstance(content, list):
|
||||
filtered_parts: list = []
|
||||
removed_any_multimodal = False
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = str(part.get("type", "")).lower()
|
||||
if not supports_image and part_type in {"image_url", "image"}:
|
||||
removed_any_multimodal = True
|
||||
removed_image_blocks += 1
|
||||
continue
|
||||
if not supports_audio and part_type in {
|
||||
"audio_url",
|
||||
"input_audio",
|
||||
}:
|
||||
removed_any_multimodal = True
|
||||
removed_audio_blocks += 1
|
||||
continue
|
||||
filtered_parts.append(part)
|
||||
if removed_any_multimodal:
|
||||
new_msg["content"] = filtered_parts
|
||||
|
||||
if role == "assistant":
|
||||
content = new_msg.get("content")
|
||||
has_tool_calls = bool(new_msg.get("tool_calls"))
|
||||
if not has_tool_calls:
|
||||
if not content:
|
||||
continue
|
||||
if isinstance(content, str) and not content.strip():
|
||||
continue
|
||||
|
||||
sanitized_contexts.append(new_msg)
|
||||
|
||||
if (
|
||||
removed_image_blocks
|
||||
or removed_audio_blocks
|
||||
or removed_tool_messages
|
||||
or removed_tool_calls
|
||||
):
|
||||
logger.debug(
|
||||
"sanitize_context_by_modalities applied: "
|
||||
"removed_image_blocks=%s, removed_audio_blocks=%s, "
|
||||
"removed_tool_messages=%s, removed_tool_calls=%s",
|
||||
removed_image_blocks,
|
||||
removed_audio_blocks,
|
||||
removed_tool_messages,
|
||||
removed_tool_calls,
|
||||
)
|
||||
req.contexts = sanitized_contexts
|
||||
|
||||
|
||||
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表。
|
||||
|
||||
@@ -1393,10 +1263,8 @@ async def build_main_agent(
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
_modalities_fix(provider, req)
|
||||
_plugin_tool_fix(event, req)
|
||||
await _apply_web_search_tools(event, req, plugin_context)
|
||||
_sanitize_context_by_modalities(config, provider, req)
|
||||
|
||||
if config.llm_safety_mode:
|
||||
_apply_llm_safety_mode(config, req)
|
||||
|
||||
@@ -91,7 +91,7 @@ print(
|
||||
json.dumps(
|
||||
{{
|
||||
"size_bytes": path.stat().st_size,
|
||||
"sample_b64": base64.b64encode(sample).decode("ascii"),
|
||||
"sample_b64": base64.b64encode(sample).decode("utf-8"),
|
||||
}}
|
||||
)
|
||||
)
|
||||
@@ -140,7 +140,7 @@ print(
|
||||
json.dumps(
|
||||
{{
|
||||
"size_bytes": len(data),
|
||||
"base64": base64.b64encode(data).decode("ascii"),
|
||||
"base64": base64.b64encode(data).decode("utf-8"),
|
||||
}}
|
||||
)
|
||||
)
|
||||
@@ -278,7 +278,7 @@ async def _probe_local_file(path: str) -> dict[str, str | int]:
|
||||
sample = file_obj.read(_FILE_SNIFF_BYTES)
|
||||
return {
|
||||
"size_bytes": file_path.stat().st_size,
|
||||
"sample_b64": base64.b64encode(sample).decode("ascii"),
|
||||
"sample_b64": base64.b64encode(sample).decode("utf-8"),
|
||||
}
|
||||
|
||||
return await to_thread(_run)
|
||||
@@ -289,7 +289,7 @@ async def _read_local_image_base64(path: str) -> dict[str, str | int]:
|
||||
data = Path(path).read_bytes()
|
||||
return {
|
||||
"size_bytes": len(data),
|
||||
"base64": base64.b64encode(data).decode("ascii"),
|
||||
"base64": base64.b64encode(data).decode("utf-8"),
|
||||
}
|
||||
|
||||
return await to_thread(_run)
|
||||
@@ -319,7 +319,7 @@ async def _compress_image_bytes_to_base64(data: bytes) -> dict[str, str | int]:
|
||||
|
||||
return {
|
||||
"size_bytes": len(compressed_bytes),
|
||||
"base64": base64.b64encode(compressed_bytes).decode("ascii"),
|
||||
"base64": base64.b64encode(compressed_bytes).decode("utf-8"),
|
||||
"mime_type": "image/jpeg",
|
||||
}
|
||||
|
||||
@@ -659,14 +659,14 @@ async def read_file_tool_result(
|
||||
return "Error reading file: image payload is empty."
|
||||
raw_bytes = base64.b64decode(raw_base64_data)
|
||||
compressed_payload = await _compress_image_bytes_to_base64(raw_bytes)
|
||||
base64_data = str(compressed_payload.get("base64", "") or "")
|
||||
if not base64_data:
|
||||
compressed_base64_data = str(compressed_payload.get("base64", "") or "")
|
||||
if not compressed_base64_data:
|
||||
return "Error reading file: compressed image payload is empty."
|
||||
return mcp.types.CallToolResult(
|
||||
content=[
|
||||
mcp.types.ImageContent(
|
||||
type="image",
|
||||
data=base64_data,
|
||||
data=compressed_base64_data,
|
||||
mimeType=str(
|
||||
compressed_payload.get("mime_type", "") or "image/jpeg"
|
||||
),
|
||||
|
||||
158
astrbot/core/provider/modalities.py
Normal file
158
astrbot/core/provider/modalities.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContextSanitizeStats:
|
||||
fixed_image_blocks: int = 0
|
||||
fixed_audio_blocks: int = 0
|
||||
fixed_tool_messages: int = 0
|
||||
removed_tool_calls: int = 0
|
||||
|
||||
@property
|
||||
def changed(self) -> bool:
|
||||
return bool(
|
||||
self.fixed_image_blocks
|
||||
or self.fixed_audio_blocks
|
||||
or self.fixed_tool_messages
|
||||
or self.removed_tool_calls
|
||||
)
|
||||
|
||||
|
||||
def _message_to_dict(message: dict[str, Any] | Message) -> dict[str, Any] | None:
|
||||
if isinstance(message, Message):
|
||||
return dict(message.model_dump())
|
||||
if isinstance(message, dict):
|
||||
return dict(copy.deepcopy(message))
|
||||
return None
|
||||
|
||||
|
||||
def sanitize_contexts_by_modalities(
|
||||
contexts: Sequence[dict[str, Any] | Message],
|
||||
modalities: list[str] | None,
|
||||
) -> tuple[list[dict[str, Any]], ContextSanitizeStats]:
|
||||
if not contexts:
|
||||
return [], ContextSanitizeStats()
|
||||
if not modalities or not isinstance(modalities, list):
|
||||
copied_contexts = []
|
||||
for msg in contexts:
|
||||
copied_msg = _message_to_dict(msg)
|
||||
if copied_msg:
|
||||
copied_contexts.append(copied_msg)
|
||||
return copied_contexts, ContextSanitizeStats()
|
||||
|
||||
supports_image = "image" in modalities
|
||||
supports_audio = "audio" in modalities
|
||||
supports_tool_use = "tool_use" in modalities
|
||||
if supports_image and supports_audio and supports_tool_use:
|
||||
copied_contexts = []
|
||||
for msg in contexts:
|
||||
copied_msg = _message_to_dict(msg)
|
||||
if copied_msg:
|
||||
copied_contexts.append(copied_msg)
|
||||
return copied_contexts, ContextSanitizeStats()
|
||||
|
||||
sanitized_contexts: list[dict[str, Any]] = []
|
||||
stats = ContextSanitizeStats()
|
||||
|
||||
for raw_msg in contexts:
|
||||
msg = _message_to_dict(raw_msg)
|
||||
if not msg:
|
||||
continue
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
|
||||
if not supports_tool_use:
|
||||
if role == "tool":
|
||||
stats.fixed_tool_messages += 1
|
||||
fixed_msg: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": _tool_result_placeholder(msg.get("content")),
|
||||
}
|
||||
msg = fixed_msg
|
||||
if role == "assistant" and "tool_calls" in msg:
|
||||
stats.removed_tool_calls += 1
|
||||
msg.pop("tool_calls", None)
|
||||
msg.pop("tool_call_id", None)
|
||||
|
||||
if not supports_image or not supports_audio:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
filtered_parts: list[Any] = []
|
||||
removed_any_multimodal = False
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = str(part.get("type", "")).lower()
|
||||
if not supports_image and part_type in {"image_url", "image"}:
|
||||
removed_any_multimodal = True
|
||||
stats.fixed_image_blocks += 1
|
||||
filtered_parts.append({"type": "text", "text": "[Image]"})
|
||||
continue
|
||||
if not supports_audio and part_type in {
|
||||
"audio_url",
|
||||
"input_audio",
|
||||
}:
|
||||
removed_any_multimodal = True
|
||||
stats.fixed_audio_blocks += 1
|
||||
filtered_parts.append({"type": "text", "text": "[Audio]"})
|
||||
continue
|
||||
filtered_parts.append(part)
|
||||
if removed_any_multimodal:
|
||||
msg["content"] = filtered_parts
|
||||
|
||||
if role == "assistant":
|
||||
content = msg.get("content")
|
||||
has_tool_calls = bool(msg.get("tool_calls"))
|
||||
if not has_tool_calls:
|
||||
if not content:
|
||||
continue
|
||||
if isinstance(content, str) and not content.strip():
|
||||
continue
|
||||
|
||||
sanitized_contexts.append(msg)
|
||||
|
||||
return sanitized_contexts, stats
|
||||
|
||||
|
||||
def _tool_result_placeholder(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
content_text = content.strip()
|
||||
elif isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = str(part.get("type", "")).lower()
|
||||
if part_type == "text":
|
||||
text_parts.append(str(part.get("text", "")))
|
||||
elif part_type in {"image_url", "image"}:
|
||||
text_parts.append("[Image]")
|
||||
elif part_type in {"audio_url", "input_audio"}:
|
||||
text_parts.append("[Audio]")
|
||||
content_text = "\n".join(part for part in text_parts if part).strip()
|
||||
else:
|
||||
content_text = ""
|
||||
if not content_text:
|
||||
return "[Tool result]"
|
||||
return f"[Tool result]\n{content_text}"
|
||||
|
||||
|
||||
def log_context_sanitize_stats(stats: ContextSanitizeStats) -> None:
|
||||
if not stats.changed:
|
||||
return
|
||||
logger.debug(
|
||||
"context modality fix applied: "
|
||||
"fixed_image_blocks=%s, fixed_audio_blocks=%s, "
|
||||
"fixed_tool_messages=%s, removed_tool_calls=%s",
|
||||
stats.fixed_image_blocks,
|
||||
stats.fixed_audio_blocks,
|
||||
stats.fixed_tool_messages,
|
||||
stats.removed_tool_calls,
|
||||
)
|
||||
@@ -171,7 +171,7 @@ def _decode_escaped_text(value: str) -> str:
|
||||
@dataclass
|
||||
class FileReadTool(FunctionTool):
|
||||
name: str = "astrbot_file_read_tool"
|
||||
description: str = "read file content."
|
||||
description: str = "read file content. Supports text, image, and PDF (text extraction), docx and epub files."
|
||||
parameters: dict = field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
|
||||
@@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")
|
||||
from astrbot.core.agent.agent import Agent
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.message import ImageURLPart, Message, TextPart
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
@@ -156,6 +157,25 @@ class MockErrProvider(MockProvider):
|
||||
)
|
||||
|
||||
|
||||
class CapturingProvider(MockProvider):
|
||||
def __init__(self, modalities: list[str]):
|
||||
super().__init__()
|
||||
self.provider_config["modalities"] = modalities
|
||||
self.received_contexts = []
|
||||
self.received_func_tools = []
|
||||
self.should_call_tools = False
|
||||
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
self.received_contexts.append(kwargs.get("contexts"))
|
||||
self.received_func_tools.append(kwargs.get("func_tool"))
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="final",
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
|
||||
class MockEmptyOutputThenSuccessProvider(MockProvider):
|
||||
def __init__(self, failures_before_success: int = 1):
|
||||
super().__init__()
|
||||
@@ -615,6 +635,99 @@ async def test_tool_result_includes_all_calltoolresult_content(
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_replaces_runtime_image_context_before_provider_call(
|
||||
runner, provider_request, mock_hooks
|
||||
):
|
||||
provider = CapturingProvider(modalities=["tool_use"])
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=MockToolExecutor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=[
|
||||
TextPart(text="Review this image"),
|
||||
ImageURLPart(
|
||||
image_url=ImageURLPart.ImageURL(
|
||||
url="data:image/png;base64,dGVzdA=="
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(1):
|
||||
pass
|
||||
|
||||
assert provider.received_contexts
|
||||
sent_context = provider.received_contexts[0]
|
||||
assert sent_context[-1]["content"] == [
|
||||
{"type": "text", "text": "Review this image"},
|
||||
{"type": "text", "text": "[Image]"},
|
||||
]
|
||||
assert len(runner.run_context.messages[-2].content) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_builds_placeholder_for_unsupported_request_image(
|
||||
runner, mock_hooks, tool_set
|
||||
):
|
||||
provider = CapturingProvider(modalities=["tool_use"])
|
||||
request = ProviderRequest(
|
||||
prompt="Describe it",
|
||||
image_urls=["/path/that/should/not/be/read.jpg"],
|
||||
func_tool=tool_set,
|
||||
contexts=[],
|
||||
)
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=MockToolExecutor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(1):
|
||||
pass
|
||||
|
||||
sent_context = provider.received_contexts[0]
|
||||
assert sent_context[-1]["content"] == [
|
||||
{"type": "text", "text": "Describe it"},
|
||||
{"type": "text", "text": "[Image]"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_clears_tools_for_provider_without_tool_use(
|
||||
runner, provider_request, mock_hooks, mock_tool_executor
|
||||
):
|
||||
provider = CapturingProvider(modalities=["text"])
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(1):
|
||||
pass
|
||||
|
||||
assert provider.received_func_tools == [None]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_tool_consecutive_results_include_escalating_guidance(
|
||||
runner, mock_tool_executor, mock_hooks
|
||||
|
||||
@@ -713,144 +713,6 @@ class TestDecorateLlmRequest:
|
||||
assert req.prompt == "Hello"
|
||||
|
||||
|
||||
class TestModalitiesFix:
|
||||
"""Tests for _modalities_fix function."""
|
||||
|
||||
def test_modalities_fix_image_not_supported(self, mock_provider):
|
||||
"""Test modality fix when image is not supported."""
|
||||
module = ama
|
||||
mock_provider.provider_config = {"modalities": ["text"]}
|
||||
req = ProviderRequest(prompt="Hello", image_urls=["/path/to/image.jpg"])
|
||||
|
||||
module._modalities_fix(mock_provider, req)
|
||||
|
||||
assert "[Image]" in req.prompt
|
||||
assert req.image_urls == []
|
||||
|
||||
def test_modalities_fix_tool_not_supported(self, mock_provider):
|
||||
"""Test modality fix when tool is not supported."""
|
||||
module = ama
|
||||
mock_provider.provider_config = {"modalities": ["text", "image"]}
|
||||
req = ProviderRequest(prompt="Hello")
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(
|
||||
FunctionTool(
|
||||
name="dummy_tool",
|
||||
description="dummy",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
)
|
||||
|
||||
module._modalities_fix(mock_provider, req)
|
||||
|
||||
assert req.func_tool is None
|
||||
|
||||
def test_modalities_fix_all_supported(self, mock_provider):
|
||||
"""Test modality fix when all features are supported."""
|
||||
module = ama
|
||||
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
|
||||
tool_set = ToolSet()
|
||||
tool_set.add_tool(
|
||||
FunctionTool(
|
||||
name="dummy_tool",
|
||||
description="dummy",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
)
|
||||
req = ProviderRequest(
|
||||
prompt="Hello",
|
||||
image_urls=["/path/to/image.jpg"],
|
||||
func_tool=tool_set,
|
||||
)
|
||||
|
||||
module._modalities_fix(mock_provider, req)
|
||||
|
||||
assert req.prompt == "Hello"
|
||||
assert len(req.image_urls) == 1
|
||||
assert req.func_tool is not None
|
||||
|
||||
|
||||
class TestSanitizeContextByModalities:
|
||||
"""Tests for _sanitize_context_by_modalities function."""
|
||||
|
||||
def test_sanitize_no_op(self, mock_provider):
|
||||
"""Test sanitize when disabled or modalities support everything."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60, sanitize_context_by_modalities=False
|
||||
)
|
||||
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
|
||||
req = ProviderRequest(contexts=[{"role": "user", "content": "Hello"}])
|
||||
|
||||
module._sanitize_context_by_modalities(config, mock_provider, req)
|
||||
|
||||
assert len(req.contexts) == 1
|
||||
|
||||
def test_sanitize_removes_tool_messages(self, mock_provider):
|
||||
"""Test sanitize removes tool messages when tool_use not supported."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60, sanitize_context_by_modalities=True
|
||||
)
|
||||
mock_provider.provider_config = {"modalities": ["image"]}
|
||||
req = ProviderRequest(
|
||||
contexts=[
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "tool", "content": "Tool result"},
|
||||
]
|
||||
)
|
||||
|
||||
module._sanitize_context_by_modalities(config, mock_provider, req)
|
||||
|
||||
assert len(req.contexts) == 1
|
||||
assert req.contexts[0]["role"] == "user"
|
||||
|
||||
def test_sanitize_removes_tool_calls(self, mock_provider):
|
||||
"""Test sanitize removes tool_calls from assistant messages."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60, sanitize_context_by_modalities=True
|
||||
)
|
||||
mock_provider.provider_config = {"modalities": ["image"]}
|
||||
req = ProviderRequest(
|
||||
contexts=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response",
|
||||
"tool_calls": [{"name": "tool"}],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
module._sanitize_context_by_modalities(config, mock_provider, req)
|
||||
|
||||
assert "tool_calls" not in req.contexts[0]
|
||||
|
||||
def test_sanitize_removes_image_blocks(self, mock_provider):
|
||||
"""Test sanitize removes image blocks when image not supported."""
|
||||
module = ama
|
||||
config = module.MainAgentBuildConfig(
|
||||
tool_call_timeout=60, sanitize_context_by_modalities=True
|
||||
)
|
||||
mock_provider.provider_config = {"modalities": ["tool_use"]}
|
||||
req = ProviderRequest(
|
||||
contexts=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "image_url", "url": "image.jpg"},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
module._sanitize_context_by_modalities(config, mock_provider, req)
|
||||
|
||||
assert len(req.contexts[0]["content"]) == 1
|
||||
assert req.contexts[0]["content"][0]["type"] == "text"
|
||||
|
||||
|
||||
class TestPluginToolFix:
|
||||
"""Tests for _plugin_tool_fix function."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user