mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 19:20:16 +08:00
Compare commits
10 Commits
copilot/fi
...
v4.25.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d21f218844 | ||
|
|
88c9a90ae6 | ||
|
|
3226cc6f0e | ||
|
|
90ea91884c | ||
|
|
598a739bab | ||
|
|
af70151ff8 | ||
|
|
66ec415e56 | ||
|
|
8f5178d265 | ||
|
|
05c137eb29 | ||
|
|
1a04998787 |
25
.github/workflows/release.yml
vendored
25
.github/workflows/release.yml
vendored
@@ -71,6 +71,15 @@ jobs:
|
||||
echo "${{ steps.tag.outputs.tag }}" > dist/assets/version
|
||||
zip -r "AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" dist
|
||||
|
||||
- name: Build core package
|
||||
shell: bash
|
||||
run: |
|
||||
git archive \
|
||||
--format=zip \
|
||||
--prefix="AstrBot-${{ steps.tag.outputs.tag }}/" \
|
||||
--output="AstrBot-${{ steps.tag.outputs.tag }}-core.zip" \
|
||||
HEAD
|
||||
|
||||
- name: Upload dashboard artifact
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
@@ -78,11 +87,12 @@ jobs:
|
||||
if-no-files-found: error
|
||||
path: dashboard/AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip
|
||||
|
||||
- name: Upload dashboard package to Cloudflare R2
|
||||
- name: Upload release packages to Cloudflare R2
|
||||
if: ${{ env.R2_ACCOUNT_ID != '' && env.R2_ACCESS_KEY_ID != '' && env.R2_SECRET_ACCESS_KEY != '' }}
|
||||
env:
|
||||
R2_BUCKET_NAME: "astrbot"
|
||||
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
|
||||
DASHBOARD_LATEST_OBJECT_NAME: "astrbot-webui-latest.zip"
|
||||
CORE_LATEST_OBJECT_NAME: "astrbot-core-latest.zip"
|
||||
VERSION_TAG: ${{ steps.tag.outputs.tag }}
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -98,11 +108,18 @@ jobs:
|
||||
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
|
||||
EOF
|
||||
|
||||
cp "dashboard/AstrBot-${VERSION_TAG}-dashboard.zip" "dashboard/${R2_OBJECT_NAME}"
|
||||
rclone copy "dashboard/${R2_OBJECT_NAME}" "r2:${R2_BUCKET_NAME}" --progress
|
||||
cp "dashboard/AstrBot-${VERSION_TAG}-dashboard.zip" "dashboard/${DASHBOARD_LATEST_OBJECT_NAME}"
|
||||
rclone copy "dashboard/${DASHBOARD_LATEST_OBJECT_NAME}" "r2:${R2_BUCKET_NAME}" --progress
|
||||
cp "dashboard/AstrBot-${VERSION_TAG}-dashboard.zip" "dashboard/astrbot-webui-${VERSION_TAG}.zip"
|
||||
rclone copy "dashboard/astrbot-webui-${VERSION_TAG}.zip" "r2:${R2_BUCKET_NAME}" --progress
|
||||
|
||||
cp "AstrBot-${VERSION_TAG}-core.zip" "${CORE_LATEST_OBJECT_NAME}"
|
||||
rclone copy "${CORE_LATEST_OBJECT_NAME}" "r2:${R2_BUCKET_NAME}" --progress
|
||||
cp "AstrBot-${VERSION_TAG}-core.zip" "astrbot-core-${VERSION_TAG}.zip"
|
||||
rclone copy "astrbot-core-${VERSION_TAG}.zip" "r2:${R2_BUCKET_NAME}" --progress
|
||||
rclone copyto "AstrBot-${VERSION_TAG}-core.zip" "r2:${R2_BUCKET_NAME}/astrbot-core/${VERSION_TAG}/source.zip" --progress
|
||||
rclone copyto "AstrBot-${VERSION_TAG}-core.zip" "r2:${R2_BUCKET_NAME}/download/astrbot-core/${VERSION_TAG}/source.zip" --progress
|
||||
|
||||
publish-release:
|
||||
name: Publish GitHub Release
|
||||
if: github.repository == 'AstrBotDevs/AstrBot'
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.25.3"
|
||||
__version__ = "4.25.6"
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.25.4"
|
||||
VERSION = "4.25.6"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
PERSONAL_WECHAT_CONFIG_METADATA = {
|
||||
"weixin_oc_base_url": {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from binascii import Error as BinasciiError
|
||||
from typing import cast
|
||||
|
||||
import quart
|
||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -13,6 +16,57 @@ from astrbot.api import logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
_SIGNATURE_HEADER = "X-Signature-Ed25519"
|
||||
_SIGNATURE_TIMESTAMP_HEADER = "X-Signature-Timestamp"
|
||||
_ED25519_SEED_SIZE = 32
|
||||
_ED25519_SIGNATURE_SIZE = 64
|
||||
|
||||
|
||||
def _build_ed25519_seed(secret: str) -> bytes:
|
||||
if not secret:
|
||||
raise ValueError("QQ official bot secret is empty.")
|
||||
|
||||
seed = secret.encode("utf-8")
|
||||
while len(seed) < _ED25519_SEED_SIZE:
|
||||
seed *= 2
|
||||
return seed[:_ED25519_SEED_SIZE]
|
||||
|
||||
|
||||
def _sign_qq_webhook_payload(secret: str, timestamp: str, payload: bytes) -> str:
|
||||
seed = _build_ed25519_seed(secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
return private_key.sign(timestamp.encode("utf-8") + payload).hex()
|
||||
|
||||
|
||||
def _verify_qq_webhook_signature(
|
||||
secret: str,
|
||||
timestamp: str | None,
|
||||
signature: str | None,
|
||||
body: bytes,
|
||||
) -> bool:
|
||||
if not timestamp or not signature:
|
||||
return False
|
||||
|
||||
try:
|
||||
signature_buffer = bytes.fromhex(signature)
|
||||
except (BinasciiError, ValueError):
|
||||
return False
|
||||
|
||||
if (
|
||||
len(signature_buffer) != _ED25519_SIGNATURE_SIZE
|
||||
or signature_buffer[63] & 224 != 0
|
||||
):
|
||||
return False
|
||||
|
||||
try:
|
||||
seed = _build_ed25519_seed(secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
public_key = private_key.public_key()
|
||||
public_key.verify(signature_buffer, timestamp.encode("utf-8") + body)
|
||||
except (InvalidSignature, ValueError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class QQOfficialWebhook:
|
||||
def __init__(
|
||||
@@ -27,7 +81,12 @@ class QQOfficialWebhook:
|
||||
if isinstance(self.port, str):
|
||||
self.port = int(self.port)
|
||||
|
||||
self.http: BotHttp = BotHttp(timeout=300, is_sandbox=self.is_sandbox)
|
||||
self.http: BotHttp = BotHttp(
|
||||
timeout=300,
|
||||
is_sandbox=self.is_sandbox,
|
||||
app_id=self.appid,
|
||||
secret=self.secret,
|
||||
)
|
||||
self.api: BotAPI = BotAPI(http=self.http)
|
||||
self.token = Token(self.appid, self.secret)
|
||||
|
||||
@@ -40,6 +99,7 @@ class QQOfficialWebhook:
|
||||
self.client = botpy_client
|
||||
self.event_queue = event_queue
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self._connection: ConnectionSession | None = None
|
||||
|
||||
# Cache for extra fields extracted from raw webhook payloads, keyed by message id
|
||||
self._extra_data_cache: dict[str, dict] = {}
|
||||
@@ -55,6 +115,13 @@ class QQOfficialWebhook:
|
||||
# 直接注入到 botpy 的 Client,移花接木!
|
||||
self.client.api = self.api
|
||||
self.client.http = self.http
|
||||
self._setup_connection()
|
||||
|
||||
def _setup_connection(self) -> None:
|
||||
if self._connection is not None:
|
||||
return
|
||||
self.client.api = self.api
|
||||
self.client.http = self.http
|
||||
|
||||
async def bot_connect() -> None:
|
||||
pass
|
||||
@@ -105,7 +172,24 @@ class QQOfficialWebhook:
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
msg: dict = await request.json
|
||||
body = await request.get_data()
|
||||
if not _verify_qq_webhook_signature(
|
||||
self.secret,
|
||||
request.headers.get(_SIGNATURE_TIMESTAMP_HEADER),
|
||||
request.headers.get(_SIGNATURE_HEADER),
|
||||
body,
|
||||
):
|
||||
logger.warning("qq_official_webhook signature verification failed.")
|
||||
return {"error": "Invalid signature"}, 401
|
||||
|
||||
try:
|
||||
msg = json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("qq_official_webhook callback body is not valid JSON.")
|
||||
return {"error": "Invalid JSON"}, 400
|
||||
if not isinstance(msg, dict):
|
||||
return {"error": "Invalid JSON"}, 400
|
||||
|
||||
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
|
||||
|
||||
event = msg.get("t")
|
||||
@@ -136,6 +220,13 @@ class QQOfficialWebhook:
|
||||
|
||||
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
|
||||
event = msg["t"].lower()
|
||||
if self._connection is None:
|
||||
logger.warning(
|
||||
"qq_official_webhook botpy connection is not initialized; "
|
||||
"creating parser connection lazily.",
|
||||
)
|
||||
self._setup_connection()
|
||||
|
||||
# Extract extra fields from raw payload before botpy parses and discards them
|
||||
if data:
|
||||
msg_id = data.get("id")
|
||||
|
||||
@@ -302,12 +302,14 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
||||
def _extract_usage(self, usage: Usage | None) -> TokenUsage:
|
||||
if usage is None:
|
||||
return TokenUsage()
|
||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
||||
return TokenUsage(
|
||||
input_other=usage.input_tokens or 0,
|
||||
input_cached=usage.cache_read_input_tokens or 0,
|
||||
output=usage.output_tokens,
|
||||
output=usage.output_tokens or 0,
|
||||
)
|
||||
|
||||
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
@@ -9,6 +9,9 @@ from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -17,11 +20,9 @@ class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||
|
||||
author: str
|
||||
name: str
|
||||
context: Context
|
||||
|
||||
class _ContextLike(Protocol):
|
||||
def get_config(self, umo: str | None = None) -> Any: ...
|
||||
|
||||
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
|
||||
def __init__(self, context: Context, config: dict | None = None) -> None:
|
||||
self.context = context
|
||||
|
||||
def _get_context_config(self) -> Any:
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import shlex
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
@@ -14,9 +15,55 @@ from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.tools.computer_tools.util import check_admin_permission
|
||||
from astrbot.core.tools.computer_tools.util import (
|
||||
check_admin_permission,
|
||||
is_local_runtime,
|
||||
workspace_root,
|
||||
)
|
||||
from astrbot.core.tools.registry import builtin_tool
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_system_tmp_path,
|
||||
get_astrbot_temp_path,
|
||||
)
|
||||
|
||||
|
||||
def _file_send_allowed_roots(umo: str | None) -> tuple[Path, ...]:
|
||||
roots = []
|
||||
if umo:
|
||||
roots.append(workspace_root(umo))
|
||||
roots.extend(
|
||||
[
|
||||
Path(get_astrbot_temp_path()).resolve(strict=False),
|
||||
Path(get_astrbot_system_tmp_path()).resolve(strict=False),
|
||||
]
|
||||
)
|
||||
return tuple(roots)
|
||||
|
||||
|
||||
def _is_path_within(path: Path, roots: tuple[Path, ...]) -> bool:
|
||||
return any(path == root or path.is_relative_to(root) for root in roots)
|
||||
|
||||
|
||||
def _is_restricted_local_env(context: ContextWrapper[AstrAgentContext]) -> bool:
|
||||
if not is_local_runtime(context):
|
||||
return False
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
require_admin = provider_settings.get("computer_use_require_admin", True)
|
||||
return require_admin and context.context.event.role != "admin"
|
||||
|
||||
|
||||
def _can_send_local_file(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
local_path: Path,
|
||||
) -> bool:
|
||||
umo = context.context.event.unified_msg_origin
|
||||
allowed_roots = _file_send_allowed_roots(umo)
|
||||
if _is_path_within(local_path, allowed_roots):
|
||||
return True
|
||||
return is_local_runtime(context) and not _is_restricted_local_env(context)
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@@ -85,23 +132,38 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
*,
|
||||
component_type: str = "file",
|
||||
) -> tuple[str, bool]:
|
||||
path = str(path)
|
||||
# if the path is relative, check if the file exists in user's local workspace
|
||||
path = str(path).strip()
|
||||
if not path:
|
||||
raise FileNotFoundError(f"{component_type} path is empty")
|
||||
|
||||
# Relative host paths are resolved only inside the user's workspace.
|
||||
if not os.path.isabs(path):
|
||||
unified_msg_origin = context.context.event.unified_msg_origin
|
||||
if unified_msg_origin:
|
||||
from astrbot.core.tools.computer_tools.util import workspace_root
|
||||
|
||||
try:
|
||||
ws_path = workspace_root(unified_msg_origin)
|
||||
ws_candidate = (ws_path / path).resolve()
|
||||
ws_candidate = (ws_path / path).resolve(strict=False)
|
||||
if ws_candidate.is_file() and ws_candidate.is_relative_to(ws_path):
|
||||
return str(ws_candidate), False
|
||||
except Exception:
|
||||
pass
|
||||
# check if the file exists in local environment (only allow absolute paths to prevent traversal)
|
||||
elif os.path.isfile(path):
|
||||
return path, False
|
||||
else:
|
||||
local_candidate = Path(path).expanduser().resolve(strict=False)
|
||||
if local_candidate.is_file():
|
||||
if _can_send_local_file(context, local_candidate):
|
||||
return str(local_candidate), False
|
||||
if is_local_runtime(context):
|
||||
allowed = ", ".join(
|
||||
str(root)
|
||||
for root in _file_send_allowed_roots(
|
||||
context.context.event.unified_msg_origin
|
||||
)
|
||||
)
|
||||
raise PermissionError(
|
||||
"Local file send is restricted for this user. "
|
||||
f"Allowed directories: {allowed}. "
|
||||
f"Blocked path: {local_candidate}."
|
||||
)
|
||||
|
||||
try:
|
||||
sb = await get_booter(
|
||||
@@ -221,6 +283,8 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
return f"error: {exc}"
|
||||
except PermissionError as exc:
|
||||
return f"error: {exc}"
|
||||
except Exception as exc:
|
||||
return f"error: failed to build messages[{idx}] component: {exc}"
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.utils.io import ensure_dir
|
||||
|
||||
from .zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
|
||||
@@ -21,6 +24,30 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
super().__init__(repo_mirror, verify=verify)
|
||||
self.MAIN_PATH = get_astrbot_path()
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
self.CORE_PACKAGE_BASE_URL = (
|
||||
"https://astrbot-registry.soulter.top/download/astrbot-core"
|
||||
)
|
||||
|
||||
def _build_core_package_url(self, version: str | None) -> str | None:
|
||||
"""Build the hosted core package URL for a release tag.
|
||||
|
||||
Args:
|
||||
version: Release tag, such as ``v4.25.6``.
|
||||
|
||||
Returns:
|
||||
Public package URL, or None when hosted package download is disabled.
|
||||
"""
|
||||
|
||||
if not version or not str(version).startswith("v"):
|
||||
return None
|
||||
|
||||
base_url = os.environ.get(
|
||||
"ASTRBOT_CORE_PACKAGE_BASE_URL",
|
||||
self.CORE_PACKAGE_BASE_URL,
|
||||
).strip()
|
||||
if not base_url:
|
||||
return None
|
||||
return f"{base_url.rstrip('/')}/{version}/source.zip"
|
||||
|
||||
def terminate_child_processes(self) -> None:
|
||||
"""终止当前进程的所有子进程
|
||||
@@ -151,6 +178,41 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
proxy="",
|
||||
progress_callback=None,
|
||||
) -> None:
|
||||
zip_path = await self.download_update_package(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
self.apply_update_package(zip_path)
|
||||
|
||||
if reboot:
|
||||
self._reboot()
|
||||
|
||||
async def download_update_package(
|
||||
self,
|
||||
latest=True,
|
||||
version=None,
|
||||
proxy="",
|
||||
path: str | Path = "temp.zip",
|
||||
progress_callback=None,
|
||||
) -> Path:
|
||||
"""Download an AstrBot core update package without applying it.
|
||||
|
||||
Args:
|
||||
latest: Whether to download the latest release.
|
||||
version: Specific release tag or commit hash to download.
|
||||
proxy: Optional GitHub proxy prefix.
|
||||
path: Destination zip path.
|
||||
progress_callback: Optional callback for download progress payloads.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded update package.
|
||||
|
||||
Raises:
|
||||
Exception: If update metadata cannot resolve a package URL.
|
||||
"""
|
||||
|
||||
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
file_url = None
|
||||
|
||||
@@ -159,15 +221,18 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
"Error: You are running AstrBot via CLI, please use `pip` or `uv tool upgrade` to update AstrBot."
|
||||
) # 避免版本管理混乱
|
||||
|
||||
target_version = None
|
||||
if latest:
|
||||
latest_version = update_data[0]["tag_name"]
|
||||
if self.compare_version(VERSION, latest_version) >= 0:
|
||||
raise Exception("当前已经是最新版本。")
|
||||
target_version = latest_version
|
||||
file_url = update_data[0]["zipball_url"]
|
||||
elif str(version).startswith("v"):
|
||||
# 更新到指定版本
|
||||
for data in update_data:
|
||||
if data["tag_name"] == version:
|
||||
target_version = data["tag_name"]
|
||||
file_url = data["zipball_url"]
|
||||
if not file_url:
|
||||
raise Exception(f"未找到版本号为 {version} 的更新文件。")
|
||||
@@ -181,16 +246,49 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
proxy = proxy.removesuffix("/")
|
||||
file_url = f"{proxy}/{file_url}"
|
||||
|
||||
try:
|
||||
await self._download_file(
|
||||
file_url,
|
||||
"temp.zip",
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...")
|
||||
self.unzip_file("temp.zip", self.MAIN_PATH)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
zip_path = Path(path)
|
||||
ensure_dir(zip_path.parent)
|
||||
hosted_package_url = self._build_core_package_url(target_version)
|
||||
if hosted_package_url:
|
||||
try:
|
||||
logger.info(
|
||||
f"优先从托管存储下载 AstrBot Core 更新包: {hosted_package_url}"
|
||||
)
|
||||
await self._download_file(
|
||||
hosted_package_url,
|
||||
str(zip_path),
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise RuntimeError(
|
||||
"Downloaded hosted package is not a valid ZIP file"
|
||||
)
|
||||
return zip_path
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"从托管存储下载 AstrBot Core 更新包失败: {exc},"
|
||||
"将回退到当前更新源。"
|
||||
)
|
||||
|
||||
if reboot:
|
||||
self._reboot()
|
||||
await self._download_file(
|
||||
file_url,
|
||||
str(zip_path),
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
return zip_path
|
||||
|
||||
def apply_update_package(self, zip_path: str | Path) -> None:
|
||||
"""Apply a previously downloaded AstrBot core update package.
|
||||
|
||||
Args:
|
||||
zip_path: Core update zip archive path.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Raises:
|
||||
Exception: If the archive cannot be extracted or applied.
|
||||
"""
|
||||
|
||||
logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...")
|
||||
self.unzip_file(str(zip_path), self.MAIN_PATH)
|
||||
|
||||
@@ -398,12 +398,27 @@ async def download_dashboard(
|
||||
version: str | None = None,
|
||||
proxy: str | None = None,
|
||||
progress_callback=None,
|
||||
extract: bool = True,
|
||||
) -> None:
|
||||
"""下载管理面板文件"""
|
||||
"""Download dashboard assets and optionally extract them.
|
||||
|
||||
Args:
|
||||
path: Destination zip path. Defaults to the AstrBot data directory.
|
||||
extract_path: Directory where assets should be extracted.
|
||||
latest: Whether to download the latest dashboard build.
|
||||
version: Specific release tag or commit hash to download.
|
||||
proxy: Optional download proxy prefix.
|
||||
progress_callback: Optional callback for download progress payloads.
|
||||
extract: Whether to extract the archive after download.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if path is None:
|
||||
zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip"
|
||||
else:
|
||||
zip_path = Path(path).absolute()
|
||||
ensure_dir(zip_path.parent)
|
||||
|
||||
if latest or len(str(version)) != 40:
|
||||
ver_name = "latest" if latest else version
|
||||
@@ -456,5 +471,28 @@ async def download_dashboard(
|
||||
show_progress=True,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if extract:
|
||||
extract_dashboard(zip_path, extract_path)
|
||||
|
||||
|
||||
def extract_dashboard(zip_path: str | Path, extract_path: str | Path = "data") -> None:
|
||||
"""Extract a downloaded dashboard archive.
|
||||
|
||||
Args:
|
||||
zip_path: Dashboard zip archive path.
|
||||
extract_path: Directory where the archive contents should be extracted.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
|
||||
extract_root = Path(extract_path).resolve()
|
||||
ensure_dir(extract_root)
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
for member in z.infolist():
|
||||
target_path = (extract_root / member.filename).resolve()
|
||||
if not target_path.is_relative_to(extract_root):
|
||||
raise ValueError(
|
||||
f"Unsafe dashboard archive path: {member.filename}",
|
||||
)
|
||||
z.extract(member, extract_root)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from quart import request
|
||||
|
||||
@@ -8,7 +11,15 @@ from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db.migration.helper import check_migration_needed_v4, do_migration_v4
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_system_tmp_path,
|
||||
)
|
||||
from astrbot.core.utils.io import (
|
||||
download_dashboard,
|
||||
extract_dashboard,
|
||||
get_dashboard_version,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -35,6 +46,7 @@ class UpdateRoute(Route):
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.update_progress: dict[str, dict] = {}
|
||||
self._update_tasks: dict[str, asyncio.Task] = {}
|
||||
self.register_routes()
|
||||
|
||||
def _init_update_progress(self, progress_id: str, version: str) -> None:
|
||||
@@ -198,7 +210,62 @@ class UpdateRoute(Route):
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
|
||||
existing_task = self._update_tasks.get(progress_id)
|
||||
if existing_task and not existing_task.done():
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{"id": progress_id, "status": "running"},
|
||||
"更新任务正在进行中。",
|
||||
)
|
||||
.__dict__,
|
||||
200,
|
||||
CLEAR_SITE_DATA_HEADERS,
|
||||
)
|
||||
|
||||
self._init_update_progress(progress_id, version)
|
||||
task = asyncio.create_task(
|
||||
self._run_update_project(progress_id, version, latest, reboot, proxy),
|
||||
)
|
||||
self._update_tasks[progress_id] = task
|
||||
task.add_done_callback(lambda _task: self._update_tasks.pop(progress_id, None))
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{"id": progress_id, "status": "running"},
|
||||
"更新任务已开始。",
|
||||
)
|
||||
.__dict__,
|
||||
200,
|
||||
CLEAR_SITE_DATA_HEADERS,
|
||||
)
|
||||
|
||||
async def _run_update_project(
|
||||
self,
|
||||
progress_id: str,
|
||||
version: str,
|
||||
latest: bool,
|
||||
reboot: bool,
|
||||
proxy: str | None,
|
||||
) -> None:
|
||||
"""Run an update task outside the request lifecycle.
|
||||
|
||||
Args:
|
||||
progress_id: Progress record id reported to the frontend.
|
||||
version: Target version without the latest sentinel.
|
||||
latest: Whether to install the latest release.
|
||||
reboot: Whether to restart AstrBot after applying files.
|
||||
proxy: Optional GitHub proxy URL.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
|
||||
update_temp_dir = Path(get_astrbot_system_tmp_path()) / "updates"
|
||||
update_temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
update_token = uuid.uuid4().hex
|
||||
dashboard_zip_path = update_temp_dir / f"{update_token}-dashboard.zip"
|
||||
core_zip_path = update_temp_dir / f"{update_token}-core.zip"
|
||||
try:
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
@@ -208,15 +275,17 @@ class UpdateRoute(Route):
|
||||
0,
|
||||
)
|
||||
await download_dashboard(
|
||||
path=str(dashboard_zip_path),
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
proxy=proxy or "",
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
0,
|
||||
45,
|
||||
),
|
||||
extract=False,
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
@@ -233,16 +302,19 @@ class UpdateRoute(Route):
|
||||
"正在下载 AstrBot 项目代码...",
|
||||
45,
|
||||
)
|
||||
await self.astrbot_updator.update(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"core",
|
||||
45,
|
||||
45,
|
||||
),
|
||||
core_zip_path = Path(
|
||||
await self.astrbot_updator.download_update_package(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy or "",
|
||||
path=core_zip_path,
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"core",
|
||||
45,
|
||||
45,
|
||||
),
|
||||
)
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
@@ -252,6 +324,50 @@ class UpdateRoute(Route):
|
||||
90,
|
||||
)
|
||||
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"verify",
|
||||
"running",
|
||||
"下载完成,正在校验更新包...",
|
||||
90,
|
||||
)
|
||||
for zip_path in (dashboard_zip_path, core_zip_path):
|
||||
with zipfile.ZipFile(zip_path, "r") as archive:
|
||||
corrupt_member = archive.testzip()
|
||||
if corrupt_member:
|
||||
raise RuntimeError(f"更新包校验失败: {corrupt_member}")
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"verify",
|
||||
"done",
|
||||
"更新包校验完成。",
|
||||
91,
|
||||
)
|
||||
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"apply",
|
||||
"running",
|
||||
"下载完成,正在应用更新...",
|
||||
91,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.astrbot_updator.apply_update_package,
|
||||
core_zip_path,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
extract_dashboard,
|
||||
dashboard_zip_path,
|
||||
Path(get_astrbot_data_path()),
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"apply",
|
||||
"done",
|
||||
"更新文件应用完成。",
|
||||
92,
|
||||
)
|
||||
|
||||
# pip 更新依赖
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
@@ -290,12 +406,7 @@ class UpdateRoute(Route):
|
||||
"overall_percent": 100,
|
||||
},
|
||||
)
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
return
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "success",
|
||||
@@ -304,12 +415,14 @@ class UpdateRoute(Route):
|
||||
"overall_percent": 100,
|
||||
},
|
||||
)
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。")
|
||||
.__dict__
|
||||
except asyncio.CancelledError:
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "更新任务已取消。",
|
||||
},
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
logger.warning(f"Update task was cancelled: {progress_id}")
|
||||
except Exception as e:
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
@@ -318,7 +431,13 @@ class UpdateRoute(Route):
|
||||
},
|
||||
)
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
finally:
|
||||
for zip_path in (dashboard_zip_path, core_zip_path):
|
||||
try:
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
except Exception as cleanup_exc:
|
||||
logger.warning(f"清理更新临时文件失败: {zip_path}, {cleanup_exc}")
|
||||
|
||||
async def update_dashboard(self):
|
||||
try:
|
||||
|
||||
30
changelogs/v4.25.5.md
Normal file
30
changelogs/v4.25.5.md
Normal file
@@ -0,0 +1,30 @@
|
||||
- [更新日志(简体中文)](#chinese)
|
||||
- [Changelog(English)](#english)
|
||||
|
||||
<a id="chinese"></a>
|
||||
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
- 收紧消息工具对本地文件路径的处理边界,减少非预期路径被用于消息附件的情况。([#8660](https://github.com/AstrBotDevs/AstrBot/pull/8660))
|
||||
- 修复 Star Context 类型定义,恢复相关 SDK 类型提示与运行兼容性。([#8659](https://github.com/AstrBotDevs/AstrBot/pull/8659))
|
||||
- 修复 QQ 官方 Webhook 模式无法正常重启的问题。
|
||||
|
||||
### 优化
|
||||
|
||||
- 改进 Anthropic 在内容过滤响应中缺失 `usage` 字段时的处理,避免相关请求结果解析异常。([#8647](https://github.com/AstrBotDevs/AstrBot/pull/8647))
|
||||
|
||||
<a id="english"></a>
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Tightened local file path handling in the message tool to avoid unintended attachment path usage. ([#8660](https://github.com/AstrBotDevs/AstrBot/pull/8660))
|
||||
- Fixed Star Context typing to restore related SDK type hints and runtime compatibility. ([#8659](https://github.com/AstrBotDevs/AstrBot/pull/8659))
|
||||
- Fixed QQ Official Webhook mode not restarting correctly.
|
||||
|
||||
### Improvements
|
||||
|
||||
- Improved Anthropic response parsing when content-filtered responses omit the `usage` field. ([#8647](https://github.com/AstrBotDevs/AstrBot/pull/8647))
|
||||
34
changelogs/v4.25.6.md
Normal file
34
changelogs/v4.25.6.md
Normal file
@@ -0,0 +1,34 @@
|
||||
- [更新日志(简体中文)](#chinese)
|
||||
- [Changelog(English)](#english)
|
||||
|
||||
<a id="chinese"></a>
|
||||
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
- 将 WebUI 项目升级改为后台任务执行,避免关闭或刷新前端页面导致升级请求被取消。
|
||||
- 调整升级流程为先下载并校验 WebUI 与 Core 两个更新包,再统一应用,降低下载失败导致文件半更新的风险。
|
||||
- Core 更新包优先从 AstrBot Registry 下载,失败时回退到 GitHub zipball。
|
||||
- WebUI 在升级重启期间会轮询启动时间,并在重启完成后使用 cache-buster 进行全量刷新,减少旧前端缓存残留。
|
||||
|
||||
### 构建
|
||||
|
||||
- 调整 Hatch artifact 配置,确保 sdist 与 wheel 都包含打包后的 Dashboard 资源。
|
||||
- Release workflow 增加 Core 更新包构建与 Registry 上传。
|
||||
|
||||
<a id="english"></a>
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Changed WebUI project updates to run as backend tasks, preventing updates from being cancelled when the frontend is closed or refreshed.
|
||||
- Made the update flow download and verify both WebUI and Core packages before applying files, reducing partial-update risk after download failures.
|
||||
- Prefer AstrBot Registry for Core update packages, with GitHub zipball fallback.
|
||||
- During update restarts, WebUI now polls backend start time and performs a cache-busted full refresh after restart completion to avoid stale frontend assets.
|
||||
|
||||
### Build
|
||||
|
||||
- Updated Hatch artifact configuration so both sdist and wheel include bundled Dashboard assets.
|
||||
- Added Core package build and Registry upload steps to the release workflow.
|
||||
@@ -55,6 +55,7 @@ let showAdvancedUpdateSettings = ref(false);
|
||||
let restartWaiting = ref(false);
|
||||
let restartStartTime = ref<number | string | null>(null);
|
||||
let restartPollTimer: ReturnType<typeof setInterval> | null = null;
|
||||
const RESTART_START_TIME_POLL_INTERVAL_MS = 2000;
|
||||
type DownloadStageStatus = "pending" | "running" | "done" | "error";
|
||||
type DownloadStage = {
|
||||
status: DownloadStageStatus;
|
||||
@@ -569,20 +570,31 @@ async function fetchAstrBotStartTime() {
|
||||
return startTime;
|
||||
}
|
||||
|
||||
function waitForAstrBotRestart(initialStartTime: number | string | null) {
|
||||
if (restartWaiting.value) {
|
||||
function reloadWithCacheBuster() {
|
||||
const url = new URL(window.location.href);
|
||||
url.searchParams.set("_r", Date.now().toString());
|
||||
window.location.replace(url.toString());
|
||||
}
|
||||
|
||||
function waitForAstrBotRestart(
|
||||
initialStartTime: number | string | null,
|
||||
showWaiting = true,
|
||||
) {
|
||||
if (showWaiting && !restartWaiting.value) {
|
||||
restartWaiting.value = true;
|
||||
updateProgress.value = {
|
||||
...updateProgress.value,
|
||||
stage: "restart",
|
||||
status: "success",
|
||||
message: t("core.header.updateDialog.progress.restarting"),
|
||||
overall_percent: 100,
|
||||
};
|
||||
}
|
||||
if (restartPollTimer) {
|
||||
return;
|
||||
}
|
||||
stopRestartPolling();
|
||||
restartWaiting.value = true;
|
||||
|
||||
restartStartTime.value = initialStartTime;
|
||||
updateProgress.value = {
|
||||
...updateProgress.value,
|
||||
stage: "restart",
|
||||
status: "success",
|
||||
message: t("core.header.updateDialog.progress.restarting"),
|
||||
overall_percent: 100,
|
||||
};
|
||||
|
||||
const poll = async () => {
|
||||
try {
|
||||
@@ -594,16 +606,17 @@ function waitForAstrBotRestart(initialStartTime: number | string | null) {
|
||||
) {
|
||||
stopRestartPolling();
|
||||
restartWaiting.value = false;
|
||||
window.location.reload();
|
||||
reloadWithCacheBuster();
|
||||
}
|
||||
} catch (_error) {
|
||||
// Backend may be unavailable while the process is restarting.
|
||||
}
|
||||
};
|
||||
|
||||
void poll();
|
||||
restartPollTimer = setInterval(() => {
|
||||
void poll();
|
||||
}, 1000);
|
||||
}, RESTART_START_TIME_POLL_INTERVAL_MS);
|
||||
}
|
||||
|
||||
function applyUpdateProgress(payload: UpdateProgress) {
|
||||
@@ -616,8 +629,15 @@ function applyUpdateProgress(payload: UpdateProgress) {
|
||||
},
|
||||
};
|
||||
if (payload.status === "success" || payload.status === "error") {
|
||||
installLoading.value = false;
|
||||
stopUpdateProgressPolling();
|
||||
}
|
||||
if (payload.status === "error") {
|
||||
stopRestartPolling();
|
||||
}
|
||||
if (payload.stage === "restart") {
|
||||
waitForAstrBotRestart(restartStartTime.value);
|
||||
}
|
||||
if (payload.status === "success") {
|
||||
waitForAstrBotRestart(restartStartTime.value);
|
||||
}
|
||||
@@ -663,6 +683,7 @@ async function switchVersion(targetVersion: string) {
|
||||
initialStartTime = commonStore.getStartTime();
|
||||
}
|
||||
restartStartTime.value = initialStartTime;
|
||||
waitForAstrBotRestart(initialStartTime, false);
|
||||
startUpdateProgressPolling(progressId);
|
||||
|
||||
axios
|
||||
@@ -673,20 +694,27 @@ async function switchVersion(targetVersion: string) {
|
||||
})
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
updateProgress.value = {
|
||||
...updateProgress.value,
|
||||
status:
|
||||
res.data.status === "ok" ? "success" : updateProgress.value.status,
|
||||
message: res.data.message,
|
||||
overall_percent:
|
||||
res.data.status === "ok" ? 100 : updateProgress.value.overall_percent,
|
||||
};
|
||||
if (res.data.status == "ok") {
|
||||
waitForAstrBotRestart(initialStartTime);
|
||||
if (res.data.status === "error") {
|
||||
stopUpdateProgressPolling();
|
||||
installLoading.value = false;
|
||||
updateProgress.value = {
|
||||
...updateProgress.value,
|
||||
status: "error",
|
||||
message:
|
||||
res.data.message || t("core.header.updateDialog.progress.failed"),
|
||||
};
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
stopUpdateProgressPolling();
|
||||
if (!err?.response && restartPollTimer) {
|
||||
waitForAstrBotRestart(restartStartTime.value);
|
||||
updateStatus.value = t("core.header.updateDialog.progress.restarting");
|
||||
return;
|
||||
}
|
||||
stopRestartPolling();
|
||||
installLoading.value = false;
|
||||
updateStatus.value = err;
|
||||
updateProgress.value = {
|
||||
...updateProgress.value,
|
||||
@@ -696,10 +724,6 @@ async function switchVersion(targetVersion: string) {
|
||||
err?.message ||
|
||||
t("core.header.updateDialog.progress.failed"),
|
||||
};
|
||||
})
|
||||
.finally(() => {
|
||||
installLoading.value = false;
|
||||
stopUpdateProgressPolling();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -712,7 +736,7 @@ function updateDashboard() {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == "ok") {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
reloadWithCacheBuster();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.25.4"
|
||||
version = "4.25.6"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
license = { text = "AGPL-3.0-or-later" }
|
||||
@@ -120,7 +120,7 @@ exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
||||
allow-direct-references = true
|
||||
|
||||
# Include bundled dashboard dist even though it is not tracked by VCS.
|
||||
[tool.hatch.build.targets.wheel]
|
||||
[tool.hatch.build]
|
||||
artifacts = ["astrbot/dashboard/dist/**"]
|
||||
|
||||
# Custom build hook: builds the Vue dashboard and copies dist into the package.
|
||||
|
||||
@@ -483,6 +483,40 @@ def _setup_provider_with_mock_client(monkeypatch) -> anthropic_source.ProviderAn
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_handles_none_usage_when_content_filtered(monkeypatch):
|
||||
provider = _setup_provider_with_mock_client(monkeypatch)
|
||||
content_filter_message = (
|
||||
"The request was rejected because it was considered high risk"
|
||||
)
|
||||
|
||||
class _FakeMessageBlock:
|
||||
def __init__(self, text: str):
|
||||
self.type = "text"
|
||||
self.text = text
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self):
|
||||
self.id = "msg_content_filter"
|
||||
self.content = [_FakeMessageBlock(content_filter_message)]
|
||||
self.stop_reason = "content_filter"
|
||||
self.usage = None
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
return _FakeMessage()
|
||||
|
||||
monkeypatch.setattr(anthropic_source, "Message", _FakeMessage)
|
||||
provider.client.messages.create = fake_create
|
||||
|
||||
llm_response = await provider.text_chat(prompt="test")
|
||||
|
||||
assert llm_response.completion_text == content_filter_message
|
||||
assert llm_response.usage is not None
|
||||
assert llm_response.usage.input_other == 0
|
||||
assert llm_response.usage.input_cached == 0
|
||||
assert llm_response.usage.output == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_auto_converts_to_dict(monkeypatch):
|
||||
"""tool_choice='auto' 应转换为 {'type': 'auto'}"""
|
||||
|
||||
@@ -59,6 +59,34 @@ def _strip_query(url: str) -> str:
|
||||
return urlunsplit(("", "", parsed.path, "", parsed.fragment))
|
||||
|
||||
|
||||
async def _wait_for_update_progress(
|
||||
test_client,
|
||||
authenticated_header: dict,
|
||||
progress_id: str,
|
||||
) -> dict:
|
||||
"""Wait until an update task reaches a terminal status.
|
||||
|
||||
Args:
|
||||
test_client: Quart test client.
|
||||
authenticated_header: Headers for authenticated dashboard requests.
|
||||
progress_id: Update progress id to poll.
|
||||
|
||||
Returns:
|
||||
The update progress response payload.
|
||||
"""
|
||||
|
||||
for _ in range(100):
|
||||
response = await test_client.get(
|
||||
f"/api/update/progress?id={progress_id}",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
data = await response.get_json()
|
||||
if data["data"].get("status") in {"success", "error"}:
|
||||
return data
|
||||
await asyncio.sleep(0.01)
|
||||
pytest.fail(f"Update task did not finish: {progress_id}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registered_plugin_page(core_lifecycle_td: AstrBotCoreLifecycle, monkeypatch):
|
||||
plugin_root = (
|
||||
@@ -2463,40 +2491,77 @@ async def test_do_update(
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path_factory,
|
||||
tmp_path,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
# Use a temporary path for the mock update to avoid side effects
|
||||
temp_release_dir = tmp_path_factory.mktemp("release")
|
||||
release_path = temp_release_dir / "astrbot"
|
||||
calls = []
|
||||
release_path = tmp_path / "astrbot"
|
||||
calls: list[str] = []
|
||||
|
||||
async def mock_update(*args, **kwargs):
|
||||
"""Mocks the update process by creating a directory in the temp path."""
|
||||
calls.append("core")
|
||||
async def mock_download_update_package(*args, **kwargs):
|
||||
"""Mock the core package download by writing a valid ZIP archive."""
|
||||
calls.append("download-core")
|
||||
callback = kwargs.get("progress_callback")
|
||||
if callback:
|
||||
callback({"downloaded": 10, "total": 10, "percent": 1, "speed": 1})
|
||||
zip_path = Path(kwargs["path"])
|
||||
with zipfile.ZipFile(zip_path, "w") as archive:
|
||||
archive.writestr("AstrBot-v3.4.0/README.md", "core")
|
||||
return zip_path
|
||||
|
||||
def mock_apply_update_package(zip_path):
|
||||
"""Mock applying the core package."""
|
||||
calls.append("apply-core")
|
||||
assert zipfile.is_zipfile(zip_path)
|
||||
os.makedirs(release_path, exist_ok=True)
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
"""Mocks the dashboard download to prevent network access."""
|
||||
calls.append("dashboard")
|
||||
"""Mock the dashboard download by writing a valid ZIP archive."""
|
||||
calls.append("download-dashboard")
|
||||
callback = kwargs.get("progress_callback")
|
||||
if callback:
|
||||
callback({"downloaded": 10, "total": 10, "percent": 1, "speed": 1})
|
||||
return
|
||||
zip_path = Path(kwargs["path"])
|
||||
with zipfile.ZipFile(zip_path, "w") as archive:
|
||||
archive.writestr("dist/index.html", "dashboard")
|
||||
|
||||
def mock_extract_dashboard(zip_path, extract_path):
|
||||
"""Mock applying the dashboard package."""
|
||||
calls.append("apply-dashboard")
|
||||
assert zipfile.is_zipfile(zip_path)
|
||||
assert Path(extract_path) == tmp_path / "data"
|
||||
|
||||
async def mock_pip_install(*args, **kwargs):
|
||||
"""Mocks pip install to prevent actual installation."""
|
||||
calls.append("pip")
|
||||
return
|
||||
|
||||
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"download_update_package",
|
||||
mock_download_update_package,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"apply_update_package",
|
||||
mock_apply_update_package,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.download_dashboard",
|
||||
mock_download_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.extract_dashboard",
|
||||
mock_extract_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.get_astrbot_system_tmp_path",
|
||||
lambda: str(tmp_path / "tmp"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.get_astrbot_data_path",
|
||||
lambda: str(tmp_path / "data"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.pip_installer.install",
|
||||
mock_pip_install,
|
||||
@@ -2510,19 +2575,101 @@ async def test_do_update(
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists(release_path)
|
||||
assert calls[:2] == ["dashboard", "core"]
|
||||
assert data["data"]["id"] == "test-progress"
|
||||
assert data["data"]["status"] == "running"
|
||||
|
||||
progress_response = await test_client.get(
|
||||
"/api/update/progress?id=test-progress",
|
||||
headers=authenticated_header,
|
||||
progress_data = await _wait_for_update_progress(
|
||||
test_client,
|
||||
authenticated_header,
|
||||
"test-progress",
|
||||
)
|
||||
progress_data = await progress_response.get_json()
|
||||
assert os.path.exists(release_path)
|
||||
assert calls == [
|
||||
"download-dashboard",
|
||||
"download-core",
|
||||
"apply-core",
|
||||
"apply-dashboard",
|
||||
"pip",
|
||||
]
|
||||
assert progress_data["status"] == "ok"
|
||||
assert progress_data["data"]["status"] == "success"
|
||||
assert progress_data["data"]["overall_percent"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update_does_not_apply_files_when_core_download_fails(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
calls: list[str] = []
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
"""Mock the dashboard download by writing a valid ZIP archive."""
|
||||
calls.append("download-dashboard")
|
||||
zip_path = Path(kwargs["path"])
|
||||
with zipfile.ZipFile(zip_path, "w") as archive:
|
||||
archive.writestr("dist/index.html", "dashboard")
|
||||
|
||||
async def mock_download_update_package(*args, **kwargs):
|
||||
"""Mock a core package download failure."""
|
||||
calls.append("download-core")
|
||||
raise RuntimeError("core download failed")
|
||||
|
||||
def fail_apply_update_package(*args, **kwargs):
|
||||
"""Ensure core files are not applied after a download failure."""
|
||||
calls.append("apply-core")
|
||||
raise AssertionError("core package should not be applied")
|
||||
|
||||
def fail_extract_dashboard(*args, **kwargs):
|
||||
"""Ensure dashboard files are not applied after a download failure."""
|
||||
calls.append("apply-dashboard")
|
||||
raise AssertionError("dashboard package should not be applied")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.get_astrbot_system_tmp_path",
|
||||
lambda: str(tmp_path / "tmp"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.download_dashboard",
|
||||
mock_download_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.extract_dashboard",
|
||||
fail_extract_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"download_update_package",
|
||||
mock_download_update_package,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"apply_update_package",
|
||||
fail_apply_update_package,
|
||||
)
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/do",
|
||||
headers=authenticated_header,
|
||||
json={"version": "v3.4.0", "reboot": False, "progress_id": "atomic-fail"},
|
||||
)
|
||||
data = await response.get_json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["status"] == "ok"
|
||||
progress_data = await _wait_for_update_progress(
|
||||
test_client,
|
||||
authenticated_header,
|
||||
"atomic-fail",
|
||||
)
|
||||
assert progress_data["data"]["status"] == "error"
|
||||
assert calls == ["download-dashboard", "download-core"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_pip_package_returns_pip_install_error_message(
|
||||
app: Quart,
|
||||
|
||||
124
tests/test_qqofficial_webhook_signature.py
Normal file
124
tests/test_qqofficial_webhook_signature.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.platform.sources.qqofficial_webhook.qo_webhook_server import (
|
||||
_SIGNATURE_HEADER,
|
||||
_SIGNATURE_TIMESTAMP_HEADER,
|
||||
QQOfficialWebhook,
|
||||
_sign_qq_webhook_payload,
|
||||
_verify_qq_webhook_signature,
|
||||
)
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, body: bytes, headers: dict[str, str] | None = None) -> None:
|
||||
self._body = body
|
||||
self.headers = headers or {}
|
||||
|
||||
async def get_data(self) -> bytes:
|
||||
return self._body
|
||||
|
||||
|
||||
class FakeBotpyClient:
|
||||
api = None
|
||||
http = None
|
||||
|
||||
def ws_dispatch(self, *_args, **_kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_qq_webhook_signature_verification_accepts_valid_signature():
|
||||
secret = "test-secret"
|
||||
timestamp = "1710000000"
|
||||
body = b'{"op":12,"d":0}'
|
||||
signature = _sign_qq_webhook_payload(secret, timestamp, body)
|
||||
|
||||
assert _verify_qq_webhook_signature(secret, timestamp, signature, body)
|
||||
|
||||
|
||||
def test_qq_webhook_signature_verification_rejects_tampered_body():
|
||||
secret = "test-secret"
|
||||
timestamp = "1710000000"
|
||||
body = b'{"op":12,"d":0}'
|
||||
signature = _sign_qq_webhook_payload(secret, timestamp, body)
|
||||
|
||||
assert not _verify_qq_webhook_signature(
|
||||
secret,
|
||||
timestamp,
|
||||
signature,
|
||||
b'{"op":12,"d":1}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qq_webhook_callback_rejects_missing_signature():
|
||||
webhook = object.__new__(QQOfficialWebhook)
|
||||
webhook.secret = "test-secret"
|
||||
|
||||
result = await webhook.handle_callback(FakeRequest(b'{"op":12,"d":0}'))
|
||||
|
||||
assert result == ({"error": "Invalid signature"}, 401)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qq_webhook_callback_accepts_signed_validation():
|
||||
secret = "test-secret"
|
||||
event_ts = "1710000000"
|
||||
plain_token = "plain-token"
|
||||
body = json.dumps(
|
||||
{"op": 13, "d": {"event_ts": event_ts, "plain_token": plain_token}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
signature = _sign_qq_webhook_payload(secret, event_ts, body)
|
||||
webhook = object.__new__(QQOfficialWebhook)
|
||||
webhook.secret = secret
|
||||
|
||||
result = await webhook.handle_callback(
|
||||
FakeRequest(
|
||||
body,
|
||||
{
|
||||
_SIGNATURE_TIMESTAMP_HEADER: event_ts,
|
||||
_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"plain_token": plain_token,
|
||||
"signature": _sign_qq_webhook_payload(secret, event_ts, plain_token.encode()),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qq_webhook_callback_lazily_creates_botpy_connection():
|
||||
secret = "test-secret"
|
||||
timestamp = "1710000000"
|
||||
body = json.dumps(
|
||||
{"op": 0, "t": "UNKNOWN_EVENT", "id": "event-id", "d": {"id": "message-id"}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
signature = _sign_qq_webhook_payload(secret, timestamp, body)
|
||||
webhook = QQOfficialWebhook(
|
||||
{"appid": "123", "secret": secret},
|
||||
asyncio.Queue(),
|
||||
FakeBotpyClient(),
|
||||
)
|
||||
|
||||
result = await webhook.handle_callback(
|
||||
FakeRequest(
|
||||
body,
|
||||
{
|
||||
_SIGNATURE_TIMESTAMP_HEADER: timestamp,
|
||||
_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == {"opcode": 12}
|
||||
assert webhook._connection is not None
|
||||
assert webhook.http._token is not None
|
||||
assert webhook.http._token.app_id == "123"
|
||||
assert webhook.client.api is webhook.api
|
||||
assert webhook.client.http is webhook.http
|
||||
@@ -1,14 +1,17 @@
|
||||
import ntpath
|
||||
import posixpath
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import certifi
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from astrbot.core.star.updator import PluginUpdator
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.zip_updator import RepoZipUpdator
|
||||
|
||||
|
||||
@@ -286,6 +289,144 @@ async def test_plugin_updator_install_prefers_download_url(
|
||||
assert calls["unzip"] == (str(expected_path) + ".zip", str(expected_path))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astrbot_updator_prefers_hosted_core_package(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.delenv("ASTRBOT_CLI", raising=False)
|
||||
monkeypatch.delenv("ASTRBOT_LAUNCHER", raising=False)
|
||||
monkeypatch.setenv("ASTRBOT_CORE_PACKAGE_BASE_URL", "https://cdn.example/core")
|
||||
|
||||
updator = AstrBotUpdator()
|
||||
calls: list[str] = []
|
||||
|
||||
async def fake_fetch_release_info(url: str, latest: bool = True): # noqa: ARG001
|
||||
return [
|
||||
{
|
||||
"version": "AstrBot v99.0.0",
|
||||
"published_at": "2026-06-19T00:00:00Z",
|
||||
"body": "hosted core package",
|
||||
"tag_name": "v99.0.0",
|
||||
"zipball_url": "https://github.example/archive.zip",
|
||||
}
|
||||
]
|
||||
|
||||
async def fake_download_file(url: str, path: str, progress_callback=None): # noqa: ARG001
|
||||
calls.append(url)
|
||||
with zipfile.ZipFile(path, "w") as archive:
|
||||
archive.writestr("AstrBot-v99.0.0/README.md", "hosted-core")
|
||||
|
||||
monkeypatch.setattr(updator, "fetch_release_info", fake_fetch_release_info)
|
||||
monkeypatch.setattr(updator, "_download_file", fake_download_file)
|
||||
|
||||
zip_path = await updator.download_update_package(
|
||||
latest=False,
|
||||
version="v99.0.0",
|
||||
path=tmp_path / "core.zip",
|
||||
)
|
||||
|
||||
assert zip_path == tmp_path / "core.zip"
|
||||
assert zipfile.is_zipfile(zip_path)
|
||||
assert calls == ["https://cdn.example/core/v99.0.0/source.zip"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astrbot_updator_falls_back_when_hosted_core_package_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.delenv("ASTRBOT_CLI", raising=False)
|
||||
monkeypatch.delenv("ASTRBOT_LAUNCHER", raising=False)
|
||||
monkeypatch.setenv("ASTRBOT_CORE_PACKAGE_BASE_URL", "https://cdn.example/core")
|
||||
|
||||
updator = AstrBotUpdator()
|
||||
calls: list[str] = []
|
||||
|
||||
async def fake_fetch_release_info(url: str, latest: bool = True): # noqa: ARG001
|
||||
return [
|
||||
{
|
||||
"version": "AstrBot v99.0.0",
|
||||
"published_at": "2026-06-19T00:00:00Z",
|
||||
"body": "hosted core package",
|
||||
"tag_name": "v99.0.0",
|
||||
"zipball_url": "https://github.example/archive.zip",
|
||||
}
|
||||
]
|
||||
|
||||
async def fake_download_file(url: str, path: str, progress_callback=None): # noqa: ARG001
|
||||
calls.append(url)
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme == "https" and parsed.hostname == "cdn.example":
|
||||
raise RuntimeError("404")
|
||||
Path(path).write_bytes(b"github-core")
|
||||
|
||||
monkeypatch.setattr(updator, "fetch_release_info", fake_fetch_release_info)
|
||||
monkeypatch.setattr(updator, "_download_file", fake_download_file)
|
||||
|
||||
zip_path = await updator.download_update_package(
|
||||
latest=False,
|
||||
version="v99.0.0",
|
||||
path=tmp_path / "core.zip",
|
||||
)
|
||||
|
||||
assert zip_path == tmp_path / "core.zip"
|
||||
assert zip_path.read_bytes() == b"github-core"
|
||||
assert calls == [
|
||||
"https://cdn.example/core/v99.0.0/source.zip",
|
||||
"https://github.example/archive.zip",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astrbot_updator_falls_back_when_hosted_core_package_is_not_zip(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.delenv("ASTRBOT_CLI", raising=False)
|
||||
monkeypatch.delenv("ASTRBOT_LAUNCHER", raising=False)
|
||||
monkeypatch.setenv("ASTRBOT_CORE_PACKAGE_BASE_URL", "https://cdn.example/core")
|
||||
|
||||
updator = AstrBotUpdator()
|
||||
calls: list[str] = []
|
||||
|
||||
async def fake_fetch_release_info(url: str, latest: bool = True): # noqa: ARG001
|
||||
return [
|
||||
{
|
||||
"version": "AstrBot v99.0.0",
|
||||
"published_at": "2026-06-19T00:00:00Z",
|
||||
"body": "hosted core package",
|
||||
"tag_name": "v99.0.0",
|
||||
"zipball_url": "https://github.example/archive.zip",
|
||||
}
|
||||
]
|
||||
|
||||
async def fake_download_file(url: str, path: str, progress_callback=None): # noqa: ARG001
|
||||
calls.append(url)
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme == "https" and parsed.hostname == "cdn.example":
|
||||
Path(path).write_bytes(b"not a zip")
|
||||
return
|
||||
with zipfile.ZipFile(path, "w") as archive:
|
||||
archive.writestr("AstrBot-v99.0.0/README.md", "github-core")
|
||||
|
||||
monkeypatch.setattr(updator, "fetch_release_info", fake_fetch_release_info)
|
||||
monkeypatch.setattr(updator, "_download_file", fake_download_file)
|
||||
|
||||
zip_path = await updator.download_update_package(
|
||||
latest=False,
|
||||
version="v99.0.0",
|
||||
path=tmp_path / "core.zip",
|
||||
)
|
||||
|
||||
assert zip_path == tmp_path / "core.zip"
|
||||
assert zipfile.is_zipfile(zip_path)
|
||||
assert calls == [
|
||||
"https://cdn.example/core/v99.0.0/source.zip",
|
||||
"https://github.example/archive.zip",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_release_info_uses_httpx_client_with_env_proxy_support(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -12,9 +12,15 @@ def _make_context(
|
||||
current_session="feishu:GroupMessage:oc_xxx",
|
||||
role="admin",
|
||||
require_admin=True,
|
||||
runtime="local",
|
||||
):
|
||||
"""Build a minimal ContextWrapper for SendMessageToUserTool."""
|
||||
cfg = {"provider_settings": {"computer_use_require_admin": require_admin}}
|
||||
cfg = {
|
||||
"provider_settings": {
|
||||
"computer_use_require_admin": require_admin,
|
||||
"computer_use_runtime": runtime,
|
||||
}
|
||||
}
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(
|
||||
event=SimpleNamespace(
|
||||
@@ -161,3 +167,71 @@ async def test_send_message_missing_image_path_stops_before_send(tmp_path, monke
|
||||
|
||||
assert "error: failed to build messages[1] component: sandbox unavailable" in result
|
||||
ctx.context.context.send_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_cannot_send_arbitrary_local_absolute_file(tmp_path):
|
||||
"""Non-admin users cannot send host files outside the allowed local roots."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(role="member", require_admin=True)
|
||||
secret_path = tmp_path / "secret.txt"
|
||||
secret_path.write_text("secret", encoding="utf-8")
|
||||
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "file", "path": str(secret_path)}],
|
||||
)
|
||||
|
||||
assert "error: Local file send is restricted for this user" in result
|
||||
assert str(secret_path) in result
|
||||
ctx.context.context.send_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_can_send_workspace_file(tmp_path, monkeypatch):
|
||||
"""Non-admin users can send files inside their per-session workspace."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(
|
||||
current_session="feishu:GroupMessage:oc_workspace",
|
||||
role="member",
|
||||
require_admin=True,
|
||||
)
|
||||
workspace_root = tmp_path / "workspaces"
|
||||
workspace_file = workspace_root / "feishu_GroupMessage_oc_workspace" / "result.txt"
|
||||
workspace_file.parent.mkdir(parents=True)
|
||||
workspace_file.write_text("result", encoding="utf-8")
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.tools.computer_tools.util.get_astrbot_workspaces_path",
|
||||
lambda: str(workspace_root),
|
||||
)
|
||||
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "file", "path": "result.txt"}],
|
||||
)
|
||||
|
||||
assert "Message sent to session" in result
|
||||
ctx.context.context.send_message.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_can_send_temp_file(tmp_path, monkeypatch):
|
||||
"""Non-admin users can send generated files under AstrBot temp."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(role="member", require_admin=True)
|
||||
temp_root = tmp_path / "temp"
|
||||
temp_root.mkdir()
|
||||
output_path = temp_root / "output.txt"
|
||||
output_path.write_text("output", encoding="utf-8")
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.tools.message_tools.get_astrbot_temp_path",
|
||||
lambda: str(temp_root),
|
||||
)
|
||||
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "file", "path": str(output_path)}],
|
||||
)
|
||||
|
||||
assert "Message sent to session" in result
|
||||
ctx.context.context.send_message.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user