Compare commits

...

11 Commits

Author SHA1 Message Date
Soulter
47ea036a81 fix: add reasoning_content field for DeepSeek v4 models in assistant messages 2026-04-27 11:45:42 +08:00
Soulter
643a8b177e fix: update reasoning_content handling to support empty string values 2026-04-27 11:43:49 +08:00
Weilong Liao
07b37b98de fix: handle empty reasoning content for DeepSeek v4 models (#7823)
Co-authored-by: Copilot <copilot@github.com>
2026-04-27 02:19:40 +08:00
bugkeep
bbda1e678f fix(core): downscale oversized images (#7807)
* fix(core): downscale oversized images

* refactor: share image max-size check helper

* Delete tests/unit/test_media_utils_compress_image.py

---------

Co-authored-by: Weilong Liao <37870767+Soulter@users.noreply.github.com>
2026-04-26 23:10:58 +08:00
EnemyWind
3c1d0cd2c2 [fix] 将Minimax TTS默认输出格式改为wav以解决RIFF错误 (#7797)
## 问题
在 QQ 官方平台插件中,处理来自 Minimax TTS 的语音时,会抛出错误:`处理语音时出错: file does not start with RIFF id`。
## 原因
Minimax TTS 提供商 (`minimax_tts_api_source.py`) 默认配置的音频输出格式为 `mp3`,而 `qqofficial_message_event.py` 中的 `wav_to_tencent_silk` 函数要求输入为 WAV 格式(具有 RIFF 文件头)。
## 解决方案
将 `minimax_tts_api_source.py` 文件中 `ProviderMiniMaxTTSAPI` 类的 `audio_setting` 字典的 `format` 键值,从 `"mp3"` 修改为 `"wav"`。
## 结果
修改后,Minimax TTS 生成的音频文件将直接为 WAV 格式,从而被下游函数正确识别和处理,修复上述错误。
2026-04-26 23:06:54 +08:00
Weilong Liao
d16ed4e552 fix: revise reasoning_key attribute to OpenRouter (#7821) 2026-04-26 22:21:57 +08:00
Yufeng He
55c1558686 fix(openai): apply empty-assistant filter to streaming path (fixes #7721) (#7758)
PR #7202 added empty-assistant filtering in `_query` so strict
providers (Moonshot, etc.) wouldn't 400 on history with blank
assistant entries. The streaming sibling `_query_stream` was
never updated, so DeepSeek Reasoner — which returns reasoning only
during tool calls, leaving serialized content as `""` — blew up with
`Invalid assistant message: content or tool_calls must be set` on
the next turn.

Hoisted the filter into a `_sanitize_assistant_messages` helper and
called it from both paths. Also widened the empty check to cover
`content == []`, which the original filter missed and which shows up
with providers that emit content as a list of parts.
2026-04-26 13:10:47 +08:00
wjiajian
17aea1aa2c feat: add Firecrawl web search tools (#7764)
* feat: add Firecrawl web search and extract tools, update configuration and tests

* feat: implement Firecrawl API integration and error handling in web search tools

* feat: enhance Firecrawl web search with session management and payload validation

* feat:  Firecrawl web search to use aiohttp.ClientSession directly for improved session management as it was

* feat: update Firecrawl search to handle grouped web data response and add corresponding tests

* feat: refactor Firecrawl web search to use aiohttp.ClientSession for improved error handling and session management

* feat: remove unused coercion function and update Firecrawl search to use default limit in payload
2026-04-26 13:07:27 +08:00
Rhonin Wang
d4cdeeae72 fix(computer): send sandbox image downloads as images (#7785) 2026-04-25 16:44:08 +08:00
lingyun14
5ce02da6df fix: use certifi ssl context on Windows (#7778)
* fix: use certifi ssl context on Windows

* docs: update docstring to reflect hybrid SSL context

* chore: ruff

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-04-25 16:34:50 +08:00
Soulter
5d79c99938 feat: add deduplication for WeChat kefu text messages within 15 seconds (#7788) 2026-04-25 16:26:30 +08:00
22 changed files with 913 additions and 62 deletions

View File

@@ -183,10 +183,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.stats.end_time = time.time()
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
if llm_resp.reasoning_content is not None or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
think=llm_resp.reasoning_content or "",
encrypted=llm_resp.reasoning_signature,
)
)
@@ -876,10 +876,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 将结果添加到上下文中
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
if llm_resp.reasoning_content is not None or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
think=llm_resp.reasoning_content or "",
encrypted=llm_resp.reasoning_signature,
)
)
@@ -1361,10 +1361,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.stats.end_time = time.time()
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
if llm_resp.reasoning_content is not None or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
think=llm_resp.reasoning_content or "",
encrypted=llm_resp.reasoning_signature,
)
)

View File

@@ -77,6 +77,8 @@ from astrbot.core.tools.web_search_tools import (
BaiduWebSearchTool,
BochaWebSearchTool,
BraveWebSearchTool,
FirecrawlExtractWebPageTool,
FirecrawlWebSearchTool,
TavilyExtractWebPageTool,
TavilyWebSearchTool,
normalize_legacy_web_search_config,
@@ -1047,6 +1049,9 @@ async def _apply_web_search_tools(
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BochaWebSearchTool))
elif provider == "brave":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BraveWebSearchTool))
elif provider == "firecrawl":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlWebSearchTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlExtractWebPageTool))
elif provider == "baidu_ai_search":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool))

