mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 01:10:21 +08:00
* fix: paginate knowledge base dashboard lists * fix: preserve knowledge document search pagination
508 lines
15 KiB
Python
508 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, Depends, Request
|
|
|
|
from astrbot.core import logger
|
|
from astrbot.dashboard.async_utils import run_maybe_async
|
|
from astrbot.dashboard.responses import error, ok
|
|
from astrbot.dashboard.schemas import (
|
|
KnowledgeBaseImportRequest,
|
|
KnowledgeBaseRequest,
|
|
KnowledgeBaseRetrieveRequest,
|
|
KnowledgeBaseUrlImportRequest,
|
|
)
|
|
from astrbot.dashboard.services.knowledge_base_service import (
|
|
KnowledgeBaseService,
|
|
KnowledgeBaseServiceError,
|
|
)
|
|
|
|
from .auth import AuthContext, require_dashboard_user, require_scope
|
|
from .multipart import multipart_parts
|
|
|
|
router = APIRouter(tags=["Knowledge Bases"])
|
|
legacy_router = APIRouter(
|
|
prefix="/api/kb",
|
|
tags=["Dashboard Knowledge Bases"],
|
|
include_in_schema=False,
|
|
)
|
|
|
|
|
|
def get_service(request: Request) -> KnowledgeBaseService:
|
|
return request.app.state.services.knowledge_bases
|
|
|
|
|
|
async def require_kb_scope(request: Request) -> AuthContext:
|
|
return await require_scope(request, "kb")
|
|
|
|
|
|
async def _json_or_empty(request: Request) -> dict[str, Any]:
|
|
try:
|
|
data = await request.json()
|
|
except Exception:
|
|
return {}
|
|
return data if isinstance(data, dict) else {}
|
|
|
|
|
|
def _to_int(value: Any, default: int) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _model_dict(payload) -> dict[str, Any]:
|
|
if payload is None:
|
|
return {}
|
|
if hasattr(payload, "model_dump"):
|
|
return payload.model_dump(exclude_none=True)
|
|
return payload if isinstance(payload, dict) else {}
|
|
|
|
|
|
async def _run(operation, *, prefix: str):
|
|
try:
|
|
result = await run_maybe_async(operation)
|
|
if isinstance(result, tuple):
|
|
data, message = result
|
|
return ok(data, message)
|
|
return ok(result)
|
|
except (KnowledgeBaseServiceError, ValueError) as exc:
|
|
return error(str(exc))
|
|
except Exception as exc:
|
|
logger.error("%s: %s", prefix, exc, exc_info=True)
|
|
return error(f"{prefix}: {exc!s}")
|
|
|
|
|
|
async def _run_json(
|
|
request: Request,
|
|
operation: Callable[[dict[str, Any]], Any],
|
|
*,
|
|
prefix: str,
|
|
):
|
|
body = await _json_or_empty(request)
|
|
return await _run(lambda: operation(body), prefix=prefix)
|
|
|
|
|
|
@router.get("/knowledge-bases")
|
|
async def list_knowledge_bases(
|
|
request: Request,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.list_kbs(
|
|
page=_to_int(request.query_params.get("page"), 1),
|
|
page_size=_to_int(request.query_params.get("page_size"), 20),
|
|
),
|
|
prefix="获取知识库列表失败",
|
|
)
|
|
|
|
|
|
@router.post("/knowledge-bases")
|
|
async def create_knowledge_base(
|
|
payload: KnowledgeBaseRequest,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.create_kb(_model_dict(payload)),
|
|
prefix="创建知识库失败",
|
|
)
|
|
|
|
|
|
@router.get("/knowledge-bases/tasks/{task_id}")
|
|
async def get_knowledge_base_task(
|
|
task_id: str,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_upload_progress(task_id),
|
|
prefix="获取上传进度失败",
|
|
)
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}")
|
|
async def get_knowledge_base(
|
|
kb_id: str,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(lambda: service.get_kb(kb_id), prefix="获取知识库详情失败")
|
|
|
|
|
|
@router.put("/knowledge-bases/{kb_id}")
|
|
async def update_knowledge_base(
|
|
kb_id: str,
|
|
payload: KnowledgeBaseRequest,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
body = _model_dict(payload)
|
|
return await _run(
|
|
lambda: service.update_kb({"kb_id": kb_id, **body}),
|
|
prefix="更新知识库失败",
|
|
)
|
|
|
|
|
|
@router.delete("/knowledge-bases/{kb_id}")
|
|
async def delete_knowledge_base(
|
|
kb_id: str,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.delete_kb({"kb_id": kb_id}), prefix="删除知识库失败"
|
|
)
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}/stats")
|
|
async def get_knowledge_base_stats(
|
|
kb_id: str,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_kb_stats(kb_id),
|
|
prefix="获取知识库统计失败",
|
|
)
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}/documents")
|
|
async def list_knowledge_base_documents(
|
|
kb_id: str,
|
|
request: Request,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.list_documents(
|
|
kb_id=kb_id,
|
|
page=_to_int(request.query_params.get("page"), 1),
|
|
page_size=_to_int(request.query_params.get("page_size"), 100),
|
|
search=request.query_params.get("search"),
|
|
),
|
|
prefix="获取文档列表失败",
|
|
)
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/documents")
|
|
async def upload_knowledge_base_document(
|
|
kb_id: str,
|
|
request: Request,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
async def _operation():
|
|
form_data, files = await multipart_parts(request, extra_form={"kb_id": kb_id})
|
|
return await service.upload_document(
|
|
content_type=request.headers.get("content-type"),
|
|
form_data=form_data,
|
|
files=files,
|
|
)
|
|
|
|
return await _run(_operation, prefix="上传文档失败")
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/documents/import")
|
|
async def import_knowledge_base_documents(
|
|
kb_id: str,
|
|
payload: KnowledgeBaseImportRequest,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
body = _model_dict(payload)
|
|
return await _run(
|
|
lambda: service.import_documents({"kb_id": kb_id, **body}),
|
|
prefix="导入文档失败",
|
|
)
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/documents/import-url")
|
|
async def import_knowledge_base_document_url(
|
|
kb_id: str,
|
|
payload: KnowledgeBaseUrlImportRequest,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
body = _model_dict(payload)
|
|
return await _run(
|
|
lambda: service.upload_document_from_url({"kb_id": kb_id, **body}),
|
|
prefix="从URL上传文档失败",
|
|
)
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}/documents/{document_id}")
|
|
async def get_knowledge_base_document(
|
|
kb_id: str,
|
|
document_id: str,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_document(kb_id=kb_id, doc_id=document_id),
|
|
prefix="获取文档详情失败",
|
|
)
|
|
|
|
|
|
@router.delete("/knowledge-bases/{kb_id}/documents/{document_id}")
|
|
async def delete_knowledge_base_document(
|
|
kb_id: str,
|
|
document_id: str,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.delete_document({"kb_id": kb_id, "doc_id": document_id}),
|
|
prefix="删除文档失败",
|
|
)
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}/chunks")
|
|
async def list_knowledge_base_chunks(
|
|
kb_id: str,
|
|
request: Request,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
document_id = request.query_params.get("document_id") or request.query_params.get(
|
|
"doc_id"
|
|
)
|
|
return await _run(
|
|
lambda: service.list_chunks(
|
|
kb_id=kb_id,
|
|
doc_id=document_id,
|
|
page=_to_int(request.query_params.get("page"), 1),
|
|
page_size=_to_int(request.query_params.get("page_size"), 100),
|
|
),
|
|
prefix="获取块列表失败",
|
|
)
|
|
|
|
|
|
@router.delete("/knowledge-bases/{kb_id}/chunks/{chunk_id}")
|
|
async def delete_knowledge_base_chunk(
|
|
kb_id: str,
|
|
chunk_id: str,
|
|
request: Request,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
document_id = request.query_params.get("document_id") or request.query_params.get(
|
|
"doc_id"
|
|
)
|
|
return await _run(
|
|
lambda: service.delete_chunk(
|
|
{"kb_id": kb_id, "chunk_id": chunk_id, "doc_id": document_id}
|
|
),
|
|
prefix="删除文本块失败",
|
|
)
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/retrieve")
|
|
async def retrieve_knowledge_base(
|
|
kb_id: str,
|
|
payload: KnowledgeBaseRetrieveRequest,
|
|
_auth: AuthContext = Depends(require_kb_scope),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
body = _model_dict(payload)
|
|
return await _run(
|
|
lambda: service.retrieve({"kb_id": kb_id, **body}),
|
|
prefix="检索失败",
|
|
)
|
|
|
|
|
|
@legacy_router.get("/list")
|
|
async def dashboard_list_kbs(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.list_kbs(
|
|
page=_to_int(request.query_params.get("page"), 1),
|
|
page_size=_to_int(request.query_params.get("page_size"), 20),
|
|
),
|
|
prefix="获取知识库列表失败",
|
|
)
|
|
|
|
|
|
@legacy_router.post("/create")
|
|
async def dashboard_create_kb(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.create_kb, prefix="创建知识库失败")
|
|
|
|
|
|
@legacy_router.get("/get")
|
|
async def dashboard_get_kb(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_kb(request.query_params.get("kb_id")),
|
|
prefix="获取知识库详情失败",
|
|
)
|
|
|
|
|
|
@legacy_router.post("/update")
|
|
async def dashboard_update_kb(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.update_kb, prefix="更新知识库失败")
|
|
|
|
|
|
@legacy_router.post("/delete")
|
|
async def dashboard_delete_kb(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.delete_kb, prefix="删除知识库失败")
|
|
|
|
|
|
@legacy_router.get("/stats")
|
|
async def dashboard_get_kb_stats(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_kb_stats(request.query_params.get("kb_id")),
|
|
prefix="获取知识库统计失败",
|
|
)
|
|
|
|
|
|
@legacy_router.get("/document/list")
|
|
async def dashboard_list_documents(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.list_documents(
|
|
kb_id=request.query_params.get("kb_id"),
|
|
page=_to_int(request.query_params.get("page"), 1),
|
|
page_size=_to_int(request.query_params.get("page_size"), 100),
|
|
search=request.query_params.get("search"),
|
|
),
|
|
prefix="获取文档列表失败",
|
|
)
|
|
|
|
|
|
@legacy_router.post("/document/upload")
|
|
async def dashboard_upload_document(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
async def _operation():
|
|
form_data, files = await multipart_parts(request)
|
|
return await service.upload_document(
|
|
content_type=request.headers.get("content-type"),
|
|
form_data=form_data,
|
|
files=files,
|
|
)
|
|
|
|
return await _run(_operation, prefix="上传文档失败")
|
|
|
|
|
|
@legacy_router.post("/document/import")
|
|
async def dashboard_import_documents(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.import_documents, prefix="导入文档失败")
|
|
|
|
|
|
@legacy_router.post("/document/upload/url")
|
|
async def dashboard_upload_document_from_url(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(
|
|
request,
|
|
service.upload_document_from_url,
|
|
prefix="从URL上传文档失败",
|
|
)
|
|
|
|
|
|
@legacy_router.get("/document/upload/progress")
|
|
async def dashboard_get_upload_progress(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_upload_progress(request.query_params.get("task_id")),
|
|
prefix="获取上传进度失败",
|
|
)
|
|
|
|
|
|
@legacy_router.get("/document/get")
|
|
async def dashboard_get_document(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.get_document(
|
|
kb_id=request.query_params.get("kb_id"),
|
|
doc_id=request.query_params.get("doc_id"),
|
|
),
|
|
prefix="获取文档详情失败",
|
|
)
|
|
|
|
|
|
@legacy_router.post("/document/delete")
|
|
async def dashboard_delete_document(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.delete_document, prefix="删除文档失败")
|
|
|
|
|
|
@legacy_router.get("/chunk/list")
|
|
async def dashboard_list_chunks(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run(
|
|
lambda: service.list_chunks(
|
|
kb_id=request.query_params.get("kb_id"),
|
|
doc_id=request.query_params.get("doc_id"),
|
|
page=_to_int(request.query_params.get("page"), 1),
|
|
page_size=_to_int(request.query_params.get("page_size"), 100),
|
|
),
|
|
prefix="获取块列表失败",
|
|
)
|
|
|
|
|
|
@legacy_router.post("/chunk/delete")
|
|
async def dashboard_delete_chunk(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.delete_chunk, prefix="删除文本块失败")
|
|
|
|
|
|
@legacy_router.post("/retrieve")
|
|
async def dashboard_retrieve(
|
|
request: Request,
|
|
_username: str = Depends(require_dashboard_user),
|
|
service: KnowledgeBaseService = Depends(get_service),
|
|
):
|
|
return await _run_json(request, service.retrieve, prefix="检索失败")
|