Files
AstrBot/astrbot/dashboard/api/knowledge_bases.py
lxfight 758e43273d fix: paginate knowledge base dashboard lists (#9055)
* fix: paginate knowledge base dashboard lists

* fix: preserve knowledge document search pagination
2026-06-28 14:00:45 +08:00

508 lines
15 KiB
Python

from __future__ import annotations
from collections.abc import Callable
from typing import Any
from fastapi import APIRouter, Depends, Request
from astrbot.core import logger
from astrbot.dashboard.async_utils import run_maybe_async
from astrbot.dashboard.responses import error, ok
from astrbot.dashboard.schemas import (
KnowledgeBaseImportRequest,
KnowledgeBaseRequest,
KnowledgeBaseRetrieveRequest,
KnowledgeBaseUrlImportRequest,
)
from astrbot.dashboard.services.knowledge_base_service import (
KnowledgeBaseService,
KnowledgeBaseServiceError,
)
from .auth import AuthContext, require_dashboard_user, require_scope
from .multipart import multipart_parts
router = APIRouter(tags=["Knowledge Bases"])
legacy_router = APIRouter(
prefix="/api/kb",
tags=["Dashboard Knowledge Bases"],
include_in_schema=False,
)
def get_service(request: Request) -> KnowledgeBaseService:
return request.app.state.services.knowledge_bases
async def require_kb_scope(request: Request) -> AuthContext:
return await require_scope(request, "kb")
async def _json_or_empty(request: Request) -> dict[str, Any]:
try:
data = await request.json()
except Exception:
return {}
return data if isinstance(data, dict) else {}
def _to_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _model_dict(payload) -> dict[str, Any]:
if payload is None:
return {}
if hasattr(payload, "model_dump"):
return payload.model_dump(exclude_none=True)
return payload if isinstance(payload, dict) else {}
async def _run(operation, *, prefix: str):
try:
result = await run_maybe_async(operation)
if isinstance(result, tuple):
data, message = result
return ok(data, message)
return ok(result)
except (KnowledgeBaseServiceError, ValueError) as exc:
return error(str(exc))
except Exception as exc:
logger.error("%s: %s", prefix, exc, exc_info=True)
return error(f"{prefix}: {exc!s}")
async def _run_json(
request: Request,
operation: Callable[[dict[str, Any]], Any],
*,
prefix: str,
):
body = await _json_or_empty(request)
return await _run(lambda: operation(body), prefix=prefix)
@router.get("/knowledge-bases")
async def list_knowledge_bases(
request: Request,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.list_kbs(
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 20),
),
prefix="获取知识库列表失败",
)
@router.post("/knowledge-bases")
async def create_knowledge_base(
payload: KnowledgeBaseRequest,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.create_kb(_model_dict(payload)),
prefix="创建知识库失败",
)
@router.get("/knowledge-bases/tasks/{task_id}")
async def get_knowledge_base_task(
task_id: str,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_upload_progress(task_id),
prefix="获取上传进度失败",
)
@router.get("/knowledge-bases/{kb_id}")
async def get_knowledge_base(
kb_id: str,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(lambda: service.get_kb(kb_id), prefix="获取知识库详情失败")
@router.put("/knowledge-bases/{kb_id}")
async def update_knowledge_base(
kb_id: str,
payload: KnowledgeBaseRequest,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.update_kb({"kb_id": kb_id, **body}),
prefix="更新知识库失败",
)
@router.delete("/knowledge-bases/{kb_id}")
async def delete_knowledge_base(
kb_id: str,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.delete_kb({"kb_id": kb_id}), prefix="删除知识库失败"
)
@router.get("/knowledge-bases/{kb_id}/stats")
async def get_knowledge_base_stats(
kb_id: str,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_kb_stats(kb_id),
prefix="获取知识库统计失败",
)
@router.get("/knowledge-bases/{kb_id}/documents")
async def list_knowledge_base_documents(
kb_id: str,
request: Request,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.list_documents(
kb_id=kb_id,
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 100),
search=request.query_params.get("search"),
),
prefix="获取文档列表失败",
)
@router.post("/knowledge-bases/{kb_id}/documents")
async def upload_knowledge_base_document(
kb_id: str,
request: Request,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
async def _operation():
form_data, files = await multipart_parts(request, extra_form={"kb_id": kb_id})
return await service.upload_document(
content_type=request.headers.get("content-type"),
form_data=form_data,
files=files,
)
return await _run(_operation, prefix="上传文档失败")
@router.post("/knowledge-bases/{kb_id}/documents/import")
async def import_knowledge_base_documents(
kb_id: str,
payload: KnowledgeBaseImportRequest,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.import_documents({"kb_id": kb_id, **body}),
prefix="导入文档失败",
)
@router.post("/knowledge-bases/{kb_id}/documents/import-url")
async def import_knowledge_base_document_url(
kb_id: str,
payload: KnowledgeBaseUrlImportRequest,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.upload_document_from_url({"kb_id": kb_id, **body}),
prefix="从URL上传文档失败",
)
@router.get("/knowledge-bases/{kb_id}/documents/{document_id}")
async def get_knowledge_base_document(
kb_id: str,
document_id: str,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_document(kb_id=kb_id, doc_id=document_id),
prefix="获取文档详情失败",
)
@router.delete("/knowledge-bases/{kb_id}/documents/{document_id}")
async def delete_knowledge_base_document(
kb_id: str,
document_id: str,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.delete_document({"kb_id": kb_id, "doc_id": document_id}),
prefix="删除文档失败",
)
@router.get("/knowledge-bases/{kb_id}/chunks")
async def list_knowledge_base_chunks(
kb_id: str,
request: Request,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
document_id = request.query_params.get("document_id") or request.query_params.get(
"doc_id"
)
return await _run(
lambda: service.list_chunks(
kb_id=kb_id,
doc_id=document_id,
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 100),
),
prefix="获取块列表失败",
)
@router.delete("/knowledge-bases/{kb_id}/chunks/{chunk_id}")
async def delete_knowledge_base_chunk(
kb_id: str,
chunk_id: str,
request: Request,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
document_id = request.query_params.get("document_id") or request.query_params.get(
"doc_id"
)
return await _run(
lambda: service.delete_chunk(
{"kb_id": kb_id, "chunk_id": chunk_id, "doc_id": document_id}
),
prefix="删除文本块失败",
)
@router.post("/knowledge-bases/{kb_id}/retrieve")
async def retrieve_knowledge_base(
kb_id: str,
payload: KnowledgeBaseRetrieveRequest,
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.retrieve({"kb_id": kb_id, **body}),
prefix="检索失败",
)
@legacy_router.get("/list")
async def dashboard_list_kbs(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.list_kbs(
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 20),
),
prefix="获取知识库列表失败",
)
@legacy_router.post("/create")
async def dashboard_create_kb(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.create_kb, prefix="创建知识库失败")
@legacy_router.get("/get")
async def dashboard_get_kb(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_kb(request.query_params.get("kb_id")),
prefix="获取知识库详情失败",
)
@legacy_router.post("/update")
async def dashboard_update_kb(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.update_kb, prefix="更新知识库失败")
@legacy_router.post("/delete")
async def dashboard_delete_kb(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.delete_kb, prefix="删除知识库失败")
@legacy_router.get("/stats")
async def dashboard_get_kb_stats(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_kb_stats(request.query_params.get("kb_id")),
prefix="获取知识库统计失败",
)
@legacy_router.get("/document/list")
async def dashboard_list_documents(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.list_documents(
kb_id=request.query_params.get("kb_id"),
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 100),
search=request.query_params.get("search"),
),
prefix="获取文档列表失败",
)
@legacy_router.post("/document/upload")
async def dashboard_upload_document(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
async def _operation():
form_data, files = await multipart_parts(request)
return await service.upload_document(
content_type=request.headers.get("content-type"),
form_data=form_data,
files=files,
)
return await _run(_operation, prefix="上传文档失败")
@legacy_router.post("/document/import")
async def dashboard_import_documents(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.import_documents, prefix="导入文档失败")
@legacy_router.post("/document/upload/url")
async def dashboard_upload_document_from_url(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(
request,
service.upload_document_from_url,
prefix="从URL上传文档失败",
)
@legacy_router.get("/document/upload/progress")
async def dashboard_get_upload_progress(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_upload_progress(request.query_params.get("task_id")),
prefix="获取上传进度失败",
)
@legacy_router.get("/document/get")
async def dashboard_get_document(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.get_document(
kb_id=request.query_params.get("kb_id"),
doc_id=request.query_params.get("doc_id"),
),
prefix="获取文档详情失败",
)
@legacy_router.post("/document/delete")
async def dashboard_delete_document(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.delete_document, prefix="删除文档失败")
@legacy_router.get("/chunk/list")
async def dashboard_list_chunks(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.list_chunks(
kb_id=request.query_params.get("kb_id"),
doc_id=request.query_params.get("doc_id"),
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 100),
),
prefix="获取块列表失败",
)
@legacy_router.post("/chunk/delete")
async def dashboard_delete_chunk(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.delete_chunk, prefix="删除文本块失败")
@legacy_router.post("/retrieve")
async def dashboard_retrieve(
request: Request,
_username: str = Depends(require_dashboard_user),
service: KnowledgeBaseService = Depends(get_service),
):
return await _run_json(request, service.retrieve, prefix="检索失败")