Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
8ea47c87e5 refactor: migrate to fastapi 2026-06-08 10:23:17 +08:00
199 changed files with 36821 additions and 15808 deletions

View File

@@ -157,7 +157,7 @@ class Platform(abc.ABC):
当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。
Args:
request: Quart 请求对象
request: webhook 请求对象
Returns:
响应内容,格式取决于具体平台的要求

View File

@@ -132,7 +132,7 @@ class LarkWebhookServer:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: webhook 请求对象
Returns:
响应数据

View File

@@ -3,11 +3,11 @@ import logging
import time
from typing import cast
import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
from cryptography.hazmat.primitives.asymmetric import ed25519
from astrbot.api import logger
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
# remove logger handler
for handler in logging.root.handlers[:]:
@@ -31,7 +31,7 @@ class QQOfficialWebhook:
self.api: BotAPI = BotAPI(http=self.http)
self.token = Token(self.appid, self.secret)
self.server = quart.Quart(__name__)
self.server = FastAPIWebhookServer("qq-official-webhook")
self.server.add_url_rule(
"/astrbot-qo-webhook/callback",
view_func=self.callback,
@@ -92,15 +92,15 @@ class QQOfficialWebhook:
"""Pop and return extra fields cached from the raw webhook payload for a given message ID."""
return self._extra_data_cache.pop(message_id, {})
async def callback(self):
async def callback(self, request):
"""内部服务器的回调入口"""
return await self.handle_callback(quart.request)
return await self.handle_callback(request)
async def handle_callback(self, request) -> dict:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
响应数据

View File

@@ -2,11 +2,10 @@ import asyncio
import hashlib
import hmac
import json
import logging
from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request
from fastapi.responses import Response
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -14,10 +13,11 @@ from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient
from astrbot.api import logger
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
class SlackWebhookClient:
"""Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器"""
"""Slack Webhook 模式客户端,使用 FastAPI 作为 Web 服务器"""
def __init__(
self,
@@ -35,20 +35,16 @@ class SlackWebhookClient:
self.path = path
self.event_handler = event_handler
self.app = Quart(__name__)
self.app = FastAPIWebhookServer("slack-webhook")
self._setup_routes()
# 禁用 Quart 的默认日志输出
logging.getLogger("quart.app").setLevel(logging.WARNING)
logging.getLogger("quart.serving").setLevel(logging.WARNING)
self.shutdown_event = asyncio.Event()
def _setup_routes(self) -> None:
"""设置路由"""
@self.app.route(self.path, methods=["POST"])
async def slack_events():
async def slack_events(request):
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(request)
@@ -61,7 +57,7 @@ class SlackWebhookClient:
"""处理 Slack 回调请求,可被统一 webhook 入口复用
Args:
req: Quart 请求对象
req: webhook 请求对象
Returns:
Response 对象或字典
@@ -75,7 +71,7 @@ class SlackWebhookClient:
timestamp = req.headers.get("X-Slack-Request-Timestamp")
signature = req.headers.get("X-Slack-Signature")
if not timestamp or not signature:
return Response("Missing headers", status=400)
return Response("Missing headers", status_code=400)
# Calculate the HMAC signature
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
my_signature = (
@@ -89,7 +85,7 @@ class SlackWebhookClient:
# Verify the signature
if not hmac.compare_digest(my_signature, signature):
logger.warning("Slack request signature verification failed")
return Response("Invalid signature", status=400)
return Response("Invalid signature", status_code=400)
logger.info(f"Received Slack event: {event_data}")
# 处理 URL 验证事件
@@ -99,11 +95,11 @@ class SlackWebhookClient:
if self.event_handler and event_data.get("type") == "event_callback":
await self.event_handler(event_data)
return Response("", status=200)
return Response("", status_code=200)
except Exception as e:
logger.error(f"处理 Slack 事件时出错: {e}")
return Response("Internal Server Error", status=500)
return Response("Internal Server Error", status_code=500)
async def start(self) -> None:
"""启动 Webhook 服务器"""

View File

@@ -8,7 +8,6 @@ from pathlib import Path
from typing import Any, cast
from urllib.parse import unquote
import quart
from requests import Response
from wechatpy.enterprise import WeChatClient, parse_message
from wechatpy.enterprise.crypto import WeChatCrypto
@@ -28,6 +27,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info
@@ -65,7 +65,7 @@ def _extract_wecom_media_filename(disposition: str | None) -> str | None:
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
self.server = quart.Quart(__name__)
self.server = FastAPIWebhookServer("wecom-webhook")
self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule(
@@ -89,15 +89,15 @@ class WecomServer:
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
async def verify(self, request):
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
return await self.handle_verify(request)
async def handle_verify(self, request) -> str:
"""处理验证请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
验证响应
@@ -117,15 +117,15 @@ class WecomServer:
logger.error("验证请求有效性失败,签名异常,请检查配置。")
raise
async def callback_command(self):
async def callback_command(self, request):
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(quart.request)
return await self.handle_callback(request)
async def handle_callback(self, request) -> str:
"""处理回调请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
响应内容

View File

@@ -6,9 +6,8 @@ import asyncio
from collections.abc import Callable
from typing import Any
import quart
from astrbot.api import logger
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_utils import WecomAIBotConstants
@@ -38,14 +37,13 @@ class WecomAIBotServer:
self.api_client = api_client
self.message_handler = message_handler
self.app = quart.Quart(__name__)
self.app = FastAPIWebhookServer("wecom-ai-bot-webhook")
self._setup_routes()
self.shutdown_event = asyncio.Event()
def _setup_routes(self) -> None:
"""设置 Quart 路由"""
# 使用 Quart 的 add_url_rule 方法添加路由
"""设置 FastAPI 路由"""
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.verify_url,
@@ -58,15 +56,15 @@ class WecomAIBotServer:
methods=["POST"],
)
async def verify_url(self):
async def verify_url(self, request):
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
return await self.handle_verify(request)
async def handle_verify(self, request):
"""处理 URL 验证请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
验证响应元组 (content, status_code, headers)
@@ -91,15 +89,15 @@ class WecomAIBotServer:
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
return result, 200, {"Content-Type": "text/plain"}
async def handle_message(self):
async def handle_message(self, request):
"""内部服务器的 POST 消息回调入口"""
return await self.handle_callback(quart.request)
return await self.handle_callback(request)
async def handle_callback(self, request):
"""处理消息回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
响应元组 (content, status_code, headers)
@@ -186,5 +184,5 @@ class WecomAIBotServer:
self.shutdown_event.set()
def get_app(self):
"""获取 Quart 应用实例"""
return self.app
"""获取 FastAPI 应用实例"""
return self.app.app

View File

@@ -5,7 +5,6 @@ import time
from collections.abc import Callable, Coroutine
from typing import Any, cast
import quart
from requests import Response
from wechatpy import WeChatClient, create_reply, parse_message
from wechatpy.crypto import WeChatCrypto
@@ -25,6 +24,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info
@@ -44,7 +44,7 @@ class WeixinOfficialAccountServer:
config: dict,
user_buffer: dict[Any, dict[str, Any]],
) -> None:
self.server = quart.Quart(__name__)
self.server = FastAPIWebhookServer("weixin-official-account-webhook")
self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token")
@@ -73,15 +73,15 @@ class WeixinOfficialAccountServer:
self.user_buffer: dict[str, dict[str, Any]] = user_buffer # from_user -> state
self.active_send_mode = False # 是否启用主动发送模式,启用后 callback 将直接返回回复内容,无需等待微信回调
async def verify(self):
async def verify(self, request):
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
return await self.handle_verify(request)
async def handle_verify(self, request) -> str:
"""处理验证请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
验证响应
@@ -105,9 +105,9 @@ class WeixinOfficialAccountServer:
logger.error("验证请求有效性失败,签名异常,请检查配置。")
return "err"
async def callback_command(self):
async def callback_command(self, request):
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(quart.request)
return await self.handle_callback(request)
def _maybe_encrypt(self, xml: str, nonce: str | None, timestamp: str | None) -> str:
if xml and "<Encrypt>" not in xml and nonce and timestamp:
@@ -129,7 +129,7 @@ class WeixinOfficialAccountServer:
"""处理回调请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
request: FastAPI webhook request 对象
Returns:
响应内容

View File

@@ -0,0 +1,107 @@
from __future__ import annotations
import inspect
from collections.abc import Callable
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from hypercorn.asyncio import serve
from hypercorn.config import Config as HyperConfig
class WebhookRequest:
def __init__(self, request: Request) -> None:
self._request = request
self.args = request.query_params
self.headers = request.headers
self.method = request.method
@property
def json(self):
return self._request.json()
async def get_data(self) -> bytes:
return await self._request.body()
async def get_json(self, *, force: bool = False, silent: bool = False):
try:
return await self._request.json()
except Exception:
if silent:
return None
raise
def _response_from_result(result: Any):
if isinstance(result, Response):
return result
if isinstance(result, tuple):
content = result[0] if result else ""
status_code = (
result[1] if len(result) > 1 and isinstance(result[1], int) else 200
)
headers = result[2] if len(result) > 2 and isinstance(result[2], dict) else None
if isinstance(content, dict | list):
return JSONResponse(content, status_code=status_code, headers=headers)
return Response(
content=content,
status_code=status_code,
headers=headers,
media_type=headers.get("Content-Type") if headers else None,
)
if isinstance(result, dict | list):
return JSONResponse(result)
return result
class FastAPIWebhookServer:
def __init__(self, name: str) -> None:
self.app = FastAPI(title=name, docs_url=None, redoc_url=None, openapi_url=None)
def add_url_rule(
self,
path: str,
view_func: Callable,
methods: list[str] | None = None,
) -> None:
async def endpoint(request: Request):
if inspect.signature(view_func).parameters:
result = view_func(WebhookRequest(request))
else:
result = view_func()
if inspect.isawaitable(result):
result = await result
return _response_from_result(result)
self.app.add_api_route(
path,
endpoint,
methods=methods or ["GET"],
include_in_schema=False,
)
def route(self, path: str, methods: list[str] | None = None):
def decorator(view_func: Callable):
self.add_url_rule(path, view_func, methods)
return view_func
return decorator
async def run_task(
self,
*,
host: str,
port: int,
shutdown_trigger: Callable | None = None,
**_kwargs,
) -> None:
config = HyperConfig()
config.bind = [f"{host}:{port}"]
await serve(self.app, config, shutdown_trigger=shutdown_trigger)
async def shutdown(self) -> None:
return None

View File

@@ -0,0 +1,598 @@
from __future__ import annotations
import contextvars
import inspect
import re
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from pathlib import Path
from typing import Any
import httpx
from fastapi import FastAPI, HTTPException, Request, WebSocket
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse, JSONResponse, Response
from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.responses import StreamingResponse
_request_var: contextvars.ContextVar[CompatRequest] = contextvars.ContextVar(
"dashboard_request"
)
_websocket_var: contextvars.ContextVar[CompatWebSocket] = contextvars.ContextVar(
"dashboard_websocket"
)
_g_var: contextvars.ContextVar[CompatG] = contextvars.ContextVar("dashboard_g")
_app_var: contextvars.ContextVar[FastAPIAppAdapter] = contextvars.ContextVar(
"dashboard_app"
)
class CompatArgs:
def __init__(self, values) -> None:
self._values = values
def get(self, key: str, default: Any = None, type: Callable | None = None):
value = self._values.get(key, default)
if value is default or type is None:
return value
try:
return type(value)
except (TypeError, ValueError):
return default
class CompatMultiDict:
def __init__(self, pairs: list[tuple[str, Any]]) -> None:
self._pairs = pairs
def get(self, key: str, default: Any = None, type: Callable | None = None):
for item_key, item_value in reversed(self._pairs):
if item_key != key:
continue
if type is None:
return item_value
try:
return type(item_value)
except (TypeError, ValueError):
return default
return default
def getlist(self, key: str) -> list[Any]:
return [item_value for item_key, item_value in self._pairs if item_key == key]
def keys(self):
return dict.fromkeys(item_key for item_key, _ in self._pairs).keys()
def values(self):
return [self[key] for key in self.keys()]
def items(self):
return [(key, self[key]) for key in self.keys()]
def __contains__(self, key: str) -> bool:
return any(item_key == key for item_key, _ in self._pairs)
def __getitem__(self, key: str):
value = self.get(key)
if value is None and key not in self:
raise KeyError(key)
return value
def __bool__(self) -> bool:
return bool(self._pairs)
class CompatUploadFile:
def __init__(self, upload_file: StarletteUploadFile) -> None:
self._upload_file = upload_file
self.filename = upload_file.filename
self.content_type = upload_file.content_type
self.headers = upload_file.headers
self.content_length = self._resolve_content_length()
def _resolve_content_length(self) -> int | None:
try:
raw = self.headers.get("content-length")
return int(raw) if raw else None
except (TypeError, ValueError):
return None
async def save(self, destination: str | Path) -> None:
path = Path(destination)
try:
await self._upload_file.seek(0)
except Exception:
pass
with path.open("wb") as output:
while True:
chunk = await self._upload_file.read(1024 * 1024)
if not chunk:
break
output.write(chunk)
def __getattr__(self, key: str):
return getattr(self._upload_file, key)
class CompatG:
def __init__(self) -> None:
self._values: dict[str, Any] = {}
def get(self, key: str, default: Any = None):
return self._values.get(key, default)
def __getattr__(self, key: str):
try:
return self._values[key]
except KeyError as exc:
raise AttributeError(key) from exc
def __setattr__(self, key: str, value: Any) -> None:
if key == "_values":
super().__setattr__(key, value)
return
self._values[key] = value
class CompatRequest:
def __init__(self, request: Request) -> None:
self._request = request
self.args = CompatArgs(request.query_params)
self.headers = request.headers
self.cookies = request.cookies
self.method = request.method
self.path = request.url.path
self.content_type = request.headers.get("content-type")
self.remote_addr = request.client.host if request.client else None
self._form_cache: CompatMultiDict | None = None
self._files_cache: CompatMultiDict | None = None
@property
def json(self):
return self.get_json()
@property
def files(self):
return self._load_files()
@property
def form(self):
return self._load_form()
async def get_json(self, silent: bool = False):
try:
return await self._request.json()
except Exception:
if silent:
return None
raise
async def _load_form_parts(self) -> None:
if self._form_cache is not None and self._files_cache is not None:
return
form = await self._request.form()
form_pairs: list[tuple[str, Any]] = []
file_pairs: list[tuple[str, Any]] = []
for key, value in form.multi_items():
if isinstance(value, StarletteUploadFile):
file_pairs.append((key, CompatUploadFile(value)))
else:
form_pairs.append((key, value))
self._form_cache = CompatMultiDict(form_pairs)
self._files_cache = CompatMultiDict(file_pairs)
async def _load_form(self) -> CompatMultiDict:
await self._load_form_parts()
assert self._form_cache is not None
return self._form_cache
async def _load_files(self) -> CompatMultiDict:
await self._load_form_parts()
assert self._files_cache is not None
return self._files_cache
class CompatWebSocket:
def __init__(self, websocket: WebSocket) -> None:
self._websocket = websocket
self.args = CompatArgs(websocket.query_params)
self.headers = websocket.headers
async def accept(self) -> None:
await self._websocket.accept()
async def receive_json(self):
return await self._websocket.receive_json()
async def send_json(self, payload: Any) -> None:
await self._websocket.send_json(payload)
async def close(self, code: int = 1000, reason: str | None = None) -> None:
await self._websocket.close(code=code, reason=reason or "")
class CompatTestHeaders:
def __init__(self, headers: httpx.Headers) -> None:
self._headers = headers
def getlist(self, key: str) -> list[str]:
values = self._headers.get_list(key)
if key.lower() == "set-cookie":
return [value.replace('=""', "=") for value in values]
return values
def get(self, key: str, default: Any = None):
value = self._headers.get(key, default)
if isinstance(value, str) and key.lower() == "set-cookie":
return value.replace('=""', "=")
return value
def __getitem__(self, key: str):
return self._headers[key]
def __contains__(self, key: str) -> bool:
return key in self._headers
class CompatTestResponse:
def __init__(self, response: httpx.Response) -> None:
self._response = response
self.status_code = response.status_code
self.headers = CompatTestHeaders(response.headers)
self.data = response.content
self.content = response.content
self.text = response.text
async def get_json(self):
return self._response.json()
async def get_data(self):
return self._response.content
class CompatTestClient:
def __init__(self, app: FastAPI) -> None:
self._client = httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://testserver",
)
@staticmethod
def _is_file_storage(value: Any) -> bool:
return hasattr(value, "stream") and hasattr(value, "filename")
@classmethod
def _file_tuple(cls, value: Any):
stream = value.stream
if hasattr(stream, "seek"):
stream.seek(0)
content = stream.read()
filename = getattr(value, "filename", "upload.bin")
content_type = getattr(value, "content_type", None)
return filename, content, content_type
@classmethod
def _normalize_data(cls, data: Any):
if not isinstance(data, dict):
return data, None
form: dict[str, Any] = {}
files: list[tuple[str, tuple]] = []
for key, value in data.items():
if cls._is_file_storage(value):
files.append((key, cls._file_tuple(value)))
continue
if isinstance(value, Iterable) and not isinstance(
value, str | bytes | dict
):
values = list(value)
if values and all(cls._is_file_storage(item) for item in values):
files.extend((key, cls._file_tuple(item)) for item in values)
continue
form[key] = value
return form, files or None
@classmethod
def _normalize_files(cls, files: Any):
if isinstance(files, dict):
items = files.items()
elif isinstance(files, Iterable) and not isinstance(files, str | bytes):
items = files
else:
return files
normalized_files: list[tuple[str, Any]] = []
for key, value in items:
if cls._is_file_storage(value):
normalized_files.append((key, cls._file_tuple(value)))
continue
if isinstance(value, Iterable) and not isinstance(
value, str | bytes | dict
):
values = list(value)
if values and all(cls._is_file_storage(item) for item in values):
normalized_files.extend(
(key, cls._file_tuple(item)) for item in values
)
continue
normalized_files.append((key, value))
return normalized_files
async def request(self, method: str, url: str, **kwargs):
data = kwargs.pop("data", None)
if data is not None and "files" not in kwargs:
normalized_data, files = self._normalize_data(data)
kwargs["data"] = normalized_data
if files:
kwargs["files"] = files
elif data is not None:
kwargs["data"] = data
if "files" in kwargs:
kwargs["files"] = self._normalize_files(kwargs["files"])
response = await self._client.request(method, url, **kwargs)
return CompatTestResponse(response)
async def get(self, url: str, **kwargs):
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs):
return await self.request("POST", url, **kwargs)
async def put(self, url: str, **kwargs):
return await self.request("PUT", url, **kwargs)
async def patch(self, url: str, **kwargs):
return await self.request("PATCH", url, **kwargs)
async def delete(self, url: str, **kwargs):
return await self.request("DELETE", url, **kwargs)
class _ContextProxy:
def __init__(self, var) -> None:
self._var = var
def __getattr__(self, key: str):
return getattr(self._var.get(), key)
def __setattr__(self, key: str, value: Any) -> None:
if key == "_var":
super().__setattr__(key, value)
return
setattr(self._var.get(), key, value)
request = _ContextProxy(_request_var)
websocket = _ContextProxy(_websocket_var)
g = _ContextProxy(_g_var)
current_app = _ContextProxy(_app_var)
@contextmanager
def bind_request_context(
request_: Request,
app: FastAPIAppAdapter,
g_obj: CompatG | None = None,
):
token_request = _request_var.set(CompatRequest(request_))
token_g = _g_var.set(g_obj or getattr(request_.state, "dashboard_g", CompatG()))
token_app = _app_var.set(app)
try:
yield _g_var.get()
finally:
_app_var.reset(token_app)
_g_var.reset(token_g)
_request_var.reset(token_request)
@contextmanager
def bind_websocket_context(
websocket_: WebSocket,
app: FastAPIAppAdapter,
g_obj: CompatG | None = None,
):
token_websocket = _websocket_var.set(CompatWebSocket(websocket_))
token_g = _g_var.set(g_obj or getattr(websocket_.state, "dashboard_g", CompatG()))
token_app = _app_var.set(app)
try:
yield
finally:
_app_var.reset(token_app)
_g_var.reset(token_g)
_websocket_var.reset(token_websocket)
def jsonify(payload: Any = None):
return JSONResponse(payload if payload is not None else {})
async def make_response(*args):
if not args:
return Response()
content = args[0]
status_code = args[1] if len(args) > 1 and isinstance(args[1], int) else None
headers = args[1] if len(args) > 1 and isinstance(args[1], dict) else None
if len(args) > 2 and isinstance(args[2], dict):
headers = args[2]
if isinstance(content, Response):
if status_code is not None:
content.status_code = status_code
if headers:
content.headers.update(headers)
return content
if hasattr(content, "__aiter__"):
return StreamingResponse(
content,
status_code=status_code or 200,
headers=headers,
)
return Response(
content=content,
status_code=status_code or 200,
headers=headers,
)
async def send_file(path: str | Path, mimetype: str | None = None, **kwargs):
filename = kwargs.get("attachment_filename") or kwargs.get("download_name")
as_attachment = bool(kwargs.get("as_attachment"))
return FileResponse(
path,
media_type=mimetype,
filename=filename if as_attachment else None,
)
def abort(status_code: int):
raise HTTPException(status_code=status_code)
def _convert_rule(path: str) -> str:
converted = re.sub(r"<path:([A-Za-z_][A-Za-z0-9_]*)>", r"{\1:path}", path)
converted = re.sub(r"<([A-Za-z_][A-Za-z0-9_]*)>", r"{\1}", converted)
return converted
async def _call_view(view_func: Callable, path_params: dict[str, Any]):
result = view_func(**path_params)
if inspect.isawaitable(result):
result = await result
return _coerce_view_result(result)
def _coerce_view_result(result: Any):
if isinstance(result, Response):
return result
if isinstance(result, tuple):
content = result[0] if result else None
status_code = next((item for item in result[1:] if isinstance(item, int)), 200)
headers = next(
(item for item in result[1:] if isinstance(item, dict)),
None,
)
if isinstance(content, Response):
content.status_code = status_code
if headers:
content.headers.update(headers)
return content
return _response_from_content(content, status_code=status_code, headers=headers)
if isinstance(result, dict | list):
return JSONResponse(jsonable_encoder(result))
return result
def _response_from_content(
content: Any,
*,
status_code: int,
headers: dict[str, str] | None = None,
):
if isinstance(content, dict | list):
return JSONResponse(
jsonable_encoder(content),
status_code=status_code,
headers=headers,
)
return Response(
content=content,
status_code=status_code,
headers=headers,
)
async def call_request_view(
request_: Request,
app: FastAPIAppAdapter,
view_func: Callable,
path_params: dict[str, Any] | None = None,
g_obj: CompatG | None = None,
):
with bind_request_context(request_, app, g_obj):
return await _call_view(view_func, path_params or {})
async def call_websocket_view(
websocket_: WebSocket,
app: FastAPIAppAdapter,
view_func: Callable,
path_params: dict[str, Any] | None = None,
*,
accept: bool = True,
):
if accept:
await websocket_.accept()
with bind_websocket_context(websocket_, app):
return await _call_view(view_func, path_params or {})
class FastAPIAppAdapter:
def __init__(self, app: FastAPI, static_folder: str | None = None) -> None:
self._app = app
self.static_folder = static_folder
self.config: dict[str, Any] = {}
self.debug = False
self.testing = False
self.name = "dashboard"
def add_url_rule(
self,
path: str,
view_func: Callable,
methods: list[str] | None = None,
endpoint: str | None = None,
) -> None:
route_path = _convert_rule(path)
methods = methods or ["GET"]
async def endpoint_func(request_: Request):
with bind_request_context(request_, self):
return await _call_view(view_func, dict(request_.path_params))
self._app.add_api_route(
route_path,
endpoint_func,
methods=methods,
name=endpoint,
include_in_schema=False,
)
def websocket(self, path: str):
route_path = _convert_rule(path)
def decorator(view_func: Callable):
async def endpoint_func(websocket_: WebSocket):
return await call_websocket_view(
websocket_,
self,
view_func,
dict(websocket_.path_params),
)
self._app.add_api_websocket_route(
route_path,
endpoint_func,
name=getattr(view_func, "__name__", None),
)
return view_func
return decorator
def errorhandler(self, _status_code: int):
def decorator(func: Callable):
return func
return decorator
async def send_static_file(self, filename: str):
if not self.static_folder:
raise HTTPException(status_code=404)
return FileResponse(Path(self.static_folder) / filename)
def test_client(self):
self.testing = True
return CompatTestClient(self._app)
CompatResponse = Response

View File

@@ -1,6 +1,6 @@
from urllib.parse import unquote
from quart import request
from astrbot.dashboard.fastapi_compat import request
PLUGIN_PAGE_CONTENT_PREFIX = "/api/plugin/page/content/"
PLUGIN_PAGE_BRIDGE_PATH = "/api/plugin/page/bridge-sdk.js"

View File

@@ -1,21 +1,17 @@
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from quart import g, request
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.datetime_utils import normalize_datetime_utc
from astrbot.dashboard.fastapi_compat import g, request
from astrbot.dashboard.services.api_key_service import (
ApiKeyService,
ApiKeyServiceError,
)
from .route import Response, Route, RouteContext
ALL_OPEN_API_SCOPES = ("chat", "config", "file", "im")
class ApiKeyRoute(Route):
def __init__(self, context: RouteContext, db: BaseDatabase) -> None:
super().__init__(context)
self.db = db
self.service = ApiKeyService(db)
self.routes = {
"/apikey/list": ("GET", self.list_api_keys),
"/apikey/create": ("POST", self.create_api_key),
@@ -25,119 +21,39 @@ class ApiKeyRoute(Route):
self.register_routes()
@staticmethod
def _normalize_utc(dt: datetime | None) -> datetime | None:
return normalize_datetime_utc(dt)
@classmethod
def _serialize_datetime(cls, dt: datetime | None) -> str | None:
normalized = cls._normalize_utc(dt)
if normalized is None:
return None
return normalized.astimezone().isoformat()
def _ok(data=None):
return Response().ok(data=data).__dict__
@staticmethod
def _hash_key(raw_key: str) -> str:
return hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
def _error(message: str):
return Response().error(message).__dict__
@staticmethod
def _serialize_api_key(key) -> dict:
expires_at = ApiKeyRoute._normalize_utc(key.expires_at)
return {
"key_id": key.key_id,
"name": key.name,
"key_prefix": key.key_prefix,
"scopes": key.scopes or [],
"created_by": key.created_by,
"created_at": ApiKeyRoute._serialize_datetime(key.created_at),
"updated_at": ApiKeyRoute._serialize_datetime(key.updated_at),
"last_used_at": ApiKeyRoute._serialize_datetime(key.last_used_at),
"expires_at": ApiKeyRoute._serialize_datetime(key.expires_at),
"revoked_at": ApiKeyRoute._serialize_datetime(key.revoked_at),
"is_revoked": key.revoked_at is not None,
"is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)),
}
async def _json_body(self):
return await request.json or {}
async def _run(self, operation):
try:
return self._ok(await operation())
except ApiKeyServiceError as exc:
return self._error(str(exc))
async def _run_json(self, operation):
payload = await self._json_body()
return await self._run(lambda: operation(payload))
async def list_api_keys(self):
keys = await self.db.list_api_keys()
return (
Response().ok(data=[self._serialize_api_key(key) for key in keys]).__dict__
)
return await self._run(self.service.list_api_keys)
async def create_api_key(self):
post_data = await request.json or {}
name = str(post_data.get("name", "")).strip() or "Untitled API Key"
scopes = post_data.get("scopes")
if scopes is None:
normalized_scopes = list(ALL_OPEN_API_SCOPES)
elif isinstance(scopes, list):
normalized_scopes = [
scope
for scope in scopes
if isinstance(scope, str) and scope in ALL_OPEN_API_SCOPES
]
normalized_scopes = list(dict.fromkeys(normalized_scopes))
if not normalized_scopes:
return Response().error("At least one valid scope is required").__dict__
else:
return Response().error("Invalid scopes").__dict__
expires_at = None
expires_in_days = post_data.get("expires_in_days")
if expires_in_days is not None:
try:
expires_in_days_int = int(expires_in_days)
except (TypeError, ValueError):
return Response().error("expires_in_days must be an integer").__dict__
if expires_in_days_int <= 0:
return (
Response().error("expires_in_days must be greater than 0").__dict__
)
expires_at = datetime.now(timezone.utc) + timedelta(
days=expires_in_days_int
return await self._run_json(
lambda payload: self.service.create_api_key_from_legacy_payload(
payload,
created_by=g.get("username", "unknown"),
)
raw_key = f"abk_{secrets.token_urlsafe(32)}"
key_hash = self._hash_key(raw_key)
key_prefix = raw_key[:12]
created_by = g.get("username", "unknown")
api_key = await self.db.create_api_key(
name=name,
key_hash=key_hash,
key_prefix=key_prefix,
scopes=normalized_scopes, # type: ignore
created_by=created_by,
expires_at=expires_at,
)
payload = self._serialize_api_key(api_key)
payload["api_key"] = raw_key
return Response().ok(data=payload).__dict__
async def revoke_api_key(self):
post_data = await request.json or {}
key_id = post_data.get("key_id")
if not key_id:
return Response().error("Missing key: key_id").__dict__
success = await self.db.revoke_api_key(key_id)
if not success:
return Response().error("API key not found").__dict__
return Response().ok().__dict__
return await self._run_json(self.service.revoke_api_key_from_legacy_payload)
async def delete_api_key(self):
post_data = await request.json or {}
key_id = post_data.get("key_id")
if not key_id:
return Response().error("Missing key: key_id").__dict__
success = await self.db.delete_api_key(key_id)
if not success:
return Response().error("API key not found").__dict__
return Response().ok().__dict__
return await self._run_json(self.service.delete_api_key_from_legacy_payload)

View File

@@ -1,73 +1,27 @@
import asyncio
import datetime
import os
import jwt
import pyotp
from quart import current_app, g, jsonify, make_response, request
from astrbot import logger
from astrbot.core import DEMO_MODE
from astrbot.core.utils.auth_password import (
is_default_dashboard_password,
is_legacy_dashboard_password,
validate_dashboard_password,
verify_dashboard_password,
from astrbot.dashboard.fastapi_compat import (
current_app,
g,
jsonify,
make_response,
request,
)
from astrbot.core.utils.totp import (
from astrbot.dashboard.services.auth_service import (
DASHBOARD_JWT_COOKIE_MAX_AGE,
DASHBOARD_JWT_COOKIE_NAME,
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
TOTP_TRUSTED_DEVICE_MAX_AGE,
TwoFactorCodeType,
consume_configured_totp_code,
consume_rotation_verified,
consume_totp_code,
generate_recovery_code,
is_totp_enabled,
is_totp_trusted_device_valid,
issue_totp_trusted_device,
revoke_user_trusted_devices,
set_pending_totp_secret,
set_rotation_verified,
verify_configured_2fa_code,
)
from astrbot.dashboard.password_state import (
get_dashboard_password_hash,
is_password_change_required,
is_password_storage_upgraded,
set_dashboard_password_hashes,
set_password_change_required,
set_password_storage_upgraded,
AuthService,
AuthServiceResult,
)
from .route import Response, Route, RouteContext
DASHBOARD_JWT_COOKIE_NAME = "astrbot_dashboard_jwt"
DASHBOARD_JWT_COOKIE_MAX_AGE = 7 * 24 * 60 * 60
SKIP_DEFAULT_PASSWORD_AUTH_ENV = "ASTRBOT_DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY = "DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
LOCAL_DASHBOARD_HOSTS = {"127.0.0.1", "localhost", "::1"}
DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE = (
"Login failed. If this is your first time using AstrBot, the old default "
"astrbot password has been replaced by a random strong password printed in "
"the startup logs. Check the initial password in the logs and try again. "
"Learn more: https://docs.astrbot.app/en/faq.html\n\n"
"登录失败。如果您是初次使用,旧版默认 astrbot 密码已改为启动日志中输出的"
"随机强密码。请使用日志中提供的的初始密码来登录。了解更多:"
"https://docs.astrbot.app/faq.html"
)
LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE = (
"Incorrect username or password. If you cannot log in after upgrading "
"AstrBot even though the password is correct, see "
"https://docs.astrbot.app/en/faq.html\n\n"
"用户名或密码错误。如果你在升级 AstrBot 后遇到了密码正确但无法登录的情况,"
"请参考 https://docs.astrbot.app/faq.html"
)
__all__ = ("AuthRoute",)
class AuthRoute(Route):
def __init__(self, context: RouteContext, db) -> None:
super().__init__(context)
self.db = db
self.routes = {
"/auth/login": ("POST", self.login),
"/auth/logout": ("POST", self.logout),
@@ -78,261 +32,43 @@ class AuthRoute(Route):
"/auth/totp/recovery": ("POST", self.totp_recovery),
"/auth/account/edit": ("POST", self.edit_account),
}
self.service = AuthService(db, self.config)
self.register_routes()
async def setup_status(self):
return (
Response()
.ok(
{
"setup_required": await self._is_setup_required(),
"skip_default_password_auth": self._can_skip_default_password_auth(),
"password_upgrade_required": not await is_password_storage_upgraded(
self.db,
self.config,
),
}
)
.__dict__
async def _json_body(self):
return await request.json
async def _service_json_response(self, operation, *args, **kwargs):
return await self._service_response(
await operation(await self._json_body(), *args, **kwargs)
)
async def setup_status(self):
return await self._service_response(await self.service.setup_status())
async def totp_setup(self):
post_data = await request.json
if isinstance(post_data, dict) and post_data.get("secret"):
secret = post_data["secret"]
code = post_data.get("code")
if not isinstance(secret, str) or not secret.strip():
return Response().error("Invalid request payload").__dict__
if not isinstance(code, str) or not code.strip():
return Response().error("TOTP 验证码是必需的").__dict__
if not await consume_totp_code(secret, code):
return Response().error("TOTP 验证码无效").__dict__
if is_totp_enabled(self.config) and not consume_rotation_verified():
return Response().error("需要先验证当前 TOTP").__dict__
set_pending_totp_secret(secret)
recovery_code, recovery_code_hash = generate_recovery_code()
return (
Response()
.ok(
{
"recovery_code": recovery_code,
"recovery_code_hash": recovery_code_hash,
},
"TOTP verified",
)
.__dict__
)
if is_totp_enabled(self.config):
if not isinstance(post_data, dict):
return Response().error("Invalid request payload").__dict__
set_rotation_verified(False)
code = post_data.get("code")
if isinstance(code, str) and code.strip():
if await consume_configured_totp_code(self.config, code):
set_rotation_verified(True)
return Response().ok({"secret": pyotp.random_base32()}).__dict__
return Response().error("当前 TOTP 验证码无效").__dict__
return Response().error("需要提供 TOTP 验证码或新密钥").__dict__
return Response().ok({"secret": pyotp.random_base32()}).__dict__
return await self._service_json_response(self.service.totp_setup)
async def totp_recovery(self):
# This endpoint MUST NOT persist the recovery code.
recovery_code, recovery_code_hash = generate_recovery_code()
return (
Response()
.ok(
{
"recovery_code": recovery_code,
"recovery_code_hash": recovery_code_hash,
}
)
.__dict__
)
return await self._service_response(await self.service.totp_recovery())
async def setup(self):
if not self._can_skip_default_password_auth():
return Response().error("Setup without password is not enabled").__dict__
if not await self._is_setup_required():
return Response().error("Setup is not required").__dict__
return await self._complete_setup()
return await self._service_json_response(self.service.setup)
async def setup_authenticated(self):
if not await self._is_setup_required():
return Response().error("Setup is not required").__dict__
if not isinstance(getattr(g, "username", None), str):
return Response().error("未授权").__dict__
return await self._complete_setup()
async def _complete_setup(self):
post_data = await request.json
if not isinstance(post_data, dict):
return Response().error("Invalid request payload").__dict__
new_username = post_data.get("username")
new_password = post_data.get("password")
confirm_password = post_data.get("confirm_password")
if not isinstance(new_username, str) or len(new_username.strip()) < 3:
return Response().error("用户名长度至少3位").__dict__
if not isinstance(new_password, str):
return Response().error("新密码无效").__dict__
if not isinstance(confirm_password, str) or confirm_password != new_password:
return Response().error("两次输入的新密码不一致").__dict__
try:
validate_dashboard_password(new_password)
except ValueError as e:
return Response().error(str(e)).__dict__
username = new_username.strip()
self.config["dashboard"]["username"] = username
set_dashboard_password_hashes(self.config, new_password)
await set_password_storage_upgraded(self.db, self.config, True)
await set_password_change_required(self.db, self.config, False)
self.config.save_config()
token = self.generate_jwt(username)
payload = Response().ok(
{
"token": token,
"username": username,
"change_pwd_hint": False,
"legacy_pwd_hint": False,
"password_upgrade_required": False,
},
"Setup completed successfully",
return await self._service_json_response(
self.service.setup_authenticated,
getattr(g, "username", None),
)
response = await make_response(jsonify(payload.__dict__))
self._set_dashboard_jwt_cookie(response, token)
return response
async def login(self):
username = self.config["dashboard"]["username"]
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
post_data = await request.json
req_username = (
post_data.get("username") if isinstance(post_data, dict) else None
return await self._service_json_response(
self.service.login,
trusted_device_cookie_token=request.cookies.get(
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
"",
).strip(),
)
req_password = (
post_data.get("password") if isinstance(post_data, dict) else None
)
totp_code = post_data.get("code") if isinstance(post_data, dict) else None
trust_device_flag = (
post_data.get("trust_device_flag") is True
if isinstance(post_data, dict)
else False
)
if not isinstance(req_username, str) or not isinstance(req_password, str):
return Response().error("Invalid request payload").__dict__
login_verified = req_username == username and verify_dashboard_password(
password, req_password
)
if not login_verified:
await asyncio.sleep(3)
if req_password == "astrbot":
return Response().error(DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE).__dict__
if is_legacy_dashboard_password(password):
return Response().error(LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE).__dict__
return await self._error_response(
"用户名或密码错误",
401,
)
totp_verified = False
if is_totp_enabled(self.config):
cookie_token = request.cookies.get(
TOTP_TRUSTED_DEVICE_COOKIE_NAME, ""
).strip()
if not await is_totp_trusted_device_valid(
self.config, self.db, cookie_token
):
if not isinstance(totp_code, str) or not totp_code.strip():
response = await make_response(
jsonify(
{
"status": "error",
"message": "需要 TOTP 验证",
"data": {"totp_required": True},
}
)
)
response.status_code = 401
return response
verified_type = await verify_configured_2fa_code(
self.config, totp_code, allow_recovery=True
)
if verified_type is TwoFactorCodeType.TOTP:
totp_verified = True
elif verified_type is TwoFactorCodeType.RECOVERY:
self.config["dashboard"]["totp"] = {
"enable": False,
"secret": "",
"recovery_code_hash": "",
}
await revoke_user_trusted_devices(self.db)
self.config.save_config()
elif len(totp_code) == 6 and totp_code.isdigit():
return await self._error_response("TOTP 验证码无效", 401)
else:
return await self._error_response("恢复码无效", 401)
change_pwd_hint = False
legacy_pwd_hint = is_legacy_dashboard_password(password)
password_change_required = await is_password_change_required(
self.db,
self.config,
)
if (
storage_upgraded
and username == "astrbot"
and is_default_dashboard_password(password)
and not DEMO_MODE
):
change_pwd_hint = True
legacy_pwd_hint = True
logger.warning("为了保证安全,请尽快修改默认密码。")
if password_change_required and not DEMO_MODE:
change_pwd_hint = True
token = self.generate_jwt(username)
login_data = {
"token": token,
"username": username,
"change_pwd_hint": change_pwd_hint,
"legacy_pwd_hint": legacy_pwd_hint,
"password_upgrade_required": not storage_upgraded,
}
payload = Response().ok(login_data)
response = await make_response(jsonify(payload.__dict__))
self._set_dashboard_jwt_cookie(response, token)
if totp_verified and trust_device_flag:
raw_token = await issue_totp_trusted_device(self.config, self.db)
if raw_token:
response.set_cookie(
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
raw_token,
max_age=TOTP_TRUSTED_DEVICE_MAX_AGE,
httponly=True,
samesite="Strict",
secure=AuthRoute._use_secure_dashboard_jwt_cookie(),
path="/api/auth",
)
return response
async def logout(self):
response = await make_response(
@@ -342,117 +78,38 @@ class AuthRoute(Route):
return response
async def edit_account(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
post_data = await request.json
if not isinstance(post_data, dict):
return Response().error("Invalid request payload").__dict__
req_password = post_data.get("password")
if not isinstance(req_password, str):
return Response().error("Invalid request payload").__dict__
if not verify_dashboard_password(password, req_password):
return Response().error("原密码错误").__dict__
new_pwd = post_data.get("new_password", None)
new_username = post_data.get("new_username", None)
password_change_required = await is_password_change_required(
self.db,
self.config,
)
if (not storage_upgraded or password_change_required) and not new_pwd:
return Response().error("请设置新密码以完成安全升级").__dict__
if not new_pwd and not new_username:
return Response().error("新用户名和新密码不能同时为空").__dict__
# Verify password confirmation
if new_pwd:
if not isinstance(new_pwd, str):
return Response().error("新密码无效").__dict__
confirm_pwd = post_data.get("confirm_password", None)
if not isinstance(confirm_pwd, str) or confirm_pwd != new_pwd:
return Response().error("两次输入的新密码不一致").__dict__
try:
validate_dashboard_password(new_pwd)
except ValueError as e:
return Response().error(str(e)).__dict__
set_dashboard_password_hashes(self.config, new_pwd)
await set_password_storage_upgraded(self.db, self.config, True)
await set_password_change_required(self.db, self.config, False)
if is_totp_enabled(self.config):
await revoke_user_trusted_devices(self.db)
if new_username:
self.config["dashboard"]["username"] = new_username
self.config.save_config()
return Response().ok(None, "Updated account successfully").__dict__
return await self._service_json_response(self.service.edit_account)
def generate_jwt(self, username):
payload = {
"username": username,
"exp": datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(days=7),
}
jwt_token = self.config["dashboard"].get("jwt_secret", None)
if not jwt_token:
raise ValueError("JWT secret is not set in the cmd_config.")
token = jwt.encode(payload, jwt_token, algorithm="HS256")
return token
return self.service.generate_jwt(username)
async def _is_setup_required(self) -> bool:
if DEMO_MODE:
return False
dashboard_config = self.config["dashboard"]
password_change_required = await is_password_change_required(
self.db,
self.config,
async def _service_response(self, result: AuthServiceResult):
payload = (
Response().error(result.message or "")
if result.status == "error"
else Response().ok(result.data, result.message)
)
if password_change_required:
return True
if result.status == "error" and result.data is not None:
payload.data = result.data
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
if not storage_upgraded:
return False
response = await make_response(jsonify(payload.__dict__))
response.status_code = result.status_code
return dashboard_config.get(
"username"
) == "astrbot" and is_default_dashboard_password(
dashboard_config.get("pbkdf2_password", "")
)
if result.jwt_token:
self._set_dashboard_jwt_cookie(response, result.jwt_token)
@staticmethod
async def _error_response(message: str, status_code: int):
response = await make_response(jsonify(Response().error(message).__dict__))
response.status_code = status_code
if result.trusted_device_token:
response.set_cookie(
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
result.trusted_device_token,
max_age=TOTP_TRUSTED_DEVICE_MAX_AGE,
httponly=True,
samesite="Strict",
secure=AuthRoute._use_secure_dashboard_jwt_cookie(),
path="/api/auth",
)
return response
def _can_skip_default_password_auth(self) -> bool:
if not self._env_flag_enabled(SKIP_DEFAULT_PASSWORD_AUTH_ENV):
return False
host = (
os.environ.get("DASHBOARD_HOST")
or os.environ.get("ASTRBOT_DASHBOARD_HOST")
or self.config["dashboard"].get("host", "")
)
return str(host).strip().lower() in LOCAL_DASHBOARD_HOSTS
@staticmethod
def _env_flag_enabled(name: str) -> bool:
value = os.environ.get(name)
if value is None and name == SKIP_DEFAULT_PASSWORD_AUTH_ENV:
value = os.environ.get(SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY)
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}
@staticmethod
def _use_secure_dashboard_jwt_cookie() -> bool:
return bool(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,9 @@
from quart import g, request
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from astrbot.dashboard.fastapi_compat import g, request
from astrbot.dashboard.services.chatui_project_service import (
ChatUIProjectService,
ChatUIProjectServiceError,
)
from .route import Response, Route, RouteContext
@@ -22,225 +24,96 @@ class ChatUIProjectRoute(Route):
),
"/chatui_project/get_sessions": ("GET", self.get_project_sessions),
}
self.db = db
self.service = ChatUIProjectService(db)
self.register_routes()
@staticmethod
def _username() -> str:
return g.get("username", "guest")
@staticmethod
def _service_error(exc: ChatUIProjectServiceError):
return Response().error(str(exc)).__dict__
@staticmethod
def _ok(data=None):
return Response().ok(data=data).__dict__
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result)
except ChatUIProjectServiceError as exc:
return self._service_error(exc)
async def _run_json(self, operation):
async def invoke():
data = await self._json_body()
return operation(data)
return await self._run(invoke)
async def create_project(self):
"""Create a new ChatUI project."""
username = g.get("username", "guest")
post_data = await request.json
title = post_data.get("title")
emoji = post_data.get("emoji", "📁")
description = post_data.get("description")
if not title:
return Response().error("Missing key: title").__dict__
project = await self.db.create_chatui_project(
creator=username,
title=title,
emoji=emoji,
description=description,
)
return (
Response()
.ok(
data={
"project_id": project.project_id,
"title": project.title,
"emoji": project.emoji,
"description": project.description,
"created_at": to_utc_isoformat(project.created_at),
"updated_at": to_utc_isoformat(project.updated_at),
}
)
.__dict__
return await self._run_json(
lambda data: self.service.create_project(self._username(), data)
)
async def list_projects(self):
"""Get all ChatUI projects for the current user."""
username = g.get("username", "guest")
projects = await self.db.get_chatui_projects_by_creator(creator=username)
projects_data = [
{
"project_id": project.project_id,
"title": project.title,
"emoji": project.emoji,
"description": project.description,
"created_at": to_utc_isoformat(project.created_at),
"updated_at": to_utc_isoformat(project.updated_at),
}
for project in projects
]
return Response().ok(data=projects_data).__dict__
return await self._run(lambda: self.service.list_projects(self._username()))
async def get_project(self):
"""Get a specific ChatUI project."""
project_id = request.args.get("project_id")
if not project_id:
return Response().error("Missing key: project_id").__dict__
username = g.get("username", "guest")
project = await self.db.get_chatui_project_by_id(project_id)
if not project:
return Response().error(f"Project {project_id} not found").__dict__
# Verify ownership
if project.creator != username:
return Response().error("Permission denied").__dict__
return (
Response()
.ok(
data={
"project_id": project.project_id,
"title": project.title,
"emoji": project.emoji,
"description": project.description,
"created_at": to_utc_isoformat(project.created_at),
"updated_at": to_utc_isoformat(project.updated_at),
}
return await self._run(
lambda: self.service.get_project_from_legacy_query(
self._username(),
request.args.get("project_id"),
)
.__dict__
)
async def update_chatui_project(self):
"""Update a ChatUI project."""
post_data = await request.json
project_id = post_data.get("project_id")
title = post_data.get("title")
emoji = post_data.get("emoji")
description = post_data.get("description")
if not project_id:
return Response().error("Missing key: project_id").__dict__
username = g.get("username", "guest")
# Verify ownership
project = await self.db.get_chatui_project_by_id(project_id)
if not project:
return Response().error(f"Project {project_id} not found").__dict__
if project.creator != username:
return Response().error("Permission denied").__dict__
await self.db.update_chatui_project(
project_id=project_id,
title=title,
emoji=emoji,
description=description,
return await self._run_json(
lambda data: self.service.update_project(self._username(), data)
)
return Response().ok().__dict__
async def delete_project(self):
"""Delete a ChatUI project."""
project_id = request.args.get("project_id")
if not project_id:
return Response().error("Missing key: project_id").__dict__
username = g.get("username", "guest")
# Verify ownership
project = await self.db.get_chatui_project_by_id(project_id)
if not project:
return Response().error(f"Project {project_id} not found").__dict__
if project.creator != username:
return Response().error("Permission denied").__dict__
await self.db.delete_chatui_project(project_id)
return Response().ok().__dict__
return await self._run(
lambda: self.service.delete_project_from_legacy_query(
self._username(),
request.args.get("project_id"),
)
)
async def add_session_to_project(self):
"""Add a session to a project."""
post_data = await request.json
session_id = post_data.get("session_id")
project_id = post_data.get("project_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
if not project_id:
return Response().error("Missing key: project_id").__dict__
username = g.get("username", "guest")
# Verify project ownership
project = await self.db.get_chatui_project_by_id(project_id)
if not project:
return Response().error(f"Project {project_id} not found").__dict__
if project.creator != username:
return Response().error("Permission denied").__dict__
# Verify session ownership
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
await self.db.add_session_to_project(session_id, project_id)
return Response().ok().__dict__
return await self._run_json(
lambda data: self.service.add_session_to_project(self._username(), data)
)
async def remove_session_from_project(self):
"""Remove a session from its project."""
post_data = await request.json
session_id = post_data.get("session_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
# Verify session ownership
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
await self.db.remove_session_from_project(session_id)
return Response().ok().__dict__
return await self._run_json(
lambda data: self.service.remove_session_from_project(
self._username(),
data,
)
)
async def get_project_sessions(self):
"""Get all sessions in a project."""
project_id = request.args.get("project_id")
if not project_id:
return Response().error("Missing key: project_id").__dict__
username = g.get("username", "guest")
# Verify project ownership
project = await self.db.get_chatui_project_by_id(project_id)
if not project:
return Response().error(f"Project {project_id} not found").__dict__
if project.creator != username:
return Response().error("Permission denied").__dict__
sessions = await self.db.get_project_sessions(project_id)
sessions_data = [
{
"session_id": session.session_id,
"platform_id": session.platform_id,
"creator": session.creator,
"display_name": session.display_name,
"is_group": session.is_group,
"created_at": to_utc_isoformat(session.created_at),
"updated_at": to_utc_isoformat(session.updated_at),
}
for session in sessions
]
return Response().ok(data=sessions_data).__dict__
return await self._run(
lambda: self.service.get_project_sessions_from_legacy_query(
self._username(),
request.args.get("project_id"),
)
)

View File

@@ -1,17 +1,7 @@
from quart import request
from astrbot.core.star.command_management import (
list_command_conflicts,
list_commands,
)
from astrbot.core.star.command_management import (
rename_command as rename_command_service,
)
from astrbot.core.star.command_management import (
toggle_command as toggle_command_service,
)
from astrbot.core.star.command_management import (
update_command_permission as update_command_permission_service,
from astrbot.dashboard.fastapi_compat import request
from astrbot.dashboard.services.command_service import (
CommandService,
CommandServiceError,
)
from .route import Response, Route, RouteContext
@@ -20,7 +10,7 @@ from .route import Response, Route, RouteContext
class CommandRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle=None) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.service = CommandService(self.config, core_lifecycle)
self.routes = {
"/commands": ("GET", self.get_commands),
"/commands/conflicts": ("GET", self.get_conflicts),
@@ -30,88 +20,50 @@ class CommandRoute(Route):
}
self.register_routes()
@staticmethod
def _ok(data=None):
return Response().ok(data).__dict__
@staticmethod
def _error(message: str):
return Response().error(message).__dict__
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result)
except CommandServiceError as exc:
return self._error(str(exc))
async def _run_json(self, operation):
async def invoke():
data = await self._json_body()
return operation(data)
return await self._run(invoke)
async def get_commands(self):
commands = await list_commands()
summary = {
"total": len(commands),
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
}
# 优先从指定 config_id 的配置中读取唤醒词,否则使用默认配置
config_id = request.args.get("config_id", "").strip()
wake_prefix = self.config.get("wake_prefix", ["/"])
if config_id and self.core_lifecycle:
acm = getattr(self.core_lifecycle, "astrbot_config_mgr", None)
if acm and config_id in acm.confs:
wake_prefix = acm.confs[config_id].get("wake_prefix", wake_prefix)
return (
Response()
.ok({"items": commands, "summary": summary, "wake_prefix": wake_prefix})
.__dict__
return await self._run(
self.service.list_commands_from_legacy_query(
request.args.get("config_id", "")
)
)
async def get_conflicts(self):
conflicts = await list_command_conflicts()
return Response().ok(conflicts).__dict__
return await self._run(self.service.list_conflicts())
async def toggle_command(self):
data = await request.get_json()
handler_full_name = data.get("handler_full_name")
enabled = data.get("enabled")
if handler_full_name is None or enabled is None:
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
if isinstance(enabled, str):
enabled = enabled.lower() in ("1", "true", "yes", "on")
try:
await toggle_command_service(handler_full_name, bool(enabled))
except ValueError as exc:
return Response().error(str(exc)).__dict__
payload = await _get_command_payload(handler_full_name)
return Response().ok(payload).__dict__
return await self._run_json(self.service.toggle_command_from_legacy_payload)
async def rename_command(self):
data = await request.get_json()
handler_full_name = data.get("handler_full_name")
new_name = data.get("new_name")
aliases = data.get("aliases")
if not handler_full_name or not new_name:
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
try:
await rename_command_service(handler_full_name, new_name, aliases=aliases)
except ValueError as exc:
return Response().error(str(exc)).__dict__
payload = await _get_command_payload(handler_full_name)
return Response().ok(payload).__dict__
return await self._run_json(self.service.rename_command_from_legacy_payload)
async def update_permission(self):
data = await request.get_json()
handler_full_name = data.get("handler_full_name")
permission = data.get("permission")
if not handler_full_name or not permission:
return (
Response().error("handler_full_name 与 permission 均为必填。").__dict__
)
try:
await update_command_permission_service(handler_full_name, permission)
except ValueError as exc:
return Response().error(str(exc)).__dict__
payload = await _get_command_payload(handler_full_name)
return Response().ok(payload).__dict__
async def _get_command_payload(handler_full_name: str):
commands = await list_commands()
for cmd in commands:
if cmd["handler_full_name"] == handler_full_name:
return cmd
return {}
return await self._run_json(self.service.update_permission_from_legacy_payload)

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,11 @@
import json
import traceback
from dataclasses import asdict
from datetime import datetime
from io import BytesIO
from quart import request, send_file
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias
from astrbot.dashboard.fastapi_compat import request, send_file
from astrbot.dashboard.services.conversation_service import (
ConversationService,
ConversationServiceError,
)
from .route import Response, Route, RouteContext
@@ -24,379 +20,102 @@ class ConversationRoute(Route):
super().__init__(context)
self.routes = {
"/conversation/list": ("GET", self.list_conversations),
"/conversation/detail": (
"POST",
self.get_conv_detail,
),
"/conversation/detail": ("POST", self.get_conv_detail),
"/conversation/update": ("POST", self.upd_conv),
"/conversation/delete": ("POST", self.del_conv),
"/conversation/update_history": (
"POST",
self.update_history,
),
"/conversation/update_history": ("POST", self.update_history),
"/conversation/export": ("POST", self.export_conversations),
}
self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
self.service = ConversationService(db_helper, core_lifecycle)
self.register_routes()
def _build_umo_info(self, umo: str | None, alias_map: dict) -> dict:
umo_str = umo or ""
return {
"umo": umo_str,
**parse_umo(umo_str),
**serialize_umo_alias(alias_map.get(umo_str), umo_str),
}
@staticmethod
def _error(message: str):
return Response().error(message).__dict__
def _serialize_conversation(self, conversation, alias_map: dict) -> dict:
return {
**asdict(conversation),
"umo_info": self._build_umo_info(conversation.user_id, alias_map),
}
@staticmethod
def _ok(data=None):
return Response().ok(data).__dict__
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation, *, label: str):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result)
except ConversationServiceError as exc:
return self._error(str(exc))
except Exception as exc:
logger.error("%s: %s", label, exc, exc_info=True)
return self._error(f"{label}: {exc!s}")
async def _run_json(self, operation, *, label: str):
async def invoke():
data = await self._json_body()
return operation(data)
return await self._run(invoke, label=label)
async def list_conversations(self):
"""获取对话列表,支持分页、排序和筛选"""
try:
# 获取分页参数
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 获取筛选参数
platforms = request.args.get("platforms", "")
message_types = request.args.get("message_types", "")
search_query = request.args.get("search", "")
exclude_ids = request.args.get("exclude_ids", "")
exclude_platforms = request.args.get("exclude_platforms", "")
# 转换为列表
platform_list = platforms.split(",") if platforms else []
message_type_list = message_types.split(",") if message_types else []
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
exclude_platform_list = (
exclude_platforms.split(",") if exclude_platforms else []
)
page = max(page, 1)
if page_size < 1:
page_size = 20
page_size = min(page_size, 100)
try:
(
conversations,
total_count,
) = await self.conv_mgr.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
message_types=message_type_list,
search_query=search_query,
exclude_ids=exclude_id_list,
exclude_platforms=exclude_platform_list,
)
except Exception as e:
logger.error(f"数据库查询出错: {e!s}\n{traceback.format_exc()}")
return Response().error(f"数据库查询出错: {e!s}").__dict__
# 计算总页数
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 1
)
umos = sorted({conv.user_id for conv in conversations if conv.user_id})
alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases(umos))
result = {
"conversations": [
self._serialize_conversation(conversation, alias_map)
for conversation in conversations
],
"pagination": {
"page": page,
"page_size": page_size,
"total": total_count,
"total_pages": total_pages,
},
}
return Response().ok(result).__dict__
except Exception as e:
error_msg = f"获取对话列表失败: {e!s}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"获取对话列表失败: {e!s}").__dict__
return await self._run(
self.service.list_conversations_from_legacy_query(
page=request.args.get("page", 1),
page_size=request.args.get("page_size", 20),
platforms=request.args.get("platforms", ""),
message_types=request.args.get("message_types", ""),
search_query=request.args.get("search", ""),
exclude_ids=request.args.get("exclude_ids", ""),
exclude_platforms=request.args.get("exclude_platforms", ""),
),
label="获取对话列表失败",
)
async def get_conv_detail(self):
"""获取指定对话详情通过POST请求"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
return Response().error("对话不存在").__dict__
alias_map = build_umo_alias_map(
await self.db_helper.get_umo_aliases([user_id])
)
return (
Response()
.ok(
{
"user_id": user_id,
"cid": cid,
"title": conversation.title,
"persona_id": conversation.persona_id,
"history": conversation.history,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"umo_info": self._build_umo_info(user_id, alias_map),
},
)
.__dict__
)
except Exception as e:
logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取对话详情失败: {e!s}").__dict__
return await self._run_json(
self.service.get_conversation_detail,
label="获取对话详情失败",
)
async def upd_conv(self):
"""更新对话信息(标题和角色ID)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
title = data.get("title")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
return Response().error("对话不存在").__dict__
persona_id = data.get("persona_id", conversation.persona_id)
if title is not None or persona_id is not None:
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
title=title,
persona_id=persona_id,
)
return Response().ok({"message": "对话信息更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话信息失败: {e!s}").__dict__
return await self._run_json(
self.service.update_conversation,
label="更新对话信息失败",
)
async def del_conv(self):
"""删除对话"""
try:
data = await request.get_json()
# 检查是否是批量删除
if "conversations" in data:
# 批量删除
conversations = data.get("conversations", [])
if not conversations:
return (
Response().error("批量删除时conversations参数不能为空").__dict__
)
deleted_count = 0
failed_items = []
for conv in conversations:
user_id = conv.get("user_id")
cid = conv.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
deleted_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
message = f"成功删除 {deleted_count} 个对话"
if failed_items:
message += f",失败 {len(failed_items)}"
return (
Response()
.ok(
{
"message": message,
"deleted_count": deleted_count,
"failed_count": len(failed_items),
"failed_items": failed_items,
},
)
.__dict__
)
# 单个删除
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
return Response().ok({"message": "对话删除成功"}).__dict__
except Exception as e:
logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"删除对话失败: {e!s}").__dict__
return await self._run_json(
self.service.delete_conversation,
label="删除对话失败",
)
async def update_history(self):
"""更新对话历史内容"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
history = data.get("history")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
if history is None:
return Response().error("缺少必要参数: history").__dict__
# 历史记录必须是合法的 JSON 字符串
try:
if isinstance(history, list):
history = json.dumps(history)
else:
# 验证是否为有效的 JSON 字符串
json.loads(history)
except json.JSONDecodeError:
return (
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
)
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
return Response().error("对话不存在").__dict__
history = json.loads(history) if isinstance(history, str) else history
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
history=history,
)
return Response().ok({"message": "对话历史更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {e!s}").__dict__
return await self._run_json(
self.service.update_history,
label="更新对话历史失败",
)
async def export_conversations(self):
"""批量导出对话为 JSONL 格式"""
try:
data = await request.get_json()
conversations_to_export = data.get("conversations", [])
if not conversations_to_export:
return Response().error("导出列表不能为空").__dict__
# 收集所有对话的内容
jsonl_lines = []
exported_count = 0
failed_items = []
for conv_info in conversations_to_export:
user_id = conv_info.get("user_id")
cid = conv_info.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 对话不存在"
)
continue
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
content = json.loads(conversation.history)
# 创建导出记录
export_record = {
"cid": cid,
"user_id": user_id,
"platform_id": conversation.platform_id,
"title": conversation.title,
"persona_id": conversation.persona_id,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"content": content,
}
# 将记录转换为 JSON 字符串并添加到 JSONL
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
exported_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
logger.error(
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
)
if exported_count == 0:
return Response().error("没有成功导出任何对话").__dict__
# 创建 JSONL 内容
jsonl_content = "\n".join(jsonl_lines)
# 创建一个内存文件对象
file_obj = BytesIO(jsonl_content.encode("utf-8"))
file_obj.seek(0)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
# 返回文件流
export = await self.service.export_conversations(await self._json_body())
return await send_file(
file_obj,
mimetype="application/jsonl",
export.file_obj,
mimetype=export.mimetype,
as_attachment=True,
attachment_filename=filename,
attachment_filename=export.filename,
)
except Exception as e:
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"批量导出对话失败: {e!s}").__dict__
except ConversationServiceError as exc:
return self._error(str(exc))
except Exception as exc:
logger.error("批量导出对话失败: %s", exc, exc_info=True)
return self._error(f"批量导出对话失败: {exc!s}")

View File

@@ -1,11 +1,6 @@
import asyncio
import traceback
from datetime import datetime, timezone
from quart import jsonify, request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.dashboard.fastapi_compat import jsonify, request
from astrbot.dashboard.services.cron_service import CronService, CronServiceError
from .route import Response, Route, RouteContext
@@ -15,8 +10,7 @@ class CronRoute(Route):
self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self._background_tasks: set[asyncio.Task] = set()
self.service = CronService(core_lifecycle)
self.routes = [
("/cron/jobs", ("GET", self.list_jobs)),
("/cron/jobs", ("POST", self.create_job)),
@@ -26,276 +20,50 @@ class CronRoute(Route):
]
self.register_routes()
def _serialize_job(self, job) -> dict:
data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__
for k in ["created_at", "updated_at", "last_run_at", "next_run_time"]:
v = data.get(k)
if isinstance(v, datetime):
# Attach UTC
if v.tzinfo is None:
v = v.replace(tzinfo=timezone.utc)
data[k] = v.isoformat()
# expose note explicitly for UI (prefer payload.note then description)
payload = data.get("payload") or {}
data["note"] = payload.get("note") or data.get("description") or ""
data["run_at"] = payload.get("run_at")
data["run_once"] = data.get("run_once", False)
# status is internal; hide to avoid implying one-time completion for recurring jobs
data.pop("status", None)
return data
@staticmethod
def _ok(data=None, message: str | None = None):
return jsonify(Response().ok(data=data, message=message).__dict__)
@staticmethod
def _error(message: str):
return jsonify(Response().error(message).__dict__)
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation, *, message: str | None = None):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result, message)
except CronServiceError as exc:
return self._error(str(exc))
async def _run_json(self, operation, *, message: str | None = None):
async def invoke():
data = await self._json_body()
return operation(data)
return await self._run(invoke, message=message)
async def list_jobs(self):
try:
cron_mgr = self.core_lifecycle.cron_manager
if cron_mgr is None:
return jsonify(
Response().error("Cron manager not initialized").__dict__
)
job_type = request.args.get("type")
jobs = await cron_mgr.list_jobs(job_type)
data = [self._serialize_job(j) for j in jobs]
return jsonify(Response().ok(data=data).__dict__)
except Exception as e: # noqa: BLE001
logger.error(traceback.format_exc())
return jsonify(Response().error(f"Failed to list jobs: {e!s}").__dict__)
return await self._run(
self.service.list_jobs_from_legacy_query(request.args.get("type"))
)
async def create_job(self):
try:
cron_mgr = self.core_lifecycle.cron_manager
if cron_mgr is None:
return jsonify(
Response().error("Cron manager not initialized").__dict__
)
payload = await request.json
if not isinstance(payload, dict):
return jsonify(Response().error("Invalid payload").__dict__)
name = payload.get("name") or "active_agent_task"
cron_expression = payload.get("cron_expression")
note = payload.get("note") or payload.get("description") or name
session = str(payload.get("session") or "").strip()
persona_id = payload.get("persona_id")
provider_id = payload.get("provider_id")
timezone = payload.get("timezone")
enabled = bool(payload.get("enabled", True))
run_once = bool(payload.get("run_once", False))
run_at = payload.get("run_at")
if run_once and not run_at:
return jsonify(
Response().error("run_at is required when run_once=true").__dict__
)
if (not run_once) and not cron_expression:
return jsonify(
Response()
.error("cron_expression is required when run_once=false")
.__dict__
)
if run_once and cron_expression:
cron_expression = None # ignore cron when run_once specified
run_at_dt = None
if run_at:
try:
run_at_dt = datetime.fromisoformat(str(run_at))
except Exception:
return jsonify(
Response().error("run_at must be ISO datetime").__dict__
)
job_payload = {
"session": session,
"note": note,
"persona_id": persona_id,
"provider_id": provider_id,
"run_at": run_at,
"origin": "api",
}
job = await cron_mgr.add_active_job(
name=name,
cron_expression=cron_expression,
payload=job_payload,
description=note,
timezone=timezone,
enabled=enabled,
run_once=run_once,
run_at=run_at_dt,
)
return jsonify(Response().ok(data=self._serialize_job(job)).__dict__)
except Exception as e: # noqa: BLE001
logger.error(traceback.format_exc())
return jsonify(Response().error(f"Failed to create job: {e!s}").__dict__)
return await self._run_json(self.service.create_job)
async def update_job(self, job_id: str):
try:
cron_mgr = self.core_lifecycle.cron_manager
if cron_mgr is None:
return jsonify(
Response().error("Cron manager not initialized").__dict__
)
payload = await request.json
if not isinstance(payload, dict):
return jsonify(Response().error("Invalid payload").__dict__)
job = await cron_mgr.db.get_cron_job(job_id)
if not job:
return jsonify(Response().error("Job not found").__dict__)
updates = {}
if "name" in payload:
name = str(payload.get("name") or "").strip()
if not name:
return jsonify(Response().error("name cannot be empty").__dict__)
updates["name"] = name
if "enabled" in payload:
updates["enabled"] = bool(payload.get("enabled"))
if "timezone" in payload:
timezone = payload.get("timezone")
updates["timezone"] = str(timezone).strip() or None
next_run_once = (
bool(payload.get("run_once"))
if "run_once" in payload
else bool(job.run_once)
)
if job.job_type == "active_agent":
merged_payload = (
dict(job.payload) if isinstance(job.payload, dict) else {}
)
if "payload" in payload and isinstance(payload.get("payload"), dict):
merged_payload.update(payload["payload"])
if "session" in payload:
session = str(payload.get("session") or "").strip()
if session:
merged_payload["session"] = session
else:
merged_payload.pop("session", None)
note_updated = False
if "note" in payload:
note = str(payload.get("note") or "").strip()
if not note:
return jsonify(
Response().error("note cannot be empty").__dict__
)
merged_payload["note"] = note
updates["description"] = note
note_updated = True
elif "description" in payload:
description = str(payload.get("description") or "").strip()
if not description:
return jsonify(
Response().error("description cannot be empty").__dict__
)
updates["description"] = description
merged_payload["note"] = description
note_updated = True
if not note_updated and updates.get("description") is None:
existing_note = str(
merged_payload.get("note") or job.description or ""
).strip()
if existing_note:
merged_payload["note"] = existing_note
next_cron_expression = (
payload.get("cron_expression")
if "cron_expression" in payload
else job.cron_expression
)
if next_cron_expression is not None:
next_cron_expression = str(next_cron_expression).strip() or None
run_at_raw = (
payload.get("run_at")
if "run_at" in payload
else merged_payload.get("run_at")
)
run_at_iso = None
if run_at_raw:
try:
run_at_iso = datetime.fromisoformat(str(run_at_raw)).isoformat()
except Exception:
return jsonify(
Response().error("run_at must be ISO datetime").__dict__
)
if next_run_once:
if not run_at_iso:
return jsonify(
Response()
.error("run_at is required when run_once=true")
.__dict__
)
next_cron_expression = None
merged_payload["run_at"] = run_at_iso
else:
if not next_cron_expression:
return jsonify(
Response()
.error("cron_expression is required when run_once=false")
.__dict__
)
merged_payload.pop("run_at", None)
updates["run_once"] = next_run_once
updates["cron_expression"] = next_cron_expression
updates["payload"] = merged_payload
else:
if "cron_expression" in payload:
cron_expression = str(payload.get("cron_expression") or "").strip()
if not cron_expression:
return jsonify(
Response().error("cron_expression cannot be empty").__dict__
)
updates["cron_expression"] = cron_expression
if "description" in payload:
description = str(payload.get("description") or "").strip()
updates["description"] = description or None
job = await cron_mgr.update_job(job_id, **updates)
if not job:
return jsonify(Response().error("Job not found").__dict__)
return jsonify(Response().ok(data=self._serialize_job(job)).__dict__)
except Exception as e: # noqa: BLE001
logger.error(traceback.format_exc())
return jsonify(Response().error(f"Failed to update job: {e!s}").__dict__)
return await self._run_json(
lambda payload: self.service.update_job(job_id, payload)
)
async def delete_job(self, job_id: str):
try:
cron_mgr = self.core_lifecycle.cron_manager
if cron_mgr is None:
return jsonify(
Response().error("Cron manager not initialized").__dict__
)
await cron_mgr.delete_job(job_id)
return jsonify(Response().ok(message="deleted").__dict__)
except Exception as e: # noqa: BLE001
logger.error(traceback.format_exc())
return jsonify(Response().error(f"Failed to delete job: {e!s}").__dict__)
return await self._run(self.service.delete_job(job_id), message="deleted")
async def run_job_now(self, job_id: str):
try:
cron_mgr = self.core_lifecycle.cron_manager
if cron_mgr is None:
return jsonify(
Response().error("Cron manager not initialized").__dict__
)
job = await cron_mgr.db.get_cron_job(job_id)
if not job:
return jsonify(Response().error("Job not found").__dict__)
task = asyncio.create_task(cron_mgr.run_job_now(job_id))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return jsonify(Response().ok(message="started").__dict__)
except Exception as e: # noqa: BLE001
logger.error(traceback.format_exc())
return jsonify(Response().error(f"Failed to run job: {e!s}").__dict__)
return await self._run(self.service.run_job_now(job_id), message="started")

View File

@@ -1,6 +1,5 @@
from quart import abort, send_file
from astrbot.core import file_token_service
from astrbot.dashboard.fastapi_compat import abort, send_file
from astrbot.dashboard.services.file_service import FileService, FileServiceError
from .route import Route, RouteContext
@@ -11,6 +10,7 @@ class FileRoute(Route):
context: RouteContext,
) -> None:
super().__init__(context)
self.service = FileService()
self.routes = {
"/file/<file_token>": ("GET", self.serve_file),
}
@@ -18,7 +18,7 @@ class FileRoute(Route):
async def serve_file(self, file_token: str):
try:
file_path = await file_token_service.handle_file(file_token)
file_path = await self.service.resolve_token_file(file_token)
return await send_file(file_path)
except (FileNotFoundError, KeyError):
except FileServiceError:
return abort(404)

File diff suppressed because it is too large Load Diff

View File

@@ -1,115 +1,12 @@
import asyncio
import json
import os
import re
import time
import uuid
import wave
from typing import Any
import jwt
from quart import websocket
from astrbot import logger
from astrbot.core import sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_webchat_message_parts,
create_attachment_part_from_existing_file,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from astrbot.dashboard.fastapi_compat import websocket
from astrbot.dashboard.services.live_chat_service import LiveChatService
from .chat import (
BotMessageAccumulator,
build_bot_history_content,
collect_plain_text_from_message_parts,
)
from .route import Route, RouteContext
class LiveChatSession:
"""Live Chat 会话管理器"""
def __init__(self, session_id: str, username: str) -> None:
self.session_id = session_id
self.username = username
self.conversation_id = str(uuid.uuid4())
self.is_speaking = False
self.is_processing = False
self.should_interrupt = False
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
self.chat_subscriptions: dict[str, str] = {}
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
self.ws_send_lock = asyncio.Lock()
def start_speaking(self, stamp: str) -> None:
"""开始说话"""
self.is_speaking = True
self.current_stamp = stamp
self.audio_frames = []
logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}")
def add_audio_frame(self, data: bytes) -> None:
"""添加音频帧"""
if self.is_speaking:
self.audio_frames.append(data)
async def end_speaking(self, stamp: str) -> tuple[str | None, float]:
"""结束说话,返回组装的 WAV 文件路径和耗时"""
start_time = time.time()
if not self.is_speaking or stamp != self.current_stamp:
logger.warning(
f"[Live Chat] stamp 不匹配或未在说话状态: {stamp} vs {self.current_stamp}"
)
return None, 0.0
self.is_speaking = False
if not self.audio_frames:
logger.warning("[Live Chat] 没有音频帧数据")
return None, 0.0
# 组装 WAV 文件
try:
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
# 假设前端发送的是 PCM 数据,采样率 16000Hz单声道16位
with wave.open(audio_path, "wb") as wav_file:
wav_file.setnchannels(1) # 单声道
wav_file.setsampwidth(2) # 16位 = 2字节
wav_file.setframerate(16000) # 采样率 16000Hz
for frame in self.audio_frames:
wav_file.writeframes(frame)
self.temp_audio_path = audio_path
logger.info(
f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes"
)
return audio_path, time.time() - start_time
except Exception as e:
logger.error(f"[Live Chat] 组装 WAV 文件失败: {e}", exc_info=True)
return None, 0.0
def cleanup(self) -> None:
"""清理临时文件"""
if self.temp_audio_path and os.path.exists(self.temp_audio_path):
try:
os.remove(self.temp_audio_path)
logger.debug(f"[Live Chat] 已删除临时文件: {self.temp_audio_path}")
except Exception as e:
logger.warning(f"[Live Chat] 删除临时文件失败: {e}")
self.temp_audio_path = None
class LiveChatRoute(Route):
"""Live Chat WebSocket 路由"""
@@ -120,16 +17,9 @@ class LiveChatRoute(Route):
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.db = db
self.plugin_manager = core_lifecycle.plugin_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.sessions: dict[str, LiveChatSession] = {}
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.attachments_dir, exist_ok=True)
self.service = LiveChatService(db, core_lifecycle)
self.sessions = self.service.sessions
# 注册 WebSocket 路由
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
@@ -142,819 +32,13 @@ class LiveChatRoute(Route):
await self._unified_ws_loop(force_ct=None)
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
"""统一 WebSocket 循环"""
# WebSocket 不能通过 header 传递 token需要从 query 参数获取
# 注意WebSocket 上下文使用 websocket.args 而不是 request.args
token = websocket.args.get("token")
if not token:
await websocket.close(1008, "Missing authentication token")
return
try:
jwt_secret = self.config["dashboard"].get("jwt_secret")
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
username = payload["username"]
except jwt.ExpiredSignatureError:
await websocket.close(1008, "Token expired")
return
except jwt.InvalidTokenError:
await websocket.close(1008, "Invalid token")
return
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
live_session = LiveChatSession(session_id, username)
self.sessions[session_id] = live_session
logger.info(f"[Live Chat] WebSocket 连接建立: {username}")
try:
while True:
message = await websocket.receive_json()
ct = force_ct or message.get("ct", "live")
if ct == "chat":
await self._handle_chat_message(live_session, message)
else:
await self._handle_message(live_session, message)
except Exception as e:
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
finally:
# 清理会话
if session_id in self.sessions:
await self._cleanup_chat_subscriptions(live_session)
live_session.cleanup()
del self.sessions[session_id]
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
"""从本地文件创建 attachment 并返回消息部分。"""
return await create_attachment_part_from_existing_file(
filename,
attach_type=attach_type,
insert_attachment=self.db.insert_attachment,
attachments_dir=self.attachments_dir,
fallback_dirs=[self.legacy_img_dir],
await self.service.run_websocket_session(
token=websocket.args.get("token"),
force_ct=force_ct,
receive_json=websocket.receive_json,
send_json=websocket.send_json,
close=websocket.close,
)
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
) -> dict:
"""从消息中提取 web_search 引用。"""
supported = [
"web_search_baidu",
"web_search_tavily",
"web_search_bocha",
"web_search_brave",
]
web_search_results = {}
tool_call_parts = [
p
for p in accumulated_parts
if p.get("type") == "tool_call" and p.get("tool_calls")
]
for part in tool_call_parts:
for tool_call in part["tool_calls"]:
if tool_call.get("name") not in supported or not tool_call.get(
"result"
):
continue
try:
result_data = json.loads(tool_call["result"])
for item in result_data.get("results", []):
if idx := item.get("index"):
web_search_results[idx] = {
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
except (json.JSONDecodeError, KeyError):
pass
if not web_search_results:
return {}
ref_indices = {
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
}
used_refs = []
for ref_index in ref_indices:
if ref_index not in web_search_results:
continue
payload = {"index": ref_index, **web_search_results[ref_index]}
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
payload["favicon"] = favicon
used_refs.append(payload)
return {"used": used_refs} if used_refs else {}
async def _save_bot_message(
self,
webchat_conv_id: str,
message_parts: list[dict],
agent_stats: dict,
refs: dict,
llm_checkpoint_id: str | None = None,
):
"""保存 bot 消息到历史记录。"""
new_his = build_bot_history_content(
message_parts,
agent_stats=agent_stats,
refs=refs,
)
return await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
llm_checkpoint_id=llm_checkpoint_id,
)
async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
async with session.ws_send_lock:
await websocket.send_json(payload)
async def _forward_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
request_id: str,
) -> None:
back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id, chat_session_id
)
try:
while True:
result = await back_queue.get()
if not result:
continue
await self._send_chat_payload(session, {"ct": "chat", **result})
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
exc_info=True,
)
finally:
webchat_queue_mgr.remove_back_queue(request_id)
if session.chat_subscriptions.get(chat_session_id) == request_id:
session.chat_subscriptions.pop(chat_session_id, None)
session.chat_subscription_tasks.pop(chat_session_id, None)
async def _ensure_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
) -> str:
existing_request_id = session.chat_subscriptions.get(chat_session_id)
existing_task = session.chat_subscription_tasks.get(chat_session_id)
if existing_request_id and existing_task and not existing_task.done():
return existing_request_id
request_id = f"ws_sub_{uuid.uuid4().hex}"
session.chat_subscriptions[chat_session_id] = request_id
task = asyncio.create_task(
self._forward_chat_subscription(session, chat_session_id, request_id),
name=f"chat_ws_sub_{chat_session_id}",
)
session.chat_subscription_tasks[chat_session_id] = task
return request_id
async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
tasks = list(session.chat_subscription_tasks.values())
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
for request_id in list(session.chat_subscriptions.values()):
webchat_queue_mgr.remove_back_queue(request_id)
session.chat_subscriptions.clear()
session.chat_subscription_tasks.clear()
async def _handle_chat_message(
self, session: LiveChatSession, message: dict
) -> None:
"""处理 Chat Mode 消息ct=chat"""
msg_type = message.get("t")
if msg_type == "bind":
chat_session_id = message.get("session_id")
if not isinstance(chat_session_id, str) or not chat_session_id:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "session_id is required",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
request_id = await self._ensure_chat_subscription(session, chat_session_id)
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "session_bound",
"session_id": chat_session_id,
"message_id": request_id,
},
)
return
if msg_type == "interrupt":
session.should_interrupt = True
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "INTERRUPTED",
"code": "INTERRUPTED",
},
)
return
if msg_type != "send":
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"Unsupported message type: {msg_type}",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
if session.is_processing:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Session is busy",
"code": "PROCESSING_ERROR",
},
)
return
payload = message.get("message")
session_id = message.get("session_id") or session.session_id
message_id = message.get("message_id") or str(uuid.uuid4())
selected_provider = message.get("selected_provider")
selected_model = message.get("selected_model")
selected_stt_provider = message.get("selected_stt_provider")
selected_tts_provider = message.get("selected_tts_provider")
persona_prompt = message.get("persona_prompt")
show_reasoning = message.get("show_reasoning")
enable_streaming = message.get("enable_streaming", True)
if not isinstance(payload, list):
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "message must be list",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
message_parts = await self._build_chat_message_parts(payload)
has_content = webchat_message_parts_have_content(message_parts)
if not has_content:
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Message content is empty",
"code": "INVALID_MESSAGE_FORMAT",
},
)
return
await self._ensure_chat_subscription(session, session_id)
session.is_processing = True
session.should_interrupt = False
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
llm_checkpoint_id = str(uuid.uuid4())
try:
pending_bot_message_flusher = None
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
session.username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"selected_stt_provider": selected_stt_provider,
"selected_tts_provider": selected_tts_provider,
"persona_prompt": persona_prompt,
"show_reasoning": show_reasoning,
"enable_streaming": enable_streaming,
"message_id": message_id,
"llm_checkpoint_id": llm_checkpoint_id,
},
),
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
saved_user_record = await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=session.username,
sender_name=session.username,
llm_checkpoint_id=llm_checkpoint_id,
)
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "user_message_saved",
"data": {
"id": saved_user_record.id,
"created_at": to_utc_isoformat(saved_user_record.created_at),
"llm_checkpoint_id": llm_checkpoint_id,
},
},
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
async def flush_pending_bot_message():
nonlocal message_accumulator, agent_stats, refs
if not (message_accumulator.has_content() or refs or agent_stats):
return None
message_parts_to_save = message_accumulator.build_message_parts(
include_pending_tool_calls=True
)
plain_text = collect_plain_text_from_message_parts(
message_parts_to_save
)
try:
extracted_refs = self._extract_web_search_refs(
plain_text,
message_parts_to_save,
)
except Exception as e:
logger.exception(
f"[Live Chat] Failed to extract web search refs: {e}",
exc_info=True,
)
extracted_refs = refs
saved_record = await self._save_bot_message(
session_id,
message_parts_to_save,
agent_stats,
extracted_refs,
llm_checkpoint_id,
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
return saved_record
pending_bot_message_flusher = flush_pending_bot_message
async def send_attachment_saved_event(part: dict | None) -> None:
if not part or not part.get("attachment_id") or not part.get("type"):
return
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "attachment_saved",
"data": {
"id": part["attachment_id"],
"type": part["type"],
},
},
)
while True:
if session.should_interrupt:
session.should_interrupt = False
await flush_pending_bot_message()
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if result.get("message_id") and result.get("message_id") != message_id:
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
parsed_agent_stats = json.loads(result_text)
agent_stats = parsed_agent_stats
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "agent_stats",
"data": parsed_agent_stats,
},
)
except Exception:
pass
continue
outgoing = {"ct": "chat", **result}
await self._send_chat_payload(session, outgoing)
if msg_type == "plain":
message_accumulator.add_plain(
result_text,
chain_type=chain_type,
streaming=streaming,
)
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self._create_attachment_from_file(filename, "image")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self._create_attachment_from_file(filename, "record")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
part = await self._create_attachment_from_file(filename, "file")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
part = await self._create_attachment_from_file(filename, "video")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
should_save = False
if msg_type == "end":
should_save = bool(
message_accumulator.has_content() or refs or agent_stats
)
elif (streaming and msg_type == "complete") or not streaming:
if chain_type not in (
"tool_call",
"tool_call_result",
"agent_stats",
):
should_save = True
if should_save:
saved_record = await flush_pending_bot_message()
if saved_record:
await self._send_chat_payload(
session,
{
"ct": "chat",
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": to_utc_isoformat(
saved_record.created_at
),
"llm_checkpoint_id": llm_checkpoint_id,
},
},
)
if msg_type == "end":
break
except Exception as e:
logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
await self._send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"处理失败: {str(e)}",
"code": "PROCESSING_ERROR",
},
)
finally:
try:
if pending_bot_message_flusher is not None:
await pending_bot_message_flusher()
except Exception as e:
logger.exception(
f"[Live Chat] Failed to persist pending chat message: {e}",
exc_info=True,
)
session.is_processing = False
webchat_queue_mgr.remove_back_queue(message_id)
async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
"""构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
return await build_webchat_message_parts(
message,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=False,
)
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
"""处理 WebSocket 消息"""
msg_type = message.get("t") # 使用 t 代替 type
if msg_type == "start_speaking":
# 开始说话
stamp = message.get("stamp")
if not stamp:
logger.warning("[Live Chat] start_speaking 缺少 stamp")
return
session.start_speaking(stamp)
elif msg_type == "speaking_part":
# 音频片段
audio_data_b64 = message.get("data")
if not audio_data_b64:
return
# 解码 base64
import base64
try:
audio_data = base64.b64decode(audio_data_b64)
session.add_audio_frame(audio_data)
except Exception as e:
logger.error(f"[Live Chat] 解码音频数据失败: {e}")
elif msg_type == "end_speaking":
# 结束说话
stamp = message.get("stamp")
if not stamp:
logger.warning("[Live Chat] end_speaking 缺少 stamp")
return
audio_path, assemble_duration = await session.end_speaking(stamp)
if not audio_path:
await websocket.send_json({"t": "error", "data": "音频组装失败"})
return
# 处理音频STT -> LLM -> TTS
await self._process_audio(session, audio_path, assemble_duration)
elif msg_type == "interrupt":
# 用户打断
session.should_interrupt = True
logger.info(f"[Live Chat] 用户打断: {session.username}")
async def _process_audio(
self, session: LiveChatSession, audio_path: str, assemble_duration: float
) -> None:
"""处理音频STT -> LLM -> 流式 TTS"""
try:
# 发送 WAV 组装耗时
await websocket.send_json(
{"t": "metrics", "data": {"wav_assemble_time": assemble_duration}}
)
wav_assembly_finish_time = time.time()
session.is_processing = True
session.should_interrupt = False
# 1. STT - 语音转文字
ctx = self.plugin_manager.context
stt_provider = ctx.provider_manager.stt_provider_insts[0]
if not stt_provider:
logger.error("[Live Chat] STT Provider 未配置")
await websocket.send_json({"t": "error", "data": "语音识别服务未配置"})
return
await websocket.send_json(
{"t": "metrics", "data": {"stt": stt_provider.meta().type}}
)
user_text = await stt_provider.get_text(audio_path)
if not user_text:
logger.warning("[Live Chat] STT 识别结果为空")
return
logger.info(f"[Live Chat] STT 结果: {user_text}")
await websocket.send_json(
{
"t": "user_msg",
"data": {"text": user_text, "ts": int(time.time() * 1000)},
}
)
# 2. 构造消息事件并发送到 pipeline
# 使用 webchat queue 机制
cid = session.conversation_id
queue = webchat_queue_mgr.get_or_create_queue(cid)
message_id = str(uuid.uuid4())
payload = {
"message_id": message_id,
"message": [{"type": "plain", "text": user_text}], # 直接发送文本
"action_type": "live", # 标记为 live mode
}
# 将消息放入队列
await queue.put((session.username, cid, payload))
# 3. 等待响应并流式发送 TTS 音频
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
bot_text = ""
audio_playing = False
try:
while True:
if session.should_interrupt:
# 用户打断,停止处理
logger.info("[Live Chat] 检测到用户打断")
await websocket.send_json({"t": "stop_play"})
# 保存消息并标记为被打断
await self._save_interrupted_message(
session, user_text, bot_text
)
# 清空队列中未处理的消息
while not back_queue.empty():
try:
back_queue.get_nowait()
except asyncio.QueueEmpty:
break
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if not result:
continue
result_message_id = result.get("message_id")
if result_message_id != message_id:
logger.warning(
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
)
continue
result_type = result.get("type")
result_chain_type = result.get("chain_type")
data = result.get("data", "")
if result_chain_type == "agent_stats":
try:
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": {
"llm_ttft": stats.get("time_to_first_token", 0),
"llm_total_time": stats.get("end_time", 0)
- stats.get("start_time", 0),
},
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
continue
if result_chain_type == "tts_stats":
try:
stats = json.loads(data)
await websocket.send_json(
{
"t": "metrics",
"data": stats,
}
)
except Exception as e:
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
continue
if result_type == "plain":
# 普通文本消息
bot_text += data
elif result_type == "audio_chunk":
# 流式音频数据
if not audio_playing:
audio_playing = True
logger.debug("[Live Chat] 开始播放音频流")
# Calculate latency from wav assembly finish to first audio chunk
speak_to_first_frame_latency = (
time.time() - wav_assembly_finish_time
)
await websocket.send_json(
{
"t": "metrics",
"data": {
"speak_to_first_frame": speak_to_first_frame_latency
},
}
)
text = result.get("text")
if text:
await websocket.send_json(
{
"t": "bot_text_chunk",
"data": {"text": text},
}
)
# 发送音频数据给前端
await websocket.send_json(
{
"t": "response",
"data": data, # base64 编码的音频数据
}
)
elif result_type in ["complete", "end"]:
# 处理完成
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
# 如果没有音频流,发送 bot 消息文本
if not audio_playing:
await websocket.send_json(
{
"t": "bot_msg",
"data": {
"text": bot_text,
"ts": int(time.time() * 1000),
},
}
)
# 发送结束标记
await websocket.send_json({"t": "end"})
# 发送总耗时
wav_to_tts_duration = time.time() - wav_assembly_finish_time
await websocket.send_json(
{
"t": "metrics",
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
}
)
break
finally:
webchat_queue_mgr.remove_back_queue(message_id)
except Exception as e:
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"})
finally:
session.is_processing = False
session.should_interrupt = False
async def _save_interrupted_message(
self, session: LiveChatSession, user_text: str, bot_text: str
) -> None:
"""保存被打断的消息"""
interrupted_text = bot_text + " [用户打断]"
logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}")
# 简单记录到日志,实际保存逻辑可以后续完善
try:
timestamp = int(time.time() * 1000)
logger.info(
f"[Live Chat] 用户消息: {user_text} (session: {session.session_id}, ts: {timestamp})"
)
if bot_text:
logger.info(
f"[Live Chat] Bot 消息(打断): {interrupted_text} (session: {session.session_id}, ts: {timestamp})"
)
except Exception as e:
logger.error(f"[Live Chat] 记录消息失败: {e}", exc_info=True)
__all__ = ["LiveChatRoute"]

View File

@@ -1,30 +1,18 @@
import asyncio
import json
import time
from collections.abc import AsyncGenerator
from typing import cast
from quart import Response as QuartResponse
from quart import make_response, request
from astrbot.core import LogBroker, logger
from astrbot.core import LogBroker
from astrbot.dashboard.fastapi_compat import Response as CompatResponse
from astrbot.dashboard.fastapi_compat import make_response, request
from astrbot.dashboard.services.log_service import LogService, LogServiceError
from .route import Response, Route, RouteContext
def _format_log_sse(log: dict, ts: float) -> str:
"""辅助函数:格式化 SSE 消息"""
payload = {
"type": "log",
**log,
}
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
class LogRoute(Route):
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
super().__init__(context)
self.log_broker = log_broker
self.service = LogService(log_broker, self.config)
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
self.app.add_url_rule(
"/api/log-history",
@@ -42,51 +30,46 @@ class LogRoute(Route):
methods=["POST"],
)
async def _replay_cached_logs(
self, last_event_id: str
) -> AsyncGenerator[str, None]:
"""辅助生成器:重放缓存的日志"""
@staticmethod
def _ok(data=None, message: str | None = None):
return Response().ok(data=data, message=message).__dict__
@staticmethod
def _error(message: str):
return Response().error(message).__dict__
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation, *, result_as_message: bool = False):
try:
last_ts = float(last_event_id)
cached_logs = list(self.log_broker.log_cache)
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
if result_as_message:
return self._ok(message=str(result))
return self._ok(result)
except LogServiceError as exc:
return self._error(str(exc))
for log_item in cached_logs:
log_ts = float(log_item.get("time", 0))
async def _run_json(self, operation, *, result_as_message: bool = False):
async def invoke():
data = await self._json_body()
return operation(data)
if log_ts > last_ts:
yield _format_log_sse(log_item, log_ts)
return await self._run(invoke, result_as_message=result_as_message)
except ValueError:
pass
except Exception as e:
logger.error(f"Log SSE 补发历史错误: {e}")
async def log(self) -> QuartResponse:
async def log(self) -> CompatResponse:
last_event_id = request.headers.get("Last-Event-ID")
async def stream():
queue = None
try:
if last_event_id:
async for event in self._replay_cached_logs(last_event_id):
yield event
queue = self.log_broker.register()
while True:
message = await queue.get()
current_ts = message.get("time", time.time())
yield _format_log_sse(message, current_ts)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Log SSE 连接错误: {e}")
finally:
if queue:
self.log_broker.unregister(queue)
async for event in self.service.stream_log_events(last_event_id):
yield event
response = cast(
QuartResponse,
CompatResponse,
await make_response(
stream(),
{
@@ -102,43 +85,15 @@ class LogRoute(Route):
async def log_history(self):
"""获取日志历史"""
try:
logs = list(self.log_broker.log_cache)
return (
Response()
.ok(
data={
"logs": logs,
},
)
.__dict__
)
except Exception as e:
logger.error(f"获取日志历史失败: {e}")
return Response().error(f"获取日志历史失败: {e}").__dict__
return await self._run(self.service.get_log_history)
async def get_trace_settings(self):
"""获取 Trace 设置"""
try:
trace_enable = self.config.get("trace_enable", True)
return Response().ok(data={"trace_enable": trace_enable}).__dict__
except Exception as e:
logger.error(f"获取 Trace 设置失败: {e}")
return Response().error(f"获取 Trace 设置失败: {e}").__dict__
return await self._run(self.service.get_trace_settings)
async def update_trace_settings(self):
"""更新 Trace 设置"""
try:
data = await request.json
if data is None:
return Response().error("请求数据为空").__dict__
trace_enable = data.get("trace_enable")
if trace_enable is not None:
self.config["trace_enable"] = bool(trace_enable)
self.config.save_config()
return Response().ok(message="Trace 设置已更新").__dict__
except Exception as e:
logger.error(f"更新 Trace 设置失败: {e}")
return Response().error(f"更新 Trace 设置失败: {e}").__dict__
return await self._run_json(
self.service.update_trace_settings_from_legacy_payload,
result_as_message=True,
)

View File

@@ -1,28 +1,27 @@
import asyncio
import hashlib
import json
from uuid import uuid4
from typing import cast
from quart import g, request, websocket
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_session import MessageSesion
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_message_chain_from_payload,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
from astrbot.dashboard.fastapi_compat import (
Response as CompatResponse,
)
from astrbot.dashboard.fastapi_compat import (
make_response,
request,
send_file,
websocket,
)
from astrbot.dashboard.services.chat_service import (
ChatService,
ChatServiceError,
extract_web_search_refs,
)
from astrbot.dashboard.services.open_api_service import (
OpenApiService,
OpenApiServiceError,
OpenApiWebSocketChatBridge,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from .api_key import ALL_OPEN_API_SCOPES
from .chat import (
BotMessageAccumulator,
ChatRoute,
collect_plain_text_from_message_parts,
)
from .route import Response, Route, RouteContext
@@ -32,13 +31,14 @@ class OpenApiRoute(Route):
context: RouteContext,
db: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
chat_route: ChatRoute,
chat_service: ChatService,
*,
register_routes: bool = True,
) -> None:
super().__init__(context)
self.db = db
self.core_lifecycle = core_lifecycle
self.platform_manager = core_lifecycle.platform_manager
self.chat_route = chat_route
self.chat_service = chat_service
self.service = OpenApiService(db, core_lifecycle)
self.routes = {
"/v1/chat": ("POST", self.chat_send),
@@ -51,151 +51,85 @@ class OpenApiRoute(Route):
"/v1/im/message": ("POST", self.send_message),
"/v1/im/bots": ("GET", self.get_bots),
}
self.register_routes()
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
if register_routes:
self.register_routes()
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
@staticmethod
def _resolve_open_username(
raw_username: str | None,
) -> tuple[str | None, str | None]:
if raw_username is None:
return None, "Missing key: username"
username = str(raw_username).strip()
if not username:
return None, "username is empty"
return username, None
def _ok(data=None):
return Response().ok(data=data).__dict__
@staticmethod
def _error(message: str):
return Response().error(message).__dict__
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result)
except (OpenApiServiceError, ChatServiceError) as exc:
return self._error(str(exc))
async def _run_json(self, operation):
async def invoke():
data = await self._json_body()
return operation(data)
return await self._run(invoke)
def _get_chat_config_list(self) -> list[dict]:
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
result = []
for conf_info in conf_list:
conf_id = str(conf_info.get("id", "")).strip()
result.append(
{
"id": conf_id,
"name": str(conf_info.get("name", "")).strip(),
"path": str(conf_info.get("path", "")).strip(),
"is_default": conf_id == "default",
}
)
return result
def _resolve_chat_config_id(self, post_data: dict) -> tuple[str | None, str | None]:
raw_config_id = post_data.get("config_id")
raw_config_name = post_data.get("config_name")
config_id = str(raw_config_id).strip() if raw_config_id is not None else ""
config_name = (
str(raw_config_name).strip() if raw_config_name is not None else ""
)
if not config_id and not config_name:
return None, None
conf_list = self._get_chat_config_list()
conf_map = {item["id"]: item for item in conf_list}
if config_id:
if config_id not in conf_map:
return None, f"config_id not found: {config_id}"
return config_id, None
if not config_name:
return None, "config_name is empty"
matched = [item for item in conf_list if item["name"] == config_name]
if not matched:
return None, f"config_name not found: {config_name}"
if len(matched) > 1:
return (
None,
f"config_name is ambiguous, please use config_id: {config_name}",
)
return matched[0]["id"], None
async def _ensure_chat_session(
self,
username: str,
session_id: str,
) -> str | None:
session = await self.db.get_platform_session_by_id(session_id)
if session:
if session.creator != username:
return "session_id belongs to another username"
return None
try:
await self.db.create_platform_session(
creator=username,
platform_id="webchat",
session_id=session_id,
is_group=0,
)
except Exception as e:
# Handle rare race when same session_id is created concurrently.
existing = await self.db.get_platform_session_by_id(session_id)
if existing and existing.creator == username:
return None
logger.error("Failed to create chat session %s: %s", session_id, e)
return f"Failed to create session: {e}"
return None
return self.service.get_chat_config_list()
async def chat_send(self):
post_data = await request.get_json(silent=True) or {}
effective_username, username_err = self._resolve_open_username(
post_data.get("username")
)
if username_err:
return Response().error(username_err).__dict__
if not effective_username:
return Response().error("Invalid username").__dict__
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
if not session_id:
session_id = str(uuid4())
post_data["session_id"] = session_id
ensure_session_err = await self._ensure_chat_session(
effective_username,
session_id,
)
if ensure_session_err:
return Response().error(ensure_session_err).__dict__
config_id, resolve_err = self._resolve_chat_config_id(post_data)
if resolve_err:
return Response().error(resolve_err).__dict__
original_username = g.get("username", "guest")
g.username = effective_username
if config_id:
umo = f"webchat:FriendMessage:webchat!{effective_username}!{session_id}"
try:
if config_id == "default":
await self.core_lifecycle.umop_config_router.delete_route(umo)
else:
await self.core_lifecycle.umop_config_router.update_route(
umo, config_id
)
except Exception as e:
logger.error(
"Failed to update chat config route for %s with %s: %s",
umo,
config_id,
e,
exc_info=True,
)
return (
Response()
.error(f"Failed to update chat config route: {e}")
.__dict__
)
try:
return await self.chat_route.chat(post_data=post_data)
finally:
g.username = original_username
(
effective_username,
session_id,
config_id,
) = await self.service.prepare_chat_send(
post_data,
self._get_chat_config_list(),
)
except OpenApiServiceError as exc:
return self._error(str(exc))
config_err = await self.service.update_session_config_route(
username=effective_username,
session_id=session_id,
config_id=config_id,
)
if config_err:
return self._error(config_err)
return await self._chat_response(effective_username, post_data)
async def _chat_response(self, username: str, post_data: dict):
try:
stream = await self.chat_service.build_chat_stream(username, post_data)
except ChatServiceError as exc:
return self._error(str(exc))
response = cast(
CompatResponse,
await make_response(
stream,
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
},
),
)
response.timeout = None
return response
@staticmethod
def _extract_ws_api_key() -> str | None:
@@ -213,451 +147,71 @@ class OpenApiRoute(Route):
return auth_header.removeprefix("ApiKey ").strip()
return None
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
raw_key = self._extract_ws_api_key()
if not raw_key:
return False, "Missing API key"
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
return False, "Invalid API key"
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
if "*" not in scopes and "chat" not in scopes:
return False, "Insufficient API key scope"
await self.db.touch_api_key(api_key.key_id)
return True, None
async def _send_chat_ws_error(self, message: str, code: str) -> None:
await websocket.send_json(
{
"type": "error",
"code": code,
"data": message,
}
)
async def _update_session_config_route(
async def _insert_webchat_user_message(
self,
*,
username: str,
session_id: str,
config_id: str | None,
) -> str | None:
if not config_id:
return None
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
try:
if config_id == "default":
await self.core_lifecycle.umop_config_router.delete_route(umo)
else:
await self.core_lifecycle.umop_config_router.update_route(
umo, config_id
)
except Exception as e:
logger.error(
"Failed to update chat config route for %s with %s: %s",
umo,
config_id,
e,
exc_info=True,
)
return f"Failed to update chat config route: {e}"
return None
async def _handle_chat_ws_send(self, post_data: dict) -> None:
effective_username, username_err = self._resolve_open_username(
post_data.get("username")
)
if username_err or not effective_username:
await self._send_chat_ws_error(
username_err or "Invalid username", "BAD_USER"
)
return
message = post_data.get("message")
if message is None:
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
return
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
if not session_id:
session_id = str(uuid4())
ensure_session_err = await self._ensure_chat_session(
effective_username,
session_id,
)
if ensure_session_err:
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
return
config_id, resolve_err = self._resolve_chat_config_id(post_data)
if resolve_err:
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
return
config_err = await self._update_session_config_route(
username=effective_username,
effective_username: str,
message_parts: list,
) -> None:
await self.service.insert_webchat_user_message(
session_id=session_id,
config_id=config_id,
effective_username=effective_username,
message_parts=message_parts,
)
if config_err:
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
return
message_parts = await self.chat_route._build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
await self._send_chat_ws_error(
"Message content is empty (reply only is not allowed)",
"INVALID_MESSAGE",
)
return
message_id = str(post_data.get("message_id") or uuid4())
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
effective_username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
)
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.chat_route.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=effective_username,
sender_name=effective_username,
)
await websocket.send_json(
{
"type": "session_id",
"data": None,
"session_id": session_id,
"message_id": message_id,
}
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if "message_id" in result and result["message_id"] != message_id:
logger.warning("openapi ws stream message_id mismatch")
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
stats_info = {
"type": "agent_stats",
"data": json.loads(result_text),
}
await websocket.send_json(stats_info)
agent_stats = stats_info["data"]
except Exception:
pass
continue
await websocket.send_json(result)
if msg_type == "plain":
message_accumulator.add_plain(
result_text,
chain_type=chain_type,
streaming=streaming,
)
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "image"
)
message_accumulator.add_attachment(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "record"
)
message_accumulator.add_attachment(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "file"
)
message_accumulator.add_attachment(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "video"
)
message_accumulator.add_attachment(part)
should_save = False
if msg_type == "end":
should_save = bool(
message_accumulator.has_content() or refs or agent_stats
)
elif (streaming and msg_type == "complete") or not streaming:
if chain_type not in ("tool_call", "tool_call_result"):
should_save = True
if should_save:
message_parts_to_save = message_accumulator.build_message_parts(
include_pending_tool_calls=True
)
plain_text = collect_plain_text_from_message_parts(
message_parts_to_save
)
try:
refs = self.chat_route._extract_web_search_refs(
plain_text,
message_parts_to_save,
)
except Exception as e:
logger.exception(
f"Open API WS failed to extract web search refs: {e}",
exc_info=True,
)
saved_record = await self.chat_route._save_bot_message(
session_id,
message_parts_to_save,
agent_stats,
refs,
)
if saved_record:
await websocket.send_json(
{
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": to_utc_isoformat(
saved_record.created_at
),
},
"session_id": session_id,
}
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
if msg_type == "end":
break
except Exception as e:
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
await self._send_chat_ws_error(
f"Failed to process message: {e}", "PROCESSING_ERROR"
)
finally:
webchat_queue_mgr.remove_back_queue(message_id)
def _build_chat_ws_bridge(self) -> OpenApiWebSocketChatBridge:
return OpenApiWebSocketChatBridge(
build_user_message_parts=self.chat_service.build_user_message_parts,
create_attachment_from_file=self.chat_service.create_attachment_from_file,
extract_web_search_refs=extract_web_search_refs,
insert_user_message=self._insert_webchat_user_message,
save_bot_message=self.chat_service.save_bot_message,
)
async def chat_ws(self) -> None:
authed, auth_err = await self._authenticate_chat_ws_api_key()
if not authed:
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
await websocket.close(1008, auth_err or "Unauthorized")
return
try:
while True:
message = await websocket.receive_json()
if not isinstance(message, dict):
await self._send_chat_ws_error(
"message must be an object",
"INVALID_MESSAGE",
)
continue
msg_type = message.get("t", "send")
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
continue
if msg_type != "send":
await self._send_chat_ws_error(
f"Unsupported message type: {msg_type}",
"INVALID_MESSAGE",
)
continue
await self._handle_chat_ws_send(message)
except Exception as e:
logger.debug("Open API WS connection closed: %s", e)
await self.service.run_chat_websocket(
raw_api_key=self._extract_ws_api_key(),
receive_json=websocket.receive_json,
send_json=websocket.send_json,
close=websocket.close,
conf_list=self._get_chat_config_list(),
chat_bridge=self._build_chat_ws_bridge(),
)
async def openapi_upload_file(self):
return await self.chat_route.post_file()
return await self._run(
self.chat_service.save_uploaded_file_from_legacy_files(await request.files)
)
async def openapi_get_file(self):
return await self.chat_route.get_attachment()
try:
(
file_path,
mimetype,
) = await self.chat_service.resolve_attachment_file_from_legacy_query(
request.args.get("attachment_id")
)
return await send_file(file_path, mimetype=mimetype)
except ChatServiceError as exc:
return self._error(str(exc))
except (FileNotFoundError, OSError):
return self._error("File access error")
async def get_chat_sessions(self):
username, username_err = self._resolve_open_username(
request.args.get("username")
)
if username_err:
return Response().error(username_err).__dict__
assert username is not None # for type checker
try:
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
except ValueError:
return Response().error("page and page_size must be integers").__dict__
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100:
page_size = 100
platform_id = request.args.get("platform_id")
(
paginated_sessions,
total,
) = await self.db.get_platform_sessions_by_creator_paginated(
creator=username,
platform_id=platform_id,
page=page,
page_size=page_size,
exclude_project_sessions=True,
)
sessions_data = []
for item in paginated_sessions:
session = item["session"]
sessions_data.append(
{
"session_id": session.session_id,
"platform_id": session.platform_id,
"creator": session.creator,
"display_name": session.display_name,
"is_group": session.is_group,
"created_at": to_utc_isoformat(session.created_at),
"updated_at": to_utc_isoformat(session.updated_at),
}
return await self._run(
self.service.get_chat_sessions_from_legacy_query(
username=request.args.get("username"),
page=request.args.get("page", 1),
page_size=request.args.get("page_size", 20),
platform_id=request.args.get("platform_id"),
)
return (
Response()
.ok(
data={
"sessions": sessions_data,
"page": page,
"page_size": page_size,
"total": total,
}
)
.__dict__
)
async def get_chat_configs(self):
conf_list = self._get_chat_config_list()
return Response().ok(data={"configs": conf_list}).__dict__
async def _build_message_chain_from_payload(
self,
message_payload: str | list,
):
return await build_message_chain_from_payload(
message_payload,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=True,
)
return self._ok({"configs": self._get_chat_config_list()})
async def send_message(self):
post_data = await request.json or {}
message_payload = post_data.get("message", {})
umo = post_data.get("umo")
if message_payload is None:
return Response().error("Missing key: message").__dict__
if not umo:
return Response().error("Missing key: umo").__dict__
try:
session = MessageSesion.from_str(str(umo))
except Exception as e:
return Response().error(f"Invalid umo: {e}").__dict__
platform_id = session.platform_name
platform_inst = next(
(
inst
for inst in self.platform_manager.platform_insts
if inst.meta().id == platform_id
),
None,
)
if not platform_inst:
return (
Response()
.error(f"Bot not found or not running for platform: {platform_id}")
.__dict__
)
try:
message_chain = await self._build_message_chain_from_payload(
message_payload
)
await platform_inst.send_by_session(session, message_chain)
return Response().ok().__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"Open API send_message failed: {e}", exc_info=True)
return Response().error(f"Failed to send message: {e}").__dict__
return await self._run_json(self.service.send_message)
async def get_bots(self):
bot_ids = []
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
platform_id = platform.get("id") if isinstance(platform, dict) else None
if (
isinstance(platform_id, str)
and platform_id
and platform_id not in bot_ids
):
bot_ids.append(platform_id)
return Response().ok(data={"bot_ids": bot_ids}).__dict__
return await self._run(self.service.get_bots)

View File

@@ -1,11 +1,10 @@
import traceback
from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.sentinels import NOT_GIVEN
from astrbot.dashboard.fastapi_compat import request
from astrbot.dashboard.services.persona_service import (
PersonaService,
PersonaServiceError,
)
from .route import Response, Route, RouteContext
@@ -14,7 +13,7 @@ class PersonaRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
db_helper,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
@@ -26,7 +25,6 @@ class PersonaRoute(Route):
"/persona/delete": ("POST", self.delete_persona),
"/persona/move": ("POST", self.move_persona),
"/persona/reorder": ("POST", self.reorder_items),
# Folder routes
"/persona/folder/list": ("GET", self.list_folders),
"/persona/folder/tree": ("GET", self.get_folder_tree),
"/persona/folder/detail": ("POST", self.get_folder_detail),
@@ -34,464 +32,129 @@ class PersonaRoute(Route):
"/persona/folder/update": ("POST", self.update_folder),
"/persona/folder/delete": ("POST", self.delete_folder),
}
self.db_helper = db_helper
self.persona_mgr = core_lifecycle.persona_mgr
self.service = PersonaService(core_lifecycle)
self.register_routes()
@staticmethod
def _ok(data):
return Response().ok(data).__dict__
@staticmethod
def _error(message: str):
return Response().error(message).__dict__
@staticmethod
async def _run(self, operation, *, label: str):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result)
except (PersonaServiceError, ValueError) as exc:
return self._error(str(exc))
except Exception as exc:
logger.error("%s: %s", label, exc, exc_info=True)
return self._error(f"{label}: {exc!s}")
async def list_personas(self):
"""获取所有人格列表"""
try:
# 支持按文件夹筛选
folder_id = request.args.get("folder_id")
if folder_id is not None:
personas = await self.persona_mgr.get_personas_by_folder(
folder_id if folder_id else None
)
else:
personas = await self.persona_mgr.get_all_personas()
return (
Response()
.ok(
[
{
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools,
"skills": persona.skills,
"custom_error_message": persona.custom_error_message,
"folder_id": persona.folder_id,
"sort_order": persona.sort_order,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
}
for persona in personas
],
)
.__dict__
)
except Exception as e:
logger.error(f"获取人格列表失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取人格列表失败: {e!s}").__dict__
return await self._run(
lambda: self.service.list_personas_from_legacy_query(
folder_id=request.args.get("folder_id"),
has_folder_id="folder_id" in request.args,
),
label="获取人格列表失败",
)
async def get_persona_detail(self):
"""获取指定人格的详细信息"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
persona = await self.persona_mgr.get_persona(persona_id)
if not persona:
return Response().error("人格不存在").__dict__
return (
Response()
.ok(
{
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools,
"skills": persona.skills,
"custom_error_message": persona.custom_error_message,
"folder_id": persona.folder_id,
"sort_order": persona.sort_order,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
},
)
.__dict__
)
except Exception as e:
logger.error(f"获取人格详情失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取人格详情失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.get_persona_detail(data),
label="获取人格详情失败",
)
async def create_persona(self):
"""创建新人格"""
try:
data = await request.get_json()
persona_id = data.get("persona_id", "").strip()
system_prompt = data.get("system_prompt", "").strip()
begin_dialogs = data.get("begin_dialogs", [])
tools = data.get("tools")
skills = data.get("skills")
custom_error_message = data.get("custom_error_message")
folder_id = data.get("folder_id") # None 表示根目录
sort_order = data.get("sort_order", 0)
if not persona_id:
return Response().error("人格ID不能为空").__dict__
if not system_prompt:
return Response().error("系统提示词不能为空").__dict__
if custom_error_message is not None:
if not isinstance(custom_error_message, str):
return Response().error("自定义报错回复信息必须是字符串").__dict__
custom_error_message = custom_error_message.strip() or None
# 验证 begin_dialogs 格式
if begin_dialogs and len(begin_dialogs) % 2 != 0:
return (
Response()
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
.__dict__
)
persona = await self.persona_mgr.create_persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs if begin_dialogs else None,
tools=tools if tools else None,
skills=skills if skills else None,
custom_error_message=custom_error_message,
folder_id=folder_id,
sort_order=sort_order,
)
return (
Response()
.ok(
{
"message": "人格创建成功",
"persona": {
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools or [],
"skills": persona.skills or [],
"custom_error_message": persona.custom_error_message,
"folder_id": persona.folder_id,
"sort_order": persona.sort_order,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
},
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"创建人格失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"创建人格失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.create_persona(data),
label="创建人格失败",
)
async def update_persona(self):
"""更新人格信息"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
system_prompt = data.get("system_prompt")
begin_dialogs = data.get("begin_dialogs")
has_tools = "tools" in data
tools = data.get("tools")
has_skills = "skills" in data
skills = data.get("skills")
has_custom_error_message = "custom_error_message" in data
custom_error_message = data.get("custom_error_message")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
if has_custom_error_message:
if custom_error_message is not None and not isinstance(
custom_error_message, str
):
return Response().error("自定义报错回复信息必须是字符串").__dict__
if isinstance(custom_error_message, str):
custom_error_message = custom_error_message.strip() or None
# 验证 begin_dialogs 格式
if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
return (
Response()
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
.__dict__
)
update_kwargs = {
"persona_id": persona_id,
"system_prompt": system_prompt,
"begin_dialogs": begin_dialogs,
}
if has_tools:
update_kwargs["tools"] = tools
if has_skills:
update_kwargs["skills"] = skills
if has_custom_error_message:
update_kwargs["custom_error_message"] = custom_error_message
await self.persona_mgr.update_persona(**update_kwargs)
return Response().ok({"message": "人格更新成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"更新人格失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新人格失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.update_persona(data),
label="更新人格失败",
)
async def delete_persona(self):
"""删除人格"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
await self.persona_mgr.delete_persona(persona_id)
return Response().ok({"message": "人格删除成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"删除人格失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"删除人格失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.delete_persona(data),
label="删除人格失败",
)
async def move_persona(self):
"""移动人格到指定文件夹"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
folder_id = data.get("folder_id") # None 表示移动到根目录
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
await self.persona_mgr.move_persona_to_folder(persona_id, folder_id)
return Response().ok({"message": "人格移动成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"移动人格失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"移动人格失败: {e!s}").__dict__
# ====
# Folder Routes
# ====
data = await request.get_json()
return await self._run(
self.service.move_persona(data),
label="移动人格失败",
)
async def list_folders(self):
"""获取文件夹列表"""
try:
parent_id = request.args.get("parent_id")
# 空字符串视为 None根目录
if parent_id == "":
parent_id = None
folders = await self.persona_mgr.get_folders(parent_id)
return (
Response()
.ok(
[
{
"folder_id": folder.folder_id,
"name": folder.name,
"parent_id": folder.parent_id,
"description": folder.description,
"sort_order": folder.sort_order,
"created_at": folder.created_at.isoformat()
if folder.created_at
else None,
"updated_at": folder.updated_at.isoformat()
if folder.updated_at
else None,
}
for folder in folders
],
)
.__dict__
)
except Exception as e:
logger.error(f"获取文件夹列表失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取文件夹列表失败: {e!s}").__dict__
return await self._run(
lambda: self.service.list_folders_from_legacy_query(
request.args.get("parent_id")
),
label="获取文件夹列表失败",
)
async def get_folder_tree(self):
"""获取文件夹树形结构"""
try:
tree = await self.persona_mgr.get_folder_tree()
return Response().ok(tree).__dict__
except Exception as e:
logger.error(f"获取文件夹树失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取文件夹树失败: {e!s}").__dict__
return await self._run(self.service.get_folder_tree, label="获取文件夹树失败")
async def get_folder_detail(self):
"""获取指定文件夹的详细信息"""
try:
data = await request.get_json()
folder_id = data.get("folder_id")
if not folder_id:
return Response().error("缺少必要参数: folder_id").__dict__
folder = await self.persona_mgr.get_folder(folder_id)
if not folder:
return Response().error("文件夹不存在").__dict__
return (
Response()
.ok(
{
"folder_id": folder.folder_id,
"name": folder.name,
"parent_id": folder.parent_id,
"description": folder.description,
"sort_order": folder.sort_order,
"created_at": folder.created_at.isoformat()
if folder.created_at
else None,
"updated_at": folder.updated_at.isoformat()
if folder.updated_at
else None,
},
)
.__dict__
)
except Exception as e:
logger.error(f"获取文件夹详情失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取文件夹详情失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.get_folder_detail(data),
label="获取文件夹详情失败",
)
async def create_folder(self):
"""创建文件夹"""
try:
data = await request.get_json()
name = data.get("name", "").strip()
parent_id = data.get("parent_id")
description = data.get("description")
sort_order = data.get("sort_order", 0)
if not name:
return Response().error("文件夹名称不能为空").__dict__
folder = await self.persona_mgr.create_folder(
name=name,
parent_id=parent_id,
description=description,
sort_order=sort_order,
)
return (
Response()
.ok(
{
"message": "文件夹创建成功",
"folder": {
"folder_id": folder.folder_id,
"name": folder.name,
"parent_id": folder.parent_id,
"description": folder.description,
"sort_order": folder.sort_order,
"created_at": folder.created_at.isoformat()
if folder.created_at
else None,
"updated_at": folder.updated_at.isoformat()
if folder.updated_at
else None,
},
},
)
.__dict__
)
except Exception as e:
logger.error(f"创建文件夹失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"创建文件夹失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.create_folder(data),
label="创建文件夹失败",
)
async def update_folder(self):
"""更新文件夹信息"""
try:
data = await request.get_json()
folder_id = data.get("folder_id")
name = data.get("name")
parent_id = data.get("parent_id") if "parent_id" in data else NOT_GIVEN
description = (
data.get("description") if "description" in data else NOT_GIVEN
)
sort_order = data.get("sort_order")
if not folder_id:
return Response().error("缺少必要参数: folder_id").__dict__
await self.persona_mgr.update_folder(
folder_id=folder_id,
name=name,
parent_id=parent_id,
description=description,
sort_order=sort_order,
)
return Response().ok({"message": "文件夹更新成功"}).__dict__
except Exception as e:
logger.error(f"更新文件夹失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新文件夹失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.update_folder(data),
label="更新文件夹失败",
)
async def delete_folder(self):
"""删除文件夹"""
try:
data = await request.get_json()
folder_id = data.get("folder_id")
if not folder_id:
return Response().error("缺少必要参数: folder_id").__dict__
await self.persona_mgr.delete_folder(folder_id)
return Response().ok({"message": "文件夹删除成功"}).__dict__
except Exception as e:
logger.error(f"删除文件夹失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"删除文件夹失败: {e!s}").__dict__
data = await request.get_json()
return await self._run(
self.service.delete_folder(data),
label="删除文件夹失败",
)
async def reorder_items(self):
"""批量更新排序顺序
请求体格式:
{
"items": [
{"id": "persona_id_1", "type": "persona", "sort_order": 0},
{"id": "persona_id_2", "type": "persona", "sort_order": 1},
{"id": "folder_id_1", "type": "folder", "sort_order": 0},
...
]
}
"""
try:
data = await request.get_json()
items = data.get("items", [])
if not items:
return Response().error("items 不能为空").__dict__
# 验证每个 item 的格式
for item in items:
if not all(k in item for k in ("id", "type", "sort_order")):
return (
Response()
.error("每个 item 必须包含 id, type, sort_order 字段")
.__dict__
)
if item["type"] not in ("persona", "folder"):
return (
Response()
.error("type 字段必须是 'persona''folder'")
.__dict__
)
await self.persona_mgr.batch_update_sort_order(items)
return Response().ok({"message": "排序更新成功"}).__dict__
except Exception as e:
logger.error(f"更新排序失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新排序失败: {e!s}").__dict__
"""批量更新排序顺序"""
data = await request.get_json()
return await self._run(
self.service.reorder_items(data),
label="更新排序失败",
)

View File

@@ -1,39 +1,20 @@
"""统一 Webhook 路由
"""Unified webhook routes.
提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。
Provides a unified webhook callback entrypoint for multiple platforms.
"""
import secrets
import string
from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform import Platform
from astrbot.core.platform.sources.dingtalk.app_registration import (
poll_dingtalk_app_registration_once,
request_dingtalk_app_registration,
)
from astrbot.core.platform.sources.lark.app_registration import (
poll_app_registration_once,
request_app_registration,
)
from astrbot.core.platform.sources.lark.bot_info import request_lark_bot_info
from astrbot.core.platform.sources.weixin_oc.login_registration import (
poll_weixin_oc_login_once,
request_weixin_oc_login_qr,
from astrbot.dashboard.fastapi_compat import request
from astrbot.dashboard.services.platform_service import (
PlatformService,
PlatformServiceError,
)
from .route import Response, Route, RouteContext
def _random_platform_id_suffix() -> str:
return "_" + "".join(secrets.choice(string.ascii_lowercase) for _ in range(4))
class PlatformRoute(Route):
"""统一 Webhook 路由"""
"""Unified webhook route."""
def __init__(
self,
@@ -41,21 +22,17 @@ class PlatformRoute(Route):
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.platform_manager = core_lifecycle.platform_manager
self.service = PlatformService(core_lifecycle)
self._register_webhook_routes()
def _register_webhook_routes(self) -> None:
"""注册 webhook 路由"""
# 统一 webhook 入口,支持 GET 和 POST
self.app.add_url_rule(
"/api/platform/webhook/<webhook_uuid>",
view_func=self.unified_webhook_callback,
methods=["GET", "POST"],
)
# 平台统计信息接口
self.app.add_url_rule(
"/api/platform/stats",
view_func=self.get_platform_stats,
@@ -68,218 +45,37 @@ class PlatformRoute(Route):
methods=["POST"],
)
async def unified_webhook_callback(self, webhook_uuid: str):
"""统一 webhook 回调入口
@staticmethod
def _ok(data=None):
return Response().ok(data).__dict__
Args:
webhook_uuid: 平台配置中的 webhook_uuid
@staticmethod
def _error(exc: PlatformServiceError):
return Response().error(str(exc)).__dict__, exc.status_code
Returns:
根据平台适配器返回相应的响应
"""
# 根据 webhook_uuid 查找对应的平台
platform_adapter = self._find_platform_by_uuid(webhook_uuid)
if not platform_adapter:
logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台")
return Response().error("未找到对应平台").__dict__, 404
# 调用平台适配器的 webhook_callback 方法
async def _run(self, operation):
try:
result = await platform_adapter.webhook_callback(request)
return result
except NotImplementedError:
logger.error(
f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法"
)
return Response().error("平台未支持统一 Webhook 模式").__dict__, 500
except Exception as e:
logger.error(f"处理 webhook 回调时发生错误: {e}", exc_info=True)
return Response().error("处理回调失败").__dict__, 500
return self._ok(await operation())
except PlatformServiceError as exc:
return self._error(exc)
def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
"""根据 webhook_uuid 查找对应的平台适配器
async def _run_sync(self, operation):
try:
return self._ok(operation())
except PlatformServiceError as exc:
return self._error(exc)
Args:
webhook_uuid: webhook UUID
Returns:
平台适配器实例,未找到则返回 None
"""
for platform in self.platform_manager.platform_insts:
if platform.config.get("webhook_uuid") == webhook_uuid:
if platform.unified_webhook():
return platform
return None
async def unified_webhook_callback(self, webhook_uuid: str):
return await self._run(
lambda: self.service.handle_webhook_callback(webhook_uuid, request)
)
async def get_platform_stats(self):
"""获取所有平台的统计信息
Returns:
包含平台统计信息的响应
"""
try:
stats = self.platform_manager.get_all_stats()
return Response().ok(stats).__dict__
except Exception as e:
logger.error(f"获取平台统计信息失败: {e}", exc_info=True)
return Response().error(f"获取统计信息失败: {e}").__dict__, 500
return await self._run_sync(self.service.get_platform_stats)
async def handle_platform_registration(self, platform_type: str):
"""Handle dashboard one-click platform registration actions."""
try:
payload = await request.get_json(silent=True) or {}
action = str(payload.get("action", "")).strip().lower()
if not action:
return Response().error("Missing action").__dict__, 400
platform_config = payload.get("platform_config")
if not isinstance(platform_config, dict):
platform_config = {}
if platform_type == "lark":
return await self._handle_lark_registration(
action,
payload,
platform_config,
)
if platform_type == "weixin_oc":
return await self._handle_weixin_oc_registration(
action,
payload,
platform_config,
)
if platform_type == "dingtalk":
return await self._handle_dingtalk_registration(action, payload)
return Response().error(
f"Unsupported platform registration: {platform_type}"
).__dict__, 404
except Exception as e:
logger.error(f"处理平台一键创建请求失败: {e}", exc_info=True)
return Response().error(str(e)).__dict__, 500
async def _handle_lark_registration(
self,
action: str,
payload: dict,
platform_config: dict,
):
domain = str(platform_config.get("domain") or "").strip()
if action == "start":
registration = await request_app_registration(domain)
return (
Response()
.ok(
{
"status": "pending",
"device_code": registration.device_code,
"registration_code": registration.device_code,
"user_code": registration.user_code,
"verification_uri": registration.verification_uri,
"verification_uri_complete": registration.verification_uri_complete,
"expires_in": registration.expires_in,
"interval": registration.interval,
}
)
.__dict__
)
if action == "poll":
device_code = str(
payload.get("device_code") or payload.get("registration_code") or ""
).strip()
if not device_code:
return Response().error("Missing device_code").__dict__, 400
result = await poll_app_registration_once(
domain=domain,
device_code=device_code,
)
if result.get("status") == "created":
try:
bot_info = await request_lark_bot_info(
domain=str(result.get("domain") or domain),
app_id=str(result.get("app_id") or ""),
app_secret=str(result.get("app_secret") or ""),
)
if bot_info.app_name:
result["bot_name"] = bot_info.app_name
if bot_info.open_id:
result["bot_open_id"] = bot_info.open_id
except Exception as e:
logger.error(f"获取飞书机器人信息失败: {e}", exc_info=True)
return Response().ok(result).__dict__
return Response().error(f"Unsupported action: {action}").__dict__, 400
async def _handle_dingtalk_registration(self, action: str, payload: dict):
if action == "start":
registration = await request_dingtalk_app_registration()
return (
Response()
.ok(
{
"status": "pending",
"device_code": registration.device_code,
"registration_code": registration.device_code,
"user_code": registration.user_code,
"verification_uri": registration.verification_uri,
"verification_uri_complete": registration.verification_uri_complete,
"expires_in": registration.expires_in,
"interval": registration.interval,
}
)
.__dict__
)
if action == "poll":
device_code = str(
payload.get("device_code") or payload.get("registration_code") or ""
).strip()
if not device_code:
return Response().error("Missing device_code").__dict__, 400
result = await poll_dingtalk_app_registration_once(device_code)
if result.get("status") == "created":
result["platform_id_suffix"] = _random_platform_id_suffix()
return Response().ok(result).__dict__
return Response().error(f"Unsupported action: {action}").__dict__, 400
async def _handle_weixin_oc_registration(
self,
action: str,
payload: dict,
platform_config: dict,
):
if action == "start":
registration = await request_weixin_oc_login_qr(platform_config)
return (
Response()
.ok(
{
"status": "pending",
"registration_code": registration.qrcode,
"qrcode": registration.qrcode,
"qrcode_img_content": registration.qrcode_img_content,
"interval": registration.interval,
}
)
.__dict__
)
if action == "poll":
qrcode = str(
payload.get("qrcode") or payload.get("registration_code") or ""
).strip()
if not qrcode:
return Response().error("Missing qrcode").__dict__, 400
result = await poll_weixin_oc_login_once(
platform_config=platform_config,
qrcode=qrcode,
)
if result.get("status") == "created":
result["platform_id_suffix"] = _random_platform_id_suffix()
return Response().ok(result).__dict__
return Response().error(f"Unsupported action: {action}").__dict__, 400
payload = await request.get_json(silent=True) or {}
return await self._run(
lambda: self.service.handle_platform_registration(platform_type, payload)
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,13 @@
from dataclasses import dataclass
from quart import Quart
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.dashboard.fastapi_compat import FastAPIAppAdapter
@dataclass
class RouteContext:
config: AstrBotConfig
app: Quart
app: FastAPIAppAdapter
class Route:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,38 +1,7 @@
import asyncio
import os
import re
import threading
import time
import traceback
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from functools import cmp_to_key
from pathlib import Path
import aiohttp
import psutil
from quart import request
from sqlmodel import col, select
from astrbot.core import DEMO_MODE, logger
from astrbot.core.config import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.helper import check_migration_needed_v4
from astrbot.core.db.po import ProviderStat
from astrbot.core.utils.astrbot_path import get_astrbot_path
from astrbot.core.utils.auth_password import (
is_default_dashboard_password,
is_legacy_dashboard_password,
)
from astrbot.core.utils.io import get_dashboard_version
from astrbot.core.utils.storage_cleaner import StorageCleaner
from astrbot.core.utils.version_comparator import VersionComparator
from astrbot.dashboard.password_state import (
get_dashboard_password_hash,
is_password_change_required,
is_password_storage_upgraded,
)
from astrbot.dashboard.fastapi_compat import request
from astrbot.dashboard.services.stat_service import StatService, StatServiceError
from .route import Response, Route, RouteContext
@@ -58,533 +27,79 @@ class StatRoute(Route):
"/stat/storage": ("GET", self.get_storage_status),
"/stat/storage/cleanup": ("POST", self.cleanup_storage),
}
self.db_helper = db_helper
self.service = StatService(db_helper, core_lifecycle, self.config)
self.register_routes()
self.core_lifecycle = core_lifecycle
self.storage_cleaner = StorageCleaner(self.config)
async def restart_core(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
await self.core_lifecycle.restart()
return Response().ok().__dict__
def _get_running_time_components(self, total_seconds: int):
"""将总秒数转换为时分秒组件"""
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
return {"hours": hours, "minutes": minutes, "seconds": seconds}
async def is_default_cred(self):
password_change_required = await is_password_change_required(
self.db_helper,
self.config,
)
if password_change_required:
return not DEMO_MODE
storage_upgraded = await is_password_storage_upgraded(
self.db_helper,
self.config,
)
if not storage_upgraded:
return False
username = self.config["dashboard"]["username"]
password = get_dashboard_password_hash(self.config, upgraded=True)
return (
username == "astrbot" and is_default_dashboard_password(password)
) and not DEMO_MODE
async def get_version(self):
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
storage_upgraded = await is_password_storage_upgraded(
self.db_helper,
self.config,
)
password = get_dashboard_password_hash(
self.config,
upgraded=storage_upgraded,
)
return (
Response()
.ok(
{
"version": VERSION,
"dashboard_version": await get_dashboard_version(),
"change_pwd_hint": await self.is_default_cred(),
"legacy_pwd_hint": is_legacy_dashboard_password(password),
"password_upgrade_required": not storage_upgraded,
"need_migration": need_migration,
},
)
.__dict__
)
async def get_start_time(self):
return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__
async def get_storage_status(self):
try:
status = await asyncio.to_thread(self.storage_cleaner.get_status)
return Response().ok(status).__dict__
except Exception:
logger.error("获取存储占用失败", exc_info=True)
return (
Response().error("获取存储占用失败,请查看后端日志了解详情。").__dict__
)
async def cleanup_storage(self):
try:
data = await request.get_json(silent=True)
target = "all"
if isinstance(data, dict):
target = str(data.get("target", "all"))
result = await asyncio.to_thread(self.storage_cleaner.cleanup, target)
return Response().ok(result).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception:
logger.error("清理存储失败", exc_info=True)
return Response().error("清理存储失败,请查看后端日志了解详情。").__dict__
async def get_stat(self):
offset_sec = request.args.get("offset_sec", 86400)
offset_sec = int(offset_sec)
try:
stat = self.db_helper.get_base_stats(offset_sec)
now = int(time.time())
start_time = now - offset_sec
message_time_based_stats = []
idx = 0
for bucket_end in range(start_time, now, 3600):
cnt = 0
while (
idx < len(stat.platform)
and stat.platform[idx].timestamp < bucket_end
):
cnt += stat.platform[idx].count
idx += 1
message_time_based_stats.append([bucket_end, cnt])
stat_dict = stat.__dict__
cpu_percent = psutil.cpu_percent(interval=0.5)
thread_count = threading.active_count()
# 获取插件信息
plugins = self.core_lifecycle.star_context.get_all_stars()
plugin_info = []
for plugin in plugins:
info = {
"name": getattr(plugin, "name", plugin.__class__.__name__),
"version": getattr(plugin, "version", "1.0.0"),
"is_enabled": True,
}
plugin_info.append(info)
# 计算运行时长组件
running_time = self._get_running_time_components(
int(time.time()) - self.core_lifecycle.start_time,
)
stat_dict.update(
{
"platform": self.db_helper.get_grouped_base_stats(
offset_sec,
).platform,
"message_count": self.db_helper.get_total_message_count() or 0,
"platform_count": len(
self.core_lifecycle.platform_manager.get_insts(),
),
"plugin_count": len(plugins),
"plugins": plugin_info,
"message_time_series": message_time_based_stats,
"running": running_time, # 现在返回时间组件而不是格式化的字符串
"memory": {
"process": psutil.Process().memory_info().rss >> 20,
"system": psutil.virtual_memory().total >> 20,
},
"cpu_percent": round(cpu_percent, 1),
"thread_count": thread_count,
"start_time": self.core_lifecycle.start_time,
},
)
return Response().ok(stat_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
@staticmethod
def _ensure_aware_utc(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _ok(data=None):
return Response().ok(data).__dict__
@staticmethod
def _error(message: str):
return Response().error(message).__dict__
async def _run(self, operation):
try:
return self._ok(await operation())
except StatServiceError as exc:
return self._error(str(exc))
async def _run_sync(self, operation):
try:
return self._ok(operation())
except StatServiceError as exc:
return self._error(str(exc))
async def _run_json(self, operation, *, silent: bool = False):
payload = await request.get_json(silent=silent)
return await self._run(lambda: operation(payload))
async def restart_core(self):
return await self._run(self.service.restart_core)
async def get_version(self):
return await self._run(self.service.get_version)
async def get_start_time(self):
return await self._run_sync(self.service.get_start_time)
async def get_storage_status(self):
return await self._run(self.service.get_storage_status)
async def cleanup_storage(self):
return await self._run_json(
self.service.cleanup_storage_from_legacy_payload,
silent=True,
)
async def get_stat(self):
return await self._run(
lambda: self.service.get_stat_from_legacy_query(
request.args.get("offset_sec", 86400)
)
)
async def get_provider_token_stats(self):
try:
try:
days = int(request.args.get("days", 1))
except (TypeError, ValueError):
days = 1
if days not in (1, 3, 7):
days = 1
local_tz = datetime.now().astimezone().tzinfo or timezone.utc
now_local = datetime.now(local_tz)
range_start_local = (now_local - timedelta(days=days)).replace(
minute=0, second=0, microsecond=0
return await self._run(
lambda: self.service.get_provider_token_stats_from_legacy_query(
request.args.get("days", 1)
)
today_start_local = now_local.replace(
hour=0, minute=0, second=0, microsecond=0
)
query_start_local = min(range_start_local, today_start_local)
query_start_utc = query_start_local.astimezone(timezone.utc)
async with self.db_helper.get_db() as session:
result = await session.execute(
select(ProviderStat)
.where(
ProviderStat.agent_type == "internal",
ProviderStat.created_at >= query_start_utc,
)
.order_by(col(ProviderStat.created_at).asc())
)
records = result.scalars().all()
bucket_timestamps: list[int] = []
bucket_cursor = range_start_local
while bucket_cursor <= now_local:
bucket_timestamps.append(int(bucket_cursor.timestamp() * 1000))
bucket_cursor += timedelta(hours=1)
trend_by_provider: dict[str, dict[int, int]] = defaultdict(
lambda: defaultdict(int)
)
total_by_provider: dict[str, int] = defaultdict(int)
total_by_umo: dict[str, int] = defaultdict(int)
total_by_bucket: dict[int, int] = defaultdict(int)
range_total_tokens = 0
range_total_output_tokens = 0
range_total_calls = 0
range_success_calls = 0
range_ttft_total_ms = 0.0
range_ttft_samples = 0
range_duration_total_ms = 0.0
range_duration_samples = 0
today_by_model: dict[str, int] = defaultdict(int)
today_by_provider: dict[str, int] = defaultdict(int)
today_total_tokens = 0
today_total_calls = 0
for record in records:
created_at_utc = self._ensure_aware_utc(record.created_at)
created_at_local = created_at_utc.astimezone(local_tz)
token_total = (
record.token_input_other
+ record.token_input_cached
+ record.token_output
)
provider_id = record.provider_id or "unknown"
provider_model = record.provider_model or "Unknown"
if created_at_local >= range_start_local:
bucket_local = created_at_local.replace(
minute=0, second=0, microsecond=0
)
bucket_ts = int(bucket_local.timestamp() * 1000)
trend_by_provider[provider_id][bucket_ts] += token_total
total_by_provider[provider_id] += token_total
total_by_umo[record.umo or "unknown"] += token_total
total_by_bucket[bucket_ts] += token_total
range_total_tokens += token_total
range_total_calls += 1
if record.status != "error":
range_success_calls += 1
if record.time_to_first_token > 0:
range_ttft_total_ms += record.time_to_first_token * 1000
range_ttft_samples += 1
if record.end_time > record.start_time:
range_duration_total_ms += (
record.end_time - record.start_time
) * 1000
range_duration_samples += 1
range_total_output_tokens += record.token_output
if created_at_local >= today_start_local:
today_total_calls += 1
today_total_tokens += token_total
today_by_model[provider_model] += token_total
today_by_provider[provider_id] += token_total
sorted_provider_ids = sorted(
total_by_provider.keys(),
key=lambda item: total_by_provider[item],
reverse=True,
)
series = [
{
"name": provider_id,
"data": [
[bucket_ts, trend_by_provider[provider_id].get(bucket_ts, 0)]
for bucket_ts in bucket_timestamps
],
"total_tokens": total_by_provider[provider_id],
}
for provider_id in sorted_provider_ids
]
total_series = [
[bucket_ts, total_by_bucket.get(bucket_ts, 0)]
for bucket_ts in bucket_timestamps
]
today_by_model_data = [
{"provider_model": model_name, "tokens": tokens}
for model_name, tokens in sorted(
today_by_model.items(),
key=lambda item: item[1],
reverse=True,
)
]
today_by_provider_data = [
{"provider_id": provider_id, "tokens": tokens}
for provider_id, tokens in sorted(
today_by_provider.items(),
key=lambda item: item[1],
reverse=True,
)
]
range_by_provider_data = [
{"provider_id": provider_id, "tokens": tokens}
for provider_id, tokens in sorted(
total_by_provider.items(),
key=lambda item: item[1],
reverse=True,
)
]
range_by_umo_data = [
{"umo": umo, "tokens": tokens}
for umo, tokens in sorted(
total_by_umo.items(),
key=lambda item: item[1],
reverse=True,
)
]
return (
Response()
.ok(
{
"days": days,
"trend": {
"series": series,
"total_series": total_series,
},
"range_total_tokens": range_total_tokens,
"range_total_calls": range_total_calls,
"range_avg_ttft_ms": (
range_ttft_total_ms / range_ttft_samples
if range_ttft_samples
else 0
),
"range_avg_duration_ms": (
range_duration_total_ms / range_duration_samples
if range_duration_samples
else 0
),
"range_avg_tpm": (
range_total_output_tokens
/ (range_duration_total_ms / 1000 / 60)
if range_duration_total_ms > 0
else 0
),
"range_success_rate": (
range_success_calls / range_total_calls
if range_total_calls
else 0
),
"range_by_provider": range_by_provider_data,
"range_by_umo": range_by_umo_data,
"today_total_tokens": today_total_tokens,
"today_total_calls": today_total_calls,
"today_by_model": today_by_model_data,
"today_by_provider": today_by_provider_data,
}
)
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
)
async def test_ghproxy_connection(self):
"""测试 GitHub 代理连接是否可用。"""
try:
data = await request.get_json()
proxy_url: str = data.get("proxy_url")
if not proxy_url:
return Response().error("proxy_url is required").__dict__
proxy_url = proxy_url.rstrip("/")
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
start_time = time.time()
async with (
aiohttp.ClientSession() as session,
session.get(
test_url,
timeout=aiohttp.ClientTimeout(total=10),
) as response,
):
if response.status == 200:
end_time = time.time()
_ = await response.text()
ret = {
"latency": round((end_time - start_time) * 1000, 2),
}
return Response().ok(data=ret).__dict__
return (
Response().error(f"Failed. Status code: {response.status}").__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
return await self._run_json(
self.service.test_ghproxy_connection_from_legacy_payload
)
async def get_changelog(self):
"""获取指定版本的更新日志"""
try:
version = request.args.get("version")
if not version:
return Response().error("version parameter is required").__dict__
version = version.lstrip("v")
# 防止路径遍历攻击
if not re.match(r"^[a-zA-Z0-9._-]+$", version):
return Response().error("Invalid version format").__dict__
if ".." in version or "/" in version or "\\" in version:
return Response().error("Invalid version format").__dict__
filename = f"v{version}.md"
project_path = get_astrbot_path()
changelogs_dir = os.path.join(project_path, "changelogs")
changelog_path = os.path.join(changelogs_dir, filename)
# 规范化路径,防止符号链接攻击
changelog_path = os.path.realpath(changelog_path)
changelogs_dir = os.path.realpath(changelogs_dir)
# 验证最终路径在预期的 changelogs 目录内(防止路径遍历)
# 确保规范化后的路径以 changelogs_dir 开头,且是目录内的文件
changelog_path_normalized = os.path.normpath(changelog_path)
changelogs_dir_normalized = os.path.normpath(changelogs_dir)
# 检查路径是否在预期目录内(必须是目录的子文件,不能是目录本身)
expected_prefix = changelogs_dir_normalized + os.sep
if not changelog_path_normalized.startswith(expected_prefix):
logger.warning(
f"Path traversal attempt detected: {version} -> {changelog_path}",
)
return Response().error("Invalid version format").__dict__
if not os.path.exists(changelog_path):
return (
Response()
.error(f"Changelog for version {version} not found")
.__dict__
)
if not os.path.isfile(changelog_path):
return (
Response()
.error(f"Changelog for version {version} not found")
.__dict__
)
with open(changelog_path, encoding="utf-8") as f:
content = f.read()
return Response().ok({"content": content, "version": version}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
return await self._run_sync(
lambda: self.service.get_changelog(request.args.get("version"))
)
async def list_changelog_versions(self):
"""获取所有可用的更新日志版本列表"""
try:
project_path = get_astrbot_path()
changelogs_dir = os.path.join(project_path, "changelogs")
if not os.path.exists(changelogs_dir):
return Response().ok({"versions": []}).__dict__
versions = []
for filename in os.listdir(changelogs_dir):
if filename.endswith(".md") and filename.startswith("v"):
# 提取版本号(去除 v 前缀和 .md 后缀)
version = filename[1:-3] # 去掉 "v" 和 ".md"
# 验证版本号格式
if re.match(r"^[a-zA-Z0-9._-]+$", version):
versions.append(version)
# 按版本号排序(降序,最新的在前)
# 使用项目中的 VersionComparator 进行语义化版本号排序
versions.sort(
key=cmp_to_key(
lambda v1, v2: VersionComparator.compare_version(v2, v1),
),
)
return Response().ok({"versions": versions}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
return await self._run_sync(self.service.list_changelog_versions)
async def get_first_notice(self):
"""读取项目根目录 FIRST_NOTICE.md 内容。"""
try:
locale = (request.args.get("locale") or "").strip()
if not re.match(r"^[A-Za-z0-9_-]*$", locale):
locale = ""
base_path = Path(get_astrbot_path())
candidates: list[Path] = []
if locale:
candidates.append(base_path / f"FIRST_NOTICE.{locale}.md")
if locale.lower().startswith("zh"):
candidates.append(base_path / "FIRST_NOTICE.md")
candidates.append(base_path / "FIRST_NOTICE.zh-CN.md")
elif locale.lower().startswith("en"):
candidates.append(base_path / "FIRST_NOTICE.en-US.md")
candidates.extend(
[
base_path / "FIRST_NOTICE.md",
base_path / "FIRST_NOTICE.en-US.md",
],
)
for notice_path in candidates:
if not notice_path.is_file():
continue
content = notice_path.read_text(encoding="utf-8")
if content.strip():
return Response().ok({"content": content}).__dict__
return Response().ok({"content": None}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
return await self._run_sync(
lambda: self.service.get_first_notice(request.args.get("locale"))
)

View File

@@ -1,37 +1,19 @@
from astrbot.dashboard.services.static_file_service import StaticFileService
from .route import Route, RouteContext
class StaticFileRoute(Route):
def __init__(self, context: RouteContext) -> None:
super().__init__(context)
self.service = StaticFileService()
index_ = [
"/",
"/auth/login",
"/config",
"/logs",
"/extension",
"/dashboard/default",
"/alkaid",
"/alkaid/knowledge-base",
"/alkaid/long-term-memory",
"/alkaid/other",
"/console",
"/chat",
"/settings",
"/platforms",
"/providers",
"/about",
"/extension-marketplace",
"/conversation",
"/tool-use",
]
for i in index_:
for i in self.service.list_index_routes():
self.app.add_url_rule(i, view_func=self.index)
@self.app.errorhandler(404)
async def page_not_found(e) -> str:
return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://docs.astrbot.app/faq.html。如果你正在测试回调地址可达性显示这段文字说明测试成功了。"
return self.service.get_not_found_message()
async def index(self):
return await self.app.send_static_file("index.html")

View File

@@ -1,10 +1,9 @@
import traceback
from quart import jsonify, request
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.dashboard.fastapi_compat import jsonify, request
from astrbot.dashboard.services.subagent_service import (
SubAgentService,
SubAgentServiceError,
)
from .route import Response, Route, RouteContext
@@ -16,7 +15,7 @@ class SubAgentRoute(Route):
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.service = SubAgentService(core_lifecycle)
# NOTE: dict cannot hold duplicate keys; use list form to register multiple
# methods for the same path.
self.routes = [
@@ -26,92 +25,35 @@ class SubAgentRoute(Route):
]
self.register_routes()
async def get_config(self):
@staticmethod
def _response(data=None, message: str | None = None):
return jsonify(Response().ok(data=data, message=message).__dict__)
@staticmethod
def _error(message: str):
return jsonify(Response().error(message).__dict__)
async def _run(self, operation, *, message: str | None = None):
try:
cfg = self.core_lifecycle.astrbot_config
data = cfg.get("subagent_orchestrator")
return self._response(await operation(), message)
except SubAgentServiceError as exc:
return self._error(str(exc))
# First-time access: return a sane default instead of erroring.
if not isinstance(data, dict):
data = {
"main_enable": False,
"remove_main_duplicate_tools": False,
"agents": [],
}
async def _run_sync(self, operation):
try:
return self._response(operation())
except SubAgentServiceError as exc:
return self._error(str(exc))
# Backward compatibility: older config used `enable`.
if (
isinstance(data, dict)
and "main_enable" not in data
and "enable" in data
):
data["main_enable"] = bool(data.get("enable", False))
async def _run_json(self, operation, *, message: str | None = None):
data = await request.json
return await self._run(lambda: operation(data), message=message)
# Ensure required keys exist.
data.setdefault("main_enable", False)
data.setdefault("remove_main_duplicate_tools", False)
data.setdefault("agents", [])
# Backward/forward compatibility: ensure each agent contains provider_id.
# None means follow global/default provider settings.
if isinstance(data.get("agents"), list):
for a in data["agents"]:
if isinstance(a, dict):
a.setdefault("provider_id", None)
a.setdefault("persona_id", None)
return jsonify(Response().ok(data=data).__dict__)
except Exception as e:
logger.error(traceback.format_exc())
return jsonify(Response().error(f"获取 subagent 配置失败: {e!s}").__dict__)
async def get_config(self):
return await self._run_sync(self.service.get_config)
async def update_config(self):
try:
data = await request.json
if not isinstance(data, dict):
return jsonify(Response().error("配置必须为 JSON 对象").__dict__)
cfg = self.core_lifecycle.astrbot_config
cfg["subagent_orchestrator"] = data
# Persist to cmd_config.json
# AstrBotConfigManager does not expose a `save()` method; persist via AstrBotConfig.
cfg.save_config()
# Reload dynamic handoff tools if orchestrator exists
orch = getattr(self.core_lifecycle, "subagent_orchestrator", None)
if orch is not None:
await orch.reload_from_config(data)
return jsonify(Response().ok(message="保存成功").__dict__)
except Exception as e:
logger.error(traceback.format_exc())
return jsonify(Response().error(f"保存 subagent 配置失败: {e!s}").__dict__)
return await self._run_json(self.service.update_config, message="保存成功")
async def get_available_tools(self):
"""Return all registered tools (name/description/parameters/active/origin).
UI can use this to build a multi-select list for subagent tool assignment.
"""
try:
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools_dict = []
for tool in tool_mgr.func_list:
# Prevent recursive routing: subagents should not be able to select
# the handoff (transfer_to_*) tools as their own mounted tools.
if isinstance(tool, HandoffTool):
continue
if tool.handler_module_path == "core.subagent_orchestrator":
continue
tools_dict.append(
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"active": tool.active,
"handler_module_path": tool.handler_module_path,
}
)
return jsonify(Response().ok(data=tools_dict).__dict__)
except Exception as e:
logger.error(traceback.format_exc())
return jsonify(Response().error(f"获取可用工具失败: {e!s}").__dict__)
return await self._run_sync(self.service.get_available_tools)

View File

@@ -2,11 +2,9 @@
from dataclasses import asdict
from quart import jsonify, request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.t2i.template_manager import TemplateManager
from astrbot.dashboard.fastapi_compat import jsonify, request
from astrbot.dashboard.services.t2i_service import T2iService, T2iServiceError
from .route import Response, Route, RouteContext
@@ -16,9 +14,7 @@ class T2iRoute(Route):
self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.manager = TemplateManager()
self.service = T2iService(core_lifecycle)
# 使用列表保证路由注册顺序,避免 /<name> 路由优先匹配 /reset_default
self.routes = [
("/t2i/templates", ("GET", self.list_templates)),
@@ -28,7 +24,7 @@ class T2iRoute(Route):
("/t2i/templates/set_active", ("POST", self.set_active_template)),
# 动态路由应该在静态路由之后注册
(
"/t2i/templates/<name>",
"/t2i/templates/<path:name>",
[
("GET", self.get_template),
("PUT", self.update_template),
@@ -38,200 +34,94 @@ class T2iRoute(Route):
]
self.register_routes()
async def _reload_all_pipeline_schedulers(self) -> None:
"""热重载所有配置对应的 pipeline scheduler。"""
for conf_id in self.core_lifecycle.astrbot_config_mgr.confs:
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
@staticmethod
def _ok(data=None, message: str | None = None, status_code: int = 200):
response = jsonify(asdict(Response().ok(data=data, message=message)))
response.status_code = status_code
return response
async def _sync_active_template_to_all_configs(self, name: str) -> None:
"""同步当前激活模板到所有配置文件,并热重载对应流水线。"""
for config in self.core_lifecycle.astrbot_config_mgr.confs.values():
config["t2i_active_template"] = name
config.save_config()
await self._reload_all_pipeline_schedulers()
@staticmethod
def _service_error(exc: T2iServiceError):
response = jsonify(asdict(Response().error(str(exc))))
response.status_code = exc.status_code
return response
@staticmethod
async def _request_data() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(
self,
operation,
*,
message: str | None = None,
status_code: int = 200,
result_as_message: bool = False,
):
try:
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
if isinstance(result, tuple):
payload, result_message = result
return self._ok(data=payload, message=result_message)
if result_as_message:
return self._ok(message=str(result), status_code=status_code)
return self._ok(data=result, message=message, status_code=status_code)
except T2iServiceError as exc:
return self._service_error(exc)
async def _run_json(self, operation, **kwargs):
async def invoke():
data = await self._request_data()
return operation(data)
return await self._run(invoke, **kwargs)
async def list_templates(self):
"""获取所有T2I模板列表"""
try:
templates = self.manager.list_templates()
return jsonify(asdict(Response().ok(data=templates)))
except Exception as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run(self.service.list_templates)
async def get_active_template(self):
"""获取当前激活的T2I模板"""
try:
active_template = self.config.get("t2i_active_template", "base")
return jsonify(
asdict(Response().ok(data={"active_template": active_template})),
)
except Exception as e:
logger.error("Error in get_active_template", exc_info=True)
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run(self.service.get_active_template)
async def get_template(self, name: str):
"""获取指定名称的T2I模板内容"""
try:
content = self.manager.get_template(name)
return jsonify(
asdict(Response().ok(data={"name": name, "content": content})),
)
except FileNotFoundError:
response = jsonify(asdict(Response().error("Template not found")))
response.status_code = 404
return response
except Exception as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run(lambda: self.service.get_template(name))
async def create_template(self):
"""创建一个新的T2I模板"""
try:
data = await request.json
name = data.get("name")
content = data.get("content")
if not name or not content:
response = jsonify(
asdict(Response().error("Name and content are required.")),
)
response.status_code = 400
return response
name = name.strip()
self.manager.create_template(name, content)
response = jsonify(
asdict(
Response().ok(
data={"name": name},
message="Template created successfully.",
),
),
)
response.status_code = 201
return response
except FileExistsError:
response = jsonify(
asdict(Response().error("Template with this name already exists.")),
)
response.status_code = 409
return response
except ValueError as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 400
return response
except Exception as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run_json(
self.service.create_template_from_legacy_payload,
message="Template created successfully.",
status_code=201,
)
async def update_template(self, name: str):
"""更新一个已存在的T2I模板"""
try:
name = name.strip()
data = await request.json
content = data.get("content")
if content is None:
response = jsonify(asdict(Response().error("Content is required.")))
response.status_code = 400
return response
self.manager.update_template(name, content)
# 检查更新的是否为当前激活的模板,如果是,则热重载
active_template = self.config.get("t2i_active_template", "base")
if name == active_template:
await self._reload_all_pipeline_schedulers()
message = f"模板 '{name}' 已更新并重新加载。"
else:
message = f"模板 '{name}' 已更新。"
return jsonify(asdict(Response().ok(data={"name": name}, message=message)))
except ValueError as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 400
return response
except Exception as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run_json(
lambda data: self.service.update_template_from_legacy_payload(name, data)
)
async def delete_template(self, name: str):
"""删除一个T2I模板"""
try:
name = name.strip()
self.manager.delete_template(name)
return jsonify(
asdict(Response().ok(message="Template deleted successfully.")),
)
except FileNotFoundError:
response = jsonify(asdict(Response().error("Template not found.")))
response.status_code = 404
return response
except ValueError as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 400
return response
except Exception as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run(
lambda: self.service.delete_template(name),
message="Template deleted successfully.",
)
async def set_active_template(self):
"""设置当前活动的T2I模板"""
try:
data = await request.json
name = data.get("name")
if not name:
response = jsonify(asdict(Response().error("模板名称(name)不能为空。")))
response.status_code = 400
return response
# 验证模板文件是否存在
self.manager.get_template(name)
# 更新所有配置并热重载以应用更改
await self._sync_active_template_to_all_configs(name)
return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。")))
except FileNotFoundError:
response = jsonify(
asdict(Response().error(f"模板 '{name}' 不存在,无法应用。")),
)
response.status_code = 404
return response
except Exception as e:
logger.error("Error in set_active_template", exc_info=True)
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run_json(
self.service.set_active_template_from_legacy_payload,
result_as_message=True,
)
async def reset_default_template(self):
"""重置默认的'base'模板"""
try:
self.manager.reset_default_template()
# 更新所有配置,将激活模板也重置为'base'
await self._sync_active_template_to_all_configs("base")
return jsonify(
asdict(
Response().ok(
message="Default template has been reset and activated.",
),
),
)
except FileNotFoundError as e:
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 404
return response
except Exception as e:
logger.error("Error in reset_default_template", exc_info=True)
response = jsonify(asdict(Response().error(str(e))))
response.status_code = 500
return response
return await self._run(
self.service.reset_default_template(),
result_as_message=True,
)

View File

@@ -1,43 +1,9 @@
import traceback
from quart import request
from astrbot.core import logger
from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star import star_map
from astrbot.core.tools.registry import get_builtin_tool_config_statuses
from astrbot.dashboard.fastapi_compat import request
from astrbot.dashboard.services.tools_service import ToolsService, ToolsServiceError
from .route import Response, Route, RouteContext
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
class EmptyMcpServersError(ValueError):
"""Raised when mcpServers is empty."""
pass
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
"""Extract server configuration from user-submitted mcpServers field.
Raises:
ValueError: Invalid configuration
"""
if not isinstance(mcp_servers_value, dict):
raise ValueError("mcpServers must be a JSON object")
if not mcp_servers_value:
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
key_0 = next(iter(mcp_servers_value))
extracted = mcp_servers_value[key_0]
if not isinstance(extracted, dict):
raise ValueError(
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
"and each value is an object containing fields like command/url."
)
return extracted
class ToolsRoute(Route):
def __init__(
@@ -46,7 +12,7 @@ class ToolsRoute(Route):
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.service = ToolsService(core_lifecycle)
self.routes = {
"/tools/mcp/servers": ("GET", self.get_mcp_servers),
"/tools/mcp/add": ("POST", self.add_mcp_server),
@@ -58,514 +24,80 @@ class ToolsRoute(Route):
"/tools/mcp/sync-provider": ("POST", self.sync_provider),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
def _rollback_mcp_server(self, name: str) -> bool:
@staticmethod
def _ok(data: dict | list | None = None, message: str | None = None) -> dict:
return Response().ok(data, message).__dict__
@staticmethod
def _error(message: str) -> dict:
return Response().error(message).__dict__
@staticmethod
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation, *, message: str | None = None) -> dict:
try:
rollback_config = self.tool_mgr.load_mcp_config()
if name in rollback_config["mcpServers"]:
rollback_config["mcpServers"].pop(name)
return self.tool_mgr.save_mcp_config(rollback_config)
return True
except Exception:
logger.error(traceback.format_exc())
return False
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._ok(result, message)
except ToolsServiceError as exc:
return self._error(str(exc))
async def _run_json(
self,
operation,
*,
message: str | None = None,
result_as_message: bool = False,
) -> dict:
async def invoke():
data = await self._json_body()
return operation(data)
result = await self._run(invoke)
if result_as_message and result.get("status") == "ok":
return self._ok(None, result["data"])
if message and result.get("status") == "ok":
return self._ok(result.get("data"), message)
return result
async def get_mcp_servers(self):
try:
config = self.tool_mgr.load_mcp_config()
servers = []
mcp_servers = config.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
logger.warning(
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
)
mcp_servers = {}
# 获取所有服务器并添加它们的工具列表
for name, server_config in mcp_servers.items():
if not isinstance(server_config, dict):
logger.warning(
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
)
continue
server_info = {
"name": name,
"active": server_config.get("active", True),
}
# 复制所有配置字段
for key, value in server_config.items():
if key != "active": # active 已经处理
server_info[key] = value
# 如果MCP客户端已初始化从客户端获取工具名称
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
if name_key == name:
mcp_client = runtime.client
server_info["tools"] = [tool.name for tool in mcp_client.tools]
server_info["errlogs"] = mcp_client.server_errlogs
break
else:
server_info["tools"] = []
servers.append(server_info)
return Response().ok(servers).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__
return await self._run(self.service.get_mcp_servers)
async def add_mcp_server(self):
try:
server_data = await request.json
name = server_data.get("name", "")
# 检查必填字段
if not name:
return Response().error("Server name cannot be empty").__dict__
# 移除特殊字段并检查配置是否有效
has_valid_config = False
server_config = {"active": server_data.get("active", True)}
# 复制所有配置字段
for key, value in server_data.items():
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
if key == "mcpServers":
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else:
server_config[key] = value
has_valid_config = True
if not has_valid_config:
return (
Response()
.error("A valid server configuration is required")
.__dict__
)
try:
validate_mcp_stdio_config(server_config)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
config = self.tool_mgr.load_mcp_config()
if name in config["mcpServers"]:
return Response().error(f"Server {name} already exists").__dict__
try:
await self.tool_mgr.test_mcp_server_connection(server_config)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"MCP connection test failed: {e!s}").__dict__
config["mcpServers"][name] = server_config
if self.tool_mgr.save_mcp_config(config):
try:
await self.tool_mgr.enable_mcp_server(
name,
server_config,
timeout=30,
)
except TimeoutError:
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Timed out while enabling MCP server {name}."
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
except Exception as e:
logger.error(traceback.format_exc())
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Failed to enable MCP server {name}: {e!s}"
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
return (
Response()
.ok(None, f"Successfully added MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to add MCP server: {e!s}").__dict__
return await self._run_json(self.service.add_mcp_server, result_as_message=True)
async def update_mcp_server(self):
try:
server_data = await request.json
name = server_data.get("name", "")
old_name = server_data.get("oldName") or name
if not name:
return Response().error("Server name cannot be empty").__dict__
config = self.tool_mgr.load_mcp_config()
if old_name not in config["mcpServers"]:
return Response().error(f"Server {old_name} does not exist").__dict__
is_rename = name != old_name
if name in config["mcpServers"] and is_rename:
return Response().error(f"Server {name} already exists").__dict__
# 获取活动状态
old_config = config["mcpServers"][old_name]
if isinstance(old_config, dict):
old_active = old_config.get("active", True)
else:
old_active = True
active = server_data.get("active", old_active)
# 创建新的配置对象
server_config = {"active": active}
# 仅更新活动状态的特殊处理
only_update_active = True
# 复制所有配置字段
for key, value in server_data.items():
if key not in [
"name",
"active",
"tools",
"errlogs",
"oldName",
]: # 排除特殊字段
if key == "mcpServers":
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else:
server_config[key] = value
only_update_active = False
# 如果只更新活动状态,保留原始配置
if only_update_active and isinstance(old_config, dict):
for key, value in old_config.items():
if key != "active": # 除了active之外的所有字段都保留
server_config[key] = value
try:
validate_mcp_stdio_config(server_config)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
# config["mcpServers"][name] = server_config
if is_rename:
config["mcpServers"].pop(old_name)
config["mcpServers"][name] = server_config
else:
config["mcpServers"][name] = server_config
if self.tool_mgr.save_mcp_config(config):
# 处理MCP客户端状态变化
if active:
if (
old_name in self.tool_mgr.mcp_server_runtime_view
or not only_update_active
or is_rename
):
try:
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
except TimeoutError as e:
return (
Response()
.error(
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}"
)
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(
f"Failed to disable MCP server {old_name} before enabling: {e!s}"
)
.__dict__
)
try:
await self.tool_mgr.enable_mcp_server(
name,
config["mcpServers"][name],
timeout=30,
)
except TimeoutError:
return (
Response()
.error(f"Timed out while enabling MCP server {name}.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"Failed to enable MCP server {name}: {e!s}")
.__dict__
)
# 如果要停用服务器
elif old_name in self.tool_mgr.mcp_server_runtime_view:
try:
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
except TimeoutError:
return (
Response()
.error(f"Timed out while disabling MCP server {old_name}.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"Failed to disable MCP server {old_name}: {e!s}")
.__dict__
)
return (
Response()
.ok(None, f"Successfully updated MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to update MCP server: {e!s}").__dict__
return await self._run_json(
self.service.update_mcp_server,
result_as_message=True,
)
async def delete_mcp_server(self):
try:
server_data = await request.json
name = server_data.get("name", "")
if not name:
return Response().error("Server name cannot be empty").__dict__
config = self.tool_mgr.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"Server {name} does not exist").__dict__
del config["mcpServers"][name]
if self.tool_mgr.save_mcp_config(config):
if name in self.tool_mgr.mcp_server_runtime_view:
try:
await self.tool_mgr.disable_mcp_server(name, timeout=10)
except TimeoutError:
return (
Response()
.error(f"Timed out while disabling MCP server {name}.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"Failed to disable MCP server {name}: {e!s}")
.__dict__
)
return (
Response()
.ok(None, f"Successfully deleted MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__
return await self._run_json(
self.service.delete_mcp_server,
result_as_message=True,
)
async def test_mcp_connection(self):
"""Test MCP server connection."""
try:
server_data = await request.json
config = server_data.get("mcp_server_config", None)
if not isinstance(config, dict) or not config:
return Response().error("Invalid MCP server configuration").__dict__
if "mcpServers" in config:
mcp_servers = config["mcpServers"]
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
return (
Response()
.error(
"Only one MCP server configuration can be tested at a time"
)
.__dict__
)
try:
config = _extract_mcp_server_config(mcp_servers)
except EmptyMcpServersError:
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
elif not config:
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
try:
validate_mcp_stdio_config(config)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
return (
Response()
.ok(data=tools_name, message="🎉 MCP server is available!")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__
return await self._run_json(
self.service.test_mcp_connection,
message="🎉 MCP server is available!",
)
async def get_tool_list(self):
"""Get all registered tools."""
try:
tools = list(self.tool_mgr.func_list)
existing_names = {tool.name for tool in tools}
for tool in self.tool_mgr.iter_builtin_tools():
if tool.name not in existing_names:
tools.append(tool)
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
conf_name_map = {conf["id"]: conf["name"] for conf in conf_list}
config_entries = []
for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items():
config_entries.append(
{
"conf_id": conf_id,
"conf_name": conf_name_map.get(conf_id, conf_id),
"config": conf,
}
)
tools_dict = []
for tool in tools:
readonly = False
builtin_config_statuses = []
builtin_config_tags = []
if self.tool_mgr.is_builtin_tool(tool.name):
origin = "builtin"
origin_name = "AstrBot Core"
readonly = True
builtin_config_statuses = get_builtin_tool_config_statuses(
tool.name,
config_entries,
)
builtin_config_tags = [
status
for status in builtin_config_statuses
if status["enabled"]
]
elif isinstance(tool, MCPTool):
origin = "mcp"
origin_name = tool.mcp_server_name
elif tool.handler_module_path and star_map.get(
tool.handler_module_path
):
star = star_map[tool.handler_module_path]
origin = "plugin"
origin_name = star.name
else:
origin = "unknown"
origin_name = "unknown"
tool_info = {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"active": tool.active,
"origin": origin,
"origin_name": origin_name,
"readonly": readonly,
"builtin_config_statuses": builtin_config_statuses,
"builtin_config_tags": builtin_config_tags,
}
tools_dict.append(tool_info)
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to get tool list: {e!s}").__dict__
return await self._run(self.service.get_tool_list)
async def toggle_tool(self):
"""Activate or deactivate a specified tool."""
try:
data = await request.json
tool_name = data.get("name")
action = data.get("activate") # True or False
if not tool_name or action is None:
return (
Response()
.error("Missing required parameters: name or activate")
.__dict__
)
if self.tool_mgr.is_builtin_tool(tool_name):
return (
Response()
.error("Builtin tools are read-only and cannot be toggled.")
.__dict__
)
if action:
try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as e:
return Response().error(f"Failed to activate tool: {e!s}").__dict__
else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok:
return Response().ok(None, "Operation successful.").__dict__
return (
Response()
.error(f"Tool {tool_name} does not exist or the operation failed.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Failed to operate tool: {e!s}").__dict__
return await self._run_json(self.service.toggle_tool, result_as_message=True)
async def sync_provider(self):
"""Sync MCP provider configuration."""
try:
data = await request.json
provider_name = data.get("name") # modelscope, or others
match provider_name:
case "modelscope":
access_token = data.get("access_token", "")
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
case _:
return (
Response().error(f"Unknown provider: {provider_name}").__dict__
)
return Response().ok(message="Sync completed").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Sync failed: {e!s}").__dict__
return await self._run_json(self.service.sync_provider, result_as_message=True)

View File

@@ -1,17 +1,26 @@
import traceback
import uuid
from __future__ import annotations
from quart import request
from typing import TYPE_CHECKING
from astrbot.core import DEMO_MODE, logger, pip_installer
from astrbot.core.config.default import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.migration.helper import check_migration_needed_v4, do_migration_v4
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.dashboard.fastapi_compat import request
from astrbot.dashboard.services.update_service import (
DEMO_MODE,
UpdateService,
UpdateServiceError,
UpdateServiceResult,
call_check_migration_needed_v4,
call_do_migration_v4,
call_download_dashboard,
call_get_dashboard_version,
call_pip_install,
)
from .route import Response, Route, RouteContext
if TYPE_CHECKING:
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.updator import AstrBotUpdator
CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'}
@@ -32,323 +41,82 @@ class UpdateRoute(Route):
"/update/pip-install": ("POST", self.install_pip_package),
"/update/migration": ("POST", self.do_migration),
}
self.astrbot_updator = astrbot_updator
self.core_lifecycle = core_lifecycle
self.update_progress: dict[str, dict] = {}
self.service = UpdateService(
astrbot_updator,
core_lifecycle,
download_dashboard_func=call_download_dashboard,
get_dashboard_version_func=call_get_dashboard_version,
pip_install_func=call_pip_install,
check_migration_needed_func=call_check_migration_needed_v4,
do_migration_func=call_do_migration_v4,
demo_mode=DEMO_MODE,
clear_site_data_headers=CLEAR_SITE_DATA_HEADERS,
)
self.register_routes()
def _init_update_progress(self, progress_id: str, version: str) -> None:
self.update_progress[progress_id] = {
"id": progress_id,
"status": "running",
"stage": "preparing",
"version": version or "latest",
"message": "正在准备更新...",
"overall_percent": 0,
"stages": {
"dashboard": self._empty_stage("pending"),
"core": self._empty_stage("pending"),
},
}
@staticmethod
def _service_response(result: UpdateServiceResult):
if result.status == "success":
payload = Response(
status="success",
message=result.message,
data=result.data,
).__dict__
else:
payload = Response().ok(result.data, result.message).__dict__
if result.headers:
return payload, 200, result.headers
return payload
@staticmethod
def _empty_stage(status: str = "pending") -> dict:
return {
"status": status,
"downloaded": 0,
"total": 0,
"percent": 0,
"speed": 0,
}
def _set_update_stage(
self,
progress_id: str,
stage: str,
status: str,
message: str,
overall_percent: int | None = None,
) -> None:
progress = self.update_progress.get(progress_id)
if not progress:
return
progress["stage"] = stage
progress["message"] = message
progress["stages"].setdefault(stage, self._empty_stage())
progress["stages"][stage]["status"] = status
if overall_percent is not None:
progress["overall_percent"] = overall_percent
def _service_error(exc: UpdateServiceError):
return Response().error(str(exc)).__dict__
@staticmethod
def _normalize_percent(value) -> int:
async def _json_body() -> dict:
data = await request.get_json()
return data if isinstance(data, dict) else {}
async def _run(self, operation):
try:
percent = float(value or 0)
except (TypeError, ValueError):
return 0
if percent <= 1:
percent *= 100
return max(0, min(100, int(percent)))
result = operation() if callable(operation) else operation
while hasattr(result, "__await__"):
result = await result
return self._service_response(result)
except UpdateServiceError as exc:
return self._service_error(exc)
def _make_progress_callback(
self,
progress_id: str,
stage: str,
stage_start: int,
stage_weight: int,
):
def _callback(payload: dict) -> None:
progress = self.update_progress.get(progress_id)
if not progress:
return
stage_percent = self._normalize_percent(payload.get("percent"))
progress["stage"] = stage
progress["stages"][stage] = {
"status": "running" if stage_percent < 100 else "done",
"downloaded": payload.get("downloaded", 0),
"total": payload.get("total", 0),
"percent": stage_percent,
"speed": payload.get("speed", 0),
}
progress["overall_percent"] = min(
99,
stage_start + int(stage_percent * stage_weight / 100),
)
async def _run_json(self, operation):
async def invoke():
data = await self._json_body()
return operation(data)
return _callback
return await self._run(invoke)
async def get_update_progress(self):
progress_id = request.args.get("id", "")
if not progress_id:
return Response().error("缺少参数 id").__dict__
progress = self.update_progress.get(progress_id)
if not progress:
return (
Response()
.ok(
{"id": progress_id, "status": "idle"},
"没有正在进行的更新。",
)
.__dict__
return await self._run(
lambda: self.service.get_update_progress_from_legacy_query(
request.args.get("id")
)
return Response().ok(progress).__dict__
)
async def do_migration(self):
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
if not need_migration:
return Response().ok(None, "不需要进行迁移。").__dict__
try:
data = await request.json
pim = data.get("platform_id_map", {})
await do_migration_v4(
self.core_lifecycle.db,
pim,
self.core_lifecycle.astrbot_config,
)
return Response().ok(None, "迁移成功。").__dict__
except Exception as e:
logger.error(f"迁移失败: {traceback.format_exc()}")
return Response().error(f"迁移失败: {e!s}").__dict__
return await self._run_json(self.service.do_migration_v4)
async def check_update(self):
type_ = request.args.get("type", None)
try:
dv = await get_dashboard_version()
if type_ == "dashboard":
return (
Response()
.ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv})
.__dict__
)
ret = await self.astrbot_updator.check_update(None, None, False)
return Response(
status="success",
message=str(ret) if ret is not None else "已经是最新版本了。",
data={
"version": f"v{VERSION}",
"has_new_version": ret is not None,
"dashboard_version": dv,
"dashboard_has_new_version": bool(dv and dv != f"v{VERSION}"),
},
).__dict__
except Exception as e:
logger.warning(f"检查更新失败: {e!s} (不影响除项目更新外的正常使用)")
return Response().error(e.__str__()).__dict__
return await self._run(
self.service.check_update_from_legacy_query(request.args.get("type"))
)
async def get_releases(self):
try:
ret = await self.astrbot_updator.get_releases()
return Response().ok(ret).__dict__
except Exception as e:
logger.error(f"/api/update/releases: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
return await self._run(self.service.get_releases())
async def update_project(self):
data = await request.json
version = data.get("version", "")
reboot = data.get("reboot", True)
progress_id = data.get("progress_id") or uuid.uuid4().hex
if version == "" or version == "latest":
latest = True
version = ""
else:
latest = False
proxy: str = data.get("proxy", None)
if proxy:
proxy = proxy.removesuffix("/")
self._init_update_progress(progress_id, version)
try:
self._set_update_stage(
progress_id,
"dashboard",
"running",
"正在下载 WebUI...",
0,
)
await download_dashboard(
latest=latest,
version=version,
proxy=proxy,
progress_callback=self._make_progress_callback(
progress_id,
"dashboard",
0,
45,
),
)
self._set_update_stage(
progress_id,
"dashboard",
"done",
"WebUI 下载完成。",
45,
)
self._set_update_stage(
progress_id,
"core",
"running",
"正在下载 AstrBot 项目代码...",
45,
)
await self.astrbot_updator.update(
latest=latest,
version=version,
proxy=proxy,
progress_callback=self._make_progress_callback(
progress_id,
"core",
45,
45,
),
)
self._set_update_stage(
progress_id,
"core",
"done",
"项目代码下载完成。",
90,
)
# pip 更新依赖
self._set_update_stage(
progress_id,
"dependencies",
"running",
"正在更新依赖...",
92,
)
logger.info("更新依赖中...")
try:
await pip_installer.install(requirements_path="requirements.txt")
except Exception as e:
logger.error(f"更新依赖失败: {e}")
self._set_update_stage(
progress_id,
"dependencies",
"done",
"依赖更新完成。",
96,
)
if reboot:
self._set_update_stage(
progress_id,
"restart",
"running",
"更新成功,正在准备重启...",
98,
)
await self.core_lifecycle.restart()
self.update_progress[progress_id].update(
{
"status": "success",
"stage": "done",
"message": "更新成功AstrBot 将在 2 秒内全量重启以应用新的代码。",
"overall_percent": 100,
},
)
ret = (
Response()
.ok(None, "更新成功AstrBot 将在 2 秒内全量重启以应用新的代码。")
.__dict__
)
return ret, 200, CLEAR_SITE_DATA_HEADERS
self.update_progress[progress_id].update(
{
"status": "success",
"stage": "done",
"message": "更新成功AstrBot 将在下次启动时应用新的代码。",
"overall_percent": 100,
},
)
ret = (
Response()
.ok(None, "更新成功AstrBot 将在下次启动时应用新的代码。")
.__dict__
)
return ret, 200, CLEAR_SITE_DATA_HEADERS
except Exception as e:
self.update_progress[progress_id].update(
{
"status": "error",
"message": e.__str__(),
},
)
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
return await self._run_json(self.service.update_project)
async def update_dashboard(self):
try:
try:
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
return Response().error(f"下载管理面板文件失败: {e}").__dict__
ret = Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
return ret, 200, CLEAR_SITE_DATA_HEADERS
except Exception as e:
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
return await self._run(self.service.update_dashboard())
async def install_pip_package(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
data = await request.json
package = data.get("package", "")
mirror = data.get("mirror", None)
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
await pip_installer.install(package, mirror=mirror)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
return await self._run_json(self.service.install_pip_package)

View File

@@ -1,117 +0,0 @@
"""Dashboard 路由工具集。
这里放一些 dashboard routes 可复用的小工具函数。
目前主要用于「配置文件上传file 类型配置项)」功能:
- 清洗/规范化用户可控的文件名与相对路径
- 将配置 key 映射到配置项独立子目录
"""
import os
def get_schema_item(schema: dict | None, key_path: str) -> dict | None:
"""按 dot-path 获取 schema 的节点。
同时支持:
- 扁平 schema直接 key 命中)
- 嵌套 object schema{type: "object", items: {...}}
- template_list schema<field>.templates.<template>.items
"""
if not isinstance(schema, dict) or not key_path:
return None
if key_path in schema:
return schema.get(key_path)
parts = key_path.split(".")
current = schema
idx = 0
while idx < len(parts):
part = parts[idx]
if part not in current:
return None
meta = current.get(part)
if idx == len(parts) - 1:
return meta
if not isinstance(meta, dict) or meta.get("type") != "object":
if not isinstance(meta, dict) or meta.get("type") != "template_list":
return None
if idx + 2 >= len(parts) or parts[idx + 1] != "templates":
return None
template_meta = meta.get("templates", {}).get(parts[idx + 2])
if not isinstance(template_meta, dict):
return None
if idx + 2 == len(parts) - 1:
return template_meta
current = template_meta.get("items", {})
idx += 3
continue
current = meta.get("items", {})
idx += 1
return None
def sanitize_filename(name: str) -> str:
"""清洗上传文件名,避免路径穿越与非法名称。
- 丢弃目录部分,仅保留 basename
- 将路径分隔符替换为下划线
- 拒绝空字符串 / "." / ".."
"""
cleaned = os.path.basename(name).strip()
if not cleaned or cleaned in {".", ".."}:
return ""
for sep in (os.sep, os.altsep):
if sep:
cleaned = cleaned.replace(sep, "_")
return cleaned
def sanitize_path_segment(segment: str) -> str:
"""清洗目录片段URL/path 安全,避免穿越)。
仅保留 [A-Za-z0-9_-],其余替换为 "_"
"""
cleaned = []
for ch in segment:
if (
("a" <= ch <= "z")
or ("A" <= ch <= "Z")
or ch.isdigit()
or ch
in {
"-",
"_",
}
):
cleaned.append(ch)
else:
cleaned.append("_")
result = "".join(cleaned).strip("_")
return result or "_"
def config_key_to_folder(key_path: str) -> str:
"""将 dot-path 的配置 key 转成稳定的文件夹路径。"""
parts = [sanitize_path_segment(p) for p in key_path.split(".") if p]
return "/".join(parts) if parts else "_"
def normalize_rel_path(rel_path: str | None) -> str | None:
"""规范化用户传入的相对路径,并阻止路径穿越。"""
if not isinstance(rel_path, str):
return None
rel = rel_path.replace("\\", "/").lstrip("/")
if not rel:
return None
parts = [p for p in rel.split("/") if p]
if any(part in {".", ".."} for part in parts):
return None
if rel.startswith("../") or "/../" in rel:
return None
return "/".join(parts)

View File

@@ -1,49 +1,68 @@
import asyncio
import hashlib
import ipaddress
import logging
import os
import re
import socket
import time
from datetime import datetime
from pathlib import Path
from typing import Protocol, cast
import jwt
import psutil
from flask.json.provider import DefaultJSONProvider
from hypercorn.asyncio import serve
from hypercorn.config import Config as HyperConfig
from hypercorn.logging import AccessLogAtoms
from hypercorn.logging import Logger as HypercornLogger
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
from werkzeug.exceptions import MethodNotAllowed, NotFound
from werkzeug.routing import Map, Rule
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from astrbot.core.utils.io import (
get_bundled_dashboard_dist_path,
get_local_ip_addresses,
should_use_bundled_dashboard_dist,
)
from astrbot.dashboard.fastapi_compat import (
CompatG,
FastAPIAppAdapter,
bind_request_context,
g,
jsonify,
request,
)
from .plugin_page_auth import PluginPageAuth
from .routes import *
from .routes.api_key import ALL_OPEN_API_SCOPES
from .routes.auth import DASHBOARD_JWT_COOKIE_NAME
from .routes.api_key import ApiKeyRoute
from .routes.auth import AuthRoute
from .routes.backup import BackupRoute
from .routes.chat import ChatRoute
from .routes.chatui_project import ChatUIProjectRoute
from .routes.command import CommandRoute
from .routes.config import ConfigRoute
from .routes.conversation import ConversationRoute
from .routes.cron import CronRoute
from .routes.file import FileRoute
from .routes.knowledge_base import KnowledgeBaseRoute
from .routes.live_chat import LiveChatRoute
from .routes.log import LogRoute
from .routes.open_api import OpenApiRoute
from .routes.persona import PersonaRoute
from .routes.platform import PlatformRoute
from .routes.plugin import PluginRoute
from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute
from .routes.skills import SkillsRoute
from .routes.stat import StatRoute
from .routes.static_file import StaticFileRoute
from .routes.subagent import SubAgentRoute
from .routes.t2i import T2iRoute
from .routes.tools import ToolsRoute
from .routes.update import UpdateRoute
from .services.auth_service import DASHBOARD_JWT_COOKIE_NAME
from .services.chat_service import ChatService
from .v1.app import create_v1_asgi_app
_RATE_LIMITED_ENDPOINTS: frozenset = frozenset(
{
@@ -120,7 +139,7 @@ class _AddrWithPort(Protocol):
port: int
APP: Quart
APP: FastAPIAppAdapter | None = None
def _normalize_plugin_api_route(route: str) -> str:
@@ -139,27 +158,28 @@ def _match_registered_web_api(registered_web_apis, subpath: str, method: str):
if request_method not in allowed_methods:
continue
url_map = Map(
[
Rule(
_normalize_plugin_api_route(route),
endpoint="plugin_api",
methods=allowed_methods,
),
]
)
try:
_, path_values = url_map.bind("").match(
request_path,
method=request_method,
)
except (MethodNotAllowed, NotFound):
pattern = _plugin_api_route_pattern(route)
matched = re.fullmatch(pattern, request_path)
if not matched:
continue
return view_handler, path_values
return view_handler, matched.groupdict()
return None
def _plugin_api_route_pattern(route: str) -> str:
normalized = _normalize_plugin_api_route(route)
chunks = []
pos = 0
for match in re.finditer(r"<(?:(path):)?([A-Za-z_][A-Za-z0-9_]*)>", normalized):
chunks.append(re.escape(normalized[pos : match.start()]))
name = match.group(2)
chunks.append(f"(?P<{name}>.*)" if match.group(1) else f"(?P<{name}>[^/]+)")
pos = match.end()
chunks.append(re.escape(normalized[pos:]))
return "".join(chunks)
def _parse_env_bool(value: str | None, default: bool) -> bool:
if value is None:
return default
@@ -213,13 +233,6 @@ class _ProxyAwareHypercornLogger(HypercornLogger):
return atoms
class AstrBotJSONProvider(DefaultJSONProvider):
def default(self, obj):
if isinstance(obj, datetime):
return to_utc_isoformat(obj)
return super().default(obj)
class AstrBotDashboard:
def __init__(
self,
@@ -254,16 +267,30 @@ class AstrBotDashboard:
self.data_path = os.path.abspath(user_dist)
self._rate_limiter_registry = _RateLimiterRegistry()
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
APP = self.app # noqa
self._init_jwt_secret()
self.asgi_app = create_v1_asgi_app(
core_lifecycle=core_lifecycle,
db=db,
jwt_secret=self._jwt_secret,
)
self.app = FastAPIAppAdapter(self.asgi_app, static_folder=self.data_path)
self.asgi_app.state.dashboard_app_adapter = self.app
self.app._dashboard_server = self
global APP
APP = self.app
self.app.config["MAX_CONTENT_LENGTH"] = (
128 * 1024 * 1024
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
self.app.json = AstrBotJSONProvider(self.app)
self.app.json.sort_keys = False
self.app.before_request(self.auth_middleware)
# token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler)
@self.asgi_app.middleware("http")
async def dashboard_auth_middleware(request_, call_next):
request_.state.dashboard_g = CompatG()
with bind_request_context(request_, self.app, request_.state.dashboard_g):
auth_response = await self.auth_middleware()
if auth_response is not None:
return auth_response
return await call_next(request_)
self.context = RouteContext(self.config, self.app)
self.ur = UpdateRoute(
self.context,
@@ -282,13 +309,21 @@ class AstrBotDashboard:
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context, db)
self.api_key_route = ApiKeyRoute(self.context, db)
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
self.chat_service = ChatService(db, core_lifecycle)
self.chat_route = ChatRoute(
self.context,
db,
core_lifecycle,
service=self.chat_service,
)
self.open_api_route = OpenApiRoute(
self.context,
db,
core_lifecycle,
self.chat_route,
self.chat_service,
register_routes=False,
)
self.asgi_app.state.open_api_route = self.open_api_route
self.chatui_project_route = ChatUIProjectRoute(self.context, db)
self.tools_root = ToolsRoute(self.context, core_lifecycle)
self.subagent_route = SubAgentRoute(self.context, core_lifecycle)
@@ -307,6 +342,7 @@ class AstrBotDashboard:
self.platform_route = PlatformRoute(self.context, core_lifecycle)
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle)
self.asgi_app.state.live_chat_route = self.live_chat_route
self.app.add_url_rule(
"/api/plug/<path:subpath>",
@@ -316,8 +352,6 @@ class AstrBotDashboard:
self.shutdown_event = shutdown_event
self._init_jwt_secret()
async def srv_plug_route(self, subpath, *args, **kwargs):
"""插件路由"""
registered_web_apis = self.core_lifecycle.star_context.registered_web_apis
@@ -335,37 +369,6 @@ class AstrBotDashboard:
if not request.path.startswith("/api"):
return None
if request.path.startswith("/api/v1"):
raw_key = self._extract_raw_api_key()
if not raw_key:
r = jsonify(Response().error("Missing API key").__dict__)
r.status_code = 401
return r
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
r = jsonify(Response().error("Invalid API key").__dict__)
r.status_code = 401
return r
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
required_scope = self._get_required_open_api_scope(request.path)
if required_scope and "*" not in scopes and required_scope not in scopes:
r = jsonify(Response().error("Insufficient API key scope").__dict__)
r.status_code = 403
return r
g.api_key_id = api_key.key_id
g.api_key_scopes = scopes
g.username = f"api_key:{api_key.key_id}"
await self.db.touch_api_key(api_key.key_id)
return None
if (
@@ -403,6 +406,7 @@ class AstrBotDashboard:
}
allowed_endpoint_prefixes = [
"/api/file",
"/api/v1/files/tokens",
"/api/platform/webhook",
"/api/stat/start-time",
"/api/backup/download", # 备份下载使用 URL 参数传递 token
@@ -486,34 +490,6 @@ class AstrBotDashboard:
return cookie_token
return None
@staticmethod
def _extract_raw_api_key() -> str | None:
if key := request.args.get("api_key"):
return key.strip()
if key := request.args.get("key"):
return key.strip()
if key := request.headers.get("X-API-Key"):
return key.strip()
auth_header = request.headers.get("Authorization", "").strip()
if auth_header.startswith("Bearer "):
return auth_header.removeprefix("Bearer ").strip()
if auth_header.startswith("ApiKey "):
return auth_header.removeprefix("ApiKey ").strip()
return None
@staticmethod
def _get_required_open_api_scope(path: str) -> str | None:
scope_map = {
"/api/v1/chat": "chat",
"/api/v1/chat/ws": "chat",
"/api/v1/chat/sessions": "chat",
"/api/v1/configs": "config",
"/api/v1/file": "file",
"/api/v1/im/message": "im",
"/api/v1/im/bots": "im",
}
return scope_map.get(path)
def check_port_in_use(self, port: int) -> bool:
"""跨平台检测端口是否被占用"""
try:
@@ -725,7 +701,7 @@ class AstrBotDashboard:
config.accesslog = "-"
config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s"
return serve(self.app, config, shutdown_trigger=self.shutdown_trigger)
return serve(self.asgi_app, config, shutdown_trigger=self.shutdown_trigger)
async def shutdown_trigger(self) -> None:
await self.shutdown_event.wait()

View File

@@ -0,0 +1 @@
"""Application services for dashboard HTTP APIs."""

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.datetime_utils import normalize_datetime_utc
from .auth_service import ALL_OPEN_API_SCOPES
class ApiKeyServiceError(Exception):
pass
class ApiKeyService:
def __init__(self, db: BaseDatabase) -> None:
self.db = db
@staticmethod
def hash_key(raw_key: str) -> str:
return hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
@staticmethod
def _normalize_utc(dt: datetime | None) -> datetime | None:
return normalize_datetime_utc(dt)
@classmethod
def _serialize_datetime(cls, dt: datetime | None) -> str | None:
normalized = cls._normalize_utc(dt)
if normalized is None:
return None
return normalized.astimezone().isoformat()
@classmethod
def serialize_api_key(cls, key) -> dict:
expires_at = cls._normalize_utc(key.expires_at)
return {
"key_id": key.key_id,
"name": key.name,
"key_prefix": key.key_prefix,
"scopes": key.scopes or [],
"created_by": key.created_by,
"created_at": cls._serialize_datetime(key.created_at),
"updated_at": cls._serialize_datetime(key.updated_at),
"last_used_at": cls._serialize_datetime(key.last_used_at),
"expires_at": cls._serialize_datetime(key.expires_at),
"revoked_at": cls._serialize_datetime(key.revoked_at),
"is_revoked": key.revoked_at is not None,
"is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)),
}
@staticmethod
def _normalize_scopes(raw_scopes: Any) -> list[str]:
if raw_scopes is None:
return list(ALL_OPEN_API_SCOPES)
if not isinstance(raw_scopes, list):
raise ApiKeyServiceError("Invalid scopes")
scopes = [
scope
for scope in raw_scopes
if isinstance(scope, str) and scope in ALL_OPEN_API_SCOPES
]
normalized_scopes = list(dict.fromkeys(scopes))
if not normalized_scopes:
raise ApiKeyServiceError("At least one valid scope is required")
return normalized_scopes
@staticmethod
def _resolve_expires_at(expires_in_days: Any) -> datetime | None:
if expires_in_days is None:
return None
try:
expires_in_days_int = int(expires_in_days)
except (TypeError, ValueError) as exc:
raise ApiKeyServiceError("expires_in_days must be an integer") from exc
if expires_in_days_int <= 0:
raise ApiKeyServiceError("expires_in_days must be greater than 0")
return datetime.now(timezone.utc) + timedelta(days=expires_in_days_int)
async def list_api_keys(self) -> list[dict]:
keys = await self.db.list_api_keys()
return [self.serialize_api_key(key) for key in keys]
async def create_api_key(self, payload: dict, *, created_by: str) -> dict:
name = str(payload.get("name", "")).strip() or "Untitled API Key"
scopes = self._normalize_scopes(payload.get("scopes"))
expires_at = self._resolve_expires_at(payload.get("expires_in_days"))
raw_key = f"abk_{secrets.token_urlsafe(32)}"
api_key = await self.db.create_api_key(
name=name,
key_hash=self.hash_key(raw_key),
key_prefix=raw_key[:12],
scopes=scopes, # type: ignore
created_by=created_by,
expires_at=expires_at,
)
result = self.serialize_api_key(api_key)
result["api_key"] = raw_key
return result
async def create_api_key_from_legacy_payload(
self,
payload: object,
*,
created_by: str,
) -> dict:
data = payload if isinstance(payload, dict) else {}
return await self.create_api_key(data, created_by=created_by)
async def revoke_api_key(self, key_id: str | None) -> bool:
if not key_id:
raise ApiKeyServiceError("Missing key: key_id")
return await self.db.revoke_api_key(key_id)
async def revoke_api_key_from_legacy_payload(self, payload: object) -> None:
data = payload if isinstance(payload, dict) else {}
if not await self.revoke_api_key(data.get("key_id")):
raise ApiKeyServiceError("API key not found")
async def delete_api_key(self, key_id: str | None) -> bool:
if not key_id:
raise ApiKeyServiceError("Missing key: key_id")
return await self.db.delete_api_key(key_id)
async def delete_api_key_from_legacy_payload(self, payload: object) -> None:
data = payload if isinstance(payload, dict) else {}
if not await self.delete_api_key(data.get("key_id")):
raise ApiKeyServiceError("API key not found")

View File

@@ -0,0 +1,452 @@
from __future__ import annotations
import asyncio
import datetime
import os
from dataclasses import dataclass
import jwt
import pyotp
from astrbot import logger
from astrbot.core import DEMO_MODE
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.auth_password import (
is_default_dashboard_password,
is_legacy_dashboard_password,
validate_dashboard_password,
verify_dashboard_password,
)
from astrbot.core.utils.totp import (
TOTP_TRUSTED_DEVICE_COOKIE_NAME as _TOTP_TRUSTED_DEVICE_COOKIE_NAME,
)
from astrbot.core.utils.totp import (
TOTP_TRUSTED_DEVICE_MAX_AGE as _TOTP_TRUSTED_DEVICE_MAX_AGE,
)
from astrbot.core.utils.totp import (
TwoFactorCodeType,
consume_configured_totp_code,
consume_rotation_verified,
consume_totp_code,
generate_recovery_code,
is_totp_enabled,
is_totp_trusted_device_valid,
issue_totp_trusted_device,
revoke_user_trusted_devices,
set_pending_totp_secret,
set_rotation_verified,
verify_configured_2fa_code,
)
from astrbot.dashboard.password_state import (
get_dashboard_password_hash,
is_password_change_required,
is_password_storage_upgraded,
set_dashboard_password_hashes,
set_password_change_required,
set_password_storage_upgraded,
)
ALL_OPEN_API_SCOPES = (
"chat",
"config",
"file",
"im",
"plugin",
"tool",
"skill",
"kb",
"persona",
"data",
"system",
)
DASHBOARD_JWT_COOKIE_NAME = "astrbot_dashboard_jwt"
DASHBOARD_JWT_COOKIE_MAX_AGE = 7 * 24 * 60 * 60
SKIP_DEFAULT_PASSWORD_AUTH_ENV = "ASTRBOT_DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY = "DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
LOCAL_DASHBOARD_HOSTS = {"127.0.0.1", "localhost", "::1"}
DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE = (
"Login failed. If this is your first time using AstrBot, the old default "
"astrbot password has been replaced by a random strong password printed in "
"the startup logs. Check the initial password in the logs and try again. "
"Learn more: https://docs.astrbot.app/en/faq.html\n\n"
"登录失败。如果您是初次使用,旧版默认 astrbot 密码已改为启动日志中输出的"
"随机强密码。请使用日志中提供的的初始密码来登录。了解更多:"
"https://docs.astrbot.app/faq.html"
)
LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE = (
"Incorrect username or password. If you cannot log in after upgrading "
"AstrBot even though the password is correct, see "
"https://docs.astrbot.app/en/faq.html\n\n"
"用户名或密码错误。如果你在升级 AstrBot 后遇到了密码正确但无法登录的情况,"
"请参考 https://docs.astrbot.app/faq.html"
)
TOTP_TRUSTED_DEVICE_COOKIE_NAME = _TOTP_TRUSTED_DEVICE_COOKIE_NAME
TOTP_TRUSTED_DEVICE_MAX_AGE = _TOTP_TRUSTED_DEVICE_MAX_AGE
@dataclass
class AuthServiceResult:
status: str = "ok"
data: dict | None = None
message: str | None = None
status_code: int = 200
jwt_token: str | None = None
trusted_device_token: str | None = None
class AuthService:
def __init__(
self,
db: BaseDatabase,
config: AstrBotConfig,
*,
demo_mode: bool = DEMO_MODE,
) -> None:
self.db = db
self.config = config
self.demo_mode = demo_mode
async def setup_status(self) -> AuthServiceResult:
return AuthServiceResult(
data={
"setup_required": await self.is_setup_required(),
"skip_default_password_auth": self.can_skip_default_password_auth(),
"password_upgrade_required": not await is_password_storage_upgraded(
self.db,
self.config,
),
}
)
async def totp_setup(self, post_data: object) -> AuthServiceResult:
if isinstance(post_data, dict) and post_data.get("secret"):
secret = post_data["secret"]
code = post_data.get("code")
if not isinstance(secret, str) or not secret.strip():
return self.error("Invalid request payload")
if not isinstance(code, str) or not code.strip():
return self.error("TOTP 验证码是必需的")
if not await consume_totp_code(secret, code):
return self.error("TOTP 验证码无效")
if is_totp_enabled(self.config) and not consume_rotation_verified():
return self.error("需要先验证当前 TOTP")
set_pending_totp_secret(secret)
recovery_code, recovery_code_hash = generate_recovery_code()
return AuthServiceResult(
data={
"recovery_code": recovery_code,
"recovery_code_hash": recovery_code_hash,
},
message="TOTP verified",
)
if is_totp_enabled(self.config):
if not isinstance(post_data, dict):
return self.error("Invalid request payload")
set_rotation_verified(False)
code = post_data.get("code")
if isinstance(code, str) and code.strip():
if await consume_configured_totp_code(self.config, code):
set_rotation_verified(True)
return AuthServiceResult(data={"secret": pyotp.random_base32()})
return self.error("当前 TOTP 验证码无效")
return self.error("需要提供 TOTP 验证码或新密钥")
return AuthServiceResult(data={"secret": pyotp.random_base32()})
async def totp_recovery(self) -> AuthServiceResult:
recovery_code, recovery_code_hash = generate_recovery_code()
return AuthServiceResult(
data={
"recovery_code": recovery_code,
"recovery_code_hash": recovery_code_hash,
}
)
async def setup(self, post_data: object) -> AuthServiceResult:
if not self.can_skip_default_password_auth():
return self.error("Setup without password is not enabled")
if not await self.is_setup_required():
return self.error("Setup is not required")
return await self.complete_setup(post_data)
async def setup_authenticated(
self,
post_data: object,
authenticated_username,
) -> AuthServiceResult:
if not await self.is_setup_required():
return self.error("Setup is not required")
if not isinstance(authenticated_username, str):
return self.error("未授权")
return await self.complete_setup(post_data)
async def complete_setup(self, post_data: object) -> AuthServiceResult:
if not isinstance(post_data, dict):
return self.error("Invalid request payload")
new_username = post_data.get("username")
new_password = post_data.get("password")
confirm_password = post_data.get("confirm_password")
if not isinstance(new_username, str) or len(new_username.strip()) < 3:
return self.error("用户名长度至少3位")
if not isinstance(new_password, str):
return self.error("新密码无效")
if not isinstance(confirm_password, str) or confirm_password != new_password:
return self.error("两次输入的新密码不一致")
try:
validate_dashboard_password(new_password)
except ValueError as exc:
return self.error(str(exc))
username = new_username.strip()
self.config["dashboard"]["username"] = username
set_dashboard_password_hashes(self.config, new_password)
await set_password_storage_upgraded(self.db, self.config, True)
await set_password_change_required(self.db, self.config, False)
self.config.save_config()
token = self.generate_jwt(username)
return AuthServiceResult(
data={
"token": token,
"username": username,
"change_pwd_hint": False,
"legacy_pwd_hint": False,
"password_upgrade_required": False,
},
message="Setup completed successfully",
jwt_token=token,
)
async def login(
self,
post_data: object,
*,
trusted_device_cookie_token: str,
) -> AuthServiceResult:
username = self.config["dashboard"]["username"]
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
req_username = (
post_data.get("username") if isinstance(post_data, dict) else None
)
req_password = (
post_data.get("password") if isinstance(post_data, dict) else None
)
totp_code = post_data.get("code") if isinstance(post_data, dict) else None
trust_device_flag = (
post_data.get("trust_device_flag") is True
if isinstance(post_data, dict)
else False
)
if not isinstance(req_username, str) or not isinstance(req_password, str):
return self.error("Invalid request payload")
login_verified = req_username == username and verify_dashboard_password(
password,
req_password,
)
if not login_verified:
await asyncio.sleep(3)
if req_password == "astrbot":
return self.error(DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE)
if is_legacy_dashboard_password(password):
return self.error(LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE)
return self.error("用户名或密码错误", status_code=401)
totp_verified = False
if is_totp_enabled(self.config):
if not await is_totp_trusted_device_valid(
self.config,
self.db,
trusted_device_cookie_token,
):
if not isinstance(totp_code, str) or not totp_code.strip():
return self.error(
"需要 TOTP 验证",
data={"totp_required": True},
status_code=401,
)
verified_type = await verify_configured_2fa_code(
self.config,
totp_code,
allow_recovery=True,
)
if verified_type is TwoFactorCodeType.TOTP:
totp_verified = True
elif verified_type is TwoFactorCodeType.RECOVERY:
self.config["dashboard"]["totp"] = {
"enable": False,
"secret": "",
"recovery_code_hash": "",
}
await revoke_user_trusted_devices(self.db)
self.config.save_config()
elif len(totp_code) == 6 and totp_code.isdigit():
return self.error("TOTP 验证码无效", status_code=401)
else:
return self.error("恢复码无效", status_code=401)
change_pwd_hint = False
legacy_pwd_hint = is_legacy_dashboard_password(password)
password_change_required = await is_password_change_required(
self.db,
self.config,
)
if (
storage_upgraded
and username == "astrbot"
and is_default_dashboard_password(password)
and not self.demo_mode
):
change_pwd_hint = True
legacy_pwd_hint = True
logger.warning("为了保证安全,请尽快修改默认密码。")
if password_change_required and not self.demo_mode:
change_pwd_hint = True
token = self.generate_jwt(username)
result = AuthServiceResult(
data={
"token": token,
"username": username,
"change_pwd_hint": change_pwd_hint,
"legacy_pwd_hint": legacy_pwd_hint,
"password_upgrade_required": not storage_upgraded,
},
jwt_token=token,
)
if totp_verified and trust_device_flag:
result.trusted_device_token = await issue_totp_trusted_device(
self.config,
self.db,
)
return result
async def edit_account(self, post_data: object) -> AuthServiceResult:
if self.demo_mode:
return self.error("You are not permitted to do this operation in demo mode")
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
if not isinstance(post_data, dict):
return self.error("Invalid request payload")
req_password = post_data.get("password")
if not isinstance(req_password, str):
return self.error("Invalid request payload")
if not verify_dashboard_password(password, req_password):
return self.error("原密码错误")
new_pwd = post_data.get("new_password", None)
new_username = post_data.get("new_username", None)
password_change_required = await is_password_change_required(
self.db,
self.config,
)
if (not storage_upgraded or password_change_required) and not new_pwd:
return self.error("请设置新密码以完成安全升级")
if not new_pwd and not new_username:
return self.error("新用户名和新密码不能同时为空")
if new_pwd:
if not isinstance(new_pwd, str):
return self.error("新密码无效")
confirm_pwd = post_data.get("confirm_password", None)
if not isinstance(confirm_pwd, str) or confirm_pwd != new_pwd:
return self.error("两次输入的新密码不一致")
try:
validate_dashboard_password(new_pwd)
except ValueError as exc:
return self.error(str(exc))
set_dashboard_password_hashes(self.config, new_pwd)
await set_password_storage_upgraded(self.db, self.config, True)
await set_password_change_required(self.db, self.config, False)
if is_totp_enabled(self.config):
await revoke_user_trusted_devices(self.db)
if new_username:
self.config["dashboard"]["username"] = new_username
self.config.save_config()
return AuthServiceResult(message="Updated account successfully")
def generate_jwt(self, username: str):
payload = {
"username": username,
"exp": datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(days=7),
}
jwt_token = self.config["dashboard"].get("jwt_secret", None)
if not jwt_token:
raise ValueError("JWT secret is not set in the cmd_config.")
return jwt.encode(payload, jwt_token, algorithm="HS256")
async def is_setup_required(self) -> bool:
if self.demo_mode:
return False
dashboard_config = self.config["dashboard"]
password_change_required = await is_password_change_required(
self.db,
self.config,
)
if password_change_required:
return True
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
if not storage_upgraded:
return False
return dashboard_config.get(
"username"
) == "astrbot" and is_default_dashboard_password(
dashboard_config.get("pbkdf2_password", "")
)
def can_skip_default_password_auth(self) -> bool:
if not self.env_flag_enabled(SKIP_DEFAULT_PASSWORD_AUTH_ENV):
return False
host = (
os.environ.get("DASHBOARD_HOST")
or os.environ.get("ASTRBOT_DASHBOARD_HOST")
or self.config["dashboard"].get("host", "")
)
return str(host).strip().lower() in LOCAL_DASHBOARD_HOSTS
@staticmethod
def env_flag_enabled(name: str) -> bool:
value = os.environ.get(name)
if value is None and name == SKIP_DEFAULT_PASSWORD_AUTH_ENV:
value = os.environ.get(SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY)
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}
@staticmethod
def error(
message: str,
*,
data: dict | None = None,
status_code: int = 200,
) -> AuthServiceResult:
return AuthServiceResult(
status="error",
data=data,
message=message,
status_code=status_code,
)

View File

@@ -0,0 +1,694 @@
from __future__ import annotations
import asyncio
import json
import math
import os
import re
import shutil
import time
import traceback
import uuid
import zipfile
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import jwt
from astrbot.core import logger
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import AstrBotImporter
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_backups_path,
get_astrbot_data_path,
)
CHUNK_SIZE = 1024 * 1024
UPLOAD_EXPIRE_SECONDS = 3600
class BackupServiceError(Exception):
pass
@dataclass
class BackupDownload:
path: str
filename: str
def secure_filename(filename: str) -> str:
filename = filename.replace("\\", "/")
filename = os.path.basename(filename)
filename = filename.replace("..", "_")
filename = re.sub(r"[^\w\-.]", "_", filename)
filename = filename.strip(".")
if not filename or filename.replace("_", "") == "":
filename = "backup"
return filename
def generate_unique_filename(original_filename: str) -> str:
name, ext = os.path.splitext(original_filename)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{name}_{timestamp}{ext}"
class BackupService:
def __init__(
self,
db: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
self.db = db
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.backup_dir = get_astrbot_backups_path()
self.data_dir = get_astrbot_data_path()
self.chunks_dir = os.path.join(self.backup_dir, ".chunks")
self.backup_tasks: dict[str, dict] = {}
self.backup_progress: dict[str, dict] = {}
self.upload_sessions: dict[str, dict] = {}
self._cleanup_task: asyncio.Task | None = None
@staticmethod
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}
@staticmethod
async def _save_upload(file: Any, target_path: str) -> None:
if hasattr(file, "save"):
result = file.save(target_path)
if hasattr(result, "__await__"):
await result
return
if hasattr(file, "read"):
data = file.read()
if hasattr(data, "__await__"):
data = await data
Path(target_path).write_bytes(data)
return
raise BackupServiceError("无效的上传文件")
@staticmethod
def _validate_backup_filename(filename: str | None, *, missing: str) -> str:
if not filename:
raise BackupServiceError(missing)
if ".." in filename or "/" in filename or "\\" in filename:
raise BackupServiceError("无效的文件名")
return filename
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
self.backup_tasks[task_id] = {
"type": task_type,
"status": status,
"result": None,
"error": None,
}
self.backup_progress[task_id] = {
"status": status,
"stage": "waiting",
"current": 0,
"total": 100,
"message": "",
}
def _set_task_result(
self,
task_id: str,
status: str,
result: dict | None = None,
error: str | None = None,
) -> None:
if task_id in self.backup_tasks:
self.backup_tasks[task_id]["status"] = status
self.backup_tasks[task_id]["result"] = result
self.backup_tasks[task_id]["error"] = error
if task_id in self.backup_progress:
self.backup_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
message: str | None = None,
) -> None:
if task_id not in self.backup_progress:
return
progress = self.backup_progress[task_id]
if status is not None:
progress["status"] = status
if stage is not None:
progress["stage"] = stage
if current is not None:
progress["current"] = current
if total is not None:
progress["total"] = total
if message is not None:
progress["message"] = message
def _make_progress_callback(self, task_id: str):
async def _callback(
stage: str,
current: int,
total: int,
message: str = "",
) -> None:
self._update_progress(
task_id,
status="processing",
stage=stage,
current=current,
total=total,
message=message,
)
return _callback
def ensure_cleanup_task_started(self) -> None:
if self._cleanup_task is None or self._cleanup_task.done():
try:
self._cleanup_task = asyncio.create_task(
self._cleanup_expired_uploads()
)
except RuntimeError:
pass
async def _cleanup_expired_uploads(self) -> None:
while True:
try:
await asyncio.sleep(300)
current_time = time.time()
expired_sessions = []
for upload_id, session in self.upload_sessions.items():
last_activity = session.get("last_activity", session["created_at"])
if current_time - last_activity > UPLOAD_EXPIRE_SECONDS:
expired_sessions.append(upload_id)
for upload_id in expired_sessions:
await self.cleanup_upload_session(upload_id)
logger.info(f"清理过期的上传会话: {upload_id}")
except asyncio.CancelledError:
break
except Exception as exc:
logger.error(f"清理过期上传会话失败: {exc}")
async def cleanup_upload_session(self, upload_id: str) -> None:
if upload_id in self.upload_sessions:
session = self.upload_sessions[upload_id]
chunk_dir = session.get("chunk_dir")
if chunk_dir and os.path.exists(chunk_dir):
try:
shutil.rmtree(chunk_dir)
except Exception as exc:
logger.warning(f"清理分片目录失败: {exc}")
del self.upload_sessions[upload_id]
def get_backup_manifest(self, zip_path: str) -> dict | None:
try:
with zipfile.ZipFile(zip_path, "r") as zf:
if "manifest.json" in zf.namelist():
manifest_data = zf.read("manifest.json")
return json.loads(manifest_data.decode("utf-8"))
return None
except Exception as exc:
logger.debug(f"读取备份 manifest 失败: {exc}")
return None
def list_backups(self, *, page: int, page_size: int) -> dict:
self.ensure_cleanup_task_started()
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
backup_files = []
for filename in os.listdir(self.backup_dir):
if not filename.endswith(".zip") or filename.startswith("."):
continue
file_path = os.path.join(self.backup_dir, filename)
if not os.path.isfile(file_path):
continue
manifest = self.get_backup_manifest(file_path)
if manifest is None:
logger.debug(f"跳过无效备份文件: {filename}")
continue
stat = os.stat(file_path)
backup_files.append(
{
"filename": filename,
"size": stat.st_size,
"created_at": stat.st_mtime,
"type": manifest.get("origin", "exported"),
"astrbot_version": manifest.get("astrbot_version", "未知"),
"exported_at": manifest.get("exported_at"),
}
)
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
start = (page - 1) * page_size
end = start + page_size
return {
"items": backup_files[start:end],
"total": len(backup_files),
"page": page,
"page_size": page_size,
}
def list_backups_from_legacy_query(self, *, page, page_size) -> dict:
return self.list_backups(
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 20),
)
def export_backup(self) -> dict:
task_id = str(uuid.uuid4())
self._init_task(task_id, "export", "pending")
asyncio.create_task(self.background_export_task(task_id))
return {
"task_id": task_id,
"message": "export task created, processing in background",
}
async def background_export_task(self, task_id: str) -> None:
try:
self._update_progress(task_id, status="processing", message="正在初始化...")
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
exporter = AstrBotExporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
zip_path = await exporter.export_all(
output_dir=self.backup_dir,
progress_callback=self._make_progress_callback(task_id),
)
self._set_task_result(
task_id,
"completed",
result={
"filename": os.path.basename(zip_path),
"path": zip_path,
"size": os.path.getsize(zip_path),
},
)
except Exception as exc:
logger.error(f"后台导出任务 {task_id} 失败: {exc}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(exc))
async def upload_backup(self, file: Any | None) -> dict:
if not file:
raise BackupServiceError("缺少备份文件")
if not file.filename or not file.filename.endswith(".zip"):
raise BackupServiceError("请上传 ZIP 格式的备份文件")
safe_filename = secure_filename(file.filename)
unique_filename = generate_unique_filename(safe_filename)
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
zip_path = os.path.join(self.backup_dir, unique_filename)
await self._save_upload(file, zip_path)
logger.info(
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
)
return {
"filename": unique_filename,
"original_filename": file.filename,
"size": os.path.getsize(zip_path),
}
def upload_init(self, data: object) -> dict:
payload = self._payload(data)
filename = payload.get("filename")
total_size = payload.get("total_size", 0)
if not filename:
raise BackupServiceError("缺少 filename 参数")
if not filename.endswith(".zip"):
raise BackupServiceError("请上传 ZIP 格式的备份文件")
if total_size <= 0:
raise BackupServiceError("无效的文件大小")
total_chunks = math.ceil(total_size / CHUNK_SIZE)
upload_id = str(uuid.uuid4())
chunk_dir = os.path.join(self.chunks_dir, upload_id)
Path(chunk_dir).mkdir(parents=True, exist_ok=True)
safe_filename = secure_filename(filename)
unique_filename = generate_unique_filename(safe_filename)
current_time = time.time()
self.upload_sessions[upload_id] = {
"filename": unique_filename,
"original_filename": filename,
"total_size": total_size,
"total_chunks": total_chunks,
"received_chunks": set(),
"created_at": current_time,
"last_activity": current_time,
"chunk_dir": chunk_dir,
}
logger.info(
f"初始化分片上传: upload_id={upload_id}, "
f"filename={unique_filename}, total_chunks={total_chunks}"
)
return {
"upload_id": upload_id,
"chunk_size": CHUNK_SIZE,
"total_chunks": total_chunks,
"filename": unique_filename,
}
async def upload_chunk(
self,
*,
upload_id: str | None,
chunk_index_str: str | None,
chunk_file: Any | None,
) -> dict:
if not upload_id or chunk_index_str is None:
raise BackupServiceError("缺少必要参数")
try:
chunk_index = int(chunk_index_str)
except ValueError as exc:
raise BackupServiceError("无效的分片索引") from exc
if not chunk_file:
raise BackupServiceError("缺少分片数据")
if upload_id not in self.upload_sessions:
raise BackupServiceError("上传会话不存在或已过期")
session = self.upload_sessions[upload_id]
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
raise BackupServiceError("分片索引超出范围")
chunk_path = os.path.join(session["chunk_dir"], f"{chunk_index}.part")
await self._save_upload(chunk_file, chunk_path)
session["received_chunks"].add(chunk_index)
session["last_activity"] = time.time()
received_count = len(session["received_chunks"])
total_chunks = session["total_chunks"]
logger.debug(
f"接收分片: upload_id={upload_id}, chunk={chunk_index + 1}/{total_chunks}"
)
return {
"received": received_count,
"total": total_chunks,
"chunk_index": chunk_index,
}
def mark_backup_as_uploaded(self, zip_path: str) -> None:
try:
manifest = {"origin": "uploaded", "uploaded_at": datetime.now().isoformat()}
with zipfile.ZipFile(zip_path, "r") as zf:
if "manifest.json" in zf.namelist():
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data.decode("utf-8"))
manifest["origin"] = "uploaded"
manifest["uploaded_at"] = datetime.now().isoformat()
with zipfile.ZipFile(zip_path, "a") as zf:
new_manifest = json.dumps(manifest, ensure_ascii=False, indent=2)
zf.writestr("manifest.json", new_manifest)
logger.debug(f"已标记备份为上传来源: {zip_path}")
except Exception as exc:
logger.warning(f"标记备份来源失败: {exc}")
async def upload_complete(self, data: object) -> dict:
payload = self._payload(data)
upload_id = payload.get("upload_id")
if not upload_id:
raise BackupServiceError("缺少 upload_id 参数")
if upload_id not in self.upload_sessions:
raise BackupServiceError("上传会话不存在或已过期")
session = self.upload_sessions[upload_id]
received = session["received_chunks"]
total = session["total_chunks"]
if len(received) != total:
missing = set(range(total)) - received
raise BackupServiceError(f"分片不完整,缺少: {sorted(missing)[:10]}...")
chunk_dir = session["chunk_dir"]
filename = session["filename"]
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
output_path = os.path.join(self.backup_dir, filename)
try:
with open(output_path, "wb") as outfile:
for i in range(total):
chunk_path = os.path.join(chunk_dir, f"{i}.part")
with open(chunk_path, "rb") as chunk_file:
while True:
data_block = chunk_file.read(8192)
if not data_block:
break
outfile.write(data_block)
file_size = os.path.getsize(output_path)
self.mark_backup_as_uploaded(output_path)
logger.info(f"分片上传完成: {filename}, size={file_size}, chunks={total}")
await self.cleanup_upload_session(upload_id)
return {
"filename": filename,
"original_filename": session["original_filename"],
"size": file_size,
}
except Exception:
if os.path.exists(output_path):
os.remove(output_path)
raise
async def upload_abort(self, data: object) -> tuple[dict | None, str | None]:
payload = self._payload(data)
upload_id = payload.get("upload_id")
if not upload_id:
raise BackupServiceError("缺少 upload_id 参数")
if upload_id in self.upload_sessions:
await self.cleanup_upload_session(upload_id)
logger.info(f"取消分片上传: {upload_id}")
return None, "上传已取消"
def check_backup(self, data: object) -> dict:
payload = self._payload(data)
filename = self._validate_backup_filename(
payload.get("filename"),
missing="缺少 filename 参数",
)
zip_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(zip_path):
raise BackupServiceError(f"备份文件不存在: {filename}")
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
importer = AstrBotImporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
return importer.pre_check(zip_path).to_dict()
def import_backup(self, data: object) -> dict:
payload = self._payload(data)
filename = self._validate_backup_filename(
payload.get("filename"),
missing="缺少 filename 参数",
)
confirmed = payload.get("confirmed", False)
if not confirmed:
raise BackupServiceError(
"请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。"
)
zip_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(zip_path):
raise BackupServiceError(f"备份文件不存在: {filename}")
task_id = str(uuid.uuid4())
self._init_task(task_id, "import", "pending")
asyncio.create_task(self.background_import_task(task_id, zip_path))
return {
"task_id": task_id,
"message": "import task created, processing in background",
}
async def background_import_task(self, task_id: str, zip_path: str) -> None:
try:
self._update_progress(task_id, status="processing", message="正在初始化...")
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
importer = AstrBotImporter(
main_db=self.db,
kb_manager=kb_manager,
config_path=os.path.join(self.data_dir, "cmd_config.json"),
)
result = await importer.import_all(
zip_path=zip_path,
mode="replace",
progress_callback=self._make_progress_callback(task_id),
)
if result.success:
self._set_task_result(task_id, "completed", result=result.to_dict())
else:
self._set_task_result(
task_id,
"failed",
error="; ".join(result.errors),
)
except Exception as exc:
logger.error(f"后台导入任务 {task_id} 失败: {exc}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(exc))
def get_progress(self, task_id: str | None) -> dict:
if not task_id:
raise BackupServiceError("缺少参数 task_id")
if task_id not in self.backup_tasks:
raise BackupServiceError("找不到该任务")
task_info = self.backup_tasks[task_id]
status = task_info["status"]
response_data = {
"task_id": task_id,
"type": task_info["type"],
"status": status,
}
if status == "processing" and task_id in self.backup_progress:
response_data["progress"] = self.backup_progress[task_id]
if status == "completed":
response_data["result"] = task_info["result"]
if status == "failed":
response_data["error"] = task_info["error"]
return response_data
def get_progress_from_legacy_query(self, task_id: str | None) -> dict:
return self.get_progress(task_id)
def prepare_download(
self,
*,
filename: str | None,
token: str | None,
jwt_secret: str | None,
) -> BackupDownload:
if not filename:
raise BackupServiceError("缺少参数 filename")
if not token:
raise BackupServiceError("缺少参数 token")
if not jwt_secret:
raise BackupServiceError("服务器配置错误")
try:
jwt.decode(
token,
jwt_secret,
algorithms=["HS256"],
options={
"require": ["exp"],
"verify_signature": True,
"verify_exp": True,
},
)
except jwt.ExpiredSignatureError as exc:
raise BackupServiceError("Token 已过期,请刷新页面后重试") from exc
except jwt.InvalidTokenError as exc:
raise BackupServiceError("Token 无效") from exc
filename = self._validate_backup_filename(filename, missing="缺少参数 filename")
file_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(file_path):
raise BackupServiceError("备份文件不存在")
return BackupDownload(path=file_path, filename=filename)
def prepare_download_from_legacy_query(
self,
*,
filename: str | None,
token: str | None,
jwt_secret: str | None = None,
) -> BackupDownload:
return self.prepare_download(
filename=filename,
token=token,
jwt_secret=jwt_secret or self.config.get("dashboard", {}).get("jwt_secret"),
)
def delete_backup(self, data: object) -> tuple[dict | None, str | None]:
payload = self._payload(data)
filename = self._validate_backup_filename(
payload.get("filename"),
missing="缺少参数 filename",
)
file_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(file_path):
raise BackupServiceError("备份文件不存在")
os.remove(file_path)
return None, "删除备份成功"
def rename_backup(self, data: object) -> dict:
payload = self._payload(data)
filename = self._validate_backup_filename(
payload.get("filename"),
missing="缺少参数 filename",
)
new_name = payload.get("new_name")
if not new_name:
raise BackupServiceError("缺少参数 new_name")
new_name = secure_filename(new_name)
if new_name.endswith(".zip"):
new_name = new_name[:-4]
if not new_name or new_name.replace("_", "") == "":
raise BackupServiceError("新文件名无效")
new_filename = f"{new_name}.zip"
old_path = os.path.join(self.backup_dir, filename)
if not os.path.exists(old_path):
raise BackupServiceError("备份文件不存在")
new_path = os.path.join(self.backup_dir, new_filename)
if os.path.exists(new_path):
raise BackupServiceError(f"文件名 '{new_filename}' 已存在")
os.rename(old_path, new_path)
logger.info(f"备份文件重命名: {filename} -> {new_filename}")
return {
"old_filename": filename,
"new_filename": new_filename,
}
@staticmethod
def _to_int(value, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,162 @@
from __future__ import annotations
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.datetime_utils import to_utc_isoformat
class ChatUIProjectServiceError(Exception):
pass
class ChatUIProjectService:
def __init__(self, db: BaseDatabase) -> None:
self.db = db
async def create_project(self, username: str, data: object) -> dict:
payload = self._as_payload(data)
title = payload.get("title")
emoji = payload.get("emoji", "📁")
description = payload.get("description")
if not title:
raise ChatUIProjectServiceError("Missing key: title")
project = await self.db.create_chatui_project(
creator=username,
title=title,
emoji=emoji,
description=description,
)
return self._serialize_project(project)
async def list_projects(self, username: str) -> list[dict]:
projects = await self.db.get_chatui_projects_by_creator(creator=username)
return [self._serialize_project(project) for project in projects]
async def get_project(self, username: str, project_id: str | None) -> dict:
if not project_id:
raise ChatUIProjectServiceError("Missing key: project_id")
project = await self._get_owned_project(username, project_id)
return self._serialize_project(project)
async def get_project_from_legacy_query(
self,
username: str,
project_id: str | None,
) -> dict:
return await self.get_project(username, project_id)
async def update_project(self, username: str, data: object) -> None:
payload = self._as_payload(data)
project_id = payload.get("project_id")
if not project_id:
raise ChatUIProjectServiceError("Missing key: project_id")
await self._get_owned_project(username, project_id)
await self.db.update_chatui_project(
project_id=project_id,
title=payload.get("title"),
emoji=payload.get("emoji"),
description=payload.get("description"),
)
async def delete_project(self, username: str, project_id: str | None) -> None:
if not project_id:
raise ChatUIProjectServiceError("Missing key: project_id")
await self._get_owned_project(username, project_id)
await self.db.delete_chatui_project(project_id)
async def delete_project_from_legacy_query(
self,
username: str,
project_id: str | None,
) -> None:
await self.delete_project(username, project_id)
async def add_session_to_project(self, username: str, data: object) -> None:
payload = self._as_payload(data)
session_id = payload.get("session_id")
project_id = payload.get("project_id")
if not session_id:
raise ChatUIProjectServiceError("Missing key: session_id")
if not project_id:
raise ChatUIProjectServiceError("Missing key: project_id")
await self._get_owned_project(username, project_id)
await self._get_owned_session(username, session_id)
await self.db.add_session_to_project(session_id, project_id)
async def remove_session_from_project(self, username: str, data: object) -> None:
payload = self._as_payload(data)
session_id = payload.get("session_id")
if not session_id:
raise ChatUIProjectServiceError("Missing key: session_id")
await self._get_owned_session(username, session_id)
await self.db.remove_session_from_project(session_id)
async def get_project_sessions(
self,
username: str,
project_id: str | None,
) -> list[dict]:
if not project_id:
raise ChatUIProjectServiceError("Missing key: project_id")
await self._get_owned_project(username, project_id)
sessions = await self.db.get_project_sessions(project_id)
return [self._serialize_session(session) for session in sessions]
async def get_project_sessions_from_legacy_query(
self,
username: str,
project_id: str | None,
) -> list[dict]:
return await self.get_project_sessions(username, project_id)
async def _get_owned_project(self, username: str, project_id: str):
project = await self.db.get_chatui_project_by_id(project_id)
if not project:
raise ChatUIProjectServiceError(f"Project {project_id} not found")
if project.creator != username:
raise ChatUIProjectServiceError("Permission denied")
return project
async def _get_owned_session(self, username: str, session_id: str):
session = await self.db.get_platform_session_by_id(session_id)
if not session:
raise ChatUIProjectServiceError(f"Session {session_id} not found")
if session.creator != username:
raise ChatUIProjectServiceError("Permission denied")
return session
@staticmethod
def _serialize_project(project) -> dict:
return {
"project_id": project.project_id,
"title": project.title,
"emoji": project.emoji,
"description": project.description,
"created_at": to_utc_isoformat(project.created_at),
"updated_at": to_utc_isoformat(project.updated_at),
}
@staticmethod
def _serialize_session(session) -> dict:
return {
"session_id": session.session_id,
"platform_id": session.platform_id,
"creator": session.creator,
"display_name": session.display_name,
"is_group": session.is_group,
"created_at": to_utc_isoformat(session.created_at),
"updated_at": to_utc_isoformat(session.updated_at),
}
@staticmethod
def _as_payload(data: object) -> dict:
return data if isinstance(data, dict) else {}

View File

@@ -0,0 +1,133 @@
from __future__ import annotations
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.command_management import (
list_command_conflicts,
list_commands,
rename_command,
toggle_command,
update_command_permission,
)
class CommandServiceError(Exception):
pass
class CommandService:
def __init__(
self,
config: AstrBotConfig,
core_lifecycle: AstrBotCoreLifecycle | None = None,
) -> None:
self.config = config
self.core_lifecycle = core_lifecycle
async def list_commands(self, config_id: str = "") -> dict:
commands = await list_commands()
summary = {
"total": len(commands),
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
}
wake_prefix = self._get_wake_prefix(config_id)
return {
"items": commands,
"summary": summary,
"wake_prefix": wake_prefix,
}
async def list_commands_from_legacy_query(self, config_id: str | None) -> dict:
return await self.list_commands(config_id or "")
async def list_conflicts(self):
return await list_command_conflicts()
async def toggle_command(self, handler_full_name: str | None, enabled) -> dict:
if handler_full_name is None or enabled is None:
raise CommandServiceError("handler_full_name 与 enabled 均为必填。")
if isinstance(enabled, str):
enabled = enabled.lower() in ("1", "true", "yes", "on")
try:
await toggle_command(handler_full_name, bool(enabled))
except ValueError as exc:
raise CommandServiceError(str(exc)) from exc
return await self._get_command_payload(handler_full_name)
async def toggle_command_from_legacy_payload(self, data: object) -> dict:
payload = self._payload(data)
return await self.toggle_command(
payload.get("handler_full_name"),
payload.get("enabled"),
)
async def rename_command(
self,
handler_full_name: str | None,
new_name: str | None,
aliases=None,
) -> dict:
if not handler_full_name or not new_name:
raise CommandServiceError("handler_full_name 与 new_name 均为必填。")
try:
await rename_command(handler_full_name, new_name, aliases=aliases)
except ValueError as exc:
raise CommandServiceError(str(exc)) from exc
return await self._get_command_payload(handler_full_name)
async def rename_command_from_legacy_payload(self, data: object) -> dict:
payload = self._payload(data)
return await self.rename_command(
payload.get("handler_full_name"),
payload.get("new_name"),
aliases=payload.get("aliases"),
)
async def update_permission(
self,
handler_full_name: str | None,
permission: str | None,
) -> dict:
if not handler_full_name or not permission:
raise CommandServiceError("handler_full_name 与 permission 均为必填。")
try:
await update_command_permission(handler_full_name, permission)
except ValueError as exc:
raise CommandServiceError(str(exc)) from exc
return await self._get_command_payload(handler_full_name)
async def update_permission_from_legacy_payload(self, data: object) -> dict:
payload = self._payload(data)
return await self.update_permission(
payload.get("handler_full_name"),
payload.get("permission"),
)
def _get_wake_prefix(self, config_id: str) -> list:
wake_prefix = self.config.get("wake_prefix", ["/"])
config_id = config_id.strip()
if config_id and self.core_lifecycle:
config_mgr = getattr(self.core_lifecycle, "astrbot_config_mgr", None)
if config_mgr and config_id in config_mgr.confs:
return config_mgr.confs[config_id].get("wake_prefix", wake_prefix)
return wake_prefix
@staticmethod
async def _get_command_payload(handler_full_name: str) -> dict:
commands = await list_commands()
for cmd in commands:
if cmd["handler_full_name"] == handler_full_name:
return cmd
return {}
@staticmethod
def _payload(data: object) -> dict:
return data if isinstance(data, dict) else {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,337 @@
from __future__ import annotations
import json
import traceback
from dataclasses import asdict, dataclass
from datetime import datetime
from io import BytesIO
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias
class ConversationServiceError(Exception):
pass
@dataclass
class ConversationExport:
file_obj: BytesIO
filename: str
mimetype: str = "application/jsonl"
class ConversationService:
def __init__(
self,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager
async def list_conversations(
self,
*,
page: int,
page_size: int,
platforms: str,
message_types: str,
search_query: str,
exclude_ids: str,
exclude_platforms: str,
) -> dict:
platform_list = platforms.split(",") if platforms else []
message_type_list = message_types.split(",") if message_types else []
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
exclude_platform_list = (
exclude_platforms.split(",") if exclude_platforms else []
)
page = max(page, 1)
if page_size < 1:
page_size = 20
page_size = min(page_size, 100)
try:
conversations, total_count = await self.conv_mgr.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
message_types=message_type_list,
search_query=search_query,
exclude_ids=exclude_id_list,
exclude_platforms=exclude_platform_list,
)
except Exception as exc:
logger.error(f"数据库查询出错: {exc!s}\n{traceback.format_exc()}")
raise ConversationServiceError(f"数据库查询出错: {exc!s}") from exc
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 1
)
umos = sorted({conv.user_id for conv in conversations if conv.user_id})
alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases(umos))
return {
"conversations": [
self._serialize_conversation(conversation, alias_map)
for conversation in conversations
],
"pagination": {
"page": page,
"page_size": page_size,
"total": total_count,
"total_pages": total_pages,
},
}
async def list_conversations_from_legacy_query(
self,
*,
page,
page_size,
platforms: str | None,
message_types: str | None,
search_query: str | None,
exclude_ids: str | None,
exclude_platforms: str | None,
) -> dict:
return await self.list_conversations(
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 20),
platforms=platforms or "",
message_types=message_types or "",
search_query=search_query or "",
exclude_ids=exclude_ids or "",
exclude_platforms=exclude_platforms or "",
)
async def get_conversation_detail(self, data: object) -> dict:
payload = self._payload(data)
user_id, cid = self._require_user_and_cid(payload)
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
raise ConversationServiceError("对话不存在")
alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases([user_id]))
return {
"user_id": user_id,
"cid": cid,
"title": conversation.title,
"persona_id": conversation.persona_id,
"history": conversation.history,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"umo_info": self._build_umo_info(user_id, alias_map),
}
async def update_conversation(self, data: object) -> dict:
payload = self._payload(data)
user_id, cid = self._require_user_and_cid(payload)
title = payload.get("title")
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
raise ConversationServiceError("对话不存在")
persona_id = payload.get("persona_id", conversation.persona_id)
if title is not None or persona_id is not None:
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
title=title,
persona_id=persona_id,
)
return {"message": "对话信息更新成功"}
async def delete_conversation(self, data: object) -> dict:
payload = self._payload(data)
if "conversations" in payload:
return await self._delete_conversations(payload.get("conversations", []))
user_id, cid = self._require_user_and_cid(payload)
await self.conv_mgr.delete_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
return {"message": "对话删除成功"}
async def update_history(self, data: object) -> dict:
payload = self._payload(data)
user_id, cid = self._require_user_and_cid(payload)
history = payload.get("history")
if history is None:
raise ConversationServiceError("缺少必要参数: history")
history = self._normalize_history(history)
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
raise ConversationServiceError("对话不存在")
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
history=history,
)
return {"message": "对话历史更新成功"}
async def export_conversations(self, data: object) -> ConversationExport:
payload = self._payload(data)
conversations_to_export = payload.get("conversations", [])
if not conversations_to_export:
raise ConversationServiceError("导出列表不能为空")
jsonl_lines = []
exported_count = 0
failed_items = []
for conv_info in conversations_to_export:
user_id = conv_info.get("user_id")
cid = conv_info.get("cid")
if not user_id or not cid:
failed_items.append(f"user_id:{user_id}, cid:{cid} - 缺少必要参数")
continue
try:
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
failed_items.append(f"user_id:{user_id}, cid:{cid} - 对话不存在")
continue
content = json.loads(conversation.history)
export_record = {
"cid": cid,
"user_id": user_id,
"platform_id": conversation.platform_id,
"title": conversation.title,
"persona_id": conversation.persona_id,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"content": content,
}
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
exported_count += 1
except Exception as exc:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {exc!s}")
logger.error(
f"导出对话失败: user_id={user_id}, cid={cid}, error={exc!s}"
)
if exported_count == 0:
raise ConversationServiceError("没有成功导出任何对话")
jsonl_content = "\n".join(jsonl_lines)
file_obj = BytesIO(jsonl_content.encode("utf-8"))
file_obj.seek(0)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return ConversationExport(
file_obj=file_obj,
filename=f"astrbot_conversations_export_{timestamp}.jsonl",
)
async def _delete_conversations(self, conversations: object) -> dict:
if not conversations:
raise ConversationServiceError("批量删除时conversations参数不能为空")
deleted_count = 0
failed_items = []
for conv in conversations:
user_id = conv.get("user_id")
cid = conv.get("cid")
if not user_id or not cid:
failed_items.append(f"user_id:{user_id}, cid:{cid} - 缺少必要参数")
continue
try:
await self.conv_mgr.delete_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
deleted_count += 1
except Exception as exc:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {exc!s}")
message = f"成功删除 {deleted_count} 个对话"
if failed_items:
message += f",失败 {len(failed_items)}"
return {
"message": message,
"deleted_count": deleted_count,
"failed_count": len(failed_items),
"failed_items": failed_items,
}
def _serialize_conversation(self, conversation, alias_map: dict) -> dict:
return {
**asdict(conversation),
"umo_info": self._build_umo_info(conversation.user_id, alias_map),
}
@staticmethod
def _build_umo_info(umo: str | None, alias_map: dict) -> dict:
umo_str = umo or ""
return {
"umo": umo_str,
**parse_umo(umo_str),
**serialize_umo_alias(alias_map.get(umo_str), umo_str),
}
@staticmethod
def _require_user_and_cid(payload: dict) -> tuple[str, str]:
user_id = payload.get("user_id")
cid = payload.get("cid")
if not user_id or not cid:
raise ConversationServiceError("缺少必要参数: user_id 和 cid")
return user_id, cid
@staticmethod
def _normalize_history(history):
try:
if isinstance(history, list):
history = json.dumps(history)
else:
json.loads(history)
except json.JSONDecodeError as exc:
raise ConversationServiceError(
"history 必须是有效的 JSON 字符串或数组"
) from exc
return json.loads(history) if isinstance(history, str) else history
@staticmethod
def _payload(data: object) -> dict:
return data if isinstance(data, dict) else {}
@staticmethod
def _to_int(value, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default

View File

@@ -0,0 +1,285 @@
from __future__ import annotations
import asyncio
import traceback
from datetime import datetime, timezone
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class CronServiceError(Exception):
pass
class CronService:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
self.core_lifecycle = core_lifecycle
self._background_tasks: set[asyncio.Task] = set()
def _get_cron_manager(self):
cron_mgr = self.core_lifecycle.cron_manager
if cron_mgr is None:
raise CronServiceError("Cron manager not initialized")
return cron_mgr
@staticmethod
def serialize_job(job) -> dict:
data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__
for key in ["created_at", "updated_at", "last_run_at", "next_run_time"]:
value = data.get(key)
if isinstance(value, datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
data[key] = value.isoformat()
payload = data.get("payload") or {}
data["note"] = payload.get("note") or data.get("description") or ""
data["run_at"] = payload.get("run_at")
data["run_once"] = data.get("run_once", False)
data.pop("status", None)
return data
async def list_jobs(self, job_type: str | None = None) -> list[dict]:
try:
cron_mgr = self._get_cron_manager()
jobs = await cron_mgr.list_jobs(job_type)
return [self.serialize_job(job) for job in jobs]
except CronServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise CronServiceError(f"Failed to list jobs: {exc!s}") from exc
async def list_jobs_from_legacy_query(
self,
job_type: str | None,
) -> list[dict]:
return await self.list_jobs(job_type)
async def create_job(self, payload: object) -> dict:
try:
cron_mgr = self._get_cron_manager()
if not isinstance(payload, dict):
raise CronServiceError("Invalid payload")
name = payload.get("name") or "active_agent_task"
cron_expression = payload.get("cron_expression")
note = payload.get("note") or payload.get("description") or name
session = str(payload.get("session") or "").strip()
persona_id = payload.get("persona_id")
provider_id = payload.get("provider_id")
timezone_name = payload.get("timezone")
enabled = bool(payload.get("enabled", True))
run_once = bool(payload.get("run_once", False))
run_at = payload.get("run_at")
if run_once and not run_at:
raise CronServiceError("run_at is required when run_once=true")
if (not run_once) and not cron_expression:
raise CronServiceError(
"cron_expression is required when run_once=false"
)
if run_once and cron_expression:
cron_expression = None
run_at_dt = self._parse_optional_run_at(run_at)
job_payload = {
"session": session,
"note": note,
"persona_id": persona_id,
"provider_id": provider_id,
"run_at": run_at,
"origin": "api",
}
job = await cron_mgr.add_active_job(
name=name,
cron_expression=cron_expression,
payload=job_payload,
description=note,
timezone=timezone_name,
enabled=enabled,
run_once=run_once,
run_at=run_at_dt,
)
return self.serialize_job(job)
except CronServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise CronServiceError(f"Failed to create job: {exc!s}") from exc
async def update_job(self, job_id: str, payload: object) -> dict:
try:
cron_mgr = self._get_cron_manager()
if not isinstance(payload, dict):
raise CronServiceError("Invalid payload")
job = await cron_mgr.db.get_cron_job(job_id)
if not job:
raise CronServiceError("Job not found")
updates = {}
if "name" in payload:
name = str(payload.get("name") or "").strip()
if not name:
raise CronServiceError("name cannot be empty")
updates["name"] = name
if "enabled" in payload:
updates["enabled"] = bool(payload.get("enabled"))
if "timezone" in payload:
timezone_name = payload.get("timezone")
updates["timezone"] = str(timezone_name).strip() or None
if job.job_type == "active_agent":
self._merge_active_agent_updates(job, payload, updates)
else:
self._merge_generic_updates(payload, updates)
updated_job = await cron_mgr.update_job(job_id, **updates)
if not updated_job:
raise CronServiceError("Job not found")
return self.serialize_job(updated_job)
except CronServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise CronServiceError(f"Failed to update job: {exc!s}") from exc
async def delete_job(self, job_id: str) -> None:
try:
cron_mgr = self._get_cron_manager()
await cron_mgr.delete_job(job_id)
except CronServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise CronServiceError(f"Failed to delete job: {exc!s}") from exc
async def run_job_now(self, job_id: str) -> None:
try:
cron_mgr = self._get_cron_manager()
job = await cron_mgr.db.get_cron_job(job_id)
if not job:
raise CronServiceError("Job not found")
task = asyncio.create_task(cron_mgr.run_job_now(job_id))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
except CronServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise CronServiceError(f"Failed to run job: {exc!s}") from exc
@staticmethod
def _parse_optional_run_at(run_at: object) -> datetime | None:
if not run_at:
return None
try:
return datetime.fromisoformat(str(run_at))
except Exception as exc:
raise CronServiceError("run_at must be ISO datetime") from exc
@staticmethod
def _normalize_run_at_iso(run_at: object) -> str | None:
if not run_at:
return None
try:
return datetime.fromisoformat(str(run_at)).isoformat()
except Exception as exc:
raise CronServiceError("run_at must be ISO datetime") from exc
def _merge_active_agent_updates(self, job, payload: dict, updates: dict) -> None:
merged_payload = dict(job.payload) if isinstance(job.payload, dict) else {}
if "payload" in payload and isinstance(payload.get("payload"), dict):
merged_payload.update(payload["payload"])
if "session" in payload:
session = str(payload.get("session") or "").strip()
if session:
merged_payload["session"] = session
else:
merged_payload.pop("session", None)
self._merge_note(payload, job, merged_payload, updates)
next_run_once = (
bool(payload.get("run_once"))
if "run_once" in payload
else bool(job.run_once)
)
next_cron_expression = (
payload.get("cron_expression")
if "cron_expression" in payload
else job.cron_expression
)
if next_cron_expression is not None:
next_cron_expression = str(next_cron_expression).strip() or None
run_at_raw = (
payload.get("run_at")
if "run_at" in payload
else merged_payload.get("run_at")
)
run_at_iso = self._normalize_run_at_iso(run_at_raw)
if next_run_once:
if not run_at_iso:
raise CronServiceError("run_at is required when run_once=true")
next_cron_expression = None
merged_payload["run_at"] = run_at_iso
else:
if not next_cron_expression:
raise CronServiceError(
"cron_expression is required when run_once=false"
)
merged_payload.pop("run_at", None)
updates["run_once"] = next_run_once
updates["cron_expression"] = next_cron_expression
updates["payload"] = merged_payload
@staticmethod
def _merge_note(
payload: dict,
job,
merged_payload: dict,
updates: dict,
) -> None:
note_updated = False
if "note" in payload:
note = str(payload.get("note") or "").strip()
if not note:
raise CronServiceError("note cannot be empty")
merged_payload["note"] = note
updates["description"] = note
note_updated = True
elif "description" in payload:
description = str(payload.get("description") or "").strip()
if not description:
raise CronServiceError("description cannot be empty")
updates["description"] = description
merged_payload["note"] = description
note_updated = True
if not note_updated and updates.get("description") is None:
existing_note = str(
merged_payload.get("note") or job.description or ""
).strip()
if existing_note:
merged_payload["note"] = existing_note
@staticmethod
def _merge_generic_updates(payload: dict, updates: dict) -> None:
if "cron_expression" in payload:
cron_expression = str(payload.get("cron_expression") or "").strip()
if not cron_expression:
raise CronServiceError("cron_expression cannot be empty")
updates["cron_expression"] = cron_expression
if "description" in payload:
description = str(payload.get("description") or "").strip()
updates["description"] = description or None

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from astrbot.core import file_token_service
class FileServiceError(Exception):
pass
class FileService:
async def resolve_token_file(self, file_token: str) -> str:
try:
return await file_token_service.handle_file(file_token)
except (FileNotFoundError, KeyError) as exc:
raise FileServiceError(str(exc)) from exc

View File

@@ -0,0 +1,863 @@
from __future__ import annotations
import asyncio
import traceback
import uuid
from pathlib import Path
from typing import Any
import aiofiles
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.dashboard.utils import generate_tsne_visualization
class KnowledgeBaseServiceError(Exception):
pass
class KnowledgeBaseService:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
self.core_lifecycle = core_lifecycle
self.upload_progress: dict[str, dict[str, Any]] = {}
self.upload_tasks: dict[str, dict[str, Any]] = {}
@staticmethod
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}
def get_kb_manager(self):
return self.core_lifecycle.kb_manager
def init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": None,
"error": None,
}
def set_task_result(
self,
task_id: str,
status: str,
result: Any = None,
error: str | None = None,
) -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": result,
"error": error,
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = status
def update_progress(
self,
task_id: str,
*,
status: str | None = None,
file_index: int | None = None,
file_name: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
) -> None:
if task_id not in self.upload_progress:
return
progress = self.upload_progress[task_id]
if status is not None:
progress["status"] = status
if file_index is not None:
progress["file_index"] = file_index
if file_name is not None:
progress["file_name"] = file_name
if stage is not None:
progress["stage"] = stage
if current is not None:
progress["current"] = current
if total is not None:
progress["total"] = total
def make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
async def _callback(stage: str, current: int, total: int) -> None:
self.update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage=stage,
current=current,
total=total,
)
return _callback
@staticmethod
def format_failed_doc_error(file_name: str, error: Exception) -> str:
message = str(error).strip() or "上传失败:发生未知错误。"
if message.startswith(file_name):
return message
return f"{file_name}: {message}"
async def background_upload_task(
self,
task_id: str,
kb_helper,
files_to_upload: list[dict[str, Any]],
chunk_size: int,
chunk_overlap: int,
batch_size: int,
tasks_limit: int,
max_retries: int,
) -> None:
try:
self.init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(files_to_upload),
"stage": "waiting",
"current": 0,
"total": 100,
}
uploaded_docs = []
failed_docs = []
for file_idx, file_info in enumerate(files_to_upload):
try:
self.update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_info["file_name"],
stage="parsing",
current=0,
total=100,
)
progress_callback = self.make_progress_callback(
task_id, file_idx, file_info["file_name"]
)
doc = await kb_helper.upload_document(
file_name=file_info["file_name"],
file_content=file_info["file_content"],
file_type=file_info["file_type"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
uploaded_docs.append(doc.model_dump())
except Exception as exc:
logger.error(f"上传文档 {file_info['file_name']} 失败: {exc}")
failed_docs.append(
{
"file_name": file_info["file_name"],
"error": self.format_failed_doc_error(
file_info["file_name"], exc
),
},
)
self.set_task_result(
task_id,
"completed",
result={
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(files_to_upload),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
},
)
except Exception as exc:
logger.error(f"后台上传任务 {task_id} 失败: {exc}")
logger.error(traceback.format_exc())
self.set_task_result(task_id, "failed", error=str(exc))
async def background_import_task(
self,
task_id: str,
kb_helper,
documents: list[dict[str, Any]],
batch_size: int,
tasks_limit: int,
max_retries: int,
) -> None:
try:
self.init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(documents),
"stage": "waiting",
"current": 0,
"total": 100,
}
uploaded_docs = []
failed_docs = []
for file_idx, doc_info in enumerate(documents):
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
chunks = doc_info.get("chunks", [])
try:
self.update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage="importing",
current=0,
total=100,
)
progress_callback = self.make_progress_callback(
task_id, file_idx, file_name
)
doc = await kb_helper.upload_document(
file_name=file_name,
file_content=None,
file_type=doc_info.get("file_type")
or (
file_name.rsplit(".", 1)[-1].lower()
if "." in file_name
else "txt"
),
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=chunks,
)
uploaded_docs.append(doc.model_dump())
except Exception as exc:
logger.error(f"导入文档 {file_name} 失败: {exc}")
failed_docs.append(
{
"file_name": file_name,
"error": self.format_failed_doc_error(file_name, exc),
},
)
self.set_task_result(
task_id,
"completed",
result={
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(documents),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
},
)
except Exception as exc:
logger.error(f"后台导入任务 {task_id} 失败: {exc}")
logger.error(traceback.format_exc())
self.set_task_result(task_id, "failed", error=str(exc))
async def list_kbs(self, *, page: int, page_size: int) -> dict[str, Any]:
kb_manager = self.get_kb_manager()
kbs = await kb_manager.list_kbs()
kb_list = []
for kb in kbs:
kb_dict = kb.model_dump()
kb_helper = await kb_manager.get_kb(kb.kb_id)
if kb_helper and kb_helper.init_error:
kb_dict["init_error"] = kb_helper.init_error
kb_list.append(kb_dict)
return {"items": kb_list, "page": page, "page_size": page_size}
async def list_kbs_from_legacy_query(self, *, page, page_size) -> dict[str, Any]:
return await self.list_kbs(
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 20),
)
async def create_kb(self, data: object) -> tuple[dict[str, Any], str]:
kb_manager = self.get_kb_manager()
payload = self._payload(data)
kb_name = payload.get("kb_name")
if not kb_name:
raise KnowledgeBaseServiceError("知识库名称不能为空")
embedding_provider_id = payload.get("embedding_provider_id")
rerank_provider_id = payload.get("rerank_provider_id")
if not embedding_provider_id:
raise KnowledgeBaseServiceError("缺少参数 embedding_provider_id")
provider = await kb_manager.provider_manager.get_provider_by_id(
embedding_provider_id,
)
if not provider or not isinstance(provider, EmbeddingProvider):
raise KnowledgeBaseServiceError(
f"嵌入模型不存在或类型错误({type(provider)})"
)
try:
vec = await provider.get_embedding("astrbot")
if len(vec) != provider.get_dim():
raise ValueError(
f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {provider.get_dim()}",
)
except Exception as exc:
raise KnowledgeBaseServiceError(f"测试嵌入模型失败: {exc!s}") from exc
if rerank_provider_id:
rerank_provider: RerankProvider = (
await kb_manager.provider_manager.get_provider_by_id(
rerank_provider_id,
)
)
if not rerank_provider:
raise KnowledgeBaseServiceError("重排序模型不存在")
try:
result = await rerank_provider.rerank(
query="astrbot",
documents=["astrbot knowledge base"],
)
if not result:
raise ValueError("重排序模型返回结果异常")
except Exception as exc:
raise KnowledgeBaseServiceError(
f"测试重排序模型失败: {exc!s},请检查平台日志输出。"
) from exc
kb_helper = await kb_manager.create_kb(
kb_name=kb_name,
description=payload.get("description"),
emoji=payload.get("emoji"),
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=payload.get("chunk_size"),
chunk_overlap=payload.get("chunk_overlap"),
top_k_dense=payload.get("top_k_dense"),
top_k_sparse=payload.get("top_k_sparse"),
top_m_final=payload.get("top_m_final"),
)
return kb_helper.kb.model_dump(), "创建知识库成功"
async def get_kb(self, kb_id: str | None) -> dict[str, Any]:
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
return kb_helper.kb.model_dump()
async def get_kb_from_legacy_query(self, kb_id: str | None) -> dict[str, Any]:
return await self.get_kb(kb_id)
async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
payload = self._payload(data)
kb_id = payload.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
update_keys = [
"kb_name",
"description",
"emoji",
"embedding_provider_id",
"rerank_provider_id",
"chunk_size",
"chunk_overlap",
"top_k_dense",
"top_k_sparse",
"top_m_final",
]
if all(payload.get(key) is None for key in update_keys):
raise KnowledgeBaseServiceError("至少需要提供一个更新字段")
kb_helper = await self.get_kb_manager().update_kb(
kb_id=kb_id,
kb_name=payload.get("kb_name"),
description=payload.get("description"),
emoji=payload.get("emoji"),
embedding_provider_id=payload.get("embedding_provider_id"),
rerank_provider_id=payload.get("rerank_provider_id"),
chunk_size=payload.get("chunk_size"),
chunk_overlap=payload.get("chunk_overlap"),
top_k_dense=payload.get("top_k_dense"),
top_k_sparse=payload.get("top_k_sparse"),
top_m_final=payload.get("top_m_final"),
)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
return kb_helper.kb.model_dump(), "更新知识库成功"
async def delete_kb(self, data: object) -> tuple[None, str]:
payload = self._payload(data)
kb_id = payload.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
success = await self.get_kb_manager().delete_kb(kb_id)
if not success:
raise KnowledgeBaseServiceError("知识库不存在")
return None, "删除知识库成功"
async def get_kb_stats(self, kb_id: str | None) -> dict[str, Any]:
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
kb = kb_helper.kb
return {
"kb_id": kb.kb_id,
"kb_name": kb.kb_name,
"doc_count": kb.doc_count,
"chunk_count": kb.chunk_count,
"created_at": kb.created_at.isoformat(),
"updated_at": kb.updated_at.isoformat(),
}
async def get_kb_stats_from_legacy_query(
self,
kb_id: str | None,
) -> dict[str, Any]:
return await self.get_kb_stats(kb_id)
async def list_documents(
self,
*,
kb_id: str | None,
page: int,
page_size: int,
) -> dict[str, Any]:
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
offset = (page - 1) * page_size
doc_list = await kb_helper.list_documents(offset=offset, limit=page_size)
return {
"items": [doc.model_dump() for doc in doc_list],
"page": page,
"page_size": page_size,
}
async def list_documents_from_legacy_query(
self,
*,
kb_id: str | None,
page,
page_size,
) -> dict[str, Any]:
return await self.list_documents(
kb_id=kb_id,
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 100),
)
async def upload_document(
self,
*,
content_type: str | None,
form_data,
files,
) -> dict[str, Any]:
if content_type and "multipart/form-data" not in content_type:
raise KnowledgeBaseServiceError("Content-Type 须为 multipart/form-data")
kb_id = form_data.get("kb_id")
chunk_size = int(form_data.get("chunk_size", 512))
chunk_overlap = int(form_data.get("chunk_overlap", 50))
batch_size = int(form_data.get("batch_size", 32))
tasks_limit = int(form_data.get("tasks_limit", 3))
max_retries = int(form_data.get("max_retries", 3))
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
file_list = []
for key in files.keys():
if key == "file" or key.startswith("file") or key == "files[]":
file_list.extend(files.getlist(key))
if not file_list:
raise KnowledgeBaseServiceError("缺少文件")
if len(file_list) > 10:
raise KnowledgeBaseServiceError("最多只能上传10个文件")
files_to_upload = []
for file in file_list:
file_name = file.filename
temp_file_path = (
Path(get_astrbot_temp_path()) / f"kb_upload_{uuid.uuid4()}_{file_name}"
)
await file.save(temp_file_path)
try:
async with aiofiles.open(temp_file_path, "rb") as file_obj:
file_content = await file_obj.read()
file_type = (
file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
)
files_to_upload.append(
{
"file_name": file_name,
"file_content": file_content,
"file_type": file_type,
},
)
finally:
temp_file_path.unlink(missing_ok=True)
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
task_id = str(uuid.uuid4())
self.init_task(task_id, status="pending")
asyncio.create_task(
self.background_upload_task(
task_id=task_id,
kb_helper=kb_helper,
files_to_upload=files_to_upload,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return {
"task_id": task_id,
"file_count": len(files_to_upload),
"message": "task created, processing in background",
}
@staticmethod
def validate_import_request(data: dict[str, Any]):
kb_id = data.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
documents = data.get("documents")
if not documents or not isinstance(documents, list):
raise KnowledgeBaseServiceError("缺少参数 documents 或格式错误")
for doc in documents:
if (
not isinstance(doc, dict)
or "file_name" not in doc
or "chunks" not in doc
):
raise KnowledgeBaseServiceError(
"文档格式错误,必须包含 file_name 和 chunks"
)
if not isinstance(doc["chunks"], list):
raise KnowledgeBaseServiceError("chunks 必须是列表")
if not all(
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
):
raise KnowledgeBaseServiceError("chunks 必须是非空字符串列表")
return (
kb_id,
documents,
data.get("batch_size", 32),
data.get("tasks_limit", 3),
data.get("max_retries", 3),
)
async def import_documents(self, data: object) -> dict[str, Any]:
payload = self._payload(data)
kb_id, documents, batch_size, tasks_limit, max_retries = (
self.validate_import_request(payload)
)
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
task_id = str(uuid.uuid4())
self.init_task(task_id, status="pending")
asyncio.create_task(
self.background_import_task(
task_id=task_id,
kb_helper=kb_helper,
documents=documents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return {
"task_id": task_id,
"doc_count": len(documents),
"message": "import task created, processing in background",
}
def get_upload_progress(self, task_id: str | None) -> dict[str, Any]:
if not task_id:
raise KnowledgeBaseServiceError("缺少参数 task_id")
if task_id not in self.upload_tasks:
raise KnowledgeBaseServiceError("找不到该任务")
task_info = self.upload_tasks[task_id]
status = task_info["status"]
response_data = {
"task_id": task_id,
"status": status,
}
if status == "processing" and task_id in self.upload_progress:
response_data["progress"] = self.upload_progress[task_id]
if status == "completed":
response_data["result"] = task_info["result"]
if status == "failed":
response_data["error"] = task_info["error"]
return response_data
def get_upload_progress_from_legacy_query(
self,
task_id: str | None,
) -> dict[str, Any]:
return self.get_upload_progress(task_id)
async def get_document(
self,
*,
kb_id: str | None,
doc_id: str | None,
) -> dict[str, Any]:
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
if not doc_id:
raise KnowledgeBaseServiceError("缺少参数 doc_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
doc = await kb_helper.get_document(doc_id)
if not doc:
raise KnowledgeBaseServiceError("文档不存在")
return doc.model_dump()
async def get_document_from_legacy_query(
self,
*,
kb_id: str | None,
doc_id: str | None,
) -> dict[str, Any]:
return await self.get_document(kb_id=kb_id, doc_id=doc_id)
async def delete_document(self, data: object) -> tuple[None, str]:
payload = self._payload(data)
kb_id = payload.get("kb_id")
doc_id = payload.get("doc_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
if not doc_id:
raise KnowledgeBaseServiceError("缺少参数 doc_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
await kb_helper.delete_document(doc_id)
return None, "删除文档成功"
async def delete_chunk(self, data: object) -> tuple[None, str]:
payload = self._payload(data)
kb_id = payload.get("kb_id")
chunk_id = payload.get("chunk_id")
doc_id = payload.get("doc_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
if not chunk_id:
raise KnowledgeBaseServiceError("缺少参数 chunk_id")
if not doc_id:
raise KnowledgeBaseServiceError("缺少参数 doc_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
await kb_helper.delete_chunk(chunk_id, doc_id)
return None, "删除文本块成功"
async def list_chunks(
self,
*,
kb_id: str | None,
doc_id: str | None,
page: int,
page_size: int,
) -> dict[str, Any]:
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
if not doc_id:
raise KnowledgeBaseServiceError("缺少参数 doc_id")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
offset = (page - 1) * page_size
return {
"items": await kb_helper.get_chunks_by_doc_id(
doc_id=doc_id,
offset=offset,
limit=page_size,
),
"page": page,
"page_size": page_size,
"total": await kb_helper.get_chunk_count_by_doc_id(doc_id),
}
async def list_chunks_from_legacy_query(
self,
*,
kb_id: str | None,
doc_id: str | None,
page,
page_size,
) -> dict[str, Any]:
return await self.list_chunks(
kb_id=kb_id,
doc_id=doc_id,
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 100),
)
async def retrieve(self, data: object) -> dict[str, Any]:
payload = self._payload(data)
query = payload.get("query")
kb_names = payload.get("kb_names")
debug = payload.get("debug", False)
if not query:
raise KnowledgeBaseServiceError("缺少参数 query")
if not kb_names or not isinstance(kb_names, list):
raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误")
top_k = payload.get("top_k", 5)
kb_manager = self.get_kb_manager()
results = await kb_manager.retrieve(
query=query,
kb_names=kb_names,
top_m_final=top_k,
)
result_list = results["results"] if results else []
response_data = {
"results": result_list,
"total": len(result_list),
"query": query,
}
if debug:
try:
img_base64 = await generate_tsne_visualization(
query,
kb_names,
kb_manager,
)
if img_base64:
response_data["visualization"] = img_base64
except Exception as exc:
logger.error(f"生成 t-SNE 可视化失败: {exc}")
logger.error(traceback.format_exc())
response_data["visualization_error"] = str(exc)
return response_data
async def upload_document_from_url(self, data: object) -> dict[str, Any]:
payload = self._payload(data)
kb_id = payload.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
url = payload.get("url")
if not url:
raise KnowledgeBaseServiceError("缺少参数 url")
kb_helper = await self.get_kb_manager().get_kb(kb_id)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
task_id = str(uuid.uuid4())
self.init_task(task_id, status="pending")
asyncio.create_task(
self.background_upload_from_url_task(
task_id=task_id,
kb_helper=kb_helper,
url=url,
chunk_size=payload.get("chunk_size", 512),
chunk_overlap=payload.get("chunk_overlap", 50),
batch_size=payload.get("batch_size", 32),
tasks_limit=payload.get("tasks_limit", 3),
max_retries=payload.get("max_retries", 3),
enable_cleaning=payload.get("enable_cleaning", False),
cleaning_provider_id=payload.get("cleaning_provider_id"),
),
)
return {
"task_id": task_id,
"url": url,
"message": "URL upload task created, processing in background",
}
async def background_upload_from_url_task(
self,
task_id: str,
kb_helper,
url: str,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
tasks_limit: int,
max_retries: int,
enable_cleaning: bool,
cleaning_provider_id: str | None,
) -> None:
try:
self.init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": 1,
"file_name": f"URL: {url}",
"stage": "extracting",
"current": 0,
"total": 100,
}
progress_callback = self.make_progress_callback(task_id, 0, f"URL: {url}")
doc = await kb_helper.upload_from_url(
url=url,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
enable_cleaning=enable_cleaning,
cleaning_provider_id=cleaning_provider_id,
)
self.set_task_result(
task_id,
"completed",
result={
"task_id": task_id,
"uploaded": [doc.model_dump()],
"failed": [],
"total": 1,
"success_count": 1,
"failed_count": 0,
},
)
except Exception as exc:
logger.error(f"后台上传URL任务 {task_id} 失败: {exc}")
logger.error(traceback.format_exc())
self.set_task_result(task_id, "failed", error=str(exc))
@staticmethod
def _to_int(value, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
__all__ = ["KnowledgeBaseService", "KnowledgeBaseServiceError"]

View File

@@ -0,0 +1,982 @@
from __future__ import annotations
import asyncio
import base64
import json
import os
import re
import time
import uuid
import wave
from collections.abc import Awaitable, Callable
from typing import Any
import jwt
from astrbot import logger
from astrbot.core import sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_webchat_message_parts,
create_attachment_part_from_existing_file,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from astrbot.dashboard.services.chat_service import (
BotMessageAccumulator,
build_bot_history_content,
collect_plain_text_from_message_parts,
)
SendJson = Callable[[dict], Awaitable[None]]
ReceiveJson = Callable[[], Awaitable[dict]]
CloseWebSocket = Callable[[int, str], Awaitable[None]]
class LiveChatAuthError(Exception):
pass
class LiveChatSession:
"""Live Chat 会话管理器"""
def __init__(self, session_id: str, username: str) -> None:
self.session_id = session_id
self.username = username
self.conversation_id = str(uuid.uuid4())
self.is_speaking = False
self.is_processing = False
self.should_interrupt = False
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
self.chat_subscriptions: dict[str, str] = {}
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
self.ws_send_lock = asyncio.Lock()
def start_speaking(self, stamp: str) -> None:
self.is_speaking = True
self.current_stamp = stamp
self.audio_frames = []
logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}")
def add_audio_frame(self, data: bytes) -> None:
if self.is_speaking:
self.audio_frames.append(data)
async def end_speaking(self, stamp: str) -> tuple[str | None, float]:
start_time = time.time()
if not self.is_speaking or stamp != self.current_stamp:
logger.warning(
f"[Live Chat] stamp 不匹配或未在说话状态: {stamp} vs {self.current_stamp}"
)
return None, 0.0
self.is_speaking = False
if not self.audio_frames:
logger.warning("[Live Chat] 没有音频帧数据")
return None, 0.0
try:
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
with wave.open(audio_path, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(16000)
for frame in self.audio_frames:
wav_file.writeframes(frame)
self.temp_audio_path = audio_path
logger.info(
f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes"
)
return audio_path, time.time() - start_time
except Exception as exc:
logger.error(f"[Live Chat] 组装 WAV 文件失败: {exc}", exc_info=True)
return None, 0.0
def cleanup(self) -> None:
if self.temp_audio_path and os.path.exists(self.temp_audio_path):
try:
os.remove(self.temp_audio_path)
logger.debug(f"[Live Chat] 已删除临时文件: {self.temp_audio_path}")
except Exception as exc:
logger.warning(f"[Live Chat] 删除临时文件失败: {exc}")
self.temp_audio_path = None
class LiveChatService:
def __init__(
self,
db: Any,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
self.db = db
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.plugin_manager = core_lifecycle.plugin_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.sessions: dict[str, LiveChatSession] = {}
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.attachments_dir, exist_ok=True)
def authenticate_token(
self,
token: str | None,
jwt_secret: str | None = None,
) -> str:
if not token:
raise LiveChatAuthError("Missing authentication token")
jwt_secret = jwt_secret or self.config["dashboard"].get("jwt_secret")
try:
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
return payload["username"]
except jwt.ExpiredSignatureError as exc:
raise LiveChatAuthError("Token expired") from exc
except jwt.InvalidTokenError as exc:
raise LiveChatAuthError("Invalid token") from exc
def create_session(self, username: str) -> LiveChatSession:
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
session = LiveChatSession(session_id, username)
self.sessions[session_id] = session
return session
async def cleanup_session(self, session: LiveChatSession) -> None:
if session.session_id in self.sessions:
await self.cleanup_chat_subscriptions(session)
session.cleanup()
del self.sessions[session.session_id]
async def run_websocket_session(
self,
*,
token: str | None,
force_ct: str | None,
receive_json: ReceiveJson,
send_json: SendJson,
close: CloseWebSocket,
) -> None:
try:
username = self.authenticate_token(token)
except LiveChatAuthError as exc:
await close(1008, str(exc))
return
live_session = self.create_session(username)
logger.info(f"[Live Chat] WebSocket 连接建立: {username}")
try:
while True:
message = await receive_json()
ct = force_ct or message.get("ct", "live")
if ct == "chat":
await self.handle_chat_message(
live_session,
message,
send_json,
)
else:
await self.handle_live_message(
live_session,
message,
send_json,
)
except Exception as exc:
logger.error(f"[Live Chat] WebSocket 错误: {exc}", exc_info=True)
finally:
await self.cleanup_session(live_session)
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
async def create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
return await create_attachment_part_from_existing_file(
filename,
attach_type=attach_type,
insert_attachment=self.db.insert_attachment,
attachments_dir=self.attachments_dir,
fallback_dirs=[self.legacy_img_dir],
)
@staticmethod
def extract_web_search_refs(accumulated_text: str, accumulated_parts: list) -> dict:
supported = [
"web_search_baidu",
"web_search_tavily",
"web_search_bocha",
"web_search_brave",
]
web_search_results = {}
tool_call_parts = [
p
for p in accumulated_parts
if p.get("type") == "tool_call" and p.get("tool_calls")
]
for part in tool_call_parts:
for tool_call in part["tool_calls"]:
if tool_call.get("name") not in supported or not tool_call.get(
"result"
):
continue
try:
result_data = json.loads(tool_call["result"])
for item in result_data.get("results", []):
if idx := item.get("index"):
web_search_results[idx] = {
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
except (json.JSONDecodeError, KeyError):
pass
if not web_search_results:
return {}
ref_indices = {
match.strip() for match in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
}
used_refs = []
for ref_index in ref_indices:
if ref_index not in web_search_results:
continue
payload = {"index": ref_index, **web_search_results[ref_index]}
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
payload["favicon"] = favicon
used_refs.append(payload)
return {"used": used_refs} if used_refs else {}
async def save_bot_message(
self,
webchat_conv_id: str,
message_parts: list[dict],
agent_stats: dict,
refs: dict,
llm_checkpoint_id: str | None = None,
):
new_his = build_bot_history_content(
message_parts,
agent_stats=agent_stats,
refs=refs,
)
return await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
llm_checkpoint_id=llm_checkpoint_id,
)
async def send_chat_payload(
self,
session: LiveChatSession,
payload: dict,
send_json: SendJson,
) -> None:
async with session.ws_send_lock:
await send_json(payload)
async def forward_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
request_id: str,
send_json: SendJson,
) -> None:
back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id, chat_session_id
)
try:
while True:
result = await back_queue.get()
if not result:
continue
await self.send_chat_payload(
session, {"ct": "chat", **result}, send_json
)
except asyncio.CancelledError:
pass
except Exception as exc:
logger.error(
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {exc}",
exc_info=True,
)
finally:
webchat_queue_mgr.remove_back_queue(request_id)
if session.chat_subscriptions.get(chat_session_id) == request_id:
session.chat_subscriptions.pop(chat_session_id, None)
session.chat_subscription_tasks.pop(chat_session_id, None)
async def ensure_chat_subscription(
self,
session: LiveChatSession,
chat_session_id: str,
send_json: SendJson,
) -> str:
existing_request_id = session.chat_subscriptions.get(chat_session_id)
existing_task = session.chat_subscription_tasks.get(chat_session_id)
if existing_request_id and existing_task and not existing_task.done():
return existing_request_id
request_id = f"ws_sub_{uuid.uuid4().hex}"
session.chat_subscriptions[chat_session_id] = request_id
task = asyncio.create_task(
self.forward_chat_subscription(
session,
chat_session_id,
request_id,
send_json,
),
name=f"chat_ws_sub_{chat_session_id}",
)
session.chat_subscription_tasks[chat_session_id] = task
return request_id
async def cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
tasks = list(session.chat_subscription_tasks.values())
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
for request_id in list(session.chat_subscriptions.values()):
webchat_queue_mgr.remove_back_queue(request_id)
session.chat_subscriptions.clear()
session.chat_subscription_tasks.clear()
async def handle_chat_message(
self,
session: LiveChatSession,
message: dict,
send_json: SendJson,
) -> None:
msg_type = message.get("t")
if msg_type == "bind":
chat_session_id = message.get("session_id")
if not isinstance(chat_session_id, str) or not chat_session_id:
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "session_id is required",
"code": "INVALID_MESSAGE_FORMAT",
},
send_json,
)
return
request_id = await self.ensure_chat_subscription(
session, chat_session_id, send_json
)
await self.send_chat_payload(
session,
{
"ct": "chat",
"type": "session_bound",
"session_id": chat_session_id,
"message_id": request_id,
},
send_json,
)
return
if msg_type == "interrupt":
session.should_interrupt = True
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "INTERRUPTED",
"code": "INTERRUPTED",
},
send_json,
)
return
if msg_type != "send":
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"Unsupported message type: {msg_type}",
"code": "INVALID_MESSAGE_FORMAT",
},
send_json,
)
return
if session.is_processing:
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Session is busy",
"code": "PROCESSING_ERROR",
},
send_json,
)
return
payload = message.get("message")
session_id = message.get("session_id") or session.session_id
message_id = message.get("message_id") or str(uuid.uuid4())
selected_provider = message.get("selected_provider")
selected_model = message.get("selected_model")
selected_stt_provider = message.get("selected_stt_provider")
selected_tts_provider = message.get("selected_tts_provider")
persona_prompt = message.get("persona_prompt")
show_reasoning = message.get("show_reasoning")
enable_streaming = message.get("enable_streaming", True)
if not isinstance(payload, list):
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "message must be list",
"code": "INVALID_MESSAGE_FORMAT",
},
send_json,
)
return
message_parts = await self.build_chat_message_parts(payload)
has_content = webchat_message_parts_have_content(message_parts)
if not has_content:
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": "Message content is empty",
"code": "INVALID_MESSAGE_FORMAT",
},
send_json,
)
return
await self.ensure_chat_subscription(session, session_id, send_json)
session.is_processing = True
session.should_interrupt = False
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
llm_checkpoint_id = str(uuid.uuid4())
try:
pending_bot_message_flusher = None
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
session.username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"selected_stt_provider": selected_stt_provider,
"selected_tts_provider": selected_tts_provider,
"persona_prompt": persona_prompt,
"show_reasoning": show_reasoning,
"enable_streaming": enable_streaming,
"message_id": message_id,
"llm_checkpoint_id": llm_checkpoint_id,
},
),
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
saved_user_record = await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=session.username,
sender_name=session.username,
llm_checkpoint_id=llm_checkpoint_id,
)
await self.send_chat_payload(
session,
{
"ct": "chat",
"type": "user_message_saved",
"data": {
"id": saved_user_record.id,
"created_at": to_utc_isoformat(saved_user_record.created_at),
"llm_checkpoint_id": llm_checkpoint_id,
},
},
send_json,
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
async def flush_pending_bot_message():
nonlocal message_accumulator, agent_stats, refs
if not (message_accumulator.has_content() or refs or agent_stats):
return None
message_parts_to_save = message_accumulator.build_message_parts(
include_pending_tool_calls=True
)
plain_text = collect_plain_text_from_message_parts(
message_parts_to_save
)
try:
extracted_refs = self.extract_web_search_refs(
plain_text,
message_parts_to_save,
)
except Exception as exc:
logger.exception(
f"[Live Chat] Failed to extract web search refs: {exc}",
exc_info=True,
)
extracted_refs = refs
saved_record = await self.save_bot_message(
session_id,
message_parts_to_save,
agent_stats,
extracted_refs,
llm_checkpoint_id,
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
return saved_record
pending_bot_message_flusher = flush_pending_bot_message
async def send_attachment_saved_event(part: dict | None) -> None:
if not part or not part.get("attachment_id") or not part.get("type"):
return
await self.send_chat_payload(
session,
{
"ct": "chat",
"type": "attachment_saved",
"data": {
"id": part["attachment_id"],
"type": part["type"],
},
},
send_json,
)
while True:
if session.should_interrupt:
session.should_interrupt = False
await flush_pending_bot_message()
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if result.get("message_id") and result.get("message_id") != message_id:
continue
result_text = result.get("data", "")
result_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
parsed_agent_stats = json.loads(result_text)
agent_stats = parsed_agent_stats
await self.send_chat_payload(
session,
{
"ct": "chat",
"type": "agent_stats",
"data": parsed_agent_stats,
},
send_json,
)
except Exception:
pass
continue
outgoing = {"ct": "chat", **result}
await self.send_chat_payload(session, outgoing, send_json)
if result_type == "plain":
message_accumulator.add_plain(
result_text,
chain_type=chain_type,
streaming=streaming,
)
elif result_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self.create_attachment_from_file(filename, "image")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
elif result_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self.create_attachment_from_file(filename, "record")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
elif result_type == "file":
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
part = await self.create_attachment_from_file(filename, "file")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
elif result_type == "video":
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
part = await self.create_attachment_from_file(filename, "video")
message_accumulator.add_attachment(part)
await send_attachment_saved_event(part)
should_save = False
if result_type == "end":
should_save = bool(
message_accumulator.has_content() or refs or agent_stats
)
elif (streaming and result_type == "complete") or not streaming:
if chain_type not in (
"tool_call",
"tool_call_result",
"agent_stats",
):
should_save = True
if should_save:
saved_record = await flush_pending_bot_message()
if saved_record:
await self.send_chat_payload(
session,
{
"ct": "chat",
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": to_utc_isoformat(
saved_record.created_at
),
"llm_checkpoint_id": llm_checkpoint_id,
},
},
send_json,
)
if result_type == "end":
break
except Exception as exc:
logger.error(f"[Live Chat] 处理 chat 消息失败: {exc}", exc_info=True)
await self.send_chat_payload(
session,
{
"ct": "chat",
"t": "error",
"data": f"处理失败: {str(exc)}",
"code": "PROCESSING_ERROR",
},
send_json,
)
finally:
try:
if pending_bot_message_flusher is not None:
await pending_bot_message_flusher()
except Exception as exc:
logger.exception(
f"[Live Chat] Failed to persist pending chat message: {exc}",
exc_info=True,
)
session.is_processing = False
webchat_queue_mgr.remove_back_queue(message_id)
async def build_chat_message_parts(self, message: list[dict]) -> list[dict]:
return await build_webchat_message_parts(
message,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=False,
)
async def handle_live_message(
self,
session: LiveChatSession,
message: dict,
send_json: SendJson,
) -> None:
msg_type = message.get("t")
if msg_type == "start_speaking":
stamp = message.get("stamp")
if not stamp:
logger.warning("[Live Chat] start_speaking 缺少 stamp")
return
session.start_speaking(stamp)
return
if msg_type == "speaking_part":
audio_data_b64 = message.get("data")
if not audio_data_b64:
return
try:
audio_data = base64.b64decode(audio_data_b64)
session.add_audio_frame(audio_data)
except Exception as exc:
logger.error(f"[Live Chat] 解码音频数据失败: {exc}")
return
if msg_type == "end_speaking":
stamp = message.get("stamp")
if not stamp:
logger.warning("[Live Chat] end_speaking 缺少 stamp")
return
audio_path, assemble_duration = await session.end_speaking(stamp)
if not audio_path:
await send_json({"t": "error", "data": "音频组装失败"})
return
await self.process_audio(session, audio_path, assemble_duration, send_json)
return
if msg_type == "interrupt":
session.should_interrupt = True
logger.info(f"[Live Chat] 用户打断: {session.username}")
async def process_audio(
self,
session: LiveChatSession,
audio_path: str,
assemble_duration: float,
send_json: SendJson,
) -> None:
try:
await send_json(
{"t": "metrics", "data": {"wav_assemble_time": assemble_duration}}
)
wav_assembly_finish_time = time.time()
session.is_processing = True
session.should_interrupt = False
ctx = self.plugin_manager.context
stt_provider = ctx.provider_manager.stt_provider_insts[0]
if not stt_provider:
logger.error("[Live Chat] STT Provider 未配置")
await send_json({"t": "error", "data": "语音识别服务未配置"})
return
await send_json({"t": "metrics", "data": {"stt": stt_provider.meta().type}})
user_text = await stt_provider.get_text(audio_path)
if not user_text:
logger.warning("[Live Chat] STT 识别结果为空")
return
logger.info(f"[Live Chat] STT 结果: {user_text}")
await send_json(
{
"t": "user_msg",
"data": {"text": user_text, "ts": int(time.time() * 1000)},
}
)
conversation_id = session.conversation_id
queue = webchat_queue_mgr.get_or_create_queue(conversation_id)
message_id = str(uuid.uuid4())
payload = {
"message_id": message_id,
"message": [{"type": "plain", "text": user_text}],
"action_type": "live",
}
await queue.put((session.username, conversation_id, payload))
back_queue = webchat_queue_mgr.get_or_create_back_queue(
message_id, conversation_id
)
bot_text = ""
audio_playing = False
try:
while True:
if session.should_interrupt:
logger.info("[Live Chat] 检测到用户打断")
await send_json({"t": "stop_play"})
await self.save_interrupted_message(
session, user_text, bot_text
)
while not back_queue.empty():
try:
back_queue.get_nowait()
except asyncio.QueueEmpty:
break
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if not result:
continue
result_message_id = result.get("message_id")
if result_message_id != message_id:
logger.warning(
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
)
continue
result_type = result.get("type")
result_chain_type = result.get("chain_type")
data = result.get("data", "")
if result_chain_type == "agent_stats":
try:
stats = json.loads(data)
await send_json(
{
"t": "metrics",
"data": {
"llm_ttft": stats.get("time_to_first_token", 0),
"llm_total_time": stats.get("end_time", 0)
- stats.get("start_time", 0),
},
}
)
except Exception as exc:
logger.error(f"[Live Chat] 解析 AgentStats 失败: {exc}")
continue
if result_chain_type == "tts_stats":
try:
stats = json.loads(data)
await send_json(
{
"t": "metrics",
"data": stats,
}
)
except Exception as exc:
logger.error(f"[Live Chat] 解析 TTSStats 失败: {exc}")
continue
if result_type == "plain":
bot_text += data
elif result_type == "audio_chunk":
if not audio_playing:
audio_playing = True
logger.debug("[Live Chat] 开始播放音频流")
speak_to_first_frame_latency = (
time.time() - wav_assembly_finish_time
)
await send_json(
{
"t": "metrics",
"data": {
"speak_to_first_frame": speak_to_first_frame_latency
},
}
)
text = result.get("text")
if text:
await send_json(
{
"t": "bot_text_chunk",
"data": {"text": text},
}
)
await send_json(
{
"t": "response",
"data": data,
}
)
elif result_type in ["complete", "end"]:
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
if not audio_playing:
await send_json(
{
"t": "bot_msg",
"data": {
"text": bot_text,
"ts": int(time.time() * 1000),
},
}
)
await send_json({"t": "end"})
wav_to_tts_duration = time.time() - wav_assembly_finish_time
await send_json(
{
"t": "metrics",
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
}
)
break
finally:
webchat_queue_mgr.remove_back_queue(message_id)
except Exception as exc:
logger.error(f"[Live Chat] 处理音频失败: {exc}", exc_info=True)
await send_json({"t": "error", "data": f"处理失败: {str(exc)}"})
finally:
session.is_processing = False
session.should_interrupt = False
@staticmethod
async def save_interrupted_message(
session: LiveChatSession, user_text: str, bot_text: str
) -> None:
interrupted_text = bot_text + " [用户打断]"
logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}")
try:
timestamp = int(time.time() * 1000)
logger.info(
f"[Live Chat] 用户消息: {user_text} (session: {session.session_id}, ts: {timestamp})"
)
if bot_text:
logger.info(
f"[Live Chat] Bot 消息(打断): {interrupted_text} (session: {session.session_id}, ts: {timestamp})"
)
except Exception as exc:
logger.error(f"[Live Chat] 记录消息失败: {exc}", exc_info=True)
__all__ = ["LiveChatAuthError", "LiveChatService", "LiveChatSession"]

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
import asyncio
import json
import time
from collections.abc import AsyncGenerator
from astrbot.core import LogBroker, logger
from astrbot.core.config.astrbot_config import AstrBotConfig
class LogServiceError(Exception):
pass
class LogService:
def __init__(self, log_broker: LogBroker, config: AstrBotConfig) -> None:
self.log_broker = log_broker
self.config = config
@staticmethod
def format_log_sse(log: dict, ts: float) -> str:
payload = {
"type": "log",
**log,
}
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
async def replay_cached_logs(self, last_event_id: str) -> AsyncGenerator[str, None]:
try:
last_ts = float(last_event_id)
cached_logs = list(self.log_broker.log_cache)
for log_item in cached_logs:
log_ts = float(log_item.get("time", 0))
if log_ts > last_ts:
yield self.format_log_sse(log_item, log_ts)
except ValueError:
pass
except Exception as exc:
logger.error(f"Log SSE 补发历史错误: {exc}")
async def stream_log_events(
self, last_event_id: str | None
) -> AsyncGenerator[str, None]:
queue = None
try:
if last_event_id:
async for event in self.replay_cached_logs(last_event_id):
yield event
queue = self.log_broker.register()
while True:
message = await queue.get()
current_ts = message.get("time", time.time())
yield self.format_log_sse(message, current_ts)
except asyncio.CancelledError:
pass
except Exception as exc:
logger.error(f"Log SSE 连接错误: {exc}")
finally:
if queue:
self.log_broker.unregister(queue)
def get_log_history(self) -> dict:
try:
return {"logs": list(self.log_broker.log_cache)}
except Exception as exc:
logger.error(f"获取日志历史失败: {exc}")
raise LogServiceError(f"获取日志历史失败: {exc}") from exc
def get_trace_settings(self) -> dict:
try:
return {"trace_enable": self.config.get("trace_enable", True)}
except Exception as exc:
logger.error(f"获取 Trace 设置失败: {exc}")
raise LogServiceError(f"获取 Trace 设置失败: {exc}") from exc
def update_trace_settings(self, payload: dict | None) -> str:
try:
if payload is None:
raise LogServiceError("请求数据为空")
trace_enable = payload.get("trace_enable")
if trace_enable is not None:
self.config["trace_enable"] = bool(trace_enable)
self.config.save_config()
return "Trace 设置已更新"
except LogServiceError:
raise
except Exception as exc:
logger.error(f"更新 Trace 设置失败: {exc}")
raise LogServiceError(f"更新 Trace 设置失败: {exc}") from exc
def update_trace_settings_from_legacy_payload(self, data: object) -> str:
return self.update_trace_settings(data if isinstance(data, dict) else None)

View File

@@ -0,0 +1,650 @@
from __future__ import annotations
import asyncio
import json
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_session import MessageSesion
from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_message_chain_from_payload,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from astrbot.dashboard.services.api_key_service import ApiKeyService
from astrbot.dashboard.services.auth_service import ALL_OPEN_API_SCOPES
from astrbot.dashboard.services.chat_service import (
BotMessageAccumulator,
collect_plain_text_from_message_parts,
)
SendJson = Callable[[dict], Awaitable[None]]
ReceiveJson = Callable[[], Awaitable[Any]]
CloseWebSocket = Callable[[int, str], Awaitable[None]]
class OpenApiServiceError(Exception):
pass
@dataclass
class OpenApiWebSocketChatBridge:
build_user_message_parts: Callable[[object], Awaitable[list]]
create_attachment_from_file: Callable[[str, str], Awaitable[Any]]
extract_web_search_refs: Callable[[str, list], dict]
insert_user_message: Callable[[str, str, list], Awaitable[None]]
save_bot_message: Callable[[str, list, dict, dict], Awaitable[Any]]
class OpenApiService:
def __init__(
self,
db: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
self.db = db
self.core_lifecycle = core_lifecycle
self.platform_manager = core_lifecycle.platform_manager
self.platform_history_mgr = getattr(
core_lifecycle,
"platform_message_history_manager",
None,
)
@staticmethod
def resolve_open_username(
raw_username: str | None,
) -> tuple[str | None, str | None]:
if raw_username is None:
return None, "Missing key: username"
username = str(raw_username).strip()
if not username:
return None, "username is empty"
return username, None
def get_chat_config_list(self) -> list[dict]:
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
result = []
for conf_info in conf_list:
conf_id = str(conf_info.get("id", "")).strip()
result.append(
{
"id": conf_id,
"name": str(conf_info.get("name", "")).strip(),
"path": str(conf_info.get("path", "")).strip(),
"is_default": conf_id == "default",
}
)
return result
@staticmethod
def resolve_chat_config_id(
post_data: dict,
conf_list: list[dict],
) -> tuple[str | None, str | None]:
raw_config_id = post_data.get("config_id")
raw_config_name = post_data.get("config_name")
config_id = str(raw_config_id).strip() if raw_config_id is not None else ""
config_name = (
str(raw_config_name).strip() if raw_config_name is not None else ""
)
if not config_id and not config_name:
return None, None
conf_map = {item["id"]: item for item in conf_list}
if config_id:
if config_id not in conf_map:
return None, f"config_id not found: {config_id}"
return config_id, None
if not config_name:
return None, "config_name is empty"
matched = [item for item in conf_list if item["name"] == config_name]
if not matched:
return None, f"config_name not found: {config_name}"
if len(matched) > 1:
return (
None,
f"config_name is ambiguous, please use config_id: {config_name}",
)
return matched[0]["id"], None
async def prepare_chat_send(
self,
post_data: dict,
conf_list: list[dict],
) -> tuple[str, str, str | None]:
effective_username, username_err = self.resolve_open_username(
post_data.get("username")
)
if username_err:
raise OpenApiServiceError(username_err)
if not effective_username:
raise OpenApiServiceError("Invalid username")
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
if not session_id:
session_id = str(uuid4())
post_data["session_id"] = session_id
ensure_session_err = await self.ensure_chat_session(
effective_username,
session_id,
)
if ensure_session_err:
raise OpenApiServiceError(ensure_session_err)
config_id, resolve_err = self.resolve_chat_config_id(post_data, conf_list)
if resolve_err:
raise OpenApiServiceError(resolve_err)
return effective_username, session_id, config_id
async def ensure_chat_session(
self,
username: str,
session_id: str,
) -> str | None:
session = await self.db.get_platform_session_by_id(session_id)
if session:
if session.creator != username:
return "session_id belongs to another username"
return None
try:
await self.db.create_platform_session(
creator=username,
platform_id="webchat",
session_id=session_id,
is_group=0,
)
except Exception as exc:
existing = await self.db.get_platform_session_by_id(session_id)
if existing and existing.creator == username:
return None
logger.error("Failed to create chat session %s: %s", session_id, exc)
return f"Failed to create session: {exc}"
return None
async def authenticate_api_key(
self, raw_key: str | None
) -> tuple[bool, str | None]:
if not raw_key:
return False, "Missing API key"
key_hash = ApiKeyService.hash_key(raw_key)
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
return False, "Invalid API key"
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
if "*" not in scopes and "chat" not in scopes:
return False, "Insufficient API key scope"
await self.db.touch_api_key(api_key.key_id)
return True, None
@staticmethod
async def send_chat_ws_error(
send_json: SendJson,
message: str,
code: str,
) -> None:
await send_json(
{
"type": "error",
"code": code,
"data": message,
}
)
async def run_chat_websocket(
self,
*,
raw_api_key: str | None,
receive_json: ReceiveJson,
send_json: SendJson,
close: CloseWebSocket,
conf_list: list[dict],
chat_bridge: OpenApiWebSocketChatBridge,
) -> None:
authed, auth_err = await self.authenticate_api_key(raw_api_key)
if not authed:
message = auth_err or "Unauthorized"
await self.send_chat_ws_error(send_json, message, "UNAUTHORIZED")
await close(1008, message)
return
async def send_error(message: str, code: str) -> None:
await self.send_chat_ws_error(send_json, message, code)
try:
while True:
message = await receive_json()
if not isinstance(message, dict):
await send_error(
"message must be an object",
"INVALID_MESSAGE",
)
continue
msg_type = message.get("t", "send")
if msg_type == "ping":
await send_json({"type": "pong"})
continue
if msg_type != "send":
await send_error(
f"Unsupported message type: {msg_type}",
"INVALID_MESSAGE",
)
continue
await self.handle_chat_ws_send(
post_data=message,
conf_list=conf_list,
chat_bridge=chat_bridge,
send_json=send_json,
send_error=send_error,
)
except Exception as exc:
logger.debug("Open API WS connection closed: %s", exc)
async def update_session_config_route(
self,
*,
username: str,
session_id: str,
config_id: str | None,
) -> str | None:
if not config_id:
return None
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
try:
if config_id == "default":
await self.core_lifecycle.umop_config_router.delete_route(umo)
else:
await self.core_lifecycle.umop_config_router.update_route(
umo, config_id
)
except Exception as exc:
logger.error(
"Failed to update chat config route for %s with %s: %s",
umo,
config_id,
exc,
exc_info=True,
)
return f"Failed to update chat config route: {exc}"
return None
async def insert_webchat_user_message(
self,
*,
session_id: str,
effective_username: str,
message_parts: list,
) -> None:
if self.platform_history_mgr is None:
raise OpenApiServiceError("Platform message history manager is unavailable")
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts},
sender_id=effective_username,
sender_name=effective_username,
)
@staticmethod
def get_chat_send_error_code(message: str) -> str:
if message in ("Missing key: username", "username is empty"):
return "BAD_USER"
if message.startswith("config_"):
return "CONFIG_ERROR"
if "session" in message:
return "SESSION_ERROR"
return "INVALID_MESSAGE"
async def handle_chat_ws_send(
self,
*,
post_data: dict,
conf_list: list[dict],
chat_bridge: OpenApiWebSocketChatBridge,
send_json: SendJson,
send_error: Callable[[str, str], Awaitable[None]],
) -> None:
message = post_data.get("message")
if message is None:
await send_error("Missing key: message", "INVALID_MESSAGE")
return
try:
(
effective_username,
session_id,
config_id,
) = await self.prepare_chat_send(
post_data,
conf_list,
)
except OpenApiServiceError as exc:
message = str(exc)
await send_error(message, self.get_chat_send_error_code(message))
return
config_err = await self.update_session_config_route(
username=effective_username,
session_id=session_id,
config_id=config_id,
)
if config_err:
await send_error(config_err, "CONFIG_ERROR")
return
message_parts = await chat_bridge.build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
await send_error(
"Message content is empty (reply only is not allowed)",
"INVALID_MESSAGE",
)
return
message_id = str(post_data.get("message_id") or uuid4())
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
effective_username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
)
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await chat_bridge.insert_user_message(
session_id,
effective_username,
message_parts_for_storage,
)
await send_json(
{
"type": "session_id",
"data": None,
"session_id": session_id,
"message_id": message_id,
}
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if "message_id" in result and result["message_id"] != message_id:
logger.warning("openapi ws stream message_id mismatch")
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
stats_info = {
"type": "agent_stats",
"data": json.loads(result_text),
}
await send_json(stats_info)
agent_stats = stats_info["data"]
except Exception:
pass
continue
await send_json(result)
if msg_type == "plain":
message_accumulator.add_plain(
result_text,
chain_type=chain_type,
streaming=streaming,
)
elif msg_type in {"image", "record", "file", "video"}:
filename = str(result_text).replace(f"[{msg_type.upper()}]", "")
part = await chat_bridge.create_attachment_from_file(
filename,
msg_type,
)
message_accumulator.add_attachment(part)
should_save = False
if msg_type == "end":
should_save = bool(
message_accumulator.has_content() or refs or agent_stats
)
elif (streaming and msg_type == "complete") or not streaming:
if chain_type not in ("tool_call", "tool_call_result"):
should_save = True
if should_save:
message_parts_to_save = message_accumulator.build_message_parts(
include_pending_tool_calls=True
)
plain_text = collect_plain_text_from_message_parts(
message_parts_to_save
)
try:
refs = chat_bridge.extract_web_search_refs(
plain_text,
message_parts_to_save,
)
except Exception as exc:
logger.exception(
f"Open API WS failed to extract web search refs: {exc}",
exc_info=True,
)
saved_record = await chat_bridge.save_bot_message(
session_id,
message_parts_to_save,
agent_stats,
refs,
)
if saved_record:
await send_json(
{
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": to_utc_isoformat(
saved_record.created_at
),
},
"session_id": session_id,
}
)
message_accumulator = BotMessageAccumulator()
agent_stats = {}
refs = {}
if msg_type == "end":
break
except Exception as exc:
logger.exception(f"Open API WS chat failed: {exc}", exc_info=True)
await send_error(f"Failed to process message: {exc}", "PROCESSING_ERROR")
finally:
webchat_queue_mgr.remove_back_queue(message_id)
async def get_chat_sessions(
self,
*,
username: str,
page_raw,
page_size_raw,
platform_id: str | None,
) -> dict:
try:
page = int(page_raw)
page_size = int(page_size_raw)
except ValueError as exc:
raise OpenApiServiceError("page and page_size must be integers") from exc
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100:
page_size = 100
(
paginated_sessions,
total,
) = await self.db.get_platform_sessions_by_creator_paginated(
creator=username,
platform_id=platform_id,
page=page,
page_size=page_size,
exclude_project_sessions=True,
)
sessions_data = []
for item in paginated_sessions:
session = item["session"]
sessions_data.append(
{
"session_id": session.session_id,
"platform_id": session.platform_id,
"creator": session.creator,
"display_name": session.display_name,
"is_group": session.is_group,
"created_at": to_utc_isoformat(session.created_at),
"updated_at": to_utc_isoformat(session.updated_at),
}
)
return {
"sessions": sessions_data,
"page": page,
"page_size": page_size,
"total": total,
}
async def get_chat_sessions_from_legacy_query(
self,
*,
username: str | None,
page,
page_size,
platform_id: str | None,
) -> dict:
resolved_username, username_err = self.resolve_open_username(username)
if username_err:
raise OpenApiServiceError(username_err)
if not resolved_username:
raise OpenApiServiceError("Invalid username")
return await self.get_chat_sessions(
username=resolved_username,
page_raw=page,
page_size_raw=page_size,
platform_id=platform_id,
)
def get_chat_configs(self) -> dict:
return {"configs": self.get_chat_config_list()}
async def build_message_chain_from_payload(self, message_payload: str | list):
return await build_message_chain_from_payload(
message_payload,
get_attachment_by_id=self.db.get_attachment_by_id,
strict=True,
)
async def send_message(self, post_data: object) -> None:
payload = post_data if isinstance(post_data, dict) else {}
message_payload = payload.get("message", {})
umo = payload.get("umo")
if message_payload is None:
raise OpenApiServiceError("Missing key: message")
if not umo:
raise OpenApiServiceError("Missing key: umo")
try:
session = MessageSesion.from_str(str(umo))
except Exception as exc:
raise OpenApiServiceError(f"Invalid umo: {exc}") from exc
platform_id = session.platform_name
platform_inst = next(
(
inst
for inst in self.platform_manager.platform_insts
if inst.meta().id == platform_id
),
None,
)
if not platform_inst:
raise OpenApiServiceError(
f"Bot not found or not running for platform: {platform_id}"
)
try:
message_chain = await self.build_message_chain_from_payload(message_payload)
await platform_inst.send_by_session(session, message_chain)
except OpenApiServiceError:
raise
except ValueError as exc:
raise OpenApiServiceError(str(exc)) from exc
except Exception as exc:
logger.error(f"Open API send_message failed: {exc}", exc_info=True)
raise OpenApiServiceError(f"Failed to send message: {exc}") from exc
def get_bots(self) -> dict:
bot_ids = []
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
platform_id = platform.get("id") if isinstance(platform, dict) else None
if (
isinstance(platform_id, str)
and platform_id
and platform_id not in bot_ids
):
bot_ids.append(platform_id)
return {"bot_ids": bot_ids}

View File

@@ -0,0 +1,293 @@
from __future__ import annotations
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.sentinels import NOT_GIVEN
class PersonaServiceError(Exception):
pass
class PersonaService:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
self.persona_mgr = core_lifecycle.persona_mgr
async def list_personas(
self,
folder_id: str | None,
filter_by_folder: bool,
) -> list[dict]:
if filter_by_folder:
personas = await self.persona_mgr.get_personas_by_folder(
folder_id if folder_id else None
)
else:
personas = await self.persona_mgr.get_all_personas()
return [self.serialize_persona(persona) for persona in personas]
async def list_personas_from_legacy_query(
self,
*,
folder_id: str | None,
has_folder_id: bool,
) -> list[dict]:
return await self.list_personas(folder_id, has_folder_id)
async def get_persona_detail(self, data: object) -> dict:
payload = self._payload(data)
persona_id = payload.get("persona_id")
if not persona_id:
raise PersonaServiceError("缺少必要参数: persona_id")
persona = await self.persona_mgr.get_persona(persona_id)
if not persona:
raise PersonaServiceError("人格不存在")
return self.serialize_persona(persona)
async def create_persona(self, data: object) -> dict:
payload = self._payload(data)
persona_id = str(payload.get("persona_id", "")).strip()
system_prompt = str(payload.get("system_prompt", "")).strip()
begin_dialogs = payload.get("begin_dialogs", [])
tools = payload.get("tools")
skills = payload.get("skills")
custom_error_message = self._normalize_custom_error_message(
payload.get("custom_error_message")
)
folder_id = payload.get("folder_id")
sort_order = payload.get("sort_order", 0)
if not persona_id:
raise PersonaServiceError("人格ID不能为空")
if not system_prompt:
raise PersonaServiceError("系统提示词不能为空")
self._validate_begin_dialogs(begin_dialogs)
persona = await self.persona_mgr.create_persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs if begin_dialogs else None,
tools=tools if tools else None,
skills=skills if skills else None,
custom_error_message=custom_error_message,
folder_id=folder_id,
sort_order=sort_order,
)
return {
"message": "人格创建成功",
"persona": self.serialize_persona(persona, empty_lists_for_tools=True),
}
async def update_persona(self, data: object) -> dict:
payload = self._payload(data)
persona_id = payload.get("persona_id")
system_prompt = payload.get("system_prompt")
begin_dialogs = payload.get("begin_dialogs")
has_tools = "tools" in payload
tools = payload.get("tools")
has_skills = "skills" in payload
skills = payload.get("skills")
has_custom_error_message = "custom_error_message" in payload
custom_error_message = payload.get("custom_error_message")
if not persona_id:
raise PersonaServiceError("缺少必要参数: persona_id")
if has_custom_error_message:
custom_error_message = self._normalize_custom_error_message(
custom_error_message
)
if begin_dialogs is not None:
self._validate_begin_dialogs(begin_dialogs)
update_kwargs = {
"persona_id": persona_id,
"system_prompt": system_prompt,
"begin_dialogs": begin_dialogs,
}
if has_tools:
update_kwargs["tools"] = tools
if has_skills:
update_kwargs["skills"] = skills
if has_custom_error_message:
update_kwargs["custom_error_message"] = custom_error_message
await self.persona_mgr.update_persona(**update_kwargs)
return {"message": "人格更新成功"}
async def delete_persona(self, data: object) -> dict:
payload = self._payload(data)
persona_id = payload.get("persona_id")
if not persona_id:
raise PersonaServiceError("缺少必要参数: persona_id")
await self.persona_mgr.delete_persona(persona_id)
return {"message": "人格删除成功"}
async def move_persona(self, data: object) -> dict:
payload = self._payload(data)
persona_id = payload.get("persona_id")
folder_id = payload.get("folder_id")
if not persona_id:
raise PersonaServiceError("缺少必要参数: persona_id")
await self.persona_mgr.move_persona_to_folder(persona_id, folder_id)
return {"message": "人格移动成功"}
async def list_folders(self, parent_id: str | None) -> list[dict]:
folders = await self.persona_mgr.get_folders(parent_id)
return [self.serialize_folder(folder) for folder in folders]
async def list_folders_from_legacy_query(
self,
parent_id: str | None,
) -> list[dict]:
if parent_id == "":
parent_id = None
return await self.list_folders(parent_id)
async def get_folder_tree(self):
return await self.persona_mgr.get_folder_tree()
async def get_folder_detail(self, data: object) -> dict:
payload = self._payload(data)
folder_id = payload.get("folder_id")
if not folder_id:
raise PersonaServiceError("缺少必要参数: folder_id")
folder = await self.persona_mgr.get_folder(folder_id)
if not folder:
raise PersonaServiceError("文件夹不存在")
return self.serialize_folder(folder)
async def create_folder(self, data: object) -> dict:
payload = self._payload(data)
name = str(payload.get("name", "")).strip()
parent_id = payload.get("parent_id")
description = payload.get("description")
sort_order = payload.get("sort_order", 0)
if not name:
raise PersonaServiceError("文件夹名称不能为空")
folder = await self.persona_mgr.create_folder(
name=name,
parent_id=parent_id,
description=description,
sort_order=sort_order,
)
return {
"message": "文件夹创建成功",
"folder": self.serialize_folder(folder),
}
async def update_folder(self, data: object) -> dict:
payload = self._payload(data)
folder_id = payload.get("folder_id")
name = payload.get("name")
parent_id = payload.get("parent_id") if "parent_id" in payload else NOT_GIVEN
description = (
payload.get("description") if "description" in payload else NOT_GIVEN
)
sort_order = payload.get("sort_order")
if not folder_id:
raise PersonaServiceError("缺少必要参数: folder_id")
await self.persona_mgr.update_folder(
folder_id=folder_id,
name=name,
parent_id=parent_id,
description=description,
sort_order=sort_order,
)
return {"message": "文件夹更新成功"}
async def delete_folder(self, data: object) -> dict:
payload = self._payload(data)
folder_id = payload.get("folder_id")
if not folder_id:
raise PersonaServiceError("缺少必要参数: folder_id")
await self.persona_mgr.delete_folder(folder_id)
return {"message": "文件夹删除成功"}
async def reorder_items(self, data: object) -> dict:
payload = self._payload(data)
items = payload.get("items", [])
if not items:
raise PersonaServiceError("items 不能为空")
for item in items:
if not all(key in item for key in ("id", "type", "sort_order")):
raise PersonaServiceError(
"每个 item 必须包含 id, type, sort_order 字段"
)
if item["type"] not in ("persona", "folder"):
raise PersonaServiceError("type 字段必须是 'persona''folder'")
await self.persona_mgr.batch_update_sort_order(items)
return {"message": "排序更新成功"}
@staticmethod
def serialize_persona(persona, empty_lists_for_tools: bool = False) -> dict:
return {
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": (persona.tools or []) if empty_lists_for_tools else persona.tools,
"skills": (persona.skills or [])
if empty_lists_for_tools
else persona.skills,
"custom_error_message": persona.custom_error_message,
"folder_id": persona.folder_id,
"sort_order": persona.sort_order,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
}
@staticmethod
def serialize_folder(folder) -> dict:
return {
"folder_id": folder.folder_id,
"name": folder.name,
"parent_id": folder.parent_id,
"description": folder.description,
"sort_order": folder.sort_order,
"created_at": folder.created_at.isoformat() if folder.created_at else None,
"updated_at": folder.updated_at.isoformat() if folder.updated_at else None,
}
@staticmethod
def _normalize_custom_error_message(value):
if value is not None:
if not isinstance(value, str):
raise PersonaServiceError("自定义报错回复信息必须是字符串")
return value.strip() or None
return None
@staticmethod
def _validate_begin_dialogs(begin_dialogs) -> None:
if begin_dialogs and len(begin_dialogs) % 2 != 0:
raise PersonaServiceError("预设对话数量必须为偶数(用户和助手轮流对话)")
@staticmethod
def _payload(data: object) -> dict:
return data if isinstance(data, dict) else {}

View File

@@ -0,0 +1,218 @@
from __future__ import annotations
import secrets
import string
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform import Platform
from astrbot.core.platform.sources.dingtalk.app_registration import (
poll_dingtalk_app_registration_once,
request_dingtalk_app_registration,
)
from astrbot.core.platform.sources.lark.app_registration import (
poll_app_registration_once,
request_app_registration,
)
from astrbot.core.platform.sources.lark.bot_info import request_lark_bot_info
from astrbot.core.platform.sources.weixin_oc.login_registration import (
poll_weixin_oc_login_once,
request_weixin_oc_login_qr,
)
class PlatformServiceError(Exception):
def __init__(self, message: str, status_code: int = 500) -> None:
super().__init__(message)
self.status_code = status_code
def random_platform_id_suffix() -> str:
return "_" + "".join(secrets.choice(string.ascii_lowercase) for _ in range(4))
class PlatformService:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
self.platform_manager = core_lifecycle.platform_manager
async def handle_webhook_callback(self, webhook_uuid: str, request_obj):
platform_adapter = self.find_platform_by_uuid(webhook_uuid)
if not platform_adapter:
logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台")
raise PlatformServiceError("未找到对应平台", 404)
try:
return await platform_adapter.webhook_callback(request_obj)
except NotImplementedError as exc:
logger.error(
f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法"
)
raise PlatformServiceError("平台未支持统一 Webhook 模式", 500) from exc
except Exception as exc:
logger.error(f"处理 webhook 回调时发生错误: {exc}", exc_info=True)
raise PlatformServiceError("处理回调失败", 500) from exc
def find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
for platform in self.platform_manager.platform_insts:
if platform.config.get("webhook_uuid") == webhook_uuid:
if platform.unified_webhook():
return platform
return None
def get_platform_stats(self):
try:
return self.platform_manager.get_all_stats()
except Exception as exc:
logger.error(f"获取平台统计信息失败: {exc}", exc_info=True)
raise PlatformServiceError(f"获取统计信息失败: {exc}", 500) from exc
async def handle_platform_registration(
self,
platform_type: str,
payload: dict,
) -> dict:
try:
action = str(payload.get("action", "")).strip().lower()
if not action:
raise PlatformServiceError("Missing action", 400)
platform_config = payload.get("platform_config")
if not isinstance(platform_config, dict):
platform_config = {}
if platform_type == "lark":
return await self._handle_lark_registration(
action,
payload,
platform_config,
)
if platform_type == "weixin_oc":
return await self._handle_weixin_oc_registration(
action,
payload,
platform_config,
)
if platform_type == "dingtalk":
return await self._handle_dingtalk_registration(action, payload)
raise PlatformServiceError(
f"Unsupported platform registration: {platform_type}",
404,
)
except PlatformServiceError:
raise
except Exception as exc:
logger.error(f"处理平台一键创建请求失败: {exc}", exc_info=True)
raise PlatformServiceError(str(exc), 500) from exc
async def _handle_lark_registration(
self,
action: str,
payload: dict,
platform_config: dict,
) -> dict:
domain = str(platform_config.get("domain") or "").strip()
if action == "start":
registration = await request_app_registration(domain)
return {
"status": "pending",
"device_code": registration.device_code,
"registration_code": registration.device_code,
"user_code": registration.user_code,
"verification_uri": registration.verification_uri,
"verification_uri_complete": registration.verification_uri_complete,
"expires_in": registration.expires_in,
"interval": registration.interval,
}
if action == "poll":
device_code = str(
payload.get("device_code") or payload.get("registration_code") or ""
).strip()
if not device_code:
raise PlatformServiceError("Missing device_code", 400)
result = await poll_app_registration_once(
domain=domain,
device_code=device_code,
)
if result.get("status") == "created":
try:
bot_info = await request_lark_bot_info(
domain=str(result.get("domain") or domain),
app_id=str(result.get("app_id") or ""),
app_secret=str(result.get("app_secret") or ""),
)
if bot_info.app_name:
result["bot_name"] = bot_info.app_name
if bot_info.open_id:
result["bot_open_id"] = bot_info.open_id
except Exception as exc:
logger.error(f"获取飞书机器人信息失败: {exc}", exc_info=True)
return result
raise PlatformServiceError(f"Unsupported action: {action}", 400)
async def _handle_dingtalk_registration(
self,
action: str,
payload: dict,
) -> dict:
if action == "start":
registration = await request_dingtalk_app_registration()
return {
"status": "pending",
"device_code": registration.device_code,
"registration_code": registration.device_code,
"user_code": registration.user_code,
"verification_uri": registration.verification_uri,
"verification_uri_complete": registration.verification_uri_complete,
"expires_in": registration.expires_in,
"interval": registration.interval,
}
if action == "poll":
device_code = str(
payload.get("device_code") or payload.get("registration_code") or ""
).strip()
if not device_code:
raise PlatformServiceError("Missing device_code", 400)
result = await poll_dingtalk_app_registration_once(device_code)
if result.get("status") == "created":
result["platform_id_suffix"] = random_platform_id_suffix()
return result
raise PlatformServiceError(f"Unsupported action: {action}", 400)
async def _handle_weixin_oc_registration(
self,
action: str,
payload: dict,
platform_config: dict,
) -> dict:
if action == "start":
registration = await request_weixin_oc_login_qr(platform_config)
return {
"status": "pending",
"registration_code": registration.qrcode,
"qrcode": registration.qrcode,
"qrcode_img_content": registration.qrcode_img_content,
"interval": registration.interval,
}
if action == "poll":
qrcode = str(
payload.get("qrcode") or payload.get("registration_code") or ""
).strip()
if not qrcode:
raise PlatformServiceError("Missing qrcode", 400)
result = await poll_weixin_oc_login_once(
platform_config=platform_config,
qrcode=qrcode,
)
if result.get("status") == "created":
result["platform_id_suffix"] = random_platform_id_suffix()
return result
raise PlatformServiceError(f"Unsupported action: {action}", 400)

View File

@@ -0,0 +1,923 @@
from __future__ import annotations
import json
import mimetypes
import os
import posixpath
import re
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import cast
from urllib.parse import parse_qsl, quote, urlencode, urlsplit, urlunsplit
import aiofiles
import jwt
from aiofiles import ospath as aio_ospath
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.star import StarMetadata
from astrbot.core.star.star_manager import PluginManager
PLUGIN_PAGE_ASSET_TOKEN_TYPE = "plugin_page_asset"
PLUGIN_PAGE_ASSET_TOKEN_TTL_SECONDS = 60
PLUGIN_PAGE_ROOT_DIR_NAME = "pages"
PLUGIN_PAGE_ENTRY_FILE_NAME = "index.html"
PLUGIN_PAGE_BRIDGE_FILE = (
Path(__file__).resolve().parent.parent / "plugin_page_bridge.js"
)
_HTML_ASSET_ATTR_RE = re.compile(
r"(?P<attr>src|href)=(?P<quote>[\"\'])(?P<url>.*?)(?P=quote)",
re.IGNORECASE,
)
_CSS_URL_RE = re.compile(
r"url\(\s*(?P<quote>[\"\']?)(?P<url>.*?)(?P=quote)\s*\)",
re.IGNORECASE,
)
_JS_DYNAMIC_IMPORT_RE = re.compile(
r"(?P<prefix>\bimport\s*\(\s*)(?P<quote>[\"\'])(?P<url>.*?)(?P=quote)(?P<suffix>\s*\))",
re.IGNORECASE,
)
_JS_MODULE_FROM_RE = re.compile(
r"(?P<prefix>\b(?:import|export)\s+(?:[^;]*?\s+from\s+))(?P<quote>[\"\'])(?P<url>.*?)(?P=quote)",
re.IGNORECASE | re.DOTALL,
)
_JS_SIDE_EFFECT_IMPORT_RE = re.compile(
r"(?P<prefix>\bimport\s+)(?P<quote>[\"\'])(?P<url>[^\"'\r\n]+)(?P=quote)",
re.IGNORECASE,
)
@dataclass
class PluginPage:
name: str
title: str
entry_file: str = PLUGIN_PAGE_ENTRY_FILE_NAME
@dataclass
class PluginPageContentPayload:
content: str | bytes
content_type: str
class PluginPageServiceError(Exception):
def __init__(self, message: str, status_code: int = 400) -> None:
super().__init__(message)
self.status_code = status_code
class PluginPageService:
def __init__(
self,
plugin_manager: PluginManager,
core_lifecycle: AstrBotCoreLifecycle | None = None,
config: AstrBotConfig | None = None,
) -> None:
self.plugin_manager = plugin_manager
self.config = config or (
core_lifecycle.astrbot_config if core_lifecycle is not None else None
)
self.bridge_file = PLUGIN_PAGE_BRIDGE_FILE
def _jwt_secret(self) -> str | None:
if self.config is None:
return None
return self.config.get("dashboard", {}).get("jwt_secret")
def get_plugin_metadata_by_name(self, plugin_name: str) -> StarMetadata | None:
for plugin in self.plugin_manager.context.get_all_stars():
if plugin.name == plugin_name:
return plugin
return None
@staticmethod
def get_by_path(source: dict | None, key: str):
if not isinstance(source, dict) or not key:
return None
current = source
for part in key.split("."):
if not isinstance(current, dict) or part not in current:
return None
current = current[part]
return current
@staticmethod
def apply_theme_to_html(html: str, theme: str) -> str:
def _replace_html_tag(m: re.Match) -> str:
attrs = m.group(1) or ""
attrs = re.sub(
r'\s+data-theme\s*=\s*["\'][^"\']*["\']',
"",
attrs,
flags=re.IGNORECASE,
)
return f'<html{attrs} data-theme="{theme}">'
html = re.sub(
r"<html(\b[^>]*)>",
_replace_html_tag,
html,
count=1,
flags=re.IGNORECASE,
)
meta_tag = f'<meta name="color-scheme" content="{theme}">'
html = re.sub(
r'<meta\s[^>]*name\s*=\s*["\']color-scheme["\'][^>]*>',
"",
html,
flags=re.IGNORECASE,
)
head_match = re.search(r"<head\b[^>]*>", html, re.IGNORECASE)
if head_match:
html = html.replace(
head_match.group(0), f"{head_match.group(0)}{meta_tag}", 1
)
else:
html = re.sub(
r"(<html\b[^>]*>)",
rf"\1<head>{meta_tag}</head>",
html,
count=1,
flags=re.IGNORECASE,
)
return html
def build_initial_context(
self,
*,
asset_token: str,
jwt_secret: str | None = None,
locale: str,
theme: str | None,
) -> dict | None:
if not asset_token:
return None
jwt_secret = jwt_secret or self._jwt_secret()
if not isinstance(jwt_secret, str) or not jwt_secret.strip():
return None
try:
payload = jwt.decode(asset_token, jwt_secret, algorithms=["HS256"])
except jwt.InvalidTokenError:
return None
if payload.get("token_type") != PLUGIN_PAGE_ASSET_TOKEN_TYPE:
return None
plugin_name = payload.get("plugin_name")
page_name = payload.get("page_name")
if not isinstance(plugin_name, str) or not isinstance(page_name, str):
return None
plugin = self.get_plugin_metadata_by_name(plugin_name)
if not plugin:
return None
resolved_locale = (
payload.get("locale") if isinstance(payload.get("locale"), str) else locale
)
plugin_i18n = plugin.i18n or {}
try:
plugin_root = self.get_plugin_root_dir(plugin)
fresh_i18n = PluginManager._load_plugin_i18n(str(plugin_root))
if fresh_i18n:
plugin_i18n = fresh_i18n
except (OSError, ValueError):
pass
locale_data = plugin_i18n.get(resolved_locale)
display_name = (
self.get_by_path(locale_data, "metadata.display_name")
or plugin.display_name
or plugin.name
)
page_title = (
self.get_by_path(locale_data, f"pages.{page_name}.title") or page_name
)
return {
"pluginName": plugin.name,
"displayName": display_name,
"pageName": page_name,
"pageTitle": page_title,
"locale": resolved_locale,
"i18n": plugin_i18n,
"isDark": theme == "dark",
}
async def get_plugin_page_entry_config(
self,
*,
plugin_name: str | None,
page_name: str | None,
jwt_secret: str | None = None,
username: str | None,
locale: str,
) -> dict:
if not plugin_name:
raise PluginPageServiceError("缺少插件名")
if not page_name:
raise PluginPageServiceError("缺少 Page 名称")
plugin = self.get_plugin_metadata_by_name(plugin_name)
if not plugin:
raise PluginPageServiceError("插件不存在")
if not plugin.activated:
raise PluginPageServiceError("插件未启用")
page = await self.serialize_plugin_page_for_request(
plugin,
page_name,
include_content_path=True,
jwt_secret=jwt_secret,
username=username,
locale=locale,
)
if not page:
raise PluginPageServiceError("插件 Page 不存在")
return page
async def serialize_plugin_page_for_request(
self,
plugin: StarMetadata,
page_name: str,
*,
include_content_path: bool = False,
jwt_secret: str | None = None,
username: str | None,
locale: str,
) -> dict | None:
asset_token = ""
if include_content_path:
plugin_name = plugin.name.strip() if isinstance(plugin.name, str) else ""
asset_token = (
self.issue_plugin_page_asset_token(
plugin_name=plugin_name,
page_name=page_name,
jwt_secret=jwt_secret or self._jwt_secret(),
username=username,
locale=locale,
)
or ""
)
return await self.serialize_plugin_page(
plugin,
page_name,
include_content_path=include_content_path,
asset_token=asset_token,
)
def prepare_plugin_page_query_params(
self,
plugin_name: str,
page_name: str,
*,
asset_token: str,
jwt_secret: str | None = None,
username: str | None,
locale: str,
theme: str | None,
) -> dict[str, str] | None:
if not asset_token:
asset_token = (
self.issue_plugin_page_asset_token(
plugin_name=plugin_name,
page_name=page_name,
jwt_secret=jwt_secret or self._jwt_secret(),
username=username,
locale=locale,
)
or ""
)
if not asset_token and not theme:
return None
params: dict[str, str] = {}
if asset_token:
params["asset_token"] = asset_token
if theme:
params["theme"] = theme
return params
async def serve_bridge_sdk(
self,
*,
asset_token: str,
jwt_secret: str | None = None,
locale: str,
theme: str | None,
) -> PluginPageContentPayload:
if not self.bridge_file.is_file():
raise PluginPageServiceError(
"Plugin Page bridge SDK not found",
status_code=404,
)
bridge_js = await self.read_plugin_page_text(self.bridge_file)
initial_context = self.build_initial_context(
asset_token=asset_token,
jwt_secret=jwt_secret,
locale=locale,
theme=theme,
)
if initial_context:
context_json = json.dumps(initial_context, ensure_ascii=False)
bridge_js += (
f"\n;window.AstrBotPluginPage?.__setInitialContext({context_json});\n"
)
return PluginPageContentPayload(
content=bridge_js,
content_type="application/javascript; charset=utf-8",
)
async def serve_page_content(
self,
*,
plugin_name: str,
page_name: str,
asset_path: str,
asset_token: str,
jwt_secret: str | None = None,
username: str | None,
locale: str,
theme: str | None,
) -> PluginPageContentPayload:
plugin = self.get_plugin_metadata_by_name(plugin_name)
if not plugin:
raise PluginPageServiceError("Plugin not found", status_code=404)
if not plugin.activated:
raise PluginPageServiceError("Plugin is disabled", status_code=403)
try:
page = await self.get_plugin_page(plugin, page_name)
file_path = await self.resolve_plugin_page_file(
plugin,
page.name,
asset_path,
)
except (FileNotFoundError, ValueError) as exc:
raise PluginPageServiceError(
"Plugin Page asset not found",
status_code=404,
) from exc
extra_query_params = self.prepare_plugin_page_query_params(
plugin_name,
page.name,
asset_token=asset_token,
jwt_secret=jwt_secret,
username=username,
locale=locale,
theme=theme,
)
served_asset_path = asset_path or page.entry_file
suffix = file_path.suffix.lower()
if suffix == ".html":
html_text = await self.read_plugin_page_text(file_path)
return PluginPageContentPayload(
content=self.rewrite_plugin_page_html(
html_text,
plugin_name,
page.name,
served_asset_path,
theme=theme,
extra_query_params=extra_query_params,
),
content_type="text/html; charset=utf-8",
)
if suffix == ".css":
css_text = await self.read_plugin_page_text(file_path)
return PluginPageContentPayload(
content=self.rewrite_plugin_page_css(
css_text,
plugin_name,
page.name,
served_asset_path,
extra_query_params=extra_query_params,
),
content_type="text/css; charset=utf-8",
)
if suffix in {".js", ".mjs"}:
js_text = await self.read_plugin_page_text(file_path)
return PluginPageContentPayload(
content=self.rewrite_plugin_page_js(
js_text,
plugin_name,
page.name,
served_asset_path,
extra_query_params=extra_query_params,
),
content_type="application/javascript; charset=utf-8",
)
return PluginPageContentPayload(
content=await self.read_plugin_page_binary(file_path),
content_type=self.guess_plugin_page_mime_type(file_path),
)
@staticmethod
def build_security_headers() -> dict[str, str]:
headers = {
"Cache-Control": "no-store",
"Referrer-Policy": "no-referrer",
"X-Content-Type-Options": "nosniff",
"Cross-Origin-Resource-Policy": "cross-origin",
"Access-Control-Allow-Origin": "*",
}
csp = "object-src 'none'; base-uri 'self'"
if os.environ.get("ASTRBOT_LAUNCHER") not in ("1", "true"):
headers["X-Frame-Options"] = "SAMEORIGIN"
csp = f"frame-ancestors 'self'; {csp}"
headers["Content-Security-Policy"] = csp
return headers
@staticmethod
def normalize_plugin_page_path(
raw_path: str,
*,
base_dir: str | None = None,
allow_empty: bool = False,
) -> str:
path = raw_path.replace("\\", "/").strip()
if base_dir:
path = posixpath.join(base_dir, path)
normalized = posixpath.normpath(path)
if normalized in {"", "."}:
if allow_empty:
return ""
raise ValueError("Invalid plugin Page asset path")
if (
normalized.startswith("../")
or normalized == ".."
or normalized.startswith("/")
):
raise ValueError("Invalid plugin Page asset path")
return normalized
@staticmethod
def normalize_plugin_page_name(raw_name: str) -> str:
page_name = raw_name.strip()
if not page_name:
raise ValueError("Invalid plugin Page name")
normalized = posixpath.normpath(page_name.replace("\\", "/"))
if (
normalized != page_name
or normalized in {".", ".."}
or normalized.startswith(".")
or "/" in page_name
or "\\" in page_name
):
raise ValueError("Invalid plugin Page name")
return page_name
def get_plugin_root_dir(self, plugin: StarMetadata) -> Path:
if not plugin.root_dir_name:
raise FileNotFoundError("Plugin directory metadata is missing")
base_dir = Path(
self.plugin_manager.reserved_plugin_path
if plugin.reserved
else self.plugin_manager.plugin_store_path
).resolve(strict=False)
plugin_root = (base_dir / plugin.root_dir_name).resolve(strict=False)
plugin_root.relative_to(base_dir)
return plugin_root
async def resolve_plugin_pages_root(self, plugin: StarMetadata) -> Path:
plugin_root = self.get_plugin_root_dir(plugin)
pages_root = (plugin_root / PLUGIN_PAGE_ROOT_DIR_NAME).resolve(strict=False)
pages_root.relative_to(plugin_root)
if pages_root == plugin_root:
raise FileNotFoundError("Plugin Pages root directory is invalid")
if not await aio_ospath.isdir(str(pages_root)):
raise FileNotFoundError("Plugin Pages root directory does not exist")
return pages_root
async def discover_plugin_pages(self, plugin: StarMetadata) -> list[PluginPage]:
try:
pages_root = await self.resolve_plugin_pages_root(plugin)
except (FileNotFoundError, ValueError):
return []
pages: list[PluginPage] = []
try:
page_dirs = sorted(
(item for item in pages_root.iterdir() if item.is_dir()),
key=lambda item: item.name.lower(),
)
except OSError:
return []
for page_dir in page_dirs:
try:
page_name = self.normalize_plugin_page_name(page_dir.name)
except ValueError:
continue
entry_path = page_dir / PLUGIN_PAGE_ENTRY_FILE_NAME
if not await aio_ospath.isfile(str(entry_path)):
continue
pages.append(
PluginPage(
name=page_name,
title=page_name,
entry_file=PLUGIN_PAGE_ENTRY_FILE_NAME,
)
)
return pages
async def get_plugin_page(
self,
plugin: StarMetadata,
page_name: str,
) -> PluginPage:
normalized_name = self.normalize_plugin_page_name(page_name)
for page in await self.discover_plugin_pages(plugin):
if page.name == normalized_name:
return page
raise FileNotFoundError("Plugin Page entry not found")
async def resolve_plugin_page_root(
self,
plugin: StarMetadata,
page_name: str,
) -> Path:
normalized_name = self.normalize_plugin_page_name(page_name)
pages_root = await self.resolve_plugin_pages_root(plugin)
page_root = (pages_root / normalized_name).resolve(strict=False)
page_root.relative_to(pages_root)
if not await aio_ospath.isdir(str(page_root)):
raise FileNotFoundError("Plugin Page root directory does not exist")
return page_root
async def resolve_plugin_page_file(
self,
plugin: StarMetadata,
page_name: str,
asset_path: str,
) -> Path:
page = await self.get_plugin_page(plugin, page_name)
page_root = await self.resolve_plugin_page_root(plugin, page.name)
target_name = (
self.normalize_plugin_page_path(asset_path, allow_empty=True)
or page.entry_file
)
target_path = (page_root / target_name).resolve(strict=False)
target_path.relative_to(page_root)
if not await aio_ospath.isfile(str(target_path)):
raise FileNotFoundError("Plugin Page asset not found")
return target_path
@staticmethod
def is_rewritable_asset_url(raw_url: str) -> bool:
value = raw_url.strip()
lower = value.lower()
if not value:
return False
if value.startswith(("#", "/#")):
return False
if lower.startswith(
(
"http://",
"https://",
"//",
"data:",
"javascript:",
"mailto:",
"tel:",
"blob:",
)
):
return False
return True
@staticmethod
def resolve_referenced_asset_path(
base_asset_path: str,
referenced_url: str,
) -> str:
parts = urlsplit(referenced_url)
referenced_path = parts.path.strip()
if not referenced_path:
raise ValueError("Plugin Page referenced asset path is empty")
base_dir = posixpath.dirname(base_asset_path) if base_asset_path else ""
normalized = PluginPageService.normalize_plugin_page_path(
referenced_path,
base_dir=base_dir,
)
if not normalized:
raise ValueError("Plugin Page referenced asset path is invalid")
return normalized
def build_plugin_page_asset_url(
self,
plugin_name: str,
page_name: str,
asset_path: str,
original_query: str = "",
original_fragment: str = "",
extra_query_params: dict[str, str] | None = None,
) -> str:
path = self.build_plugin_page_content_path(plugin_name, page_name, asset_path)
query_dict = dict(parse_qsl(original_query, keep_blank_values=True))
if extra_query_params:
for key, value in extra_query_params.items():
if value:
query_dict[key] = value
query = urlencode(query_dict)
return urlunsplit(("", "", path, query, original_fragment))
@staticmethod
def build_plugin_page_content_path(
plugin_name: str,
page_name: str,
asset_path: str = "",
) -> str:
encoded_plugin_name = quote(plugin_name, safe="")
encoded_page_name = quote(
PluginPageService.normalize_plugin_page_name(page_name),
safe="",
)
if not asset_path:
return (
f"/api/plugin/page/content/{encoded_plugin_name}/{encoded_page_name}/"
)
safe_asset_path = PluginPageService.normalize_plugin_page_path(
asset_path,
allow_empty=True,
)
encoded_path = "/".join(
quote(part, safe="") for part in safe_asset_path.split("/")
)
return (
f"/api/plugin/page/content/{encoded_plugin_name}/"
f"{encoded_page_name}/{encoded_path}"
)
@staticmethod
def get_plugin_page_bridge_sdk_url(
extra_query_params: dict[str, str] | None = None,
) -> str:
query = urlencode(extra_query_params or {})
return urlunsplit(("", "", "/api/plugin/page/bridge-sdk.js", query, ""))
@staticmethod
def is_js_relative_module_specifier(raw_url: str) -> bool:
value = raw_url.strip()
return value.startswith(("./", "../", "/"))
def rewrite_relative_asset_url(
self,
raw_url: str,
base_asset_path: str,
plugin_name: str,
page_name: str,
extra_query_params: dict[str, str] | None = None,
) -> str | None:
candidate = raw_url.strip()
if not self.is_rewritable_asset_url(candidate):
return None
parts = urlsplit(candidate)
asset_path = self.resolve_referenced_asset_path(base_asset_path, candidate)
return self.build_plugin_page_asset_url(
plugin_name,
page_name,
asset_path,
original_query=parts.query,
original_fragment=parts.fragment,
extra_query_params=extra_query_params,
)
def rewrite_plugin_page_html(
self,
html_text: str,
plugin_name: str,
page_name: str,
entry_asset_path: str,
*,
theme: str | None,
extra_query_params: dict[str, str] | None = None,
) -> str:
def replace_attr(match: re.Match[str]) -> str:
raw_url = match.group("url")
attr = match.group("attr")
quote_char = match.group("quote")
if raw_url.strip() == "/api/plugin/page/bridge-sdk.js":
url = self.get_plugin_page_bridge_sdk_url(extra_query_params)
return f"{attr}={quote_char}{url}{quote_char}"
if not self.is_rewritable_asset_url(raw_url):
return match.group(0)
try:
rewritten_url = self.rewrite_relative_asset_url(
raw_url,
entry_asset_path,
plugin_name,
page_name,
extra_query_params=extra_query_params,
)
if not rewritten_url:
return match.group(0)
return f"{attr}={quote_char}{rewritten_url}{quote_char}"
except ValueError:
return match.group(0)
rewritten_html = _HTML_ASSET_ATTR_RE.sub(replace_attr, html_text)
if theme:
rewritten_html = self.apply_theme_to_html(rewritten_html, theme)
if "/api/plugin/page/bridge-sdk.js" not in rewritten_html:
bridge_tag = f'<script src="{self.get_plugin_page_bridge_sdk_url(extra_query_params)}"></script>'
if "</body>" in rewritten_html:
rewritten_html = rewritten_html.replace(
"</body>", f"{bridge_tag}</body>", 1
)
else:
rewritten_html += bridge_tag
return rewritten_html
def rewrite_plugin_page_css(
self,
css_text: str,
plugin_name: str,
page_name: str,
css_asset_path: str,
extra_query_params: dict[str, str] | None = None,
) -> str:
def replace_url(match: re.Match[str]) -> str:
raw_url = match.group("url").strip()
quote_char = match.group("quote") or ""
try:
rewritten_url = self.rewrite_relative_asset_url(
raw_url,
css_asset_path,
plugin_name,
page_name,
extra_query_params=extra_query_params,
)
if not rewritten_url:
return match.group(0)
return f"url({quote_char}{rewritten_url}{quote_char})"
except ValueError:
return match.group(0)
return _CSS_URL_RE.sub(replace_url, css_text)
def rewrite_plugin_page_js(
self,
js_text: str,
plugin_name: str,
page_name: str,
js_asset_path: str,
extra_query_params: dict[str, str] | None = None,
) -> str:
def rewrite_specifier(raw_url: str) -> str:
if not self.is_js_relative_module_specifier(raw_url):
return raw_url
if not self.is_rewritable_asset_url(raw_url):
return raw_url
rewritten = self.rewrite_relative_asset_url(
raw_url,
js_asset_path,
plugin_name,
page_name,
extra_query_params=extra_query_params,
)
return rewritten or raw_url
def replace_dynamic(match: re.Match[str]) -> str:
raw_url = match.group("url")
try:
rewritten = rewrite_specifier(raw_url)
except ValueError:
return match.group(0)
return (
f"{match.group('prefix')}{match.group('quote')}{rewritten}"
f"{match.group('quote')}{match.group('suffix')}"
)
def replace_from(match: re.Match[str]) -> str:
raw_url = match.group("url")
try:
rewritten = rewrite_specifier(raw_url)
except ValueError:
return match.group(0)
return (
f"{match.group('prefix')}{match.group('quote')}"
f"{rewritten}{match.group('quote')}"
)
rewritten_js = _JS_DYNAMIC_IMPORT_RE.sub(replace_dynamic, js_text)
rewritten_js = _JS_MODULE_FROM_RE.sub(replace_from, rewritten_js)
def replace_side_effect(match: re.Match[str]) -> str:
raw_url = match.group("url")
if raw_url.startswith(("{", "*")):
return match.group(0)
try:
rewritten = rewrite_specifier(raw_url)
except ValueError:
return match.group(0)
return (
f"{match.group('prefix')}{match.group('quote')}"
f"{rewritten}{match.group('quote')}"
)
return _JS_SIDE_EFFECT_IMPORT_RE.sub(replace_side_effect, rewritten_js)
@staticmethod
async def read_plugin_page_text(file_path: Path) -> str:
async with aiofiles.open(file_path, encoding="utf-8") as file:
return await file.read()
@staticmethod
async def read_plugin_page_binary(file_path: Path) -> bytes:
async with aiofiles.open(file_path, mode="rb") as file:
return await file.read()
@staticmethod
def guess_plugin_page_mime_type(file_path: Path) -> str:
return mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
async def serialize_plugin_page(
self,
plugin: StarMetadata,
page_name: str,
*,
include_content_path: bool = False,
asset_token: str = "",
) -> dict | None:
plugin_name = plugin.name.strip() if isinstance(plugin.name, str) else ""
if not plugin_name:
return None
try:
page = await self.get_plugin_page(plugin, page_name)
await self.resolve_plugin_page_file(plugin, page.name, "")
except (FileNotFoundError, ValueError):
return None
page_data = {
"name": page.name,
"title": page.title,
"i18n_key": f"pages.{page.name}",
}
if include_content_path:
extra_query_params = {"asset_token": asset_token} if asset_token else None
page_data["content_path"] = self.build_plugin_page_asset_url(
plugin_name,
page.name,
"",
extra_query_params=extra_query_params,
)
return page_data
async def serialize_plugin_pages(self, plugin: StarMetadata) -> list[dict]:
pages = []
for page in await self.discover_plugin_pages(plugin):
page_data = await self.serialize_plugin_page(plugin, page.name)
if page_data:
pages.append(page_data)
return pages
def issue_plugin_page_asset_token(
self,
*,
plugin_name: str,
page_name: str,
jwt_secret: str | None = None,
username: str | None,
locale: str,
) -> str | None:
jwt_secret = jwt_secret or self._jwt_secret()
if not isinstance(jwt_secret, str) or not jwt_secret.strip():
return None
if not isinstance(username, str) or not username.strip():
return None
now = datetime.now(timezone.utc)
payload = {
"username": username,
"token_type": PLUGIN_PAGE_ASSET_TOKEN_TYPE,
"plugin_name": plugin_name,
"page_name": page_name,
"locale": locale,
"iat": now,
"exp": now + timedelta(seconds=PLUGIN_PAGE_ASSET_TOKEN_TTL_SECONDS),
}
return cast(str, jwt.encode(payload, jwt_secret, algorithm="HS256"))
__all__ = [
"PLUGIN_PAGE_ASSET_TOKEN_TYPE",
"PLUGIN_PAGE_BRIDGE_FILE",
"PLUGIN_PAGE_ENTRY_FILE_NAME",
"PLUGIN_PAGE_ROOT_DIR_NAME",
"PluginPage",
"PluginPageContentPayload",
"PluginPageService",
"PluginPageServiceError",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
import httpx
import jwt
from fastapi import Request, Response
from astrbot.dashboard.v1.auth import AuthContext
_BODY_NOT_SET = object()
_SKIP_RESPONSE_HEADERS = {
"content-length",
"transfer-encoding",
"connection",
"keep-alive",
"server",
}
class DashboardRouteBridgeService:
"""Forward a v1 request to a dashboard compatibility route after v1 auth."""
def __init__(self, route_app, jwt_secret: str) -> None:
self.route_app = route_app
self.jwt_secret = jwt_secret
def _build_headers(
self, request: Request, auth: AuthContext | None
) -> dict[str, str]:
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in {"authorization", "host"}
}
if auth is None:
return headers
token = jwt.encode(
{"username": auth.username},
self.jwt_secret,
algorithm="HS256",
)
headers["Authorization"] = f"Bearer {token}"
return headers
@staticmethod
def _merge_query(
request: Request,
query: Mapping[str, Any] | None,
*,
drop: set[str] | None = None,
) -> list[tuple[str, str]]:
drop = drop or set()
pairs = [
(key, value)
for key, value in request.query_params.multi_items()
if key not in drop
]
if query:
for key, value in query.items():
if value is None:
continue
if isinstance(value, list | tuple):
pairs.extend((key, str(item)) for item in value)
else:
pairs.append((key, str(value)))
return pairs
@staticmethod
def _as_response(response: httpx.Response) -> Response:
headers = {
key: value
for key, value in response.headers.items()
if key.lower() not in _SKIP_RESPONSE_HEADERS
}
return Response(
content=response.content,
status_code=response.status_code,
headers=headers,
media_type=headers.get("content-type"),
)
async def forward(
self,
request: Request,
auth: AuthContext | None,
*,
method: str,
target_path: str,
query: Mapping[str, Any] | None = None,
drop_query: set[str] | None = None,
json_body: Any = _BODY_NOT_SET,
) -> Response:
headers = self._build_headers(request, auth)
params = self._merge_query(request, query, drop=drop_query)
transport = httpx.ASGITransport(app=self.route_app)
async with httpx.AsyncClient(
transport=transport,
base_url="http://dashboard-routes",
) as client:
if json_body is _BODY_NOT_SET:
response = await client.request(
method,
target_path,
params=params,
content=await request.body(),
headers=headers,
)
else:
response = await client.request(
method,
target_path,
params=params,
json=json_body,
headers=headers,
)
return self._as_response(response)

View File

@@ -0,0 +1,715 @@
from __future__ import annotations
import uuid
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, select
from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ConversationV2, Preference
from astrbot.core.provider.entities import ProviderType
from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias
AVAILABLE_SESSION_RULE_KEYS = [
"session_service_config",
"session_plugin_config",
"kb_config",
f"provider_perf_{ProviderType.CHAT_COMPLETION.value}",
f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}",
f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}",
]
class SessionManagementServiceError(Exception):
pass
class SessionManagementService:
def __init__(
self,
core_lifecycle: AstrBotCoreLifecycle,
db_helper: BaseDatabase,
) -> None:
self.core_lifecycle = core_lifecycle
self.db_helper = db_helper
@staticmethod
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}
@staticmethod
def _is_group_umo(umo: str) -> bool:
umo_lower = umo.lower()
return ":group:" in umo_lower or ":groupmessage:" in umo_lower
@staticmethod
def _is_private_umo(umo: str) -> bool:
umo_lower = umo.lower()
return (
":private:" in umo_lower
or ":friend:" in umo_lower
or ":friendmessage:" in umo_lower
)
async def list_known_umos(self) -> list[str]:
async with self.db_helper.get_db() as session:
session: AsyncSession
result = await session.execute(select(ConversationV2.user_id).distinct())
umos = {str(row[0]) for row in result.fetchall() if row[0]}
aliases = await self.db_helper.get_umo_aliases()
umos.update(str(alias.umo) for alias in aliases if alias.umo)
return sorted(umos)
async def get_umo_alias_map(self, umos: list[str]) -> dict:
return build_umo_alias_map(await self.db_helper.get_umo_aliases(umos))
def build_umo_info(self, umo: str | None, alias_map: dict) -> dict:
umo_str = umo or ""
return {
"umo": umo_str,
**parse_umo(umo_str),
**serialize_umo_alias(alias_map.get(umo_str), umo_str),
}
async def list_active_umos(self) -> dict:
umos = await self.list_known_umos()
alias_map = await self.get_umo_alias_map(umos)
return {
"umos": umos,
"umo_infos": [self.build_umo_info(umo, alias_map) for umo in umos],
}
async def get_umos_by_scope(
self,
scope: str,
group_id: str = "",
) -> list[str]:
if scope == "custom_group":
if not group_id:
raise SessionManagementServiceError("请指定分组 ID")
groups = self.get_groups()
if group_id not in groups:
raise SessionManagementServiceError(f"分组 '{group_id}' 不存在")
return groups[group_id].get("umos", [])
all_umos = await self.list_known_umos()
if scope == "group":
return [umo for umo in all_umos if self._is_group_umo(umo)]
if scope == "private":
return [umo for umo in all_umos if self._is_private_umo(umo)]
if scope == "all":
return all_umos
return []
async def get_umo_rules(
self,
page: int = 1,
page_size: int = 10,
search: str = "",
) -> tuple[dict, int]:
umo_rules = {}
async with self.db_helper.get_db() as session:
session: AsyncSession
result = await session.execute(
select(Preference).where(
col(Preference.scope) == "umo",
col(Preference.key).in_(AVAILABLE_SESSION_RULE_KEYS),
)
)
prefs = result.scalars().all()
for pref in prefs:
umo_id = pref.scope_id
if umo_id not in umo_rules:
umo_rules[umo_id] = {}
if pref.key == "session_plugin_config" and umo_id in pref.value["val"]:
umo_rules[umo_id][pref.key] = pref.value["val"][umo_id]
else:
umo_rules[umo_id][pref.key] = pref.value["val"]
alias_map = await self.get_umo_alias_map(list(umo_rules.keys()))
if search:
search_lower = search.lower()
filtered_rules = {}
for umo_id, rules in umo_rules.items():
if search_lower in umo_id.lower():
filtered_rules[umo_id] = rules
continue
svc_config = rules.get("session_service_config", {})
custom_name = svc_config.get("custom_name", "") if svc_config else ""
if custom_name and search_lower in custom_name.lower():
filtered_rules[umo_id] = rules
continue
alias_info = serialize_umo_alias(alias_map.get(umo_id), umo_id)
if any(
search_lower in alias_info[key].lower()
for key in ("auto_name", "user_alias", "display_name")
if alias_info.get(key)
):
filtered_rules[umo_id] = rules
umo_rules = filtered_rules
total = len(umo_rules)
all_umo_ids = list(umo_rules.keys())
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_umo_ids = all_umo_ids[start_idx:end_idx]
return {umo_id: umo_rules[umo_id] for umo_id in paginated_umo_ids}, total
async def list_session_rules(
self,
*,
page: int,
page_size: int,
search: str,
) -> dict:
page, page_size = self._normalize_page(page, page_size, default_page_size=10)
umo_rules, total = await self.get_umo_rules(
page=page,
page_size=page_size,
search=search,
)
alias_map = await self.get_umo_alias_map(list(umo_rules.keys()))
rules_list = [
{
"rules": rules,
**self.build_umo_info(umo, alias_map),
}
for umo, rules in umo_rules.items()
]
provider_manager = self.core_lifecycle.provider_manager
persona_mgr = getattr(self.core_lifecycle, "persona_mgr", None)
plugin_manager = getattr(self.core_lifecycle, "plugin_manager", None)
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
available_personas = [
{"name": p["name"], "prompt": p.get("prompt", "")}
for p in getattr(persona_mgr, "personas_v3", [])
]
available_plugins = []
if plugin_manager and getattr(plugin_manager, "context", None):
available_plugins = [
{
"name": p.name,
"display_name": p.display_name or p.name,
"desc": p.desc,
}
for p in plugin_manager.context.get_all_stars()
if not p.reserved and p.name
]
available_kbs = []
if kb_manager:
try:
kbs = await kb_manager.list_kbs()
available_kbs = [
{
"kb_id": kb.kb_id,
"kb_name": kb.kb_name,
"emoji": kb.emoji,
}
for kb in kbs
]
except Exception as exc:
logger.warning(f"获取知识库列表失败: {exc!s}")
return {
"rules": rules_list,
"total": total,
"page": page,
"page_size": page_size,
"available_personas": available_personas,
"available_chat_providers": self._serialize_provider_insts(
getattr(provider_manager, "provider_insts", [])
),
"available_stt_providers": self._serialize_provider_insts(
getattr(provider_manager, "stt_provider_insts", [])
),
"available_tts_providers": self._serialize_provider_insts(
getattr(provider_manager, "tts_provider_insts", [])
),
"available_plugins": available_plugins,
"available_kbs": available_kbs,
"available_rule_keys": AVAILABLE_SESSION_RULE_KEYS,
}
async def list_session_rules_from_legacy_query(
self,
*,
page,
page_size,
search,
) -> dict:
return await self.list_session_rules(
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 10),
search=str(search or "").strip(),
)
async def update_session_rule(self, data: object) -> dict:
payload = self._payload(data)
umo = payload.get("umo")
rule_key = payload.get("rule_key")
rule_value = payload.get("rule_value")
if not umo:
raise SessionManagementServiceError("缺少必要参数: umo")
if not rule_key:
raise SessionManagementServiceError("缺少必要参数: rule_key")
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
raise SessionManagementServiceError(f"不支持的规则键: {rule_key}")
if rule_key == "session_plugin_config":
rule_value = {umo: rule_value}
await sp.session_put(umo, rule_key, rule_value)
return {"message": f"规则 {rule_key} 已更新", "umo": umo}
async def delete_session_rule(self, data: object) -> dict:
payload = self._payload(data)
umo = payload.get("umo")
rule_key = payload.get("rule_key")
if not umo:
raise SessionManagementServiceError("缺少必要参数: umo")
if rule_key:
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
raise SessionManagementServiceError(f"不支持的规则键: {rule_key}")
await sp.session_remove(umo, rule_key)
return {"message": f"规则 {rule_key} 已删除", "umo": umo}
await sp.clear_async("umo", umo)
return {"message": "所有规则已删除", "umo": umo}
async def delete_session_rules(self, data: object) -> dict:
payload = self._payload(data)
if payload.get("umo") and not payload.get("umos") and not payload.get("scope"):
return await self.delete_session_rule(payload)
return await self.batch_delete_session_rule(payload)
async def batch_delete_session_rule(self, data: object) -> dict:
payload = self._payload(data)
umos = payload.get("umos", [])
scope = payload.get("scope", "")
group_id = payload.get("group_id", "")
rule_key = payload.get("rule_key")
if scope and not umos:
umos = await self.get_umos_by_scope(scope, group_id)
if not umos:
raise SessionManagementServiceError("缺少必要参数: umos 或有效的 scope")
if not isinstance(umos, list):
raise SessionManagementServiceError("参数 umos 必须是数组")
if rule_key and rule_key not in AVAILABLE_SESSION_RULE_KEYS:
raise SessionManagementServiceError(f"不支持的规则键: {rule_key}")
success_count = 0
failed_umos = []
for umo in umos:
try:
if rule_key:
await sp.session_remove(umo, rule_key)
else:
await sp.clear_async("umo", umo)
success_count += 1
except Exception as exc:
logger.error(f"删除 umo {umo} 的规则失败: {exc!s}")
failed_umos.append(umo)
message = f"已删除 {success_count} 条规则"
if rule_key:
message = f"已删除 {success_count}{rule_key} 规则"
result = {
"message": message,
"success_count": success_count,
}
if failed_umos:
result.update(
{
"message": f"{message}{len(failed_umos)} 条删除失败",
"failed_umos": failed_umos,
}
)
return result
async def list_all_umos_with_status(
self,
*,
page: int,
page_size: int,
search: str,
message_type: str,
platform: str,
) -> dict:
page, page_size = self._normalize_page(page, page_size, default_page_size=20)
all_umos = await self.list_known_umos()
alias_map = await self.get_umo_alias_map(all_umos)
umo_rules, _ = await self.get_umo_rules(page=1, page_size=99999, search="")
umos_with_status = []
for umo in all_umos:
umo_info = self.build_umo_info(umo, alias_map)
umo_platform = umo_info["platform"]
umo_message_type = umo_info["message_type"]
if message_type != "all":
if message_type == "group" and umo_message_type not in [
"group",
"GroupMessage",
]:
continue
if message_type == "private" and umo_message_type not in [
"private",
"FriendMessage",
"friend",
]:
continue
if platform and umo_platform != platform:
continue
rules = umo_rules.get(umo, {})
svc_config = rules.get("session_service_config", {})
custom_name = svc_config.get("custom_name", "") if svc_config else ""
session_enabled = (
svc_config.get("session_enabled", True) if svc_config else True
)
llm_enabled = svc_config.get("llm_enabled", True) if svc_config else True
tts_enabled = svc_config.get("tts_enabled", True) if svc_config else True
if search:
search_lower = search.lower()
search_targets = [
umo,
custom_name,
umo_info["auto_name"],
umo_info["user_alias"],
umo_info["display_name"],
]
if not any(
search_lower in target.lower()
for target in search_targets
if target
):
continue
chat_provider_key = f"provider_perf_{ProviderType.CHAT_COMPLETION.value}"
tts_provider_key = f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}"
stt_provider_key = f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}"
umos_with_status.append(
{
**umo_info,
"custom_name": custom_name,
"session_enabled": session_enabled,
"llm_enabled": llm_enabled,
"tts_enabled": tts_enabled,
"has_rules": umo in umo_rules,
"chat_provider": rules.get(chat_provider_key),
"tts_provider": rules.get(tts_provider_key),
"stt_provider": rules.get(stt_provider_key),
}
)
total = len(umos_with_status)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated = umos_with_status[start_idx:end_idx]
platforms = list({u["platform"] for u in umos_with_status})
provider_manager = self.core_lifecycle.provider_manager
return {
"sessions": paginated,
"total": total,
"page": page,
"page_size": page_size,
"platforms": platforms,
"available_chat_providers": self._serialize_provider_insts(
getattr(provider_manager, "provider_insts", [])
),
"available_tts_providers": self._serialize_provider_insts(
getattr(provider_manager, "tts_provider_insts", [])
),
"available_stt_providers": self._serialize_provider_insts(
getattr(provider_manager, "stt_provider_insts", [])
),
}
async def list_all_umos_with_status_from_legacy_query(
self,
*,
page,
page_size,
search,
message_type,
platform,
) -> dict:
return await self.list_all_umos_with_status(
page=self._to_int(page, 1),
page_size=self._to_int(page_size, 20),
search=str(search or "").strip(),
message_type=str(message_type or "all"),
platform=str(platform or ""),
)
async def batch_update_service(self, data: object) -> dict:
payload = self._payload(data)
umos = payload.get("umos", [])
scope = payload.get("scope", "")
group_id = payload.get("group_id", "")
llm_enabled = payload.get("llm_enabled")
tts_enabled = payload.get("tts_enabled")
session_enabled = payload.get("session_enabled")
if llm_enabled is None and tts_enabled is None and session_enabled is None:
raise SessionManagementServiceError("至少需要指定一个要修改的状态")
if scope and not umos:
umos = await self.get_umos_by_scope(scope, group_id)
if not umos:
raise SessionManagementServiceError("没有找到符合条件的会话")
success_count = 0
failed_umos = []
for umo in umos:
try:
session_config = (
sp.get("session_service_config", {}, scope="umo", scope_id=umo)
or {}
)
if llm_enabled is not None:
session_config["llm_enabled"] = llm_enabled
if tts_enabled is not None:
session_config["tts_enabled"] = tts_enabled
if session_enabled is not None:
session_config["session_enabled"] = session_enabled
sp.put(
"session_service_config",
session_config,
scope="umo",
scope_id=umo,
)
success_count += 1
except Exception as exc:
logger.error(f"更新 {umo} 服务状态失败: {exc!s}")
failed_umos.append(umo)
status_changes = []
if llm_enabled is not None:
status_changes.append(f"LLM={'启用' if llm_enabled else '禁用'}")
if tts_enabled is not None:
status_changes.append(f"TTS={'启用' if tts_enabled else '禁用'}")
if session_enabled is not None:
status_changes.append(f"会话={'启用' if session_enabled else '禁用'}")
return {
"message": f"已更新 {success_count} 个会话 ({', '.join(status_changes)})",
"success_count": success_count,
"failed_count": len(failed_umos),
"failed_umos": failed_umos,
}
async def batch_update_provider(self, data: object) -> dict:
payload = self._payload(data)
umos = payload.get("umos", [])
scope = payload.get("scope", "")
provider_type = payload.get("provider_type")
provider_id = payload.get("provider_id")
if not provider_type or not provider_id:
raise SessionManagementServiceError(
"缺少必要参数: provider_type, provider_id"
)
provider_type_map = {
"chat_completion": ProviderType.CHAT_COMPLETION,
"text_to_speech": ProviderType.TEXT_TO_SPEECH,
"speech_to_text": ProviderType.SPEECH_TO_TEXT,
}
if provider_type not in provider_type_map:
raise SessionManagementServiceError(
f"不支持的 provider_type: {provider_type}"
)
group_id = payload.get("group_id", "")
if scope and not umos:
umos = await self.get_umos_by_scope(scope, group_id)
if not umos:
raise SessionManagementServiceError("没有找到符合条件的会话")
success_count = 0
failed_umos = []
provider_manager = self.core_lifecycle.provider_manager
for umo in umos:
try:
await provider_manager.set_provider(
provider_id=provider_id,
provider_type=provider_type_map[provider_type],
umo=umo,
)
success_count += 1
except Exception as exc:
logger.error(f"更新 {umo} Provider 失败: {exc!s}")
failed_umos.append(umo)
return {
"message": f"已更新 {success_count} 个会话的 {provider_type}{provider_id}",
"success_count": success_count,
"failed_count": len(failed_umos),
"failed_umos": failed_umos,
}
def get_groups(self) -> dict:
return sp.get("session_groups", {})
def save_groups(self, groups: dict) -> None:
sp.put("session_groups", groups)
def list_groups(self) -> dict:
groups = self.get_groups()
return {
"groups": [
{
"id": group_id,
"name": group_data.get("name", ""),
"umos": group_data.get("umos", []),
"umo_count": len(group_data.get("umos", [])),
}
for group_id, group_data in groups.items()
]
}
def create_group(self, data: object) -> dict:
payload = self._payload(data)
name = str(payload.get("name", "")).strip()
umos = payload.get("umos", [])
if not name:
raise SessionManagementServiceError("分组名称不能为空")
groups = self.get_groups()
group_id = str(uuid.uuid4())[:8]
groups[group_id] = {
"name": name,
"umos": umos,
}
self.save_groups(groups)
return {
"message": f"分组 '{name}' 创建成功",
"group": {
"id": group_id,
"name": name,
"umos": umos,
"umo_count": len(umos),
},
}
def update_group(self, data: object) -> dict:
payload = self._payload(data)
group_id = payload.get("id") or payload.get("group_id")
name = payload.get("name")
umos = payload.get("umos")
add_umos = payload.get("add_umos", [])
remove_umos = payload.get("remove_umos", [])
if not group_id:
raise SessionManagementServiceError("分组 ID 不能为空")
groups = self.get_groups()
if group_id not in groups:
raise SessionManagementServiceError(f"分组 '{group_id}' 不存在")
group = groups[group_id]
if name is not None:
group["name"] = name.strip()
if umos is not None:
group["umos"] = umos
else:
current_umos = set(group.get("umos", []))
if add_umos:
current_umos.update(add_umos)
if remove_umos:
current_umos.difference_update(remove_umos)
group["umos"] = list(current_umos)
self.save_groups(groups)
return {
"message": f"分组 '{group['name']}' 更新成功",
"group": {
"id": group_id,
"name": group["name"],
"umos": group["umos"],
"umo_count": len(group["umos"]),
},
}
def delete_group(self, data: object) -> dict:
payload = self._payload(data)
group_id = payload.get("id") or payload.get("group_id")
if not group_id:
raise SessionManagementServiceError("分组 ID 不能为空")
groups = self.get_groups()
if group_id not in groups:
raise SessionManagementServiceError(f"分组 '{group_id}' 不存在")
group_name = groups[group_id].get("name", group_id)
del groups[group_id]
self.save_groups(groups)
return {"message": f"分组 '{group_name}' 已删除"}
@staticmethod
def _normalize_page(
page: int,
page_size: int,
*,
default_page_size: int,
) -> tuple[int, int]:
if page < 1:
page = 1
if page_size < 1:
page_size = default_page_size
if page_size > 100:
page_size = 100
return page, page_size
@staticmethod
def _to_int(value, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
@staticmethod
def _serialize_provider_insts(provider_insts: list) -> list[dict]:
return [
{
"id": provider.meta().id,
"name": provider.meta().id,
"model": provider.meta().model,
}
for provider in provider_insts
]

View File

@@ -0,0 +1,877 @@
from __future__ import annotations
import os
import re
import shutil
import traceback
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from astrbot.core import DEMO_MODE, logger
from astrbot.core.computer.computer_client import (
_discover_bay_credentials,
sync_skills_to_active_sandboxes,
)
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from astrbot.core.skills.skill_manager import SkillManager
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
_SKILL_FILE_MAX_BYTES = 512 * 1024
_EDITABLE_SKILL_FILE_SUFFIXES = {
".css",
".html",
".ini",
".js",
".json",
".md",
".py",
".sh",
".toml",
".ts",
".txt",
".yaml",
".yml",
}
_EDITABLE_SKILL_FILENAMES = {"Dockerfile", "Makefile"}
class SkillsServiceError(Exception):
pass
@dataclass
class SkillsOperationResult:
ok: bool = True
data: dict | list | None = None
message: str | None = None
@dataclass
class SkillArchive:
path: Path
filename: str
def _to_jsonable(value: Any) -> Any:
if isinstance(value, dict):
return {key: _to_jsonable(item) for key, item in value.items()}
if isinstance(value, list):
return [_to_jsonable(item) for item in value]
if hasattr(value, "model_dump"):
return _to_jsonable(value.model_dump())
return value
def _to_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
return bool(value)
def _next_available_temp_path(temp_dir: str, filename: str) -> str:
stem = Path(filename).stem
suffix = Path(filename).suffix
candidate = filename
index = 1
while os.path.exists(os.path.join(temp_dir, candidate)):
candidate = f"{stem}_{index}{suffix}"
index += 1
return os.path.join(temp_dir, candidate)
class SkillsService:
def __init__(self, core_lifecycle) -> None:
self.core_lifecycle = core_lifecycle
@staticmethod
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}
@staticmethod
def _ensure_mutation_allowed() -> None:
if DEMO_MODE:
raise SkillsServiceError(
"You are not permitted to do this operation in demo mode"
)
@staticmethod
async def _save_upload(file: Any, target_path: str) -> None:
if hasattr(file, "save"):
maybe_awaitable = file.save(target_path)
if hasattr(maybe_awaitable, "__await__"):
await maybe_awaitable
return
if hasattr(file, "read"):
data = file.read()
if hasattr(data, "__await__"):
data = await data
Path(target_path).write_bytes(data)
return
raise SkillsServiceError("Invalid upload file")
def resolve_local_skill_dir(self, name: str) -> Path:
skill_name = str(name or "").strip()
if not skill_name:
raise ValueError("Missing skill name")
if not _SKILL_NAME_RE.match(skill_name):
raise ValueError("Invalid skill name")
skill_mgr = SkillManager()
if skill_mgr.is_sandbox_only_skill(skill_name):
raise PermissionError(
"Sandbox preset skill cannot be opened from local skill files."
)
plugin_skill_dir = skill_mgr._get_plugin_skill_dir(skill_name)
if plugin_skill_dir is not None:
return plugin_skill_dir.resolve(strict=True)
skills_root = Path(skill_mgr.skills_root).resolve(strict=True)
skill_dir = (skills_root / skill_name).resolve(strict=True)
if not skill_dir.is_relative_to(skills_root):
raise PermissionError("Invalid skill path")
if not skill_dir.is_dir() or not (skill_dir / "SKILL.md").exists():
raise FileNotFoundError("Local skill not found")
return skill_dir
@staticmethod
def resolve_skill_relative_path(
skill_dir: Path,
relative_path: str | None,
*,
expect_file: bool,
) -> Path:
raw_path = str(relative_path or ".").strip() or "."
normalized = Path(raw_path.replace("\\", "/"))
if normalized.is_absolute() or ".." in normalized.parts:
raise ValueError("Invalid relative path")
target = (skill_dir / normalized).resolve(strict=True)
if not target.is_relative_to(skill_dir):
raise PermissionError("Path escapes skill directory")
if expect_file and not target.is_file():
raise FileNotFoundError("Skill file not found")
if not expect_file and not target.is_dir():
raise FileNotFoundError("Skill directory not found")
return target
@staticmethod
def skill_relative_path(skill_dir: Path, target: Path) -> str:
rel = target.relative_to(skill_dir).as_posix()
return "" if rel == "." else rel
@staticmethod
def is_editable_skill_file(path: Path) -> bool:
return (
path.name in _EDITABLE_SKILL_FILENAMES
or path.suffix.lower() in _EDITABLE_SKILL_FILE_SUFFIXES
)
def serialize_skill_file_entry(
self,
skill_dir: Path,
path: Path,
*,
readonly: bool = False,
) -> dict:
stat = path.stat()
is_dir = path.is_dir()
return {
"name": path.name,
"path": self.skill_relative_path(skill_dir, path),
"type": "directory" if is_dir else "file",
"size": 0 if is_dir else stat.st_size,
"editable": (
not readonly
and (not is_dir)
and self.is_editable_skill_file(path)
and stat.st_size <= _SKILL_FILE_MAX_BYTES
),
}
def get_neo_client_config(self) -> tuple[str, str]:
provider_settings = self.core_lifecycle.astrbot_config.get(
"provider_settings",
{},
)
sandbox = provider_settings.get("sandbox", {})
endpoint = sandbox.get("shipyard_neo_endpoint", "")
access_token = sandbox.get("shipyard_neo_access_token", "")
if not access_token and endpoint:
access_token = _discover_bay_credentials(endpoint)
if not endpoint or not access_token:
raise ValueError(
"Shipyard Neo endpoint or access token not configured. "
"Set them in Dashboard or ensure Bay's credentials.json is accessible."
)
return endpoint, access_token
async def with_neo_client(
self,
operation: Callable[[Any], Awaitable[Any]],
) -> SkillsOperationResult:
try:
endpoint, access_token = self.get_neo_client_config()
from shipyard_neo import BayClient
async with BayClient(
endpoint_url=endpoint,
access_token=access_token,
) as client:
result = await operation(client)
if isinstance(result, SkillsOperationResult):
return result
return SkillsOperationResult(data=_to_jsonable(result))
except ValueError as exc:
logger.debug("[Neo] %s", exc)
return SkillsOperationResult(ok=False, message=str(exc))
except Exception as exc:
logger.error(traceback.format_exc())
return SkillsOperationResult(ok=False, message=str(exc))
def get_skills(self) -> dict:
provider_settings = self.core_lifecycle.astrbot_config.get(
"provider_settings", {}
)
runtime = provider_settings.get("computer_use_runtime", "local")
skill_mgr = SkillManager()
skills = skill_mgr.list_skills(
active_only=False,
runtime=runtime,
show_sandbox_path=False,
)
return {
"skills": [skill.__dict__ for skill in skills],
"runtime": runtime,
"sandbox_cache": skill_mgr.get_sandbox_skills_cache_status(),
}
async def upload_skill(self, file: Any | None) -> SkillsOperationResult:
self._ensure_mutation_allowed()
temp_path = None
if not file:
raise SkillsServiceError("Missing file")
filename = os.path.basename(file.filename or "skill.zip")
if not filename.lower().endswith(".zip"):
raise SkillsServiceError("Only .zip files are supported")
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
skill_mgr = SkillManager()
temp_path = _next_available_temp_path(temp_dir, filename)
try:
await self._save_upload(file, temp_path)
try:
skill_name = skill_mgr.install_skill_from_zip(
temp_path,
overwrite=False,
skill_name_hint=Path(filename).stem,
)
except TypeError:
skill_name = skill_mgr.install_skill_from_zip(
temp_path,
overwrite=False,
)
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync uploaded skills to active sandboxes.")
return SkillsOperationResult(
data={"name": skill_name},
message="Skill uploaded successfully.",
)
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
logger.warning(f"Failed to remove temp skill file: {temp_path}")
async def batch_upload_skills(self, file_list: list[Any]) -> SkillsOperationResult:
self._ensure_mutation_allowed()
if not file_list:
raise SkillsServiceError("No files provided")
succeeded = []
failed = []
skipped = []
skill_mgr = SkillManager()
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
for file in file_list:
filename = os.path.basename(file.filename or "unknown.zip")
temp_path = None
try:
if not filename.lower().endswith(".zip"):
failed.append(
{
"filename": filename,
"error": "Only .zip files are supported",
}
)
continue
temp_path = _next_available_temp_path(temp_dir, filename)
await self._save_upload(file, temp_path)
try:
skill_name = skill_mgr.install_skill_from_zip(
temp_path,
overwrite=False,
skill_name_hint=Path(filename).stem,
)
except TypeError:
try:
skill_name = skill_mgr.install_skill_from_zip(
temp_path,
overwrite=False,
)
except FileExistsError:
skipped.append(
{
"filename": filename,
"name": Path(filename).stem,
"error": "Skill already exists.",
}
)
skill_name = None
except FileExistsError:
skipped.append(
{
"filename": filename,
"name": Path(filename).stem,
"error": "Skill already exists.",
}
)
skill_name = None
if skill_name is None:
continue
succeeded.append({"filename": filename, "name": skill_name})
except Exception as exc:
failed.append({"filename": filename, "error": str(exc)})
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
pass
if succeeded:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync uploaded skills to active sandboxes.")
total = len(file_list)
success_count = len(succeeded)
skipped_count = len(skipped)
failed_count = len(failed)
data = {
"total": total,
"succeeded": succeeded,
"failed": failed,
"skipped": skipped,
}
if failed_count == 0 and success_count == total:
return SkillsOperationResult(
data=data,
message=f"All {total} skill(s) uploaded successfully.",
)
if failed_count == 0 and success_count == 0:
return SkillsOperationResult(
data=data,
message=f"All {total} file(s) were skipped.",
)
if success_count == 0 and skipped_count == 0:
return SkillsOperationResult(
ok=False,
data=data,
message=f"Upload failed for all {total} file(s).",
)
return SkillsOperationResult(
data=data,
message=f"Partial success: {success_count}/{total} skill(s) uploaded.",
)
def prepare_skill_archive(self, name: str) -> SkillArchive:
skill_name = str(name or "").strip()
if not skill_name:
raise SkillsServiceError("Missing skill name")
if not _SKILL_NAME_RE.match(skill_name):
raise SkillsServiceError("Invalid skill name")
skill_mgr = SkillManager()
if skill_mgr.is_sandbox_only_skill(skill_name):
raise SkillsServiceError(
"Sandbox preset skill cannot be downloaded from local skill files."
)
if skill_mgr.is_plugin_skill(skill_name):
raise SkillsServiceError(
"Plugin-provided skill cannot be downloaded from local skill files."
)
skill_dir = Path(skill_mgr.skills_root) / skill_name
skill_md = skill_dir / "SKILL.md"
if not skill_dir.is_dir() or not skill_md.exists():
raise SkillsServiceError("Local skill not found")
export_dir = Path(get_astrbot_temp_path()) / "skill_exports"
export_dir.mkdir(parents=True, exist_ok=True)
zip_base = export_dir / skill_name
zip_path = zip_base.with_suffix(".zip")
if zip_path.exists():
zip_path.unlink()
shutil.make_archive(
str(zip_base),
"zip",
root_dir=str(skill_mgr.skills_root),
base_dir=skill_name,
)
return SkillArchive(path=zip_path, filename=f"{skill_name}.zip")
def prepare_skill_archive_from_legacy_query(self, name: str | None) -> SkillArchive:
return self.prepare_skill_archive(name or "")
def list_skill_files(self, name: str, relative_path: str | None = "") -> dict:
skill_name = str(name or "").strip()
readonly = SkillManager().is_plugin_skill(skill_name)
skill_dir = self.resolve_local_skill_dir(skill_name)
target_dir = self.resolve_skill_relative_path(
skill_dir,
relative_path,
expect_file=False,
)
entries = []
for entry in sorted(
target_dir.iterdir(),
key=lambda item: (not item.is_dir(), item.name.lower()),
):
try:
resolved = entry.resolve(strict=True)
except OSError:
continue
if not resolved.is_relative_to(skill_dir):
continue
if not resolved.is_dir() and not resolved.is_file():
continue
entries.append(
self.serialize_skill_file_entry(
skill_dir,
resolved,
readonly=readonly,
)
)
return {
"name": skill_name,
"path": self.skill_relative_path(skill_dir, target_dir),
"entries": entries,
}
def list_skill_files_from_legacy_query(
self,
*,
name: str | None,
relative_path: str | None,
) -> dict:
return self.list_skill_files(name or "", relative_path or "")
def get_skill_file(self, name: str, relative_path: str | None = "SKILL.md") -> dict:
skill_name = str(name or "").strip()
skill_dir = self.resolve_local_skill_dir(skill_name)
target_file = self.resolve_skill_relative_path(
skill_dir,
relative_path,
expect_file=True,
)
if not self.is_editable_skill_file(target_file):
raise SkillsServiceError("Unsupported file type")
size = target_file.stat().st_size
if size > _SKILL_FILE_MAX_BYTES:
raise SkillsServiceError("File is too large")
try:
content = target_file.read_text(encoding="utf-8")
except UnicodeDecodeError as exc:
raise SkillsServiceError("File is not valid UTF-8 text") from exc
return {
"name": skill_name,
"path": self.skill_relative_path(skill_dir, target_file),
"content": content,
"size": size,
"editable": not SkillManager().is_plugin_skill(skill_name),
}
def get_skill_file_from_legacy_query(
self,
*,
name: str | None,
relative_path: str | None,
) -> dict:
return self.get_skill_file(name or "", relative_path or "SKILL.md")
async def update_skill_file(self, data: object) -> dict:
self._ensure_mutation_allowed()
payload = self._payload(data)
skill_name = str(payload.get("name") or "").strip()
relative_path = payload.get("path", "SKILL.md")
content = payload.get("content")
if not isinstance(content, str):
raise SkillsServiceError("Missing file content")
encoded = content.encode("utf-8")
if len(encoded) > _SKILL_FILE_MAX_BYTES:
raise SkillsServiceError("File content is too large")
skill_dir = self.resolve_local_skill_dir(skill_name)
if SkillManager().is_plugin_skill(skill_name):
raise SkillsServiceError("Plugin-provided skill is read-only.")
target_file = self.resolve_skill_relative_path(
skill_dir,
relative_path,
expect_file=True,
)
if not self.is_editable_skill_file(target_file):
raise SkillsServiceError("Unsupported file type")
target_file.write_text(content, encoding="utf-8")
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync edited skills to active sandboxes.")
return {
"name": skill_name,
"path": self.skill_relative_path(skill_dir, target_file),
"size": len(encoded),
}
def update_skill(self, data: object) -> dict:
self._ensure_mutation_allowed()
payload = self._payload(data)
name = payload.get("name")
active = payload.get("active", True)
if not name:
raise SkillsServiceError("Missing skill name")
SkillManager().set_skill_active(name, bool(active))
return {"name": name, "active": bool(active)}
async def delete_skill(self, data: object) -> dict:
self._ensure_mutation_allowed()
payload = self._payload(data)
name = payload.get("name")
if not name:
raise SkillsServiceError("Missing skill name")
SkillManager().delete_skill(name)
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync deleted skills to active sandboxes.")
return {"name": name}
async def get_neo_candidates(self, query: dict[str, Any]) -> SkillsOperationResult:
logger.info("[Neo] GET /skills/neo/candidates requested.")
status = query.get("status")
skill_key = query.get("skill_key")
limit = int(query.get("limit", 100))
offset = int(query.get("offset", 0))
async def _do(client):
candidates = await client.skills.list_candidates(
status=status,
skill_key=skill_key,
limit=limit,
offset=offset,
)
result = _to_jsonable(candidates)
total = result.get("total", "?") if isinstance(result, dict) else "?"
logger.info(f"[Neo] Candidates fetched: total={total}")
return result
return await self.with_neo_client(_do)
async def get_neo_candidates_from_legacy_query(
self,
*,
status: str | None,
skill_key: str | None,
limit: str | None,
offset: str | None,
) -> SkillsOperationResult:
return await self.get_neo_candidates(
self._legacy_query(
status=status,
skill_key=skill_key,
limit=limit,
offset=offset,
)
)
async def get_neo_releases(self, query: dict[str, Any]) -> SkillsOperationResult:
logger.info("[Neo] GET /skills/neo/releases requested.")
skill_key = query.get("skill_key")
stage = query.get("stage")
active_only = _to_bool(query.get("active_only"), False)
limit = int(query.get("limit", 100))
offset = int(query.get("offset", 0))
async def _do(client):
releases = await client.skills.list_releases(
skill_key=skill_key,
active_only=active_only,
stage=stage,
limit=limit,
offset=offset,
)
result = _to_jsonable(releases)
total = result.get("total", "?") if isinstance(result, dict) else "?"
logger.info(f"[Neo] Releases fetched: total={total}")
return result
return await self.with_neo_client(_do)
async def get_neo_releases_from_legacy_query(
self,
*,
skill_key: str | None,
stage: str | None,
active_only: str | None,
limit: str | None,
offset: str | None,
) -> SkillsOperationResult:
return await self.get_neo_releases(
self._legacy_query(
skill_key=skill_key,
stage=stage,
active_only=active_only,
limit=limit,
offset=offset,
)
)
async def get_neo_payload(self, query: dict[str, Any]) -> SkillsOperationResult:
logger.info("[Neo] GET /skills/neo/payload requested.")
payload_ref = query.get("payload_ref", "")
if not payload_ref:
return SkillsOperationResult(ok=False, message="Missing payload_ref")
async def _do(client):
payload = await client.skills.get_payload(payload_ref)
logger.info(f"[Neo] Payload fetched: ref={payload_ref}")
return payload
return await self.with_neo_client(_do)
async def get_neo_payload_from_legacy_query(
self,
payload_ref: str | None,
) -> SkillsOperationResult:
return await self.get_neo_payload(self._legacy_query(payload_ref=payload_ref))
async def evaluate_neo_candidate(
self,
data: object,
) -> SkillsOperationResult:
self._ensure_mutation_allowed()
logger.info("[Neo] POST /skills/neo/evaluate requested.")
payload = self._payload(data)
candidate_id = payload.get("candidate_id")
passed_value = payload.get("passed")
if not candidate_id or passed_value is None:
return SkillsOperationResult(
ok=False,
message="Missing candidate_id or passed",
)
passed = _to_bool(passed_value, False)
async def _do(client):
result = await client.skills.evaluate_candidate(
candidate_id,
passed=passed,
score=payload.get("score"),
benchmark_id=payload.get("benchmark_id"),
report=payload.get("report"),
)
logger.info(
f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}"
)
return result
return await self.with_neo_client(_do)
async def promote_neo_candidate(self, data: object) -> SkillsOperationResult:
self._ensure_mutation_allowed()
logger.info("[Neo] POST /skills/neo/promote requested.")
payload = self._payload(data)
candidate_id = payload.get("candidate_id")
stage = payload.get("stage", "canary")
sync_to_local = _to_bool(payload.get("sync_to_local"), True)
if not candidate_id:
return SkillsOperationResult(ok=False, message="Missing candidate_id")
if stage not in {"canary", "stable"}:
return SkillsOperationResult(
ok=False,
message="Invalid stage, must be canary/stable",
)
async def _do(client):
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.promote_with_optional_sync(
client,
candidate_id=candidate_id,
stage=stage,
sync_to_local=sync_to_local,
)
release_json = result.get("release")
logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}")
sync_json = result.get("sync")
did_sync_to_local = bool(sync_json)
if did_sync_to_local:
logger.info(
"[Neo] Stable release synced to local: "
f"skill={sync_json.get('local_skill_name', '')}"
)
if result.get("sync_error"):
return SkillsOperationResult(
ok=False,
message=(
"Stable promote synced failed and has been rolled back. "
f"sync_error={result['sync_error']}"
),
data={
"release": release_json,
"rollback": result.get("rollback"),
},
)
if not did_sync_to_local:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync skills to active sandboxes.")
return {"release": release_json, "sync": sync_json}
return await self.with_neo_client(_do)
async def rollback_neo_release(self, data: object) -> SkillsOperationResult:
self._ensure_mutation_allowed()
logger.info("[Neo] POST /skills/neo/rollback requested.")
payload = self._payload(data)
release_id = payload.get("release_id")
if not release_id:
return SkillsOperationResult(ok=False, message="Missing release_id")
async def _do(client):
result = await client.skills.rollback_release(release_id)
logger.info(f"[Neo] Release rolled back: id={release_id}")
return result
return await self.with_neo_client(_do)
async def sync_neo_release(self, data: object) -> SkillsOperationResult:
self._ensure_mutation_allowed()
logger.info("[Neo] POST /skills/neo/sync requested.")
payload = self._payload(data)
release_id = payload.get("release_id")
skill_key = payload.get("skill_key")
require_stable = _to_bool(payload.get("require_stable"), True)
if not release_id and not skill_key:
return SkillsOperationResult(
ok=False,
message="Missing release_id or skill_key",
)
async def _do(client):
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.sync_release(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
)
logger.info(
f"[Neo] Release synced to local: skill={result.local_skill_name}, "
f"release_id={result.release_id}"
)
return {
"skill_key": result.skill_key,
"local_skill_name": result.local_skill_name,
"release_id": result.release_id,
"candidate_id": result.candidate_id,
"payload_ref": result.payload_ref,
"map_path": result.map_path,
"synced_at": result.synced_at,
}
return await self.with_neo_client(_do)
async def delete_neo_candidate(self, data: object) -> SkillsOperationResult:
self._ensure_mutation_allowed()
logger.info("[Neo] POST /skills/neo/delete-candidate requested.")
payload = self._payload(data)
candidate_id = payload.get("candidate_id")
reason = payload.get("reason")
if not candidate_id:
return SkillsOperationResult(ok=False, message="Missing candidate_id")
async def _do(client):
result = await client.skills.delete_candidate(candidate_id, reason=reason)
logger.info(f"[Neo] Candidate deleted: id={candidate_id}")
return result
return await self.with_neo_client(_do)
async def delete_neo_release(self, data: object) -> SkillsOperationResult:
self._ensure_mutation_allowed()
logger.info("[Neo] POST /skills/neo/delete-release requested.")
payload = self._payload(data)
release_id = payload.get("release_id")
reason = payload.get("reason")
if not release_id:
return SkillsOperationResult(ok=False, message="Missing release_id")
async def _do(client):
result = await client.skills.delete_release(release_id, reason=reason)
logger.info(f"[Neo] Release deleted: id={release_id}")
return result
return await self.with_neo_client(_do)
@staticmethod
def _legacy_query(**values: Any) -> dict[str, Any]:
return {
key: value
for key, value in values.items()
if value is not None and value != ""
}

View File

@@ -0,0 +1,532 @@
from __future__ import annotations
import asyncio
import re
import threading
import time
import traceback
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from functools import cmp_to_key
from pathlib import Path
import aiohttp
import psutil
from sqlmodel import col, select
from astrbot.core import DEMO_MODE, logger
from astrbot.core.config import VERSION
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.helper import check_migration_needed_v4
from astrbot.core.db.po import ProviderStat
from astrbot.core.utils.astrbot_path import get_astrbot_path
from astrbot.core.utils.auth_password import (
is_default_dashboard_password,
is_legacy_dashboard_password,
)
from astrbot.core.utils.io import get_dashboard_version
from astrbot.core.utils.storage_cleaner import StorageCleaner
from astrbot.core.utils.version_comparator import VersionComparator
from astrbot.dashboard.password_state import (
get_dashboard_password_hash,
is_password_change_required,
is_password_storage_upgraded,
)
class StatServiceError(Exception):
pass
class StatService:
def __init__(
self,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
config: AstrBotConfig,
) -> None:
self.db_helper = db_helper
self.core_lifecycle = core_lifecycle
self.config = config
self.storage_cleaner = StorageCleaner(config)
async def restart_core(self) -> None:
if DEMO_MODE:
raise StatServiceError(
"You are not permitted to do this operation in demo mode"
)
await self.core_lifecycle.restart()
@staticmethod
def get_running_time_components(total_seconds: int):
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
return {"hours": hours, "minutes": minutes, "seconds": seconds}
async def is_default_cred(self):
password_change_required = await is_password_change_required(
self.db_helper,
self.config,
)
if password_change_required:
return not DEMO_MODE
storage_upgraded = await is_password_storage_upgraded(
self.db_helper,
self.config,
)
if not storage_upgraded:
return False
username = self.config["dashboard"]["username"]
password = get_dashboard_password_hash(self.config, upgraded=True)
return (
username == "astrbot" and is_default_dashboard_password(password)
) and not DEMO_MODE
async def get_version(self) -> dict:
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
storage_upgraded = await is_password_storage_upgraded(
self.db_helper,
self.config,
)
password = get_dashboard_password_hash(
self.config,
upgraded=storage_upgraded,
)
return {
"version": VERSION,
"dashboard_version": await get_dashboard_version(),
"change_pwd_hint": await self.is_default_cred(),
"legacy_pwd_hint": is_legacy_dashboard_password(password),
"password_upgrade_required": not storage_upgraded,
"need_migration": need_migration,
}
def get_start_time(self) -> dict:
return {"start_time": self.core_lifecycle.start_time}
async def get_storage_status(self) -> dict:
try:
return await asyncio.to_thread(self.storage_cleaner.get_status)
except Exception as exc:
logger.error("获取存储占用失败", exc_info=True)
raise StatServiceError(
"获取存储占用失败,请查看后端日志了解详情。"
) from exc
async def cleanup_storage(self, target: str) -> dict:
try:
return await asyncio.to_thread(self.storage_cleaner.cleanup, target)
except ValueError as exc:
raise StatServiceError(str(exc)) from exc
except Exception as exc:
logger.error("清理存储失败", exc_info=True)
raise StatServiceError("清理存储失败,请查看后端日志了解详情。") from exc
async def cleanup_storage_from_legacy_payload(self, payload: object) -> dict:
target = "all"
if isinstance(payload, dict):
target = str(payload.get("target", "all"))
return await self.cleanup_storage(target)
async def get_stat(self, offset_sec: int) -> dict:
try:
stat = self.db_helper.get_base_stats(offset_sec)
now = int(time.time())
start_time = now - offset_sec
message_time_based_stats = []
idx = 0
for bucket_end in range(start_time, now, 3600):
cnt = 0
while (
idx < len(stat.platform)
and stat.platform[idx].timestamp < bucket_end
):
cnt += stat.platform[idx].count
idx += 1
message_time_based_stats.append([bucket_end, cnt])
stat_dict = stat.__dict__
cpu_percent = psutil.cpu_percent(interval=0.5)
thread_count = threading.active_count()
plugins = self.core_lifecycle.star_context.get_all_stars()
plugin_info = []
for plugin in plugins:
info = {
"name": getattr(plugin, "name", plugin.__class__.__name__),
"version": getattr(plugin, "version", "1.0.0"),
"is_enabled": True,
}
plugin_info.append(info)
running_time = self.get_running_time_components(
int(time.time()) - self.core_lifecycle.start_time,
)
stat_dict.update(
{
"platform": self.db_helper.get_grouped_base_stats(
offset_sec,
).platform,
"message_count": self.db_helper.get_total_message_count() or 0,
"platform_count": len(
self.core_lifecycle.platform_manager.get_insts(),
),
"plugin_count": len(plugins),
"plugins": plugin_info,
"message_time_series": message_time_based_stats,
"running": running_time,
"memory": {
"process": psutil.Process().memory_info().rss >> 20,
"system": psutil.virtual_memory().total >> 20,
},
"cpu_percent": round(cpu_percent, 1),
"thread_count": thread_count,
"start_time": self.core_lifecycle.start_time,
},
)
return stat_dict
except Exception as exc:
logger.error(traceback.format_exc())
raise StatServiceError(str(exc)) from exc
async def get_stat_from_legacy_query(self, offset_sec_raw) -> dict:
try:
offset_sec = int(offset_sec_raw if offset_sec_raw is not None else 86400)
except (TypeError, ValueError) as exc:
raise StatServiceError("offset_sec must be an integer") from exc
return await self.get_stat(offset_sec)
@staticmethod
def _ensure_aware_utc(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
async def get_provider_token_stats(self, days: int) -> dict:
try:
if days not in (1, 3, 7):
days = 1
local_tz = datetime.now().astimezone().tzinfo or timezone.utc
now_local = datetime.now(local_tz)
range_start_local = (now_local - timedelta(days=days)).replace(
minute=0, second=0, microsecond=0
)
today_start_local = now_local.replace(
hour=0, minute=0, second=0, microsecond=0
)
query_start_local = min(range_start_local, today_start_local)
query_start_utc = query_start_local.astimezone(timezone.utc)
async with self.db_helper.get_db() as session:
result = await session.execute(
select(ProviderStat)
.where(
ProviderStat.agent_type == "internal",
ProviderStat.created_at >= query_start_utc,
)
.order_by(col(ProviderStat.created_at).asc())
)
records = result.scalars().all()
bucket_timestamps: list[int] = []
bucket_cursor = range_start_local
while bucket_cursor <= now_local:
bucket_timestamps.append(int(bucket_cursor.timestamp() * 1000))
bucket_cursor += timedelta(hours=1)
trend_by_provider: dict[str, dict[int, int]] = defaultdict(
lambda: defaultdict(int)
)
total_by_provider: dict[str, int] = defaultdict(int)
total_by_umo: dict[str, int] = defaultdict(int)
total_by_bucket: dict[int, int] = defaultdict(int)
range_total_tokens = 0
range_total_output_tokens = 0
range_total_calls = 0
range_success_calls = 0
range_ttft_total_ms = 0.0
range_ttft_samples = 0
range_duration_total_ms = 0.0
range_duration_samples = 0
today_by_model: dict[str, int] = defaultdict(int)
today_by_provider: dict[str, int] = defaultdict(int)
today_total_tokens = 0
today_total_calls = 0
for record in records:
created_at_utc = self._ensure_aware_utc(record.created_at)
created_at_local = created_at_utc.astimezone(local_tz)
token_total = (
record.token_input_other
+ record.token_input_cached
+ record.token_output
)
provider_id = record.provider_id or "unknown"
provider_model = record.provider_model or "Unknown"
if created_at_local >= range_start_local:
bucket_local = created_at_local.replace(
minute=0, second=0, microsecond=0
)
bucket_ts = int(bucket_local.timestamp() * 1000)
trend_by_provider[provider_id][bucket_ts] += token_total
total_by_provider[provider_id] += token_total
total_by_umo[record.umo or "unknown"] += token_total
total_by_bucket[bucket_ts] += token_total
range_total_tokens += token_total
range_total_calls += 1
if record.status != "error":
range_success_calls += 1
if record.time_to_first_token > 0:
range_ttft_total_ms += record.time_to_first_token * 1000
range_ttft_samples += 1
if record.end_time > record.start_time:
range_duration_total_ms += (
record.end_time - record.start_time
) * 1000
range_duration_samples += 1
range_total_output_tokens += record.token_output
if created_at_local >= today_start_local:
today_total_calls += 1
today_total_tokens += token_total
today_by_model[provider_model] += token_total
today_by_provider[provider_id] += token_total
sorted_provider_ids = sorted(
total_by_provider.keys(),
key=lambda item: total_by_provider[item],
reverse=True,
)
series = [
{
"name": provider_id,
"data": [
[bucket_ts, trend_by_provider[provider_id].get(bucket_ts, 0)]
for bucket_ts in bucket_timestamps
],
"total_tokens": total_by_provider[provider_id],
}
for provider_id in sorted_provider_ids
]
total_series = [
[bucket_ts, total_by_bucket.get(bucket_ts, 0)]
for bucket_ts in bucket_timestamps
]
today_by_model_data = [
{"provider_model": model_name, "tokens": tokens}
for model_name, tokens in sorted(
today_by_model.items(),
key=lambda item: item[1],
reverse=True,
)
]
today_by_provider_data = [
{"provider_id": provider_id, "tokens": tokens}
for provider_id, tokens in sorted(
today_by_provider.items(),
key=lambda item: item[1],
reverse=True,
)
]
range_by_provider_data = [
{"provider_id": provider_id, "tokens": tokens}
for provider_id, tokens in sorted(
total_by_provider.items(),
key=lambda item: item[1],
reverse=True,
)
]
range_by_umo_data = [
{"umo": umo, "tokens": tokens}
for umo, tokens in sorted(
total_by_umo.items(),
key=lambda item: item[1],
reverse=True,
)
]
return {
"days": days,
"trend": {
"series": series,
"total_series": total_series,
},
"range_total_tokens": range_total_tokens,
"range_total_calls": range_total_calls,
"range_avg_ttft_ms": (
range_ttft_total_ms / range_ttft_samples
if range_ttft_samples
else 0
),
"range_avg_duration_ms": (
range_duration_total_ms / range_duration_samples
if range_duration_samples
else 0
),
"range_avg_tpm": (
range_total_output_tokens / (range_duration_total_ms / 1000 / 60)
if range_duration_total_ms > 0
else 0
),
"range_success_rate": (
range_success_calls / range_total_calls if range_total_calls else 0
),
"range_by_provider": range_by_provider_data,
"range_by_umo": range_by_umo_data,
"today_total_tokens": today_total_tokens,
"today_total_calls": today_total_calls,
"today_by_model": today_by_model_data,
"today_by_provider": today_by_provider_data,
}
except Exception as exc:
logger.error(traceback.format_exc())
raise StatServiceError(f"Error: {exc!s}") from exc
async def get_provider_token_stats_from_legacy_query(self, days_raw) -> dict:
try:
days = int(days_raw if days_raw is not None else 1)
except (TypeError, ValueError):
days = 1
return await self.get_provider_token_stats(days)
async def test_ghproxy_connection(self, proxy_url: str | None) -> dict:
try:
if not proxy_url:
raise StatServiceError("proxy_url is required")
proxy_url = proxy_url.rstrip("/")
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
start_time = time.time()
async with (
aiohttp.ClientSession() as session,
session.get(
test_url,
timeout=aiohttp.ClientTimeout(total=10),
) as response,
):
if response.status == 200:
end_time = time.time()
_ = await response.text()
return {
"latency": round((end_time - start_time) * 1000, 2),
}
raise StatServiceError(f"Failed. Status code: {response.status}")
except StatServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise StatServiceError(f"Error: {exc!s}") from exc
async def test_ghproxy_connection_from_legacy_payload(
self,
payload: object,
) -> dict:
proxy_url = payload.get("proxy_url") if isinstance(payload, dict) else None
return await self.test_ghproxy_connection(proxy_url)
def get_changelog(self, version: str | None) -> dict:
try:
if not version:
raise StatServiceError("version parameter is required")
version = version.lstrip("v")
if not re.match(r"^[a-zA-Z0-9._-]+$", version):
raise StatServiceError("Invalid version format")
if ".." in version or "/" in version or "\\" in version:
raise StatServiceError("Invalid version format")
changelogs_dir = (Path(get_astrbot_path()) / "changelogs").resolve()
changelog_path = (changelogs_dir / f"v{version}.md").resolve(strict=False)
if not changelog_path.is_relative_to(changelogs_dir):
logger.warning(
"Path traversal attempt detected: %s -> %s",
version,
changelog_path,
)
raise StatServiceError("Invalid version format")
if not changelog_path.is_file():
raise StatServiceError(f"Changelog for version {version} not found")
content = changelog_path.read_text(encoding="utf-8")
return {"content": content, "version": version}
except StatServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise StatServiceError(f"Error: {exc!s}") from exc
def list_changelog_versions(self) -> dict:
try:
changelogs_dir = Path(get_astrbot_path()) / "changelogs"
if not changelogs_dir.exists():
return {"versions": []}
versions = []
for path in changelogs_dir.iterdir():
filename = path.name
if filename.endswith(".md") and filename.startswith("v"):
version = filename[1:-3]
if re.match(r"^[a-zA-Z0-9._-]+$", version):
versions.append(version)
versions.sort(
key=cmp_to_key(
lambda v1, v2: VersionComparator.compare_version(v2, v1),
),
)
return {"versions": versions}
except Exception as exc:
logger.error(traceback.format_exc())
raise StatServiceError(f"Error: {exc!s}") from exc
def get_first_notice(self, locale: str | None) -> dict:
try:
locale = (locale or "").strip()
if not re.match(r"^[A-Za-z0-9_-]*$", locale):
locale = ""
base_path = Path(get_astrbot_path())
candidates: list[Path] = []
if locale:
candidates.append(base_path / f"FIRST_NOTICE.{locale}.md")
if locale.lower().startswith("zh"):
candidates.append(base_path / "FIRST_NOTICE.md")
candidates.append(base_path / "FIRST_NOTICE.zh-CN.md")
elif locale.lower().startswith("en"):
candidates.append(base_path / "FIRST_NOTICE.en-US.md")
candidates.extend(
[
base_path / "FIRST_NOTICE.md",
base_path / "FIRST_NOTICE.en-US.md",
],
)
for notice_path in candidates:
if not notice_path.is_file():
continue
content = notice_path.read_text(encoding="utf-8")
if content.strip():
return {"content": content}
return {"content": None}
except Exception as exc:
logger.error(traceback.format_exc())
raise StatServiceError(f"Error: {exc!s}") from exc

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
class StaticFileService:
INDEX_ROUTES = (
"/",
"/auth/login",
"/config",
"/logs",
"/extension",
"/dashboard/default",
"/alkaid",
"/alkaid/knowledge-base",
"/alkaid/long-term-memory",
"/alkaid/other",
"/console",
"/chat",
"/settings",
"/platforms",
"/providers",
"/about",
"/extension-marketplace",
"/conversation",
"/tool-use",
)
NOT_FOUND_MESSAGE = (
"404 Not found。如果你初次使用打开面板发现 404, 请参考文档: "
"https://docs.astrbot.app/faq.html。如果你正在测试回调地址可达性"
"显示这段文字说明测试成功了。"
)
def list_index_routes(self) -> tuple[str, ...]:
return self.INDEX_ROUTES
def get_not_found_message(self) -> str:
return self.NOT_FOUND_MESSAGE

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import traceback
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class SubAgentServiceError(Exception):
pass
class SubAgentService:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
self.core_lifecycle = core_lifecycle
def get_config(self) -> dict:
try:
config_data = self.core_lifecycle.astrbot_config.get(
"subagent_orchestrator"
)
return self._normalize_config(config_data)
except Exception as exc:
logger.error(traceback.format_exc())
raise SubAgentServiceError(f"获取 subagent 配置失败: {exc!s}") from exc
async def update_config(self, data: object) -> None:
try:
if not isinstance(data, dict):
raise SubAgentServiceError("配置必须为 JSON 对象")
config = self.core_lifecycle.astrbot_config
config["subagent_orchestrator"] = data
config.save_config()
orchestrator = getattr(self.core_lifecycle, "subagent_orchestrator", None)
if orchestrator is not None:
await orchestrator.reload_from_config(data)
except SubAgentServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise SubAgentServiceError(f"保存 subagent 配置失败: {exc!s}") from exc
def get_available_tools(self) -> list[dict]:
try:
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = []
for tool in tool_mgr.func_list:
if self._is_subagent_internal_tool(tool):
continue
tools.append(
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"active": tool.active,
"handler_module_path": tool.handler_module_path,
}
)
return tools
except Exception as exc:
logger.error(traceback.format_exc())
raise SubAgentServiceError(f"获取可用工具失败: {exc!s}") from exc
@staticmethod
def _normalize_config(data: object) -> dict:
if not isinstance(data, dict):
data = {
"main_enable": False,
"remove_main_duplicate_tools": False,
"agents": [],
}
if "main_enable" not in data and "enable" in data:
data["main_enable"] = bool(data.get("enable", False))
data.setdefault("main_enable", False)
data.setdefault("remove_main_duplicate_tools", False)
data.setdefault("agents", [])
if isinstance(data.get("agents"), list):
for agent in data["agents"]:
if isinstance(agent, dict):
agent.setdefault("provider_id", None)
agent.setdefault("persona_id", None)
return data
@staticmethod
def _is_subagent_internal_tool(tool) -> bool:
return (
isinstance(tool, HandoffTool)
or tool.handler_module_path == "core.subagent_orchestrator"
)

View File

@@ -0,0 +1,150 @@
from __future__ import annotations
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.t2i.template_manager import TemplateManager
class T2iServiceError(Exception):
def __init__(self, message: str, status_code: int = 500) -> None:
super().__init__(message)
self.status_code = status_code
class T2iService:
def __init__(
self,
core_lifecycle: AstrBotCoreLifecycle,
manager: TemplateManager | None = None,
) -> None:
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.manager = manager or TemplateManager()
async def reload_all_pipeline_schedulers(self) -> None:
for conf_id in self.core_lifecycle.astrbot_config_mgr.confs:
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
async def sync_active_template_to_all_configs(self, name: str) -> None:
for config in self.core_lifecycle.astrbot_config_mgr.confs.values():
config["t2i_active_template"] = name
config.save_config()
await self.reload_all_pipeline_schedulers()
def list_templates(self):
try:
return self.manager.list_templates()
except Exception as exc:
raise T2iServiceError(str(exc)) from exc
def get_active_template(self) -> dict:
try:
return {"active_template": self.config.get("t2i_active_template", "base")}
except Exception as exc:
logger.error("Error in get_active_template", exc_info=True)
raise T2iServiceError(str(exc)) from exc
def get_template(self, name: str) -> dict:
try:
return {"name": name, "content": self.manager.get_template(name)}
except FileNotFoundError as exc:
raise T2iServiceError("Template not found", 404) from exc
except Exception as exc:
raise T2iServiceError(str(exc)) from exc
def create_template(self, name: str | None, content: str | None) -> dict:
if not name or not content:
raise T2iServiceError("Name and content are required.", 400)
name = name.strip()
try:
self.manager.create_template(name, content)
except FileExistsError as exc:
raise T2iServiceError(
"Template with this name already exists.",
409,
) from exc
except ValueError as exc:
raise T2iServiceError(str(exc), 400) from exc
except Exception as exc:
raise T2iServiceError(str(exc)) from exc
return {"name": name}
def create_template_from_legacy_payload(self, data: object) -> dict:
payload = self._payload(data)
return self.create_template(payload.get("name"), payload.get("content"))
async def update_template(self, name: str, content: str | None) -> tuple[dict, str]:
name = name.strip()
if content is None:
raise T2iServiceError("Content is required.", 400)
try:
self.manager.update_template(name, content)
active_template = self.config.get("t2i_active_template", "base")
if name == active_template:
await self.reload_all_pipeline_schedulers()
message = f"模板 '{name}' 已更新并重新加载。"
else:
message = f"模板 '{name}' 已更新。"
except ValueError as exc:
raise T2iServiceError(str(exc), 400) from exc
except Exception as exc:
raise T2iServiceError(str(exc)) from exc
return {"name": name}, message
async def update_template_from_legacy_payload(
self,
name: str,
data: object,
) -> tuple[dict, str]:
payload = self._payload(data)
return await self.update_template(name, payload.get("content"))
def delete_template(self, name: str) -> None:
name = name.strip()
try:
self.manager.delete_template(name)
except FileNotFoundError as exc:
raise T2iServiceError("Template not found.", 404) from exc
except ValueError as exc:
raise T2iServiceError(str(exc), 400) from exc
except Exception as exc:
raise T2iServiceError(str(exc)) from exc
async def set_active_template(self, name: str | None) -> str:
if not name:
raise T2iServiceError("模板名称(name)不能为空。", 400)
try:
self.manager.get_template(name)
await self.sync_active_template_to_all_configs(name)
except FileNotFoundError as exc:
raise T2iServiceError(f"模板 '{name}' 不存在,无法应用。", 404) from exc
except Exception as exc:
logger.error("Error in set_active_template", exc_info=True)
raise T2iServiceError(str(exc)) from exc
return f"模板 '{name}' 已成功应用。"
async def set_active_template_from_legacy_payload(self, data: object) -> str:
payload = self._payload(data)
return await self.set_active_template(payload.get("name"))
async def reset_default_template(self) -> str:
try:
self.manager.reset_default_template()
await self.sync_active_template_to_all_configs("base")
except FileNotFoundError as exc:
raise T2iServiceError(str(exc), 404) from exc
except Exception as exc:
logger.error("Error in reset_default_template", exc_info=True)
raise T2iServiceError(str(exc)) from exc
return "Default template has been reset and activated."
@staticmethod
def _payload(data: object) -> dict:
return data if isinstance(data, dict) else {}

View File

@@ -0,0 +1,492 @@
from __future__ import annotations
import traceback
from typing import Any
from astrbot.core import logger
from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star import star_map
from astrbot.core.tools.registry import get_builtin_tool_config_statuses
class ToolsServiceError(Exception):
pass
class EmptyMcpServersError(ValueError):
pass
def extract_mcp_server_config(mcp_servers_value: object) -> dict:
if not isinstance(mcp_servers_value, dict):
raise ValueError("mcpServers must be a JSON object")
if not mcp_servers_value:
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
key_0 = next(iter(mcp_servers_value))
extracted = mcp_servers_value[key_0]
if not isinstance(extracted, dict):
raise ValueError(
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
"and each value is an object containing fields like command/url."
)
return extracted
class ToolsService:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
self.core_lifecycle = core_lifecycle
self.tool_mgr = core_lifecycle.provider_manager.llm_tools
def rollback_mcp_server(self, name: str) -> bool:
try:
rollback_config = self.tool_mgr.load_mcp_config()
if name in rollback_config["mcpServers"]:
rollback_config["mcpServers"].pop(name)
return self.tool_mgr.save_mcp_config(rollback_config)
return True
except Exception:
logger.error(traceback.format_exc())
return False
def get_mcp_servers(self) -> list[dict]:
try:
config = self.tool_mgr.load_mcp_config()
servers = []
mcp_servers = config.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
logger.warning(
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
)
mcp_servers = {}
for name, server_config in mcp_servers.items():
if not isinstance(server_config, dict):
logger.warning(
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
)
continue
server_info = {
"name": name,
"active": server_config.get("active", True),
}
for key, value in server_config.items():
if key != "active":
server_info[key] = value
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
if name_key == name:
mcp_client = runtime.client
server_info["tools"] = [tool.name for tool in mcp_client.tools]
server_info["errlogs"] = mcp_client.server_errlogs
break
else:
server_info["tools"] = []
servers.append(server_info)
return servers
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to get MCP server list: {exc!s}") from exc
async def add_mcp_server(self, server_data: Any) -> str:
try:
name = server_data.get("name", "")
if not name:
raise ToolsServiceError("Server name cannot be empty")
has_valid_config, server_config = self._build_server_config(server_data)
if not has_valid_config:
raise ToolsServiceError("A valid server configuration is required")
self._validate_server_config(server_config)
config = self.tool_mgr.load_mcp_config()
if name in config["mcpServers"]:
raise ToolsServiceError(f"Server {name} already exists")
try:
await self.tool_mgr.test_mcp_server_connection(server_config)
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"MCP connection test failed: {exc!s}") from exc
config["mcpServers"][name] = server_config
if self.tool_mgr.save_mcp_config(config):
await self._enable_added_server(name, server_config)
return f"Successfully added MCP server {name}"
raise ToolsServiceError("Failed to save configuration")
except ToolsServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to add MCP server: {exc!s}") from exc
async def update_mcp_server(self, server_data: Any) -> str:
try:
name = server_data.get("name", "")
old_name = server_data.get("oldName") or name
if not name:
raise ToolsServiceError("Server name cannot be empty")
config = self.tool_mgr.load_mcp_config()
if old_name not in config["mcpServers"]:
raise ToolsServiceError(f"Server {old_name} does not exist")
is_rename = name != old_name
if name in config["mcpServers"] and is_rename:
raise ToolsServiceError(f"Server {name} already exists")
old_config = config["mcpServers"][old_name]
old_active = (
old_config.get("active", True) if isinstance(old_config, dict) else True
)
active = server_data.get("active", old_active)
only_update_active, server_config = self._build_updated_server_config(
server_data,
old_config,
active,
)
self._validate_server_config(server_config)
if is_rename:
config["mcpServers"].pop(old_name)
config["mcpServers"][name] = server_config
else:
config["mcpServers"][name] = server_config
if self.tool_mgr.save_mcp_config(config):
await self._sync_updated_server_runtime(
name=name,
old_name=old_name,
active=active,
is_rename=is_rename,
only_update_active=only_update_active,
server_config=config["mcpServers"][name],
)
return f"Successfully updated MCP server {name}"
raise ToolsServiceError("Failed to save configuration")
except ToolsServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to update MCP server: {exc!s}") from exc
async def delete_mcp_server(self, server_data: Any) -> str:
try:
name = server_data.get("name", "")
if not name:
raise ToolsServiceError("Server name cannot be empty")
config = self.tool_mgr.load_mcp_config()
if name not in config["mcpServers"]:
raise ToolsServiceError(f"Server {name} does not exist")
del config["mcpServers"][name]
if self.tool_mgr.save_mcp_config(config):
if name in self.tool_mgr.mcp_server_runtime_view:
await self._disable_server(name)
return f"Successfully deleted MCP server {name}"
raise ToolsServiceError("Failed to save configuration")
except ToolsServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to delete MCP server: {exc!s}") from exc
async def test_mcp_connection(self, server_data: Any) -> list:
try:
config = server_data.get("mcp_server_config", None)
if not isinstance(config, dict) or not config:
raise ToolsServiceError("Invalid MCP server configuration")
if "mcpServers" in config:
mcp_servers = config["mcpServers"]
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
raise ToolsServiceError(
"Only one MCP server configuration can be tested at a time"
)
try:
config = extract_mcp_server_config(mcp_servers)
except EmptyMcpServersError as exc:
raise ToolsServiceError(
"MCP server configuration cannot be empty"
) from exc
except ValueError as exc:
raise ToolsServiceError(f"{exc!s}") from exc
elif not config:
raise ToolsServiceError("MCP server configuration cannot be empty")
self._validate_server_config(config)
return await self.tool_mgr.test_mcp_server_connection(config)
except ToolsServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to test MCP connection: {exc!s}") from exc
def get_tool_list(self) -> list[dict]:
try:
tools = list(self.tool_mgr.func_list)
existing_names = {tool.name for tool in tools}
for tool in self.tool_mgr.iter_builtin_tools():
if tool.name not in existing_names:
tools.append(tool)
config_entries = self._get_config_entries()
tools_dict = []
for tool in tools:
tools_dict.append(self._serialize_tool(tool, config_entries))
return tools_dict
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to get tool list: {exc!s}") from exc
def toggle_tool(self, data: Any) -> str:
try:
tool_name = data.get("name")
action = data.get("activate")
if not tool_name or action is None:
raise ToolsServiceError("Missing required parameters: name or activate")
if self.tool_mgr.is_builtin_tool(tool_name):
raise ToolsServiceError(
"Builtin tools are read-only and cannot be toggled."
)
if action:
try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as exc:
raise ToolsServiceError(
f"Failed to activate tool: {exc!s}"
) from exc
else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok:
return "Operation successful."
raise ToolsServiceError(
f"Tool {tool_name} does not exist or the operation failed."
)
except ToolsServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Failed to operate tool: {exc!s}") from exc
async def sync_provider(self, data: Any) -> str:
try:
provider_name = data.get("name")
match provider_name:
case "modelscope":
access_token = data.get("access_token", "")
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
case _:
raise ToolsServiceError(f"Unknown provider: {provider_name}")
return "Sync completed"
except ToolsServiceError:
raise
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(f"Sync failed: {exc!s}") from exc
@staticmethod
def _build_server_config(server_data: dict) -> tuple[bool, dict]:
has_valid_config = False
server_config = {"active": server_data.get("active", True)}
for key, value in server_data.items():
if key in ["name", "active", "tools", "errlogs"]:
continue
if key == "mcpServers":
try:
server_config = extract_mcp_server_config(server_data["mcpServers"])
except ValueError as exc:
raise ToolsServiceError(f"{exc!s}") from exc
else:
server_config[key] = value
has_valid_config = True
return has_valid_config, server_config
@staticmethod
def _build_updated_server_config(
server_data: dict,
old_config: object,
active: bool,
) -> tuple[bool, dict]:
server_config = {"active": active}
only_update_active = True
for key, value in server_data.items():
if key in ["name", "active", "tools", "errlogs", "oldName"]:
continue
if key == "mcpServers":
try:
server_config = extract_mcp_server_config(server_data["mcpServers"])
except ValueError as exc:
raise ToolsServiceError(f"{exc!s}") from exc
else:
server_config[key] = value
only_update_active = False
if only_update_active and isinstance(old_config, dict):
for key, value in old_config.items():
if key != "active":
server_config[key] = value
return only_update_active, server_config
@staticmethod
def _validate_server_config(server_config: dict) -> None:
try:
validate_mcp_stdio_config(server_config)
except ValueError as exc:
raise ToolsServiceError(f"{exc!s}") from exc
async def _enable_added_server(self, name: str, server_config: dict) -> None:
try:
await self.tool_mgr.enable_mcp_server(name, server_config, timeout=30)
except TimeoutError as exc:
rollback_ok = self.rollback_mcp_server(name)
err_msg = f"Timed out while enabling MCP server {name}."
if not rollback_ok:
err_msg += (
" Configuration rollback failed. Please check the config manually."
)
raise ToolsServiceError(err_msg) from exc
except Exception as exc:
logger.error(traceback.format_exc())
rollback_ok = self.rollback_mcp_server(name)
err_msg = f"Failed to enable MCP server {name}: {exc!s}"
if not rollback_ok:
err_msg += (
" Configuration rollback failed. Please check the config manually."
)
raise ToolsServiceError(err_msg) from exc
async def _sync_updated_server_runtime(
self,
*,
name: str,
old_name: str,
active: bool,
is_rename: bool,
only_update_active: bool,
server_config: dict,
) -> None:
if active:
if (
old_name in self.tool_mgr.mcp_server_runtime_view
or not only_update_active
or is_rename
):
await self._disable_server_before_enable(old_name)
await self._enable_updated_server(name, server_config)
elif old_name in self.tool_mgr.mcp_server_runtime_view:
await self._disable_server(old_name)
async def _disable_server_before_enable(self, old_name: str) -> None:
try:
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
except TimeoutError as exc:
raise ToolsServiceError(
f"Timed out while disabling MCP server {old_name} before enabling: {exc!s}"
) from exc
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(
f"Failed to disable MCP server {old_name} before enabling: {exc!s}"
) from exc
async def _enable_updated_server(self, name: str, server_config: dict) -> None:
try:
await self.tool_mgr.enable_mcp_server(name, server_config, timeout=30)
except TimeoutError as exc:
raise ToolsServiceError(
f"Timed out while enabling MCP server {name}."
) from exc
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(
f"Failed to enable MCP server {name}: {exc!s}"
) from exc
async def _disable_server(self, name: str) -> None:
try:
await self.tool_mgr.disable_mcp_server(name, timeout=10)
except TimeoutError as exc:
raise ToolsServiceError(
f"Timed out while disabling MCP server {name}."
) from exc
except Exception as exc:
logger.error(traceback.format_exc())
raise ToolsServiceError(
f"Failed to disable MCP server {name}: {exc!s}"
) from exc
def _get_config_entries(self) -> list[dict]:
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
conf_name_map = {conf["id"]: conf["name"] for conf in conf_list}
config_entries = []
for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items():
config_entries.append(
{
"conf_id": conf_id,
"conf_name": conf_name_map.get(conf_id, conf_id),
"config": conf,
}
)
return config_entries
def _serialize_tool(self, tool, config_entries: list[dict]) -> dict:
readonly = False
builtin_config_statuses = []
builtin_config_tags = []
if self.tool_mgr.is_builtin_tool(tool.name):
origin = "builtin"
origin_name = "AstrBot Core"
readonly = True
builtin_config_statuses = get_builtin_tool_config_statuses(
tool.name,
config_entries,
)
builtin_config_tags = [
status for status in builtin_config_statuses if status["enabled"]
]
elif isinstance(tool, MCPTool):
origin = "mcp"
origin_name = tool.mcp_server_name
elif tool.handler_module_path and star_map.get(tool.handler_module_path):
star = star_map[tool.handler_module_path]
origin = "plugin"
origin_name = star.name
else:
origin = "unknown"
origin_name = "unknown"
return {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"active": tool.active,
"origin": origin,
"origin_name": origin_name,
"readonly": readonly,
"builtin_config_statuses": builtin_config_statuses,
"builtin_config_tags": builtin_config_tags,
}

View File

@@ -0,0 +1,412 @@
from __future__ import annotations
import traceback
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from astrbot.core import (
DEMO_MODE as _DEMO_MODE,
)
from astrbot.core import (
logger,
)
from astrbot.core import (
pip_installer as _pip_installer,
)
from astrbot.core.config.default import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.migration.helper import (
check_migration_needed_v4 as _check_migration_needed_v4,
)
from astrbot.core.db.migration.helper import (
do_migration_v4 as _do_migration_v4,
)
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.io import (
download_dashboard as _download_dashboard,
)
from astrbot.core.utils.io import (
get_dashboard_version as _get_dashboard_version,
)
DEMO_MODE = _DEMO_MODE
pip_installer = _pip_installer
download_dashboard = _download_dashboard
get_dashboard_version = _get_dashboard_version
default_check_migration_needed_v4 = _check_migration_needed_v4
default_do_migration_v4 = _do_migration_v4
async def call_download_dashboard(*args, **kwargs):
return await download_dashboard(*args, **kwargs)
async def call_get_dashboard_version(*args, **kwargs):
return await get_dashboard_version(*args, **kwargs)
async def call_pip_install(*args, **kwargs):
return await pip_installer.install(*args, **kwargs)
async def call_check_migration_needed_v4(*args, **kwargs):
return await default_check_migration_needed_v4(*args, **kwargs)
async def call_do_migration_v4(*args, **kwargs):
return await default_do_migration_v4(*args, **kwargs)
@dataclass
class UpdateServiceResult:
data: Any = None
message: str | None = None
status: str = "ok"
headers: dict | None = None
class UpdateServiceError(Exception):
pass
class UpdateService:
def __init__(
self,
astrbot_updator: AstrBotUpdator,
core_lifecycle: AstrBotCoreLifecycle,
*,
download_dashboard_func: Callable[..., Awaitable[Any]],
get_dashboard_version_func: Callable[..., Awaitable[str]],
pip_install_func: Callable[..., Awaitable[Any]],
check_migration_needed_func: Callable[..., Awaitable[bool]],
do_migration_func: Callable[..., Awaitable[Any]],
demo_mode: bool,
clear_site_data_headers: dict,
) -> None:
self.astrbot_updator = astrbot_updator
self.core_lifecycle = core_lifecycle
self.download_dashboard = download_dashboard_func
self.get_dashboard_version = get_dashboard_version_func
self.pip_install = pip_install_func
self.check_migration_needed = check_migration_needed_func
self.do_migration = do_migration_func
self.demo_mode = demo_mode
self.clear_site_data_headers = clear_site_data_headers
self.update_progress: dict[str, dict] = {}
def get_update_progress(self, progress_id: str) -> UpdateServiceResult:
if not progress_id:
raise UpdateServiceError("缺少参数 id。")
progress = self.update_progress.get(progress_id)
if not progress:
return UpdateServiceResult(
data={"id": progress_id, "status": "idle"},
message="没有正在进行的更新。",
)
return UpdateServiceResult(data=progress)
def get_update_progress_from_legacy_query(
self,
progress_id: str | None,
) -> UpdateServiceResult:
return self.get_update_progress(progress_id or "")
async def do_migration_v4(self, data: object) -> UpdateServiceResult:
need_migration = await self.check_migration_needed(self.core_lifecycle.db)
if not need_migration:
return UpdateServiceResult(message="不需要进行迁移。")
try:
payload = data if isinstance(data, dict) else {}
platform_id_map = payload.get("platform_id_map", {})
await self.do_migration(
self.core_lifecycle.db,
platform_id_map,
self.core_lifecycle.astrbot_config,
)
return UpdateServiceResult(message="迁移成功。")
except Exception as exc:
logger.error(f"迁移失败: {traceback.format_exc()}")
raise UpdateServiceError(f"迁移失败: {exc!s}") from exc
async def check_update(self, update_type: str | None) -> UpdateServiceResult:
try:
dashboard_version = await self.get_dashboard_version()
if update_type == "dashboard":
return UpdateServiceResult(
data={
"has_new_version": dashboard_version != f"v{VERSION}",
"current_version": dashboard_version,
}
)
update_result = await self.astrbot_updator.check_update(None, None, False)
return UpdateServiceResult(
status="success",
message=str(update_result)
if update_result is not None
else "已经是最新版本了。",
data={
"version": f"v{VERSION}",
"has_new_version": update_result is not None,
"dashboard_version": dashboard_version,
"dashboard_has_new_version": bool(
dashboard_version and dashboard_version != f"v{VERSION}"
),
},
)
except Exception as exc:
logger.warning(f"检查更新失败: {exc!s} (不影响除项目更新外的正常使用)")
raise UpdateServiceError(exc.__str__()) from exc
async def check_update_from_legacy_query(
self,
update_type: str | None,
) -> UpdateServiceResult:
return await self.check_update(update_type)
async def get_releases(self) -> UpdateServiceResult:
try:
releases = await self.astrbot_updator.get_releases()
return UpdateServiceResult(data=releases)
except Exception as exc:
logger.error(f"/api/update/releases: {traceback.format_exc()}")
raise UpdateServiceError(exc.__str__()) from exc
async def update_project(self, data: object) -> UpdateServiceResult:
payload = data if isinstance(data, dict) else {}
version = payload.get("version", "")
reboot = payload.get("reboot", True)
progress_id = payload.get("progress_id") or uuid.uuid4().hex
if version == "" or version == "latest":
latest = True
version = ""
else:
latest = False
proxy: str = payload.get("proxy", None)
if proxy:
proxy = proxy.removesuffix("/")
self._init_update_progress(progress_id, version)
try:
self._set_update_stage(
progress_id,
"dashboard",
"running",
"正在下载 WebUI...",
0,
)
await self.download_dashboard(
latest=latest,
version=version,
proxy=proxy,
progress_callback=self._make_progress_callback(
progress_id,
"dashboard",
0,
45,
),
)
self._set_update_stage(
progress_id,
"dashboard",
"done",
"WebUI 下载完成。",
45,
)
self._set_update_stage(
progress_id,
"core",
"running",
"正在下载 AstrBot 项目代码...",
45,
)
await self.astrbot_updator.update(
latest=latest,
version=version,
proxy=proxy,
progress_callback=self._make_progress_callback(
progress_id,
"core",
45,
45,
),
)
self._set_update_stage(
progress_id,
"core",
"done",
"项目代码下载完成。",
90,
)
self._set_update_stage(
progress_id,
"dependencies",
"running",
"正在更新依赖...",
92,
)
logger.info("更新依赖中...")
try:
await self.pip_install(requirements_path="requirements.txt")
except Exception as exc:
logger.error(f"更新依赖失败: {exc}")
self._set_update_stage(
progress_id,
"dependencies",
"done",
"依赖更新完成。",
96,
)
if reboot:
self._set_update_stage(
progress_id,
"restart",
"running",
"更新成功,正在准备重启...",
98,
)
await self.core_lifecycle.restart()
message = "更新成功AstrBot 将在 2 秒内全量重启以应用新的代码。"
else:
message = "更新成功AstrBot 将在下次启动时应用新的代码。"
self.update_progress[progress_id].update(
{
"status": "success",
"stage": "done",
"message": message,
"overall_percent": 100,
},
)
return UpdateServiceResult(
message=message,
headers=self.clear_site_data_headers,
)
except Exception as exc:
self.update_progress[progress_id].update(
{
"status": "error",
"message": exc.__str__(),
},
)
logger.error(f"/api/update_project: {traceback.format_exc()}")
raise UpdateServiceError(exc.__str__()) from exc
async def update_dashboard(self) -> UpdateServiceResult:
try:
try:
await self.download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as exc:
logger.error(f"下载管理面板文件失败: {exc}")
raise UpdateServiceError(f"下载管理面板文件失败: {exc}") from exc
return UpdateServiceResult(
message="更新成功。刷新页面即可应用新版本面板。",
headers=self.clear_site_data_headers,
)
except UpdateServiceError:
raise
except Exception as exc:
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
raise UpdateServiceError(exc.__str__()) from exc
async def install_pip_package(self, data: object) -> UpdateServiceResult:
if self.demo_mode:
raise UpdateServiceError(
"You are not permitted to do this operation in demo mode"
)
payload = data if isinstance(data, dict) else {}
package = payload.get("package", "")
mirror = payload.get("mirror", None)
if not package:
raise UpdateServiceError("缺少参数 package 或不合法。")
try:
await self.pip_install(package, mirror=mirror)
return UpdateServiceResult(message="安装成功。")
except Exception as exc:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
raise UpdateServiceError(exc.__str__()) from exc
def _init_update_progress(self, progress_id: str, version: str) -> None:
self.update_progress[progress_id] = {
"id": progress_id,
"status": "running",
"stage": "preparing",
"version": version or "latest",
"message": "正在准备更新...",
"overall_percent": 0,
"stages": {
"dashboard": self._empty_stage("pending"),
"core": self._empty_stage("pending"),
},
}
@staticmethod
def _empty_stage(status: str = "pending") -> dict:
return {
"status": status,
"downloaded": 0,
"total": 0,
"percent": 0,
"speed": 0,
}
def _set_update_stage(
self,
progress_id: str,
stage: str,
status: str,
message: str,
overall_percent: int | None = None,
) -> None:
progress = self.update_progress.get(progress_id)
if not progress:
return
progress["stage"] = stage
progress["message"] = message
progress["stages"].setdefault(stage, self._empty_stage())
progress["stages"][stage]["status"] = status
if overall_percent is not None:
progress["overall_percent"] = overall_percent
@staticmethod
def _normalize_percent(value) -> int:
try:
percent = float(value or 0)
except (TypeError, ValueError):
return 0
if percent <= 1:
percent *= 100
return max(0, min(100, int(percent)))
def _make_progress_callback(
self,
progress_id: str,
stage: str,
stage_start: int,
stage_weight: int,
):
def _callback(payload: dict) -> None:
progress = self.update_progress.get(progress_id)
if not progress:
return
stage_percent = self._normalize_percent(payload.get("percent"))
progress["stage"] = stage
progress["stages"][stage] = {
"status": "running" if stage_percent < 100 else "done",
"downloaded": payload.get("downloaded", 0),
"total": payload.get("total", 0),
"percent": stage_percent,
"speed": payload.get("speed", 0),
}
progress["overall_percent"] = min(
99,
stage_start + int(stage_percent * stage_weight / 100),
)
return _callback

View File

@@ -0,0 +1 @@
"""FastAPI based OpenAPI v1 surface for AstrBot dashboard."""

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from types import SimpleNamespace
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.dashboard.services.config_service import (
BotConfigService,
ConfigProfileService,
ConfigRoutingService,
ProviderConfigService,
)
from astrbot.dashboard.services.open_api_service import OpenApiService
from astrbot.dashboard.services.plugin_service import PluginService
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from astrbot.dashboard.services.session_management_service import (
SessionManagementService,
)
from .compat_aliases import router as compat_alias_router
from .responses import ApiError, error
from .routers import build_v1_router
def create_v1_asgi_app(
*,
core_lifecycle: AstrBotCoreLifecycle,
db: BaseDatabase,
jwt_secret: str,
) -> FastAPI:
app = FastAPI(
title="AstrBot OpenAPI",
version="1.0.0",
openapi_url="/api/v1/openapi.json",
docs_url="/api/v1/docs",
redoc_url="/api/v1/redoc",
)
app.state.core_lifecycle = core_lifecycle
app.state.db = db
app.state.jwt_secret = jwt_secret
app.state.services = SimpleNamespace(
config_profiles=ConfigProfileService(core_lifecycle, db),
config_routes=ConfigRoutingService(core_lifecycle),
bots=BotConfigService(core_lifecycle),
providers=ProviderConfigService(core_lifecycle),
plugins=PluginService(core_lifecycle, core_lifecycle.plugin_manager),
open_api=OpenApiService(db, core_lifecycle),
sessions=SessionManagementService(core_lifecycle, db),
route_bridge=None,
)
app.state.services.route_bridge = DashboardRouteBridgeService(app, jwt_secret)
@app.exception_handler(ApiError)
async def api_error_handler(_request: Request, exc: ApiError):
return JSONResponse(
error(exc.message, exc.data),
status_code=exc.status_code,
)
@app.exception_handler(ValueError)
async def value_error_handler(_request: Request, exc: ValueError):
return JSONResponse(error(str(exc)), status_code=400)
@app.exception_handler(HTTPException)
async def http_error_handler(_request: Request, exc: HTTPException):
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
return JSONResponse(error(detail), status_code=exc.status_code)
app.include_router(compat_alias_router)
app.include_router(build_v1_router())
return app

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from dataclasses import dataclass
import jwt
from fastapi import Request
from astrbot.dashboard.services.api_key_service import ApiKeyService
from astrbot.dashboard.services.auth_service import ALL_OPEN_API_SCOPES
from .responses import ApiError
@dataclass(frozen=True)
class AuthContext:
username: str
scopes: list[str]
api_key_id: str | None = None
via: str = "jwt"
def _extract_raw_api_key(request: Request) -> str | None:
if key := request.query_params.get("api_key"):
return key.strip()
if key := request.query_params.get("key"):
return key.strip()
if key := request.headers.get("X-API-Key"):
return key.strip()
auth_header = request.headers.get("Authorization", "").strip()
if auth_header.startswith("ApiKey "):
return auth_header.removeprefix("ApiKey ").strip()
return None
async def _require_api_key_scope(
request: Request,
raw_key: str,
scope: str,
) -> AuthContext:
key_hash = ApiKeyService.hash_key(raw_key)
api_key = await request.app.state.db.get_active_api_key_by_hash(key_hash)
if not api_key:
raise ApiError("Invalid API key", status_code=401)
scopes = (
api_key.scopes
if isinstance(api_key.scopes, list)
else list(ALL_OPEN_API_SCOPES)
)
if "*" not in scopes and scope not in scopes:
raise ApiError("Insufficient API key scope", status_code=403)
await request.app.state.db.touch_api_key(api_key.key_id)
return AuthContext(
username=f"api_key:{api_key.key_id}",
scopes=scopes,
api_key_id=api_key.key_id,
via="api_key",
)
async def require_scope(request: Request, scope: str) -> AuthContext:
raw_key = _extract_raw_api_key(request)
if raw_key:
return await _require_api_key_scope(request, raw_key, scope)
auth_header = request.headers.get("Authorization", "").strip()
if not auth_header.startswith("Bearer "):
raise ApiError("Missing API key", status_code=401)
token = auth_header.removeprefix("Bearer ").strip()
try:
payload = jwt.decode(
token,
request.app.state.jwt_secret,
algorithms=["HS256"],
)
except jwt.ExpiredSignatureError as exc:
raise ApiError("Token expired", status_code=401) from exc
except jwt.InvalidTokenError as exc:
try:
return await _require_api_key_scope(request, token, scope)
except ApiError as api_key_exc:
raise api_key_exc from exc
username = payload.get("username")
if not isinstance(username, str) or not username.strip():
raise ApiError("Invalid token", status_code=401)
return AuthContext(username=username, scopes=["*"], via="jwt")

View File

@@ -0,0 +1,424 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from astrbot.dashboard.services.config_service import (
BotConfigService,
ConfigProfileService,
ConfigRoutingService,
ProviderConfigService,
)
from .auth import AuthContext, require_scope
from .responses import error, ok
from .schemas import BotConfigRequest, ProviderConfigRequest, ProviderSourceRequest
router = APIRouter(
prefix="/api",
tags=["Compatibility Aliases"],
include_in_schema=False,
)
async def require_config_scope(request: Request) -> AuthContext:
return await require_scope(request, "config")
def get_config_profile_service(request: Request) -> ConfigProfileService:
return request.app.state.services.config_profiles
def get_config_routing_service(request: Request) -> ConfigRoutingService:
return request.app.state.services.config_routes
def get_bot_service(request: Request) -> BotConfigService:
return request.app.state.services.bots
def get_provider_service(request: Request) -> ProviderConfigService:
return request.app.state.services.providers
async def _json_or_empty(request: Request) -> dict:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _compat_error(message: str):
return error(message)
@router.get("/config/default")
async def get_legacy_default_config(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
return ok(service.get_profile_schema())
@router.get("/config/abconfs")
async def list_legacy_config_profiles(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
return ok(service.list_profiles())
@router.post("/config/abconf/new")
async def create_legacy_config_profile(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
body = await _json_or_empty(request)
try:
return ok(
await service.create_profile(
body.get("name"),
body.get("config"),
),
"创建成功",
)
except ValueError as exc:
return _compat_error(str(exc))
@router.get("/config/abconf")
async def get_legacy_config_profile(
id: str | None = Query(default=None),
system_config: str = Query(default="0"),
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
if system_config.lower() == "1":
return ok(service.get_system_schema())
if not id:
return _compat_error("缺少配置文件 ID")
try:
return ok(service.get_profile(id))
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/abconf/delete")
async def delete_legacy_config_profile(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
body = await _json_or_empty(request)
config_id = body.get("id")
if not config_id:
return _compat_error("缺少配置文件 ID")
try:
service.delete_profile(str(config_id))
return ok(message="删除成功")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/abconf/update")
async def rename_legacy_config_profile(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
body = await _json_or_empty(request)
config_id = body.get("id")
if not config_id:
return _compat_error("缺少配置文件 ID")
try:
service.rename_profile(str(config_id), body.get("name"))
return ok(message="更新成功")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/astrbot/update")
async def update_legacy_astrbot_config(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_config_profile_service),
):
body = await _json_or_empty(request)
config = body.get("config")
config_id = body.get("conf_id")
if not isinstance(config, dict):
return _compat_error("Invalid config payload")
if not config_id:
return _compat_error("Config file None does not exist")
try:
message = await service.update_profile(
str(config_id),
config,
two_factor_code=request.headers.get("X-2FA-Code"),
)
return ok(message=message or "保存成功~")
except ValueError as exc:
return _compat_error(str(exc))
@router.get("/config/umo_abconf_routes")
async def get_legacy_config_routes(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_config_routing_service),
):
return ok(service.list_routes())
@router.post("/config/umo_abconf_route/update_all")
async def update_legacy_config_routes(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_config_routing_service),
):
body = await _json_or_empty(request)
try:
await service.replace_routes(body)
except ValueError:
return _compat_error("缺少或错误的路由表数据")
return ok(message="更新成功")
@router.post("/config/umo_abconf_route/update")
async def upsert_legacy_config_route(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_config_routing_service),
):
body = await _json_or_empty(request)
try:
await service.upsert_route(body)
except ValueError:
return _compat_error("缺少 UMO 或配置文件 ID")
return ok(message="更新成功")
@router.post("/config/umo_abconf_route/delete")
async def delete_legacy_config_route(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_config_routing_service),
):
body = await _json_or_empty(request)
try:
await service.delete_route(body)
except ValueError:
return _compat_error("缺少 UMO")
return ok(message="删除成功")
@router.get("/config/platform/list")
async def list_legacy_platforms(
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_bot_service),
):
return ok({"platforms": service.list_bots()["bots"]})
@router.post("/config/platform/new")
async def create_legacy_platform(
payload: BotConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_bot_service),
):
try:
await service.create_bot(payload.to_legacy_config())
return ok(message="新增平台配置成功~")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/platform/update")
async def update_legacy_platform(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_bot_service),
):
body = await _json_or_empty(request)
bot_id = body.get("id")
config = body.get("config")
if not bot_id or not isinstance(config, dict):
return _compat_error("参数错误")
try:
await service.update_bot(
str(bot_id),
BotConfigRequest(config=config).to_legacy_config(fallback_id=str(bot_id)),
)
return ok(message="更新平台配置成功~")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/platform/delete")
async def delete_legacy_platform(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_bot_service),
):
body = await _json_or_empty(request)
bot_id = body.get("id")
if not bot_id:
return _compat_error("缺少参数 id")
try:
await service.delete_bot(str(bot_id))
return ok(message="删除平台配置成功~")
except ValueError as exc:
return _compat_error(str(exc))
@router.get("/config/provider/template")
async def get_legacy_provider_template(
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
return ok(service.get_provider_schema())
@router.get("/config/provider/list")
async def list_legacy_providers(
provider_type: str | None = Query(default=None),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
if not provider_type:
return _compat_error("缺少参数 provider_type")
providers = []
seen_ids = set()
for item in provider_type.split(","):
for provider in service.list_providers(capability=item)["providers"]:
provider_id = provider.get("id")
if provider_id in seen_ids:
continue
seen_ids.add(provider_id)
providers.append(provider)
return ok(providers)
@router.post("/config/provider/new")
async def create_legacy_provider(
payload: ProviderConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
try:
await service.create_provider(payload.to_legacy_config())
return ok(message="新增服务提供商配置成功")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/provider/update")
async def update_legacy_provider(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
body = await _json_or_empty(request)
provider_id = body.get("id")
config = body.get("config")
if not provider_id or not isinstance(config, dict):
return _compat_error("参数错误")
try:
await service.update_provider(
str(provider_id),
ProviderConfigRequest(config=config).to_legacy_config(
fallback_id=str(provider_id),
),
)
return ok(message="更新成功,已经实时生效~")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/provider/delete")
async def delete_legacy_provider(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
body = await _json_or_empty(request)
provider_id = body.get("id")
if not provider_id:
return _compat_error("缺少参数 id")
try:
await service.delete_provider(str(provider_id))
return ok(message="删除成功,已经实时生效。")
except ValueError as exc:
return _compat_error(str(exc))
@router.get("/config/provider/check_one")
async def check_legacy_provider(
id: str | None = Query(default=None),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
if not id:
return _compat_error("Missing provider_id parameter")
try:
return ok(await service.test_provider(id))
except ValueError as exc:
return _compat_error(str(exc))
@router.get("/config/provider_sources/models")
async def list_legacy_provider_source_models(
source_id: str | None = Query(default=None),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
if not source_id:
return _compat_error("缺少参数 source_id")
try:
data = await service.list_provider_source_models(source_id)
data.pop("provider_source_id", None)
return ok(data)
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/provider_sources/update")
async def update_legacy_provider_source(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
body = await _json_or_empty(request)
source_id = body.get("original_id")
config = body.get("config") or body
if not source_id:
return _compat_error("缺少 original_id")
if not isinstance(config, dict):
return _compat_error("缺少或错误的配置数据")
try:
await service.upsert_provider_source(
str(source_id),
ProviderSourceRequest(config=config).to_legacy_config(
fallback_id=str(source_id),
),
)
return ok(message="更新 provider source 成功")
except ValueError as exc:
return _compat_error(str(exc))
@router.post("/config/provider_sources/delete")
async def delete_legacy_provider_source(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_provider_service),
):
body = await _json_or_empty(request)
source_id = body.get("id")
if not source_id:
return _compat_error("缺少 provider_source_id")
try:
await service.delete_provider_source(str(source_id))
return ok(message="删除 provider source 成功")
except ValueError as exc:
return _compat_error(str(exc))

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass
class ApiError(Exception):
message: str
status_code: int = 400
data: Any = None
def ok(data: Any = None, message: str | None = None) -> dict[str, Any]:
return {"status": "ok", "message": message, "data": {} if data is None else data}
def error(message: str, data: Any = None) -> dict[str, Any]:
payload: dict[str, Any] = {"status": "error", "message": message}
if data is not None:
payload["data"] = data
return payload

View File

@@ -0,0 +1,22 @@
from fastapi import APIRouter
from .bots import router as bots_router
from .compat import compat_routers
from .config_profiles import router as config_profiles_router
from .extensions import router as extensions_router
from .open_api_compat import router as open_api_compat_router
from .plugins import router as plugins_router
from .providers import router as providers_router
def build_v1_router() -> APIRouter:
router = APIRouter(prefix="/api/v1")
router.include_router(config_profiles_router)
router.include_router(bots_router)
router.include_router(providers_router)
router.include_router(plugins_router)
router.include_router(extensions_router)
router.include_router(open_api_compat_router)
for compat_router in compat_routers:
router.include_router(compat_router)
return router

View File

@@ -0,0 +1,188 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from astrbot.dashboard.services.config_service import BotConfigService
from ..auth import AuthContext, require_scope
from ..responses import ok
from ..schemas import BotConfigRequest, EnabledPatch
router = APIRouter(tags=["Bots"])
async def require_config_scope(request: Request) -> AuthContext:
return await require_scope(request, "config")
def get_service(request: Request) -> BotConfigService:
return request.app.state.services.bots
async def _json_or_empty(request: Request) -> dict:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _required_text(value: object, name: str) -> str:
text = str(value or "").strip()
if not text:
raise ValueError(f"Missing key: {name}")
return text
def _config_from_body(body: dict) -> dict:
config = body.get("config")
if isinstance(config, dict):
return config
return {
key: value
for key, value in body.items()
if key not in {"bot_id", "config", "enabled"}
}
@router.get("/bot-types")
async def list_bot_types(
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
return ok(service.list_bot_types())
@router.get("/bots")
async def list_bots(
enabled: bool | None = Query(default=None),
type_: str | None = Query(default=None, alias="type"),
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
return ok(service.list_bots(enabled=enabled, type_=type_))
@router.post("/bots")
async def create_bot(
payload: BotConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
await service.create_bot(payload.to_legacy_config())
return ok(message="新增平台配置成功~")
@router.get("/bots/stats")
async def list_bot_stats(
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
return ok(service.get_bot_stats())
@router.get("/bots/by-id")
async def get_bot_by_id(
bot_id: str = Query(...),
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
return ok(service.get_bot(bot_id))
@router.put("/bots/by-id")
async def update_bot_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
bot_id = _required_text(body.get("bot_id"), "bot_id")
await service.update_bot(
bot_id,
BotConfigRequest(config=_config_from_body(body)).to_legacy_config(
fallback_id=bot_id,
),
)
return ok(message="更新平台配置成功~")
@router.delete("/bots/by-id")
async def delete_bot_by_id(
bot_id: str = Query(...),
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
await service.delete_bot(bot_id)
return ok(message="删除平台配置成功~")
@router.patch("/bots/enabled")
async def set_bot_enabled_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
bot_id = _required_text(body.get("bot_id"), "bot_id")
await service.set_bot_enabled(bot_id, bool(body.get("enabled")))
return ok(message="更新平台配置成功~")
@router.post("/bots/test")
async def test_bot_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
):
body = await _json_or_empty(request)
bot_id = _required_text(body.get("bot_id"), "bot_id")
return ok({"id": bot_id, "status": "unsupported"})
@router.patch("/bots/{bot_id:path}/enabled")
async def set_bot_enabled(
bot_id: str,
payload: EnabledPatch,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
await service.set_bot_enabled(bot_id, payload.enabled)
return ok(message="更新平台配置成功~")
@router.post("/bots/{bot_id:path}/test")
async def test_bot(
bot_id: str,
_auth: AuthContext = Depends(require_config_scope),
):
return ok({"id": bot_id, "status": "unsupported"})
@router.get("/bots/{bot_id:path}")
async def get_bot(
bot_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
return ok(service.get_bot(bot_id))
@router.put("/bots/{bot_id:path}")
async def update_bot(
bot_id: str,
payload: BotConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
await service.update_bot(bot_id, payload.to_legacy_config(fallback_id=bot_id))
return ok(message="更新平台配置成功~")
@router.delete("/bots/{bot_id:path}")
async def delete_bot(
bot_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: BotConfigService = Depends(get_service),
):
await service.delete_bot(bot_id)
return ok(message="删除平台配置成功~")

View File

@@ -0,0 +1,19 @@
from .auth import router as auth_router
from .chat import router as chat_router
from .conversations import router as conversations_router
from .files import router as files_router
from .knowledge_bases import router as knowledge_bases_router
from .personas import router as personas_router
from .sessions import router as sessions_router
from .system import router as system_router
compat_routers = [
auth_router,
chat_router,
files_router,
conversations_router,
system_router,
sessions_router,
personas_router,
knowledge_bases_router,
]

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import (
auth_optional,
get_bridge,
require_config_scope,
require_system_scope,
)
router = APIRouter(tags=["Auth"])
@router.post("/auth/login")
async def login(
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, None, method="POST", target_path="/api/auth/login"
)
@router.post("/auth/logout")
async def logout(
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, None, method="POST", target_path="/api/auth/logout"
)
@router.get("/auth/setup-status")
async def setup_status(
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, None, method="GET", target_path="/api/auth/setup-status"
)
@router.post("/auth/setup")
async def setup(
request: Request,
auth: AuthContext | None = Depends(auth_optional),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
target_path = "/api/auth/setup-authenticated" if auth else "/api/auth/setup"
return await bridge.forward(request, auth, method="POST", target_path=target_path)
@router.post("/auth/totp/setup")
async def totp_setup(
request: Request,
auth: AuthContext | None = Depends(auth_optional),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/auth/totp/setup"
)
@router.post("/auth/totp/recovery")
async def totp_recovery(
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, None, method="POST", target_path="/api/auth/totp/recovery"
)
@router.get("/system-config/runtime")
async def get_system_config_runtime(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/config/get"
)
@router.patch("/auth/account")
async def update_account(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/auth/account/edit"
)
@router.get("/api-keys")
async def list_api_keys(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/apikey/list"
)
@router.post("/api-keys")
async def create_api_key(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/apikey/create"
)
@router.post("/api-keys/{key_id}/revoke")
async def revoke_api_key(
key_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/apikey/revoke",
json_body={"key_id": key_id},
)
@router.delete("/api-keys/{key_id}")
async def delete_api_key(
key_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/apikey/delete",
json_body={"key_id": key_id},
)

View File

@@ -0,0 +1,342 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import get_bridge, require_chat_scope, require_config_scope
from .common import (
json_or_empty as _json_or_empty,
)
router = APIRouter(tags=["Chat"])
@router.post("/bot-types/{bot_type}/registration")
async def register_bot_type(
bot_type: str,
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path=f"/api/platform/registration/{bot_type}",
)
@router.get("/chat/sessions/new")
async def create_chat_session(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/chat/new_session"
)
@router.post("/chat/sessions/batch-delete")
async def batch_delete_chat_sessions(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/chat/batch_delete_sessions"
)
@router.get("/chat/sessions/{session_id}")
async def get_chat_session(
session_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chat/get_session",
query={"session_id": session_id},
)
@router.patch("/chat/sessions/{session_id}")
async def update_chat_session(
session_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chat/update_session_display_name",
json_body={"session_id": session_id, **body},
)
@router.delete("/chat/sessions/{session_id}")
async def delete_chat_session(
session_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chat/delete_session",
query={"session_id": session_id},
)
@router.post("/chat/sessions/{session_id}/stop")
async def stop_chat_session(
session_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chat/stop",
json_body={"session_id": session_id},
)
@router.patch("/chat/sessions/{session_id}/messages/{message_id}")
async def update_chat_message(
session_id: str,
message_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chat/message/edit",
json_body={"session_id": session_id, "message_id": message_id, **body},
)
@router.post("/chat/sessions/{session_id}/messages/{message_id}/regenerate")
async def regenerate_chat_message(
session_id: str,
message_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chat/message/regenerate",
json_body={"session_id": session_id, "message_id": message_id, **body},
)
@router.get("/chat/configs")
async def chat_configs(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/config/abconfs"
)
@router.post("/chat/threads")
async def create_chat_thread(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/chat/thread/create"
)
@router.get("/chat/threads/{thread_id}")
async def get_chat_thread(
thread_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chat/thread/get",
query={"thread_id": thread_id},
)
@router.delete("/chat/threads/{thread_id}")
async def delete_chat_thread(
thread_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chat/thread/delete",
json_body={"thread_id": thread_id},
)
@router.post("/chat/threads/{thread_id}/messages")
async def send_chat_thread_message(
thread_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chat/thread/send",
json_body={"thread_id": thread_id, **body},
)
@router.get("/chat/projects")
async def list_chat_projects(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/chatui_project/list"
)
@router.post("/chat/projects")
async def create_chat_project(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/chatui_project/create"
)
@router.get("/chat/projects/{project_id}")
async def get_chat_project(
project_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chatui_project/get",
query={"project_id": project_id},
)
@router.patch("/chat/projects/{project_id}")
async def update_chat_project(
project_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chatui_project/update",
json_body={"project_id": project_id, **body},
)
@router.delete("/chat/projects/{project_id}")
async def delete_chat_project(
project_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chatui_project/delete",
query={"project_id": project_id},
)
@router.get("/chat/projects/{project_id}/sessions")
async def list_chat_project_sessions(
project_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chatui_project/get_sessions",
query={"project_id": project_id},
)
@router.post("/chat/projects/{project_id}/sessions/{session_id}")
async def add_chat_project_session(
project_id: str,
session_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chatui_project/add_session",
json_body={"project_id": project_id, "session_id": session_id},
)
@router.delete("/chat/projects/sessions/{session_id}")
async def remove_chat_project_session(
session_id: str,
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/chatui_project/remove_session",
json_body={"session_id": session_id},
)

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from fastapi import Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext, require_scope
def get_bridge(request: Request) -> DashboardRouteBridgeService:
return request.app.state.services.route_bridge
async def json_or_empty(request: Request) -> dict:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
async def auth_optional(request: Request) -> AuthContext | None:
auth_header = request.headers.get("Authorization", "")
has_api_key = bool(
request.query_params.get("api_key")
or request.query_params.get("key")
or request.headers.get("X-API-Key")
)
if auth_header.startswith(("Bearer ", "ApiKey ")) or has_api_key:
return await require_scope(request, "system")
return None
async def require_chat_scope(request: Request) -> AuthContext:
return await require_scope(request, "chat")
async def require_config_scope(request: Request) -> AuthContext:
return await require_scope(request, "config")
async def require_data_scope(request: Request) -> AuthContext:
return await require_scope(request, "data")
async def require_file_scope(request: Request) -> AuthContext:
return await require_scope(request, "file")
async def require_kb_scope(request: Request) -> AuthContext:
return await require_scope(request, "kb")
async def require_persona_scope(request: Request) -> AuthContext:
return await require_scope(request, "persona")
async def require_system_scope(request: Request) -> AuthContext:
return await require_scope(request, "system")

View File

@@ -0,0 +1,116 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import get_bridge, require_data_scope
from .common import json_or_empty as _json_or_empty
router = APIRouter(tags=["Conversations"])
@router.get("/conversations")
async def list_conversations(
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/conversation/list"
)
@router.post("/conversations/export")
async def export_conversations(
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/conversation/export"
)
@router.post("/conversations/batch-delete")
async def batch_delete_conversations(
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/conversation/delete"
)
@router.get("/conversations/{conversation_id}")
async def get_conversation(
conversation_id: str,
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
user_id = request.query_params.get("user_id")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/conversation/detail",
json_body={"user_id": user_id, "cid": conversation_id},
)
@router.patch("/conversations/{conversation_id}")
async def update_conversation(
conversation_id: str,
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
user_id = body.pop("user_id", None) or request.query_params.get("user_id")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/conversation/update",
json_body={"user_id": user_id, "cid": conversation_id, **body},
)
@router.delete("/conversations/{conversation_id}")
async def delete_conversation(
conversation_id: str,
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
user_id = request.query_params.get("user_id")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/conversation/delete",
json_body={"user_id": user_id, "cid": conversation_id},
)
@router.put("/conversations/{conversation_id}/messages")
async def update_conversation_messages(
conversation_id: str,
request: Request,
auth: AuthContext = Depends(require_data_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
user_id = body.pop("user_id", None) or request.query_params.get("user_id")
if "messages" in body and "history" not in body:
body["history"] = body.pop("messages")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/conversation/update_history",
json_body={"user_id": user_id, "cid": conversation_id, **body},
)

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import get_bridge, require_file_scope
router = APIRouter(tags=["Files"])
@router.post("/files")
async def upload_file(
request: Request,
auth: AuthContext = Depends(require_file_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/chat/post_file"
)
@router.get("/files/content")
async def get_file_by_name(
request: Request,
auth: AuthContext = Depends(require_file_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/chat/get_file"
)
@router.get("/files/tokens/{file_token}")
async def get_token_file(
file_token: str,
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
None,
method="GET",
target_path=f"/api/file/{file_token}",
)
@router.get("/files/{attachment_id}")
@router.get("/files/{attachment_id}/content")
async def get_file(
attachment_id: str,
request: Request,
auth: AuthContext = Depends(require_file_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/chat/get_attachment",
query={"attachment_id": attachment_id},
)
@router.delete("/files/{attachment_id}")
async def delete_file(
attachment_id: str,
_request: Request,
_auth: AuthContext = Depends(require_file_scope),
):
return {"status": "ok", "data": {"attachment_id": attachment_id}}

View File

@@ -0,0 +1,267 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import get_bridge, require_kb_scope
from .common import json_or_empty as _json_or_empty
router = APIRouter(tags=["Knowledge Bases"])
@router.get("/knowledge-bases")
async def list_knowledge_bases(
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(request, auth, method="GET", target_path="/api/kb/list")
@router.post("/knowledge-bases")
async def create_knowledge_base(
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/kb/create"
)
@router.get("/knowledge-bases/tasks/{task_id}")
async def get_knowledge_base_task(
task_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/kb/document/upload/progress",
query={"task_id": task_id},
)
@router.get("/knowledge-bases/{kb_id}")
async def get_knowledge_base(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/kb/get", query={"kb_id": kb_id}
)
@router.put("/knowledge-bases/{kb_id}")
async def update_knowledge_base(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/update",
json_body={"kb_id": kb_id, **body},
)
@router.delete("/knowledge-bases/{kb_id}")
async def delete_knowledge_base(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/delete",
json_body={"kb_id": kb_id},
)
@router.get("/knowledge-bases/{kb_id}/stats")
async def get_knowledge_base_stats(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/kb/stats",
query={"kb_id": kb_id},
)
@router.get("/knowledge-bases/{kb_id}/documents")
async def list_knowledge_base_documents(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
query = dict(request.query_params)
query["kb_id"] = kb_id
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/kb/document/list",
query=query,
)
@router.post("/knowledge-bases/{kb_id}/documents")
async def upload_knowledge_base_document(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/document/upload",
query={"kb_id": kb_id},
)
@router.post("/knowledge-bases/{kb_id}/documents/import")
async def import_knowledge_base_documents(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/document/import",
json_body={"kb_id": kb_id, **body},
)
@router.post("/knowledge-bases/{kb_id}/documents/import-url")
async def import_knowledge_base_document_url(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/document/upload/url",
json_body={"kb_id": kb_id, **body},
)
@router.get("/knowledge-bases/{kb_id}/documents/{document_id}")
async def get_knowledge_base_document(
kb_id: str,
document_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/kb/document/get",
query={"kb_id": kb_id, "doc_id": document_id},
)
@router.delete("/knowledge-bases/{kb_id}/documents/{document_id}")
async def delete_knowledge_base_document(
kb_id: str,
document_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/document/delete",
json_body={"kb_id": kb_id, "doc_id": document_id},
)
@router.get("/knowledge-bases/{kb_id}/chunks")
async def list_knowledge_base_chunks(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
query = dict(request.query_params)
query["kb_id"] = kb_id
if "document_id" in query and "doc_id" not in query:
query["doc_id"] = query.pop("document_id")
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/kb/chunk/list",
query=query,
)
@router.delete("/knowledge-bases/{kb_id}/chunks/{chunk_id}")
async def delete_knowledge_base_chunk(
kb_id: str,
chunk_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
document_id = request.query_params.get("document_id") or request.query_params.get(
"doc_id"
)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/chunk/delete",
json_body={"kb_id": kb_id, "chunk_id": chunk_id, "doc_id": document_id},
)
@router.post("/knowledge-bases/{kb_id}/retrieve")
async def retrieve_knowledge_base(
kb_id: str,
request: Request,
auth: AuthContext = Depends(require_kb_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/kb/retrieve",
json_body={"kb_id": kb_id, **body},
)

View File

@@ -0,0 +1,224 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import get_bridge, require_persona_scope
from .common import json_or_empty as _json_or_empty
router = APIRouter(tags=["Personas"])
@router.get("/personas/tree")
async def persona_tree(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/persona/folder/tree"
)
@router.get("/personas")
async def list_personas(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/persona/list"
)
@router.post("/personas")
async def create_persona(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/persona/create"
)
@router.get("/personas/by-id")
async def get_persona_by_id(
request: Request,
persona_id: str = Query(...),
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/detail",
json_body={"persona_id": persona_id},
)
@router.put("/personas/by-id")
async def update_persona_by_id(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
persona_id = str(body.get("persona_id") or "").strip()
if not persona_id:
raise ValueError("Missing key: persona_id")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/update",
json_body={
"persona_id": persona_id,
**{key: value for key, value in body.items() if key != "persona_id"},
},
)
@router.delete("/personas/by-id")
async def delete_persona_by_id(
request: Request,
persona_id: str = Query(...),
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/delete",
json_body={"persona_id": persona_id},
)
@router.get("/personas/{persona_id:path}")
async def get_persona(
persona_id: str,
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/detail",
json_body={"persona_id": persona_id},
)
@router.put("/personas/{persona_id:path}")
async def update_persona(
persona_id: str,
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/update",
json_body={"persona_id": persona_id, **body},
)
@router.delete("/personas/{persona_id:path}")
async def delete_persona(
persona_id: str,
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/delete",
json_body={"persona_id": persona_id},
)
@router.post("/personas/move")
async def move_persona(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/persona/move"
)
@router.post("/personas/reorder")
async def reorder_personas(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/persona/reorder"
)
@router.get("/persona-folders")
async def list_persona_folders(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/persona/folder/list"
)
@router.post("/persona-folders")
async def create_persona_folder(
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/persona/folder/create"
)
@router.put("/persona-folders/{folder_id}")
async def update_persona_folder(
folder_id: str,
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/folder/update",
json_body={"folder_id": folder_id, **body},
)
@router.delete("/persona-folders/{folder_id}")
async def delete_persona_folder(
folder_id: str,
request: Request,
auth: AuthContext = Depends(require_persona_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/persona/folder/delete",
json_body={"folder_id": folder_id},
)

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from astrbot.core import logger
from astrbot.dashboard.services.session_management_service import (
SessionManagementService,
SessionManagementServiceError,
)
from ...auth import AuthContext
from ...responses import error, ok
from ...schemas import (
BatchSessionProviderRequest,
BatchSessionServiceRequest,
SessionGroupRequest,
SessionRuleRequest,
UmoListRequest,
)
from .common import require_data_scope
router = APIRouter(tags=["Sessions"])
def get_service(request: Request) -> SessionManagementService:
return request.app.state.services.sessions
def _service_error(exc: SessionManagementServiceError) -> dict:
return error(str(exc))
def _unexpected_error(prefix: str, exc: Exception) -> dict:
logger.error(f"{prefix}: {exc!s}")
return error(f"{prefix}: {exc!s}")
@router.get("/sessions")
async def list_sessions(
page: int = Query(1),
page_size: int = Query(20),
search: str = Query(""),
message_type: str = Query("all"),
platform: str = Query(""),
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(
await service.list_all_umos_with_status(
page=page,
page_size=page_size,
search=search.strip(),
message_type=message_type,
platform=platform,
)
)
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("获取会话状态列表失败", exc)
@router.get("/sessions/active-umos")
async def list_active_umos(
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(await service.list_active_umos())
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("获取 UMO 列表失败", exc)
@router.get("/sessions/rules")
async def list_session_rules(
page: int = Query(1),
page_size: int = Query(10),
search: str = Query(""),
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(
await service.list_session_rules(
page=page,
page_size=page_size,
search=search.strip(),
)
)
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("获取规则列表失败", exc)
@router.post("/sessions/rules")
async def update_session_rule(
payload: SessionRuleRequest,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(
await service.update_session_rule(payload.model_dump(exclude_none=True))
)
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("更新会话规则失败", exc)
@router.post("/sessions/rules/delete")
async def delete_session_rule(
payload: UmoListRequest,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(
await service.delete_session_rules(payload.model_dump(exclude_none=True))
)
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("删除会话规则失败", exc)
@router.patch("/sessions/provider")
async def update_session_provider(
payload: BatchSessionProviderRequest,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(
await service.batch_update_provider(payload.model_dump(exclude_none=True))
)
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("批量更新 Provider 失败", exc)
@router.patch("/sessions/service")
async def update_session_service(
payload: BatchSessionServiceRequest,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(
await service.batch_update_service(payload.model_dump(exclude_none=True))
)
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("批量更新服务状态失败", exc)
@router.get("/session-groups")
async def list_session_groups(
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(service.list_groups())
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("获取分组列表失败", exc)
@router.post("/session-groups")
async def create_session_group(
payload: SessionGroupRequest,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(service.create_group(payload.model_dump(exclude_none=True)))
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("创建分组失败", exc)
@router.put("/session-groups/{group_id}")
async def update_session_group(
group_id: str,
payload: SessionGroupRequest,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
body = payload.model_dump(exclude_none=True)
return ok(service.update_group({"group_id": group_id, **body}))
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("更新分组失败", exc)
@router.delete("/session-groups/{group_id}")
async def delete_session_group(
group_id: str,
_auth: AuthContext = Depends(require_data_scope),
service: SessionManagementService = Depends(get_service),
):
try:
return ok(service.delete_group({"group_id": group_id}))
except SessionManagementServiceError as exc:
return _service_error(exc)
except Exception as exc:
return _unexpected_error("删除分组失败", exc)

View File

@@ -0,0 +1,647 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ...auth import AuthContext
from .common import (
get_bridge,
require_config_scope,
require_system_scope,
)
from .common import (
json_or_empty as _json_or_empty,
)
router = APIRouter(tags=["System"])
@router.get("/stats")
async def get_stats(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/stat/get"
)
@router.get("/stats/provider-tokens")
async def provider_tokens(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/stat/provider-tokens"
)
@router.get("/stats/version")
async def version(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/stat/version"
)
@router.get("/stats/first-notice")
async def first_notice(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/stat/first-notice"
)
@router.post("/stats/ghproxy/test")
async def test_ghproxy_connection(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/stat/test-ghproxy-connection",
)
@router.get("/changelogs")
async def list_changelog_versions(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/stat/changelog/list"
)
@router.get("/changelogs/{version}")
async def get_changelog(
version: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/stat/changelog",
query={"version": version},
)
@router.get("/stats/start-time")
async def start_time(
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, None, method="GET", target_path="/api/stat/start-time"
)
@router.get("/stats/storage")
async def storage_status(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/stat/storage"
)
@router.post("/stats/storage/cleanup")
async def cleanup_storage(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/stat/storage/cleanup"
)
@router.post("/system/restart")
async def restart_system(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/stat/restart-core"
)
@router.get("/logs/history")
async def log_history(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/log-history"
)
@router.get("/logs/live")
async def live_logs(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/live-log"
)
@router.get("/cron/jobs")
async def list_cron_jobs(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/cron/jobs"
)
@router.post("/cron/jobs")
async def create_cron_job(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/cron/jobs"
)
@router.patch("/cron/jobs/{job_id}")
async def update_cron_job(
job_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="PATCH",
target_path=f"/api/cron/jobs/{job_id}",
)
@router.delete("/cron/jobs/{job_id}")
async def delete_cron_job(
job_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="DELETE",
target_path=f"/api/cron/jobs/{job_id}",
)
@router.post("/cron/jobs/{job_id}/run")
async def run_cron_job(
job_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path=f"/api/cron/jobs/{job_id}/run",
)
@router.get("/trace/settings")
async def get_trace_settings(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/trace/settings"
)
@router.put("/trace/settings")
async def update_trace_settings(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/trace/settings"
)
@router.get("/updates/check")
async def check_updates(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/update/check"
)
@router.get("/updates/releases")
async def update_releases(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/update/releases"
)
@router.get("/updates/progress/{task_id}")
async def update_progress(
task_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/update/progress",
query={"id": task_id},
)
@router.post("/updates/core")
async def update_core(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/update/do"
)
@router.post("/updates/dashboard")
async def update_dashboard(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/update/dashboard"
)
@router.post("/pip/install")
async def install_pip_package(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/update/pip-install"
)
@router.post("/migrations")
async def run_migration(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/update/migration"
)
@router.get("/subagents/config")
async def get_subagent_config(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/subagent/config"
)
@router.put("/subagents/config")
async def update_subagent_config(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/subagent/config"
)
@router.get("/subagents/available-tools")
async def get_subagent_tools(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/subagent/available-tools"
)
@router.get("/backups")
async def list_backups(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/backup/list"
)
@router.post("/backups")
async def export_backup(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/backup/export"
)
@router.post("/backups/upload")
async def upload_backup(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/backup/upload"
)
@router.post("/backups/upload/init")
async def init_backup_upload(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/backup/upload/init"
)
@router.post("/backups/upload/chunk")
async def upload_backup_chunk(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/backup/upload/chunk"
)
@router.post("/backups/upload/complete")
async def complete_backup_upload(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/backup/upload/complete"
)
@router.post("/backups/upload/abort")
async def abort_backup_upload(
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/backup/upload/abort"
)
@router.get("/backups/tasks/{task_id}")
async def get_backup_progress(
task_id: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/backup/progress",
query={"task_id": task_id},
)
@router.get("/backups/{filename}")
async def download_backup(
filename: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/backup/download",
query={"filename": filename},
)
@router.patch("/backups/{filename}")
async def rename_backup(
filename: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/backup/rename",
json_body={"filename": filename, **body},
)
@router.delete("/backups/{filename}")
async def delete_backup(
filename: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/backup/delete",
json_body={"filename": filename},
)
@router.post("/backups/{filename}/check")
async def check_backup(
filename: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/backup/check",
json_body={"filename": filename, **body},
)
@router.post("/backups/{filename}/import")
async def import_backup(
filename: str,
request: Request,
auth: AuthContext = Depends(require_system_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/backup/import",
json_body={"filename": filename, **body},
)
@router.get("/t2i/templates")
async def list_t2i_templates(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/t2i/templates"
)
@router.post("/t2i/templates")
async def create_t2i_template(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/t2i/templates/create"
)
@router.get("/t2i/templates/active")
async def get_active_t2i_template(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="GET", target_path="/api/t2i/templates/active"
)
@router.put("/t2i/templates/active")
async def set_active_t2i_template(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/t2i/templates/set_active"
)
@router.post("/t2i/templates/default/reset")
async def reset_default_t2i_template(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request, auth, method="POST", target_path="/api/t2i/templates/reset_default"
)
@router.get("/t2i/templates/{name}")
async def get_t2i_template(
name: str,
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path=f"/api/t2i/templates/{name}",
)
@router.put("/t2i/templates/{name}")
async def update_t2i_template(
name: str,
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="PUT",
target_path=f"/api/t2i/templates/{name}",
)
@router.delete("/t2i/templates/{name}")
async def delete_t2i_template(
name: str,
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="DELETE",
target_path=f"/api/t2i/templates/{name}",
)

View File

@@ -0,0 +1,174 @@
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.config_service import (
ConfigProfileService,
ConfigRoutingService,
)
from ..auth import AuthContext, require_scope
from ..responses import ok
from ..schemas import (
ConfigProfileCreateRequest,
ConfigRoutesReplaceRequest,
ConfigRouteUpsertRequest,
RenameRequest,
)
router = APIRouter(tags=["Config Profiles"])
async def require_config_scope(request: Request) -> AuthContext:
return await require_scope(request, "config")
def get_service(request: Request) -> ConfigProfileService:
return request.app.state.services.config_profiles
def get_routing_service(request: Request) -> ConfigRoutingService:
return request.app.state.services.config_routes
@router.get("/config-profiles/schema")
async def get_config_profile_schema(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
return ok(service.get_profile_schema())
@router.get("/config-profiles")
async def list_config_profiles(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
return ok(service.list_profiles())
@router.post("/config-profiles")
async def create_config_profile(
payload: ConfigProfileCreateRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
return ok(await service.create_profile(payload.name, payload.config), "创建成功")
@router.get("/config-profiles/{config_id}")
async def get_config_profile(
config_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
return ok(service.get_profile(config_id))
@router.put("/config-profiles/{config_id}")
async def update_config_profile(
config_id: str,
payload: dict[str, Any],
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
message = await service.update_profile(
config_id,
payload,
two_factor_code=request.headers.get("X-2FA-Code"),
)
return ok(message=message or "保存成功")
@router.patch("/config-profiles/{config_id}")
async def rename_config_profile(
config_id: str,
payload: RenameRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
service.rename_profile(config_id, payload.name)
return ok(message="更新成功")
@router.delete("/config-profiles/{config_id}")
async def delete_config_profile(
config_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
service.delete_profile(config_id)
return ok(message="删除成功")
@router.get("/system-config/schema")
async def get_system_config_schema(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
return ok(service.get_system_schema())
@router.get("/system-config")
async def get_system_config(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
return ok(service.get_profile("default"))
@router.put("/system-config")
async def update_system_config(
payload: dict[str, Any],
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigProfileService = Depends(get_service),
):
message = await service.update_profile(
"default",
payload,
two_factor_code=request.headers.get("X-2FA-Code"),
)
return ok(message=message or "保存成功")
@router.get("/config-routes")
async def list_config_routes(
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_routing_service),
):
return ok(service.list_routes())
@router.put("/config-routes")
async def replace_config_routes(
payload: ConfigRoutesReplaceRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_routing_service),
):
await service.replace_route_mapping(payload.routing)
return ok(message="更新成功")
@router.put("/config-routes/{umo}")
async def upsert_config_route(
umo: str,
payload: ConfigRouteUpsertRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_routing_service),
):
await service.set_route(umo, payload.config_id)
return ok(message="更新成功")
@router.delete("/config-routes/{umo}")
async def delete_config_route(
umo: str,
_auth: AuthContext = Depends(require_config_scope),
service: ConfigRoutingService = Depends(get_routing_service),
):
await service.delete_route_by_umo(umo)
return ok(message="删除成功")

View File

@@ -0,0 +1,732 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Request
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ..auth import AuthContext, require_scope
router = APIRouter(tags=["Extension Components"])
def get_bridge(request: Request) -> DashboardRouteBridgeService:
return request.app.state.services.route_bridge
async def require_tool_scope(request: Request) -> AuthContext:
return await require_scope(request, "tool")
async def require_skill_scope(request: Request) -> AuthContext:
return await require_scope(request, "skill")
async def _json_or_empty(request: Request) -> dict:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _required_text(value: object, name: str) -> str:
text = str(value or "").strip()
if not text:
raise ValueError(f"Missing key: {name}")
return text
def _config_from_body(body: dict, id_key: str) -> dict:
config = body.get("config")
if isinstance(config, dict):
return dict(config)
return {
key: value
for key, value in body.items()
if key not in {id_key, "config", "enabled"}
}
@router.get("/commands")
async def list_commands(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/commands",
)
@router.get("/commands/conflicts")
async def list_command_conflicts(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/commands/conflicts",
)
@router.patch("/commands/{command_id}")
async def update_command(
command_id: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
if "enabled" in body:
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/commands/toggle",
json_body={
"handler_full_name": command_id,
"enabled": body["enabled"],
},
)
if "alias" in body:
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/commands/rename",
json_body={
"handler_full_name": command_id,
"new_name": body["alias"],
"aliases": body.get("aliases"),
},
)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/commands/permission",
json_body={
"handler_full_name": command_id,
"permission": body.get("permission_group"),
},
)
@router.get("/tools")
async def list_tools(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/tools/list",
)
@router.patch("/tools/{tool_id}/enabled")
async def set_tool_enabled(
tool_id: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/toggle-tool",
json_body={"name": tool_id, "activate": body.get("enabled")},
)
@router.get("/mcp/servers")
async def list_mcp_servers(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/tools/mcp/servers",
)
@router.post("/mcp/servers")
async def create_mcp_server(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
if "enabled" in body and "active" not in body:
body["active"] = body.pop("enabled")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/add",
json_body=body,
)
@router.put("/mcp/servers/by-name")
async def update_mcp_server_by_name(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
server_name = _required_text(body.get("server_name"), "server_name")
config = _config_from_body(body, "server_name")
if "enabled" in body and "active" not in config:
config["active"] = body["enabled"]
config.setdefault("name", server_name)
config.setdefault("oldName", server_name)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/update",
json_body=config,
)
@router.delete("/mcp/servers/by-name")
async def delete_mcp_server_by_name(
server_name: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/delete",
json_body={"name": server_name},
)
@router.patch("/mcp/servers/enabled")
async def set_mcp_server_enabled_by_name(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
server_name = _required_text(body.get("server_name"), "server_name")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/update",
json_body={
"name": server_name,
"oldName": server_name,
"active": body.get("enabled"),
},
)
@router.post("/mcp/servers/test")
async def test_mcp_server_by_name(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
server_name = _required_text(body.get("server_name"), "server_name")
config = body.get("mcp_server_config") or body.get("config")
config = dict(config) if isinstance(config, dict) else {"name": server_name}
config.setdefault("name", server_name)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/test",
json_body={"mcp_server_config": config},
)
@router.patch("/mcp/servers/{server_name:path}/enabled")
async def set_mcp_server_enabled(
server_name: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/update",
json_body={
"name": server_name,
"oldName": server_name,
"active": body.get("enabled"),
},
)
@router.post("/mcp/servers/{server_name:path}/test")
async def test_mcp_server(
server_name: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
config = body.get("mcp_server_config") or body or {"name": server_name}
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/test",
json_body={"mcp_server_config": config},
)
@router.put("/mcp/servers/{server_name:path}")
async def update_mcp_server(
server_name: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
if "enabled" in body and "active" not in body:
body["active"] = body.pop("enabled")
body.setdefault("name", server_name)
body.setdefault("oldName", server_name)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/update",
json_body=body,
)
@router.delete("/mcp/servers/{server_name:path}")
async def delete_mcp_server(
server_name: str,
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/delete",
json_body={"name": server_name},
)
@router.post("/mcp/providers/modelscope/sync")
async def sync_modelscope_mcp_servers(
request: Request,
auth: AuthContext = Depends(require_tool_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/tools/mcp/sync-provider",
json_body={
"name": "modelscope",
"access_token": body.get("access_token", ""),
},
)
@router.get("/skills")
async def list_skills(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills",
)
@router.post("/skills")
async def upload_skill(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/upload",
)
@router.post("/skills/batch")
async def upload_skills_batch(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/batch-upload",
)
@router.patch("/skills/by-name")
async def update_skill_by_name(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
skill_name = _required_text(body.get("skill_name"), "skill_name")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/update",
json_body={
"name": skill_name,
"active": body.get("enabled", body.get("active", True)),
},
)
@router.delete("/skills/by-name")
async def delete_skill_by_name(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/delete",
json_body={"name": skill_name},
)
@router.get("/skills/archive")
async def download_skill_by_name(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/download",
query={"name": skill_name},
)
@router.get("/skills/files")
async def list_skill_files_by_name(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
path = request.query_params.get("path", "")
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/files",
query={"name": skill_name, "path": path},
)
@router.get("/skills/file")
async def get_skill_file_by_name(
skill_name: str,
path: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/file",
query={"name": skill_name, "path": path},
)
@router.put("/skills/file")
async def update_skill_file_by_name(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
skill_name = _required_text(body.get("skill_name"), "skill_name")
path = _required_text(body.get("path"), "path")
content = str(body.get("content", ""))
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/file",
json_body={"name": skill_name, "path": path, "content": content},
)
@router.get("/skills/{skill_name:path}/archive")
async def download_skill(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/download",
query={"name": skill_name},
)
@router.get("/skills/{skill_name:path}/files")
async def list_skill_files(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
path = request.query_params.get("path", "")
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/files",
query={"name": skill_name, "path": path},
)
@router.get("/skills/{skill_name:path}/files/{file_path:path}")
async def get_skill_file(
skill_name: str,
file_path: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/file",
query={"name": skill_name, "path": file_path},
)
@router.put("/skills/{skill_name:path}/files/{file_path:path}")
async def update_skill_file(
skill_name: str,
file_path: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
content = (await request.body()).decode("utf-8")
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/file",
json_body={"name": skill_name, "path": file_path, "content": content},
)
@router.patch("/skills/{skill_name:path}")
async def update_skill(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/update",
json_body={
"name": skill_name,
"active": body.get("enabled", body.get("active", True)),
},
)
@router.delete("/skills/{skill_name:path}")
async def delete_skill(
skill_name: str,
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/delete",
json_body={"name": skill_name},
)
@router.get("/skills/neo/candidates")
async def list_neo_skill_candidates(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/neo/candidates",
)
@router.get("/skills/neo/releases")
async def list_neo_skill_releases(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/neo/releases",
)
@router.get("/skills/neo/payload")
async def get_neo_skill_payload(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/skills/neo/payload",
)
@router.post("/skills/neo/evaluate")
async def evaluate_neo_skill_candidate(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/neo/evaluate",
)
@router.post("/skills/neo/promote")
async def promote_neo_skill_candidate(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/neo/promote",
)
@router.post("/skills/neo/rollback")
async def rollback_neo_skill_release(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/neo/rollback",
)
@router.post("/skills/neo/sync")
async def sync_neo_skill_release(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/neo/sync",
)
@router.post("/skills/neo/candidates/delete")
async def delete_neo_skill_candidate(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/neo/delete-candidate",
)
@router.post("/skills/neo/releases/delete")
async def delete_neo_skill_release(
request: Request,
auth: AuthContext = Depends(require_skill_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/skills/neo/delete-release",
)

View File

@@ -0,0 +1,223 @@
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Request, WebSocket
from astrbot.dashboard.fastapi_compat import (
CompatG,
call_request_view,
call_websocket_view,
)
from astrbot.dashboard.services.open_api_service import (
OpenApiService,
OpenApiServiceError,
)
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ..auth import AuthContext, require_scope
from ..responses import ApiError, ok
router = APIRouter(tags=["Open API Compatibility"])
async def require_im_scope(request: Request) -> AuthContext:
return await require_scope(request, "im")
async def require_chat_scope(request: Request) -> AuthContext:
return await require_scope(request, "chat")
async def require_config_scope(request: Request) -> AuthContext:
return await require_scope(request, "config")
async def require_file_scope(request: Request) -> AuthContext:
return await require_scope(request, "file")
def get_bridge(request: Request) -> DashboardRouteBridgeService:
return request.app.state.services.route_bridge
def get_service(request: Request) -> OpenApiService:
return request.app.state.services.open_api
async def _json_or_empty(request: Request) -> dict[str, Any]:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _compat_g(auth: AuthContext) -> CompatG:
g_obj = CompatG()
g_obj.username = auth.username
return g_obj
async def _call_open_api_route(
request: Request,
auth: AuthContext,
handler_name: str,
):
route = getattr(request.app.state, "open_api_route", None)
app_adapter = getattr(request.app.state, "dashboard_app_adapter", None)
if route is None or app_adapter is None:
raise ApiError("OpenAPI compatibility route is unavailable", status_code=503)
return await call_request_view(
request,
app_adapter,
getattr(route, handler_name),
g_obj=_compat_g(auth),
)
@router.post("/chat")
async def chat(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
if auth.via == "api_key":
return await _call_open_api_route(request, auth, "chat_send")
return await bridge.forward(
request, auth, method="POST", target_path="/api/chat/send"
)
@router.get("/chat/sessions")
async def chat_sessions(
request: Request,
auth: AuthContext = Depends(require_chat_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
if auth.via == "api_key":
return await _call_open_api_route(request, auth, "get_chat_sessions")
return await bridge.forward(
request, auth, method="GET", target_path="/api/chat/sessions"
)
@router.get("/configs", include_in_schema=False)
async def get_chat_configs(
request: Request,
auth: AuthContext = Depends(require_config_scope),
):
return await _call_open_api_route(request, auth, "get_chat_configs")
@router.post("/file", include_in_schema=False)
async def upload_open_api_file(
request: Request,
auth: AuthContext = Depends(require_file_scope),
):
return await _call_open_api_route(request, auth, "openapi_upload_file")
@router.get("/file", include_in_schema=False)
async def get_open_api_file(
request: Request,
auth: AuthContext = Depends(require_file_scope),
):
return await _call_open_api_route(request, auth, "openapi_get_file")
@router.websocket("/chat/ws")
async def chat_ws(websocket: WebSocket) -> None:
route = getattr(websocket.app.state, "open_api_route", None)
app_adapter = getattr(websocket.app.state, "dashboard_app_adapter", None)
if route is not None and app_adapter is not None:
await call_websocket_view(websocket, app_adapter, route.chat_ws)
return
await websocket.accept()
await websocket.close(1011, "OpenAPI chat websocket route is unavailable")
async def _forward_route_websocket(websocket: WebSocket, target_path: str) -> None:
route_app = websocket.app.state.services.route_bridge.route_app
receive = getattr(websocket, "_receive")
send = getattr(websocket, "_send")
scope = {
**websocket.scope,
"path": target_path,
"raw_path": target_path.encode(),
}
await route_app(scope, receive, send)
@router.websocket("/live-chat/ws")
async def live_chat_ws(websocket: WebSocket) -> None:
await _forward_route_websocket(websocket, "/api/live_chat/ws")
@router.websocket("/unified-chat/ws")
async def unified_chat_ws(websocket: WebSocket) -> None:
await _forward_route_websocket(websocket, "/api/unified_chat/ws")
@router.post("/im/messages")
async def send_im_message(
request: Request,
_auth: AuthContext = Depends(require_im_scope),
service: OpenApiService = Depends(get_service),
):
body = await _json_or_empty(request)
try:
await service.send_message(body)
except OpenApiServiceError as exc:
raise ApiError(str(exc)) from exc
return ok()
@router.post("/im/message", include_in_schema=False)
async def send_legacy_im_message(
request: Request,
auth: AuthContext = Depends(require_im_scope),
):
return await _call_open_api_route(request, auth, "send_message")
@router.get("/im/bots")
async def list_im_bots(
request: Request,
_auth: AuthContext = Depends(require_im_scope),
service: OpenApiService = Depends(get_service),
):
return ok(service.get_bots())
async def _forward_platform_webhook(
webhook_uuid: str,
request: Request,
bridge: DashboardRouteBridgeService,
):
return await bridge.forward(
request,
None,
method=request.method,
target_path=f"/api/platform/webhook/{webhook_uuid}",
)
@router.get("/webhooks/platforms/{webhook_uuid}")
async def verify_platform_webhook(
webhook_uuid: str,
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _forward_platform_webhook(webhook_uuid, request, bridge)
@router.post("/webhooks/platforms/{webhook_uuid}")
async def receive_platform_webhook(
webhook_uuid: str,
request: Request,
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _forward_platform_webhook(webhook_uuid, request, bridge)

View File

@@ -0,0 +1,894 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from astrbot.dashboard.services.plugin_service import PluginService, PluginServiceError
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ..auth import AuthContext, require_scope
from ..responses import ApiError, ok
router = APIRouter(tags=["Plugins"])
async def require_plugin_scope(request: Request) -> AuthContext:
return await require_scope(request, "plugin")
def get_bridge(request: Request) -> DashboardRouteBridgeService:
return request.app.state.services.route_bridge
def get_service(request: Request) -> PluginService:
return request.app.state.services.plugins
async def _proxy_plugin_extension(
plugin_path: str,
request: Request,
auth: AuthContext,
bridge: DashboardRouteBridgeService,
):
return await bridge.forward(
request,
auth,
method=request.method,
target_path=f"/api/plug/{plugin_path.lstrip('/')}",
)
@router.get("/plugins/extensions/{plugin_path:path}")
async def get_plugin_extension_route(
plugin_path: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
@router.post("/plugins/extensions/{plugin_path:path}")
async def post_plugin_extension_route(
plugin_path: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
@router.put("/plugins/extensions/{plugin_path:path}")
async def put_plugin_extension_route(
plugin_path: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
@router.patch("/plugins/extensions/{plugin_path:path}")
async def patch_plugin_extension_route(
plugin_path: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
@router.delete("/plugins/extensions/{plugin_path:path}")
async def delete_plugin_extension_route(
plugin_path: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
async def _json_or_empty(request: Request) -> dict:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _required_text(value: object, name: str) -> str:
text = str(value or "").strip()
if not text:
raise ValueError(f"Missing key: {name}")
return text
def _plugin_id_from_body(body: dict) -> str:
return _required_text(body.get("plugin_id"), "plugin_id")
@router.get("/plugins/failed")
async def list_failed_plugins(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/source/get-failed-plugins",
)
@router.post("/plugins/update")
async def update_plugins(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
if body.get("plugin_id"):
plugin_id = _plugin_id_from_body(body)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/update",
json_body={
"name": plugin_id,
**{key: value for key, value in body.items() if key != "plugin_id"},
},
)
legacy_body = {
**body,
"names": body.get("names") or body.get("plugin_ids") or [],
}
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/update-all",
json_body=legacy_body,
)
@router.post("/plugins/compatibility/check")
async def check_plugin_compatibility(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/check-compat",
)
@router.post("/plugins/install/github")
async def install_plugin_from_github(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
repository = str(body.get("repository") or body.get("url") or "").strip()
if repository and not repository.startswith(("http://", "https://")):
repository = f"https://github.com/{repository}"
legacy_body = {
"url": repository,
"proxy": body.get("proxy"),
"ignore_version_check": body.get("ignore_version_check", False),
}
if body.get("download_url"):
legacy_body["download_url"] = body["download_url"]
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/install",
json_body=legacy_body,
)
@router.post("/plugins/install/url")
async def install_plugin_from_url(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
url = str(body.get("url") or "").strip()
legacy_body = {
"url": body.get("repository") or url,
"download_url": url,
"proxy": body.get("proxy"),
"ignore_version_check": body.get("ignore_version_check", False),
}
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/install",
json_body=legacy_body,
)
@router.post("/plugins/install/upload")
async def install_plugin_from_upload(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/install-upload",
)
@router.get("/plugins/market")
async def list_plugin_market(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/market_list",
)
@router.get("/plugins/market/categories")
async def list_plugin_market_categories(
_request: Request,
_auth: AuthContext = Depends(require_plugin_scope),
):
return ok({"categories": []})
@router.get("/plugin-sources")
async def list_plugin_sources(
_request: Request,
_auth: AuthContext = Depends(require_plugin_scope),
service: PluginService = Depends(get_service),
):
return ok({"sources": await service.get_custom_sources()})
@router.post("/plugin-sources")
async def create_plugin_source(
request: Request,
_auth: AuthContext = Depends(require_plugin_scope),
service: PluginService = Depends(get_service),
):
body = await _json_or_empty(request)
try:
sources = await service.create_custom_source(body)
except PluginServiceError as exc:
raise ApiError(str(exc)) from exc
return ok({"sources": sources}, message="保存成功")
@router.put("/plugin-sources")
async def replace_plugin_sources(
request: Request,
_auth: AuthContext = Depends(require_plugin_scope),
service: PluginService = Depends(get_service),
):
body = await _json_or_empty(request)
try:
sources = await service.replace_custom_sources(body)
except PluginServiceError as exc:
raise ApiError(str(exc)) from exc
return ok({"sources": sources}, message="保存成功")
@router.delete("/plugin-sources/by-id")
async def delete_plugin_source_by_id(
source_id: str = Query(...),
_auth: AuthContext = Depends(require_plugin_scope),
service: PluginService = Depends(get_service),
):
return ok(
{"sources": await service.delete_custom_source(source_id)},
message="保存成功",
)
@router.delete("/plugin-sources/{source_id}")
async def delete_plugin_source(
source_id: str,
_request: Request,
_auth: AuthContext = Depends(require_plugin_scope),
service: PluginService = Depends(get_service),
):
return ok(
{"sources": await service.delete_custom_source(source_id)},
message="保存成功",
)
@router.get("/plugins/page-bridge-sdk.js")
async def get_plugin_page_bridge_sdk(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/page/bridge-sdk.js",
)
@router.get("/plugins")
async def list_plugins(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/get",
)
@router.get("/plugins/by-id")
async def get_plugin_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/detail",
query={"name": plugin_id},
)
@router.delete("/plugins/by-id")
async def uninstall_plugin_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/uninstall",
json_body={"name": plugin_id, **body},
)
@router.get("/plugins/config")
async def get_plugin_config_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/config/get",
query={"plugin_name": plugin_id},
)
@router.put("/plugins/config")
async def update_plugin_config_by_id(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
plugin_id = _plugin_id_from_body(body)
config = body.get("config")
config = config if isinstance(config, dict) else {}
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/plugin/update",
query={"plugin_name": plugin_id},
json_body=config,
)
@router.get("/plugins/config/schema")
async def get_plugin_config_schema_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/config/get",
query={"plugin_name": plugin_id},
)
@router.get("/plugins/config-files")
async def list_plugin_config_files_by_id(
request: Request,
plugin_id: str = Query(...),
config_key: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/config/file/get",
query={"scope": "plugin", "name": plugin_id, "key": config_key},
)
@router.post("/plugins/config-files")
async def upload_plugin_config_files_by_id(
request: Request,
plugin_id: str = Query(...),
config_key: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/file/upload",
query={"scope": "plugin", "name": plugin_id, "key": config_key},
)
@router.delete("/plugins/config-files")
async def delete_plugin_config_file_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/file/delete",
query={"scope": "plugin", "name": plugin_id},
json_body=body,
)
@router.get("/plugins/readme")
async def get_plugin_readme_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/readme",
query={"name": plugin_id},
)
@router.get("/plugins/changelog")
async def get_plugin_changelog_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/changelog",
query={"name": plugin_id},
)
@router.post("/plugins/reload")
async def reload_plugin_by_id(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
plugin_id = _plugin_id_from_body(body)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/reload",
json_body={"name": plugin_id},
)
@router.patch("/plugins/enabled")
async def set_plugin_enabled_by_id(
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
plugin_id = _plugin_id_from_body(body)
target_path = "/api/plugin/on" if body.get("enabled") else "/api/plugin/off"
return await bridge.forward(
request,
auth,
method="POST",
target_path=target_path,
json_body={"name": plugin_id},
)
@router.get("/plugins/pages")
async def list_plugin_pages_by_id(
request: Request,
plugin_id: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/detail",
query={"name": plugin_id},
)
@router.get("/plugins/page")
async def get_plugin_page_by_id(
request: Request,
plugin_id: str = Query(...),
page_name: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/page/entry",
query={"name": plugin_id, "page": page_name},
)
@router.get("/plugins/page/assets")
async def get_plugin_page_asset_by_id(
request: Request,
plugin_id: str = Query(...),
page_name: str = Query(...),
asset_path: str = Query(...),
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path=f"/api/plugin/page/content/{plugin_id}/{page_name}/{asset_path}",
)
@router.get("/plugins/{plugin_id}")
async def get_plugin(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/detail",
query={"name": plugin_id},
)
@router.delete("/plugins/{plugin_id}")
async def uninstall_plugin(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/uninstall",
json_body={"name": plugin_id, **body},
)
@router.delete("/plugins/failed/{plugin_id}")
async def uninstall_failed_plugin(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/uninstall-failed",
json_body={"dir_name": plugin_id, **body},
)
@router.post("/plugins/failed/{plugin_id}/reload")
async def reload_failed_plugin(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/reload-failed",
json_body={"dir_name": plugin_id},
)
@router.get("/plugins/{plugin_id}/config")
async def get_plugin_config(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/config/get",
query={"plugin_name": plugin_id},
)
@router.put("/plugins/{plugin_id}/config")
async def update_plugin_config(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/plugin/update",
query={"plugin_name": plugin_id},
)
@router.get("/plugins/{plugin_id}/config/schema")
async def get_plugin_config_schema(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/config/get",
query={"plugin_name": plugin_id},
)
@router.get("/plugins/{plugin_id}/config-files/{config_key:path}")
async def list_plugin_config_files(
plugin_id: str,
config_key: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/config/file/get",
query={"scope": "plugin", "name": plugin_id, "key": config_key},
)
@router.post("/plugins/{plugin_id}/config-files/{config_key:path}")
async def upload_plugin_config_files(
plugin_id: str,
config_key: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/file/upload",
query={"scope": "plugin", "name": plugin_id, "key": config_key},
)
@router.delete("/plugins/{plugin_id}/config-files")
async def delete_plugin_config_file(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/file/delete",
query={"scope": "plugin", "name": plugin_id},
json_body=body,
)
@router.get("/plugins/{plugin_id}/readme")
async def get_plugin_readme(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/readme",
query={"name": plugin_id},
)
@router.get("/plugins/{plugin_id}/changelog")
async def get_plugin_changelog(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/changelog",
query={"name": plugin_id},
)
@router.post("/plugins/{plugin_id}/reload")
async def reload_plugin(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/reload",
json_body={"name": plugin_id},
)
@router.patch("/plugins/{plugin_id}/enabled")
async def set_plugin_enabled(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
target_path = "/api/plugin/on" if body.get("enabled") else "/api/plugin/off"
return await bridge.forward(
request,
auth,
method="POST",
target_path=target_path,
json_body={"name": plugin_id},
)
@router.post("/plugins/{plugin_id}/update")
async def update_plugin(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/plugin/update",
json_body={"name": plugin_id, **body},
)
@router.get("/plugins/{plugin_id}/pages")
async def list_plugin_pages(
plugin_id: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/detail",
query={"name": plugin_id},
)
@router.get("/plugins/{plugin_id}/pages/{page_name}")
async def get_plugin_page(
plugin_id: str,
page_name: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path="/api/plugin/page/entry",
query={"name": plugin_id, "page": page_name},
)
@router.get("/plugins/{plugin_id}/pages/{page_name}/assets/{asset_path:path}")
async def get_plugin_page_asset(
plugin_id: str,
page_name: str,
asset_path: str,
request: Request,
auth: AuthContext = Depends(require_plugin_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
return await bridge.forward(
request,
auth,
method="GET",
target_path=f"/api/plugin/page/content/{plugin_id}/{page_name}/{asset_path}",
)

View File

@@ -0,0 +1,403 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, Query, Request
from astrbot.dashboard.services.config_service import ProviderConfigService
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
from ..auth import AuthContext, require_scope
from ..responses import ok
from ..schemas import EnabledPatch, ProviderConfigRequest, ProviderSourceRequest
router = APIRouter(tags=["Providers"])
async def require_config_scope(request: Request) -> AuthContext:
return await require_scope(request, "config")
def get_service(request: Request) -> ProviderConfigService:
return request.app.state.services.providers
def get_bridge(request: Request) -> DashboardRouteBridgeService:
return request.app.state.services.route_bridge
async def _json_or_empty(request: Request) -> dict:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _required_text(value: object, name: str) -> str:
text = str(value or "").strip()
if not text:
raise ValueError(f"Missing key: {name}")
return text
def _config_from_body(body: dict) -> dict:
config = body.get("config")
if isinstance(config, dict):
return config
return {
key: value
for key, value in body.items()
if key
not in {
"provider_id",
"source_id",
"config",
"enabled",
"provider_config",
}
}
@router.get("/providers/schema")
async def get_provider_schema(
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.get_provider_schema())
@router.get("/provider-sources")
async def list_provider_sources(
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.list_provider_sources())
@router.post("/provider-sources")
async def create_provider_source(
payload: ProviderSourceRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
config = payload.to_legacy_config()
source_id = config.get("id")
if not source_id:
raise ValueError("Provider source config must have an 'id' field")
await service.upsert_provider_source(source_id, config)
return ok(message="更新 provider source 成功")
@router.get("/provider-sources/by-id")
async def get_provider_source_by_id(
source_id: str = Query(...),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.get_provider_source(source_id))
@router.put("/provider-sources/by-id")
async def upsert_provider_source_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
source_id = _required_text(body.get("source_id"), "source_id")
await service.upsert_provider_source(
source_id,
ProviderSourceRequest(config=_config_from_body(body)).to_legacy_config(
fallback_id=source_id,
),
)
return ok(message="更新 provider source 成功")
@router.delete("/provider-sources/by-id")
async def delete_provider_source_by_id(
source_id: str = Query(...),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.delete_provider_source(source_id)
return ok(message="删除 provider source 成功")
@router.get("/provider-sources/models")
async def list_provider_source_models_by_id(
source_id: str = Query(...),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(await service.list_provider_source_models(source_id))
@router.get("/provider-sources/providers")
async def list_providers_by_source_id(
source_id: str = Query(...),
capability: str | None = Query(default=None),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.list_providers(capability=capability, source_id=source_id))
@router.post("/provider-sources/providers")
async def create_provider_in_source_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
source_id = _required_text(body.get("source_id"), "source_id")
await service.create_provider(
ProviderConfigRequest(config=_config_from_body(body)).to_legacy_config(
source_id=source_id,
),
source_id,
)
return ok(message="新增服务提供商配置成功")
@router.get("/provider-sources/{source_id:path}/models")
async def list_provider_source_models(
source_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(await service.list_provider_source_models(source_id))
@router.get("/provider-sources/{source_id:path}/providers")
async def list_providers_by_source(
source_id: str,
capability: str | None = Query(default=None),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.list_providers(capability=capability, source_id=source_id))
@router.post("/provider-sources/{source_id:path}/providers")
async def create_provider_in_source(
source_id: str,
payload: ProviderConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.create_provider(
payload.to_legacy_config(source_id=source_id), source_id
)
return ok(message="新增服务提供商配置成功")
@router.get("/provider-sources/{source_id:path}")
async def get_provider_source(
source_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.get_provider_source(source_id))
@router.put("/provider-sources/{source_id:path}")
async def upsert_provider_source(
source_id: str,
payload: ProviderSourceRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.upsert_provider_source(
source_id,
payload.to_legacy_config(),
)
return ok(message="更新 provider source 成功")
@router.delete("/provider-sources/{source_id:path}")
async def delete_provider_source(
source_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.delete_provider_source(source_id)
return ok(message="删除 provider source 成功")
@router.get("/providers")
async def list_providers(
capability: str | None = Query(default=None),
source_id: str | None = Query(default=None),
enabled: bool | None = Query(default=None),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(
service.list_providers(
capability=capability,
source_id=source_id,
enabled=enabled,
)
)
@router.post("/providers")
async def create_provider(
payload: ProviderConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.create_provider(payload.to_legacy_config())
return ok(message="新增服务提供商配置成功")
@router.get("/providers/by-id")
async def get_provider_by_id(
provider_id: str = Query(...),
merged: bool = Query(default=False),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.get_provider(provider_id, merged=merged))
@router.put("/providers/by-id")
async def update_provider_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
provider_id = _required_text(body.get("provider_id"), "provider_id")
await service.update_provider(
provider_id,
ProviderConfigRequest(config=_config_from_body(body)).to_legacy_config(
fallback_id=provider_id,
),
)
return ok(message="更新成功,已经实时生效~")
@router.delete("/providers/by-id")
async def delete_provider_by_id(
provider_id: str = Query(...),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.delete_provider(provider_id)
return ok(message="删除成功,已经实时生效。")
@router.patch("/providers/enabled")
async def set_provider_enabled_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
provider_id = _required_text(body.get("provider_id"), "provider_id")
await service.set_provider_enabled(provider_id, bool(body.get("enabled")))
return ok(message="更新成功,已经实时生效~")
@router.post("/providers/test")
async def test_provider_by_id(
request: Request,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
body = await _json_or_empty(request)
provider_id = _required_text(body.get("provider_id"), "provider_id")
return ok(await service.test_provider(provider_id))
@router.post("/providers/embedding-dimension")
async def get_embedding_dimension_by_id(
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
provider_id = _required_text(body.get("provider_id"), "provider_id")
provider_config = body.get("provider_config")
legacy_body = {"provider_id": provider_id}
if isinstance(provider_config, dict):
legacy_body["provider_config"] = provider_config
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/provider/get_embedding_dim",
json_body=legacy_body,
)
@router.patch("/providers/{provider_id:path}/enabled")
async def set_provider_enabled(
provider_id: str,
payload: EnabledPatch,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.set_provider_enabled(provider_id, payload.enabled)
return ok(message="更新成功,已经实时生效~")
@router.post("/providers/{provider_id:path}/test")
async def test_provider(
provider_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(await service.test_provider(provider_id))
@router.post("/providers/{provider_id:path}/embedding-dimension")
async def get_embedding_dimension(
provider_id: str,
request: Request,
auth: AuthContext = Depends(require_config_scope),
bridge: DashboardRouteBridgeService = Depends(get_bridge),
):
body = await _json_or_empty(request)
return await bridge.forward(
request,
auth,
method="POST",
target_path="/api/config/provider/get_embedding_dim",
json_body={"provider_id": provider_id, **body},
)
@router.get("/providers/{provider_id:path}")
async def get_provider(
provider_id: str,
merged: bool = Query(default=False),
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
return ok(service.get_provider(provider_id, merged=merged))
@router.put("/providers/{provider_id:path}")
async def update_provider(
provider_id: str,
payload: ProviderConfigRequest,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.update_provider(
provider_id,
payload.to_legacy_config(fallback_id=provider_id),
)
return ok(message="更新成功,已经实时生效~")
@router.delete("/providers/{provider_id:path}")
async def delete_provider(
provider_id: str,
_auth: AuthContext = Depends(require_config_scope),
service: ProviderConfigService = Depends(get_service),
):
await service.delete_provider(provider_id)
return ok(message="删除成功,已经实时生效。")

View File

@@ -0,0 +1,172 @@
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class OpenModel(BaseModel):
model_config = ConfigDict(extra="allow")
class ConfigProfileCreateRequest(BaseModel):
name: str | None = None
config: dict[str, Any] | None = None
class RenameRequest(BaseModel):
name: str | None = None
class EnabledPatch(BaseModel):
enabled: bool
class BotConfigRequest(OpenModel):
id: str | None = None
name: str | None = None
type: str | None = None
enabled: bool | None = None
enable: bool | None = None
config: dict[str, Any] | None = None
def to_legacy_config(self, *, fallback_id: str | None = None) -> dict[str, Any]:
config = dict(
self.config
or self.model_dump(
exclude={"config", "enabled"},
exclude_none=True,
)
)
if fallback_id and "id" not in config:
config["id"] = fallback_id
if self.type and "type" not in config:
config["type"] = self.type
if self.id and "id" not in config:
config["id"] = self.id
if self.enabled is not None:
config["enable"] = self.enabled
elif self.enable is not None:
config["enable"] = self.enable
elif "enable" not in config:
config["enable"] = True
return config
class ProviderSourceRequest(OpenModel):
id: str | None = None
config: dict[str, Any] | None = None
def to_legacy_config(self, *, fallback_id: str | None = None) -> dict[str, Any]:
config = dict(
self.config or self.model_dump(exclude={"config"}, exclude_none=True)
)
if fallback_id:
config["id"] = fallback_id
elif self.id and "id" not in config:
config["id"] = self.id
return config
class ProviderConfigRequest(OpenModel):
id: str | None = None
provider_source_id: str | None = None
capability: str | None = None
enabled: bool | None = None
enable: bool | None = None
config: dict[str, Any] | None = None
def to_legacy_config(
self,
*,
fallback_id: str | None = None,
source_id: str | None = None,
) -> dict[str, Any]:
config = dict(
self.config
or self.model_dump(
exclude={"config", "capability", "enabled"},
exclude_none=True,
)
)
if fallback_id and "id" not in config:
config["id"] = fallback_id
if self.id and "id" not in config:
config["id"] = self.id
if source_id:
config["provider_source_id"] = source_id
elif self.provider_source_id and "provider_source_id" not in config:
config["provider_source_id"] = self.provider_source_id
if self.enabled is not None:
config["enable"] = self.enabled
elif self.enable is not None:
config["enable"] = self.enable
elif "enable" not in config:
config["enable"] = True
if self.capability and "provider_type" not in config:
capability_map = {
"chat": "chat_completion",
"agent": "agent_runner",
"stt": "speech_to_text",
"tts": "text_to_speech",
"embedding": "embedding",
"rerank": "rerank",
}
config["provider_type"] = capability_map.get(
self.capability, self.capability
)
return config
class ProviderListQuery(BaseModel):
capability: str | None = None
source_id: str | None = None
enabled: bool | None = None
class ConfigRoutesReplaceRequest(BaseModel):
routing: dict[str, str]
class ConfigRouteUpsertRequest(BaseModel):
config_id: str = Field(..., min_length=1)
class SessionRuleRequest(OpenModel):
umo: str | None = None
rule_key: str | None = None
rule_value: Any = None
class UmoListRequest(OpenModel):
umo: str | None = None
umos: list[str] | None = None
scope: Literal["all", "group", "private", "custom_group"] | None = None
group_id: str | None = None
rule_key: str | None = None
class BatchSessionProviderRequest(UmoListRequest):
provider_id: str | None = None
provider_type: (
Literal[
"chat_completion",
"speech_to_text",
"text_to_speech",
]
| None
) = None
class BatchSessionServiceRequest(UmoListRequest):
session_enabled: bool | None = None
llm_enabled: bool | None = None
tts_enabled: bool | None = None
class SessionGroupRequest(OpenModel):
id: str | None = None
name: str | None = None
umos: list[str] | None = None
add_umos: list[str] | None = None
remove_umos: list[str] | None = None

View File

@@ -10,6 +10,7 @@
"build-stage": "node scripts/subset-mdi-font.mjs && vue-tsc --noEmit && vite build --base=/vue/free/stage/",
"build-prod": "node scripts/subset-mdi-font.mjs && vue-tsc --noEmit && vite build --base=/vue/free/",
"preview": "vite preview --port 5050",
"generate:api": "uv run python scripts/generate_openapi_client.py",
"typecheck": "vue-tsc --noEmit",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
},

View File

@@ -0,0 +1,364 @@
from __future__ import annotations
import argparse
import json
import re
import urllib.request
from pathlib import Path
from typing import Any
import yaml
ROOT_DIR = Path(__file__).resolve().parents[2]
DASHBOARD_DIR = Path(__file__).resolve().parents[1]
DEFAULT_SPEC = ROOT_DIR / "openspec" / "openapi-v1.yaml"
DEFAULT_OUTPUT = DASHBOARD_DIR / "src" / "api" / "generated" / "openapi-v1.ts"
HTTP_METHODS = {"get", "post", "put", "patch", "delete", "head", "options"}
def load_spec(source: str) -> dict[str, Any]:
if source.startswith(("http://", "https://")):
with urllib.request.urlopen(source, timeout=10) as response:
return json.loads(response.read().decode("utf-8"))
spec_path = Path(source)
if not spec_path.is_absolute():
spec_path = (ROOT_DIR / spec_path).resolve()
text = spec_path.read_text(encoding="utf-8")
if spec_path.suffix.lower() == ".json":
return json.loads(text)
return yaml.safe_load(text)
def pascal_case(value: str) -> str:
words = re.split(r"[^a-zA-Z0-9]+", value)
return "".join(word[:1].upper() + word[1:] for word in words if word)
def camel_case(value: str) -> str:
pascal = pascal_case(value)
return pascal[:1].lower() + pascal[1:]
def quote(value: str) -> str:
return json.dumps(value, ensure_ascii=True)
def property_name(name: str) -> str:
if re.fullmatch(r"[A-Za-z_$][A-Za-z0-9_$]*", name):
return name
return quote(name)
def ref_name(ref: str) -> str:
return ref.rsplit("/", 1)[-1]
class TypeScriptGenerator:
def __init__(self, spec: dict[str, Any]) -> None:
self.spec = spec
self.components = spec.get("components", {})
def resolve_ref(self, obj: dict[str, Any]) -> dict[str, Any]:
ref = obj.get("$ref")
if not ref:
return obj
if not ref.startswith("#/"):
raise ValueError(f"Unsupported external ref: {ref}")
current: Any = self.spec
for part in ref.removeprefix("#/").split("/"):
current = current[part]
return current
def schema_to_ts(self, schema: dict[str, Any] | None) -> str:
if not schema:
return "unknown"
if "$ref" in schema:
return ref_name(schema["$ref"])
if "allOf" in schema:
parts = [self.schema_to_ts(item) for item in schema["allOf"]]
return " & ".join(parts) or "unknown"
if "oneOf" in schema:
parts = [self.schema_to_ts(item) for item in schema["oneOf"]]
return " | ".join(parts) or "unknown"
if "anyOf" in schema:
parts = [self.schema_to_ts(item) for item in schema["anyOf"]]
return " | ".join(parts) or "unknown"
if "const" in schema:
return quote(str(schema["const"]))
if "enum" in schema:
values = schema.get("enum") or []
return " | ".join(quote(str(value)) for value in values) or "string"
schema_type = schema.get("type")
if isinstance(schema_type, list):
return " | ".join(
self.schema_to_ts({**schema, "type": item}) for item in schema_type
)
if schema_type == "string":
if schema.get("format") == "binary":
return "Blob | File"
return "string"
if schema_type in {"integer", "number"}:
return "number"
if schema_type == "boolean":
return "boolean"
if schema_type == "array":
return f"{self.schema_to_ts(schema.get('items'))}[]"
if schema_type == "object" or "properties" in schema:
properties = schema.get("properties") or {}
additional = schema.get("additionalProperties")
if not properties:
if isinstance(additional, dict):
return f"Record<string, {self.schema_to_ts(additional)}>"
return "Record<string, unknown>"
required = set(schema.get("required") or [])
fields = []
for name, prop_schema in properties.items():
optional = "" if name in required else "?"
fields.append(
f"{property_name(name)}{optional}: {self.schema_to_ts(prop_schema)};"
)
if additional is True:
fields.append("[key: string]: unknown;")
elif isinstance(additional, dict):
fields.append(f"[key: string]: {self.schema_to_ts(additional)};")
return "{ " + " ".join(fields) + " }"
return "unknown"
def component_declarations(self) -> list[str]:
declarations = []
schemas = self.components.get("schemas") or {}
for name, schema in schemas.items():
if (
schema.get("type") == "object"
and "properties" in schema
and "allOf" not in schema
and "oneOf" not in schema
and "anyOf" not in schema
):
declarations.append(self.object_interface(name, schema))
else:
declarations.append(
f"export type {name} = {self.schema_to_ts(schema)};"
)
return declarations
def object_interface(self, name: str, schema: dict[str, Any]) -> str:
required = set(schema.get("required") or [])
lines = [f"export interface {name} {{"]
for prop_name, prop_schema in (schema.get("properties") or {}).items():
optional = "" if prop_name in required else "?"
lines.append(
f" {property_name(prop_name)}{optional}: "
f"{self.schema_to_ts(prop_schema)};"
)
additional = schema.get("additionalProperties")
if additional is True:
lines.append(" [key: string]: unknown;")
elif isinstance(additional, dict):
lines.append(f" [key: string]: {self.schema_to_ts(additional)};")
lines.append("}")
return "\n".join(lines)
def resolve_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
if "$ref" not in parameter:
return parameter
name = ref_name(parameter["$ref"])
return self.components["parameters"][name]
def request_body_type(self, request_body: dict[str, Any] | None) -> str | None:
if not request_body:
return None
if "$ref" in request_body:
request_body = self.resolve_ref(request_body)
content = request_body.get("content") or {}
if "multipart/form-data" in content:
return "FormData"
if "application/octet-stream" in content:
return "Blob | ArrayBuffer | string"
media = content.get("application/json") or next(iter(content.values()), None)
if not media:
return "unknown"
return self.schema_to_ts(media.get("schema"))
def response_type(self, operation: dict[str, Any]) -> str:
responses = operation.get("responses") or {}
response = responses.get("200") or responses.get("201") or responses.get("101")
if not response:
return "unknown"
if "$ref" in response:
response = self.resolve_ref(response)
content = response.get("content") or {}
if "application/json" in content:
return self.schema_to_ts(content["application/json"].get("schema"))
if "text/plain" in content or "text/html" in content:
return "string"
return "unknown"
def operation_parameters(
self,
operation: dict[str, Any],
path_item: dict[str, Any],
operation_id: str,
) -> tuple[list[dict[str, Any]], list[str]]:
path_params: list[dict[str, Any]] = []
query_params: list[dict[str, Any]] = []
declarations: list[str] = []
parameters = [
*(path_item.get("parameters") or []),
*(operation.get("parameters") or []),
]
for raw_parameter in parameters:
parameter = self.resolve_parameter(raw_parameter)
target = path_params if parameter.get("in") == "path" else query_params
if parameter.get("in") in {"path", "query"}:
target.append(parameter)
def emit_params(name_suffix: str, params: list[dict[str, Any]]) -> str | None:
if not params:
return None
type_name = f"{pascal_case(operation_id)}{name_suffix}"
lines = [f"export interface {type_name} {{"]
for param in params:
required = bool(param.get("required"))
optional = "" if required else "?"
lines.append(
f" {property_name(param['name'])}{optional}: "
f"{self.schema_to_ts(param.get('schema'))};"
)
lines.append("}")
declarations.append("\n".join(lines))
return type_name
path_type = emit_params("Path", path_params)
query_type = emit_params("Query", query_params)
return declarations, [path_type or "undefined", query_type or "undefined"]
def operation_declaration(
self,
path: str,
method: str,
path_item: dict[str, Any],
operation: dict[str, Any],
) -> tuple[list[str], str]:
operation_id = operation.get("operationId") or camel_case(f"{method}_{path}")
operation_name = camel_case(operation_id)
declarations, [path_type, query_type] = self.operation_parameters(
operation,
path_item,
operation_id,
)
body_type = self.request_body_type(operation.get("requestBody")) or "undefined"
response_type = self.response_type(operation)
args_type_name = f"{pascal_case(operation_id)}Args"
members: list[str] = []
if path_type != "undefined":
members.append(f"path: {path_type};")
if query_type != "undefined":
members.append(f"query?: {query_type};")
if body_type != "undefined":
required = bool((operation.get("requestBody") or {}).get("required"))
optional = "" if required else "?"
members.append(f"body{optional}: {body_type};")
if members:
declarations.append(
"export interface "
+ args_type_name
+ " {\n "
+ "\n ".join(members)
+ "\n}"
)
args_signature = f"args: {args_type_name}"
args_value = "args"
else:
args_signature = "args?: undefined"
args_value = "args"
function = (
f" {operation_name}({args_signature}, config?: AxiosRequestConfig) {{\n"
f" return request<{response_type}>("
f"{quote(method.upper())}, {quote(path)}, {args_value}, config"
f");\n"
f" }}"
)
return declarations, function
def generate(self) -> str:
declarations = self.component_declarations()
operation_functions = []
for path, path_item in sorted((self.spec.get("paths") or {}).items()):
for method, operation in path_item.items():
if method not in HTTP_METHODS:
continue
operation_declarations, operation_function = self.operation_declaration(
path,
method,
path_item,
operation,
)
declarations.extend(operation_declarations)
operation_functions.append(operation_function)
return (
"\n\n".join(
[
"/* eslint-disable */",
"// This file is auto-generated by dashboard/scripts/generate_openapi_client.py.",
"// Do not edit it manually; update openspec/openapi-v1.yaml and regenerate instead.",
"import type { AxiosRequestConfig, AxiosResponse } from 'axios';",
"import { apiV1Client } from '../http';",
"type RequestArgs = { path?: object; query?: object; body?: unknown } | undefined;",
"function encodePathValue(value: unknown): string {\n return encodeURIComponent(String(value));\n}",
"function applyPathParams(path: string, params?: object): string {\n if (!params) return path;\n const values = params as Record<string, unknown>;\n return path.replace(/\\{([^}:]+)(?::path)?\\}/g, (_match, key) => encodePathValue(values[key]));\n}",
"function request<T>(method: string, path: string, args?: RequestArgs, config?: AxiosRequestConfig): Promise<AxiosResponse<T>> {\n return apiV1Client.request<T>({\n ...config,\n method,\n url: applyPathParams(path, args?.path),\n params: args?.query,\n data: args?.body,\n });\n}",
*declarations,
"export const openApiV1 = {\n"
+ ",\n".join(operation_functions)
+ "\n};",
]
)
+ "\n"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate the dashboard OpenAPI v1 client."
)
parser.add_argument(
"--spec",
default=str(DEFAULT_SPEC),
help="OpenAPI source URL or file path. Defaults to openspec/openapi-v1.yaml.",
)
parser.add_argument(
"--out",
default=str(DEFAULT_OUTPUT),
help="Generated TypeScript output path.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
spec = load_spec(args.spec)
output = Path(args.out)
if not output.is_absolute():
output = (DASHBOARD_DIR / output).resolve()
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(TypeScriptGenerator(spec).generate(), encoding="utf-8")
print(f"Generated {output.relative_to(ROOT_DIR)}")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

104
dashboard/src/api/http.ts Normal file
View File

@@ -0,0 +1,104 @@
import axios, {
type AxiosError,
type AxiosInstance,
type InternalAxiosRequestConfig,
} from 'axios';
const AUTH_HEADER = 'Authorization';
const LOCALE_HEADER = 'Accept-Language';
let configured = false;
let originalFetch: typeof window.fetch | null = null;
export const httpClient = axios;
export const legacyApiClient = axios.create({ baseURL: '/api' });
export const apiV1Client = axios.create({ baseURL: '/api/v1' });
function getToken(): string | null {
return localStorage.getItem('token');
}
function getLocale(): string | null {
return localStorage.getItem('astrbot-locale');
}
function setAxiosHeader(
headers: InternalAxiosRequestConfig['headers'],
key: string,
value: string,
) {
if (typeof headers.set === 'function') {
headers.set(key, value);
return;
}
headers[key] = value;
}
function attachAxiosHeaders(config: InternalAxiosRequestConfig) {
const token = getToken();
if (token) {
setAxiosHeader(config.headers, AUTH_HEADER, `Bearer ${token}`);
}
const locale = getLocale();
if (locale) {
setAxiosHeader(config.headers, LOCALE_HEADER, locale);
}
return config;
}
function normalizeAxiosError(error: AxiosError) {
if (error.response?.status === 429) {
const data = error.response.data as { message?: string } | undefined;
if (data?.message) {
return Promise.reject(data.message);
}
}
return Promise.reject(error);
}
function installAxiosInterceptors(instance: AxiosInstance) {
instance.interceptors.request.use(attachAxiosHeaders);
instance.interceptors.response.use((response) => response, normalizeAxiosError);
}
export function fetchWithAuth(input: RequestInfo | URL, init?: RequestInit) {
const fetchImpl = originalFetch ?? window.fetch.bind(window);
const token = getToken();
const locale = getLocale();
if (!token && !locale) {
return fetchImpl(input, init);
}
const requestHeaders =
typeof input !== 'string' && 'headers' in input
? (input as Request).headers
: undefined;
const headers = new Headers(init?.headers || requestHeaders);
if (token && !headers.has(AUTH_HEADER)) {
headers.set(AUTH_HEADER, `Bearer ${token}`);
}
if (locale && !headers.has(LOCALE_HEADER)) {
headers.set(LOCALE_HEADER, locale);
}
return fetchImpl(input, { ...init, headers });
}
export function setupHttpClient() {
if (configured) {
return;
}
installAxiosInterceptors(axios);
installAxiosInterceptors(legacyApiClient);
installAxiosInterceptors(apiV1Client);
originalFetch = window.fetch.bind(window);
window.fetch = fetchWithAuth;
configured = true;
}

1552
dashboard/src/api/v1.ts Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -503,7 +503,8 @@ import {
} from "vue";
import { useRoute, useRouter } from "vue-router";
import { useDisplay } from "vuetify";
import axios from "axios";
import { isAxiosError } from "axios";
import { chatApi } from "@/api/v1";
import StyledMenu from "@/components/shared/StyledMenu.vue";
import ProjectDialog, {
type ProjectFormData,
@@ -903,8 +904,7 @@ async function saveSessionTitleDialog() {
try {
const sessionId = editingSessionTitleId.value;
const displayName = sessionTitleDraft.value.trim();
await axios.post("/api/chat/update_session_display_name", {
session_id: sessionId,
await chatApi.updateSession(sessionId, {
display_name: displayName,
});
updateSessionTitle(sessionId, displayName);
@@ -1202,7 +1202,7 @@ async function createThreadFromSelection() {
const message = threadSelection.message;
if (!currSessionId.value || !message?.id || !threadSelection.selectedText) return;
try {
const response = await axios.post("/api/chat/thread/create", {
const response = await chatApi.createThread({
session_id: currSessionId.value,
parent_message_id: message.id,
selected_text: threadSelection.selectedText,
@@ -1224,7 +1224,7 @@ async function createThreadFromSelection() {
window.getSelection()?.removeAllRanges();
} catch (error) {
toast.error(
axios.isAxiosError(error)
isAxiosError(error)
? error.response?.data?.message || error.message
: tm("thread.createFailed"),
);
@@ -1269,9 +1269,7 @@ async function deleteThread(thread: ChatThread) {
if (!(await askForConfirmation(tm("thread.confirmDelete"), confirmDialog))) return;
deletingThread.value = true;
try {
await axios.post("/api/chat/thread/delete", {
thread_id: thread.thread_id,
});
await chatApi.deleteThread(thread.thread_id);
removeThreadFromMessages(thread.thread_id);
if (activeThread.value?.thread_id === thread.thread_id) {
threadPanelOpen.value = false;

View File

@@ -321,7 +321,7 @@ import { useDisplay } from "vuetify";
import { useModuleI18n } from "@/i18n/composables";
import { useCustomizerStore } from "@/stores/customizer";
import { isComposingEnter } from "@/utils/imeInput.mjs";
import axios from "axios";
import { commandApi } from "@/api/v1";
import type { CommandItem } from "@/components/extension/componentPanel/types";
import ConfigSelector from "./ConfigSelector.vue";
import ProviderModelMenu from "./ProviderModelMenu.vue";
@@ -742,16 +742,12 @@ async function fetchCommands() {
if (commandSuggestionLoading.value) return;
commandSuggestionLoading.value = true;
try {
const params: Record<string, string> = {};
const cid = currentConfigId.value;
if (cid && cid !== "default") {
params.config_id = cid;
}
const res = await axios.get("/api/commands", { params });
const res = await commandApi.list(cid && cid !== "default" ? cid : undefined);
if (res.data.status === "ok") {
allCommands.value = res.data.data.items || [];
// 读取当前配置的唤醒词列表,用于指令候选的触发前缀
const prefixes: string[] = res.data.data.wake_prefix;
const prefixes: string[] = res.data.data.wake_prefix || [];
if (prefixes && prefixes.length > 0) {
wakePrefixes.value = prefixes;
}

View File

@@ -389,6 +389,7 @@
<script setup lang="ts">
import { computed, nextTick, reactive, ref } from "vue";
import axios from "axios";
import { fileApi } from "@/api/v1";
import { setCustomComponents } from "markstream-vue";
import "markstream-vue/index.css";
import RegenerateMenu, {
@@ -693,12 +694,10 @@ function partUrl(part: MessagePart) {
if (part.embedded_url) return part.embedded_url;
if (part.embedded_file?.url) return part.embedded_file.url;
if (part.attachment_id) {
return `/api/chat/get_attachment?attachment_id=${encodeURIComponent(
part.attachment_id,
)}`;
return fileApi.contentUrl(part.attachment_id);
}
if (part.filename) {
return `/api/chat/get_file?filename=${encodeURIComponent(part.filename)}`;
return fileApi.byNameUrl(part.filename);
}
return "";
}

View File

@@ -74,7 +74,7 @@
<script setup lang="ts">
import { computed, onMounted, ref, watch } from 'vue';
import axios from 'axios';
import { configProfileApi, configRouteApi } from '@/api/v1';
import { useToast } from '@/utils/toast';
import { useModuleI18n } from '@/i18n/composables';
import {
@@ -164,8 +164,11 @@ function closeDialog() {
async function fetchConfigList() {
loadingConfigs.value = true;
try {
const res = await axios.get('/api/config/abconfs');
configOptions.value = res.data.data?.info_list || [];
const res = await configProfileApi.list();
configOptions.value = (res.data.data?.info_list || []).map((item: any) => ({
id: String(item.id || ''),
name: String(item.name || item.id || 'default')
}));
} catch (error) {
console.error('加载配置文件列表失败', error);
configOptions.value = [];
@@ -176,7 +179,7 @@ async function fetchConfigList() {
async function fetchRoutingEntries() {
try {
const res = await axios.get('/api/config/umo_abconf_routes');
const res = await configRouteApi.list();
const routing = res.data.data?.routing || {};
routingEntries.value = Object.entries(routing).map(([pattern, confId]) => ({
pattern,
@@ -214,10 +217,9 @@ async function getAgentRunnerType(confId: string): Promise<string> {
return configCache.value[confId];
}
try {
const res = await axios.get('/api/config/abconf', {
params: { id: confId }
});
const type = res.data.data?.config?.provider_settings?.agent_runner_type || 'local';
const res = await configProfileApi.get(confId);
const config = ((res.data.data as any).config || {}) as any;
const type = config?.provider_settings?.agent_runner_type || 'local';
configCache.value[confId] = type;
return type;
} catch (error) {
@@ -244,12 +246,11 @@ async function applySelectionToBackend(confId: string): Promise<boolean> {
}
saving.value = true;
try {
await axios.post('/api/config/umo_abconf_route/update', {
umo: targetUmo.value,
conf_id: confId
});
await configRouteApi.upsert(targetUmo.value, { config_id: confId });
const filtered = routingEntries.value.filter((entry) => entry.pattern !== targetUmo.value);
filtered.push({ pattern: targetUmo.value, confId });
if (confId !== 'default') {
filtered.push({ pattern: targetUmo.value, confId });
}
routingEntries.value = filtered;
return true;
} catch (error) {

View File

@@ -55,6 +55,7 @@
<script setup lang="ts">
import { ref, computed, onBeforeUnmount, watch } from 'vue';
import { useTheme } from 'vuetify';
import { chatApi } from '@/api/v1';
import { useVADRecording } from '@/composables/useVADRecording';
import SiriOrb from './LiveOrb.vue';
@@ -280,8 +281,7 @@ function connectWebSocket(): Promise<void> {
return;
}
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//localhost:6185/api/live_chat/ws?token=${encodeURIComponent(token)}`;
const wsUrl = chatApi.liveWebSocketUrl(token);
ws = new WebSocket(wsUrl);

View File

@@ -256,6 +256,7 @@
<script setup lang="ts">
import { computed, nextTick, reactive, ref } from "vue";
import axios from "axios";
import { fileApi } from "@/api/v1";
import { setCustomComponents } from "markstream-vue";
import "markstream-vue/index.css";
import IPythonToolBlock from "@/components/chat/message_list_comps/IPythonToolBlock.vue";
@@ -347,12 +348,10 @@ function partUrl(part: MessagePart) {
if (part.embedded_url) return part.embedded_url;
if (part.embedded_file?.url) return part.embedded_file.url;
if (part.attachment_id) {
return `/api/chat/get_attachment?attachment_id=${encodeURIComponent(
part.attachment_id,
)}`;
return fileApi.contentUrl(part.attachment_id);
}
if (part.filename) {
return `/api/chat/get_file?filename=${encodeURIComponent(part.filename)}`;
return fileApi.byNameUrl(part.filename);
}
return "";
}

Some files were not shown because too many files have changed in this diff Show More