mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 19:20:16 +08:00
Compare commits
11 Commits
feat/multi
...
fix/reason
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
47ea036a81 | ||
|
|
643a8b177e | ||
|
|
07b37b98de | ||
|
|
bbda1e678f | ||
|
|
3c1d0cd2c2 | ||
|
|
d16ed4e552 | ||
|
|
55c1558686 | ||
|
|
17aea1aa2c | ||
|
|
d4cdeeae72 | ||
|
|
5ce02da6df | ||
|
|
5d79c99938 |
@@ -183,10 +183,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
if llm_resp.reasoning_content is not None or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
think=llm_resp.reasoning_content or "",
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
@@ -876,10 +876,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
# 将结果添加到上下文中
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
if llm_resp.reasoning_content is not None or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
think=llm_resp.reasoning_content or "",
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
@@ -1361,10 +1361,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
if llm_resp.reasoning_content is not None or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
think=llm_resp.reasoning_content or "",
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -77,6 +77,8 @@ from astrbot.core.tools.web_search_tools import (
|
||||
BaiduWebSearchTool,
|
||||
BochaWebSearchTool,
|
||||
BraveWebSearchTool,
|
||||
FirecrawlExtractWebPageTool,
|
||||
FirecrawlWebSearchTool,
|
||||
TavilyExtractWebPageTool,
|
||||
TavilyWebSearchTool,
|
||||
normalize_legacy_web_search_config,
|
||||
@@ -1047,6 +1049,9 @@ async def _apply_web_search_tools(
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BochaWebSearchTool))
|
||||
elif provider == "brave":
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BraveWebSearchTool))
|
||||
elif provider == "firecrawl":
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlWebSearchTool))
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlExtractWebPageTool))
|
||||
elif provider == "baidu_ai_search":
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool))
|
||||
|
||||
|
||||
@@ -3202,6 +3202,7 @@ CONFIG_METADATA_3 = {
|
||||
"baidu_ai_search",
|
||||
"bocha",
|
||||
"brave",
|
||||
"firecrawl",
|
||||
],
|
||||
"condition": {
|
||||
"provider_settings.web_search": True,
|
||||
@@ -3237,6 +3238,16 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_firecrawl_key": {
|
||||
"description": "Firecrawl API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可添加多个 Key 进行轮询。",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "firecrawl",
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_baidu_app_builder_key": {
|
||||
"description": "百度千帆智能云 APP Builder API Key",
|
||||
"type": "string",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
@@ -140,6 +141,8 @@ class WecomServer:
|
||||
|
||||
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
|
||||
class WecomPlatformAdapter(Platform):
|
||||
WECHAT_KF_TEXT_CONTENT_DEDUP_TTL_SECONDS = 15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
platform_config: dict,
|
||||
@@ -166,6 +169,7 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
self.server = WecomServer(self._event_queue, self.config)
|
||||
self.agent_id: str | None = None
|
||||
self._wechat_kf_seen_text_messages: dict[str, float] = {}
|
||||
|
||||
self.client = WeChatClient(
|
||||
self.config["corpid"].strip(),
|
||||
@@ -210,6 +214,28 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
self.server.callback = callback
|
||||
|
||||
def _is_duplicate_wechat_kf_text_message(self, session_id: str, text: str) -> bool:
|
||||
normalized_text = text.strip()
|
||||
if not normalized_text:
|
||||
return False
|
||||
|
||||
now = time.monotonic()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, expires_at in self._wechat_kf_seen_text_messages.items()
|
||||
if expires_at <= now
|
||||
]
|
||||
for key in expired_keys:
|
||||
self._wechat_kf_seen_text_messages.pop(key, None)
|
||||
|
||||
dedup_key = f"{session_id}:{normalized_text}"
|
||||
if dedup_key in self._wechat_kf_seen_text_messages:
|
||||
return True
|
||||
self._wechat_kf_seen_text_messages[dedup_key] = (
|
||||
now + self.WECHAT_KF_TEXT_CONTENT_DEDUP_TTL_SECONDS
|
||||
)
|
||||
return False
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -390,6 +416,13 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.message_str = ""
|
||||
if msgtype == "text":
|
||||
text = msg.get("text", {}).get("content", "").strip()
|
||||
if self._is_duplicate_wechat_kf_text_message(abm.session_id, text):
|
||||
logger.debug(
|
||||
"忽略 15 秒内重复微信客服文本消息 session_id=%s text=%s",
|
||||
abm.session_id,
|
||||
text,
|
||||
)
|
||||
return None
|
||||
abm.message = [Plain(text=text)]
|
||||
abm.message_str = text
|
||||
elif msgtype == "image":
|
||||
|
||||
@@ -353,7 +353,7 @@ class LLMResponse:
|
||||
"""Tool call IDs."""
|
||||
tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Tool call extra content. tool_call_id -> extra_content dict"""
|
||||
reasoning_content: str = ""
|
||||
reasoning_content: str | None = None
|
||||
"""The reasoning content extracted from the LLM, if any."""
|
||||
reasoning_signature: str | None = None
|
||||
"""The signature of the reasoning content, if any."""
|
||||
@@ -404,8 +404,6 @@ class LLMResponse:
|
||||
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||
|
||||
"""
|
||||
if reasoning_content is None:
|
||||
reasoning_content = ""
|
||||
if tools_call_args is None:
|
||||
tools_call_args = []
|
||||
if tools_call_name is None:
|
||||
|
||||
@@ -39,7 +39,7 @@ class ProviderAnthropic(Provider):
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
has_text_output = bool((llm_response.completion_text or "").strip())
|
||||
has_reasoning_output = bool(llm_response.reasoning_content.strip())
|
||||
has_reasoning_output = bool((llm_response.reasoning_content or "").strip())
|
||||
has_tool_output = bool(llm_response.tools_call_args)
|
||||
if has_text_output or has_reasoning_output or has_tool_output:
|
||||
return
|
||||
|
||||
@@ -462,7 +462,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
finish_reason: str | None = None,
|
||||
) -> None:
|
||||
has_text_output = bool((llm_response.completion_text or "").strip())
|
||||
has_reasoning_output = bool(llm_response.reasoning_content.strip())
|
||||
has_reasoning_output = bool((llm_response.reasoning_content or "").strip())
|
||||
has_tool_output = bool(llm_response.tools_call_args)
|
||||
if has_text_output or has_reasoning_output or has_tool_output:
|
||||
return
|
||||
|
||||
@@ -65,7 +65,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
self.audio_setting: dict = {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
"format": "wav",
|
||||
}
|
||||
|
||||
self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
|
||||
@@ -147,7 +147,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
|
||||
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.wav")
|
||||
|
||||
try:
|
||||
# 直接将异步生成器传递给 _audio_play 方法
|
||||
|
||||
@@ -519,6 +519,42 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_assistant_messages(payloads: dict) -> None:
|
||||
"""在请求发送前过滤/规范化空的 assistant 消息。
|
||||
|
||||
严格 API(Moonshot、DeepSeek Reasoner 等)会在 assistant 消息同时缺少
|
||||
``content`` 和 ``tool_calls`` 时返回 400。把 ``""`` / ``None`` / ``[]``
|
||||
都视作空内容:无 tool_calls 时整条过滤掉;有 tool_calls 时将 content
|
||||
设为 ``None`` 以符合 OpenAI 规范。就地修改 ``payloads["messages"]``。
|
||||
"""
|
||||
messages = payloads.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return
|
||||
|
||||
def _is_empty(content: Any) -> bool:
|
||||
return content is None or content == "" or content == []
|
||||
|
||||
cleaned: list[Any] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or msg.get("role") != "assistant":
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
|
||||
content = msg.get("content")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
if _is_empty(content) and not tool_calls:
|
||||
logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
|
||||
continue
|
||||
|
||||
if _is_empty(content) and tool_calls:
|
||||
msg["content"] = None
|
||||
|
||||
cleaned.append(msg)
|
||||
|
||||
payloads["messages"] = cleaned
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
model = payloads.get("model", "").lower()
|
||||
@@ -548,26 +584,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
model = payloads.get("model", "").lower()
|
||||
|
||||
if "messages" in payloads and isinstance(payloads["messages"], list):
|
||||
cleaned_messages = []
|
||||
for idx, msg in enumerate(payloads["messages"]):
|
||||
# 过滤空的 assistant 消息,防止严格 API(如 Moonshot)返回 400 错误
|
||||
if msg.get("role") == "assistant":
|
||||
content = msg.get("content")
|
||||
tool_calls = msg.get("tool_calls")
|
||||
|
||||
# 情况1: 空/null content 且无 tool_calls -> 过滤掉
|
||||
if not tool_calls and (content == "" or content is None):
|
||||
logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
|
||||
continue
|
||||
|
||||
# 情况2: 空 content 但有 tool_calls -> 设为 None (符合 OpenAI 规范)
|
||||
if content == "" and tool_calls:
|
||||
msg["content"] = None
|
||||
|
||||
cleaned_messages.append(msg)
|
||||
|
||||
payloads["messages"] = cleaned_messages
|
||||
self._sanitize_assistant_messages(payloads)
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
@@ -619,6 +636,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
del payloads[key]
|
||||
self._apply_provider_specific_extra_body_overrides(extra_body)
|
||||
|
||||
self._sanitize_assistant_messages(payloads)
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=True,
|
||||
@@ -652,9 +671,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
llm_response.id = chunk.id
|
||||
llm_response.reasoning_content = ""
|
||||
llm_response.reasoning_content = None
|
||||
llm_response.completion_text = ""
|
||||
if reasoning:
|
||||
if reasoning is not None:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
if delta and delta.content:
|
||||
@@ -682,22 +701,28 @@ class ProviderOpenAIOfficial(Provider):
|
||||
def _extract_reasoning_content(
|
||||
self,
|
||||
completion: ChatCompletion | ChatCompletionChunk,
|
||||
) -> str:
|
||||
) -> str | None:
|
||||
"""Extract reasoning content from OpenAI ChatCompletion if available."""
|
||||
reasoning_text = ""
|
||||
|
||||
def _get_reasoning_attr(obj: Any) -> str | None:
|
||||
fields_set = getattr(obj, "model_fields_set", None)
|
||||
if isinstance(fields_set, set) and self.reasoning_key in fields_set:
|
||||
attr = getattr(obj, self.reasoning_key, "")
|
||||
return "" if attr is None else str(attr)
|
||||
attr = getattr(obj, self.reasoning_key, None)
|
||||
return None if attr is None else str(attr)
|
||||
|
||||
if not completion.choices:
|
||||
return reasoning_text
|
||||
return None
|
||||
if isinstance(completion, ChatCompletion):
|
||||
choice = completion.choices[0]
|
||||
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
reasoning_attr = _get_reasoning_attr(choice.message)
|
||||
elif isinstance(completion, ChatCompletionChunk):
|
||||
delta = completion.choices[0].delta
|
||||
reasoning_attr = getattr(delta, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
reasoning_attr = _get_reasoning_attr(delta)
|
||||
else:
|
||||
return None
|
||||
return reasoning_attr
|
||||
|
||||
def _extract_usage(self, usage: CompletionUsage | dict) -> TokenUsage:
|
||||
ptd = getattr(usage, "prompt_tokens_details", None)
|
||||
@@ -840,7 +865,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
# parse the reasoning content if any
|
||||
# the priority is higher than the <think> tag extraction
|
||||
llm_response.reasoning_content = self._extract_reasoning_content(completion)
|
||||
reasoning_content = self._extract_reasoning_content(completion)
|
||||
if reasoning_content is not None:
|
||||
llm_response.reasoning_content = reasoning_content
|
||||
|
||||
# parse tool calls if any
|
||||
if choice.message.tool_calls and tools is not None:
|
||||
@@ -887,7 +914,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
|
||||
)
|
||||
has_text_output = bool((llm_response.completion_text or "").strip())
|
||||
has_reasoning_output = bool(llm_response.reasoning_content.strip())
|
||||
has_reasoning_output = bool((llm_response.reasoning_content or "").strip())
|
||||
if (
|
||||
not has_text_output
|
||||
and not has_reasoning_output
|
||||
@@ -963,24 +990,39 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"""Finally convert the payload. Such as think part conversion, tool inject."""
|
||||
model = payloads.get("model", "").lower()
|
||||
is_gemini = "gemini" in model
|
||||
|
||||
deepseek_reasoning_models = {"deepseek-v4-pro", "deepseek-v4-flash"}
|
||||
is_deepseek_v4_reasoning = (
|
||||
model in deepseek_reasoning_models
|
||||
or "api.deepseek.com" in self.client.base_url.host
|
||||
)
|
||||
for message in payloads.get("messages", []):
|
||||
if message.get("role") == "assistant" and isinstance(
|
||||
message.get("content"), list
|
||||
):
|
||||
reasoning_content = ""
|
||||
reasoning_content_present = False
|
||||
new_content = [] # not including think part
|
||||
for part in message["content"]:
|
||||
if part.get("type") == "think":
|
||||
reasoning_content_present = True
|
||||
reasoning_content += str(part.get("think"))
|
||||
else:
|
||||
new_content.append(part)
|
||||
# Some providers (Grok, etc.) reject empty content lists.
|
||||
# When all parts were think blocks, fall back to None.
|
||||
message["content"] = new_content or None
|
||||
if reasoning_content:
|
||||
if reasoning_content_present:
|
||||
message["reasoning_content"] = reasoning_content
|
||||
|
||||
if (
|
||||
message.get("role") == "assistant"
|
||||
and is_deepseek_v4_reasoning
|
||||
and "reasoning_content" not in message
|
||||
):
|
||||
# DeepSeek v4 reasoning models require the field on assistant
|
||||
# history messages, even when the reasoning content is empty.
|
||||
message["reasoning_content"] = ""
|
||||
|
||||
# Gemini 的 function_response 要求 google.protobuf.Struct(即 JSON 对象),
|
||||
# 纯文本会触发 400 Invalid argument,需要包一层 JSON。
|
||||
if is_gemini and message.get("role") == "tool":
|
||||
|
||||
@@ -20,3 +20,4 @@ class ProviderOpenRouter(ProviderOpenAIOfficial):
|
||||
self.client._custom_headers["X-OpenRouter-Categories"] = (
|
||||
"general-chat,personal-agent" # type: ignore
|
||||
)
|
||||
self.reasoning_key = "reasoning"
|
||||
|
||||
@@ -43,7 +43,7 @@ from astrbot.core.agent.tool import ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.computer.file_read_utils import read_file_tool_result
|
||||
from astrbot.core.message.components import File
|
||||
from astrbot.core.message.components import File, Image
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_skills_path,
|
||||
get_astrbot_system_tmp_path,
|
||||
@@ -64,6 +64,7 @@ _COMPUTER_RUNTIME_TOOL_CONFIG = {
|
||||
_SANDBOX_RUNTIME_TOOL_CONFIG = {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
}
|
||||
_IMAGE_FILE_SUFFIXES = {".bmp", ".gif", ".jpeg", ".jpg", ".png", ".webp"}
|
||||
|
||||
|
||||
def _restricted_env_path_labels(umo: str) -> list[str]:
|
||||
@@ -729,11 +730,21 @@ class FileDownloadTool(FunctionTool):
|
||||
if also_send_to_user:
|
||||
try:
|
||||
name = os.path.basename(local_path)
|
||||
if Path(local_path).suffix.lower() in _IMAGE_FILE_SUFFIXES:
|
||||
message_component = Image.fromFileSystem(local_path)
|
||||
sent_as = "image"
|
||||
else:
|
||||
message_component = File(name=name, file=local_path)
|
||||
sent_as = "file"
|
||||
await context.context.event.send(
|
||||
MessageChain(chain=[File(name=name, file=local_path)])
|
||||
MessageChain(chain=[message_component])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending file message: {e}")
|
||||
return (
|
||||
f"File downloaded successfully to {local_path} "
|
||||
f"but sending to user failed: {e}"
|
||||
)
|
||||
|
||||
# remove
|
||||
# try:
|
||||
@@ -741,7 +752,10 @@ class FileDownloadTool(FunctionTool):
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error removing temp file {local_path}: {e}")
|
||||
|
||||
return f"File downloaded successfully to {local_path} and sent to user."
|
||||
return (
|
||||
f"File downloaded successfully to {local_path} "
|
||||
f"and sent to user as {sent_as}."
|
||||
)
|
||||
|
||||
return f"File downloaded successfully to {local_path}"
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,6 +19,8 @@ WEB_SEARCH_TOOL_NAMES = [
|
||||
"tavily_extract_web_page",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
"web_search_firecrawl",
|
||||
"firecrawl_extract_web_page",
|
||||
]
|
||||
_TAVILY_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
@@ -32,6 +34,10 @@ _BRAVE_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "brave",
|
||||
}
|
||||
_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "firecrawl",
|
||||
}
|
||||
_BAIDU_WEB_SEARCH_TOOL_CONFIG = {
|
||||
"provider_settings.web_search": True,
|
||||
"provider_settings.websearch_provider": "baidu_ai_search",
|
||||
@@ -69,6 +75,7 @@ class _KeyRotator:
|
||||
_TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily")
|
||||
_BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha")
|
||||
_BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave")
|
||||
_FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl")
|
||||
|
||||
|
||||
def normalize_legacy_web_search_config(cfg) -> None:
|
||||
@@ -91,6 +98,7 @@ def normalize_legacy_web_search_config(cfg) -> None:
|
||||
"websearch_tavily_key",
|
||||
"websearch_bocha_key",
|
||||
"websearch_brave_key",
|
||||
"websearch_firecrawl_key",
|
||||
):
|
||||
value = provider_settings.get(setting_name)
|
||||
if isinstance(value, str):
|
||||
@@ -258,6 +266,72 @@ async def _brave_search(
|
||||
]
|
||||
|
||||
|
||||
async def _firecrawl_search(
|
||||
provider_settings: dict,
|
||||
payload: dict,
|
||||
) -> list[SearchResult]:
|
||||
firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings)
|
||||
header = {
|
||||
"Authorization": f"Bearer {firecrawl_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
"https://api.firecrawl.dev/v2/search",
|
||||
json=payload,
|
||||
headers=header,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
reason = await response.text()
|
||||
raise Exception(
|
||||
f"Firecrawl web search failed: {reason}, status: {response.status}",
|
||||
)
|
||||
data = await response.json()
|
||||
rows = data.get("data", [])
|
||||
if isinstance(rows, dict):
|
||||
rows = rows.get("web", [])
|
||||
return [
|
||||
SearchResult(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=(
|
||||
item.get("description")
|
||||
or item.get("snippet")
|
||||
or item.get("markdown")
|
||||
or ""
|
||||
),
|
||||
)
|
||||
for item in rows
|
||||
if item.get("url")
|
||||
]
|
||||
|
||||
|
||||
async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict:
|
||||
firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings)
|
||||
header = {
|
||||
"Authorization": f"Bearer {firecrawl_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
"https://api.firecrawl.dev/v2/scrape",
|
||||
json=payload,
|
||||
headers=header,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
reason = await response.text()
|
||||
raise Exception(
|
||||
f"Firecrawl web scraper failed: {reason}, status: {response.status}",
|
||||
)
|
||||
data = await response.json()
|
||||
result = data.get("data", {})
|
||||
if not result:
|
||||
raise ValueError(
|
||||
"Error: Firecrawl web scraper does not return any results."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _baidu_search(
|
||||
provider_settings: dict,
|
||||
payload: dict,
|
||||
@@ -548,6 +622,124 @@ class BraveWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
return _search_result_payload(results)
|
||||
|
||||
|
||||
@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class FirecrawlWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "web_search_firecrawl"
|
||||
description: str = (
|
||||
"A web search tool based on Firecrawl Search API, used to retrieve web "
|
||||
"pages related to the user's query."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Required. Search query."},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Optional. Number of results to return. Range: 1-100. Default is 5.",
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Optional. Geographic location for search results.",
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": 'Optional. Country code for search results, for example "US" or "CN".',
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Optional. Request timeout in milliseconds.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(self, context, **kwargs) -> ToolExecResult:
|
||||
_, provider_settings, _ = _get_runtime(context)
|
||||
if not provider_settings.get("websearch_firecrawl_key", []):
|
||||
return "Error: Firecrawl API key is not configured in AstrBot."
|
||||
|
||||
payload = {
|
||||
"query": kwargs["query"],
|
||||
"limit": kwargs.get("limit", 5),
|
||||
"sources": ["web"],
|
||||
}
|
||||
for key in ("location", "country", "timeout"):
|
||||
if kwargs.get(key):
|
||||
payload[key] = kwargs[key]
|
||||
|
||||
results = await _firecrawl_search(provider_settings, payload)
|
||||
if not results:
|
||||
return "Error: Firecrawl web searcher does not return any results."
|
||||
return _search_result_payload(results)
|
||||
|
||||
|
||||
@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class FirecrawlExtractWebPageTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "firecrawl_extract_web_page"
|
||||
description: str = "Extract the content of a web page using Firecrawl."
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Required. A URL to extract content from.",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": 'Optional. Output format, one of "markdown", "html", "rawHtml", "summary". Default is "markdown".',
|
||||
},
|
||||
"only_main_content": {
|
||||
"type": "boolean",
|
||||
"description": "Optional. Whether to extract only the main page content. Default is true.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Optional. Request timeout in milliseconds.",
|
||||
},
|
||||
"max_age": {
|
||||
"type": "integer",
|
||||
"description": "Optional. Maximum cache age in milliseconds.",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(self, context, **kwargs) -> ToolExecResult:
|
||||
_, provider_settings, _ = _get_runtime(context)
|
||||
if not provider_settings.get("websearch_firecrawl_key", []):
|
||||
return "Error: Firecrawl API key is not configured in AstrBot."
|
||||
|
||||
url = str(kwargs.get("url", "")).strip()
|
||||
if not url:
|
||||
return "Error: url must be a non-empty string."
|
||||
|
||||
output_format = kwargs.get("format", "markdown")
|
||||
if output_format not in ["markdown", "html", "rawHtml", "summary"]:
|
||||
output_format = "markdown"
|
||||
|
||||
payload = {
|
||||
"url": url,
|
||||
"formats": [output_format],
|
||||
"onlyMainContent": kwargs.get("only_main_content", True),
|
||||
}
|
||||
if kwargs.get("timeout"):
|
||||
payload["timeout"] = kwargs["timeout"]
|
||||
if kwargs.get("max_age"):
|
||||
payload["maxAge"] = kwargs["max_age"]
|
||||
|
||||
result = await _firecrawl_scrape(provider_settings, payload)
|
||||
content = result.get(output_format, "")
|
||||
result_url = result.get("url") or url
|
||||
ret = f"URL: {result_url}\nContent: {content}" if content else ""
|
||||
return ret or "Error: Firecrawl web scraper does not return any results."
|
||||
|
||||
|
||||
@builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG)
|
||||
@pydantic_dataclass
|
||||
class BaiduWebSearchTool(FunctionTool[AstrAgentContext]):
|
||||
|
||||
@@ -436,19 +436,30 @@ async def compress_image(
|
||||
optimize = IMAGE_COMPRESS_DEFAULT_OPTIMIZE
|
||||
min_file_size_bytes = int(IMAGE_COMPRESS_DEFAULT_MIN_FILE_SIZE_MB * 1024 * 1024)
|
||||
data = None
|
||||
|
||||
def _exceeds_max_size(source: bytes | Path) -> bool:
|
||||
try:
|
||||
fp = io.BytesIO(source) if isinstance(source, bytes) else source
|
||||
with PILImage.open(fp) as opened_img:
|
||||
return max(opened_img.size) > max_size
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
# Skip compression for remote images and return the original value.
|
||||
if url_or_path.startswith("http"):
|
||||
return url_or_path
|
||||
elif url_or_path.startswith("data:image"):
|
||||
_header, encoded = url_or_path.split(",", 1)
|
||||
data = base64.b64decode(encoded)
|
||||
if len(data) < min_file_size_bytes:
|
||||
if len(data) < min_file_size_bytes and not _exceeds_max_size(data):
|
||||
return url_or_path
|
||||
else:
|
||||
local_path = Path(url_or_path)
|
||||
if not local_path.exists():
|
||||
return url_or_path
|
||||
if local_path.stat().st_size < min_file_size_bytes:
|
||||
if local_path.stat().st_size < min_file_size_bytes and not _exceeds_max_size(
|
||||
local_path
|
||||
):
|
||||
return url_or_path
|
||||
with local_path.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
@@ -5,8 +5,9 @@ import ssl
|
||||
import httpx
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.utils.http_ssl_common import build_ssl_context_with_certifi
|
||||
|
||||
_SYSTEM_SSL_CTX = ssl.create_default_context()
|
||||
_SYSTEM_SSL_CTX = build_ssl_context_with_certifi()
|
||||
|
||||
|
||||
def is_connection_error(exc: BaseException) -> bool:
|
||||
@@ -92,9 +93,9 @@ def create_proxy_client(
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx AsyncClient with proxy configuration if provided.
|
||||
|
||||
Uses the system SSL certificate store instead of certifi, which avoids
|
||||
SSL verification failures for endpoints whose CA chain is not in certifi
|
||||
but is trusted by the operating system.
|
||||
Uses a hybrid SSL context that combines the system SSL certificate store
|
||||
with certifi as a fallback, ensuring compatibility across different
|
||||
environments including Windows where the system store may be incomplete.
|
||||
|
||||
Note: The caller is responsible for closing the client when done.
|
||||
Consider using the client as a context manager or calling aclose() explicitly.
|
||||
@@ -103,11 +104,11 @@ def create_proxy_client(
|
||||
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
|
||||
proxy: The proxy address (e.g., "http://127.0.0.1:7890"), or None/empty
|
||||
headers: Optional custom headers to include in every request
|
||||
verify: Optional override for TLS verification. Defaults to the shared
|
||||
system SSL context when not provided.
|
||||
verify: Optional override for TLS verification. Defaults to the hybrid
|
||||
SSL context (system store + certifi) when not provided.
|
||||
|
||||
Returns:
|
||||
An httpx.AsyncClient created with the shared system SSL context; the proxy is applied only if one is provided.
|
||||
An httpx.AsyncClient created with the hybrid SSL context (system store + certifi); the proxy is applied only if one is provided.
|
||||
"""
|
||||
resolved_verify = _SYSTEM_SSL_CTX if verify is None else verify
|
||||
if proxy:
|
||||
|
||||
@@ -303,7 +303,7 @@ export default {
|
||||
part.tool_calls.forEach(toolCall => {
|
||||
// 检查是否是支持引用解析的 web_search 工具调用
|
||||
if (
|
||||
!['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave'].includes(toolCall.name) ||
|
||||
!['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave', 'web_search_firecrawl'].includes(toolCall.name) ||
|
||||
!toolCall.result
|
||||
) {
|
||||
return;
|
||||
|
||||
@@ -125,6 +125,10 @@
|
||||
"description": "Brave Search API Key",
|
||||
"hint": "Multiple keys can be added for rotation."
|
||||
},
|
||||
"websearch_firecrawl_key": {
|
||||
"description": "Firecrawl API Key",
|
||||
"hint": "Multiple keys can be added for rotation."
|
||||
},
|
||||
"websearch_baidu_app_builder_key": {
|
||||
"description": "Baidu Qianfan Smart Cloud APP Builder API Key",
|
||||
"hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"
|
||||
|
||||
@@ -125,6 +125,10 @@
|
||||
"description": "API-ключ Brave Search",
|
||||
"hint": "Можно добавить несколько ключей для ротации."
|
||||
},
|
||||
"websearch_firecrawl_key": {
|
||||
"description": "API-ключ Firecrawl",
|
||||
"hint": "Можно добавить несколько ключей для ротации."
|
||||
},
|
||||
"websearch_baidu_app_builder_key": {
|
||||
"description": "API-ключ Baidu Qianfan APP Builder",
|
||||
"hint": "Ссылка: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"
|
||||
|
||||
@@ -127,6 +127,10 @@
|
||||
"description": "Brave Search API Key",
|
||||
"hint": "可添加多个 Key 进行轮询。"
|
||||
},
|
||||
"websearch_firecrawl_key": {
|
||||
"description": "Firecrawl API Key",
|
||||
"hint": "可添加多个 Key 进行轮询。"
|
||||
},
|
||||
"websearch_baidu_app_builder_key": {
|
||||
"description": "百度千帆智能云 APP Builder API Key",
|
||||
"hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"
|
||||
|
||||
@@ -1618,3 +1618,109 @@ async def test_query_does_not_filter_user_or_system_messages(monkeypatch):
|
||||
assert messages[2] == {"role": "user", "content": "hello"}
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_stream_filters_empty_assistant_message(monkeypatch):
|
||||
"""Regression for #7721: streaming path must also filter empty assistant messages.
|
||||
|
||||
Previously only ``_query`` sanitized the payload; ``_query_stream`` forwarded
|
||||
the raw history and strict providers (e.g. DeepSeek Reasoner) returned 400 on
|
||||
the next turn after a tool call whose assistant entry had reasoning only.
|
||||
"""
|
||||
provider = _make_provider()
|
||||
try:
|
||||
captured_kwargs = {}
|
||||
|
||||
async def fake_stream():
|
||||
yield ChatCompletionChunk.model_validate(
|
||||
{
|
||||
"id": "chatcmpl-stream",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": "deepseek-reasoner",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return fake_stream()
|
||||
|
||||
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
|
||||
|
||||
payloads = {
|
||||
"model": "deepseek-reasoner",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": ""}, # should be filtered
|
||||
{"role": "user", "content": "world"},
|
||||
],
|
||||
}
|
||||
|
||||
async for _ in provider._query_stream(payloads=payloads, tools=None):
|
||||
pass
|
||||
|
||||
messages = captured_kwargs["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "hello"}
|
||||
assert messages[1] == {"role": "user", "content": "world"}
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_filters_empty_list_content_assistant_message(monkeypatch):
|
||||
"""Empty-list content (``content == []``) must also be filtered, not just ``""`` / ``None``."""
|
||||
provider = _make_provider()
|
||||
try:
|
||||
captured_kwargs = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return ChatCompletion.model_validate(
|
||||
{
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 0,
|
||||
"model": "gpt-4o-mini",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
|
||||
|
||||
payloads = {
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": []}, # should be filtered
|
||||
{"role": "user", "content": "again"},
|
||||
],
|
||||
}
|
||||
|
||||
await provider._query(payloads=payloads, tools=None)
|
||||
|
||||
messages = captured_kwargs["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "hi"}
|
||||
assert messages[1] == {"role": "user", "content": "again"}
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
@@ -398,6 +398,37 @@ class TestBuiltinToolInjection:
|
||||
assert req.func_tool is not None
|
||||
assert req.func_tool.get_tool("web_search_baidu") is builtin_tool
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_web_search_tools_adds_firecrawl_search_and_extract_tools(
|
||||
self, mock_event, mock_context
|
||||
):
|
||||
"""Test Firecrawl web search injects search and extract tools."""
|
||||
module = ama
|
||||
req = ProviderRequest()
|
||||
mock_context.get_config.return_value = {
|
||||
"provider_settings": {
|
||||
"web_search": True,
|
||||
"websearch_provider": "firecrawl",
|
||||
}
|
||||
}
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_firecrawl"
|
||||
extract_tool = MagicMock(spec=FunctionTool)
|
||||
extract_tool.name = "firecrawl_extract_web_page"
|
||||
tool_mgr = MagicMock()
|
||||
tool_mgr.get_builtin_tool.side_effect = [search_tool, extract_tool]
|
||||
mock_context.get_llm_tool_manager.return_value = tool_mgr
|
||||
|
||||
await module._apply_web_search_tools(mock_event, req, mock_context)
|
||||
|
||||
assert tool_mgr.get_builtin_tool.call_args_list == [
|
||||
((module.FirecrawlWebSearchTool,),),
|
||||
((module.FirecrawlExtractWebPageTool,),),
|
||||
]
|
||||
assert req.func_tool is not None
|
||||
assert req.func_tool.get_tool("web_search_firecrawl") is search_tool
|
||||
assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool
|
||||
|
||||
def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context):
|
||||
"""Test cron tool injection through the builtin tool manager."""
|
||||
module = ama
|
||||
|
||||
@@ -2,6 +2,8 @@ from astrbot.core import sp
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from astrbot.core.tools.computer_tools.shell import ExecuteShellTool
|
||||
from astrbot.core.tools.message_tools import SendMessageToUserTool
|
||||
from astrbot.core.tools.web_search_tools import FirecrawlExtractWebPageTool
|
||||
from astrbot.core.tools.web_search_tools import FirecrawlWebSearchTool
|
||||
|
||||
|
||||
def test_get_builtin_tool_by_class_returns_cached_instance():
|
||||
@@ -38,3 +40,15 @@ def test_computer_tools_are_registered_as_builtin_tools():
|
||||
|
||||
assert tool.name == "astrbot_execute_shell"
|
||||
assert manager.is_builtin_tool("astrbot_execute_shell") is True
|
||||
|
||||
|
||||
def test_firecrawl_tools_are_registered_as_builtin_tools():
|
||||
manager = FunctionToolManager()
|
||||
|
||||
search_tool = manager.get_builtin_tool(FirecrawlWebSearchTool)
|
||||
extract_tool = manager.get_builtin_tool(FirecrawlExtractWebPageTool)
|
||||
|
||||
assert search_tool.name == "web_search_firecrawl"
|
||||
assert extract_tool.name == "firecrawl_extract_web_page"
|
||||
assert manager.is_builtin_tool("web_search_firecrawl") is True
|
||||
assert manager.is_builtin_tool("firecrawl_extract_web_page") is True
|
||||
|
||||
380
tests/unit/test_web_search_tools.py
Normal file
380
tests/unit/test_web_search_tools.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.tools import web_search_tools as tools
|
||||
|
||||
|
||||
class _FakeConfig(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.saved = False
|
||||
|
||||
def save_config(self):
|
||||
self.saved = True
|
||||
|
||||
|
||||
def test_normalize_legacy_web_search_config_migrates_firecrawl_key():
|
||||
config = _FakeConfig(
|
||||
{"provider_settings": {"websearch_firecrawl_key": "firecrawl-key"}}
|
||||
)
|
||||
|
||||
tools.normalize_legacy_web_search_config(config)
|
||||
|
||||
assert config["provider_settings"]["websearch_firecrawl_key"] == ["firecrawl-key"]
|
||||
assert config.saved is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_search_maps_web_results(monkeypatch):
|
||||
async def fake_firecrawl_search(provider_settings, payload):
|
||||
assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"]
|
||||
assert payload == {
|
||||
"query": "AstrBot",
|
||||
"limit": 3,
|
||||
"sources": ["web"],
|
||||
"country": "US",
|
||||
}
|
||||
return [
|
||||
tools.SearchResult(
|
||||
title="AstrBot",
|
||||
url="https://example.com",
|
||||
snippet="Search result",
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search)
|
||||
tool = tools.FirecrawlWebSearchTool()
|
||||
context = _context_with_provider_settings(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]}
|
||||
)
|
||||
|
||||
result = await tool.call(context, query="AstrBot", limit=3, country="US")
|
||||
|
||||
assert json.loads(result)["results"] == [
|
||||
{
|
||||
"title": "AstrBot",
|
||||
"url": "https://example.com",
|
||||
"snippet": "Search result",
|
||||
"index": json.loads(result)["results"][0]["index"],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_search_maps_v2_data_list(monkeypatch):
|
||||
session = _FakeFirecrawlSession(
|
||||
_FakeFirecrawlResponse(
|
||||
status=200,
|
||||
json_data={
|
||||
"success": True,
|
||||
"data": [
|
||||
{
|
||||
"title": "AstrBot",
|
||||
"url": "https://example.com",
|
||||
"description": "Search result",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def fake_client_session(*, trust_env):
|
||||
session.trust_env = trust_env
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
|
||||
|
||||
results = await tools._firecrawl_search(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]},
|
||||
{"query": "AstrBot", "limit": 5, "sources": ["web"]},
|
||||
)
|
||||
|
||||
assert session.posted == {
|
||||
"url": "https://api.firecrawl.dev/v2/search",
|
||||
"json": {"query": "AstrBot", "limit": 5, "sources": ["web"]},
|
||||
"headers": {
|
||||
"Authorization": "Bearer firecrawl-key",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
assert results == [
|
||||
tools.SearchResult(
|
||||
title="AstrBot", url="https://example.com", snippet="Search result"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_search_maps_v2_grouped_web_data(monkeypatch):
|
||||
session = _FakeFirecrawlSession(
|
||||
_FakeFirecrawlResponse(
|
||||
status=200,
|
||||
json_data={
|
||||
"success": True,
|
||||
"data": {
|
||||
"web": [
|
||||
{
|
||||
"title": "AstrBot",
|
||||
"url": "https://example.com",
|
||||
"description": "Search result",
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def fake_client_session(*, trust_env):
|
||||
session.trust_env = trust_env
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
|
||||
|
||||
results = await tools._firecrawl_search(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]},
|
||||
{"query": "AstrBot", "limit": 5, "sources": ["web"]},
|
||||
)
|
||||
|
||||
assert results == [
|
||||
tools.SearchResult(
|
||||
title="AstrBot", url="https://example.com", snippet="Search result"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_search_payload_omits_tbs_and_uses_default_limit(monkeypatch):
|
||||
async def fake_firecrawl_search(provider_settings, payload):
|
||||
assert payload == {
|
||||
"query": "AstrBot",
|
||||
"limit": 5,
|
||||
"sources": ["web"],
|
||||
"country": "US",
|
||||
}
|
||||
return [
|
||||
tools.SearchResult(
|
||||
title="AstrBot",
|
||||
url="https://example.com",
|
||||
snippet="Search result",
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search)
|
||||
tool = tools.FirecrawlWebSearchTool()
|
||||
context = _context_with_provider_settings(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]}
|
||||
)
|
||||
|
||||
result = await tool.call(
|
||||
context,
|
||||
query="AstrBot",
|
||||
tbs="qdr:d",
|
||||
country="US",
|
||||
)
|
||||
|
||||
assert json.loads(result)["results"][0]["url"] == "https://example.com"
|
||||
assert "tbs" not in tool.parameters["properties"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_extract_returns_scraped_markdown(monkeypatch):
|
||||
async def fake_firecrawl_scrape(provider_settings, payload):
|
||||
assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"]
|
||||
assert payload == {
|
||||
"url": "https://example.com",
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
}
|
||||
return {"url": "https://example.com", "markdown": "# Example"}
|
||||
|
||||
monkeypatch.setattr(tools, "_firecrawl_scrape", fake_firecrawl_scrape)
|
||||
tool = tools.FirecrawlExtractWebPageTool()
|
||||
context = _context_with_provider_settings(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]}
|
||||
)
|
||||
|
||||
result = await tool.call(context, url="https://example.com")
|
||||
|
||||
assert result == "URL: https://example.com\nContent: # Example"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_search_uses_session_context(monkeypatch):
|
||||
session = _FakeFirecrawlSession(
|
||||
_FakeFirecrawlResponse(
|
||||
status=200,
|
||||
json_data={
|
||||
"success": True,
|
||||
"data": [
|
||||
{
|
||||
"title": "AstrBot",
|
||||
"url": "https://example.com",
|
||||
"description": "Search result",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def fake_client_session(*, trust_env):
|
||||
session.trust_env = trust_env
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
|
||||
|
||||
await tools._firecrawl_search(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]},
|
||||
{"query": "AstrBot"},
|
||||
)
|
||||
|
||||
assert session.trust_env is True
|
||||
assert session.entered is True
|
||||
assert session.exited is True
|
||||
assert session.posted == {
|
||||
"url": "https://api.firecrawl.dev/v2/search",
|
||||
"json": {"query": "AstrBot"},
|
||||
"headers": {
|
||||
"Authorization": "Bearer firecrawl-key",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_search_raises_error_for_http_errors(monkeypatch):
|
||||
session = _FakeFirecrawlSession(
|
||||
_FakeFirecrawlResponse(status=401, text_data="Unauthorized")
|
||||
)
|
||||
|
||||
def fake_client_session(*, trust_env):
|
||||
session.trust_env = trust_env
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
|
||||
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="Firecrawl web search failed: Unauthorized, status: 401",
|
||||
):
|
||||
await tools._firecrawl_search(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]},
|
||||
{"query": "AstrBot"},
|
||||
)
|
||||
|
||||
assert session.trust_env is True
|
||||
assert session.entered is True
|
||||
assert session.exited is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_scrape_uses_request_setup(monkeypatch):
|
||||
session = _FakeFirecrawlSession(
|
||||
_FakeFirecrawlResponse(
|
||||
status=200,
|
||||
json_data={
|
||||
"success": True,
|
||||
"data": {"url": "https://example.com", "markdown": "# Example"},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def fake_client_session(*, trust_env):
|
||||
session.trust_env = trust_env
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
|
||||
|
||||
result = await tools._firecrawl_scrape(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]},
|
||||
{"url": "https://example.com", "formats": ["markdown"]},
|
||||
)
|
||||
|
||||
assert result == {"url": "https://example.com", "markdown": "# Example"}
|
||||
assert session.trust_env is True
|
||||
assert session.entered is True
|
||||
assert session.exited is True
|
||||
assert session.posted == {
|
||||
"url": "https://api.firecrawl.dev/v2/scrape",
|
||||
"json": {"url": "https://example.com", "formats": ["markdown"]},
|
||||
"headers": {
|
||||
"Authorization": "Bearer firecrawl-key",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_firecrawl_scrape_raises_error_for_http_errors(monkeypatch):
|
||||
session = _FakeFirecrawlSession(
|
||||
_FakeFirecrawlResponse(status=401, text_data="Unauthorized")
|
||||
)
|
||||
|
||||
def fake_client_session(*, trust_env):
|
||||
session.trust_env = trust_env
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
|
||||
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="Firecrawl web scraper failed: Unauthorized, status: 401",
|
||||
):
|
||||
await tools._firecrawl_scrape(
|
||||
{"websearch_firecrawl_key": ["firecrawl-key"]},
|
||||
{"url": "https://example.com", "formats": ["markdown"]},
|
||||
)
|
||||
|
||||
assert session.trust_env is True
|
||||
assert session.entered is True
|
||||
assert session.exited is True
|
||||
|
||||
|
||||
class _FakeFirecrawlResponse:
|
||||
def __init__(self, status=200, json_data=None, text_data=""):
|
||||
self.status = status
|
||||
self.json_data = json_data or {}
|
||||
self.text_data = text_data
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def json(self):
|
||||
return self.json_data
|
||||
|
||||
async def text(self):
|
||||
return self.text_data
|
||||
|
||||
|
||||
class _FakeFirecrawlSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.trust_env = None
|
||||
self.entered = False
|
||||
self.exited = False
|
||||
self.posted = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.entered = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.exited = True
|
||||
return None
|
||||
|
||||
def post(self, url, json, headers):
|
||||
self.posted = {"url": url, "json": json, "headers": headers}
|
||||
return self.response
|
||||
|
||||
|
||||
def _context_with_provider_settings(provider_settings):
|
||||
config = {"provider_settings": provider_settings}
|
||||
agent_context = SimpleNamespace(
|
||||
context=SimpleNamespace(get_config=lambda umo: config),
|
||||
event=SimpleNamespace(unified_msg_origin="test:private:session"),
|
||||
)
|
||||
return SimpleNamespace(context=agent_context)
|
||||
Reference in New Issue
Block a user