Compare commits

..

2 Commits

Author SHA1 Message Date
Soulter
51bb487346 fix: address release workflow review feedback 2026-06-19 15:29:48 +08:00
Soulter
a2567a202e chore: add release preparation workflow 2026-06-19 15:19:32 +08:00
30 changed files with 172 additions and 1197 deletions

View File

@@ -16,11 +16,8 @@ venv*/
ENV/
.conda/
dashboard/
!astrbot/dashboard/
!astrbot/dashboard/dist/
!astrbot/dashboard/dist/**
data/
tests/
.ruff_cache/
.astrbot
astrbot.lock
astrbot.lock

View File

@@ -46,21 +46,14 @@ jobs:
- name: Build Dashboard
run: |
dashboard_version=$(python3 - <<'PY'
import tomllib
with open("pyproject.toml", "rb") as f:
print("v" + tomllib.load(f)["project"]["version"])
PY
)
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo "$dashboard_version" > dist/assets/version
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p astrbot/dashboard
rm -rf astrbot/dashboard/dist
cp -r dashboard/dist astrbot/dashboard/dist
mkdir -p data
cp -r dashboard/dist data/
- name: Determine test image tags
id: test-meta
@@ -164,11 +157,10 @@ jobs:
npm install
npm run build
mkdir -p dist/assets
echo "${{ steps.release-meta.outputs.version }}" > dist/assets/version
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p astrbot/dashboard
rm -rf astrbot/dashboard/dist
cp -r dashboard/dist astrbot/dashboard/dist
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v4.1.0

View File

@@ -9,7 +9,6 @@ from datetime import timedelta
from pathlib import Path, PureWindowsPath
from typing import Any, Generic
import httpx
from tenacity import (
before_sleep_log,
retry,
@@ -103,22 +102,12 @@ except (ModuleNotFoundError, ImportError):
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
)
streamable_http_client_legacy = None
streamable_http_client = None
try:
from mcp.client.streamable_http import (
streamablehttp_client as streamable_http_client_legacy,
)
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
try:
from mcp.client.streamable_http import (
streamable_http_client as streamable_http_client,
)
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
)
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
)
def _prepare_config(config: dict) -> dict:
@@ -470,38 +459,17 @@ class MCPClient:
),
)
else:
timeout_seconds = cfg.get("timeout", 30)
sse_read_timeout_seconds = cfg.get("sse_read_timeout", 60 * 5)
if streamable_http_client_legacy:
timeout = timedelta(seconds=timeout_seconds)
sse_read_timeout = timedelta(seconds=sse_read_timeout_seconds)
self._streams_context = streamable_http_client_legacy(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
elif streamable_http_client:
http_client = await self.exit_stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.get("headers", {}),
timeout=httpx.Timeout(
timeout_seconds,
read=sse_read_timeout_seconds,
),
follow_redirects=True,
),
)
self._streams_context = streamable_http_client(
url=cfg["url"],
http_client=http_client,
terminate_on_close=cfg.get("terminate_on_close", True),
)
else:
raise RuntimeError(
"Streamable HTTP transport is not available in the installed MCP library version."
)
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5),
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context,
)

View File

@@ -224,7 +224,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
request_max_retries: int | None = None,
tool_result_overflow_dir: str | None = None,
read_tool: FunctionTool | None = None,
**kwargs: T.Any,
@@ -238,7 +237,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
self.request_max_retries = request_max_retries
self.tool_result_overflow_dir = tool_result_overflow_dir
self.read_tool = read_tool
self._tool_result_token_counter = EstimateTokenCounter()
@@ -465,7 +463,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
"abort_signal": self._abort_signal,
"request_max_retries": self.request_max_retries,
}
if include_model:
# For primary provider we keep explicit model selection if provided.
@@ -1308,7 +1305,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
extra_user_content_parts=self.req.extra_user_content_parts,
# tool_choice="required",
abort_signal=self._abort_signal,
request_max_retries=self.request_max_retries,
)
if requery_resp:
llm_resp = requery_resp
@@ -1335,7 +1331,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
extra_user_content_parts=self.req.extra_user_content_parts,
# tool_choice="required",
abort_signal=self._abort_signal,
request_max_retries=self.request_max_retries,
)
if repair_resp:
llm_resp = repair_resp

View File

@@ -278,11 +278,10 @@ async def _apply_kb(
)
if not kb_result:
return
req.extra_user_content_parts.append(
TextPart(
text=f"[Related Knowledge Base Results]:\n{kb_result}",
).mark_as_temp()
)
if req.system_prompt is not None:
req.system_prompt += (
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
)
except Exception as exc: # noqa: BLE001
logger.error("Error occurred while retrieving knowledge base: %s", exc)
else:
@@ -457,10 +456,10 @@ async def _ensure_persona_and_skills(
cfg: dict,
plugin_context: Context,
event: AstrMessageEvent,
) -> None:
) -> set[str] | None:
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
if not req.conversation:
return
return None
(
persona_id,
@@ -527,11 +526,13 @@ async def _ensure_persona_and_skills(
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
persona_allowed_tools = None
persona_toolset = tmgr.get_full_tool_set()
for tool in list(persona_toolset):
if not tool.active:
persona_toolset.remove_tool(tool.name)
else:
persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]}
persona_toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
@@ -612,6 +613,7 @@ async def _ensure_persona_and_skills(
)
except Exception:
pass
return persona_allowed_tools
async def _request_img_caption(
@@ -944,12 +946,13 @@ async def _decorate_llm_request(
plugin_context: Context,
config: MainAgentBuildConfig,
provider: Provider | None = None,
) -> None:
) -> set[str] | None:
cfg = config.provider_settings or plugin_context.get_config(
umo=event.unified_msg_origin
).get("provider_settings", {})
_apply_prompt_prefix(req, cfg)
persona_allowed_tools = None
main_provider_supports_image = provider is not None and _provider_supports_modality(
provider, "image"
@@ -958,7 +961,9 @@ async def _decorate_llm_request(
quote_images_already_captioned = False
if req.conversation:
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
persona_allowed_tools = await _ensure_persona_and_skills(
req, cfg, plugin_context, event
)
if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
await _ensure_img_caption(
@@ -987,6 +992,7 @@ async def _decorate_llm_request(
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
return persona_allowed_tools
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
@@ -1515,7 +1521,9 @@ async def build_main_agent(
else:
return None
await _decorate_llm_request(event, req, plugin_context, config, provider=provider)
persona_allowed_tools = await _decorate_llm_request(
event, req, plugin_context, config, provider=provider
)
await _apply_kb(event, req, plugin_context, config)
@@ -1551,6 +1559,11 @@ async def build_main_agent(
)
)
if persona_allowed_tools is not None and req.func_tool:
req.func_tool.tools = [
tool for tool in req.func_tool.tools if tool.name in persona_allowed_tools
]
fallback_providers = _get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
)
@@ -1616,7 +1629,6 @@ async def build_main_agent(
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
fallback_providers=fallback_providers,
request_max_retries=config.provider_settings.get("request_max_retries", 5),
tool_result_overflow_dir=(
get_astrbot_system_tmp_path()
if req.func_tool and req.func_tool.get_tool("astrbot_file_read_tool")

View File

@@ -129,7 +129,6 @@ DEFAULT_CONFIG = {
"enable": True,
"default_provider_id": "",
"fallback_chat_models": [],
"request_max_retries": 5,
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
@@ -2837,9 +2836,6 @@ CONFIG_METADATA_2 = {
"type": "list",
"items": {"type": "string"},
},
"request_max_retries": {
"type": "int",
},
"wake_prefix": {
"type": "string",
},
@@ -3199,11 +3195,6 @@ CONFIG_METADATA_3 = {
"_special": "select_providers",
"hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
},
"provider_settings.request_max_retries": {
"description": "请求最大重试次数",
"type": "int",
"hint": "单次模型请求遇到可重试错误时的最大尝试次数。",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",

View File

@@ -106,7 +106,6 @@ class Provider(AbstractProvider):
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -121,7 +120,6 @@ class Provider(AbstractProvider):
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
request_max_retries: 可重试请求错误的最大尝试次数,包含首次请求。
kwargs: 其他参数
Notes:
@@ -144,7 +142,6 @@ class Provider(AbstractProvider):
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -158,7 +155,6 @@ class Provider(AbstractProvider):
tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
request_max_retries: 可重试请求错误的最大尝试次数,包含首次请求。
kwargs: 其他参数
Notes:

View File

@@ -27,7 +27,6 @@ from astrbot.core.utils.network_utils import (
)
from ..register import register_provider_adapter
from .request_retry import retry_provider_request, retry_provider_request_context
@register_provider_adapter(
@@ -354,13 +353,7 @@ class ProviderAnthropic(Provider):
logger.warning(f"未知的 tool_choice 值: {tool_choice},已回退为 'auto'")
return {"type": "auto"}
async def _query(
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
@@ -375,12 +368,8 @@ class ProviderAnthropic(Provider):
self._apply_thinking_config(payloads)
try:
completion = await retry_provider_request(
"Anthropic",
lambda: self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
),
max_attempts=request_max_retries,
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
)
except httpx.RequestError as e:
proxy = self.provider_config.get("proxy", "")
@@ -449,8 +438,6 @@ class ProviderAnthropic(Provider):
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -474,10 +461,8 @@ class ProviderAnthropic(Provider):
payloads["max_tokens"] = 65536
self._apply_thinking_config(payloads)
async with retry_provider_request_context(
"Anthropic",
lambda: self.client.messages.stream(**payloads, extra_body=extra_body),
max_attempts=request_max_retries,
async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
@@ -616,7 +601,6 @@ class ProviderAnthropic(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -666,11 +650,7 @@ class ProviderAnthropic(Provider):
llm_response = None
try:
llm_response = await self._query(
payloads,
func_tool,
request_max_retries=request_max_retries,
)
llm_response = await self._query(payloads, func_tool)
except Exception as e:
raise e
@@ -689,7 +669,6 @@ class ProviderAnthropic(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto",
request_max_retries: int | None = None,
**kwargs,
):
if contexts is None:
@@ -736,11 +715,7 @@ class ProviderAnthropic(Provider):
else system_prompt
)
async for llm_response in self._query_stream(
payloads,
func_tool,
request_max_retries=request_max_retries,
):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
def _detect_image_mime_type(self, data: bytes) -> str:
@@ -852,10 +827,7 @@ class ProviderAnthropic(Provider):
async def get_models(self) -> list[str]:
models_str = []
models = await retry_provider_request(
"Anthropic",
lambda: self.client.models.list(),
)
models = await self.client.models.list()
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)

View File

@@ -26,7 +26,6 @@ from astrbot.core.utils.media_utils import (
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
from ..register import register_provider_adapter
from .request_retry import retry_provider_request
class SuppressNonTextPartsWarning(logging.Filter):
@@ -578,13 +577,7 @@ class ProviderGoogleGenAI(Provider):
)
return chain_result
async def _query(
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
@@ -611,14 +604,10 @@ class ProviderGoogleGenAI(Provider):
modalities,
temperature,
)
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
max_attempts=request_max_retries,
result = await self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
)
logger.debug(f"genai result: {result}")
@@ -683,8 +672,6 @@ class ProviderGoogleGenAI(Provider):
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -703,14 +690,10 @@ class ProviderGoogleGenAI(Provider):
payloads.get("tool_choice", "auto"),
system_instruction,
)
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
max_attempts=request_max_retries,
result = await self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
)
break
except APIError as e:
@@ -826,7 +809,6 @@ class ProviderGoogleGenAI(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -868,11 +850,7 @@ class ProviderGoogleGenAI(Provider):
for _ in range(retry):
try:
return await self._query(
payloads,
func_tool,
request_max_retries=request_max_retries,
)
return await self._query(payloads, func_tool)
except APIError as e:
if await self._handle_api_error(e, keys):
continue
@@ -893,7 +871,6 @@ class ProviderGoogleGenAI(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
@@ -935,11 +912,7 @@ class ProviderGoogleGenAI(Provider):
for _ in range(retry):
try:
async for response in self._query_stream(
payloads,
func_tool,
request_max_retries=request_max_retries,
):
async for response in self._query_stream(payloads, func_tool):
yield response
break
except APIError as e:
@@ -949,10 +922,7 @@ class ProviderGoogleGenAI(Provider):
async def get_models(self):
try:
models = await retry_provider_request(
"Gemini",
lambda: self.client.models.list(),
)
models = await self.client.models.list()
return [
m.name.replace("models/", "")
for m in models

View File

@@ -41,7 +41,6 @@ from astrbot.core.utils.network_utils import (
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
from .request_retry import retry_provider_request
@register_provider_adapter(
@@ -421,10 +420,7 @@ class ProviderOpenAIOfficial(Provider):
async def get_models(self):
try:
models_str = []
models = await retry_provider_request(
"OpenAI",
lambda: self.client.models.list(),
)
models = await self.client.models.list()
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
@@ -469,13 +465,7 @@ class ProviderOpenAIOfficial(Provider):
payloads["messages"] = cleaned
async def _query(
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
@@ -506,14 +496,10 @@ class ProviderOpenAIOfficial(Provider):
self._sanitize_assistant_messages(payloads)
completion = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
),
max_attempts=request_max_retries,
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
)
if not isinstance(completion, ChatCompletion):
@@ -531,8 +517,6 @@ class ProviderOpenAIOfficial(Provider):
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API逐步返回结果"""
if tools:
@@ -564,15 +548,11 @@ class ProviderOpenAIOfficial(Provider):
self._sanitize_assistant_messages(payloads)
stream = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
),
max_attempts=request_max_retries,
stream = await self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
)
llm_response = LLMResponse("assistant", is_chunk=True)
@@ -1124,7 +1104,6 @@ class ProviderOpenAIOfficial(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -1152,11 +1131,7 @@ class ProviderOpenAIOfficial(Provider):
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(
payloads,
func_tool,
request_max_retries=request_max_retries,
)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
last_exception = e
@@ -1201,7 +1176,6 @@ class ProviderOpenAIOfficial(Provider):
tool_calls_result=None,
model=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
@@ -1228,11 +1202,7 @@ class ProviderOpenAIOfficial(Provider):
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(
payloads,
func_tool,
request_max_retries=request_max_retries,
):
async for response in self._query_stream(payloads, func_tool):
yield response
break
except Exception as e:

View File

@@ -1,163 +0,0 @@
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TypeVar
from tenacity import (
AsyncRetrying,
RetryCallState,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
from astrbot import logger
from astrbot.core.utils.config_number import coerce_int_config
from astrbot.core.utils.network_utils import is_connection_error
T = TypeVar("T")
REQUEST_RETRY_ATTEMPTS = 5 # default value
REQUEST_RETRY_WAIT_MIN_S = 0.2
REQUEST_RETRY_WAIT_MAX_S = 30
REQUEST_RETRY_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504, 529}
def _get_status_code(error: BaseException) -> int | None:
for attr in ("status_code", "status", "code"):
value = getattr(error, attr, None)
if isinstance(value, int):
return value
response = getattr(error, "response", None)
if response is not None:
status_code = getattr(response, "status_code", None)
if isinstance(status_code, int):
return status_code
return None
def _is_retryable_provider_request_error(
error: BaseException,
*,
retry_rate_limits: bool,
) -> bool:
if is_connection_error(error):
return True
error_type_name = type(error).__name__
if error_type_name in {"APIConnectionError", "APITimeoutError"}:
return True
status_code = _get_status_code(error)
if status_code is None:
return False
if status_code == 429 and not retry_rate_limits:
return False
return status_code in REQUEST_RETRY_STATUS_CODES or 500 <= status_code <= 599
def _log_retry(
provider_label: str,
retry_state: RetryCallState,
max_attempts: int,
) -> None:
error = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
f"[{provider_label}] Request failed with retryable error; "
f"retrying ({retry_state.attempt_number + 1}/{max_attempts}): "
f"{error}"
)
def _build_retrying(
provider_label: str,
*,
retry_rate_limits: bool,
max_attempts: int | None = None,
) -> AsyncRetrying:
max_attempts = coerce_int_config(
max_attempts if max_attempts is not None else REQUEST_RETRY_ATTEMPTS,
default=REQUEST_RETRY_ATTEMPTS,
min_value=1,
field_name="request_max_retries",
source=provider_label,
)
return AsyncRetrying(
retry=retry_if_exception(
lambda error: _is_retryable_provider_request_error(
error,
retry_rate_limits=retry_rate_limits,
)
),
stop=stop_after_attempt(max_attempts),
wait=wait_exponential(
multiplier=1,
min=REQUEST_RETRY_WAIT_MIN_S,
max=REQUEST_RETRY_WAIT_MAX_S,
),
before_sleep=lambda retry_state: _log_retry(
provider_label,
retry_state,
max_attempts,
),
reraise=True,
)
async def retry_provider_request(
provider_label: str,
request_factory: Callable[[], Awaitable[T]],
*,
retry_rate_limits: bool = True,
max_attempts: int | None = None,
) -> T:
retrying = _build_retrying(
provider_label,
retry_rate_limits=retry_rate_limits,
max_attempts=max_attempts,
)
async for attempt in retrying:
with attempt:
return await request_factory()
raise RuntimeError("Provider request retry loop exited unexpectedly.")
@asynccontextmanager
async def retry_provider_request_context(
provider_label: str,
context_manager_factory: Callable[[], AbstractAsyncContextManager[T]],
*,
retry_rate_limits: bool = True,
max_attempts: int | None = None,
) -> AsyncIterator[T]:
manager: AbstractAsyncContextManager[T] | None = None
async def _enter_context() -> T:
nonlocal manager
manager = context_manager_factory()
return await manager.__aenter__()
value = await retry_provider_request(
provider_label,
_enter_context,
retry_rate_limits=retry_rate_limits,
max_attempts=max_attempts,
)
if manager is None:
raise RuntimeError("Provider request context was not created.")
try:
yield value
except BaseException as error:
if await manager.__aexit__(type(error), error, error.__traceback__):
return
raise
else:
await manager.__aexit__(None, None, None)

View File

@@ -183,22 +183,8 @@ async def download_file(
path: str,
show_progress: bool = False,
progress_callback=None,
allow_insecure_ssl_fallback: bool = True,
) -> None:
"""Download a remote file to a local path.
Args:
url: Remote URL to download.
path: Local destination path.
show_progress: Whether to print progress to stdout.
progress_callback: Optional callback for progress payloads.
allow_insecure_ssl_fallback: Whether certificate failures may retry with
TLS certificate verification disabled.
Returns:
None.
"""
"""从指定 url 下载文件到指定路径 path"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
@@ -273,8 +259,6 @@ async def download_file(
},
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
if not allow_insecure_ssl_fallback:
raise
# 关闭SSL验证仅在证书验证失败时作为fallback
logger.warning(
f"SSL certificate verification failed for {_safe_url_for_log(url)}. "
@@ -371,22 +355,10 @@ def get_local_ip_addresses():
return network_ips
def get_dashboard_dist_version(dist_dir: str | Path) -> str | None:
"""Read the WebUI version from a dashboard dist directory.
Args:
dist_dir: Dashboard dist directory path.
Returns:
The version string from assets/version, or None when unavailable.
"""
def _read_dashboard_dist_version(dist_dir: str | Path) -> str | None:
version_file = Path(dist_dir) / "assets" / "version"
try:
if version_file.exists():
return version_file.read_text(encoding="utf-8").strip()
except (OSError, UnicodeDecodeError) as exc:
logger.warning("Failed to read WebUI version from %s: %s", version_file, exc)
if version_file.exists():
return version_file.read_text(encoding="utf-8").strip()
return None
@@ -408,106 +380,42 @@ def _normalize_dashboard_version(version: str) -> str:
return version
def is_dashboard_version_compatible(
dashboard_version: str | None, current_version: str
def should_use_bundled_dashboard_dist(
user_dist: str | Path, current_version: str
) -> bool:
"""Check whether a WebUI version matches the current core version.
Args:
dashboard_version: Version read from the WebUI assets/version file.
current_version: Current AstrBot core version.
Returns:
True when both versions are valid SemVer values and compare equal.
"""
if dashboard_version is None:
user_version = _read_dashboard_dist_version(user_dist)
bundled_dist = get_bundled_dashboard_dist_path()
if user_version is None or not bundled_dist.exists():
return False
try:
return (
VersionComparator.compare_version(
_normalize_dashboard_version(dashboard_version),
_normalize_dashboard_version(current_version),
_normalize_dashboard_version(user_version),
)
== 0
> 0
)
except (TypeError, ValueError):
return False
def is_dashboard_dist_compatible(dist_dir: str | Path, current_version: str) -> bool:
"""Check whether a WebUI dist is complete and matches the core version.
Args:
dist_dir: Dashboard dist directory path.
current_version: Current AstrBot core version.
Returns:
True when the dist has an index file and a compatible assets/version.
"""
dist_path = Path(dist_dir)
return (dist_path / "index.html").is_file() and is_dashboard_version_compatible(
get_dashboard_dist_version(dist_path),
current_version,
)
def should_use_bundled_dashboard_dist(
user_dist: str | Path, current_version: str
) -> bool:
"""Decide whether bundled WebUI should replace a user data dist.
Args:
user_dist: Runtime dashboard dist directory under data/.
current_version: Current AstrBot core version.
Returns:
True when user_dist exists but is missing or mismatched against the
current core version, and bundled WebUI matches the current core version.
"""
user_dist = Path(user_dist)
user_version = get_dashboard_dist_version(user_dist)
bundled_dist = get_bundled_dashboard_dist_path()
if not user_dist.exists() or not is_dashboard_dist_compatible(
bundled_dist,
current_version,
):
return False
if user_version is None or not (user_dist / "index.html").is_file():
return True
try:
return not is_dashboard_version_compatible(user_version, current_version)
except (TypeError, ValueError):
return False
async def get_dashboard_version():
"""Return the effective WebUI version for the current runtime.
Returns:
The matching data/dist version, matching bundled version, or the raw
data/dist version when no compatible bundled WebUI is available.
"""
from astrbot.core.config.default import VERSION
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(dist_dir):
user_version = get_dashboard_dist_version(dist_dir)
if is_dashboard_dist_compatible(dist_dir, VERSION):
return user_version
from astrbot.core.config.default import VERSION
bundled = get_bundled_dashboard_dist_path()
if is_dashboard_dist_compatible(bundled, VERSION):
return get_dashboard_dist_version(bundled)
return user_version
if should_use_bundled_dashboard_dist(dist_dir, VERSION):
bundled_version = _read_dashboard_dist_version(
get_bundled_dashboard_dist_path()
)
if bundled_version is not None:
return bundled_version
return _read_dashboard_dist_version(dist_dir)
bundled = get_bundled_dashboard_dist_path()
if is_dashboard_dist_compatible(bundled, VERSION):
return get_dashboard_dist_version(bundled)
if bundled.exists():
return _read_dashboard_dist_version(bundled)
return None
@@ -519,7 +427,6 @@ async def download_dashboard(
proxy: str | None = None,
progress_callback=None,
extract: bool = True,
allow_insecure_ssl_fallback: bool = True,
) -> None:
"""Download dashboard assets and optionally extract them.
@@ -531,8 +438,6 @@ async def download_dashboard(
proxy: Optional download proxy prefix.
progress_callback: Optional callback for download progress payloads.
extract: Whether to extract the archive after download.
allow_insecure_ssl_fallback: Whether certificate failures may retry with
TLS certificate verification disabled.
Returns:
None.
@@ -555,7 +460,6 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError(
@@ -587,7 +491,6 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError(
@@ -603,7 +506,6 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError("Downloaded dashboard package is not a valid ZIP file")

View File

@@ -22,9 +22,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import (
get_bundled_dashboard_dist_path,
get_dashboard_dist_version,
get_local_ip_addresses,
is_dashboard_dist_compatible,
should_use_bundled_dashboard_dist,
)
from astrbot.dashboard.asgi_runtime import (
@@ -184,41 +182,21 @@ class AstrBotDashboard:
# Path priority:
# 1. Explicit webui_dir argument
# 2. data/dist/ when it matches the core version
# 3. astrbot/dashboard/dist/ when it matches the core version
# 2. data/dist/ (user-installed / manually updated dashboard)
# 3. astrbot/dashboard/dist/ (bundled with the wheel)
if webui_dir and os.path.exists(webui_dir):
self.data_path = os.path.abspath(webui_dir)
else:
user_dist = os.path.join(get_astrbot_data_path(), "dist")
bundled_dist = get_bundled_dashboard_dist_path()
user_version = get_dashboard_dist_version(user_dist)
if os.path.exists(user_dist) and is_dashboard_dist_compatible(
if os.path.exists(user_dist) and not should_use_bundled_dashboard_dist(
user_dist,
VERSION,
):
self.data_path = os.path.abspath(user_dist)
elif should_use_bundled_dashboard_dist(
user_dist,
VERSION,
) or is_dashboard_dist_compatible(bundled_dist, VERSION):
elif bundled_dist.exists():
self.data_path = str(bundled_dist)
logger.info("Using bundled dashboard dist: %s", self.data_path)
elif (
os.path.exists(user_dist) and (Path(user_dist) / "index.html").is_file()
):
logger.warning(
"Using existing data/dist as a fallback even though WebUI version mismatches core: %s, expected v%s. "
"Some dashboard features may not work until the matching WebUI is available.",
user_version,
VERSION,
)
self.data_path = os.path.abspath(user_dist)
elif os.path.exists(user_dist):
logger.warning(
"Ignoring data/dist because WebUI files are incomplete for core v%s.",
VERSION,
)
self.data_path = None
else:
# Fall back to expected user path (will fail gracefully later)
self.data_path = os.path.abspath(user_dist)
@@ -567,7 +545,7 @@ class AstrBotDashboard:
raise Exception(f"端口 {port} 已被占用")
if self.data_path and (Path(self.data_path) / "index.html").is_file():
if (Path(self.data_path) / "index.html").is_file():
webui_status = "WebUI is ready"
else:
webui_status = (

View File

@@ -1,22 +0,0 @@
- [更新日志(简体中文)](#chinese)
- [Changelog(English)](#english)
<a id="chinese"></a>
## What's Changed
### 修复
- 恢复 WebUI 在接口返回 401 时跳转登录页,避免会话失效后停留在异常状态。([#8903](https://github.com/AstrBotDevs/AstrBot/pull/8903))
- 保持 Core 版本与 WebUI 静态资源版本同步,修复打包或升级后可能加载旧 dist、资源版本错配的问题。([#8901](https://github.com/AstrBotDevs/AstrBot/pull/8901))
- 将知识库上下文作为临时 user 内容注入,修复模型请求中知识库上下文角色不准确的问题。([#8904](https://github.com/AstrBotDevs/AstrBot/pull/8904))
<a id="english"></a>
## What's Changed (EN)
### Bug Fixes
- Restored the WebUI login redirect when API requests return 401, preventing expired sessions from staying in a broken state. ([#8903](https://github.com/AstrBotDevs/AstrBot/pull/8903))
- Kept Core and WebUI static asset versions in sync, fixing stale dist loading and asset version mismatches after packaging or upgrades. ([#8901](https://github.com/AstrBotDevs/AstrBot/pull/8901))
- Injected knowledge base context as temporary user content, fixing the role used for knowledge context in model requests. ([#8904](https://github.com/AstrBotDevs/AstrBot/pull/8904))

View File

@@ -1,54 +0,0 @@
- [更新日志(简体中文)](#chinese)
- [Changelog(English)](#english)
<a id="chinese"></a>
## What's Changed
### 重点更新
- 为 OpenAI、Gemini、Anthropic 等模型请求加入可配置的重试机制,并新增请求最大重试次数配置,提升临时网络错误与 5xx 服务端错误下的稳定性。([#8893](https://github.com/AstrBotDevs/AstrBot/pull/8893))
- 新增托管 Core 包下载能力,并加强 Core 与 Dashboard 包下载归档校验。([#8888](https://github.com/AstrBotDevs/AstrBot/pull/8888))
- 支持在请求中加载 workspace skills并加固 workspace skill 发现流程。([#8884](https://github.com/AstrBotDevs/AstrBot/pull/8884))
### 修复
- 修复 OpenAPI 文件上传能力,恢复 `/api/v1/file` OpenAPI 暴露、文件范围 API Key 与相关文档/客户端产物。
- 修复新版 MCP 中 Streamable HTTP client 重命名导致的兼容问题,并保持 `mcp` 依赖小于 2。
- 加固人格工具边界,确保人格限定的工具范围在主 Agent 请求中正确生效。([#8786](https://github.com/AstrBotDevs/AstrBot/pull/8786))
- 加强 Future Task 所有者校验,避免越权访问定时任务。([#8881](https://github.com/AstrBotDevs/AstrBot/pull/8881))
- 在受限本地文件系统工具中拒绝 hardlink 文件,避免通过工作区 hardlink 别名读写允许目录外的文件。
### 发布流程
- 新增 `scripts/prepare_release.py`,统一 release 分支、版本号、changelog 与校验流程。([#8891](https://github.com/AstrBotDevs/AstrBot/pull/8891))
### 文档
- 明确 OpenAPI Chat 中 `username` 字段的身份含义。([#8880](https://github.com/AstrBotDevs/AstrBot/pull/8880))
<a id="english"></a>
## What's Changed (EN)
### Highlights
- Added configurable retry handling for OpenAI, Gemini, Anthropic, and related provider requests, including a maximum request retry setting to improve stability for transient network failures and 5xx server errors. ([#8893](https://github.com/AstrBotDevs/AstrBot/pull/8893))
- Added hosted Core package downloads and strengthened archive validation for hosted Core and Dashboard packages. ([#8888](https://github.com/AstrBotDevs/AstrBot/pull/8888))
- Added workspace skills support in requests and hardened workspace skill discovery. ([#8884](https://github.com/AstrBotDevs/AstrBot/pull/8884))
### Bug Fixes
- Restored OpenAPI file uploads by exposing `/api/v1/file`, enabling file-scoped API keys, and regenerating docs/client artifacts.
- Fixed compatibility with the renamed MCP Streamable HTTP client while keeping the `mcp` dependency below 2.
- Hardened persona tool boundaries so persona-restricted tool scopes are enforced correctly in main Agent requests. ([#8786](https://github.com/AstrBotDevs/AstrBot/pull/8786))
- Enforced Future Task owner checks to prevent unauthorized scheduled-task access. ([#8881](https://github.com/AstrBotDevs/AstrBot/pull/8881))
- Rejected hardlinked files in restricted local filesystem tools to prevent workspace hardlink aliases from reading or overwriting files outside allowed directories.
### Release Process
- Added `scripts/prepare_release.py` to standardize release branches, version bumps, changelog generation, and validation. ([#8891](https://github.com/AstrBotDevs/AstrBot/pull/8891))
### Docs
- Clarified the identity semantics of the `username` field in OpenAPI Chat. ([#8880](https://github.com/AstrBotDevs/AstrBot/pull/8880))

View File

@@ -48,55 +48,6 @@ function attachAxiosHeaders(config: InternalAxiosRequestConfig) {
}
function normalizeAxiosError(error: AxiosError) {
if (error.response?.status === 401) {
let requestPath = '';
try {
const url = error.config?.url || '';
const baseURL = error.config?.baseURL;
const resolvedUrl =
url && baseURL && !/^([a-z][a-z\d+\-.]*:)?\/\//i.test(url)
? `${baseURL.replace(/\/+$/, '')}/${url.replace(/^\/+/, '')}`
: url;
const requestUrl = new URL(resolvedUrl || '/', window.location.origin);
if (requestUrl.origin === window.location.origin) {
requestPath = requestUrl.pathname;
}
} catch {
requestPath = '';
}
const isAuthChallenge =
[
'/api/auth/login',
'/api/auth/setup',
'/api/auth/setup-status',
'/api/v1/auth/login',
'/api/v1/auth/setup',
'/api/v1/auth/setup-status',
].includes(requestPath) ||
Boolean(
(
error.response.data as
| { data?: { totp_required?: boolean } }
| undefined
)?.data?.totp_required,
);
if (requestPath.startsWith('/api/') && !isAuthChallenge) {
[
'user',
'token',
'change_pwd_hint',
'md5_pwd_hint',
'password_upgrade_required',
].forEach((key) => localStorage.removeItem(key));
if (!window.location.hash.startsWith('#/auth/login')) {
window.location.hash = '/auth/login';
}
}
}
if (error.response?.status === 429) {
const data = error.response.data as { message?: string } | undefined;
if (data?.message) {

View File

@@ -45,10 +45,6 @@
"description": "Fallback chat model IDs",
"hint": "When the primary chat model request fails, fallback to these chat models in order."
},
"request_max_retries": {
"description": "Request Max Retries",
"hint": "Maximum attempts for a single model request when retryable errors occur."
},
"default_image_caption_provider_id": {
"description": "Default Image Caption Model",
"hint": "Leave empty to disable; useful for non-multimodal models"

View File

@@ -45,10 +45,6 @@
"description": "Резервные модели чата (ID)",
"hint": "Если текущая модель недоступна, запрос будет перенаправлен на эти модели по порядку."
},
"request_max_retries": {
"description": "Максимум повторов запроса",
"hint": "Максимальное число попыток для одного запроса модели при повторяемых ошибках."
},
"default_image_caption_provider_id": {
"description": "Модель описания изображений",
"hint": "Оставьте пустым для отключения; полезно для моделей без поддержки мультимодальности"

View File

@@ -45,10 +45,6 @@
"description": "回退对话模型列表",
"hint": "主对话模型请求失败时,按顺序切换到这些对话模型。"
},
"request_max_retries": {
"description": "请求最大重试次数",
"hint": "单次模型请求遇到可重试错误时的最大尝试次数。"
},
"default_image_caption_provider_id": {
"description": "默认图片转述模型",
"hint": "留空代表不使用,可用于非多模态模型"

97
main.py
View File

@@ -2,7 +2,6 @@ import argparse
import asyncio
import mimetypes
import os
import shutil
import sys
from pathlib import Path
@@ -47,10 +46,7 @@ from astrbot.core.utils.astrbot_path import ( # noqa: E402
from astrbot.core.utils.io import ( # noqa: E402
download_dashboard,
get_bundled_dashboard_dist_path,
get_dashboard_dist_version,
is_dashboard_dist_compatible,
is_dashboard_version_compatible,
remove_dir,
get_dashboard_version,
should_use_bundled_dashboard_dist,
)
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime # noqa: E402
@@ -95,15 +91,7 @@ def check_env() -> None:
async def check_dashboard_files(webui_dir: str | None = None):
"""Resolve and repair dashboard static files for startup.
Args:
webui_dir: Optional explicit WebUI directory path from CLI.
Returns:
The directory path to serve, or None when no usable WebUI can be prepared.
"""
"""下载管理面板文件"""
# 指定webui目录
if webui_dir:
if os.path.exists(webui_dir):
@@ -111,89 +99,40 @@ async def check_dashboard_files(webui_dir: str | None = None):
return webui_dir
logger.warning("WebUI directory not found: %s. Using default.", webui_dir)
data_dist_path = Path(get_astrbot_data_path()) / "dist"
bundled_dist = get_bundled_dashboard_dist_path()
if data_dist_path.exists():
v = get_dashboard_dist_version(data_dist_path)
if is_dashboard_dist_compatible(data_dist_path, VERSION):
logger.info("WebUI is up to date.")
return str(data_dist_path)
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(data_dist_path):
v = await get_dashboard_version()
if should_use_bundled_dashboard_dist(data_dist_path, VERSION):
bundled_dist = get_bundled_dashboard_dist_path()
logger.info(
"Replacing data/dist with bundled WebUI because its version does not match core version v%s.",
"Using bundled WebUI because data/dist is older than core version v%s.",
VERSION,
)
try:
remove_dir(str(data_dist_path))
shutil.copytree(bundled_dist, data_dist_path)
return str(data_dist_path)
except Exception as e:
return str(bundled_dist)
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI is up to date.")
else:
logger.warning(
"Failed to replace data/dist with bundled WebUI: %s. Using bundled WebUI directly.",
e,
)
return str(bundled_dist)
if is_dashboard_version_compatible(v, VERSION):
logger.warning(
"WebUI files are incomplete for v%s. Re-downloading WebUI.",
VERSION,
)
elif v is not None:
logger.warning(
"WebUI version mismatch: %s, expected v%s. Re-downloading WebUI.",
v,
VERSION,
)
else:
logger.warning(
"WebUI version file is missing. Re-downloading WebUI v%s.",
VERSION,
)
try:
await download_dashboard(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
if (data_dist_path / "index.html").is_file():
logger.warning(
"Falling back to existing data/dist WebUI %s even though core expects v%s. "
"Some dashboard features may not work until the matching WebUI is available.",
v or "unknown",
"WebUI version mismatch: %s, expected v%s.",
v,
VERSION,
)
return str(data_dist_path)
return None
logger.info("管理面板下载完成。")
return str(data_dist_path)
if is_dashboard_dist_compatible(bundled_dist, VERSION):
logger.info(
"Using bundled WebUI v%s.", get_dashboard_dist_version(bundled_dist)
)
return str(bundled_dist)
return data_dist_path
logger.info(
"Downloading WebUI. If it fails, download dist.zip from https://github.com/AstrBotDevs/AstrBot/releases/latest and extract dist to data/.",
)
try:
await download_dashboard(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return str(data_dist_path)
return data_dist_path
async def main_async(webui_dir_arg: str | None) -> None:

View File

@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.26.0-beta.10"
version = "4.26.0-beta.8"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
license = { text = "AGPL-3.0-or-later" }
@@ -29,7 +29,7 @@ dependencies = [
"google-genai>=1.56.0",
"httpx[socks]>=0.28.1",
"lark-oapi>=1.4.15",
"mcp>=1.8.0,<2",
"mcp>=1.8.0",
"openai>=1.78.0",
"ormsgpack>=1.9.1",
"pillow>=11.2.1",

View File

@@ -18,7 +18,7 @@ filelock>=3.18.0
google-genai>=1.56.0
httpx[socks]>=0.28.1
lark-oapi>=1.4.15
mcp>=1.8.0,<2
mcp>=1.8.0
openai>=1.78.0
ormsgpack>=1.9.1
pillow>=11.2.1

View File

@@ -1,12 +1,9 @@
import builtins
from types import SimpleNamespace
import httpx
import pytest
import astrbot.core.provider.sources.anthropic_source as anthropic_source
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
@@ -174,36 +171,6 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
assert captured["httpx_module"] is anthropic_source.httpx
@pytest.mark.asyncio
async def test_anthropic_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="claude-b"),
SimpleNamespace(id="claude-a"),
]
)
models = FakeModels()
provider = anthropic_source.ProviderAnthropic.__new__(
anthropic_source.ProviderAnthropic
)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["claude-a", "claude-b"]
assert models.calls == 2
@pytest.mark.asyncio
async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
@@ -220,7 +187,7 @@ async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
captured_payloads: dict[str, object] = {}
async def fake_query(payloads, tools, *, request_max_retries=None):
async def fake_query(payloads, tools):
captured_payloads.update(payloads)
return LLMResponse(role="assistant", completion_text="ok")
@@ -247,7 +214,7 @@ async def test_text_chat_passes_through_list_system_prompt(monkeypatch):
captured_payloads: dict[str, object] = {}
async def fake_query(payloads, tools, *, request_max_retries=None):
async def fake_query(payloads, tools):
captured_payloads.update(payloads)
return LLMResponse(role="assistant", completion_text="ok")

View File

@@ -273,7 +273,6 @@ def test_dashboard_uses_bundled_dist_when_data_dist_is_stale(
bundled_dist = tmp_path / "bundled-dist"
user_dist.mkdir(parents=True)
bundled_dist.mkdir()
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
@@ -294,59 +293,6 @@ def test_dashboard_uses_bundled_dist_when_data_dist_is_stale(
assert server.data_path == str(bundled_dist)
def test_dashboard_falls_back_to_mismatched_data_dist_without_bundled(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path,
):
data_dir = tmp_path / "data"
user_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
(user_dist / "index.html").write_text("stale", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.dashboard.server.get_bundled_dashboard_dist_path",
lambda: bundled_dist,
)
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
assert server.data_path == str(user_dist)
def test_dashboard_ignores_incomplete_mismatched_data_dist_without_bundled(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path,
):
data_dir = tmp_path / "data"
user_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.dashboard.server.get_bundled_dashboard_dist_path",
lambda: bundled_dist,
)
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
assert server.data_path is None
async def _set_dashboard_password_change_required(
core_lifecycle_td: AstrBotCoreLifecycle,
required: bool,

View File

@@ -1,10 +1,6 @@
from types import SimpleNamespace
import httpx
import pytest
from astrbot.core.exceptions import EmptyModelOutputError
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
@@ -31,35 +27,3 @@ def test_gemini_reasoning_only_output_is_allowed():
response_id="resp_reasoning",
finish_reason="STOP",
)
@pytest.mark.asyncio
async def test_gemini_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return [
SimpleNamespace(
name="models/gemini-a",
supported_actions=["generateContent"],
),
SimpleNamespace(
name="models/gemini-b",
supported_actions=["embedContent"],
),
]
models = FakeModels()
provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gemini-a"]
assert models.calls == 2

View File

@@ -9,7 +9,7 @@ from unittest import mock
import pytest
from astrbot.core.utils.io import get_dashboard_version, should_use_bundled_dashboard_dist
from astrbot.core.utils.io import should_use_bundled_dashboard_dist
from main import (
DASHBOARD_RESET_PASSWORD_ENV,
_apply_startup_env_flags,
@@ -173,146 +173,49 @@ def test_version_info_comparisons():
@pytest.mark.asyncio
async def test_check_dashboard_files_not_exists(tmp_path):
async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist."""
data_dir = tmp_path / "data"
bundled_dist = tmp_path / "bundled-dist"
monkeypatch.setattr(os.path, "exists", lambda x: False)
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
from main import VERSION
assert result == str(data_dir / "dist")
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
mock_download.assert_called_once()
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(tmp_path):
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
"""Tests that dashboard is not downloaded when it exists and version matches."""
from main import VERSION
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(data_dist / "index.html").write_text("user", encoding="utf-8")
# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION
mock_get_version.return_value = f"v{VERSION}"
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
await check_dashboard_files()
# Assert that download_dashboard was NOT called
mock_download.assert_not_called()
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch_downloads(tmp_path):
"""Tests that a mismatched dashboard is downloaded on startup."""
from main import VERSION
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.logger.warning") as mock_logger_warning:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
with mock.patch(
"main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1")
):
with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "WebUI version mismatch" in call_args[0]
@pytest.mark.asyncio
async def test_check_dashboard_files_falls_back_to_stale_dist_when_download_fails(
tmp_path,
):
"""Tests stale dashboard fallback when the matching WebUI cannot be downloaded."""
from main import VERSION
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
(data_dist / "index.html").write_text("stale", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch(
"main.download_dashboard",
side_effect=RuntimeError("missing dashboard asset"),
) as mock_download:
with mock.patch("main.logger.warning") as mock_logger_warning:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
assert any(
"Falling back to existing data/dist WebUI" in call.args[0]
for call in mock_logger_warning.call_args_list
)
@pytest.mark.asyncio
async def test_check_dashboard_files_downloads_when_matching_dist_is_incomplete(
tmp_path,
):
"""Tests that a version match alone is not enough to serve WebUI."""
from main import VERSION
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
@@ -320,7 +223,6 @@ def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
(bundled_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v4.24.2", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
@@ -329,94 +231,46 @@ def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
assert should_use_bundled_dashboard_dist(user_dist, "v4.24.4") is True
def test_should_use_bundled_dashboard_dist_when_version_file_is_malformed(tmp_path):
def test_should_keep_data_dist_when_version_file_is_malformed(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("not-a-version", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is True
def test_should_use_bundled_dashboard_dist_when_data_version_file_is_missing(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is True
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is False
@pytest.mark.asyncio
async def test_get_dashboard_version_uses_bundled_dist_when_data_dist_is_missing(
async def test_check_dashboard_files_uses_bundled_dist_when_data_dist_is_stale(
tmp_path,
):
"""Tests bundled WebUI version lookup when data/dist is absent."""
from main import VERSION
data_dir = tmp_path / "data"
bundled_dist = tmp_path / "bundled-dist"
(bundled_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_astrbot_data_path",
return_value=str(data_dir),
):
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert await get_dashboard_version() == f"v{VERSION}"
@pytest.mark.asyncio
async def test_check_dashboard_files_replaces_stale_data_dist_with_bundled_dist(
tmp_path,
):
"""Tests that a stale data/dist is repaired from bundled dashboard assets."""
from main import VERSION
"""Tests that a stale data/dist does not override bundled dashboard assets."""
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
(data_dist / "old.txt").write_text("old", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
data_dist.mkdir(parents=True)
bundled_dist.mkdir()
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
"main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1")
):
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
"main.should_use_bundled_dashboard_dist", return_value=True
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
assert (data_dist / "assets" / "version").read_text(encoding="utf-8") == f"v{VERSION}"
assert (data_dist / "index.html").read_text(encoding="utf-8") == "bundled"
assert not (data_dist / "old.txt").exists()
assert result == str(bundled_dist)
mock_download.assert_not_called()
@@ -427,7 +281,7 @@ async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_dist_version") as mock_get_version:
with mock.patch("main.get_dashboard_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()

View File

@@ -3,16 +3,13 @@ import builtins
from io import BytesIO
from types import SimpleNamespace
import httpx
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from PIL import Image as PILImage
import astrbot.core.provider.sources.openai_source as openai_source_module
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.groq_source import ProviderGroq
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
from astrbot.core.utils.media_utils import ResolvedMediaData, file_uri_to_path
@@ -120,57 +117,6 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
assert captured["httpx_module"] is openai_source_module.httpx
@pytest.mark.asyncio
async def test_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="gpt-b"),
SimpleNamespace(id="gpt-a"),
]
)
models = FakeModels()
provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gpt-a", "gpt-b"]
assert models.calls == 2
@pytest.mark.asyncio
async def test_text_chat_passes_request_max_retries_to_query():
captured: dict[str, object] = {}
provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial)
provider.api_keys = ["test-key"]
provider.client = SimpleNamespace(api_key=None)
async def fake_prepare_chat_payload(*args, **kwargs):
return {"messages": [], "model": "gpt-4o-mini"}, []
async def fake_query(payloads, func_tool, *, request_max_retries=None):
captured["request_max_retries"] = request_max_retries
return LLMResponse(role="assistant", completion_text="ok")
provider._prepare_chat_payload = fake_prepare_chat_payload
provider._query = fake_query
await provider.text_chat(prompt="hello", request_max_retries=2)
assert captured["request_max_retries"] == 2
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_removes_images():
provider = _make_provider(

View File

@@ -1,27 +0,0 @@
import httpx
import pytest
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.provider.sources.request_retry import retry_provider_request
@pytest.mark.asyncio
async def test_retry_provider_request_uses_configured_max_retries(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
calls = 0
async def request():
nonlocal calls
calls += 1
raise httpx.ConnectError("temporary connection failure")
with pytest.raises(httpx.ConnectError):
await retry_provider_request(
"Test",
request,
max_attempts=2,
)
assert calls == 2

View File

@@ -440,7 +440,6 @@ async def test_download_dashboard_falls_back_when_hosted_package_is_not_zip(
path: str,
show_progress: bool = False, # noqa: ARG001
progress_callback=None, # noqa: ARG001
allow_insecure_ssl_fallback: bool = True, # noqa: ARG001
) -> None:
calls.append(url)
parsed = urlparse(url)

View File

@@ -8,7 +8,6 @@ import pytest
from astrbot.core import astr_main_agent as ama
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.message import Message, dump_messages_with_checkpoints
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Plain, Reply, Video
@@ -378,18 +377,8 @@ class TestApplyKb:
):
await module._apply_kb(mock_event, req, mock_context, config)
assert req.system_prompt == "System prompt"
assert len(req.extra_user_content_parts) == 1
kb_part = req.extra_user_content_parts[0]
assert kb_part.text == "[Related Knowledge Base Results]:\nKB result"
message = Message.model_validate(await req.assemble_context())
assert isinstance(message.content, list)
assert message.content[0].text == "test question"
assert message.content[1].text == "[Related Knowledge Base Results]:\nKB result"
assert dump_messages_with_checkpoints([message]) == [
{"role": "user", "content": [{"type": "text", "text": "test question"}]}
]
assert "[Related Knowledge Base Results]:" in req.system_prompt
assert "KB result" in req.system_prompt
@pytest.mark.asyncio
async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context):
@@ -1009,7 +998,7 @@ class TestEnsurePersonaAndSkills:
assert req.func_tool is not None
@pytest.mark.asyncio
async def test_persona_empty_tools_keeps_late_builtin_tools(
async def test_persona_empty_tools_filters_late_builtin_tools(
self, mock_event, mock_context, mock_provider
):
module = ama
@@ -1017,7 +1006,6 @@ class TestEnsurePersonaAndSkills:
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("locked", persona, None, False)
)
mock_event.platform_meta.support_proactive_message = False
mock_context.get_config.return_value = {
"provider_settings": {
"web_search": True,
@@ -1031,7 +1019,6 @@ class TestEnsurePersonaAndSkills:
"websearch_provider": "baidu_ai_search",
},
computer_use_runtime="none",
add_cron_tools=False,
)
req = ProviderRequest(prompt="hello")
req.conversation = MagicMock(persona_id="locked", history="[]")
@@ -1054,52 +1041,9 @@ class TestEnsurePersonaAndSkills:
)
assert result is not None
try:
assert result.provider_request.func_tool is not None
assert result.provider_request.func_tool.names() == ["web_search_baidu"]
finally:
if result.reset_coro:
result.reset_coro.close()
@pytest.mark.asyncio
async def test_persona_empty_tools_keeps_local_runtime_builtin_tools(
self, mock_event, mock_context, mock_provider
):
module = ama
persona = {"name": "locked", "prompt": "No tools.", "tools": []}
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("locked", persona, None, False)
)
mock_event.platform_meta.support_proactive_message = False
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="local",
add_cron_tools=False,
)
req = ProviderRequest(prompt="hello")
req.conversation = MagicMock(persona_id="locked", history="[]")
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=config,
provider=mock_provider,
req=req,
apply_reset=False,
assert result.provider_request.func_tool is None or (
result.provider_request.func_tool.empty()
)
assert result is not None
try:
assert result.provider_request.func_tool is not None
tool_names = result.provider_request.func_tool.names()
assert "astrbot_execute_shell" in tool_names
assert "astrbot_execute_python" in tool_names
finally:
if result.reset_coro:
result.reset_coro.close()