Compare commits

..

8 Commits

Author SHA1 Message Date
Soulter
eea74cf909 fix: prevent cli init from creating cwd data 2026-06-19 23:57:03 +08:00
lxfight
2d98d38078 fix: inject knowledge base context as temporary user content (#8904) 2026-06-19 22:48:23 +08:00
Weilong Liao
1b0f5cb0d3 fix: keep WebUI assets in sync with core version (#8901)
* fix: keep WebUI assets in sync with core version

* fix: import dashboard version before bundled fallback

* fix: remove stale WebUI dist robustly
2026-06-19 22:46:38 +08:00
Weilong Liao
cdfb0bdf91 fix: restore webui 401 login redirect (#8903) 2026-06-19 22:43:21 +08:00
Weilong Liao
3760abb39b chore: bump version to 4.26.0-beta.9 (#8895) 2026-06-19 17:47:17 +08:00
Weilong Liao
272242e407 chore: add release preparation workflow (#8891)
* chore: add release preparation workflow

* fix: address release workflow review feedback
2026-06-19 17:41:13 +08:00
Weilong Liao
dd36979eca feat: implement request retry mechanism for provider requests (#8893)
* feat: implement request retry mechanism for provider requests

* feat: add request max retries configuration and implement retry logic for provider requests

* feat: update fake_query function to accept request_max_retries parameter

* feat: remove retry_rate_limits from provider request calls
2026-06-19 17:13:40 +08:00
Weilong Liao
143f846b92 fix: support renamed MCP streamable HTTP client
Support both MCP streamable HTTP client names and keep mcp dependency below 2.
2026-06-19 15:54:47 +08:00
36 changed files with 1220 additions and 175 deletions

View File

@@ -16,8 +16,11 @@ venv*/
ENV/
.conda/
dashboard/
!astrbot/dashboard/
!astrbot/dashboard/dist/
!astrbot/dashboard/dist/**
data/
tests/
.ruff_cache/
.astrbot
astrbot.lock
astrbot.lock

View File

@@ -46,14 +46,21 @@ jobs:
- name: Build Dashboard
run: |
dashboard_version=$(python3 - <<'PY'
import tomllib
with open("pyproject.toml", "rb") as f:
print("v" + tomllib.load(f)["project"]["version"])
PY
)
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
echo "$dashboard_version" > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
mkdir -p astrbot/dashboard
rm -rf astrbot/dashboard/dist
cp -r dashboard/dist astrbot/dashboard/dist
- name: Determine test image tags
id: test-meta
@@ -157,10 +164,11 @@ jobs:
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
echo "${{ steps.release-meta.outputs.version }}" > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
mkdir -p astrbot/dashboard
rm -rf astrbot/dashboard/dist
cp -r dashboard/dist astrbot/dashboard/dist
- name: Set QEMU
uses: docker/setup-qemu-action@v4.1.0

View File

@@ -1,3 +1,3 @@
from .core.log import LogManager
import logging
logger = LogManager.GetLogger(log_name="astrbot")
logger = logging.getLogger("astrbot")

View File

@@ -1,3 +1,32 @@
from astrbot.core.config.default import VERSION
import re
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as package_version
from pathlib import Path
__version__ = VERSION
try:
import tomllib
except ModuleNotFoundError:
tomllib = None
try:
__version__ = package_version("astrbot")
except PackageNotFoundError:
pyproject_path = Path(__file__).resolve().parents[2] / "pyproject.toml"
try:
if tomllib is None:
match = re.search(
r"(?m)^version\s*=\s*[\"']([^\"']+)[\"']",
pyproject_path.read_text(encoding="utf-8"),
)
__version__ = match.group(1) if match else "0.0.0"
else:
with pyproject_path.open("rb") as f:
__version__ = tomllib.load(f)["project"]["version"]
except (FileNotFoundError, IndexError, KeyError, TypeError, ValueError):
__version__ = "0.0.0"
match = re.match(r"^(\d+(?:\.\d+)*)(a|b|rc)(\d+)$", __version__)
if match:
release, prerelease, number = match.groups()
prerelease = {"a": "alpha", "b": "beta", "rc": "rc"}[prerelease]
__version__ = f"{release}-{prerelease}.{number}"

View File

@@ -1,16 +1,11 @@
import json
import os
import zoneinfo
from collections.abc import Callable
from typing import Any
import click
from astrbot.core.utils.auth_password import (
hash_dashboard_password,
hash_md5_dashboard_password,
validate_dashboard_password,
)
from ..utils import check_astrbot_root, get_astrbot_root
@@ -44,6 +39,8 @@ def _validate_dashboard_username(value: str) -> str:
def _validate_dashboard_password(value: str) -> str:
"""Validate Dashboard password"""
from astrbot.core.utils.auth_password import validate_dashboard_password
try:
validate_dashboard_password(value)
except ValueError as e:
@@ -89,6 +86,7 @@ def _load_config() -> dict[str, Any]:
raise click.ClickException(
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
os.environ["ASTRBOT_ROOT"] = str(root)
config_path = root / "data" / "cmd_config.json"
if not config_path.exists():
@@ -107,7 +105,8 @@ def _load_config() -> dict[str, Any]:
def _save_config(config: dict[str, Any]) -> None:
"""Save config file"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
root = get_astrbot_root()
config_path = root / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2),
@@ -139,6 +138,11 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
def _set_dashboard_password(config: dict[str, Any], raw_password: str) -> None:
"""Set dashboard password hashes and clear password migration flags."""
from astrbot.core.utils.auth_password import (
hash_dashboard_password,
hash_md5_dashboard_password,
)
_set_nested_item(
config,
"dashboard.pbkdf2_password",

View File

@@ -21,17 +21,16 @@ def _initialize_config_from_env(astrbot_root: Path) -> None:
async def initialize_astrbot(astrbot_root: Path) -> None:
"""Execute AstrBot initialization logic"""
"""Execute AstrBot initialization logic.
Args:
astrbot_root: Runtime root directory to initialize.
"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
if click.confirm(
f"Install AstrBot to this directory? {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
paths = {
"data": astrbot_root / "data",
@@ -41,8 +40,9 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
}
for name, path in paths.items():
path_exists = path.exists()
path.mkdir(parents=True, exist_ok=True)
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
click.echo(f"{'Directory exists' if path_exists else 'Created'}: {path}")
_initialize_config_from_env(astrbot_root)
@@ -53,7 +53,25 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
def init() -> None:
"""Initialize AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
if os.environ.get("ASTRBOT_ROOT"):
astrbot_root = get_astrbot_root()
click.echo(f"Using ASTRBOT_ROOT: {astrbot_root}")
else:
user_root = (Path.home() / ".astrbot").resolve()
current_root = Path.cwd().resolve()
click.echo("Choose AstrBot runtime directory:")
click.echo(f"1. {user_root} (recommended)")
click.echo(f"2. Current directory: {current_root}")
choice = click.prompt(
"Select",
type=click.Choice(["1", "2"]),
default="1",
show_choices=False,
)
astrbot_root = user_root if choice == "1" else current_root
astrbot_root.mkdir(parents=True, exist_ok=True)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
@@ -65,6 +83,8 @@ def init() -> None:
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running"
)
except click.Abort:
raise
except Exception as e:
raise click.ClickException(f"Initialization failed: {e!s}")

View File

@@ -1,3 +1,4 @@
import os
from pathlib import Path
import click
@@ -7,7 +8,14 @@ _BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
def check_astrbot_root(path: str | Path) -> bool:
"""Check if the path is an AstrBot root directory"""
"""Check whether a path is an AstrBot root directory.
Args:
path: Directory path to inspect.
Returns:
Whether the directory contains the AstrBot root marker.
"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
@@ -18,8 +26,24 @@ def check_astrbot_root(path: str | Path) -> bool:
def get_astrbot_root() -> Path:
"""Get the AstrBot root directory path"""
return Path.cwd()
"""Get the AstrBot root directory path.
Returns:
The explicit root, current local root, default user root, or current
directory when no initialized root exists.
"""
if root := os.environ.get("ASTRBOT_ROOT"):
return Path(root).expanduser().resolve()
current_root = Path.cwd().resolve()
if check_astrbot_root(current_root):
return current_root
user_root = (Path.home() / ".astrbot").resolve()
if check_astrbot_root(user_root):
return user_root
return current_root
async def check_dashboard(astrbot_root: Path) -> None:

View File

@@ -9,6 +9,7 @@ from datetime import timedelta
from pathlib import Path, PureWindowsPath
from typing import Any, Generic
import httpx
from tenacity import (
before_sleep_log,
retry,
@@ -102,12 +103,22 @@ except (ModuleNotFoundError, ImportError):
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
)
streamable_http_client_legacy = None
streamable_http_client = None
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
from mcp.client.streamable_http import (
streamablehttp_client as streamable_http_client_legacy,
)
except (ModuleNotFoundError, ImportError):
try:
from mcp.client.streamable_http import (
streamable_http_client as streamable_http_client,
)
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
)
def _prepare_config(config: dict) -> dict:
@@ -459,17 +470,38 @@ class MCPClient:
),
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5),
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
timeout_seconds = cfg.get("timeout", 30)
sse_read_timeout_seconds = cfg.get("sse_read_timeout", 60 * 5)
if streamable_http_client_legacy:
timeout = timedelta(seconds=timeout_seconds)
sse_read_timeout = timedelta(seconds=sse_read_timeout_seconds)
self._streams_context = streamable_http_client_legacy(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
elif streamable_http_client:
http_client = await self.exit_stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.get("headers", {}),
timeout=httpx.Timeout(
timeout_seconds,
read=sse_read_timeout_seconds,
),
follow_redirects=True,
),
)
self._streams_context = streamable_http_client(
url=cfg["url"],
http_client=http_client,
terminate_on_close=cfg.get("terminate_on_close", True),
)
else:
raise RuntimeError(
"Streamable HTTP transport is not available in the installed MCP library version."
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context,
)

View File

@@ -224,6 +224,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
request_max_retries: int | None = None,
tool_result_overflow_dir: str | None = None,
read_tool: FunctionTool | None = None,
**kwargs: T.Any,
@@ -237,6 +238,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
self.request_max_retries = request_max_retries
self.tool_result_overflow_dir = tool_result_overflow_dir
self.read_tool = read_tool
self._tool_result_token_counter = EstimateTokenCounter()
@@ -463,6 +465,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
"abort_signal": self._abort_signal,
"request_max_retries": self.request_max_retries,
}
if include_model:
# For primary provider we keep explicit model selection if provided.
@@ -1305,6 +1308,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
extra_user_content_parts=self.req.extra_user_content_parts,
# tool_choice="required",
abort_signal=self._abort_signal,
request_max_retries=self.request_max_retries,
)
if requery_resp:
llm_resp = requery_resp
@@ -1331,6 +1335,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
extra_user_content_parts=self.req.extra_user_content_parts,
# tool_choice="required",
abort_signal=self._abort_signal,
request_max_retries=self.request_max_retries,
)
if repair_resp:
llm_resp = repair_resp

View File

@@ -278,10 +278,11 @@ async def _apply_kb(
)
if not kb_result:
return
if req.system_prompt is not None:
req.system_prompt += (
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
)
req.extra_user_content_parts.append(
TextPart(
text=f"[Related Knowledge Base Results]:\n{kb_result}",
).mark_as_temp()
)
except Exception as exc: # noqa: BLE001
logger.error("Error occurred while retrieving knowledge base: %s", exc)
else:
@@ -1629,6 +1630,7 @@ async def build_main_agent(
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
fallback_providers=fallback_providers,
request_max_retries=config.provider_settings.get("request_max_retries", 5),
tool_result_overflow_dir=(
get_astrbot_system_tmp_path()
if req.func_tool and req.func_tool.get_tool("astrbot_file_read_tool")

View File

@@ -129,6 +129,7 @@ DEFAULT_CONFIG = {
"enable": True,
"default_provider_id": "",
"fallback_chat_models": [],
"request_max_retries": 5,
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
@@ -2836,6 +2837,9 @@ CONFIG_METADATA_2 = {
"type": "list",
"items": {"type": "string"},
},
"request_max_retries": {
"type": "int",
},
"wake_prefix": {
"type": "string",
},
@@ -3195,6 +3199,11 @@ CONFIG_METADATA_3 = {
"_special": "select_providers",
"hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
},
"provider_settings.request_max_retries": {
"description": "请求最大重试次数",
"type": "int",
"hint": "单次模型请求遇到可重试错误时的最大尝试次数。",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",

View File

@@ -106,6 +106,7 @@ class Provider(AbstractProvider):
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -120,6 +121,7 @@ class Provider(AbstractProvider):
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
request_max_retries: 可重试请求错误的最大尝试次数,包含首次请求。
kwargs: 其他参数
Notes:
@@ -142,6 +144,7 @@ class Provider(AbstractProvider):
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -155,6 +158,7 @@ class Provider(AbstractProvider):
tool_choice: 工具调用策略,`auto` 表示由模型自行决定,`required` 表示要求模型必须调用工具
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
request_max_retries: 可重试请求错误的最大尝试次数,包含首次请求。
kwargs: 其他参数
Notes:

View File

@@ -27,6 +27,7 @@ from astrbot.core.utils.network_utils import (
)
from ..register import register_provider_adapter
from .request_retry import retry_provider_request, retry_provider_request_context
@register_provider_adapter(
@@ -353,7 +354,13 @@ class ProviderAnthropic(Provider):
logger.warning(f"未知的 tool_choice 值: {tool_choice},已回退为 'auto'")
return {"type": "auto"}
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
async def _query(
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
@@ -368,8 +375,12 @@ class ProviderAnthropic(Provider):
self._apply_thinking_config(payloads)
try:
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
completion = await retry_provider_request(
"Anthropic",
lambda: self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
),
max_attempts=request_max_retries,
)
except httpx.RequestError as e:
proxy = self.provider_config.get("proxy", "")
@@ -438,6 +449,8 @@ class ProviderAnthropic(Provider):
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -461,8 +474,10 @@ class ProviderAnthropic(Provider):
payloads["max_tokens"] = 65536
self._apply_thinking_config(payloads)
async with self.client.messages.stream(
**payloads, extra_body=extra_body
async with retry_provider_request_context(
"Anthropic",
lambda: self.client.messages.stream(**payloads, extra_body=extra_body),
max_attempts=request_max_retries,
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
@@ -601,6 +616,7 @@ class ProviderAnthropic(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -650,7 +666,11 @@ class ProviderAnthropic(Provider):
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
llm_response = await self._query(
payloads,
func_tool,
request_max_retries=request_max_retries,
)
except Exception as e:
raise e
@@ -669,6 +689,7 @@ class ProviderAnthropic(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto",
request_max_retries: int | None = None,
**kwargs,
):
if contexts is None:
@@ -715,7 +736,11 @@ class ProviderAnthropic(Provider):
else system_prompt
)
async for llm_response in self._query_stream(payloads, func_tool):
async for llm_response in self._query_stream(
payloads,
func_tool,
request_max_retries=request_max_retries,
):
yield llm_response
def _detect_image_mime_type(self, data: bytes) -> str:
@@ -827,7 +852,10 @@ class ProviderAnthropic(Provider):
async def get_models(self) -> list[str]:
models_str = []
models = await self.client.models.list()
models = await retry_provider_request(
"Anthropic",
lambda: self.client.models.list(),
)
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)

View File

@@ -26,6 +26,7 @@ from astrbot.core.utils.media_utils import (
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
from ..register import register_provider_adapter
from .request_retry import retry_provider_request
class SuppressNonTextPartsWarning(logging.Filter):
@@ -577,7 +578,13 @@ class ProviderGoogleGenAI(Provider):
)
return chain_result
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
async def _query(
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
@@ -604,10 +611,14 @@ class ProviderGoogleGenAI(Provider):
modalities,
temperature,
)
result = await self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
max_attempts=request_max_retries,
)
logger.debug(f"genai result: {result}")
@@ -672,6 +683,8 @@ class ProviderGoogleGenAI(Provider):
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -690,10 +703,14 @@ class ProviderGoogleGenAI(Provider):
payloads.get("tool_choice", "auto"),
system_instruction,
)
result = await self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
max_attempts=request_max_retries,
)
break
except APIError as e:
@@ -809,6 +826,7 @@ class ProviderGoogleGenAI(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -850,7 +868,11 @@ class ProviderGoogleGenAI(Provider):
for _ in range(retry):
try:
return await self._query(payloads, func_tool)
return await self._query(
payloads,
func_tool,
request_max_retries=request_max_retries,
)
except APIError as e:
if await self._handle_api_error(e, keys):
continue
@@ -871,6 +893,7 @@ class ProviderGoogleGenAI(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
@@ -912,7 +935,11 @@ class ProviderGoogleGenAI(Provider):
for _ in range(retry):
try:
async for response in self._query_stream(payloads, func_tool):
async for response in self._query_stream(
payloads,
func_tool,
request_max_retries=request_max_retries,
):
yield response
break
except APIError as e:
@@ -922,7 +949,10 @@ class ProviderGoogleGenAI(Provider):
async def get_models(self):
try:
models = await self.client.models.list()
models = await retry_provider_request(
"Gemini",
lambda: self.client.models.list(),
)
return [
m.name.replace("models/", "")
for m in models

View File

@@ -41,6 +41,7 @@ from astrbot.core.utils.network_utils import (
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
from .request_retry import retry_provider_request
@register_provider_adapter(
@@ -420,7 +421,10 @@ class ProviderOpenAIOfficial(Provider):
async def get_models(self):
try:
models_str = []
models = await self.client.models.list()
models = await retry_provider_request(
"OpenAI",
lambda: self.client.models.list(),
)
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
@@ -465,7 +469,13 @@ class ProviderOpenAIOfficial(Provider):
payloads["messages"] = cleaned
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
async def _query(
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
@@ -496,10 +506,14 @@ class ProviderOpenAIOfficial(Provider):
self._sanitize_assistant_messages(payloads)
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
completion = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
),
max_attempts=request_max_retries,
)
if not isinstance(completion, ChatCompletion):
@@ -517,6 +531,8 @@ class ProviderOpenAIOfficial(Provider):
self,
payloads: dict,
tools: ToolSet | None,
*,
request_max_retries: int | None = None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API逐步返回结果"""
if tools:
@@ -548,11 +564,15 @@ class ProviderOpenAIOfficial(Provider):
self._sanitize_assistant_messages(payloads)
stream = await self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
stream = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
),
max_attempts=request_max_retries,
)
llm_response = LLMResponse("assistant", is_chunk=True)
@@ -1104,6 +1124,7 @@ class ProviderOpenAIOfficial(Provider):
model=None,
extra_user_content_parts=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -1131,7 +1152,11 @@ class ProviderOpenAIOfficial(Provider):
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
llm_response = await self._query(
payloads,
func_tool,
request_max_retries=request_max_retries,
)
break
except Exception as e:
last_exception = e
@@ -1176,6 +1201,7 @@ class ProviderOpenAIOfficial(Provider):
tool_calls_result=None,
model=None,
tool_choice: Literal["auto", "required"] = "auto",
request_max_retries: int | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
@@ -1202,7 +1228,11 @@ class ProviderOpenAIOfficial(Provider):
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
async for response in self._query_stream(
payloads,
func_tool,
request_max_retries=request_max_retries,
):
yield response
break
except Exception as e:

View File

@@ -0,0 +1,163 @@
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TypeVar
from tenacity import (
AsyncRetrying,
RetryCallState,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
from astrbot import logger
from astrbot.core.utils.config_number import coerce_int_config
from astrbot.core.utils.network_utils import is_connection_error
T = TypeVar("T")
REQUEST_RETRY_ATTEMPTS = 5 # default value
REQUEST_RETRY_WAIT_MIN_S = 0.2
REQUEST_RETRY_WAIT_MAX_S = 30
REQUEST_RETRY_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504, 529}
def _get_status_code(error: BaseException) -> int | None:
for attr in ("status_code", "status", "code"):
value = getattr(error, attr, None)
if isinstance(value, int):
return value
response = getattr(error, "response", None)
if response is not None:
status_code = getattr(response, "status_code", None)
if isinstance(status_code, int):
return status_code
return None
def _is_retryable_provider_request_error(
error: BaseException,
*,
retry_rate_limits: bool,
) -> bool:
if is_connection_error(error):
return True
error_type_name = type(error).__name__
if error_type_name in {"APIConnectionError", "APITimeoutError"}:
return True
status_code = _get_status_code(error)
if status_code is None:
return False
if status_code == 429 and not retry_rate_limits:
return False
return status_code in REQUEST_RETRY_STATUS_CODES or 500 <= status_code <= 599
def _log_retry(
provider_label: str,
retry_state: RetryCallState,
max_attempts: int,
) -> None:
error = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
f"[{provider_label}] Request failed with retryable error; "
f"retrying ({retry_state.attempt_number + 1}/{max_attempts}): "
f"{error}"
)
def _build_retrying(
provider_label: str,
*,
retry_rate_limits: bool,
max_attempts: int | None = None,
) -> AsyncRetrying:
max_attempts = coerce_int_config(
max_attempts if max_attempts is not None else REQUEST_RETRY_ATTEMPTS,
default=REQUEST_RETRY_ATTEMPTS,
min_value=1,
field_name="request_max_retries",
source=provider_label,
)
return AsyncRetrying(
retry=retry_if_exception(
lambda error: _is_retryable_provider_request_error(
error,
retry_rate_limits=retry_rate_limits,
)
),
stop=stop_after_attempt(max_attempts),
wait=wait_exponential(
multiplier=1,
min=REQUEST_RETRY_WAIT_MIN_S,
max=REQUEST_RETRY_WAIT_MAX_S,
),
before_sleep=lambda retry_state: _log_retry(
provider_label,
retry_state,
max_attempts,
),
reraise=True,
)
async def retry_provider_request(
provider_label: str,
request_factory: Callable[[], Awaitable[T]],
*,
retry_rate_limits: bool = True,
max_attempts: int | None = None,
) -> T:
retrying = _build_retrying(
provider_label,
retry_rate_limits=retry_rate_limits,
max_attempts=max_attempts,
)
async for attempt in retrying:
with attempt:
return await request_factory()
raise RuntimeError("Provider request retry loop exited unexpectedly.")
@asynccontextmanager
async def retry_provider_request_context(
provider_label: str,
context_manager_factory: Callable[[], AbstractAsyncContextManager[T]],
*,
retry_rate_limits: bool = True,
max_attempts: int | None = None,
) -> AsyncIterator[T]:
manager: AbstractAsyncContextManager[T] | None = None
async def _enter_context() -> T:
nonlocal manager
manager = context_manager_factory()
return await manager.__aenter__()
value = await retry_provider_request(
provider_label,
_enter_context,
retry_rate_limits=retry_rate_limits,
max_attempts=max_attempts,
)
if manager is None:
raise RuntimeError("Provider request context was not created.")
try:
yield value
except BaseException as error:
if await manager.__aexit__(type(error), error, error.__traceback__):
return
raise
else:
await manager.__aexit__(None, None, None)

View File

@@ -183,8 +183,22 @@ async def download_file(
path: str,
show_progress: bool = False,
progress_callback=None,
allow_insecure_ssl_fallback: bool = True,
) -> None:
"""从指定 url 下载文件到指定路径 path"""
"""Download a remote file to a local path.
Args:
url: Remote URL to download.
path: Local destination path.
show_progress: Whether to print progress to stdout.
progress_callback: Optional callback for progress payloads.
allow_insecure_ssl_fallback: Whether certificate failures may retry with
TLS certificate verification disabled.
Returns:
None.
"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
@@ -259,6 +273,8 @@ async def download_file(
},
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
if not allow_insecure_ssl_fallback:
raise
# 关闭SSL验证仅在证书验证失败时作为fallback
logger.warning(
f"SSL certificate verification failed for {_safe_url_for_log(url)}. "
@@ -355,10 +371,22 @@ def get_local_ip_addresses():
return network_ips
def _read_dashboard_dist_version(dist_dir: str | Path) -> str | None:
def get_dashboard_dist_version(dist_dir: str | Path) -> str | None:
"""Read the WebUI version from a dashboard dist directory.
Args:
dist_dir: Dashboard dist directory path.
Returns:
The version string from assets/version, or None when unavailable.
"""
version_file = Path(dist_dir) / "assets" / "version"
if version_file.exists():
return version_file.read_text(encoding="utf-8").strip()
try:
if version_file.exists():
return version_file.read_text(encoding="utf-8").strip()
except (OSError, UnicodeDecodeError) as exc:
logger.warning("Failed to read WebUI version from %s: %s", version_file, exc)
return None
@@ -380,42 +408,106 @@ def _normalize_dashboard_version(version: str) -> str:
return version
def should_use_bundled_dashboard_dist(
user_dist: str | Path, current_version: str
def is_dashboard_version_compatible(
dashboard_version: str | None, current_version: str
) -> bool:
user_version = _read_dashboard_dist_version(user_dist)
bundled_dist = get_bundled_dashboard_dist_path()
if user_version is None or not bundled_dist.exists():
"""Check whether a WebUI version matches the current core version.
Args:
dashboard_version: Version read from the WebUI assets/version file.
current_version: Current AstrBot core version.
Returns:
True when both versions are valid SemVer values and compare equal.
"""
if dashboard_version is None:
return False
try:
return (
VersionComparator.compare_version(
_normalize_dashboard_version(dashboard_version),
_normalize_dashboard_version(current_version),
_normalize_dashboard_version(user_version),
)
> 0
== 0
)
except (TypeError, ValueError):
return False
def is_dashboard_dist_compatible(dist_dir: str | Path, current_version: str) -> bool:
"""Check whether a WebUI dist is complete and matches the core version.
Args:
dist_dir: Dashboard dist directory path.
current_version: Current AstrBot core version.
Returns:
True when the dist has an index file and a compatible assets/version.
"""
dist_path = Path(dist_dir)
return (dist_path / "index.html").is_file() and is_dashboard_version_compatible(
get_dashboard_dist_version(dist_path),
current_version,
)
def should_use_bundled_dashboard_dist(
user_dist: str | Path, current_version: str
) -> bool:
"""Decide whether bundled WebUI should replace a user data dist.
Args:
user_dist: Runtime dashboard dist directory under data/.
current_version: Current AstrBot core version.
Returns:
True when user_dist exists but is missing or mismatched against the
current core version, and bundled WebUI matches the current core version.
"""
user_dist = Path(user_dist)
user_version = get_dashboard_dist_version(user_dist)
bundled_dist = get_bundled_dashboard_dist_path()
if not user_dist.exists() or not is_dashboard_dist_compatible(
bundled_dist,
current_version,
):
return False
if user_version is None or not (user_dist / "index.html").is_file():
return True
try:
return not is_dashboard_version_compatible(user_version, current_version)
except (TypeError, ValueError):
return False
async def get_dashboard_version():
"""Return the effective WebUI version for the current runtime.
Returns:
The matching data/dist version, matching bundled version, or the raw
data/dist version when no compatible bundled WebUI is available.
"""
from astrbot.core.config.default import VERSION
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(dist_dir):
from astrbot.core.config.default import VERSION
user_version = get_dashboard_dist_version(dist_dir)
if is_dashboard_dist_compatible(dist_dir, VERSION):
return user_version
if should_use_bundled_dashboard_dist(dist_dir, VERSION):
bundled_version = _read_dashboard_dist_version(
get_bundled_dashboard_dist_path()
)
if bundled_version is not None:
return bundled_version
return _read_dashboard_dist_version(dist_dir)
bundled = get_bundled_dashboard_dist_path()
if is_dashboard_dist_compatible(bundled, VERSION):
return get_dashboard_dist_version(bundled)
return user_version
bundled = get_bundled_dashboard_dist_path()
if bundled.exists():
return _read_dashboard_dist_version(bundled)
if is_dashboard_dist_compatible(bundled, VERSION):
return get_dashboard_dist_version(bundled)
return None
@@ -427,6 +519,7 @@ async def download_dashboard(
proxy: str | None = None,
progress_callback=None,
extract: bool = True,
allow_insecure_ssl_fallback: bool = True,
) -> None:
"""Download dashboard assets and optionally extract them.
@@ -438,6 +531,8 @@ async def download_dashboard(
proxy: Optional download proxy prefix.
progress_callback: Optional callback for download progress payloads.
extract: Whether to extract the archive after download.
allow_insecure_ssl_fallback: Whether certificate failures may retry with
TLS certificate verification disabled.
Returns:
None.
@@ -460,6 +555,7 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError(
@@ -491,6 +587,7 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError(
@@ -506,6 +603,7 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError("Downloaded dashboard package is not a valid ZIP file")

View File

@@ -22,7 +22,9 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import (
get_bundled_dashboard_dist_path,
get_dashboard_dist_version,
get_local_ip_addresses,
is_dashboard_dist_compatible,
should_use_bundled_dashboard_dist,
)
from astrbot.dashboard.asgi_runtime import (
@@ -182,21 +184,32 @@ class AstrBotDashboard:
# Path priority:
# 1. Explicit webui_dir argument
# 2. data/dist/ (user-installed / manually updated dashboard)
# 3. astrbot/dashboard/dist/ (bundled with the wheel)
# 2. data/dist/ when it matches the core version
# 3. astrbot/dashboard/dist/ when it matches the core version
if webui_dir and os.path.exists(webui_dir):
self.data_path = os.path.abspath(webui_dir)
else:
user_dist = os.path.join(get_astrbot_data_path(), "dist")
bundled_dist = get_bundled_dashboard_dist_path()
if os.path.exists(user_dist) and not should_use_bundled_dashboard_dist(
user_version = get_dashboard_dist_version(user_dist)
if os.path.exists(user_dist) and is_dashboard_dist_compatible(
user_dist,
VERSION,
):
self.data_path = os.path.abspath(user_dist)
elif bundled_dist.exists():
elif should_use_bundled_dashboard_dist(
user_dist,
VERSION,
) or is_dashboard_dist_compatible(bundled_dist, VERSION):
self.data_path = str(bundled_dist)
logger.info("Using bundled dashboard dist: %s", self.data_path)
elif os.path.exists(user_dist):
logger.warning(
"Ignoring data/dist because WebUI version mismatches core: %s, expected v%s.",
user_version,
VERSION,
)
self.data_path = None
else:
# Fall back to expected user path (will fail gracefully later)
self.data_path = os.path.abspath(user_dist)
@@ -545,7 +558,7 @@ class AstrBotDashboard:
raise Exception(f"端口 {port} 已被占用")
if (Path(self.data_path) / "index.html").is_file():
if self.data_path and (Path(self.data_path) / "index.html").is_file():
webui_status = "WebUI is ready"
else:
webui_status = (

View File

@@ -0,0 +1,54 @@
- [更新日志(简体中文)](#chinese)
- [Changelog(English)](#english)
<a id="chinese"></a>
## What's Changed
### 重点更新
- 为 OpenAI、Gemini、Anthropic 等模型请求加入可配置的重试机制,并新增请求最大重试次数配置,提升临时网络错误与 5xx 服务端错误下的稳定性。([#8893](https://github.com/AstrBotDevs/AstrBot/pull/8893))
- 新增托管 Core 包下载能力,并加强 Core 与 Dashboard 包下载归档校验。([#8888](https://github.com/AstrBotDevs/AstrBot/pull/8888))
- 支持在请求中加载 workspace skills并加固 workspace skill 发现流程。([#8884](https://github.com/AstrBotDevs/AstrBot/pull/8884))
### 修复
- 修复 OpenAPI 文件上传能力,恢复 `/api/v1/file` OpenAPI 暴露、文件范围 API Key 与相关文档/客户端产物。
- 修复新版 MCP 中 Streamable HTTP client 重命名导致的兼容问题,并保持 `mcp` 依赖小于 2。
- 加固人格工具边界,确保人格限定的工具范围在主 Agent 请求中正确生效。([#8786](https://github.com/AstrBotDevs/AstrBot/pull/8786))
- 加强 Future Task 所有者校验,避免越权访问定时任务。([#8881](https://github.com/AstrBotDevs/AstrBot/pull/8881))
- 在受限本地文件系统工具中拒绝 hardlink 文件,避免通过工作区 hardlink 别名读写允许目录外的文件。
### 发布流程
- 新增 `scripts/prepare_release.py`,统一 release 分支、版本号、changelog 与校验流程。([#8891](https://github.com/AstrBotDevs/AstrBot/pull/8891))
### 文档
- 明确 OpenAPI Chat 中 `username` 字段的身份含义。([#8880](https://github.com/AstrBotDevs/AstrBot/pull/8880))
<a id="english"></a>
## What's Changed (EN)
### Highlights
- Added configurable retry handling for OpenAI, Gemini, Anthropic, and related provider requests, including a maximum request retry setting to improve stability for transient network failures and 5xx server errors. ([#8893](https://github.com/AstrBotDevs/AstrBot/pull/8893))
- Added hosted Core package downloads and strengthened archive validation for hosted Core and Dashboard packages. ([#8888](https://github.com/AstrBotDevs/AstrBot/pull/8888))
- Added workspace skills support in requests and hardened workspace skill discovery. ([#8884](https://github.com/AstrBotDevs/AstrBot/pull/8884))
### Bug Fixes
- Restored OpenAPI file uploads by exposing `/api/v1/file`, enabling file-scoped API keys, and regenerating docs/client artifacts.
- Fixed compatibility with the renamed MCP Streamable HTTP client while keeping the `mcp` dependency below 2.
- Hardened persona tool boundaries so persona-restricted tool scopes are enforced correctly in main Agent requests. ([#8786](https://github.com/AstrBotDevs/AstrBot/pull/8786))
- Enforced Future Task owner checks to prevent unauthorized scheduled-task access. ([#8881](https://github.com/AstrBotDevs/AstrBot/pull/8881))
- Rejected hardlinked files in restricted local filesystem tools to prevent workspace hardlink aliases from reading or overwriting files outside allowed directories.
### Release Process
- Added `scripts/prepare_release.py` to standardize release branches, version bumps, changelog generation, and validation. ([#8891](https://github.com/AstrBotDevs/AstrBot/pull/8891))
### Docs
- Clarified the identity semantics of the `username` field in OpenAPI Chat. ([#8880](https://github.com/AstrBotDevs/AstrBot/pull/8880))

View File

@@ -48,6 +48,55 @@ function attachAxiosHeaders(config: InternalAxiosRequestConfig) {
}
function normalizeAxiosError(error: AxiosError) {
if (error.response?.status === 401) {
let requestPath = '';
try {
const url = error.config?.url || '';
const baseURL = error.config?.baseURL;
const resolvedUrl =
url && baseURL && !/^([a-z][a-z\d+\-.]*:)?\/\//i.test(url)
? `${baseURL.replace(/\/+$/, '')}/${url.replace(/^\/+/, '')}`
: url;
const requestUrl = new URL(resolvedUrl || '/', window.location.origin);
if (requestUrl.origin === window.location.origin) {
requestPath = requestUrl.pathname;
}
} catch {
requestPath = '';
}
const isAuthChallenge =
[
'/api/auth/login',
'/api/auth/setup',
'/api/auth/setup-status',
'/api/v1/auth/login',
'/api/v1/auth/setup',
'/api/v1/auth/setup-status',
].includes(requestPath) ||
Boolean(
(
error.response.data as
| { data?: { totp_required?: boolean } }
| undefined
)?.data?.totp_required,
);
if (requestPath.startsWith('/api/') && !isAuthChallenge) {
[
'user',
'token',
'change_pwd_hint',
'md5_pwd_hint',
'password_upgrade_required',
].forEach((key) => localStorage.removeItem(key));
if (!window.location.hash.startsWith('#/auth/login')) {
window.location.hash = '/auth/login';
}
}
}
if (error.response?.status === 429) {
const data = error.response.data as { message?: string } | undefined;
if (data?.message) {

View File

@@ -45,6 +45,10 @@
"description": "Fallback chat model IDs",
"hint": "When the primary chat model request fails, fallback to these chat models in order."
},
"request_max_retries": {
"description": "Request Max Retries",
"hint": "Maximum attempts for a single model request when retryable errors occur."
},
"default_image_caption_provider_id": {
"description": "Default Image Caption Model",
"hint": "Leave empty to disable; useful for non-multimodal models"

View File

@@ -45,6 +45,10 @@
"description": "Резервные модели чата (ID)",
"hint": "Если текущая модель недоступна, запрос будет перенаправлен на эти модели по порядку."
},
"request_max_retries": {
"description": "Максимум повторов запроса",
"hint": "Максимальное число попыток для одного запроса модели при повторяемых ошибках."
},
"default_image_caption_provider_id": {
"description": "Модель описания изображений",
"hint": "Оставьте пустым для отключения; полезно для моделей без поддержки мультимодальности"

View File

@@ -45,6 +45,10 @@
"description": "回退对话模型列表",
"hint": "主对话模型请求失败时,按顺序切换到这些对话模型。"
},
"request_max_retries": {
"description": "请求最大重试次数",
"hint": "单次模型请求遇到可重试错误时的最大尝试次数。"
},
"default_image_caption_provider_id": {
"description": "默认图片转述模型",
"hint": "留空代表不使用,可用于非多模态模型"

91
main.py
View File

@@ -2,6 +2,7 @@ import argparse
import asyncio
import mimetypes
import os
import shutil
import sys
from pathlib import Path
@@ -46,7 +47,10 @@ from astrbot.core.utils.astrbot_path import ( # noqa: E402
from astrbot.core.utils.io import ( # noqa: E402
download_dashboard,
get_bundled_dashboard_dist_path,
get_dashboard_version,
get_dashboard_dist_version,
is_dashboard_dist_compatible,
is_dashboard_version_compatible,
remove_dir,
should_use_bundled_dashboard_dist,
)
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime # noqa: E402
@@ -91,7 +95,15 @@ def check_env() -> None:
async def check_dashboard_files(webui_dir: str | None = None):
"""下载管理面板文件"""
"""Resolve and repair dashboard static files for startup.
Args:
webui_dir: Optional explicit WebUI directory path from CLI.
Returns:
The directory path to serve, or None when no usable WebUI can be prepared.
"""
# 指定webui目录
if webui_dir:
if os.path.exists(webui_dir):
@@ -99,40 +111,81 @@ async def check_dashboard_files(webui_dir: str | None = None):
return webui_dir
logger.warning("WebUI directory not found: %s. Using default.", webui_dir)
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(data_dist_path):
v = await get_dashboard_version()
data_dist_path = Path(get_astrbot_data_path()) / "dist"
bundled_dist = get_bundled_dashboard_dist_path()
if data_dist_path.exists():
v = get_dashboard_dist_version(data_dist_path)
if is_dashboard_dist_compatible(data_dist_path, VERSION):
logger.info("WebUI is up to date.")
return str(data_dist_path)
if should_use_bundled_dashboard_dist(data_dist_path, VERSION):
bundled_dist = get_bundled_dashboard_dist_path()
logger.info(
"Using bundled WebUI because data/dist is older than core version v%s.",
"Replacing data/dist with bundled WebUI because its version does not match core version v%s.",
VERSION,
)
return str(bundled_dist)
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI is up to date.")
else:
try:
remove_dir(str(data_dist_path))
shutil.copytree(bundled_dist, data_dist_path)
return str(data_dist_path)
except Exception as e:
logger.warning(
"WebUI version mismatch: %s, expected v%s.",
v,
VERSION,
"Failed to replace data/dist with bundled WebUI: %s. Using bundled WebUI directly.",
e,
)
return data_dist_path
return str(bundled_dist)
if is_dashboard_version_compatible(v, VERSION):
logger.warning(
"WebUI files are incomplete for v%s. Re-downloading WebUI.",
VERSION,
)
elif v is not None:
logger.warning(
"WebUI version mismatch: %s, expected v%s. Re-downloading WebUI.",
v,
VERSION,
)
else:
logger.warning(
"WebUI version file is missing. Re-downloading WebUI v%s.",
VERSION,
)
try:
await download_dashboard(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return str(data_dist_path)
if is_dashboard_dist_compatible(bundled_dist, VERSION):
logger.info(
"Using bundled WebUI v%s.", get_dashboard_dist_version(bundled_dist)
)
return str(bundled_dist)
logger.info(
"Downloading WebUI. If it fails, download dist.zip from https://github.com/AstrBotDevs/AstrBot/releases/latest and extract dist to data/.",
)
try:
await download_dashboard(version=f"v{VERSION}", latest=False)
await download_dashboard(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return data_dist_path
return str(data_dist_path)
async def main_async(webui_dir_arg: str | None) -> None:

View File

@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.26.0-beta.8"
version = "4.26.0-beta.9"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
license = { text = "AGPL-3.0-or-later" }
@@ -29,7 +29,7 @@ dependencies = [
"google-genai>=1.56.0",
"httpx[socks]>=0.28.1",
"lark-oapi>=1.4.15",
"mcp>=1.8.0",
"mcp>=1.8.0,<2",
"openai>=1.78.0",
"ormsgpack>=1.9.1",
"pillow>=11.2.1",

View File

@@ -18,7 +18,7 @@ filelock>=3.18.0
google-genai>=1.56.0
httpx[socks]>=0.28.1
lark-oapi>=1.4.15
mcp>=1.8.0
mcp>=1.8.0,<2
openai>=1.78.0
ormsgpack>=1.9.1
pillow>=11.2.1

View File

@@ -1,9 +1,12 @@
import builtins
from types import SimpleNamespace
import httpx
import pytest
import astrbot.core.provider.sources.anthropic_source as anthropic_source
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
@@ -171,6 +174,36 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
assert captured["httpx_module"] is anthropic_source.httpx
@pytest.mark.asyncio
async def test_anthropic_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="claude-b"),
SimpleNamespace(id="claude-a"),
]
)
models = FakeModels()
provider = anthropic_source.ProviderAnthropic.__new__(
anthropic_source.ProviderAnthropic
)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["claude-a", "claude-b"]
assert models.calls == 2
@pytest.mark.asyncio
async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
@@ -187,7 +220,7 @@ async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
captured_payloads: dict[str, object] = {}
async def fake_query(payloads, tools):
async def fake_query(payloads, tools, *, request_max_retries=None):
captured_payloads.update(payloads)
return LLMResponse(role="assistant", completion_text="ok")
@@ -214,7 +247,7 @@ async def test_text_chat_passes_through_list_system_prompt(monkeypatch):
captured_payloads: dict[str, object] = {}
async def fake_query(payloads, tools):
async def fake_query(payloads, tools, *, request_max_retries=None):
captured_payloads.update(payloads)
return LLMResponse(role="assistant", completion_text="ok")

View File

@@ -1,6 +1,11 @@
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
from click.testing import CliRunner
from astrbot.cli.commands import cmd_init
from astrbot.core.utils.auth_password import verify_dashboard_password
@@ -14,6 +19,7 @@ async def test_init_without_initial_password_env_does_not_create_config(
async def fake_check_dashboard(_data_path):
return None
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.delenv(cmd_init.DASHBOARD_INITIAL_PASSWORD_ENV, raising=False)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
(tmp_path / ".astrbot").touch()
@@ -32,6 +38,7 @@ async def test_init_uses_initial_password_env_to_create_config(
return None
initial_password = "AstrBotInitialPassword123"
monkeypatch.setenv("ASTRBOT_ROOT", str(tmp_path))
monkeypatch.setenv(cmd_init.DASHBOARD_INITIAL_PASSWORD_ENV, initial_password)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
(tmp_path / ".astrbot").touch()
@@ -52,3 +59,71 @@ async def test_init_uses_initial_password_env_to_create_config(
)
assert dashboard_config["password_change_required"] is True
assert dashboard_config["password_storage_upgraded"] is True
def test_cli_main_import_does_not_create_cwd_data(tmp_path):
repo_root = Path(__file__).resolve().parents[1]
env = os.environ.copy()
env.pop("ASTRBOT_ROOT", None)
env["HOME"] = str(tmp_path / "home")
env["PYTHONPATH"] = (
str(repo_root)
if not env.get("PYTHONPATH")
else f"{repo_root}{os.pathsep}{env['PYTHONPATH']}"
)
result = subprocess.run(
[sys.executable, "-c", "import astrbot.cli.__main__"],
cwd=tmp_path,
env=env,
capture_output=True,
text=True,
check=False,
)
assert result.returncode == 0, result.stderr
assert not (tmp_path / "data").exists()
def test_init_defaults_to_user_runtime(monkeypatch, tmp_path):
async def fake_check_dashboard(_data_path):
return None
home = tmp_path / "home"
workdir = tmp_path / "workdir"
home.mkdir()
workdir.mkdir()
monkeypatch.setenv("HOME", str(home))
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(workdir)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
result = CliRunner().invoke(cmd_init.init, input="\n", env={"ASTRBOT_ROOT": ""})
assert result.exit_code == 0, result.output
assert (home / ".astrbot" / ".astrbot").exists()
assert (home / ".astrbot" / "data" / "config").is_dir()
assert not (workdir / "data").exists()
def test_init_can_install_to_current_directory(monkeypatch, tmp_path):
async def fake_check_dashboard(_data_path):
return None
home = tmp_path / "home"
workdir = tmp_path / "workdir"
home.mkdir()
workdir.mkdir()
monkeypatch.setenv("HOME", str(home))
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(workdir)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
result = CliRunner().invoke(cmd_init.init, input="2\n", env={"ASTRBOT_ROOT": ""})
assert result.exit_code == 0, result.output
assert (workdir / ".astrbot").exists()
assert (workdir / "data" / "config").is_dir()
assert not (home / ".astrbot").exists()

View File

@@ -30,6 +30,7 @@ def _read_config(config_path):
def test_password_command_changes_dashboard_password(monkeypatch, tmp_path):
config_path = _write_config(tmp_path)
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
runner = CliRunner()
@@ -55,6 +56,7 @@ def test_password_command_changes_dashboard_password(monkeypatch, tmp_path):
def test_password_command_can_update_dashboard_username(monkeypatch, tmp_path):
config_path = _write_config(tmp_path)
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
runner = CliRunner()
@@ -71,6 +73,7 @@ def test_password_command_can_update_dashboard_username(monkeypatch, tmp_path):
def test_conf_set_dashboard_password_updates_password_state(monkeypatch, tmp_path):
config_path = _write_config(tmp_path)
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
runner = CliRunner()

View File

@@ -273,6 +273,7 @@ def test_dashboard_uses_bundled_dist_when_data_dist_is_stale(
bundled_dist = tmp_path / "bundled-dist"
user_dist.mkdir(parents=True)
bundled_dist.mkdir()
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
@@ -293,6 +294,32 @@ def test_dashboard_uses_bundled_dist_when_data_dist_is_stale(
assert server.data_path == str(bundled_dist)
def test_dashboard_ignores_mismatched_data_dist_without_bundled(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path,
):
data_dir = tmp_path / "data"
user_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.dashboard.server.get_bundled_dashboard_dist_path",
lambda: bundled_dist,
)
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
assert server.data_path is None
async def _set_dashboard_password_change_required(
core_lifecycle_td: AstrBotCoreLifecycle,
required: bool,

View File

@@ -1,6 +1,10 @@
from types import SimpleNamespace
import httpx
import pytest
from astrbot.core.exceptions import EmptyModelOutputError
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
@@ -27,3 +31,35 @@ def test_gemini_reasoning_only_output_is_allowed():
response_id="resp_reasoning",
finish_reason="STOP",
)
@pytest.mark.asyncio
async def test_gemini_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return [
SimpleNamespace(
name="models/gemini-a",
supported_actions=["generateContent"],
),
SimpleNamespace(
name="models/gemini-b",
supported_actions=["embedContent"],
),
]
models = FakeModels()
provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gemini-a"]
assert models.calls == 2

View File

@@ -9,7 +9,7 @@ from unittest import mock
import pytest
from astrbot.core.utils.io import should_use_bundled_dashboard_dist
from astrbot.core.utils.io import get_dashboard_version, should_use_bundled_dashboard_dist
from main import (
DASHBOARD_RESET_PASSWORD_ENV,
_apply_startup_env_flags,
@@ -173,49 +173,108 @@ def test_version_info_comparisons():
@pytest.mark.asyncio
async def test_check_dashboard_files_not_exists(monkeypatch):
async def test_check_dashboard_files_not_exists(tmp_path):
"""Tests dashboard download when files do not exist."""
monkeypatch.setattr(os.path, "exists", lambda x: False)
data_dir = tmp_path / "data"
bundled_dist = tmp_path / "bundled-dist"
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
from main import VERSION
assert result == str(data_dir / "dist")
mock_download.assert_called_once()
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
async def test_check_dashboard_files_exists_and_version_match(tmp_path):
"""Tests that dashboard is not downloaded when it exists and version matches."""
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)
from main import VERSION
# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION
mock_get_version.return_value = f"v{VERSION}"
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(data_dist / "index.html").write_text("user", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
# Assert that download_dashboard was NOT called
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_not_called()
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)
async def test_check_dashboard_files_exists_but_version_mismatch_downloads(tmp_path):
"""Tests that a mismatched dashboard is downloaded on startup."""
from main import VERSION
with mock.patch(
"main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1")
):
with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.logger.warning") as mock_logger_warning:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "WebUI version mismatch" in call_args[0]
@pytest.mark.asyncio
async def test_check_dashboard_files_downloads_when_matching_dist_is_incomplete(
tmp_path,
):
"""Tests that a version match alone is not enough to serve WebUI."""
from main import VERSION
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
@@ -223,6 +282,7 @@ def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
(bundled_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v4.24.2", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
@@ -231,46 +291,94 @@ def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
assert should_use_bundled_dashboard_dist(user_dist, "v4.24.4") is True
def test_should_keep_data_dist_when_version_file_is_malformed(tmp_path):
def test_should_use_bundled_dashboard_dist_when_version_file_is_malformed(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("not-a-version", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is False
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is True
def test_should_use_bundled_dashboard_dist_when_data_version_file_is_missing(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is True
@pytest.mark.asyncio
async def test_check_dashboard_files_uses_bundled_dist_when_data_dist_is_stale(
async def test_get_dashboard_version_uses_bundled_dist_when_data_dist_is_missing(
tmp_path,
):
"""Tests that a stale data/dist does not override bundled dashboard assets."""
"""Tests bundled WebUI version lookup when data/dist is absent."""
from main import VERSION
data_dir = tmp_path / "data"
bundled_dist = tmp_path / "bundled-dist"
(bundled_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_astrbot_data_path",
return_value=str(data_dir),
):
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert await get_dashboard_version() == f"v{VERSION}"
@pytest.mark.asyncio
async def test_check_dashboard_files_replaces_stale_data_dist_with_bundled_dist(
tmp_path,
):
"""Tests that a stale data/dist is repaired from bundled dashboard assets."""
from main import VERSION
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
data_dist.mkdir(parents=True)
bundled_dist.mkdir()
(data_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
(data_dist / "old.txt").write_text("old", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1")
"main.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
):
with mock.patch(
"main.should_use_bundled_dashboard_dist", return_value=True
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(bundled_dist)
assert result == str(data_dist)
assert (data_dist / "assets" / "version").read_text(encoding="utf-8") == f"v{VERSION}"
assert (data_dist / "index.html").read_text(encoding="utf-8") == "bundled"
assert not (data_dist / "old.txt").exists()
mock_download.assert_not_called()
@@ -281,7 +389,7 @@ async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_version") as mock_get_version:
with mock.patch("main.get_dashboard_dist_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()

View File

@@ -3,13 +3,16 @@ import builtins
from io import BytesIO
from types import SimpleNamespace
import httpx
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from PIL import Image as PILImage
import astrbot.core.provider.sources.openai_source as openai_source_module
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.groq_source import ProviderGroq
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
from astrbot.core.utils.media_utils import ResolvedMediaData, file_uri_to_path
@@ -117,6 +120,57 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
assert captured["httpx_module"] is openai_source_module.httpx
@pytest.mark.asyncio
async def test_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="gpt-b"),
SimpleNamespace(id="gpt-a"),
]
)
models = FakeModels()
provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gpt-a", "gpt-b"]
assert models.calls == 2
@pytest.mark.asyncio
async def test_text_chat_passes_request_max_retries_to_query():
captured: dict[str, object] = {}
provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial)
provider.api_keys = ["test-key"]
provider.client = SimpleNamespace(api_key=None)
async def fake_prepare_chat_payload(*args, **kwargs):
return {"messages": [], "model": "gpt-4o-mini"}, []
async def fake_query(payloads, func_tool, *, request_max_retries=None):
captured["request_max_retries"] = request_max_retries
return LLMResponse(role="assistant", completion_text="ok")
provider._prepare_chat_payload = fake_prepare_chat_payload
provider._query = fake_query
await provider.text_chat(prompt="hello", request_max_retries=2)
assert captured["request_max_retries"] == 2
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_removes_images():
provider = _make_provider(

View File

@@ -0,0 +1,27 @@
import httpx
import pytest
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.provider.sources.request_retry import retry_provider_request
@pytest.mark.asyncio
async def test_retry_provider_request_uses_configured_max_retries(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
calls = 0
async def request():
nonlocal calls
calls += 1
raise httpx.ConnectError("temporary connection failure")
with pytest.raises(httpx.ConnectError):
await retry_provider_request(
"Test",
request,
max_attempts=2,
)
assert calls == 2

View File

@@ -440,6 +440,7 @@ async def test_download_dashboard_falls_back_when_hosted_package_is_not_zip(
path: str,
show_progress: bool = False, # noqa: ARG001
progress_callback=None, # noqa: ARG001
allow_insecure_ssl_fallback: bool = True, # noqa: ARG001
) -> None:
calls.append(url)
parsed = urlparse(url)

View File

@@ -8,6 +8,7 @@ import pytest
from astrbot.core import astr_main_agent as ama
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.message import Message, dump_messages_with_checkpoints
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Plain, Reply, Video
@@ -377,8 +378,18 @@ class TestApplyKb:
):
await module._apply_kb(mock_event, req, mock_context, config)
assert "[Related Knowledge Base Results]:" in req.system_prompt
assert "KB result" in req.system_prompt
assert req.system_prompt == "System prompt"
assert len(req.extra_user_content_parts) == 1
kb_part = req.extra_user_content_parts[0]
assert kb_part.text == "[Related Knowledge Base Results]:\nKB result"
message = Message.model_validate(await req.assemble_context())
assert isinstance(message.content, list)
assert message.content[0].text == "test question"
assert message.content[1].text == "[Related Knowledge Base Results]:\nKB result"
assert dump_messages_with_checkpoints([message]) == [
{"role": "user", "content": [{"type": "text", "text": "test question"}]}
]
@pytest.mark.asyncio
async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context):