mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 18:20:16 +08:00
Compare commits
13 Commits
refactor/b
...
v4.25.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af70151ff8 | ||
|
|
66ec415e56 | ||
|
|
8f5178d265 | ||
|
|
05c137eb29 | ||
|
|
1a04998787 | ||
|
|
c4251e8210 | ||
|
|
66a10c08b2 | ||
|
|
c7e9d5b481 | ||
|
|
0db7fc9b39 | ||
|
|
556903c135 | ||
|
|
bdc32bb78c | ||
|
|
c70a1924fe | ||
|
|
6ae103a24f |
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
@@ -70,37 +69,6 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
tool_result,
|
||||
)
|
||||
|
||||
# special handle web_search_tavily
|
||||
platform_name = run_context.context.event.get_platform_name()
|
||||
if (
|
||||
platform_name == "webchat"
|
||||
and tool.name
|
||||
in [
|
||||
"web_search_baidu",
|
||||
"web_search_tavily",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
]
|
||||
and len(run_context.messages) > 0
|
||||
and tool_result
|
||||
and len(tool_result.content)
|
||||
):
|
||||
# inject system prompt
|
||||
first_part = run_context.messages[0]
|
||||
if (
|
||||
isinstance(first_part, Message)
|
||||
and first_part.role == "system"
|
||||
and first_part.content
|
||||
and isinstance(first_part.content, str)
|
||||
):
|
||||
# we assume system part is str
|
||||
first_part.content += (
|
||||
"Always cite web search results you rely on. "
|
||||
"Index is a unique identifier for each search result. "
|
||||
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
|
||||
"after the sentence that uses the information. Do not invent citations."
|
||||
)
|
||||
|
||||
|
||||
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
pass
|
||||
|
||||
@@ -115,6 +115,20 @@ from astrbot.core.utils.quoted_message_parser import (
|
||||
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
||||
|
||||
LLM_ERROR_MESSAGE_EXTRA_KEY = "_llm_error_message"
|
||||
WEB_SEARCH_CITATION_TOOL_NAMES = frozenset(
|
||||
{
|
||||
"web_search_baidu",
|
||||
"web_search_tavily",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
}
|
||||
)
|
||||
WEB_SEARCH_CITATION_PROMPT = (
|
||||
"Always cite web search results you rely on. "
|
||||
"Index is a unique identifier for each search result. "
|
||||
"Use the exact citation format <ref>index</ref> (e.g. <ref>abcd.3</ref>) "
|
||||
"after the sentence that uses the information. Do not invent citations."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -1149,6 +1163,23 @@ async def _apply_web_search_tools(
|
||||
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool))
|
||||
|
||||
|
||||
def _apply_web_search_citation_prompt(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
if event.get_platform_name() != "webchat" or not req.func_tool:
|
||||
return
|
||||
|
||||
if not any(req.func_tool.get_tool(name) for name in WEB_SEARCH_CITATION_TOOL_NAMES):
|
||||
return
|
||||
|
||||
system_prompt = req.system_prompt or ""
|
||||
if WEB_SEARCH_CITATION_PROMPT in system_prompt:
|
||||
return
|
||||
|
||||
req.system_prompt = f"{system_prompt}\n{WEB_SEARCH_CITATION_PROMPT}\n"
|
||||
|
||||
|
||||
def _get_compress_provider(
|
||||
config: MainAgentBuildConfig,
|
||||
plugin_context: Context,
|
||||
@@ -1520,6 +1551,8 @@ async def build_main_agent(
|
||||
if action_type == "live":
|
||||
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
|
||||
|
||||
_apply_web_search_citation_prompt(event, req)
|
||||
|
||||
reset_coro = agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.25.3"
|
||||
VERSION = "4.25.5"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
PERSONAL_WECHAT_CONFIG_METADATA = {
|
||||
"weixin_oc_base_url": {
|
||||
@@ -417,7 +417,7 @@ CONFIG_METADATA_2 = {
|
||||
"weixin_oc_bot_type": "3",
|
||||
"weixin_oc_qr_poll_interval": 1,
|
||||
"weixin_oc_long_poll_timeout_ms": 35_000,
|
||||
"weixin_oc_api_timeout_ms": 15_000,
|
||||
"weixin_oc_api_timeout_ms": 120_000,
|
||||
},
|
||||
"飞书(Lark)": {
|
||||
"id": "lark",
|
||||
|
||||
@@ -5,10 +5,7 @@ from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
ApiKey,
|
||||
@@ -32,19 +29,6 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
|
||||
def _configure_sqlite_connection(dbapi_connection, connection_record) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA cache_size=20000")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA mmap_size=134217728")
|
||||
cursor.execute("PRAGMA optimize")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
"""数据库基类"""
|
||||
@@ -57,29 +41,14 @@ class BaseDatabase(abc.ABC):
|
||||
# second write is attempted. Setting timeout=30 tells SQLite to
|
||||
# wait up to 30 s for the lock, which is enough to ride out brief
|
||||
# write bursts from concurrent agent/metrics/session operations.
|
||||
db_url = make_url(self.DATABASE_URL)
|
||||
is_sqlite = db_url.get_backend_name() == "sqlite"
|
||||
is_sqlite = "sqlite" in self.DATABASE_URL
|
||||
connect_args = {"timeout": 30} if is_sqlite else {}
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"connect_args": connect_args,
|
||||
}
|
||||
if is_sqlite:
|
||||
# Keep SQLite async engines off SQLAlchemy's default async queue
|
||||
# pool so packaged runtimes don't depend on dialect-specific pool
|
||||
# event support.
|
||||
engine_kwargs["poolclass"] = NullPool
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
**engine_kwargs,
|
||||
echo=False,
|
||||
future=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
if is_sqlite:
|
||||
event.listen(
|
||||
self.engine.sync_engine,
|
||||
"connect",
|
||||
_configure_sqlite_connection,
|
||||
)
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
|
||||
@@ -53,6 +53,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA busy_timeout=30000"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
|
||||
@@ -5,11 +5,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import Column, Text, bindparam
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
|
||||
|
||||
from astrbot.core import logger
|
||||
@@ -63,7 +60,8 @@ class DocumentStorage:
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
await self.connect()
|
||||
async with self.engine.begin() as conn: # type: ignore
|
||||
await self._ensure_documents_table(conn)
|
||||
# Create tables using SQLModel
|
||||
await conn.run_sync(BaseDocModel.metadata.create_all)
|
||||
|
||||
try:
|
||||
await conn.execute(
|
||||
@@ -93,59 +91,15 @@ class DocumentStorage:
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_documents_doc_id_unique ON documents(doc_id)",
|
||||
),
|
||||
)
|
||||
|
||||
await self._initialize_fts5(conn)
|
||||
await conn.commit()
|
||||
|
||||
async def _ensure_documents_table(self, executor) -> None:
|
||||
"""Create the document table from the SQLModel definition."""
|
||||
result = await executor.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM sqlite_master
|
||||
WHERE type='table' AND name=:table_name
|
||||
LIMIT 1
|
||||
""",
|
||||
),
|
||||
{"table_name": Document.__tablename__},
|
||||
)
|
||||
if result.scalar_one_or_none() is not None:
|
||||
await self._ensure_doc_id_unique_index(executor)
|
||||
return
|
||||
|
||||
create_table = CreateTable(Document.__table__, if_not_exists=True) # type: ignore[attr-defined]
|
||||
|
||||
await executor.execute(
|
||||
text(str(create_table.compile(dialect=sqlite.dialect())))
|
||||
)
|
||||
await self._ensure_doc_id_unique_index(executor)
|
||||
|
||||
async def _ensure_doc_id_unique_index(self, executor) -> None:
|
||||
duplicate_result = await executor.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT doc_id
|
||||
FROM documents
|
||||
GROUP BY doc_id
|
||||
HAVING COUNT(*) > 1
|
||||
LIMIT 1
|
||||
""",
|
||||
),
|
||||
)
|
||||
if duplicate_result.scalar_one_or_none() is not None:
|
||||
logger.warning(
|
||||
"Skipping documents.doc_id unique index migration because duplicate "
|
||||
f"doc_id values already exist in {self.db_path}.",
|
||||
)
|
||||
return
|
||||
|
||||
await executor.execute(
|
||||
text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS "
|
||||
"idx_documents_doc_id_unique ON documents(doc_id)",
|
||||
),
|
||||
)
|
||||
|
||||
async def _initialize_fts5(self, executor) -> None:
|
||||
try:
|
||||
await self._create_fts5_table(executor, if_not_exists=True)
|
||||
@@ -249,7 +203,6 @@ class DocumentStorage:
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
poolclass=NullPool,
|
||||
)
|
||||
self.async_session_maker = sessionmaker(
|
||||
self.engine, # type: ignore
|
||||
|
||||
@@ -33,6 +33,8 @@ class EventBus:
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
# 持有正在执行的 pipeline 任务的强引用, 防止 task 在 pending 状态被 GC 回收
|
||||
self._pending_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
while True:
|
||||
@@ -47,7 +49,18 @@ class EventBus:
|
||||
f"PipelineScheduler not found for id: {conf_id}, event ignored."
|
||||
)
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
task = asyncio.create_task(scheduler.execute(event))
|
||||
self._pending_tasks.add(task)
|
||||
task.add_done_callback(self._on_task_done)
|
||||
|
||||
def _on_task_done(self, task: asyncio.Task) -> None:
|
||||
"""pipeline 任务结束回调: 移除强引用并暴露未捕获的异常"""
|
||||
self._pending_tasks.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.error("pipeline 任务执行异常", exc_info=exc)
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
|
||||
"""用于记录事件信息
|
||||
|
||||
@@ -2,9 +2,8 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import delete, event, func, select, text, update
|
||||
from sqlalchemy import delete, func, select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel import col, desc
|
||||
|
||||
from astrbot.core import logger
|
||||
@@ -20,19 +19,6 @@ if TYPE_CHECKING:
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
def _configure_sqlite_connection(dbapi_connection, connection_record) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA cache_size=20000")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA mmap_size=134217728")
|
||||
cursor.execute("PRAGMA optimize")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""初始化知识库数据库
|
||||
@@ -54,12 +40,8 @@ class KBSQLiteDatabase:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
poolclass=NullPool,
|
||||
)
|
||||
event.listen(
|
||||
self.engine.sync_engine,
|
||||
"connect",
|
||||
_configure_sqlite_connection,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
|
||||
@@ -5,8 +5,6 @@ import base64
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import replace
|
||||
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from astrbot.core import db_helper, logger
|
||||
from astrbot.core.agent.message import (
|
||||
CheckpointData,
|
||||
@@ -521,15 +519,6 @@ class InternalAgentSubStage(Stage):
|
||||
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}
|
||||
decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED]
|
||||
|
||||
PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS = 3
|
||||
PROVIDER_STATS_SQLITE_LOCK_RETRY_BASE_DELAY = 0.2
|
||||
|
||||
|
||||
def _is_sqlite_database_locked_error(exc: OperationalError) -> bool:
|
||||
raw = getattr(exc, "orig", exc)
|
||||
message = str(raw).lower()
|
||||
return "database" in message and "locked" in message
|
||||
|
||||
|
||||
async def _record_internal_agent_stats(
|
||||
event: AstrMessageEvent,
|
||||
@@ -560,35 +549,15 @@ async def _record_internal_agent_stats(
|
||||
status = "error"
|
||||
else:
|
||||
status = "completed"
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
await db_helper.insert_provider_stat(
|
||||
umo=event.unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
provider_id=provider_config.get("id", "") or provider.meta().id,
|
||||
provider_model=provider.get_model(),
|
||||
status=status,
|
||||
stats=stats.to_dict(),
|
||||
agent_type="internal",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
|
||||
return
|
||||
|
||||
for attempt in range(PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS):
|
||||
last_attempt = attempt == PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS - 1
|
||||
try:
|
||||
await db_helper.insert_provider_stat(
|
||||
umo=event.unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
provider_id=provider_config.get("id", "") or provider.meta().id,
|
||||
provider_model=provider.get_model(),
|
||||
status=status,
|
||||
stats=stats.to_dict(),
|
||||
agent_type="internal",
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except OperationalError as e:
|
||||
if _is_sqlite_database_locked_error(e) and not last_attempt:
|
||||
await asyncio.sleep(
|
||||
PROVIDER_STATS_SQLITE_LOCK_RETRY_BASE_DELAY * (2**attempt)
|
||||
)
|
||||
continue
|
||||
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
|
||||
break
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from binascii import Error as BinasciiError
|
||||
from typing import cast
|
||||
|
||||
import quart
|
||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -13,6 +16,57 @@ from astrbot.api import logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
_SIGNATURE_HEADER = "X-Signature-Ed25519"
|
||||
_SIGNATURE_TIMESTAMP_HEADER = "X-Signature-Timestamp"
|
||||
_ED25519_SEED_SIZE = 32
|
||||
_ED25519_SIGNATURE_SIZE = 64
|
||||
|
||||
|
||||
def _build_ed25519_seed(secret: str) -> bytes:
|
||||
if not secret:
|
||||
raise ValueError("QQ official bot secret is empty.")
|
||||
|
||||
seed = secret.encode("utf-8")
|
||||
while len(seed) < _ED25519_SEED_SIZE:
|
||||
seed *= 2
|
||||
return seed[:_ED25519_SEED_SIZE]
|
||||
|
||||
|
||||
def _sign_qq_webhook_payload(secret: str, timestamp: str, payload: bytes) -> str:
|
||||
seed = _build_ed25519_seed(secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
return private_key.sign(timestamp.encode("utf-8") + payload).hex()
|
||||
|
||||
|
||||
def _verify_qq_webhook_signature(
|
||||
secret: str,
|
||||
timestamp: str | None,
|
||||
signature: str | None,
|
||||
body: bytes,
|
||||
) -> bool:
|
||||
if not timestamp or not signature:
|
||||
return False
|
||||
|
||||
try:
|
||||
signature_buffer = bytes.fromhex(signature)
|
||||
except (BinasciiError, ValueError):
|
||||
return False
|
||||
|
||||
if (
|
||||
len(signature_buffer) != _ED25519_SIGNATURE_SIZE
|
||||
or signature_buffer[63] & 224 != 0
|
||||
):
|
||||
return False
|
||||
|
||||
try:
|
||||
seed = _build_ed25519_seed(secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
public_key = private_key.public_key()
|
||||
public_key.verify(signature_buffer, timestamp.encode("utf-8") + body)
|
||||
except (InvalidSignature, ValueError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class QQOfficialWebhook:
|
||||
def __init__(
|
||||
@@ -27,7 +81,12 @@ class QQOfficialWebhook:
|
||||
if isinstance(self.port, str):
|
||||
self.port = int(self.port)
|
||||
|
||||
self.http: BotHttp = BotHttp(timeout=300, is_sandbox=self.is_sandbox)
|
||||
self.http: BotHttp = BotHttp(
|
||||
timeout=300,
|
||||
is_sandbox=self.is_sandbox,
|
||||
app_id=self.appid,
|
||||
secret=self.secret,
|
||||
)
|
||||
self.api: BotAPI = BotAPI(http=self.http)
|
||||
self.token = Token(self.appid, self.secret)
|
||||
|
||||
@@ -40,6 +99,7 @@ class QQOfficialWebhook:
|
||||
self.client = botpy_client
|
||||
self.event_queue = event_queue
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self._connection: ConnectionSession | None = None
|
||||
|
||||
# Cache for extra fields extracted from raw webhook payloads, keyed by message id
|
||||
self._extra_data_cache: dict[str, dict] = {}
|
||||
@@ -55,6 +115,13 @@ class QQOfficialWebhook:
|
||||
# 直接注入到 botpy 的 Client,移花接木!
|
||||
self.client.api = self.api
|
||||
self.client.http = self.http
|
||||
self._setup_connection()
|
||||
|
||||
def _setup_connection(self) -> None:
|
||||
if self._connection is not None:
|
||||
return
|
||||
self.client.api = self.api
|
||||
self.client.http = self.http
|
||||
|
||||
async def bot_connect() -> None:
|
||||
pass
|
||||
@@ -105,7 +172,24 @@ class QQOfficialWebhook:
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
msg: dict = await request.json
|
||||
body = await request.get_data()
|
||||
if not _verify_qq_webhook_signature(
|
||||
self.secret,
|
||||
request.headers.get(_SIGNATURE_TIMESTAMP_HEADER),
|
||||
request.headers.get(_SIGNATURE_HEADER),
|
||||
body,
|
||||
):
|
||||
logger.warning("qq_official_webhook signature verification failed.")
|
||||
return {"error": "Invalid signature"}, 401
|
||||
|
||||
try:
|
||||
msg = json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("qq_official_webhook callback body is not valid JSON.")
|
||||
return {"error": "Invalid JSON"}, 400
|
||||
if not isinstance(msg, dict):
|
||||
return {"error": "Invalid JSON"}, 400
|
||||
|
||||
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
|
||||
|
||||
event = msg.get("t")
|
||||
@@ -136,6 +220,13 @@ class QQOfficialWebhook:
|
||||
|
||||
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
|
||||
event = msg["t"].lower()
|
||||
if self._connection is None:
|
||||
logger.warning(
|
||||
"qq_official_webhook botpy connection is not initialized; "
|
||||
"creating parser connection lazily.",
|
||||
)
|
||||
self._setup_connection()
|
||||
|
||||
# Extract extra fields from raw payload before botpy parses and discards them
|
||||
if data:
|
||||
msg_id = data.get("id")
|
||||
|
||||
@@ -130,7 +130,7 @@ class WeixinOCAdapter(Platform):
|
||||
platform_config.get("weixin_oc_long_poll_timeout_ms", 35_000),
|
||||
)
|
||||
self.api_timeout_ms = int(
|
||||
platform_config.get("weixin_oc_api_timeout_ms", 15_000),
|
||||
platform_config.get("weixin_oc_api_timeout_ms", 120_000),
|
||||
)
|
||||
self.cdn_base_url = str(
|
||||
platform_config.get(
|
||||
|
||||
@@ -302,12 +302,14 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
||||
def _extract_usage(self, usage: Usage | None) -> TokenUsage:
|
||||
if usage is None:
|
||||
return TokenUsage()
|
||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
||||
return TokenUsage(
|
||||
input_other=usage.input_tokens or 0,
|
||||
input_cached=usage.cache_read_input_tokens or 0,
|
||||
output=usage.output_tokens,
|
||||
output=usage.output_tokens or 0,
|
||||
)
|
||||
|
||||
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
@@ -9,6 +9,9 @@ from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -17,11 +20,9 @@ class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||
|
||||
author: str
|
||||
name: str
|
||||
context: Context
|
||||
|
||||
class _ContextLike(Protocol):
|
||||
def get_config(self, umo: str | None = None) -> Any: ...
|
||||
|
||||
def __init__(self, context: _ContextLike, config: dict | None = None) -> None:
|
||||
def __init__(self, context: Context, config: dict | None = None) -> None:
|
||||
self.context = context
|
||||
|
||||
def _get_context_config(self) -> Any:
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import shlex
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
@@ -14,9 +15,55 @@ from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.computer.computer_client import get_booter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.tools.computer_tools.util import check_admin_permission
|
||||
from astrbot.core.tools.computer_tools.util import (
|
||||
check_admin_permission,
|
||||
is_local_runtime,
|
||||
workspace_root,
|
||||
)
|
||||
from astrbot.core.tools.registry import builtin_tool
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_system_tmp_path,
|
||||
get_astrbot_temp_path,
|
||||
)
|
||||
|
||||
|
||||
def _file_send_allowed_roots(umo: str | None) -> tuple[Path, ...]:
|
||||
roots = []
|
||||
if umo:
|
||||
roots.append(workspace_root(umo))
|
||||
roots.extend(
|
||||
[
|
||||
Path(get_astrbot_temp_path()).resolve(strict=False),
|
||||
Path(get_astrbot_system_tmp_path()).resolve(strict=False),
|
||||
]
|
||||
)
|
||||
return tuple(roots)
|
||||
|
||||
|
||||
def _is_path_within(path: Path, roots: tuple[Path, ...]) -> bool:
|
||||
return any(path == root or path.is_relative_to(root) for root in roots)
|
||||
|
||||
|
||||
def _is_restricted_local_env(context: ContextWrapper[AstrAgentContext]) -> bool:
|
||||
if not is_local_runtime(context):
|
||||
return False
|
||||
cfg = context.context.context.get_config(
|
||||
umo=context.context.event.unified_msg_origin
|
||||
)
|
||||
provider_settings = cfg.get("provider_settings", {})
|
||||
require_admin = provider_settings.get("computer_use_require_admin", True)
|
||||
return require_admin and context.context.event.role != "admin"
|
||||
|
||||
|
||||
def _can_send_local_file(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
local_path: Path,
|
||||
) -> bool:
|
||||
umo = context.context.event.unified_msg_origin
|
||||
allowed_roots = _file_send_allowed_roots(umo)
|
||||
if _is_path_within(local_path, allowed_roots):
|
||||
return True
|
||||
return is_local_runtime(context) and not _is_restricted_local_env(context)
|
||||
|
||||
|
||||
@builtin_tool
|
||||
@@ -85,23 +132,38 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
*,
|
||||
component_type: str = "file",
|
||||
) -> tuple[str, bool]:
|
||||
path = str(path)
|
||||
# if the path is relative, check if the file exists in user's local workspace
|
||||
path = str(path).strip()
|
||||
if not path:
|
||||
raise FileNotFoundError(f"{component_type} path is empty")
|
||||
|
||||
# Relative host paths are resolved only inside the user's workspace.
|
||||
if not os.path.isabs(path):
|
||||
unified_msg_origin = context.context.event.unified_msg_origin
|
||||
if unified_msg_origin:
|
||||
from astrbot.core.tools.computer_tools.util import workspace_root
|
||||
|
||||
try:
|
||||
ws_path = workspace_root(unified_msg_origin)
|
||||
ws_candidate = (ws_path / path).resolve()
|
||||
ws_candidate = (ws_path / path).resolve(strict=False)
|
||||
if ws_candidate.is_file() and ws_candidate.is_relative_to(ws_path):
|
||||
return str(ws_candidate), False
|
||||
except Exception:
|
||||
pass
|
||||
# check if the file exists in local environment (only allow absolute paths to prevent traversal)
|
||||
elif os.path.isfile(path):
|
||||
return path, False
|
||||
else:
|
||||
local_candidate = Path(path).expanduser().resolve(strict=False)
|
||||
if local_candidate.is_file():
|
||||
if _can_send_local_file(context, local_candidate):
|
||||
return str(local_candidate), False
|
||||
if is_local_runtime(context):
|
||||
allowed = ", ".join(
|
||||
str(root)
|
||||
for root in _file_send_allowed_roots(
|
||||
context.context.event.unified_msg_origin
|
||||
)
|
||||
)
|
||||
raise PermissionError(
|
||||
"Local file send is restricted for this user. "
|
||||
f"Allowed directories: {allowed}. "
|
||||
f"Blocked path: {local_candidate}."
|
||||
)
|
||||
|
||||
try:
|
||||
sb = await get_booter(
|
||||
@@ -221,6 +283,8 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
return f"error: {exc}"
|
||||
except PermissionError as exc:
|
||||
return f"error: {exc}"
|
||||
except Exception as exc:
|
||||
return f"error: failed to build messages[{idx}] component: {exc}"
|
||||
|
||||
|
||||
35
changelogs/v4.25.4.md
Normal file
35
changelogs/v4.25.4.md
Normal file
@@ -0,0 +1,35 @@
|
||||
- [更新日志(简体中文)](#chinese)
|
||||
- [Changelog(English)](#english)
|
||||
|
||||
<a id="chinese"></a>
|
||||
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
- 回滚部分改动,修复偶现的 `Database is locked` 的问题。([#8639](https://github.com/AstrBotDevs/AstrBot/pull/8639))
|
||||
- 修复 Pipeline 异步任务可能因缺少强引用而被垃圾回收的问题,提升事件处理稳定性。([#8618](https://github.com/AstrBotDevs/AstrBot/pull/8618))
|
||||
- 修复 WebChat 使用 Web 搜索工具时,引用提示词在同一轮对话多次工具调用后被重复追加到系统消息的问题,避免破坏上下文缓存。([#8642](https://github.com/AstrBotDevs/AstrBot/pull/8642))
|
||||
- 同步 Dashboard `pnpm-lock.yaml` 中的 overrides 配置,修复锁文件与工作区配置不一致的问题。([#8637](https://github.com/AstrBotDevs/AstrBot/pull/8637))
|
||||
|
||||
### 优化
|
||||
|
||||
- 将微信公众号 HTTP API 请求超时时间从 15 秒提升到 120 秒,降低较慢网络或接口响应下下载文件超时失败概率。([#8643](https://github.com/AstrBotDevs/AstrBot/pull/8643))
|
||||
- Dashboard 登录表单启用完整凭据自动填充,改善浏览器密码管理器的使用体验。([#8631](https://github.com/AstrBotDevs/AstrBot/pull/8631))
|
||||
|
||||
<a id="english"></a>
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fixed repeated Web search citation prompt appends in WebChat after multiple tool calls within the same interaction, preventing context cache invalidation. ([#8642](https://github.com/AstrBotDevs/AstrBot/pull/8642))
|
||||
- Fixed Pipeline async tasks potentially being garbage-collected due to missing strong references, improving event processing stability. ([#8618](https://github.com/AstrBotDevs/AstrBot/pull/8618))
|
||||
- Synced Dashboard `pnpm-lock.yaml` overrides with the workspace configuration. ([#8637](https://github.com/AstrBotDevs/AstrBot/pull/8637))
|
||||
- Reverted the Provider stats SQLite lock retry change to avoid related regressions. ([#8639](https://github.com/AstrBotDevs/AstrBot/pull/8639))
|
||||
- Reverted the macOS SQLAlchemy compatibility changes to avoid regressions in database initialization and vector storage paths. ([#8638](https://github.com/AstrBotDevs/AstrBot/pull/8638))
|
||||
|
||||
### Improvements
|
||||
|
||||
- Increased the WeChat Official Account HTTP API request timeout from 15 seconds to 120 seconds, reducing timeout failures on slower networks or API responses. ([#8643](https://github.com/AstrBotDevs/AstrBot/pull/8643))
|
||||
- Enabled full credential autofill on the Dashboard login form for better browser password manager support. ([#8631](https://github.com/AstrBotDevs/AstrBot/pull/8631))
|
||||
30
changelogs/v4.25.5.md
Normal file
30
changelogs/v4.25.5.md
Normal file
@@ -0,0 +1,30 @@
|
||||
- [更新日志(简体中文)](#chinese)
|
||||
- [Changelog(English)](#english)
|
||||
|
||||
<a id="chinese"></a>
|
||||
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
- 收紧消息工具对本地文件路径的处理边界,减少非预期路径被用于消息附件的情况。([#8660](https://github.com/AstrBotDevs/AstrBot/pull/8660))
|
||||
- 修复 Star Context 类型定义,恢复相关 SDK 类型提示与运行兼容性。([#8659](https://github.com/AstrBotDevs/AstrBot/pull/8659))
|
||||
- 修复 QQ 官方 Webhook 模式无法正常重启的问题。
|
||||
|
||||
### 优化
|
||||
|
||||
- 改进 Anthropic 在内容过滤响应中缺失 `usage` 字段时的处理,避免相关请求结果解析异常。([#8647](https://github.com/AstrBotDevs/AstrBot/pull/8647))
|
||||
|
||||
<a id="english"></a>
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Tightened local file path handling in the message tool to avoid unintended attachment path usage. ([#8660](https://github.com/AstrBotDevs/AstrBot/pull/8660))
|
||||
- Fixed Star Context typing to restore related SDK type hints and runtime compatibility. ([#8659](https://github.com/AstrBotDevs/AstrBot/pull/8659))
|
||||
- Fixed QQ Official Webhook mode not restarting correctly.
|
||||
|
||||
### Improvements
|
||||
|
||||
- Improved Anthropic response parsing when content-filtered responses omit the `usage` field. ([#8647](https://github.com/AstrBotDevs/AstrBot/pull/8647))
|
||||
1
dashboard/pnpm-lock.yaml
generated
1
dashboard/pnpm-lock.yaml
generated
@@ -998,6 +998,7 @@ packages:
|
||||
|
||||
'@ungap/structured-clone@1.3.0':
|
||||
resolution: {integrity: sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==}
|
||||
deprecated: Potential CWE-502 - Update to 1.3.1 or higher
|
||||
|
||||
'@vitejs/plugin-vue@5.2.4':
|
||||
resolution: {integrity: sha512-7Yx/SXSOcQq5HiiV3orevHUFn+pmMB4cgbEkDYgnkUWb0WfeQ/wa2yFv6D5ICiCQOVpjA7vYDXrC7AGO8yjDHA==}
|
||||
|
||||
3
dashboard/pnpm-workspace.yaml
Normal file
3
dashboard/pnpm-workspace.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
allowBuilds:
|
||||
esbuild: true
|
||||
vue-demi: true
|
||||
@@ -27,6 +27,7 @@ function onSubmit() {
|
||||
<v-text-field
|
||||
:model-value="props.username"
|
||||
:label="t('username')"
|
||||
autocomplete="username"
|
||||
class="mb-6 input-field"
|
||||
required
|
||||
hide-details="auto"
|
||||
@@ -40,6 +41,7 @@ function onSubmit() {
|
||||
<v-text-field
|
||||
:model-value="props.password"
|
||||
:label="t('password')"
|
||||
autocomplete="current-password"
|
||||
required
|
||||
variant="outlined"
|
||||
hide-details="auto"
|
||||
|
||||
3
main.py
3
main.py
@@ -54,9 +54,6 @@ def check_env() -> None:
|
||||
|
||||
site_packages_path = get_astrbot_site_packages_path()
|
||||
if not is_packaged_desktop_runtime() and site_packages_path not in sys.path:
|
||||
# Packaged desktop runtime keeps shared plugin dependencies out of the
|
||||
# global import path so bundled core libraries don't mix with user-
|
||||
# installed wheels from ~/.astrbot/data/site-packages.
|
||||
sys.path.append(site_packages_path)
|
||||
|
||||
os.makedirs(get_astrbot_config_path(), exist_ok=True)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.25.3"
|
||||
version = "4.25.5"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
license = { text = "AGPL-3.0-or-later" }
|
||||
|
||||
@@ -483,6 +483,40 @@ def _setup_provider_with_mock_client(monkeypatch) -> anthropic_source.ProviderAn
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_handles_none_usage_when_content_filtered(monkeypatch):
|
||||
provider = _setup_provider_with_mock_client(monkeypatch)
|
||||
content_filter_message = (
|
||||
"The request was rejected because it was considered high risk"
|
||||
)
|
||||
|
||||
class _FakeMessageBlock:
|
||||
def __init__(self, text: str):
|
||||
self.type = "text"
|
||||
self.text = text
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self):
|
||||
self.id = "msg_content_filter"
|
||||
self.content = [_FakeMessageBlock(content_filter_message)]
|
||||
self.stop_reason = "content_filter"
|
||||
self.usage = None
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
return _FakeMessage()
|
||||
|
||||
monkeypatch.setattr(anthropic_source, "Message", _FakeMessage)
|
||||
provider.client.messages.create = fake_create
|
||||
|
||||
llm_response = await provider.text_chat(prompt="test")
|
||||
|
||||
assert llm_response.completion_text == content_filter_message
|
||||
assert llm_response.usage is not None
|
||||
assert llm_response.usage.input_other == 0
|
||||
assert llm_response.usage.input_cached == 0
|
||||
assert llm_response.usage.output == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_auto_converts_to_dict(monkeypatch):
|
||||
"""tool_choice='auto' 应转换为 {'type': 'auto'}"""
|
||||
|
||||
124
tests/test_qqofficial_webhook_signature.py
Normal file
124
tests/test_qqofficial_webhook_signature.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.platform.sources.qqofficial_webhook.qo_webhook_server import (
|
||||
_SIGNATURE_HEADER,
|
||||
_SIGNATURE_TIMESTAMP_HEADER,
|
||||
QQOfficialWebhook,
|
||||
_sign_qq_webhook_payload,
|
||||
_verify_qq_webhook_signature,
|
||||
)
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, body: bytes, headers: dict[str, str] | None = None) -> None:
|
||||
self._body = body
|
||||
self.headers = headers or {}
|
||||
|
||||
async def get_data(self) -> bytes:
|
||||
return self._body
|
||||
|
||||
|
||||
class FakeBotpyClient:
|
||||
api = None
|
||||
http = None
|
||||
|
||||
def ws_dispatch(self, *_args, **_kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_qq_webhook_signature_verification_accepts_valid_signature():
|
||||
secret = "test-secret"
|
||||
timestamp = "1710000000"
|
||||
body = b'{"op":12,"d":0}'
|
||||
signature = _sign_qq_webhook_payload(secret, timestamp, body)
|
||||
|
||||
assert _verify_qq_webhook_signature(secret, timestamp, signature, body)
|
||||
|
||||
|
||||
def test_qq_webhook_signature_verification_rejects_tampered_body():
|
||||
secret = "test-secret"
|
||||
timestamp = "1710000000"
|
||||
body = b'{"op":12,"d":0}'
|
||||
signature = _sign_qq_webhook_payload(secret, timestamp, body)
|
||||
|
||||
assert not _verify_qq_webhook_signature(
|
||||
secret,
|
||||
timestamp,
|
||||
signature,
|
||||
b'{"op":12,"d":1}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qq_webhook_callback_rejects_missing_signature():
|
||||
webhook = object.__new__(QQOfficialWebhook)
|
||||
webhook.secret = "test-secret"
|
||||
|
||||
result = await webhook.handle_callback(FakeRequest(b'{"op":12,"d":0}'))
|
||||
|
||||
assert result == ({"error": "Invalid signature"}, 401)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qq_webhook_callback_accepts_signed_validation():
|
||||
secret = "test-secret"
|
||||
event_ts = "1710000000"
|
||||
plain_token = "plain-token"
|
||||
body = json.dumps(
|
||||
{"op": 13, "d": {"event_ts": event_ts, "plain_token": plain_token}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
signature = _sign_qq_webhook_payload(secret, event_ts, body)
|
||||
webhook = object.__new__(QQOfficialWebhook)
|
||||
webhook.secret = secret
|
||||
|
||||
result = await webhook.handle_callback(
|
||||
FakeRequest(
|
||||
body,
|
||||
{
|
||||
_SIGNATURE_TIMESTAMP_HEADER: event_ts,
|
||||
_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"plain_token": plain_token,
|
||||
"signature": _sign_qq_webhook_payload(secret, event_ts, plain_token.encode()),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qq_webhook_callback_lazily_creates_botpy_connection():
|
||||
secret = "test-secret"
|
||||
timestamp = "1710000000"
|
||||
body = json.dumps(
|
||||
{"op": 0, "t": "UNKNOWN_EVENT", "id": "event-id", "d": {"id": "message-id"}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
signature = _sign_qq_webhook_payload(secret, timestamp, body)
|
||||
webhook = QQOfficialWebhook(
|
||||
{"appid": "123", "secret": secret},
|
||||
asyncio.Queue(),
|
||||
FakeBotpyClient(),
|
||||
)
|
||||
|
||||
result = await webhook.handle_callback(
|
||||
FakeRequest(
|
||||
body,
|
||||
{
|
||||
_SIGNATURE_TIMESTAMP_HEADER: timestamp,
|
||||
_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == {"opcode": 12}
|
||||
assert webhook._connection is not None
|
||||
assert webhook.http._token is not None
|
||||
assert webhook.http._token.app_id == "123"
|
||||
assert webhook.client.api is webhook.api
|
||||
assert webhook.client.http is webhook.http
|
||||
@@ -476,6 +476,46 @@ class TestBuiltinToolInjection:
|
||||
assert req.func_tool.get_tool("web_search_firecrawl") is search_tool
|
||||
assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool
|
||||
|
||||
def test_apply_web_search_citation_prompt_for_webchat(self, mock_event):
|
||||
module = ama
|
||||
req = ProviderRequest(system_prompt="base")
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_tavily"
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(search_tool)
|
||||
mock_event.get_platform_name.return_value = "webchat"
|
||||
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
|
||||
assert module.WEB_SEARCH_CITATION_PROMPT in req.system_prompt
|
||||
|
||||
def test_apply_web_search_citation_prompt_is_idempotent(self, mock_event):
|
||||
module = ama
|
||||
req = ProviderRequest(system_prompt="")
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_tavily"
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(search_tool)
|
||||
mock_event.get_platform_name.return_value = "webchat"
|
||||
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
|
||||
assert req.system_prompt.count(module.WEB_SEARCH_CITATION_PROMPT) == 1
|
||||
|
||||
def test_apply_web_search_citation_prompt_requires_webchat(self, mock_event):
|
||||
module = ama
|
||||
req = ProviderRequest(system_prompt="")
|
||||
search_tool = MagicMock(spec=FunctionTool)
|
||||
search_tool.name = "web_search_tavily"
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(search_tool)
|
||||
mock_event.get_platform_name.return_value = "test_platform"
|
||||
|
||||
module._apply_web_search_citation_prompt(mock_event, req)
|
||||
|
||||
assert module.WEB_SEARCH_CITATION_PROMPT not in req.system_prompt
|
||||
|
||||
def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context):
|
||||
"""Test cron tool injection through the builtin tool manager."""
|
||||
module = ama
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import sqlite3
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
@@ -102,38 +101,3 @@ async def test_document_storage_fts_recovers_from_legacy_non_fts_table(tmp_path)
|
||||
assert [result["doc_id"] for result in results] == ["legacy-fix"]
|
||||
|
||||
await storage.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_storage_adds_unique_doc_id_index_to_existing_table(tmp_path):
|
||||
db_path = tmp_path / "doc.db"
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
doc_id VARCHAR NOT NULL,
|
||||
text VARCHAR NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
)
|
||||
""",
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO documents (doc_id, text) VALUES ('legacy-chunk', 'legacy text')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
storage = DocumentStorage(str(db_path))
|
||||
await storage.initialize()
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await storage.insert_document(
|
||||
doc_id="legacy-chunk",
|
||||
text="duplicate text",
|
||||
metadata={},
|
||||
)
|
||||
|
||||
await storage.close()
|
||||
|
||||
@@ -12,9 +12,15 @@ def _make_context(
|
||||
current_session="feishu:GroupMessage:oc_xxx",
|
||||
role="admin",
|
||||
require_admin=True,
|
||||
runtime="local",
|
||||
):
|
||||
"""Build a minimal ContextWrapper for SendMessageToUserTool."""
|
||||
cfg = {"provider_settings": {"computer_use_require_admin": require_admin}}
|
||||
cfg = {
|
||||
"provider_settings": {
|
||||
"computer_use_require_admin": require_admin,
|
||||
"computer_use_runtime": runtime,
|
||||
}
|
||||
}
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(
|
||||
event=SimpleNamespace(
|
||||
@@ -161,3 +167,71 @@ async def test_send_message_missing_image_path_stops_before_send(tmp_path, monke
|
||||
|
||||
assert "error: failed to build messages[1] component: sandbox unavailable" in result
|
||||
ctx.context.context.send_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_cannot_send_arbitrary_local_absolute_file(tmp_path):
|
||||
"""Non-admin users cannot send host files outside the allowed local roots."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(role="member", require_admin=True)
|
||||
secret_path = tmp_path / "secret.txt"
|
||||
secret_path.write_text("secret", encoding="utf-8")
|
||||
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "file", "path": str(secret_path)}],
|
||||
)
|
||||
|
||||
assert "error: Local file send is restricted for this user" in result
|
||||
assert str(secret_path) in result
|
||||
ctx.context.context.send_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_can_send_workspace_file(tmp_path, monkeypatch):
|
||||
"""Non-admin users can send files inside their per-session workspace."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(
|
||||
current_session="feishu:GroupMessage:oc_workspace",
|
||||
role="member",
|
||||
require_admin=True,
|
||||
)
|
||||
workspace_root = tmp_path / "workspaces"
|
||||
workspace_file = workspace_root / "feishu_GroupMessage_oc_workspace" / "result.txt"
|
||||
workspace_file.parent.mkdir(parents=True)
|
||||
workspace_file.write_text("result", encoding="utf-8")
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.tools.computer_tools.util.get_astrbot_workspaces_path",
|
||||
lambda: str(workspace_root),
|
||||
)
|
||||
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "file", "path": "result.txt"}],
|
||||
)
|
||||
|
||||
assert "Message sent to session" in result
|
||||
ctx.context.context.send_message.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_can_send_temp_file(tmp_path, monkeypatch):
|
||||
"""Non-admin users can send generated files under AstrBot temp."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(role="member", require_admin=True)
|
||||
temp_root = tmp_path / "temp"
|
||||
temp_root.mkdir()
|
||||
output_path = temp_root / "output.txt"
|
||||
output_path.write_text("output", encoding="utf-8")
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.tools.message_tools.get_astrbot_temp_path",
|
||||
lambda: str(temp_root),
|
||||
)
|
||||
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "file", "path": str(output_path)}],
|
||||
)
|
||||
|
||||
assert "Message sent to session" in result
|
||||
ctx.context.context.send_message.assert_called_once()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import select
|
||||
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
@@ -65,143 +63,3 @@ async def test_record_internal_agent_stats_persists_provider_stat(
|
||||
assert record.start_time == 100.0
|
||||
assert record.end_time == 108.5
|
||||
assert record.time_to_first_token == 0.6
|
||||
|
||||
|
||||
def _provider_stats_recording_args():
|
||||
event = SimpleNamespace(unified_msg_origin="webchat:FriendMessage:session-42")
|
||||
req = ProviderRequest(conversation=SimpleNamespace(cid="conv-123"))
|
||||
provider = SimpleNamespace(
|
||||
provider_config={"id": "provider-1"},
|
||||
meta=lambda: SimpleNamespace(id="provider-1", type="openai"),
|
||||
get_model=lambda: "gpt-4.1",
|
||||
)
|
||||
agent_runner = SimpleNamespace(
|
||||
provider=provider,
|
||||
stats=AgentStats(),
|
||||
was_aborted=lambda: False,
|
||||
)
|
||||
return event, req, agent_runner, SimpleNamespace(role="assistant")
|
||||
|
||||
|
||||
def _provider_stats_operational_error(message: str) -> OperationalError:
|
||||
return OperationalError("insert into provider_stats", {}, Exception(message))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"lock_message",
|
||||
["database is locked", "database table is locked"],
|
||||
)
|
||||
async def test_record_internal_agent_stats_retries_transient_database_locks(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
lock_message: str,
|
||||
):
|
||||
attempts = 0
|
||||
|
||||
class LockedOnceDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise _provider_stats_operational_error(lock_message)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", LockedOnceDb())
|
||||
|
||||
async def no_sleep(delay: float) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(internal.asyncio, "sleep", no_sleep)
|
||||
|
||||
await internal._record_internal_agent_stats(
|
||||
*_provider_stats_recording_args(),
|
||||
)
|
||||
|
||||
assert attempts == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_internal_agent_stats_logs_after_exhausting_database_lock_retries(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
attempts = 0
|
||||
sleep_delays = []
|
||||
warnings = []
|
||||
|
||||
class AlwaysLockedDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
raise _provider_stats_operational_error("database is locked")
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", AlwaysLockedDb())
|
||||
|
||||
async def record_sleep(delay: float) -> None:
|
||||
sleep_delays.append(delay)
|
||||
|
||||
monkeypatch.setattr(internal.asyncio, "sleep", record_sleep)
|
||||
monkeypatch.setattr(
|
||||
internal.logger,
|
||||
"warning",
|
||||
lambda *args, **kwargs: warnings.append((args, kwargs)),
|
||||
)
|
||||
|
||||
await internal._record_internal_agent_stats(*_provider_stats_recording_args())
|
||||
|
||||
assert attempts == internal.PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS
|
||||
base_delay = internal.PROVIDER_STATS_SQLITE_LOCK_RETRY_BASE_DELAY
|
||||
expected_sleep_delays = [
|
||||
base_delay * (2**attempt)
|
||||
for attempt in range(internal.PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS - 1)
|
||||
]
|
||||
assert sleep_delays == expected_sleep_delays
|
||||
assert len(warnings) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_internal_agent_stats_does_not_retry_other_operational_errors(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
attempts = 0
|
||||
warnings = []
|
||||
|
||||
class FailingDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
raise _provider_stats_operational_error("no such table: provider_stats")
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", FailingDb())
|
||||
monkeypatch.setattr(
|
||||
internal.logger,
|
||||
"warning",
|
||||
lambda *args, **kwargs: warnings.append((args, kwargs)),
|
||||
)
|
||||
|
||||
await internal._record_internal_agent_stats(*_provider_stats_recording_args())
|
||||
|
||||
assert attempts == 1
|
||||
assert len(warnings) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_internal_agent_stats_propagates_cancelled_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
warnings = []
|
||||
|
||||
class CancellingDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
raise asyncio.CancelledError
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", CancellingDb())
|
||||
monkeypatch.setattr(
|
||||
internal.logger,
|
||||
"warning",
|
||||
lambda *args, **kwargs: warnings.append((args, kwargs)),
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await internal._record_internal_agent_stats(*_provider_stats_recording_args())
|
||||
|
||||
assert warnings == []
|
||||
|
||||
Reference in New Issue
Block a user