mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-05 20:30:14 +08:00
Compare commits
1 Commits
codex/umo-
...
refactor/b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ea47c87e5 |
@@ -157,7 +157,7 @@ class Platform(abc.ABC):
|
||||
当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: webhook 请求对象
|
||||
|
||||
Returns:
|
||||
响应内容,格式取决于具体平台的要求
|
||||
|
||||
@@ -132,7 +132,7 @@ class LarkWebhookServer:
|
||||
"""处理 webhook 回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: webhook 请求对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import quart
|
||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
|
||||
|
||||
# remove logger handler
|
||||
for handler in logging.root.handlers[:]:
|
||||
@@ -31,7 +31,7 @@ class QQOfficialWebhook:
|
||||
self.api: BotAPI = BotAPI(http=self.http)
|
||||
self.token = Token(self.appid, self.secret)
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server = FastAPIWebhookServer("qq-official-webhook")
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-qo-webhook/callback",
|
||||
view_func=self.callback,
|
||||
@@ -92,15 +92,15 @@ class QQOfficialWebhook:
|
||||
"""Pop and return extra fields cached from the raw webhook payload for a given message ID."""
|
||||
return self._extra_data_cache.pop(message_id, {})
|
||||
|
||||
async def callback(self):
|
||||
async def callback(self, request):
|
||||
"""内部服务器的回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
return await self.handle_callback(request)
|
||||
|
||||
async def handle_callback(self, request) -> dict:
|
||||
"""处理 webhook 回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
|
||||
@@ -2,11 +2,10 @@ import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from quart import Quart, Response, request
|
||||
from fastapi.responses import Response
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -14,10 +13,11 @@ from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
|
||||
|
||||
|
||||
class SlackWebhookClient:
|
||||
"""Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器"""
|
||||
"""Slack Webhook 模式客户端,使用 FastAPI 作为 Web 服务器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,20 +35,16 @@ class SlackWebhookClient:
|
||||
self.path = path
|
||||
self.event_handler = event_handler
|
||||
|
||||
self.app = Quart(__name__)
|
||||
self.app = FastAPIWebhookServer("slack-webhook")
|
||||
self._setup_routes()
|
||||
|
||||
# 禁用 Quart 的默认日志输出
|
||||
logging.getLogger("quart.app").setLevel(logging.WARNING)
|
||||
logging.getLogger("quart.serving").setLevel(logging.WARNING)
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
"""设置路由"""
|
||||
|
||||
@self.app.route(self.path, methods=["POST"])
|
||||
async def slack_events():
|
||||
async def slack_events(request):
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(request)
|
||||
|
||||
@@ -61,7 +57,7 @@ class SlackWebhookClient:
|
||||
"""处理 Slack 回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
req: Quart 请求对象
|
||||
req: webhook 请求对象
|
||||
|
||||
Returns:
|
||||
Response 对象或字典
|
||||
@@ -75,7 +71,7 @@ class SlackWebhookClient:
|
||||
timestamp = req.headers.get("X-Slack-Request-Timestamp")
|
||||
signature = req.headers.get("X-Slack-Signature")
|
||||
if not timestamp or not signature:
|
||||
return Response("Missing headers", status=400)
|
||||
return Response("Missing headers", status_code=400)
|
||||
# Calculate the HMAC signature
|
||||
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
my_signature = (
|
||||
@@ -89,7 +85,7 @@ class SlackWebhookClient:
|
||||
# Verify the signature
|
||||
if not hmac.compare_digest(my_signature, signature):
|
||||
logger.warning("Slack request signature verification failed")
|
||||
return Response("Invalid signature", status=400)
|
||||
return Response("Invalid signature", status_code=400)
|
||||
logger.info(f"Received Slack event: {event_data}")
|
||||
|
||||
# 处理 URL 验证事件
|
||||
@@ -99,11 +95,11 @@ class SlackWebhookClient:
|
||||
if self.event_handler and event_data.get("type") == "event_callback":
|
||||
await self.event_handler(event_data)
|
||||
|
||||
return Response("", status=200)
|
||||
return Response("", status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Slack 事件时出错: {e}")
|
||||
return Response("Internal Server Error", status=500)
|
||||
return Response("Internal Server Error", status_code=500)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Webhook 服务器"""
|
||||
|
||||
@@ -8,7 +8,6 @@ from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
from wechatpy.enterprise import WeChatClient, parse_message
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
@@ -28,6 +27,7 @@ from astrbot.api.platform import (
|
||||
)
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.media_utils import convert_audio_to_wav
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
@@ -65,7 +65,7 @@ def _extract_wecom_media_filename(disposition: str | None) -> str | None:
|
||||
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server = FastAPIWebhookServer("wecom-webhook")
|
||||
self.port = int(cast(str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.server.add_url_rule(
|
||||
@@ -89,15 +89,15 @@ class WecomServer:
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
async def verify(self, request):
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
return await self.handle_verify(request)
|
||||
|
||||
async def handle_verify(self, request) -> str:
|
||||
"""处理验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
验证响应
|
||||
@@ -117,15 +117,15 @@ class WecomServer:
|
||||
logger.error("验证请求有效性失败,签名异常,请检查配置。")
|
||||
raise
|
||||
|
||||
async def callback_command(self):
|
||||
async def callback_command(self, request):
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
return await self.handle_callback(request)
|
||||
|
||||
async def handle_callback(self, request) -> str:
|
||||
"""处理回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
|
||||
@@ -6,9 +6,8 @@ import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
|
||||
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
@@ -38,14 +37,13 @@ class WecomAIBotServer:
|
||||
self.api_client = api_client
|
||||
self.message_handler = message_handler
|
||||
|
||||
self.app = quart.Quart(__name__)
|
||||
self.app = FastAPIWebhookServer("wecom-ai-bot-webhook")
|
||||
self._setup_routes()
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
"""设置 Quart 路由"""
|
||||
# 使用 Quart 的 add_url_rule 方法添加路由
|
||||
"""设置 FastAPI 路由"""
|
||||
self.app.add_url_rule(
|
||||
"/webhook/wecom-ai-bot",
|
||||
view_func=self.verify_url,
|
||||
@@ -58,15 +56,15 @@ class WecomAIBotServer:
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def verify_url(self):
|
||||
async def verify_url(self, request):
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
return await self.handle_verify(request)
|
||||
|
||||
async def handle_verify(self, request):
|
||||
"""处理 URL 验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
验证响应元组 (content, status_code, headers)
|
||||
@@ -91,15 +89,15 @@ class WecomAIBotServer:
|
||||
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
|
||||
return result, 200, {"Content-Type": "text/plain"}
|
||||
|
||||
async def handle_message(self):
|
||||
async def handle_message(self, request):
|
||||
"""内部服务器的 POST 消息回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
return await self.handle_callback(request)
|
||||
|
||||
async def handle_callback(self, request):
|
||||
"""处理消息回调,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
响应元组 (content, status_code, headers)
|
||||
@@ -186,5 +184,5 @@ class WecomAIBotServer:
|
||||
self.shutdown_event.set()
|
||||
|
||||
def get_app(self):
|
||||
"""获取 Quart 应用实例"""
|
||||
return self.app
|
||||
"""获取 FastAPI 应用实例"""
|
||||
return self.app.app
|
||||
|
||||
@@ -5,7 +5,6 @@ import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
from wechatpy import WeChatClient, create_reply, parse_message
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
@@ -25,6 +24,7 @@ from astrbot.api.platform import (
|
||||
)
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.webhook_server import FastAPIWebhookServer
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.media_utils import convert_audio_to_wav
|
||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||
@@ -44,7 +44,7 @@ class WeixinOfficialAccountServer:
|
||||
config: dict,
|
||||
user_buffer: dict[Any, dict[str, Any]],
|
||||
) -> None:
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server = FastAPIWebhookServer("weixin-official-account-webhook")
|
||||
self.port = int(cast(int | str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.token = config.get("token")
|
||||
@@ -73,15 +73,15 @@ class WeixinOfficialAccountServer:
|
||||
self.user_buffer: dict[str, dict[str, Any]] = user_buffer # from_user -> state
|
||||
self.active_send_mode = False # 是否启用主动发送模式,启用后 callback 将直接返回回复内容,无需等待微信回调
|
||||
|
||||
async def verify(self):
|
||||
async def verify(self, request):
|
||||
"""内部服务器的 GET 验证入口"""
|
||||
return await self.handle_verify(quart.request)
|
||||
return await self.handle_verify(request)
|
||||
|
||||
async def handle_verify(self, request) -> str:
|
||||
"""处理验证请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
验证响应
|
||||
@@ -105,9 +105,9 @@ class WeixinOfficialAccountServer:
|
||||
logger.error("验证请求有效性失败,签名异常,请检查配置。")
|
||||
return "err"
|
||||
|
||||
async def callback_command(self):
|
||||
async def callback_command(self, request):
|
||||
"""内部服务器的 POST 回调入口"""
|
||||
return await self.handle_callback(quart.request)
|
||||
return await self.handle_callback(request)
|
||||
|
||||
def _maybe_encrypt(self, xml: str, nonce: str | None, timestamp: str | None) -> str:
|
||||
if xml and "<Encrypt>" not in xml and nonce and timestamp:
|
||||
@@ -129,7 +129,7 @@ class WeixinOfficialAccountServer:
|
||||
"""处理回调请求,可被统一 webhook 入口复用
|
||||
|
||||
Args:
|
||||
request: Quart 请求对象
|
||||
request: FastAPI webhook request 对象
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
|
||||
107
astrbot/core/platform/webhook_server.py
Normal file
107
astrbot/core/platform/webhook_server.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config as HyperConfig
|
||||
|
||||
|
||||
class WebhookRequest:
|
||||
def __init__(self, request: Request) -> None:
|
||||
self._request = request
|
||||
self.args = request.query_params
|
||||
self.headers = request.headers
|
||||
self.method = request.method
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
return self._request.json()
|
||||
|
||||
async def get_data(self) -> bytes:
|
||||
return await self._request.body()
|
||||
|
||||
async def get_json(self, *, force: bool = False, silent: bool = False):
|
||||
try:
|
||||
return await self._request.json()
|
||||
except Exception:
|
||||
if silent:
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
def _response_from_result(result: Any):
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
|
||||
if isinstance(result, tuple):
|
||||
content = result[0] if result else ""
|
||||
status_code = (
|
||||
result[1] if len(result) > 1 and isinstance(result[1], int) else 200
|
||||
)
|
||||
headers = result[2] if len(result) > 2 and isinstance(result[2], dict) else None
|
||||
if isinstance(content, dict | list):
|
||||
return JSONResponse(content, status_code=status_code, headers=headers)
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
media_type=headers.get("Content-Type") if headers else None,
|
||||
)
|
||||
|
||||
if isinstance(result, dict | list):
|
||||
return JSONResponse(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class FastAPIWebhookServer:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.app = FastAPI(title=name, docs_url=None, redoc_url=None, openapi_url=None)
|
||||
|
||||
def add_url_rule(
|
||||
self,
|
||||
path: str,
|
||||
view_func: Callable,
|
||||
methods: list[str] | None = None,
|
||||
) -> None:
|
||||
async def endpoint(request: Request):
|
||||
if inspect.signature(view_func).parameters:
|
||||
result = view_func(WebhookRequest(request))
|
||||
else:
|
||||
result = view_func()
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return _response_from_result(result)
|
||||
|
||||
self.app.add_api_route(
|
||||
path,
|
||||
endpoint,
|
||||
methods=methods or ["GET"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
|
||||
def route(self, path: str, methods: list[str] | None = None):
|
||||
def decorator(view_func: Callable):
|
||||
self.add_url_rule(path, view_func, methods)
|
||||
return view_func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
shutdown_trigger: Callable | None = None,
|
||||
**_kwargs,
|
||||
) -> None:
|
||||
config = HyperConfig()
|
||||
config.bind = [f"{host}:{port}"]
|
||||
await serve(self.app, config, shutdown_trigger=shutdown_trigger)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
return None
|
||||
598
astrbot/dashboard/fastapi_compat.py
Normal file
598
astrbot/dashboard/fastapi_compat.py
Normal file
@@ -0,0 +1,598 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, Request, WebSocket
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import FileResponse, JSONResponse, Response
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
_request_var: contextvars.ContextVar[CompatRequest] = contextvars.ContextVar(
|
||||
"dashboard_request"
|
||||
)
|
||||
_websocket_var: contextvars.ContextVar[CompatWebSocket] = contextvars.ContextVar(
|
||||
"dashboard_websocket"
|
||||
)
|
||||
_g_var: contextvars.ContextVar[CompatG] = contextvars.ContextVar("dashboard_g")
|
||||
_app_var: contextvars.ContextVar[FastAPIAppAdapter] = contextvars.ContextVar(
|
||||
"dashboard_app"
|
||||
)
|
||||
|
||||
|
||||
class CompatArgs:
|
||||
def __init__(self, values) -> None:
|
||||
self._values = values
|
||||
|
||||
def get(self, key: str, default: Any = None, type: Callable | None = None):
|
||||
value = self._values.get(key, default)
|
||||
if value is default or type is None:
|
||||
return value
|
||||
try:
|
||||
return type(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
class CompatMultiDict:
|
||||
def __init__(self, pairs: list[tuple[str, Any]]) -> None:
|
||||
self._pairs = pairs
|
||||
|
||||
def get(self, key: str, default: Any = None, type: Callable | None = None):
|
||||
for item_key, item_value in reversed(self._pairs):
|
||||
if item_key != key:
|
||||
continue
|
||||
if type is None:
|
||||
return item_value
|
||||
try:
|
||||
return type(item_value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return default
|
||||
|
||||
def getlist(self, key: str) -> list[Any]:
|
||||
return [item_value for item_key, item_value in self._pairs if item_key == key]
|
||||
|
||||
def keys(self):
|
||||
return dict.fromkeys(item_key for item_key, _ in self._pairs).keys()
|
||||
|
||||
def values(self):
|
||||
return [self[key] for key in self.keys()]
|
||||
|
||||
def items(self):
|
||||
return [(key, self[key]) for key in self.keys()]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return any(item_key == key for item_key, _ in self._pairs)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
value = self.get(key)
|
||||
if value is None and key not in self:
|
||||
raise KeyError(key)
|
||||
return value
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._pairs)
|
||||
|
||||
|
||||
class CompatUploadFile:
|
||||
def __init__(self, upload_file: StarletteUploadFile) -> None:
|
||||
self._upload_file = upload_file
|
||||
self.filename = upload_file.filename
|
||||
self.content_type = upload_file.content_type
|
||||
self.headers = upload_file.headers
|
||||
self.content_length = self._resolve_content_length()
|
||||
|
||||
def _resolve_content_length(self) -> int | None:
|
||||
try:
|
||||
raw = self.headers.get("content-length")
|
||||
return int(raw) if raw else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def save(self, destination: str | Path) -> None:
|
||||
path = Path(destination)
|
||||
try:
|
||||
await self._upload_file.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
with path.open("wb") as output:
|
||||
while True:
|
||||
chunk = await self._upload_file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
output.write(chunk)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(self._upload_file, key)
|
||||
|
||||
|
||||
class CompatG:
|
||||
def __init__(self) -> None:
|
||||
self._values: dict[str, Any] = {}
|
||||
|
||||
def get(self, key: str, default: Any = None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
try:
|
||||
return self._values[key]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(key) from exc
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
if key == "_values":
|
||||
super().__setattr__(key, value)
|
||||
return
|
||||
self._values[key] = value
|
||||
|
||||
|
||||
class CompatRequest:
|
||||
def __init__(self, request: Request) -> None:
|
||||
self._request = request
|
||||
self.args = CompatArgs(request.query_params)
|
||||
self.headers = request.headers
|
||||
self.cookies = request.cookies
|
||||
self.method = request.method
|
||||
self.path = request.url.path
|
||||
self.content_type = request.headers.get("content-type")
|
||||
self.remote_addr = request.client.host if request.client else None
|
||||
self._form_cache: CompatMultiDict | None = None
|
||||
self._files_cache: CompatMultiDict | None = None
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
return self.get_json()
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
return self._load_files()
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
return self._load_form()
|
||||
|
||||
async def get_json(self, silent: bool = False):
|
||||
try:
|
||||
return await self._request.json()
|
||||
except Exception:
|
||||
if silent:
|
||||
return None
|
||||
raise
|
||||
|
||||
async def _load_form_parts(self) -> None:
|
||||
if self._form_cache is not None and self._files_cache is not None:
|
||||
return
|
||||
form = await self._request.form()
|
||||
form_pairs: list[tuple[str, Any]] = []
|
||||
file_pairs: list[tuple[str, Any]] = []
|
||||
for key, value in form.multi_items():
|
||||
if isinstance(value, StarletteUploadFile):
|
||||
file_pairs.append((key, CompatUploadFile(value)))
|
||||
else:
|
||||
form_pairs.append((key, value))
|
||||
self._form_cache = CompatMultiDict(form_pairs)
|
||||
self._files_cache = CompatMultiDict(file_pairs)
|
||||
|
||||
async def _load_form(self) -> CompatMultiDict:
|
||||
await self._load_form_parts()
|
||||
assert self._form_cache is not None
|
||||
return self._form_cache
|
||||
|
||||
async def _load_files(self) -> CompatMultiDict:
|
||||
await self._load_form_parts()
|
||||
assert self._files_cache is not None
|
||||
return self._files_cache
|
||||
|
||||
|
||||
class CompatWebSocket:
|
||||
def __init__(self, websocket: WebSocket) -> None:
|
||||
self._websocket = websocket
|
||||
self.args = CompatArgs(websocket.query_params)
|
||||
self.headers = websocket.headers
|
||||
|
||||
async def accept(self) -> None:
|
||||
await self._websocket.accept()
|
||||
|
||||
async def receive_json(self):
|
||||
return await self._websocket.receive_json()
|
||||
|
||||
async def send_json(self, payload: Any) -> None:
|
||||
await self._websocket.send_json(payload)
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||||
await self._websocket.close(code=code, reason=reason or "")
|
||||
|
||||
|
||||
class CompatTestHeaders:
|
||||
def __init__(self, headers: httpx.Headers) -> None:
|
||||
self._headers = headers
|
||||
|
||||
def getlist(self, key: str) -> list[str]:
|
||||
values = self._headers.get_list(key)
|
||||
if key.lower() == "set-cookie":
|
||||
return [value.replace('=""', "=") for value in values]
|
||||
return values
|
||||
|
||||
def get(self, key: str, default: Any = None):
|
||||
value = self._headers.get(key, default)
|
||||
if isinstance(value, str) and key.lower() == "set-cookie":
|
||||
return value.replace('=""', "=")
|
||||
return value
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self._headers[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._headers
|
||||
|
||||
|
||||
class CompatTestResponse:
|
||||
def __init__(self, response: httpx.Response) -> None:
|
||||
self._response = response
|
||||
self.status_code = response.status_code
|
||||
self.headers = CompatTestHeaders(response.headers)
|
||||
self.data = response.content
|
||||
self.content = response.content
|
||||
self.text = response.text
|
||||
|
||||
async def get_json(self):
|
||||
return self._response.json()
|
||||
|
||||
async def get_data(self):
|
||||
return self._response.content
|
||||
|
||||
|
||||
class CompatTestClient:
|
||||
def __init__(self, app: FastAPI) -> None:
|
||||
self._client = httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=app),
|
||||
base_url="http://testserver",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_file_storage(value: Any) -> bool:
|
||||
return hasattr(value, "stream") and hasattr(value, "filename")
|
||||
|
||||
@classmethod
|
||||
def _file_tuple(cls, value: Any):
|
||||
stream = value.stream
|
||||
if hasattr(stream, "seek"):
|
||||
stream.seek(0)
|
||||
content = stream.read()
|
||||
filename = getattr(value, "filename", "upload.bin")
|
||||
content_type = getattr(value, "content_type", None)
|
||||
return filename, content, content_type
|
||||
|
||||
@classmethod
|
||||
def _normalize_data(cls, data: Any):
|
||||
if not isinstance(data, dict):
|
||||
return data, None
|
||||
|
||||
form: dict[str, Any] = {}
|
||||
files: list[tuple[str, tuple]] = []
|
||||
for key, value in data.items():
|
||||
if cls._is_file_storage(value):
|
||||
files.append((key, cls._file_tuple(value)))
|
||||
continue
|
||||
if isinstance(value, Iterable) and not isinstance(
|
||||
value, str | bytes | dict
|
||||
):
|
||||
values = list(value)
|
||||
if values and all(cls._is_file_storage(item) for item in values):
|
||||
files.extend((key, cls._file_tuple(item)) for item in values)
|
||||
continue
|
||||
form[key] = value
|
||||
return form, files or None
|
||||
|
||||
@classmethod
|
||||
def _normalize_files(cls, files: Any):
|
||||
if isinstance(files, dict):
|
||||
items = files.items()
|
||||
elif isinstance(files, Iterable) and not isinstance(files, str | bytes):
|
||||
items = files
|
||||
else:
|
||||
return files
|
||||
|
||||
normalized_files: list[tuple[str, Any]] = []
|
||||
for key, value in items:
|
||||
if cls._is_file_storage(value):
|
||||
normalized_files.append((key, cls._file_tuple(value)))
|
||||
continue
|
||||
if isinstance(value, Iterable) and not isinstance(
|
||||
value, str | bytes | dict
|
||||
):
|
||||
values = list(value)
|
||||
if values and all(cls._is_file_storage(item) for item in values):
|
||||
normalized_files.extend(
|
||||
(key, cls._file_tuple(item)) for item in values
|
||||
)
|
||||
continue
|
||||
normalized_files.append((key, value))
|
||||
return normalized_files
|
||||
|
||||
async def request(self, method: str, url: str, **kwargs):
|
||||
data = kwargs.pop("data", None)
|
||||
if data is not None and "files" not in kwargs:
|
||||
normalized_data, files = self._normalize_data(data)
|
||||
kwargs["data"] = normalized_data
|
||||
if files:
|
||||
kwargs["files"] = files
|
||||
elif data is not None:
|
||||
kwargs["data"] = data
|
||||
if "files" in kwargs:
|
||||
kwargs["files"] = self._normalize_files(kwargs["files"])
|
||||
response = await self._client.request(method, url, **kwargs)
|
||||
return CompatTestResponse(response)
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
return await self.request("GET", url, **kwargs)
|
||||
|
||||
async def post(self, url: str, **kwargs):
|
||||
return await self.request("POST", url, **kwargs)
|
||||
|
||||
async def put(self, url: str, **kwargs):
|
||||
return await self.request("PUT", url, **kwargs)
|
||||
|
||||
async def patch(self, url: str, **kwargs):
|
||||
return await self.request("PATCH", url, **kwargs)
|
||||
|
||||
async def delete(self, url: str, **kwargs):
|
||||
return await self.request("DELETE", url, **kwargs)
|
||||
|
||||
|
||||
class _ContextProxy:
|
||||
def __init__(self, var) -> None:
|
||||
self._var = var
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(self._var.get(), key)
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
if key == "_var":
|
||||
super().__setattr__(key, value)
|
||||
return
|
||||
setattr(self._var.get(), key, value)
|
||||
|
||||
|
||||
request = _ContextProxy(_request_var)
|
||||
websocket = _ContextProxy(_websocket_var)
|
||||
g = _ContextProxy(_g_var)
|
||||
current_app = _ContextProxy(_app_var)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_request_context(
|
||||
request_: Request,
|
||||
app: FastAPIAppAdapter,
|
||||
g_obj: CompatG | None = None,
|
||||
):
|
||||
token_request = _request_var.set(CompatRequest(request_))
|
||||
token_g = _g_var.set(g_obj or getattr(request_.state, "dashboard_g", CompatG()))
|
||||
token_app = _app_var.set(app)
|
||||
try:
|
||||
yield _g_var.get()
|
||||
finally:
|
||||
_app_var.reset(token_app)
|
||||
_g_var.reset(token_g)
|
||||
_request_var.reset(token_request)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_websocket_context(
|
||||
websocket_: WebSocket,
|
||||
app: FastAPIAppAdapter,
|
||||
g_obj: CompatG | None = None,
|
||||
):
|
||||
token_websocket = _websocket_var.set(CompatWebSocket(websocket_))
|
||||
token_g = _g_var.set(g_obj or getattr(websocket_.state, "dashboard_g", CompatG()))
|
||||
token_app = _app_var.set(app)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_app_var.reset(token_app)
|
||||
_g_var.reset(token_g)
|
||||
_websocket_var.reset(token_websocket)
|
||||
|
||||
|
||||
def jsonify(payload: Any = None):
|
||||
return JSONResponse(payload if payload is not None else {})
|
||||
|
||||
|
||||
async def make_response(*args):
|
||||
if not args:
|
||||
return Response()
|
||||
content = args[0]
|
||||
status_code = args[1] if len(args) > 1 and isinstance(args[1], int) else None
|
||||
headers = args[1] if len(args) > 1 and isinstance(args[1], dict) else None
|
||||
if len(args) > 2 and isinstance(args[2], dict):
|
||||
headers = args[2]
|
||||
if isinstance(content, Response):
|
||||
if status_code is not None:
|
||||
content.status_code = status_code
|
||||
if headers:
|
||||
content.headers.update(headers)
|
||||
return content
|
||||
if hasattr(content, "__aiter__"):
|
||||
return StreamingResponse(
|
||||
content,
|
||||
status_code=status_code or 200,
|
||||
headers=headers,
|
||||
)
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=status_code or 200,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
async def send_file(path: str | Path, mimetype: str | None = None, **kwargs):
|
||||
filename = kwargs.get("attachment_filename") or kwargs.get("download_name")
|
||||
as_attachment = bool(kwargs.get("as_attachment"))
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type=mimetype,
|
||||
filename=filename if as_attachment else None,
|
||||
)
|
||||
|
||||
|
||||
def abort(status_code: int):
|
||||
raise HTTPException(status_code=status_code)
|
||||
|
||||
|
||||
def _convert_rule(path: str) -> str:
|
||||
converted = re.sub(r"<path:([A-Za-z_][A-Za-z0-9_]*)>", r"{\1:path}", path)
|
||||
converted = re.sub(r"<([A-Za-z_][A-Za-z0-9_]*)>", r"{\1}", converted)
|
||||
return converted
|
||||
|
||||
|
||||
async def _call_view(view_func: Callable, path_params: dict[str, Any]):
|
||||
result = view_func(**path_params)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return _coerce_view_result(result)
|
||||
|
||||
|
||||
def _coerce_view_result(result: Any):
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
|
||||
if isinstance(result, tuple):
|
||||
content = result[0] if result else None
|
||||
status_code = next((item for item in result[1:] if isinstance(item, int)), 200)
|
||||
headers = next(
|
||||
(item for item in result[1:] if isinstance(item, dict)),
|
||||
None,
|
||||
)
|
||||
if isinstance(content, Response):
|
||||
content.status_code = status_code
|
||||
if headers:
|
||||
content.headers.update(headers)
|
||||
return content
|
||||
return _response_from_content(content, status_code=status_code, headers=headers)
|
||||
|
||||
if isinstance(result, dict | list):
|
||||
return JSONResponse(jsonable_encoder(result))
|
||||
return result
|
||||
|
||||
|
||||
def _response_from_content(
|
||||
content: Any,
|
||||
*,
|
||||
status_code: int,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
if isinstance(content, dict | list):
|
||||
return JSONResponse(
|
||||
jsonable_encoder(content),
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
async def call_request_view(
|
||||
request_: Request,
|
||||
app: FastAPIAppAdapter,
|
||||
view_func: Callable,
|
||||
path_params: dict[str, Any] | None = None,
|
||||
g_obj: CompatG | None = None,
|
||||
):
|
||||
with bind_request_context(request_, app, g_obj):
|
||||
return await _call_view(view_func, path_params or {})
|
||||
|
||||
|
||||
async def call_websocket_view(
|
||||
websocket_: WebSocket,
|
||||
app: FastAPIAppAdapter,
|
||||
view_func: Callable,
|
||||
path_params: dict[str, Any] | None = None,
|
||||
*,
|
||||
accept: bool = True,
|
||||
):
|
||||
if accept:
|
||||
await websocket_.accept()
|
||||
with bind_websocket_context(websocket_, app):
|
||||
return await _call_view(view_func, path_params or {})
|
||||
|
||||
|
||||
class FastAPIAppAdapter:
|
||||
def __init__(self, app: FastAPI, static_folder: str | None = None) -> None:
|
||||
self._app = app
|
||||
self.static_folder = static_folder
|
||||
self.config: dict[str, Any] = {}
|
||||
self.debug = False
|
||||
self.testing = False
|
||||
self.name = "dashboard"
|
||||
|
||||
def add_url_rule(
|
||||
self,
|
||||
path: str,
|
||||
view_func: Callable,
|
||||
methods: list[str] | None = None,
|
||||
endpoint: str | None = None,
|
||||
) -> None:
|
||||
route_path = _convert_rule(path)
|
||||
methods = methods or ["GET"]
|
||||
|
||||
async def endpoint_func(request_: Request):
|
||||
with bind_request_context(request_, self):
|
||||
return await _call_view(view_func, dict(request_.path_params))
|
||||
|
||||
self._app.add_api_route(
|
||||
route_path,
|
||||
endpoint_func,
|
||||
methods=methods,
|
||||
name=endpoint,
|
||||
include_in_schema=False,
|
||||
)
|
||||
|
||||
def websocket(self, path: str):
|
||||
route_path = _convert_rule(path)
|
||||
|
||||
def decorator(view_func: Callable):
|
||||
async def endpoint_func(websocket_: WebSocket):
|
||||
return await call_websocket_view(
|
||||
websocket_,
|
||||
self,
|
||||
view_func,
|
||||
dict(websocket_.path_params),
|
||||
)
|
||||
|
||||
self._app.add_api_websocket_route(
|
||||
route_path,
|
||||
endpoint_func,
|
||||
name=getattr(view_func, "__name__", None),
|
||||
)
|
||||
return view_func
|
||||
|
||||
return decorator
|
||||
|
||||
def errorhandler(self, _status_code: int):
|
||||
def decorator(func: Callable):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def send_static_file(self, filename: str):
|
||||
if not self.static_folder:
|
||||
raise HTTPException(status_code=404)
|
||||
return FileResponse(Path(self.static_folder) / filename)
|
||||
|
||||
def test_client(self):
|
||||
self.testing = True
|
||||
return CompatTestClient(self._app)
|
||||
|
||||
|
||||
CompatResponse = Response
|
||||
@@ -1,6 +1,6 @@
|
||||
from urllib.parse import unquote
|
||||
|
||||
from quart import request
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
|
||||
PLUGIN_PAGE_CONTENT_PREFIX = "/api/plugin/page/content/"
|
||||
PLUGIN_PAGE_BRIDGE_PATH = "/api/plugin/page/bridge-sdk.js"
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from quart import g, request
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.datetime_utils import normalize_datetime_utc
|
||||
from astrbot.dashboard.fastapi_compat import g, request
|
||||
from astrbot.dashboard.services.api_key_service import (
|
||||
ApiKeyService,
|
||||
ApiKeyServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
ALL_OPEN_API_SCOPES = ("chat", "config", "file", "im")
|
||||
|
||||
|
||||
class ApiKeyRoute(Route):
|
||||
def __init__(self, context: RouteContext, db: BaseDatabase) -> None:
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.service = ApiKeyService(db)
|
||||
self.routes = {
|
||||
"/apikey/list": ("GET", self.list_api_keys),
|
||||
"/apikey/create": ("POST", self.create_api_key),
|
||||
@@ -25,119 +21,39 @@ class ApiKeyRoute(Route):
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_utc(dt: datetime | None) -> datetime | None:
|
||||
return normalize_datetime_utc(dt)
|
||||
|
||||
@classmethod
|
||||
def _serialize_datetime(cls, dt: datetime | None) -> str | None:
|
||||
normalized = cls._normalize_utc(dt)
|
||||
if normalized is None:
|
||||
return None
|
||||
return normalized.astimezone().isoformat()
|
||||
def _ok(data=None):
|
||||
return Response().ok(data=data).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _hash_key(raw_key: str) -> str:
|
||||
return hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
raw_key.encode("utf-8"),
|
||||
b"astrbot_api_key",
|
||||
100_000,
|
||||
).hex()
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _serialize_api_key(key) -> dict:
|
||||
expires_at = ApiKeyRoute._normalize_utc(key.expires_at)
|
||||
return {
|
||||
"key_id": key.key_id,
|
||||
"name": key.name,
|
||||
"key_prefix": key.key_prefix,
|
||||
"scopes": key.scopes or [],
|
||||
"created_by": key.created_by,
|
||||
"created_at": ApiKeyRoute._serialize_datetime(key.created_at),
|
||||
"updated_at": ApiKeyRoute._serialize_datetime(key.updated_at),
|
||||
"last_used_at": ApiKeyRoute._serialize_datetime(key.last_used_at),
|
||||
"expires_at": ApiKeyRoute._serialize_datetime(key.expires_at),
|
||||
"revoked_at": ApiKeyRoute._serialize_datetime(key.revoked_at),
|
||||
"is_revoked": key.revoked_at is not None,
|
||||
"is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)),
|
||||
}
|
||||
async def _json_body(self):
|
||||
return await request.json or {}
|
||||
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
return self._ok(await operation())
|
||||
except ApiKeyServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_json(self, operation):
|
||||
payload = await self._json_body()
|
||||
return await self._run(lambda: operation(payload))
|
||||
|
||||
async def list_api_keys(self):
|
||||
keys = await self.db.list_api_keys()
|
||||
return (
|
||||
Response().ok(data=[self._serialize_api_key(key) for key in keys]).__dict__
|
||||
)
|
||||
return await self._run(self.service.list_api_keys)
|
||||
|
||||
async def create_api_key(self):
|
||||
post_data = await request.json or {}
|
||||
|
||||
name = str(post_data.get("name", "")).strip() or "Untitled API Key"
|
||||
scopes = post_data.get("scopes")
|
||||
if scopes is None:
|
||||
normalized_scopes = list(ALL_OPEN_API_SCOPES)
|
||||
elif isinstance(scopes, list):
|
||||
normalized_scopes = [
|
||||
scope
|
||||
for scope in scopes
|
||||
if isinstance(scope, str) and scope in ALL_OPEN_API_SCOPES
|
||||
]
|
||||
normalized_scopes = list(dict.fromkeys(normalized_scopes))
|
||||
if not normalized_scopes:
|
||||
return Response().error("At least one valid scope is required").__dict__
|
||||
else:
|
||||
return Response().error("Invalid scopes").__dict__
|
||||
|
||||
expires_at = None
|
||||
expires_in_days = post_data.get("expires_in_days")
|
||||
if expires_in_days is not None:
|
||||
try:
|
||||
expires_in_days_int = int(expires_in_days)
|
||||
except (TypeError, ValueError):
|
||||
return Response().error("expires_in_days must be an integer").__dict__
|
||||
if expires_in_days_int <= 0:
|
||||
return (
|
||||
Response().error("expires_in_days must be greater than 0").__dict__
|
||||
)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=expires_in_days_int
|
||||
return await self._run_json(
|
||||
lambda payload: self.service.create_api_key_from_legacy_payload(
|
||||
payload,
|
||||
created_by=g.get("username", "unknown"),
|
||||
)
|
||||
|
||||
raw_key = f"abk_{secrets.token_urlsafe(32)}"
|
||||
key_hash = self._hash_key(raw_key)
|
||||
key_prefix = raw_key[:12]
|
||||
created_by = g.get("username", "unknown")
|
||||
|
||||
api_key = await self.db.create_api_key(
|
||||
name=name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
scopes=normalized_scopes, # type: ignore
|
||||
created_by=created_by,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
payload = self._serialize_api_key(api_key)
|
||||
payload["api_key"] = raw_key
|
||||
return Response().ok(data=payload).__dict__
|
||||
|
||||
async def revoke_api_key(self):
|
||||
post_data = await request.json or {}
|
||||
key_id = post_data.get("key_id")
|
||||
if not key_id:
|
||||
return Response().error("Missing key: key_id").__dict__
|
||||
|
||||
success = await self.db.revoke_api_key(key_id)
|
||||
if not success:
|
||||
return Response().error("API key not found").__dict__
|
||||
return Response().ok().__dict__
|
||||
return await self._run_json(self.service.revoke_api_key_from_legacy_payload)
|
||||
|
||||
async def delete_api_key(self):
|
||||
post_data = await request.json or {}
|
||||
key_id = post_data.get("key_id")
|
||||
if not key_id:
|
||||
return Response().error("Missing key: key_id").__dict__
|
||||
|
||||
success = await self.db.delete_api_key(key_id)
|
||||
if not success:
|
||||
return Response().error("API key not found").__dict__
|
||||
return Response().ok().__dict__
|
||||
return await self._run_json(self.service.delete_api_key_from_legacy_payload)
|
||||
|
||||
@@ -1,73 +1,27 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import jwt
|
||||
import pyotp
|
||||
from quart import current_app, g, jsonify, make_response, request
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot.core.utils.auth_password import (
|
||||
is_default_dashboard_password,
|
||||
is_legacy_dashboard_password,
|
||||
validate_dashboard_password,
|
||||
verify_dashboard_password,
|
||||
from astrbot.dashboard.fastapi_compat import (
|
||||
current_app,
|
||||
g,
|
||||
jsonify,
|
||||
make_response,
|
||||
request,
|
||||
)
|
||||
from astrbot.core.utils.totp import (
|
||||
from astrbot.dashboard.services.auth_service import (
|
||||
DASHBOARD_JWT_COOKIE_MAX_AGE,
|
||||
DASHBOARD_JWT_COOKIE_NAME,
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
|
||||
TOTP_TRUSTED_DEVICE_MAX_AGE,
|
||||
TwoFactorCodeType,
|
||||
consume_configured_totp_code,
|
||||
consume_rotation_verified,
|
||||
consume_totp_code,
|
||||
generate_recovery_code,
|
||||
is_totp_enabled,
|
||||
is_totp_trusted_device_valid,
|
||||
issue_totp_trusted_device,
|
||||
revoke_user_trusted_devices,
|
||||
set_pending_totp_secret,
|
||||
set_rotation_verified,
|
||||
verify_configured_2fa_code,
|
||||
)
|
||||
from astrbot.dashboard.password_state import (
|
||||
get_dashboard_password_hash,
|
||||
is_password_change_required,
|
||||
is_password_storage_upgraded,
|
||||
set_dashboard_password_hashes,
|
||||
set_password_change_required,
|
||||
set_password_storage_upgraded,
|
||||
AuthService,
|
||||
AuthServiceResult,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
DASHBOARD_JWT_COOKIE_NAME = "astrbot_dashboard_jwt"
|
||||
DASHBOARD_JWT_COOKIE_MAX_AGE = 7 * 24 * 60 * 60
|
||||
SKIP_DEFAULT_PASSWORD_AUTH_ENV = "ASTRBOT_DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
|
||||
SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY = "DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
|
||||
LOCAL_DASHBOARD_HOSTS = {"127.0.0.1", "localhost", "::1"}
|
||||
DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE = (
|
||||
"Login failed. If this is your first time using AstrBot, the old default "
|
||||
"astrbot password has been replaced by a random strong password printed in "
|
||||
"the startup logs. Check the initial password in the logs and try again. "
|
||||
"Learn more: https://docs.astrbot.app/en/faq.html\n\n"
|
||||
"登录失败。如果您是初次使用,旧版默认 astrbot 密码已改为启动日志中输出的"
|
||||
"随机强密码。请使用日志中提供的的初始密码来登录。了解更多:"
|
||||
"https://docs.astrbot.app/faq.html"
|
||||
)
|
||||
LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE = (
|
||||
"Incorrect username or password. If you cannot log in after upgrading "
|
||||
"AstrBot even though the password is correct, see "
|
||||
"https://docs.astrbot.app/en/faq.html\n\n"
|
||||
"用户名或密码错误。如果你在升级 AstrBot 后遇到了密码正确但无法登录的情况,"
|
||||
"请参考 https://docs.astrbot.app/faq.html"
|
||||
)
|
||||
__all__ = ("AuthRoute",)
|
||||
|
||||
|
||||
class AuthRoute(Route):
|
||||
def __init__(self, context: RouteContext, db) -> None:
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.routes = {
|
||||
"/auth/login": ("POST", self.login),
|
||||
"/auth/logout": ("POST", self.logout),
|
||||
@@ -78,261 +32,43 @@ class AuthRoute(Route):
|
||||
"/auth/totp/recovery": ("POST", self.totp_recovery),
|
||||
"/auth/account/edit": ("POST", self.edit_account),
|
||||
}
|
||||
self.service = AuthService(db, self.config)
|
||||
self.register_routes()
|
||||
|
||||
async def setup_status(self):
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"setup_required": await self._is_setup_required(),
|
||||
"skip_default_password_auth": self._can_skip_default_password_auth(),
|
||||
"password_upgrade_required": not await is_password_storage_upgraded(
|
||||
self.db,
|
||||
self.config,
|
||||
),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
async def _json_body(self):
|
||||
return await request.json
|
||||
|
||||
async def _service_json_response(self, operation, *args, **kwargs):
|
||||
return await self._service_response(
|
||||
await operation(await self._json_body(), *args, **kwargs)
|
||||
)
|
||||
|
||||
async def setup_status(self):
|
||||
return await self._service_response(await self.service.setup_status())
|
||||
|
||||
async def totp_setup(self):
|
||||
post_data = await request.json
|
||||
|
||||
if isinstance(post_data, dict) and post_data.get("secret"):
|
||||
secret = post_data["secret"]
|
||||
code = post_data.get("code")
|
||||
if not isinstance(secret, str) or not secret.strip():
|
||||
return Response().error("Invalid request payload").__dict__
|
||||
|
||||
if not isinstance(code, str) or not code.strip():
|
||||
return Response().error("TOTP 验证码是必需的").__dict__
|
||||
if not await consume_totp_code(secret, code):
|
||||
return Response().error("TOTP 验证码无效").__dict__
|
||||
|
||||
if is_totp_enabled(self.config) and not consume_rotation_verified():
|
||||
return Response().error("需要先验证当前 TOTP").__dict__
|
||||
|
||||
set_pending_totp_secret(secret)
|
||||
recovery_code, recovery_code_hash = generate_recovery_code()
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"recovery_code": recovery_code,
|
||||
"recovery_code_hash": recovery_code_hash,
|
||||
},
|
||||
"TOTP verified",
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
if is_totp_enabled(self.config):
|
||||
if not isinstance(post_data, dict):
|
||||
return Response().error("Invalid request payload").__dict__
|
||||
|
||||
set_rotation_verified(False)
|
||||
|
||||
code = post_data.get("code")
|
||||
if isinstance(code, str) and code.strip():
|
||||
if await consume_configured_totp_code(self.config, code):
|
||||
set_rotation_verified(True)
|
||||
return Response().ok({"secret": pyotp.random_base32()}).__dict__
|
||||
return Response().error("当前 TOTP 验证码无效").__dict__
|
||||
|
||||
return Response().error("需要提供 TOTP 验证码或新密钥").__dict__
|
||||
|
||||
return Response().ok({"secret": pyotp.random_base32()}).__dict__
|
||||
return await self._service_json_response(self.service.totp_setup)
|
||||
|
||||
async def totp_recovery(self):
|
||||
# This endpoint MUST NOT persist the recovery code.
|
||||
recovery_code, recovery_code_hash = generate_recovery_code()
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"recovery_code": recovery_code,
|
||||
"recovery_code_hash": recovery_code_hash,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
return await self._service_response(await self.service.totp_recovery())
|
||||
|
||||
async def setup(self):
|
||||
if not self._can_skip_default_password_auth():
|
||||
return Response().error("Setup without password is not enabled").__dict__
|
||||
if not await self._is_setup_required():
|
||||
return Response().error("Setup is not required").__dict__
|
||||
|
||||
return await self._complete_setup()
|
||||
return await self._service_json_response(self.service.setup)
|
||||
|
||||
async def setup_authenticated(self):
|
||||
if not await self._is_setup_required():
|
||||
return Response().error("Setup is not required").__dict__
|
||||
if not isinstance(getattr(g, "username", None), str):
|
||||
return Response().error("未授权").__dict__
|
||||
|
||||
return await self._complete_setup()
|
||||
|
||||
async def _complete_setup(self):
|
||||
post_data = await request.json
|
||||
if not isinstance(post_data, dict):
|
||||
return Response().error("Invalid request payload").__dict__
|
||||
|
||||
new_username = post_data.get("username")
|
||||
new_password = post_data.get("password")
|
||||
confirm_password = post_data.get("confirm_password")
|
||||
if not isinstance(new_username, str) or len(new_username.strip()) < 3:
|
||||
return Response().error("用户名长度至少3位").__dict__
|
||||
if not isinstance(new_password, str):
|
||||
return Response().error("新密码无效").__dict__
|
||||
if not isinstance(confirm_password, str) or confirm_password != new_password:
|
||||
return Response().error("两次输入的新密码不一致").__dict__
|
||||
|
||||
try:
|
||||
validate_dashboard_password(new_password)
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
username = new_username.strip()
|
||||
self.config["dashboard"]["username"] = username
|
||||
set_dashboard_password_hashes(self.config, new_password)
|
||||
await set_password_storage_upgraded(self.db, self.config, True)
|
||||
await set_password_change_required(self.db, self.config, False)
|
||||
self.config.save_config()
|
||||
|
||||
token = self.generate_jwt(username)
|
||||
payload = Response().ok(
|
||||
{
|
||||
"token": token,
|
||||
"username": username,
|
||||
"change_pwd_hint": False,
|
||||
"legacy_pwd_hint": False,
|
||||
"password_upgrade_required": False,
|
||||
},
|
||||
"Setup completed successfully",
|
||||
return await self._service_json_response(
|
||||
self.service.setup_authenticated,
|
||||
getattr(g, "username", None),
|
||||
)
|
||||
response = await make_response(jsonify(payload.__dict__))
|
||||
self._set_dashboard_jwt_cookie(response, token)
|
||||
return response
|
||||
|
||||
async def login(self):
|
||||
username = self.config["dashboard"]["username"]
|
||||
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
|
||||
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
|
||||
post_data = await request.json
|
||||
|
||||
req_username = (
|
||||
post_data.get("username") if isinstance(post_data, dict) else None
|
||||
return await self._service_json_response(
|
||||
self.service.login,
|
||||
trusted_device_cookie_token=request.cookies.get(
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
|
||||
"",
|
||||
).strip(),
|
||||
)
|
||||
req_password = (
|
||||
post_data.get("password") if isinstance(post_data, dict) else None
|
||||
)
|
||||
totp_code = post_data.get("code") if isinstance(post_data, dict) else None
|
||||
trust_device_flag = (
|
||||
post_data.get("trust_device_flag") is True
|
||||
if isinstance(post_data, dict)
|
||||
else False
|
||||
)
|
||||
if not isinstance(req_username, str) or not isinstance(req_password, str):
|
||||
return Response().error("Invalid request payload").__dict__
|
||||
|
||||
login_verified = req_username == username and verify_dashboard_password(
|
||||
password, req_password
|
||||
)
|
||||
|
||||
if not login_verified:
|
||||
await asyncio.sleep(3)
|
||||
if req_password == "astrbot":
|
||||
return Response().error(DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE).__dict__
|
||||
if is_legacy_dashboard_password(password):
|
||||
return Response().error(LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE).__dict__
|
||||
return await self._error_response(
|
||||
"用户名或密码错误",
|
||||
401,
|
||||
)
|
||||
|
||||
totp_verified = False
|
||||
|
||||
if is_totp_enabled(self.config):
|
||||
cookie_token = request.cookies.get(
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME, ""
|
||||
).strip()
|
||||
if not await is_totp_trusted_device_valid(
|
||||
self.config, self.db, cookie_token
|
||||
):
|
||||
if not isinstance(totp_code, str) or not totp_code.strip():
|
||||
response = await make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "需要 TOTP 验证",
|
||||
"data": {"totp_required": True},
|
||||
}
|
||||
)
|
||||
)
|
||||
response.status_code = 401
|
||||
return response
|
||||
verified_type = await verify_configured_2fa_code(
|
||||
self.config, totp_code, allow_recovery=True
|
||||
)
|
||||
if verified_type is TwoFactorCodeType.TOTP:
|
||||
totp_verified = True
|
||||
elif verified_type is TwoFactorCodeType.RECOVERY:
|
||||
self.config["dashboard"]["totp"] = {
|
||||
"enable": False,
|
||||
"secret": "",
|
||||
"recovery_code_hash": "",
|
||||
}
|
||||
await revoke_user_trusted_devices(self.db)
|
||||
self.config.save_config()
|
||||
elif len(totp_code) == 6 and totp_code.isdigit():
|
||||
return await self._error_response("TOTP 验证码无效", 401)
|
||||
else:
|
||||
return await self._error_response("恢复码无效", 401)
|
||||
|
||||
change_pwd_hint = False
|
||||
legacy_pwd_hint = is_legacy_dashboard_password(password)
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db,
|
||||
self.config,
|
||||
)
|
||||
if (
|
||||
storage_upgraded
|
||||
and username == "astrbot"
|
||||
and is_default_dashboard_password(password)
|
||||
and not DEMO_MODE
|
||||
):
|
||||
change_pwd_hint = True
|
||||
legacy_pwd_hint = True
|
||||
logger.warning("为了保证安全,请尽快修改默认密码。")
|
||||
if password_change_required and not DEMO_MODE:
|
||||
change_pwd_hint = True
|
||||
token = self.generate_jwt(username)
|
||||
login_data = {
|
||||
"token": token,
|
||||
"username": username,
|
||||
"change_pwd_hint": change_pwd_hint,
|
||||
"legacy_pwd_hint": legacy_pwd_hint,
|
||||
"password_upgrade_required": not storage_upgraded,
|
||||
}
|
||||
payload = Response().ok(login_data)
|
||||
response = await make_response(jsonify(payload.__dict__))
|
||||
self._set_dashboard_jwt_cookie(response, token)
|
||||
|
||||
if totp_verified and trust_device_flag:
|
||||
raw_token = await issue_totp_trusted_device(self.config, self.db)
|
||||
if raw_token:
|
||||
response.set_cookie(
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
|
||||
raw_token,
|
||||
max_age=TOTP_TRUSTED_DEVICE_MAX_AGE,
|
||||
httponly=True,
|
||||
samesite="Strict",
|
||||
secure=AuthRoute._use_secure_dashboard_jwt_cookie(),
|
||||
path="/api/auth",
|
||||
)
|
||||
return response
|
||||
|
||||
async def logout(self):
|
||||
response = await make_response(
|
||||
@@ -342,117 +78,38 @@ class AuthRoute(Route):
|
||||
return response
|
||||
|
||||
async def edit_account(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
|
||||
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
|
||||
post_data = await request.json
|
||||
if not isinstance(post_data, dict):
|
||||
return Response().error("Invalid request payload").__dict__
|
||||
|
||||
req_password = post_data.get("password")
|
||||
if not isinstance(req_password, str):
|
||||
return Response().error("Invalid request payload").__dict__
|
||||
|
||||
if not verify_dashboard_password(password, req_password):
|
||||
return Response().error("原密码错误").__dict__
|
||||
|
||||
new_pwd = post_data.get("new_password", None)
|
||||
new_username = post_data.get("new_username", None)
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db,
|
||||
self.config,
|
||||
)
|
||||
if (not storage_upgraded or password_change_required) and not new_pwd:
|
||||
return Response().error("请设置新密码以完成安全升级").__dict__
|
||||
if not new_pwd and not new_username:
|
||||
return Response().error("新用户名和新密码不能同时为空").__dict__
|
||||
|
||||
# Verify password confirmation
|
||||
if new_pwd:
|
||||
if not isinstance(new_pwd, str):
|
||||
return Response().error("新密码无效").__dict__
|
||||
confirm_pwd = post_data.get("confirm_password", None)
|
||||
if not isinstance(confirm_pwd, str) or confirm_pwd != new_pwd:
|
||||
return Response().error("两次输入的新密码不一致").__dict__
|
||||
try:
|
||||
validate_dashboard_password(new_pwd)
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
set_dashboard_password_hashes(self.config, new_pwd)
|
||||
await set_password_storage_upgraded(self.db, self.config, True)
|
||||
await set_password_change_required(self.db, self.config, False)
|
||||
if is_totp_enabled(self.config):
|
||||
await revoke_user_trusted_devices(self.db)
|
||||
if new_username:
|
||||
self.config["dashboard"]["username"] = new_username
|
||||
|
||||
self.config.save_config()
|
||||
|
||||
return Response().ok(None, "Updated account successfully").__dict__
|
||||
return await self._service_json_response(self.service.edit_account)
|
||||
|
||||
def generate_jwt(self, username):
|
||||
payload = {
|
||||
"username": username,
|
||||
"exp": datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(days=7),
|
||||
}
|
||||
jwt_token = self.config["dashboard"].get("jwt_secret", None)
|
||||
if not jwt_token:
|
||||
raise ValueError("JWT secret is not set in the cmd_config.")
|
||||
token = jwt.encode(payload, jwt_token, algorithm="HS256")
|
||||
return token
|
||||
return self.service.generate_jwt(username)
|
||||
|
||||
async def _is_setup_required(self) -> bool:
|
||||
if DEMO_MODE:
|
||||
return False
|
||||
|
||||
dashboard_config = self.config["dashboard"]
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db,
|
||||
self.config,
|
||||
async def _service_response(self, result: AuthServiceResult):
|
||||
payload = (
|
||||
Response().error(result.message or "")
|
||||
if result.status == "error"
|
||||
else Response().ok(result.data, result.message)
|
||||
)
|
||||
if password_change_required:
|
||||
return True
|
||||
if result.status == "error" and result.data is not None:
|
||||
payload.data = result.data
|
||||
|
||||
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
|
||||
if not storage_upgraded:
|
||||
return False
|
||||
response = await make_response(jsonify(payload.__dict__))
|
||||
response.status_code = result.status_code
|
||||
|
||||
return dashboard_config.get(
|
||||
"username"
|
||||
) == "astrbot" and is_default_dashboard_password(
|
||||
dashboard_config.get("pbkdf2_password", "")
|
||||
)
|
||||
if result.jwt_token:
|
||||
self._set_dashboard_jwt_cookie(response, result.jwt_token)
|
||||
|
||||
@staticmethod
|
||||
async def _error_response(message: str, status_code: int):
|
||||
response = await make_response(jsonify(Response().error(message).__dict__))
|
||||
response.status_code = status_code
|
||||
if result.trusted_device_token:
|
||||
response.set_cookie(
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME,
|
||||
result.trusted_device_token,
|
||||
max_age=TOTP_TRUSTED_DEVICE_MAX_AGE,
|
||||
httponly=True,
|
||||
samesite="Strict",
|
||||
secure=AuthRoute._use_secure_dashboard_jwt_cookie(),
|
||||
path="/api/auth",
|
||||
)
|
||||
return response
|
||||
|
||||
def _can_skip_default_password_auth(self) -> bool:
|
||||
if not self._env_flag_enabled(SKIP_DEFAULT_PASSWORD_AUTH_ENV):
|
||||
return False
|
||||
host = (
|
||||
os.environ.get("DASHBOARD_HOST")
|
||||
or os.environ.get("ASTRBOT_DASHBOARD_HOST")
|
||||
or self.config["dashboard"].get("host", "")
|
||||
)
|
||||
return str(host).strip().lower() in LOCAL_DASHBOARD_HOSTS
|
||||
|
||||
@staticmethod
|
||||
def _env_flag_enabled(name: str) -> bool:
|
||||
value = os.environ.get(name)
|
||||
if value is None and name == SKIP_DEFAULT_PASSWORD_AUTH_ENV:
|
||||
value = os.environ.get(SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY)
|
||||
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
@staticmethod
|
||||
def _use_secure_dashboard_jwt_cookie() -> bool:
|
||||
return bool(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,9 @@
|
||||
from quart import g, request
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
from astrbot.dashboard.fastapi_compat import g, request
|
||||
from astrbot.dashboard.services.chatui_project_service import (
|
||||
ChatUIProjectService,
|
||||
ChatUIProjectServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -22,225 +24,96 @@ class ChatUIProjectRoute(Route):
|
||||
),
|
||||
"/chatui_project/get_sessions": ("GET", self.get_project_sessions),
|
||||
}
|
||||
self.db = db
|
||||
self.service = ChatUIProjectService(db)
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _username() -> str:
|
||||
return g.get("username", "guest")
|
||||
|
||||
@staticmethod
|
||||
def _service_error(exc: ChatUIProjectServiceError):
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _ok(data=None):
|
||||
return Response().ok(data=data).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result)
|
||||
except ChatUIProjectServiceError as exc:
|
||||
return self._service_error(exc)
|
||||
|
||||
async def _run_json(self, operation):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
return await self._run(invoke)
|
||||
|
||||
async def create_project(self):
|
||||
"""Create a new ChatUI project."""
|
||||
username = g.get("username", "guest")
|
||||
post_data = await request.json
|
||||
|
||||
title = post_data.get("title")
|
||||
emoji = post_data.get("emoji", "📁")
|
||||
description = post_data.get("description")
|
||||
|
||||
if not title:
|
||||
return Response().error("Missing key: title").__dict__
|
||||
|
||||
project = await self.db.create_chatui_project(
|
||||
creator=username,
|
||||
title=title,
|
||||
emoji=emoji,
|
||||
description=description,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"project_id": project.project_id,
|
||||
"title": project.title,
|
||||
"emoji": project.emoji,
|
||||
"description": project.description,
|
||||
"created_at": to_utc_isoformat(project.created_at),
|
||||
"updated_at": to_utc_isoformat(project.updated_at),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
return await self._run_json(
|
||||
lambda data: self.service.create_project(self._username(), data)
|
||||
)
|
||||
|
||||
async def list_projects(self):
|
||||
"""Get all ChatUI projects for the current user."""
|
||||
username = g.get("username", "guest")
|
||||
|
||||
projects = await self.db.get_chatui_projects_by_creator(creator=username)
|
||||
|
||||
projects_data = [
|
||||
{
|
||||
"project_id": project.project_id,
|
||||
"title": project.title,
|
||||
"emoji": project.emoji,
|
||||
"description": project.description,
|
||||
"created_at": to_utc_isoformat(project.created_at),
|
||||
"updated_at": to_utc_isoformat(project.updated_at),
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
|
||||
return Response().ok(data=projects_data).__dict__
|
||||
return await self._run(lambda: self.service.list_projects(self._username()))
|
||||
|
||||
async def get_project(self):
|
||||
"""Get a specific ChatUI project."""
|
||||
project_id = request.args.get("project_id")
|
||||
if not project_id:
|
||||
return Response().error("Missing key: project_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
project = await self.db.get_chatui_project_by_id(project_id)
|
||||
if not project:
|
||||
return Response().error(f"Project {project_id} not found").__dict__
|
||||
|
||||
# Verify ownership
|
||||
if project.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"project_id": project.project_id,
|
||||
"title": project.title,
|
||||
"emoji": project.emoji,
|
||||
"description": project.description,
|
||||
"created_at": to_utc_isoformat(project.created_at),
|
||||
"updated_at": to_utc_isoformat(project.updated_at),
|
||||
}
|
||||
return await self._run(
|
||||
lambda: self.service.get_project_from_legacy_query(
|
||||
self._username(),
|
||||
request.args.get("project_id"),
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def update_chatui_project(self):
|
||||
"""Update a ChatUI project."""
|
||||
post_data = await request.json
|
||||
|
||||
project_id = post_data.get("project_id")
|
||||
title = post_data.get("title")
|
||||
emoji = post_data.get("emoji")
|
||||
description = post_data.get("description")
|
||||
|
||||
if not project_id:
|
||||
return Response().error("Missing key: project_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Verify ownership
|
||||
project = await self.db.get_chatui_project_by_id(project_id)
|
||||
if not project:
|
||||
return Response().error(f"Project {project_id} not found").__dict__
|
||||
if project.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
await self.db.update_chatui_project(
|
||||
project_id=project_id,
|
||||
title=title,
|
||||
emoji=emoji,
|
||||
description=description,
|
||||
return await self._run_json(
|
||||
lambda data: self.service.update_project(self._username(), data)
|
||||
)
|
||||
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def delete_project(self):
|
||||
"""Delete a ChatUI project."""
|
||||
project_id = request.args.get("project_id")
|
||||
if not project_id:
|
||||
return Response().error("Missing key: project_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Verify ownership
|
||||
project = await self.db.get_chatui_project_by_id(project_id)
|
||||
if not project:
|
||||
return Response().error(f"Project {project_id} not found").__dict__
|
||||
if project.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
await self.db.delete_chatui_project(project_id)
|
||||
|
||||
return Response().ok().__dict__
|
||||
return await self._run(
|
||||
lambda: self.service.delete_project_from_legacy_query(
|
||||
self._username(),
|
||||
request.args.get("project_id"),
|
||||
)
|
||||
)
|
||||
|
||||
async def add_session_to_project(self):
|
||||
"""Add a session to a project."""
|
||||
post_data = await request.json
|
||||
|
||||
session_id = post_data.get("session_id")
|
||||
project_id = post_data.get("project_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("Missing key: session_id").__dict__
|
||||
if not project_id:
|
||||
return Response().error("Missing key: project_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Verify project ownership
|
||||
project = await self.db.get_chatui_project_by_id(project_id)
|
||||
if not project:
|
||||
return Response().error(f"Project {project_id} not found").__dict__
|
||||
if project.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
# Verify session ownership
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if not session:
|
||||
return Response().error(f"Session {session_id} not found").__dict__
|
||||
if session.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
await self.db.add_session_to_project(session_id, project_id)
|
||||
|
||||
return Response().ok().__dict__
|
||||
return await self._run_json(
|
||||
lambda data: self.service.add_session_to_project(self._username(), data)
|
||||
)
|
||||
|
||||
async def remove_session_from_project(self):
|
||||
"""Remove a session from its project."""
|
||||
post_data = await request.json
|
||||
|
||||
session_id = post_data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("Missing key: session_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Verify session ownership
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if not session:
|
||||
return Response().error(f"Session {session_id} not found").__dict__
|
||||
if session.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
await self.db.remove_session_from_project(session_id)
|
||||
|
||||
return Response().ok().__dict__
|
||||
return await self._run_json(
|
||||
lambda data: self.service.remove_session_from_project(
|
||||
self._username(),
|
||||
data,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_project_sessions(self):
|
||||
"""Get all sessions in a project."""
|
||||
project_id = request.args.get("project_id")
|
||||
if not project_id:
|
||||
return Response().error("Missing key: project_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Verify project ownership
|
||||
project = await self.db.get_chatui_project_by_id(project_id)
|
||||
if not project:
|
||||
return Response().error(f"Project {project_id} not found").__dict__
|
||||
if project.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
sessions = await self.db.get_project_sessions(project_id)
|
||||
|
||||
sessions_data = [
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"platform_id": session.platform_id,
|
||||
"creator": session.creator,
|
||||
"display_name": session.display_name,
|
||||
"is_group": session.is_group,
|
||||
"created_at": to_utc_isoformat(session.created_at),
|
||||
"updated_at": to_utc_isoformat(session.updated_at),
|
||||
}
|
||||
for session in sessions
|
||||
]
|
||||
|
||||
return Response().ok(data=sessions_data).__dict__
|
||||
return await self._run(
|
||||
lambda: self.service.get_project_sessions_from_legacy_query(
|
||||
self._username(),
|
||||
request.args.get("project_id"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,17 +1,7 @@
|
||||
from quart import request
|
||||
|
||||
from astrbot.core.star.command_management import (
|
||||
list_command_conflicts,
|
||||
list_commands,
|
||||
)
|
||||
from astrbot.core.star.command_management import (
|
||||
rename_command as rename_command_service,
|
||||
)
|
||||
from astrbot.core.star.command_management import (
|
||||
toggle_command as toggle_command_service,
|
||||
)
|
||||
from astrbot.core.star.command_management import (
|
||||
update_command_permission as update_command_permission_service,
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
from astrbot.dashboard.services.command_service import (
|
||||
CommandService,
|
||||
CommandServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
@@ -20,7 +10,7 @@ from .route import Response, Route, RouteContext
|
||||
class CommandRoute(Route):
|
||||
def __init__(self, context: RouteContext, core_lifecycle=None) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.service = CommandService(self.config, core_lifecycle)
|
||||
self.routes = {
|
||||
"/commands": ("GET", self.get_commands),
|
||||
"/commands/conflicts": ("GET", self.get_conflicts),
|
||||
@@ -30,88 +20,50 @@ class CommandRoute(Route):
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _ok(data=None):
|
||||
return Response().ok(data).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result)
|
||||
except CommandServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_json(self, operation):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
return await self._run(invoke)
|
||||
|
||||
async def get_commands(self):
|
||||
commands = await list_commands()
|
||||
summary = {
|
||||
"total": len(commands),
|
||||
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
|
||||
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
|
||||
}
|
||||
# 优先从指定 config_id 的配置中读取唤醒词,否则使用默认配置
|
||||
config_id = request.args.get("config_id", "").strip()
|
||||
wake_prefix = self.config.get("wake_prefix", ["/"])
|
||||
if config_id and self.core_lifecycle:
|
||||
acm = getattr(self.core_lifecycle, "astrbot_config_mgr", None)
|
||||
if acm and config_id in acm.confs:
|
||||
wake_prefix = acm.confs[config_id].get("wake_prefix", wake_prefix)
|
||||
return (
|
||||
Response()
|
||||
.ok({"items": commands, "summary": summary, "wake_prefix": wake_prefix})
|
||||
.__dict__
|
||||
return await self._run(
|
||||
self.service.list_commands_from_legacy_query(
|
||||
request.args.get("config_id", "")
|
||||
)
|
||||
)
|
||||
|
||||
async def get_conflicts(self):
|
||||
conflicts = await list_command_conflicts()
|
||||
return Response().ok(conflicts).__dict__
|
||||
return await self._run(self.service.list_conflicts())
|
||||
|
||||
async def toggle_command(self):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if handler_full_name is None or enabled is None:
|
||||
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
|
||||
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("1", "true", "yes", "on")
|
||||
|
||||
try:
|
||||
await toggle_command_service(handler_full_name, bool(enabled))
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
return await self._run_json(self.service.toggle_command_from_legacy_payload)
|
||||
|
||||
async def rename_command(self):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
new_name = data.get("new_name")
|
||||
aliases = data.get("aliases")
|
||||
|
||||
if not handler_full_name or not new_name:
|
||||
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
|
||||
|
||||
try:
|
||||
await rename_command_service(handler_full_name, new_name, aliases=aliases)
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
return await self._run_json(self.service.rename_command_from_legacy_payload)
|
||||
|
||||
async def update_permission(self):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
permission = data.get("permission")
|
||||
|
||||
if not handler_full_name or not permission:
|
||||
return (
|
||||
Response().error("handler_full_name 与 permission 均为必填。").__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
await update_command_permission_service(handler_full_name, permission)
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
|
||||
|
||||
async def _get_command_payload(handler_full_name: str):
|
||||
commands = await list_commands()
|
||||
for cmd in commands:
|
||||
if cmd["handler_full_name"] == handler_full_name:
|
||||
return cmd
|
||||
return {}
|
||||
return await self._run_json(self.service.update_permission_from_legacy_payload)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,11 @@
|
||||
import json
|
||||
import traceback
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias
|
||||
from astrbot.dashboard.fastapi_compat import request, send_file
|
||||
from astrbot.dashboard.services.conversation_service import (
|
||||
ConversationService,
|
||||
ConversationServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -24,379 +20,102 @@ class ConversationRoute(Route):
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/conversation/list": ("GET", self.list_conversations),
|
||||
"/conversation/detail": (
|
||||
"POST",
|
||||
self.get_conv_detail,
|
||||
),
|
||||
"/conversation/detail": ("POST", self.get_conv_detail),
|
||||
"/conversation/update": ("POST", self.upd_conv),
|
||||
"/conversation/delete": ("POST", self.del_conv),
|
||||
"/conversation/update_history": (
|
||||
"POST",
|
||||
self.update_history,
|
||||
),
|
||||
"/conversation/update_history": ("POST", self.update_history),
|
||||
"/conversation/export": ("POST", self.export_conversations),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.service = ConversationService(db_helper, core_lifecycle)
|
||||
self.register_routes()
|
||||
|
||||
def _build_umo_info(self, umo: str | None, alias_map: dict) -> dict:
|
||||
umo_str = umo or ""
|
||||
return {
|
||||
"umo": umo_str,
|
||||
**parse_umo(umo_str),
|
||||
**serialize_umo_alias(alias_map.get(umo_str), umo_str),
|
||||
}
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
def _serialize_conversation(self, conversation, alias_map: dict) -> dict:
|
||||
return {
|
||||
**asdict(conversation),
|
||||
"umo_info": self._build_umo_info(conversation.user_id, alias_map),
|
||||
}
|
||||
@staticmethod
|
||||
def _ok(data=None):
|
||||
return Response().ok(data).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation, *, label: str):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result)
|
||||
except ConversationServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("%s: %s", label, exc, exc_info=True)
|
||||
return self._error(f"{label}: {exc!s}")
|
||||
|
||||
async def _run_json(self, operation, *, label: str):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
return await self._run(invoke, label=label)
|
||||
|
||||
async def list_conversations(self):
|
||||
"""获取对话列表,支持分页、排序和筛选"""
|
||||
try:
|
||||
# 获取分页参数
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
# 获取筛选参数
|
||||
platforms = request.args.get("platforms", "")
|
||||
message_types = request.args.get("message_types", "")
|
||||
search_query = request.args.get("search", "")
|
||||
exclude_ids = request.args.get("exclude_ids", "")
|
||||
exclude_platforms = request.args.get("exclude_platforms", "")
|
||||
|
||||
# 转换为列表
|
||||
platform_list = platforms.split(",") if platforms else []
|
||||
message_type_list = message_types.split(",") if message_types else []
|
||||
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
|
||||
exclude_platform_list = (
|
||||
exclude_platforms.split(",") if exclude_platforms else []
|
||||
)
|
||||
|
||||
page = max(page, 1)
|
||||
if page_size < 1:
|
||||
page_size = 20
|
||||
page_size = min(page_size, 100)
|
||||
|
||||
try:
|
||||
(
|
||||
conversations,
|
||||
total_count,
|
||||
) = await self.conv_mgr.get_filtered_conversations(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
platforms=platform_list,
|
||||
message_types=message_type_list,
|
||||
search_query=search_query,
|
||||
exclude_ids=exclude_id_list,
|
||||
exclude_platforms=exclude_platform_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询出错: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"数据库查询出错: {e!s}").__dict__
|
||||
|
||||
# 计算总页数
|
||||
total_pages = (
|
||||
(total_count + page_size - 1) // page_size if total_count > 0 else 1
|
||||
)
|
||||
umos = sorted({conv.user_id for conv in conversations if conv.user_id})
|
||||
alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases(umos))
|
||||
|
||||
result = {
|
||||
"conversations": [
|
||||
self._serialize_conversation(conversation, alias_map)
|
||||
for conversation in conversations
|
||||
],
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total_count,
|
||||
"total_pages": total_pages,
|
||||
},
|
||||
}
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取对话列表失败: {e!s}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取对话列表失败: {e!s}").__dict__
|
||||
return await self._run(
|
||||
self.service.list_conversations_from_legacy_query(
|
||||
page=request.args.get("page", 1),
|
||||
page_size=request.args.get("page_size", 20),
|
||||
platforms=request.args.get("platforms", ""),
|
||||
message_types=request.args.get("message_types", ""),
|
||||
search_query=request.args.get("search", ""),
|
||||
exclude_ids=request.args.get("exclude_ids", ""),
|
||||
exclude_platforms=request.args.get("exclude_platforms", ""),
|
||||
),
|
||||
label="获取对话列表失败",
|
||||
)
|
||||
|
||||
async def get_conv_detail(self):
|
||||
"""获取指定对话详情(通过POST请求)"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
user_id = data.get("user_id")
|
||||
cid = data.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
if not conversation:
|
||||
return Response().error("对话不存在").__dict__
|
||||
|
||||
alias_map = build_umo_alias_map(
|
||||
await self.db_helper.get_umo_aliases([user_id])
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"cid": cid,
|
||||
"title": conversation.title,
|
||||
"persona_id": conversation.persona_id,
|
||||
"history": conversation.history,
|
||||
"created_at": conversation.created_at,
|
||||
"updated_at": conversation.updated_at,
|
||||
"umo_info": self._build_umo_info(user_id, alias_map),
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取对话详情失败: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.get_conversation_detail,
|
||||
label="获取对话详情失败",
|
||||
)
|
||||
|
||||
async def upd_conv(self):
|
||||
"""更新对话信息(标题和角色ID)"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
user_id = data.get("user_id")
|
||||
cid = data.get("cid")
|
||||
title = data.get("title")
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
if not conversation:
|
||||
return Response().error("对话不存在").__dict__
|
||||
|
||||
persona_id = data.get("persona_id", conversation.persona_id)
|
||||
|
||||
if title is not None or persona_id is not None:
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
return Response().ok({"message": "对话信息更新成功"}).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新对话信息失败: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.update_conversation,
|
||||
label="更新对话信息失败",
|
||||
)
|
||||
|
||||
async def del_conv(self):
|
||||
"""删除对话"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
|
||||
# 检查是否是批量删除
|
||||
if "conversations" in data:
|
||||
# 批量删除
|
||||
conversations = data.get("conversations", [])
|
||||
if not conversations:
|
||||
return (
|
||||
Response().error("批量删除时conversations参数不能为空").__dict__
|
||||
)
|
||||
|
||||
deleted_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv in conversations:
|
||||
user_id = conv.get("user_id")
|
||||
cid = conv.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
||||
|
||||
message = f"成功删除 {deleted_count} 个对话"
|
||||
if failed_items:
|
||||
message += f",失败 {len(failed_items)} 个"
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": message,
|
||||
"deleted_count": deleted_count,
|
||||
"failed_count": len(failed_items),
|
||||
"failed_items": failed_items,
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
# 单个删除
|
||||
user_id = data.get("user_id")
|
||||
cid = data.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
|
||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
return Response().ok({"message": "对话删除成功"}).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"删除对话失败: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.delete_conversation,
|
||||
label="删除对话失败",
|
||||
)
|
||||
|
||||
async def update_history(self):
|
||||
"""更新对话历史内容"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
user_id = data.get("user_id")
|
||||
cid = data.get("cid")
|
||||
history = data.get("history")
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
|
||||
if history is None:
|
||||
return Response().error("缺少必要参数: history").__dict__
|
||||
|
||||
# 历史记录必须是合法的 JSON 字符串
|
||||
try:
|
||||
if isinstance(history, list):
|
||||
history = json.dumps(history)
|
||||
else:
|
||||
# 验证是否为有效的 JSON 字符串
|
||||
json.loads(history)
|
||||
except json.JSONDecodeError:
|
||||
return (
|
||||
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
|
||||
)
|
||||
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
if not conversation:
|
||||
return Response().error("对话不存在").__dict__
|
||||
|
||||
history = json.loads(history) if isinstance(history, str) else history
|
||||
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
history=history,
|
||||
)
|
||||
|
||||
return Response().ok({"message": "对话历史更新成功"}).__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.update_history,
|
||||
label="更新对话历史失败",
|
||||
)
|
||||
|
||||
async def export_conversations(self):
|
||||
"""批量导出对话为 JSONL 格式"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
conversations_to_export = data.get("conversations", [])
|
||||
|
||||
if not conversations_to_export:
|
||||
return Response().error("导出列表不能为空").__dict__
|
||||
|
||||
# 收集所有对话的内容
|
||||
jsonl_lines = []
|
||||
exported_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv_info in conversations_to_export:
|
||||
user_id = conv_info.get("user_id")
|
||||
cid = conv_info.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 对话不存在"
|
||||
)
|
||||
continue
|
||||
|
||||
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
|
||||
content = json.loads(conversation.history)
|
||||
|
||||
# 创建导出记录
|
||||
export_record = {
|
||||
"cid": cid,
|
||||
"user_id": user_id,
|
||||
"platform_id": conversation.platform_id,
|
||||
"title": conversation.title,
|
||||
"persona_id": conversation.persona_id,
|
||||
"created_at": conversation.created_at,
|
||||
"updated_at": conversation.updated_at,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# 将记录转换为 JSON 字符串并添加到 JSONL
|
||||
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
||||
exported_count += 1
|
||||
|
||||
except Exception as e:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
||||
logger.error(
|
||||
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
|
||||
)
|
||||
|
||||
if exported_count == 0:
|
||||
return Response().error("没有成功导出任何对话").__dict__
|
||||
|
||||
# 创建 JSONL 内容
|
||||
jsonl_content = "\n".join(jsonl_lines)
|
||||
|
||||
# 创建一个内存文件对象
|
||||
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
||||
file_obj.seek(0)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
|
||||
|
||||
# 返回文件流
|
||||
export = await self.service.export_conversations(await self._json_body())
|
||||
return await send_file(
|
||||
file_obj,
|
||||
mimetype="application/jsonl",
|
||||
export.file_obj,
|
||||
mimetype=export.mimetype,
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
attachment_filename=export.filename,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"批量导出对话失败: {e!s}").__dict__
|
||||
except ConversationServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("批量导出对话失败: %s", exc, exc_info=True)
|
||||
return self._error(f"批量导出对话失败: {exc!s}")
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from quart import jsonify, request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.dashboard.fastapi_compat import jsonify, request
|
||||
from astrbot.dashboard.services.cron_service import CronService, CronServiceError
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -15,8 +10,7 @@ class CronRoute(Route):
|
||||
self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
self.service = CronService(core_lifecycle)
|
||||
self.routes = [
|
||||
("/cron/jobs", ("GET", self.list_jobs)),
|
||||
("/cron/jobs", ("POST", self.create_job)),
|
||||
@@ -26,276 +20,50 @@ class CronRoute(Route):
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
def _serialize_job(self, job) -> dict:
|
||||
data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__
|
||||
for k in ["created_at", "updated_at", "last_run_at", "next_run_time"]:
|
||||
v = data.get(k)
|
||||
if isinstance(v, datetime):
|
||||
# Attach UTC
|
||||
if v.tzinfo is None:
|
||||
v = v.replace(tzinfo=timezone.utc)
|
||||
data[k] = v.isoformat()
|
||||
# expose note explicitly for UI (prefer payload.note then description)
|
||||
payload = data.get("payload") or {}
|
||||
data["note"] = payload.get("note") or data.get("description") or ""
|
||||
data["run_at"] = payload.get("run_at")
|
||||
data["run_once"] = data.get("run_once", False)
|
||||
# status is internal; hide to avoid implying one-time completion for recurring jobs
|
||||
data.pop("status", None)
|
||||
return data
|
||||
@staticmethod
|
||||
def _ok(data=None, message: str | None = None):
|
||||
return jsonify(Response().ok(data=data, message=message).__dict__)
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return jsonify(Response().error(message).__dict__)
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation, *, message: str | None = None):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result, message)
|
||||
except CronServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_json(self, operation, *, message: str | None = None):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
return await self._run(invoke, message=message)
|
||||
|
||||
async def list_jobs(self):
|
||||
try:
|
||||
cron_mgr = self.core_lifecycle.cron_manager
|
||||
if cron_mgr is None:
|
||||
return jsonify(
|
||||
Response().error("Cron manager not initialized").__dict__
|
||||
)
|
||||
job_type = request.args.get("type")
|
||||
jobs = await cron_mgr.list_jobs(job_type)
|
||||
data = [self._serialize_job(j) for j in jobs]
|
||||
return jsonify(Response().ok(data=data).__dict__)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"Failed to list jobs: {e!s}").__dict__)
|
||||
return await self._run(
|
||||
self.service.list_jobs_from_legacy_query(request.args.get("type"))
|
||||
)
|
||||
|
||||
async def create_job(self):
|
||||
try:
|
||||
cron_mgr = self.core_lifecycle.cron_manager
|
||||
if cron_mgr is None:
|
||||
return jsonify(
|
||||
Response().error("Cron manager not initialized").__dict__
|
||||
)
|
||||
|
||||
payload = await request.json
|
||||
if not isinstance(payload, dict):
|
||||
return jsonify(Response().error("Invalid payload").__dict__)
|
||||
|
||||
name = payload.get("name") or "active_agent_task"
|
||||
cron_expression = payload.get("cron_expression")
|
||||
note = payload.get("note") or payload.get("description") or name
|
||||
session = str(payload.get("session") or "").strip()
|
||||
persona_id = payload.get("persona_id")
|
||||
provider_id = payload.get("provider_id")
|
||||
timezone = payload.get("timezone")
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
run_once = bool(payload.get("run_once", False))
|
||||
run_at = payload.get("run_at")
|
||||
|
||||
if run_once and not run_at:
|
||||
return jsonify(
|
||||
Response().error("run_at is required when run_once=true").__dict__
|
||||
)
|
||||
if (not run_once) and not cron_expression:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error("cron_expression is required when run_once=false")
|
||||
.__dict__
|
||||
)
|
||||
if run_once and cron_expression:
|
||||
cron_expression = None # ignore cron when run_once specified
|
||||
run_at_dt = None
|
||||
if run_at:
|
||||
try:
|
||||
run_at_dt = datetime.fromisoformat(str(run_at))
|
||||
except Exception:
|
||||
return jsonify(
|
||||
Response().error("run_at must be ISO datetime").__dict__
|
||||
)
|
||||
|
||||
job_payload = {
|
||||
"session": session,
|
||||
"note": note,
|
||||
"persona_id": persona_id,
|
||||
"provider_id": provider_id,
|
||||
"run_at": run_at,
|
||||
"origin": "api",
|
||||
}
|
||||
|
||||
job = await cron_mgr.add_active_job(
|
||||
name=name,
|
||||
cron_expression=cron_expression,
|
||||
payload=job_payload,
|
||||
description=note,
|
||||
timezone=timezone,
|
||||
enabled=enabled,
|
||||
run_once=run_once,
|
||||
run_at=run_at_dt,
|
||||
)
|
||||
|
||||
return jsonify(Response().ok(data=self._serialize_job(job)).__dict__)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"Failed to create job: {e!s}").__dict__)
|
||||
return await self._run_json(self.service.create_job)
|
||||
|
||||
async def update_job(self, job_id: str):
|
||||
try:
|
||||
cron_mgr = self.core_lifecycle.cron_manager
|
||||
if cron_mgr is None:
|
||||
return jsonify(
|
||||
Response().error("Cron manager not initialized").__dict__
|
||||
)
|
||||
|
||||
payload = await request.json
|
||||
if not isinstance(payload, dict):
|
||||
return jsonify(Response().error("Invalid payload").__dict__)
|
||||
|
||||
job = await cron_mgr.db.get_cron_job(job_id)
|
||||
if not job:
|
||||
return jsonify(Response().error("Job not found").__dict__)
|
||||
|
||||
updates = {}
|
||||
if "name" in payload:
|
||||
name = str(payload.get("name") or "").strip()
|
||||
if not name:
|
||||
return jsonify(Response().error("name cannot be empty").__dict__)
|
||||
updates["name"] = name
|
||||
|
||||
if "enabled" in payload:
|
||||
updates["enabled"] = bool(payload.get("enabled"))
|
||||
|
||||
if "timezone" in payload:
|
||||
timezone = payload.get("timezone")
|
||||
updates["timezone"] = str(timezone).strip() or None
|
||||
|
||||
next_run_once = (
|
||||
bool(payload.get("run_once"))
|
||||
if "run_once" in payload
|
||||
else bool(job.run_once)
|
||||
)
|
||||
|
||||
if job.job_type == "active_agent":
|
||||
merged_payload = (
|
||||
dict(job.payload) if isinstance(job.payload, dict) else {}
|
||||
)
|
||||
if "payload" in payload and isinstance(payload.get("payload"), dict):
|
||||
merged_payload.update(payload["payload"])
|
||||
|
||||
if "session" in payload:
|
||||
session = str(payload.get("session") or "").strip()
|
||||
if session:
|
||||
merged_payload["session"] = session
|
||||
else:
|
||||
merged_payload.pop("session", None)
|
||||
|
||||
note_updated = False
|
||||
if "note" in payload:
|
||||
note = str(payload.get("note") or "").strip()
|
||||
if not note:
|
||||
return jsonify(
|
||||
Response().error("note cannot be empty").__dict__
|
||||
)
|
||||
merged_payload["note"] = note
|
||||
updates["description"] = note
|
||||
note_updated = True
|
||||
elif "description" in payload:
|
||||
description = str(payload.get("description") or "").strip()
|
||||
if not description:
|
||||
return jsonify(
|
||||
Response().error("description cannot be empty").__dict__
|
||||
)
|
||||
updates["description"] = description
|
||||
merged_payload["note"] = description
|
||||
note_updated = True
|
||||
|
||||
if not note_updated and updates.get("description") is None:
|
||||
existing_note = str(
|
||||
merged_payload.get("note") or job.description or ""
|
||||
).strip()
|
||||
if existing_note:
|
||||
merged_payload["note"] = existing_note
|
||||
|
||||
next_cron_expression = (
|
||||
payload.get("cron_expression")
|
||||
if "cron_expression" in payload
|
||||
else job.cron_expression
|
||||
)
|
||||
if next_cron_expression is not None:
|
||||
next_cron_expression = str(next_cron_expression).strip() or None
|
||||
|
||||
run_at_raw = (
|
||||
payload.get("run_at")
|
||||
if "run_at" in payload
|
||||
else merged_payload.get("run_at")
|
||||
)
|
||||
run_at_iso = None
|
||||
if run_at_raw:
|
||||
try:
|
||||
run_at_iso = datetime.fromisoformat(str(run_at_raw)).isoformat()
|
||||
except Exception:
|
||||
return jsonify(
|
||||
Response().error("run_at must be ISO datetime").__dict__
|
||||
)
|
||||
|
||||
if next_run_once:
|
||||
if not run_at_iso:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error("run_at is required when run_once=true")
|
||||
.__dict__
|
||||
)
|
||||
next_cron_expression = None
|
||||
merged_payload["run_at"] = run_at_iso
|
||||
else:
|
||||
if not next_cron_expression:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error("cron_expression is required when run_once=false")
|
||||
.__dict__
|
||||
)
|
||||
merged_payload.pop("run_at", None)
|
||||
|
||||
updates["run_once"] = next_run_once
|
||||
updates["cron_expression"] = next_cron_expression
|
||||
updates["payload"] = merged_payload
|
||||
else:
|
||||
if "cron_expression" in payload:
|
||||
cron_expression = str(payload.get("cron_expression") or "").strip()
|
||||
if not cron_expression:
|
||||
return jsonify(
|
||||
Response().error("cron_expression cannot be empty").__dict__
|
||||
)
|
||||
updates["cron_expression"] = cron_expression
|
||||
|
||||
if "description" in payload:
|
||||
description = str(payload.get("description") or "").strip()
|
||||
updates["description"] = description or None
|
||||
|
||||
job = await cron_mgr.update_job(job_id, **updates)
|
||||
if not job:
|
||||
return jsonify(Response().error("Job not found").__dict__)
|
||||
return jsonify(Response().ok(data=self._serialize_job(job)).__dict__)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"Failed to update job: {e!s}").__dict__)
|
||||
return await self._run_json(
|
||||
lambda payload: self.service.update_job(job_id, payload)
|
||||
)
|
||||
|
||||
async def delete_job(self, job_id: str):
|
||||
try:
|
||||
cron_mgr = self.core_lifecycle.cron_manager
|
||||
if cron_mgr is None:
|
||||
return jsonify(
|
||||
Response().error("Cron manager not initialized").__dict__
|
||||
)
|
||||
await cron_mgr.delete_job(job_id)
|
||||
return jsonify(Response().ok(message="deleted").__dict__)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"Failed to delete job: {e!s}").__dict__)
|
||||
return await self._run(self.service.delete_job(job_id), message="deleted")
|
||||
|
||||
async def run_job_now(self, job_id: str):
|
||||
try:
|
||||
cron_mgr = self.core_lifecycle.cron_manager
|
||||
if cron_mgr is None:
|
||||
return jsonify(
|
||||
Response().error("Cron manager not initialized").__dict__
|
||||
)
|
||||
job = await cron_mgr.db.get_cron_job(job_id)
|
||||
if not job:
|
||||
return jsonify(Response().error("Job not found").__dict__)
|
||||
task = asyncio.create_task(cron_mgr.run_job_now(job_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return jsonify(Response().ok(message="started").__dict__)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"Failed to run job: {e!s}").__dict__)
|
||||
return await self._run(self.service.run_job_now(job_id), message="started")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from quart import abort, send_file
|
||||
|
||||
from astrbot.core import file_token_service
|
||||
from astrbot.dashboard.fastapi_compat import abort, send_file
|
||||
from astrbot.dashboard.services.file_service import FileService, FileServiceError
|
||||
|
||||
from .route import Route, RouteContext
|
||||
|
||||
@@ -11,6 +10,7 @@ class FileRoute(Route):
|
||||
context: RouteContext,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.service = FileService()
|
||||
self.routes = {
|
||||
"/file/<file_token>": ("GET", self.serve_file),
|
||||
}
|
||||
@@ -18,7 +18,7 @@ class FileRoute(Route):
|
||||
|
||||
async def serve_file(self, file_token: str):
|
||||
try:
|
||||
file_path = await file_token_service.handle_file(file_token)
|
||||
file_path = await self.service.resolve_token_file(file_token)
|
||||
return await send_file(file_path)
|
||||
except (FileNotFoundError, KeyError):
|
||||
except FileServiceError:
|
||||
return abort(404)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,115 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import wave
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from quart import websocket
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_webchat_message_parts,
|
||||
create_attachment_part_from_existing_file,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
from astrbot.dashboard.fastapi_compat import websocket
|
||||
from astrbot.dashboard.services.live_chat_service import LiveChatService
|
||||
|
||||
from .chat import (
|
||||
BotMessageAccumulator,
|
||||
build_bot_history_content,
|
||||
collect_plain_text_from_message_parts,
|
||||
)
|
||||
from .route import Route, RouteContext
|
||||
|
||||
|
||||
class LiveChatSession:
|
||||
"""Live Chat 会话管理器"""
|
||||
|
||||
def __init__(self, session_id: str, username: str) -> None:
|
||||
self.session_id = session_id
|
||||
self.username = username
|
||||
self.conversation_id = str(uuid.uuid4())
|
||||
self.is_speaking = False
|
||||
self.is_processing = False
|
||||
self.should_interrupt = False
|
||||
self.audio_frames: list[bytes] = []
|
||||
self.current_stamp: str | None = None
|
||||
self.temp_audio_path: str | None = None
|
||||
self.chat_subscriptions: dict[str, str] = {}
|
||||
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
|
||||
self.ws_send_lock = asyncio.Lock()
|
||||
|
||||
def start_speaking(self, stamp: str) -> None:
|
||||
"""开始说话"""
|
||||
self.is_speaking = True
|
||||
self.current_stamp = stamp
|
||||
self.audio_frames = []
|
||||
logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}")
|
||||
|
||||
def add_audio_frame(self, data: bytes) -> None:
|
||||
"""添加音频帧"""
|
||||
if self.is_speaking:
|
||||
self.audio_frames.append(data)
|
||||
|
||||
async def end_speaking(self, stamp: str) -> tuple[str | None, float]:
|
||||
"""结束说话,返回组装的 WAV 文件路径和耗时"""
|
||||
start_time = time.time()
|
||||
if not self.is_speaking or stamp != self.current_stamp:
|
||||
logger.warning(
|
||||
f"[Live Chat] stamp 不匹配或未在说话状态: {stamp} vs {self.current_stamp}"
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
self.is_speaking = False
|
||||
|
||||
if not self.audio_frames:
|
||||
logger.warning("[Live Chat] 没有音频帧数据")
|
||||
return None, 0.0
|
||||
|
||||
# 组装 WAV 文件
|
||||
try:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
|
||||
|
||||
# 假设前端发送的是 PCM 数据,采样率 16000Hz,单声道,16位
|
||||
with wave.open(audio_path, "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # 单声道
|
||||
wav_file.setsampwidth(2) # 16位 = 2字节
|
||||
wav_file.setframerate(16000) # 采样率 16000Hz
|
||||
for frame in self.audio_frames:
|
||||
wav_file.writeframes(frame)
|
||||
|
||||
self.temp_audio_path = audio_path
|
||||
logger.info(
|
||||
f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes"
|
||||
)
|
||||
return audio_path, time.time() - start_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 组装 WAV 文件失败: {e}", exc_info=True)
|
||||
return None, 0.0
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""清理临时文件"""
|
||||
if self.temp_audio_path and os.path.exists(self.temp_audio_path):
|
||||
try:
|
||||
os.remove(self.temp_audio_path)
|
||||
logger.debug(f"[Live Chat] 已删除临时文件: {self.temp_audio_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Live Chat] 删除临时文件失败: {e}")
|
||||
self.temp_audio_path = None
|
||||
|
||||
|
||||
class LiveChatRoute(Route):
|
||||
"""Live Chat WebSocket 路由"""
|
||||
|
||||
@@ -120,16 +17,9 @@ class LiveChatRoute(Route):
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.db = db
|
||||
self.plugin_manager = core_lifecycle.plugin_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
self.sessions: dict[str, LiveChatSession] = {}
|
||||
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
|
||||
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.attachments_dir, exist_ok=True)
|
||||
self.service = LiveChatService(db, core_lifecycle)
|
||||
self.sessions = self.service.sessions
|
||||
|
||||
# 注册 WebSocket 路由
|
||||
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
|
||||
self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
|
||||
|
||||
@@ -142,819 +32,13 @@ class LiveChatRoute(Route):
|
||||
await self._unified_ws_loop(force_ct=None)
|
||||
|
||||
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
|
||||
"""统一 WebSocket 循环"""
|
||||
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
|
||||
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
|
||||
token = websocket.args.get("token")
|
||||
if not token:
|
||||
await websocket.close(1008, "Missing authentication token")
|
||||
return
|
||||
|
||||
try:
|
||||
jwt_secret = self.config["dashboard"].get("jwt_secret")
|
||||
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
||||
username = payload["username"]
|
||||
except jwt.ExpiredSignatureError:
|
||||
await websocket.close(1008, "Token expired")
|
||||
return
|
||||
except jwt.InvalidTokenError:
|
||||
await websocket.close(1008, "Invalid token")
|
||||
return
|
||||
|
||||
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
|
||||
live_session = LiveChatSession(session_id, username)
|
||||
self.sessions[session_id] = live_session
|
||||
|
||||
logger.info(f"[Live Chat] WebSocket 连接建立: {username}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
ct = force_ct or message.get("ct", "live")
|
||||
if ct == "chat":
|
||||
await self._handle_chat_message(live_session, message)
|
||||
else:
|
||||
await self._handle_message(live_session, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# 清理会话
|
||||
if session_id in self.sessions:
|
||||
await self._cleanup_chat_subscriptions(live_session)
|
||||
live_session.cleanup()
|
||||
del self.sessions[session_id]
|
||||
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
|
||||
|
||||
async def _create_attachment_from_file(
|
||||
self, filename: str, attach_type: str
|
||||
) -> dict | None:
|
||||
"""从本地文件创建 attachment 并返回消息部分。"""
|
||||
return await create_attachment_part_from_existing_file(
|
||||
filename,
|
||||
attach_type=attach_type,
|
||||
insert_attachment=self.db.insert_attachment,
|
||||
attachments_dir=self.attachments_dir,
|
||||
fallback_dirs=[self.legacy_img_dir],
|
||||
await self.service.run_websocket_session(
|
||||
token=websocket.args.get("token"),
|
||||
force_ct=force_ct,
|
||||
receive_json=websocket.receive_json,
|
||||
send_json=websocket.send_json,
|
||||
close=websocket.close,
|
||||
)
|
||||
|
||||
def _extract_web_search_refs(
|
||||
self, accumulated_text: str, accumulated_parts: list
|
||||
) -> dict:
|
||||
"""从消息中提取 web_search 引用。"""
|
||||
supported = [
|
||||
"web_search_baidu",
|
||||
"web_search_tavily",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
]
|
||||
web_search_results = {}
|
||||
tool_call_parts = [
|
||||
p
|
||||
for p in accumulated_parts
|
||||
if p.get("type") == "tool_call" and p.get("tool_calls")
|
||||
]
|
||||
|
||||
for part in tool_call_parts:
|
||||
for tool_call in part["tool_calls"]:
|
||||
if tool_call.get("name") not in supported or not tool_call.get(
|
||||
"result"
|
||||
):
|
||||
continue
|
||||
try:
|
||||
result_data = json.loads(tool_call["result"])
|
||||
for item in result_data.get("results", []):
|
||||
if idx := item.get("index"):
|
||||
web_search_results[idx] = {
|
||||
"url": item.get("url"),
|
||||
"title": item.get("title"),
|
||||
"snippet": item.get("snippet"),
|
||||
}
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
if not web_search_results:
|
||||
return {}
|
||||
|
||||
ref_indices = {
|
||||
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
|
||||
}
|
||||
|
||||
used_refs = []
|
||||
for ref_index in ref_indices:
|
||||
if ref_index not in web_search_results:
|
||||
continue
|
||||
payload = {"index": ref_index, **web_search_results[ref_index]}
|
||||
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
payload["favicon"] = favicon
|
||||
used_refs.append(payload)
|
||||
|
||||
return {"used": used_refs} if used_refs else {}
|
||||
|
||||
async def _save_bot_message(
|
||||
self,
|
||||
webchat_conv_id: str,
|
||||
message_parts: list[dict],
|
||||
agent_stats: dict,
|
||||
refs: dict,
|
||||
llm_checkpoint_id: str | None = None,
|
||||
):
|
||||
"""保存 bot 消息到历史记录。"""
|
||||
new_his = build_bot_history_content(
|
||||
message_parts,
|
||||
agent_stats=agent_stats,
|
||||
refs=refs,
|
||||
)
|
||||
|
||||
return await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
llm_checkpoint_id=llm_checkpoint_id,
|
||||
)
|
||||
|
||||
async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
|
||||
async with session.ws_send_lock:
|
||||
await websocket.send_json(payload)
|
||||
|
||||
async def _forward_chat_subscription(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
chat_session_id: str,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id, chat_session_id
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
result = await back_queue.get()
|
||||
if not result:
|
||||
continue
|
||||
await self._send_chat_payload(session, {"ct": "chat", **result})
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(request_id)
|
||||
if session.chat_subscriptions.get(chat_session_id) == request_id:
|
||||
session.chat_subscriptions.pop(chat_session_id, None)
|
||||
session.chat_subscription_tasks.pop(chat_session_id, None)
|
||||
|
||||
async def _ensure_chat_subscription(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
chat_session_id: str,
|
||||
) -> str:
|
||||
existing_request_id = session.chat_subscriptions.get(chat_session_id)
|
||||
existing_task = session.chat_subscription_tasks.get(chat_session_id)
|
||||
if existing_request_id and existing_task and not existing_task.done():
|
||||
return existing_request_id
|
||||
|
||||
request_id = f"ws_sub_{uuid.uuid4().hex}"
|
||||
session.chat_subscriptions[chat_session_id] = request_id
|
||||
task = asyncio.create_task(
|
||||
self._forward_chat_subscription(session, chat_session_id, request_id),
|
||||
name=f"chat_ws_sub_{chat_session_id}",
|
||||
)
|
||||
session.chat_subscription_tasks[chat_session_id] = task
|
||||
return request_id
|
||||
|
||||
async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
|
||||
tasks = list(session.chat_subscription_tasks.values())
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for request_id in list(session.chat_subscriptions.values()):
|
||||
webchat_queue_mgr.remove_back_queue(request_id)
|
||||
session.chat_subscriptions.clear()
|
||||
session.chat_subscription_tasks.clear()
|
||||
|
||||
async def _handle_chat_message(
|
||||
self, session: LiveChatSession, message: dict
|
||||
) -> None:
|
||||
"""处理 Chat Mode 消息(ct=chat)"""
|
||||
msg_type = message.get("t")
|
||||
|
||||
if msg_type == "bind":
|
||||
chat_session_id = message.get("session_id")
|
||||
if not isinstance(chat_session_id, str) or not chat_session_id:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "session_id is required",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
request_id = await self._ensure_chat_subscription(session, chat_session_id)
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "session_bound",
|
||||
"session_id": chat_session_id,
|
||||
"message_id": request_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if msg_type == "interrupt":
|
||||
session.should_interrupt = True
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "INTERRUPTED",
|
||||
"code": "INTERRUPTED",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if msg_type != "send":
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": f"Unsupported message type: {msg_type}",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if session.is_processing:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "Session is busy",
|
||||
"code": "PROCESSING_ERROR",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
payload = message.get("message")
|
||||
session_id = message.get("session_id") or session.session_id
|
||||
message_id = message.get("message_id") or str(uuid.uuid4())
|
||||
selected_provider = message.get("selected_provider")
|
||||
selected_model = message.get("selected_model")
|
||||
selected_stt_provider = message.get("selected_stt_provider")
|
||||
selected_tts_provider = message.get("selected_tts_provider")
|
||||
persona_prompt = message.get("persona_prompt")
|
||||
show_reasoning = message.get("show_reasoning")
|
||||
enable_streaming = message.get("enable_streaming", True)
|
||||
|
||||
if not isinstance(payload, list):
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "message must be list",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
message_parts = await self._build_chat_message_parts(payload)
|
||||
has_content = webchat_message_parts_have_content(message_parts)
|
||||
if not has_content:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "Message content is empty",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
await self._ensure_chat_subscription(session, session_id)
|
||||
|
||||
session.is_processing = True
|
||||
session.should_interrupt = False
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||
llm_checkpoint_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
pending_bot_message_flusher = None
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
session.username,
|
||||
session_id,
|
||||
{
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"selected_stt_provider": selected_stt_provider,
|
||||
"selected_tts_provider": selected_tts_provider,
|
||||
"persona_prompt": persona_prompt,
|
||||
"show_reasoning": show_reasoning,
|
||||
"enable_streaming": enable_streaming,
|
||||
"message_id": message_id,
|
||||
"llm_checkpoint_id": llm_checkpoint_id,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
saved_user_record = await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=session_id,
|
||||
content={"type": "user", "message": message_parts_for_storage},
|
||||
sender_id=session.username,
|
||||
sender_name=session.username,
|
||||
llm_checkpoint_id=llm_checkpoint_id,
|
||||
)
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "user_message_saved",
|
||||
"data": {
|
||||
"id": saved_user_record.id,
|
||||
"created_at": to_utc_isoformat(saved_user_record.created_at),
|
||||
"llm_checkpoint_id": llm_checkpoint_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
|
||||
async def flush_pending_bot_message():
|
||||
nonlocal message_accumulator, agent_stats, refs
|
||||
if not (message_accumulator.has_content() or refs or agent_stats):
|
||||
return None
|
||||
|
||||
message_parts_to_save = message_accumulator.build_message_parts(
|
||||
include_pending_tool_calls=True
|
||||
)
|
||||
plain_text = collect_plain_text_from_message_parts(
|
||||
message_parts_to_save
|
||||
)
|
||||
try:
|
||||
extracted_refs = self._extract_web_search_refs(
|
||||
plain_text,
|
||||
message_parts_to_save,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[Live Chat] Failed to extract web search refs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
extracted_refs = refs
|
||||
|
||||
saved_record = await self._save_bot_message(
|
||||
session_id,
|
||||
message_parts_to_save,
|
||||
agent_stats,
|
||||
extracted_refs,
|
||||
llm_checkpoint_id,
|
||||
)
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
return saved_record
|
||||
|
||||
pending_bot_message_flusher = flush_pending_bot_message
|
||||
|
||||
async def send_attachment_saved_event(part: dict | None) -> None:
|
||||
if not part or not part.get("attachment_id") or not part.get("type"):
|
||||
return
|
||||
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "attachment_saved",
|
||||
"data": {
|
||||
"id": part["attachment_id"],
|
||||
"type": part["type"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
session.should_interrupt = False
|
||||
await flush_pending_bot_message()
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if result.get("message_id") and result.get("message_id") != message_id:
|
||||
continue
|
||||
|
||||
result_text = result.get("data", "")
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
if chain_type == "agent_stats":
|
||||
try:
|
||||
parsed_agent_stats = json.loads(result_text)
|
||||
agent_stats = parsed_agent_stats
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "agent_stats",
|
||||
"data": parsed_agent_stats,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
outgoing = {"ct": "chat", **result}
|
||||
await self._send_chat_payload(session, outgoing)
|
||||
|
||||
if msg_type == "plain":
|
||||
message_accumulator.add_plain(
|
||||
result_text,
|
||||
chain_type=chain_type,
|
||||
streaming=streaming,
|
||||
)
|
||||
elif msg_type == "image":
|
||||
filename = str(result_text).replace("[IMAGE]", "")
|
||||
part = await self._create_attachment_from_file(filename, "image")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
elif msg_type == "record":
|
||||
filename = str(result_text).replace("[RECORD]", "")
|
||||
part = await self._create_attachment_from_file(filename, "record")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
elif msg_type == "file":
|
||||
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
|
||||
part = await self._create_attachment_from_file(filename, "file")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
elif msg_type == "video":
|
||||
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
|
||||
part = await self._create_attachment_from_file(filename, "video")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
|
||||
should_save = False
|
||||
if msg_type == "end":
|
||||
should_save = bool(
|
||||
message_accumulator.has_content() or refs or agent_stats
|
||||
)
|
||||
elif (streaming and msg_type == "complete") or not streaming:
|
||||
if chain_type not in (
|
||||
"tool_call",
|
||||
"tool_call_result",
|
||||
"agent_stats",
|
||||
):
|
||||
should_save = True
|
||||
|
||||
if should_save:
|
||||
saved_record = await flush_pending_bot_message()
|
||||
if saved_record:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": to_utc_isoformat(
|
||||
saved_record.created_at
|
||||
),
|
||||
"llm_checkpoint_id": llm_checkpoint_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if msg_type == "end":
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": f"处理失败: {str(e)}",
|
||||
"code": "PROCESSING_ERROR",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
if pending_bot_message_flusher is not None:
|
||||
await pending_bot_message_flusher()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[Live Chat] Failed to persist pending chat message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
session.is_processing = False
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
|
||||
"""构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
|
||||
return await build_webchat_message_parts(
|
||||
message,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
|
||||
"""处理 WebSocket 消息"""
|
||||
msg_type = message.get("t") # 使用 t 代替 type
|
||||
|
||||
if msg_type == "start_speaking":
|
||||
# 开始说话
|
||||
stamp = message.get("stamp")
|
||||
if not stamp:
|
||||
logger.warning("[Live Chat] start_speaking 缺少 stamp")
|
||||
return
|
||||
session.start_speaking(stamp)
|
||||
|
||||
elif msg_type == "speaking_part":
|
||||
# 音频片段
|
||||
audio_data_b64 = message.get("data")
|
||||
if not audio_data_b64:
|
||||
return
|
||||
|
||||
# 解码 base64
|
||||
import base64
|
||||
|
||||
try:
|
||||
audio_data = base64.b64decode(audio_data_b64)
|
||||
session.add_audio_frame(audio_data)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解码音频数据失败: {e}")
|
||||
|
||||
elif msg_type == "end_speaking":
|
||||
# 结束说话
|
||||
stamp = message.get("stamp")
|
||||
if not stamp:
|
||||
logger.warning("[Live Chat] end_speaking 缺少 stamp")
|
||||
return
|
||||
|
||||
audio_path, assemble_duration = await session.end_speaking(stamp)
|
||||
if not audio_path:
|
||||
await websocket.send_json({"t": "error", "data": "音频组装失败"})
|
||||
return
|
||||
|
||||
# 处理音频:STT -> LLM -> TTS
|
||||
await self._process_audio(session, audio_path, assemble_duration)
|
||||
|
||||
elif msg_type == "interrupt":
|
||||
# 用户打断
|
||||
session.should_interrupt = True
|
||||
logger.info(f"[Live Chat] 用户打断: {session.username}")
|
||||
|
||||
async def _process_audio(
|
||||
self, session: LiveChatSession, audio_path: str, assemble_duration: float
|
||||
) -> None:
|
||||
"""处理音频:STT -> LLM -> 流式 TTS"""
|
||||
try:
|
||||
# 发送 WAV 组装耗时
|
||||
await websocket.send_json(
|
||||
{"t": "metrics", "data": {"wav_assemble_time": assemble_duration}}
|
||||
)
|
||||
wav_assembly_finish_time = time.time()
|
||||
|
||||
session.is_processing = True
|
||||
session.should_interrupt = False
|
||||
|
||||
# 1. STT - 语音转文字
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.provider_manager.stt_provider_insts[0]
|
||||
|
||||
if not stt_provider:
|
||||
logger.error("[Live Chat] STT Provider 未配置")
|
||||
await websocket.send_json({"t": "error", "data": "语音识别服务未配置"})
|
||||
return
|
||||
|
||||
await websocket.send_json(
|
||||
{"t": "metrics", "data": {"stt": stt_provider.meta().type}}
|
||||
)
|
||||
|
||||
user_text = await stt_provider.get_text(audio_path)
|
||||
if not user_text:
|
||||
logger.warning("[Live Chat] STT 识别结果为空")
|
||||
return
|
||||
|
||||
logger.info(f"[Live Chat] STT 结果: {user_text}")
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "user_msg",
|
||||
"data": {"text": user_text, "ts": int(time.time() * 1000)},
|
||||
}
|
||||
)
|
||||
|
||||
# 2. 构造消息事件并发送到 pipeline
|
||||
# 使用 webchat queue 机制
|
||||
cid = session.conversation_id
|
||||
queue = webchat_queue_mgr.get_or_create_queue(cid)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
payload = {
|
||||
"message_id": message_id,
|
||||
"message": [{"type": "plain", "text": user_text}], # 直接发送文本
|
||||
"action_type": "live", # 标记为 live mode
|
||||
}
|
||||
|
||||
# 将消息放入队列
|
||||
await queue.put((session.username, cid, payload))
|
||||
|
||||
# 3. 等待响应并流式发送 TTS 音频
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
|
||||
|
||||
bot_text = ""
|
||||
audio_playing = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(
|
||||
session, user_text, bot_text
|
||||
)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
break
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
|
||||
await websocket.send_json({"t": "error", "data": f"处理失败: {str(e)}"})
|
||||
|
||||
finally:
|
||||
session.is_processing = False
|
||||
session.should_interrupt = False
|
||||
|
||||
async def _save_interrupted_message(
|
||||
self, session: LiveChatSession, user_text: str, bot_text: str
|
||||
) -> None:
|
||||
"""保存被打断的消息"""
|
||||
interrupted_text = bot_text + " [用户打断]"
|
||||
logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}")
|
||||
|
||||
# 简单记录到日志,实际保存逻辑可以后续完善
|
||||
try:
|
||||
timestamp = int(time.time() * 1000)
|
||||
logger.info(
|
||||
f"[Live Chat] 用户消息: {user_text} (session: {session.session_id}, ts: {timestamp})"
|
||||
)
|
||||
if bot_text:
|
||||
logger.info(
|
||||
f"[Live Chat] Bot 消息(打断): {interrupted_text} (session: {session.session_id}, ts: {timestamp})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 记录消息失败: {e}", exc_info=True)
|
||||
__all__ = ["LiveChatRoute"]
|
||||
|
||||
@@ -1,30 +1,18 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import make_response, request
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.dashboard.fastapi_compat import Response as CompatResponse
|
||||
from astrbot.dashboard.fastapi_compat import make_response, request
|
||||
from astrbot.dashboard.services.log_service import LogService, LogServiceError
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def _format_log_sse(log: dict, ts: float) -> str:
|
||||
"""辅助函数:格式化 SSE 消息"""
|
||||
payload = {
|
||||
"type": "log",
|
||||
**log,
|
||||
}
|
||||
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class LogRoute(Route):
|
||||
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.service = LogService(log_broker, self.config)
|
||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||
self.app.add_url_rule(
|
||||
"/api/log-history",
|
||||
@@ -42,51 +30,46 @@ class LogRoute(Route):
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def _replay_cached_logs(
|
||||
self, last_event_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助生成器:重放缓存的日志"""
|
||||
@staticmethod
|
||||
def _ok(data=None, message: str | None = None):
|
||||
return Response().ok(data=data, message=message).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation, *, result_as_message: bool = False):
|
||||
try:
|
||||
last_ts = float(last_event_id)
|
||||
cached_logs = list(self.log_broker.log_cache)
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
if result_as_message:
|
||||
return self._ok(message=str(result))
|
||||
return self._ok(result)
|
||||
except LogServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
for log_item in cached_logs:
|
||||
log_ts = float(log_item.get("time", 0))
|
||||
async def _run_json(self, operation, *, result_as_message: bool = False):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
if log_ts > last_ts:
|
||||
yield _format_log_sse(log_item, log_ts)
|
||||
return await self._run(invoke, result_as_message=result_as_message)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Log SSE 补发历史错误: {e}")
|
||||
|
||||
async def log(self) -> QuartResponse:
|
||||
async def log(self) -> CompatResponse:
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
|
||||
async def stream():
|
||||
queue = None
|
||||
try:
|
||||
if last_event_id:
|
||||
async for event in self._replay_cached_logs(last_event_id):
|
||||
yield event
|
||||
|
||||
queue = self.log_broker.register()
|
||||
while True:
|
||||
message = await queue.get()
|
||||
current_ts = message.get("time", time.time())
|
||||
yield _format_log_sse(message, current_ts)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Log SSE 连接错误: {e}")
|
||||
finally:
|
||||
if queue:
|
||||
self.log_broker.unregister(queue)
|
||||
async for event in self.service.stream_log_events(last_event_id):
|
||||
yield event
|
||||
|
||||
response = cast(
|
||||
QuartResponse,
|
||||
CompatResponse,
|
||||
await make_response(
|
||||
stream(),
|
||||
{
|
||||
@@ -102,43 +85,15 @@ class LogRoute(Route):
|
||||
|
||||
async def log_history(self):
|
||||
"""获取日志历史"""
|
||||
try:
|
||||
logs = list(self.log_broker.log_cache)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"logs": logs,
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
return await self._run(self.service.get_log_history)
|
||||
|
||||
async def get_trace_settings(self):
|
||||
"""获取 Trace 设置"""
|
||||
try:
|
||||
trace_enable = self.config.get("trace_enable", True)
|
||||
return Response().ok(data={"trace_enable": trace_enable}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Trace 设置失败: {e}")
|
||||
return Response().error(f"获取 Trace 设置失败: {e}").__dict__
|
||||
return await self._run(self.service.get_trace_settings)
|
||||
|
||||
async def update_trace_settings(self):
|
||||
"""更新 Trace 设置"""
|
||||
try:
|
||||
data = await request.json
|
||||
if data is None:
|
||||
return Response().error("请求数据为空").__dict__
|
||||
|
||||
trace_enable = data.get("trace_enable")
|
||||
if trace_enable is not None:
|
||||
self.config["trace_enable"] = bool(trace_enable)
|
||||
self.config.save_config()
|
||||
|
||||
return Response().ok(message="Trace 设置已更新").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Trace 设置失败: {e}")
|
||||
return Response().error(f"更新 Trace 设置失败: {e}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.update_trace_settings_from_legacy_payload,
|
||||
result_as_message=True,
|
||||
)
|
||||
|
||||
@@ -1,28 +1,27 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from typing import cast
|
||||
|
||||
from quart import g, request, websocket
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.platform.message_session import MessageSesion
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_message_chain_from_payload,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
from astrbot.dashboard.fastapi_compat import (
|
||||
Response as CompatResponse,
|
||||
)
|
||||
from astrbot.dashboard.fastapi_compat import (
|
||||
make_response,
|
||||
request,
|
||||
send_file,
|
||||
websocket,
|
||||
)
|
||||
from astrbot.dashboard.services.chat_service import (
|
||||
ChatService,
|
||||
ChatServiceError,
|
||||
extract_web_search_refs,
|
||||
)
|
||||
from astrbot.dashboard.services.open_api_service import (
|
||||
OpenApiService,
|
||||
OpenApiServiceError,
|
||||
OpenApiWebSocketChatBridge,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
|
||||
from .api_key import ALL_OPEN_API_SCOPES
|
||||
from .chat import (
|
||||
BotMessageAccumulator,
|
||||
ChatRoute,
|
||||
collect_plain_text_from_message_parts,
|
||||
)
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
@@ -32,13 +31,14 @@ class OpenApiRoute(Route):
|
||||
context: RouteContext,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
chat_route: ChatRoute,
|
||||
chat_service: ChatService,
|
||||
*,
|
||||
register_routes: bool = True,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.platform_manager = core_lifecycle.platform_manager
|
||||
self.chat_route = chat_route
|
||||
self.chat_service = chat_service
|
||||
self.service = OpenApiService(db, core_lifecycle)
|
||||
|
||||
self.routes = {
|
||||
"/v1/chat": ("POST", self.chat_send),
|
||||
@@ -51,151 +51,85 @@ class OpenApiRoute(Route):
|
||||
"/v1/im/message": ("POST", self.send_message),
|
||||
"/v1/im/bots": ("GET", self.get_bots),
|
||||
}
|
||||
self.register_routes()
|
||||
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
|
||||
if register_routes:
|
||||
self.register_routes()
|
||||
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_open_username(
|
||||
raw_username: str | None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
if raw_username is None:
|
||||
return None, "Missing key: username"
|
||||
username = str(raw_username).strip()
|
||||
if not username:
|
||||
return None, "username is empty"
|
||||
return username, None
|
||||
def _ok(data=None):
|
||||
return Response().ok(data=data).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result)
|
||||
except (OpenApiServiceError, ChatServiceError) as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_json(self, operation):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
return await self._run(invoke)
|
||||
|
||||
def _get_chat_config_list(self) -> list[dict]:
|
||||
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
|
||||
|
||||
result = []
|
||||
for conf_info in conf_list:
|
||||
conf_id = str(conf_info.get("id", "")).strip()
|
||||
result.append(
|
||||
{
|
||||
"id": conf_id,
|
||||
"name": str(conf_info.get("name", "")).strip(),
|
||||
"path": str(conf_info.get("path", "")).strip(),
|
||||
"is_default": conf_id == "default",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _resolve_chat_config_id(self, post_data: dict) -> tuple[str | None, str | None]:
|
||||
raw_config_id = post_data.get("config_id")
|
||||
raw_config_name = post_data.get("config_name")
|
||||
config_id = str(raw_config_id).strip() if raw_config_id is not None else ""
|
||||
config_name = (
|
||||
str(raw_config_name).strip() if raw_config_name is not None else ""
|
||||
)
|
||||
|
||||
if not config_id and not config_name:
|
||||
return None, None
|
||||
|
||||
conf_list = self._get_chat_config_list()
|
||||
conf_map = {item["id"]: item for item in conf_list}
|
||||
|
||||
if config_id:
|
||||
if config_id not in conf_map:
|
||||
return None, f"config_id not found: {config_id}"
|
||||
return config_id, None
|
||||
|
||||
if not config_name:
|
||||
return None, "config_name is empty"
|
||||
|
||||
matched = [item for item in conf_list if item["name"] == config_name]
|
||||
if not matched:
|
||||
return None, f"config_name not found: {config_name}"
|
||||
if len(matched) > 1:
|
||||
return (
|
||||
None,
|
||||
f"config_name is ambiguous, please use config_id: {config_name}",
|
||||
)
|
||||
|
||||
return matched[0]["id"], None
|
||||
|
||||
async def _ensure_chat_session(
|
||||
self,
|
||||
username: str,
|
||||
session_id: str,
|
||||
) -> str | None:
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if session:
|
||||
if session.creator != username:
|
||||
return "session_id belongs to another username"
|
||||
return None
|
||||
|
||||
try:
|
||||
await self.db.create_platform_session(
|
||||
creator=username,
|
||||
platform_id="webchat",
|
||||
session_id=session_id,
|
||||
is_group=0,
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle rare race when same session_id is created concurrently.
|
||||
existing = await self.db.get_platform_session_by_id(session_id)
|
||||
if existing and existing.creator == username:
|
||||
return None
|
||||
logger.error("Failed to create chat session %s: %s", session_id, e)
|
||||
return f"Failed to create session: {e}"
|
||||
|
||||
return None
|
||||
return self.service.get_chat_config_list()
|
||||
|
||||
async def chat_send(self):
|
||||
post_data = await request.get_json(silent=True) or {}
|
||||
effective_username, username_err = self._resolve_open_username(
|
||||
post_data.get("username")
|
||||
)
|
||||
if username_err:
|
||||
return Response().error(username_err).__dict__
|
||||
if not effective_username:
|
||||
return Response().error("Invalid username").__dict__
|
||||
|
||||
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
|
||||
if not session_id:
|
||||
session_id = str(uuid4())
|
||||
post_data["session_id"] = session_id
|
||||
ensure_session_err = await self._ensure_chat_session(
|
||||
effective_username,
|
||||
session_id,
|
||||
)
|
||||
if ensure_session_err:
|
||||
return Response().error(ensure_session_err).__dict__
|
||||
|
||||
config_id, resolve_err = self._resolve_chat_config_id(post_data)
|
||||
if resolve_err:
|
||||
return Response().error(resolve_err).__dict__
|
||||
|
||||
original_username = g.get("username", "guest")
|
||||
g.username = effective_username
|
||||
if config_id:
|
||||
umo = f"webchat:FriendMessage:webchat!{effective_username}!{session_id}"
|
||||
try:
|
||||
if config_id == "default":
|
||||
await self.core_lifecycle.umop_config_router.delete_route(umo)
|
||||
else:
|
||||
await self.core_lifecycle.umop_config_router.update_route(
|
||||
umo, config_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update chat config route for %s with %s: %s",
|
||||
umo,
|
||||
config_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed to update chat config route: {e}")
|
||||
.__dict__
|
||||
)
|
||||
try:
|
||||
return await self.chat_route.chat(post_data=post_data)
|
||||
finally:
|
||||
g.username = original_username
|
||||
(
|
||||
effective_username,
|
||||
session_id,
|
||||
config_id,
|
||||
) = await self.service.prepare_chat_send(
|
||||
post_data,
|
||||
self._get_chat_config_list(),
|
||||
)
|
||||
except OpenApiServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
config_err = await self.service.update_session_config_route(
|
||||
username=effective_username,
|
||||
session_id=session_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
if config_err:
|
||||
return self._error(config_err)
|
||||
|
||||
return await self._chat_response(effective_username, post_data)
|
||||
|
||||
async def _chat_response(self, username: str, post_data: dict):
|
||||
try:
|
||||
stream = await self.chat_service.build_chat_stream(username, post_data)
|
||||
except ChatServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
response = cast(
|
||||
CompatResponse,
|
||||
await make_response(
|
||||
stream,
|
||||
{
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
),
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _extract_ws_api_key() -> str | None:
|
||||
@@ -213,451 +147,71 @@ class OpenApiRoute(Route):
|
||||
return auth_header.removeprefix("ApiKey ").strip()
|
||||
return None
|
||||
|
||||
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
|
||||
raw_key = self._extract_ws_api_key()
|
||||
if not raw_key:
|
||||
return False, "Missing API key"
|
||||
|
||||
key_hash = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
raw_key.encode("utf-8"),
|
||||
b"astrbot_api_key",
|
||||
100_000,
|
||||
).hex()
|
||||
api_key = await self.db.get_active_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
return False, "Invalid API key"
|
||||
|
||||
if isinstance(api_key.scopes, list):
|
||||
scopes = api_key.scopes
|
||||
else:
|
||||
scopes = list(ALL_OPEN_API_SCOPES)
|
||||
|
||||
if "*" not in scopes and "chat" not in scopes:
|
||||
return False, "Insufficient API key scope"
|
||||
|
||||
await self.db.touch_api_key(api_key.key_id)
|
||||
return True, None
|
||||
|
||||
async def _send_chat_ws_error(self, message: str, code: str) -> None:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": code,
|
||||
"data": message,
|
||||
}
|
||||
)
|
||||
|
||||
async def _update_session_config_route(
|
||||
async def _insert_webchat_user_message(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
session_id: str,
|
||||
config_id: str | None,
|
||||
) -> str | None:
|
||||
if not config_id:
|
||||
return None
|
||||
|
||||
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
|
||||
try:
|
||||
if config_id == "default":
|
||||
await self.core_lifecycle.umop_config_router.delete_route(umo)
|
||||
else:
|
||||
await self.core_lifecycle.umop_config_router.update_route(
|
||||
umo, config_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update chat config route for %s with %s: %s",
|
||||
umo,
|
||||
config_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return f"Failed to update chat config route: {e}"
|
||||
return None
|
||||
|
||||
async def _handle_chat_ws_send(self, post_data: dict) -> None:
|
||||
effective_username, username_err = self._resolve_open_username(
|
||||
post_data.get("username")
|
||||
)
|
||||
if username_err or not effective_username:
|
||||
await self._send_chat_ws_error(
|
||||
username_err or "Invalid username", "BAD_USER"
|
||||
)
|
||||
return
|
||||
|
||||
message = post_data.get("message")
|
||||
if message is None:
|
||||
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
|
||||
return
|
||||
|
||||
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
|
||||
if not session_id:
|
||||
session_id = str(uuid4())
|
||||
|
||||
ensure_session_err = await self._ensure_chat_session(
|
||||
effective_username,
|
||||
session_id,
|
||||
)
|
||||
if ensure_session_err:
|
||||
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
|
||||
return
|
||||
|
||||
config_id, resolve_err = self._resolve_chat_config_id(post_data)
|
||||
if resolve_err:
|
||||
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
|
||||
return
|
||||
|
||||
config_err = await self._update_session_config_route(
|
||||
username=effective_username,
|
||||
effective_username: str,
|
||||
message_parts: list,
|
||||
) -> None:
|
||||
await self.service.insert_webchat_user_message(
|
||||
session_id=session_id,
|
||||
config_id=config_id,
|
||||
effective_username=effective_username,
|
||||
message_parts=message_parts,
|
||||
)
|
||||
if config_err:
|
||||
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
|
||||
return
|
||||
|
||||
message_parts = await self.chat_route._build_user_message_parts(message)
|
||||
if not webchat_message_parts_have_content(message_parts):
|
||||
await self._send_chat_ws_error(
|
||||
"Message content is empty (reply only is not allowed)",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
return
|
||||
|
||||
message_id = str(post_data.get("message_id") or uuid4())
|
||||
selected_provider = post_data.get("selected_provider")
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True)
|
||||
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||
try:
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
effective_username,
|
||||
session_id,
|
||||
{
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"enable_streaming": enable_streaming,
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
await self.chat_route.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=session_id,
|
||||
content={"type": "user", "message": message_parts_for_storage},
|
||||
sender_id=effective_username,
|
||||
sender_name=effective_username,
|
||||
)
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "session_id",
|
||||
"data": None,
|
||||
"session_id": session_id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if "message_id" in result and result["message_id"] != message_id:
|
||||
logger.warning("openapi ws stream message_id mismatch")
|
||||
continue
|
||||
|
||||
result_text = result.get("data", "")
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
|
||||
if chain_type == "agent_stats":
|
||||
try:
|
||||
stats_info = {
|
||||
"type": "agent_stats",
|
||||
"data": json.loads(result_text),
|
||||
}
|
||||
await websocket.send_json(stats_info)
|
||||
agent_stats = stats_info["data"]
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
await websocket.send_json(result)
|
||||
|
||||
if msg_type == "plain":
|
||||
message_accumulator.add_plain(
|
||||
result_text,
|
||||
chain_type=chain_type,
|
||||
streaming=streaming,
|
||||
)
|
||||
elif msg_type == "image":
|
||||
filename = str(result_text).replace("[IMAGE]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "image"
|
||||
)
|
||||
message_accumulator.add_attachment(part)
|
||||
elif msg_type == "record":
|
||||
filename = str(result_text).replace("[RECORD]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "record"
|
||||
)
|
||||
message_accumulator.add_attachment(part)
|
||||
elif msg_type == "file":
|
||||
filename = str(result_text).replace("[FILE]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "file"
|
||||
)
|
||||
message_accumulator.add_attachment(part)
|
||||
elif msg_type == "video":
|
||||
filename = str(result_text).replace("[VIDEO]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "video"
|
||||
)
|
||||
message_accumulator.add_attachment(part)
|
||||
|
||||
should_save = False
|
||||
if msg_type == "end":
|
||||
should_save = bool(
|
||||
message_accumulator.has_content() or refs or agent_stats
|
||||
)
|
||||
elif (streaming and msg_type == "complete") or not streaming:
|
||||
if chain_type not in ("tool_call", "tool_call_result"):
|
||||
should_save = True
|
||||
|
||||
if should_save:
|
||||
message_parts_to_save = message_accumulator.build_message_parts(
|
||||
include_pending_tool_calls=True
|
||||
)
|
||||
plain_text = collect_plain_text_from_message_parts(
|
||||
message_parts_to_save
|
||||
)
|
||||
try:
|
||||
refs = self.chat_route._extract_web_search_refs(
|
||||
plain_text,
|
||||
message_parts_to_save,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Open API WS failed to extract web search refs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
saved_record = await self.chat_route._save_bot_message(
|
||||
session_id,
|
||||
message_parts_to_save,
|
||||
agent_stats,
|
||||
refs,
|
||||
)
|
||||
if saved_record:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": to_utc_isoformat(
|
||||
saved_record.created_at
|
||||
),
|
||||
},
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
if msg_type == "end":
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
|
||||
await self._send_chat_ws_error(
|
||||
f"Failed to process message: {e}", "PROCESSING_ERROR"
|
||||
)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
def _build_chat_ws_bridge(self) -> OpenApiWebSocketChatBridge:
|
||||
return OpenApiWebSocketChatBridge(
|
||||
build_user_message_parts=self.chat_service.build_user_message_parts,
|
||||
create_attachment_from_file=self.chat_service.create_attachment_from_file,
|
||||
extract_web_search_refs=extract_web_search_refs,
|
||||
insert_user_message=self._insert_webchat_user_message,
|
||||
save_bot_message=self.chat_service.save_bot_message,
|
||||
)
|
||||
|
||||
async def chat_ws(self) -> None:
|
||||
authed, auth_err = await self._authenticate_chat_ws_api_key()
|
||||
if not authed:
|
||||
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
|
||||
await websocket.close(1008, auth_err or "Unauthorized")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
if not isinstance(message, dict):
|
||||
await self._send_chat_ws_error(
|
||||
"message must be an object",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
msg_type = message.get("t", "send")
|
||||
if msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
continue
|
||||
if msg_type != "send":
|
||||
await self._send_chat_ws_error(
|
||||
f"Unsupported message type: {msg_type}",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
await self._handle_chat_ws_send(message)
|
||||
except Exception as e:
|
||||
logger.debug("Open API WS connection closed: %s", e)
|
||||
await self.service.run_chat_websocket(
|
||||
raw_api_key=self._extract_ws_api_key(),
|
||||
receive_json=websocket.receive_json,
|
||||
send_json=websocket.send_json,
|
||||
close=websocket.close,
|
||||
conf_list=self._get_chat_config_list(),
|
||||
chat_bridge=self._build_chat_ws_bridge(),
|
||||
)
|
||||
|
||||
async def openapi_upload_file(self):
|
||||
return await self.chat_route.post_file()
|
||||
return await self._run(
|
||||
self.chat_service.save_uploaded_file_from_legacy_files(await request.files)
|
||||
)
|
||||
|
||||
async def openapi_get_file(self):
|
||||
return await self.chat_route.get_attachment()
|
||||
try:
|
||||
(
|
||||
file_path,
|
||||
mimetype,
|
||||
) = await self.chat_service.resolve_attachment_file_from_legacy_query(
|
||||
request.args.get("attachment_id")
|
||||
)
|
||||
return await send_file(file_path, mimetype=mimetype)
|
||||
except ChatServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
except (FileNotFoundError, OSError):
|
||||
return self._error("File access error")
|
||||
|
||||
async def get_chat_sessions(self):
|
||||
username, username_err = self._resolve_open_username(
|
||||
request.args.get("username")
|
||||
)
|
||||
if username_err:
|
||||
return Response().error(username_err).__dict__
|
||||
|
||||
assert username is not None # for type checker
|
||||
|
||||
try:
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
except ValueError:
|
||||
return Response().error("page and page_size must be integers").__dict__
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
platform_id = request.args.get("platform_id")
|
||||
|
||||
(
|
||||
paginated_sessions,
|
||||
total,
|
||||
) = await self.db.get_platform_sessions_by_creator_paginated(
|
||||
creator=username,
|
||||
platform_id=platform_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
exclude_project_sessions=True,
|
||||
)
|
||||
|
||||
sessions_data = []
|
||||
for item in paginated_sessions:
|
||||
session = item["session"]
|
||||
sessions_data.append(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"platform_id": session.platform_id,
|
||||
"creator": session.creator,
|
||||
"display_name": session.display_name,
|
||||
"is_group": session.is_group,
|
||||
"created_at": to_utc_isoformat(session.created_at),
|
||||
"updated_at": to_utc_isoformat(session.updated_at),
|
||||
}
|
||||
return await self._run(
|
||||
self.service.get_chat_sessions_from_legacy_query(
|
||||
username=request.args.get("username"),
|
||||
page=request.args.get("page", 1),
|
||||
page_size=request.args.get("page_size", 20),
|
||||
platform_id=request.args.get("platform_id"),
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"sessions": sessions_data,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def get_chat_configs(self):
|
||||
conf_list = self._get_chat_config_list()
|
||||
return Response().ok(data={"configs": conf_list}).__dict__
|
||||
|
||||
async def _build_message_chain_from_payload(
|
||||
self,
|
||||
message_payload: str | list,
|
||||
):
|
||||
return await build_message_chain_from_payload(
|
||||
message_payload,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=True,
|
||||
)
|
||||
return self._ok({"configs": self._get_chat_config_list()})
|
||||
|
||||
async def send_message(self):
|
||||
post_data = await request.json or {}
|
||||
message_payload = post_data.get("message", {})
|
||||
umo = post_data.get("umo")
|
||||
|
||||
if message_payload is None:
|
||||
return Response().error("Missing key: message").__dict__
|
||||
if not umo:
|
||||
return Response().error("Missing key: umo").__dict__
|
||||
|
||||
try:
|
||||
session = MessageSesion.from_str(str(umo))
|
||||
except Exception as e:
|
||||
return Response().error(f"Invalid umo: {e}").__dict__
|
||||
|
||||
platform_id = session.platform_name
|
||||
platform_inst = next(
|
||||
(
|
||||
inst
|
||||
for inst in self.platform_manager.platform_insts
|
||||
if inst.meta().id == platform_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not platform_inst:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Bot not found or not running for platform: {platform_id}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
message_chain = await self._build_message_chain_from_payload(
|
||||
message_payload
|
||||
)
|
||||
await platform_inst.send_by_session(session, message_chain)
|
||||
return Response().ok().__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"Open API send_message failed: {e}", exc_info=True)
|
||||
return Response().error(f"Failed to send message: {e}").__dict__
|
||||
return await self._run_json(self.service.send_message)
|
||||
|
||||
async def get_bots(self):
|
||||
bot_ids = []
|
||||
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
|
||||
platform_id = platform.get("id") if isinstance(platform, dict) else None
|
||||
if (
|
||||
isinstance(platform_id, str)
|
||||
and platform_id
|
||||
and platform_id not in bot_ids
|
||||
):
|
||||
bot_ids.append(platform_id)
|
||||
return Response().ok(data={"bot_ids": bot_ids}).__dict__
|
||||
return await self._run(self.service.get_bots)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import traceback
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.sentinels import NOT_GIVEN
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
from astrbot.dashboard.services.persona_service import (
|
||||
PersonaService,
|
||||
PersonaServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -14,7 +13,7 @@ class PersonaRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db_helper: BaseDatabase,
|
||||
db_helper,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
@@ -26,7 +25,6 @@ class PersonaRoute(Route):
|
||||
"/persona/delete": ("POST", self.delete_persona),
|
||||
"/persona/move": ("POST", self.move_persona),
|
||||
"/persona/reorder": ("POST", self.reorder_items),
|
||||
# Folder routes
|
||||
"/persona/folder/list": ("GET", self.list_folders),
|
||||
"/persona/folder/tree": ("GET", self.get_folder_tree),
|
||||
"/persona/folder/detail": ("POST", self.get_folder_detail),
|
||||
@@ -34,464 +32,129 @@ class PersonaRoute(Route):
|
||||
"/persona/folder/update": ("POST", self.update_folder),
|
||||
"/persona/folder/delete": ("POST", self.delete_folder),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.persona_mgr = core_lifecycle.persona_mgr
|
||||
self.service = PersonaService(core_lifecycle)
|
||||
self.register_routes()
|
||||
|
||||
@staticmethod
|
||||
def _ok(data):
|
||||
return Response().ok(data).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _run(self, operation, *, label: str):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result)
|
||||
except (PersonaServiceError, ValueError) as exc:
|
||||
return self._error(str(exc))
|
||||
except Exception as exc:
|
||||
logger.error("%s: %s", label, exc, exc_info=True)
|
||||
return self._error(f"{label}: {exc!s}")
|
||||
|
||||
async def list_personas(self):
|
||||
"""获取所有人格列表"""
|
||||
try:
|
||||
# 支持按文件夹筛选
|
||||
folder_id = request.args.get("folder_id")
|
||||
if folder_id is not None:
|
||||
personas = await self.persona_mgr.get_personas_by_folder(
|
||||
folder_id if folder_id else None
|
||||
)
|
||||
else:
|
||||
personas = await self.persona_mgr.get_all_personas()
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
[
|
||||
{
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools,
|
||||
"skills": persona.skills,
|
||||
"custom_error_message": persona.custom_error_message,
|
||||
"folder_id": persona.folder_id,
|
||||
"sort_order": persona.sort_order,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
}
|
||||
for persona in personas
|
||||
],
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取人格列表失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取人格列表失败: {e!s}").__dict__
|
||||
return await self._run(
|
||||
lambda: self.service.list_personas_from_legacy_query(
|
||||
folder_id=request.args.get("folder_id"),
|
||||
has_folder_id="folder_id" in request.args,
|
||||
),
|
||||
label="获取人格列表失败",
|
||||
)
|
||||
|
||||
async def get_persona_detail(self):
|
||||
"""获取指定人格的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
persona = await self.persona_mgr.get_persona(persona_id)
|
||||
if not persona:
|
||||
return Response().error("人格不存在").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools,
|
||||
"skills": persona.skills,
|
||||
"custom_error_message": persona.custom_error_message,
|
||||
"folder_id": persona.folder_id,
|
||||
"sort_order": persona.sort_order,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取人格详情失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取人格详情失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.get_persona_detail(data),
|
||||
label="获取人格详情失败",
|
||||
)
|
||||
|
||||
async def create_persona(self):
|
||||
"""创建新人格"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id", "").strip()
|
||||
system_prompt = data.get("system_prompt", "").strip()
|
||||
begin_dialogs = data.get("begin_dialogs", [])
|
||||
tools = data.get("tools")
|
||||
skills = data.get("skills")
|
||||
custom_error_message = data.get("custom_error_message")
|
||||
folder_id = data.get("folder_id") # None 表示根目录
|
||||
sort_order = data.get("sort_order", 0)
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("人格ID不能为空").__dict__
|
||||
|
||||
if not system_prompt:
|
||||
return Response().error("系统提示词不能为空").__dict__
|
||||
|
||||
if custom_error_message is not None:
|
||||
if not isinstance(custom_error_message, str):
|
||||
return Response().error("自定义报错回复信息必须是字符串").__dict__
|
||||
custom_error_message = custom_error_message.strip() or None
|
||||
|
||||
# 验证 begin_dialogs 格式
|
||||
if begin_dialogs and len(begin_dialogs) % 2 != 0:
|
||||
return (
|
||||
Response()
|
||||
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
persona = await self.persona_mgr.create_persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs if begin_dialogs else None,
|
||||
tools=tools if tools else None,
|
||||
skills=skills if skills else None,
|
||||
custom_error_message=custom_error_message,
|
||||
folder_id=folder_id,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": "人格创建成功",
|
||||
"persona": {
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools or [],
|
||||
"skills": persona.skills or [],
|
||||
"custom_error_message": persona.custom_error_message,
|
||||
"folder_id": persona.folder_id,
|
||||
"sort_order": persona.sort_order,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
},
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"创建人格失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"创建人格失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.create_persona(data),
|
||||
label="创建人格失败",
|
||||
)
|
||||
|
||||
async def update_persona(self):
|
||||
"""更新人格信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
system_prompt = data.get("system_prompt")
|
||||
begin_dialogs = data.get("begin_dialogs")
|
||||
has_tools = "tools" in data
|
||||
tools = data.get("tools")
|
||||
has_skills = "skills" in data
|
||||
skills = data.get("skills")
|
||||
has_custom_error_message = "custom_error_message" in data
|
||||
custom_error_message = data.get("custom_error_message")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
if has_custom_error_message:
|
||||
if custom_error_message is not None and not isinstance(
|
||||
custom_error_message, str
|
||||
):
|
||||
return Response().error("自定义报错回复信息必须是字符串").__dict__
|
||||
if isinstance(custom_error_message, str):
|
||||
custom_error_message = custom_error_message.strip() or None
|
||||
|
||||
# 验证 begin_dialogs 格式
|
||||
if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
|
||||
return (
|
||||
Response()
|
||||
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
update_kwargs = {
|
||||
"persona_id": persona_id,
|
||||
"system_prompt": system_prompt,
|
||||
"begin_dialogs": begin_dialogs,
|
||||
}
|
||||
if has_tools:
|
||||
update_kwargs["tools"] = tools
|
||||
if has_skills:
|
||||
update_kwargs["skills"] = skills
|
||||
if has_custom_error_message:
|
||||
update_kwargs["custom_error_message"] = custom_error_message
|
||||
|
||||
await self.persona_mgr.update_persona(**update_kwargs)
|
||||
|
||||
return Response().ok({"message": "人格更新成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新人格失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新人格失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.update_persona(data),
|
||||
label="更新人格失败",
|
||||
)
|
||||
|
||||
async def delete_persona(self):
|
||||
"""删除人格"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
await self.persona_mgr.delete_persona(persona_id)
|
||||
|
||||
return Response().ok({"message": "人格删除成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除人格失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"删除人格失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.delete_persona(data),
|
||||
label="删除人格失败",
|
||||
)
|
||||
|
||||
async def move_persona(self):
|
||||
"""移动人格到指定文件夹"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
folder_id = data.get("folder_id") # None 表示移动到根目录
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
await self.persona_mgr.move_persona_to_folder(persona_id, folder_id)
|
||||
|
||||
return Response().ok({"message": "人格移动成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"移动人格失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"移动人格失败: {e!s}").__dict__
|
||||
|
||||
# ====
|
||||
# Folder Routes
|
||||
# ====
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.move_persona(data),
|
||||
label="移动人格失败",
|
||||
)
|
||||
|
||||
async def list_folders(self):
|
||||
"""获取文件夹列表"""
|
||||
try:
|
||||
parent_id = request.args.get("parent_id")
|
||||
# 空字符串视为 None(根目录)
|
||||
if parent_id == "":
|
||||
parent_id = None
|
||||
folders = await self.persona_mgr.get_folders(parent_id)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
[
|
||||
{
|
||||
"folder_id": folder.folder_id,
|
||||
"name": folder.name,
|
||||
"parent_id": folder.parent_id,
|
||||
"description": folder.description,
|
||||
"sort_order": folder.sort_order,
|
||||
"created_at": folder.created_at.isoformat()
|
||||
if folder.created_at
|
||||
else None,
|
||||
"updated_at": folder.updated_at.isoformat()
|
||||
if folder.updated_at
|
||||
else None,
|
||||
}
|
||||
for folder in folders
|
||||
],
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件夹列表失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取文件夹列表失败: {e!s}").__dict__
|
||||
return await self._run(
|
||||
lambda: self.service.list_folders_from_legacy_query(
|
||||
request.args.get("parent_id")
|
||||
),
|
||||
label="获取文件夹列表失败",
|
||||
)
|
||||
|
||||
async def get_folder_tree(self):
|
||||
"""获取文件夹树形结构"""
|
||||
try:
|
||||
tree = await self.persona_mgr.get_folder_tree()
|
||||
return Response().ok(tree).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件夹树失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取文件夹树失败: {e!s}").__dict__
|
||||
return await self._run(self.service.get_folder_tree, label="获取文件夹树失败")
|
||||
|
||||
async def get_folder_detail(self):
|
||||
"""获取指定文件夹的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
if not folder_id:
|
||||
return Response().error("缺少必要参数: folder_id").__dict__
|
||||
|
||||
folder = await self.persona_mgr.get_folder(folder_id)
|
||||
if not folder:
|
||||
return Response().error("文件夹不存在").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"folder_id": folder.folder_id,
|
||||
"name": folder.name,
|
||||
"parent_id": folder.parent_id,
|
||||
"description": folder.description,
|
||||
"sort_order": folder.sort_order,
|
||||
"created_at": folder.created_at.isoformat()
|
||||
if folder.created_at
|
||||
else None,
|
||||
"updated_at": folder.updated_at.isoformat()
|
||||
if folder.updated_at
|
||||
else None,
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件夹详情失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取文件夹详情失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.get_folder_detail(data),
|
||||
label="获取文件夹详情失败",
|
||||
)
|
||||
|
||||
async def create_folder(self):
|
||||
"""创建文件夹"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
name = data.get("name", "").strip()
|
||||
parent_id = data.get("parent_id")
|
||||
description = data.get("description")
|
||||
sort_order = data.get("sort_order", 0)
|
||||
|
||||
if not name:
|
||||
return Response().error("文件夹名称不能为空").__dict__
|
||||
|
||||
folder = await self.persona_mgr.create_folder(
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
description=description,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": "文件夹创建成功",
|
||||
"folder": {
|
||||
"folder_id": folder.folder_id,
|
||||
"name": folder.name,
|
||||
"parent_id": folder.parent_id,
|
||||
"description": folder.description,
|
||||
"sort_order": folder.sort_order,
|
||||
"created_at": folder.created_at.isoformat()
|
||||
if folder.created_at
|
||||
else None,
|
||||
"updated_at": folder.updated_at.isoformat()
|
||||
if folder.updated_at
|
||||
else None,
|
||||
},
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建文件夹失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"创建文件夹失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.create_folder(data),
|
||||
label="创建文件夹失败",
|
||||
)
|
||||
|
||||
async def update_folder(self):
|
||||
"""更新文件夹信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
folder_id = data.get("folder_id")
|
||||
name = data.get("name")
|
||||
parent_id = data.get("parent_id") if "parent_id" in data else NOT_GIVEN
|
||||
description = (
|
||||
data.get("description") if "description" in data else NOT_GIVEN
|
||||
)
|
||||
sort_order = data.get("sort_order")
|
||||
|
||||
if not folder_id:
|
||||
return Response().error("缺少必要参数: folder_id").__dict__
|
||||
|
||||
await self.persona_mgr.update_folder(
|
||||
folder_id=folder_id,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
description=description,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return Response().ok({"message": "文件夹更新成功"}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新文件夹失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新文件夹失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.update_folder(data),
|
||||
label="更新文件夹失败",
|
||||
)
|
||||
|
||||
async def delete_folder(self):
|
||||
"""删除文件夹"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
if not folder_id:
|
||||
return Response().error("缺少必要参数: folder_id").__dict__
|
||||
|
||||
await self.persona_mgr.delete_folder(folder_id)
|
||||
|
||||
return Response().ok({"message": "文件夹删除成功"}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除文件夹失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"删除文件夹失败: {e!s}").__dict__
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.delete_folder(data),
|
||||
label="删除文件夹失败",
|
||||
)
|
||||
|
||||
async def reorder_items(self):
|
||||
"""批量更新排序顺序
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"items": [
|
||||
{"id": "persona_id_1", "type": "persona", "sort_order": 0},
|
||||
{"id": "persona_id_2", "type": "persona", "sort_order": 1},
|
||||
{"id": "folder_id_1", "type": "folder", "sort_order": 0},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
items = data.get("items", [])
|
||||
|
||||
if not items:
|
||||
return Response().error("items 不能为空").__dict__
|
||||
|
||||
# 验证每个 item 的格式
|
||||
for item in items:
|
||||
if not all(k in item for k in ("id", "type", "sort_order")):
|
||||
return (
|
||||
Response()
|
||||
.error("每个 item 必须包含 id, type, sort_order 字段")
|
||||
.__dict__
|
||||
)
|
||||
if item["type"] not in ("persona", "folder"):
|
||||
return (
|
||||
Response()
|
||||
.error("type 字段必须是 'persona' 或 'folder'")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
await self.persona_mgr.batch_update_sort_order(items)
|
||||
|
||||
return Response().ok({"message": "排序更新成功"}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新排序失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新排序失败: {e!s}").__dict__
|
||||
"""批量更新排序顺序"""
|
||||
data = await request.get_json()
|
||||
return await self._run(
|
||||
self.service.reorder_items(data),
|
||||
label="更新排序失败",
|
||||
)
|
||||
|
||||
@@ -1,39 +1,20 @@
|
||||
"""统一 Webhook 路由
|
||||
"""Unified webhook routes.
|
||||
|
||||
提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。
|
||||
Provides a unified webhook callback entrypoint for multiple platforms.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.sources.dingtalk.app_registration import (
|
||||
poll_dingtalk_app_registration_once,
|
||||
request_dingtalk_app_registration,
|
||||
)
|
||||
from astrbot.core.platform.sources.lark.app_registration import (
|
||||
poll_app_registration_once,
|
||||
request_app_registration,
|
||||
)
|
||||
from astrbot.core.platform.sources.lark.bot_info import request_lark_bot_info
|
||||
from astrbot.core.platform.sources.weixin_oc.login_registration import (
|
||||
poll_weixin_oc_login_once,
|
||||
request_weixin_oc_login_qr,
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
from astrbot.dashboard.services.platform_service import (
|
||||
PlatformService,
|
||||
PlatformServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def _random_platform_id_suffix() -> str:
|
||||
return "_" + "".join(secrets.choice(string.ascii_lowercase) for _ in range(4))
|
||||
|
||||
|
||||
class PlatformRoute(Route):
|
||||
"""统一 Webhook 路由"""
|
||||
"""Unified webhook route."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -41,21 +22,17 @@ class PlatformRoute(Route):
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.platform_manager = core_lifecycle.platform_manager
|
||||
self.service = PlatformService(core_lifecycle)
|
||||
|
||||
self._register_webhook_routes()
|
||||
|
||||
def _register_webhook_routes(self) -> None:
|
||||
"""注册 webhook 路由"""
|
||||
# 统一 webhook 入口,支持 GET 和 POST
|
||||
self.app.add_url_rule(
|
||||
"/api/platform/webhook/<webhook_uuid>",
|
||||
view_func=self.unified_webhook_callback,
|
||||
methods=["GET", "POST"],
|
||||
)
|
||||
|
||||
# 平台统计信息接口
|
||||
self.app.add_url_rule(
|
||||
"/api/platform/stats",
|
||||
view_func=self.get_platform_stats,
|
||||
@@ -68,218 +45,37 @@ class PlatformRoute(Route):
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def unified_webhook_callback(self, webhook_uuid: str):
|
||||
"""统一 webhook 回调入口
|
||||
@staticmethod
|
||||
def _ok(data=None):
|
||||
return Response().ok(data).__dict__
|
||||
|
||||
Args:
|
||||
webhook_uuid: 平台配置中的 webhook_uuid
|
||||
@staticmethod
|
||||
def _error(exc: PlatformServiceError):
|
||||
return Response().error(str(exc)).__dict__, exc.status_code
|
||||
|
||||
Returns:
|
||||
根据平台适配器返回相应的响应
|
||||
"""
|
||||
# 根据 webhook_uuid 查找对应的平台
|
||||
platform_adapter = self._find_platform_by_uuid(webhook_uuid)
|
||||
|
||||
if not platform_adapter:
|
||||
logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台")
|
||||
return Response().error("未找到对应平台").__dict__, 404
|
||||
|
||||
# 调用平台适配器的 webhook_callback 方法
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
result = await platform_adapter.webhook_callback(request)
|
||||
return result
|
||||
except NotImplementedError:
|
||||
logger.error(
|
||||
f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法"
|
||||
)
|
||||
return Response().error("平台未支持统一 Webhook 模式").__dict__, 500
|
||||
except Exception as e:
|
||||
logger.error(f"处理 webhook 回调时发生错误: {e}", exc_info=True)
|
||||
return Response().error("处理回调失败").__dict__, 500
|
||||
return self._ok(await operation())
|
||||
except PlatformServiceError as exc:
|
||||
return self._error(exc)
|
||||
|
||||
def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
|
||||
"""根据 webhook_uuid 查找对应的平台适配器
|
||||
async def _run_sync(self, operation):
|
||||
try:
|
||||
return self._ok(operation())
|
||||
except PlatformServiceError as exc:
|
||||
return self._error(exc)
|
||||
|
||||
Args:
|
||||
webhook_uuid: webhook UUID
|
||||
|
||||
Returns:
|
||||
平台适配器实例,未找到则返回 None
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.config.get("webhook_uuid") == webhook_uuid:
|
||||
if platform.unified_webhook():
|
||||
return platform
|
||||
return None
|
||||
async def unified_webhook_callback(self, webhook_uuid: str):
|
||||
return await self._run(
|
||||
lambda: self.service.handle_webhook_callback(webhook_uuid, request)
|
||||
)
|
||||
|
||||
async def get_platform_stats(self):
|
||||
"""获取所有平台的统计信息
|
||||
|
||||
Returns:
|
||||
包含平台统计信息的响应
|
||||
"""
|
||||
try:
|
||||
stats = self.platform_manager.get_all_stats()
|
||||
return Response().ok(stats).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台统计信息失败: {e}", exc_info=True)
|
||||
return Response().error(f"获取统计信息失败: {e}").__dict__, 500
|
||||
return await self._run_sync(self.service.get_platform_stats)
|
||||
|
||||
async def handle_platform_registration(self, platform_type: str):
|
||||
"""Handle dashboard one-click platform registration actions."""
|
||||
try:
|
||||
payload = await request.get_json(silent=True) or {}
|
||||
action = str(payload.get("action", "")).strip().lower()
|
||||
if not action:
|
||||
return Response().error("Missing action").__dict__, 400
|
||||
|
||||
platform_config = payload.get("platform_config")
|
||||
if not isinstance(platform_config, dict):
|
||||
platform_config = {}
|
||||
|
||||
if platform_type == "lark":
|
||||
return await self._handle_lark_registration(
|
||||
action,
|
||||
payload,
|
||||
platform_config,
|
||||
)
|
||||
if platform_type == "weixin_oc":
|
||||
return await self._handle_weixin_oc_registration(
|
||||
action,
|
||||
payload,
|
||||
platform_config,
|
||||
)
|
||||
if platform_type == "dingtalk":
|
||||
return await self._handle_dingtalk_registration(action, payload)
|
||||
|
||||
return Response().error(
|
||||
f"Unsupported platform registration: {platform_type}"
|
||||
).__dict__, 404
|
||||
except Exception as e:
|
||||
logger.error(f"处理平台一键创建请求失败: {e}", exc_info=True)
|
||||
return Response().error(str(e)).__dict__, 500
|
||||
|
||||
async def _handle_lark_registration(
|
||||
self,
|
||||
action: str,
|
||||
payload: dict,
|
||||
platform_config: dict,
|
||||
):
|
||||
domain = str(platform_config.get("domain") or "").strip()
|
||||
|
||||
if action == "start":
|
||||
registration = await request_app_registration(domain)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"status": "pending",
|
||||
"device_code": registration.device_code,
|
||||
"registration_code": registration.device_code,
|
||||
"user_code": registration.user_code,
|
||||
"verification_uri": registration.verification_uri,
|
||||
"verification_uri_complete": registration.verification_uri_complete,
|
||||
"expires_in": registration.expires_in,
|
||||
"interval": registration.interval,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
if action == "poll":
|
||||
device_code = str(
|
||||
payload.get("device_code") or payload.get("registration_code") or ""
|
||||
).strip()
|
||||
if not device_code:
|
||||
return Response().error("Missing device_code").__dict__, 400
|
||||
result = await poll_app_registration_once(
|
||||
domain=domain,
|
||||
device_code=device_code,
|
||||
)
|
||||
if result.get("status") == "created":
|
||||
try:
|
||||
bot_info = await request_lark_bot_info(
|
||||
domain=str(result.get("domain") or domain),
|
||||
app_id=str(result.get("app_id") or ""),
|
||||
app_secret=str(result.get("app_secret") or ""),
|
||||
)
|
||||
if bot_info.app_name:
|
||||
result["bot_name"] = bot_info.app_name
|
||||
if bot_info.open_id:
|
||||
result["bot_open_id"] = bot_info.open_id
|
||||
except Exception as e:
|
||||
logger.error(f"获取飞书机器人信息失败: {e}", exc_info=True)
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
return Response().error(f"Unsupported action: {action}").__dict__, 400
|
||||
|
||||
async def _handle_dingtalk_registration(self, action: str, payload: dict):
|
||||
if action == "start":
|
||||
registration = await request_dingtalk_app_registration()
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"status": "pending",
|
||||
"device_code": registration.device_code,
|
||||
"registration_code": registration.device_code,
|
||||
"user_code": registration.user_code,
|
||||
"verification_uri": registration.verification_uri,
|
||||
"verification_uri_complete": registration.verification_uri_complete,
|
||||
"expires_in": registration.expires_in,
|
||||
"interval": registration.interval,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
if action == "poll":
|
||||
device_code = str(
|
||||
payload.get("device_code") or payload.get("registration_code") or ""
|
||||
).strip()
|
||||
if not device_code:
|
||||
return Response().error("Missing device_code").__dict__, 400
|
||||
result = await poll_dingtalk_app_registration_once(device_code)
|
||||
if result.get("status") == "created":
|
||||
result["platform_id_suffix"] = _random_platform_id_suffix()
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
return Response().error(f"Unsupported action: {action}").__dict__, 400
|
||||
|
||||
async def _handle_weixin_oc_registration(
|
||||
self,
|
||||
action: str,
|
||||
payload: dict,
|
||||
platform_config: dict,
|
||||
):
|
||||
if action == "start":
|
||||
registration = await request_weixin_oc_login_qr(platform_config)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"status": "pending",
|
||||
"registration_code": registration.qrcode,
|
||||
"qrcode": registration.qrcode,
|
||||
"qrcode_img_content": registration.qrcode_img_content,
|
||||
"interval": registration.interval,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
if action == "poll":
|
||||
qrcode = str(
|
||||
payload.get("qrcode") or payload.get("registration_code") or ""
|
||||
).strip()
|
||||
if not qrcode:
|
||||
return Response().error("Missing qrcode").__dict__, 400
|
||||
result = await poll_weixin_oc_login_once(
|
||||
platform_config=platform_config,
|
||||
qrcode=qrcode,
|
||||
)
|
||||
if result.get("status") == "created":
|
||||
result["platform_id_suffix"] = _random_platform_id_suffix()
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
return Response().error(f"Unsupported action: {action}").__dict__, 400
|
||||
payload = await request.get_json(silent=True) or {}
|
||||
return await self._run(
|
||||
lambda: self.service.handle_platform_registration(platform_type, payload)
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import Quart
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.dashboard.fastapi_compat import FastAPIAppAdapter
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteContext:
|
||||
config: AstrBotConfig
|
||||
app: Quart
|
||||
app: FastAPIAppAdapter
|
||||
|
||||
|
||||
class Route:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,38 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import cmp_to_key
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import psutil
|
||||
from quart import request
|
||||
from sqlmodel import col, select
|
||||
|
||||
from astrbot.core import DEMO_MODE, logger
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.helper import check_migration_needed_v4
|
||||
from astrbot.core.db.po import ProviderStat
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.utils.auth_password import (
|
||||
is_default_dashboard_password,
|
||||
is_legacy_dashboard_password,
|
||||
)
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
from astrbot.core.utils.storage_cleaner import StorageCleaner
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
from astrbot.dashboard.password_state import (
|
||||
get_dashboard_password_hash,
|
||||
is_password_change_required,
|
||||
is_password_storage_upgraded,
|
||||
)
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
from astrbot.dashboard.services.stat_service import StatService, StatServiceError
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -58,533 +27,79 @@ class StatRoute(Route):
|
||||
"/stat/storage": ("GET", self.get_storage_status),
|
||||
"/stat/storage/cleanup": ("POST", self.cleanup_storage),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.service = StatService(db_helper, core_lifecycle, self.config)
|
||||
self.register_routes()
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.storage_cleaner = StorageCleaner(self.config)
|
||||
|
||||
async def restart_core(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
await self.core_lifecycle.restart()
|
||||
return Response().ok().__dict__
|
||||
|
||||
def _get_running_time_components(self, total_seconds: int):
|
||||
"""将总秒数转换为时分秒组件"""
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return {"hours": hours, "minutes": minutes, "seconds": seconds}
|
||||
|
||||
async def is_default_cred(self):
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db_helper,
|
||||
self.config,
|
||||
)
|
||||
if password_change_required:
|
||||
return not DEMO_MODE
|
||||
|
||||
storage_upgraded = await is_password_storage_upgraded(
|
||||
self.db_helper,
|
||||
self.config,
|
||||
)
|
||||
if not storage_upgraded:
|
||||
return False
|
||||
|
||||
username = self.config["dashboard"]["username"]
|
||||
password = get_dashboard_password_hash(self.config, upgraded=True)
|
||||
return (
|
||||
username == "astrbot" and is_default_dashboard_password(password)
|
||||
) and not DEMO_MODE
|
||||
|
||||
async def get_version(self):
|
||||
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
|
||||
storage_upgraded = await is_password_storage_upgraded(
|
||||
self.db_helper,
|
||||
self.config,
|
||||
)
|
||||
password = get_dashboard_password_hash(
|
||||
self.config,
|
||||
upgraded=storage_upgraded,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"version": VERSION,
|
||||
"dashboard_version": await get_dashboard_version(),
|
||||
"change_pwd_hint": await self.is_default_cred(),
|
||||
"legacy_pwd_hint": is_legacy_dashboard_password(password),
|
||||
"password_upgrade_required": not storage_upgraded,
|
||||
"need_migration": need_migration,
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def get_start_time(self):
|
||||
return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__
|
||||
|
||||
async def get_storage_status(self):
|
||||
try:
|
||||
status = await asyncio.to_thread(self.storage_cleaner.get_status)
|
||||
return Response().ok(status).__dict__
|
||||
except Exception:
|
||||
logger.error("获取存储占用失败", exc_info=True)
|
||||
return (
|
||||
Response().error("获取存储占用失败,请查看后端日志了解详情。").__dict__
|
||||
)
|
||||
|
||||
async def cleanup_storage(self):
|
||||
try:
|
||||
data = await request.get_json(silent=True)
|
||||
target = "all"
|
||||
if isinstance(data, dict):
|
||||
target = str(data.get("target", "all"))
|
||||
|
||||
result = await asyncio.to_thread(self.storage_cleaner.cleanup, target)
|
||||
return Response().ok(result).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception:
|
||||
logger.error("清理存储失败", exc_info=True)
|
||||
return Response().error("清理存储失败,请查看后端日志了解详情。").__dict__
|
||||
|
||||
async def get_stat(self):
|
||||
offset_sec = request.args.get("offset_sec", 86400)
|
||||
offset_sec = int(offset_sec)
|
||||
try:
|
||||
stat = self.db_helper.get_base_stats(offset_sec)
|
||||
now = int(time.time())
|
||||
start_time = now - offset_sec
|
||||
message_time_based_stats = []
|
||||
|
||||
idx = 0
|
||||
for bucket_end in range(start_time, now, 3600):
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(stat.platform)
|
||||
and stat.platform[idx].timestamp < bucket_end
|
||||
):
|
||||
cnt += stat.platform[idx].count
|
||||
idx += 1
|
||||
message_time_based_stats.append([bucket_end, cnt])
|
||||
|
||||
stat_dict = stat.__dict__
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=0.5)
|
||||
thread_count = threading.active_count()
|
||||
|
||||
# 获取插件信息
|
||||
plugins = self.core_lifecycle.star_context.get_all_stars()
|
||||
plugin_info = []
|
||||
for plugin in plugins:
|
||||
info = {
|
||||
"name": getattr(plugin, "name", plugin.__class__.__name__),
|
||||
"version": getattr(plugin, "version", "1.0.0"),
|
||||
"is_enabled": True,
|
||||
}
|
||||
plugin_info.append(info)
|
||||
|
||||
# 计算运行时长组件
|
||||
running_time = self._get_running_time_components(
|
||||
int(time.time()) - self.core_lifecycle.start_time,
|
||||
)
|
||||
|
||||
stat_dict.update(
|
||||
{
|
||||
"platform": self.db_helper.get_grouped_base_stats(
|
||||
offset_sec,
|
||||
).platform,
|
||||
"message_count": self.db_helper.get_total_message_count() or 0,
|
||||
"platform_count": len(
|
||||
self.core_lifecycle.platform_manager.get_insts(),
|
||||
),
|
||||
"plugin_count": len(plugins),
|
||||
"plugins": plugin_info,
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": running_time, # 现在返回时间组件而不是格式化的字符串
|
||||
"memory": {
|
||||
"process": psutil.Process().memory_info().rss >> 20,
|
||||
"system": psutil.virtual_memory().total >> 20,
|
||||
},
|
||||
"cpu_percent": round(cpu_percent, 1),
|
||||
"thread_count": thread_count,
|
||||
"start_time": self.core_lifecycle.start_time,
|
||||
},
|
||||
)
|
||||
|
||||
return Response().ok(stat_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _ensure_aware_utc(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
def _ok(data=None):
|
||||
return Response().ok(data).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return Response().error(message).__dict__
|
||||
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
return self._ok(await operation())
|
||||
except StatServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_sync(self, operation):
|
||||
try:
|
||||
return self._ok(operation())
|
||||
except StatServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_json(self, operation, *, silent: bool = False):
|
||||
payload = await request.get_json(silent=silent)
|
||||
return await self._run(lambda: operation(payload))
|
||||
|
||||
async def restart_core(self):
|
||||
return await self._run(self.service.restart_core)
|
||||
|
||||
async def get_version(self):
|
||||
return await self._run(self.service.get_version)
|
||||
|
||||
async def get_start_time(self):
|
||||
return await self._run_sync(self.service.get_start_time)
|
||||
|
||||
async def get_storage_status(self):
|
||||
return await self._run(self.service.get_storage_status)
|
||||
|
||||
async def cleanup_storage(self):
|
||||
return await self._run_json(
|
||||
self.service.cleanup_storage_from_legacy_payload,
|
||||
silent=True,
|
||||
)
|
||||
|
||||
async def get_stat(self):
|
||||
return await self._run(
|
||||
lambda: self.service.get_stat_from_legacy_query(
|
||||
request.args.get("offset_sec", 86400)
|
||||
)
|
||||
)
|
||||
|
||||
async def get_provider_token_stats(self):
|
||||
try:
|
||||
try:
|
||||
days = int(request.args.get("days", 1))
|
||||
except (TypeError, ValueError):
|
||||
days = 1
|
||||
if days not in (1, 3, 7):
|
||||
days = 1
|
||||
|
||||
local_tz = datetime.now().astimezone().tzinfo or timezone.utc
|
||||
now_local = datetime.now(local_tz)
|
||||
range_start_local = (now_local - timedelta(days=days)).replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
return await self._run(
|
||||
lambda: self.service.get_provider_token_stats_from_legacy_query(
|
||||
request.args.get("days", 1)
|
||||
)
|
||||
today_start_local = now_local.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
query_start_local = min(range_start_local, today_start_local)
|
||||
query_start_utc = query_start_local.astimezone(timezone.utc)
|
||||
|
||||
async with self.db_helper.get_db() as session:
|
||||
result = await session.execute(
|
||||
select(ProviderStat)
|
||||
.where(
|
||||
ProviderStat.agent_type == "internal",
|
||||
ProviderStat.created_at >= query_start_utc,
|
||||
)
|
||||
.order_by(col(ProviderStat.created_at).asc())
|
||||
)
|
||||
records = result.scalars().all()
|
||||
|
||||
bucket_timestamps: list[int] = []
|
||||
bucket_cursor = range_start_local
|
||||
while bucket_cursor <= now_local:
|
||||
bucket_timestamps.append(int(bucket_cursor.timestamp() * 1000))
|
||||
bucket_cursor += timedelta(hours=1)
|
||||
|
||||
trend_by_provider: dict[str, dict[int, int]] = defaultdict(
|
||||
lambda: defaultdict(int)
|
||||
)
|
||||
total_by_provider: dict[str, int] = defaultdict(int)
|
||||
total_by_umo: dict[str, int] = defaultdict(int)
|
||||
total_by_bucket: dict[int, int] = defaultdict(int)
|
||||
range_total_tokens = 0
|
||||
range_total_output_tokens = 0
|
||||
range_total_calls = 0
|
||||
range_success_calls = 0
|
||||
range_ttft_total_ms = 0.0
|
||||
range_ttft_samples = 0
|
||||
range_duration_total_ms = 0.0
|
||||
range_duration_samples = 0
|
||||
today_by_model: dict[str, int] = defaultdict(int)
|
||||
today_by_provider: dict[str, int] = defaultdict(int)
|
||||
today_total_tokens = 0
|
||||
today_total_calls = 0
|
||||
|
||||
for record in records:
|
||||
created_at_utc = self._ensure_aware_utc(record.created_at)
|
||||
created_at_local = created_at_utc.astimezone(local_tz)
|
||||
token_total = (
|
||||
record.token_input_other
|
||||
+ record.token_input_cached
|
||||
+ record.token_output
|
||||
)
|
||||
provider_id = record.provider_id or "unknown"
|
||||
provider_model = record.provider_model or "Unknown"
|
||||
|
||||
if created_at_local >= range_start_local:
|
||||
bucket_local = created_at_local.replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
)
|
||||
bucket_ts = int(bucket_local.timestamp() * 1000)
|
||||
trend_by_provider[provider_id][bucket_ts] += token_total
|
||||
total_by_provider[provider_id] += token_total
|
||||
total_by_umo[record.umo or "unknown"] += token_total
|
||||
total_by_bucket[bucket_ts] += token_total
|
||||
range_total_tokens += token_total
|
||||
range_total_calls += 1
|
||||
if record.status != "error":
|
||||
range_success_calls += 1
|
||||
if record.time_to_first_token > 0:
|
||||
range_ttft_total_ms += record.time_to_first_token * 1000
|
||||
range_ttft_samples += 1
|
||||
if record.end_time > record.start_time:
|
||||
range_duration_total_ms += (
|
||||
record.end_time - record.start_time
|
||||
) * 1000
|
||||
range_duration_samples += 1
|
||||
range_total_output_tokens += record.token_output
|
||||
|
||||
if created_at_local >= today_start_local:
|
||||
today_total_calls += 1
|
||||
today_total_tokens += token_total
|
||||
today_by_model[provider_model] += token_total
|
||||
today_by_provider[provider_id] += token_total
|
||||
|
||||
sorted_provider_ids = sorted(
|
||||
total_by_provider.keys(),
|
||||
key=lambda item: total_by_provider[item],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
series = [
|
||||
{
|
||||
"name": provider_id,
|
||||
"data": [
|
||||
[bucket_ts, trend_by_provider[provider_id].get(bucket_ts, 0)]
|
||||
for bucket_ts in bucket_timestamps
|
||||
],
|
||||
"total_tokens": total_by_provider[provider_id],
|
||||
}
|
||||
for provider_id in sorted_provider_ids
|
||||
]
|
||||
|
||||
total_series = [
|
||||
[bucket_ts, total_by_bucket.get(bucket_ts, 0)]
|
||||
for bucket_ts in bucket_timestamps
|
||||
]
|
||||
|
||||
today_by_model_data = [
|
||||
{"provider_model": model_name, "tokens": tokens}
|
||||
for model_name, tokens in sorted(
|
||||
today_by_model.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
today_by_provider_data = [
|
||||
{"provider_id": provider_id, "tokens": tokens}
|
||||
for provider_id, tokens in sorted(
|
||||
today_by_provider.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
range_by_provider_data = [
|
||||
{"provider_id": provider_id, "tokens": tokens}
|
||||
for provider_id, tokens in sorted(
|
||||
total_by_provider.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
range_by_umo_data = [
|
||||
{"umo": umo, "tokens": tokens}
|
||||
for umo, tokens in sorted(
|
||||
total_by_umo.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"days": days,
|
||||
"trend": {
|
||||
"series": series,
|
||||
"total_series": total_series,
|
||||
},
|
||||
"range_total_tokens": range_total_tokens,
|
||||
"range_total_calls": range_total_calls,
|
||||
"range_avg_ttft_ms": (
|
||||
range_ttft_total_ms / range_ttft_samples
|
||||
if range_ttft_samples
|
||||
else 0
|
||||
),
|
||||
"range_avg_duration_ms": (
|
||||
range_duration_total_ms / range_duration_samples
|
||||
if range_duration_samples
|
||||
else 0
|
||||
),
|
||||
"range_avg_tpm": (
|
||||
range_total_output_tokens
|
||||
/ (range_duration_total_ms / 1000 / 60)
|
||||
if range_duration_total_ms > 0
|
||||
else 0
|
||||
),
|
||||
"range_success_rate": (
|
||||
range_success_calls / range_total_calls
|
||||
if range_total_calls
|
||||
else 0
|
||||
),
|
||||
"range_by_provider": range_by_provider_data,
|
||||
"range_by_umo": range_by_umo_data,
|
||||
"today_total_tokens": today_total_tokens,
|
||||
"today_total_calls": today_total_calls,
|
||||
"today_by_model": today_by_model_data,
|
||||
"today_by_provider": today_by_provider_data,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
)
|
||||
|
||||
async def test_ghproxy_connection(self):
|
||||
"""测试 GitHub 代理连接是否可用。"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
proxy_url: str = data.get("proxy_url")
|
||||
|
||||
if not proxy_url:
|
||||
return Response().error("proxy_url is required").__dict__
|
||||
|
||||
proxy_url = proxy_url.rstrip("/")
|
||||
|
||||
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
|
||||
start_time = time.time()
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
test_url,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as response,
|
||||
):
|
||||
if response.status == 200:
|
||||
end_time = time.time()
|
||||
_ = await response.text()
|
||||
ret = {
|
||||
"latency": round((end_time - start_time) * 1000, 2),
|
||||
}
|
||||
return Response().ok(data=ret).__dict__
|
||||
return (
|
||||
Response().error(f"Failed. Status code: {response.status}").__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.test_ghproxy_connection_from_legacy_payload
|
||||
)
|
||||
|
||||
async def get_changelog(self):
|
||||
"""获取指定版本的更新日志"""
|
||||
try:
|
||||
version = request.args.get("version")
|
||||
if not version:
|
||||
return Response().error("version parameter is required").__dict__
|
||||
|
||||
version = version.lstrip("v")
|
||||
|
||||
# 防止路径遍历攻击
|
||||
if not re.match(r"^[a-zA-Z0-9._-]+$", version):
|
||||
return Response().error("Invalid version format").__dict__
|
||||
if ".." in version or "/" in version or "\\" in version:
|
||||
return Response().error("Invalid version format").__dict__
|
||||
|
||||
filename = f"v{version}.md"
|
||||
project_path = get_astrbot_path()
|
||||
changelogs_dir = os.path.join(project_path, "changelogs")
|
||||
changelog_path = os.path.join(changelogs_dir, filename)
|
||||
|
||||
# 规范化路径,防止符号链接攻击
|
||||
changelog_path = os.path.realpath(changelog_path)
|
||||
changelogs_dir = os.path.realpath(changelogs_dir)
|
||||
|
||||
# 验证最终路径在预期的 changelogs 目录内(防止路径遍历)
|
||||
# 确保规范化后的路径以 changelogs_dir 开头,且是目录内的文件
|
||||
changelog_path_normalized = os.path.normpath(changelog_path)
|
||||
changelogs_dir_normalized = os.path.normpath(changelogs_dir)
|
||||
|
||||
# 检查路径是否在预期目录内(必须是目录的子文件,不能是目录本身)
|
||||
expected_prefix = changelogs_dir_normalized + os.sep
|
||||
if not changelog_path_normalized.startswith(expected_prefix):
|
||||
logger.warning(
|
||||
f"Path traversal attempt detected: {version} -> {changelog_path}",
|
||||
)
|
||||
return Response().error("Invalid version format").__dict__
|
||||
|
||||
if not os.path.exists(changelog_path):
|
||||
return (
|
||||
Response()
|
||||
.error(f"Changelog for version {version} not found")
|
||||
.__dict__
|
||||
)
|
||||
if not os.path.isfile(changelog_path):
|
||||
return (
|
||||
Response()
|
||||
.error(f"Changelog for version {version} not found")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
with open(changelog_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return Response().ok({"content": content, "version": version}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
return await self._run_sync(
|
||||
lambda: self.service.get_changelog(request.args.get("version"))
|
||||
)
|
||||
|
||||
async def list_changelog_versions(self):
|
||||
"""获取所有可用的更新日志版本列表"""
|
||||
try:
|
||||
project_path = get_astrbot_path()
|
||||
changelogs_dir = os.path.join(project_path, "changelogs")
|
||||
|
||||
if not os.path.exists(changelogs_dir):
|
||||
return Response().ok({"versions": []}).__dict__
|
||||
|
||||
versions = []
|
||||
for filename in os.listdir(changelogs_dir):
|
||||
if filename.endswith(".md") and filename.startswith("v"):
|
||||
# 提取版本号(去除 v 前缀和 .md 后缀)
|
||||
version = filename[1:-3] # 去掉 "v" 和 ".md"
|
||||
# 验证版本号格式
|
||||
if re.match(r"^[a-zA-Z0-9._-]+$", version):
|
||||
versions.append(version)
|
||||
|
||||
# 按版本号排序(降序,最新的在前)
|
||||
# 使用项目中的 VersionComparator 进行语义化版本号排序
|
||||
versions.sort(
|
||||
key=cmp_to_key(
|
||||
lambda v1, v2: VersionComparator.compare_version(v2, v1),
|
||||
),
|
||||
)
|
||||
|
||||
return Response().ok({"versions": versions}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
return await self._run_sync(self.service.list_changelog_versions)
|
||||
|
||||
async def get_first_notice(self):
|
||||
"""读取项目根目录 FIRST_NOTICE.md 内容。"""
|
||||
try:
|
||||
locale = (request.args.get("locale") or "").strip()
|
||||
if not re.match(r"^[A-Za-z0-9_-]*$", locale):
|
||||
locale = ""
|
||||
|
||||
base_path = Path(get_astrbot_path())
|
||||
candidates: list[Path] = []
|
||||
|
||||
if locale:
|
||||
candidates.append(base_path / f"FIRST_NOTICE.{locale}.md")
|
||||
if locale.lower().startswith("zh"):
|
||||
candidates.append(base_path / "FIRST_NOTICE.md")
|
||||
candidates.append(base_path / "FIRST_NOTICE.zh-CN.md")
|
||||
elif locale.lower().startswith("en"):
|
||||
candidates.append(base_path / "FIRST_NOTICE.en-US.md")
|
||||
|
||||
candidates.extend(
|
||||
[
|
||||
base_path / "FIRST_NOTICE.md",
|
||||
base_path / "FIRST_NOTICE.en-US.md",
|
||||
],
|
||||
)
|
||||
|
||||
for notice_path in candidates:
|
||||
if not notice_path.is_file():
|
||||
continue
|
||||
content = notice_path.read_text(encoding="utf-8")
|
||||
if content.strip():
|
||||
return Response().ok({"content": content}).__dict__
|
||||
|
||||
return Response().ok({"content": None}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
return await self._run_sync(
|
||||
lambda: self.service.get_first_notice(request.args.get("locale"))
|
||||
)
|
||||
|
||||
@@ -1,37 +1,19 @@
|
||||
from astrbot.dashboard.services.static_file_service import StaticFileService
|
||||
|
||||
from .route import Route, RouteContext
|
||||
|
||||
|
||||
class StaticFileRoute(Route):
|
||||
def __init__(self, context: RouteContext) -> None:
|
||||
super().__init__(context)
|
||||
self.service = StaticFileService()
|
||||
|
||||
index_ = [
|
||||
"/",
|
||||
"/auth/login",
|
||||
"/config",
|
||||
"/logs",
|
||||
"/extension",
|
||||
"/dashboard/default",
|
||||
"/alkaid",
|
||||
"/alkaid/knowledge-base",
|
||||
"/alkaid/long-term-memory",
|
||||
"/alkaid/other",
|
||||
"/console",
|
||||
"/chat",
|
||||
"/settings",
|
||||
"/platforms",
|
||||
"/providers",
|
||||
"/about",
|
||||
"/extension-marketplace",
|
||||
"/conversation",
|
||||
"/tool-use",
|
||||
]
|
||||
for i in index_:
|
||||
for i in self.service.list_index_routes():
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
@self.app.errorhandler(404)
|
||||
async def page_not_found(e) -> str:
|
||||
return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://docs.astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。"
|
||||
return self.service.get_not_found_message()
|
||||
|
||||
async def index(self):
|
||||
return await self.app.send_static_file("index.html")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import traceback
|
||||
|
||||
from quart import jsonify, request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.dashboard.fastapi_compat import jsonify, request
|
||||
from astrbot.dashboard.services.subagent_service import (
|
||||
SubAgentService,
|
||||
SubAgentServiceError,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -16,7 +15,7 @@ class SubAgentRoute(Route):
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.service = SubAgentService(core_lifecycle)
|
||||
# NOTE: dict cannot hold duplicate keys; use list form to register multiple
|
||||
# methods for the same path.
|
||||
self.routes = [
|
||||
@@ -26,92 +25,35 @@ class SubAgentRoute(Route):
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
async def get_config(self):
|
||||
@staticmethod
|
||||
def _response(data=None, message: str | None = None):
|
||||
return jsonify(Response().ok(data=data, message=message).__dict__)
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str):
|
||||
return jsonify(Response().error(message).__dict__)
|
||||
|
||||
async def _run(self, operation, *, message: str | None = None):
|
||||
try:
|
||||
cfg = self.core_lifecycle.astrbot_config
|
||||
data = cfg.get("subagent_orchestrator")
|
||||
return self._response(await operation(), message)
|
||||
except SubAgentServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
# First-time access: return a sane default instead of erroring.
|
||||
if not isinstance(data, dict):
|
||||
data = {
|
||||
"main_enable": False,
|
||||
"remove_main_duplicate_tools": False,
|
||||
"agents": [],
|
||||
}
|
||||
async def _run_sync(self, operation):
|
||||
try:
|
||||
return self._response(operation())
|
||||
except SubAgentServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
# Backward compatibility: older config used `enable`.
|
||||
if (
|
||||
isinstance(data, dict)
|
||||
and "main_enable" not in data
|
||||
and "enable" in data
|
||||
):
|
||||
data["main_enable"] = bool(data.get("enable", False))
|
||||
async def _run_json(self, operation, *, message: str | None = None):
|
||||
data = await request.json
|
||||
return await self._run(lambda: operation(data), message=message)
|
||||
|
||||
# Ensure required keys exist.
|
||||
data.setdefault("main_enable", False)
|
||||
data.setdefault("remove_main_duplicate_tools", False)
|
||||
data.setdefault("agents", [])
|
||||
|
||||
# Backward/forward compatibility: ensure each agent contains provider_id.
|
||||
# None means follow global/default provider settings.
|
||||
if isinstance(data.get("agents"), list):
|
||||
for a in data["agents"]:
|
||||
if isinstance(a, dict):
|
||||
a.setdefault("provider_id", None)
|
||||
a.setdefault("persona_id", None)
|
||||
return jsonify(Response().ok(data=data).__dict__)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"获取 subagent 配置失败: {e!s}").__dict__)
|
||||
async def get_config(self):
|
||||
return await self._run_sync(self.service.get_config)
|
||||
|
||||
async def update_config(self):
|
||||
try:
|
||||
data = await request.json
|
||||
if not isinstance(data, dict):
|
||||
return jsonify(Response().error("配置必须为 JSON 对象").__dict__)
|
||||
|
||||
cfg = self.core_lifecycle.astrbot_config
|
||||
cfg["subagent_orchestrator"] = data
|
||||
|
||||
# Persist to cmd_config.json
|
||||
# AstrBotConfigManager does not expose a `save()` method; persist via AstrBotConfig.
|
||||
cfg.save_config()
|
||||
|
||||
# Reload dynamic handoff tools if orchestrator exists
|
||||
orch = getattr(self.core_lifecycle, "subagent_orchestrator", None)
|
||||
if orch is not None:
|
||||
await orch.reload_from_config(data)
|
||||
|
||||
return jsonify(Response().ok(message="保存成功").__dict__)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"保存 subagent 配置失败: {e!s}").__dict__)
|
||||
return await self._run_json(self.service.update_config, message="保存成功")
|
||||
|
||||
async def get_available_tools(self):
|
||||
"""Return all registered tools (name/description/parameters/active/origin).
|
||||
|
||||
UI can use this to build a multi-select list for subagent tool assignment.
|
||||
"""
|
||||
try:
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools_dict = []
|
||||
for tool in tool_mgr.func_list:
|
||||
# Prevent recursive routing: subagents should not be able to select
|
||||
# the handoff (transfer_to_*) tools as their own mounted tools.
|
||||
if isinstance(tool, HandoffTool):
|
||||
continue
|
||||
if tool.handler_module_path == "core.subagent_orchestrator":
|
||||
continue
|
||||
tools_dict.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
"handler_module_path": tool.handler_module_path,
|
||||
}
|
||||
)
|
||||
return jsonify(Response().ok(data=tools_dict).__dict__)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify(Response().error(f"获取可用工具失败: {e!s}").__dict__)
|
||||
return await self._run_sync(self.service.get_available_tools)
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
from quart import jsonify, request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.t2i.template_manager import TemplateManager
|
||||
from astrbot.dashboard.fastapi_compat import jsonify, request
|
||||
from astrbot.dashboard.services.t2i_service import T2iService, T2iServiceError
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -16,9 +14,7 @@ class T2iRoute(Route):
|
||||
self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.manager = TemplateManager()
|
||||
self.service = T2iService(core_lifecycle)
|
||||
# 使用列表保证路由注册顺序,避免 /<name> 路由优先匹配 /reset_default
|
||||
self.routes = [
|
||||
("/t2i/templates", ("GET", self.list_templates)),
|
||||
@@ -28,7 +24,7 @@ class T2iRoute(Route):
|
||||
("/t2i/templates/set_active", ("POST", self.set_active_template)),
|
||||
# 动态路由应该在静态路由之后注册
|
||||
(
|
||||
"/t2i/templates/<name>",
|
||||
"/t2i/templates/<path:name>",
|
||||
[
|
||||
("GET", self.get_template),
|
||||
("PUT", self.update_template),
|
||||
@@ -38,200 +34,94 @@ class T2iRoute(Route):
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
async def _reload_all_pipeline_schedulers(self) -> None:
|
||||
"""热重载所有配置对应的 pipeline scheduler。"""
|
||||
for conf_id in self.core_lifecycle.astrbot_config_mgr.confs:
|
||||
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
||||
@staticmethod
|
||||
def _ok(data=None, message: str | None = None, status_code: int = 200):
|
||||
response = jsonify(asdict(Response().ok(data=data, message=message)))
|
||||
response.status_code = status_code
|
||||
return response
|
||||
|
||||
async def _sync_active_template_to_all_configs(self, name: str) -> None:
|
||||
"""同步当前激活模板到所有配置文件,并热重载对应流水线。"""
|
||||
for config in self.core_lifecycle.astrbot_config_mgr.confs.values():
|
||||
config["t2i_active_template"] = name
|
||||
config.save_config()
|
||||
await self._reload_all_pipeline_schedulers()
|
||||
@staticmethod
|
||||
def _service_error(exc: T2iServiceError):
|
||||
response = jsonify(asdict(Response().error(str(exc))))
|
||||
response.status_code = exc.status_code
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def _request_data() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
operation,
|
||||
*,
|
||||
message: str | None = None,
|
||||
status_code: int = 200,
|
||||
result_as_message: bool = False,
|
||||
):
|
||||
try:
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
if isinstance(result, tuple):
|
||||
payload, result_message = result
|
||||
return self._ok(data=payload, message=result_message)
|
||||
if result_as_message:
|
||||
return self._ok(message=str(result), status_code=status_code)
|
||||
return self._ok(data=result, message=message, status_code=status_code)
|
||||
except T2iServiceError as exc:
|
||||
return self._service_error(exc)
|
||||
|
||||
async def _run_json(self, operation, **kwargs):
|
||||
async def invoke():
|
||||
data = await self._request_data()
|
||||
return operation(data)
|
||||
|
||||
return await self._run(invoke, **kwargs)
|
||||
|
||||
async def list_templates(self):
|
||||
"""获取所有T2I模板列表"""
|
||||
try:
|
||||
templates = self.manager.list_templates()
|
||||
return jsonify(asdict(Response().ok(data=templates)))
|
||||
except Exception as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run(self.service.list_templates)
|
||||
|
||||
async def get_active_template(self):
|
||||
"""获取当前激活的T2I模板"""
|
||||
try:
|
||||
active_template = self.config.get("t2i_active_template", "base")
|
||||
return jsonify(
|
||||
asdict(Response().ok(data={"active_template": active_template})),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error in get_active_template", exc_info=True)
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run(self.service.get_active_template)
|
||||
|
||||
async def get_template(self, name: str):
|
||||
"""获取指定名称的T2I模板内容"""
|
||||
try:
|
||||
content = self.manager.get_template(name)
|
||||
return jsonify(
|
||||
asdict(Response().ok(data={"name": name, "content": content})),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
response = jsonify(asdict(Response().error("Template not found")))
|
||||
response.status_code = 404
|
||||
return response
|
||||
except Exception as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run(lambda: self.service.get_template(name))
|
||||
|
||||
async def create_template(self):
|
||||
"""创建一个新的T2I模板"""
|
||||
try:
|
||||
data = await request.json
|
||||
name = data.get("name")
|
||||
content = data.get("content")
|
||||
if not name or not content:
|
||||
response = jsonify(
|
||||
asdict(Response().error("Name and content are required.")),
|
||||
)
|
||||
response.status_code = 400
|
||||
return response
|
||||
name = name.strip()
|
||||
|
||||
self.manager.create_template(name, content)
|
||||
response = jsonify(
|
||||
asdict(
|
||||
Response().ok(
|
||||
data={"name": name},
|
||||
message="Template created successfully.",
|
||||
),
|
||||
),
|
||||
)
|
||||
response.status_code = 201
|
||||
return response
|
||||
except FileExistsError:
|
||||
response = jsonify(
|
||||
asdict(Response().error("Template with this name already exists.")),
|
||||
)
|
||||
response.status_code = 409
|
||||
return response
|
||||
except ValueError as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 400
|
||||
return response
|
||||
except Exception as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run_json(
|
||||
self.service.create_template_from_legacy_payload,
|
||||
message="Template created successfully.",
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
async def update_template(self, name: str):
|
||||
"""更新一个已存在的T2I模板"""
|
||||
try:
|
||||
name = name.strip()
|
||||
data = await request.json
|
||||
content = data.get("content")
|
||||
if content is None:
|
||||
response = jsonify(asdict(Response().error("Content is required.")))
|
||||
response.status_code = 400
|
||||
return response
|
||||
|
||||
self.manager.update_template(name, content)
|
||||
|
||||
# 检查更新的是否为当前激活的模板,如果是,则热重载
|
||||
active_template = self.config.get("t2i_active_template", "base")
|
||||
if name == active_template:
|
||||
await self._reload_all_pipeline_schedulers()
|
||||
message = f"模板 '{name}' 已更新并重新加载。"
|
||||
else:
|
||||
message = f"模板 '{name}' 已更新。"
|
||||
|
||||
return jsonify(asdict(Response().ok(data={"name": name}, message=message)))
|
||||
except ValueError as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 400
|
||||
return response
|
||||
except Exception as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run_json(
|
||||
lambda data: self.service.update_template_from_legacy_payload(name, data)
|
||||
)
|
||||
|
||||
async def delete_template(self, name: str):
|
||||
"""删除一个T2I模板"""
|
||||
try:
|
||||
name = name.strip()
|
||||
self.manager.delete_template(name)
|
||||
return jsonify(
|
||||
asdict(Response().ok(message="Template deleted successfully.")),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
response = jsonify(asdict(Response().error("Template not found.")))
|
||||
response.status_code = 404
|
||||
return response
|
||||
except ValueError as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 400
|
||||
return response
|
||||
except Exception as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run(
|
||||
lambda: self.service.delete_template(name),
|
||||
message="Template deleted successfully.",
|
||||
)
|
||||
|
||||
async def set_active_template(self):
|
||||
"""设置当前活动的T2I模板"""
|
||||
try:
|
||||
data = await request.json
|
||||
name = data.get("name")
|
||||
if not name:
|
||||
response = jsonify(asdict(Response().error("模板名称(name)不能为空。")))
|
||||
response.status_code = 400
|
||||
return response
|
||||
|
||||
# 验证模板文件是否存在
|
||||
self.manager.get_template(name)
|
||||
|
||||
# 更新所有配置并热重载以应用更改
|
||||
await self._sync_active_template_to_all_configs(name)
|
||||
|
||||
return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。")))
|
||||
|
||||
except FileNotFoundError:
|
||||
response = jsonify(
|
||||
asdict(Response().error(f"模板 '{name}' 不存在,无法应用。")),
|
||||
)
|
||||
response.status_code = 404
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error("Error in set_active_template", exc_info=True)
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run_json(
|
||||
self.service.set_active_template_from_legacy_payload,
|
||||
result_as_message=True,
|
||||
)
|
||||
|
||||
async def reset_default_template(self):
|
||||
"""重置默认的'base'模板"""
|
||||
try:
|
||||
self.manager.reset_default_template()
|
||||
|
||||
# 更新所有配置,将激活模板也重置为'base'
|
||||
await self._sync_active_template_to_all_configs("base")
|
||||
|
||||
return jsonify(
|
||||
asdict(
|
||||
Response().ok(
|
||||
message="Default template has been reset and activated.",
|
||||
),
|
||||
),
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 404
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error("Error in reset_default_template", exc_info=True)
|
||||
response = jsonify(asdict(Response().error(str(e))))
|
||||
response.status_code = 500
|
||||
return response
|
||||
return await self._run(
|
||||
self.service.reset_default_template(),
|
||||
result_as_message=True,
|
||||
)
|
||||
|
||||
@@ -1,43 +1,9 @@
|
||||
import traceback
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star import star_map
|
||||
from astrbot.core.tools.registry import get_builtin_tool_config_statuses
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
from astrbot.dashboard.services.tools_service import ToolsService, ToolsServiceError
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
|
||||
class EmptyMcpServersError(ValueError):
|
||||
"""Raised when mcpServers is empty."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
|
||||
"""Extract server configuration from user-submitted mcpServers field.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid configuration
|
||||
"""
|
||||
if not isinstance(mcp_servers_value, dict):
|
||||
raise ValueError("mcpServers must be a JSON object")
|
||||
if not mcp_servers_value:
|
||||
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
|
||||
key_0 = next(iter(mcp_servers_value))
|
||||
extracted = mcp_servers_value[key_0]
|
||||
if not isinstance(extracted, dict):
|
||||
raise ValueError(
|
||||
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
|
||||
"and each value is an object containing fields like command/url."
|
||||
)
|
||||
return extracted
|
||||
|
||||
|
||||
class ToolsRoute(Route):
|
||||
def __init__(
|
||||
@@ -46,7 +12,7 @@ class ToolsRoute(Route):
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.service = ToolsService(core_lifecycle)
|
||||
self.routes = {
|
||||
"/tools/mcp/servers": ("GET", self.get_mcp_servers),
|
||||
"/tools/mcp/add": ("POST", self.add_mcp_server),
|
||||
@@ -58,514 +24,80 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/sync-provider": ("POST", self.sync_provider),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
|
||||
def _rollback_mcp_server(self, name: str) -> bool:
|
||||
@staticmethod
|
||||
def _ok(data: dict | list | None = None, message: str | None = None) -> dict:
|
||||
return Response().ok(data, message).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str) -> dict:
|
||||
return Response().error(message).__dict__
|
||||
|
||||
@staticmethod
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation, *, message: str | None = None) -> dict:
|
||||
try:
|
||||
rollback_config = self.tool_mgr.load_mcp_config()
|
||||
if name in rollback_config["mcpServers"]:
|
||||
rollback_config["mcpServers"].pop(name)
|
||||
return self.tool_mgr.save_mcp_config(rollback_config)
|
||||
return True
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._ok(result, message)
|
||||
except ToolsServiceError as exc:
|
||||
return self._error(str(exc))
|
||||
|
||||
async def _run_json(
|
||||
self,
|
||||
operation,
|
||||
*,
|
||||
message: str | None = None,
|
||||
result_as_message: bool = False,
|
||||
) -> dict:
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
result = await self._run(invoke)
|
||||
if result_as_message and result.get("status") == "ok":
|
||||
return self._ok(None, result["data"])
|
||||
if message and result.get("status") == "ok":
|
||||
return self._ok(result.get("data"), message)
|
||||
return result
|
||||
|
||||
async def get_mcp_servers(self):
|
||||
try:
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
servers = []
|
||||
mcp_servers = config.get("mcpServers", {})
|
||||
|
||||
if not isinstance(mcp_servers, dict):
|
||||
logger.warning(
|
||||
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
|
||||
)
|
||||
mcp_servers = {}
|
||||
|
||||
# 获取所有服务器并添加它们的工具列表
|
||||
for name, server_config in mcp_servers.items():
|
||||
if not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
|
||||
)
|
||||
continue
|
||||
|
||||
server_info = {
|
||||
"name": name,
|
||||
"active": server_config.get("active", True),
|
||||
}
|
||||
|
||||
# 复制所有配置字段
|
||||
for key, value in server_config.items():
|
||||
if key != "active": # active 已经处理
|
||||
server_info[key] = value
|
||||
|
||||
# 如果MCP客户端已初始化,从客户端获取工具名称
|
||||
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
|
||||
if name_key == name:
|
||||
mcp_client = runtime.client
|
||||
server_info["tools"] = [tool.name for tool in mcp_client.tools]
|
||||
server_info["errlogs"] = mcp_client.server_errlogs
|
||||
break
|
||||
else:
|
||||
server_info["tools"] = []
|
||||
|
||||
servers.append(server_info)
|
||||
|
||||
return Response().ok(servers).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__
|
||||
return await self._run(self.service.get_mcp_servers)
|
||||
|
||||
async def add_mcp_server(self):
|
||||
try:
|
||||
server_data = await request.json
|
||||
|
||||
name = server_data.get("name", "")
|
||||
|
||||
# 检查必填字段
|
||||
if not name:
|
||||
return Response().error("Server name cannot be empty").__dict__
|
||||
|
||||
# 移除特殊字段并检查配置是否有效
|
||||
has_valid_config = False
|
||||
server_config = {"active": server_data.get("active", True)}
|
||||
|
||||
# 复制所有配置字段
|
||||
for key, value in server_data.items():
|
||||
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
||||
if key == "mcpServers":
|
||||
try:
|
||||
server_config = _extract_mcp_server_config(
|
||||
server_data["mcpServers"]
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(f"{e!s}").__dict__
|
||||
else:
|
||||
server_config[key] = value
|
||||
has_valid_config = True
|
||||
|
||||
if not has_valid_config:
|
||||
return (
|
||||
Response()
|
||||
.error("A valid server configuration is required")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
validate_mcp_stdio_config(server_config)
|
||||
except ValueError as e:
|
||||
return Response().error(f"{e!s}").__dict__
|
||||
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if name in config["mcpServers"]:
|
||||
return Response().error(f"Server {name} already exists").__dict__
|
||||
|
||||
try:
|
||||
await self.tool_mgr.test_mcp_server_connection(server_config)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"MCP connection test failed: {e!s}").__dict__
|
||||
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name,
|
||||
server_config,
|
||||
timeout=30,
|
||||
)
|
||||
except TimeoutError:
|
||||
rollback_ok = self._rollback_mcp_server(name)
|
||||
err_msg = f"Timed out while enabling MCP server {name}."
|
||||
if not rollback_ok:
|
||||
err_msg += " Configuration rollback failed. Please check the config manually."
|
||||
return Response().error(err_msg).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
rollback_ok = self._rollback_mcp_server(name)
|
||||
err_msg = f"Failed to enable MCP server {name}: {e!s}"
|
||||
if not rollback_ok:
|
||||
err_msg += " Configuration rollback failed. Please check the config manually."
|
||||
return Response().error(err_msg).__dict__
|
||||
return (
|
||||
Response()
|
||||
.ok(None, f"Successfully added MCP server {name}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().error("Failed to save configuration").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to add MCP server: {e!s}").__dict__
|
||||
return await self._run_json(self.service.add_mcp_server, result_as_message=True)
|
||||
|
||||
async def update_mcp_server(self):
|
||||
try:
|
||||
server_data = await request.json
|
||||
|
||||
name = server_data.get("name", "")
|
||||
old_name = server_data.get("oldName") or name
|
||||
|
||||
if not name:
|
||||
return Response().error("Server name cannot be empty").__dict__
|
||||
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if old_name not in config["mcpServers"]:
|
||||
return Response().error(f"Server {old_name} does not exist").__dict__
|
||||
|
||||
is_rename = name != old_name
|
||||
|
||||
if name in config["mcpServers"] and is_rename:
|
||||
return Response().error(f"Server {name} already exists").__dict__
|
||||
|
||||
# 获取活动状态
|
||||
old_config = config["mcpServers"][old_name]
|
||||
if isinstance(old_config, dict):
|
||||
old_active = old_config.get("active", True)
|
||||
else:
|
||||
old_active = True
|
||||
active = server_data.get("active", old_active)
|
||||
|
||||
# 创建新的配置对象
|
||||
server_config = {"active": active}
|
||||
|
||||
# 仅更新活动状态的特殊处理
|
||||
only_update_active = True
|
||||
|
||||
# 复制所有配置字段
|
||||
for key, value in server_data.items():
|
||||
if key not in [
|
||||
"name",
|
||||
"active",
|
||||
"tools",
|
||||
"errlogs",
|
||||
"oldName",
|
||||
]: # 排除特殊字段
|
||||
if key == "mcpServers":
|
||||
try:
|
||||
server_config = _extract_mcp_server_config(
|
||||
server_data["mcpServers"]
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(f"{e!s}").__dict__
|
||||
else:
|
||||
server_config[key] = value
|
||||
only_update_active = False
|
||||
|
||||
# 如果只更新活动状态,保留原始配置
|
||||
if only_update_active and isinstance(old_config, dict):
|
||||
for key, value in old_config.items():
|
||||
if key != "active": # 除了active之外的所有字段都保留
|
||||
server_config[key] = value
|
||||
|
||||
try:
|
||||
validate_mcp_stdio_config(server_config)
|
||||
except ValueError as e:
|
||||
return Response().error(f"{e!s}").__dict__
|
||||
|
||||
# config["mcpServers"][name] = server_config
|
||||
if is_rename:
|
||||
config["mcpServers"].pop(old_name)
|
||||
config["mcpServers"][name] = server_config
|
||||
else:
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
if (
|
||||
old_name in self.tool_mgr.mcp_server_runtime_view
|
||||
or not only_update_active
|
||||
or is_rename
|
||||
):
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
|
||||
except TimeoutError as e:
|
||||
return (
|
||||
Response()
|
||||
.error(
|
||||
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}"
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(
|
||||
f"Failed to disable MCP server {old_name} before enabling: {e!s}"
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name,
|
||||
config["mcpServers"][name],
|
||||
timeout=30,
|
||||
)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Timed out while enabling MCP server {name}.")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed to enable MCP server {name}: {e!s}")
|
||||
.__dict__
|
||||
)
|
||||
# 如果要停用服务器
|
||||
elif old_name in self.tool_mgr.mcp_server_runtime_view:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Timed out while disabling MCP server {old_name}.")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed to disable MCP server {old_name}: {e!s}")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(None, f"Successfully updated MCP server {name}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().error("Failed to save configuration").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to update MCP server: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.update_mcp_server,
|
||||
result_as_message=True,
|
||||
)
|
||||
|
||||
async def delete_mcp_server(self):
|
||||
try:
|
||||
server_data = await request.json
|
||||
name = server_data.get("name", "")
|
||||
|
||||
if not name:
|
||||
return Response().error("Server name cannot be empty").__dict__
|
||||
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if name not in config["mcpServers"]:
|
||||
return Response().error(f"Server {name} does not exist").__dict__
|
||||
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
if name in self.tool_mgr.mcp_server_runtime_view:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError:
|
||||
return (
|
||||
Response()
|
||||
.error(f"Timed out while disabling MCP server {name}.")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return (
|
||||
Response()
|
||||
.error(f"Failed to disable MCP server {name}: {e!s}")
|
||||
.__dict__
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.ok(None, f"Successfully deleted MCP server {name}")
|
||||
.__dict__
|
||||
)
|
||||
return Response().error("Failed to save configuration").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.delete_mcp_server,
|
||||
result_as_message=True,
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
try:
|
||||
server_data = await request.json
|
||||
config = server_data.get("mcp_server_config", None)
|
||||
|
||||
if not isinstance(config, dict) or not config:
|
||||
return Response().error("Invalid MCP server configuration").__dict__
|
||||
|
||||
if "mcpServers" in config:
|
||||
mcp_servers = config["mcpServers"]
|
||||
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
|
||||
return (
|
||||
Response()
|
||||
.error(
|
||||
"Only one MCP server configuration can be tested at a time"
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
try:
|
||||
config = _extract_mcp_server_config(mcp_servers)
|
||||
except EmptyMcpServersError:
|
||||
return (
|
||||
Response()
|
||||
.error("MCP server configuration cannot be empty")
|
||||
.__dict__
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(f"{e!s}").__dict__
|
||||
elif not config:
|
||||
return (
|
||||
Response()
|
||||
.error("MCP server configuration cannot be empty")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
validate_mcp_stdio_config(config)
|
||||
except ValueError as e:
|
||||
return Response().error(f"{e!s}").__dict__
|
||||
|
||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||
return (
|
||||
Response()
|
||||
.ok(data=tools_name, message="🎉 MCP server is available!")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__
|
||||
return await self._run_json(
|
||||
self.service.test_mcp_connection,
|
||||
message="🎉 MCP server is available!",
|
||||
)
|
||||
|
||||
async def get_tool_list(self):
|
||||
"""Get all registered tools."""
|
||||
try:
|
||||
tools = list(self.tool_mgr.func_list)
|
||||
existing_names = {tool.name for tool in tools}
|
||||
for tool in self.tool_mgr.iter_builtin_tools():
|
||||
if tool.name not in existing_names:
|
||||
tools.append(tool)
|
||||
|
||||
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
|
||||
conf_name_map = {conf["id"]: conf["name"] for conf in conf_list}
|
||||
config_entries = []
|
||||
for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items():
|
||||
config_entries.append(
|
||||
{
|
||||
"conf_id": conf_id,
|
||||
"conf_name": conf_name_map.get(conf_id, conf_id),
|
||||
"config": conf,
|
||||
}
|
||||
)
|
||||
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
readonly = False
|
||||
builtin_config_statuses = []
|
||||
builtin_config_tags = []
|
||||
if self.tool_mgr.is_builtin_tool(tool.name):
|
||||
origin = "builtin"
|
||||
origin_name = "AstrBot Core"
|
||||
readonly = True
|
||||
builtin_config_statuses = get_builtin_tool_config_statuses(
|
||||
tool.name,
|
||||
config_entries,
|
||||
)
|
||||
builtin_config_tags = [
|
||||
status
|
||||
for status in builtin_config_statuses
|
||||
if status["enabled"]
|
||||
]
|
||||
elif isinstance(tool, MCPTool):
|
||||
origin = "mcp"
|
||||
origin_name = tool.mcp_server_name
|
||||
elif tool.handler_module_path and star_map.get(
|
||||
tool.handler_module_path
|
||||
):
|
||||
star = star_map[tool.handler_module_path]
|
||||
origin = "plugin"
|
||||
origin_name = star.name
|
||||
else:
|
||||
origin = "unknown"
|
||||
origin_name = "unknown"
|
||||
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
"origin": origin,
|
||||
"origin_name": origin_name,
|
||||
"readonly": readonly,
|
||||
"builtin_config_statuses": builtin_config_statuses,
|
||||
"builtin_config_tags": builtin_config_tags,
|
||||
}
|
||||
tools_dict.append(tool_info)
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to get tool list: {e!s}").__dict__
|
||||
return await self._run(self.service.get_tool_list)
|
||||
|
||||
async def toggle_tool(self):
|
||||
"""Activate or deactivate a specified tool."""
|
||||
try:
|
||||
data = await request.json
|
||||
tool_name = data.get("name")
|
||||
action = data.get("activate") # True or False
|
||||
|
||||
if not tool_name or action is None:
|
||||
return (
|
||||
Response()
|
||||
.error("Missing required parameters: name or activate")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
if self.tool_mgr.is_builtin_tool(tool_name):
|
||||
return (
|
||||
Response()
|
||||
.error("Builtin tools are read-only and cannot be toggled.")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
if action:
|
||||
try:
|
||||
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
||||
except ValueError as e:
|
||||
return Response().error(f"Failed to activate tool: {e!s}").__dict__
|
||||
else:
|
||||
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
||||
|
||||
if ok:
|
||||
return Response().ok(None, "Operation successful.").__dict__
|
||||
return (
|
||||
Response()
|
||||
.error(f"Tool {tool_name} does not exist or the operation failed.")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Failed to operate tool: {e!s}").__dict__
|
||||
return await self._run_json(self.service.toggle_tool, result_as_message=True)
|
||||
|
||||
async def sync_provider(self):
|
||||
"""Sync MCP provider configuration."""
|
||||
try:
|
||||
data = await request.json
|
||||
provider_name = data.get("name") # modelscope, or others
|
||||
match provider_name:
|
||||
case "modelscope":
|
||||
access_token = data.get("access_token", "")
|
||||
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
||||
case _:
|
||||
return (
|
||||
Response().error(f"Unknown provider: {provider_name}").__dict__
|
||||
)
|
||||
|
||||
return Response().ok(message="Sync completed").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Sync failed: {e!s}").__dict__
|
||||
return await self._run_json(self.service.sync_provider, result_as_message=True)
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
import traceback
|
||||
import uuid
|
||||
from __future__ import annotations
|
||||
|
||||
from quart import request
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.core import DEMO_MODE, logger, pip_installer
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db.migration.helper import check_migration_needed_v4, do_migration_v4
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.dashboard.fastapi_compat import request
|
||||
from astrbot.dashboard.services.update_service import (
|
||||
DEMO_MODE,
|
||||
UpdateService,
|
||||
UpdateServiceError,
|
||||
UpdateServiceResult,
|
||||
call_check_migration_needed_v4,
|
||||
call_do_migration_v4,
|
||||
call_download_dashboard,
|
||||
call_get_dashboard_version,
|
||||
call_pip_install,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
|
||||
CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'}
|
||||
|
||||
|
||||
@@ -32,323 +41,82 @@ class UpdateRoute(Route):
|
||||
"/update/pip-install": ("POST", self.install_pip_package),
|
||||
"/update/migration": ("POST", self.do_migration),
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.update_progress: dict[str, dict] = {}
|
||||
self.service = UpdateService(
|
||||
astrbot_updator,
|
||||
core_lifecycle,
|
||||
download_dashboard_func=call_download_dashboard,
|
||||
get_dashboard_version_func=call_get_dashboard_version,
|
||||
pip_install_func=call_pip_install,
|
||||
check_migration_needed_func=call_check_migration_needed_v4,
|
||||
do_migration_func=call_do_migration_v4,
|
||||
demo_mode=DEMO_MODE,
|
||||
clear_site_data_headers=CLEAR_SITE_DATA_HEADERS,
|
||||
)
|
||||
self.register_routes()
|
||||
|
||||
def _init_update_progress(self, progress_id: str, version: str) -> None:
|
||||
self.update_progress[progress_id] = {
|
||||
"id": progress_id,
|
||||
"status": "running",
|
||||
"stage": "preparing",
|
||||
"version": version or "latest",
|
||||
"message": "正在准备更新...",
|
||||
"overall_percent": 0,
|
||||
"stages": {
|
||||
"dashboard": self._empty_stage("pending"),
|
||||
"core": self._empty_stage("pending"),
|
||||
},
|
||||
}
|
||||
@staticmethod
|
||||
def _service_response(result: UpdateServiceResult):
|
||||
if result.status == "success":
|
||||
payload = Response(
|
||||
status="success",
|
||||
message=result.message,
|
||||
data=result.data,
|
||||
).__dict__
|
||||
else:
|
||||
payload = Response().ok(result.data, result.message).__dict__
|
||||
|
||||
if result.headers:
|
||||
return payload, 200, result.headers
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _empty_stage(status: str = "pending") -> dict:
|
||||
return {
|
||||
"status": status,
|
||||
"downloaded": 0,
|
||||
"total": 0,
|
||||
"percent": 0,
|
||||
"speed": 0,
|
||||
}
|
||||
|
||||
def _set_update_stage(
|
||||
self,
|
||||
progress_id: str,
|
||||
stage: str,
|
||||
status: str,
|
||||
message: str,
|
||||
overall_percent: int | None = None,
|
||||
) -> None:
|
||||
progress = self.update_progress.get(progress_id)
|
||||
if not progress:
|
||||
return
|
||||
progress["stage"] = stage
|
||||
progress["message"] = message
|
||||
progress["stages"].setdefault(stage, self._empty_stage())
|
||||
progress["stages"][stage]["status"] = status
|
||||
if overall_percent is not None:
|
||||
progress["overall_percent"] = overall_percent
|
||||
def _service_error(exc: UpdateServiceError):
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
@staticmethod
|
||||
def _normalize_percent(value) -> int:
|
||||
async def _json_body() -> dict:
|
||||
data = await request.get_json()
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def _run(self, operation):
|
||||
try:
|
||||
percent = float(value or 0)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
if percent <= 1:
|
||||
percent *= 100
|
||||
return max(0, min(100, int(percent)))
|
||||
result = operation() if callable(operation) else operation
|
||||
while hasattr(result, "__await__"):
|
||||
result = await result
|
||||
return self._service_response(result)
|
||||
except UpdateServiceError as exc:
|
||||
return self._service_error(exc)
|
||||
|
||||
def _make_progress_callback(
|
||||
self,
|
||||
progress_id: str,
|
||||
stage: str,
|
||||
stage_start: int,
|
||||
stage_weight: int,
|
||||
):
|
||||
def _callback(payload: dict) -> None:
|
||||
progress = self.update_progress.get(progress_id)
|
||||
if not progress:
|
||||
return
|
||||
stage_percent = self._normalize_percent(payload.get("percent"))
|
||||
progress["stage"] = stage
|
||||
progress["stages"][stage] = {
|
||||
"status": "running" if stage_percent < 100 else "done",
|
||||
"downloaded": payload.get("downloaded", 0),
|
||||
"total": payload.get("total", 0),
|
||||
"percent": stage_percent,
|
||||
"speed": payload.get("speed", 0),
|
||||
}
|
||||
progress["overall_percent"] = min(
|
||||
99,
|
||||
stage_start + int(stage_percent * stage_weight / 100),
|
||||
)
|
||||
async def _run_json(self, operation):
|
||||
async def invoke():
|
||||
data = await self._json_body()
|
||||
return operation(data)
|
||||
|
||||
return _callback
|
||||
return await self._run(invoke)
|
||||
|
||||
async def get_update_progress(self):
|
||||
progress_id = request.args.get("id", "")
|
||||
if not progress_id:
|
||||
return Response().error("缺少参数 id。").__dict__
|
||||
progress = self.update_progress.get(progress_id)
|
||||
if not progress:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{"id": progress_id, "status": "idle"},
|
||||
"没有正在进行的更新。",
|
||||
)
|
||||
.__dict__
|
||||
return await self._run(
|
||||
lambda: self.service.get_update_progress_from_legacy_query(
|
||||
request.args.get("id")
|
||||
)
|
||||
return Response().ok(progress).__dict__
|
||||
)
|
||||
|
||||
async def do_migration(self):
|
||||
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
|
||||
if not need_migration:
|
||||
return Response().ok(None, "不需要进行迁移。").__dict__
|
||||
try:
|
||||
data = await request.json
|
||||
pim = data.get("platform_id_map", {})
|
||||
await do_migration_v4(
|
||||
self.core_lifecycle.db,
|
||||
pim,
|
||||
self.core_lifecycle.astrbot_config,
|
||||
)
|
||||
return Response().ok(None, "迁移成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"迁移失败: {traceback.format_exc()}")
|
||||
return Response().error(f"迁移失败: {e!s}").__dict__
|
||||
return await self._run_json(self.service.do_migration_v4)
|
||||
|
||||
async def check_update(self):
|
||||
type_ = request.args.get("type", None)
|
||||
|
||||
try:
|
||||
dv = await get_dashboard_version()
|
||||
if type_ == "dashboard":
|
||||
return (
|
||||
Response()
|
||||
.ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv})
|
||||
.__dict__
|
||||
)
|
||||
ret = await self.astrbot_updator.check_update(None, None, False)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"version": f"v{VERSION}",
|
||||
"has_new_version": ret is not None,
|
||||
"dashboard_version": dv,
|
||||
"dashboard_has_new_version": bool(dv and dv != f"v{VERSION}"),
|
||||
},
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.warning(f"检查更新失败: {e!s} (不影响除项目更新外的正常使用)")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
return await self._run(
|
||||
self.service.check_update_from_legacy_query(request.args.get("type"))
|
||||
)
|
||||
|
||||
async def get_releases(self):
|
||||
try:
|
||||
ret = await self.astrbot_updator.get_releases()
|
||||
return Response().ok(ret).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update/releases: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
return await self._run(self.service.get_releases())
|
||||
|
||||
async def update_project(self):
|
||||
data = await request.json
|
||||
version = data.get("version", "")
|
||||
reboot = data.get("reboot", True)
|
||||
progress_id = data.get("progress_id") or uuid.uuid4().hex
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ""
|
||||
else:
|
||||
latest = False
|
||||
|
||||
proxy: str = data.get("proxy", None)
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
|
||||
self._init_update_progress(progress_id, version)
|
||||
try:
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
"running",
|
||||
"正在下载 WebUI...",
|
||||
0,
|
||||
)
|
||||
await download_dashboard(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
0,
|
||||
45,
|
||||
),
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
"done",
|
||||
"WebUI 下载完成。",
|
||||
45,
|
||||
)
|
||||
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"core",
|
||||
"running",
|
||||
"正在下载 AstrBot 项目代码...",
|
||||
45,
|
||||
)
|
||||
await self.astrbot_updator.update(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"core",
|
||||
45,
|
||||
45,
|
||||
),
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"core",
|
||||
"done",
|
||||
"项目代码下载完成。",
|
||||
90,
|
||||
)
|
||||
|
||||
# pip 更新依赖
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dependencies",
|
||||
"running",
|
||||
"正在更新依赖...",
|
||||
92,
|
||||
)
|
||||
logger.info("更新依赖中...")
|
||||
try:
|
||||
await pip_installer.install(requirements_path="requirements.txt")
|
||||
except Exception as e:
|
||||
logger.error(f"更新依赖失败: {e}")
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dependencies",
|
||||
"done",
|
||||
"依赖更新完成。",
|
||||
96,
|
||||
)
|
||||
|
||||
if reboot:
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"restart",
|
||||
"running",
|
||||
"更新成功,正在准备重启...",
|
||||
98,
|
||||
)
|
||||
await self.core_lifecycle.restart()
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "success",
|
||||
"stage": "done",
|
||||
"message": "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。",
|
||||
"overall_percent": 100,
|
||||
},
|
||||
)
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "success",
|
||||
"stage": "done",
|
||||
"message": "更新成功,AstrBot 将在下次启动时应用新的代码。",
|
||||
"overall_percent": 100,
|
||||
},
|
||||
)
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
except Exception as e:
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "error",
|
||||
"message": e.__str__(),
|
||||
},
|
||||
)
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
return await self._run_json(self.service.update_project)
|
||||
|
||||
async def update_dashboard(self):
|
||||
try:
|
||||
try:
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
return Response().error(f"下载管理面板文件失败: {e}").__dict__
|
||||
ret = Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
return await self._run(self.service.update_dashboard())
|
||||
|
||||
async def install_pip_package(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
data = await request.json
|
||||
package = data.get("package", "")
|
||||
mirror = data.get("mirror", None)
|
||||
if not package:
|
||||
return Response().error("缺少参数 package 或不合法。").__dict__
|
||||
try:
|
||||
await pip_installer.install(package, mirror=mirror)
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_pip: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
return await self._run_json(self.service.install_pip_package)
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
"""Dashboard 路由工具集。
|
||||
|
||||
这里放一些 dashboard routes 可复用的小工具函数。
|
||||
|
||||
目前主要用于「配置文件上传(file 类型配置项)」功能:
|
||||
- 清洗/规范化用户可控的文件名与相对路径
|
||||
- 将配置 key 映射到配置项独立子目录
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_schema_item(schema: dict | None, key_path: str) -> dict | None:
|
||||
"""按 dot-path 获取 schema 的节点。
|
||||
|
||||
同时支持:
|
||||
- 扁平 schema(直接 key 命中)
|
||||
- 嵌套 object schema({type: "object", items: {...}})
|
||||
- template_list schema(<field>.templates.<template>.items)
|
||||
"""
|
||||
|
||||
if not isinstance(schema, dict) or not key_path:
|
||||
return None
|
||||
if key_path in schema:
|
||||
return schema.get(key_path)
|
||||
|
||||
parts = key_path.split(".")
|
||||
current = schema
|
||||
idx = 0
|
||||
while idx < len(parts):
|
||||
part = parts[idx]
|
||||
if part not in current:
|
||||
return None
|
||||
meta = current.get(part)
|
||||
if idx == len(parts) - 1:
|
||||
return meta
|
||||
if not isinstance(meta, dict) or meta.get("type") != "object":
|
||||
if not isinstance(meta, dict) or meta.get("type") != "template_list":
|
||||
return None
|
||||
if idx + 2 >= len(parts) or parts[idx + 1] != "templates":
|
||||
return None
|
||||
template_meta = meta.get("templates", {}).get(parts[idx + 2])
|
||||
if not isinstance(template_meta, dict):
|
||||
return None
|
||||
if idx + 2 == len(parts) - 1:
|
||||
return template_meta
|
||||
current = template_meta.get("items", {})
|
||||
idx += 3
|
||||
continue
|
||||
current = meta.get("items", {})
|
||||
idx += 1
|
||||
return None
|
||||
|
||||
|
||||
def sanitize_filename(name: str) -> str:
|
||||
"""清洗上传文件名,避免路径穿越与非法名称。
|
||||
|
||||
- 丢弃目录部分,仅保留 basename
|
||||
- 将路径分隔符替换为下划线
|
||||
- 拒绝空字符串 / "." / ".."
|
||||
"""
|
||||
|
||||
cleaned = os.path.basename(name).strip()
|
||||
if not cleaned or cleaned in {".", ".."}:
|
||||
return ""
|
||||
for sep in (os.sep, os.altsep):
|
||||
if sep:
|
||||
cleaned = cleaned.replace(sep, "_")
|
||||
return cleaned
|
||||
|
||||
|
||||
def sanitize_path_segment(segment: str) -> str:
|
||||
"""清洗目录片段(URL/path 安全,避免穿越)。
|
||||
|
||||
仅保留 [A-Za-z0-9_-],其余替换为 "_"
|
||||
"""
|
||||
|
||||
cleaned = []
|
||||
for ch in segment:
|
||||
if (
|
||||
("a" <= ch <= "z")
|
||||
or ("A" <= ch <= "Z")
|
||||
or ch.isdigit()
|
||||
or ch
|
||||
in {
|
||||
"-",
|
||||
"_",
|
||||
}
|
||||
):
|
||||
cleaned.append(ch)
|
||||
else:
|
||||
cleaned.append("_")
|
||||
result = "".join(cleaned).strip("_")
|
||||
return result or "_"
|
||||
|
||||
|
||||
def config_key_to_folder(key_path: str) -> str:
|
||||
"""将 dot-path 的配置 key 转成稳定的文件夹路径。"""
|
||||
|
||||
parts = [sanitize_path_segment(p) for p in key_path.split(".") if p]
|
||||
return "/".join(parts) if parts else "_"
|
||||
|
||||
|
||||
def normalize_rel_path(rel_path: str | None) -> str | None:
|
||||
"""规范化用户传入的相对路径,并阻止路径穿越。"""
|
||||
|
||||
if not isinstance(rel_path, str):
|
||||
return None
|
||||
rel = rel_path.replace("\\", "/").lstrip("/")
|
||||
if not rel:
|
||||
return None
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if any(part in {".", ".."} for part in parts):
|
||||
return None
|
||||
if rel.startswith("../") or "/../" in rel:
|
||||
return None
|
||||
return "/".join(parts)
|
||||
@@ -1,49 +1,68 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Protocol, cast
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from flask.json.provider import DefaultJSONProvider
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config as HyperConfig
|
||||
from hypercorn.logging import AccessLogAtoms
|
||||
from hypercorn.logging import Logger as HypercornLogger
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
from werkzeug.exceptions import MethodNotAllowed, NotFound
|
||||
from werkzeug.routing import Map, Rule
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
from astrbot.core.utils.io import (
|
||||
get_bundled_dashboard_dist_path,
|
||||
get_local_ip_addresses,
|
||||
should_use_bundled_dashboard_dist,
|
||||
)
|
||||
from astrbot.dashboard.fastapi_compat import (
|
||||
CompatG,
|
||||
FastAPIAppAdapter,
|
||||
bind_request_context,
|
||||
g,
|
||||
jsonify,
|
||||
request,
|
||||
)
|
||||
|
||||
from .plugin_page_auth import PluginPageAuth
|
||||
from .routes import *
|
||||
from .routes.api_key import ALL_OPEN_API_SCOPES
|
||||
from .routes.auth import DASHBOARD_JWT_COOKIE_NAME
|
||||
from .routes.api_key import ApiKeyRoute
|
||||
from .routes.auth import AuthRoute
|
||||
from .routes.backup import BackupRoute
|
||||
from .routes.chat import ChatRoute
|
||||
from .routes.chatui_project import ChatUIProjectRoute
|
||||
from .routes.command import CommandRoute
|
||||
from .routes.config import ConfigRoute
|
||||
from .routes.conversation import ConversationRoute
|
||||
from .routes.cron import CronRoute
|
||||
from .routes.file import FileRoute
|
||||
from .routes.knowledge_base import KnowledgeBaseRoute
|
||||
from .routes.live_chat import LiveChatRoute
|
||||
from .routes.log import LogRoute
|
||||
from .routes.open_api import OpenApiRoute
|
||||
from .routes.persona import PersonaRoute
|
||||
from .routes.platform import PlatformRoute
|
||||
from .routes.plugin import PluginRoute
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
from .routes.skills import SkillsRoute
|
||||
from .routes.stat import StatRoute
|
||||
from .routes.static_file import StaticFileRoute
|
||||
from .routes.subagent import SubAgentRoute
|
||||
from .routes.t2i import T2iRoute
|
||||
from .routes.tools import ToolsRoute
|
||||
from .routes.update import UpdateRoute
|
||||
from .services.auth_service import DASHBOARD_JWT_COOKIE_NAME
|
||||
from .services.chat_service import ChatService
|
||||
from .v1.app import create_v1_asgi_app
|
||||
|
||||
_RATE_LIMITED_ENDPOINTS: frozenset = frozenset(
|
||||
{
|
||||
@@ -120,7 +139,7 @@ class _AddrWithPort(Protocol):
|
||||
port: int
|
||||
|
||||
|
||||
APP: Quart
|
||||
APP: FastAPIAppAdapter | None = None
|
||||
|
||||
|
||||
def _normalize_plugin_api_route(route: str) -> str:
|
||||
@@ -139,27 +158,28 @@ def _match_registered_web_api(registered_web_apis, subpath: str, method: str):
|
||||
if request_method not in allowed_methods:
|
||||
continue
|
||||
|
||||
url_map = Map(
|
||||
[
|
||||
Rule(
|
||||
_normalize_plugin_api_route(route),
|
||||
endpoint="plugin_api",
|
||||
methods=allowed_methods,
|
||||
),
|
||||
]
|
||||
)
|
||||
try:
|
||||
_, path_values = url_map.bind("").match(
|
||||
request_path,
|
||||
method=request_method,
|
||||
)
|
||||
except (MethodNotAllowed, NotFound):
|
||||
pattern = _plugin_api_route_pattern(route)
|
||||
matched = re.fullmatch(pattern, request_path)
|
||||
if not matched:
|
||||
continue
|
||||
return view_handler, path_values
|
||||
return view_handler, matched.groupdict()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _plugin_api_route_pattern(route: str) -> str:
|
||||
normalized = _normalize_plugin_api_route(route)
|
||||
chunks = []
|
||||
pos = 0
|
||||
for match in re.finditer(r"<(?:(path):)?([A-Za-z_][A-Za-z0-9_]*)>", normalized):
|
||||
chunks.append(re.escape(normalized[pos : match.start()]))
|
||||
name = match.group(2)
|
||||
chunks.append(f"(?P<{name}>.*)" if match.group(1) else f"(?P<{name}>[^/]+)")
|
||||
pos = match.end()
|
||||
chunks.append(re.escape(normalized[pos:]))
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
def _parse_env_bool(value: str | None, default: bool) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
@@ -213,13 +233,6 @@ class _ProxyAwareHypercornLogger(HypercornLogger):
|
||||
return atoms
|
||||
|
||||
|
||||
class AstrBotJSONProvider(DefaultJSONProvider):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return to_utc_isoformat(obj)
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class AstrBotDashboard:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -254,16 +267,30 @@ class AstrBotDashboard:
|
||||
self.data_path = os.path.abspath(user_dist)
|
||||
|
||||
self._rate_limiter_registry = _RateLimiterRegistry()
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
APP = self.app # noqa
|
||||
self._init_jwt_secret()
|
||||
self.asgi_app = create_v1_asgi_app(
|
||||
core_lifecycle=core_lifecycle,
|
||||
db=db,
|
||||
jwt_secret=self._jwt_secret,
|
||||
)
|
||||
self.app = FastAPIAppAdapter(self.asgi_app, static_folder=self.data_path)
|
||||
self.asgi_app.state.dashboard_app_adapter = self.app
|
||||
self.app._dashboard_server = self
|
||||
global APP
|
||||
APP = self.app
|
||||
self.app.config["MAX_CONTENT_LENGTH"] = (
|
||||
128 * 1024 * 1024
|
||||
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
|
||||
self.app.json = AstrBotJSONProvider(self.app)
|
||||
self.app.json.sort_keys = False
|
||||
self.app.before_request(self.auth_middleware)
|
||||
# token 用于验证请求
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
|
||||
@self.asgi_app.middleware("http")
|
||||
async def dashboard_auth_middleware(request_, call_next):
|
||||
request_.state.dashboard_g = CompatG()
|
||||
with bind_request_context(request_, self.app, request_.state.dashboard_g):
|
||||
auth_response = await self.auth_middleware()
|
||||
if auth_response is not None:
|
||||
return auth_response
|
||||
return await call_next(request_)
|
||||
|
||||
self.context = RouteContext(self.config, self.app)
|
||||
self.ur = UpdateRoute(
|
||||
self.context,
|
||||
@@ -282,13 +309,21 @@ class AstrBotDashboard:
|
||||
self.sfr = StaticFileRoute(self.context)
|
||||
self.ar = AuthRoute(self.context, db)
|
||||
self.api_key_route = ApiKeyRoute(self.context, db)
|
||||
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
|
||||
self.chat_service = ChatService(db, core_lifecycle)
|
||||
self.chat_route = ChatRoute(
|
||||
self.context,
|
||||
db,
|
||||
core_lifecycle,
|
||||
service=self.chat_service,
|
||||
)
|
||||
self.open_api_route = OpenApiRoute(
|
||||
self.context,
|
||||
db,
|
||||
core_lifecycle,
|
||||
self.chat_route,
|
||||
self.chat_service,
|
||||
register_routes=False,
|
||||
)
|
||||
self.asgi_app.state.open_api_route = self.open_api_route
|
||||
self.chatui_project_route = ChatUIProjectRoute(self.context, db)
|
||||
self.tools_root = ToolsRoute(self.context, core_lifecycle)
|
||||
self.subagent_route = SubAgentRoute(self.context, core_lifecycle)
|
||||
@@ -307,6 +342,7 @@ class AstrBotDashboard:
|
||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
||||
self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle)
|
||||
self.asgi_app.state.live_chat_route = self.live_chat_route
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
@@ -316,8 +352,6 @@ class AstrBotDashboard:
|
||||
|
||||
self.shutdown_event = shutdown_event
|
||||
|
||||
self._init_jwt_secret()
|
||||
|
||||
async def srv_plug_route(self, subpath, *args, **kwargs):
|
||||
"""插件路由"""
|
||||
registered_web_apis = self.core_lifecycle.star_context.registered_web_apis
|
||||
@@ -335,37 +369,6 @@ class AstrBotDashboard:
|
||||
if not request.path.startswith("/api"):
|
||||
return None
|
||||
if request.path.startswith("/api/v1"):
|
||||
raw_key = self._extract_raw_api_key()
|
||||
if not raw_key:
|
||||
r = jsonify(Response().error("Missing API key").__dict__)
|
||||
r.status_code = 401
|
||||
return r
|
||||
key_hash = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
raw_key.encode("utf-8"),
|
||||
b"astrbot_api_key",
|
||||
100_000,
|
||||
).hex()
|
||||
api_key = await self.db.get_active_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
r = jsonify(Response().error("Invalid API key").__dict__)
|
||||
r.status_code = 401
|
||||
return r
|
||||
|
||||
if isinstance(api_key.scopes, list):
|
||||
scopes = api_key.scopes
|
||||
else:
|
||||
scopes = list(ALL_OPEN_API_SCOPES)
|
||||
required_scope = self._get_required_open_api_scope(request.path)
|
||||
if required_scope and "*" not in scopes and required_scope not in scopes:
|
||||
r = jsonify(Response().error("Insufficient API key scope").__dict__)
|
||||
r.status_code = 403
|
||||
return r
|
||||
|
||||
g.api_key_id = api_key.key_id
|
||||
g.api_key_scopes = scopes
|
||||
g.username = f"api_key:{api_key.key_id}"
|
||||
await self.db.touch_api_key(api_key.key_id)
|
||||
return None
|
||||
|
||||
if (
|
||||
@@ -403,6 +406,7 @@ class AstrBotDashboard:
|
||||
}
|
||||
allowed_endpoint_prefixes = [
|
||||
"/api/file",
|
||||
"/api/v1/files/tokens",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
"/api/backup/download", # 备份下载使用 URL 参数传递 token
|
||||
@@ -486,34 +490,6 @@ class AstrBotDashboard:
|
||||
return cookie_token
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_raw_api_key() -> str | None:
|
||||
if key := request.args.get("api_key"):
|
||||
return key.strip()
|
||||
if key := request.args.get("key"):
|
||||
return key.strip()
|
||||
if key := request.headers.get("X-API-Key"):
|
||||
return key.strip()
|
||||
auth_header = request.headers.get("Authorization", "").strip()
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header.removeprefix("Bearer ").strip()
|
||||
if auth_header.startswith("ApiKey "):
|
||||
return auth_header.removeprefix("ApiKey ").strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_required_open_api_scope(path: str) -> str | None:
|
||||
scope_map = {
|
||||
"/api/v1/chat": "chat",
|
||||
"/api/v1/chat/ws": "chat",
|
||||
"/api/v1/chat/sessions": "chat",
|
||||
"/api/v1/configs": "config",
|
||||
"/api/v1/file": "file",
|
||||
"/api/v1/im/message": "im",
|
||||
"/api/v1/im/bots": "im",
|
||||
}
|
||||
return scope_map.get(path)
|
||||
|
||||
def check_port_in_use(self, port: int) -> bool:
|
||||
"""跨平台检测端口是否被占用"""
|
||||
try:
|
||||
@@ -725,7 +701,7 @@ class AstrBotDashboard:
|
||||
config.accesslog = "-"
|
||||
config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s"
|
||||
|
||||
return serve(self.app, config, shutdown_trigger=self.shutdown_trigger)
|
||||
return serve(self.asgi_app, config, shutdown_trigger=self.shutdown_trigger)
|
||||
|
||||
async def shutdown_trigger(self) -> None:
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
1
astrbot/dashboard/services/__init__.py
Normal file
1
astrbot/dashboard/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Application services for dashboard HTTP APIs."""
|
||||
139
astrbot/dashboard/services/api_key_service.py
Normal file
139
astrbot/dashboard/services/api_key_service.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.datetime_utils import normalize_datetime_utc
|
||||
|
||||
from .auth_service import ALL_OPEN_API_SCOPES
|
||||
|
||||
|
||||
class ApiKeyServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
def __init__(self, db: BaseDatabase) -> None:
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def hash_key(raw_key: str) -> str:
|
||||
return hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
raw_key.encode("utf-8"),
|
||||
b"astrbot_api_key",
|
||||
100_000,
|
||||
).hex()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_utc(dt: datetime | None) -> datetime | None:
|
||||
return normalize_datetime_utc(dt)
|
||||
|
||||
@classmethod
|
||||
def _serialize_datetime(cls, dt: datetime | None) -> str | None:
|
||||
normalized = cls._normalize_utc(dt)
|
||||
if normalized is None:
|
||||
return None
|
||||
return normalized.astimezone().isoformat()
|
||||
|
||||
@classmethod
|
||||
def serialize_api_key(cls, key) -> dict:
|
||||
expires_at = cls._normalize_utc(key.expires_at)
|
||||
return {
|
||||
"key_id": key.key_id,
|
||||
"name": key.name,
|
||||
"key_prefix": key.key_prefix,
|
||||
"scopes": key.scopes or [],
|
||||
"created_by": key.created_by,
|
||||
"created_at": cls._serialize_datetime(key.created_at),
|
||||
"updated_at": cls._serialize_datetime(key.updated_at),
|
||||
"last_used_at": cls._serialize_datetime(key.last_used_at),
|
||||
"expires_at": cls._serialize_datetime(key.expires_at),
|
||||
"revoked_at": cls._serialize_datetime(key.revoked_at),
|
||||
"is_revoked": key.revoked_at is not None,
|
||||
"is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scopes(raw_scopes: Any) -> list[str]:
|
||||
if raw_scopes is None:
|
||||
return list(ALL_OPEN_API_SCOPES)
|
||||
if not isinstance(raw_scopes, list):
|
||||
raise ApiKeyServiceError("Invalid scopes")
|
||||
|
||||
scopes = [
|
||||
scope
|
||||
for scope in raw_scopes
|
||||
if isinstance(scope, str) and scope in ALL_OPEN_API_SCOPES
|
||||
]
|
||||
normalized_scopes = list(dict.fromkeys(scopes))
|
||||
if not normalized_scopes:
|
||||
raise ApiKeyServiceError("At least one valid scope is required")
|
||||
return normalized_scopes
|
||||
|
||||
@staticmethod
|
||||
def _resolve_expires_at(expires_in_days: Any) -> datetime | None:
|
||||
if expires_in_days is None:
|
||||
return None
|
||||
try:
|
||||
expires_in_days_int = int(expires_in_days)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ApiKeyServiceError("expires_in_days must be an integer") from exc
|
||||
if expires_in_days_int <= 0:
|
||||
raise ApiKeyServiceError("expires_in_days must be greater than 0")
|
||||
return datetime.now(timezone.utc) + timedelta(days=expires_in_days_int)
|
||||
|
||||
async def list_api_keys(self) -> list[dict]:
|
||||
keys = await self.db.list_api_keys()
|
||||
return [self.serialize_api_key(key) for key in keys]
|
||||
|
||||
async def create_api_key(self, payload: dict, *, created_by: str) -> dict:
|
||||
name = str(payload.get("name", "")).strip() or "Untitled API Key"
|
||||
scopes = self._normalize_scopes(payload.get("scopes"))
|
||||
expires_at = self._resolve_expires_at(payload.get("expires_in_days"))
|
||||
|
||||
raw_key = f"abk_{secrets.token_urlsafe(32)}"
|
||||
api_key = await self.db.create_api_key(
|
||||
name=name,
|
||||
key_hash=self.hash_key(raw_key),
|
||||
key_prefix=raw_key[:12],
|
||||
scopes=scopes, # type: ignore
|
||||
created_by=created_by,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
result = self.serialize_api_key(api_key)
|
||||
result["api_key"] = raw_key
|
||||
return result
|
||||
|
||||
async def create_api_key_from_legacy_payload(
|
||||
self,
|
||||
payload: object,
|
||||
*,
|
||||
created_by: str,
|
||||
) -> dict:
|
||||
data = payload if isinstance(payload, dict) else {}
|
||||
return await self.create_api_key(data, created_by=created_by)
|
||||
|
||||
async def revoke_api_key(self, key_id: str | None) -> bool:
|
||||
if not key_id:
|
||||
raise ApiKeyServiceError("Missing key: key_id")
|
||||
return await self.db.revoke_api_key(key_id)
|
||||
|
||||
async def revoke_api_key_from_legacy_payload(self, payload: object) -> None:
|
||||
data = payload if isinstance(payload, dict) else {}
|
||||
if not await self.revoke_api_key(data.get("key_id")):
|
||||
raise ApiKeyServiceError("API key not found")
|
||||
|
||||
async def delete_api_key(self, key_id: str | None) -> bool:
|
||||
if not key_id:
|
||||
raise ApiKeyServiceError("Missing key: key_id")
|
||||
return await self.db.delete_api_key(key_id)
|
||||
|
||||
async def delete_api_key_from_legacy_payload(self, payload: object) -> None:
|
||||
data = payload if isinstance(payload, dict) else {}
|
||||
if not await self.delete_api_key(data.get("key_id")):
|
||||
raise ApiKeyServiceError("API key not found")
|
||||
452
astrbot/dashboard/services/auth_service.py
Normal file
452
astrbot/dashboard/services/auth_service.py
Normal file
@@ -0,0 +1,452 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jwt
|
||||
import pyotp
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.auth_password import (
|
||||
is_default_dashboard_password,
|
||||
is_legacy_dashboard_password,
|
||||
validate_dashboard_password,
|
||||
verify_dashboard_password,
|
||||
)
|
||||
from astrbot.core.utils.totp import (
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME as _TOTP_TRUSTED_DEVICE_COOKIE_NAME,
|
||||
)
|
||||
from astrbot.core.utils.totp import (
|
||||
TOTP_TRUSTED_DEVICE_MAX_AGE as _TOTP_TRUSTED_DEVICE_MAX_AGE,
|
||||
)
|
||||
from astrbot.core.utils.totp import (
|
||||
TwoFactorCodeType,
|
||||
consume_configured_totp_code,
|
||||
consume_rotation_verified,
|
||||
consume_totp_code,
|
||||
generate_recovery_code,
|
||||
is_totp_enabled,
|
||||
is_totp_trusted_device_valid,
|
||||
issue_totp_trusted_device,
|
||||
revoke_user_trusted_devices,
|
||||
set_pending_totp_secret,
|
||||
set_rotation_verified,
|
||||
verify_configured_2fa_code,
|
||||
)
|
||||
from astrbot.dashboard.password_state import (
|
||||
get_dashboard_password_hash,
|
||||
is_password_change_required,
|
||||
is_password_storage_upgraded,
|
||||
set_dashboard_password_hashes,
|
||||
set_password_change_required,
|
||||
set_password_storage_upgraded,
|
||||
)
|
||||
|
||||
ALL_OPEN_API_SCOPES = (
|
||||
"chat",
|
||||
"config",
|
||||
"file",
|
||||
"im",
|
||||
"plugin",
|
||||
"tool",
|
||||
"skill",
|
||||
"kb",
|
||||
"persona",
|
||||
"data",
|
||||
"system",
|
||||
)
|
||||
|
||||
DASHBOARD_JWT_COOKIE_NAME = "astrbot_dashboard_jwt"
|
||||
DASHBOARD_JWT_COOKIE_MAX_AGE = 7 * 24 * 60 * 60
|
||||
SKIP_DEFAULT_PASSWORD_AUTH_ENV = "ASTRBOT_DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
|
||||
SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY = "DASHBOARD_SKIP_DEFAULT_PASSWORD_AUTH"
|
||||
LOCAL_DASHBOARD_HOSTS = {"127.0.0.1", "localhost", "::1"}
|
||||
DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE = (
|
||||
"Login failed. If this is your first time using AstrBot, the old default "
|
||||
"astrbot password has been replaced by a random strong password printed in "
|
||||
"the startup logs. Check the initial password in the logs and try again. "
|
||||
"Learn more: https://docs.astrbot.app/en/faq.html\n\n"
|
||||
"登录失败。如果您是初次使用,旧版默认 astrbot 密码已改为启动日志中输出的"
|
||||
"随机强密码。请使用日志中提供的的初始密码来登录。了解更多:"
|
||||
"https://docs.astrbot.app/faq.html"
|
||||
)
|
||||
LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE = (
|
||||
"Incorrect username or password. If you cannot log in after upgrading "
|
||||
"AstrBot even though the password is correct, see "
|
||||
"https://docs.astrbot.app/en/faq.html\n\n"
|
||||
"用户名或密码错误。如果你在升级 AstrBot 后遇到了密码正确但无法登录的情况,"
|
||||
"请参考 https://docs.astrbot.app/faq.html"
|
||||
)
|
||||
TOTP_TRUSTED_DEVICE_COOKIE_NAME = _TOTP_TRUSTED_DEVICE_COOKIE_NAME
|
||||
TOTP_TRUSTED_DEVICE_MAX_AGE = _TOTP_TRUSTED_DEVICE_MAX_AGE
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthServiceResult:
|
||||
status: str = "ok"
|
||||
data: dict | None = None
|
||||
message: str | None = None
|
||||
status_code: int = 200
|
||||
jwt_token: str | None = None
|
||||
trusted_device_token: str | None = None
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseDatabase,
|
||||
config: AstrBotConfig,
|
||||
*,
|
||||
demo_mode: bool = DEMO_MODE,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.config = config
|
||||
self.demo_mode = demo_mode
|
||||
|
||||
async def setup_status(self) -> AuthServiceResult:
|
||||
return AuthServiceResult(
|
||||
data={
|
||||
"setup_required": await self.is_setup_required(),
|
||||
"skip_default_password_auth": self.can_skip_default_password_auth(),
|
||||
"password_upgrade_required": not await is_password_storage_upgraded(
|
||||
self.db,
|
||||
self.config,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
async def totp_setup(self, post_data: object) -> AuthServiceResult:
|
||||
if isinstance(post_data, dict) and post_data.get("secret"):
|
||||
secret = post_data["secret"]
|
||||
code = post_data.get("code")
|
||||
if not isinstance(secret, str) or not secret.strip():
|
||||
return self.error("Invalid request payload")
|
||||
|
||||
if not isinstance(code, str) or not code.strip():
|
||||
return self.error("TOTP 验证码是必需的")
|
||||
if not await consume_totp_code(secret, code):
|
||||
return self.error("TOTP 验证码无效")
|
||||
|
||||
if is_totp_enabled(self.config) and not consume_rotation_verified():
|
||||
return self.error("需要先验证当前 TOTP")
|
||||
|
||||
set_pending_totp_secret(secret)
|
||||
recovery_code, recovery_code_hash = generate_recovery_code()
|
||||
return AuthServiceResult(
|
||||
data={
|
||||
"recovery_code": recovery_code,
|
||||
"recovery_code_hash": recovery_code_hash,
|
||||
},
|
||||
message="TOTP verified",
|
||||
)
|
||||
|
||||
if is_totp_enabled(self.config):
|
||||
if not isinstance(post_data, dict):
|
||||
return self.error("Invalid request payload")
|
||||
|
||||
set_rotation_verified(False)
|
||||
|
||||
code = post_data.get("code")
|
||||
if isinstance(code, str) and code.strip():
|
||||
if await consume_configured_totp_code(self.config, code):
|
||||
set_rotation_verified(True)
|
||||
return AuthServiceResult(data={"secret": pyotp.random_base32()})
|
||||
return self.error("当前 TOTP 验证码无效")
|
||||
|
||||
return self.error("需要提供 TOTP 验证码或新密钥")
|
||||
|
||||
return AuthServiceResult(data={"secret": pyotp.random_base32()})
|
||||
|
||||
async def totp_recovery(self) -> AuthServiceResult:
|
||||
recovery_code, recovery_code_hash = generate_recovery_code()
|
||||
return AuthServiceResult(
|
||||
data={
|
||||
"recovery_code": recovery_code,
|
||||
"recovery_code_hash": recovery_code_hash,
|
||||
}
|
||||
)
|
||||
|
||||
async def setup(self, post_data: object) -> AuthServiceResult:
|
||||
if not self.can_skip_default_password_auth():
|
||||
return self.error("Setup without password is not enabled")
|
||||
if not await self.is_setup_required():
|
||||
return self.error("Setup is not required")
|
||||
|
||||
return await self.complete_setup(post_data)
|
||||
|
||||
async def setup_authenticated(
|
||||
self,
|
||||
post_data: object,
|
||||
authenticated_username,
|
||||
) -> AuthServiceResult:
|
||||
if not await self.is_setup_required():
|
||||
return self.error("Setup is not required")
|
||||
if not isinstance(authenticated_username, str):
|
||||
return self.error("未授权")
|
||||
|
||||
return await self.complete_setup(post_data)
|
||||
|
||||
async def complete_setup(self, post_data: object) -> AuthServiceResult:
|
||||
if not isinstance(post_data, dict):
|
||||
return self.error("Invalid request payload")
|
||||
|
||||
new_username = post_data.get("username")
|
||||
new_password = post_data.get("password")
|
||||
confirm_password = post_data.get("confirm_password")
|
||||
if not isinstance(new_username, str) or len(new_username.strip()) < 3:
|
||||
return self.error("用户名长度至少3位")
|
||||
if not isinstance(new_password, str):
|
||||
return self.error("新密码无效")
|
||||
if not isinstance(confirm_password, str) or confirm_password != new_password:
|
||||
return self.error("两次输入的新密码不一致")
|
||||
|
||||
try:
|
||||
validate_dashboard_password(new_password)
|
||||
except ValueError as exc:
|
||||
return self.error(str(exc))
|
||||
|
||||
username = new_username.strip()
|
||||
self.config["dashboard"]["username"] = username
|
||||
set_dashboard_password_hashes(self.config, new_password)
|
||||
await set_password_storage_upgraded(self.db, self.config, True)
|
||||
await set_password_change_required(self.db, self.config, False)
|
||||
self.config.save_config()
|
||||
|
||||
token = self.generate_jwt(username)
|
||||
return AuthServiceResult(
|
||||
data={
|
||||
"token": token,
|
||||
"username": username,
|
||||
"change_pwd_hint": False,
|
||||
"legacy_pwd_hint": False,
|
||||
"password_upgrade_required": False,
|
||||
},
|
||||
message="Setup completed successfully",
|
||||
jwt_token=token,
|
||||
)
|
||||
|
||||
async def login(
|
||||
self,
|
||||
post_data: object,
|
||||
*,
|
||||
trusted_device_cookie_token: str,
|
||||
) -> AuthServiceResult:
|
||||
username = self.config["dashboard"]["username"]
|
||||
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
|
||||
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
|
||||
|
||||
req_username = (
|
||||
post_data.get("username") if isinstance(post_data, dict) else None
|
||||
)
|
||||
req_password = (
|
||||
post_data.get("password") if isinstance(post_data, dict) else None
|
||||
)
|
||||
totp_code = post_data.get("code") if isinstance(post_data, dict) else None
|
||||
trust_device_flag = (
|
||||
post_data.get("trust_device_flag") is True
|
||||
if isinstance(post_data, dict)
|
||||
else False
|
||||
)
|
||||
if not isinstance(req_username, str) or not isinstance(req_password, str):
|
||||
return self.error("Invalid request payload")
|
||||
|
||||
login_verified = req_username == username and verify_dashboard_password(
|
||||
password,
|
||||
req_password,
|
||||
)
|
||||
|
||||
if not login_verified:
|
||||
await asyncio.sleep(3)
|
||||
if req_password == "astrbot":
|
||||
return self.error(DEFAULT_PASSWORD_LOGIN_FAILURE_MESSAGE)
|
||||
if is_legacy_dashboard_password(password):
|
||||
return self.error(LEGACY_PASSWORD_LOGIN_FAILURE_MESSAGE)
|
||||
return self.error("用户名或密码错误", status_code=401)
|
||||
|
||||
totp_verified = False
|
||||
|
||||
if is_totp_enabled(self.config):
|
||||
if not await is_totp_trusted_device_valid(
|
||||
self.config,
|
||||
self.db,
|
||||
trusted_device_cookie_token,
|
||||
):
|
||||
if not isinstance(totp_code, str) or not totp_code.strip():
|
||||
return self.error(
|
||||
"需要 TOTP 验证",
|
||||
data={"totp_required": True},
|
||||
status_code=401,
|
||||
)
|
||||
verified_type = await verify_configured_2fa_code(
|
||||
self.config,
|
||||
totp_code,
|
||||
allow_recovery=True,
|
||||
)
|
||||
if verified_type is TwoFactorCodeType.TOTP:
|
||||
totp_verified = True
|
||||
elif verified_type is TwoFactorCodeType.RECOVERY:
|
||||
self.config["dashboard"]["totp"] = {
|
||||
"enable": False,
|
||||
"secret": "",
|
||||
"recovery_code_hash": "",
|
||||
}
|
||||
await revoke_user_trusted_devices(self.db)
|
||||
self.config.save_config()
|
||||
elif len(totp_code) == 6 and totp_code.isdigit():
|
||||
return self.error("TOTP 验证码无效", status_code=401)
|
||||
else:
|
||||
return self.error("恢复码无效", status_code=401)
|
||||
|
||||
change_pwd_hint = False
|
||||
legacy_pwd_hint = is_legacy_dashboard_password(password)
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db,
|
||||
self.config,
|
||||
)
|
||||
if (
|
||||
storage_upgraded
|
||||
and username == "astrbot"
|
||||
and is_default_dashboard_password(password)
|
||||
and not self.demo_mode
|
||||
):
|
||||
change_pwd_hint = True
|
||||
legacy_pwd_hint = True
|
||||
logger.warning("为了保证安全,请尽快修改默认密码。")
|
||||
if password_change_required and not self.demo_mode:
|
||||
change_pwd_hint = True
|
||||
token = self.generate_jwt(username)
|
||||
result = AuthServiceResult(
|
||||
data={
|
||||
"token": token,
|
||||
"username": username,
|
||||
"change_pwd_hint": change_pwd_hint,
|
||||
"legacy_pwd_hint": legacy_pwd_hint,
|
||||
"password_upgrade_required": not storage_upgraded,
|
||||
},
|
||||
jwt_token=token,
|
||||
)
|
||||
|
||||
if totp_verified and trust_device_flag:
|
||||
result.trusted_device_token = await issue_totp_trusted_device(
|
||||
self.config,
|
||||
self.db,
|
||||
)
|
||||
return result
|
||||
|
||||
async def edit_account(self, post_data: object) -> AuthServiceResult:
|
||||
if self.demo_mode:
|
||||
return self.error("You are not permitted to do this operation in demo mode")
|
||||
|
||||
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
|
||||
password = get_dashboard_password_hash(self.config, upgraded=storage_upgraded)
|
||||
if not isinstance(post_data, dict):
|
||||
return self.error("Invalid request payload")
|
||||
|
||||
req_password = post_data.get("password")
|
||||
if not isinstance(req_password, str):
|
||||
return self.error("Invalid request payload")
|
||||
|
||||
if not verify_dashboard_password(password, req_password):
|
||||
return self.error("原密码错误")
|
||||
|
||||
new_pwd = post_data.get("new_password", None)
|
||||
new_username = post_data.get("new_username", None)
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db,
|
||||
self.config,
|
||||
)
|
||||
if (not storage_upgraded or password_change_required) and not new_pwd:
|
||||
return self.error("请设置新密码以完成安全升级")
|
||||
if not new_pwd and not new_username:
|
||||
return self.error("新用户名和新密码不能同时为空")
|
||||
|
||||
if new_pwd:
|
||||
if not isinstance(new_pwd, str):
|
||||
return self.error("新密码无效")
|
||||
confirm_pwd = post_data.get("confirm_password", None)
|
||||
if not isinstance(confirm_pwd, str) or confirm_pwd != new_pwd:
|
||||
return self.error("两次输入的新密码不一致")
|
||||
try:
|
||||
validate_dashboard_password(new_pwd)
|
||||
except ValueError as exc:
|
||||
return self.error(str(exc))
|
||||
set_dashboard_password_hashes(self.config, new_pwd)
|
||||
await set_password_storage_upgraded(self.db, self.config, True)
|
||||
await set_password_change_required(self.db, self.config, False)
|
||||
if is_totp_enabled(self.config):
|
||||
await revoke_user_trusted_devices(self.db)
|
||||
if new_username:
|
||||
self.config["dashboard"]["username"] = new_username
|
||||
|
||||
self.config.save_config()
|
||||
|
||||
return AuthServiceResult(message="Updated account successfully")
|
||||
|
||||
def generate_jwt(self, username: str):
|
||||
payload = {
|
||||
"username": username,
|
||||
"exp": datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(days=7),
|
||||
}
|
||||
jwt_token = self.config["dashboard"].get("jwt_secret", None)
|
||||
if not jwt_token:
|
||||
raise ValueError("JWT secret is not set in the cmd_config.")
|
||||
return jwt.encode(payload, jwt_token, algorithm="HS256")
|
||||
|
||||
async def is_setup_required(self) -> bool:
|
||||
if self.demo_mode:
|
||||
return False
|
||||
|
||||
dashboard_config = self.config["dashboard"]
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db,
|
||||
self.config,
|
||||
)
|
||||
if password_change_required:
|
||||
return True
|
||||
|
||||
storage_upgraded = await is_password_storage_upgraded(self.db, self.config)
|
||||
if not storage_upgraded:
|
||||
return False
|
||||
|
||||
return dashboard_config.get(
|
||||
"username"
|
||||
) == "astrbot" and is_default_dashboard_password(
|
||||
dashboard_config.get("pbkdf2_password", "")
|
||||
)
|
||||
|
||||
def can_skip_default_password_auth(self) -> bool:
|
||||
if not self.env_flag_enabled(SKIP_DEFAULT_PASSWORD_AUTH_ENV):
|
||||
return False
|
||||
host = (
|
||||
os.environ.get("DASHBOARD_HOST")
|
||||
or os.environ.get("ASTRBOT_DASHBOARD_HOST")
|
||||
or self.config["dashboard"].get("host", "")
|
||||
)
|
||||
return str(host).strip().lower() in LOCAL_DASHBOARD_HOSTS
|
||||
|
||||
@staticmethod
|
||||
def env_flag_enabled(name: str) -> bool:
|
||||
value = os.environ.get(name)
|
||||
if value is None and name == SKIP_DEFAULT_PASSWORD_AUTH_ENV:
|
||||
value = os.environ.get(SKIP_DEFAULT_PASSWORD_AUTH_ENV_LEGACY)
|
||||
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
@staticmethod
|
||||
def error(
|
||||
message: str,
|
||||
*,
|
||||
data: dict | None = None,
|
||||
status_code: int = 200,
|
||||
) -> AuthServiceResult:
|
||||
return AuthServiceResult(
|
||||
status="error",
|
||||
data=data,
|
||||
message=message,
|
||||
status_code=status_code,
|
||||
)
|
||||
694
astrbot/dashboard/services/backup_service.py
Normal file
694
astrbot/dashboard/services/backup_service.py
Normal file
@@ -0,0 +1,694 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import AstrBotImporter
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
CHUNK_SIZE = 1024 * 1024
|
||||
UPLOAD_EXPIRE_SECONDS = 3600
|
||||
|
||||
|
||||
class BackupServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupDownload:
|
||||
path: str
|
||||
filename: str
|
||||
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
filename = filename.replace("\\", "/")
|
||||
filename = os.path.basename(filename)
|
||||
filename = filename.replace("..", "_")
|
||||
filename = re.sub(r"[^\w\-.]", "_", filename)
|
||||
filename = filename.strip(".")
|
||||
if not filename or filename.replace("_", "") == "":
|
||||
filename = "backup"
|
||||
return filename
|
||||
|
||||
|
||||
def generate_unique_filename(original_filename: str) -> str:
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{name}_{timestamp}{ext}"
|
||||
|
||||
|
||||
class BackupService:
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.backup_dir = get_astrbot_backups_path()
|
||||
self.data_dir = get_astrbot_data_path()
|
||||
self.chunks_dir = os.path.join(self.backup_dir, ".chunks")
|
||||
self.backup_tasks: dict[str, dict] = {}
|
||||
self.backup_progress: dict[str, dict] = {}
|
||||
self.upload_sessions: dict[str, dict] = {}
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict[str, Any]:
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
async def _save_upload(file: Any, target_path: str) -> None:
|
||||
if hasattr(file, "save"):
|
||||
result = file.save(target_path)
|
||||
if hasattr(result, "__await__"):
|
||||
await result
|
||||
return
|
||||
|
||||
if hasattr(file, "read"):
|
||||
data = file.read()
|
||||
if hasattr(data, "__await__"):
|
||||
data = await data
|
||||
Path(target_path).write_bytes(data)
|
||||
return
|
||||
|
||||
raise BackupServiceError("无效的上传文件")
|
||||
|
||||
@staticmethod
|
||||
def _validate_backup_filename(filename: str | None, *, missing: str) -> str:
|
||||
if not filename:
|
||||
raise BackupServiceError(missing)
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise BackupServiceError("无效的文件名")
|
||||
return filename
|
||||
|
||||
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
|
||||
self.backup_tasks[task_id] = {
|
||||
"type": task_type,
|
||||
"status": status,
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
self.backup_progress[task_id] = {
|
||||
"status": status,
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
"message": "",
|
||||
}
|
||||
|
||||
def _set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
result: dict | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
if task_id in self.backup_tasks:
|
||||
self.backup_tasks[task_id]["status"] = status
|
||||
self.backup_tasks[task_id]["result"] = result
|
||||
self.backup_tasks[task_id]["error"] = error
|
||||
if task_id in self.backup_progress:
|
||||
self.backup_progress[task_id]["status"] = status
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
stage: str | None = None,
|
||||
current: int | None = None,
|
||||
total: int | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
if task_id not in self.backup_progress:
|
||||
return
|
||||
progress = self.backup_progress[task_id]
|
||||
if status is not None:
|
||||
progress["status"] = status
|
||||
if stage is not None:
|
||||
progress["stage"] = stage
|
||||
if current is not None:
|
||||
progress["current"] = current
|
||||
if total is not None:
|
||||
progress["total"] = total
|
||||
if message is not None:
|
||||
progress["message"] = message
|
||||
|
||||
def _make_progress_callback(self, task_id: str):
|
||||
async def _callback(
|
||||
stage: str,
|
||||
current: int,
|
||||
total: int,
|
||||
message: str = "",
|
||||
) -> None:
|
||||
self._update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
stage=stage,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
def ensure_cleanup_task_started(self) -> None:
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
try:
|
||||
self._cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_uploads()
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def _cleanup_expired_uploads(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300)
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
for upload_id, session in self.upload_sessions.items():
|
||||
last_activity = session.get("last_activity", session["created_at"])
|
||||
if current_time - last_activity > UPLOAD_EXPIRE_SECONDS:
|
||||
expired_sessions.append(upload_id)
|
||||
|
||||
for upload_id in expired_sessions:
|
||||
await self.cleanup_upload_session(upload_id)
|
||||
logger.info(f"清理过期的上传会话: {upload_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error(f"清理过期上传会话失败: {exc}")
|
||||
|
||||
async def cleanup_upload_session(self, upload_id: str) -> None:
|
||||
if upload_id in self.upload_sessions:
|
||||
session = self.upload_sessions[upload_id]
|
||||
chunk_dir = session.get("chunk_dir")
|
||||
if chunk_dir and os.path.exists(chunk_dir):
|
||||
try:
|
||||
shutil.rmtree(chunk_dir)
|
||||
except Exception as exc:
|
||||
logger.warning(f"清理分片目录失败: {exc}")
|
||||
del self.upload_sessions[upload_id]
|
||||
|
||||
def get_backup_manifest(self, zip_path: str) -> dict | None:
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
if "manifest.json" in zf.namelist():
|
||||
manifest_data = zf.read("manifest.json")
|
||||
return json.loads(manifest_data.decode("utf-8"))
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.debug(f"读取备份 manifest 失败: {exc}")
|
||||
return None
|
||||
|
||||
def list_backups(self, *, page: int, page_size: int) -> dict:
|
||||
self.ensure_cleanup_task_started()
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
backup_files = []
|
||||
for filename in os.listdir(self.backup_dir):
|
||||
if not filename.endswith(".zip") or filename.startswith("."):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
|
||||
manifest = self.get_backup_manifest(file_path)
|
||||
if manifest is None:
|
||||
logger.debug(f"跳过无效备份文件: {filename}")
|
||||
continue
|
||||
|
||||
stat = os.stat(file_path)
|
||||
backup_files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"created_at": stat.st_mtime,
|
||||
"type": manifest.get("origin", "exported"),
|
||||
"astrbot_version": manifest.get("astrbot_version", "未知"),
|
||||
"exported_at": manifest.get("exported_at"),
|
||||
}
|
||||
)
|
||||
|
||||
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return {
|
||||
"items": backup_files[start:end],
|
||||
"total": len(backup_files),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
def list_backups_from_legacy_query(self, *, page, page_size) -> dict:
|
||||
return self.list_backups(
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 20),
|
||||
)
|
||||
|
||||
def export_backup(self) -> dict:
|
||||
task_id = str(uuid.uuid4())
|
||||
self._init_task(task_id, "export", "pending")
|
||||
asyncio.create_task(self.background_export_task(task_id))
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"message": "export task created, processing in background",
|
||||
}
|
||||
|
||||
async def background_export_task(self, task_id: str) -> None:
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
exporter = AstrBotExporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
zip_path = await exporter.export_all(
|
||||
output_dir=self.backup_dir,
|
||||
progress_callback=self._make_progress_callback(task_id),
|
||||
)
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"filename": os.path.basename(zip_path),
|
||||
"path": zip_path,
|
||||
"size": os.path.getsize(zip_path),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"后台导出任务 {task_id} 失败: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(exc))
|
||||
|
||||
async def upload_backup(self, file: Any | None) -> dict:
|
||||
if not file:
|
||||
raise BackupServiceError("缺少备份文件")
|
||||
if not file.filename or not file.filename.endswith(".zip"):
|
||||
raise BackupServiceError("请上传 ZIP 格式的备份文件")
|
||||
|
||||
safe_filename = secure_filename(file.filename)
|
||||
unique_filename = generate_unique_filename(safe_filename)
|
||||
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
zip_path = os.path.join(self.backup_dir, unique_filename)
|
||||
await self._save_upload(file, zip_path)
|
||||
|
||||
logger.info(
|
||||
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
|
||||
)
|
||||
return {
|
||||
"filename": unique_filename,
|
||||
"original_filename": file.filename,
|
||||
"size": os.path.getsize(zip_path),
|
||||
}
|
||||
|
||||
def upload_init(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
filename = payload.get("filename")
|
||||
total_size = payload.get("total_size", 0)
|
||||
|
||||
if not filename:
|
||||
raise BackupServiceError("缺少 filename 参数")
|
||||
if not filename.endswith(".zip"):
|
||||
raise BackupServiceError("请上传 ZIP 格式的备份文件")
|
||||
if total_size <= 0:
|
||||
raise BackupServiceError("无效的文件大小")
|
||||
|
||||
total_chunks = math.ceil(total_size / CHUNK_SIZE)
|
||||
upload_id = str(uuid.uuid4())
|
||||
chunk_dir = os.path.join(self.chunks_dir, upload_id)
|
||||
Path(chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
safe_filename = secure_filename(filename)
|
||||
unique_filename = generate_unique_filename(safe_filename)
|
||||
current_time = time.time()
|
||||
self.upload_sessions[upload_id] = {
|
||||
"filename": unique_filename,
|
||||
"original_filename": filename,
|
||||
"total_size": total_size,
|
||||
"total_chunks": total_chunks,
|
||||
"received_chunks": set(),
|
||||
"created_at": current_time,
|
||||
"last_activity": current_time,
|
||||
"chunk_dir": chunk_dir,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"初始化分片上传: upload_id={upload_id}, "
|
||||
f"filename={unique_filename}, total_chunks={total_chunks}"
|
||||
)
|
||||
|
||||
return {
|
||||
"upload_id": upload_id,
|
||||
"chunk_size": CHUNK_SIZE,
|
||||
"total_chunks": total_chunks,
|
||||
"filename": unique_filename,
|
||||
}
|
||||
|
||||
async def upload_chunk(
|
||||
self,
|
||||
*,
|
||||
upload_id: str | None,
|
||||
chunk_index_str: str | None,
|
||||
chunk_file: Any | None,
|
||||
) -> dict:
|
||||
if not upload_id or chunk_index_str is None:
|
||||
raise BackupServiceError("缺少必要参数")
|
||||
|
||||
try:
|
||||
chunk_index = int(chunk_index_str)
|
||||
except ValueError as exc:
|
||||
raise BackupServiceError("无效的分片索引") from exc
|
||||
|
||||
if not chunk_file:
|
||||
raise BackupServiceError("缺少分片数据")
|
||||
if upload_id not in self.upload_sessions:
|
||||
raise BackupServiceError("上传会话不存在或已过期")
|
||||
|
||||
session = self.upload_sessions[upload_id]
|
||||
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
|
||||
raise BackupServiceError("分片索引超出范围")
|
||||
|
||||
chunk_path = os.path.join(session["chunk_dir"], f"{chunk_index}.part")
|
||||
await self._save_upload(chunk_file, chunk_path)
|
||||
session["received_chunks"].add(chunk_index)
|
||||
session["last_activity"] = time.time()
|
||||
|
||||
received_count = len(session["received_chunks"])
|
||||
total_chunks = session["total_chunks"]
|
||||
logger.debug(
|
||||
f"接收分片: upload_id={upload_id}, chunk={chunk_index + 1}/{total_chunks}"
|
||||
)
|
||||
|
||||
return {
|
||||
"received": received_count,
|
||||
"total": total_chunks,
|
||||
"chunk_index": chunk_index,
|
||||
}
|
||||
|
||||
def mark_backup_as_uploaded(self, zip_path: str) -> None:
|
||||
try:
|
||||
manifest = {"origin": "uploaded", "uploaded_at": datetime.now().isoformat()}
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
if "manifest.json" in zf.namelist():
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data.decode("utf-8"))
|
||||
manifest["origin"] = "uploaded"
|
||||
manifest["uploaded_at"] = datetime.now().isoformat()
|
||||
|
||||
with zipfile.ZipFile(zip_path, "a") as zf:
|
||||
new_manifest = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", new_manifest)
|
||||
|
||||
logger.debug(f"已标记备份为上传来源: {zip_path}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"标记备份来源失败: {exc}")
|
||||
|
||||
async def upload_complete(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
upload_id = payload.get("upload_id")
|
||||
|
||||
if not upload_id:
|
||||
raise BackupServiceError("缺少 upload_id 参数")
|
||||
if upload_id not in self.upload_sessions:
|
||||
raise BackupServiceError("上传会话不存在或已过期")
|
||||
|
||||
session = self.upload_sessions[upload_id]
|
||||
received = session["received_chunks"]
|
||||
total = session["total_chunks"]
|
||||
|
||||
if len(received) != total:
|
||||
missing = set(range(total)) - received
|
||||
raise BackupServiceError(f"分片不完整,缺少: {sorted(missing)[:10]}...")
|
||||
|
||||
chunk_dir = session["chunk_dir"]
|
||||
filename = session["filename"]
|
||||
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
output_path = os.path.join(self.backup_dir, filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as outfile:
|
||||
for i in range(total):
|
||||
chunk_path = os.path.join(chunk_dir, f"{i}.part")
|
||||
with open(chunk_path, "rb") as chunk_file:
|
||||
while True:
|
||||
data_block = chunk_file.read(8192)
|
||||
if not data_block:
|
||||
break
|
||||
outfile.write(data_block)
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
self.mark_backup_as_uploaded(output_path)
|
||||
logger.info(f"分片上传完成: {filename}, size={file_size}, chunks={total}")
|
||||
await self.cleanup_upload_session(upload_id)
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"original_filename": session["original_filename"],
|
||||
"size": file_size,
|
||||
}
|
||||
except Exception:
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
raise
|
||||
|
||||
async def upload_abort(self, data: object) -> tuple[dict | None, str | None]:
|
||||
payload = self._payload(data)
|
||||
upload_id = payload.get("upload_id")
|
||||
if not upload_id:
|
||||
raise BackupServiceError("缺少 upload_id 参数")
|
||||
|
||||
if upload_id in self.upload_sessions:
|
||||
await self.cleanup_upload_session(upload_id)
|
||||
logger.info(f"取消分片上传: {upload_id}")
|
||||
|
||||
return None, "上传已取消"
|
||||
|
||||
def check_backup(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
filename = self._validate_backup_filename(
|
||||
payload.get("filename"),
|
||||
missing="缺少 filename 参数",
|
||||
)
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
raise BackupServiceError(f"备份文件不存在: {filename}")
|
||||
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
return importer.pre_check(zip_path).to_dict()
|
||||
|
||||
def import_backup(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
filename = self._validate_backup_filename(
|
||||
payload.get("filename"),
|
||||
missing="缺少 filename 参数",
|
||||
)
|
||||
confirmed = payload.get("confirmed", False)
|
||||
if not confirmed:
|
||||
raise BackupServiceError(
|
||||
"请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。"
|
||||
)
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
raise BackupServiceError(f"备份文件不存在: {filename}")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
self._init_task(task_id, "import", "pending")
|
||||
asyncio.create_task(self.background_import_task(task_id, zip_path))
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"message": "import task created, processing in background",
|
||||
}
|
||||
|
||||
async def background_import_task(self, task_id: str, zip_path: str) -> None:
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
result = await importer.import_all(
|
||||
zip_path=zip_path,
|
||||
mode="replace",
|
||||
progress_callback=self._make_progress_callback(task_id),
|
||||
)
|
||||
|
||||
if result.success:
|
||||
self._set_task_result(task_id, "completed", result=result.to_dict())
|
||||
else:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"failed",
|
||||
error="; ".join(result.errors),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"后台导入任务 {task_id} 失败: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(exc))
|
||||
|
||||
def get_progress(self, task_id: str | None) -> dict:
|
||||
if not task_id:
|
||||
raise BackupServiceError("缺少参数 task_id")
|
||||
if task_id not in self.backup_tasks:
|
||||
raise BackupServiceError("找不到该任务")
|
||||
|
||||
task_info = self.backup_tasks[task_id]
|
||||
status = task_info["status"]
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"type": task_info["type"],
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if status == "processing" and task_id in self.backup_progress:
|
||||
response_data["progress"] = self.backup_progress[task_id]
|
||||
if status == "completed":
|
||||
response_data["result"] = task_info["result"]
|
||||
if status == "failed":
|
||||
response_data["error"] = task_info["error"]
|
||||
|
||||
return response_data
|
||||
|
||||
def get_progress_from_legacy_query(self, task_id: str | None) -> dict:
|
||||
return self.get_progress(task_id)
|
||||
|
||||
def prepare_download(
|
||||
self,
|
||||
*,
|
||||
filename: str | None,
|
||||
token: str | None,
|
||||
jwt_secret: str | None,
|
||||
) -> BackupDownload:
|
||||
if not filename:
|
||||
raise BackupServiceError("缺少参数 filename")
|
||||
if not token:
|
||||
raise BackupServiceError("缺少参数 token")
|
||||
if not jwt_secret:
|
||||
raise BackupServiceError("服务器配置错误")
|
||||
|
||||
try:
|
||||
jwt.decode(
|
||||
token,
|
||||
jwt_secret,
|
||||
algorithms=["HS256"],
|
||||
options={
|
||||
"require": ["exp"],
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
},
|
||||
)
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise BackupServiceError("Token 已过期,请刷新页面后重试") from exc
|
||||
except jwt.InvalidTokenError as exc:
|
||||
raise BackupServiceError("Token 无效") from exc
|
||||
|
||||
filename = self._validate_backup_filename(filename, missing="缺少参数 filename")
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
raise BackupServiceError("备份文件不存在")
|
||||
return BackupDownload(path=file_path, filename=filename)
|
||||
|
||||
def prepare_download_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
filename: str | None,
|
||||
token: str | None,
|
||||
jwt_secret: str | None = None,
|
||||
) -> BackupDownload:
|
||||
return self.prepare_download(
|
||||
filename=filename,
|
||||
token=token,
|
||||
jwt_secret=jwt_secret or self.config.get("dashboard", {}).get("jwt_secret"),
|
||||
)
|
||||
|
||||
def delete_backup(self, data: object) -> tuple[dict | None, str | None]:
|
||||
payload = self._payload(data)
|
||||
filename = self._validate_backup_filename(
|
||||
payload.get("filename"),
|
||||
missing="缺少参数 filename",
|
||||
)
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
raise BackupServiceError("备份文件不存在")
|
||||
|
||||
os.remove(file_path)
|
||||
return None, "删除备份成功"
|
||||
|
||||
def rename_backup(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
filename = self._validate_backup_filename(
|
||||
payload.get("filename"),
|
||||
missing="缺少参数 filename",
|
||||
)
|
||||
new_name = payload.get("new_name")
|
||||
if not new_name:
|
||||
raise BackupServiceError("缺少参数 new_name")
|
||||
|
||||
new_name = secure_filename(new_name)
|
||||
if new_name.endswith(".zip"):
|
||||
new_name = new_name[:-4]
|
||||
if not new_name or new_name.replace("_", "") == "":
|
||||
raise BackupServiceError("新文件名无效")
|
||||
|
||||
new_filename = f"{new_name}.zip"
|
||||
old_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(old_path):
|
||||
raise BackupServiceError("备份文件不存在")
|
||||
|
||||
new_path = os.path.join(self.backup_dir, new_filename)
|
||||
if os.path.exists(new_path):
|
||||
raise BackupServiceError(f"文件名 '{new_filename}' 已存在")
|
||||
|
||||
os.rename(old_path, new_path)
|
||||
logger.info(f"备份文件重命名: {filename} -> {new_filename}")
|
||||
return {
|
||||
"old_filename": filename,
|
||||
"new_filename": new_filename,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
1626
astrbot/dashboard/services/chat_service.py
Normal file
1626
astrbot/dashboard/services/chat_service.py
Normal file
File diff suppressed because it is too large
Load Diff
162
astrbot/dashboard/services/chatui_project_service.py
Normal file
162
astrbot/dashboard/services/chatui_project_service.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
|
||||
|
||||
class ChatUIProjectServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChatUIProjectService:
|
||||
def __init__(self, db: BaseDatabase) -> None:
|
||||
self.db = db
|
||||
|
||||
async def create_project(self, username: str, data: object) -> dict:
|
||||
payload = self._as_payload(data)
|
||||
title = payload.get("title")
|
||||
emoji = payload.get("emoji", "📁")
|
||||
description = payload.get("description")
|
||||
|
||||
if not title:
|
||||
raise ChatUIProjectServiceError("Missing key: title")
|
||||
|
||||
project = await self.db.create_chatui_project(
|
||||
creator=username,
|
||||
title=title,
|
||||
emoji=emoji,
|
||||
description=description,
|
||||
)
|
||||
return self._serialize_project(project)
|
||||
|
||||
async def list_projects(self, username: str) -> list[dict]:
|
||||
projects = await self.db.get_chatui_projects_by_creator(creator=username)
|
||||
return [self._serialize_project(project) for project in projects]
|
||||
|
||||
async def get_project(self, username: str, project_id: str | None) -> dict:
|
||||
if not project_id:
|
||||
raise ChatUIProjectServiceError("Missing key: project_id")
|
||||
|
||||
project = await self._get_owned_project(username, project_id)
|
||||
return self._serialize_project(project)
|
||||
|
||||
async def get_project_from_legacy_query(
|
||||
self,
|
||||
username: str,
|
||||
project_id: str | None,
|
||||
) -> dict:
|
||||
return await self.get_project(username, project_id)
|
||||
|
||||
async def update_project(self, username: str, data: object) -> None:
|
||||
payload = self._as_payload(data)
|
||||
project_id = payload.get("project_id")
|
||||
if not project_id:
|
||||
raise ChatUIProjectServiceError("Missing key: project_id")
|
||||
|
||||
await self._get_owned_project(username, project_id)
|
||||
await self.db.update_chatui_project(
|
||||
project_id=project_id,
|
||||
title=payload.get("title"),
|
||||
emoji=payload.get("emoji"),
|
||||
description=payload.get("description"),
|
||||
)
|
||||
|
||||
async def delete_project(self, username: str, project_id: str | None) -> None:
|
||||
if not project_id:
|
||||
raise ChatUIProjectServiceError("Missing key: project_id")
|
||||
|
||||
await self._get_owned_project(username, project_id)
|
||||
await self.db.delete_chatui_project(project_id)
|
||||
|
||||
async def delete_project_from_legacy_query(
|
||||
self,
|
||||
username: str,
|
||||
project_id: str | None,
|
||||
) -> None:
|
||||
await self.delete_project(username, project_id)
|
||||
|
||||
async def add_session_to_project(self, username: str, data: object) -> None:
|
||||
payload = self._as_payload(data)
|
||||
session_id = payload.get("session_id")
|
||||
project_id = payload.get("project_id")
|
||||
|
||||
if not session_id:
|
||||
raise ChatUIProjectServiceError("Missing key: session_id")
|
||||
if not project_id:
|
||||
raise ChatUIProjectServiceError("Missing key: project_id")
|
||||
|
||||
await self._get_owned_project(username, project_id)
|
||||
await self._get_owned_session(username, session_id)
|
||||
await self.db.add_session_to_project(session_id, project_id)
|
||||
|
||||
async def remove_session_from_project(self, username: str, data: object) -> None:
|
||||
payload = self._as_payload(data)
|
||||
session_id = payload.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
raise ChatUIProjectServiceError("Missing key: session_id")
|
||||
|
||||
await self._get_owned_session(username, session_id)
|
||||
await self.db.remove_session_from_project(session_id)
|
||||
|
||||
async def get_project_sessions(
|
||||
self,
|
||||
username: str,
|
||||
project_id: str | None,
|
||||
) -> list[dict]:
|
||||
if not project_id:
|
||||
raise ChatUIProjectServiceError("Missing key: project_id")
|
||||
|
||||
await self._get_owned_project(username, project_id)
|
||||
sessions = await self.db.get_project_sessions(project_id)
|
||||
return [self._serialize_session(session) for session in sessions]
|
||||
|
||||
async def get_project_sessions_from_legacy_query(
|
||||
self,
|
||||
username: str,
|
||||
project_id: str | None,
|
||||
) -> list[dict]:
|
||||
return await self.get_project_sessions(username, project_id)
|
||||
|
||||
async def _get_owned_project(self, username: str, project_id: str):
|
||||
project = await self.db.get_chatui_project_by_id(project_id)
|
||||
if not project:
|
||||
raise ChatUIProjectServiceError(f"Project {project_id} not found")
|
||||
if project.creator != username:
|
||||
raise ChatUIProjectServiceError("Permission denied")
|
||||
return project
|
||||
|
||||
async def _get_owned_session(self, username: str, session_id: str):
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if not session:
|
||||
raise ChatUIProjectServiceError(f"Session {session_id} not found")
|
||||
if session.creator != username:
|
||||
raise ChatUIProjectServiceError("Permission denied")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def _serialize_project(project) -> dict:
|
||||
return {
|
||||
"project_id": project.project_id,
|
||||
"title": project.title,
|
||||
"emoji": project.emoji,
|
||||
"description": project.description,
|
||||
"created_at": to_utc_isoformat(project.created_at),
|
||||
"updated_at": to_utc_isoformat(project.updated_at),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_session(session) -> dict:
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"platform_id": session.platform_id,
|
||||
"creator": session.creator,
|
||||
"display_name": session.display_name,
|
||||
"is_group": session.is_group,
|
||||
"created_at": to_utc_isoformat(session.created_at),
|
||||
"updated_at": to_utc_isoformat(session.updated_at),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _as_payload(data: object) -> dict:
|
||||
return data if isinstance(data, dict) else {}
|
||||
133
astrbot/dashboard/services/command_service.py
Normal file
133
astrbot/dashboard/services/command_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star.command_management import (
|
||||
list_command_conflicts,
|
||||
list_commands,
|
||||
rename_command,
|
||||
toggle_command,
|
||||
update_command_permission,
|
||||
)
|
||||
|
||||
|
||||
class CommandServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandService:
|
||||
def __init__(
|
||||
self,
|
||||
config: AstrBotConfig,
|
||||
core_lifecycle: AstrBotCoreLifecycle | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.core_lifecycle = core_lifecycle
|
||||
|
||||
async def list_commands(self, config_id: str = "") -> dict:
|
||||
commands = await list_commands()
|
||||
summary = {
|
||||
"total": len(commands),
|
||||
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
|
||||
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
|
||||
}
|
||||
wake_prefix = self._get_wake_prefix(config_id)
|
||||
return {
|
||||
"items": commands,
|
||||
"summary": summary,
|
||||
"wake_prefix": wake_prefix,
|
||||
}
|
||||
|
||||
async def list_commands_from_legacy_query(self, config_id: str | None) -> dict:
|
||||
return await self.list_commands(config_id or "")
|
||||
|
||||
async def list_conflicts(self):
|
||||
return await list_command_conflicts()
|
||||
|
||||
async def toggle_command(self, handler_full_name: str | None, enabled) -> dict:
|
||||
if handler_full_name is None or enabled is None:
|
||||
raise CommandServiceError("handler_full_name 与 enabled 均为必填。")
|
||||
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("1", "true", "yes", "on")
|
||||
|
||||
try:
|
||||
await toggle_command(handler_full_name, bool(enabled))
|
||||
except ValueError as exc:
|
||||
raise CommandServiceError(str(exc)) from exc
|
||||
|
||||
return await self._get_command_payload(handler_full_name)
|
||||
|
||||
async def toggle_command_from_legacy_payload(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
return await self.toggle_command(
|
||||
payload.get("handler_full_name"),
|
||||
payload.get("enabled"),
|
||||
)
|
||||
|
||||
async def rename_command(
|
||||
self,
|
||||
handler_full_name: str | None,
|
||||
new_name: str | None,
|
||||
aliases=None,
|
||||
) -> dict:
|
||||
if not handler_full_name or not new_name:
|
||||
raise CommandServiceError("handler_full_name 与 new_name 均为必填。")
|
||||
|
||||
try:
|
||||
await rename_command(handler_full_name, new_name, aliases=aliases)
|
||||
except ValueError as exc:
|
||||
raise CommandServiceError(str(exc)) from exc
|
||||
|
||||
return await self._get_command_payload(handler_full_name)
|
||||
|
||||
async def rename_command_from_legacy_payload(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
return await self.rename_command(
|
||||
payload.get("handler_full_name"),
|
||||
payload.get("new_name"),
|
||||
aliases=payload.get("aliases"),
|
||||
)
|
||||
|
||||
async def update_permission(
|
||||
self,
|
||||
handler_full_name: str | None,
|
||||
permission: str | None,
|
||||
) -> dict:
|
||||
if not handler_full_name or not permission:
|
||||
raise CommandServiceError("handler_full_name 与 permission 均为必填。")
|
||||
|
||||
try:
|
||||
await update_command_permission(handler_full_name, permission)
|
||||
except ValueError as exc:
|
||||
raise CommandServiceError(str(exc)) from exc
|
||||
|
||||
return await self._get_command_payload(handler_full_name)
|
||||
|
||||
async def update_permission_from_legacy_payload(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
return await self.update_permission(
|
||||
payload.get("handler_full_name"),
|
||||
payload.get("permission"),
|
||||
)
|
||||
|
||||
def _get_wake_prefix(self, config_id: str) -> list:
|
||||
wake_prefix = self.config.get("wake_prefix", ["/"])
|
||||
config_id = config_id.strip()
|
||||
if config_id and self.core_lifecycle:
|
||||
config_mgr = getattr(self.core_lifecycle, "astrbot_config_mgr", None)
|
||||
if config_mgr and config_id in config_mgr.confs:
|
||||
return config_mgr.confs[config_id].get("wake_prefix", wake_prefix)
|
||||
return wake_prefix
|
||||
|
||||
@staticmethod
|
||||
async def _get_command_payload(handler_full_name: str) -> dict:
|
||||
commands = await list_commands()
|
||||
for cmd in commands:
|
||||
if cmd["handler_full_name"] == handler_full_name:
|
||||
return cmd
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict:
|
||||
return data if isinstance(data, dict) else {}
|
||||
1701
astrbot/dashboard/services/config_service.py
Normal file
1701
astrbot/dashboard/services/config_service.py
Normal file
File diff suppressed because it is too large
Load Diff
337
astrbot/dashboard/services/conversation_service.py
Normal file
337
astrbot/dashboard/services/conversation_service.py
Normal file
@@ -0,0 +1,337 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias
|
||||
|
||||
|
||||
class ConversationServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationExport:
|
||||
file_obj: BytesIO
|
||||
filename: str
|
||||
mimetype: str = "application/jsonl"
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(
|
||||
self,
|
||||
db_helper: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
self.db_helper = db_helper
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
|
||||
async def list_conversations(
|
||||
self,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
platforms: str,
|
||||
message_types: str,
|
||||
search_query: str,
|
||||
exclude_ids: str,
|
||||
exclude_platforms: str,
|
||||
) -> dict:
|
||||
platform_list = platforms.split(",") if platforms else []
|
||||
message_type_list = message_types.split(",") if message_types else []
|
||||
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
|
||||
exclude_platform_list = (
|
||||
exclude_platforms.split(",") if exclude_platforms else []
|
||||
)
|
||||
|
||||
page = max(page, 1)
|
||||
if page_size < 1:
|
||||
page_size = 20
|
||||
page_size = min(page_size, 100)
|
||||
|
||||
try:
|
||||
conversations, total_count = await self.conv_mgr.get_filtered_conversations(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
platforms=platform_list,
|
||||
message_types=message_type_list,
|
||||
search_query=search_query,
|
||||
exclude_ids=exclude_id_list,
|
||||
exclude_platforms=exclude_platform_list,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"数据库查询出错: {exc!s}\n{traceback.format_exc()}")
|
||||
raise ConversationServiceError(f"数据库查询出错: {exc!s}") from exc
|
||||
|
||||
total_pages = (
|
||||
(total_count + page_size - 1) // page_size if total_count > 0 else 1
|
||||
)
|
||||
umos = sorted({conv.user_id for conv in conversations if conv.user_id})
|
||||
alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases(umos))
|
||||
|
||||
return {
|
||||
"conversations": [
|
||||
self._serialize_conversation(conversation, alias_map)
|
||||
for conversation in conversations
|
||||
],
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total_count,
|
||||
"total_pages": total_pages,
|
||||
},
|
||||
}
|
||||
|
||||
async def list_conversations_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
page,
|
||||
page_size,
|
||||
platforms: str | None,
|
||||
message_types: str | None,
|
||||
search_query: str | None,
|
||||
exclude_ids: str | None,
|
||||
exclude_platforms: str | None,
|
||||
) -> dict:
|
||||
return await self.list_conversations(
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 20),
|
||||
platforms=platforms or "",
|
||||
message_types=message_types or "",
|
||||
search_query=search_query or "",
|
||||
exclude_ids=exclude_ids or "",
|
||||
exclude_platforms=exclude_platforms or "",
|
||||
)
|
||||
|
||||
async def get_conversation_detail(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
user_id, cid = self._require_user_and_cid(payload)
|
||||
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
if not conversation:
|
||||
raise ConversationServiceError("对话不存在")
|
||||
|
||||
alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases([user_id]))
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"cid": cid,
|
||||
"title": conversation.title,
|
||||
"persona_id": conversation.persona_id,
|
||||
"history": conversation.history,
|
||||
"created_at": conversation.created_at,
|
||||
"updated_at": conversation.updated_at,
|
||||
"umo_info": self._build_umo_info(user_id, alias_map),
|
||||
}
|
||||
|
||||
async def update_conversation(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
user_id, cid = self._require_user_and_cid(payload)
|
||||
title = payload.get("title")
|
||||
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
if not conversation:
|
||||
raise ConversationServiceError("对话不存在")
|
||||
|
||||
persona_id = payload.get("persona_id", conversation.persona_id)
|
||||
|
||||
if title is not None or persona_id is not None:
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
return {"message": "对话信息更新成功"}
|
||||
|
||||
async def delete_conversation(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
if "conversations" in payload:
|
||||
return await self._delete_conversations(payload.get("conversations", []))
|
||||
|
||||
user_id, cid = self._require_user_and_cid(payload)
|
||||
await self.conv_mgr.delete_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
return {"message": "对话删除成功"}
|
||||
|
||||
async def update_history(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
user_id, cid = self._require_user_and_cid(payload)
|
||||
history = payload.get("history")
|
||||
|
||||
if history is None:
|
||||
raise ConversationServiceError("缺少必要参数: history")
|
||||
|
||||
history = self._normalize_history(history)
|
||||
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
if not conversation:
|
||||
raise ConversationServiceError("对话不存在")
|
||||
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
history=history,
|
||||
)
|
||||
|
||||
return {"message": "对话历史更新成功"}
|
||||
|
||||
async def export_conversations(self, data: object) -> ConversationExport:
|
||||
payload = self._payload(data)
|
||||
conversations_to_export = payload.get("conversations", [])
|
||||
|
||||
if not conversations_to_export:
|
||||
raise ConversationServiceError("导出列表不能为空")
|
||||
|
||||
jsonl_lines = []
|
||||
exported_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv_info in conversations_to_export:
|
||||
user_id = conv_info.get("user_id")
|
||||
cid = conv_info.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - 缺少必要参数")
|
||||
continue
|
||||
|
||||
try:
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - 对话不存在")
|
||||
continue
|
||||
|
||||
content = json.loads(conversation.history)
|
||||
export_record = {
|
||||
"cid": cid,
|
||||
"user_id": user_id,
|
||||
"platform_id": conversation.platform_id,
|
||||
"title": conversation.title,
|
||||
"persona_id": conversation.persona_id,
|
||||
"created_at": conversation.created_at,
|
||||
"updated_at": conversation.updated_at,
|
||||
"content": content,
|
||||
}
|
||||
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
||||
exported_count += 1
|
||||
except Exception as exc:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {exc!s}")
|
||||
logger.error(
|
||||
f"导出对话失败: user_id={user_id}, cid={cid}, error={exc!s}"
|
||||
)
|
||||
|
||||
if exported_count == 0:
|
||||
raise ConversationServiceError("没有成功导出任何对话")
|
||||
|
||||
jsonl_content = "\n".join(jsonl_lines)
|
||||
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
||||
file_obj.seek(0)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return ConversationExport(
|
||||
file_obj=file_obj,
|
||||
filename=f"astrbot_conversations_export_{timestamp}.jsonl",
|
||||
)
|
||||
|
||||
async def _delete_conversations(self, conversations: object) -> dict:
|
||||
if not conversations:
|
||||
raise ConversationServiceError("批量删除时conversations参数不能为空")
|
||||
|
||||
deleted_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv in conversations:
|
||||
user_id = conv.get("user_id")
|
||||
cid = conv.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - 缺少必要参数")
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.conv_mgr.delete_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
deleted_count += 1
|
||||
except Exception as exc:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {exc!s}")
|
||||
|
||||
message = f"成功删除 {deleted_count} 个对话"
|
||||
if failed_items:
|
||||
message += f",失败 {len(failed_items)} 个"
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"deleted_count": deleted_count,
|
||||
"failed_count": len(failed_items),
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
|
||||
def _serialize_conversation(self, conversation, alias_map: dict) -> dict:
|
||||
return {
|
||||
**asdict(conversation),
|
||||
"umo_info": self._build_umo_info(conversation.user_id, alias_map),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_umo_info(umo: str | None, alias_map: dict) -> dict:
|
||||
umo_str = umo or ""
|
||||
return {
|
||||
"umo": umo_str,
|
||||
**parse_umo(umo_str),
|
||||
**serialize_umo_alias(alias_map.get(umo_str), umo_str),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _require_user_and_cid(payload: dict) -> tuple[str, str]:
|
||||
user_id = payload.get("user_id")
|
||||
cid = payload.get("cid")
|
||||
if not user_id or not cid:
|
||||
raise ConversationServiceError("缺少必要参数: user_id 和 cid")
|
||||
return user_id, cid
|
||||
|
||||
@staticmethod
|
||||
def _normalize_history(history):
|
||||
try:
|
||||
if isinstance(history, list):
|
||||
history = json.dumps(history)
|
||||
else:
|
||||
json.loads(history)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ConversationServiceError(
|
||||
"history 必须是有效的 JSON 字符串或数组"
|
||||
) from exc
|
||||
|
||||
return json.loads(history) if isinstance(history, str) else history
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict:
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
285
astrbot/dashboard/services/cron_service.py
Normal file
285
astrbot/dashboard/services/cron_service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
class CronServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CronService:
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _get_cron_manager(self):
|
||||
cron_mgr = self.core_lifecycle.cron_manager
|
||||
if cron_mgr is None:
|
||||
raise CronServiceError("Cron manager not initialized")
|
||||
return cron_mgr
|
||||
|
||||
@staticmethod
|
||||
def serialize_job(job) -> dict:
|
||||
data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__
|
||||
for key in ["created_at", "updated_at", "last_run_at", "next_run_time"]:
|
||||
value = data.get(key)
|
||||
if isinstance(value, datetime):
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
data[key] = value.isoformat()
|
||||
|
||||
payload = data.get("payload") or {}
|
||||
data["note"] = payload.get("note") or data.get("description") or ""
|
||||
data["run_at"] = payload.get("run_at")
|
||||
data["run_once"] = data.get("run_once", False)
|
||||
data.pop("status", None)
|
||||
return data
|
||||
|
||||
async def list_jobs(self, job_type: str | None = None) -> list[dict]:
|
||||
try:
|
||||
cron_mgr = self._get_cron_manager()
|
||||
jobs = await cron_mgr.list_jobs(job_type)
|
||||
return [self.serialize_job(job) for job in jobs]
|
||||
except CronServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise CronServiceError(f"Failed to list jobs: {exc!s}") from exc
|
||||
|
||||
async def list_jobs_from_legacy_query(
|
||||
self,
|
||||
job_type: str | None,
|
||||
) -> list[dict]:
|
||||
return await self.list_jobs(job_type)
|
||||
|
||||
async def create_job(self, payload: object) -> dict:
|
||||
try:
|
||||
cron_mgr = self._get_cron_manager()
|
||||
if not isinstance(payload, dict):
|
||||
raise CronServiceError("Invalid payload")
|
||||
|
||||
name = payload.get("name") or "active_agent_task"
|
||||
cron_expression = payload.get("cron_expression")
|
||||
note = payload.get("note") or payload.get("description") or name
|
||||
session = str(payload.get("session") or "").strip()
|
||||
persona_id = payload.get("persona_id")
|
||||
provider_id = payload.get("provider_id")
|
||||
timezone_name = payload.get("timezone")
|
||||
enabled = bool(payload.get("enabled", True))
|
||||
run_once = bool(payload.get("run_once", False))
|
||||
run_at = payload.get("run_at")
|
||||
|
||||
if run_once and not run_at:
|
||||
raise CronServiceError("run_at is required when run_once=true")
|
||||
if (not run_once) and not cron_expression:
|
||||
raise CronServiceError(
|
||||
"cron_expression is required when run_once=false"
|
||||
)
|
||||
if run_once and cron_expression:
|
||||
cron_expression = None
|
||||
|
||||
run_at_dt = self._parse_optional_run_at(run_at)
|
||||
job_payload = {
|
||||
"session": session,
|
||||
"note": note,
|
||||
"persona_id": persona_id,
|
||||
"provider_id": provider_id,
|
||||
"run_at": run_at,
|
||||
"origin": "api",
|
||||
}
|
||||
|
||||
job = await cron_mgr.add_active_job(
|
||||
name=name,
|
||||
cron_expression=cron_expression,
|
||||
payload=job_payload,
|
||||
description=note,
|
||||
timezone=timezone_name,
|
||||
enabled=enabled,
|
||||
run_once=run_once,
|
||||
run_at=run_at_dt,
|
||||
)
|
||||
return self.serialize_job(job)
|
||||
except CronServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise CronServiceError(f"Failed to create job: {exc!s}") from exc
|
||||
|
||||
async def update_job(self, job_id: str, payload: object) -> dict:
|
||||
try:
|
||||
cron_mgr = self._get_cron_manager()
|
||||
if not isinstance(payload, dict):
|
||||
raise CronServiceError("Invalid payload")
|
||||
|
||||
job = await cron_mgr.db.get_cron_job(job_id)
|
||||
if not job:
|
||||
raise CronServiceError("Job not found")
|
||||
|
||||
updates = {}
|
||||
if "name" in payload:
|
||||
name = str(payload.get("name") or "").strip()
|
||||
if not name:
|
||||
raise CronServiceError("name cannot be empty")
|
||||
updates["name"] = name
|
||||
|
||||
if "enabled" in payload:
|
||||
updates["enabled"] = bool(payload.get("enabled"))
|
||||
|
||||
if "timezone" in payload:
|
||||
timezone_name = payload.get("timezone")
|
||||
updates["timezone"] = str(timezone_name).strip() or None
|
||||
|
||||
if job.job_type == "active_agent":
|
||||
self._merge_active_agent_updates(job, payload, updates)
|
||||
else:
|
||||
self._merge_generic_updates(payload, updates)
|
||||
|
||||
updated_job = await cron_mgr.update_job(job_id, **updates)
|
||||
if not updated_job:
|
||||
raise CronServiceError("Job not found")
|
||||
return self.serialize_job(updated_job)
|
||||
except CronServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise CronServiceError(f"Failed to update job: {exc!s}") from exc
|
||||
|
||||
async def delete_job(self, job_id: str) -> None:
|
||||
try:
|
||||
cron_mgr = self._get_cron_manager()
|
||||
await cron_mgr.delete_job(job_id)
|
||||
except CronServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise CronServiceError(f"Failed to delete job: {exc!s}") from exc
|
||||
|
||||
async def run_job_now(self, job_id: str) -> None:
|
||||
try:
|
||||
cron_mgr = self._get_cron_manager()
|
||||
job = await cron_mgr.db.get_cron_job(job_id)
|
||||
if not job:
|
||||
raise CronServiceError("Job not found")
|
||||
task = asyncio.create_task(cron_mgr.run_job_now(job_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except CronServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise CronServiceError(f"Failed to run job: {exc!s}") from exc
|
||||
|
||||
@staticmethod
|
||||
def _parse_optional_run_at(run_at: object) -> datetime | None:
|
||||
if not run_at:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(str(run_at))
|
||||
except Exception as exc:
|
||||
raise CronServiceError("run_at must be ISO datetime") from exc
|
||||
|
||||
@staticmethod
|
||||
def _normalize_run_at_iso(run_at: object) -> str | None:
|
||||
if not run_at:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(str(run_at)).isoformat()
|
||||
except Exception as exc:
|
||||
raise CronServiceError("run_at must be ISO datetime") from exc
|
||||
|
||||
def _merge_active_agent_updates(self, job, payload: dict, updates: dict) -> None:
|
||||
merged_payload = dict(job.payload) if isinstance(job.payload, dict) else {}
|
||||
if "payload" in payload and isinstance(payload.get("payload"), dict):
|
||||
merged_payload.update(payload["payload"])
|
||||
|
||||
if "session" in payload:
|
||||
session = str(payload.get("session") or "").strip()
|
||||
if session:
|
||||
merged_payload["session"] = session
|
||||
else:
|
||||
merged_payload.pop("session", None)
|
||||
|
||||
self._merge_note(payload, job, merged_payload, updates)
|
||||
|
||||
next_run_once = (
|
||||
bool(payload.get("run_once"))
|
||||
if "run_once" in payload
|
||||
else bool(job.run_once)
|
||||
)
|
||||
next_cron_expression = (
|
||||
payload.get("cron_expression")
|
||||
if "cron_expression" in payload
|
||||
else job.cron_expression
|
||||
)
|
||||
if next_cron_expression is not None:
|
||||
next_cron_expression = str(next_cron_expression).strip() or None
|
||||
|
||||
run_at_raw = (
|
||||
payload.get("run_at")
|
||||
if "run_at" in payload
|
||||
else merged_payload.get("run_at")
|
||||
)
|
||||
run_at_iso = self._normalize_run_at_iso(run_at_raw)
|
||||
|
||||
if next_run_once:
|
||||
if not run_at_iso:
|
||||
raise CronServiceError("run_at is required when run_once=true")
|
||||
next_cron_expression = None
|
||||
merged_payload["run_at"] = run_at_iso
|
||||
else:
|
||||
if not next_cron_expression:
|
||||
raise CronServiceError(
|
||||
"cron_expression is required when run_once=false"
|
||||
)
|
||||
merged_payload.pop("run_at", None)
|
||||
|
||||
updates["run_once"] = next_run_once
|
||||
updates["cron_expression"] = next_cron_expression
|
||||
updates["payload"] = merged_payload
|
||||
|
||||
@staticmethod
|
||||
def _merge_note(
|
||||
payload: dict,
|
||||
job,
|
||||
merged_payload: dict,
|
||||
updates: dict,
|
||||
) -> None:
|
||||
note_updated = False
|
||||
if "note" in payload:
|
||||
note = str(payload.get("note") or "").strip()
|
||||
if not note:
|
||||
raise CronServiceError("note cannot be empty")
|
||||
merged_payload["note"] = note
|
||||
updates["description"] = note
|
||||
note_updated = True
|
||||
elif "description" in payload:
|
||||
description = str(payload.get("description") or "").strip()
|
||||
if not description:
|
||||
raise CronServiceError("description cannot be empty")
|
||||
updates["description"] = description
|
||||
merged_payload["note"] = description
|
||||
note_updated = True
|
||||
|
||||
if not note_updated and updates.get("description") is None:
|
||||
existing_note = str(
|
||||
merged_payload.get("note") or job.description or ""
|
||||
).strip()
|
||||
if existing_note:
|
||||
merged_payload["note"] = existing_note
|
||||
|
||||
@staticmethod
|
||||
def _merge_generic_updates(payload: dict, updates: dict) -> None:
|
||||
if "cron_expression" in payload:
|
||||
cron_expression = str(payload.get("cron_expression") or "").strip()
|
||||
if not cron_expression:
|
||||
raise CronServiceError("cron_expression cannot be empty")
|
||||
updates["cron_expression"] = cron_expression
|
||||
|
||||
if "description" in payload:
|
||||
description = str(payload.get("description") or "").strip()
|
||||
updates["description"] = description or None
|
||||
15
astrbot/dashboard/services/file_service.py
Normal file
15
astrbot/dashboard/services/file_service.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from astrbot.core import file_token_service
|
||||
|
||||
|
||||
class FileServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileService:
|
||||
async def resolve_token_file(self, file_token: str) -> str:
|
||||
try:
|
||||
return await file_token_service.handle_file(file_token)
|
||||
except (FileNotFoundError, KeyError) as exc:
|
||||
raise FileServiceError(str(exc)) from exc
|
||||
863
astrbot/dashboard/services/knowledge_base_service.py
Normal file
863
astrbot/dashboard/services/knowledge_base_service.py
Normal file
@@ -0,0 +1,863 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.dashboard.utils import generate_tsne_visualization
|
||||
|
||||
|
||||
class KnowledgeBaseServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KnowledgeBaseService:
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.upload_progress: dict[str, dict[str, Any]] = {}
|
||||
self.upload_tasks: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict[str, Any]:
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
def get_kb_manager(self):
|
||||
return self.core_lifecycle.kb_manager
|
||||
|
||||
def init_task(self, task_id: str, status: str = "pending") -> None:
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": status,
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
def set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
result: Any = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
}
|
||||
if task_id in self.upload_progress:
|
||||
self.upload_progress[task_id]["status"] = status
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
file_index: int | None = None,
|
||||
file_name: str | None = None,
|
||||
stage: str | None = None,
|
||||
current: int | None = None,
|
||||
total: int | None = None,
|
||||
) -> None:
|
||||
if task_id not in self.upload_progress:
|
||||
return
|
||||
progress = self.upload_progress[task_id]
|
||||
if status is not None:
|
||||
progress["status"] = status
|
||||
if file_index is not None:
|
||||
progress["file_index"] = file_index
|
||||
if file_name is not None:
|
||||
progress["file_name"] = file_name
|
||||
if stage is not None:
|
||||
progress["stage"] = stage
|
||||
if current is not None:
|
||||
progress["current"] = current
|
||||
if total is not None:
|
||||
progress["total"] = total
|
||||
|
||||
def make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
|
||||
async def _callback(stage: str, current: int, total: int) -> None:
|
||||
self.update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
file_index=file_idx,
|
||||
file_name=file_name,
|
||||
stage=stage,
|
||||
current=current,
|
||||
total=total,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
@staticmethod
|
||||
def format_failed_doc_error(file_name: str, error: Exception) -> str:
|
||||
message = str(error).strip() or "上传失败:发生未知错误。"
|
||||
if message.startswith(file_name):
|
||||
return message
|
||||
return f"{file_name}: {message}"
|
||||
|
||||
async def background_upload_task(
|
||||
self,
|
||||
task_id: str,
|
||||
kb_helper,
|
||||
files_to_upload: list[dict[str, Any]],
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
batch_size: int,
|
||||
tasks_limit: int,
|
||||
max_retries: int,
|
||||
) -> None:
|
||||
try:
|
||||
self.init_task(task_id, status="processing")
|
||||
self.upload_progress[task_id] = {
|
||||
"status": "processing",
|
||||
"file_index": 0,
|
||||
"file_total": len(files_to_upload),
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
}
|
||||
|
||||
uploaded_docs = []
|
||||
failed_docs = []
|
||||
|
||||
for file_idx, file_info in enumerate(files_to_upload):
|
||||
try:
|
||||
self.update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
file_index=file_idx,
|
||||
file_name=file_info["file_name"],
|
||||
stage="parsing",
|
||||
current=0,
|
||||
total=100,
|
||||
)
|
||||
progress_callback = self.make_progress_callback(
|
||||
task_id, file_idx, file_info["file_name"]
|
||||
)
|
||||
doc = await kb_helper.upload_document(
|
||||
file_name=file_info["file_name"],
|
||||
file_content=file_info["file_content"],
|
||||
file_type=file_info["file_type"],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
uploaded_docs.append(doc.model_dump())
|
||||
except Exception as exc:
|
||||
logger.error(f"上传文档 {file_info['file_name']} 失败: {exc}")
|
||||
failed_docs.append(
|
||||
{
|
||||
"file_name": file_info["file_name"],
|
||||
"error": self.format_failed_doc_error(
|
||||
file_info["file_name"], exc
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
self.set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"task_id": task_id,
|
||||
"uploaded": uploaded_docs,
|
||||
"failed": failed_docs,
|
||||
"total": len(files_to_upload),
|
||||
"success_count": len(uploaded_docs),
|
||||
"failed_count": len(failed_docs),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"后台上传任务 {task_id} 失败: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.set_task_result(task_id, "failed", error=str(exc))
|
||||
|
||||
async def background_import_task(
|
||||
self,
|
||||
task_id: str,
|
||||
kb_helper,
|
||||
documents: list[dict[str, Any]],
|
||||
batch_size: int,
|
||||
tasks_limit: int,
|
||||
max_retries: int,
|
||||
) -> None:
|
||||
try:
|
||||
self.init_task(task_id, status="processing")
|
||||
self.upload_progress[task_id] = {
|
||||
"status": "processing",
|
||||
"file_index": 0,
|
||||
"file_total": len(documents),
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
}
|
||||
|
||||
uploaded_docs = []
|
||||
failed_docs = []
|
||||
|
||||
for file_idx, doc_info in enumerate(documents):
|
||||
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
|
||||
chunks = doc_info.get("chunks", [])
|
||||
|
||||
try:
|
||||
self.update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
file_index=file_idx,
|
||||
file_name=file_name,
|
||||
stage="importing",
|
||||
current=0,
|
||||
total=100,
|
||||
)
|
||||
progress_callback = self.make_progress_callback(
|
||||
task_id, file_idx, file_name
|
||||
)
|
||||
doc = await kb_helper.upload_document(
|
||||
file_name=file_name,
|
||||
file_content=None,
|
||||
file_type=doc_info.get("file_type")
|
||||
or (
|
||||
file_name.rsplit(".", 1)[-1].lower()
|
||||
if "." in file_name
|
||||
else "txt"
|
||||
),
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
pre_chunked_text=chunks,
|
||||
)
|
||||
uploaded_docs.append(doc.model_dump())
|
||||
except Exception as exc:
|
||||
logger.error(f"导入文档 {file_name} 失败: {exc}")
|
||||
failed_docs.append(
|
||||
{
|
||||
"file_name": file_name,
|
||||
"error": self.format_failed_doc_error(file_name, exc),
|
||||
},
|
||||
)
|
||||
|
||||
self.set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"task_id": task_id,
|
||||
"uploaded": uploaded_docs,
|
||||
"failed": failed_docs,
|
||||
"total": len(documents),
|
||||
"success_count": len(uploaded_docs),
|
||||
"failed_count": len(failed_docs),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"后台导入任务 {task_id} 失败: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.set_task_result(task_id, "failed", error=str(exc))
|
||||
|
||||
async def list_kbs(self, *, page: int, page_size: int) -> dict[str, Any]:
|
||||
kb_manager = self.get_kb_manager()
|
||||
kbs = await kb_manager.list_kbs()
|
||||
|
||||
kb_list = []
|
||||
for kb in kbs:
|
||||
kb_dict = kb.model_dump()
|
||||
kb_helper = await kb_manager.get_kb(kb.kb_id)
|
||||
if kb_helper and kb_helper.init_error:
|
||||
kb_dict["init_error"] = kb_helper.init_error
|
||||
kb_list.append(kb_dict)
|
||||
|
||||
return {"items": kb_list, "page": page, "page_size": page_size}
|
||||
|
||||
async def list_kbs_from_legacy_query(self, *, page, page_size) -> dict[str, Any]:
|
||||
return await self.list_kbs(
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 20),
|
||||
)
|
||||
|
||||
async def create_kb(self, data: object) -> tuple[dict[str, Any], str]:
|
||||
kb_manager = self.get_kb_manager()
|
||||
payload = self._payload(data)
|
||||
kb_name = payload.get("kb_name")
|
||||
if not kb_name:
|
||||
raise KnowledgeBaseServiceError("知识库名称不能为空")
|
||||
|
||||
embedding_provider_id = payload.get("embedding_provider_id")
|
||||
rerank_provider_id = payload.get("rerank_provider_id")
|
||||
|
||||
if not embedding_provider_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 embedding_provider_id")
|
||||
provider = await kb_manager.provider_manager.get_provider_by_id(
|
||||
embedding_provider_id,
|
||||
)
|
||||
if not provider or not isinstance(provider, EmbeddingProvider):
|
||||
raise KnowledgeBaseServiceError(
|
||||
f"嵌入模型不存在或类型错误({type(provider)})"
|
||||
)
|
||||
try:
|
||||
vec = await provider.get_embedding("astrbot")
|
||||
if len(vec) != provider.get_dim():
|
||||
raise ValueError(
|
||||
f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {provider.get_dim()}",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise KnowledgeBaseServiceError(f"测试嵌入模型失败: {exc!s}") from exc
|
||||
|
||||
if rerank_provider_id:
|
||||
rerank_provider: RerankProvider = (
|
||||
await kb_manager.provider_manager.get_provider_by_id(
|
||||
rerank_provider_id,
|
||||
)
|
||||
)
|
||||
if not rerank_provider:
|
||||
raise KnowledgeBaseServiceError("重排序模型不存在")
|
||||
try:
|
||||
result = await rerank_provider.rerank(
|
||||
query="astrbot",
|
||||
documents=["astrbot knowledge base"],
|
||||
)
|
||||
if not result:
|
||||
raise ValueError("重排序模型返回结果异常")
|
||||
except Exception as exc:
|
||||
raise KnowledgeBaseServiceError(
|
||||
f"测试重排序模型失败: {exc!s},请检查平台日志输出。"
|
||||
) from exc
|
||||
|
||||
kb_helper = await kb_manager.create_kb(
|
||||
kb_name=kb_name,
|
||||
description=payload.get("description"),
|
||||
emoji=payload.get("emoji"),
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=payload.get("chunk_size"),
|
||||
chunk_overlap=payload.get("chunk_overlap"),
|
||||
top_k_dense=payload.get("top_k_dense"),
|
||||
top_k_sparse=payload.get("top_k_sparse"),
|
||||
top_m_final=payload.get("top_m_final"),
|
||||
)
|
||||
return kb_helper.kb.model_dump(), "创建知识库成功"
|
||||
|
||||
async def get_kb(self, kb_id: str | None) -> dict[str, Any]:
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
return kb_helper.kb.model_dump()
|
||||
|
||||
async def get_kb_from_legacy_query(self, kb_id: str | None) -> dict[str, Any]:
|
||||
return await self.get_kb(kb_id)
|
||||
|
||||
async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
|
||||
payload = self._payload(data)
|
||||
kb_id = payload.get("kb_id")
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
|
||||
update_keys = [
|
||||
"kb_name",
|
||||
"description",
|
||||
"emoji",
|
||||
"embedding_provider_id",
|
||||
"rerank_provider_id",
|
||||
"chunk_size",
|
||||
"chunk_overlap",
|
||||
"top_k_dense",
|
||||
"top_k_sparse",
|
||||
"top_m_final",
|
||||
]
|
||||
if all(payload.get(key) is None for key in update_keys):
|
||||
raise KnowledgeBaseServiceError("至少需要提供一个更新字段")
|
||||
|
||||
kb_helper = await self.get_kb_manager().update_kb(
|
||||
kb_id=kb_id,
|
||||
kb_name=payload.get("kb_name"),
|
||||
description=payload.get("description"),
|
||||
emoji=payload.get("emoji"),
|
||||
embedding_provider_id=payload.get("embedding_provider_id"),
|
||||
rerank_provider_id=payload.get("rerank_provider_id"),
|
||||
chunk_size=payload.get("chunk_size"),
|
||||
chunk_overlap=payload.get("chunk_overlap"),
|
||||
top_k_dense=payload.get("top_k_dense"),
|
||||
top_k_sparse=payload.get("top_k_sparse"),
|
||||
top_m_final=payload.get("top_m_final"),
|
||||
)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
return kb_helper.kb.model_dump(), "更新知识库成功"
|
||||
|
||||
async def delete_kb(self, data: object) -> tuple[None, str]:
|
||||
payload = self._payload(data)
|
||||
kb_id = payload.get("kb_id")
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
success = await self.get_kb_manager().delete_kb(kb_id)
|
||||
if not success:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
return None, "删除知识库成功"
|
||||
|
||||
async def get_kb_stats(self, kb_id: str | None) -> dict[str, Any]:
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
kb = kb_helper.kb
|
||||
return {
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"doc_count": kb.doc_count,
|
||||
"chunk_count": kb.chunk_count,
|
||||
"created_at": kb.created_at.isoformat(),
|
||||
"updated_at": kb.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
async def get_kb_stats_from_legacy_query(
|
||||
self,
|
||||
kb_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
return await self.get_kb_stats(kb_id)
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
*,
|
||||
kb_id: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> dict[str, Any]:
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
doc_list = await kb_helper.list_documents(offset=offset, limit=page_size)
|
||||
return {
|
||||
"items": [doc.model_dump() for doc in doc_list],
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
async def list_documents_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
kb_id: str | None,
|
||||
page,
|
||||
page_size,
|
||||
) -> dict[str, Any]:
|
||||
return await self.list_documents(
|
||||
kb_id=kb_id,
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 100),
|
||||
)
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
*,
|
||||
content_type: str | None,
|
||||
form_data,
|
||||
files,
|
||||
) -> dict[str, Any]:
|
||||
if content_type and "multipart/form-data" not in content_type:
|
||||
raise KnowledgeBaseServiceError("Content-Type 须为 multipart/form-data")
|
||||
|
||||
kb_id = form_data.get("kb_id")
|
||||
chunk_size = int(form_data.get("chunk_size", 512))
|
||||
chunk_overlap = int(form_data.get("chunk_overlap", 50))
|
||||
batch_size = int(form_data.get("batch_size", 32))
|
||||
tasks_limit = int(form_data.get("tasks_limit", 3))
|
||||
max_retries = int(form_data.get("max_retries", 3))
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
|
||||
file_list = []
|
||||
for key in files.keys():
|
||||
if key == "file" or key.startswith("file") or key == "files[]":
|
||||
file_list.extend(files.getlist(key))
|
||||
if not file_list:
|
||||
raise KnowledgeBaseServiceError("缺少文件")
|
||||
if len(file_list) > 10:
|
||||
raise KnowledgeBaseServiceError("最多只能上传10个文件")
|
||||
|
||||
files_to_upload = []
|
||||
for file in file_list:
|
||||
file_name = file.filename
|
||||
temp_file_path = (
|
||||
Path(get_astrbot_temp_path()) / f"kb_upload_{uuid.uuid4()}_{file_name}"
|
||||
)
|
||||
await file.save(temp_file_path)
|
||||
try:
|
||||
async with aiofiles.open(temp_file_path, "rb") as file_obj:
|
||||
file_content = await file_obj.read()
|
||||
file_type = (
|
||||
file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
|
||||
)
|
||||
files_to_upload.append(
|
||||
{
|
||||
"file_name": file_name,
|
||||
"file_content": file_content,
|
||||
"file_type": file_type,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
temp_file_path.unlink(missing_ok=True)
|
||||
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
self.init_task(task_id, status="pending")
|
||||
asyncio.create_task(
|
||||
self.background_upload_task(
|
||||
task_id=task_id,
|
||||
kb_helper=kb_helper,
|
||||
files_to_upload=files_to_upload,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"file_count": len(files_to_upload),
|
||||
"message": "task created, processing in background",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_import_request(data: dict[str, Any]):
|
||||
kb_id = data.get("kb_id")
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
|
||||
documents = data.get("documents")
|
||||
if not documents or not isinstance(documents, list):
|
||||
raise KnowledgeBaseServiceError("缺少参数 documents 或格式错误")
|
||||
|
||||
for doc in documents:
|
||||
if (
|
||||
not isinstance(doc, dict)
|
||||
or "file_name" not in doc
|
||||
or "chunks" not in doc
|
||||
):
|
||||
raise KnowledgeBaseServiceError(
|
||||
"文档格式错误,必须包含 file_name 和 chunks"
|
||||
)
|
||||
if not isinstance(doc["chunks"], list):
|
||||
raise KnowledgeBaseServiceError("chunks 必须是列表")
|
||||
if not all(
|
||||
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
|
||||
):
|
||||
raise KnowledgeBaseServiceError("chunks 必须是非空字符串列表")
|
||||
|
||||
return (
|
||||
kb_id,
|
||||
documents,
|
||||
data.get("batch_size", 32),
|
||||
data.get("tasks_limit", 3),
|
||||
data.get("max_retries", 3),
|
||||
)
|
||||
|
||||
async def import_documents(self, data: object) -> dict[str, Any]:
|
||||
payload = self._payload(data)
|
||||
kb_id, documents, batch_size, tasks_limit, max_retries = (
|
||||
self.validate_import_request(payload)
|
||||
)
|
||||
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
self.init_task(task_id, status="pending")
|
||||
asyncio.create_task(
|
||||
self.background_import_task(
|
||||
task_id=task_id,
|
||||
kb_helper=kb_helper,
|
||||
documents=documents,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"doc_count": len(documents),
|
||||
"message": "import task created, processing in background",
|
||||
}
|
||||
|
||||
def get_upload_progress(self, task_id: str | None) -> dict[str, Any]:
|
||||
if not task_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 task_id")
|
||||
if task_id not in self.upload_tasks:
|
||||
raise KnowledgeBaseServiceError("找不到该任务")
|
||||
|
||||
task_info = self.upload_tasks[task_id]
|
||||
status = task_info["status"]
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
}
|
||||
if status == "processing" and task_id in self.upload_progress:
|
||||
response_data["progress"] = self.upload_progress[task_id]
|
||||
if status == "completed":
|
||||
response_data["result"] = task_info["result"]
|
||||
if status == "failed":
|
||||
response_data["error"] = task_info["error"]
|
||||
return response_data
|
||||
|
||||
def get_upload_progress_from_legacy_query(
|
||||
self,
|
||||
task_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
return self.get_upload_progress(task_id)
|
||||
|
||||
async def get_document(
|
||||
self,
|
||||
*,
|
||||
kb_id: str | None,
|
||||
doc_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
if not doc_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 doc_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
doc = await kb_helper.get_document(doc_id)
|
||||
if not doc:
|
||||
raise KnowledgeBaseServiceError("文档不存在")
|
||||
return doc.model_dump()
|
||||
|
||||
async def get_document_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
kb_id: str | None,
|
||||
doc_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
return await self.get_document(kb_id=kb_id, doc_id=doc_id)
|
||||
|
||||
async def delete_document(self, data: object) -> tuple[None, str]:
|
||||
payload = self._payload(data)
|
||||
kb_id = payload.get("kb_id")
|
||||
doc_id = payload.get("doc_id")
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
if not doc_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 doc_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
await kb_helper.delete_document(doc_id)
|
||||
return None, "删除文档成功"
|
||||
|
||||
async def delete_chunk(self, data: object) -> tuple[None, str]:
|
||||
payload = self._payload(data)
|
||||
kb_id = payload.get("kb_id")
|
||||
chunk_id = payload.get("chunk_id")
|
||||
doc_id = payload.get("doc_id")
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
if not chunk_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 chunk_id")
|
||||
if not doc_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 doc_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
await kb_helper.delete_chunk(chunk_id, doc_id)
|
||||
return None, "删除文本块成功"
|
||||
|
||||
async def list_chunks(
|
||||
self,
|
||||
*,
|
||||
kb_id: str | None,
|
||||
doc_id: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> dict[str, Any]:
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
if not doc_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 doc_id")
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
return {
|
||||
"items": await kb_helper.get_chunks_by_doc_id(
|
||||
doc_id=doc_id,
|
||||
offset=offset,
|
||||
limit=page_size,
|
||||
),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": await kb_helper.get_chunk_count_by_doc_id(doc_id),
|
||||
}
|
||||
|
||||
async def list_chunks_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
kb_id: str | None,
|
||||
doc_id: str | None,
|
||||
page,
|
||||
page_size,
|
||||
) -> dict[str, Any]:
|
||||
return await self.list_chunks(
|
||||
kb_id=kb_id,
|
||||
doc_id=doc_id,
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 100),
|
||||
)
|
||||
|
||||
async def retrieve(self, data: object) -> dict[str, Any]:
|
||||
payload = self._payload(data)
|
||||
query = payload.get("query")
|
||||
kb_names = payload.get("kb_names")
|
||||
debug = payload.get("debug", False)
|
||||
|
||||
if not query:
|
||||
raise KnowledgeBaseServiceError("缺少参数 query")
|
||||
if not kb_names or not isinstance(kb_names, list):
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误")
|
||||
|
||||
top_k = payload.get("top_k", 5)
|
||||
kb_manager = self.get_kb_manager()
|
||||
results = await kb_manager.retrieve(
|
||||
query=query,
|
||||
kb_names=kb_names,
|
||||
top_m_final=top_k,
|
||||
)
|
||||
result_list = results["results"] if results else []
|
||||
response_data = {
|
||||
"results": result_list,
|
||||
"total": len(result_list),
|
||||
"query": query,
|
||||
}
|
||||
|
||||
if debug:
|
||||
try:
|
||||
img_base64 = await generate_tsne_visualization(
|
||||
query,
|
||||
kb_names,
|
||||
kb_manager,
|
||||
)
|
||||
if img_base64:
|
||||
response_data["visualization"] = img_base64
|
||||
except Exception as exc:
|
||||
logger.error(f"生成 t-SNE 可视化失败: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
response_data["visualization_error"] = str(exc)
|
||||
|
||||
return response_data
|
||||
|
||||
async def upload_document_from_url(self, data: object) -> dict[str, Any]:
|
||||
payload = self._payload(data)
|
||||
kb_id = payload.get("kb_id")
|
||||
if not kb_id:
|
||||
raise KnowledgeBaseServiceError("缺少参数 kb_id")
|
||||
url = payload.get("url")
|
||||
if not url:
|
||||
raise KnowledgeBaseServiceError("缺少参数 url")
|
||||
|
||||
kb_helper = await self.get_kb_manager().get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise KnowledgeBaseServiceError("知识库不存在")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
self.init_task(task_id, status="pending")
|
||||
asyncio.create_task(
|
||||
self.background_upload_from_url_task(
|
||||
task_id=task_id,
|
||||
kb_helper=kb_helper,
|
||||
url=url,
|
||||
chunk_size=payload.get("chunk_size", 512),
|
||||
chunk_overlap=payload.get("chunk_overlap", 50),
|
||||
batch_size=payload.get("batch_size", 32),
|
||||
tasks_limit=payload.get("tasks_limit", 3),
|
||||
max_retries=payload.get("max_retries", 3),
|
||||
enable_cleaning=payload.get("enable_cleaning", False),
|
||||
cleaning_provider_id=payload.get("cleaning_provider_id"),
|
||||
),
|
||||
)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"url": url,
|
||||
"message": "URL upload task created, processing in background",
|
||||
}
|
||||
|
||||
async def background_upload_from_url_task(
|
||||
self,
|
||||
task_id: str,
|
||||
kb_helper,
|
||||
url: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
batch_size: int,
|
||||
tasks_limit: int,
|
||||
max_retries: int,
|
||||
enable_cleaning: bool,
|
||||
cleaning_provider_id: str | None,
|
||||
) -> None:
|
||||
try:
|
||||
self.init_task(task_id, status="processing")
|
||||
self.upload_progress[task_id] = {
|
||||
"status": "processing",
|
||||
"file_index": 0,
|
||||
"file_total": 1,
|
||||
"file_name": f"URL: {url}",
|
||||
"stage": "extracting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
}
|
||||
progress_callback = self.make_progress_callback(task_id, 0, f"URL: {url}")
|
||||
doc = await kb_helper.upload_from_url(
|
||||
url=url,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
enable_cleaning=enable_cleaning,
|
||||
cleaning_provider_id=cleaning_provider_id,
|
||||
)
|
||||
self.set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"task_id": task_id,
|
||||
"uploaded": [doc.model_dump()],
|
||||
"failed": [],
|
||||
"total": 1,
|
||||
"success_count": 1,
|
||||
"failed_count": 0,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"后台上传URL任务 {task_id} 失败: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.set_task_result(task_id, "failed", error=str(exc))
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
__all__ = ["KnowledgeBaseService", "KnowledgeBaseServiceError"]
|
||||
982
astrbot/dashboard/services/live_chat_service.py
Normal file
982
astrbot/dashboard/services/live_chat_service.py
Normal file
@@ -0,0 +1,982 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import wave
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_webchat_message_parts,
|
||||
create_attachment_part_from_existing_file,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
from astrbot.dashboard.services.chat_service import (
|
||||
BotMessageAccumulator,
|
||||
build_bot_history_content,
|
||||
collect_plain_text_from_message_parts,
|
||||
)
|
||||
|
||||
SendJson = Callable[[dict], Awaitable[None]]
|
||||
ReceiveJson = Callable[[], Awaitable[dict]]
|
||||
CloseWebSocket = Callable[[int, str], Awaitable[None]]
|
||||
|
||||
|
||||
class LiveChatAuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LiveChatSession:
|
||||
"""Live Chat 会话管理器"""
|
||||
|
||||
def __init__(self, session_id: str, username: str) -> None:
|
||||
self.session_id = session_id
|
||||
self.username = username
|
||||
self.conversation_id = str(uuid.uuid4())
|
||||
self.is_speaking = False
|
||||
self.is_processing = False
|
||||
self.should_interrupt = False
|
||||
self.audio_frames: list[bytes] = []
|
||||
self.current_stamp: str | None = None
|
||||
self.temp_audio_path: str | None = None
|
||||
self.chat_subscriptions: dict[str, str] = {}
|
||||
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
|
||||
self.ws_send_lock = asyncio.Lock()
|
||||
|
||||
def start_speaking(self, stamp: str) -> None:
|
||||
self.is_speaking = True
|
||||
self.current_stamp = stamp
|
||||
self.audio_frames = []
|
||||
logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}")
|
||||
|
||||
def add_audio_frame(self, data: bytes) -> None:
|
||||
if self.is_speaking:
|
||||
self.audio_frames.append(data)
|
||||
|
||||
async def end_speaking(self, stamp: str) -> tuple[str | None, float]:
|
||||
start_time = time.time()
|
||||
if not self.is_speaking or stamp != self.current_stamp:
|
||||
logger.warning(
|
||||
f"[Live Chat] stamp 不匹配或未在说话状态: {stamp} vs {self.current_stamp}"
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
self.is_speaking = False
|
||||
|
||||
if not self.audio_frames:
|
||||
logger.warning("[Live Chat] 没有音频帧数据")
|
||||
return None, 0.0
|
||||
|
||||
try:
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
|
||||
|
||||
with wave.open(audio_path, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(16000)
|
||||
for frame in self.audio_frames:
|
||||
wav_file.writeframes(frame)
|
||||
|
||||
self.temp_audio_path = audio_path
|
||||
logger.info(
|
||||
f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes"
|
||||
)
|
||||
return audio_path, time.time() - start_time
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 组装 WAV 文件失败: {exc}", exc_info=True)
|
||||
return None, 0.0
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self.temp_audio_path and os.path.exists(self.temp_audio_path):
|
||||
try:
|
||||
os.remove(self.temp_audio_path)
|
||||
logger.debug(f"[Live Chat] 已删除临时文件: {self.temp_audio_path}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Live Chat] 删除临时文件失败: {exc}")
|
||||
self.temp_audio_path = None
|
||||
|
||||
|
||||
class LiveChatService:
|
||||
def __init__(
|
||||
self,
|
||||
db: Any,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.plugin_manager = core_lifecycle.plugin_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
self.sessions: dict[str, LiveChatSession] = {}
|
||||
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
|
||||
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.attachments_dir, exist_ok=True)
|
||||
|
||||
def authenticate_token(
|
||||
self,
|
||||
token: str | None,
|
||||
jwt_secret: str | None = None,
|
||||
) -> str:
|
||||
if not token:
|
||||
raise LiveChatAuthError("Missing authentication token")
|
||||
jwt_secret = jwt_secret or self.config["dashboard"].get("jwt_secret")
|
||||
try:
|
||||
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
||||
return payload["username"]
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise LiveChatAuthError("Token expired") from exc
|
||||
except jwt.InvalidTokenError as exc:
|
||||
raise LiveChatAuthError("Invalid token") from exc
|
||||
|
||||
def create_session(self, username: str) -> LiveChatSession:
|
||||
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
|
||||
session = LiveChatSession(session_id, username)
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
async def cleanup_session(self, session: LiveChatSession) -> None:
|
||||
if session.session_id in self.sessions:
|
||||
await self.cleanup_chat_subscriptions(session)
|
||||
session.cleanup()
|
||||
del self.sessions[session.session_id]
|
||||
|
||||
async def run_websocket_session(
|
||||
self,
|
||||
*,
|
||||
token: str | None,
|
||||
force_ct: str | None,
|
||||
receive_json: ReceiveJson,
|
||||
send_json: SendJson,
|
||||
close: CloseWebSocket,
|
||||
) -> None:
|
||||
try:
|
||||
username = self.authenticate_token(token)
|
||||
except LiveChatAuthError as exc:
|
||||
await close(1008, str(exc))
|
||||
return
|
||||
|
||||
live_session = self.create_session(username)
|
||||
logger.info(f"[Live Chat] WebSocket 连接建立: {username}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await receive_json()
|
||||
ct = force_ct or message.get("ct", "live")
|
||||
if ct == "chat":
|
||||
await self.handle_chat_message(
|
||||
live_session,
|
||||
message,
|
||||
send_json,
|
||||
)
|
||||
else:
|
||||
await self.handle_live_message(
|
||||
live_session,
|
||||
message,
|
||||
send_json,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] WebSocket 错误: {exc}", exc_info=True)
|
||||
|
||||
finally:
|
||||
await self.cleanup_session(live_session)
|
||||
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
|
||||
|
||||
async def create_attachment_from_file(
|
||||
self, filename: str, attach_type: str
|
||||
) -> dict | None:
|
||||
return await create_attachment_part_from_existing_file(
|
||||
filename,
|
||||
attach_type=attach_type,
|
||||
insert_attachment=self.db.insert_attachment,
|
||||
attachments_dir=self.attachments_dir,
|
||||
fallback_dirs=[self.legacy_img_dir],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_web_search_refs(accumulated_text: str, accumulated_parts: list) -> dict:
|
||||
supported = [
|
||||
"web_search_baidu",
|
||||
"web_search_tavily",
|
||||
"web_search_bocha",
|
||||
"web_search_brave",
|
||||
]
|
||||
web_search_results = {}
|
||||
tool_call_parts = [
|
||||
p
|
||||
for p in accumulated_parts
|
||||
if p.get("type") == "tool_call" and p.get("tool_calls")
|
||||
]
|
||||
|
||||
for part in tool_call_parts:
|
||||
for tool_call in part["tool_calls"]:
|
||||
if tool_call.get("name") not in supported or not tool_call.get(
|
||||
"result"
|
||||
):
|
||||
continue
|
||||
try:
|
||||
result_data = json.loads(tool_call["result"])
|
||||
for item in result_data.get("results", []):
|
||||
if idx := item.get("index"):
|
||||
web_search_results[idx] = {
|
||||
"url": item.get("url"),
|
||||
"title": item.get("title"),
|
||||
"snippet": item.get("snippet"),
|
||||
}
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
if not web_search_results:
|
||||
return {}
|
||||
|
||||
ref_indices = {
|
||||
match.strip() for match in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
|
||||
}
|
||||
|
||||
used_refs = []
|
||||
for ref_index in ref_indices:
|
||||
if ref_index not in web_search_results:
|
||||
continue
|
||||
payload = {"index": ref_index, **web_search_results[ref_index]}
|
||||
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
payload["favicon"] = favicon
|
||||
used_refs.append(payload)
|
||||
|
||||
return {"used": used_refs} if used_refs else {}
|
||||
|
||||
async def save_bot_message(
|
||||
self,
|
||||
webchat_conv_id: str,
|
||||
message_parts: list[dict],
|
||||
agent_stats: dict,
|
||||
refs: dict,
|
||||
llm_checkpoint_id: str | None = None,
|
||||
):
|
||||
new_his = build_bot_history_content(
|
||||
message_parts,
|
||||
agent_stats=agent_stats,
|
||||
refs=refs,
|
||||
)
|
||||
|
||||
return await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
llm_checkpoint_id=llm_checkpoint_id,
|
||||
)
|
||||
|
||||
async def send_chat_payload(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
payload: dict,
|
||||
send_json: SendJson,
|
||||
) -> None:
|
||||
async with session.ws_send_lock:
|
||||
await send_json(payload)
|
||||
|
||||
async def forward_chat_subscription(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
chat_session_id: str,
|
||||
request_id: str,
|
||||
send_json: SendJson,
|
||||
) -> None:
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id, chat_session_id
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
result = await back_queue.get()
|
||||
if not result:
|
||||
continue
|
||||
await self.send_chat_payload(
|
||||
session, {"ct": "chat", **result}, send_json
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(request_id)
|
||||
if session.chat_subscriptions.get(chat_session_id) == request_id:
|
||||
session.chat_subscriptions.pop(chat_session_id, None)
|
||||
session.chat_subscription_tasks.pop(chat_session_id, None)
|
||||
|
||||
async def ensure_chat_subscription(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
chat_session_id: str,
|
||||
send_json: SendJson,
|
||||
) -> str:
|
||||
existing_request_id = session.chat_subscriptions.get(chat_session_id)
|
||||
existing_task = session.chat_subscription_tasks.get(chat_session_id)
|
||||
if existing_request_id and existing_task and not existing_task.done():
|
||||
return existing_request_id
|
||||
|
||||
request_id = f"ws_sub_{uuid.uuid4().hex}"
|
||||
session.chat_subscriptions[chat_session_id] = request_id
|
||||
task = asyncio.create_task(
|
||||
self.forward_chat_subscription(
|
||||
session,
|
||||
chat_session_id,
|
||||
request_id,
|
||||
send_json,
|
||||
),
|
||||
name=f"chat_ws_sub_{chat_session_id}",
|
||||
)
|
||||
session.chat_subscription_tasks[chat_session_id] = task
|
||||
return request_id
|
||||
|
||||
async def cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
|
||||
tasks = list(session.chat_subscription_tasks.values())
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for request_id in list(session.chat_subscriptions.values()):
|
||||
webchat_queue_mgr.remove_back_queue(request_id)
|
||||
session.chat_subscriptions.clear()
|
||||
session.chat_subscription_tasks.clear()
|
||||
|
||||
async def handle_chat_message(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
message: dict,
|
||||
send_json: SendJson,
|
||||
) -> None:
|
||||
msg_type = message.get("t")
|
||||
|
||||
if msg_type == "bind":
|
||||
chat_session_id = message.get("session_id")
|
||||
if not isinstance(chat_session_id, str) or not chat_session_id:
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "session_id is required",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
request_id = await self.ensure_chat_subscription(
|
||||
session, chat_session_id, send_json
|
||||
)
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "session_bound",
|
||||
"session_id": chat_session_id,
|
||||
"message_id": request_id,
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
if msg_type == "interrupt":
|
||||
session.should_interrupt = True
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "INTERRUPTED",
|
||||
"code": "INTERRUPTED",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
if msg_type != "send":
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": f"Unsupported message type: {msg_type}",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
if session.is_processing:
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "Session is busy",
|
||||
"code": "PROCESSING_ERROR",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
payload = message.get("message")
|
||||
session_id = message.get("session_id") or session.session_id
|
||||
message_id = message.get("message_id") or str(uuid.uuid4())
|
||||
selected_provider = message.get("selected_provider")
|
||||
selected_model = message.get("selected_model")
|
||||
selected_stt_provider = message.get("selected_stt_provider")
|
||||
selected_tts_provider = message.get("selected_tts_provider")
|
||||
persona_prompt = message.get("persona_prompt")
|
||||
show_reasoning = message.get("show_reasoning")
|
||||
enable_streaming = message.get("enable_streaming", True)
|
||||
|
||||
if not isinstance(payload, list):
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "message must be list",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
message_parts = await self.build_chat_message_parts(payload)
|
||||
has_content = webchat_message_parts_have_content(message_parts)
|
||||
if not has_content:
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "Message content is empty",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
return
|
||||
|
||||
await self.ensure_chat_subscription(session, session_id, send_json)
|
||||
|
||||
session.is_processing = True
|
||||
session.should_interrupt = False
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||
llm_checkpoint_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
pending_bot_message_flusher = None
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
session.username,
|
||||
session_id,
|
||||
{
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"selected_stt_provider": selected_stt_provider,
|
||||
"selected_tts_provider": selected_tts_provider,
|
||||
"persona_prompt": persona_prompt,
|
||||
"show_reasoning": show_reasoning,
|
||||
"enable_streaming": enable_streaming,
|
||||
"message_id": message_id,
|
||||
"llm_checkpoint_id": llm_checkpoint_id,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
saved_user_record = await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=session_id,
|
||||
content={"type": "user", "message": message_parts_for_storage},
|
||||
sender_id=session.username,
|
||||
sender_name=session.username,
|
||||
llm_checkpoint_id=llm_checkpoint_id,
|
||||
)
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "user_message_saved",
|
||||
"data": {
|
||||
"id": saved_user_record.id,
|
||||
"created_at": to_utc_isoformat(saved_user_record.created_at),
|
||||
"llm_checkpoint_id": llm_checkpoint_id,
|
||||
},
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
|
||||
async def flush_pending_bot_message():
|
||||
nonlocal message_accumulator, agent_stats, refs
|
||||
if not (message_accumulator.has_content() or refs or agent_stats):
|
||||
return None
|
||||
|
||||
message_parts_to_save = message_accumulator.build_message_parts(
|
||||
include_pending_tool_calls=True
|
||||
)
|
||||
plain_text = collect_plain_text_from_message_parts(
|
||||
message_parts_to_save
|
||||
)
|
||||
try:
|
||||
extracted_refs = self.extract_web_search_refs(
|
||||
plain_text,
|
||||
message_parts_to_save,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
f"[Live Chat] Failed to extract web search refs: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
extracted_refs = refs
|
||||
|
||||
saved_record = await self.save_bot_message(
|
||||
session_id,
|
||||
message_parts_to_save,
|
||||
agent_stats,
|
||||
extracted_refs,
|
||||
llm_checkpoint_id,
|
||||
)
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
return saved_record
|
||||
|
||||
pending_bot_message_flusher = flush_pending_bot_message
|
||||
|
||||
async def send_attachment_saved_event(part: dict | None) -> None:
|
||||
if not part or not part.get("attachment_id") or not part.get("type"):
|
||||
return
|
||||
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "attachment_saved",
|
||||
"data": {
|
||||
"id": part["attachment_id"],
|
||||
"type": part["type"],
|
||||
},
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
session.should_interrupt = False
|
||||
await flush_pending_bot_message()
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if result.get("message_id") and result.get("message_id") != message_id:
|
||||
continue
|
||||
|
||||
result_text = result.get("data", "")
|
||||
result_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
if chain_type == "agent_stats":
|
||||
try:
|
||||
parsed_agent_stats = json.loads(result_text)
|
||||
agent_stats = parsed_agent_stats
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "agent_stats",
|
||||
"data": parsed_agent_stats,
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
outgoing = {"ct": "chat", **result}
|
||||
await self.send_chat_payload(session, outgoing, send_json)
|
||||
|
||||
if result_type == "plain":
|
||||
message_accumulator.add_plain(
|
||||
result_text,
|
||||
chain_type=chain_type,
|
||||
streaming=streaming,
|
||||
)
|
||||
elif result_type == "image":
|
||||
filename = str(result_text).replace("[IMAGE]", "")
|
||||
part = await self.create_attachment_from_file(filename, "image")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
elif result_type == "record":
|
||||
filename = str(result_text).replace("[RECORD]", "")
|
||||
part = await self.create_attachment_from_file(filename, "record")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
elif result_type == "file":
|
||||
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
|
||||
part = await self.create_attachment_from_file(filename, "file")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
elif result_type == "video":
|
||||
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
|
||||
part = await self.create_attachment_from_file(filename, "video")
|
||||
message_accumulator.add_attachment(part)
|
||||
await send_attachment_saved_event(part)
|
||||
|
||||
should_save = False
|
||||
if result_type == "end":
|
||||
should_save = bool(
|
||||
message_accumulator.has_content() or refs or agent_stats
|
||||
)
|
||||
elif (streaming and result_type == "complete") or not streaming:
|
||||
if chain_type not in (
|
||||
"tool_call",
|
||||
"tool_call_result",
|
||||
"agent_stats",
|
||||
):
|
||||
should_save = True
|
||||
|
||||
if should_save:
|
||||
saved_record = await flush_pending_bot_message()
|
||||
if saved_record:
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": to_utc_isoformat(
|
||||
saved_record.created_at
|
||||
),
|
||||
"llm_checkpoint_id": llm_checkpoint_id,
|
||||
},
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
|
||||
if result_type == "end":
|
||||
break
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 处理 chat 消息失败: {exc}", exc_info=True)
|
||||
await self.send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": f"处理失败: {str(exc)}",
|
||||
"code": "PROCESSING_ERROR",
|
||||
},
|
||||
send_json,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
if pending_bot_message_flusher is not None:
|
||||
await pending_bot_message_flusher()
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
f"[Live Chat] Failed to persist pending chat message: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
session.is_processing = False
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
async def build_chat_message_parts(self, message: list[dict]) -> list[dict]:
|
||||
return await build_webchat_message_parts(
|
||||
message,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
async def handle_live_message(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
message: dict,
|
||||
send_json: SendJson,
|
||||
) -> None:
|
||||
msg_type = message.get("t")
|
||||
|
||||
if msg_type == "start_speaking":
|
||||
stamp = message.get("stamp")
|
||||
if not stamp:
|
||||
logger.warning("[Live Chat] start_speaking 缺少 stamp")
|
||||
return
|
||||
session.start_speaking(stamp)
|
||||
return
|
||||
|
||||
if msg_type == "speaking_part":
|
||||
audio_data_b64 = message.get("data")
|
||||
if not audio_data_b64:
|
||||
return
|
||||
try:
|
||||
audio_data = base64.b64decode(audio_data_b64)
|
||||
session.add_audio_frame(audio_data)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 解码音频数据失败: {exc}")
|
||||
return
|
||||
|
||||
if msg_type == "end_speaking":
|
||||
stamp = message.get("stamp")
|
||||
if not stamp:
|
||||
logger.warning("[Live Chat] end_speaking 缺少 stamp")
|
||||
return
|
||||
|
||||
audio_path, assemble_duration = await session.end_speaking(stamp)
|
||||
if not audio_path:
|
||||
await send_json({"t": "error", "data": "音频组装失败"})
|
||||
return
|
||||
|
||||
await self.process_audio(session, audio_path, assemble_duration, send_json)
|
||||
return
|
||||
|
||||
if msg_type == "interrupt":
|
||||
session.should_interrupt = True
|
||||
logger.info(f"[Live Chat] 用户打断: {session.username}")
|
||||
|
||||
async def process_audio(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
audio_path: str,
|
||||
assemble_duration: float,
|
||||
send_json: SendJson,
|
||||
) -> None:
|
||||
try:
|
||||
await send_json(
|
||||
{"t": "metrics", "data": {"wav_assemble_time": assemble_duration}}
|
||||
)
|
||||
wav_assembly_finish_time = time.time()
|
||||
|
||||
session.is_processing = True
|
||||
session.should_interrupt = False
|
||||
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.provider_manager.stt_provider_insts[0]
|
||||
|
||||
if not stt_provider:
|
||||
logger.error("[Live Chat] STT Provider 未配置")
|
||||
await send_json({"t": "error", "data": "语音识别服务未配置"})
|
||||
return
|
||||
|
||||
await send_json({"t": "metrics", "data": {"stt": stt_provider.meta().type}})
|
||||
|
||||
user_text = await stt_provider.get_text(audio_path)
|
||||
if not user_text:
|
||||
logger.warning("[Live Chat] STT 识别结果为空")
|
||||
return
|
||||
|
||||
logger.info(f"[Live Chat] STT 结果: {user_text}")
|
||||
|
||||
await send_json(
|
||||
{
|
||||
"t": "user_msg",
|
||||
"data": {"text": user_text, "ts": int(time.time() * 1000)},
|
||||
}
|
||||
)
|
||||
|
||||
conversation_id = session.conversation_id
|
||||
queue = webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
payload = {
|
||||
"message_id": message_id,
|
||||
"message": [{"type": "plain", "text": user_text}],
|
||||
"action_type": "live",
|
||||
}
|
||||
|
||||
await queue.put((session.username, conversation_id, payload))
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
message_id, conversation_id
|
||||
)
|
||||
|
||||
bot_text = ""
|
||||
audio_playing = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await send_json({"t": "stop_play"})
|
||||
await self.save_interrupted_message(
|
||||
session, user_text, bot_text
|
||||
)
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {exc}")
|
||||
continue
|
||||
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {exc}")
|
||||
continue
|
||||
|
||||
if result_type == "plain":
|
||||
bot_text += data
|
||||
|
||||
elif result_type == "audio_chunk":
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
await send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
if not audio_playing:
|
||||
await send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await send_json({"t": "end"})
|
||||
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
break
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 处理音频失败: {exc}", exc_info=True)
|
||||
await send_json({"t": "error", "data": f"处理失败: {str(exc)}"})
|
||||
|
||||
finally:
|
||||
session.is_processing = False
|
||||
session.should_interrupt = False
|
||||
|
||||
@staticmethod
|
||||
async def save_interrupted_message(
|
||||
session: LiveChatSession, user_text: str, bot_text: str
|
||||
) -> None:
|
||||
interrupted_text = bot_text + " [用户打断]"
|
||||
logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}")
|
||||
|
||||
try:
|
||||
timestamp = int(time.time() * 1000)
|
||||
logger.info(
|
||||
f"[Live Chat] 用户消息: {user_text} (session: {session.session_id}, ts: {timestamp})"
|
||||
)
|
||||
if bot_text:
|
||||
logger.info(
|
||||
f"[Live Chat] Bot 消息(打断): {interrupted_text} (session: {session.session_id}, ts: {timestamp})"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Live Chat] 记录消息失败: {exc}", exc_info=True)
|
||||
|
||||
|
||||
__all__ = ["LiveChatAuthError", "LiveChatService", "LiveChatSession"]
|
||||
97
astrbot/dashboard/services/log_service.py
Normal file
97
astrbot/dashboard/services/log_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
|
||||
|
||||
class LogServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LogService:
|
||||
def __init__(self, log_broker: LogBroker, config: AstrBotConfig) -> None:
|
||||
self.log_broker = log_broker
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
def format_log_sse(log: dict, ts: float) -> str:
|
||||
payload = {
|
||||
"type": "log",
|
||||
**log,
|
||||
}
|
||||
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def replay_cached_logs(self, last_event_id: str) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
last_ts = float(last_event_id)
|
||||
cached_logs = list(self.log_broker.log_cache)
|
||||
|
||||
for log_item in cached_logs:
|
||||
log_ts = float(log_item.get("time", 0))
|
||||
if log_ts > last_ts:
|
||||
yield self.format_log_sse(log_item, log_ts)
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.error(f"Log SSE 补发历史错误: {exc}")
|
||||
|
||||
async def stream_log_events(
|
||||
self, last_event_id: str | None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
queue = None
|
||||
try:
|
||||
if last_event_id:
|
||||
async for event in self.replay_cached_logs(last_event_id):
|
||||
yield event
|
||||
|
||||
queue = self.log_broker.register()
|
||||
while True:
|
||||
message = await queue.get()
|
||||
current_ts = message.get("time", time.time())
|
||||
yield self.format_log_sse(message, current_ts)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.error(f"Log SSE 连接错误: {exc}")
|
||||
finally:
|
||||
if queue:
|
||||
self.log_broker.unregister(queue)
|
||||
|
||||
def get_log_history(self) -> dict:
|
||||
try:
|
||||
return {"logs": list(self.log_broker.log_cache)}
|
||||
except Exception as exc:
|
||||
logger.error(f"获取日志历史失败: {exc}")
|
||||
raise LogServiceError(f"获取日志历史失败: {exc}") from exc
|
||||
|
||||
def get_trace_settings(self) -> dict:
|
||||
try:
|
||||
return {"trace_enable": self.config.get("trace_enable", True)}
|
||||
except Exception as exc:
|
||||
logger.error(f"获取 Trace 设置失败: {exc}")
|
||||
raise LogServiceError(f"获取 Trace 设置失败: {exc}") from exc
|
||||
|
||||
def update_trace_settings(self, payload: dict | None) -> str:
|
||||
try:
|
||||
if payload is None:
|
||||
raise LogServiceError("请求数据为空")
|
||||
|
||||
trace_enable = payload.get("trace_enable")
|
||||
if trace_enable is not None:
|
||||
self.config["trace_enable"] = bool(trace_enable)
|
||||
self.config.save_config()
|
||||
|
||||
return "Trace 设置已更新"
|
||||
except LogServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"更新 Trace 设置失败: {exc}")
|
||||
raise LogServiceError(f"更新 Trace 设置失败: {exc}") from exc
|
||||
|
||||
def update_trace_settings_from_legacy_payload(self, data: object) -> str:
|
||||
return self.update_trace_settings(data if isinstance(data, dict) else None)
|
||||
650
astrbot/dashboard/services/open_api_service.py
Normal file
650
astrbot/dashboard/services/open_api_service.py
Normal file
@@ -0,0 +1,650 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.platform.message_session import MessageSesion
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_message_chain_from_payload,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
from astrbot.dashboard.services.api_key_service import ApiKeyService
|
||||
from astrbot.dashboard.services.auth_service import ALL_OPEN_API_SCOPES
|
||||
from astrbot.dashboard.services.chat_service import (
|
||||
BotMessageAccumulator,
|
||||
collect_plain_text_from_message_parts,
|
||||
)
|
||||
|
||||
SendJson = Callable[[dict], Awaitable[None]]
|
||||
ReceiveJson = Callable[[], Awaitable[Any]]
|
||||
CloseWebSocket = Callable[[int, str], Awaitable[None]]
|
||||
|
||||
|
||||
class OpenApiServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenApiWebSocketChatBridge:
|
||||
build_user_message_parts: Callable[[object], Awaitable[list]]
|
||||
create_attachment_from_file: Callable[[str, str], Awaitable[Any]]
|
||||
extract_web_search_refs: Callable[[str, list], dict]
|
||||
insert_user_message: Callable[[str, str, list], Awaitable[None]]
|
||||
save_bot_message: Callable[[str, list, dict, dict], Awaitable[Any]]
|
||||
|
||||
|
||||
class OpenApiService:
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.platform_manager = core_lifecycle.platform_manager
|
||||
self.platform_history_mgr = getattr(
|
||||
core_lifecycle,
|
||||
"platform_message_history_manager",
|
||||
None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_open_username(
|
||||
raw_username: str | None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
if raw_username is None:
|
||||
return None, "Missing key: username"
|
||||
username = str(raw_username).strip()
|
||||
if not username:
|
||||
return None, "username is empty"
|
||||
return username, None
|
||||
|
||||
def get_chat_config_list(self) -> list[dict]:
|
||||
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
|
||||
|
||||
result = []
|
||||
for conf_info in conf_list:
|
||||
conf_id = str(conf_info.get("id", "")).strip()
|
||||
result.append(
|
||||
{
|
||||
"id": conf_id,
|
||||
"name": str(conf_info.get("name", "")).strip(),
|
||||
"path": str(conf_info.get("path", "")).strip(),
|
||||
"is_default": conf_id == "default",
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def resolve_chat_config_id(
|
||||
post_data: dict,
|
||||
conf_list: list[dict],
|
||||
) -> tuple[str | None, str | None]:
|
||||
raw_config_id = post_data.get("config_id")
|
||||
raw_config_name = post_data.get("config_name")
|
||||
config_id = str(raw_config_id).strip() if raw_config_id is not None else ""
|
||||
config_name = (
|
||||
str(raw_config_name).strip() if raw_config_name is not None else ""
|
||||
)
|
||||
|
||||
if not config_id and not config_name:
|
||||
return None, None
|
||||
|
||||
conf_map = {item["id"]: item for item in conf_list}
|
||||
|
||||
if config_id:
|
||||
if config_id not in conf_map:
|
||||
return None, f"config_id not found: {config_id}"
|
||||
return config_id, None
|
||||
|
||||
if not config_name:
|
||||
return None, "config_name is empty"
|
||||
|
||||
matched = [item for item in conf_list if item["name"] == config_name]
|
||||
if not matched:
|
||||
return None, f"config_name not found: {config_name}"
|
||||
if len(matched) > 1:
|
||||
return (
|
||||
None,
|
||||
f"config_name is ambiguous, please use config_id: {config_name}",
|
||||
)
|
||||
|
||||
return matched[0]["id"], None
|
||||
|
||||
async def prepare_chat_send(
|
||||
self,
|
||||
post_data: dict,
|
||||
conf_list: list[dict],
|
||||
) -> tuple[str, str, str | None]:
|
||||
effective_username, username_err = self.resolve_open_username(
|
||||
post_data.get("username")
|
||||
)
|
||||
if username_err:
|
||||
raise OpenApiServiceError(username_err)
|
||||
if not effective_username:
|
||||
raise OpenApiServiceError("Invalid username")
|
||||
|
||||
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
|
||||
if not session_id:
|
||||
session_id = str(uuid4())
|
||||
post_data["session_id"] = session_id
|
||||
|
||||
ensure_session_err = await self.ensure_chat_session(
|
||||
effective_username,
|
||||
session_id,
|
||||
)
|
||||
if ensure_session_err:
|
||||
raise OpenApiServiceError(ensure_session_err)
|
||||
|
||||
config_id, resolve_err = self.resolve_chat_config_id(post_data, conf_list)
|
||||
if resolve_err:
|
||||
raise OpenApiServiceError(resolve_err)
|
||||
|
||||
return effective_username, session_id, config_id
|
||||
|
||||
async def ensure_chat_session(
|
||||
self,
|
||||
username: str,
|
||||
session_id: str,
|
||||
) -> str | None:
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if session:
|
||||
if session.creator != username:
|
||||
return "session_id belongs to another username"
|
||||
return None
|
||||
|
||||
try:
|
||||
await self.db.create_platform_session(
|
||||
creator=username,
|
||||
platform_id="webchat",
|
||||
session_id=session_id,
|
||||
is_group=0,
|
||||
)
|
||||
except Exception as exc:
|
||||
existing = await self.db.get_platform_session_by_id(session_id)
|
||||
if existing and existing.creator == username:
|
||||
return None
|
||||
logger.error("Failed to create chat session %s: %s", session_id, exc)
|
||||
return f"Failed to create session: {exc}"
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate_api_key(
|
||||
self, raw_key: str | None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not raw_key:
|
||||
return False, "Missing API key"
|
||||
|
||||
key_hash = ApiKeyService.hash_key(raw_key)
|
||||
api_key = await self.db.get_active_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
return False, "Invalid API key"
|
||||
|
||||
if isinstance(api_key.scopes, list):
|
||||
scopes = api_key.scopes
|
||||
else:
|
||||
scopes = list(ALL_OPEN_API_SCOPES)
|
||||
|
||||
if "*" not in scopes and "chat" not in scopes:
|
||||
return False, "Insufficient API key scope"
|
||||
|
||||
await self.db.touch_api_key(api_key.key_id)
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
async def send_chat_ws_error(
|
||||
send_json: SendJson,
|
||||
message: str,
|
||||
code: str,
|
||||
) -> None:
|
||||
await send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": code,
|
||||
"data": message,
|
||||
}
|
||||
)
|
||||
|
||||
async def run_chat_websocket(
|
||||
self,
|
||||
*,
|
||||
raw_api_key: str | None,
|
||||
receive_json: ReceiveJson,
|
||||
send_json: SendJson,
|
||||
close: CloseWebSocket,
|
||||
conf_list: list[dict],
|
||||
chat_bridge: OpenApiWebSocketChatBridge,
|
||||
) -> None:
|
||||
authed, auth_err = await self.authenticate_api_key(raw_api_key)
|
||||
if not authed:
|
||||
message = auth_err or "Unauthorized"
|
||||
await self.send_chat_ws_error(send_json, message, "UNAUTHORIZED")
|
||||
await close(1008, message)
|
||||
return
|
||||
|
||||
async def send_error(message: str, code: str) -> None:
|
||||
await self.send_chat_ws_error(send_json, message, code)
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await receive_json()
|
||||
if not isinstance(message, dict):
|
||||
await send_error(
|
||||
"message must be an object",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
msg_type = message.get("t", "send")
|
||||
if msg_type == "ping":
|
||||
await send_json({"type": "pong"})
|
||||
continue
|
||||
if msg_type != "send":
|
||||
await send_error(
|
||||
f"Unsupported message type: {msg_type}",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
await self.handle_chat_ws_send(
|
||||
post_data=message,
|
||||
conf_list=conf_list,
|
||||
chat_bridge=chat_bridge,
|
||||
send_json=send_json,
|
||||
send_error=send_error,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Open API WS connection closed: %s", exc)
|
||||
|
||||
async def update_session_config_route(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
session_id: str,
|
||||
config_id: str | None,
|
||||
) -> str | None:
|
||||
if not config_id:
|
||||
return None
|
||||
|
||||
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
|
||||
try:
|
||||
if config_id == "default":
|
||||
await self.core_lifecycle.umop_config_router.delete_route(umo)
|
||||
else:
|
||||
await self.core_lifecycle.umop_config_router.update_route(
|
||||
umo, config_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to update chat config route for %s with %s: %s",
|
||||
umo,
|
||||
config_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return f"Failed to update chat config route: {exc}"
|
||||
return None
|
||||
|
||||
async def insert_webchat_user_message(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
effective_username: str,
|
||||
message_parts: list,
|
||||
) -> None:
|
||||
if self.platform_history_mgr is None:
|
||||
raise OpenApiServiceError("Platform message history manager is unavailable")
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=session_id,
|
||||
content={"type": "user", "message": message_parts},
|
||||
sender_id=effective_username,
|
||||
sender_name=effective_username,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_chat_send_error_code(message: str) -> str:
|
||||
if message in ("Missing key: username", "username is empty"):
|
||||
return "BAD_USER"
|
||||
if message.startswith("config_"):
|
||||
return "CONFIG_ERROR"
|
||||
if "session" in message:
|
||||
return "SESSION_ERROR"
|
||||
return "INVALID_MESSAGE"
|
||||
|
||||
async def handle_chat_ws_send(
|
||||
self,
|
||||
*,
|
||||
post_data: dict,
|
||||
conf_list: list[dict],
|
||||
chat_bridge: OpenApiWebSocketChatBridge,
|
||||
send_json: SendJson,
|
||||
send_error: Callable[[str, str], Awaitable[None]],
|
||||
) -> None:
|
||||
message = post_data.get("message")
|
||||
if message is None:
|
||||
await send_error("Missing key: message", "INVALID_MESSAGE")
|
||||
return
|
||||
|
||||
try:
|
||||
(
|
||||
effective_username,
|
||||
session_id,
|
||||
config_id,
|
||||
) = await self.prepare_chat_send(
|
||||
post_data,
|
||||
conf_list,
|
||||
)
|
||||
except OpenApiServiceError as exc:
|
||||
message = str(exc)
|
||||
await send_error(message, self.get_chat_send_error_code(message))
|
||||
return
|
||||
|
||||
config_err = await self.update_session_config_route(
|
||||
username=effective_username,
|
||||
session_id=session_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
if config_err:
|
||||
await send_error(config_err, "CONFIG_ERROR")
|
||||
return
|
||||
|
||||
message_parts = await chat_bridge.build_user_message_parts(message)
|
||||
if not webchat_message_parts_have_content(message_parts):
|
||||
await send_error(
|
||||
"Message content is empty (reply only is not allowed)",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
return
|
||||
|
||||
message_id = str(post_data.get("message_id") or uuid4())
|
||||
selected_provider = post_data.get("selected_provider")
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True)
|
||||
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||
try:
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
effective_username,
|
||||
session_id,
|
||||
{
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"enable_streaming": enable_streaming,
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
await chat_bridge.insert_user_message(
|
||||
session_id,
|
||||
effective_username,
|
||||
message_parts_for_storage,
|
||||
)
|
||||
|
||||
await send_json(
|
||||
{
|
||||
"type": "session_id",
|
||||
"data": None,
|
||||
"session_id": session_id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if "message_id" in result and result["message_id"] != message_id:
|
||||
logger.warning("openapi ws stream message_id mismatch")
|
||||
continue
|
||||
|
||||
result_text = result.get("data", "")
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
|
||||
if chain_type == "agent_stats":
|
||||
try:
|
||||
stats_info = {
|
||||
"type": "agent_stats",
|
||||
"data": json.loads(result_text),
|
||||
}
|
||||
await send_json(stats_info)
|
||||
agent_stats = stats_info["data"]
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
await send_json(result)
|
||||
|
||||
if msg_type == "plain":
|
||||
message_accumulator.add_plain(
|
||||
result_text,
|
||||
chain_type=chain_type,
|
||||
streaming=streaming,
|
||||
)
|
||||
elif msg_type in {"image", "record", "file", "video"}:
|
||||
filename = str(result_text).replace(f"[{msg_type.upper()}]", "")
|
||||
part = await chat_bridge.create_attachment_from_file(
|
||||
filename,
|
||||
msg_type,
|
||||
)
|
||||
message_accumulator.add_attachment(part)
|
||||
|
||||
should_save = False
|
||||
if msg_type == "end":
|
||||
should_save = bool(
|
||||
message_accumulator.has_content() or refs or agent_stats
|
||||
)
|
||||
elif (streaming and msg_type == "complete") or not streaming:
|
||||
if chain_type not in ("tool_call", "tool_call_result"):
|
||||
should_save = True
|
||||
|
||||
if should_save:
|
||||
message_parts_to_save = message_accumulator.build_message_parts(
|
||||
include_pending_tool_calls=True
|
||||
)
|
||||
plain_text = collect_plain_text_from_message_parts(
|
||||
message_parts_to_save
|
||||
)
|
||||
try:
|
||||
refs = chat_bridge.extract_web_search_refs(
|
||||
plain_text,
|
||||
message_parts_to_save,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
f"Open API WS failed to extract web search refs: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
saved_record = await chat_bridge.save_bot_message(
|
||||
session_id,
|
||||
message_parts_to_save,
|
||||
agent_stats,
|
||||
refs,
|
||||
)
|
||||
if saved_record:
|
||||
await send_json(
|
||||
{
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": to_utc_isoformat(
|
||||
saved_record.created_at
|
||||
),
|
||||
},
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
message_accumulator = BotMessageAccumulator()
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
if msg_type == "end":
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.exception(f"Open API WS chat failed: {exc}", exc_info=True)
|
||||
await send_error(f"Failed to process message: {exc}", "PROCESSING_ERROR")
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
async def get_chat_sessions(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
page_raw,
|
||||
page_size_raw,
|
||||
platform_id: str | None,
|
||||
) -> dict:
|
||||
try:
|
||||
page = int(page_raw)
|
||||
page_size = int(page_size_raw)
|
||||
except ValueError as exc:
|
||||
raise OpenApiServiceError("page and page_size must be integers") from exc
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
(
|
||||
paginated_sessions,
|
||||
total,
|
||||
) = await self.db.get_platform_sessions_by_creator_paginated(
|
||||
creator=username,
|
||||
platform_id=platform_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
exclude_project_sessions=True,
|
||||
)
|
||||
|
||||
sessions_data = []
|
||||
for item in paginated_sessions:
|
||||
session = item["session"]
|
||||
sessions_data.append(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"platform_id": session.platform_id,
|
||||
"creator": session.creator,
|
||||
"display_name": session.display_name,
|
||||
"is_group": session.is_group,
|
||||
"created_at": to_utc_isoformat(session.created_at),
|
||||
"updated_at": to_utc_isoformat(session.updated_at),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"sessions": sessions_data,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
async def get_chat_sessions_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
username: str | None,
|
||||
page,
|
||||
page_size,
|
||||
platform_id: str | None,
|
||||
) -> dict:
|
||||
resolved_username, username_err = self.resolve_open_username(username)
|
||||
if username_err:
|
||||
raise OpenApiServiceError(username_err)
|
||||
if not resolved_username:
|
||||
raise OpenApiServiceError("Invalid username")
|
||||
|
||||
return await self.get_chat_sessions(
|
||||
username=resolved_username,
|
||||
page_raw=page,
|
||||
page_size_raw=page_size,
|
||||
platform_id=platform_id,
|
||||
)
|
||||
|
||||
def get_chat_configs(self) -> dict:
|
||||
return {"configs": self.get_chat_config_list()}
|
||||
|
||||
async def build_message_chain_from_payload(self, message_payload: str | list):
|
||||
return await build_message_chain_from_payload(
|
||||
message_payload,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
async def send_message(self, post_data: object) -> None:
|
||||
payload = post_data if isinstance(post_data, dict) else {}
|
||||
message_payload = payload.get("message", {})
|
||||
umo = payload.get("umo")
|
||||
|
||||
if message_payload is None:
|
||||
raise OpenApiServiceError("Missing key: message")
|
||||
if not umo:
|
||||
raise OpenApiServiceError("Missing key: umo")
|
||||
|
||||
try:
|
||||
session = MessageSesion.from_str(str(umo))
|
||||
except Exception as exc:
|
||||
raise OpenApiServiceError(f"Invalid umo: {exc}") from exc
|
||||
|
||||
platform_id = session.platform_name
|
||||
platform_inst = next(
|
||||
(
|
||||
inst
|
||||
for inst in self.platform_manager.platform_insts
|
||||
if inst.meta().id == platform_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not platform_inst:
|
||||
raise OpenApiServiceError(
|
||||
f"Bot not found or not running for platform: {platform_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
message_chain = await self.build_message_chain_from_payload(message_payload)
|
||||
await platform_inst.send_by_session(session, message_chain)
|
||||
except OpenApiServiceError:
|
||||
raise
|
||||
except ValueError as exc:
|
||||
raise OpenApiServiceError(str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Open API send_message failed: {exc}", exc_info=True)
|
||||
raise OpenApiServiceError(f"Failed to send message: {exc}") from exc
|
||||
|
||||
def get_bots(self) -> dict:
|
||||
bot_ids = []
|
||||
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
|
||||
platform_id = platform.get("id") if isinstance(platform, dict) else None
|
||||
if (
|
||||
isinstance(platform_id, str)
|
||||
and platform_id
|
||||
and platform_id not in bot_ids
|
||||
):
|
||||
bot_ids.append(platform_id)
|
||||
return {"bot_ids": bot_ids}
|
||||
293
astrbot/dashboard/services/persona_service.py
Normal file
293
astrbot/dashboard/services/persona_service.py
Normal file
@@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.sentinels import NOT_GIVEN
|
||||
|
||||
|
||||
class PersonaServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PersonaService:
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
self.persona_mgr = core_lifecycle.persona_mgr
|
||||
|
||||
async def list_personas(
|
||||
self,
|
||||
folder_id: str | None,
|
||||
filter_by_folder: bool,
|
||||
) -> list[dict]:
|
||||
if filter_by_folder:
|
||||
personas = await self.persona_mgr.get_personas_by_folder(
|
||||
folder_id if folder_id else None
|
||||
)
|
||||
else:
|
||||
personas = await self.persona_mgr.get_all_personas()
|
||||
return [self.serialize_persona(persona) for persona in personas]
|
||||
|
||||
async def list_personas_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
folder_id: str | None,
|
||||
has_folder_id: bool,
|
||||
) -> list[dict]:
|
||||
return await self.list_personas(folder_id, has_folder_id)
|
||||
|
||||
async def get_persona_detail(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
persona_id = payload.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
raise PersonaServiceError("缺少必要参数: persona_id")
|
||||
|
||||
persona = await self.persona_mgr.get_persona(persona_id)
|
||||
if not persona:
|
||||
raise PersonaServiceError("人格不存在")
|
||||
|
||||
return self.serialize_persona(persona)
|
||||
|
||||
async def create_persona(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
persona_id = str(payload.get("persona_id", "")).strip()
|
||||
system_prompt = str(payload.get("system_prompt", "")).strip()
|
||||
begin_dialogs = payload.get("begin_dialogs", [])
|
||||
tools = payload.get("tools")
|
||||
skills = payload.get("skills")
|
||||
custom_error_message = self._normalize_custom_error_message(
|
||||
payload.get("custom_error_message")
|
||||
)
|
||||
folder_id = payload.get("folder_id")
|
||||
sort_order = payload.get("sort_order", 0)
|
||||
|
||||
if not persona_id:
|
||||
raise PersonaServiceError("人格ID不能为空")
|
||||
if not system_prompt:
|
||||
raise PersonaServiceError("系统提示词不能为空")
|
||||
|
||||
self._validate_begin_dialogs(begin_dialogs)
|
||||
|
||||
persona = await self.persona_mgr.create_persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs if begin_dialogs else None,
|
||||
tools=tools if tools else None,
|
||||
skills=skills if skills else None,
|
||||
custom_error_message=custom_error_message,
|
||||
folder_id=folder_id,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "人格创建成功",
|
||||
"persona": self.serialize_persona(persona, empty_lists_for_tools=True),
|
||||
}
|
||||
|
||||
async def update_persona(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
persona_id = payload.get("persona_id")
|
||||
system_prompt = payload.get("system_prompt")
|
||||
begin_dialogs = payload.get("begin_dialogs")
|
||||
has_tools = "tools" in payload
|
||||
tools = payload.get("tools")
|
||||
has_skills = "skills" in payload
|
||||
skills = payload.get("skills")
|
||||
has_custom_error_message = "custom_error_message" in payload
|
||||
custom_error_message = payload.get("custom_error_message")
|
||||
|
||||
if not persona_id:
|
||||
raise PersonaServiceError("缺少必要参数: persona_id")
|
||||
|
||||
if has_custom_error_message:
|
||||
custom_error_message = self._normalize_custom_error_message(
|
||||
custom_error_message
|
||||
)
|
||||
|
||||
if begin_dialogs is not None:
|
||||
self._validate_begin_dialogs(begin_dialogs)
|
||||
|
||||
update_kwargs = {
|
||||
"persona_id": persona_id,
|
||||
"system_prompt": system_prompt,
|
||||
"begin_dialogs": begin_dialogs,
|
||||
}
|
||||
if has_tools:
|
||||
update_kwargs["tools"] = tools
|
||||
if has_skills:
|
||||
update_kwargs["skills"] = skills
|
||||
if has_custom_error_message:
|
||||
update_kwargs["custom_error_message"] = custom_error_message
|
||||
|
||||
await self.persona_mgr.update_persona(**update_kwargs)
|
||||
return {"message": "人格更新成功"}
|
||||
|
||||
async def delete_persona(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
persona_id = payload.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
raise PersonaServiceError("缺少必要参数: persona_id")
|
||||
|
||||
await self.persona_mgr.delete_persona(persona_id)
|
||||
return {"message": "人格删除成功"}
|
||||
|
||||
async def move_persona(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
persona_id = payload.get("persona_id")
|
||||
folder_id = payload.get("folder_id")
|
||||
|
||||
if not persona_id:
|
||||
raise PersonaServiceError("缺少必要参数: persona_id")
|
||||
|
||||
await self.persona_mgr.move_persona_to_folder(persona_id, folder_id)
|
||||
return {"message": "人格移动成功"}
|
||||
|
||||
async def list_folders(self, parent_id: str | None) -> list[dict]:
|
||||
folders = await self.persona_mgr.get_folders(parent_id)
|
||||
return [self.serialize_folder(folder) for folder in folders]
|
||||
|
||||
async def list_folders_from_legacy_query(
|
||||
self,
|
||||
parent_id: str | None,
|
||||
) -> list[dict]:
|
||||
if parent_id == "":
|
||||
parent_id = None
|
||||
return await self.list_folders(parent_id)
|
||||
|
||||
async def get_folder_tree(self):
|
||||
return await self.persona_mgr.get_folder_tree()
|
||||
|
||||
async def get_folder_detail(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
folder_id = payload.get("folder_id")
|
||||
|
||||
if not folder_id:
|
||||
raise PersonaServiceError("缺少必要参数: folder_id")
|
||||
|
||||
folder = await self.persona_mgr.get_folder(folder_id)
|
||||
if not folder:
|
||||
raise PersonaServiceError("文件夹不存在")
|
||||
|
||||
return self.serialize_folder(folder)
|
||||
|
||||
async def create_folder(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
name = str(payload.get("name", "")).strip()
|
||||
parent_id = payload.get("parent_id")
|
||||
description = payload.get("description")
|
||||
sort_order = payload.get("sort_order", 0)
|
||||
|
||||
if not name:
|
||||
raise PersonaServiceError("文件夹名称不能为空")
|
||||
|
||||
folder = await self.persona_mgr.create_folder(
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
description=description,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "文件夹创建成功",
|
||||
"folder": self.serialize_folder(folder),
|
||||
}
|
||||
|
||||
async def update_folder(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
folder_id = payload.get("folder_id")
|
||||
name = payload.get("name")
|
||||
parent_id = payload.get("parent_id") if "parent_id" in payload else NOT_GIVEN
|
||||
description = (
|
||||
payload.get("description") if "description" in payload else NOT_GIVEN
|
||||
)
|
||||
sort_order = payload.get("sort_order")
|
||||
|
||||
if not folder_id:
|
||||
raise PersonaServiceError("缺少必要参数: folder_id")
|
||||
|
||||
await self.persona_mgr.update_folder(
|
||||
folder_id=folder_id,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
description=description,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
return {"message": "文件夹更新成功"}
|
||||
|
||||
async def delete_folder(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
folder_id = payload.get("folder_id")
|
||||
|
||||
if not folder_id:
|
||||
raise PersonaServiceError("缺少必要参数: folder_id")
|
||||
|
||||
await self.persona_mgr.delete_folder(folder_id)
|
||||
return {"message": "文件夹删除成功"}
|
||||
|
||||
async def reorder_items(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
items = payload.get("items", [])
|
||||
|
||||
if not items:
|
||||
raise PersonaServiceError("items 不能为空")
|
||||
|
||||
for item in items:
|
||||
if not all(key in item for key in ("id", "type", "sort_order")):
|
||||
raise PersonaServiceError(
|
||||
"每个 item 必须包含 id, type, sort_order 字段"
|
||||
)
|
||||
if item["type"] not in ("persona", "folder"):
|
||||
raise PersonaServiceError("type 字段必须是 'persona' 或 'folder'")
|
||||
|
||||
await self.persona_mgr.batch_update_sort_order(items)
|
||||
return {"message": "排序更新成功"}
|
||||
|
||||
@staticmethod
|
||||
def serialize_persona(persona, empty_lists_for_tools: bool = False) -> dict:
|
||||
return {
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": (persona.tools or []) if empty_lists_for_tools else persona.tools,
|
||||
"skills": (persona.skills or [])
|
||||
if empty_lists_for_tools
|
||||
else persona.skills,
|
||||
"custom_error_message": persona.custom_error_message,
|
||||
"folder_id": persona.folder_id,
|
||||
"sort_order": persona.sort_order,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def serialize_folder(folder) -> dict:
|
||||
return {
|
||||
"folder_id": folder.folder_id,
|
||||
"name": folder.name,
|
||||
"parent_id": folder.parent_id,
|
||||
"description": folder.description,
|
||||
"sort_order": folder.sort_order,
|
||||
"created_at": folder.created_at.isoformat() if folder.created_at else None,
|
||||
"updated_at": folder.updated_at.isoformat() if folder.updated_at else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_custom_error_message(value):
|
||||
if value is not None:
|
||||
if not isinstance(value, str):
|
||||
raise PersonaServiceError("自定义报错回复信息必须是字符串")
|
||||
return value.strip() or None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_begin_dialogs(begin_dialogs) -> None:
|
||||
if begin_dialogs and len(begin_dialogs) % 2 != 0:
|
||||
raise PersonaServiceError("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict:
|
||||
return data if isinstance(data, dict) else {}
|
||||
218
astrbot/dashboard/services/platform_service.py
Normal file
218
astrbot/dashboard/services/platform_service.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.sources.dingtalk.app_registration import (
|
||||
poll_dingtalk_app_registration_once,
|
||||
request_dingtalk_app_registration,
|
||||
)
|
||||
from astrbot.core.platform.sources.lark.app_registration import (
|
||||
poll_app_registration_once,
|
||||
request_app_registration,
|
||||
)
|
||||
from astrbot.core.platform.sources.lark.bot_info import request_lark_bot_info
|
||||
from astrbot.core.platform.sources.weixin_oc.login_registration import (
|
||||
poll_weixin_oc_login_once,
|
||||
request_weixin_oc_login_qr,
|
||||
)
|
||||
|
||||
|
||||
class PlatformServiceError(Exception):
|
||||
def __init__(self, message: str, status_code: int = 500) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def random_platform_id_suffix() -> str:
|
||||
return "_" + "".join(secrets.choice(string.ascii_lowercase) for _ in range(4))
|
||||
|
||||
|
||||
class PlatformService:
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
self.platform_manager = core_lifecycle.platform_manager
|
||||
|
||||
async def handle_webhook_callback(self, webhook_uuid: str, request_obj):
|
||||
platform_adapter = self.find_platform_by_uuid(webhook_uuid)
|
||||
|
||||
if not platform_adapter:
|
||||
logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台")
|
||||
raise PlatformServiceError("未找到对应平台", 404)
|
||||
|
||||
try:
|
||||
return await platform_adapter.webhook_callback(request_obj)
|
||||
except NotImplementedError as exc:
|
||||
logger.error(
|
||||
f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法"
|
||||
)
|
||||
raise PlatformServiceError("平台未支持统一 Webhook 模式", 500) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"处理 webhook 回调时发生错误: {exc}", exc_info=True)
|
||||
raise PlatformServiceError("处理回调失败", 500) from exc
|
||||
|
||||
def find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None:
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.config.get("webhook_uuid") == webhook_uuid:
|
||||
if platform.unified_webhook():
|
||||
return platform
|
||||
return None
|
||||
|
||||
def get_platform_stats(self):
|
||||
try:
|
||||
return self.platform_manager.get_all_stats()
|
||||
except Exception as exc:
|
||||
logger.error(f"获取平台统计信息失败: {exc}", exc_info=True)
|
||||
raise PlatformServiceError(f"获取统计信息失败: {exc}", 500) from exc
|
||||
|
||||
async def handle_platform_registration(
|
||||
self,
|
||||
platform_type: str,
|
||||
payload: dict,
|
||||
) -> dict:
|
||||
try:
|
||||
action = str(payload.get("action", "")).strip().lower()
|
||||
if not action:
|
||||
raise PlatformServiceError("Missing action", 400)
|
||||
|
||||
platform_config = payload.get("platform_config")
|
||||
if not isinstance(platform_config, dict):
|
||||
platform_config = {}
|
||||
|
||||
if platform_type == "lark":
|
||||
return await self._handle_lark_registration(
|
||||
action,
|
||||
payload,
|
||||
platform_config,
|
||||
)
|
||||
if platform_type == "weixin_oc":
|
||||
return await self._handle_weixin_oc_registration(
|
||||
action,
|
||||
payload,
|
||||
platform_config,
|
||||
)
|
||||
if platform_type == "dingtalk":
|
||||
return await self._handle_dingtalk_registration(action, payload)
|
||||
|
||||
raise PlatformServiceError(
|
||||
f"Unsupported platform registration: {platform_type}",
|
||||
404,
|
||||
)
|
||||
except PlatformServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"处理平台一键创建请求失败: {exc}", exc_info=True)
|
||||
raise PlatformServiceError(str(exc), 500) from exc
|
||||
|
||||
async def _handle_lark_registration(
|
||||
self,
|
||||
action: str,
|
||||
payload: dict,
|
||||
platform_config: dict,
|
||||
) -> dict:
|
||||
domain = str(platform_config.get("domain") or "").strip()
|
||||
|
||||
if action == "start":
|
||||
registration = await request_app_registration(domain)
|
||||
return {
|
||||
"status": "pending",
|
||||
"device_code": registration.device_code,
|
||||
"registration_code": registration.device_code,
|
||||
"user_code": registration.user_code,
|
||||
"verification_uri": registration.verification_uri,
|
||||
"verification_uri_complete": registration.verification_uri_complete,
|
||||
"expires_in": registration.expires_in,
|
||||
"interval": registration.interval,
|
||||
}
|
||||
|
||||
if action == "poll":
|
||||
device_code = str(
|
||||
payload.get("device_code") or payload.get("registration_code") or ""
|
||||
).strip()
|
||||
if not device_code:
|
||||
raise PlatformServiceError("Missing device_code", 400)
|
||||
result = await poll_app_registration_once(
|
||||
domain=domain,
|
||||
device_code=device_code,
|
||||
)
|
||||
if result.get("status") == "created":
|
||||
try:
|
||||
bot_info = await request_lark_bot_info(
|
||||
domain=str(result.get("domain") or domain),
|
||||
app_id=str(result.get("app_id") or ""),
|
||||
app_secret=str(result.get("app_secret") or ""),
|
||||
)
|
||||
if bot_info.app_name:
|
||||
result["bot_name"] = bot_info.app_name
|
||||
if bot_info.open_id:
|
||||
result["bot_open_id"] = bot_info.open_id
|
||||
except Exception as exc:
|
||||
logger.error(f"获取飞书机器人信息失败: {exc}", exc_info=True)
|
||||
return result
|
||||
|
||||
raise PlatformServiceError(f"Unsupported action: {action}", 400)
|
||||
|
||||
async def _handle_dingtalk_registration(
|
||||
self,
|
||||
action: str,
|
||||
payload: dict,
|
||||
) -> dict:
|
||||
if action == "start":
|
||||
registration = await request_dingtalk_app_registration()
|
||||
return {
|
||||
"status": "pending",
|
||||
"device_code": registration.device_code,
|
||||
"registration_code": registration.device_code,
|
||||
"user_code": registration.user_code,
|
||||
"verification_uri": registration.verification_uri,
|
||||
"verification_uri_complete": registration.verification_uri_complete,
|
||||
"expires_in": registration.expires_in,
|
||||
"interval": registration.interval,
|
||||
}
|
||||
|
||||
if action == "poll":
|
||||
device_code = str(
|
||||
payload.get("device_code") or payload.get("registration_code") or ""
|
||||
).strip()
|
||||
if not device_code:
|
||||
raise PlatformServiceError("Missing device_code", 400)
|
||||
result = await poll_dingtalk_app_registration_once(device_code)
|
||||
if result.get("status") == "created":
|
||||
result["platform_id_suffix"] = random_platform_id_suffix()
|
||||
return result
|
||||
|
||||
raise PlatformServiceError(f"Unsupported action: {action}", 400)
|
||||
|
||||
async def _handle_weixin_oc_registration(
|
||||
self,
|
||||
action: str,
|
||||
payload: dict,
|
||||
platform_config: dict,
|
||||
) -> dict:
|
||||
if action == "start":
|
||||
registration = await request_weixin_oc_login_qr(platform_config)
|
||||
return {
|
||||
"status": "pending",
|
||||
"registration_code": registration.qrcode,
|
||||
"qrcode": registration.qrcode,
|
||||
"qrcode_img_content": registration.qrcode_img_content,
|
||||
"interval": registration.interval,
|
||||
}
|
||||
|
||||
if action == "poll":
|
||||
qrcode = str(
|
||||
payload.get("qrcode") or payload.get("registration_code") or ""
|
||||
).strip()
|
||||
if not qrcode:
|
||||
raise PlatformServiceError("Missing qrcode", 400)
|
||||
result = await poll_weixin_oc_login_once(
|
||||
platform_config=platform_config,
|
||||
qrcode=qrcode,
|
||||
)
|
||||
if result.get("status") == "created":
|
||||
result["platform_id_suffix"] = random_platform_id_suffix()
|
||||
return result
|
||||
|
||||
raise PlatformServiceError(f"Unsupported action: {action}", 400)
|
||||
923
astrbot/dashboard/services/plugin_page_service.py
Normal file
923
astrbot/dashboard/services/plugin_page_service.py
Normal file
@@ -0,0 +1,923 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import posixpath
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qsl, quote, urlencode, urlsplit, urlunsplit
|
||||
|
||||
import aiofiles
|
||||
import jwt
|
||||
from aiofiles import ospath as aio_ospath
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star.star import StarMetadata
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
|
||||
PLUGIN_PAGE_ASSET_TOKEN_TYPE = "plugin_page_asset"
|
||||
PLUGIN_PAGE_ASSET_TOKEN_TTL_SECONDS = 60
|
||||
PLUGIN_PAGE_ROOT_DIR_NAME = "pages"
|
||||
PLUGIN_PAGE_ENTRY_FILE_NAME = "index.html"
|
||||
PLUGIN_PAGE_BRIDGE_FILE = (
|
||||
Path(__file__).resolve().parent.parent / "plugin_page_bridge.js"
|
||||
)
|
||||
|
||||
_HTML_ASSET_ATTR_RE = re.compile(
|
||||
r"(?P<attr>src|href)=(?P<quote>[\"\'])(?P<url>.*?)(?P=quote)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_CSS_URL_RE = re.compile(
|
||||
r"url\(\s*(?P<quote>[\"\']?)(?P<url>.*?)(?P=quote)\s*\)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_JS_DYNAMIC_IMPORT_RE = re.compile(
|
||||
r"(?P<prefix>\bimport\s*\(\s*)(?P<quote>[\"\'])(?P<url>.*?)(?P=quote)(?P<suffix>\s*\))",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_JS_MODULE_FROM_RE = re.compile(
|
||||
r"(?P<prefix>\b(?:import|export)\s+(?:[^;]*?\s+from\s+))(?P<quote>[\"\'])(?P<url>.*?)(?P=quote)",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_JS_SIDE_EFFECT_IMPORT_RE = re.compile(
|
||||
r"(?P<prefix>\bimport\s+)(?P<quote>[\"\'])(?P<url>[^\"'\r\n]+)(?P=quote)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginPage:
|
||||
name: str
|
||||
title: str
|
||||
entry_file: str = PLUGIN_PAGE_ENTRY_FILE_NAME
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginPageContentPayload:
|
||||
content: str | bytes
|
||||
content_type: str
|
||||
|
||||
|
||||
class PluginPageServiceError(Exception):
|
||||
def __init__(self, message: str, status_code: int = 400) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class PluginPageService:
|
||||
def __init__(
|
||||
self,
|
||||
plugin_manager: PluginManager,
|
||||
core_lifecycle: AstrBotCoreLifecycle | None = None,
|
||||
config: AstrBotConfig | None = None,
|
||||
) -> None:
|
||||
self.plugin_manager = plugin_manager
|
||||
self.config = config or (
|
||||
core_lifecycle.astrbot_config if core_lifecycle is not None else None
|
||||
)
|
||||
self.bridge_file = PLUGIN_PAGE_BRIDGE_FILE
|
||||
|
||||
def _jwt_secret(self) -> str | None:
|
||||
if self.config is None:
|
||||
return None
|
||||
return self.config.get("dashboard", {}).get("jwt_secret")
|
||||
|
||||
def get_plugin_metadata_by_name(self, plugin_name: str) -> StarMetadata | None:
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
if plugin.name == plugin_name:
|
||||
return plugin
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_by_path(source: dict | None, key: str):
|
||||
if not isinstance(source, dict) or not key:
|
||||
return None
|
||||
current = source
|
||||
for part in key.split("."):
|
||||
if not isinstance(current, dict) or part not in current:
|
||||
return None
|
||||
current = current[part]
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def apply_theme_to_html(html: str, theme: str) -> str:
|
||||
def _replace_html_tag(m: re.Match) -> str:
|
||||
attrs = m.group(1) or ""
|
||||
attrs = re.sub(
|
||||
r'\s+data-theme\s*=\s*["\'][^"\']*["\']',
|
||||
"",
|
||||
attrs,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return f'<html{attrs} data-theme="{theme}">'
|
||||
|
||||
html = re.sub(
|
||||
r"<html(\b[^>]*)>",
|
||||
_replace_html_tag,
|
||||
html,
|
||||
count=1,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
meta_tag = f'<meta name="color-scheme" content="{theme}">'
|
||||
html = re.sub(
|
||||
r'<meta\s[^>]*name\s*=\s*["\']color-scheme["\'][^>]*>',
|
||||
"",
|
||||
html,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
head_match = re.search(r"<head\b[^>]*>", html, re.IGNORECASE)
|
||||
if head_match:
|
||||
html = html.replace(
|
||||
head_match.group(0), f"{head_match.group(0)}{meta_tag}", 1
|
||||
)
|
||||
else:
|
||||
html = re.sub(
|
||||
r"(<html\b[^>]*>)",
|
||||
rf"\1<head>{meta_tag}</head>",
|
||||
html,
|
||||
count=1,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return html
|
||||
|
||||
def build_initial_context(
|
||||
self,
|
||||
*,
|
||||
asset_token: str,
|
||||
jwt_secret: str | None = None,
|
||||
locale: str,
|
||||
theme: str | None,
|
||||
) -> dict | None:
|
||||
if not asset_token:
|
||||
return None
|
||||
jwt_secret = jwt_secret or self._jwt_secret()
|
||||
if not isinstance(jwt_secret, str) or not jwt_secret.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = jwt.decode(asset_token, jwt_secret, algorithms=["HS256"])
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
if payload.get("token_type") != PLUGIN_PAGE_ASSET_TOKEN_TYPE:
|
||||
return None
|
||||
|
||||
plugin_name = payload.get("plugin_name")
|
||||
page_name = payload.get("page_name")
|
||||
if not isinstance(plugin_name, str) or not isinstance(page_name, str):
|
||||
return None
|
||||
|
||||
plugin = self.get_plugin_metadata_by_name(plugin_name)
|
||||
if not plugin:
|
||||
return None
|
||||
|
||||
resolved_locale = (
|
||||
payload.get("locale") if isinstance(payload.get("locale"), str) else locale
|
||||
)
|
||||
plugin_i18n = plugin.i18n or {}
|
||||
try:
|
||||
plugin_root = self.get_plugin_root_dir(plugin)
|
||||
fresh_i18n = PluginManager._load_plugin_i18n(str(plugin_root))
|
||||
if fresh_i18n:
|
||||
plugin_i18n = fresh_i18n
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
|
||||
locale_data = plugin_i18n.get(resolved_locale)
|
||||
display_name = (
|
||||
self.get_by_path(locale_data, "metadata.display_name")
|
||||
or plugin.display_name
|
||||
or plugin.name
|
||||
)
|
||||
page_title = (
|
||||
self.get_by_path(locale_data, f"pages.{page_name}.title") or page_name
|
||||
)
|
||||
|
||||
return {
|
||||
"pluginName": plugin.name,
|
||||
"displayName": display_name,
|
||||
"pageName": page_name,
|
||||
"pageTitle": page_title,
|
||||
"locale": resolved_locale,
|
||||
"i18n": plugin_i18n,
|
||||
"isDark": theme == "dark",
|
||||
}
|
||||
|
||||
async def get_plugin_page_entry_config(
|
||||
self,
|
||||
*,
|
||||
plugin_name: str | None,
|
||||
page_name: str | None,
|
||||
jwt_secret: str | None = None,
|
||||
username: str | None,
|
||||
locale: str,
|
||||
) -> dict:
|
||||
if not plugin_name:
|
||||
raise PluginPageServiceError("缺少插件名")
|
||||
if not page_name:
|
||||
raise PluginPageServiceError("缺少 Page 名称")
|
||||
|
||||
plugin = self.get_plugin_metadata_by_name(plugin_name)
|
||||
if not plugin:
|
||||
raise PluginPageServiceError("插件不存在")
|
||||
if not plugin.activated:
|
||||
raise PluginPageServiceError("插件未启用")
|
||||
|
||||
page = await self.serialize_plugin_page_for_request(
|
||||
plugin,
|
||||
page_name,
|
||||
include_content_path=True,
|
||||
jwt_secret=jwt_secret,
|
||||
username=username,
|
||||
locale=locale,
|
||||
)
|
||||
if not page:
|
||||
raise PluginPageServiceError("插件 Page 不存在")
|
||||
return page
|
||||
|
||||
async def serialize_plugin_page_for_request(
|
||||
self,
|
||||
plugin: StarMetadata,
|
||||
page_name: str,
|
||||
*,
|
||||
include_content_path: bool = False,
|
||||
jwt_secret: str | None = None,
|
||||
username: str | None,
|
||||
locale: str,
|
||||
) -> dict | None:
|
||||
asset_token = ""
|
||||
if include_content_path:
|
||||
plugin_name = plugin.name.strip() if isinstance(plugin.name, str) else ""
|
||||
asset_token = (
|
||||
self.issue_plugin_page_asset_token(
|
||||
plugin_name=plugin_name,
|
||||
page_name=page_name,
|
||||
jwt_secret=jwt_secret or self._jwt_secret(),
|
||||
username=username,
|
||||
locale=locale,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
return await self.serialize_plugin_page(
|
||||
plugin,
|
||||
page_name,
|
||||
include_content_path=include_content_path,
|
||||
asset_token=asset_token,
|
||||
)
|
||||
|
||||
def prepare_plugin_page_query_params(
|
||||
self,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
*,
|
||||
asset_token: str,
|
||||
jwt_secret: str | None = None,
|
||||
username: str | None,
|
||||
locale: str,
|
||||
theme: str | None,
|
||||
) -> dict[str, str] | None:
|
||||
if not asset_token:
|
||||
asset_token = (
|
||||
self.issue_plugin_page_asset_token(
|
||||
plugin_name=plugin_name,
|
||||
page_name=page_name,
|
||||
jwt_secret=jwt_secret or self._jwt_secret(),
|
||||
username=username,
|
||||
locale=locale,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
if not asset_token and not theme:
|
||||
return None
|
||||
|
||||
params: dict[str, str] = {}
|
||||
if asset_token:
|
||||
params["asset_token"] = asset_token
|
||||
if theme:
|
||||
params["theme"] = theme
|
||||
return params
|
||||
|
||||
async def serve_bridge_sdk(
|
||||
self,
|
||||
*,
|
||||
asset_token: str,
|
||||
jwt_secret: str | None = None,
|
||||
locale: str,
|
||||
theme: str | None,
|
||||
) -> PluginPageContentPayload:
|
||||
if not self.bridge_file.is_file():
|
||||
raise PluginPageServiceError(
|
||||
"Plugin Page bridge SDK not found",
|
||||
status_code=404,
|
||||
)
|
||||
bridge_js = await self.read_plugin_page_text(self.bridge_file)
|
||||
initial_context = self.build_initial_context(
|
||||
asset_token=asset_token,
|
||||
jwt_secret=jwt_secret,
|
||||
locale=locale,
|
||||
theme=theme,
|
||||
)
|
||||
if initial_context:
|
||||
context_json = json.dumps(initial_context, ensure_ascii=False)
|
||||
bridge_js += (
|
||||
f"\n;window.AstrBotPluginPage?.__setInitialContext({context_json});\n"
|
||||
)
|
||||
return PluginPageContentPayload(
|
||||
content=bridge_js,
|
||||
content_type="application/javascript; charset=utf-8",
|
||||
)
|
||||
|
||||
async def serve_page_content(
|
||||
self,
|
||||
*,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
asset_path: str,
|
||||
asset_token: str,
|
||||
jwt_secret: str | None = None,
|
||||
username: str | None,
|
||||
locale: str,
|
||||
theme: str | None,
|
||||
) -> PluginPageContentPayload:
|
||||
plugin = self.get_plugin_metadata_by_name(plugin_name)
|
||||
if not plugin:
|
||||
raise PluginPageServiceError("Plugin not found", status_code=404)
|
||||
if not plugin.activated:
|
||||
raise PluginPageServiceError("Plugin is disabled", status_code=403)
|
||||
|
||||
try:
|
||||
page = await self.get_plugin_page(plugin, page_name)
|
||||
file_path = await self.resolve_plugin_page_file(
|
||||
plugin,
|
||||
page.name,
|
||||
asset_path,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise PluginPageServiceError(
|
||||
"Plugin Page asset not found",
|
||||
status_code=404,
|
||||
) from exc
|
||||
|
||||
extra_query_params = self.prepare_plugin_page_query_params(
|
||||
plugin_name,
|
||||
page.name,
|
||||
asset_token=asset_token,
|
||||
jwt_secret=jwt_secret,
|
||||
username=username,
|
||||
locale=locale,
|
||||
theme=theme,
|
||||
)
|
||||
served_asset_path = asset_path or page.entry_file
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix == ".html":
|
||||
html_text = await self.read_plugin_page_text(file_path)
|
||||
return PluginPageContentPayload(
|
||||
content=self.rewrite_plugin_page_html(
|
||||
html_text,
|
||||
plugin_name,
|
||||
page.name,
|
||||
served_asset_path,
|
||||
theme=theme,
|
||||
extra_query_params=extra_query_params,
|
||||
),
|
||||
content_type="text/html; charset=utf-8",
|
||||
)
|
||||
if suffix == ".css":
|
||||
css_text = await self.read_plugin_page_text(file_path)
|
||||
return PluginPageContentPayload(
|
||||
content=self.rewrite_plugin_page_css(
|
||||
css_text,
|
||||
plugin_name,
|
||||
page.name,
|
||||
served_asset_path,
|
||||
extra_query_params=extra_query_params,
|
||||
),
|
||||
content_type="text/css; charset=utf-8",
|
||||
)
|
||||
if suffix in {".js", ".mjs"}:
|
||||
js_text = await self.read_plugin_page_text(file_path)
|
||||
return PluginPageContentPayload(
|
||||
content=self.rewrite_plugin_page_js(
|
||||
js_text,
|
||||
plugin_name,
|
||||
page.name,
|
||||
served_asset_path,
|
||||
extra_query_params=extra_query_params,
|
||||
),
|
||||
content_type="application/javascript; charset=utf-8",
|
||||
)
|
||||
return PluginPageContentPayload(
|
||||
content=await self.read_plugin_page_binary(file_path),
|
||||
content_type=self.guess_plugin_page_mime_type(file_path),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_security_headers() -> dict[str, str]:
|
||||
headers = {
|
||||
"Cache-Control": "no-store",
|
||||
"Referrer-Policy": "no-referrer",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Cross-Origin-Resource-Policy": "cross-origin",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
}
|
||||
|
||||
csp = "object-src 'none'; base-uri 'self'"
|
||||
if os.environ.get("ASTRBOT_LAUNCHER") not in ("1", "true"):
|
||||
headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||
csp = f"frame-ancestors 'self'; {csp}"
|
||||
headers["Content-Security-Policy"] = csp
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def normalize_plugin_page_path(
|
||||
raw_path: str,
|
||||
*,
|
||||
base_dir: str | None = None,
|
||||
allow_empty: bool = False,
|
||||
) -> str:
|
||||
path = raw_path.replace("\\", "/").strip()
|
||||
if base_dir:
|
||||
path = posixpath.join(base_dir, path)
|
||||
normalized = posixpath.normpath(path)
|
||||
if normalized in {"", "."}:
|
||||
if allow_empty:
|
||||
return ""
|
||||
raise ValueError("Invalid plugin Page asset path")
|
||||
if (
|
||||
normalized.startswith("../")
|
||||
or normalized == ".."
|
||||
or normalized.startswith("/")
|
||||
):
|
||||
raise ValueError("Invalid plugin Page asset path")
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def normalize_plugin_page_name(raw_name: str) -> str:
|
||||
page_name = raw_name.strip()
|
||||
if not page_name:
|
||||
raise ValueError("Invalid plugin Page name")
|
||||
normalized = posixpath.normpath(page_name.replace("\\", "/"))
|
||||
if (
|
||||
normalized != page_name
|
||||
or normalized in {".", ".."}
|
||||
or normalized.startswith(".")
|
||||
or "/" in page_name
|
||||
or "\\" in page_name
|
||||
):
|
||||
raise ValueError("Invalid plugin Page name")
|
||||
return page_name
|
||||
|
||||
def get_plugin_root_dir(self, plugin: StarMetadata) -> Path:
|
||||
if not plugin.root_dir_name:
|
||||
raise FileNotFoundError("Plugin directory metadata is missing")
|
||||
|
||||
base_dir = Path(
|
||||
self.plugin_manager.reserved_plugin_path
|
||||
if plugin.reserved
|
||||
else self.plugin_manager.plugin_store_path
|
||||
).resolve(strict=False)
|
||||
plugin_root = (base_dir / plugin.root_dir_name).resolve(strict=False)
|
||||
plugin_root.relative_to(base_dir)
|
||||
return plugin_root
|
||||
|
||||
async def resolve_plugin_pages_root(self, plugin: StarMetadata) -> Path:
|
||||
plugin_root = self.get_plugin_root_dir(plugin)
|
||||
pages_root = (plugin_root / PLUGIN_PAGE_ROOT_DIR_NAME).resolve(strict=False)
|
||||
pages_root.relative_to(plugin_root)
|
||||
if pages_root == plugin_root:
|
||||
raise FileNotFoundError("Plugin Pages root directory is invalid")
|
||||
if not await aio_ospath.isdir(str(pages_root)):
|
||||
raise FileNotFoundError("Plugin Pages root directory does not exist")
|
||||
return pages_root
|
||||
|
||||
async def discover_plugin_pages(self, plugin: StarMetadata) -> list[PluginPage]:
|
||||
try:
|
||||
pages_root = await self.resolve_plugin_pages_root(plugin)
|
||||
except (FileNotFoundError, ValueError):
|
||||
return []
|
||||
|
||||
pages: list[PluginPage] = []
|
||||
try:
|
||||
page_dirs = sorted(
|
||||
(item for item in pages_root.iterdir() if item.is_dir()),
|
||||
key=lambda item: item.name.lower(),
|
||||
)
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
for page_dir in page_dirs:
|
||||
try:
|
||||
page_name = self.normalize_plugin_page_name(page_dir.name)
|
||||
except ValueError:
|
||||
continue
|
||||
entry_path = page_dir / PLUGIN_PAGE_ENTRY_FILE_NAME
|
||||
if not await aio_ospath.isfile(str(entry_path)):
|
||||
continue
|
||||
pages.append(
|
||||
PluginPage(
|
||||
name=page_name,
|
||||
title=page_name,
|
||||
entry_file=PLUGIN_PAGE_ENTRY_FILE_NAME,
|
||||
)
|
||||
)
|
||||
return pages
|
||||
|
||||
async def get_plugin_page(
|
||||
self,
|
||||
plugin: StarMetadata,
|
||||
page_name: str,
|
||||
) -> PluginPage:
|
||||
normalized_name = self.normalize_plugin_page_name(page_name)
|
||||
for page in await self.discover_plugin_pages(plugin):
|
||||
if page.name == normalized_name:
|
||||
return page
|
||||
raise FileNotFoundError("Plugin Page entry not found")
|
||||
|
||||
async def resolve_plugin_page_root(
|
||||
self,
|
||||
plugin: StarMetadata,
|
||||
page_name: str,
|
||||
) -> Path:
|
||||
normalized_name = self.normalize_plugin_page_name(page_name)
|
||||
pages_root = await self.resolve_plugin_pages_root(plugin)
|
||||
page_root = (pages_root / normalized_name).resolve(strict=False)
|
||||
page_root.relative_to(pages_root)
|
||||
if not await aio_ospath.isdir(str(page_root)):
|
||||
raise FileNotFoundError("Plugin Page root directory does not exist")
|
||||
return page_root
|
||||
|
||||
async def resolve_plugin_page_file(
|
||||
self,
|
||||
plugin: StarMetadata,
|
||||
page_name: str,
|
||||
asset_path: str,
|
||||
) -> Path:
|
||||
page = await self.get_plugin_page(plugin, page_name)
|
||||
page_root = await self.resolve_plugin_page_root(plugin, page.name)
|
||||
target_name = (
|
||||
self.normalize_plugin_page_path(asset_path, allow_empty=True)
|
||||
or page.entry_file
|
||||
)
|
||||
target_path = (page_root / target_name).resolve(strict=False)
|
||||
target_path.relative_to(page_root)
|
||||
if not await aio_ospath.isfile(str(target_path)):
|
||||
raise FileNotFoundError("Plugin Page asset not found")
|
||||
return target_path
|
||||
|
||||
@staticmethod
|
||||
def is_rewritable_asset_url(raw_url: str) -> bool:
|
||||
value = raw_url.strip()
|
||||
lower = value.lower()
|
||||
if not value:
|
||||
return False
|
||||
if value.startswith(("#", "/#")):
|
||||
return False
|
||||
if lower.startswith(
|
||||
(
|
||||
"http://",
|
||||
"https://",
|
||||
"//",
|
||||
"data:",
|
||||
"javascript:",
|
||||
"mailto:",
|
||||
"tel:",
|
||||
"blob:",
|
||||
)
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def resolve_referenced_asset_path(
|
||||
base_asset_path: str,
|
||||
referenced_url: str,
|
||||
) -> str:
|
||||
parts = urlsplit(referenced_url)
|
||||
referenced_path = parts.path.strip()
|
||||
if not referenced_path:
|
||||
raise ValueError("Plugin Page referenced asset path is empty")
|
||||
base_dir = posixpath.dirname(base_asset_path) if base_asset_path else ""
|
||||
normalized = PluginPageService.normalize_plugin_page_path(
|
||||
referenced_path,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
if not normalized:
|
||||
raise ValueError("Plugin Page referenced asset path is invalid")
|
||||
return normalized
|
||||
|
||||
def build_plugin_page_asset_url(
|
||||
self,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
asset_path: str,
|
||||
original_query: str = "",
|
||||
original_fragment: str = "",
|
||||
extra_query_params: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
path = self.build_plugin_page_content_path(plugin_name, page_name, asset_path)
|
||||
query_dict = dict(parse_qsl(original_query, keep_blank_values=True))
|
||||
if extra_query_params:
|
||||
for key, value in extra_query_params.items():
|
||||
if value:
|
||||
query_dict[key] = value
|
||||
query = urlencode(query_dict)
|
||||
return urlunsplit(("", "", path, query, original_fragment))
|
||||
|
||||
@staticmethod
|
||||
def build_plugin_page_content_path(
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
asset_path: str = "",
|
||||
) -> str:
|
||||
encoded_plugin_name = quote(plugin_name, safe="")
|
||||
encoded_page_name = quote(
|
||||
PluginPageService.normalize_plugin_page_name(page_name),
|
||||
safe="",
|
||||
)
|
||||
if not asset_path:
|
||||
return (
|
||||
f"/api/plugin/page/content/{encoded_plugin_name}/{encoded_page_name}/"
|
||||
)
|
||||
safe_asset_path = PluginPageService.normalize_plugin_page_path(
|
||||
asset_path,
|
||||
allow_empty=True,
|
||||
)
|
||||
encoded_path = "/".join(
|
||||
quote(part, safe="") for part in safe_asset_path.split("/")
|
||||
)
|
||||
return (
|
||||
f"/api/plugin/page/content/{encoded_plugin_name}/"
|
||||
f"{encoded_page_name}/{encoded_path}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_page_bridge_sdk_url(
|
||||
extra_query_params: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
query = urlencode(extra_query_params or {})
|
||||
return urlunsplit(("", "", "/api/plugin/page/bridge-sdk.js", query, ""))
|
||||
|
||||
@staticmethod
|
||||
def is_js_relative_module_specifier(raw_url: str) -> bool:
|
||||
value = raw_url.strip()
|
||||
return value.startswith(("./", "../", "/"))
|
||||
|
||||
def rewrite_relative_asset_url(
|
||||
self,
|
||||
raw_url: str,
|
||||
base_asset_path: str,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
extra_query_params: dict[str, str] | None = None,
|
||||
) -> str | None:
|
||||
candidate = raw_url.strip()
|
||||
if not self.is_rewritable_asset_url(candidate):
|
||||
return None
|
||||
parts = urlsplit(candidate)
|
||||
asset_path = self.resolve_referenced_asset_path(base_asset_path, candidate)
|
||||
return self.build_plugin_page_asset_url(
|
||||
plugin_name,
|
||||
page_name,
|
||||
asset_path,
|
||||
original_query=parts.query,
|
||||
original_fragment=parts.fragment,
|
||||
extra_query_params=extra_query_params,
|
||||
)
|
||||
|
||||
def rewrite_plugin_page_html(
|
||||
self,
|
||||
html_text: str,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
entry_asset_path: str,
|
||||
*,
|
||||
theme: str | None,
|
||||
extra_query_params: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
def replace_attr(match: re.Match[str]) -> str:
|
||||
raw_url = match.group("url")
|
||||
attr = match.group("attr")
|
||||
quote_char = match.group("quote")
|
||||
|
||||
if raw_url.strip() == "/api/plugin/page/bridge-sdk.js":
|
||||
url = self.get_plugin_page_bridge_sdk_url(extra_query_params)
|
||||
return f"{attr}={quote_char}{url}{quote_char}"
|
||||
|
||||
if not self.is_rewritable_asset_url(raw_url):
|
||||
return match.group(0)
|
||||
|
||||
try:
|
||||
rewritten_url = self.rewrite_relative_asset_url(
|
||||
raw_url,
|
||||
entry_asset_path,
|
||||
plugin_name,
|
||||
page_name,
|
||||
extra_query_params=extra_query_params,
|
||||
)
|
||||
if not rewritten_url:
|
||||
return match.group(0)
|
||||
return f"{attr}={quote_char}{rewritten_url}{quote_char}"
|
||||
except ValueError:
|
||||
return match.group(0)
|
||||
|
||||
rewritten_html = _HTML_ASSET_ATTR_RE.sub(replace_attr, html_text)
|
||||
if theme:
|
||||
rewritten_html = self.apply_theme_to_html(rewritten_html, theme)
|
||||
if "/api/plugin/page/bridge-sdk.js" not in rewritten_html:
|
||||
bridge_tag = f'<script src="{self.get_plugin_page_bridge_sdk_url(extra_query_params)}"></script>'
|
||||
if "</body>" in rewritten_html:
|
||||
rewritten_html = rewritten_html.replace(
|
||||
"</body>", f"{bridge_tag}</body>", 1
|
||||
)
|
||||
else:
|
||||
rewritten_html += bridge_tag
|
||||
return rewritten_html
|
||||
|
||||
def rewrite_plugin_page_css(
|
||||
self,
|
||||
css_text: str,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
css_asset_path: str,
|
||||
extra_query_params: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
def replace_url(match: re.Match[str]) -> str:
|
||||
raw_url = match.group("url").strip()
|
||||
quote_char = match.group("quote") or ""
|
||||
try:
|
||||
rewritten_url = self.rewrite_relative_asset_url(
|
||||
raw_url,
|
||||
css_asset_path,
|
||||
plugin_name,
|
||||
page_name,
|
||||
extra_query_params=extra_query_params,
|
||||
)
|
||||
if not rewritten_url:
|
||||
return match.group(0)
|
||||
return f"url({quote_char}{rewritten_url}{quote_char})"
|
||||
except ValueError:
|
||||
return match.group(0)
|
||||
|
||||
return _CSS_URL_RE.sub(replace_url, css_text)
|
||||
|
||||
def rewrite_plugin_page_js(
|
||||
self,
|
||||
js_text: str,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
js_asset_path: str,
|
||||
extra_query_params: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
def rewrite_specifier(raw_url: str) -> str:
|
||||
if not self.is_js_relative_module_specifier(raw_url):
|
||||
return raw_url
|
||||
if not self.is_rewritable_asset_url(raw_url):
|
||||
return raw_url
|
||||
rewritten = self.rewrite_relative_asset_url(
|
||||
raw_url,
|
||||
js_asset_path,
|
||||
plugin_name,
|
||||
page_name,
|
||||
extra_query_params=extra_query_params,
|
||||
)
|
||||
return rewritten or raw_url
|
||||
|
||||
def replace_dynamic(match: re.Match[str]) -> str:
|
||||
raw_url = match.group("url")
|
||||
try:
|
||||
rewritten = rewrite_specifier(raw_url)
|
||||
except ValueError:
|
||||
return match.group(0)
|
||||
return (
|
||||
f"{match.group('prefix')}{match.group('quote')}{rewritten}"
|
||||
f"{match.group('quote')}{match.group('suffix')}"
|
||||
)
|
||||
|
||||
def replace_from(match: re.Match[str]) -> str:
|
||||
raw_url = match.group("url")
|
||||
try:
|
||||
rewritten = rewrite_specifier(raw_url)
|
||||
except ValueError:
|
||||
return match.group(0)
|
||||
return (
|
||||
f"{match.group('prefix')}{match.group('quote')}"
|
||||
f"{rewritten}{match.group('quote')}"
|
||||
)
|
||||
|
||||
rewritten_js = _JS_DYNAMIC_IMPORT_RE.sub(replace_dynamic, js_text)
|
||||
rewritten_js = _JS_MODULE_FROM_RE.sub(replace_from, rewritten_js)
|
||||
|
||||
def replace_side_effect(match: re.Match[str]) -> str:
|
||||
raw_url = match.group("url")
|
||||
if raw_url.startswith(("{", "*")):
|
||||
return match.group(0)
|
||||
try:
|
||||
rewritten = rewrite_specifier(raw_url)
|
||||
except ValueError:
|
||||
return match.group(0)
|
||||
return (
|
||||
f"{match.group('prefix')}{match.group('quote')}"
|
||||
f"{rewritten}{match.group('quote')}"
|
||||
)
|
||||
|
||||
return _JS_SIDE_EFFECT_IMPORT_RE.sub(replace_side_effect, rewritten_js)
|
||||
|
||||
@staticmethod
|
||||
async def read_plugin_page_text(file_path: Path) -> str:
|
||||
async with aiofiles.open(file_path, encoding="utf-8") as file:
|
||||
return await file.read()
|
||||
|
||||
@staticmethod
|
||||
async def read_plugin_page_binary(file_path: Path) -> bytes:
|
||||
async with aiofiles.open(file_path, mode="rb") as file:
|
||||
return await file.read()
|
||||
|
||||
@staticmethod
|
||||
def guess_plugin_page_mime_type(file_path: Path) -> str:
|
||||
return mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
|
||||
|
||||
async def serialize_plugin_page(
|
||||
self,
|
||||
plugin: StarMetadata,
|
||||
page_name: str,
|
||||
*,
|
||||
include_content_path: bool = False,
|
||||
asset_token: str = "",
|
||||
) -> dict | None:
|
||||
plugin_name = plugin.name.strip() if isinstance(plugin.name, str) else ""
|
||||
if not plugin_name:
|
||||
return None
|
||||
try:
|
||||
page = await self.get_plugin_page(plugin, page_name)
|
||||
await self.resolve_plugin_page_file(plugin, page.name, "")
|
||||
except (FileNotFoundError, ValueError):
|
||||
return None
|
||||
|
||||
page_data = {
|
||||
"name": page.name,
|
||||
"title": page.title,
|
||||
"i18n_key": f"pages.{page.name}",
|
||||
}
|
||||
if include_content_path:
|
||||
extra_query_params = {"asset_token": asset_token} if asset_token else None
|
||||
page_data["content_path"] = self.build_plugin_page_asset_url(
|
||||
plugin_name,
|
||||
page.name,
|
||||
"",
|
||||
extra_query_params=extra_query_params,
|
||||
)
|
||||
return page_data
|
||||
|
||||
async def serialize_plugin_pages(self, plugin: StarMetadata) -> list[dict]:
|
||||
pages = []
|
||||
for page in await self.discover_plugin_pages(plugin):
|
||||
page_data = await self.serialize_plugin_page(plugin, page.name)
|
||||
if page_data:
|
||||
pages.append(page_data)
|
||||
return pages
|
||||
|
||||
def issue_plugin_page_asset_token(
|
||||
self,
|
||||
*,
|
||||
plugin_name: str,
|
||||
page_name: str,
|
||||
jwt_secret: str | None = None,
|
||||
username: str | None,
|
||||
locale: str,
|
||||
) -> str | None:
|
||||
jwt_secret = jwt_secret or self._jwt_secret()
|
||||
if not isinstance(jwt_secret, str) or not jwt_secret.strip():
|
||||
return None
|
||||
if not isinstance(username, str) or not username.strip():
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"username": username,
|
||||
"token_type": PLUGIN_PAGE_ASSET_TOKEN_TYPE,
|
||||
"plugin_name": plugin_name,
|
||||
"page_name": page_name,
|
||||
"locale": locale,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(seconds=PLUGIN_PAGE_ASSET_TOKEN_TTL_SECONDS),
|
||||
}
|
||||
return cast(str, jwt.encode(payload, jwt_secret, algorithm="HS256"))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PLUGIN_PAGE_ASSET_TOKEN_TYPE",
|
||||
"PLUGIN_PAGE_BRIDGE_FILE",
|
||||
"PLUGIN_PAGE_ENTRY_FILE_NAME",
|
||||
"PLUGIN_PAGE_ROOT_DIR_NAME",
|
||||
"PluginPage",
|
||||
"PluginPageContentPayload",
|
||||
"PluginPageService",
|
||||
"PluginPageServiceError",
|
||||
]
|
||||
1165
astrbot/dashboard/services/plugin_service.py
Normal file
1165
astrbot/dashboard/services/plugin_service.py
Normal file
File diff suppressed because it is too large
Load Diff
119
astrbot/dashboard/services/route_bridge_service.py
Normal file
119
astrbot/dashboard/services/route_bridge_service.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from fastapi import Request, Response
|
||||
|
||||
from astrbot.dashboard.v1.auth import AuthContext
|
||||
|
||||
_BODY_NOT_SET = object()
|
||||
_SKIP_RESPONSE_HEADERS = {
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"server",
|
||||
}
|
||||
|
||||
|
||||
class DashboardRouteBridgeService:
|
||||
"""Forward a v1 request to a dashboard compatibility route after v1 auth."""
|
||||
|
||||
def __init__(self, route_app, jwt_secret: str) -> None:
|
||||
self.route_app = route_app
|
||||
self.jwt_secret = jwt_secret
|
||||
|
||||
def _build_headers(
|
||||
self, request: Request, auth: AuthContext | None
|
||||
) -> dict[str, str]:
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in {"authorization", "host"}
|
||||
}
|
||||
if auth is None:
|
||||
return headers
|
||||
token = jwt.encode(
|
||||
{"username": auth.username},
|
||||
self.jwt_secret,
|
||||
algorithm="HS256",
|
||||
)
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _merge_query(
|
||||
request: Request,
|
||||
query: Mapping[str, Any] | None,
|
||||
*,
|
||||
drop: set[str] | None = None,
|
||||
) -> list[tuple[str, str]]:
|
||||
drop = drop or set()
|
||||
pairs = [
|
||||
(key, value)
|
||||
for key, value in request.query_params.multi_items()
|
||||
if key not in drop
|
||||
]
|
||||
if query:
|
||||
for key, value in query.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, list | tuple):
|
||||
pairs.extend((key, str(item)) for item in value)
|
||||
else:
|
||||
pairs.append((key, str(value)))
|
||||
return pairs
|
||||
|
||||
@staticmethod
|
||||
def _as_response(response: httpx.Response) -> Response:
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in response.headers.items()
|
||||
if key.lower() not in _SKIP_RESPONSE_HEADERS
|
||||
}
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type=headers.get("content-type"),
|
||||
)
|
||||
|
||||
async def forward(
|
||||
self,
|
||||
request: Request,
|
||||
auth: AuthContext | None,
|
||||
*,
|
||||
method: str,
|
||||
target_path: str,
|
||||
query: Mapping[str, Any] | None = None,
|
||||
drop_query: set[str] | None = None,
|
||||
json_body: Any = _BODY_NOT_SET,
|
||||
) -> Response:
|
||||
headers = self._build_headers(request, auth)
|
||||
params = self._merge_query(request, query, drop=drop_query)
|
||||
transport = httpx.ASGITransport(app=self.route_app)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport,
|
||||
base_url="http://dashboard-routes",
|
||||
) as client:
|
||||
if json_body is _BODY_NOT_SET:
|
||||
response = await client.request(
|
||||
method,
|
||||
target_path,
|
||||
params=params,
|
||||
content=await request.body(),
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
response = await client.request(
|
||||
method,
|
||||
target_path,
|
||||
params=params,
|
||||
json=json_body,
|
||||
headers=headers,
|
||||
)
|
||||
return self._as_response(response)
|
||||
715
astrbot/dashboard/services/session_management_service.py
Normal file
715
astrbot/dashboard/services/session_management_service.py
Normal file
@@ -0,0 +1,715 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, select
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import ConversationV2, Preference
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias
|
||||
|
||||
AVAILABLE_SESSION_RULE_KEYS = [
|
||||
"session_service_config",
|
||||
"session_plugin_config",
|
||||
"kb_config",
|
||||
f"provider_perf_{ProviderType.CHAT_COMPLETION.value}",
|
||||
f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}",
|
||||
f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}",
|
||||
]
|
||||
|
||||
|
||||
class SessionManagementServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SessionManagementService:
|
||||
def __init__(
|
||||
self,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
db_helper: BaseDatabase,
|
||||
) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.db_helper = db_helper
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict[str, Any]:
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _is_group_umo(umo: str) -> bool:
|
||||
umo_lower = umo.lower()
|
||||
return ":group:" in umo_lower or ":groupmessage:" in umo_lower
|
||||
|
||||
@staticmethod
|
||||
def _is_private_umo(umo: str) -> bool:
|
||||
umo_lower = umo.lower()
|
||||
return (
|
||||
":private:" in umo_lower
|
||||
or ":friend:" in umo_lower
|
||||
or ":friendmessage:" in umo_lower
|
||||
)
|
||||
|
||||
async def list_known_umos(self) -> list[str]:
|
||||
async with self.db_helper.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(select(ConversationV2.user_id).distinct())
|
||||
umos = {str(row[0]) for row in result.fetchall() if row[0]}
|
||||
|
||||
aliases = await self.db_helper.get_umo_aliases()
|
||||
umos.update(str(alias.umo) for alias in aliases if alias.umo)
|
||||
return sorted(umos)
|
||||
|
||||
async def get_umo_alias_map(self, umos: list[str]) -> dict:
|
||||
return build_umo_alias_map(await self.db_helper.get_umo_aliases(umos))
|
||||
|
||||
def build_umo_info(self, umo: str | None, alias_map: dict) -> dict:
|
||||
umo_str = umo or ""
|
||||
return {
|
||||
"umo": umo_str,
|
||||
**parse_umo(umo_str),
|
||||
**serialize_umo_alias(alias_map.get(umo_str), umo_str),
|
||||
}
|
||||
|
||||
async def list_active_umos(self) -> dict:
|
||||
umos = await self.list_known_umos()
|
||||
alias_map = await self.get_umo_alias_map(umos)
|
||||
return {
|
||||
"umos": umos,
|
||||
"umo_infos": [self.build_umo_info(umo, alias_map) for umo in umos],
|
||||
}
|
||||
|
||||
async def get_umos_by_scope(
|
||||
self,
|
||||
scope: str,
|
||||
group_id: str = "",
|
||||
) -> list[str]:
|
||||
if scope == "custom_group":
|
||||
if not group_id:
|
||||
raise SessionManagementServiceError("请指定分组 ID")
|
||||
groups = self.get_groups()
|
||||
if group_id not in groups:
|
||||
raise SessionManagementServiceError(f"分组 '{group_id}' 不存在")
|
||||
return groups[group_id].get("umos", [])
|
||||
|
||||
all_umos = await self.list_known_umos()
|
||||
if scope == "group":
|
||||
return [umo for umo in all_umos if self._is_group_umo(umo)]
|
||||
if scope == "private":
|
||||
return [umo for umo in all_umos if self._is_private_umo(umo)]
|
||||
if scope == "all":
|
||||
return all_umos
|
||||
return []
|
||||
|
||||
async def get_umo_rules(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
search: str = "",
|
||||
) -> tuple[dict, int]:
|
||||
umo_rules = {}
|
||||
async with self.db_helper.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(Preference).where(
|
||||
col(Preference.scope) == "umo",
|
||||
col(Preference.key).in_(AVAILABLE_SESSION_RULE_KEYS),
|
||||
)
|
||||
)
|
||||
prefs = result.scalars().all()
|
||||
for pref in prefs:
|
||||
umo_id = pref.scope_id
|
||||
if umo_id not in umo_rules:
|
||||
umo_rules[umo_id] = {}
|
||||
if pref.key == "session_plugin_config" and umo_id in pref.value["val"]:
|
||||
umo_rules[umo_id][pref.key] = pref.value["val"][umo_id]
|
||||
else:
|
||||
umo_rules[umo_id][pref.key] = pref.value["val"]
|
||||
|
||||
alias_map = await self.get_umo_alias_map(list(umo_rules.keys()))
|
||||
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_rules = {}
|
||||
for umo_id, rules in umo_rules.items():
|
||||
if search_lower in umo_id.lower():
|
||||
filtered_rules[umo_id] = rules
|
||||
continue
|
||||
|
||||
svc_config = rules.get("session_service_config", {})
|
||||
custom_name = svc_config.get("custom_name", "") if svc_config else ""
|
||||
if custom_name and search_lower in custom_name.lower():
|
||||
filtered_rules[umo_id] = rules
|
||||
continue
|
||||
|
||||
alias_info = serialize_umo_alias(alias_map.get(umo_id), umo_id)
|
||||
if any(
|
||||
search_lower in alias_info[key].lower()
|
||||
for key in ("auto_name", "user_alias", "display_name")
|
||||
if alias_info.get(key)
|
||||
):
|
||||
filtered_rules[umo_id] = rules
|
||||
umo_rules = filtered_rules
|
||||
|
||||
total = len(umo_rules)
|
||||
all_umo_ids = list(umo_rules.keys())
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated_umo_ids = all_umo_ids[start_idx:end_idx]
|
||||
|
||||
return {umo_id: umo_rules[umo_id] for umo_id in paginated_umo_ids}, total
|
||||
|
||||
async def list_session_rules(
|
||||
self,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
search: str,
|
||||
) -> dict:
|
||||
page, page_size = self._normalize_page(page, page_size, default_page_size=10)
|
||||
umo_rules, total = await self.get_umo_rules(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
)
|
||||
|
||||
alias_map = await self.get_umo_alias_map(list(umo_rules.keys()))
|
||||
rules_list = [
|
||||
{
|
||||
"rules": rules,
|
||||
**self.build_umo_info(umo, alias_map),
|
||||
}
|
||||
for umo, rules in umo_rules.items()
|
||||
]
|
||||
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
persona_mgr = getattr(self.core_lifecycle, "persona_mgr", None)
|
||||
plugin_manager = getattr(self.core_lifecycle, "plugin_manager", None)
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
available_personas = [
|
||||
{"name": p["name"], "prompt": p.get("prompt", "")}
|
||||
for p in getattr(persona_mgr, "personas_v3", [])
|
||||
]
|
||||
available_plugins = []
|
||||
if plugin_manager and getattr(plugin_manager, "context", None):
|
||||
available_plugins = [
|
||||
{
|
||||
"name": p.name,
|
||||
"display_name": p.display_name or p.name,
|
||||
"desc": p.desc,
|
||||
}
|
||||
for p in plugin_manager.context.get_all_stars()
|
||||
if not p.reserved and p.name
|
||||
]
|
||||
|
||||
available_kbs = []
|
||||
if kb_manager:
|
||||
try:
|
||||
kbs = await kb_manager.list_kbs()
|
||||
available_kbs = [
|
||||
{
|
||||
"kb_id": kb.kb_id,
|
||||
"kb_name": kb.kb_name,
|
||||
"emoji": kb.emoji,
|
||||
}
|
||||
for kb in kbs
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取知识库列表失败: {exc!s}")
|
||||
|
||||
return {
|
||||
"rules": rules_list,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"available_personas": available_personas,
|
||||
"available_chat_providers": self._serialize_provider_insts(
|
||||
getattr(provider_manager, "provider_insts", [])
|
||||
),
|
||||
"available_stt_providers": self._serialize_provider_insts(
|
||||
getattr(provider_manager, "stt_provider_insts", [])
|
||||
),
|
||||
"available_tts_providers": self._serialize_provider_insts(
|
||||
getattr(provider_manager, "tts_provider_insts", [])
|
||||
),
|
||||
"available_plugins": available_plugins,
|
||||
"available_kbs": available_kbs,
|
||||
"available_rule_keys": AVAILABLE_SESSION_RULE_KEYS,
|
||||
}
|
||||
|
||||
async def list_session_rules_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
page,
|
||||
page_size,
|
||||
search,
|
||||
) -> dict:
|
||||
return await self.list_session_rules(
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 10),
|
||||
search=str(search or "").strip(),
|
||||
)
|
||||
|
||||
async def update_session_rule(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
umo = payload.get("umo")
|
||||
rule_key = payload.get("rule_key")
|
||||
rule_value = payload.get("rule_value")
|
||||
|
||||
if not umo:
|
||||
raise SessionManagementServiceError("缺少必要参数: umo")
|
||||
if not rule_key:
|
||||
raise SessionManagementServiceError("缺少必要参数: rule_key")
|
||||
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
|
||||
raise SessionManagementServiceError(f"不支持的规则键: {rule_key}")
|
||||
|
||||
if rule_key == "session_plugin_config":
|
||||
rule_value = {umo: rule_value}
|
||||
|
||||
await sp.session_put(umo, rule_key, rule_value)
|
||||
return {"message": f"规则 {rule_key} 已更新", "umo": umo}
|
||||
|
||||
async def delete_session_rule(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
umo = payload.get("umo")
|
||||
rule_key = payload.get("rule_key")
|
||||
|
||||
if not umo:
|
||||
raise SessionManagementServiceError("缺少必要参数: umo")
|
||||
|
||||
if rule_key:
|
||||
if rule_key not in AVAILABLE_SESSION_RULE_KEYS:
|
||||
raise SessionManagementServiceError(f"不支持的规则键: {rule_key}")
|
||||
await sp.session_remove(umo, rule_key)
|
||||
return {"message": f"规则 {rule_key} 已删除", "umo": umo}
|
||||
|
||||
await sp.clear_async("umo", umo)
|
||||
return {"message": "所有规则已删除", "umo": umo}
|
||||
|
||||
async def delete_session_rules(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
if payload.get("umo") and not payload.get("umos") and not payload.get("scope"):
|
||||
return await self.delete_session_rule(payload)
|
||||
return await self.batch_delete_session_rule(payload)
|
||||
|
||||
async def batch_delete_session_rule(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
umos = payload.get("umos", [])
|
||||
scope = payload.get("scope", "")
|
||||
group_id = payload.get("group_id", "")
|
||||
rule_key = payload.get("rule_key")
|
||||
|
||||
if scope and not umos:
|
||||
umos = await self.get_umos_by_scope(scope, group_id)
|
||||
|
||||
if not umos:
|
||||
raise SessionManagementServiceError("缺少必要参数: umos 或有效的 scope")
|
||||
if not isinstance(umos, list):
|
||||
raise SessionManagementServiceError("参数 umos 必须是数组")
|
||||
if rule_key and rule_key not in AVAILABLE_SESSION_RULE_KEYS:
|
||||
raise SessionManagementServiceError(f"不支持的规则键: {rule_key}")
|
||||
|
||||
success_count = 0
|
||||
failed_umos = []
|
||||
for umo in umos:
|
||||
try:
|
||||
if rule_key:
|
||||
await sp.session_remove(umo, rule_key)
|
||||
else:
|
||||
await sp.clear_async("umo", umo)
|
||||
success_count += 1
|
||||
except Exception as exc:
|
||||
logger.error(f"删除 umo {umo} 的规则失败: {exc!s}")
|
||||
failed_umos.append(umo)
|
||||
|
||||
message = f"已删除 {success_count} 条规则"
|
||||
if rule_key:
|
||||
message = f"已删除 {success_count} 条 {rule_key} 规则"
|
||||
|
||||
result = {
|
||||
"message": message,
|
||||
"success_count": success_count,
|
||||
}
|
||||
if failed_umos:
|
||||
result.update(
|
||||
{
|
||||
"message": f"{message},{len(failed_umos)} 条删除失败",
|
||||
"failed_umos": failed_umos,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
async def list_all_umos_with_status(
|
||||
self,
|
||||
*,
|
||||
page: int,
|
||||
page_size: int,
|
||||
search: str,
|
||||
message_type: str,
|
||||
platform: str,
|
||||
) -> dict:
|
||||
page, page_size = self._normalize_page(page, page_size, default_page_size=20)
|
||||
all_umos = await self.list_known_umos()
|
||||
alias_map = await self.get_umo_alias_map(all_umos)
|
||||
umo_rules, _ = await self.get_umo_rules(page=1, page_size=99999, search="")
|
||||
|
||||
umos_with_status = []
|
||||
for umo in all_umos:
|
||||
umo_info = self.build_umo_info(umo, alias_map)
|
||||
umo_platform = umo_info["platform"]
|
||||
umo_message_type = umo_info["message_type"]
|
||||
|
||||
if message_type != "all":
|
||||
if message_type == "group" and umo_message_type not in [
|
||||
"group",
|
||||
"GroupMessage",
|
||||
]:
|
||||
continue
|
||||
if message_type == "private" and umo_message_type not in [
|
||||
"private",
|
||||
"FriendMessage",
|
||||
"friend",
|
||||
]:
|
||||
continue
|
||||
|
||||
if platform and umo_platform != platform:
|
||||
continue
|
||||
|
||||
rules = umo_rules.get(umo, {})
|
||||
svc_config = rules.get("session_service_config", {})
|
||||
|
||||
custom_name = svc_config.get("custom_name", "") if svc_config else ""
|
||||
session_enabled = (
|
||||
svc_config.get("session_enabled", True) if svc_config else True
|
||||
)
|
||||
llm_enabled = svc_config.get("llm_enabled", True) if svc_config else True
|
||||
tts_enabled = svc_config.get("tts_enabled", True) if svc_config else True
|
||||
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
search_targets = [
|
||||
umo,
|
||||
custom_name,
|
||||
umo_info["auto_name"],
|
||||
umo_info["user_alias"],
|
||||
umo_info["display_name"],
|
||||
]
|
||||
if not any(
|
||||
search_lower in target.lower()
|
||||
for target in search_targets
|
||||
if target
|
||||
):
|
||||
continue
|
||||
|
||||
chat_provider_key = f"provider_perf_{ProviderType.CHAT_COMPLETION.value}"
|
||||
tts_provider_key = f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}"
|
||||
stt_provider_key = f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}"
|
||||
|
||||
umos_with_status.append(
|
||||
{
|
||||
**umo_info,
|
||||
"custom_name": custom_name,
|
||||
"session_enabled": session_enabled,
|
||||
"llm_enabled": llm_enabled,
|
||||
"tts_enabled": tts_enabled,
|
||||
"has_rules": umo in umo_rules,
|
||||
"chat_provider": rules.get(chat_provider_key),
|
||||
"tts_provider": rules.get(tts_provider_key),
|
||||
"stt_provider": rules.get(stt_provider_key),
|
||||
}
|
||||
)
|
||||
|
||||
total = len(umos_with_status)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated = umos_with_status[start_idx:end_idx]
|
||||
platforms = list({u["platform"] for u in umos_with_status})
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
|
||||
return {
|
||||
"sessions": paginated,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"platforms": platforms,
|
||||
"available_chat_providers": self._serialize_provider_insts(
|
||||
getattr(provider_manager, "provider_insts", [])
|
||||
),
|
||||
"available_tts_providers": self._serialize_provider_insts(
|
||||
getattr(provider_manager, "tts_provider_insts", [])
|
||||
),
|
||||
"available_stt_providers": self._serialize_provider_insts(
|
||||
getattr(provider_manager, "stt_provider_insts", [])
|
||||
),
|
||||
}
|
||||
|
||||
async def list_all_umos_with_status_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
page,
|
||||
page_size,
|
||||
search,
|
||||
message_type,
|
||||
platform,
|
||||
) -> dict:
|
||||
return await self.list_all_umos_with_status(
|
||||
page=self._to_int(page, 1),
|
||||
page_size=self._to_int(page_size, 20),
|
||||
search=str(search or "").strip(),
|
||||
message_type=str(message_type or "all"),
|
||||
platform=str(platform or ""),
|
||||
)
|
||||
|
||||
async def batch_update_service(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
umos = payload.get("umos", [])
|
||||
scope = payload.get("scope", "")
|
||||
group_id = payload.get("group_id", "")
|
||||
llm_enabled = payload.get("llm_enabled")
|
||||
tts_enabled = payload.get("tts_enabled")
|
||||
session_enabled = payload.get("session_enabled")
|
||||
|
||||
if llm_enabled is None and tts_enabled is None and session_enabled is None:
|
||||
raise SessionManagementServiceError("至少需要指定一个要修改的状态")
|
||||
|
||||
if scope and not umos:
|
||||
umos = await self.get_umos_by_scope(scope, group_id)
|
||||
|
||||
if not umos:
|
||||
raise SessionManagementServiceError("没有找到符合条件的会话")
|
||||
|
||||
success_count = 0
|
||||
failed_umos = []
|
||||
|
||||
for umo in umos:
|
||||
try:
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=umo)
|
||||
or {}
|
||||
)
|
||||
|
||||
if llm_enabled is not None:
|
||||
session_config["llm_enabled"] = llm_enabled
|
||||
if tts_enabled is not None:
|
||||
session_config["tts_enabled"] = tts_enabled
|
||||
if session_enabled is not None:
|
||||
session_config["session_enabled"] = session_enabled
|
||||
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
scope="umo",
|
||||
scope_id=umo,
|
||||
)
|
||||
success_count += 1
|
||||
except Exception as exc:
|
||||
logger.error(f"更新 {umo} 服务状态失败: {exc!s}")
|
||||
failed_umos.append(umo)
|
||||
|
||||
status_changes = []
|
||||
if llm_enabled is not None:
|
||||
status_changes.append(f"LLM={'启用' if llm_enabled else '禁用'}")
|
||||
if tts_enabled is not None:
|
||||
status_changes.append(f"TTS={'启用' if tts_enabled else '禁用'}")
|
||||
if session_enabled is not None:
|
||||
status_changes.append(f"会话={'启用' if session_enabled else '禁用'}")
|
||||
|
||||
return {
|
||||
"message": f"已更新 {success_count} 个会话 ({', '.join(status_changes)})",
|
||||
"success_count": success_count,
|
||||
"failed_count": len(failed_umos),
|
||||
"failed_umos": failed_umos,
|
||||
}
|
||||
|
||||
async def batch_update_provider(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
umos = payload.get("umos", [])
|
||||
scope = payload.get("scope", "")
|
||||
provider_type = payload.get("provider_type")
|
||||
provider_id = payload.get("provider_id")
|
||||
|
||||
if not provider_type or not provider_id:
|
||||
raise SessionManagementServiceError(
|
||||
"缺少必要参数: provider_type, provider_id"
|
||||
)
|
||||
|
||||
provider_type_map = {
|
||||
"chat_completion": ProviderType.CHAT_COMPLETION,
|
||||
"text_to_speech": ProviderType.TEXT_TO_SPEECH,
|
||||
"speech_to_text": ProviderType.SPEECH_TO_TEXT,
|
||||
}
|
||||
if provider_type not in provider_type_map:
|
||||
raise SessionManagementServiceError(
|
||||
f"不支持的 provider_type: {provider_type}"
|
||||
)
|
||||
|
||||
group_id = payload.get("group_id", "")
|
||||
if scope and not umos:
|
||||
umos = await self.get_umos_by_scope(scope, group_id)
|
||||
|
||||
if not umos:
|
||||
raise SessionManagementServiceError("没有找到符合条件的会话")
|
||||
|
||||
success_count = 0
|
||||
failed_umos = []
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
|
||||
for umo in umos:
|
||||
try:
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_map[provider_type],
|
||||
umo=umo,
|
||||
)
|
||||
success_count += 1
|
||||
except Exception as exc:
|
||||
logger.error(f"更新 {umo} Provider 失败: {exc!s}")
|
||||
failed_umos.append(umo)
|
||||
|
||||
return {
|
||||
"message": f"已更新 {success_count} 个会话的 {provider_type} 为 {provider_id}",
|
||||
"success_count": success_count,
|
||||
"failed_count": len(failed_umos),
|
||||
"failed_umos": failed_umos,
|
||||
}
|
||||
|
||||
def get_groups(self) -> dict:
|
||||
return sp.get("session_groups", {})
|
||||
|
||||
def save_groups(self, groups: dict) -> None:
|
||||
sp.put("session_groups", groups)
|
||||
|
||||
def list_groups(self) -> dict:
|
||||
groups = self.get_groups()
|
||||
return {
|
||||
"groups": [
|
||||
{
|
||||
"id": group_id,
|
||||
"name": group_data.get("name", ""),
|
||||
"umos": group_data.get("umos", []),
|
||||
"umo_count": len(group_data.get("umos", [])),
|
||||
}
|
||||
for group_id, group_data in groups.items()
|
||||
]
|
||||
}
|
||||
|
||||
def create_group(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
name = str(payload.get("name", "")).strip()
|
||||
umos = payload.get("umos", [])
|
||||
|
||||
if not name:
|
||||
raise SessionManagementServiceError("分组名称不能为空")
|
||||
|
||||
groups = self.get_groups()
|
||||
group_id = str(uuid.uuid4())[:8]
|
||||
groups[group_id] = {
|
||||
"name": name,
|
||||
"umos": umos,
|
||||
}
|
||||
self.save_groups(groups)
|
||||
|
||||
return {
|
||||
"message": f"分组 '{name}' 创建成功",
|
||||
"group": {
|
||||
"id": group_id,
|
||||
"name": name,
|
||||
"umos": umos,
|
||||
"umo_count": len(umos),
|
||||
},
|
||||
}
|
||||
|
||||
def update_group(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
group_id = payload.get("id") or payload.get("group_id")
|
||||
name = payload.get("name")
|
||||
umos = payload.get("umos")
|
||||
add_umos = payload.get("add_umos", [])
|
||||
remove_umos = payload.get("remove_umos", [])
|
||||
|
||||
if not group_id:
|
||||
raise SessionManagementServiceError("分组 ID 不能为空")
|
||||
|
||||
groups = self.get_groups()
|
||||
if group_id not in groups:
|
||||
raise SessionManagementServiceError(f"分组 '{group_id}' 不存在")
|
||||
|
||||
group = groups[group_id]
|
||||
if name is not None:
|
||||
group["name"] = name.strip()
|
||||
|
||||
if umos is not None:
|
||||
group["umos"] = umos
|
||||
else:
|
||||
current_umos = set(group.get("umos", []))
|
||||
if add_umos:
|
||||
current_umos.update(add_umos)
|
||||
if remove_umos:
|
||||
current_umos.difference_update(remove_umos)
|
||||
group["umos"] = list(current_umos)
|
||||
|
||||
self.save_groups(groups)
|
||||
|
||||
return {
|
||||
"message": f"分组 '{group['name']}' 更新成功",
|
||||
"group": {
|
||||
"id": group_id,
|
||||
"name": group["name"],
|
||||
"umos": group["umos"],
|
||||
"umo_count": len(group["umos"]),
|
||||
},
|
||||
}
|
||||
|
||||
def delete_group(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
group_id = payload.get("id") or payload.get("group_id")
|
||||
|
||||
if not group_id:
|
||||
raise SessionManagementServiceError("分组 ID 不能为空")
|
||||
|
||||
groups = self.get_groups()
|
||||
if group_id not in groups:
|
||||
raise SessionManagementServiceError(f"分组 '{group_id}' 不存在")
|
||||
|
||||
group_name = groups[group_id].get("name", group_id)
|
||||
del groups[group_id]
|
||||
self.save_groups(groups)
|
||||
return {"message": f"分组 '{group_name}' 已删除"}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_page(
|
||||
page: int,
|
||||
page_size: int,
|
||||
*,
|
||||
default_page_size: int,
|
||||
) -> tuple[int, int]:
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = default_page_size
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
return page, page_size
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _serialize_provider_insts(provider_insts: list) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"id": provider.meta().id,
|
||||
"name": provider.meta().id,
|
||||
"model": provider.meta().model,
|
||||
}
|
||||
for provider in provider_insts
|
||||
]
|
||||
877
astrbot/dashboard/services/skills_service.py
Normal file
877
astrbot/dashboard/services/skills_service.py
Normal file
@@ -0,0 +1,877 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core import DEMO_MODE, logger
|
||||
from astrbot.core.computer.computer_client import (
|
||||
_discover_bay_credentials,
|
||||
sync_skills_to_active_sandboxes,
|
||||
)
|
||||
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
|
||||
from astrbot.core.skills.skill_manager import SkillManager
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
_SKILL_FILE_MAX_BYTES = 512 * 1024
|
||||
_EDITABLE_SKILL_FILE_SUFFIXES = {
|
||||
".css",
|
||||
".html",
|
||||
".ini",
|
||||
".js",
|
||||
".json",
|
||||
".md",
|
||||
".py",
|
||||
".sh",
|
||||
".toml",
|
||||
".ts",
|
||||
".txt",
|
||||
".yaml",
|
||||
".yml",
|
||||
}
|
||||
_EDITABLE_SKILL_FILENAMES = {"Dockerfile", "Makefile"}
|
||||
|
||||
|
||||
class SkillsServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillsOperationResult:
|
||||
ok: bool = True
|
||||
data: dict | list | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillArchive:
|
||||
path: Path
|
||||
filename: str
|
||||
|
||||
|
||||
def _to_jsonable(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_jsonable(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_to_jsonable(item) for item in value]
|
||||
if hasattr(value, "model_dump"):
|
||||
return _to_jsonable(value.model_dump())
|
||||
return value
|
||||
|
||||
|
||||
def _to_bool(value: Any, default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _next_available_temp_path(temp_dir: str, filename: str) -> str:
|
||||
stem = Path(filename).stem
|
||||
suffix = Path(filename).suffix
|
||||
candidate = filename
|
||||
index = 1
|
||||
while os.path.exists(os.path.join(temp_dir, candidate)):
|
||||
candidate = f"{stem}_{index}{suffix}"
|
||||
index += 1
|
||||
return os.path.join(temp_dir, candidate)
|
||||
|
||||
|
||||
class SkillsService:
|
||||
def __init__(self, core_lifecycle) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict[str, Any]:
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _ensure_mutation_allowed() -> None:
|
||||
if DEMO_MODE:
|
||||
raise SkillsServiceError(
|
||||
"You are not permitted to do this operation in demo mode"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _save_upload(file: Any, target_path: str) -> None:
|
||||
if hasattr(file, "save"):
|
||||
maybe_awaitable = file.save(target_path)
|
||||
if hasattr(maybe_awaitable, "__await__"):
|
||||
await maybe_awaitable
|
||||
return
|
||||
|
||||
if hasattr(file, "read"):
|
||||
data = file.read()
|
||||
if hasattr(data, "__await__"):
|
||||
data = await data
|
||||
Path(target_path).write_bytes(data)
|
||||
return
|
||||
|
||||
raise SkillsServiceError("Invalid upload file")
|
||||
|
||||
def resolve_local_skill_dir(self, name: str) -> Path:
|
||||
skill_name = str(name or "").strip()
|
||||
if not skill_name:
|
||||
raise ValueError("Missing skill name")
|
||||
if not _SKILL_NAME_RE.match(skill_name):
|
||||
raise ValueError("Invalid skill name")
|
||||
|
||||
skill_mgr = SkillManager()
|
||||
if skill_mgr.is_sandbox_only_skill(skill_name):
|
||||
raise PermissionError(
|
||||
"Sandbox preset skill cannot be opened from local skill files."
|
||||
)
|
||||
|
||||
plugin_skill_dir = skill_mgr._get_plugin_skill_dir(skill_name)
|
||||
if plugin_skill_dir is not None:
|
||||
return plugin_skill_dir.resolve(strict=True)
|
||||
|
||||
skills_root = Path(skill_mgr.skills_root).resolve(strict=True)
|
||||
skill_dir = (skills_root / skill_name).resolve(strict=True)
|
||||
if not skill_dir.is_relative_to(skills_root):
|
||||
raise PermissionError("Invalid skill path")
|
||||
if not skill_dir.is_dir() or not (skill_dir / "SKILL.md").exists():
|
||||
raise FileNotFoundError("Local skill not found")
|
||||
return skill_dir
|
||||
|
||||
@staticmethod
|
||||
def resolve_skill_relative_path(
|
||||
skill_dir: Path,
|
||||
relative_path: str | None,
|
||||
*,
|
||||
expect_file: bool,
|
||||
) -> Path:
|
||||
raw_path = str(relative_path or ".").strip() or "."
|
||||
normalized = Path(raw_path.replace("\\", "/"))
|
||||
if normalized.is_absolute() or ".." in normalized.parts:
|
||||
raise ValueError("Invalid relative path")
|
||||
|
||||
target = (skill_dir / normalized).resolve(strict=True)
|
||||
if not target.is_relative_to(skill_dir):
|
||||
raise PermissionError("Path escapes skill directory")
|
||||
if expect_file and not target.is_file():
|
||||
raise FileNotFoundError("Skill file not found")
|
||||
if not expect_file and not target.is_dir():
|
||||
raise FileNotFoundError("Skill directory not found")
|
||||
return target
|
||||
|
||||
@staticmethod
|
||||
def skill_relative_path(skill_dir: Path, target: Path) -> str:
|
||||
rel = target.relative_to(skill_dir).as_posix()
|
||||
return "" if rel == "." else rel
|
||||
|
||||
@staticmethod
|
||||
def is_editable_skill_file(path: Path) -> bool:
|
||||
return (
|
||||
path.name in _EDITABLE_SKILL_FILENAMES
|
||||
or path.suffix.lower() in _EDITABLE_SKILL_FILE_SUFFIXES
|
||||
)
|
||||
|
||||
def serialize_skill_file_entry(
|
||||
self,
|
||||
skill_dir: Path,
|
||||
path: Path,
|
||||
*,
|
||||
readonly: bool = False,
|
||||
) -> dict:
|
||||
stat = path.stat()
|
||||
is_dir = path.is_dir()
|
||||
return {
|
||||
"name": path.name,
|
||||
"path": self.skill_relative_path(skill_dir, path),
|
||||
"type": "directory" if is_dir else "file",
|
||||
"size": 0 if is_dir else stat.st_size,
|
||||
"editable": (
|
||||
not readonly
|
||||
and (not is_dir)
|
||||
and self.is_editable_skill_file(path)
|
||||
and stat.st_size <= _SKILL_FILE_MAX_BYTES
|
||||
),
|
||||
}
|
||||
|
||||
def get_neo_client_config(self) -> tuple[str, str]:
|
||||
provider_settings = self.core_lifecycle.astrbot_config.get(
|
||||
"provider_settings",
|
||||
{},
|
||||
)
|
||||
sandbox = provider_settings.get("sandbox", {})
|
||||
endpoint = sandbox.get("shipyard_neo_endpoint", "")
|
||||
access_token = sandbox.get("shipyard_neo_access_token", "")
|
||||
|
||||
if not access_token and endpoint:
|
||||
access_token = _discover_bay_credentials(endpoint)
|
||||
|
||||
if not endpoint or not access_token:
|
||||
raise ValueError(
|
||||
"Shipyard Neo endpoint or access token not configured. "
|
||||
"Set them in Dashboard or ensure Bay's credentials.json is accessible."
|
||||
)
|
||||
return endpoint, access_token
|
||||
|
||||
async def with_neo_client(
|
||||
self,
|
||||
operation: Callable[[Any], Awaitable[Any]],
|
||||
) -> SkillsOperationResult:
|
||||
try:
|
||||
endpoint, access_token = self.get_neo_client_config()
|
||||
|
||||
from shipyard_neo import BayClient
|
||||
|
||||
async with BayClient(
|
||||
endpoint_url=endpoint,
|
||||
access_token=access_token,
|
||||
) as client:
|
||||
result = await operation(client)
|
||||
if isinstance(result, SkillsOperationResult):
|
||||
return result
|
||||
return SkillsOperationResult(data=_to_jsonable(result))
|
||||
except ValueError as exc:
|
||||
logger.debug("[Neo] %s", exc)
|
||||
return SkillsOperationResult(ok=False, message=str(exc))
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
return SkillsOperationResult(ok=False, message=str(exc))
|
||||
|
||||
def get_skills(self) -> dict:
|
||||
provider_settings = self.core_lifecycle.astrbot_config.get(
|
||||
"provider_settings", {}
|
||||
)
|
||||
runtime = provider_settings.get("computer_use_runtime", "local")
|
||||
skill_mgr = SkillManager()
|
||||
skills = skill_mgr.list_skills(
|
||||
active_only=False,
|
||||
runtime=runtime,
|
||||
show_sandbox_path=False,
|
||||
)
|
||||
return {
|
||||
"skills": [skill.__dict__ for skill in skills],
|
||||
"runtime": runtime,
|
||||
"sandbox_cache": skill_mgr.get_sandbox_skills_cache_status(),
|
||||
}
|
||||
|
||||
async def upload_skill(self, file: Any | None) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
temp_path = None
|
||||
if not file:
|
||||
raise SkillsServiceError("Missing file")
|
||||
|
||||
filename = os.path.basename(file.filename or "skill.zip")
|
||||
if not filename.lower().endswith(".zip"):
|
||||
raise SkillsServiceError("Only .zip files are supported")
|
||||
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
skill_mgr = SkillManager()
|
||||
temp_path = _next_available_temp_path(temp_dir, filename)
|
||||
|
||||
try:
|
||||
await self._save_upload(file, temp_path)
|
||||
try:
|
||||
skill_name = skill_mgr.install_skill_from_zip(
|
||||
temp_path,
|
||||
overwrite=False,
|
||||
skill_name_hint=Path(filename).stem,
|
||||
)
|
||||
except TypeError:
|
||||
skill_name = skill_mgr.install_skill_from_zip(
|
||||
temp_path,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync uploaded skills to active sandboxes.")
|
||||
|
||||
return SkillsOperationResult(
|
||||
data={"name": skill_name},
|
||||
message="Skill uploaded successfully.",
|
||||
)
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to remove temp skill file: {temp_path}")
|
||||
|
||||
async def batch_upload_skills(self, file_list: list[Any]) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
|
||||
if not file_list:
|
||||
raise SkillsServiceError("No files provided")
|
||||
|
||||
succeeded = []
|
||||
failed = []
|
||||
skipped = []
|
||||
skill_mgr = SkillManager()
|
||||
temp_dir = get_astrbot_temp_path()
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
for file in file_list:
|
||||
filename = os.path.basename(file.filename or "unknown.zip")
|
||||
temp_path = None
|
||||
|
||||
try:
|
||||
if not filename.lower().endswith(".zip"):
|
||||
failed.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"error": "Only .zip files are supported",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
temp_path = _next_available_temp_path(temp_dir, filename)
|
||||
await self._save_upload(file, temp_path)
|
||||
|
||||
try:
|
||||
skill_name = skill_mgr.install_skill_from_zip(
|
||||
temp_path,
|
||||
overwrite=False,
|
||||
skill_name_hint=Path(filename).stem,
|
||||
)
|
||||
except TypeError:
|
||||
try:
|
||||
skill_name = skill_mgr.install_skill_from_zip(
|
||||
temp_path,
|
||||
overwrite=False,
|
||||
)
|
||||
except FileExistsError:
|
||||
skipped.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"name": Path(filename).stem,
|
||||
"error": "Skill already exists.",
|
||||
}
|
||||
)
|
||||
skill_name = None
|
||||
except FileExistsError:
|
||||
skipped.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"name": Path(filename).stem,
|
||||
"error": "Skill already exists.",
|
||||
}
|
||||
)
|
||||
skill_name = None
|
||||
|
||||
if skill_name is None:
|
||||
continue
|
||||
succeeded.append({"filename": filename, "name": skill_name})
|
||||
|
||||
except Exception as exc:
|
||||
failed.append({"filename": filename, "error": str(exc)})
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if succeeded:
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync uploaded skills to active sandboxes.")
|
||||
|
||||
total = len(file_list)
|
||||
success_count = len(succeeded)
|
||||
skipped_count = len(skipped)
|
||||
failed_count = len(failed)
|
||||
data = {
|
||||
"total": total,
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"skipped": skipped,
|
||||
}
|
||||
|
||||
if failed_count == 0 and success_count == total:
|
||||
return SkillsOperationResult(
|
||||
data=data,
|
||||
message=f"All {total} skill(s) uploaded successfully.",
|
||||
)
|
||||
if failed_count == 0 and success_count == 0:
|
||||
return SkillsOperationResult(
|
||||
data=data,
|
||||
message=f"All {total} file(s) were skipped.",
|
||||
)
|
||||
if success_count == 0 and skipped_count == 0:
|
||||
return SkillsOperationResult(
|
||||
ok=False,
|
||||
data=data,
|
||||
message=f"Upload failed for all {total} file(s).",
|
||||
)
|
||||
|
||||
return SkillsOperationResult(
|
||||
data=data,
|
||||
message=f"Partial success: {success_count}/{total} skill(s) uploaded.",
|
||||
)
|
||||
|
||||
def prepare_skill_archive(self, name: str) -> SkillArchive:
|
||||
skill_name = str(name or "").strip()
|
||||
if not skill_name:
|
||||
raise SkillsServiceError("Missing skill name")
|
||||
if not _SKILL_NAME_RE.match(skill_name):
|
||||
raise SkillsServiceError("Invalid skill name")
|
||||
|
||||
skill_mgr = SkillManager()
|
||||
if skill_mgr.is_sandbox_only_skill(skill_name):
|
||||
raise SkillsServiceError(
|
||||
"Sandbox preset skill cannot be downloaded from local skill files."
|
||||
)
|
||||
if skill_mgr.is_plugin_skill(skill_name):
|
||||
raise SkillsServiceError(
|
||||
"Plugin-provided skill cannot be downloaded from local skill files."
|
||||
)
|
||||
|
||||
skill_dir = Path(skill_mgr.skills_root) / skill_name
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_dir.is_dir() or not skill_md.exists():
|
||||
raise SkillsServiceError("Local skill not found")
|
||||
|
||||
export_dir = Path(get_astrbot_temp_path()) / "skill_exports"
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
zip_base = export_dir / skill_name
|
||||
zip_path = zip_base.with_suffix(".zip")
|
||||
if zip_path.exists():
|
||||
zip_path.unlink()
|
||||
|
||||
shutil.make_archive(
|
||||
str(zip_base),
|
||||
"zip",
|
||||
root_dir=str(skill_mgr.skills_root),
|
||||
base_dir=skill_name,
|
||||
)
|
||||
return SkillArchive(path=zip_path, filename=f"{skill_name}.zip")
|
||||
|
||||
def prepare_skill_archive_from_legacy_query(self, name: str | None) -> SkillArchive:
|
||||
return self.prepare_skill_archive(name or "")
|
||||
|
||||
def list_skill_files(self, name: str, relative_path: str | None = "") -> dict:
|
||||
skill_name = str(name or "").strip()
|
||||
readonly = SkillManager().is_plugin_skill(skill_name)
|
||||
skill_dir = self.resolve_local_skill_dir(skill_name)
|
||||
target_dir = self.resolve_skill_relative_path(
|
||||
skill_dir,
|
||||
relative_path,
|
||||
expect_file=False,
|
||||
)
|
||||
|
||||
entries = []
|
||||
for entry in sorted(
|
||||
target_dir.iterdir(),
|
||||
key=lambda item: (not item.is_dir(), item.name.lower()),
|
||||
):
|
||||
try:
|
||||
resolved = entry.resolve(strict=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not resolved.is_relative_to(skill_dir):
|
||||
continue
|
||||
if not resolved.is_dir() and not resolved.is_file():
|
||||
continue
|
||||
entries.append(
|
||||
self.serialize_skill_file_entry(
|
||||
skill_dir,
|
||||
resolved,
|
||||
readonly=readonly,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"name": skill_name,
|
||||
"path": self.skill_relative_path(skill_dir, target_dir),
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
def list_skill_files_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
name: str | None,
|
||||
relative_path: str | None,
|
||||
) -> dict:
|
||||
return self.list_skill_files(name or "", relative_path or "")
|
||||
|
||||
def get_skill_file(self, name: str, relative_path: str | None = "SKILL.md") -> dict:
|
||||
skill_name = str(name or "").strip()
|
||||
skill_dir = self.resolve_local_skill_dir(skill_name)
|
||||
target_file = self.resolve_skill_relative_path(
|
||||
skill_dir,
|
||||
relative_path,
|
||||
expect_file=True,
|
||||
)
|
||||
if not self.is_editable_skill_file(target_file):
|
||||
raise SkillsServiceError("Unsupported file type")
|
||||
|
||||
size = target_file.stat().st_size
|
||||
if size > _SKILL_FILE_MAX_BYTES:
|
||||
raise SkillsServiceError("File is too large")
|
||||
|
||||
try:
|
||||
content = target_file.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise SkillsServiceError("File is not valid UTF-8 text") from exc
|
||||
|
||||
return {
|
||||
"name": skill_name,
|
||||
"path": self.skill_relative_path(skill_dir, target_file),
|
||||
"content": content,
|
||||
"size": size,
|
||||
"editable": not SkillManager().is_plugin_skill(skill_name),
|
||||
}
|
||||
|
||||
def get_skill_file_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
name: str | None,
|
||||
relative_path: str | None,
|
||||
) -> dict:
|
||||
return self.get_skill_file(name or "", relative_path or "SKILL.md")
|
||||
|
||||
async def update_skill_file(self, data: object) -> dict:
|
||||
self._ensure_mutation_allowed()
|
||||
payload = self._payload(data)
|
||||
skill_name = str(payload.get("name") or "").strip()
|
||||
relative_path = payload.get("path", "SKILL.md")
|
||||
content = payload.get("content")
|
||||
if not isinstance(content, str):
|
||||
raise SkillsServiceError("Missing file content")
|
||||
|
||||
encoded = content.encode("utf-8")
|
||||
if len(encoded) > _SKILL_FILE_MAX_BYTES:
|
||||
raise SkillsServiceError("File content is too large")
|
||||
|
||||
skill_dir = self.resolve_local_skill_dir(skill_name)
|
||||
if SkillManager().is_plugin_skill(skill_name):
|
||||
raise SkillsServiceError("Plugin-provided skill is read-only.")
|
||||
target_file = self.resolve_skill_relative_path(
|
||||
skill_dir,
|
||||
relative_path,
|
||||
expect_file=True,
|
||||
)
|
||||
if not self.is_editable_skill_file(target_file):
|
||||
raise SkillsServiceError("Unsupported file type")
|
||||
|
||||
target_file.write_text(content, encoding="utf-8")
|
||||
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync edited skills to active sandboxes.")
|
||||
|
||||
return {
|
||||
"name": skill_name,
|
||||
"path": self.skill_relative_path(skill_dir, target_file),
|
||||
"size": len(encoded),
|
||||
}
|
||||
|
||||
def update_skill(self, data: object) -> dict:
|
||||
self._ensure_mutation_allowed()
|
||||
payload = self._payload(data)
|
||||
name = payload.get("name")
|
||||
active = payload.get("active", True)
|
||||
if not name:
|
||||
raise SkillsServiceError("Missing skill name")
|
||||
SkillManager().set_skill_active(name, bool(active))
|
||||
return {"name": name, "active": bool(active)}
|
||||
|
||||
async def delete_skill(self, data: object) -> dict:
|
||||
self._ensure_mutation_allowed()
|
||||
payload = self._payload(data)
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise SkillsServiceError("Missing skill name")
|
||||
SkillManager().delete_skill(name)
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync deleted skills to active sandboxes.")
|
||||
return {"name": name}
|
||||
|
||||
async def get_neo_candidates(self, query: dict[str, Any]) -> SkillsOperationResult:
|
||||
logger.info("[Neo] GET /skills/neo/candidates requested.")
|
||||
status = query.get("status")
|
||||
skill_key = query.get("skill_key")
|
||||
limit = int(query.get("limit", 100))
|
||||
offset = int(query.get("offset", 0))
|
||||
|
||||
async def _do(client):
|
||||
candidates = await client.skills.list_candidates(
|
||||
status=status,
|
||||
skill_key=skill_key,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
result = _to_jsonable(candidates)
|
||||
total = result.get("total", "?") if isinstance(result, dict) else "?"
|
||||
logger.info(f"[Neo] Candidates fetched: total={total}")
|
||||
return result
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def get_neo_candidates_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
status: str | None,
|
||||
skill_key: str | None,
|
||||
limit: str | None,
|
||||
offset: str | None,
|
||||
) -> SkillsOperationResult:
|
||||
return await self.get_neo_candidates(
|
||||
self._legacy_query(
|
||||
status=status,
|
||||
skill_key=skill_key,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_neo_releases(self, query: dict[str, Any]) -> SkillsOperationResult:
|
||||
logger.info("[Neo] GET /skills/neo/releases requested.")
|
||||
skill_key = query.get("skill_key")
|
||||
stage = query.get("stage")
|
||||
active_only = _to_bool(query.get("active_only"), False)
|
||||
limit = int(query.get("limit", 100))
|
||||
offset = int(query.get("offset", 0))
|
||||
|
||||
async def _do(client):
|
||||
releases = await client.skills.list_releases(
|
||||
skill_key=skill_key,
|
||||
active_only=active_only,
|
||||
stage=stage,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
result = _to_jsonable(releases)
|
||||
total = result.get("total", "?") if isinstance(result, dict) else "?"
|
||||
logger.info(f"[Neo] Releases fetched: total={total}")
|
||||
return result
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def get_neo_releases_from_legacy_query(
|
||||
self,
|
||||
*,
|
||||
skill_key: str | None,
|
||||
stage: str | None,
|
||||
active_only: str | None,
|
||||
limit: str | None,
|
||||
offset: str | None,
|
||||
) -> SkillsOperationResult:
|
||||
return await self.get_neo_releases(
|
||||
self._legacy_query(
|
||||
skill_key=skill_key,
|
||||
stage=stage,
|
||||
active_only=active_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_neo_payload(self, query: dict[str, Any]) -> SkillsOperationResult:
|
||||
logger.info("[Neo] GET /skills/neo/payload requested.")
|
||||
payload_ref = query.get("payload_ref", "")
|
||||
if not payload_ref:
|
||||
return SkillsOperationResult(ok=False, message="Missing payload_ref")
|
||||
|
||||
async def _do(client):
|
||||
payload = await client.skills.get_payload(payload_ref)
|
||||
logger.info(f"[Neo] Payload fetched: ref={payload_ref}")
|
||||
return payload
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def get_neo_payload_from_legacy_query(
|
||||
self,
|
||||
payload_ref: str | None,
|
||||
) -> SkillsOperationResult:
|
||||
return await self.get_neo_payload(self._legacy_query(payload_ref=payload_ref))
|
||||
|
||||
async def evaluate_neo_candidate(
|
||||
self,
|
||||
data: object,
|
||||
) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
logger.info("[Neo] POST /skills/neo/evaluate requested.")
|
||||
payload = self._payload(data)
|
||||
candidate_id = payload.get("candidate_id")
|
||||
passed_value = payload.get("passed")
|
||||
if not candidate_id or passed_value is None:
|
||||
return SkillsOperationResult(
|
||||
ok=False,
|
||||
message="Missing candidate_id or passed",
|
||||
)
|
||||
passed = _to_bool(passed_value, False)
|
||||
|
||||
async def _do(client):
|
||||
result = await client.skills.evaluate_candidate(
|
||||
candidate_id,
|
||||
passed=passed,
|
||||
score=payload.get("score"),
|
||||
benchmark_id=payload.get("benchmark_id"),
|
||||
report=payload.get("report"),
|
||||
)
|
||||
logger.info(
|
||||
f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}"
|
||||
)
|
||||
return result
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def promote_neo_candidate(self, data: object) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
logger.info("[Neo] POST /skills/neo/promote requested.")
|
||||
payload = self._payload(data)
|
||||
candidate_id = payload.get("candidate_id")
|
||||
stage = payload.get("stage", "canary")
|
||||
sync_to_local = _to_bool(payload.get("sync_to_local"), True)
|
||||
if not candidate_id:
|
||||
return SkillsOperationResult(ok=False, message="Missing candidate_id")
|
||||
if stage not in {"canary", "stable"}:
|
||||
return SkillsOperationResult(
|
||||
ok=False,
|
||||
message="Invalid stage, must be canary/stable",
|
||||
)
|
||||
|
||||
async def _do(client):
|
||||
sync_mgr = NeoSkillSyncManager()
|
||||
result = await sync_mgr.promote_with_optional_sync(
|
||||
client,
|
||||
candidate_id=candidate_id,
|
||||
stage=stage,
|
||||
sync_to_local=sync_to_local,
|
||||
)
|
||||
release_json = result.get("release")
|
||||
logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}")
|
||||
|
||||
sync_json = result.get("sync")
|
||||
did_sync_to_local = bool(sync_json)
|
||||
if did_sync_to_local:
|
||||
logger.info(
|
||||
"[Neo] Stable release synced to local: "
|
||||
f"skill={sync_json.get('local_skill_name', '')}"
|
||||
)
|
||||
|
||||
if result.get("sync_error"):
|
||||
return SkillsOperationResult(
|
||||
ok=False,
|
||||
message=(
|
||||
"Stable promote synced failed and has been rolled back. "
|
||||
f"sync_error={result['sync_error']}"
|
||||
),
|
||||
data={
|
||||
"release": release_json,
|
||||
"rollback": result.get("rollback"),
|
||||
},
|
||||
)
|
||||
|
||||
if not did_sync_to_local:
|
||||
try:
|
||||
await sync_skills_to_active_sandboxes()
|
||||
except Exception:
|
||||
logger.warning("Failed to sync skills to active sandboxes.")
|
||||
|
||||
return {"release": release_json, "sync": sync_json}
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def rollback_neo_release(self, data: object) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
logger.info("[Neo] POST /skills/neo/rollback requested.")
|
||||
payload = self._payload(data)
|
||||
release_id = payload.get("release_id")
|
||||
if not release_id:
|
||||
return SkillsOperationResult(ok=False, message="Missing release_id")
|
||||
|
||||
async def _do(client):
|
||||
result = await client.skills.rollback_release(release_id)
|
||||
logger.info(f"[Neo] Release rolled back: id={release_id}")
|
||||
return result
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def sync_neo_release(self, data: object) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
logger.info("[Neo] POST /skills/neo/sync requested.")
|
||||
payload = self._payload(data)
|
||||
release_id = payload.get("release_id")
|
||||
skill_key = payload.get("skill_key")
|
||||
require_stable = _to_bool(payload.get("require_stable"), True)
|
||||
if not release_id and not skill_key:
|
||||
return SkillsOperationResult(
|
||||
ok=False,
|
||||
message="Missing release_id or skill_key",
|
||||
)
|
||||
|
||||
async def _do(client):
|
||||
sync_mgr = NeoSkillSyncManager()
|
||||
result = await sync_mgr.sync_release(
|
||||
client,
|
||||
release_id=release_id,
|
||||
skill_key=skill_key,
|
||||
require_stable=require_stable,
|
||||
)
|
||||
logger.info(
|
||||
f"[Neo] Release synced to local: skill={result.local_skill_name}, "
|
||||
f"release_id={result.release_id}"
|
||||
)
|
||||
return {
|
||||
"skill_key": result.skill_key,
|
||||
"local_skill_name": result.local_skill_name,
|
||||
"release_id": result.release_id,
|
||||
"candidate_id": result.candidate_id,
|
||||
"payload_ref": result.payload_ref,
|
||||
"map_path": result.map_path,
|
||||
"synced_at": result.synced_at,
|
||||
}
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def delete_neo_candidate(self, data: object) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
logger.info("[Neo] POST /skills/neo/delete-candidate requested.")
|
||||
payload = self._payload(data)
|
||||
candidate_id = payload.get("candidate_id")
|
||||
reason = payload.get("reason")
|
||||
if not candidate_id:
|
||||
return SkillsOperationResult(ok=False, message="Missing candidate_id")
|
||||
|
||||
async def _do(client):
|
||||
result = await client.skills.delete_candidate(candidate_id, reason=reason)
|
||||
logger.info(f"[Neo] Candidate deleted: id={candidate_id}")
|
||||
return result
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
async def delete_neo_release(self, data: object) -> SkillsOperationResult:
|
||||
self._ensure_mutation_allowed()
|
||||
logger.info("[Neo] POST /skills/neo/delete-release requested.")
|
||||
payload = self._payload(data)
|
||||
release_id = payload.get("release_id")
|
||||
reason = payload.get("reason")
|
||||
if not release_id:
|
||||
return SkillsOperationResult(ok=False, message="Missing release_id")
|
||||
|
||||
async def _do(client):
|
||||
result = await client.skills.delete_release(release_id, reason=reason)
|
||||
logger.info(f"[Neo] Release deleted: id={release_id}")
|
||||
return result
|
||||
|
||||
return await self.with_neo_client(_do)
|
||||
|
||||
@staticmethod
|
||||
def _legacy_query(**values: Any) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
for key, value in values.items()
|
||||
if value is not None and value != ""
|
||||
}
|
||||
532
astrbot/dashboard/services/stat_service.py
Normal file
532
astrbot/dashboard/services/stat_service.py
Normal file
@@ -0,0 +1,532 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import cmp_to_key
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import psutil
|
||||
from sqlmodel import col, select
|
||||
|
||||
from astrbot.core import DEMO_MODE, logger
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.helper import check_migration_needed_v4
|
||||
from astrbot.core.db.po import ProviderStat
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.utils.auth_password import (
|
||||
is_default_dashboard_password,
|
||||
is_legacy_dashboard_password,
|
||||
)
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
from astrbot.core.utils.storage_cleaner import StorageCleaner
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
from astrbot.dashboard.password_state import (
|
||||
get_dashboard_password_hash,
|
||||
is_password_change_required,
|
||||
is_password_storage_upgraded,
|
||||
)
|
||||
|
||||
|
||||
class StatServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StatService:
|
||||
def __init__(
|
||||
self,
|
||||
db_helper: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
config: AstrBotConfig,
|
||||
) -> None:
|
||||
self.db_helper = db_helper
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = config
|
||||
self.storage_cleaner = StorageCleaner(config)
|
||||
|
||||
async def restart_core(self) -> None:
|
||||
if DEMO_MODE:
|
||||
raise StatServiceError(
|
||||
"You are not permitted to do this operation in demo mode"
|
||||
)
|
||||
await self.core_lifecycle.restart()
|
||||
|
||||
@staticmethod
|
||||
def get_running_time_components(total_seconds: int):
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return {"hours": hours, "minutes": minutes, "seconds": seconds}
|
||||
|
||||
async def is_default_cred(self):
|
||||
password_change_required = await is_password_change_required(
|
||||
self.db_helper,
|
||||
self.config,
|
||||
)
|
||||
if password_change_required:
|
||||
return not DEMO_MODE
|
||||
|
||||
storage_upgraded = await is_password_storage_upgraded(
|
||||
self.db_helper,
|
||||
self.config,
|
||||
)
|
||||
if not storage_upgraded:
|
||||
return False
|
||||
|
||||
username = self.config["dashboard"]["username"]
|
||||
password = get_dashboard_password_hash(self.config, upgraded=True)
|
||||
return (
|
||||
username == "astrbot" and is_default_dashboard_password(password)
|
||||
) and not DEMO_MODE
|
||||
|
||||
async def get_version(self) -> dict:
|
||||
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
|
||||
storage_upgraded = await is_password_storage_upgraded(
|
||||
self.db_helper,
|
||||
self.config,
|
||||
)
|
||||
password = get_dashboard_password_hash(
|
||||
self.config,
|
||||
upgraded=storage_upgraded,
|
||||
)
|
||||
|
||||
return {
|
||||
"version": VERSION,
|
||||
"dashboard_version": await get_dashboard_version(),
|
||||
"change_pwd_hint": await self.is_default_cred(),
|
||||
"legacy_pwd_hint": is_legacy_dashboard_password(password),
|
||||
"password_upgrade_required": not storage_upgraded,
|
||||
"need_migration": need_migration,
|
||||
}
|
||||
|
||||
def get_start_time(self) -> dict:
|
||||
return {"start_time": self.core_lifecycle.start_time}
|
||||
|
||||
async def get_storage_status(self) -> dict:
|
||||
try:
|
||||
return await asyncio.to_thread(self.storage_cleaner.get_status)
|
||||
except Exception as exc:
|
||||
logger.error("获取存储占用失败", exc_info=True)
|
||||
raise StatServiceError(
|
||||
"获取存储占用失败,请查看后端日志了解详情。"
|
||||
) from exc
|
||||
|
||||
async def cleanup_storage(self, target: str) -> dict:
|
||||
try:
|
||||
return await asyncio.to_thread(self.storage_cleaner.cleanup, target)
|
||||
except ValueError as exc:
|
||||
raise StatServiceError(str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.error("清理存储失败", exc_info=True)
|
||||
raise StatServiceError("清理存储失败,请查看后端日志了解详情。") from exc
|
||||
|
||||
async def cleanup_storage_from_legacy_payload(self, payload: object) -> dict:
|
||||
target = "all"
|
||||
if isinstance(payload, dict):
|
||||
target = str(payload.get("target", "all"))
|
||||
return await self.cleanup_storage(target)
|
||||
|
||||
async def get_stat(self, offset_sec: int) -> dict:
|
||||
try:
|
||||
stat = self.db_helper.get_base_stats(offset_sec)
|
||||
now = int(time.time())
|
||||
start_time = now - offset_sec
|
||||
message_time_based_stats = []
|
||||
|
||||
idx = 0
|
||||
for bucket_end in range(start_time, now, 3600):
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(stat.platform)
|
||||
and stat.platform[idx].timestamp < bucket_end
|
||||
):
|
||||
cnt += stat.platform[idx].count
|
||||
idx += 1
|
||||
message_time_based_stats.append([bucket_end, cnt])
|
||||
|
||||
stat_dict = stat.__dict__
|
||||
|
||||
cpu_percent = psutil.cpu_percent(interval=0.5)
|
||||
thread_count = threading.active_count()
|
||||
|
||||
plugins = self.core_lifecycle.star_context.get_all_stars()
|
||||
plugin_info = []
|
||||
for plugin in plugins:
|
||||
info = {
|
||||
"name": getattr(plugin, "name", plugin.__class__.__name__),
|
||||
"version": getattr(plugin, "version", "1.0.0"),
|
||||
"is_enabled": True,
|
||||
}
|
||||
plugin_info.append(info)
|
||||
|
||||
running_time = self.get_running_time_components(
|
||||
int(time.time()) - self.core_lifecycle.start_time,
|
||||
)
|
||||
|
||||
stat_dict.update(
|
||||
{
|
||||
"platform": self.db_helper.get_grouped_base_stats(
|
||||
offset_sec,
|
||||
).platform,
|
||||
"message_count": self.db_helper.get_total_message_count() or 0,
|
||||
"platform_count": len(
|
||||
self.core_lifecycle.platform_manager.get_insts(),
|
||||
),
|
||||
"plugin_count": len(plugins),
|
||||
"plugins": plugin_info,
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": running_time,
|
||||
"memory": {
|
||||
"process": psutil.Process().memory_info().rss >> 20,
|
||||
"system": psutil.virtual_memory().total >> 20,
|
||||
},
|
||||
"cpu_percent": round(cpu_percent, 1),
|
||||
"thread_count": thread_count,
|
||||
"start_time": self.core_lifecycle.start_time,
|
||||
},
|
||||
)
|
||||
return stat_dict
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise StatServiceError(str(exc)) from exc
|
||||
|
||||
async def get_stat_from_legacy_query(self, offset_sec_raw) -> dict:
|
||||
try:
|
||||
offset_sec = int(offset_sec_raw if offset_sec_raw is not None else 86400)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise StatServiceError("offset_sec must be an integer") from exc
|
||||
return await self.get_stat(offset_sec)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_aware_utc(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
async def get_provider_token_stats(self, days: int) -> dict:
|
||||
try:
|
||||
if days not in (1, 3, 7):
|
||||
days = 1
|
||||
|
||||
local_tz = datetime.now().astimezone().tzinfo or timezone.utc
|
||||
now_local = datetime.now(local_tz)
|
||||
range_start_local = (now_local - timedelta(days=days)).replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
)
|
||||
today_start_local = now_local.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
query_start_local = min(range_start_local, today_start_local)
|
||||
query_start_utc = query_start_local.astimezone(timezone.utc)
|
||||
|
||||
async with self.db_helper.get_db() as session:
|
||||
result = await session.execute(
|
||||
select(ProviderStat)
|
||||
.where(
|
||||
ProviderStat.agent_type == "internal",
|
||||
ProviderStat.created_at >= query_start_utc,
|
||||
)
|
||||
.order_by(col(ProviderStat.created_at).asc())
|
||||
)
|
||||
records = result.scalars().all()
|
||||
|
||||
bucket_timestamps: list[int] = []
|
||||
bucket_cursor = range_start_local
|
||||
while bucket_cursor <= now_local:
|
||||
bucket_timestamps.append(int(bucket_cursor.timestamp() * 1000))
|
||||
bucket_cursor += timedelta(hours=1)
|
||||
|
||||
trend_by_provider: dict[str, dict[int, int]] = defaultdict(
|
||||
lambda: defaultdict(int)
|
||||
)
|
||||
total_by_provider: dict[str, int] = defaultdict(int)
|
||||
total_by_umo: dict[str, int] = defaultdict(int)
|
||||
total_by_bucket: dict[int, int] = defaultdict(int)
|
||||
range_total_tokens = 0
|
||||
range_total_output_tokens = 0
|
||||
range_total_calls = 0
|
||||
range_success_calls = 0
|
||||
range_ttft_total_ms = 0.0
|
||||
range_ttft_samples = 0
|
||||
range_duration_total_ms = 0.0
|
||||
range_duration_samples = 0
|
||||
today_by_model: dict[str, int] = defaultdict(int)
|
||||
today_by_provider: dict[str, int] = defaultdict(int)
|
||||
today_total_tokens = 0
|
||||
today_total_calls = 0
|
||||
|
||||
for record in records:
|
||||
created_at_utc = self._ensure_aware_utc(record.created_at)
|
||||
created_at_local = created_at_utc.astimezone(local_tz)
|
||||
token_total = (
|
||||
record.token_input_other
|
||||
+ record.token_input_cached
|
||||
+ record.token_output
|
||||
)
|
||||
provider_id = record.provider_id or "unknown"
|
||||
provider_model = record.provider_model or "Unknown"
|
||||
|
||||
if created_at_local >= range_start_local:
|
||||
bucket_local = created_at_local.replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
)
|
||||
bucket_ts = int(bucket_local.timestamp() * 1000)
|
||||
trend_by_provider[provider_id][bucket_ts] += token_total
|
||||
total_by_provider[provider_id] += token_total
|
||||
total_by_umo[record.umo or "unknown"] += token_total
|
||||
total_by_bucket[bucket_ts] += token_total
|
||||
range_total_tokens += token_total
|
||||
range_total_calls += 1
|
||||
if record.status != "error":
|
||||
range_success_calls += 1
|
||||
if record.time_to_first_token > 0:
|
||||
range_ttft_total_ms += record.time_to_first_token * 1000
|
||||
range_ttft_samples += 1
|
||||
if record.end_time > record.start_time:
|
||||
range_duration_total_ms += (
|
||||
record.end_time - record.start_time
|
||||
) * 1000
|
||||
range_duration_samples += 1
|
||||
range_total_output_tokens += record.token_output
|
||||
|
||||
if created_at_local >= today_start_local:
|
||||
today_total_calls += 1
|
||||
today_total_tokens += token_total
|
||||
today_by_model[provider_model] += token_total
|
||||
today_by_provider[provider_id] += token_total
|
||||
|
||||
sorted_provider_ids = sorted(
|
||||
total_by_provider.keys(),
|
||||
key=lambda item: total_by_provider[item],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
series = [
|
||||
{
|
||||
"name": provider_id,
|
||||
"data": [
|
||||
[bucket_ts, trend_by_provider[provider_id].get(bucket_ts, 0)]
|
||||
for bucket_ts in bucket_timestamps
|
||||
],
|
||||
"total_tokens": total_by_provider[provider_id],
|
||||
}
|
||||
for provider_id in sorted_provider_ids
|
||||
]
|
||||
|
||||
total_series = [
|
||||
[bucket_ts, total_by_bucket.get(bucket_ts, 0)]
|
||||
for bucket_ts in bucket_timestamps
|
||||
]
|
||||
|
||||
today_by_model_data = [
|
||||
{"provider_model": model_name, "tokens": tokens}
|
||||
for model_name, tokens in sorted(
|
||||
today_by_model.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
today_by_provider_data = [
|
||||
{"provider_id": provider_id, "tokens": tokens}
|
||||
for provider_id, tokens in sorted(
|
||||
today_by_provider.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
range_by_provider_data = [
|
||||
{"provider_id": provider_id, "tokens": tokens}
|
||||
for provider_id, tokens in sorted(
|
||||
total_by_provider.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
range_by_umo_data = [
|
||||
{"umo": umo, "tokens": tokens}
|
||||
for umo, tokens in sorted(
|
||||
total_by_umo.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
"days": days,
|
||||
"trend": {
|
||||
"series": series,
|
||||
"total_series": total_series,
|
||||
},
|
||||
"range_total_tokens": range_total_tokens,
|
||||
"range_total_calls": range_total_calls,
|
||||
"range_avg_ttft_ms": (
|
||||
range_ttft_total_ms / range_ttft_samples
|
||||
if range_ttft_samples
|
||||
else 0
|
||||
),
|
||||
"range_avg_duration_ms": (
|
||||
range_duration_total_ms / range_duration_samples
|
||||
if range_duration_samples
|
||||
else 0
|
||||
),
|
||||
"range_avg_tpm": (
|
||||
range_total_output_tokens / (range_duration_total_ms / 1000 / 60)
|
||||
if range_duration_total_ms > 0
|
||||
else 0
|
||||
),
|
||||
"range_success_rate": (
|
||||
range_success_calls / range_total_calls if range_total_calls else 0
|
||||
),
|
||||
"range_by_provider": range_by_provider_data,
|
||||
"range_by_umo": range_by_umo_data,
|
||||
"today_total_tokens": today_total_tokens,
|
||||
"today_total_calls": today_total_calls,
|
||||
"today_by_model": today_by_model_data,
|
||||
"today_by_provider": today_by_provider_data,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise StatServiceError(f"Error: {exc!s}") from exc
|
||||
|
||||
async def get_provider_token_stats_from_legacy_query(self, days_raw) -> dict:
|
||||
try:
|
||||
days = int(days_raw if days_raw is not None else 1)
|
||||
except (TypeError, ValueError):
|
||||
days = 1
|
||||
return await self.get_provider_token_stats(days)
|
||||
|
||||
async def test_ghproxy_connection(self, proxy_url: str | None) -> dict:
|
||||
try:
|
||||
if not proxy_url:
|
||||
raise StatServiceError("proxy_url is required")
|
||||
|
||||
proxy_url = proxy_url.rstrip("/")
|
||||
test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
|
||||
start_time = time.time()
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
test_url,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as response,
|
||||
):
|
||||
if response.status == 200:
|
||||
end_time = time.time()
|
||||
_ = await response.text()
|
||||
return {
|
||||
"latency": round((end_time - start_time) * 1000, 2),
|
||||
}
|
||||
raise StatServiceError(f"Failed. Status code: {response.status}")
|
||||
except StatServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise StatServiceError(f"Error: {exc!s}") from exc
|
||||
|
||||
async def test_ghproxy_connection_from_legacy_payload(
|
||||
self,
|
||||
payload: object,
|
||||
) -> dict:
|
||||
proxy_url = payload.get("proxy_url") if isinstance(payload, dict) else None
|
||||
return await self.test_ghproxy_connection(proxy_url)
|
||||
|
||||
def get_changelog(self, version: str | None) -> dict:
|
||||
try:
|
||||
if not version:
|
||||
raise StatServiceError("version parameter is required")
|
||||
|
||||
version = version.lstrip("v")
|
||||
if not re.match(r"^[a-zA-Z0-9._-]+$", version):
|
||||
raise StatServiceError("Invalid version format")
|
||||
if ".." in version or "/" in version or "\\" in version:
|
||||
raise StatServiceError("Invalid version format")
|
||||
|
||||
changelogs_dir = (Path(get_astrbot_path()) / "changelogs").resolve()
|
||||
changelog_path = (changelogs_dir / f"v{version}.md").resolve(strict=False)
|
||||
if not changelog_path.is_relative_to(changelogs_dir):
|
||||
logger.warning(
|
||||
"Path traversal attempt detected: %s -> %s",
|
||||
version,
|
||||
changelog_path,
|
||||
)
|
||||
raise StatServiceError("Invalid version format")
|
||||
|
||||
if not changelog_path.is_file():
|
||||
raise StatServiceError(f"Changelog for version {version} not found")
|
||||
|
||||
content = changelog_path.read_text(encoding="utf-8")
|
||||
return {"content": content, "version": version}
|
||||
except StatServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise StatServiceError(f"Error: {exc!s}") from exc
|
||||
|
||||
def list_changelog_versions(self) -> dict:
|
||||
try:
|
||||
changelogs_dir = Path(get_astrbot_path()) / "changelogs"
|
||||
if not changelogs_dir.exists():
|
||||
return {"versions": []}
|
||||
|
||||
versions = []
|
||||
for path in changelogs_dir.iterdir():
|
||||
filename = path.name
|
||||
if filename.endswith(".md") and filename.startswith("v"):
|
||||
version = filename[1:-3]
|
||||
if re.match(r"^[a-zA-Z0-9._-]+$", version):
|
||||
versions.append(version)
|
||||
|
||||
versions.sort(
|
||||
key=cmp_to_key(
|
||||
lambda v1, v2: VersionComparator.compare_version(v2, v1),
|
||||
),
|
||||
)
|
||||
|
||||
return {"versions": versions}
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise StatServiceError(f"Error: {exc!s}") from exc
|
||||
|
||||
def get_first_notice(self, locale: str | None) -> dict:
|
||||
try:
|
||||
locale = (locale or "").strip()
|
||||
if not re.match(r"^[A-Za-z0-9_-]*$", locale):
|
||||
locale = ""
|
||||
|
||||
base_path = Path(get_astrbot_path())
|
||||
candidates: list[Path] = []
|
||||
|
||||
if locale:
|
||||
candidates.append(base_path / f"FIRST_NOTICE.{locale}.md")
|
||||
if locale.lower().startswith("zh"):
|
||||
candidates.append(base_path / "FIRST_NOTICE.md")
|
||||
candidates.append(base_path / "FIRST_NOTICE.zh-CN.md")
|
||||
elif locale.lower().startswith("en"):
|
||||
candidates.append(base_path / "FIRST_NOTICE.en-US.md")
|
||||
|
||||
candidates.extend(
|
||||
[
|
||||
base_path / "FIRST_NOTICE.md",
|
||||
base_path / "FIRST_NOTICE.en-US.md",
|
||||
],
|
||||
)
|
||||
|
||||
for notice_path in candidates:
|
||||
if not notice_path.is_file():
|
||||
continue
|
||||
content = notice_path.read_text(encoding="utf-8")
|
||||
if content.strip():
|
||||
return {"content": content}
|
||||
|
||||
return {"content": None}
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise StatServiceError(f"Error: {exc!s}") from exc
|
||||
36
astrbot/dashboard/services/static_file_service.py
Normal file
36
astrbot/dashboard/services/static_file_service.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class StaticFileService:
|
||||
INDEX_ROUTES = (
|
||||
"/",
|
||||
"/auth/login",
|
||||
"/config",
|
||||
"/logs",
|
||||
"/extension",
|
||||
"/dashboard/default",
|
||||
"/alkaid",
|
||||
"/alkaid/knowledge-base",
|
||||
"/alkaid/long-term-memory",
|
||||
"/alkaid/other",
|
||||
"/console",
|
||||
"/chat",
|
||||
"/settings",
|
||||
"/platforms",
|
||||
"/providers",
|
||||
"/about",
|
||||
"/extension-marketplace",
|
||||
"/conversation",
|
||||
"/tool-use",
|
||||
)
|
||||
NOT_FOUND_MESSAGE = (
|
||||
"404 Not found。如果你初次使用打开面板发现 404, 请参考文档: "
|
||||
"https://docs.astrbot.app/faq.html。如果你正在测试回调地址可达性,"
|
||||
"显示这段文字说明测试成功了。"
|
||||
)
|
||||
|
||||
def list_index_routes(self) -> tuple[str, ...]:
|
||||
return self.INDEX_ROUTES
|
||||
|
||||
def get_not_found_message(self) -> str:
|
||||
return self.NOT_FOUND_MESSAGE
|
||||
96
astrbot/dashboard/services/subagent_service.py
Normal file
96
astrbot/dashboard/services/subagent_service.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
class SubAgentServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SubAgentService:
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
|
||||
def get_config(self) -> dict:
|
||||
try:
|
||||
config_data = self.core_lifecycle.astrbot_config.get(
|
||||
"subagent_orchestrator"
|
||||
)
|
||||
return self._normalize_config(config_data)
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise SubAgentServiceError(f"获取 subagent 配置失败: {exc!s}") from exc
|
||||
|
||||
async def update_config(self, data: object) -> None:
|
||||
try:
|
||||
if not isinstance(data, dict):
|
||||
raise SubAgentServiceError("配置必须为 JSON 对象")
|
||||
|
||||
config = self.core_lifecycle.astrbot_config
|
||||
config["subagent_orchestrator"] = data
|
||||
config.save_config()
|
||||
|
||||
orchestrator = getattr(self.core_lifecycle, "subagent_orchestrator", None)
|
||||
if orchestrator is not None:
|
||||
await orchestrator.reload_from_config(data)
|
||||
except SubAgentServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise SubAgentServiceError(f"保存 subagent 配置失败: {exc!s}") from exc
|
||||
|
||||
def get_available_tools(self) -> list[dict]:
|
||||
try:
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools = []
|
||||
for tool in tool_mgr.func_list:
|
||||
if self._is_subagent_internal_tool(tool):
|
||||
continue
|
||||
tools.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
"handler_module_path": tool.handler_module_path,
|
||||
}
|
||||
)
|
||||
return tools
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise SubAgentServiceError(f"获取可用工具失败: {exc!s}") from exc
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config(data: object) -> dict:
|
||||
if not isinstance(data, dict):
|
||||
data = {
|
||||
"main_enable": False,
|
||||
"remove_main_duplicate_tools": False,
|
||||
"agents": [],
|
||||
}
|
||||
|
||||
if "main_enable" not in data and "enable" in data:
|
||||
data["main_enable"] = bool(data.get("enable", False))
|
||||
|
||||
data.setdefault("main_enable", False)
|
||||
data.setdefault("remove_main_duplicate_tools", False)
|
||||
data.setdefault("agents", [])
|
||||
|
||||
if isinstance(data.get("agents"), list):
|
||||
for agent in data["agents"]:
|
||||
if isinstance(agent, dict):
|
||||
agent.setdefault("provider_id", None)
|
||||
agent.setdefault("persona_id", None)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _is_subagent_internal_tool(tool) -> bool:
|
||||
return (
|
||||
isinstance(tool, HandoffTool)
|
||||
or tool.handler_module_path == "core.subagent_orchestrator"
|
||||
)
|
||||
150
astrbot/dashboard/services/t2i_service.py
Normal file
150
astrbot/dashboard/services/t2i_service.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.t2i.template_manager import TemplateManager
|
||||
|
||||
|
||||
class T2iServiceError(Exception):
|
||||
def __init__(self, message: str, status_code: int = 500) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class T2iService:
|
||||
def __init__(
|
||||
self,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
manager: TemplateManager | None = None,
|
||||
) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.manager = manager or TemplateManager()
|
||||
|
||||
async def reload_all_pipeline_schedulers(self) -> None:
|
||||
for conf_id in self.core_lifecycle.astrbot_config_mgr.confs:
|
||||
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
||||
|
||||
async def sync_active_template_to_all_configs(self, name: str) -> None:
|
||||
for config in self.core_lifecycle.astrbot_config_mgr.confs.values():
|
||||
config["t2i_active_template"] = name
|
||||
config.save_config()
|
||||
await self.reload_all_pipeline_schedulers()
|
||||
|
||||
def list_templates(self):
|
||||
try:
|
||||
return self.manager.list_templates()
|
||||
except Exception as exc:
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
def get_active_template(self) -> dict:
|
||||
try:
|
||||
return {"active_template": self.config.get("t2i_active_template", "base")}
|
||||
except Exception as exc:
|
||||
logger.error("Error in get_active_template", exc_info=True)
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
def get_template(self, name: str) -> dict:
|
||||
try:
|
||||
return {"name": name, "content": self.manager.get_template(name)}
|
||||
except FileNotFoundError as exc:
|
||||
raise T2iServiceError("Template not found", 404) from exc
|
||||
except Exception as exc:
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
def create_template(self, name: str | None, content: str | None) -> dict:
|
||||
if not name or not content:
|
||||
raise T2iServiceError("Name and content are required.", 400)
|
||||
|
||||
name = name.strip()
|
||||
try:
|
||||
self.manager.create_template(name, content)
|
||||
except FileExistsError as exc:
|
||||
raise T2iServiceError(
|
||||
"Template with this name already exists.",
|
||||
409,
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise T2iServiceError(str(exc), 400) from exc
|
||||
except Exception as exc:
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
return {"name": name}
|
||||
|
||||
def create_template_from_legacy_payload(self, data: object) -> dict:
|
||||
payload = self._payload(data)
|
||||
return self.create_template(payload.get("name"), payload.get("content"))
|
||||
|
||||
async def update_template(self, name: str, content: str | None) -> tuple[dict, str]:
|
||||
name = name.strip()
|
||||
if content is None:
|
||||
raise T2iServiceError("Content is required.", 400)
|
||||
|
||||
try:
|
||||
self.manager.update_template(name, content)
|
||||
active_template = self.config.get("t2i_active_template", "base")
|
||||
if name == active_template:
|
||||
await self.reload_all_pipeline_schedulers()
|
||||
message = f"模板 '{name}' 已更新并重新加载。"
|
||||
else:
|
||||
message = f"模板 '{name}' 已更新。"
|
||||
except ValueError as exc:
|
||||
raise T2iServiceError(str(exc), 400) from exc
|
||||
except Exception as exc:
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
return {"name": name}, message
|
||||
|
||||
async def update_template_from_legacy_payload(
|
||||
self,
|
||||
name: str,
|
||||
data: object,
|
||||
) -> tuple[dict, str]:
|
||||
payload = self._payload(data)
|
||||
return await self.update_template(name, payload.get("content"))
|
||||
|
||||
def delete_template(self, name: str) -> None:
|
||||
name = name.strip()
|
||||
try:
|
||||
self.manager.delete_template(name)
|
||||
except FileNotFoundError as exc:
|
||||
raise T2iServiceError("Template not found.", 404) from exc
|
||||
except ValueError as exc:
|
||||
raise T2iServiceError(str(exc), 400) from exc
|
||||
except Exception as exc:
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
async def set_active_template(self, name: str | None) -> str:
|
||||
if not name:
|
||||
raise T2iServiceError("模板名称(name)不能为空。", 400)
|
||||
|
||||
try:
|
||||
self.manager.get_template(name)
|
||||
await self.sync_active_template_to_all_configs(name)
|
||||
except FileNotFoundError as exc:
|
||||
raise T2iServiceError(f"模板 '{name}' 不存在,无法应用。", 404) from exc
|
||||
except Exception as exc:
|
||||
logger.error("Error in set_active_template", exc_info=True)
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
return f"模板 '{name}' 已成功应用。"
|
||||
|
||||
async def set_active_template_from_legacy_payload(self, data: object) -> str:
|
||||
payload = self._payload(data)
|
||||
return await self.set_active_template(payload.get("name"))
|
||||
|
||||
async def reset_default_template(self) -> str:
|
||||
try:
|
||||
self.manager.reset_default_template()
|
||||
await self.sync_active_template_to_all_configs("base")
|
||||
except FileNotFoundError as exc:
|
||||
raise T2iServiceError(str(exc), 404) from exc
|
||||
except Exception as exc:
|
||||
logger.error("Error in reset_default_template", exc_info=True)
|
||||
raise T2iServiceError(str(exc)) from exc
|
||||
|
||||
return "Default template has been reset and activated."
|
||||
|
||||
@staticmethod
|
||||
def _payload(data: object) -> dict:
|
||||
return data if isinstance(data, dict) else {}
|
||||
492
astrbot/dashboard/services/tools_service.py
Normal file
492
astrbot/dashboard/services/tools_service.py
Normal file
@@ -0,0 +1,492 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star import star_map
|
||||
from astrbot.core.tools.registry import get_builtin_tool_config_statuses
|
||||
|
||||
|
||||
class ToolsServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EmptyMcpServersError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def extract_mcp_server_config(mcp_servers_value: object) -> dict:
|
||||
if not isinstance(mcp_servers_value, dict):
|
||||
raise ValueError("mcpServers must be a JSON object")
|
||||
if not mcp_servers_value:
|
||||
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
|
||||
key_0 = next(iter(mcp_servers_value))
|
||||
extracted = mcp_servers_value[key_0]
|
||||
if not isinstance(extracted, dict):
|
||||
raise ValueError(
|
||||
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
|
||||
"and each value is an object containing fields like command/url."
|
||||
)
|
||||
return extracted
|
||||
|
||||
|
||||
class ToolsService:
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.tool_mgr = core_lifecycle.provider_manager.llm_tools
|
||||
|
||||
def rollback_mcp_server(self, name: str) -> bool:
|
||||
try:
|
||||
rollback_config = self.tool_mgr.load_mcp_config()
|
||||
if name in rollback_config["mcpServers"]:
|
||||
rollback_config["mcpServers"].pop(name)
|
||||
return self.tool_mgr.save_mcp_config(rollback_config)
|
||||
return True
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def get_mcp_servers(self) -> list[dict]:
|
||||
try:
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
servers = []
|
||||
mcp_servers = config.get("mcpServers", {})
|
||||
|
||||
if not isinstance(mcp_servers, dict):
|
||||
logger.warning(
|
||||
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
|
||||
)
|
||||
mcp_servers = {}
|
||||
|
||||
for name, server_config in mcp_servers.items():
|
||||
if not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
|
||||
)
|
||||
continue
|
||||
|
||||
server_info = {
|
||||
"name": name,
|
||||
"active": server_config.get("active", True),
|
||||
}
|
||||
for key, value in server_config.items():
|
||||
if key != "active":
|
||||
server_info[key] = value
|
||||
|
||||
for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items():
|
||||
if name_key == name:
|
||||
mcp_client = runtime.client
|
||||
server_info["tools"] = [tool.name for tool in mcp_client.tools]
|
||||
server_info["errlogs"] = mcp_client.server_errlogs
|
||||
break
|
||||
else:
|
||||
server_info["tools"] = []
|
||||
|
||||
servers.append(server_info)
|
||||
|
||||
return servers
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to get MCP server list: {exc!s}") from exc
|
||||
|
||||
async def add_mcp_server(self, server_data: Any) -> str:
|
||||
try:
|
||||
name = server_data.get("name", "")
|
||||
if not name:
|
||||
raise ToolsServiceError("Server name cannot be empty")
|
||||
|
||||
has_valid_config, server_config = self._build_server_config(server_data)
|
||||
if not has_valid_config:
|
||||
raise ToolsServiceError("A valid server configuration is required")
|
||||
|
||||
self._validate_server_config(server_config)
|
||||
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
if name in config["mcpServers"]:
|
||||
raise ToolsServiceError(f"Server {name} already exists")
|
||||
|
||||
try:
|
||||
await self.tool_mgr.test_mcp_server_connection(server_config)
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"MCP connection test failed: {exc!s}") from exc
|
||||
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
await self._enable_added_server(name, server_config)
|
||||
return f"Successfully added MCP server {name}"
|
||||
raise ToolsServiceError("Failed to save configuration")
|
||||
except ToolsServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to add MCP server: {exc!s}") from exc
|
||||
|
||||
async def update_mcp_server(self, server_data: Any) -> str:
|
||||
try:
|
||||
name = server_data.get("name", "")
|
||||
old_name = server_data.get("oldName") or name
|
||||
|
||||
if not name:
|
||||
raise ToolsServiceError("Server name cannot be empty")
|
||||
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if old_name not in config["mcpServers"]:
|
||||
raise ToolsServiceError(f"Server {old_name} does not exist")
|
||||
|
||||
is_rename = name != old_name
|
||||
if name in config["mcpServers"] and is_rename:
|
||||
raise ToolsServiceError(f"Server {name} already exists")
|
||||
|
||||
old_config = config["mcpServers"][old_name]
|
||||
old_active = (
|
||||
old_config.get("active", True) if isinstance(old_config, dict) else True
|
||||
)
|
||||
active = server_data.get("active", old_active)
|
||||
|
||||
only_update_active, server_config = self._build_updated_server_config(
|
||||
server_data,
|
||||
old_config,
|
||||
active,
|
||||
)
|
||||
self._validate_server_config(server_config)
|
||||
|
||||
if is_rename:
|
||||
config["mcpServers"].pop(old_name)
|
||||
config["mcpServers"][name] = server_config
|
||||
else:
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
await self._sync_updated_server_runtime(
|
||||
name=name,
|
||||
old_name=old_name,
|
||||
active=active,
|
||||
is_rename=is_rename,
|
||||
only_update_active=only_update_active,
|
||||
server_config=config["mcpServers"][name],
|
||||
)
|
||||
return f"Successfully updated MCP server {name}"
|
||||
raise ToolsServiceError("Failed to save configuration")
|
||||
except ToolsServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to update MCP server: {exc!s}") from exc
|
||||
|
||||
async def delete_mcp_server(self, server_data: Any) -> str:
|
||||
try:
|
||||
name = server_data.get("name", "")
|
||||
|
||||
if not name:
|
||||
raise ToolsServiceError("Server name cannot be empty")
|
||||
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if name not in config["mcpServers"]:
|
||||
raise ToolsServiceError(f"Server {name} does not exist")
|
||||
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
if name in self.tool_mgr.mcp_server_runtime_view:
|
||||
await self._disable_server(name)
|
||||
return f"Successfully deleted MCP server {name}"
|
||||
raise ToolsServiceError("Failed to save configuration")
|
||||
except ToolsServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to delete MCP server: {exc!s}") from exc
|
||||
|
||||
async def test_mcp_connection(self, server_data: Any) -> list:
|
||||
try:
|
||||
config = server_data.get("mcp_server_config", None)
|
||||
|
||||
if not isinstance(config, dict) or not config:
|
||||
raise ToolsServiceError("Invalid MCP server configuration")
|
||||
|
||||
if "mcpServers" in config:
|
||||
mcp_servers = config["mcpServers"]
|
||||
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
|
||||
raise ToolsServiceError(
|
||||
"Only one MCP server configuration can be tested at a time"
|
||||
)
|
||||
try:
|
||||
config = extract_mcp_server_config(mcp_servers)
|
||||
except EmptyMcpServersError as exc:
|
||||
raise ToolsServiceError(
|
||||
"MCP server configuration cannot be empty"
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise ToolsServiceError(f"{exc!s}") from exc
|
||||
elif not config:
|
||||
raise ToolsServiceError("MCP server configuration cannot be empty")
|
||||
|
||||
self._validate_server_config(config)
|
||||
return await self.tool_mgr.test_mcp_server_connection(config)
|
||||
except ToolsServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to test MCP connection: {exc!s}") from exc
|
||||
|
||||
def get_tool_list(self) -> list[dict]:
|
||||
try:
|
||||
tools = list(self.tool_mgr.func_list)
|
||||
existing_names = {tool.name for tool in tools}
|
||||
for tool in self.tool_mgr.iter_builtin_tools():
|
||||
if tool.name not in existing_names:
|
||||
tools.append(tool)
|
||||
|
||||
config_entries = self._get_config_entries()
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
tools_dict.append(self._serialize_tool(tool, config_entries))
|
||||
return tools_dict
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to get tool list: {exc!s}") from exc
|
||||
|
||||
def toggle_tool(self, data: Any) -> str:
|
||||
try:
|
||||
tool_name = data.get("name")
|
||||
action = data.get("activate")
|
||||
|
||||
if not tool_name or action is None:
|
||||
raise ToolsServiceError("Missing required parameters: name or activate")
|
||||
|
||||
if self.tool_mgr.is_builtin_tool(tool_name):
|
||||
raise ToolsServiceError(
|
||||
"Builtin tools are read-only and cannot be toggled."
|
||||
)
|
||||
|
||||
if action:
|
||||
try:
|
||||
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
||||
except ValueError as exc:
|
||||
raise ToolsServiceError(
|
||||
f"Failed to activate tool: {exc!s}"
|
||||
) from exc
|
||||
else:
|
||||
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
||||
|
||||
if ok:
|
||||
return "Operation successful."
|
||||
raise ToolsServiceError(
|
||||
f"Tool {tool_name} does not exist or the operation failed."
|
||||
)
|
||||
except ToolsServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Failed to operate tool: {exc!s}") from exc
|
||||
|
||||
async def sync_provider(self, data: Any) -> str:
|
||||
try:
|
||||
provider_name = data.get("name")
|
||||
match provider_name:
|
||||
case "modelscope":
|
||||
access_token = data.get("access_token", "")
|
||||
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
||||
case _:
|
||||
raise ToolsServiceError(f"Unknown provider: {provider_name}")
|
||||
|
||||
return "Sync completed"
|
||||
except ToolsServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(f"Sync failed: {exc!s}") from exc
|
||||
|
||||
@staticmethod
|
||||
def _build_server_config(server_data: dict) -> tuple[bool, dict]:
|
||||
has_valid_config = False
|
||||
server_config = {"active": server_data.get("active", True)}
|
||||
|
||||
for key, value in server_data.items():
|
||||
if key in ["name", "active", "tools", "errlogs"]:
|
||||
continue
|
||||
if key == "mcpServers":
|
||||
try:
|
||||
server_config = extract_mcp_server_config(server_data["mcpServers"])
|
||||
except ValueError as exc:
|
||||
raise ToolsServiceError(f"{exc!s}") from exc
|
||||
else:
|
||||
server_config[key] = value
|
||||
has_valid_config = True
|
||||
|
||||
return has_valid_config, server_config
|
||||
|
||||
@staticmethod
|
||||
def _build_updated_server_config(
|
||||
server_data: dict,
|
||||
old_config: object,
|
||||
active: bool,
|
||||
) -> tuple[bool, dict]:
|
||||
server_config = {"active": active}
|
||||
only_update_active = True
|
||||
|
||||
for key, value in server_data.items():
|
||||
if key in ["name", "active", "tools", "errlogs", "oldName"]:
|
||||
continue
|
||||
if key == "mcpServers":
|
||||
try:
|
||||
server_config = extract_mcp_server_config(server_data["mcpServers"])
|
||||
except ValueError as exc:
|
||||
raise ToolsServiceError(f"{exc!s}") from exc
|
||||
else:
|
||||
server_config[key] = value
|
||||
only_update_active = False
|
||||
|
||||
if only_update_active and isinstance(old_config, dict):
|
||||
for key, value in old_config.items():
|
||||
if key != "active":
|
||||
server_config[key] = value
|
||||
|
||||
return only_update_active, server_config
|
||||
|
||||
@staticmethod
|
||||
def _validate_server_config(server_config: dict) -> None:
|
||||
try:
|
||||
validate_mcp_stdio_config(server_config)
|
||||
except ValueError as exc:
|
||||
raise ToolsServiceError(f"{exc!s}") from exc
|
||||
|
||||
async def _enable_added_server(self, name: str, server_config: dict) -> None:
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(name, server_config, timeout=30)
|
||||
except TimeoutError as exc:
|
||||
rollback_ok = self.rollback_mcp_server(name)
|
||||
err_msg = f"Timed out while enabling MCP server {name}."
|
||||
if not rollback_ok:
|
||||
err_msg += (
|
||||
" Configuration rollback failed. Please check the config manually."
|
||||
)
|
||||
raise ToolsServiceError(err_msg) from exc
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
rollback_ok = self.rollback_mcp_server(name)
|
||||
err_msg = f"Failed to enable MCP server {name}: {exc!s}"
|
||||
if not rollback_ok:
|
||||
err_msg += (
|
||||
" Configuration rollback failed. Please check the config manually."
|
||||
)
|
||||
raise ToolsServiceError(err_msg) from exc
|
||||
|
||||
async def _sync_updated_server_runtime(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
old_name: str,
|
||||
active: bool,
|
||||
is_rename: bool,
|
||||
only_update_active: bool,
|
||||
server_config: dict,
|
||||
) -> None:
|
||||
if active:
|
||||
if (
|
||||
old_name in self.tool_mgr.mcp_server_runtime_view
|
||||
or not only_update_active
|
||||
or is_rename
|
||||
):
|
||||
await self._disable_server_before_enable(old_name)
|
||||
await self._enable_updated_server(name, server_config)
|
||||
elif old_name in self.tool_mgr.mcp_server_runtime_view:
|
||||
await self._disable_server(old_name)
|
||||
|
||||
async def _disable_server_before_enable(self, old_name: str) -> None:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(old_name, timeout=10)
|
||||
except TimeoutError as exc:
|
||||
raise ToolsServiceError(
|
||||
f"Timed out while disabling MCP server {old_name} before enabling: {exc!s}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(
|
||||
f"Failed to disable MCP server {old_name} before enabling: {exc!s}"
|
||||
) from exc
|
||||
|
||||
async def _enable_updated_server(self, name: str, server_config: dict) -> None:
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(name, server_config, timeout=30)
|
||||
except TimeoutError as exc:
|
||||
raise ToolsServiceError(
|
||||
f"Timed out while enabling MCP server {name}."
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(
|
||||
f"Failed to enable MCP server {name}: {exc!s}"
|
||||
) from exc
|
||||
|
||||
async def _disable_server(self, name: str) -> None:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
except TimeoutError as exc:
|
||||
raise ToolsServiceError(
|
||||
f"Timed out while disabling MCP server {name}."
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(traceback.format_exc())
|
||||
raise ToolsServiceError(
|
||||
f"Failed to disable MCP server {name}: {exc!s}"
|
||||
) from exc
|
||||
|
||||
def _get_config_entries(self) -> list[dict]:
|
||||
conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list()
|
||||
conf_name_map = {conf["id"]: conf["name"] for conf in conf_list}
|
||||
config_entries = []
|
||||
for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items():
|
||||
config_entries.append(
|
||||
{
|
||||
"conf_id": conf_id,
|
||||
"conf_name": conf_name_map.get(conf_id, conf_id),
|
||||
"config": conf,
|
||||
}
|
||||
)
|
||||
return config_entries
|
||||
|
||||
def _serialize_tool(self, tool, config_entries: list[dict]) -> dict:
|
||||
readonly = False
|
||||
builtin_config_statuses = []
|
||||
builtin_config_tags = []
|
||||
if self.tool_mgr.is_builtin_tool(tool.name):
|
||||
origin = "builtin"
|
||||
origin_name = "AstrBot Core"
|
||||
readonly = True
|
||||
builtin_config_statuses = get_builtin_tool_config_statuses(
|
||||
tool.name,
|
||||
config_entries,
|
||||
)
|
||||
builtin_config_tags = [
|
||||
status for status in builtin_config_statuses if status["enabled"]
|
||||
]
|
||||
elif isinstance(tool, MCPTool):
|
||||
origin = "mcp"
|
||||
origin_name = tool.mcp_server_name
|
||||
elif tool.handler_module_path and star_map.get(tool.handler_module_path):
|
||||
star = star_map[tool.handler_module_path]
|
||||
origin = "plugin"
|
||||
origin_name = star.name
|
||||
else:
|
||||
origin = "unknown"
|
||||
origin_name = "unknown"
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
"origin": origin,
|
||||
"origin_name": origin_name,
|
||||
"readonly": readonly,
|
||||
"builtin_config_statuses": builtin_config_statuses,
|
||||
"builtin_config_tags": builtin_config_tags,
|
||||
}
|
||||
412
astrbot/dashboard/services/update_service.py
Normal file
412
astrbot/dashboard/services/update_service.py
Normal file
@@ -0,0 +1,412 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core import (
|
||||
DEMO_MODE as _DEMO_MODE,
|
||||
)
|
||||
from astrbot.core import (
|
||||
logger,
|
||||
)
|
||||
from astrbot.core import (
|
||||
pip_installer as _pip_installer,
|
||||
)
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db.migration.helper import (
|
||||
check_migration_needed_v4 as _check_migration_needed_v4,
|
||||
)
|
||||
from astrbot.core.db.migration.helper import (
|
||||
do_migration_v4 as _do_migration_v4,
|
||||
)
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.io import (
|
||||
download_dashboard as _download_dashboard,
|
||||
)
|
||||
from astrbot.core.utils.io import (
|
||||
get_dashboard_version as _get_dashboard_version,
|
||||
)
|
||||
|
||||
DEMO_MODE = _DEMO_MODE
|
||||
pip_installer = _pip_installer
|
||||
download_dashboard = _download_dashboard
|
||||
get_dashboard_version = _get_dashboard_version
|
||||
default_check_migration_needed_v4 = _check_migration_needed_v4
|
||||
default_do_migration_v4 = _do_migration_v4
|
||||
|
||||
|
||||
async def call_download_dashboard(*args, **kwargs):
|
||||
return await download_dashboard(*args, **kwargs)
|
||||
|
||||
|
||||
async def call_get_dashboard_version(*args, **kwargs):
|
||||
return await get_dashboard_version(*args, **kwargs)
|
||||
|
||||
|
||||
async def call_pip_install(*args, **kwargs):
|
||||
return await pip_installer.install(*args, **kwargs)
|
||||
|
||||
|
||||
async def call_check_migration_needed_v4(*args, **kwargs):
|
||||
return await default_check_migration_needed_v4(*args, **kwargs)
|
||||
|
||||
|
||||
async def call_do_migration_v4(*args, **kwargs):
|
||||
return await default_do_migration_v4(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateServiceResult:
|
||||
data: Any = None
|
||||
message: str | None = None
|
||||
status: str = "ok"
|
||||
headers: dict | None = None
|
||||
|
||||
|
||||
class UpdateServiceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateService:
|
||||
def __init__(
|
||||
self,
|
||||
astrbot_updator: AstrBotUpdator,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
*,
|
||||
download_dashboard_func: Callable[..., Awaitable[Any]],
|
||||
get_dashboard_version_func: Callable[..., Awaitable[str]],
|
||||
pip_install_func: Callable[..., Awaitable[Any]],
|
||||
check_migration_needed_func: Callable[..., Awaitable[bool]],
|
||||
do_migration_func: Callable[..., Awaitable[Any]],
|
||||
demo_mode: bool,
|
||||
clear_site_data_headers: dict,
|
||||
) -> None:
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.download_dashboard = download_dashboard_func
|
||||
self.get_dashboard_version = get_dashboard_version_func
|
||||
self.pip_install = pip_install_func
|
||||
self.check_migration_needed = check_migration_needed_func
|
||||
self.do_migration = do_migration_func
|
||||
self.demo_mode = demo_mode
|
||||
self.clear_site_data_headers = clear_site_data_headers
|
||||
self.update_progress: dict[str, dict] = {}
|
||||
|
||||
def get_update_progress(self, progress_id: str) -> UpdateServiceResult:
|
||||
if not progress_id:
|
||||
raise UpdateServiceError("缺少参数 id。")
|
||||
progress = self.update_progress.get(progress_id)
|
||||
if not progress:
|
||||
return UpdateServiceResult(
|
||||
data={"id": progress_id, "status": "idle"},
|
||||
message="没有正在进行的更新。",
|
||||
)
|
||||
return UpdateServiceResult(data=progress)
|
||||
|
||||
def get_update_progress_from_legacy_query(
|
||||
self,
|
||||
progress_id: str | None,
|
||||
) -> UpdateServiceResult:
|
||||
return self.get_update_progress(progress_id or "")
|
||||
|
||||
async def do_migration_v4(self, data: object) -> UpdateServiceResult:
|
||||
need_migration = await self.check_migration_needed(self.core_lifecycle.db)
|
||||
if not need_migration:
|
||||
return UpdateServiceResult(message="不需要进行迁移。")
|
||||
try:
|
||||
payload = data if isinstance(data, dict) else {}
|
||||
platform_id_map = payload.get("platform_id_map", {})
|
||||
await self.do_migration(
|
||||
self.core_lifecycle.db,
|
||||
platform_id_map,
|
||||
self.core_lifecycle.astrbot_config,
|
||||
)
|
||||
return UpdateServiceResult(message="迁移成功。")
|
||||
except Exception as exc:
|
||||
logger.error(f"迁移失败: {traceback.format_exc()}")
|
||||
raise UpdateServiceError(f"迁移失败: {exc!s}") from exc
|
||||
|
||||
async def check_update(self, update_type: str | None) -> UpdateServiceResult:
|
||||
try:
|
||||
dashboard_version = await self.get_dashboard_version()
|
||||
if update_type == "dashboard":
|
||||
return UpdateServiceResult(
|
||||
data={
|
||||
"has_new_version": dashboard_version != f"v{VERSION}",
|
||||
"current_version": dashboard_version,
|
||||
}
|
||||
)
|
||||
update_result = await self.astrbot_updator.check_update(None, None, False)
|
||||
return UpdateServiceResult(
|
||||
status="success",
|
||||
message=str(update_result)
|
||||
if update_result is not None
|
||||
else "已经是最新版本了。",
|
||||
data={
|
||||
"version": f"v{VERSION}",
|
||||
"has_new_version": update_result is not None,
|
||||
"dashboard_version": dashboard_version,
|
||||
"dashboard_has_new_version": bool(
|
||||
dashboard_version and dashboard_version != f"v{VERSION}"
|
||||
),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"检查更新失败: {exc!s} (不影响除项目更新外的正常使用)")
|
||||
raise UpdateServiceError(exc.__str__()) from exc
|
||||
|
||||
async def check_update_from_legacy_query(
|
||||
self,
|
||||
update_type: str | None,
|
||||
) -> UpdateServiceResult:
|
||||
return await self.check_update(update_type)
|
||||
|
||||
async def get_releases(self) -> UpdateServiceResult:
|
||||
try:
|
||||
releases = await self.astrbot_updator.get_releases()
|
||||
return UpdateServiceResult(data=releases)
|
||||
except Exception as exc:
|
||||
logger.error(f"/api/update/releases: {traceback.format_exc()}")
|
||||
raise UpdateServiceError(exc.__str__()) from exc
|
||||
|
||||
async def update_project(self, data: object) -> UpdateServiceResult:
|
||||
payload = data if isinstance(data, dict) else {}
|
||||
version = payload.get("version", "")
|
||||
reboot = payload.get("reboot", True)
|
||||
progress_id = payload.get("progress_id") or uuid.uuid4().hex
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ""
|
||||
else:
|
||||
latest = False
|
||||
|
||||
proxy: str = payload.get("proxy", None)
|
||||
if proxy:
|
||||
proxy = proxy.removesuffix("/")
|
||||
|
||||
self._init_update_progress(progress_id, version)
|
||||
try:
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
"running",
|
||||
"正在下载 WebUI...",
|
||||
0,
|
||||
)
|
||||
await self.download_dashboard(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
0,
|
||||
45,
|
||||
),
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dashboard",
|
||||
"done",
|
||||
"WebUI 下载完成。",
|
||||
45,
|
||||
)
|
||||
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"core",
|
||||
"running",
|
||||
"正在下载 AstrBot 项目代码...",
|
||||
45,
|
||||
)
|
||||
await self.astrbot_updator.update(
|
||||
latest=latest,
|
||||
version=version,
|
||||
proxy=proxy,
|
||||
progress_callback=self._make_progress_callback(
|
||||
progress_id,
|
||||
"core",
|
||||
45,
|
||||
45,
|
||||
),
|
||||
)
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"core",
|
||||
"done",
|
||||
"项目代码下载完成。",
|
||||
90,
|
||||
)
|
||||
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dependencies",
|
||||
"running",
|
||||
"正在更新依赖...",
|
||||
92,
|
||||
)
|
||||
logger.info("更新依赖中...")
|
||||
try:
|
||||
await self.pip_install(requirements_path="requirements.txt")
|
||||
except Exception as exc:
|
||||
logger.error(f"更新依赖失败: {exc}")
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"dependencies",
|
||||
"done",
|
||||
"依赖更新完成。",
|
||||
96,
|
||||
)
|
||||
|
||||
if reboot:
|
||||
self._set_update_stage(
|
||||
progress_id,
|
||||
"restart",
|
||||
"running",
|
||||
"更新成功,正在准备重启...",
|
||||
98,
|
||||
)
|
||||
await self.core_lifecycle.restart()
|
||||
message = "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。"
|
||||
else:
|
||||
message = "更新成功,AstrBot 将在下次启动时应用新的代码。"
|
||||
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "success",
|
||||
"stage": "done",
|
||||
"message": message,
|
||||
"overall_percent": 100,
|
||||
},
|
||||
)
|
||||
return UpdateServiceResult(
|
||||
message=message,
|
||||
headers=self.clear_site_data_headers,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.update_progress[progress_id].update(
|
||||
{
|
||||
"status": "error",
|
||||
"message": exc.__str__(),
|
||||
},
|
||||
)
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
raise UpdateServiceError(exc.__str__()) from exc
|
||||
|
||||
async def update_dashboard(self) -> UpdateServiceResult:
|
||||
try:
|
||||
try:
|
||||
await self.download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
except Exception as exc:
|
||||
logger.error(f"下载管理面板文件失败: {exc}。")
|
||||
raise UpdateServiceError(f"下载管理面板文件失败: {exc}") from exc
|
||||
return UpdateServiceResult(
|
||||
message="更新成功。刷新页面即可应用新版本面板。",
|
||||
headers=self.clear_site_data_headers,
|
||||
)
|
||||
except UpdateServiceError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
|
||||
raise UpdateServiceError(exc.__str__()) from exc
|
||||
|
||||
async def install_pip_package(self, data: object) -> UpdateServiceResult:
|
||||
if self.demo_mode:
|
||||
raise UpdateServiceError(
|
||||
"You are not permitted to do this operation in demo mode"
|
||||
)
|
||||
|
||||
payload = data if isinstance(data, dict) else {}
|
||||
package = payload.get("package", "")
|
||||
mirror = payload.get("mirror", None)
|
||||
if not package:
|
||||
raise UpdateServiceError("缺少参数 package 或不合法。")
|
||||
try:
|
||||
await self.pip_install(package, mirror=mirror)
|
||||
return UpdateServiceResult(message="安装成功。")
|
||||
except Exception as exc:
|
||||
logger.error(f"/api/update_pip: {traceback.format_exc()}")
|
||||
raise UpdateServiceError(exc.__str__()) from exc
|
||||
|
||||
def _init_update_progress(self, progress_id: str, version: str) -> None:
|
||||
self.update_progress[progress_id] = {
|
||||
"id": progress_id,
|
||||
"status": "running",
|
||||
"stage": "preparing",
|
||||
"version": version or "latest",
|
||||
"message": "正在准备更新...",
|
||||
"overall_percent": 0,
|
||||
"stages": {
|
||||
"dashboard": self._empty_stage("pending"),
|
||||
"core": self._empty_stage("pending"),
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _empty_stage(status: str = "pending") -> dict:
|
||||
return {
|
||||
"status": status,
|
||||
"downloaded": 0,
|
||||
"total": 0,
|
||||
"percent": 0,
|
||||
"speed": 0,
|
||||
}
|
||||
|
||||
def _set_update_stage(
|
||||
self,
|
||||
progress_id: str,
|
||||
stage: str,
|
||||
status: str,
|
||||
message: str,
|
||||
overall_percent: int | None = None,
|
||||
) -> None:
|
||||
progress = self.update_progress.get(progress_id)
|
||||
if not progress:
|
||||
return
|
||||
progress["stage"] = stage
|
||||
progress["message"] = message
|
||||
progress["stages"].setdefault(stage, self._empty_stage())
|
||||
progress["stages"][stage]["status"] = status
|
||||
if overall_percent is not None:
|
||||
progress["overall_percent"] = overall_percent
|
||||
|
||||
@staticmethod
|
||||
def _normalize_percent(value) -> int:
|
||||
try:
|
||||
percent = float(value or 0)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
if percent <= 1:
|
||||
percent *= 100
|
||||
return max(0, min(100, int(percent)))
|
||||
|
||||
def _make_progress_callback(
|
||||
self,
|
||||
progress_id: str,
|
||||
stage: str,
|
||||
stage_start: int,
|
||||
stage_weight: int,
|
||||
):
|
||||
def _callback(payload: dict) -> None:
|
||||
progress = self.update_progress.get(progress_id)
|
||||
if not progress:
|
||||
return
|
||||
stage_percent = self._normalize_percent(payload.get("percent"))
|
||||
progress["stage"] = stage
|
||||
progress["stages"][stage] = {
|
||||
"status": "running" if stage_percent < 100 else "done",
|
||||
"downloaded": payload.get("downloaded", 0),
|
||||
"total": payload.get("total", 0),
|
||||
"percent": stage_percent,
|
||||
"speed": payload.get("speed", 0),
|
||||
}
|
||||
progress["overall_percent"] = min(
|
||||
99,
|
||||
stage_start + int(stage_percent * stage_weight / 100),
|
||||
)
|
||||
|
||||
return _callback
|
||||
1
astrbot/dashboard/v1/__init__.py
Normal file
1
astrbot/dashboard/v1/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""FastAPI based OpenAPI v1 surface for AstrBot dashboard."""
|
||||
74
astrbot/dashboard/v1/app.py
Normal file
74
astrbot/dashboard/v1/app.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.dashboard.services.config_service import (
|
||||
BotConfigService,
|
||||
ConfigProfileService,
|
||||
ConfigRoutingService,
|
||||
ProviderConfigService,
|
||||
)
|
||||
from astrbot.dashboard.services.open_api_service import OpenApiService
|
||||
from astrbot.dashboard.services.plugin_service import PluginService
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
from astrbot.dashboard.services.session_management_service import (
|
||||
SessionManagementService,
|
||||
)
|
||||
|
||||
from .compat_aliases import router as compat_alias_router
|
||||
from .responses import ApiError, error
|
||||
from .routers import build_v1_router
|
||||
|
||||
|
||||
def create_v1_asgi_app(
|
||||
*,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
db: BaseDatabase,
|
||||
jwt_secret: str,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="AstrBot OpenAPI",
|
||||
version="1.0.0",
|
||||
openapi_url="/api/v1/openapi.json",
|
||||
docs_url="/api/v1/docs",
|
||||
redoc_url="/api/v1/redoc",
|
||||
)
|
||||
app.state.core_lifecycle = core_lifecycle
|
||||
app.state.db = db
|
||||
app.state.jwt_secret = jwt_secret
|
||||
app.state.services = SimpleNamespace(
|
||||
config_profiles=ConfigProfileService(core_lifecycle, db),
|
||||
config_routes=ConfigRoutingService(core_lifecycle),
|
||||
bots=BotConfigService(core_lifecycle),
|
||||
providers=ProviderConfigService(core_lifecycle),
|
||||
plugins=PluginService(core_lifecycle, core_lifecycle.plugin_manager),
|
||||
open_api=OpenApiService(db, core_lifecycle),
|
||||
sessions=SessionManagementService(core_lifecycle, db),
|
||||
route_bridge=None,
|
||||
)
|
||||
app.state.services.route_bridge = DashboardRouteBridgeService(app, jwt_secret)
|
||||
|
||||
@app.exception_handler(ApiError)
|
||||
async def api_error_handler(_request: Request, exc: ApiError):
|
||||
return JSONResponse(
|
||||
error(exc.message, exc.data),
|
||||
status_code=exc.status_code,
|
||||
)
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def value_error_handler(_request: Request, exc: ValueError):
|
||||
return JSONResponse(error(str(exc)), status_code=400)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_error_handler(_request: Request, exc: HTTPException):
|
||||
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
|
||||
return JSONResponse(error(detail), status_code=exc.status_code)
|
||||
|
||||
app.include_router(compat_alias_router)
|
||||
app.include_router(build_v1_router())
|
||||
return app
|
||||
86
astrbot/dashboard/v1/auth.py
Normal file
86
astrbot/dashboard/v1/auth.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
|
||||
from astrbot.dashboard.services.api_key_service import ApiKeyService
|
||||
from astrbot.dashboard.services.auth_service import ALL_OPEN_API_SCOPES
|
||||
|
||||
from .responses import ApiError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthContext:
|
||||
username: str
|
||||
scopes: list[str]
|
||||
api_key_id: str | None = None
|
||||
via: str = "jwt"
|
||||
|
||||
|
||||
def _extract_raw_api_key(request: Request) -> str | None:
|
||||
if key := request.query_params.get("api_key"):
|
||||
return key.strip()
|
||||
if key := request.query_params.get("key"):
|
||||
return key.strip()
|
||||
if key := request.headers.get("X-API-Key"):
|
||||
return key.strip()
|
||||
auth_header = request.headers.get("Authorization", "").strip()
|
||||
if auth_header.startswith("ApiKey "):
|
||||
return auth_header.removeprefix("ApiKey ").strip()
|
||||
return None
|
||||
|
||||
|
||||
async def _require_api_key_scope(
|
||||
request: Request,
|
||||
raw_key: str,
|
||||
scope: str,
|
||||
) -> AuthContext:
|
||||
key_hash = ApiKeyService.hash_key(raw_key)
|
||||
api_key = await request.app.state.db.get_active_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
raise ApiError("Invalid API key", status_code=401)
|
||||
scopes = (
|
||||
api_key.scopes
|
||||
if isinstance(api_key.scopes, list)
|
||||
else list(ALL_OPEN_API_SCOPES)
|
||||
)
|
||||
if "*" not in scopes and scope not in scopes:
|
||||
raise ApiError("Insufficient API key scope", status_code=403)
|
||||
await request.app.state.db.touch_api_key(api_key.key_id)
|
||||
return AuthContext(
|
||||
username=f"api_key:{api_key.key_id}",
|
||||
scopes=scopes,
|
||||
api_key_id=api_key.key_id,
|
||||
via="api_key",
|
||||
)
|
||||
|
||||
|
||||
async def require_scope(request: Request, scope: str) -> AuthContext:
|
||||
raw_key = _extract_raw_api_key(request)
|
||||
if raw_key:
|
||||
return await _require_api_key_scope(request, raw_key, scope)
|
||||
|
||||
auth_header = request.headers.get("Authorization", "").strip()
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise ApiError("Missing API key", status_code=401)
|
||||
token = auth_header.removeprefix("Bearer ").strip()
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
request.app.state.jwt_secret,
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
raise ApiError("Token expired", status_code=401) from exc
|
||||
except jwt.InvalidTokenError as exc:
|
||||
try:
|
||||
return await _require_api_key_scope(request, token, scope)
|
||||
except ApiError as api_key_exc:
|
||||
raise api_key_exc from exc
|
||||
|
||||
username = payload.get("username")
|
||||
if not isinstance(username, str) or not username.strip():
|
||||
raise ApiError("Invalid token", status_code=401)
|
||||
return AuthContext(username=username, scopes=["*"], via="jwt")
|
||||
424
astrbot/dashboard/v1/compat_aliases.py
Normal file
424
astrbot/dashboard/v1/compat_aliases.py
Normal file
@@ -0,0 +1,424 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from astrbot.dashboard.services.config_service import (
|
||||
BotConfigService,
|
||||
ConfigProfileService,
|
||||
ConfigRoutingService,
|
||||
ProviderConfigService,
|
||||
)
|
||||
|
||||
from .auth import AuthContext, require_scope
|
||||
from .responses import error, ok
|
||||
from .schemas import BotConfigRequest, ProviderConfigRequest, ProviderSourceRequest
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api",
|
||||
tags=["Compatibility Aliases"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
|
||||
|
||||
async def require_config_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "config")
|
||||
|
||||
|
||||
def get_config_profile_service(request: Request) -> ConfigProfileService:
|
||||
return request.app.state.services.config_profiles
|
||||
|
||||
|
||||
def get_config_routing_service(request: Request) -> ConfigRoutingService:
|
||||
return request.app.state.services.config_routes
|
||||
|
||||
|
||||
def get_bot_service(request: Request) -> BotConfigService:
|
||||
return request.app.state.services.bots
|
||||
|
||||
|
||||
def get_provider_service(request: Request) -> ProviderConfigService:
|
||||
return request.app.state.services.providers
|
||||
|
||||
|
||||
async def _json_or_empty(request: Request) -> dict:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _compat_error(message: str):
|
||||
return error(message)
|
||||
|
||||
|
||||
@router.get("/config/default")
|
||||
async def get_legacy_default_config(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
return ok(service.get_profile_schema())
|
||||
|
||||
|
||||
@router.get("/config/abconfs")
|
||||
async def list_legacy_config_profiles(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
return ok(service.list_profiles())
|
||||
|
||||
|
||||
@router.post("/config/abconf/new")
|
||||
async def create_legacy_config_profile(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
return ok(
|
||||
await service.create_profile(
|
||||
body.get("name"),
|
||||
body.get("config"),
|
||||
),
|
||||
"创建成功",
|
||||
)
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.get("/config/abconf")
|
||||
async def get_legacy_config_profile(
|
||||
id: str | None = Query(default=None),
|
||||
system_config: str = Query(default="0"),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
if system_config.lower() == "1":
|
||||
return ok(service.get_system_schema())
|
||||
if not id:
|
||||
return _compat_error("缺少配置文件 ID")
|
||||
try:
|
||||
return ok(service.get_profile(id))
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/abconf/delete")
|
||||
async def delete_legacy_config_profile(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
config_id = body.get("id")
|
||||
if not config_id:
|
||||
return _compat_error("缺少配置文件 ID")
|
||||
try:
|
||||
service.delete_profile(str(config_id))
|
||||
return ok(message="删除成功")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/abconf/update")
|
||||
async def rename_legacy_config_profile(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
config_id = body.get("id")
|
||||
if not config_id:
|
||||
return _compat_error("缺少配置文件 ID")
|
||||
try:
|
||||
service.rename_profile(str(config_id), body.get("name"))
|
||||
return ok(message="更新成功")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/astrbot/update")
|
||||
async def update_legacy_astrbot_config(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_config_profile_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
config = body.get("config")
|
||||
config_id = body.get("conf_id")
|
||||
if not isinstance(config, dict):
|
||||
return _compat_error("Invalid config payload")
|
||||
if not config_id:
|
||||
return _compat_error("Config file None does not exist")
|
||||
try:
|
||||
message = await service.update_profile(
|
||||
str(config_id),
|
||||
config,
|
||||
two_factor_code=request.headers.get("X-2FA-Code"),
|
||||
)
|
||||
return ok(message=message or "保存成功~")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.get("/config/umo_abconf_routes")
|
||||
async def get_legacy_config_routes(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_config_routing_service),
|
||||
):
|
||||
return ok(service.list_routes())
|
||||
|
||||
|
||||
@router.post("/config/umo_abconf_route/update_all")
|
||||
async def update_legacy_config_routes(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_config_routing_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
await service.replace_routes(body)
|
||||
except ValueError:
|
||||
return _compat_error("缺少或错误的路由表数据")
|
||||
return ok(message="更新成功")
|
||||
|
||||
|
||||
@router.post("/config/umo_abconf_route/update")
|
||||
async def upsert_legacy_config_route(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_config_routing_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
await service.upsert_route(body)
|
||||
except ValueError:
|
||||
return _compat_error("缺少 UMO 或配置文件 ID")
|
||||
return ok(message="更新成功")
|
||||
|
||||
|
||||
@router.post("/config/umo_abconf_route/delete")
|
||||
async def delete_legacy_config_route(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_config_routing_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
await service.delete_route(body)
|
||||
except ValueError:
|
||||
return _compat_error("缺少 UMO")
|
||||
return ok(message="删除成功")
|
||||
|
||||
|
||||
@router.get("/config/platform/list")
|
||||
async def list_legacy_platforms(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_bot_service),
|
||||
):
|
||||
return ok({"platforms": service.list_bots()["bots"]})
|
||||
|
||||
|
||||
@router.post("/config/platform/new")
|
||||
async def create_legacy_platform(
|
||||
payload: BotConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_bot_service),
|
||||
):
|
||||
try:
|
||||
await service.create_bot(payload.to_legacy_config())
|
||||
return ok(message="新增平台配置成功~")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/platform/update")
|
||||
async def update_legacy_platform(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_bot_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
bot_id = body.get("id")
|
||||
config = body.get("config")
|
||||
if not bot_id or not isinstance(config, dict):
|
||||
return _compat_error("参数错误")
|
||||
try:
|
||||
await service.update_bot(
|
||||
str(bot_id),
|
||||
BotConfigRequest(config=config).to_legacy_config(fallback_id=str(bot_id)),
|
||||
)
|
||||
return ok(message="更新平台配置成功~")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/platform/delete")
|
||||
async def delete_legacy_platform(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_bot_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
bot_id = body.get("id")
|
||||
if not bot_id:
|
||||
return _compat_error("缺少参数 id")
|
||||
try:
|
||||
await service.delete_bot(str(bot_id))
|
||||
return ok(message="删除平台配置成功~")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.get("/config/provider/template")
|
||||
async def get_legacy_provider_template(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
return ok(service.get_provider_schema())
|
||||
|
||||
|
||||
@router.get("/config/provider/list")
|
||||
async def list_legacy_providers(
|
||||
provider_type: str | None = Query(default=None),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
if not provider_type:
|
||||
return _compat_error("缺少参数 provider_type")
|
||||
providers = []
|
||||
seen_ids = set()
|
||||
for item in provider_type.split(","):
|
||||
for provider in service.list_providers(capability=item)["providers"]:
|
||||
provider_id = provider.get("id")
|
||||
if provider_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(provider_id)
|
||||
providers.append(provider)
|
||||
return ok(providers)
|
||||
|
||||
|
||||
@router.post("/config/provider/new")
|
||||
async def create_legacy_provider(
|
||||
payload: ProviderConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
try:
|
||||
await service.create_provider(payload.to_legacy_config())
|
||||
return ok(message="新增服务提供商配置成功")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/provider/update")
|
||||
async def update_legacy_provider(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
provider_id = body.get("id")
|
||||
config = body.get("config")
|
||||
if not provider_id or not isinstance(config, dict):
|
||||
return _compat_error("参数错误")
|
||||
try:
|
||||
await service.update_provider(
|
||||
str(provider_id),
|
||||
ProviderConfigRequest(config=config).to_legacy_config(
|
||||
fallback_id=str(provider_id),
|
||||
),
|
||||
)
|
||||
return ok(message="更新成功,已经实时生效~")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/provider/delete")
|
||||
async def delete_legacy_provider(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
provider_id = body.get("id")
|
||||
if not provider_id:
|
||||
return _compat_error("缺少参数 id")
|
||||
try:
|
||||
await service.delete_provider(str(provider_id))
|
||||
return ok(message="删除成功,已经实时生效。")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.get("/config/provider/check_one")
|
||||
async def check_legacy_provider(
|
||||
id: str | None = Query(default=None),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
if not id:
|
||||
return _compat_error("Missing provider_id parameter")
|
||||
try:
|
||||
return ok(await service.test_provider(id))
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.get("/config/provider_sources/models")
|
||||
async def list_legacy_provider_source_models(
|
||||
source_id: str | None = Query(default=None),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
if not source_id:
|
||||
return _compat_error("缺少参数 source_id")
|
||||
try:
|
||||
data = await service.list_provider_source_models(source_id)
|
||||
data.pop("provider_source_id", None)
|
||||
return ok(data)
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/provider_sources/update")
|
||||
async def update_legacy_provider_source(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
source_id = body.get("original_id")
|
||||
config = body.get("config") or body
|
||||
if not source_id:
|
||||
return _compat_error("缺少 original_id")
|
||||
if not isinstance(config, dict):
|
||||
return _compat_error("缺少或错误的配置数据")
|
||||
try:
|
||||
await service.upsert_provider_source(
|
||||
str(source_id),
|
||||
ProviderSourceRequest(config=config).to_legacy_config(
|
||||
fallback_id=str(source_id),
|
||||
),
|
||||
)
|
||||
return ok(message="更新 provider source 成功")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
|
||||
|
||||
@router.post("/config/provider_sources/delete")
|
||||
async def delete_legacy_provider_source(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_provider_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
source_id = body.get("id")
|
||||
if not source_id:
|
||||
return _compat_error("缺少 provider_source_id")
|
||||
try:
|
||||
await service.delete_provider_source(str(source_id))
|
||||
return ok(message="删除 provider source 成功")
|
||||
except ValueError as exc:
|
||||
return _compat_error(str(exc))
|
||||
22
astrbot/dashboard/v1/responses.py
Normal file
22
astrbot/dashboard/v1/responses.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiError(Exception):
|
||||
message: str
|
||||
status_code: int = 400
|
||||
data: Any = None
|
||||
|
||||
|
||||
def ok(data: Any = None, message: str | None = None) -> dict[str, Any]:
|
||||
return {"status": "ok", "message": message, "data": {} if data is None else data}
|
||||
|
||||
|
||||
def error(message: str, data: Any = None) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"status": "error", "message": message}
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
return payload
|
||||
22
astrbot/dashboard/v1/routers/__init__.py
Normal file
22
astrbot/dashboard/v1/routers/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .bots import router as bots_router
|
||||
from .compat import compat_routers
|
||||
from .config_profiles import router as config_profiles_router
|
||||
from .extensions import router as extensions_router
|
||||
from .open_api_compat import router as open_api_compat_router
|
||||
from .plugins import router as plugins_router
|
||||
from .providers import router as providers_router
|
||||
|
||||
|
||||
def build_v1_router() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
router.include_router(config_profiles_router)
|
||||
router.include_router(bots_router)
|
||||
router.include_router(providers_router)
|
||||
router.include_router(plugins_router)
|
||||
router.include_router(extensions_router)
|
||||
router.include_router(open_api_compat_router)
|
||||
for compat_router in compat_routers:
|
||||
router.include_router(compat_router)
|
||||
return router
|
||||
188
astrbot/dashboard/v1/routers/bots.py
Normal file
188
astrbot/dashboard/v1/routers/bots.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from astrbot.dashboard.services.config_service import BotConfigService
|
||||
|
||||
from ..auth import AuthContext, require_scope
|
||||
from ..responses import ok
|
||||
from ..schemas import BotConfigRequest, EnabledPatch
|
||||
|
||||
router = APIRouter(tags=["Bots"])
|
||||
|
||||
|
||||
async def require_config_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "config")
|
||||
|
||||
|
||||
def get_service(request: Request) -> BotConfigService:
|
||||
return request.app.state.services.bots
|
||||
|
||||
|
||||
async def _json_or_empty(request: Request) -> dict:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _required_text(value: object, name: str) -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"Missing key: {name}")
|
||||
return text
|
||||
|
||||
|
||||
def _config_from_body(body: dict) -> dict:
|
||||
config = body.get("config")
|
||||
if isinstance(config, dict):
|
||||
return config
|
||||
return {
|
||||
key: value
|
||||
for key, value in body.items()
|
||||
if key not in {"bot_id", "config", "enabled"}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/bot-types")
|
||||
async def list_bot_types(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.list_bot_types())
|
||||
|
||||
|
||||
@router.get("/bots")
|
||||
async def list_bots(
|
||||
enabled: bool | None = Query(default=None),
|
||||
type_: str | None = Query(default=None, alias="type"),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.list_bots(enabled=enabled, type_=type_))
|
||||
|
||||
|
||||
@router.post("/bots")
|
||||
async def create_bot(
|
||||
payload: BotConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
await service.create_bot(payload.to_legacy_config())
|
||||
return ok(message="新增平台配置成功~")
|
||||
|
||||
|
||||
@router.get("/bots/stats")
|
||||
async def list_bot_stats(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_bot_stats())
|
||||
|
||||
|
||||
@router.get("/bots/by-id")
|
||||
async def get_bot_by_id(
|
||||
bot_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_bot(bot_id))
|
||||
|
||||
|
||||
@router.put("/bots/by-id")
|
||||
async def update_bot_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
bot_id = _required_text(body.get("bot_id"), "bot_id")
|
||||
await service.update_bot(
|
||||
bot_id,
|
||||
BotConfigRequest(config=_config_from_body(body)).to_legacy_config(
|
||||
fallback_id=bot_id,
|
||||
),
|
||||
)
|
||||
return ok(message="更新平台配置成功~")
|
||||
|
||||
|
||||
@router.delete("/bots/by-id")
|
||||
async def delete_bot_by_id(
|
||||
bot_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
await service.delete_bot(bot_id)
|
||||
return ok(message="删除平台配置成功~")
|
||||
|
||||
|
||||
@router.patch("/bots/enabled")
|
||||
async def set_bot_enabled_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
bot_id = _required_text(body.get("bot_id"), "bot_id")
|
||||
await service.set_bot_enabled(bot_id, bool(body.get("enabled")))
|
||||
return ok(message="更新平台配置成功~")
|
||||
|
||||
|
||||
@router.post("/bots/test")
|
||||
async def test_bot_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
bot_id = _required_text(body.get("bot_id"), "bot_id")
|
||||
return ok({"id": bot_id, "status": "unsupported"})
|
||||
|
||||
|
||||
@router.patch("/bots/{bot_id:path}/enabled")
|
||||
async def set_bot_enabled(
|
||||
bot_id: str,
|
||||
payload: EnabledPatch,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
await service.set_bot_enabled(bot_id, payload.enabled)
|
||||
return ok(message="更新平台配置成功~")
|
||||
|
||||
|
||||
@router.post("/bots/{bot_id:path}/test")
|
||||
async def test_bot(
|
||||
bot_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
):
|
||||
return ok({"id": bot_id, "status": "unsupported"})
|
||||
|
||||
|
||||
@router.get("/bots/{bot_id:path}")
|
||||
async def get_bot(
|
||||
bot_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_bot(bot_id))
|
||||
|
||||
|
||||
@router.put("/bots/{bot_id:path}")
|
||||
async def update_bot(
|
||||
bot_id: str,
|
||||
payload: BotConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
await service.update_bot(bot_id, payload.to_legacy_config(fallback_id=bot_id))
|
||||
return ok(message="更新平台配置成功~")
|
||||
|
||||
|
||||
@router.delete("/bots/{bot_id:path}")
|
||||
async def delete_bot(
|
||||
bot_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: BotConfigService = Depends(get_service),
|
||||
):
|
||||
await service.delete_bot(bot_id)
|
||||
return ok(message="删除平台配置成功~")
|
||||
19
astrbot/dashboard/v1/routers/compat/__init__.py
Normal file
19
astrbot/dashboard/v1/routers/compat/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .auth import router as auth_router
|
||||
from .chat import router as chat_router
|
||||
from .conversations import router as conversations_router
|
||||
from .files import router as files_router
|
||||
from .knowledge_bases import router as knowledge_bases_router
|
||||
from .personas import router as personas_router
|
||||
from .sessions import router as sessions_router
|
||||
from .system import router as system_router
|
||||
|
||||
compat_routers = [
|
||||
auth_router,
|
||||
chat_router,
|
||||
files_router,
|
||||
conversations_router,
|
||||
system_router,
|
||||
sessions_router,
|
||||
personas_router,
|
||||
knowledge_bases_router,
|
||||
]
|
||||
152
astrbot/dashboard/v1/routers/compat/auth.py
Normal file
152
astrbot/dashboard/v1/routers/compat/auth.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import (
|
||||
auth_optional,
|
||||
get_bridge,
|
||||
require_config_scope,
|
||||
require_system_scope,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Auth"])
|
||||
|
||||
|
||||
@router.post("/auth/login")
|
||||
async def login(
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, None, method="POST", target_path="/api/auth/login"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, None, method="POST", target_path="/api/auth/logout"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auth/setup-status")
|
||||
async def setup_status(
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, None, method="GET", target_path="/api/auth/setup-status"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/setup")
|
||||
async def setup(
|
||||
request: Request,
|
||||
auth: AuthContext | None = Depends(auth_optional),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
target_path = "/api/auth/setup-authenticated" if auth else "/api/auth/setup"
|
||||
return await bridge.forward(request, auth, method="POST", target_path=target_path)
|
||||
|
||||
|
||||
@router.post("/auth/totp/setup")
|
||||
async def totp_setup(
|
||||
request: Request,
|
||||
auth: AuthContext | None = Depends(auth_optional),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/auth/totp/setup"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/auth/totp/recovery")
|
||||
async def totp_recovery(
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, None, method="POST", target_path="/api/auth/totp/recovery"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/system-config/runtime")
|
||||
async def get_system_config_runtime(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/config/get"
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/auth/account")
|
||||
async def update_account(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/auth/account/edit"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api-keys")
|
||||
async def list_api_keys(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/apikey/list"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api-keys")
|
||||
async def create_api_key(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/apikey/create"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api-keys/{key_id}/revoke")
|
||||
async def revoke_api_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/apikey/revoke",
|
||||
json_body={"key_id": key_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/api-keys/{key_id}")
|
||||
async def delete_api_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/apikey/delete",
|
||||
json_body={"key_id": key_id},
|
||||
)
|
||||
342
astrbot/dashboard/v1/routers/compat/chat.py
Normal file
342
astrbot/dashboard/v1/routers/compat/chat.py
Normal file
@@ -0,0 +1,342 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import get_bridge, require_chat_scope, require_config_scope
|
||||
from .common import (
|
||||
json_or_empty as _json_or_empty,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Chat"])
|
||||
|
||||
|
||||
@router.post("/bot-types/{bot_type}/registration")
|
||||
async def register_bot_type(
|
||||
bot_type: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path=f"/api/platform/registration/{bot_type}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/sessions/new")
|
||||
async def create_chat_session(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/chat/new_session"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/sessions/batch-delete")
|
||||
async def batch_delete_chat_sessions(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/chat/batch_delete_sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/sessions/{session_id}")
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chat/get_session",
|
||||
query={"session_id": session_id},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/chat/sessions/{session_id}")
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chat/update_session_display_name",
|
||||
json_body={"session_id": session_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/chat/sessions/{session_id}")
|
||||
async def delete_chat_session(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chat/delete_session",
|
||||
query={"session_id": session_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/sessions/{session_id}/stop")
|
||||
async def stop_chat_session(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chat/stop",
|
||||
json_body={"session_id": session_id},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/chat/sessions/{session_id}/messages/{message_id}")
|
||||
async def update_chat_message(
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chat/message/edit",
|
||||
json_body={"session_id": session_id, "message_id": message_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/sessions/{session_id}/messages/{message_id}/regenerate")
|
||||
async def regenerate_chat_message(
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chat/message/regenerate",
|
||||
json_body={"session_id": session_id, "message_id": message_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/configs")
|
||||
async def chat_configs(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/config/abconfs"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/threads")
|
||||
async def create_chat_thread(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/chat/thread/create"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/threads/{thread_id}")
|
||||
async def get_chat_thread(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chat/thread/get",
|
||||
query={"thread_id": thread_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/chat/threads/{thread_id}")
|
||||
async def delete_chat_thread(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chat/thread/delete",
|
||||
json_body={"thread_id": thread_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/threads/{thread_id}/messages")
|
||||
async def send_chat_thread_message(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chat/thread/send",
|
||||
json_body={"thread_id": thread_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/projects")
|
||||
async def list_chat_projects(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/chatui_project/list"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/projects")
|
||||
async def create_chat_project(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/chatui_project/create"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/projects/{project_id}")
|
||||
async def get_chat_project(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chatui_project/get",
|
||||
query={"project_id": project_id},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/chat/projects/{project_id}")
|
||||
async def update_chat_project(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chatui_project/update",
|
||||
json_body={"project_id": project_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/chat/projects/{project_id}")
|
||||
async def delete_chat_project(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chatui_project/delete",
|
||||
query={"project_id": project_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/projects/{project_id}/sessions")
|
||||
async def list_chat_project_sessions(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chatui_project/get_sessions",
|
||||
query={"project_id": project_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/projects/{project_id}/sessions/{session_id}")
|
||||
async def add_chat_project_session(
|
||||
project_id: str,
|
||||
session_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chatui_project/add_session",
|
||||
json_body={"project_id": project_id, "session_id": session_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/chat/projects/sessions/{session_id}")
|
||||
async def remove_chat_project_session(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/chatui_project/remove_session",
|
||||
json_body={"session_id": session_id},
|
||||
)
|
||||
59
astrbot/dashboard/v1/routers/compat/common.py
Normal file
59
astrbot/dashboard/v1/routers/compat/common.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext, require_scope
|
||||
|
||||
|
||||
def get_bridge(request: Request) -> DashboardRouteBridgeService:
|
||||
return request.app.state.services.route_bridge
|
||||
|
||||
|
||||
async def json_or_empty(request: Request) -> dict:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
async def auth_optional(request: Request) -> AuthContext | None:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
has_api_key = bool(
|
||||
request.query_params.get("api_key")
|
||||
or request.query_params.get("key")
|
||||
or request.headers.get("X-API-Key")
|
||||
)
|
||||
if auth_header.startswith(("Bearer ", "ApiKey ")) or has_api_key:
|
||||
return await require_scope(request, "system")
|
||||
return None
|
||||
|
||||
|
||||
async def require_chat_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "chat")
|
||||
|
||||
|
||||
async def require_config_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "config")
|
||||
|
||||
|
||||
async def require_data_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "data")
|
||||
|
||||
|
||||
async def require_file_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "file")
|
||||
|
||||
|
||||
async def require_kb_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "kb")
|
||||
|
||||
|
||||
async def require_persona_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "persona")
|
||||
|
||||
|
||||
async def require_system_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "system")
|
||||
116
astrbot/dashboard/v1/routers/compat/conversations.py
Normal file
116
astrbot/dashboard/v1/routers/compat/conversations.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import get_bridge, require_data_scope
|
||||
from .common import json_or_empty as _json_or_empty
|
||||
|
||||
router = APIRouter(tags=["Conversations"])
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/conversation/list"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/conversations/export")
|
||||
async def export_conversations(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/conversation/export"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/conversations/batch-delete")
|
||||
async def batch_delete_conversations(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/conversation/delete"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}")
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
user_id = request.query_params.get("user_id")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/conversation/detail",
|
||||
json_body={"user_id": user_id, "cid": conversation_id},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/conversations/{conversation_id}")
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
user_id = body.pop("user_id", None) or request.query_params.get("user_id")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/conversation/update",
|
||||
json_body={"user_id": user_id, "cid": conversation_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
user_id = request.query_params.get("user_id")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/conversation/delete",
|
||||
json_body={"user_id": user_id, "cid": conversation_id},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/messages")
|
||||
async def update_conversation_messages(
|
||||
conversation_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_data_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
user_id = body.pop("user_id", None) or request.query_params.get("user_id")
|
||||
if "messages" in body and "history" not in body:
|
||||
body["history"] = body.pop("messages")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/conversation/update_history",
|
||||
json_body={"user_id": user_id, "cid": conversation_id, **body},
|
||||
)
|
||||
72
astrbot/dashboard/v1/routers/compat/files.py
Normal file
72
astrbot/dashboard/v1/routers/compat/files.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import get_bridge, require_file_scope
|
||||
|
||||
router = APIRouter(tags=["Files"])
|
||||
|
||||
|
||||
@router.post("/files")
|
||||
async def upload_file(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_file_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/chat/post_file"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/content")
|
||||
async def get_file_by_name(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_file_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/chat/get_file"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/tokens/{file_token}")
|
||||
async def get_token_file(
|
||||
file_token: str,
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
None,
|
||||
method="GET",
|
||||
target_path=f"/api/file/{file_token}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{attachment_id}")
|
||||
@router.get("/files/{attachment_id}/content")
|
||||
async def get_file(
|
||||
attachment_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_file_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/chat/get_attachment",
|
||||
query={"attachment_id": attachment_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/files/{attachment_id}")
|
||||
async def delete_file(
|
||||
attachment_id: str,
|
||||
_request: Request,
|
||||
_auth: AuthContext = Depends(require_file_scope),
|
||||
):
|
||||
return {"status": "ok", "data": {"attachment_id": attachment_id}}
|
||||
267
astrbot/dashboard/v1/routers/compat/knowledge_bases.py
Normal file
267
astrbot/dashboard/v1/routers/compat/knowledge_bases.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import get_bridge, require_kb_scope
|
||||
from .common import json_or_empty as _json_or_empty
|
||||
|
||||
router = APIRouter(tags=["Knowledge Bases"])
|
||||
|
||||
|
||||
@router.get("/knowledge-bases")
|
||||
async def list_knowledge_bases(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(request, auth, method="GET", target_path="/api/kb/list")
|
||||
|
||||
|
||||
@router.post("/knowledge-bases")
|
||||
async def create_knowledge_base(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/kb/create"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/tasks/{task_id}")
|
||||
async def get_knowledge_base_task(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/kb/document/upload/progress",
|
||||
query={"task_id": task_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}")
|
||||
async def get_knowledge_base(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/kb/get", query={"kb_id": kb_id}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/{kb_id}")
|
||||
async def update_knowledge_base(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/update",
|
||||
json_body={"kb_id": kb_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}")
|
||||
async def delete_knowledge_base(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/delete",
|
||||
json_body={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}/stats")
|
||||
async def get_knowledge_base_stats(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/kb/stats",
|
||||
query={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}/documents")
|
||||
async def list_knowledge_base_documents(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
query = dict(request.query_params)
|
||||
query["kb_id"] = kb_id
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/kb/document/list",
|
||||
query=query,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents")
|
||||
async def upload_knowledge_base_document(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/document/upload",
|
||||
query={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents/import")
|
||||
async def import_knowledge_base_documents(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/document/import",
|
||||
json_body={"kb_id": kb_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents/import-url")
|
||||
async def import_knowledge_base_document_url(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/document/upload/url",
|
||||
json_body={"kb_id": kb_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}/documents/{document_id}")
|
||||
async def get_knowledge_base_document(
|
||||
kb_id: str,
|
||||
document_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/kb/document/get",
|
||||
query={"kb_id": kb_id, "doc_id": document_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}/documents/{document_id}")
|
||||
async def delete_knowledge_base_document(
|
||||
kb_id: str,
|
||||
document_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/document/delete",
|
||||
json_body={"kb_id": kb_id, "doc_id": document_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}/chunks")
|
||||
async def list_knowledge_base_chunks(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
query = dict(request.query_params)
|
||||
query["kb_id"] = kb_id
|
||||
if "document_id" in query and "doc_id" not in query:
|
||||
query["doc_id"] = query.pop("document_id")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/kb/chunk/list",
|
||||
query=query,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}/chunks/{chunk_id}")
|
||||
async def delete_knowledge_base_chunk(
|
||||
kb_id: str,
|
||||
chunk_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
document_id = request.query_params.get("document_id") or request.query_params.get(
|
||||
"doc_id"
|
||||
)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/chunk/delete",
|
||||
json_body={"kb_id": kb_id, "chunk_id": chunk_id, "doc_id": document_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/retrieve")
|
||||
async def retrieve_knowledge_base(
|
||||
kb_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_kb_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/kb/retrieve",
|
||||
json_body={"kb_id": kb_id, **body},
|
||||
)
|
||||
224
astrbot/dashboard/v1/routers/compat/personas.py
Normal file
224
astrbot/dashboard/v1/routers/compat/personas.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import get_bridge, require_persona_scope
|
||||
from .common import json_or_empty as _json_or_empty
|
||||
|
||||
router = APIRouter(tags=["Personas"])
|
||||
|
||||
|
||||
@router.get("/personas/tree")
|
||||
async def persona_tree(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/persona/folder/tree"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/personas")
|
||||
async def list_personas(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/persona/list"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/personas")
|
||||
async def create_persona(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/persona/create"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/personas/by-id")
|
||||
async def get_persona_by_id(
|
||||
request: Request,
|
||||
persona_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/detail",
|
||||
json_body={"persona_id": persona_id},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/personas/by-id")
|
||||
async def update_persona_by_id(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
persona_id = str(body.get("persona_id") or "").strip()
|
||||
if not persona_id:
|
||||
raise ValueError("Missing key: persona_id")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/update",
|
||||
json_body={
|
||||
"persona_id": persona_id,
|
||||
**{key: value for key, value in body.items() if key != "persona_id"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/personas/by-id")
|
||||
async def delete_persona_by_id(
|
||||
request: Request,
|
||||
persona_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/delete",
|
||||
json_body={"persona_id": persona_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/personas/{persona_id:path}")
|
||||
async def get_persona(
|
||||
persona_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/detail",
|
||||
json_body={"persona_id": persona_id},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/personas/{persona_id:path}")
|
||||
async def update_persona(
|
||||
persona_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/update",
|
||||
json_body={"persona_id": persona_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/personas/{persona_id:path}")
|
||||
async def delete_persona(
|
||||
persona_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/delete",
|
||||
json_body={"persona_id": persona_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/personas/move")
|
||||
async def move_persona(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/persona/move"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/personas/reorder")
|
||||
async def reorder_personas(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/persona/reorder"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/persona-folders")
|
||||
async def list_persona_folders(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/persona/folder/list"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/persona-folders")
|
||||
async def create_persona_folder(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/persona/folder/create"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/persona-folders/{folder_id}")
|
||||
async def update_persona_folder(
|
||||
folder_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/folder/update",
|
||||
json_body={"folder_id": folder_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/persona-folders/{folder_id}")
|
||||
async def delete_persona_folder(
|
||||
folder_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_persona_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/persona/folder/delete",
|
||||
json_body={"folder_id": folder_id},
|
||||
)
|
||||
217
astrbot/dashboard/v1/routers/compat/sessions.py
Normal file
217
astrbot/dashboard/v1/routers/compat/sessions.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.dashboard.services.session_management_service import (
|
||||
SessionManagementService,
|
||||
SessionManagementServiceError,
|
||||
)
|
||||
|
||||
from ...auth import AuthContext
|
||||
from ...responses import error, ok
|
||||
from ...schemas import (
|
||||
BatchSessionProviderRequest,
|
||||
BatchSessionServiceRequest,
|
||||
SessionGroupRequest,
|
||||
SessionRuleRequest,
|
||||
UmoListRequest,
|
||||
)
|
||||
from .common import require_data_scope
|
||||
|
||||
router = APIRouter(tags=["Sessions"])
|
||||
|
||||
|
||||
def get_service(request: Request) -> SessionManagementService:
|
||||
return request.app.state.services.sessions
|
||||
|
||||
|
||||
def _service_error(exc: SessionManagementServiceError) -> dict:
|
||||
return error(str(exc))
|
||||
|
||||
|
||||
def _unexpected_error(prefix: str, exc: Exception) -> dict:
|
||||
logger.error(f"{prefix}: {exc!s}")
|
||||
return error(f"{prefix}: {exc!s}")
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def list_sessions(
|
||||
page: int = Query(1),
|
||||
page_size: int = Query(20),
|
||||
search: str = Query(""),
|
||||
message_type: str = Query("all"),
|
||||
platform: str = Query(""),
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(
|
||||
await service.list_all_umos_with_status(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search.strip(),
|
||||
message_type=message_type,
|
||||
platform=platform,
|
||||
)
|
||||
)
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("获取会话状态列表失败", exc)
|
||||
|
||||
|
||||
@router.get("/sessions/active-umos")
|
||||
async def list_active_umos(
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(await service.list_active_umos())
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("获取 UMO 列表失败", exc)
|
||||
|
||||
|
||||
@router.get("/sessions/rules")
|
||||
async def list_session_rules(
|
||||
page: int = Query(1),
|
||||
page_size: int = Query(10),
|
||||
search: str = Query(""),
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(
|
||||
await service.list_session_rules(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search.strip(),
|
||||
)
|
||||
)
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("获取规则列表失败", exc)
|
||||
|
||||
|
||||
@router.post("/sessions/rules")
|
||||
async def update_session_rule(
|
||||
payload: SessionRuleRequest,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(
|
||||
await service.update_session_rule(payload.model_dump(exclude_none=True))
|
||||
)
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("更新会话规则失败", exc)
|
||||
|
||||
|
||||
@router.post("/sessions/rules/delete")
|
||||
async def delete_session_rule(
|
||||
payload: UmoListRequest,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(
|
||||
await service.delete_session_rules(payload.model_dump(exclude_none=True))
|
||||
)
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("删除会话规则失败", exc)
|
||||
|
||||
|
||||
@router.patch("/sessions/provider")
|
||||
async def update_session_provider(
|
||||
payload: BatchSessionProviderRequest,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(
|
||||
await service.batch_update_provider(payload.model_dump(exclude_none=True))
|
||||
)
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("批量更新 Provider 失败", exc)
|
||||
|
||||
|
||||
@router.patch("/sessions/service")
|
||||
async def update_session_service(
|
||||
payload: BatchSessionServiceRequest,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(
|
||||
await service.batch_update_service(payload.model_dump(exclude_none=True))
|
||||
)
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("批量更新服务状态失败", exc)
|
||||
|
||||
|
||||
@router.get("/session-groups")
|
||||
async def list_session_groups(
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(service.list_groups())
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("获取分组列表失败", exc)
|
||||
|
||||
|
||||
@router.post("/session-groups")
|
||||
async def create_session_group(
|
||||
payload: SessionGroupRequest,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(service.create_group(payload.model_dump(exclude_none=True)))
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("创建分组失败", exc)
|
||||
|
||||
|
||||
@router.put("/session-groups/{group_id}")
|
||||
async def update_session_group(
|
||||
group_id: str,
|
||||
payload: SessionGroupRequest,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
body = payload.model_dump(exclude_none=True)
|
||||
return ok(service.update_group({"group_id": group_id, **body}))
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("更新分组失败", exc)
|
||||
|
||||
|
||||
@router.delete("/session-groups/{group_id}")
|
||||
async def delete_session_group(
|
||||
group_id: str,
|
||||
_auth: AuthContext = Depends(require_data_scope),
|
||||
service: SessionManagementService = Depends(get_service),
|
||||
):
|
||||
try:
|
||||
return ok(service.delete_group({"group_id": group_id}))
|
||||
except SessionManagementServiceError as exc:
|
||||
return _service_error(exc)
|
||||
except Exception as exc:
|
||||
return _unexpected_error("删除分组失败", exc)
|
||||
647
astrbot/dashboard/v1/routers/compat/system.py
Normal file
647
astrbot/dashboard/v1/routers/compat/system.py
Normal file
@@ -0,0 +1,647 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ...auth import AuthContext
|
||||
from .common import (
|
||||
get_bridge,
|
||||
require_config_scope,
|
||||
require_system_scope,
|
||||
)
|
||||
from .common import (
|
||||
json_or_empty as _json_or_empty,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["System"])
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/stat/get"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/provider-tokens")
|
||||
async def provider_tokens(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/stat/provider-tokens"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/version")
|
||||
async def version(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/stat/version"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/first-notice")
|
||||
async def first_notice(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/stat/first-notice"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stats/ghproxy/test")
|
||||
async def test_ghproxy_connection(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/stat/test-ghproxy-connection",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/changelogs")
|
||||
async def list_changelog_versions(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/stat/changelog/list"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/changelogs/{version}")
|
||||
async def get_changelog(
|
||||
version: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/stat/changelog",
|
||||
query={"version": version},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/start-time")
|
||||
async def start_time(
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, None, method="GET", target_path="/api/stat/start-time"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/storage")
|
||||
async def storage_status(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/stat/storage"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stats/storage/cleanup")
|
||||
async def cleanup_storage(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/stat/storage/cleanup"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/system/restart")
|
||||
async def restart_system(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/stat/restart-core"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/logs/history")
|
||||
async def log_history(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/log-history"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/logs/live")
|
||||
async def live_logs(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/live-log"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cron/jobs")
|
||||
async def list_cron_jobs(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/cron/jobs"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cron/jobs")
|
||||
async def create_cron_job(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/cron/jobs"
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/cron/jobs/{job_id}")
|
||||
async def update_cron_job(
|
||||
job_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="PATCH",
|
||||
target_path=f"/api/cron/jobs/{job_id}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/cron/jobs/{job_id}")
|
||||
async def delete_cron_job(
|
||||
job_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="DELETE",
|
||||
target_path=f"/api/cron/jobs/{job_id}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cron/jobs/{job_id}/run")
|
||||
async def run_cron_job(
|
||||
job_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path=f"/api/cron/jobs/{job_id}/run",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/trace/settings")
|
||||
async def get_trace_settings(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/trace/settings"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/trace/settings")
|
||||
async def update_trace_settings(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/trace/settings"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/updates/check")
|
||||
async def check_updates(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/update/check"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/updates/releases")
|
||||
async def update_releases(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/update/releases"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/updates/progress/{task_id}")
|
||||
async def update_progress(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/update/progress",
|
||||
query={"id": task_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/updates/core")
|
||||
async def update_core(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/update/do"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/updates/dashboard")
|
||||
async def update_dashboard(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/update/dashboard"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pip/install")
|
||||
async def install_pip_package(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/update/pip-install"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/migrations")
|
||||
async def run_migration(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/update/migration"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/subagents/config")
|
||||
async def get_subagent_config(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/subagent/config"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/subagents/config")
|
||||
async def update_subagent_config(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/subagent/config"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/subagents/available-tools")
|
||||
async def get_subagent_tools(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/subagent/available-tools"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/backups")
|
||||
async def list_backups(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/backup/list"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups")
|
||||
async def export_backup(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/backup/export"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/upload")
|
||||
async def upload_backup(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/backup/upload"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/upload/init")
|
||||
async def init_backup_upload(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/backup/upload/init"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/upload/chunk")
|
||||
async def upload_backup_chunk(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/backup/upload/chunk"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/upload/complete")
|
||||
async def complete_backup_upload(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/backup/upload/complete"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/upload/abort")
|
||||
async def abort_backup_upload(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/backup/upload/abort"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/backups/tasks/{task_id}")
|
||||
async def get_backup_progress(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/backup/progress",
|
||||
query={"task_id": task_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/backups/{filename}")
|
||||
async def download_backup(
|
||||
filename: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/backup/download",
|
||||
query={"filename": filename},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/backups/{filename}")
|
||||
async def rename_backup(
|
||||
filename: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/backup/rename",
|
||||
json_body={"filename": filename, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/backups/{filename}")
|
||||
async def delete_backup(
|
||||
filename: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/backup/delete",
|
||||
json_body={"filename": filename},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/{filename}/check")
|
||||
async def check_backup(
|
||||
filename: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/backup/check",
|
||||
json_body={"filename": filename, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/backups/{filename}/import")
|
||||
async def import_backup(
|
||||
filename: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_system_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/backup/import",
|
||||
json_body={"filename": filename, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/t2i/templates")
|
||||
async def list_t2i_templates(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/t2i/templates"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/t2i/templates")
|
||||
async def create_t2i_template(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/t2i/templates/create"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/t2i/templates/active")
|
||||
async def get_active_t2i_template(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/t2i/templates/active"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/t2i/templates/active")
|
||||
async def set_active_t2i_template(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/t2i/templates/set_active"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/t2i/templates/default/reset")
|
||||
async def reset_default_t2i_template(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/t2i/templates/reset_default"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/t2i/templates/{name}")
|
||||
async def get_t2i_template(
|
||||
name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path=f"/api/t2i/templates/{name}",
|
||||
)
|
||||
|
||||
|
||||
@router.put("/t2i/templates/{name}")
|
||||
async def update_t2i_template(
|
||||
name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="PUT",
|
||||
target_path=f"/api/t2i/templates/{name}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/t2i/templates/{name}")
|
||||
async def delete_t2i_template(
|
||||
name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="DELETE",
|
||||
target_path=f"/api/t2i/templates/{name}",
|
||||
)
|
||||
174
astrbot/dashboard/v1/routers/config_profiles.py
Normal file
174
astrbot/dashboard/v1/routers/config_profiles.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.config_service import (
|
||||
ConfigProfileService,
|
||||
ConfigRoutingService,
|
||||
)
|
||||
|
||||
from ..auth import AuthContext, require_scope
|
||||
from ..responses import ok
|
||||
from ..schemas import (
|
||||
ConfigProfileCreateRequest,
|
||||
ConfigRoutesReplaceRequest,
|
||||
ConfigRouteUpsertRequest,
|
||||
RenameRequest,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Config Profiles"])
|
||||
|
||||
|
||||
async def require_config_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "config")
|
||||
|
||||
|
||||
def get_service(request: Request) -> ConfigProfileService:
|
||||
return request.app.state.services.config_profiles
|
||||
|
||||
|
||||
def get_routing_service(request: Request) -> ConfigRoutingService:
|
||||
return request.app.state.services.config_routes
|
||||
|
||||
|
||||
@router.get("/config-profiles/schema")
|
||||
async def get_config_profile_schema(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_profile_schema())
|
||||
|
||||
|
||||
@router.get("/config-profiles")
|
||||
async def list_config_profiles(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
return ok(service.list_profiles())
|
||||
|
||||
|
||||
@router.post("/config-profiles")
|
||||
async def create_config_profile(
|
||||
payload: ConfigProfileCreateRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
return ok(await service.create_profile(payload.name, payload.config), "创建成功")
|
||||
|
||||
|
||||
@router.get("/config-profiles/{config_id}")
|
||||
async def get_config_profile(
|
||||
config_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_profile(config_id))
|
||||
|
||||
|
||||
@router.put("/config-profiles/{config_id}")
|
||||
async def update_config_profile(
|
||||
config_id: str,
|
||||
payload: dict[str, Any],
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
message = await service.update_profile(
|
||||
config_id,
|
||||
payload,
|
||||
two_factor_code=request.headers.get("X-2FA-Code"),
|
||||
)
|
||||
return ok(message=message or "保存成功")
|
||||
|
||||
|
||||
@router.patch("/config-profiles/{config_id}")
|
||||
async def rename_config_profile(
|
||||
config_id: str,
|
||||
payload: RenameRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
service.rename_profile(config_id, payload.name)
|
||||
return ok(message="更新成功")
|
||||
|
||||
|
||||
@router.delete("/config-profiles/{config_id}")
|
||||
async def delete_config_profile(
|
||||
config_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
service.delete_profile(config_id)
|
||||
return ok(message="删除成功")
|
||||
|
||||
|
||||
@router.get("/system-config/schema")
|
||||
async def get_system_config_schema(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_system_schema())
|
||||
|
||||
|
||||
@router.get("/system-config")
|
||||
async def get_system_config(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_profile("default"))
|
||||
|
||||
|
||||
@router.put("/system-config")
|
||||
async def update_system_config(
|
||||
payload: dict[str, Any],
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigProfileService = Depends(get_service),
|
||||
):
|
||||
message = await service.update_profile(
|
||||
"default",
|
||||
payload,
|
||||
two_factor_code=request.headers.get("X-2FA-Code"),
|
||||
)
|
||||
return ok(message=message or "保存成功")
|
||||
|
||||
|
||||
@router.get("/config-routes")
|
||||
async def list_config_routes(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_routing_service),
|
||||
):
|
||||
return ok(service.list_routes())
|
||||
|
||||
|
||||
@router.put("/config-routes")
|
||||
async def replace_config_routes(
|
||||
payload: ConfigRoutesReplaceRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_routing_service),
|
||||
):
|
||||
await service.replace_route_mapping(payload.routing)
|
||||
return ok(message="更新成功")
|
||||
|
||||
|
||||
@router.put("/config-routes/{umo}")
|
||||
async def upsert_config_route(
|
||||
umo: str,
|
||||
payload: ConfigRouteUpsertRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_routing_service),
|
||||
):
|
||||
await service.set_route(umo, payload.config_id)
|
||||
return ok(message="更新成功")
|
||||
|
||||
|
||||
@router.delete("/config-routes/{umo}")
|
||||
async def delete_config_route(
|
||||
umo: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ConfigRoutingService = Depends(get_routing_service),
|
||||
):
|
||||
await service.delete_route_by_umo(umo)
|
||||
return ok(message="删除成功")
|
||||
732
astrbot/dashboard/v1/routers/extensions.py
Normal file
732
astrbot/dashboard/v1/routers/extensions.py
Normal file
@@ -0,0 +1,732 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ..auth import AuthContext, require_scope
|
||||
|
||||
router = APIRouter(tags=["Extension Components"])
|
||||
|
||||
|
||||
def get_bridge(request: Request) -> DashboardRouteBridgeService:
|
||||
return request.app.state.services.route_bridge
|
||||
|
||||
|
||||
async def require_tool_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "tool")
|
||||
|
||||
|
||||
async def require_skill_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "skill")
|
||||
|
||||
|
||||
async def _json_or_empty(request: Request) -> dict:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _required_text(value: object, name: str) -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"Missing key: {name}")
|
||||
return text
|
||||
|
||||
|
||||
def _config_from_body(body: dict, id_key: str) -> dict:
|
||||
config = body.get("config")
|
||||
if isinstance(config, dict):
|
||||
return dict(config)
|
||||
return {
|
||||
key: value
|
||||
for key, value in body.items()
|
||||
if key not in {id_key, "config", "enabled"}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/commands")
|
||||
async def list_commands(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/commands",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/commands/conflicts")
|
||||
async def list_command_conflicts(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/commands/conflicts",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/commands/{command_id}")
|
||||
async def update_command(
|
||||
command_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
if "enabled" in body:
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/commands/toggle",
|
||||
json_body={
|
||||
"handler_full_name": command_id,
|
||||
"enabled": body["enabled"],
|
||||
},
|
||||
)
|
||||
if "alias" in body:
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/commands/rename",
|
||||
json_body={
|
||||
"handler_full_name": command_id,
|
||||
"new_name": body["alias"],
|
||||
"aliases": body.get("aliases"),
|
||||
},
|
||||
)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/commands/permission",
|
||||
json_body={
|
||||
"handler_full_name": command_id,
|
||||
"permission": body.get("permission_group"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
async def list_tools(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/tools/list",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/tools/{tool_id}/enabled")
|
||||
async def set_tool_enabled(
|
||||
tool_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/toggle-tool",
|
||||
json_body={"name": tool_id, "activate": body.get("enabled")},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp/servers")
|
||||
async def list_mcp_servers(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/tools/mcp/servers",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mcp/servers")
|
||||
async def create_mcp_server(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
if "enabled" in body and "active" not in body:
|
||||
body["active"] = body.pop("enabled")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/add",
|
||||
json_body=body,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mcp/servers/by-name")
|
||||
async def update_mcp_server_by_name(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
server_name = _required_text(body.get("server_name"), "server_name")
|
||||
config = _config_from_body(body, "server_name")
|
||||
if "enabled" in body and "active" not in config:
|
||||
config["active"] = body["enabled"]
|
||||
config.setdefault("name", server_name)
|
||||
config.setdefault("oldName", server_name)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/update",
|
||||
json_body=config,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/mcp/servers/by-name")
|
||||
async def delete_mcp_server_by_name(
|
||||
server_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/delete",
|
||||
json_body={"name": server_name},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/mcp/servers/enabled")
|
||||
async def set_mcp_server_enabled_by_name(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
server_name = _required_text(body.get("server_name"), "server_name")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/update",
|
||||
json_body={
|
||||
"name": server_name,
|
||||
"oldName": server_name,
|
||||
"active": body.get("enabled"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mcp/servers/test")
|
||||
async def test_mcp_server_by_name(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
server_name = _required_text(body.get("server_name"), "server_name")
|
||||
config = body.get("mcp_server_config") or body.get("config")
|
||||
config = dict(config) if isinstance(config, dict) else {"name": server_name}
|
||||
config.setdefault("name", server_name)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/test",
|
||||
json_body={"mcp_server_config": config},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/mcp/servers/{server_name:path}/enabled")
|
||||
async def set_mcp_server_enabled(
|
||||
server_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/update",
|
||||
json_body={
|
||||
"name": server_name,
|
||||
"oldName": server_name,
|
||||
"active": body.get("enabled"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mcp/servers/{server_name:path}/test")
|
||||
async def test_mcp_server(
|
||||
server_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
config = body.get("mcp_server_config") or body or {"name": server_name}
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/test",
|
||||
json_body={"mcp_server_config": config},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mcp/servers/{server_name:path}")
|
||||
async def update_mcp_server(
|
||||
server_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
if "enabled" in body and "active" not in body:
|
||||
body["active"] = body.pop("enabled")
|
||||
body.setdefault("name", server_name)
|
||||
body.setdefault("oldName", server_name)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/update",
|
||||
json_body=body,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/mcp/servers/{server_name:path}")
|
||||
async def delete_mcp_server(
|
||||
server_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/delete",
|
||||
json_body={"name": server_name},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mcp/providers/modelscope/sync")
|
||||
async def sync_modelscope_mcp_servers(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_tool_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/tools/mcp/sync-provider",
|
||||
json_body={
|
||||
"name": "modelscope",
|
||||
"access_token": body.get("access_token", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills")
|
||||
async def list_skills(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills")
|
||||
async def upload_skill(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/upload",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/batch")
|
||||
async def upload_skills_batch(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/batch-upload",
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/skills/by-name")
|
||||
async def update_skill_by_name(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
skill_name = _required_text(body.get("skill_name"), "skill_name")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/update",
|
||||
json_body={
|
||||
"name": skill_name,
|
||||
"active": body.get("enabled", body.get("active", True)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/skills/by-name")
|
||||
async def delete_skill_by_name(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/delete",
|
||||
json_body={"name": skill_name},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/archive")
|
||||
async def download_skill_by_name(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/download",
|
||||
query={"name": skill_name},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/files")
|
||||
async def list_skill_files_by_name(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
path = request.query_params.get("path", "")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/files",
|
||||
query={"name": skill_name, "path": path},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/file")
|
||||
async def get_skill_file_by_name(
|
||||
skill_name: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/file",
|
||||
query={"name": skill_name, "path": path},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/skills/file")
|
||||
async def update_skill_file_by_name(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
skill_name = _required_text(body.get("skill_name"), "skill_name")
|
||||
path = _required_text(body.get("path"), "path")
|
||||
content = str(body.get("content", ""))
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/file",
|
||||
json_body={"name": skill_name, "path": path, "content": content},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/{skill_name:path}/archive")
|
||||
async def download_skill(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/download",
|
||||
query={"name": skill_name},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/{skill_name:path}/files")
|
||||
async def list_skill_files(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
path = request.query_params.get("path", "")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/files",
|
||||
query={"name": skill_name, "path": path},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/{skill_name:path}/files/{file_path:path}")
|
||||
async def get_skill_file(
|
||||
skill_name: str,
|
||||
file_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/file",
|
||||
query={"name": skill_name, "path": file_path},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/skills/{skill_name:path}/files/{file_path:path}")
|
||||
async def update_skill_file(
|
||||
skill_name: str,
|
||||
file_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
content = (await request.body()).decode("utf-8")
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/file",
|
||||
json_body={"name": skill_name, "path": file_path, "content": content},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/skills/{skill_name:path}")
|
||||
async def update_skill(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/update",
|
||||
json_body={
|
||||
"name": skill_name,
|
||||
"active": body.get("enabled", body.get("active", True)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/skills/{skill_name:path}")
|
||||
async def delete_skill(
|
||||
skill_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/delete",
|
||||
json_body={"name": skill_name},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/neo/candidates")
|
||||
async def list_neo_skill_candidates(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/neo/candidates",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/neo/releases")
|
||||
async def list_neo_skill_releases(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/neo/releases",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/neo/payload")
|
||||
async def get_neo_skill_payload(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/skills/neo/payload",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/neo/evaluate")
|
||||
async def evaluate_neo_skill_candidate(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/neo/evaluate",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/neo/promote")
|
||||
async def promote_neo_skill_candidate(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/neo/promote",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/neo/rollback")
|
||||
async def rollback_neo_skill_release(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/neo/rollback",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/neo/sync")
|
||||
async def sync_neo_skill_release(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/neo/sync",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/neo/candidates/delete")
|
||||
async def delete_neo_skill_candidate(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/neo/delete-candidate",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/skills/neo/releases/delete")
|
||||
async def delete_neo_skill_release(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_skill_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/skills/neo/delete-release",
|
||||
)
|
||||
223
astrbot/dashboard/v1/routers/open_api_compat.py
Normal file
223
astrbot/dashboard/v1/routers/open_api_compat.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, WebSocket
|
||||
|
||||
from astrbot.dashboard.fastapi_compat import (
|
||||
CompatG,
|
||||
call_request_view,
|
||||
call_websocket_view,
|
||||
)
|
||||
from astrbot.dashboard.services.open_api_service import (
|
||||
OpenApiService,
|
||||
OpenApiServiceError,
|
||||
)
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ..auth import AuthContext, require_scope
|
||||
from ..responses import ApiError, ok
|
||||
|
||||
router = APIRouter(tags=["Open API Compatibility"])
|
||||
|
||||
|
||||
async def require_im_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "im")
|
||||
|
||||
|
||||
async def require_chat_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "chat")
|
||||
|
||||
|
||||
async def require_config_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "config")
|
||||
|
||||
|
||||
async def require_file_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "file")
|
||||
|
||||
|
||||
def get_bridge(request: Request) -> DashboardRouteBridgeService:
|
||||
return request.app.state.services.route_bridge
|
||||
|
||||
|
||||
def get_service(request: Request) -> OpenApiService:
|
||||
return request.app.state.services.open_api
|
||||
|
||||
|
||||
async def _json_or_empty(request: Request) -> dict[str, Any]:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _compat_g(auth: AuthContext) -> CompatG:
|
||||
g_obj = CompatG()
|
||||
g_obj.username = auth.username
|
||||
return g_obj
|
||||
|
||||
|
||||
async def _call_open_api_route(
|
||||
request: Request,
|
||||
auth: AuthContext,
|
||||
handler_name: str,
|
||||
):
|
||||
route = getattr(request.app.state, "open_api_route", None)
|
||||
app_adapter = getattr(request.app.state, "dashboard_app_adapter", None)
|
||||
if route is None or app_adapter is None:
|
||||
raise ApiError("OpenAPI compatibility route is unavailable", status_code=503)
|
||||
return await call_request_view(
|
||||
request,
|
||||
app_adapter,
|
||||
getattr(route, handler_name),
|
||||
g_obj=_compat_g(auth),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
if auth.via == "api_key":
|
||||
return await _call_open_api_route(request, auth, "chat_send")
|
||||
return await bridge.forward(
|
||||
request, auth, method="POST", target_path="/api/chat/send"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/sessions")
|
||||
async def chat_sessions(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_chat_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
if auth.via == "api_key":
|
||||
return await _call_open_api_route(request, auth, "get_chat_sessions")
|
||||
return await bridge.forward(
|
||||
request, auth, method="GET", target_path="/api/chat/sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/configs", include_in_schema=False)
|
||||
async def get_chat_configs(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
):
|
||||
return await _call_open_api_route(request, auth, "get_chat_configs")
|
||||
|
||||
|
||||
@router.post("/file", include_in_schema=False)
|
||||
async def upload_open_api_file(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_file_scope),
|
||||
):
|
||||
return await _call_open_api_route(request, auth, "openapi_upload_file")
|
||||
|
||||
|
||||
@router.get("/file", include_in_schema=False)
|
||||
async def get_open_api_file(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_file_scope),
|
||||
):
|
||||
return await _call_open_api_route(request, auth, "openapi_get_file")
|
||||
|
||||
|
||||
@router.websocket("/chat/ws")
|
||||
async def chat_ws(websocket: WebSocket) -> None:
|
||||
route = getattr(websocket.app.state, "open_api_route", None)
|
||||
app_adapter = getattr(websocket.app.state, "dashboard_app_adapter", None)
|
||||
if route is not None and app_adapter is not None:
|
||||
await call_websocket_view(websocket, app_adapter, route.chat_ws)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
await websocket.close(1011, "OpenAPI chat websocket route is unavailable")
|
||||
|
||||
|
||||
async def _forward_route_websocket(websocket: WebSocket, target_path: str) -> None:
|
||||
route_app = websocket.app.state.services.route_bridge.route_app
|
||||
receive = getattr(websocket, "_receive")
|
||||
send = getattr(websocket, "_send")
|
||||
scope = {
|
||||
**websocket.scope,
|
||||
"path": target_path,
|
||||
"raw_path": target_path.encode(),
|
||||
}
|
||||
await route_app(scope, receive, send)
|
||||
|
||||
|
||||
@router.websocket("/live-chat/ws")
|
||||
async def live_chat_ws(websocket: WebSocket) -> None:
|
||||
await _forward_route_websocket(websocket, "/api/live_chat/ws")
|
||||
|
||||
|
||||
@router.websocket("/unified-chat/ws")
|
||||
async def unified_chat_ws(websocket: WebSocket) -> None:
|
||||
await _forward_route_websocket(websocket, "/api/unified_chat/ws")
|
||||
|
||||
|
||||
@router.post("/im/messages")
|
||||
async def send_im_message(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_im_scope),
|
||||
service: OpenApiService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
await service.send_message(body)
|
||||
except OpenApiServiceError as exc:
|
||||
raise ApiError(str(exc)) from exc
|
||||
|
||||
return ok()
|
||||
|
||||
|
||||
@router.post("/im/message", include_in_schema=False)
|
||||
async def send_legacy_im_message(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_im_scope),
|
||||
):
|
||||
return await _call_open_api_route(request, auth, "send_message")
|
||||
|
||||
|
||||
@router.get("/im/bots")
|
||||
async def list_im_bots(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_im_scope),
|
||||
service: OpenApiService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_bots())
|
||||
|
||||
|
||||
async def _forward_platform_webhook(
|
||||
webhook_uuid: str,
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService,
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
None,
|
||||
method=request.method,
|
||||
target_path=f"/api/platform/webhook/{webhook_uuid}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/webhooks/platforms/{webhook_uuid}")
|
||||
async def verify_platform_webhook(
|
||||
webhook_uuid: str,
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _forward_platform_webhook(webhook_uuid, request, bridge)
|
||||
|
||||
|
||||
@router.post("/webhooks/platforms/{webhook_uuid}")
|
||||
async def receive_platform_webhook(
|
||||
webhook_uuid: str,
|
||||
request: Request,
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _forward_platform_webhook(webhook_uuid, request, bridge)
|
||||
894
astrbot/dashboard/v1/routers/plugins.py
Normal file
894
astrbot/dashboard/v1/routers/plugins.py
Normal file
@@ -0,0 +1,894 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from astrbot.dashboard.services.plugin_service import PluginService, PluginServiceError
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ..auth import AuthContext, require_scope
|
||||
from ..responses import ApiError, ok
|
||||
|
||||
router = APIRouter(tags=["Plugins"])
|
||||
|
||||
|
||||
async def require_plugin_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "plugin")
|
||||
|
||||
|
||||
def get_bridge(request: Request) -> DashboardRouteBridgeService:
|
||||
return request.app.state.services.route_bridge
|
||||
|
||||
|
||||
def get_service(request: Request) -> PluginService:
|
||||
return request.app.state.services.plugins
|
||||
|
||||
|
||||
async def _proxy_plugin_extension(
|
||||
plugin_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext,
|
||||
bridge: DashboardRouteBridgeService,
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method=request.method,
|
||||
target_path=f"/api/plug/{plugin_path.lstrip('/')}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/extensions/{plugin_path:path}")
|
||||
async def get_plugin_extension_route(
|
||||
plugin_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
|
||||
|
||||
|
||||
@router.post("/plugins/extensions/{plugin_path:path}")
|
||||
async def post_plugin_extension_route(
|
||||
plugin_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
|
||||
|
||||
|
||||
@router.put("/plugins/extensions/{plugin_path:path}")
|
||||
async def put_plugin_extension_route(
|
||||
plugin_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
|
||||
|
||||
|
||||
@router.patch("/plugins/extensions/{plugin_path:path}")
|
||||
async def patch_plugin_extension_route(
|
||||
plugin_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
|
||||
|
||||
|
||||
@router.delete("/plugins/extensions/{plugin_path:path}")
|
||||
async def delete_plugin_extension_route(
|
||||
plugin_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await _proxy_plugin_extension(plugin_path, request, auth, bridge)
|
||||
|
||||
|
||||
async def _json_or_empty(request: Request) -> dict:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _required_text(value: object, name: str) -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"Missing key: {name}")
|
||||
return text
|
||||
|
||||
|
||||
def _plugin_id_from_body(body: dict) -> str:
|
||||
return _required_text(body.get("plugin_id"), "plugin_id")
|
||||
|
||||
|
||||
@router.get("/plugins/failed")
|
||||
async def list_failed_plugins(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/source/get-failed-plugins",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/update")
|
||||
async def update_plugins(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
if body.get("plugin_id"):
|
||||
plugin_id = _plugin_id_from_body(body)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/update",
|
||||
json_body={
|
||||
"name": plugin_id,
|
||||
**{key: value for key, value in body.items() if key != "plugin_id"},
|
||||
},
|
||||
)
|
||||
legacy_body = {
|
||||
**body,
|
||||
"names": body.get("names") or body.get("plugin_ids") or [],
|
||||
}
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/update-all",
|
||||
json_body=legacy_body,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/compatibility/check")
|
||||
async def check_plugin_compatibility(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/check-compat",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/install/github")
|
||||
async def install_plugin_from_github(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
repository = str(body.get("repository") or body.get("url") or "").strip()
|
||||
if repository and not repository.startswith(("http://", "https://")):
|
||||
repository = f"https://github.com/{repository}"
|
||||
legacy_body = {
|
||||
"url": repository,
|
||||
"proxy": body.get("proxy"),
|
||||
"ignore_version_check": body.get("ignore_version_check", False),
|
||||
}
|
||||
if body.get("download_url"):
|
||||
legacy_body["download_url"] = body["download_url"]
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/install",
|
||||
json_body=legacy_body,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/install/url")
|
||||
async def install_plugin_from_url(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
url = str(body.get("url") or "").strip()
|
||||
legacy_body = {
|
||||
"url": body.get("repository") or url,
|
||||
"download_url": url,
|
||||
"proxy": body.get("proxy"),
|
||||
"ignore_version_check": body.get("ignore_version_check", False),
|
||||
}
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/install",
|
||||
json_body=legacy_body,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/install/upload")
|
||||
async def install_plugin_from_upload(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/install-upload",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/market")
|
||||
async def list_plugin_market(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/market_list",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/market/categories")
|
||||
async def list_plugin_market_categories(
|
||||
_request: Request,
|
||||
_auth: AuthContext = Depends(require_plugin_scope),
|
||||
):
|
||||
return ok({"categories": []})
|
||||
|
||||
|
||||
@router.get("/plugin-sources")
|
||||
async def list_plugin_sources(
|
||||
_request: Request,
|
||||
_auth: AuthContext = Depends(require_plugin_scope),
|
||||
service: PluginService = Depends(get_service),
|
||||
):
|
||||
return ok({"sources": await service.get_custom_sources()})
|
||||
|
||||
|
||||
@router.post("/plugin-sources")
|
||||
async def create_plugin_source(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_plugin_scope),
|
||||
service: PluginService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
sources = await service.create_custom_source(body)
|
||||
except PluginServiceError as exc:
|
||||
raise ApiError(str(exc)) from exc
|
||||
return ok({"sources": sources}, message="保存成功")
|
||||
|
||||
|
||||
@router.put("/plugin-sources")
|
||||
async def replace_plugin_sources(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_plugin_scope),
|
||||
service: PluginService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
try:
|
||||
sources = await service.replace_custom_sources(body)
|
||||
except PluginServiceError as exc:
|
||||
raise ApiError(str(exc)) from exc
|
||||
return ok({"sources": sources}, message="保存成功")
|
||||
|
||||
|
||||
@router.delete("/plugin-sources/by-id")
|
||||
async def delete_plugin_source_by_id(
|
||||
source_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_plugin_scope),
|
||||
service: PluginService = Depends(get_service),
|
||||
):
|
||||
return ok(
|
||||
{"sources": await service.delete_custom_source(source_id)},
|
||||
message="保存成功",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/plugin-sources/{source_id}")
|
||||
async def delete_plugin_source(
|
||||
source_id: str,
|
||||
_request: Request,
|
||||
_auth: AuthContext = Depends(require_plugin_scope),
|
||||
service: PluginService = Depends(get_service),
|
||||
):
|
||||
return ok(
|
||||
{"sources": await service.delete_custom_source(source_id)},
|
||||
message="保存成功",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/page-bridge-sdk.js")
|
||||
async def get_plugin_page_bridge_sdk(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/page/bridge-sdk.js",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins")
|
||||
async def list_plugins(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/get",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/by-id")
|
||||
async def get_plugin_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/detail",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/plugins/by-id")
|
||||
async def uninstall_plugin_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/uninstall",
|
||||
json_body={"name": plugin_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/config")
|
||||
async def get_plugin_config_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/config/get",
|
||||
query={"plugin_name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/plugins/config")
|
||||
async def update_plugin_config_by_id(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
plugin_id = _plugin_id_from_body(body)
|
||||
config = body.get("config")
|
||||
config = config if isinstance(config, dict) else {}
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/plugin/update",
|
||||
query={"plugin_name": plugin_id},
|
||||
json_body=config,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/config/schema")
|
||||
async def get_plugin_config_schema_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/config/get",
|
||||
query={"plugin_name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/config-files")
|
||||
async def list_plugin_config_files_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
config_key: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/config/file/get",
|
||||
query={"scope": "plugin", "name": plugin_id, "key": config_key},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/config-files")
|
||||
async def upload_plugin_config_files_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
config_key: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/file/upload",
|
||||
query={"scope": "plugin", "name": plugin_id, "key": config_key},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/plugins/config-files")
|
||||
async def delete_plugin_config_file_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/file/delete",
|
||||
query={"scope": "plugin", "name": plugin_id},
|
||||
json_body=body,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/readme")
|
||||
async def get_plugin_readme_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/readme",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/changelog")
|
||||
async def get_plugin_changelog_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/changelog",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/reload")
|
||||
async def reload_plugin_by_id(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
plugin_id = _plugin_id_from_body(body)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/reload",
|
||||
json_body={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/plugins/enabled")
|
||||
async def set_plugin_enabled_by_id(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
plugin_id = _plugin_id_from_body(body)
|
||||
target_path = "/api/plugin/on" if body.get("enabled") else "/api/plugin/off"
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path=target_path,
|
||||
json_body={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/pages")
|
||||
async def list_plugin_pages_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/detail",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/page")
|
||||
async def get_plugin_page_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
page_name: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/page/entry",
|
||||
query={"name": plugin_id, "page": page_name},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/page/assets")
|
||||
async def get_plugin_page_asset_by_id(
|
||||
request: Request,
|
||||
plugin_id: str = Query(...),
|
||||
page_name: str = Query(...),
|
||||
asset_path: str = Query(...),
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path=f"/api/plugin/page/content/{plugin_id}/{page_name}/{asset_path}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}")
|
||||
async def get_plugin(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/detail",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/plugins/{plugin_id}")
|
||||
async def uninstall_plugin(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/uninstall",
|
||||
json_body={"name": plugin_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/plugins/failed/{plugin_id}")
|
||||
async def uninstall_failed_plugin(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/uninstall-failed",
|
||||
json_body={"dir_name": plugin_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/failed/{plugin_id}/reload")
|
||||
async def reload_failed_plugin(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/reload-failed",
|
||||
json_body={"dir_name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/config")
|
||||
async def get_plugin_config(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/config/get",
|
||||
query={"plugin_name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/plugins/{plugin_id}/config")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/plugin/update",
|
||||
query={"plugin_name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/config/schema")
|
||||
async def get_plugin_config_schema(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/config/get",
|
||||
query={"plugin_name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/config-files/{config_key:path}")
|
||||
async def list_plugin_config_files(
|
||||
plugin_id: str,
|
||||
config_key: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/config/file/get",
|
||||
query={"scope": "plugin", "name": plugin_id, "key": config_key},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/{plugin_id}/config-files/{config_key:path}")
|
||||
async def upload_plugin_config_files(
|
||||
plugin_id: str,
|
||||
config_key: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/file/upload",
|
||||
query={"scope": "plugin", "name": plugin_id, "key": config_key},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/plugins/{plugin_id}/config-files")
|
||||
async def delete_plugin_config_file(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/file/delete",
|
||||
query={"scope": "plugin", "name": plugin_id},
|
||||
json_body=body,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/readme")
|
||||
async def get_plugin_readme(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/readme",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/changelog")
|
||||
async def get_plugin_changelog(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/changelog",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/{plugin_id}/reload")
|
||||
async def reload_plugin(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/reload",
|
||||
json_body={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/plugins/{plugin_id}/enabled")
|
||||
async def set_plugin_enabled(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
target_path = "/api/plugin/on" if body.get("enabled") else "/api/plugin/off"
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path=target_path,
|
||||
json_body={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/plugins/{plugin_id}/update")
|
||||
async def update_plugin(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/plugin/update",
|
||||
json_body={"name": plugin_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/pages")
|
||||
async def list_plugin_pages(
|
||||
plugin_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/detail",
|
||||
query={"name": plugin_id},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/pages/{page_name}")
|
||||
async def get_plugin_page(
|
||||
plugin_id: str,
|
||||
page_name: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path="/api/plugin/page/entry",
|
||||
query={"name": plugin_id, "page": page_name},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/plugins/{plugin_id}/pages/{page_name}/assets/{asset_path:path}")
|
||||
async def get_plugin_page_asset(
|
||||
plugin_id: str,
|
||||
page_name: str,
|
||||
asset_path: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_plugin_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="GET",
|
||||
target_path=f"/api/plugin/page/content/{plugin_id}/{page_name}/{asset_path}",
|
||||
)
|
||||
403
astrbot/dashboard/v1/routers/providers.py
Normal file
403
astrbot/dashboard/v1/routers/providers.py
Normal file
@@ -0,0 +1,403 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from astrbot.dashboard.services.config_service import ProviderConfigService
|
||||
from astrbot.dashboard.services.route_bridge_service import DashboardRouteBridgeService
|
||||
|
||||
from ..auth import AuthContext, require_scope
|
||||
from ..responses import ok
|
||||
from ..schemas import EnabledPatch, ProviderConfigRequest, ProviderSourceRequest
|
||||
|
||||
router = APIRouter(tags=["Providers"])
|
||||
|
||||
|
||||
async def require_config_scope(request: Request) -> AuthContext:
|
||||
return await require_scope(request, "config")
|
||||
|
||||
|
||||
def get_service(request: Request) -> ProviderConfigService:
|
||||
return request.app.state.services.providers
|
||||
|
||||
|
||||
def get_bridge(request: Request) -> DashboardRouteBridgeService:
|
||||
return request.app.state.services.route_bridge
|
||||
|
||||
|
||||
async def _json_or_empty(request: Request) -> dict:
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _required_text(value: object, name: str) -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
raise ValueError(f"Missing key: {name}")
|
||||
return text
|
||||
|
||||
|
||||
def _config_from_body(body: dict) -> dict:
|
||||
config = body.get("config")
|
||||
if isinstance(config, dict):
|
||||
return config
|
||||
return {
|
||||
key: value
|
||||
for key, value in body.items()
|
||||
if key
|
||||
not in {
|
||||
"provider_id",
|
||||
"source_id",
|
||||
"config",
|
||||
"enabled",
|
||||
"provider_config",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/providers/schema")
|
||||
async def get_provider_schema(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_provider_schema())
|
||||
|
||||
|
||||
@router.get("/provider-sources")
|
||||
async def list_provider_sources(
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.list_provider_sources())
|
||||
|
||||
|
||||
@router.post("/provider-sources")
|
||||
async def create_provider_source(
|
||||
payload: ProviderSourceRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
config = payload.to_legacy_config()
|
||||
source_id = config.get("id")
|
||||
if not source_id:
|
||||
raise ValueError("Provider source config must have an 'id' field")
|
||||
await service.upsert_provider_source(source_id, config)
|
||||
return ok(message="更新 provider source 成功")
|
||||
|
||||
|
||||
@router.get("/provider-sources/by-id")
|
||||
async def get_provider_source_by_id(
|
||||
source_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_provider_source(source_id))
|
||||
|
||||
|
||||
@router.put("/provider-sources/by-id")
|
||||
async def upsert_provider_source_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
source_id = _required_text(body.get("source_id"), "source_id")
|
||||
await service.upsert_provider_source(
|
||||
source_id,
|
||||
ProviderSourceRequest(config=_config_from_body(body)).to_legacy_config(
|
||||
fallback_id=source_id,
|
||||
),
|
||||
)
|
||||
return ok(message="更新 provider source 成功")
|
||||
|
||||
|
||||
@router.delete("/provider-sources/by-id")
|
||||
async def delete_provider_source_by_id(
|
||||
source_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.delete_provider_source(source_id)
|
||||
return ok(message="删除 provider source 成功")
|
||||
|
||||
|
||||
@router.get("/provider-sources/models")
|
||||
async def list_provider_source_models_by_id(
|
||||
source_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(await service.list_provider_source_models(source_id))
|
||||
|
||||
|
||||
@router.get("/provider-sources/providers")
|
||||
async def list_providers_by_source_id(
|
||||
source_id: str = Query(...),
|
||||
capability: str | None = Query(default=None),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.list_providers(capability=capability, source_id=source_id))
|
||||
|
||||
|
||||
@router.post("/provider-sources/providers")
|
||||
async def create_provider_in_source_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
source_id = _required_text(body.get("source_id"), "source_id")
|
||||
await service.create_provider(
|
||||
ProviderConfigRequest(config=_config_from_body(body)).to_legacy_config(
|
||||
source_id=source_id,
|
||||
),
|
||||
source_id,
|
||||
)
|
||||
return ok(message="新增服务提供商配置成功")
|
||||
|
||||
|
||||
@router.get("/provider-sources/{source_id:path}/models")
|
||||
async def list_provider_source_models(
|
||||
source_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(await service.list_provider_source_models(source_id))
|
||||
|
||||
|
||||
@router.get("/provider-sources/{source_id:path}/providers")
|
||||
async def list_providers_by_source(
|
||||
source_id: str,
|
||||
capability: str | None = Query(default=None),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.list_providers(capability=capability, source_id=source_id))
|
||||
|
||||
|
||||
@router.post("/provider-sources/{source_id:path}/providers")
|
||||
async def create_provider_in_source(
|
||||
source_id: str,
|
||||
payload: ProviderConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.create_provider(
|
||||
payload.to_legacy_config(source_id=source_id), source_id
|
||||
)
|
||||
return ok(message="新增服务提供商配置成功")
|
||||
|
||||
|
||||
@router.get("/provider-sources/{source_id:path}")
|
||||
async def get_provider_source(
|
||||
source_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_provider_source(source_id))
|
||||
|
||||
|
||||
@router.put("/provider-sources/{source_id:path}")
|
||||
async def upsert_provider_source(
|
||||
source_id: str,
|
||||
payload: ProviderSourceRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.upsert_provider_source(
|
||||
source_id,
|
||||
payload.to_legacy_config(),
|
||||
)
|
||||
return ok(message="更新 provider source 成功")
|
||||
|
||||
|
||||
@router.delete("/provider-sources/{source_id:path}")
|
||||
async def delete_provider_source(
|
||||
source_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.delete_provider_source(source_id)
|
||||
return ok(message="删除 provider source 成功")
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_providers(
|
||||
capability: str | None = Query(default=None),
|
||||
source_id: str | None = Query(default=None),
|
||||
enabled: bool | None = Query(default=None),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(
|
||||
service.list_providers(
|
||||
capability=capability,
|
||||
source_id=source_id,
|
||||
enabled=enabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/providers")
|
||||
async def create_provider(
|
||||
payload: ProviderConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.create_provider(payload.to_legacy_config())
|
||||
return ok(message="新增服务提供商配置成功")
|
||||
|
||||
|
||||
@router.get("/providers/by-id")
|
||||
async def get_provider_by_id(
|
||||
provider_id: str = Query(...),
|
||||
merged: bool = Query(default=False),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_provider(provider_id, merged=merged))
|
||||
|
||||
|
||||
@router.put("/providers/by-id")
|
||||
async def update_provider_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
provider_id = _required_text(body.get("provider_id"), "provider_id")
|
||||
await service.update_provider(
|
||||
provider_id,
|
||||
ProviderConfigRequest(config=_config_from_body(body)).to_legacy_config(
|
||||
fallback_id=provider_id,
|
||||
),
|
||||
)
|
||||
return ok(message="更新成功,已经实时生效~")
|
||||
|
||||
|
||||
@router.delete("/providers/by-id")
|
||||
async def delete_provider_by_id(
|
||||
provider_id: str = Query(...),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.delete_provider(provider_id)
|
||||
return ok(message="删除成功,已经实时生效。")
|
||||
|
||||
|
||||
@router.patch("/providers/enabled")
|
||||
async def set_provider_enabled_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
provider_id = _required_text(body.get("provider_id"), "provider_id")
|
||||
await service.set_provider_enabled(provider_id, bool(body.get("enabled")))
|
||||
return ok(message="更新成功,已经实时生效~")
|
||||
|
||||
|
||||
@router.post("/providers/test")
|
||||
async def test_provider_by_id(
|
||||
request: Request,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
provider_id = _required_text(body.get("provider_id"), "provider_id")
|
||||
return ok(await service.test_provider(provider_id))
|
||||
|
||||
|
||||
@router.post("/providers/embedding-dimension")
|
||||
async def get_embedding_dimension_by_id(
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
provider_id = _required_text(body.get("provider_id"), "provider_id")
|
||||
provider_config = body.get("provider_config")
|
||||
legacy_body = {"provider_id": provider_id}
|
||||
if isinstance(provider_config, dict):
|
||||
legacy_body["provider_config"] = provider_config
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/provider/get_embedding_dim",
|
||||
json_body=legacy_body,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/providers/{provider_id:path}/enabled")
|
||||
async def set_provider_enabled(
|
||||
provider_id: str,
|
||||
payload: EnabledPatch,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.set_provider_enabled(provider_id, payload.enabled)
|
||||
return ok(message="更新成功,已经实时生效~")
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id:path}/test")
|
||||
async def test_provider(
|
||||
provider_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(await service.test_provider(provider_id))
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id:path}/embedding-dimension")
|
||||
async def get_embedding_dimension(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
auth: AuthContext = Depends(require_config_scope),
|
||||
bridge: DashboardRouteBridgeService = Depends(get_bridge),
|
||||
):
|
||||
body = await _json_or_empty(request)
|
||||
return await bridge.forward(
|
||||
request,
|
||||
auth,
|
||||
method="POST",
|
||||
target_path="/api/config/provider/get_embedding_dim",
|
||||
json_body={"provider_id": provider_id, **body},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id:path}")
|
||||
async def get_provider(
|
||||
provider_id: str,
|
||||
merged: bool = Query(default=False),
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
return ok(service.get_provider(provider_id, merged=merged))
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id:path}")
|
||||
async def update_provider(
|
||||
provider_id: str,
|
||||
payload: ProviderConfigRequest,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.update_provider(
|
||||
provider_id,
|
||||
payload.to_legacy_config(fallback_id=provider_id),
|
||||
)
|
||||
return ok(message="更新成功,已经实时生效~")
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id:path}")
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
_auth: AuthContext = Depends(require_config_scope),
|
||||
service: ProviderConfigService = Depends(get_service),
|
||||
):
|
||||
await service.delete_provider(provider_id)
|
||||
return ok(message="删除成功,已经实时生效。")
|
||||
172
astrbot/dashboard/v1/schemas.py
Normal file
172
astrbot/dashboard/v1/schemas.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class OpenModel(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ConfigProfileCreateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class RenameRequest(BaseModel):
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class EnabledPatch(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class BotConfigRequest(OpenModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
type: str | None = None
|
||||
enabled: bool | None = None
|
||||
enable: bool | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
def to_legacy_config(self, *, fallback_id: str | None = None) -> dict[str, Any]:
|
||||
config = dict(
|
||||
self.config
|
||||
or self.model_dump(
|
||||
exclude={"config", "enabled"},
|
||||
exclude_none=True,
|
||||
)
|
||||
)
|
||||
if fallback_id and "id" not in config:
|
||||
config["id"] = fallback_id
|
||||
if self.type and "type" not in config:
|
||||
config["type"] = self.type
|
||||
if self.id and "id" not in config:
|
||||
config["id"] = self.id
|
||||
if self.enabled is not None:
|
||||
config["enable"] = self.enabled
|
||||
elif self.enable is not None:
|
||||
config["enable"] = self.enable
|
||||
elif "enable" not in config:
|
||||
config["enable"] = True
|
||||
return config
|
||||
|
||||
|
||||
class ProviderSourceRequest(OpenModel):
|
||||
id: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
def to_legacy_config(self, *, fallback_id: str | None = None) -> dict[str, Any]:
|
||||
config = dict(
|
||||
self.config or self.model_dump(exclude={"config"}, exclude_none=True)
|
||||
)
|
||||
if fallback_id:
|
||||
config["id"] = fallback_id
|
||||
elif self.id and "id" not in config:
|
||||
config["id"] = self.id
|
||||
return config
|
||||
|
||||
|
||||
class ProviderConfigRequest(OpenModel):
|
||||
id: str | None = None
|
||||
provider_source_id: str | None = None
|
||||
capability: str | None = None
|
||||
enabled: bool | None = None
|
||||
enable: bool | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
def to_legacy_config(
|
||||
self,
|
||||
*,
|
||||
fallback_id: str | None = None,
|
||||
source_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
config = dict(
|
||||
self.config
|
||||
or self.model_dump(
|
||||
exclude={"config", "capability", "enabled"},
|
||||
exclude_none=True,
|
||||
)
|
||||
)
|
||||
if fallback_id and "id" not in config:
|
||||
config["id"] = fallback_id
|
||||
if self.id and "id" not in config:
|
||||
config["id"] = self.id
|
||||
if source_id:
|
||||
config["provider_source_id"] = source_id
|
||||
elif self.provider_source_id and "provider_source_id" not in config:
|
||||
config["provider_source_id"] = self.provider_source_id
|
||||
if self.enabled is not None:
|
||||
config["enable"] = self.enabled
|
||||
elif self.enable is not None:
|
||||
config["enable"] = self.enable
|
||||
elif "enable" not in config:
|
||||
config["enable"] = True
|
||||
if self.capability and "provider_type" not in config:
|
||||
capability_map = {
|
||||
"chat": "chat_completion",
|
||||
"agent": "agent_runner",
|
||||
"stt": "speech_to_text",
|
||||
"tts": "text_to_speech",
|
||||
"embedding": "embedding",
|
||||
"rerank": "rerank",
|
||||
}
|
||||
config["provider_type"] = capability_map.get(
|
||||
self.capability, self.capability
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
class ProviderListQuery(BaseModel):
|
||||
capability: str | None = None
|
||||
source_id: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class ConfigRoutesReplaceRequest(BaseModel):
|
||||
routing: dict[str, str]
|
||||
|
||||
|
||||
class ConfigRouteUpsertRequest(BaseModel):
|
||||
config_id: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class SessionRuleRequest(OpenModel):
|
||||
umo: str | None = None
|
||||
rule_key: str | None = None
|
||||
rule_value: Any = None
|
||||
|
||||
|
||||
class UmoListRequest(OpenModel):
|
||||
umo: str | None = None
|
||||
umos: list[str] | None = None
|
||||
scope: Literal["all", "group", "private", "custom_group"] | None = None
|
||||
group_id: str | None = None
|
||||
rule_key: str | None = None
|
||||
|
||||
|
||||
class BatchSessionProviderRequest(UmoListRequest):
|
||||
provider_id: str | None = None
|
||||
provider_type: (
|
||||
Literal[
|
||||
"chat_completion",
|
||||
"speech_to_text",
|
||||
"text_to_speech",
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
|
||||
|
||||
class BatchSessionServiceRequest(UmoListRequest):
|
||||
session_enabled: bool | None = None
|
||||
llm_enabled: bool | None = None
|
||||
tts_enabled: bool | None = None
|
||||
|
||||
|
||||
class SessionGroupRequest(OpenModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
umos: list[str] | None = None
|
||||
add_umos: list[str] | None = None
|
||||
remove_umos: list[str] | None = None
|
||||
@@ -10,6 +10,7 @@
|
||||
"build-stage": "node scripts/subset-mdi-font.mjs && vue-tsc --noEmit && vite build --base=/vue/free/stage/",
|
||||
"build-prod": "node scripts/subset-mdi-font.mjs && vue-tsc --noEmit && vite build --base=/vue/free/",
|
||||
"preview": "vite preview --port 5050",
|
||||
"generate:api": "uv run python scripts/generate_openapi_client.py",
|
||||
"typecheck": "vue-tsc --noEmit",
|
||||
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
|
||||
},
|
||||
|
||||
364
dashboard/scripts/generate_openapi_client.py
Normal file
364
dashboard/scripts/generate_openapi_client.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[2]
|
||||
DASHBOARD_DIR = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_SPEC = ROOT_DIR / "openspec" / "openapi-v1.yaml"
|
||||
DEFAULT_OUTPUT = DASHBOARD_DIR / "src" / "api" / "generated" / "openapi-v1.ts"
|
||||
|
||||
HTTP_METHODS = {"get", "post", "put", "patch", "delete", "head", "options"}
|
||||
|
||||
|
||||
def load_spec(source: str) -> dict[str, Any]:
|
||||
if source.startswith(("http://", "https://")):
|
||||
with urllib.request.urlopen(source, timeout=10) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
|
||||
spec_path = Path(source)
|
||||
if not spec_path.is_absolute():
|
||||
spec_path = (ROOT_DIR / spec_path).resolve()
|
||||
text = spec_path.read_text(encoding="utf-8")
|
||||
if spec_path.suffix.lower() == ".json":
|
||||
return json.loads(text)
|
||||
return yaml.safe_load(text)
|
||||
|
||||
|
||||
def pascal_case(value: str) -> str:
|
||||
words = re.split(r"[^a-zA-Z0-9]+", value)
|
||||
return "".join(word[:1].upper() + word[1:] for word in words if word)
|
||||
|
||||
|
||||
def camel_case(value: str) -> str:
|
||||
pascal = pascal_case(value)
|
||||
return pascal[:1].lower() + pascal[1:]
|
||||
|
||||
|
||||
def quote(value: str) -> str:
|
||||
return json.dumps(value, ensure_ascii=True)
|
||||
|
||||
|
||||
def property_name(name: str) -> str:
|
||||
if re.fullmatch(r"[A-Za-z_$][A-Za-z0-9_$]*", name):
|
||||
return name
|
||||
return quote(name)
|
||||
|
||||
|
||||
def ref_name(ref: str) -> str:
|
||||
return ref.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
class TypeScriptGenerator:
|
||||
def __init__(self, spec: dict[str, Any]) -> None:
|
||||
self.spec = spec
|
||||
self.components = spec.get("components", {})
|
||||
|
||||
def resolve_ref(self, obj: dict[str, Any]) -> dict[str, Any]:
|
||||
ref = obj.get("$ref")
|
||||
if not ref:
|
||||
return obj
|
||||
if not ref.startswith("#/"):
|
||||
raise ValueError(f"Unsupported external ref: {ref}")
|
||||
current: Any = self.spec
|
||||
for part in ref.removeprefix("#/").split("/"):
|
||||
current = current[part]
|
||||
return current
|
||||
|
||||
def schema_to_ts(self, schema: dict[str, Any] | None) -> str:
|
||||
if not schema:
|
||||
return "unknown"
|
||||
if "$ref" in schema:
|
||||
return ref_name(schema["$ref"])
|
||||
|
||||
if "allOf" in schema:
|
||||
parts = [self.schema_to_ts(item) for item in schema["allOf"]]
|
||||
return " & ".join(parts) or "unknown"
|
||||
if "oneOf" in schema:
|
||||
parts = [self.schema_to_ts(item) for item in schema["oneOf"]]
|
||||
return " | ".join(parts) or "unknown"
|
||||
if "anyOf" in schema:
|
||||
parts = [self.schema_to_ts(item) for item in schema["anyOf"]]
|
||||
return " | ".join(parts) or "unknown"
|
||||
|
||||
if "const" in schema:
|
||||
return quote(str(schema["const"]))
|
||||
if "enum" in schema:
|
||||
values = schema.get("enum") or []
|
||||
return " | ".join(quote(str(value)) for value in values) or "string"
|
||||
|
||||
schema_type = schema.get("type")
|
||||
if isinstance(schema_type, list):
|
||||
return " | ".join(
|
||||
self.schema_to_ts({**schema, "type": item}) for item in schema_type
|
||||
)
|
||||
|
||||
if schema_type == "string":
|
||||
if schema.get("format") == "binary":
|
||||
return "Blob | File"
|
||||
return "string"
|
||||
if schema_type in {"integer", "number"}:
|
||||
return "number"
|
||||
if schema_type == "boolean":
|
||||
return "boolean"
|
||||
if schema_type == "array":
|
||||
return f"{self.schema_to_ts(schema.get('items'))}[]"
|
||||
if schema_type == "object" or "properties" in schema:
|
||||
properties = schema.get("properties") or {}
|
||||
additional = schema.get("additionalProperties")
|
||||
if not properties:
|
||||
if isinstance(additional, dict):
|
||||
return f"Record<string, {self.schema_to_ts(additional)}>"
|
||||
return "Record<string, unknown>"
|
||||
|
||||
required = set(schema.get("required") or [])
|
||||
fields = []
|
||||
for name, prop_schema in properties.items():
|
||||
optional = "" if name in required else "?"
|
||||
fields.append(
|
||||
f"{property_name(name)}{optional}: {self.schema_to_ts(prop_schema)};"
|
||||
)
|
||||
if additional is True:
|
||||
fields.append("[key: string]: unknown;")
|
||||
elif isinstance(additional, dict):
|
||||
fields.append(f"[key: string]: {self.schema_to_ts(additional)};")
|
||||
return "{ " + " ".join(fields) + " }"
|
||||
|
||||
return "unknown"
|
||||
|
||||
def component_declarations(self) -> list[str]:
|
||||
declarations = []
|
||||
schemas = self.components.get("schemas") or {}
|
||||
for name, schema in schemas.items():
|
||||
if (
|
||||
schema.get("type") == "object"
|
||||
and "properties" in schema
|
||||
and "allOf" not in schema
|
||||
and "oneOf" not in schema
|
||||
and "anyOf" not in schema
|
||||
):
|
||||
declarations.append(self.object_interface(name, schema))
|
||||
else:
|
||||
declarations.append(
|
||||
f"export type {name} = {self.schema_to_ts(schema)};"
|
||||
)
|
||||
return declarations
|
||||
|
||||
def object_interface(self, name: str, schema: dict[str, Any]) -> str:
|
||||
required = set(schema.get("required") or [])
|
||||
lines = [f"export interface {name} {{"]
|
||||
for prop_name, prop_schema in (schema.get("properties") or {}).items():
|
||||
optional = "" if prop_name in required else "?"
|
||||
lines.append(
|
||||
f" {property_name(prop_name)}{optional}: "
|
||||
f"{self.schema_to_ts(prop_schema)};"
|
||||
)
|
||||
additional = schema.get("additionalProperties")
|
||||
if additional is True:
|
||||
lines.append(" [key: string]: unknown;")
|
||||
elif isinstance(additional, dict):
|
||||
lines.append(f" [key: string]: {self.schema_to_ts(additional)};")
|
||||
lines.append("}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def resolve_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
if "$ref" not in parameter:
|
||||
return parameter
|
||||
name = ref_name(parameter["$ref"])
|
||||
return self.components["parameters"][name]
|
||||
|
||||
def request_body_type(self, request_body: dict[str, Any] | None) -> str | None:
|
||||
if not request_body:
|
||||
return None
|
||||
if "$ref" in request_body:
|
||||
request_body = self.resolve_ref(request_body)
|
||||
content = request_body.get("content") or {}
|
||||
if "multipart/form-data" in content:
|
||||
return "FormData"
|
||||
if "application/octet-stream" in content:
|
||||
return "Blob | ArrayBuffer | string"
|
||||
media = content.get("application/json") or next(iter(content.values()), None)
|
||||
if not media:
|
||||
return "unknown"
|
||||
return self.schema_to_ts(media.get("schema"))
|
||||
|
||||
def response_type(self, operation: dict[str, Any]) -> str:
|
||||
responses = operation.get("responses") or {}
|
||||
response = responses.get("200") or responses.get("201") or responses.get("101")
|
||||
if not response:
|
||||
return "unknown"
|
||||
if "$ref" in response:
|
||||
response = self.resolve_ref(response)
|
||||
content = response.get("content") or {}
|
||||
if "application/json" in content:
|
||||
return self.schema_to_ts(content["application/json"].get("schema"))
|
||||
if "text/plain" in content or "text/html" in content:
|
||||
return "string"
|
||||
return "unknown"
|
||||
|
||||
def operation_parameters(
|
||||
self,
|
||||
operation: dict[str, Any],
|
||||
path_item: dict[str, Any],
|
||||
operation_id: str,
|
||||
) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
path_params: list[dict[str, Any]] = []
|
||||
query_params: list[dict[str, Any]] = []
|
||||
declarations: list[str] = []
|
||||
parameters = [
|
||||
*(path_item.get("parameters") or []),
|
||||
*(operation.get("parameters") or []),
|
||||
]
|
||||
|
||||
for raw_parameter in parameters:
|
||||
parameter = self.resolve_parameter(raw_parameter)
|
||||
target = path_params if parameter.get("in") == "path" else query_params
|
||||
if parameter.get("in") in {"path", "query"}:
|
||||
target.append(parameter)
|
||||
|
||||
def emit_params(name_suffix: str, params: list[dict[str, Any]]) -> str | None:
|
||||
if not params:
|
||||
return None
|
||||
type_name = f"{pascal_case(operation_id)}{name_suffix}"
|
||||
lines = [f"export interface {type_name} {{"]
|
||||
for param in params:
|
||||
required = bool(param.get("required"))
|
||||
optional = "" if required else "?"
|
||||
lines.append(
|
||||
f" {property_name(param['name'])}{optional}: "
|
||||
f"{self.schema_to_ts(param.get('schema'))};"
|
||||
)
|
||||
lines.append("}")
|
||||
declarations.append("\n".join(lines))
|
||||
return type_name
|
||||
|
||||
path_type = emit_params("Path", path_params)
|
||||
query_type = emit_params("Query", query_params)
|
||||
return declarations, [path_type or "undefined", query_type or "undefined"]
|
||||
|
||||
def operation_declaration(
|
||||
self,
|
||||
path: str,
|
||||
method: str,
|
||||
path_item: dict[str, Any],
|
||||
operation: dict[str, Any],
|
||||
) -> tuple[list[str], str]:
|
||||
operation_id = operation.get("operationId") or camel_case(f"{method}_{path}")
|
||||
operation_name = camel_case(operation_id)
|
||||
declarations, [path_type, query_type] = self.operation_parameters(
|
||||
operation,
|
||||
path_item,
|
||||
operation_id,
|
||||
)
|
||||
body_type = self.request_body_type(operation.get("requestBody")) or "undefined"
|
||||
response_type = self.response_type(operation)
|
||||
args_type_name = f"{pascal_case(operation_id)}Args"
|
||||
|
||||
members: list[str] = []
|
||||
if path_type != "undefined":
|
||||
members.append(f"path: {path_type};")
|
||||
if query_type != "undefined":
|
||||
members.append(f"query?: {query_type};")
|
||||
if body_type != "undefined":
|
||||
required = bool((operation.get("requestBody") or {}).get("required"))
|
||||
optional = "" if required else "?"
|
||||
members.append(f"body{optional}: {body_type};")
|
||||
|
||||
if members:
|
||||
declarations.append(
|
||||
"export interface "
|
||||
+ args_type_name
|
||||
+ " {\n "
|
||||
+ "\n ".join(members)
|
||||
+ "\n}"
|
||||
)
|
||||
args_signature = f"args: {args_type_name}"
|
||||
args_value = "args"
|
||||
else:
|
||||
args_signature = "args?: undefined"
|
||||
args_value = "args"
|
||||
|
||||
function = (
|
||||
f" {operation_name}({args_signature}, config?: AxiosRequestConfig) {{\n"
|
||||
f" return request<{response_type}>("
|
||||
f"{quote(method.upper())}, {quote(path)}, {args_value}, config"
|
||||
f");\n"
|
||||
f" }}"
|
||||
)
|
||||
return declarations, function
|
||||
|
||||
def generate(self) -> str:
|
||||
declarations = self.component_declarations()
|
||||
operation_functions = []
|
||||
|
||||
for path, path_item in sorted((self.spec.get("paths") or {}).items()):
|
||||
for method, operation in path_item.items():
|
||||
if method not in HTTP_METHODS:
|
||||
continue
|
||||
operation_declarations, operation_function = self.operation_declaration(
|
||||
path,
|
||||
method,
|
||||
path_item,
|
||||
operation,
|
||||
)
|
||||
declarations.extend(operation_declarations)
|
||||
operation_functions.append(operation_function)
|
||||
|
||||
return (
|
||||
"\n\n".join(
|
||||
[
|
||||
"/* eslint-disable */",
|
||||
"// This file is auto-generated by dashboard/scripts/generate_openapi_client.py.",
|
||||
"// Do not edit it manually; update openspec/openapi-v1.yaml and regenerate instead.",
|
||||
"import type { AxiosRequestConfig, AxiosResponse } from 'axios';",
|
||||
"import { apiV1Client } from '../http';",
|
||||
"type RequestArgs = { path?: object; query?: object; body?: unknown } | undefined;",
|
||||
"function encodePathValue(value: unknown): string {\n return encodeURIComponent(String(value));\n}",
|
||||
"function applyPathParams(path: string, params?: object): string {\n if (!params) return path;\n const values = params as Record<string, unknown>;\n return path.replace(/\\{([^}:]+)(?::path)?\\}/g, (_match, key) => encodePathValue(values[key]));\n}",
|
||||
"function request<T>(method: string, path: string, args?: RequestArgs, config?: AxiosRequestConfig): Promise<AxiosResponse<T>> {\n return apiV1Client.request<T>({\n ...config,\n method,\n url: applyPathParams(path, args?.path),\n params: args?.query,\n data: args?.body,\n });\n}",
|
||||
*declarations,
|
||||
"export const openApiV1 = {\n"
|
||||
+ ",\n".join(operation_functions)
|
||||
+ "\n};",
|
||||
]
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate the dashboard OpenAPI v1 client."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spec",
|
||||
default=str(DEFAULT_SPEC),
|
||||
help="OpenAPI source URL or file path. Defaults to openspec/openapi-v1.yaml.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
default=str(DEFAULT_OUTPUT),
|
||||
help="Generated TypeScript output path.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
spec = load_spec(args.spec)
|
||||
output = Path(args.out)
|
||||
if not output.is_absolute():
|
||||
output = (DASHBOARD_DIR / output).resolve()
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
output.write_text(TypeScriptGenerator(spec).generate(), encoding="utf-8")
|
||||
print(f"Generated {output.relative_to(ROOT_DIR)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3307
dashboard/src/api/generated/openapi-v1.ts
Normal file
3307
dashboard/src/api/generated/openapi-v1.ts
Normal file
File diff suppressed because it is too large
Load Diff
104
dashboard/src/api/http.ts
Normal file
104
dashboard/src/api/http.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
import axios, {
|
||||
type AxiosError,
|
||||
type AxiosInstance,
|
||||
type InternalAxiosRequestConfig,
|
||||
} from 'axios';
|
||||
|
||||
const AUTH_HEADER = 'Authorization';
|
||||
const LOCALE_HEADER = 'Accept-Language';
|
||||
|
||||
let configured = false;
|
||||
let originalFetch: typeof window.fetch | null = null;
|
||||
|
||||
export const httpClient = axios;
|
||||
export const legacyApiClient = axios.create({ baseURL: '/api' });
|
||||
export const apiV1Client = axios.create({ baseURL: '/api/v1' });
|
||||
|
||||
function getToken(): string | null {
|
||||
return localStorage.getItem('token');
|
||||
}
|
||||
|
||||
function getLocale(): string | null {
|
||||
return localStorage.getItem('astrbot-locale');
|
||||
}
|
||||
|
||||
function setAxiosHeader(
|
||||
headers: InternalAxiosRequestConfig['headers'],
|
||||
key: string,
|
||||
value: string,
|
||||
) {
|
||||
if (typeof headers.set === 'function') {
|
||||
headers.set(key, value);
|
||||
return;
|
||||
}
|
||||
headers[key] = value;
|
||||
}
|
||||
|
||||
function attachAxiosHeaders(config: InternalAxiosRequestConfig) {
|
||||
const token = getToken();
|
||||
if (token) {
|
||||
setAxiosHeader(config.headers, AUTH_HEADER, `Bearer ${token}`);
|
||||
}
|
||||
|
||||
const locale = getLocale();
|
||||
if (locale) {
|
||||
setAxiosHeader(config.headers, LOCALE_HEADER, locale);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
function normalizeAxiosError(error: AxiosError) {
|
||||
if (error.response?.status === 429) {
|
||||
const data = error.response.data as { message?: string } | undefined;
|
||||
if (data?.message) {
|
||||
return Promise.reject(data.message);
|
||||
}
|
||||
}
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
function installAxiosInterceptors(instance: AxiosInstance) {
|
||||
instance.interceptors.request.use(attachAxiosHeaders);
|
||||
instance.interceptors.response.use((response) => response, normalizeAxiosError);
|
||||
}
|
||||
|
||||
export function fetchWithAuth(input: RequestInfo | URL, init?: RequestInit) {
|
||||
const fetchImpl = originalFetch ?? window.fetch.bind(window);
|
||||
const token = getToken();
|
||||
const locale = getLocale();
|
||||
|
||||
if (!token && !locale) {
|
||||
return fetchImpl(input, init);
|
||||
}
|
||||
|
||||
const requestHeaders =
|
||||
typeof input !== 'string' && 'headers' in input
|
||||
? (input as Request).headers
|
||||
: undefined;
|
||||
const headers = new Headers(init?.headers || requestHeaders);
|
||||
|
||||
if (token && !headers.has(AUTH_HEADER)) {
|
||||
headers.set(AUTH_HEADER, `Bearer ${token}`);
|
||||
}
|
||||
if (locale && !headers.has(LOCALE_HEADER)) {
|
||||
headers.set(LOCALE_HEADER, locale);
|
||||
}
|
||||
|
||||
return fetchImpl(input, { ...init, headers });
|
||||
}
|
||||
|
||||
export function setupHttpClient() {
|
||||
if (configured) {
|
||||
return;
|
||||
}
|
||||
|
||||
installAxiosInterceptors(axios);
|
||||
installAxiosInterceptors(legacyApiClient);
|
||||
installAxiosInterceptors(apiV1Client);
|
||||
|
||||
originalFetch = window.fetch.bind(window);
|
||||
window.fetch = fetchWithAuth;
|
||||
|
||||
configured = true;
|
||||
}
|
||||
1552
dashboard/src/api/v1.ts
Normal file
1552
dashboard/src/api/v1.ts
Normal file
File diff suppressed because it is too large
Load Diff
@@ -503,7 +503,8 @@ import {
|
||||
} from "vue";
|
||||
import { useRoute, useRouter } from "vue-router";
|
||||
import { useDisplay } from "vuetify";
|
||||
import axios from "axios";
|
||||
import { isAxiosError } from "axios";
|
||||
import { chatApi } from "@/api/v1";
|
||||
import StyledMenu from "@/components/shared/StyledMenu.vue";
|
||||
import ProjectDialog, {
|
||||
type ProjectFormData,
|
||||
@@ -903,8 +904,7 @@ async function saveSessionTitleDialog() {
|
||||
try {
|
||||
const sessionId = editingSessionTitleId.value;
|
||||
const displayName = sessionTitleDraft.value.trim();
|
||||
await axios.post("/api/chat/update_session_display_name", {
|
||||
session_id: sessionId,
|
||||
await chatApi.updateSession(sessionId, {
|
||||
display_name: displayName,
|
||||
});
|
||||
updateSessionTitle(sessionId, displayName);
|
||||
@@ -1202,7 +1202,7 @@ async function createThreadFromSelection() {
|
||||
const message = threadSelection.message;
|
||||
if (!currSessionId.value || !message?.id || !threadSelection.selectedText) return;
|
||||
try {
|
||||
const response = await axios.post("/api/chat/thread/create", {
|
||||
const response = await chatApi.createThread({
|
||||
session_id: currSessionId.value,
|
||||
parent_message_id: message.id,
|
||||
selected_text: threadSelection.selectedText,
|
||||
@@ -1224,7 +1224,7 @@ async function createThreadFromSelection() {
|
||||
window.getSelection()?.removeAllRanges();
|
||||
} catch (error) {
|
||||
toast.error(
|
||||
axios.isAxiosError(error)
|
||||
isAxiosError(error)
|
||||
? error.response?.data?.message || error.message
|
||||
: tm("thread.createFailed"),
|
||||
);
|
||||
@@ -1269,9 +1269,7 @@ async function deleteThread(thread: ChatThread) {
|
||||
if (!(await askForConfirmation(tm("thread.confirmDelete"), confirmDialog))) return;
|
||||
deletingThread.value = true;
|
||||
try {
|
||||
await axios.post("/api/chat/thread/delete", {
|
||||
thread_id: thread.thread_id,
|
||||
});
|
||||
await chatApi.deleteThread(thread.thread_id);
|
||||
removeThreadFromMessages(thread.thread_id);
|
||||
if (activeThread.value?.thread_id === thread.thread_id) {
|
||||
threadPanelOpen.value = false;
|
||||
|
||||
@@ -321,7 +321,7 @@ import { useDisplay } from "vuetify";
|
||||
import { useModuleI18n } from "@/i18n/composables";
|
||||
import { useCustomizerStore } from "@/stores/customizer";
|
||||
import { isComposingEnter } from "@/utils/imeInput.mjs";
|
||||
import axios from "axios";
|
||||
import { commandApi } from "@/api/v1";
|
||||
import type { CommandItem } from "@/components/extension/componentPanel/types";
|
||||
import ConfigSelector from "./ConfigSelector.vue";
|
||||
import ProviderModelMenu from "./ProviderModelMenu.vue";
|
||||
@@ -742,16 +742,12 @@ async function fetchCommands() {
|
||||
if (commandSuggestionLoading.value) return;
|
||||
commandSuggestionLoading.value = true;
|
||||
try {
|
||||
const params: Record<string, string> = {};
|
||||
const cid = currentConfigId.value;
|
||||
if (cid && cid !== "default") {
|
||||
params.config_id = cid;
|
||||
}
|
||||
const res = await axios.get("/api/commands", { params });
|
||||
const res = await commandApi.list(cid && cid !== "default" ? cid : undefined);
|
||||
if (res.data.status === "ok") {
|
||||
allCommands.value = res.data.data.items || [];
|
||||
// 读取当前配置的唤醒词列表,用于指令候选的触发前缀
|
||||
const prefixes: string[] = res.data.data.wake_prefix;
|
||||
const prefixes: string[] = res.data.data.wake_prefix || [];
|
||||
if (prefixes && prefixes.length > 0) {
|
||||
wakePrefixes.value = prefixes;
|
||||
}
|
||||
|
||||
@@ -389,6 +389,7 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, nextTick, reactive, ref } from "vue";
|
||||
import axios from "axios";
|
||||
import { fileApi } from "@/api/v1";
|
||||
import { setCustomComponents } from "markstream-vue";
|
||||
import "markstream-vue/index.css";
|
||||
import RegenerateMenu, {
|
||||
@@ -693,12 +694,10 @@ function partUrl(part: MessagePart) {
|
||||
if (part.embedded_url) return part.embedded_url;
|
||||
if (part.embedded_file?.url) return part.embedded_file.url;
|
||||
if (part.attachment_id) {
|
||||
return `/api/chat/get_attachment?attachment_id=${encodeURIComponent(
|
||||
part.attachment_id,
|
||||
)}`;
|
||||
return fileApi.contentUrl(part.attachment_id);
|
||||
}
|
||||
if (part.filename) {
|
||||
return `/api/chat/get_file?filename=${encodeURIComponent(part.filename)}`;
|
||||
return fileApi.byNameUrl(part.filename);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, ref, watch } from 'vue';
|
||||
import axios from 'axios';
|
||||
import { configProfileApi, configRouteApi } from '@/api/v1';
|
||||
import { useToast } from '@/utils/toast';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import {
|
||||
@@ -164,8 +164,11 @@ function closeDialog() {
|
||||
async function fetchConfigList() {
|
||||
loadingConfigs.value = true;
|
||||
try {
|
||||
const res = await axios.get('/api/config/abconfs');
|
||||
configOptions.value = res.data.data?.info_list || [];
|
||||
const res = await configProfileApi.list();
|
||||
configOptions.value = (res.data.data?.info_list || []).map((item: any) => ({
|
||||
id: String(item.id || ''),
|
||||
name: String(item.name || item.id || 'default')
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error('加载配置文件列表失败', error);
|
||||
configOptions.value = [];
|
||||
@@ -176,7 +179,7 @@ async function fetchConfigList() {
|
||||
|
||||
async function fetchRoutingEntries() {
|
||||
try {
|
||||
const res = await axios.get('/api/config/umo_abconf_routes');
|
||||
const res = await configRouteApi.list();
|
||||
const routing = res.data.data?.routing || {};
|
||||
routingEntries.value = Object.entries(routing).map(([pattern, confId]) => ({
|
||||
pattern,
|
||||
@@ -214,10 +217,9 @@ async function getAgentRunnerType(confId: string): Promise<string> {
|
||||
return configCache.value[confId];
|
||||
}
|
||||
try {
|
||||
const res = await axios.get('/api/config/abconf', {
|
||||
params: { id: confId }
|
||||
});
|
||||
const type = res.data.data?.config?.provider_settings?.agent_runner_type || 'local';
|
||||
const res = await configProfileApi.get(confId);
|
||||
const config = ((res.data.data as any).config || {}) as any;
|
||||
const type = config?.provider_settings?.agent_runner_type || 'local';
|
||||
configCache.value[confId] = type;
|
||||
return type;
|
||||
} catch (error) {
|
||||
@@ -244,12 +246,11 @@ async function applySelectionToBackend(confId: string): Promise<boolean> {
|
||||
}
|
||||
saving.value = true;
|
||||
try {
|
||||
await axios.post('/api/config/umo_abconf_route/update', {
|
||||
umo: targetUmo.value,
|
||||
conf_id: confId
|
||||
});
|
||||
await configRouteApi.upsert(targetUmo.value, { config_id: confId });
|
||||
const filtered = routingEntries.value.filter((entry) => entry.pattern !== targetUmo.value);
|
||||
filtered.push({ pattern: targetUmo.value, confId });
|
||||
if (confId !== 'default') {
|
||||
filtered.push({ pattern: targetUmo.value, confId });
|
||||
}
|
||||
routingEntries.value = filtered;
|
||||
return true;
|
||||
} catch (error) {
|
||||
|
||||
@@ -55,6 +55,7 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onBeforeUnmount, watch } from 'vue';
|
||||
import { useTheme } from 'vuetify';
|
||||
import { chatApi } from '@/api/v1';
|
||||
import { useVADRecording } from '@/composables/useVADRecording';
|
||||
import SiriOrb from './LiveOrb.vue';
|
||||
|
||||
@@ -280,8 +281,7 @@ function connectWebSocket(): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = `${protocol}//localhost:6185/api/live_chat/ws?token=${encodeURIComponent(token)}`;
|
||||
const wsUrl = chatApi.liveWebSocketUrl(token);
|
||||
|
||||
ws = new WebSocket(wsUrl);
|
||||
|
||||
|
||||
@@ -256,6 +256,7 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, nextTick, reactive, ref } from "vue";
|
||||
import axios from "axios";
|
||||
import { fileApi } from "@/api/v1";
|
||||
import { setCustomComponents } from "markstream-vue";
|
||||
import "markstream-vue/index.css";
|
||||
import IPythonToolBlock from "@/components/chat/message_list_comps/IPythonToolBlock.vue";
|
||||
@@ -347,12 +348,10 @@ function partUrl(part: MessagePart) {
|
||||
if (part.embedded_url) return part.embedded_url;
|
||||
if (part.embedded_file?.url) return part.embedded_file.url;
|
||||
if (part.attachment_id) {
|
||||
return `/api/chat/get_attachment?attachment_id=${encodeURIComponent(
|
||||
part.attachment_id,
|
||||
)}`;
|
||||
return fileApi.contentUrl(part.attachment_id);
|
||||
}
|
||||
if (part.filename) {
|
||||
return `/api/chat/get_file?filename=${encodeURIComponent(part.filename)}`;
|
||||
return fileApi.byNameUrl(part.filename);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user