View File

@@ -3202,6 +3202,7 @@ CONFIG_METADATA_3 = {
"baidu_ai_search",
"bocha",
"brave",
"firecrawl",
],
"condition": {
"provider_settings.web_search": True,
@@ -3237,6 +3238,16 @@ CONFIG_METADATA_3 = {
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_firecrawl_key": {
"description": "Firecrawl API Key",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": {
"provider_settings.websearch_provider": "firecrawl",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"type": "string",

View File

@@ -1,6 +1,7 @@
import asyncio
import os
import sys
import time
import uuid
from collections.abc import Awaitable, Callable
from typing import Any, cast
@@ -140,6 +141,8 @@ class WecomServer:
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
class WecomPlatformAdapter(Platform):
WECHAT_KF_TEXT_CONTENT_DEDUP_TTL_SECONDS = 15
def __init__(
self,
platform_config: dict,
@@ -166,6 +169,7 @@ class WecomPlatformAdapter(Platform):
self.server = WecomServer(self._event_queue, self.config)
self.agent_id: str | None = None
self._wechat_kf_seen_text_messages: dict[str, float] = {}
self.client = WeChatClient(
self.config["corpid"].strip(),
@@ -210,6 +214,28 @@ class WecomPlatformAdapter(Platform):
self.server.callback = callback
def _is_duplicate_wechat_kf_text_message(self, session_id: str, text: str) -> bool:
normalized_text = text.strip()
if not normalized_text:
return False
now = time.monotonic()
expired_keys = [
key
for key, expires_at in self._wechat_kf_seen_text_messages.items()
if expires_at <= now
]
for key in expired_keys:
self._wechat_kf_seen_text_messages.pop(key, None)
dedup_key = f"{session_id}:{normalized_text}"
if dedup_key in self._wechat_kf_seen_text_messages:
return True
self._wechat_kf_seen_text_messages[dedup_key] = (
now + self.WECHAT_KF_TEXT_CONTENT_DEDUP_TTL_SECONDS
)
return False
@override
async def send_by_session(
self,
@@ -390,6 +416,13 @@ class WecomPlatformAdapter(Platform):
abm.message_str = ""
if msgtype == "text":
text = msg.get("text", {}).get("content", "").strip()
if self._is_duplicate_wechat_kf_text_message(abm.session_id, text):
logger.debug(
"忽略 15 秒内重复微信客服文本消息 session_id=%s text=%s",
abm.session_id,
text,
)
return None
abm.message = [Plain(text=text)]
abm.message_str = text
elif msgtype == "image":

View File

@@ -353,7 +353,7 @@ class LLMResponse:
"""Tool call IDs."""
tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Tool call extra content. tool_call_id -> extra_content dict"""
reasoning_content: str = ""
reasoning_content: str | None = None
"""The reasoning content extracted from the LLM, if any."""
reasoning_signature: str | None = None
"""The signature of the reasoning content, if any."""
@@ -404,8 +404,6 @@ class LLMResponse:
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
if reasoning_content is None:
reasoning_content = ""
if tools_call_args is None:
tools_call_args = []
if tools_call_name is None:

View File

@@ -39,7 +39,7 @@ class ProviderAnthropic(Provider):
stop_reason: str | None = None,
) -> None:
has_text_output = bool((llm_response.completion_text or "").strip())
has_reasoning_output = bool(llm_response.reasoning_content.strip())
has_reasoning_output = bool((llm_response.reasoning_content or "").strip())
has_tool_output = bool(llm_response.tools_call_args)
if has_text_output or has_reasoning_output or has_tool_output:
return

View File

@@ -462,7 +462,7 @@ class ProviderGoogleGenAI(Provider):
finish_reason: str | None = None,
) -> None:
has_text_output = bool((llm_response.completion_text or "").strip())
has_reasoning_output = bool(llm_response.reasoning_content.strip())
has_reasoning_output = bool((llm_response.reasoning_content or "").strip())
has_tool_output = bool(llm_response.tools_call_args)
if has_text_output or has_reasoning_output or has_tool_output:
return

View File

@@ -65,7 +65,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
self.audio_setting: dict = {
"sample_rate": 32000,
"bitrate": 128000,
"format": "mp3",
"format": "wav",
}
self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
@@ -147,7 +147,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.wav")
try:
# 直接将异步生成器传递给 _audio_play 方法

View File

@@ -519,6 +519,42 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
@staticmethod
def _sanitize_assistant_messages(payloads: dict) -> None:
"""在请求发送前过滤/规范化空的 assistant 消息。
严格 APIMoonshot、DeepSeek Reasoner 等)会在 assistant 消息同时缺少
``content`` 和 ``tool_calls`` 时返回 400。把 ``""`` / ``None`` / ``[]``
都视作空内容:无 tool_calls 时整条过滤掉;有 tool_calls 时将 content
设为 ``None`` 以符合 OpenAI 规范。就地修改 ``payloads["messages"]``。
"""
messages = payloads.get("messages")
if not isinstance(messages, list):
return
def _is_empty(content: Any) -> bool:
return content is None or content == "" or content == []
cleaned: list[Any] = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or msg.get("role") != "assistant":
cleaned.append(msg)
continue
content = msg.get("content")
tool_calls = msg.get("tool_calls")
if _is_empty(content) and not tool_calls:
logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
continue
if _is_empty(content) and tool_calls:
msg["content"] = None
cleaned.append(msg)
payloads["messages"] = cleaned
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
@@ -548,26 +584,7 @@ class ProviderOpenAIOfficial(Provider):
model = payloads.get("model", "").lower()
if "messages" in payloads and isinstance(payloads["messages"], list):
cleaned_messages = []
for idx, msg in enumerate(payloads["messages"]):
# 过滤空的 assistant 消息,防止严格 API如 Moonshot返回 400 错误
if msg.get("role") == "assistant":
content = msg.get("content")
tool_calls = msg.get("tool_calls")
# 情况1: 空/null content 且无 tool_calls -> 过滤掉
if not tool_calls and (content == "" or content is None):
logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
continue
# 情况2: 空 content 但有 tool_calls -> 设为 None (符合 OpenAI 规范)
if content == "" and tool_calls:
msg["content"] = None
cleaned_messages.append(msg)
payloads["messages"] = cleaned_messages
self._sanitize_assistant_messages(payloads)
completion = await self.client.chat.completions.create(
**payloads,
@@ -619,6 +636,8 @@ class ProviderOpenAIOfficial(Provider):
del payloads[key]
self._apply_provider_specific_extra_body_overrides(extra_body)
self._sanitize_assistant_messages(payloads)
stream = await self.client.chat.completions.create(
**payloads,
stream=True,
@@ -652,9 +671,9 @@ class ProviderOpenAIOfficial(Provider):
reasoning = self._extract_reasoning_content(chunk)
_y = False
llm_response.id = chunk.id
llm_response.reasoning_content = ""
llm_response.reasoning_content = None
llm_response.completion_text = ""
if reasoning:
if reasoning is not None:
llm_response.reasoning_content = reasoning
_y = True
if delta and delta.content:
@@ -682,22 +701,28 @@ class ProviderOpenAIOfficial(Provider):
def _extract_reasoning_content(
self,
completion: ChatCompletion | ChatCompletionChunk,
) -> str:
) -> str | None:
"""Extract reasoning content from OpenAI ChatCompletion if available."""
reasoning_text = ""
def _get_reasoning_attr(obj: Any) -> str | None:
fields_set = getattr(obj, "model_fields_set", None)
if isinstance(fields_set, set) and self.reasoning_key in fields_set:
attr = getattr(obj, self.reasoning_key, "")
return "" if attr is None else str(attr)
attr = getattr(obj, self.reasoning_key, None)
return None if attr is None else str(attr)
if not completion.choices:
return reasoning_text
return None
if isinstance(completion, ChatCompletion):
choice = completion.choices[0]
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
reasoning_attr = _get_reasoning_attr(choice.message)
elif isinstance(completion, ChatCompletionChunk):
delta = completion.choices[0].delta
reasoning_attr = getattr(delta, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
return reasoning_text
reasoning_attr = _get_reasoning_attr(delta)
else:
return None
return reasoning_attr
def _extract_usage(self, usage: CompletionUsage | dict) -> TokenUsage:
ptd = getattr(usage, "prompt_tokens_details", None)
@@ -840,7 +865,9 @@ class ProviderOpenAIOfficial(Provider):
# parse the reasoning content if any
# the priority is higher than the <think> tag extraction
llm_response.reasoning_content = self._extract_reasoning_content(completion)
reasoning_content = self._extract_reasoning_content(completion)
if reasoning_content is not None:
llm_response.reasoning_content = reasoning_content
# parse tool calls if any
if choice.message.tool_calls and tools is not None:
@@ -887,7 +914,7 @@ class ProviderOpenAIOfficial(Provider):
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
)
has_text_output = bool((llm_response.completion_text or "").strip())
has_reasoning_output = bool(llm_response.reasoning_content.strip())
has_reasoning_output = bool((llm_response.reasoning_content or "").strip())
if (
not has_text_output
and not has_reasoning_output
@@ -963,24 +990,39 @@ class ProviderOpenAIOfficial(Provider):
"""Finally convert the payload. Such as think part conversion, tool inject."""
model = payloads.get("model", "").lower()
is_gemini = "gemini" in model
deepseek_reasoning_models = {"deepseek-v4-pro", "deepseek-v4-flash"}
is_deepseek_v4_reasoning = (
model in deepseek_reasoning_models
or "api.deepseek.com" in self.client.base_url.host
)
for message in payloads.get("messages", []):
if message.get("role") == "assistant" and isinstance(
message.get("content"), list
):
reasoning_content = ""
reasoning_content_present = False
new_content = [] # not including think part
for part in message["content"]:
if part.get("type") == "think":
reasoning_content_present = True
reasoning_content += str(part.get("think"))
else:
new_content.append(part)
# Some providers (Grok, etc.) reject empty content lists.
# When all parts were think blocks, fall back to None.
message["content"] = new_content or None
if reasoning_content:
if reasoning_content_present:
message["reasoning_content"] = reasoning_content
if (
message.get("role") == "assistant"
and is_deepseek_v4_reasoning
and "reasoning_content" not in message
):
# DeepSeek v4 reasoning models require the field on assistant
# history messages, even when the reasoning content is empty.
message["reasoning_content"] = ""
# Gemini 的 function_response 要求 google.protobuf.Struct即 JSON 对象),
# 纯文本会触发 400 Invalid argument需要包一层 JSON。
if is_gemini and message.get("role") == "tool":

View File

@@ -20,3 +20,4 @@ class ProviderOpenRouter(ProviderOpenAIOfficial):
self.client._custom_headers["X-OpenRouter-Categories"] = (
"general-chat,personal-agent" # type: ignore
)
self.reasoning_key = "reasoning"

View File

@@ -43,7 +43,7 @@ from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.computer.file_read_utils import read_file_tool_result
from astrbot.core.message.components import File
from astrbot.core.message.components import File, Image
from astrbot.core.utils.astrbot_path import (
get_astrbot_skills_path,
get_astrbot_system_tmp_path,
@@ -64,6 +64,7 @@ _COMPUTER_RUNTIME_TOOL_CONFIG = {
_SANDBOX_RUNTIME_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
}
_IMAGE_FILE_SUFFIXES = {".bmp", ".gif", ".jpeg", ".jpg", ".png", ".webp"}
def _restricted_env_path_labels(umo: str) -> list[str]:
@@ -729,11 +730,21 @@ class FileDownloadTool(FunctionTool):
if also_send_to_user:
try:
name = os.path.basename(local_path)
if Path(local_path).suffix.lower() in _IMAGE_FILE_SUFFIXES:
message_component = Image.fromFileSystem(local_path)
sent_as = "image"
else:
message_component = File(name=name, file=local_path)
sent_as = "file"
await context.context.event.send(
MessageChain(chain=[File(name=name, file=local_path)])
MessageChain(chain=[message_component])
)
except Exception as e:
logger.error(f"Error sending file message: {e}")
return (
f"File downloaded successfully to {local_path} "
f"but sending to user failed: {e}"
)
# remove
# try:
@@ -741,7 +752,10 @@ class FileDownloadTool(FunctionTool):
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path} and sent to user."
return (
f"File downloaded successfully to {local_path} "
f"and sent to user as {sent_as}."
)
return f"File downloaded successfully to {local_path}"
except Exception as e:

View File

@@ -19,6 +19,8 @@ WEB_SEARCH_TOOL_NAMES = [
"tavily_extract_web_page",
"web_search_bocha",
"web_search_brave",
"web_search_firecrawl",
"firecrawl_extract_web_page",
]
_TAVILY_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
@@ -32,6 +34,10 @@ _BRAVE_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "brave",
}
_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "firecrawl",
}
_BAIDU_WEB_SEARCH_TOOL_CONFIG = {
"provider_settings.web_search": True,
"provider_settings.websearch_provider": "baidu_ai_search",
@@ -69,6 +75,7 @@ class _KeyRotator:
_TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily")
_BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha")
_BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave")
_FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl")
def normalize_legacy_web_search_config(cfg) -> None:
@@ -91,6 +98,7 @@ def normalize_legacy_web_search_config(cfg) -> None:
"websearch_tavily_key",
"websearch_bocha_key",
"websearch_brave_key",
"websearch_firecrawl_key",
):
value = provider_settings.get(setting_name)
if isinstance(value, str):
@@ -258,6 +266,72 @@ async def _brave_search(
]
async def _firecrawl_search(
provider_settings: dict,
payload: dict,
) -> list[SearchResult]:
firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings)
header = {
"Authorization": f"Bearer {firecrawl_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
"https://api.firecrawl.dev/v2/search",
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Firecrawl web search failed: {reason}, status: {response.status}",
)
data = await response.json()
rows = data.get("data", [])
if isinstance(rows, dict):
rows = rows.get("web", [])
return [
SearchResult(
title=item.get("title", ""),
url=item.get("url", ""),
snippet=(
item.get("description")
or item.get("snippet")
or item.get("markdown")
or ""
),
)
for item in rows
if item.get("url")
]
async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict:
firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings)
header = {
"Authorization": f"Bearer {firecrawl_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
"https://api.firecrawl.dev/v2/scrape",
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Firecrawl web scraper failed: {reason}, status: {response.status}",
)
data = await response.json()
result = data.get("data", {})
if not result:
raise ValueError(
"Error: Firecrawl web scraper does not return any results."
)
return result
async def _baidu_search(
provider_settings: dict,
payload: dict,
@@ -548,6 +622,124 @@ class BraveWebSearchTool(FunctionTool[AstrAgentContext]):
return _search_result_payload(results)
@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class FirecrawlWebSearchTool(FunctionTool[AstrAgentContext]):
name: str = "web_search_firecrawl"
description: str = (
"A web search tool based on Firecrawl Search API, used to retrieve web "
"pages related to the user's query."
)
parameters: dict = Field(
default_factory=lambda: {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Required. Search query."},
"limit": {
"type": "integer",
"description": "Optional. Number of results to return. Range: 1-100. Default is 5.",
},
"location": {
"type": "string",
"description": "Optional. Geographic location for search results.",
},
"country": {
"type": "string",
"description": 'Optional. Country code for search results, for example "US" or "CN".',
},
"timeout": {
"type": "integer",
"description": "Optional. Request timeout in milliseconds.",
},
},
"required": ["query"],
}
)
async def call(self, context, **kwargs) -> ToolExecResult:
_, provider_settings, _ = _get_runtime(context)
if not provider_settings.get("websearch_firecrawl_key", []):
return "Error: Firecrawl API key is not configured in AstrBot."
payload = {
"query": kwargs["query"],
"limit": kwargs.get("limit", 5),
"sources": ["web"],
}
for key in ("location", "country", "timeout"):
if kwargs.get(key):
payload[key] = kwargs[key]
results = await _firecrawl_search(provider_settings, payload)
if not results:
return "Error: Firecrawl web searcher does not return any results."
return _search_result_payload(results)
@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class FirecrawlExtractWebPageTool(FunctionTool[AstrAgentContext]):
name: str = "firecrawl_extract_web_page"
description: str = "Extract the content of a web page using Firecrawl."
parameters: dict = Field(
default_factory=lambda: {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Required. A URL to extract content from.",
},
"format": {
"type": "string",
"description": 'Optional. Output format, one of "markdown", "html", "rawHtml", "summary". Default is "markdown".',
},
"only_main_content": {
"type": "boolean",
"description": "Optional. Whether to extract only the main page content. Default is true.",
},
"timeout": {
"type": "integer",
"description": "Optional. Request timeout in milliseconds.",
},
"max_age": {
"type": "integer",
"description": "Optional. Maximum cache age in milliseconds.",
},
},
"required": ["url"],
}
)
async def call(self, context, **kwargs) -> ToolExecResult:
_, provider_settings, _ = _get_runtime(context)
if not provider_settings.get("websearch_firecrawl_key", []):
return "Error: Firecrawl API key is not configured in AstrBot."
url = str(kwargs.get("url", "")).strip()
if not url:
return "Error: url must be a non-empty string."
output_format = kwargs.get("format", "markdown")
if output_format not in ["markdown", "html", "rawHtml", "summary"]:
output_format = "markdown"
payload = {
"url": url,
"formats": [output_format],
"onlyMainContent": kwargs.get("only_main_content", True),
}
if kwargs.get("timeout"):
payload["timeout"] = kwargs["timeout"]
if kwargs.get("max_age"):
payload["maxAge"] = kwargs["max_age"]
result = await _firecrawl_scrape(provider_settings, payload)
content = result.get(output_format, "")
result_url = result.get("url") or url
ret = f"URL: {result_url}\nContent: {content}" if content else ""
return ret or "Error: Firecrawl web scraper does not return any results."
@builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG)
@pydantic_dataclass
class BaiduWebSearchTool(FunctionTool[AstrAgentContext]):

View File

@@ -436,19 +436,30 @@ async def compress_image(
optimize = IMAGE_COMPRESS_DEFAULT_OPTIMIZE
min_file_size_bytes = int(IMAGE_COMPRESS_DEFAULT_MIN_FILE_SIZE_MB * 1024 * 1024)
data = None
def _exceeds_max_size(source: bytes | Path) -> bool:
try:
fp = io.BytesIO(source) if isinstance(source, bytes) else source
with PILImage.open(fp) as opened_img:
return max(opened_img.size) > max_size
except Exception: # noqa: BLE001
return False
# Skip compression for remote images and return the original value.
if url_or_path.startswith("http"):
return url_or_path
elif url_or_path.startswith("data:image"):
_header, encoded = url_or_path.split(",", 1)
data = base64.b64decode(encoded)
if len(data) < min_file_size_bytes:
if len(data) < min_file_size_bytes and not _exceeds_max_size(data):
return url_or_path
else:
local_path = Path(url_or_path)
if not local_path.exists():
return url_or_path
if local_path.stat().st_size < min_file_size_bytes:
if local_path.stat().st_size < min_file_size_bytes and not _exceeds_max_size(
local_path
):
return url_or_path
with local_path.open("rb") as f:
data = f.read()

View File

@@ -5,8 +5,9 @@ import ssl
import httpx
from astrbot import logger
from astrbot.utils.http_ssl_common import build_ssl_context_with_certifi
_SYSTEM_SSL_CTX = ssl.create_default_context()
_SYSTEM_SSL_CTX = build_ssl_context_with_certifi()
def is_connection_error(exc: BaseException) -> bool:
@@ -92,9 +93,9 @@ def create_proxy_client(
) -> httpx.AsyncClient:
"""Create an httpx AsyncClient with proxy configuration if provided.
Uses the system SSL certificate store instead of certifi, which avoids
SSL verification failures for endpoints whose CA chain is not in certifi
but is trusted by the operating system.
Uses a hybrid SSL context that combines the system SSL certificate store
with certifi as a fallback, ensuring compatibility across different
environments including Windows where the system store may be incomplete.
Note: The caller is responsible for closing the client when done.
Consider using the client as a context manager or calling aclose() explicitly.
@@ -103,11 +104,11 @@ def create_proxy_client(
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
proxy: The proxy address (e.g., "http://127.0.0.1:7890"), or None/empty
headers: Optional custom headers to include in every request
verify: Optional override for TLS verification. Defaults to the shared
system SSL context when not provided.
verify: Optional override for TLS verification. Defaults to the hybrid
SSL context (system store + certifi) when not provided.
Returns:
An httpx.AsyncClient created with the shared system SSL context; the proxy is applied only if one is provided.
An httpx.AsyncClient created with the hybrid SSL context (system store + certifi); the proxy is applied only if one is provided.
"""
resolved_verify = _SYSTEM_SSL_CTX if verify is None else verify
if proxy:

View File

@@ -303,7 +303,7 @@ export default {
part.tool_calls.forEach(toolCall => {
// 检查是否是支持引用解析的 web_search 工具调用
if (
!['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave'].includes(toolCall.name) ||
!['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave', 'web_search_firecrawl'].includes(toolCall.name) ||
!toolCall.result
) {
return;

View File

@@ -125,6 +125,10 @@
"description": "Brave Search API Key",
"hint": "Multiple keys can be added for rotation."
},
"websearch_firecrawl_key": {
"description": "Firecrawl API Key",
"hint": "Multiple keys can be added for rotation."
},
"websearch_baidu_app_builder_key": {
"description": "Baidu Qianfan Smart Cloud APP Builder API Key",
"hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"

View File

@@ -125,6 +125,10 @@
"description": "API-ключ Brave Search",
"hint": "Можно добавить несколько ключей для ротации."
},
"websearch_firecrawl_key": {
"description": "API-ключ Firecrawl",
"hint": "Можно добавить несколько ключей для ротации."
},
"websearch_baidu_app_builder_key": {
"description": "API-ключ Baidu Qianfan APP Builder",
"hint": "Ссылка: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"

View File

@@ -127,6 +127,10 @@
"description": "Brave Search API Key",
"hint": "可添加多个 Key 进行轮询。"
},
"websearch_firecrawl_key": {
"description": "Firecrawl API Key",
"hint": "可添加多个 Key 进行轮询。"
},
"websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"

View File

@@ -1618,3 +1618,109 @@ async def test_query_does_not_filter_user_or_system_messages(monkeypatch):
assert messages[2] == {"role": "user", "content": "hello"}
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_query_stream_filters_empty_assistant_message(monkeypatch):
"""Regression for #7721: streaming path must also filter empty assistant messages.
Previously only ``_query`` sanitized the payload; ``_query_stream`` forwarded
the raw history and strict providers (e.g. DeepSeek Reasoner) returned 400 on
the next turn after a tool call whose assistant entry had reasoning only.
"""
provider = _make_provider()
try:
captured_kwargs = {}
async def fake_stream():
yield ChatCompletionChunk.model_validate(
{
"id": "chatcmpl-stream",
"object": "chat.completion.chunk",
"created": 0,
"model": "deepseek-reasoner",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}
],
}
)
async def fake_create(**kwargs):
captured_kwargs.update(kwargs)
return fake_stream()
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
payloads = {
"model": "deepseek-reasoner",
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": ""}, # should be filtered
{"role": "user", "content": "world"},
],
}
async for _ in provider._query_stream(payloads=payloads, tools=None):
pass
messages = captured_kwargs["messages"]
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "hello"}
assert messages[1] == {"role": "user", "content": "world"}
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_query_filters_empty_list_content_assistant_message(monkeypatch):
"""Empty-list content (``content == []``) must also be filtered, not just ``""`` / ``None``."""
provider = _make_provider()
try:
captured_kwargs = {}
async def fake_create(**kwargs):
captured_kwargs.update(kwargs)
return ChatCompletion.model_validate(
{
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 0,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
)
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
payloads = {
"model": "gpt-4o-mini",
"messages": [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": []}, # should be filtered
{"role": "user", "content": "again"},
],
}
await provider._query(payloads=payloads, tools=None)
messages = captured_kwargs["messages"]
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "hi"}
assert messages[1] == {"role": "user", "content": "again"}
finally:
await provider.terminate()

View File

@@ -398,6 +398,37 @@ class TestBuiltinToolInjection:
assert req.func_tool is not None
assert req.func_tool.get_tool("web_search_baidu") is builtin_tool
@pytest.mark.asyncio
async def test_apply_web_search_tools_adds_firecrawl_search_and_extract_tools(
self, mock_event, mock_context
):
"""Test Firecrawl web search injects search and extract tools."""
module = ama
req = ProviderRequest()
mock_context.get_config.return_value = {
"provider_settings": {
"web_search": True,
"websearch_provider": "firecrawl",
}
}
search_tool = MagicMock(spec=FunctionTool)
search_tool.name = "web_search_firecrawl"
extract_tool = MagicMock(spec=FunctionTool)
extract_tool.name = "firecrawl_extract_web_page"
tool_mgr = MagicMock()
tool_mgr.get_builtin_tool.side_effect = [search_tool, extract_tool]
mock_context.get_llm_tool_manager.return_value = tool_mgr
await module._apply_web_search_tools(mock_event, req, mock_context)
assert tool_mgr.get_builtin_tool.call_args_list == [
((module.FirecrawlWebSearchTool,),),
((module.FirecrawlExtractWebPageTool,),),
]
assert req.func_tool is not None
assert req.func_tool.get_tool("web_search_firecrawl") is search_tool
assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool
def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context):
"""Test cron tool injection through the builtin tool manager."""
module = ama

View File

@@ -2,6 +2,8 @@ from astrbot.core import sp
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.tools.computer_tools.shell import ExecuteShellTool
from astrbot.core.tools.message_tools import SendMessageToUserTool
from astrbot.core.tools.web_search_tools import FirecrawlExtractWebPageTool
from astrbot.core.tools.web_search_tools import FirecrawlWebSearchTool
def test_get_builtin_tool_by_class_returns_cached_instance():
@@ -38,3 +40,15 @@ def test_computer_tools_are_registered_as_builtin_tools():
assert tool.name == "astrbot_execute_shell"
assert manager.is_builtin_tool("astrbot_execute_shell") is True
def test_firecrawl_tools_are_registered_as_builtin_tools():
manager = FunctionToolManager()
search_tool = manager.get_builtin_tool(FirecrawlWebSearchTool)
extract_tool = manager.get_builtin_tool(FirecrawlExtractWebPageTool)
assert search_tool.name == "web_search_firecrawl"
assert extract_tool.name == "firecrawl_extract_web_page"
assert manager.is_builtin_tool("web_search_firecrawl") is True
assert manager.is_builtin_tool("firecrawl_extract_web_page") is True

View File

@@ -0,0 +1,380 @@
import json
from types import SimpleNamespace
import pytest
from astrbot.core.tools import web_search_tools as tools
class _FakeConfig(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.saved = False
def save_config(self):
self.saved = True
def test_normalize_legacy_web_search_config_migrates_firecrawl_key():
config = _FakeConfig(
{"provider_settings": {"websearch_firecrawl_key": "firecrawl-key"}}
)
tools.normalize_legacy_web_search_config(config)
assert config["provider_settings"]["websearch_firecrawl_key"] == ["firecrawl-key"]
assert config.saved is True
@pytest.mark.asyncio
async def test_firecrawl_search_maps_web_results(monkeypatch):
async def fake_firecrawl_search(provider_settings, payload):
assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"]
assert payload == {
"query": "AstrBot",
"limit": 3,
"sources": ["web"],
"country": "US",
}
return [
tools.SearchResult(
title="AstrBot",
url="https://example.com",
snippet="Search result",
)
]
monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search)
tool = tools.FirecrawlWebSearchTool()
context = _context_with_provider_settings(
{"websearch_firecrawl_key": ["firecrawl-key"]}
)
result = await tool.call(context, query="AstrBot", limit=3, country="US")
assert json.loads(result)["results"] == [
{
"title": "AstrBot",
"url": "https://example.com",
"snippet": "Search result",
"index": json.loads(result)["results"][0]["index"],
}
]
@pytest.mark.asyncio
async def test_firecrawl_search_maps_v2_data_list(monkeypatch):
session = _FakeFirecrawlSession(
_FakeFirecrawlResponse(
status=200,
json_data={
"success": True,
"data": [
{
"title": "AstrBot",
"url": "https://example.com",
"description": "Search result",
}
],
},
)
)
def fake_client_session(*, trust_env):
session.trust_env = trust_env
return session
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
results = await tools._firecrawl_search(
{"websearch_firecrawl_key": ["firecrawl-key"]},
{"query": "AstrBot", "limit": 5, "sources": ["web"]},
)
assert session.posted == {
"url": "https://api.firecrawl.dev/v2/search",
"json": {"query": "AstrBot", "limit": 5, "sources": ["web"]},
"headers": {
"Authorization": "Bearer firecrawl-key",
"Content-Type": "application/json",
},
}
assert results == [
tools.SearchResult(
title="AstrBot", url="https://example.com", snippet="Search result"
)
]
@pytest.mark.asyncio
async def test_firecrawl_search_maps_v2_grouped_web_data(monkeypatch):
session = _FakeFirecrawlSession(
_FakeFirecrawlResponse(
status=200,
json_data={
"success": True,
"data": {
"web": [
{
"title": "AstrBot",
"url": "https://example.com",
"description": "Search result",
}
]
},
},
)
)
def fake_client_session(*, trust_env):
session.trust_env = trust_env
return session
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
results = await tools._firecrawl_search(
{"websearch_firecrawl_key": ["firecrawl-key"]},
{"query": "AstrBot", "limit": 5, "sources": ["web"]},
)
assert results == [
tools.SearchResult(
title="AstrBot", url="https://example.com", snippet="Search result"
)
]
@pytest.mark.asyncio
async def test_firecrawl_search_payload_omits_tbs_and_uses_default_limit(monkeypatch):
async def fake_firecrawl_search(provider_settings, payload):
assert payload == {
"query": "AstrBot",
"limit": 5,
"sources": ["web"],
"country": "US",
}
return [
tools.SearchResult(
title="AstrBot",
url="https://example.com",
snippet="Search result",
)
]
monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search)
tool = tools.FirecrawlWebSearchTool()
context = _context_with_provider_settings(
{"websearch_firecrawl_key": ["firecrawl-key"]}
)
result = await tool.call(
context,
query="AstrBot",
tbs="qdr:d",
country="US",
)
assert json.loads(result)["results"][0]["url"] == "https://example.com"
assert "tbs" not in tool.parameters["properties"]
@pytest.mark.asyncio
async def test_firecrawl_extract_returns_scraped_markdown(monkeypatch):
async def fake_firecrawl_scrape(provider_settings, payload):
assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"]
assert payload == {
"url": "https://example.com",
"formats": ["markdown"],
"onlyMainContent": True,
}
return {"url": "https://example.com", "markdown": "# Example"}
monkeypatch.setattr(tools, "_firecrawl_scrape", fake_firecrawl_scrape)
tool = tools.FirecrawlExtractWebPageTool()
context = _context_with_provider_settings(
{"websearch_firecrawl_key": ["firecrawl-key"]}
)
result = await tool.call(context, url="https://example.com")
assert result == "URL: https://example.com\nContent: # Example"
@pytest.mark.asyncio
async def test_firecrawl_search_uses_session_context(monkeypatch):
session = _FakeFirecrawlSession(
_FakeFirecrawlResponse(
status=200,
json_data={
"success": True,
"data": [
{
"title": "AstrBot",
"url": "https://example.com",
"description": "Search result",
}
],
},
)
)
def fake_client_session(*, trust_env):
session.trust_env = trust_env
return session
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
await tools._firecrawl_search(
{"websearch_firecrawl_key": ["firecrawl-key"]},
{"query": "AstrBot"},
)
assert session.trust_env is True
assert session.entered is True
assert session.exited is True
assert session.posted == {
"url": "https://api.firecrawl.dev/v2/search",
"json": {"query": "AstrBot"},
"headers": {
"Authorization": "Bearer firecrawl-key",
"Content-Type": "application/json",
},
}
@pytest.mark.asyncio
async def test_firecrawl_search_raises_error_for_http_errors(monkeypatch):
session = _FakeFirecrawlSession(
_FakeFirecrawlResponse(status=401, text_data="Unauthorized")
)
def fake_client_session(*, trust_env):
session.trust_env = trust_env
return session
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
with pytest.raises(
Exception,
match="Firecrawl web search failed: Unauthorized, status: 401",
):
await tools._firecrawl_search(
{"websearch_firecrawl_key": ["firecrawl-key"]},
{"query": "AstrBot"},
)
assert session.trust_env is True
assert session.entered is True
assert session.exited is True
@pytest.mark.asyncio
async def test_firecrawl_scrape_uses_request_setup(monkeypatch):
session = _FakeFirecrawlSession(
_FakeFirecrawlResponse(
status=200,
json_data={
"success": True,
"data": {"url": "https://example.com", "markdown": "# Example"},
},
)
)
def fake_client_session(*, trust_env):
session.trust_env = trust_env
return session
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
result = await tools._firecrawl_scrape(
{"websearch_firecrawl_key": ["firecrawl-key"]},
{"url": "https://example.com", "formats": ["markdown"]},
)
assert result == {"url": "https://example.com", "markdown": "# Example"}
assert session.trust_env is True
assert session.entered is True
assert session.exited is True
assert session.posted == {
"url": "https://api.firecrawl.dev/v2/scrape",
"json": {"url": "https://example.com", "formats": ["markdown"]},
"headers": {
"Authorization": "Bearer firecrawl-key",
"Content-Type": "application/json",
},
}
@pytest.mark.asyncio
async def test_firecrawl_scrape_raises_error_for_http_errors(monkeypatch):
session = _FakeFirecrawlSession(
_FakeFirecrawlResponse(status=401, text_data="Unauthorized")
)
def fake_client_session(*, trust_env):
session.trust_env = trust_env
return session
monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
with pytest.raises(
Exception,
match="Firecrawl web scraper failed: Unauthorized, status: 401",
):
await tools._firecrawl_scrape(
{"websearch_firecrawl_key": ["firecrawl-key"]},
{"url": "https://example.com", "formats": ["markdown"]},
)
assert session.trust_env is True
assert session.entered is True
assert session.exited is True
class _FakeFirecrawlResponse:
def __init__(self, status=200, json_data=None, text_data=""):
self.status = status
self.json_data = json_data or {}
self.text_data = text_data
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def json(self):
return self.json_data
async def text(self):
return self.text_data
class _FakeFirecrawlSession:
def __init__(self, response):
self.response = response
self.trust_env = None
self.entered = False
self.exited = False
self.posted = None
async def __aenter__(self):
self.entered = True
return self
async def __aexit__(self, exc_type, exc, tb):
self.exited = True
return None
def post(self, url, json, headers):
self.posted = {"url": url, "json": json, "headers": headers}
return self.response
def _context_with_provider_settings(provider_settings):
config = {"provider_settings": provider_settings}
agent_context = SimpleNamespace(
context=SimpleNamespace(get_config=lambda umo: config),
event=SimpleNamespace(unified_msg_origin="test:private:session"),
)
return SimpleNamespace(context=agent_context)