mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 03:00:15 +08:00
Compare commits
6 Commits
dev
...
pr-5943-de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f5541bc7e | ||
|
|
e7f57ae8ef | ||
|
|
0100f8d20c | ||
|
|
c1e2040f43 | ||
|
|
ae53b9fc9f | ||
|
|
63cbab610a |
@@ -11,6 +11,7 @@ from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
|
||||
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
|
||||
@@ -18,6 +19,8 @@ from astrbot._internal.protocols.lsp.client import AstrbotLspClient
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
from astrbot._internal.stars import RuntimeStatusStar
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotOrchestrator(BaseAstrbotOrchestrator):
|
||||
"""
|
||||
|
||||
@@ -10,11 +10,14 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core import LogBroker, LogManager
|
||||
@@ -43,6 +46,15 @@ from . import astrbot_config, html_renderer
|
||||
from .event_bus import EventBus
|
||||
|
||||
|
||||
class LifecycleState(str, Enum):
|
||||
"""Minimal lifecycle contract for split initialization."""
|
||||
|
||||
CREATED = "created"
|
||||
CORE_READY = "core_ready"
|
||||
RUNTIME_FAILED = "runtime_failed"
|
||||
RUNTIME_READY = "runtime_ready"
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
"""AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
|
||||
|
||||
@@ -56,9 +68,36 @@ class AstrBotCoreLifecycle:
|
||||
self.astrbot_config = astrbot_config # 初始化配置
|
||||
self.db = db # 初始化数据库
|
||||
|
||||
self.umop_config_router: UmopConfigRouter | None = None
|
||||
self.astrbot_config_mgr: AstrBotConfigManager | None = None
|
||||
self.event_queue: Queue | None = None
|
||||
self.persona_mgr: PersonaManager | None = None
|
||||
self.provider_manager: ProviderManager | None = None
|
||||
self.platform_manager: PlatformManager | None = None
|
||||
self.conversation_manager: ConversationManager | None = None
|
||||
self.platform_message_history_manager: PlatformMessageHistoryManager | None = (
|
||||
None
|
||||
)
|
||||
self.kb_manager: KnowledgeBaseManager | None = None
|
||||
self.subagent_orchestrator: SubAgentOrchestrator | None = None
|
||||
self.cron_manager: CronJobManager | None = None
|
||||
self.temp_dir_cleaner: TempDirCleaner | None = None
|
||||
self.star_context: Context | None = None
|
||||
self.plugin_manager: PluginManager | None = None
|
||||
self.pipeline_scheduler_mapping: dict[str, PipelineScheduler] = {}
|
||||
self.astrbot_updator: AstrBotUpdator | None = None
|
||||
self.event_bus: EventBus | None = None
|
||||
self.dashboard_shutdown_event: asyncio.Event | None = None
|
||||
self.curr_tasks: list[asyncio.Task] = []
|
||||
self.metadata_update_task: asyncio.Task[None] | None = None
|
||||
self.start_time = 0
|
||||
self.runtime_bootstrap_task: asyncio.Task[None] | None = None
|
||||
self.runtime_bootstrap_error: BaseException | None = None
|
||||
self.runtime_ready_event = asyncio.Event()
|
||||
self.runtime_failed_event = asyncio.Event()
|
||||
self.runtime_request_ready = False
|
||||
self._runtime_wait_interrupted = False
|
||||
self._set_lifecycle_state(LifecycleState.CREATED)
|
||||
|
||||
# 设置代理
|
||||
proxy_config = self.astrbot_config.get("http_proxy", "")
|
||||
@@ -79,6 +118,18 @@ class AstrBotCoreLifecycle:
|
||||
del os.environ["no_proxy"]
|
||||
logger.debug("HTTP proxy cleared")
|
||||
|
||||
@property
|
||||
def core_initialized(self) -> bool:
|
||||
return self.lifecycle_state is not LifecycleState.CREATED
|
||||
|
||||
@property
|
||||
def runtime_ready(self) -> bool:
|
||||
return self.lifecycle_state is LifecycleState.RUNTIME_READY
|
||||
|
||||
@property
|
||||
def runtime_failed(self) -> bool:
|
||||
return self.lifecycle_state is LifecycleState.RUNTIME_FAILED
|
||||
|
||||
async def _init_or_reload_subagent_orchestrator(self) -> None:
|
||||
"""Create (if needed) and reload the subagent orchestrator from config.
|
||||
|
||||
@@ -86,10 +137,14 @@ class AstrBotCoreLifecycle:
|
||||
to manage enable/disable and tool registration details.
|
||||
"""
|
||||
try:
|
||||
if self.provider_manager is None or self.persona_mgr is None:
|
||||
raise RuntimeError("core dependencies are not initialized")
|
||||
provider_manager = self.provider_manager
|
||||
persona_mgr = self.persona_mgr
|
||||
if self.subagent_orchestrator is None:
|
||||
self.subagent_orchestrator = SubAgentOrchestrator(
|
||||
self.provider_manager.llm_tools,
|
||||
self.persona_mgr,
|
||||
provider_manager.llm_tools,
|
||||
persona_mgr,
|
||||
)
|
||||
await self.subagent_orchestrator.reload_from_config(
|
||||
self.astrbot_config.get("subagent_orchestrator", {}),
|
||||
@@ -97,11 +152,199 @@ class AstrBotCoreLifecycle:
|
||||
except Exception as e:
|
||||
logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 AstrBot 核心生命周期管理类.
|
||||
def _set_lifecycle_state(self, state: LifecycleState) -> None:
|
||||
"""Update lifecycle state and keep readiness events in sync."""
|
||||
self.lifecycle_state = state
|
||||
if state is LifecycleState.RUNTIME_READY:
|
||||
self.runtime_ready_event.set()
|
||||
self.runtime_failed_event.clear()
|
||||
elif state is LifecycleState.RUNTIME_FAILED:
|
||||
self.runtime_ready_event.clear()
|
||||
self.runtime_failed_event.set()
|
||||
else:
|
||||
self.runtime_ready_event.clear()
|
||||
self.runtime_failed_event.clear()
|
||||
|
||||
def _clear_runtime_failure_for_retry(self) -> None:
|
||||
if self.lifecycle_state is LifecycleState.RUNTIME_FAILED:
|
||||
self._set_lifecycle_state(LifecycleState.CORE_READY)
|
||||
|
||||
async def _cleanup_partial_runtime_bootstrap(self) -> None:
|
||||
if self.star_context is not None and hasattr(
|
||||
self.star_context,
|
||||
"reset_runtime_registrations",
|
||||
):
|
||||
self.star_context.reset_runtime_registrations()
|
||||
if self.plugin_manager is not None and hasattr(
|
||||
self.plugin_manager,
|
||||
"cleanup_loaded_plugins",
|
||||
):
|
||||
try:
|
||||
cleanup_loaded_plugins = getattr(
|
||||
self.plugin_manager,
|
||||
"cleanup_loaded_plugins",
|
||||
)
|
||||
result = cleanup_loaded_plugins()
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to clean up loaded plugin state: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
for manager in (self.platform_manager, self.kb_manager, self.provider_manager):
|
||||
if manager is None:
|
||||
continue
|
||||
try:
|
||||
terminate = getattr(manager, "terminate", None)
|
||||
if not callable(terminate):
|
||||
continue
|
||||
result = terminate()
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to clean up partial runtime bootstrap state: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
self._clear_runtime_artifacts()
|
||||
|
||||
def _reset_runtime_bootstrap_state(self) -> None:
|
||||
self.runtime_bootstrap_task = None
|
||||
self.runtime_bootstrap_error = None
|
||||
|
||||
def _interrupt_runtime_bootstrap_waiters(self) -> None:
|
||||
self._runtime_wait_interrupted = True
|
||||
self.runtime_bootstrap_error = None
|
||||
self.runtime_failed_event.set()
|
||||
|
||||
async def _consume_completed_bootstrap_task(self) -> None:
|
||||
task = self.runtime_bootstrap_task
|
||||
if task is None or not task.done():
|
||||
return
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _wait_for_runtime_ready(self) -> bool:
|
||||
if self.runtime_ready:
|
||||
return True
|
||||
if self._runtime_wait_interrupted:
|
||||
return False
|
||||
if self.runtime_failed or self.runtime_bootstrap_error is not None:
|
||||
await self._consume_completed_bootstrap_task()
|
||||
return False
|
||||
|
||||
runtime_bootstrap_task = self.runtime_bootstrap_task
|
||||
if runtime_bootstrap_task is None:
|
||||
raise RuntimeError(
|
||||
"runtime bootstrap task was not scheduled before start",
|
||||
)
|
||||
|
||||
try:
|
||||
await runtime_bootstrap_task
|
||||
except asyncio.CancelledError:
|
||||
return False
|
||||
except BaseException as exc:
|
||||
if self.runtime_bootstrap_error is None:
|
||||
self.runtime_bootstrap_error = exc
|
||||
if not self.runtime_failed:
|
||||
self._set_lifecycle_state(LifecycleState.RUNTIME_FAILED)
|
||||
return False
|
||||
|
||||
if self._runtime_wait_interrupted:
|
||||
return False
|
||||
|
||||
return self.runtime_ready
|
||||
|
||||
def _collect_runtime_bootstrap_task(self) -> list[asyncio.Task]:
|
||||
task = self.runtime_bootstrap_task
|
||||
self.runtime_bootstrap_task = None
|
||||
if task is None:
|
||||
return []
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
return [task]
|
||||
|
||||
def _collect_metadata_update_task(self) -> list[asyncio.Task]:
|
||||
task = self.metadata_update_task
|
||||
self.metadata_update_task = None
|
||||
if task is None:
|
||||
return []
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
return [task]
|
||||
|
||||
async def _await_tasks(self, tasks: list[asyncio.Task]) -> None:
|
||||
for task in tasks:
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||
|
||||
def _require_runtime_bootstrap_components(
|
||||
self,
|
||||
) -> tuple[PluginManager, ProviderManager, KnowledgeBaseManager, PlatformManager]:
|
||||
if (
|
||||
self.plugin_manager is None
|
||||
or self.provider_manager is None
|
||||
or self.kb_manager is None
|
||||
or self.platform_manager is None
|
||||
):
|
||||
raise RuntimeError("initialize_core must complete before runtime bootstrap")
|
||||
return (
|
||||
self.plugin_manager,
|
||||
self.provider_manager,
|
||||
self.kb_manager,
|
||||
self.platform_manager,
|
||||
)
|
||||
|
||||
def _require_runtime_started_components(self) -> tuple[EventBus, Context]:
|
||||
if self.lifecycle_state is not LifecycleState.RUNTIME_READY:
|
||||
raise RuntimeError("LifecycleState.RUNTIME_READY required before start")
|
||||
if self.event_bus is None or self.star_context is None:
|
||||
raise RuntimeError("runtime bootstrap must complete before start")
|
||||
return self.event_bus, self.star_context
|
||||
|
||||
def _cancel_current_tasks(self) -> list[asyncio.Task]:
|
||||
tasks_to_wait: list[asyncio.Task] = []
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
if isinstance(task, asyncio.Task):
|
||||
tasks_to_wait.append(task)
|
||||
self.curr_tasks = []
|
||||
return tasks_to_wait
|
||||
|
||||
def _clear_runtime_artifacts(self) -> None:
|
||||
self.metadata_update_task = None
|
||||
self.runtime_request_ready = False
|
||||
self.event_bus = None
|
||||
self.pipeline_scheduler_mapping = {}
|
||||
self.curr_tasks = []
|
||||
self.start_time = 0
|
||||
|
||||
def _require_core_ready(self) -> None:
|
||||
if not self.core_initialized:
|
||||
raise RuntimeError("initialize_core must complete before this operation")
|
||||
|
||||
def _require_platform_manager(self) -> PlatformManager:
|
||||
if self.platform_manager is None:
|
||||
raise RuntimeError("platform manager is not initialized")
|
||||
return self.platform_manager
|
||||
|
||||
async def initialize_core(self) -> None:
|
||||
"""Initialize the fast core phase without runtime bootstrap."""
|
||||
if self.core_initialized:
|
||||
return
|
||||
|
||||
self._runtime_wait_interrupted = False
|
||||
self._reset_runtime_bootstrap_state()
|
||||
|
||||
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
# 初始化日志代理
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
if os.environ.get("TESTING", ""):
|
||||
@@ -127,8 +370,11 @@ class AstrBotCoreLifecycle:
|
||||
ucr=self.umop_config_router,
|
||||
sp=sp,
|
||||
)
|
||||
if self.astrbot_config_mgr is None:
|
||||
raise RuntimeError("config manager initialization failed")
|
||||
astrbot_config_mgr = self.astrbot_config_mgr
|
||||
self.temp_dir_cleaner = TempDirCleaner(
|
||||
max_size_getter=lambda: self.astrbot_config_mgr.default_conf.get(
|
||||
max_size_getter=lambda: astrbot_config_mgr.default_conf.get(
|
||||
TempDirCleaner.CONFIG_KEY,
|
||||
TempDirCleaner.DEFAULT_MAX_SIZE,
|
||||
),
|
||||
@@ -197,53 +443,100 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化插件管理器
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
# 扫描、注册插件、实例化插件类
|
||||
await self.plugin_manager.reload()
|
||||
|
||||
# 根据配置实例化各个 Provider
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
# 初始化更新器
|
||||
# 为提前启动 Dashboard 准备核心依赖
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(
|
||||
self.event_queue,
|
||||
self.pipeline_scheduler_mapping,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# 记录启动时间
|
||||
self.start_time = int(time.time())
|
||||
|
||||
# 初始化当前任务列表
|
||||
self.curr_tasks: list[asyncio.Task] = []
|
||||
|
||||
# 根据配置实例化各个平台适配器
|
||||
await self.platform_manager.initialize()
|
||||
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
asyncio.create_task(update_llm_metadata()) # noqa: RUF006
|
||||
self._set_lifecycle_state(LifecycleState.CORE_READY)
|
||||
|
||||
async def bootstrap_runtime(self) -> None:
|
||||
"""Complete deferred runtime bootstrap after core initialization."""
|
||||
if not self.core_initialized:
|
||||
raise RuntimeError(
|
||||
"initialize_core must be called before bootstrap_runtime",
|
||||
)
|
||||
if self.runtime_ready:
|
||||
return
|
||||
|
||||
self._clear_runtime_failure_for_retry()
|
||||
self.runtime_bootstrap_error = None
|
||||
self.runtime_ready_event.clear()
|
||||
self.runtime_failed_event.clear()
|
||||
|
||||
try:
|
||||
plugin_manager, provider_manager, kb_manager, platform_manager = (
|
||||
self._require_runtime_bootstrap_components()
|
||||
)
|
||||
|
||||
# 扫描、注册插件、实例化插件类
|
||||
await plugin_manager.reload()
|
||||
|
||||
# 根据配置实例化各个 Provider
|
||||
await provider_manager.initialize()
|
||||
|
||||
await kb_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
if self.event_queue is None or self.astrbot_config_mgr is None:
|
||||
raise RuntimeError(
|
||||
"initialize_core must complete before runtime bootstrap",
|
||||
)
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(
|
||||
self.event_queue,
|
||||
self.pipeline_scheduler_mapping,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# 记录启动时间
|
||||
self.start_time = int(time.time())
|
||||
|
||||
# 初始化当前任务列表
|
||||
self.curr_tasks = []
|
||||
|
||||
# 根据配置实例化各个平台适配器
|
||||
await platform_manager.initialize()
|
||||
|
||||
self.metadata_update_task = asyncio.create_task(update_llm_metadata())
|
||||
|
||||
self._set_lifecycle_state(LifecycleState.RUNTIME_READY)
|
||||
except asyncio.CancelledError:
|
||||
await self._cleanup_partial_runtime_bootstrap()
|
||||
self._set_lifecycle_state(LifecycleState.CORE_READY)
|
||||
self.runtime_bootstrap_error = None
|
||||
raise
|
||||
except BaseException as exc:
|
||||
await self._cleanup_partial_runtime_bootstrap()
|
||||
self._set_lifecycle_state(LifecycleState.RUNTIME_FAILED)
|
||||
self.runtime_bootstrap_error = exc
|
||||
raise
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 AstrBot 核心生命周期管理类.
|
||||
|
||||
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
await self.initialize_core()
|
||||
await self.bootstrap_runtime()
|
||||
self.runtime_request_ready = True
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
event_bus, star_context = self._require_runtime_started_components()
|
||||
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||
event_bus_task = asyncio.create_task(
|
||||
self.event_bus.dispatch(),
|
||||
event_bus.dispatch(),
|
||||
name="event_bus",
|
||||
)
|
||||
cron_task = None
|
||||
if self.cron_manager:
|
||||
cron_task = asyncio.create_task(
|
||||
self.cron_manager.start(self.star_context),
|
||||
self.cron_manager.start(star_context),
|
||||
name="cron_manager",
|
||||
)
|
||||
temp_dir_cleaner_task = None
|
||||
@@ -254,9 +547,9 @@ class AstrBotCoreLifecycle:
|
||||
)
|
||||
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
if self.star_context._register_tasks is not None:
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks: list[asyncio.Task[Any]] = []
|
||||
if star_context._register_tasks is not None:
|
||||
for task in star_context._register_tasks:
|
||||
task_name = getattr(task, "__name__", task.__class__.__name__)
|
||||
extra_tasks.append(asyncio.create_task(task, name=task_name))
|
||||
|
||||
@@ -295,6 +588,18 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
用load加载事件总线和任务并初始化, 执行启动完成事件钩子
|
||||
"""
|
||||
if not await self._wait_for_runtime_ready():
|
||||
if self._runtime_wait_interrupted:
|
||||
return
|
||||
error = self.runtime_bootstrap_error
|
||||
if error is None:
|
||||
logger.error("AstrBot runtime bootstrap failed before start completed.")
|
||||
else:
|
||||
logger.error(
|
||||
f"AstrBot runtime bootstrap failed before start completed: {error}",
|
||||
)
|
||||
return
|
||||
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
@@ -311,50 +616,59 @@ class AstrBotCoreLifecycle:
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
self.runtime_request_ready = True
|
||||
|
||||
# 同时运行curr_tasks中的所有任务
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def _shutdown_runtime(self) -> None:
|
||||
self.runtime_request_ready = False
|
||||
self._interrupt_runtime_bootstrap_waiters()
|
||||
|
||||
tasks_to_wait = self._cancel_current_tasks()
|
||||
await self._await_tasks(self._collect_metadata_update_task())
|
||||
runtime_bootstrap_tasks = self._collect_runtime_bootstrap_task()
|
||||
await self._await_tasks(runtime_bootstrap_tasks)
|
||||
tasks_to_wait.extend(runtime_bootstrap_tasks)
|
||||
|
||||
if self.cron_manager:
|
||||
await self.cron_manager.shutdown()
|
||||
|
||||
if self.plugin_manager and self.plugin_manager.context:
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
try:
|
||||
await self.plugin_manager._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
|
||||
)
|
||||
|
||||
if self.provider_manager:
|
||||
await self.provider_manager.terminate()
|
||||
if self.platform_manager:
|
||||
await self.platform_manager.terminate()
|
||||
if self.kb_manager:
|
||||
await self.kb_manager.terminate()
|
||||
if self.dashboard_shutdown_event:
|
||||
self.dashboard_shutdown_event.set()
|
||||
|
||||
self._clear_runtime_artifacts()
|
||||
self._set_lifecycle_state(LifecycleState.CREATED)
|
||||
self._reset_runtime_bootstrap_state()
|
||||
await self._await_tasks(tasks_to_wait)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
|
||||
if self.temp_dir_cleaner:
|
||||
await self.temp_dir_cleaner.stop()
|
||||
|
||||
# 请求停止所有正在运行的异步任务
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
|
||||
if self.cron_manager:
|
||||
await self.cron_manager.shutdown()
|
||||
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
try:
|
||||
await self.plugin_manager._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
|
||||
)
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
|
||||
# 再次遍历curr_tasks等待每个任务真正结束
|
||||
for task in self.curr_tasks:
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||
await self._shutdown_runtime()
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
await self._shutdown_runtime()
|
||||
if self.astrbot_updator is None:
|
||||
return
|
||||
threading.Thread(
|
||||
target=self.astrbot_updator._reboot,
|
||||
name="restart",
|
||||
@@ -364,7 +678,7 @@ class AstrBotCoreLifecycle:
|
||||
def load_platform(self) -> list[asyncio.Task]:
|
||||
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||
tasks = []
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
platform_insts = self._require_platform_manager().get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
@@ -382,9 +696,14 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
"""
|
||||
mapping = {}
|
||||
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
||||
self._require_core_ready()
|
||||
assert self.astrbot_config_mgr is not None
|
||||
assert self.plugin_manager is not None
|
||||
astrbot_config_mgr = self.astrbot_config_mgr
|
||||
plugin_manager = self.plugin_manager
|
||||
for conf_id, ab_config in astrbot_config_mgr.confs.items():
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id),
|
||||
PipelineContext(ab_config, plugin_manager, conf_id),
|
||||
)
|
||||
await scheduler.initialize()
|
||||
mapping[conf_id] = scheduler
|
||||
@@ -397,11 +716,16 @@ class AstrBotCoreLifecycle:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
|
||||
"""
|
||||
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
||||
self._require_core_ready()
|
||||
assert self.astrbot_config_mgr is not None
|
||||
astrbot_config_mgr = self.astrbot_config_mgr
|
||||
ab_config = astrbot_config_mgr.confs.get(conf_id)
|
||||
if not ab_config:
|
||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||
assert self.plugin_manager is not None
|
||||
plugin_manager = self.plugin_manager
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id),
|
||||
PipelineContext(ab_config, plugin_manager, conf_id),
|
||||
)
|
||||
await scheduler.initialize()
|
||||
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -27,20 +28,28 @@ class InitialLoader:
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
await core_lifecycle.initialize_core()
|
||||
except Exception as e:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||
return
|
||||
|
||||
core_lifecycle.runtime_bootstrap_task = asyncio.create_task(
|
||||
core_lifecycle.bootstrap_runtime(),
|
||||
)
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
shutdown_event = core_lifecycle.dashboard_shutdown_event
|
||||
if shutdown_event is None:
|
||||
raise RuntimeError("initialize_core must set dashboard_shutdown_event")
|
||||
shutdown_event = cast(asyncio.Event, shutdown_event)
|
||||
|
||||
webui_dir = self.webui_dir
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle,
|
||||
self.db,
|
||||
core_lifecycle.dashboard_shutdown_event,
|
||||
shutdown_event,
|
||||
webui_dir,
|
||||
)
|
||||
|
||||
@@ -55,3 +64,6 @@ class InitialLoader:
|
||||
except asyncio.CancelledError:
|
||||
logger.info("🌈 正在关闭 AstrBot...")
|
||||
await core_lifecycle.stop()
|
||||
except Exception:
|
||||
await core_lifecycle.stop()
|
||||
raise
|
||||
|
||||
@@ -742,25 +742,26 @@ class ProviderManager:
|
||||
logger.info(
|
||||
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
|
||||
)
|
||||
provider_inst = self.inst_map[provider_id]
|
||||
|
||||
if self.inst_map[provider_id] in self.provider_insts:
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if provider_inst in self.provider_insts:
|
||||
prov_inst = provider_inst
|
||||
if isinstance(prov_inst, Provider):
|
||||
self.provider_insts.remove(prov_inst)
|
||||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if provider_inst in self.stt_provider_insts:
|
||||
prov_inst = provider_inst
|
||||
if isinstance(prov_inst, STTProvider):
|
||||
self.stt_provider_insts.remove(prov_inst)
|
||||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||
prov_inst = self.inst_map[provider_id]
|
||||
if provider_inst in self.tts_provider_insts:
|
||||
prov_inst = provider_inst
|
||||
if isinstance(prov_inst, TTSProvider):
|
||||
self.tts_provider_insts.remove(prov_inst)
|
||||
|
||||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||
if provider_inst == self.curr_provider_inst:
|
||||
self.curr_provider_inst = None
|
||||
if self.inst_map[provider_id] == self.curr_stt_provider_inst:
|
||||
if provider_inst == self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = None
|
||||
if self.inst_map[provider_id] == self.curr_tts_provider_inst:
|
||||
if provider_inst == self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = None
|
||||
|
||||
inst = self.inst_map[provider_id]
|
||||
@@ -836,6 +837,35 @@ class ProviderManager:
|
||||
# sync in-memory config for API queries (e.g., embedding provider list)
|
||||
self.providers_config = astrbot_config["provider"]
|
||||
|
||||
def _get_all_provider_instances(self) -> list[Providers]:
|
||||
seen: set[int] = set()
|
||||
instances: list[Providers] = []
|
||||
for provider_inst in [
|
||||
*self.provider_insts,
|
||||
*self.stt_provider_insts,
|
||||
*self.tts_provider_insts,
|
||||
*self.embedding_provider_insts,
|
||||
*self.rerank_provider_insts,
|
||||
*self.inst_map.values(),
|
||||
]:
|
||||
marker = id(provider_inst)
|
||||
if marker in seen:
|
||||
continue
|
||||
seen.add(marker)
|
||||
instances.append(provider_inst)
|
||||
return instances
|
||||
|
||||
def _clear_loaded_instances(self) -> None:
|
||||
self.provider_insts = []
|
||||
self.stt_provider_insts = []
|
||||
self.tts_provider_insts = []
|
||||
self.embedding_provider_insts = []
|
||||
self.rerank_provider_insts = []
|
||||
self.inst_map = {}
|
||||
self.curr_provider_inst = None
|
||||
self.curr_stt_provider_inst = None
|
||||
self.curr_tts_provider_inst = None
|
||||
|
||||
async def terminate(self) -> None:
|
||||
if self._mcp_init_task and not self._mcp_init_task.done():
|
||||
self._mcp_init_task.cancel()
|
||||
@@ -844,9 +874,20 @@ class ProviderManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
for provider_inst in self.provider_insts:
|
||||
if isinstance(provider_inst, SupportsTerminate):
|
||||
self._mcp_init_task = None
|
||||
provider_instances = self._get_all_provider_instances()
|
||||
self._clear_loaded_instances()
|
||||
|
||||
for provider_inst in provider_instances:
|
||||
if not isinstance(provider_inst, SupportsTerminate):
|
||||
continue
|
||||
try:
|
||||
await provider_inst.terminate()
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Error while terminating provider instance",
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await self.llm_tools.disable_mcp_server()
|
||||
except Exception:
|
||||
|
||||
@@ -84,6 +84,8 @@ class Context:
|
||||
cron_manager: CronJobManager,
|
||||
subagent_orchestrator: SubAgentOrchestrator | None = None,
|
||||
) -> None:
|
||||
self.registered_web_apis = []
|
||||
self._register_tasks = []
|
||||
self._event_queue = event_queue
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
self._config = config
|
||||
@@ -109,10 +111,16 @@ class Context:
|
||||
self.subagent_orchestrator = subagent_orchestrator
|
||||
|
||||
# Register built-in tools so they appear in WebUI and can be
|
||||
# assigned to subagents. Done here (not at module-import time)
|
||||
# assigned to subagents. Done here (not at module-import time)
|
||||
# to avoid circular imports.
|
||||
self.provider_manager.llm_tools.register_internal_tools()
|
||||
|
||||
def reset_runtime_registrations(self) -> None:
|
||||
if self.registered_web_apis is not None:
|
||||
self.registered_web_apis.clear()
|
||||
if self._register_tasks is not None:
|
||||
self._register_tasks.clear()
|
||||
|
||||
async def llm_generate(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -108,9 +108,9 @@ async def _temporary_filtered_requirements_file(
|
||||
try:
|
||||
yield filtered_requirements_path
|
||||
finally:
|
||||
if filtered_requirements_path and os.path.exists(
|
||||
if filtered_requirements_path and await anyio.Path(
|
||||
filtered_requirements_path
|
||||
):
|
||||
).exists():
|
||||
try:
|
||||
await to_thread.run_sync(os.remove, filtered_requirements_path)
|
||||
except OSError as exc:
|
||||
@@ -742,6 +742,24 @@ class PluginManager:
|
||||
|
||||
return result
|
||||
|
||||
async def cleanup_loaded_plugins(self) -> None:
|
||||
"""Terminate and unbind all currently loaded plugins without reloading."""
|
||||
async with self._pm_lock:
|
||||
for smd in list(star_registry):
|
||||
try:
|
||||
await self._terminate_plugin(smd)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {smd.name} 未被正常终止: {e!s}, 可能会导致该插件运行不正常。",
|
||||
)
|
||||
if smd.name and smd.module_path:
|
||||
await self._unbind_plugin(smd.name, smd.module_path)
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
|
||||
async def load(
|
||||
self,
|
||||
specified_module_path=None,
|
||||
|
||||
@@ -24,7 +24,26 @@ from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queu
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
|
||||
from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
|
||||
from .route import Route, RouteContext
|
||||
from .route import (
|
||||
Route,
|
||||
RouteContext,
|
||||
get_runtime_guard_message,
|
||||
is_runtime_request_ready,
|
||||
)
|
||||
|
||||
|
||||
class _QueueTimeoutSentinel:
|
||||
pass
|
||||
|
||||
|
||||
_QUEUE_TIMEOUT = _QueueTimeoutSentinel()
|
||||
|
||||
|
||||
class _ReceiveTimeoutSentinel:
|
||||
pass
|
||||
|
||||
|
||||
_RECEIVE_TIMEOUT = _ReceiveTimeoutSentinel()
|
||||
|
||||
|
||||
class LiveChatSession:
|
||||
@@ -137,6 +156,49 @@ class LiveChatRoute(Route):
|
||||
"""Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
|
||||
await self._unified_ws_loop(force_ct=None)
|
||||
|
||||
async def _ensure_runtime_ready(self) -> bool:
|
||||
if is_runtime_request_ready(self.core_lifecycle):
|
||||
return True
|
||||
await websocket.close(
|
||||
1013,
|
||||
get_runtime_guard_message(self.core_lifecycle),
|
||||
)
|
||||
return False
|
||||
|
||||
async def _recv_ws_json_guarded(
|
||||
self,
|
||||
*,
|
||||
wait_timeout: float = 1.0,
|
||||
) -> dict[str, Any] | _ReceiveTimeoutSentinel | None:
|
||||
if not await self._ensure_runtime_ready():
|
||||
return None
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.receive_json(),
|
||||
timeout=wait_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return _RECEIVE_TIMEOUT
|
||||
if not await self._ensure_runtime_ready():
|
||||
return None
|
||||
return message
|
||||
|
||||
async def _guarded_queue_get(
|
||||
self,
|
||||
back_queue: asyncio.Queue,
|
||||
*,
|
||||
wait_timeout: float,
|
||||
) -> dict[str, Any] | _QueueTimeoutSentinel | None:
|
||||
if not await self._ensure_runtime_ready():
|
||||
return None
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=wait_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return _QUEUE_TIMEOUT
|
||||
if not await self._ensure_runtime_ready():
|
||||
return None
|
||||
return result
|
||||
|
||||
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
|
||||
"""统一 WebSocket 循环"""
|
||||
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
|
||||
@@ -157,6 +219,9 @@ class LiveChatRoute(Route):
|
||||
await websocket.close(1008, "Invalid token")
|
||||
return
|
||||
|
||||
if not await self._ensure_runtime_ready():
|
||||
return
|
||||
|
||||
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
|
||||
live_session = LiveChatSession(session_id, username)
|
||||
self.sessions[session_id] = live_session
|
||||
@@ -165,7 +230,11 @@ class LiveChatRoute(Route):
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
message = await self._recv_ws_json_guarded()
|
||||
if isinstance(message, _ReceiveTimeoutSentinel):
|
||||
continue
|
||||
if message is None:
|
||||
return
|
||||
ct = force_ct or message.get("ct", "live")
|
||||
if ct == "chat":
|
||||
await self._handle_chat_message(live_session, message)
|
||||
@@ -289,7 +358,11 @@ class LiveChatRoute(Route):
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
result = await back_queue.get()
|
||||
result = await self._guarded_queue_get(back_queue, wait_timeout=1)
|
||||
if isinstance(result, _QueueTimeoutSentinel):
|
||||
continue
|
||||
if result is None:
|
||||
break
|
||||
if not result:
|
||||
continue
|
||||
await self._send_chat_payload(session, {"ct": "chat", **result})
|
||||
@@ -486,14 +559,17 @@ class LiveChatRoute(Route):
|
||||
refs = {}
|
||||
|
||||
while True:
|
||||
if not await self._ensure_runtime_ready():
|
||||
break
|
||||
if session.should_interrupt:
|
||||
session.should_interrupt = False
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
result = await self._guarded_queue_get(back_queue, wait_timeout=1)
|
||||
if isinstance(result, _QueueTimeoutSentinel):
|
||||
continue
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if not result:
|
||||
continue
|
||||
@@ -773,6 +849,8 @@ class LiveChatRoute(Route):
|
||||
|
||||
try:
|
||||
while True:
|
||||
if not await self._ensure_runtime_ready():
|
||||
break
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
@@ -789,10 +867,14 @@ class LiveChatRoute(Route):
|
||||
break
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
result = await self._guarded_queue_get(
|
||||
back_queue,
|
||||
wait_timeout=0.5,
|
||||
)
|
||||
if isinstance(result, _QueueTimeoutSentinel):
|
||||
continue
|
||||
if result is None:
|
||||
break
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
@@ -19,7 +19,13 @@ from astrbot.core.utils.datetime_utils import to_utc_isoformat
|
||||
|
||||
from .api_key import ALL_OPEN_API_SCOPES
|
||||
from .chat import ChatRoute
|
||||
from .route import Response, Route, RouteContext
|
||||
from .route import (
|
||||
Response,
|
||||
Route,
|
||||
RouteContext,
|
||||
get_runtime_guard_message,
|
||||
is_runtime_request_ready,
|
||||
)
|
||||
|
||||
|
||||
class OpenApiRoute(Route):
|
||||
@@ -244,6 +250,14 @@ class OpenApiRoute(Route):
|
||||
}
|
||||
)
|
||||
|
||||
async def _ensure_runtime_ready(self) -> bool:
|
||||
if is_runtime_request_ready(self.core_lifecycle):
|
||||
return True
|
||||
message = get_runtime_guard_message(self.core_lifecycle)
|
||||
await self._send_chat_ws_error(message, "RUNTIME_NOT_READY")
|
||||
await websocket.close(1013, message)
|
||||
return False
|
||||
|
||||
async def _update_session_config_route(
|
||||
self,
|
||||
*,
|
||||
@@ -370,11 +384,16 @@ class OpenApiRoute(Route):
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
while True:
|
||||
if not await self._ensure_runtime_ready():
|
||||
return
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not await self._ensure_runtime_ready():
|
||||
return
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
@@ -512,9 +531,16 @@ class OpenApiRoute(Route):
|
||||
await websocket.close(1008, auth_err or "Unauthorized")
|
||||
return
|
||||
|
||||
if not await self._ensure_runtime_ready():
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
if not await self._ensure_runtime_ready():
|
||||
return
|
||||
message = await websocket.receive_json()
|
||||
if not await self._ensure_runtime_ready():
|
||||
return
|
||||
if not isinstance(message, dict):
|
||||
await self._send_chat_ws_error(
|
||||
"message must be an object",
|
||||
|
||||
@@ -30,12 +30,33 @@ from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_temp_path,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
from .route import Response, Route, RouteContext, guard_runtime_ready
|
||||
|
||||
PLUGIN_UPDATE_CONCURRENCY = (
|
||||
3 # limit concurrent updates to avoid overwhelming plugin sources
|
||||
)
|
||||
|
||||
PLUGIN_ROUTE_DEFINITIONS = (
|
||||
("/plugin/get", "GET", "get_plugins", True),
|
||||
("/plugin/check-compat", "POST", "check_plugin_compatibility", False),
|
||||
("/plugin/install", "POST", "install_plugin", True),
|
||||
("/plugin/install-upload", "POST", "install_plugin_upload", True),
|
||||
("/plugin/update", "POST", "update_plugin", True),
|
||||
("/plugin/update-all", "POST", "update_all_plugins", True),
|
||||
("/plugin/uninstall", "POST", "uninstall_plugin", True),
|
||||
("/plugin/uninstall-failed", "POST", "uninstall_failed_plugin", False),
|
||||
("/plugin/market_list", "GET", "get_online_plugins", False),
|
||||
("/plugin/off", "POST", "off_plugin", True),
|
||||
("/plugin/on", "POST", "on_plugin", True),
|
||||
("/plugin/reload-failed", "POST", "reload_failed_plugins", False),
|
||||
("/plugin/reload", "POST", "reload_plugins", True),
|
||||
("/plugin/readme", "GET", "get_plugin_readme", True),
|
||||
("/plugin/changelog", "GET", "get_plugin_changelog", True),
|
||||
("/plugin/source/get", "GET", "get_custom_source", False),
|
||||
("/plugin/source/save", "POST", "save_custom_source", False),
|
||||
("/plugin/source/get-failed-plugins", "GET", "get_failed_plugins", False),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrySource:
|
||||
@@ -52,28 +73,18 @@ class PluginRoute(Route):
|
||||
plugin_manager: PluginManager,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/plugin/get": ("GET", self.get_plugins),
|
||||
"/plugin/check-compat": ("POST", self.check_plugin_compatibility),
|
||||
"/plugin/install": ("POST", self.install_plugin),
|
||||
"/plugin/install-upload": ("POST", self.install_plugin_upload),
|
||||
"/plugin/update": ("POST", self.update_plugin),
|
||||
"/plugin/update-all": ("POST", self.update_all_plugins),
|
||||
"/plugin/uninstall": ("POST", self.uninstall_plugin),
|
||||
"/plugin/uninstall-failed": ("POST", self.uninstall_failed_plugin),
|
||||
"/plugin/market_list": ("GET", self.get_online_plugins),
|
||||
"/plugin/off": ("POST", self.off_plugin),
|
||||
"/plugin/on": ("POST", self.on_plugin),
|
||||
"/plugin/reload-failed": ("POST", self.reload_failed_plugins),
|
||||
"/plugin/reload": ("POST", self.reload_plugins),
|
||||
"/plugin/readme": ("GET", self.get_plugin_readme),
|
||||
"/plugin/changelog": ("GET", self.get_plugin_changelog),
|
||||
"/plugin/source/get": ("GET", self.get_custom_source),
|
||||
"/plugin/source/save": ("POST", self.save_custom_source),
|
||||
"/plugin/source/get-failed-plugins": ("GET", self.get_failed_plugins),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
self._guard_runtime_ready = lambda handler: guard_runtime_ready(
|
||||
self.core_lifecycle,
|
||||
handler,
|
||||
)
|
||||
self.routes = {}
|
||||
for path, method, handler_name, requires_runtime in PLUGIN_ROUTE_DEFINITIONS:
|
||||
handler = getattr(self, handler_name)
|
||||
if requires_runtime:
|
||||
handler = self._guard_runtime_ready(handler)
|
||||
self.routes[path] = (method, handler)
|
||||
self.register_routes()
|
||||
|
||||
self.translated_event_type = {
|
||||
|
||||
@@ -1,9 +1,90 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from quart import Quart
|
||||
from quart import Quart, jsonify
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
RUNTIME_LOADING_MESSAGE = "Runtime is still loading. Please try again shortly."
|
||||
RUNTIME_FAILED_MESSAGE = "Runtime bootstrap failed. Please check logs and retry."
|
||||
|
||||
|
||||
def is_runtime_request_ready(core_lifecycle: "AstrBotCoreLifecycle") -> bool:
|
||||
return getattr(core_lifecycle, "runtime_request_ready", core_lifecycle.runtime_ready)
|
||||
|
||||
|
||||
def get_runtime_guard_message(core_lifecycle: "AstrBotCoreLifecycle") -> str:
|
||||
failed = (
|
||||
core_lifecycle.runtime_failed
|
||||
or core_lifecycle.runtime_bootstrap_error is not None
|
||||
)
|
||||
return RUNTIME_FAILED_MESSAGE if failed else RUNTIME_LOADING_MESSAGE
|
||||
|
||||
|
||||
def build_runtime_status_data(
|
||||
core_lifecycle: "AstrBotCoreLifecycle",
|
||||
*,
|
||||
include_failure_details: bool = True,
|
||||
) -> dict[str, str | bool | None]:
|
||||
failure_message = None
|
||||
if include_failure_details and core_lifecycle.runtime_bootstrap_error is not None:
|
||||
failure_message = str(core_lifecycle.runtime_bootstrap_error)
|
||||
return {
|
||||
"state": core_lifecycle.lifecycle_state.value,
|
||||
"ready": is_runtime_request_ready(core_lifecycle),
|
||||
"failed": core_lifecycle.runtime_failed,
|
||||
"failure_message": failure_message,
|
||||
}
|
||||
|
||||
|
||||
def runtime_status_response(
|
||||
core_lifecycle: "AstrBotCoreLifecycle",
|
||||
status_code: int = 503,
|
||||
*,
|
||||
include_failure_details: bool = True,
|
||||
):
|
||||
message = get_runtime_guard_message(core_lifecycle)
|
||||
response = jsonify(
|
||||
Response(
|
||||
status="error",
|
||||
message=message,
|
||||
data=build_runtime_status_data(
|
||||
core_lifecycle,
|
||||
include_failure_details=include_failure_details,
|
||||
),
|
||||
).__dict__
|
||||
)
|
||||
response.status_code = status_code
|
||||
return response
|
||||
|
||||
|
||||
def runtime_loading_response(
|
||||
core_lifecycle: "AstrBotCoreLifecycle",
|
||||
status_code: int = 503,
|
||||
*,
|
||||
include_failure_details: bool = True,
|
||||
):
|
||||
return runtime_status_response(
|
||||
core_lifecycle,
|
||||
status_code=status_code,
|
||||
include_failure_details=include_failure_details,
|
||||
)
|
||||
|
||||
|
||||
def guard_runtime_ready(core_lifecycle: "AstrBotCoreLifecycle", handler):
|
||||
@wraps(handler)
|
||||
async def wrapped(*args: Any, **kwargs: Any):
|
||||
if not is_runtime_request_ready(core_lifecycle):
|
||||
return runtime_status_response(core_lifecycle)
|
||||
return await handler(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteContext:
|
||||
|
||||
@@ -22,7 +22,14 @@ from astrbot.core.utils.io import get_dashboard_version
|
||||
from astrbot.core.utils.storage_cleaner import StorageCleaner
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
from .route import (
|
||||
Response,
|
||||
Route,
|
||||
RouteContext,
|
||||
build_runtime_status_data,
|
||||
is_runtime_request_ready,
|
||||
runtime_loading_response,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_path(path: str | Path) -> Path:
|
||||
@@ -40,6 +47,7 @@ class StatRoute(Route):
|
||||
self.routes = {
|
||||
"/stat/get": ("GET", self.get_stat),
|
||||
"/stat/version": ("GET", self.get_version),
|
||||
"/stat/runtime-status": ("GET", self.get_runtime_status),
|
||||
"/stat/start-time": ("GET", self.get_start_time),
|
||||
"/stat/restart-core": ("POST", self.restart_core),
|
||||
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
|
||||
@@ -91,8 +99,16 @@ class StatRoute(Route):
|
||||
)
|
||||
|
||||
async def get_start_time(self):
|
||||
if not is_runtime_request_ready(self.core_lifecycle):
|
||||
return runtime_loading_response(
|
||||
self.core_lifecycle,
|
||||
include_failure_details=False,
|
||||
)
|
||||
return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__
|
||||
|
||||
async def get_runtime_status(self):
|
||||
return Response().ok(build_runtime_status_data(self.core_lifecycle)).__dict__
|
||||
|
||||
async def get_storage_status(self):
|
||||
try:
|
||||
status = await asyncio.to_thread(self.storage_cleaner.get_status)
|
||||
@@ -100,7 +116,7 @@ class StatRoute(Route):
|
||||
except Exception:
|
||||
logger.error("获取存储占用失败", exc_info=True)
|
||||
return (
|
||||
Response().error("获取存储占用失败,请查看后端日志了解详情。").__dict__
|
||||
Response().error("获取存储占用失败, 请查看后端日志了解详情。").__dict__
|
||||
)
|
||||
|
||||
async def cleanup_storage(self):
|
||||
@@ -116,9 +132,11 @@ class StatRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception:
|
||||
logger.error("清理存储失败", exc_info=True)
|
||||
return Response().error("清理存储失败,请查看后端日志了解详情。").__dict__
|
||||
return Response().error("清理存储失败, 请查看后端日志了解详情。").__dict__
|
||||
|
||||
async def get_stat(self):
|
||||
if not is_runtime_request_ready(self.core_lifecycle):
|
||||
return runtime_loading_response(self.core_lifecycle)
|
||||
offset_sec = request.args.get("offset_sec", 86400)
|
||||
offset_sec = int(offset_sec)
|
||||
try:
|
||||
|
||||
@@ -21,6 +21,7 @@ from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config as HyperConfig
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
from quart.typing import ResponseReturnValue
|
||||
from quart_cors import cors
|
||||
|
||||
from astrbot.core import logger
|
||||
@@ -62,10 +63,41 @@ from .routes import (
|
||||
UpdateRoute,
|
||||
)
|
||||
from .routes.api_key import ALL_OPEN_API_SCOPES
|
||||
from .routes.route import is_runtime_request_ready, runtime_loading_response
|
||||
|
||||
# Static assets shipped inside the wheel (built during `hatch build`).
|
||||
_BUNDLED_DIST = Path(__file__).parent / "dist"
|
||||
|
||||
_PUBLIC_ALLOWED_ENDPOINT_PREFIXES = (
|
||||
"/api/auth/login",
|
||||
"/api/file",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
"/api/backup/download",
|
||||
)
|
||||
_RUNTIME_EXTRA_BYPASS_ENDPOINT_PREFIXES = (
|
||||
"/api/stat/version",
|
||||
"/api/stat/runtime-status",
|
||||
"/api/stat/restart-core",
|
||||
"/api/stat/changelog",
|
||||
"/api/stat/changelog/list",
|
||||
"/api/stat/first-notice",
|
||||
)
|
||||
_RUNTIME_BYPASS_ENDPOINT_PREFIXES = (
|
||||
tuple(
|
||||
prefix
|
||||
for prefix in _PUBLIC_ALLOWED_ENDPOINT_PREFIXES
|
||||
if prefix != "/api/platform/webhook"
|
||||
)
|
||||
+ _RUNTIME_EXTRA_BYPASS_ENDPOINT_PREFIXES
|
||||
)
|
||||
_RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES = (
|
||||
"/api/config/",
|
||||
"/api/plugin/reload-failed",
|
||||
"/api/plugin/uninstall-failed",
|
||||
"/api/plugin/source/get-failed-plugins",
|
||||
)
|
||||
|
||||
|
||||
APP: Quart
|
||||
_ENV_PLACEHOLDER_RE = re.compile(
|
||||
@@ -122,12 +154,10 @@ class AstrBotJSONProvider(DefaultJSONProvider):
|
||||
class AstrBotDashboard:
|
||||
"""AstrBot Web Dashboard"""
|
||||
|
||||
ALLOWED_ENDPOINT_PREFIXES = (
|
||||
"/api/auth/login",
|
||||
"/api/file",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
"/api/backup/download",
|
||||
ALLOWED_ENDPOINT_PREFIXES = _PUBLIC_ALLOWED_ENDPOINT_PREFIXES
|
||||
RUNTIME_BYPASS_ENDPOINT_PREFIXES = _RUNTIME_BYPASS_ENDPOINT_PREFIXES
|
||||
RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES = (
|
||||
_RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -182,8 +212,8 @@ class AstrBotDashboard:
|
||||
|
||||
if self.enable_webui and not (Path(self.data_path) / "index.html").exists():
|
||||
logger.warning(
|
||||
f"前端未内置或未初始化 (index.html missing in {self.data_path}),"
|
||||
"回退到仅启动后端。请访问在线面板:dash.astrbot.men"
|
||||
f"前端未内置或未初始化 (index.html missing in {self.data_path}), "
|
||||
"回退到仅启动后端. 请访问在线面板: dash.astrbot.men"
|
||||
)
|
||||
self.enable_webui = False
|
||||
self._webui_fallback = True
|
||||
@@ -233,7 +263,7 @@ class AstrBotDashboard:
|
||||
@self.app.route("/")
|
||||
async def index():
|
||||
if not self.enable_webui:
|
||||
return "前端未启用,请访问在线面板:dash.astrbot.men"
|
||||
return "前端未启用, 请访问在线面板: dash.astrbot.men"
|
||||
try:
|
||||
return await self.app.send_static_file("index.html")
|
||||
except werkzeug.exceptions.NotFound:
|
||||
@@ -243,7 +273,7 @@ class AstrBotDashboard:
|
||||
@self.app.errorhandler(404)
|
||||
async def not_found(e):
|
||||
if not self.enable_webui:
|
||||
return "前端未启用,请访问在线面板:dash.astrbot.men"
|
||||
return "前端未启用, 请访问在线面板: dash.astrbot.men"
|
||||
if request.path.startswith("/api/"):
|
||||
return jsonify(Response().error("Not Found").to_json()), 404
|
||||
try:
|
||||
@@ -263,13 +293,14 @@ class AstrBotDashboard:
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
|
||||
def _init_routes(self, db: BaseDatabase):
|
||||
UpdateRoute(
|
||||
self.context, self.core_lifecycle.astrbot_updator, self.core_lifecycle
|
||||
)
|
||||
astrbot_updator = self.core_lifecycle.astrbot_updator
|
||||
plugin_manager = self.core_lifecycle.plugin_manager
|
||||
assert astrbot_updator is not None
|
||||
assert plugin_manager is not None
|
||||
|
||||
UpdateRoute(self.context, astrbot_updator, self.core_lifecycle)
|
||||
StatRoute(self.context, db, self.core_lifecycle)
|
||||
PluginRoute(
|
||||
self.context, self.core_lifecycle, self.core_lifecycle.plugin_manager
|
||||
)
|
||||
PluginRoute(self.context, self.core_lifecycle, plugin_manager)
|
||||
|
||||
self.command_route = CommandRoute(self.context)
|
||||
self.cr = ConfigRoute(self.context, self.core_lifecycle)
|
||||
@@ -308,21 +339,24 @@ class AstrBotDashboard:
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
view_func=self.srv_plug_route,
|
||||
view_func=self.guarded_srv_plug_route,
|
||||
methods=["GET", "POST"],
|
||||
)
|
||||
|
||||
def _init_plugin_route_index(self):
|
||||
"""将插件路由索引,避免 O(n) 查找"""
|
||||
self._plugin_route_map: dict[tuple[str, str], Callable] = {}
|
||||
if self.core_lifecycle.star_context.registered_web_apis is None:
|
||||
self.core_lifecycle.star_context.registered_web_apis = []
|
||||
star_context = self.core_lifecycle.star_context
|
||||
if star_context is None:
|
||||
return
|
||||
if star_context.registered_web_apis is None:
|
||||
star_context.registered_web_apis = []
|
||||
for (
|
||||
route,
|
||||
handler,
|
||||
methods,
|
||||
_,
|
||||
) in self.core_lifecycle.star_context.registered_web_apis:
|
||||
) in star_context.registered_web_apis:
|
||||
for method in methods:
|
||||
self._plugin_route_map[(route, method)] = handler
|
||||
|
||||
@@ -334,6 +368,48 @@ class AstrBotDashboard:
|
||||
logger.info("Initialized random JWT secret for dashboard.")
|
||||
self._jwt_secret = dashboard_cfg["jwt_secret"]
|
||||
|
||||
async def guarded_srv_plug_route(
|
||||
self, subpath: str, *args, **kwargs
|
||||
) -> ResponseReturnValue:
|
||||
guard_resp = self._maybe_runtime_guard(request.path)
|
||||
if guard_resp is not None:
|
||||
return guard_resp
|
||||
return await self.srv_plug_route(subpath, *args, **kwargs)
|
||||
|
||||
def _should_bypass_runtime_guard(self, path: str) -> bool:
|
||||
return any(
|
||||
path.startswith(prefix)
|
||||
for prefix in self.RUNTIME_BYPASS_ENDPOINT_PREFIXES
|
||||
)
|
||||
|
||||
def _should_allow_failed_runtime_recovery(self, path: str) -> bool:
|
||||
if not (
|
||||
self.core_lifecycle.runtime_failed
|
||||
or self.core_lifecycle.runtime_bootstrap_error is not None
|
||||
):
|
||||
return False
|
||||
return any(
|
||||
path.startswith(prefix)
|
||||
for prefix in self.RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES
|
||||
)
|
||||
|
||||
def _maybe_runtime_guard(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
include_failure_details: bool = True,
|
||||
) -> ResponseReturnValue | None:
|
||||
if self._should_bypass_runtime_guard(path):
|
||||
return None
|
||||
if self._should_allow_failed_runtime_recovery(path):
|
||||
return None
|
||||
if not is_runtime_request_ready(self.core_lifecycle):
|
||||
return runtime_loading_response(
|
||||
self.core_lifecycle,
|
||||
include_failure_details=include_failure_details,
|
||||
)
|
||||
return None
|
||||
|
||||
async def auth_middleware(self):
|
||||
# 放行CORS预检请求
|
||||
if request.method == "OPTIONS":
|
||||
@@ -372,9 +448,21 @@ class AstrBotDashboard:
|
||||
g.api_key_scopes = scopes
|
||||
g.username = f"api_key:{api_key.key_id}"
|
||||
await self.db.touch_api_key(api_key.key_id)
|
||||
guard_resp = self._maybe_runtime_guard(
|
||||
request.path,
|
||||
include_failure_details=False,
|
||||
)
|
||||
if guard_resp is not None:
|
||||
return guard_resp
|
||||
return None
|
||||
|
||||
if any(request.path.startswith(p) for p in self.ALLOWED_ENDPOINT_PREFIXES):
|
||||
guard_resp = self._maybe_runtime_guard(
|
||||
request.path,
|
||||
include_failure_details=False,
|
||||
)
|
||||
if guard_resp is not None:
|
||||
return guard_resp
|
||||
return None
|
||||
|
||||
token = request.headers.get("Authorization")
|
||||
@@ -394,14 +482,25 @@ class AstrBotDashboard:
|
||||
except jwt.PyJWTError:
|
||||
return self._unauthorized("Token 无效")
|
||||
|
||||
guard_resp = self._maybe_runtime_guard(request.path)
|
||||
if guard_resp is not None:
|
||||
return guard_resp
|
||||
|
||||
@staticmethod
|
||||
def _unauthorized(msg: str):
|
||||
r = jsonify(Response().error(msg).to_json())
|
||||
r.status_code = 401
|
||||
return r
|
||||
|
||||
def _get_plugin_handler(self, subpath: str, method: str) -> Callable | None:
|
||||
handler = self._plugin_route_map.get((f"/{subpath}", method))
|
||||
if handler is not None:
|
||||
return handler
|
||||
self._init_plugin_route_index()
|
||||
return self._plugin_route_map.get((f"/{subpath}", method))
|
||||
|
||||
async def srv_plug_route(self, subpath: str, *args, **kwargs):
|
||||
handler = self._plugin_route_map.get((f"/{subpath}", request.method))
|
||||
handler = self._get_plugin_handler(subpath, request.method)
|
||||
if not handler:
|
||||
return jsonify(Response().error("未找到该路由").to_json())
|
||||
|
||||
@@ -481,10 +580,10 @@ class AstrBotDashboard:
|
||||
"""Run dashboard server (blocking)"""
|
||||
if self._webui_fallback:
|
||||
logger.warning(
|
||||
"前端未内置或未初始化,回退到仅启动后端。请访问在线面板:dash.astrbot.men"
|
||||
"前端未内置或未初始化, 回退到仅启动后端. 请访问在线面板: dash.astrbot.men"
|
||||
)
|
||||
elif not self.enable_webui:
|
||||
logger.warning("前端已禁用,请访问在线面板:dash.astrbot.men")
|
||||
logger.warning("前端已禁用, 请访问在线面板: dash.astrbot.men")
|
||||
|
||||
dashboard_config = self.config.get("dashboard", {})
|
||||
host_value = os.environ.get("ASTRBOT_HOST") or dashboard_config.get(
|
||||
|
||||
@@ -9,8 +9,9 @@ import uuid
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from quart import Quart
|
||||
@@ -18,11 +19,17 @@ from werkzeug.datastructures import FileStorage
|
||||
|
||||
from astrbot.cli.commands.cmd_conf import hash_dashboard_password_secure
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle, LifecycleState
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.utils.pip_installer import PipInstallError
|
||||
from astrbot.dashboard.routes.live_chat import (
|
||||
LiveChatRoute,
|
||||
LiveChatSession,
|
||||
_ReceiveTimeoutSentinel,
|
||||
)
|
||||
from astrbot.dashboard.routes.open_api import OpenApiRoute
|
||||
from astrbot.dashboard.routes.plugin import PluginRoute
|
||||
from astrbot.dashboard.server import AstrBotDashboard, _expand_env_placeholders
|
||||
from tests.fixtures.helpers import (
|
||||
@@ -150,6 +157,86 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _set_runtime_loading(core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
core_lifecycle._set_lifecycle_state(LifecycleState.CORE_READY)
|
||||
core_lifecycle.runtime_ready_event.clear()
|
||||
core_lifecycle.runtime_failed_event.clear()
|
||||
core_lifecycle.runtime_bootstrap_error = None
|
||||
core_lifecycle.runtime_request_ready = False
|
||||
|
||||
|
||||
def _restore_runtime_ready(core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
core_lifecycle._set_lifecycle_state(LifecycleState.RUNTIME_READY)
|
||||
core_lifecycle.runtime_ready_event.set()
|
||||
core_lifecycle.runtime_failed_event.clear()
|
||||
core_lifecycle.runtime_bootstrap_error = None
|
||||
core_lifecycle.runtime_request_ready = True
|
||||
|
||||
|
||||
def _set_runtime_failed(
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
message: str = "Runtime bootstrap failed.",
|
||||
) -> None:
|
||||
core_lifecycle._set_lifecycle_state(LifecycleState.RUNTIME_FAILED)
|
||||
core_lifecycle.runtime_ready_event.clear()
|
||||
core_lifecycle.runtime_failed_event.set()
|
||||
core_lifecycle.runtime_bootstrap_error = RuntimeError(message)
|
||||
core_lifecycle.runtime_request_ready = False
|
||||
|
||||
|
||||
def _set_runtime_starting(core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
core_lifecycle._set_lifecycle_state(LifecycleState.RUNTIME_READY)
|
||||
core_lifecycle.runtime_ready_event.set()
|
||||
core_lifecycle.runtime_failed_event.clear()
|
||||
core_lifecycle.runtime_bootstrap_error = None
|
||||
core_lifecycle.runtime_request_ready = False
|
||||
|
||||
|
||||
def _assert_runtime_loading_response(
|
||||
data: dict,
|
||||
state: LifecycleState = LifecycleState.CORE_READY,
|
||||
) -> None:
|
||||
assert data == {
|
||||
"status": "error",
|
||||
"message": "Runtime is still loading. Please try again shortly.",
|
||||
"data": {
|
||||
"state": state.value,
|
||||
"ready": False,
|
||||
"failed": False,
|
||||
"failure_message": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _assert_runtime_failed_response(
|
||||
data: dict,
|
||||
message: str | None,
|
||||
state: SimpleNamespace | LifecycleState = LifecycleState.RUNTIME_FAILED,
|
||||
) -> None:
|
||||
assert data == {
|
||||
"status": "error",
|
||||
"message": "Runtime bootstrap failed. Please check logs and retry.",
|
||||
"data": {
|
||||
"state": state.value,
|
||||
"ready": False,
|
||||
"failed": True,
|
||||
"failure_message": message,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_plugin_upload_files() -> dict:
|
||||
return {
|
||||
"files": {
|
||||
"file": FileStorage(
|
||||
stream=io.BytesIO(b"fake-plugin-archive"),
|
||||
filename="demo-plugin.zip",
|
||||
content_type="application/zip",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Tests the login functionality with both wrong and correct credentials."""
|
||||
@@ -517,6 +604,847 @@ async def test_batch_delete_sessions_uses_batch_lookup(
|
||||
assert called["batch_lookup_count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_status_reports_current_state(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/stat/runtime-status",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["state"] == LifecycleState.RUNTIME_READY.value
|
||||
assert data["data"]["ready"] is True
|
||||
assert data["data"]["failed"] is False
|
||||
assert data["data"]["failure_message"] is None
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/stat/runtime-status",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["state"] == LifecycleState.CORE_READY.value
|
||||
assert data["data"]["ready"] is False
|
||||
assert data["data"]["failed"] is False
|
||||
assert data["data"]["failure_message"] is None
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "bootstrap exploded during provider init")
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/stat/runtime-status",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["data"]["state"] == "runtime_failed"
|
||||
assert data["data"]["ready"] is False
|
||||
assert data["data"]["failed"] is True
|
||||
assert (
|
||||
data["data"]["failure_message"] == "bootstrap exploded during provider init"
|
||||
)
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("path", "headers", "misleading_field"),
|
||||
[
|
||||
("/api/stat/get", "auth", "platform"),
|
||||
("/api/stat/start-time", None, "start_time"),
|
||||
],
|
||||
ids=["stat-get", "stat-start-time"],
|
||||
)
|
||||
async def test_stat_endpoints_return_503_while_runtime_loading(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
path: str,
|
||||
headers: str | None,
|
||||
misleading_field: str,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
request_kwargs = {}
|
||||
if headers == "auth":
|
||||
request_kwargs["headers"] = authenticated_header
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
response = await test_client.get(path, **request_kwargs)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_loading_response(data)
|
||||
assert misleading_field not in data["data"]
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("path", "headers", "misleading_field"),
|
||||
[
|
||||
("/api/stat/get", "auth", "platform"),
|
||||
("/api/stat/start-time", None, "start_time"),
|
||||
],
|
||||
ids=["stat-get", "stat-start-time"],
|
||||
)
|
||||
async def test_stat_endpoints_return_failure_aware_503_after_bootstrap_failure(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
path: str,
|
||||
headers: str | None,
|
||||
misleading_field: str,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
request_kwargs = {}
|
||||
if headers == "auth":
|
||||
request_kwargs["headers"] = authenticated_header
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "runtime bootstrap failed in plugin reload")
|
||||
try:
|
||||
response = await test_client.get(path, **request_kwargs)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_failed_response(
|
||||
data,
|
||||
(
|
||||
None
|
||||
if path == "/api/stat/start-time"
|
||||
else "runtime bootstrap failed in plugin reload"
|
||||
),
|
||||
)
|
||||
assert misleading_field not in data["data"]
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_dependent_dashboard_route_returns_503_while_runtime_starting(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
_set_runtime_starting(core_lifecycle_td)
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/config/get",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_loading_response(data, state=LifecycleState.RUNTIME_READY)
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_configs_returns_503_while_runtime_loading(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
create_res = await test_client.post(
|
||||
"/api/apikey/create",
|
||||
json={"name": f"runtime-guard-{uuid.uuid4().hex[:8]}", "scopes": ["config"]},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert create_res.status_code == 200
|
||||
create_data = await create_res.get_json()
|
||||
api_key = create_data["data"]["api_key"]
|
||||
|
||||
assert core_lifecycle_td.astrbot_config_mgr is not None
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_config_mgr,
|
||||
"get_conf_list",
|
||||
MagicMock(side_effect=AssertionError("config list should not be read")),
|
||||
)
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/v1/configs",
|
||||
headers={"X-API-Key": api_key},
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_loading_response(data)
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_failure_response_hides_bootstrap_error_details(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
create_res = await test_client.post(
|
||||
"/api/apikey/create",
|
||||
json={"name": f"runtime-failed-{uuid.uuid4().hex[:8]}", "scopes": ["config"]},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert create_res.status_code == 200
|
||||
create_data = await create_res.get_json()
|
||||
api_key = create_data["data"]["api_key"]
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "provider secret exploded")
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/v1/configs",
|
||||
headers={"X-API-Key": api_key},
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_failed_response(data, None)
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_chat_ws_closes_while_runtime_loading(
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
route = OpenApiRoute.__new__(OpenApiRoute)
|
||||
route.core_lifecycle = core_lifecycle_td
|
||||
route._authenticate_chat_ws_api_key = AsyncMock(return_value=(True, None))
|
||||
route._send_chat_ws_error = AsyncMock()
|
||||
|
||||
fake_websocket = MagicMock()
|
||||
fake_websocket.receive_json = AsyncMock(
|
||||
side_effect=AssertionError("websocket should not consume messages")
|
||||
)
|
||||
fake_websocket.close = AsyncMock()
|
||||
monkeypatch.setattr("astrbot.dashboard.routes.open_api.websocket", fake_websocket)
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
await route.chat_ws()
|
||||
route._send_chat_ws_error.assert_awaited_once()
|
||||
fake_websocket.close.assert_awaited_once()
|
||||
fake_websocket.receive_json.assert_not_awaited()
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_chat_ws_closes_when_runtime_stops_mid_session(
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
route = OpenApiRoute.__new__(OpenApiRoute)
|
||||
route.core_lifecycle = core_lifecycle_td
|
||||
route._authenticate_chat_ws_api_key = AsyncMock(return_value=(True, None))
|
||||
route._send_chat_ws_error = AsyncMock()
|
||||
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
fake_websocket = MagicMock()
|
||||
fake_websocket.send_json = AsyncMock()
|
||||
fake_websocket.close = AsyncMock()
|
||||
receive_calls = 0
|
||||
|
||||
async def receive_json():
|
||||
nonlocal receive_calls
|
||||
receive_calls += 1
|
||||
if receive_calls > 1:
|
||||
raise AssertionError("websocket should close before next receive")
|
||||
core_lifecycle_td.runtime_request_ready = False
|
||||
return {"t": "ping"}
|
||||
|
||||
fake_websocket.receive_json = AsyncMock(side_effect=receive_json)
|
||||
monkeypatch.setattr("astrbot.dashboard.routes.open_api.websocket", fake_websocket)
|
||||
|
||||
await route.chat_ws()
|
||||
|
||||
route._send_chat_ws_error.assert_awaited_once()
|
||||
fake_websocket.send_json.assert_not_awaited()
|
||||
fake_websocket.close.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_chat_ws_closes_while_runtime_loading(
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
route = LiveChatRoute.__new__(LiveChatRoute)
|
||||
route.core_lifecycle = core_lifecycle_td
|
||||
route.db = MagicMock()
|
||||
route.plugin_manager = core_lifecycle_td.plugin_manager
|
||||
route.platform_history_mgr = core_lifecycle_td.platform_message_history_manager
|
||||
route.sessions = {}
|
||||
route.config = core_lifecycle_td.astrbot_config
|
||||
route.attachments_dir = str(tmp_path / "attachments")
|
||||
route.legacy_img_dir = str(tmp_path / "legacy")
|
||||
|
||||
fake_websocket = MagicMock()
|
||||
fake_websocket.args = {"token": "test-token"}
|
||||
fake_websocket.receive_json = AsyncMock(
|
||||
side_effect=AssertionError("live chat websocket should not consume messages")
|
||||
)
|
||||
fake_websocket.close = AsyncMock()
|
||||
monkeypatch.setattr("astrbot.dashboard.routes.live_chat.websocket", fake_websocket)
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
with patch(
|
||||
"astrbot.dashboard.routes.live_chat.jwt.decode",
|
||||
return_value={"username": "astrbot"},
|
||||
):
|
||||
await route._unified_ws_loop(force_ct="live")
|
||||
fake_websocket.close.assert_awaited_once()
|
||||
fake_websocket.receive_json.assert_not_awaited()
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_chat_ws_closes_when_runtime_stops_mid_session(
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
route = LiveChatRoute.__new__(LiveChatRoute)
|
||||
route.core_lifecycle = core_lifecycle_td
|
||||
route.db = MagicMock()
|
||||
route.plugin_manager = core_lifecycle_td.plugin_manager
|
||||
route.platform_history_mgr = core_lifecycle_td.platform_message_history_manager
|
||||
route.sessions = {}
|
||||
route.config = core_lifecycle_td.astrbot_config
|
||||
route.attachments_dir = str(tmp_path / "attachments")
|
||||
route.legacy_img_dir = str(tmp_path / "legacy")
|
||||
route._handle_message = AsyncMock(
|
||||
side_effect=lambda *_args, **_kwargs: setattr(
|
||||
core_lifecycle_td,
|
||||
"runtime_request_ready",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
fake_websocket = MagicMock()
|
||||
fake_websocket.args = {"token": "test-token"}
|
||||
fake_websocket.close = AsyncMock()
|
||||
fake_websocket.receive_json = AsyncMock(
|
||||
side_effect=[{}, AssertionError("live chat websocket should close before next receive")]
|
||||
)
|
||||
monkeypatch.setattr("astrbot.dashboard.routes.live_chat.websocket", fake_websocket)
|
||||
|
||||
with patch(
|
||||
"astrbot.dashboard.routes.live_chat.jwt.decode",
|
||||
return_value={"username": "astrbot"},
|
||||
):
|
||||
await route._unified_ws_loop(force_ct="live")
|
||||
|
||||
route._handle_message.assert_awaited_once()
|
||||
fake_websocket.close.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_chat_recv_guard_polls_until_runtime_stops(
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
route = LiveChatRoute.__new__(LiveChatRoute)
|
||||
route.core_lifecycle = core_lifecycle_td
|
||||
|
||||
fake_websocket = MagicMock()
|
||||
fake_websocket.close = AsyncMock()
|
||||
fake_websocket.receive_json = AsyncMock(
|
||||
side_effect=AssertionError("receive_json should be wrapped by wait_for")
|
||||
)
|
||||
monkeypatch.setattr("astrbot.dashboard.routes.live_chat.websocket", fake_websocket)
|
||||
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
async def fake_wait_for(awaitable, **kwargs):
|
||||
del kwargs
|
||||
if hasattr(awaitable, "close"):
|
||||
awaitable.close()
|
||||
core_lifecycle_td.runtime_request_ready = False
|
||||
raise asyncio.TimeoutError
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.live_chat.asyncio.wait_for",
|
||||
fake_wait_for,
|
||||
)
|
||||
|
||||
first = await route._recv_ws_json_guarded(wait_timeout=0.01)
|
||||
assert isinstance(first, _ReceiveTimeoutSentinel)
|
||||
|
||||
second = await route._recv_ws_json_guarded(wait_timeout=0.01)
|
||||
assert second is None
|
||||
fake_websocket.close.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_chat_subscription_stops_forwarding_when_runtime_stops(
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
route = LiveChatRoute.__new__(LiveChatRoute)
|
||||
route.core_lifecycle = core_lifecycle_td
|
||||
route._send_chat_payload = AsyncMock(
|
||||
side_effect=AssertionError("subscription should not forward after runtime stop")
|
||||
)
|
||||
|
||||
session = LiveChatSession("session", "astrbot")
|
||||
session.chat_subscriptions["chat-session"] = "request-id"
|
||||
session.chat_subscription_tasks["chat-session"] = MagicMock()
|
||||
|
||||
fake_queue = MagicMock()
|
||||
fake_queue.get = AsyncMock(
|
||||
side_effect=[{"type": "plain", "data": "hello"}, asyncio.CancelledError()]
|
||||
)
|
||||
remove_back_queue = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.live_chat.webchat_queue_mgr.get_or_create_back_queue",
|
||||
MagicMock(return_value=fake_queue),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.live_chat.webchat_queue_mgr.remove_back_queue",
|
||||
remove_back_queue,
|
||||
)
|
||||
|
||||
core_lifecycle_td.runtime_request_ready = False
|
||||
|
||||
await route._forward_chat_subscription(session, "chat-session", "request-id")
|
||||
|
||||
route._send_chat_payload.assert_not_awaited()
|
||||
remove_back_queue.assert_called_once_with("request-id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_start_time_failure_response_does_not_leak_bootstrap_error(
|
||||
app: Quart,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "provider api key leaked in stacktrace")
|
||||
try:
|
||||
response = await test_client.get("/api/stat/start-time")
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
assert data == {
|
||||
"status": "error",
|
||||
"message": "Runtime bootstrap failed. Please check logs and retry.",
|
||||
"data": {
|
||||
"state": LifecycleState.RUNTIME_FAILED.value,
|
||||
"ready": False,
|
||||
"failed": True,
|
||||
"failure_message": None,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_webhook_failure_response_does_not_leak_bootstrap_error(
|
||||
app: Quart,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "provider webhook secret leaked")
|
||||
try:
|
||||
response = await test_client.post("/api/platform/webhook/test-webhook")
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_failed_response(data, None)
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_routes_remain_available_after_runtime_bootstrap_failure(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "broken provider config")
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/config/get",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("method", "path", "request_kwargs_factory", "setup_recovery"),
|
||||
[
|
||||
(
|
||||
"get",
|
||||
"/api/plugin/source/get-failed-plugins",
|
||||
lambda: {},
|
||||
lambda pm: pm.failed_plugin_dict.update(
|
||||
{"broken-plugin": {"error": "boom"}}
|
||||
),
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/reload-failed",
|
||||
lambda: {"json": {"dir_name": "broken-plugin"}},
|
||||
lambda pm: setattr(
|
||||
pm,
|
||||
"reload_failed_plugin",
|
||||
AsyncMock(return_value=(True, None)),
|
||||
),
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/uninstall-failed",
|
||||
lambda: {"json": {"dir_name": "broken-plugin"}},
|
||||
lambda pm: setattr(pm, "uninstall_failed_plugin", AsyncMock()),
|
||||
),
|
||||
],
|
||||
ids=["get-failed-plugins", "reload-failed", "uninstall-failed"],
|
||||
)
|
||||
async def test_failed_plugin_recovery_routes_remain_available_after_bootstrap_failure(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
method: str,
|
||||
path: str,
|
||||
request_kwargs_factory,
|
||||
setup_recovery,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
plugin_manager = core_lifecycle_td.plugin_manager
|
||||
assert plugin_manager is not None
|
||||
|
||||
setup_recovery(plugin_manager)
|
||||
_set_runtime_failed(core_lifecycle_td, "plugin bootstrap failed")
|
||||
try:
|
||||
response = await getattr(test_client, method)(
|
||||
path,
|
||||
headers=authenticated_header,
|
||||
**request_kwargs_factory(),
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("method", "path", "request_kwargs_factory", "guard_attr"),
|
||||
[
|
||||
("get", "/api/plugin/get", lambda: {}, "context.get_all_stars"),
|
||||
(
|
||||
"get",
|
||||
"/api/plugin/source/get-failed-plugins",
|
||||
lambda: {},
|
||||
None,
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/install",
|
||||
lambda: {"json": {"url": "https://example.com/plugin"}},
|
||||
"install_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/install-upload",
|
||||
_build_plugin_upload_files,
|
||||
"install_plugin_from_file",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/update",
|
||||
lambda: {"json": {"name": "demo-plugin"}},
|
||||
"update_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/update-all",
|
||||
lambda: {"json": {"names": ["demo-plugin"]}},
|
||||
"update_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/uninstall",
|
||||
lambda: {"json": {"name": "demo-plugin"}},
|
||||
"uninstall_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/uninstall-failed",
|
||||
lambda: {"json": {"dir_name": "demo-plugin"}},
|
||||
"uninstall_failed_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/off",
|
||||
lambda: {"json": {"name": "demo-plugin"}},
|
||||
"turn_off_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/on",
|
||||
lambda: {"json": {"name": "demo-plugin"}},
|
||||
"turn_on_plugin",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/reload",
|
||||
lambda: {"json": {"name": "demo-plugin"}},
|
||||
"reload",
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/reload-failed",
|
||||
lambda: {"json": {"dir_name": "demo-plugin"}},
|
||||
"reload_failed_plugin",
|
||||
),
|
||||
(
|
||||
"get",
|
||||
"/api/plugin/readme?name=demo-plugin",
|
||||
lambda: {},
|
||||
"context.get_all_stars",
|
||||
),
|
||||
(
|
||||
"get",
|
||||
"/api/plugin/changelog?name=demo-plugin",
|
||||
lambda: {},
|
||||
"context.get_all_stars",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"get",
|
||||
"get-failed-plugins",
|
||||
"install",
|
||||
"install-upload",
|
||||
"update",
|
||||
"update-all",
|
||||
"uninstall",
|
||||
"uninstall-failed",
|
||||
"off",
|
||||
"on",
|
||||
"reload",
|
||||
"reload-failed",
|
||||
"readme",
|
||||
"changelog",
|
||||
],
|
||||
)
|
||||
async def test_plugin_api_returns_503_while_runtime_loading(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
method: str,
|
||||
path: str,
|
||||
request_kwargs_factory,
|
||||
guard_attr: str,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
plugin_manager = core_lifecycle_td.plugin_manager
|
||||
assert plugin_manager is not None
|
||||
|
||||
if guard_attr == "context.get_all_stars":
|
||||
monkeypatch.setattr(
|
||||
plugin_manager.context,
|
||||
"get_all_stars",
|
||||
lambda: (_ for _ in ()).throw(
|
||||
AssertionError("plugin state should not be read")
|
||||
),
|
||||
)
|
||||
elif guard_attr is not None:
|
||||
monkeypatch.setattr(
|
||||
plugin_manager,
|
||||
guard_attr,
|
||||
AsyncMock(side_effect=AssertionError(f"{guard_attr} should not be called")),
|
||||
)
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
request_kwargs = request_kwargs_factory()
|
||||
response = await getattr(test_client, method)(
|
||||
path,
|
||||
headers=authenticated_header,
|
||||
**request_kwargs,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_loading_response(data)
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("method", "path", "request_kwargs_factory", "guard_attr"),
|
||||
[
|
||||
("get", "/api/plugin/get", lambda: {}, "context.get_all_stars"),
|
||||
(
|
||||
"post",
|
||||
"/api/plugin/install",
|
||||
lambda: {"json": {"url": "https://example.com/plugin"}},
|
||||
"install_plugin",
|
||||
),
|
||||
],
|
||||
ids=["get", "install"],
|
||||
)
|
||||
async def test_plugin_api_returns_failed_response_after_runtime_bootstrap_failure(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
method: str,
|
||||
path: str,
|
||||
request_kwargs_factory,
|
||||
guard_attr: str,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
plugin_manager = core_lifecycle_td.plugin_manager
|
||||
assert plugin_manager is not None
|
||||
|
||||
if guard_attr == "context.get_all_stars":
|
||||
monkeypatch.setattr(
|
||||
plugin_manager.context,
|
||||
"get_all_stars",
|
||||
lambda: (_ for _ in ()).throw(
|
||||
AssertionError("plugin state should not be read")
|
||||
),
|
||||
)
|
||||
elif guard_attr is not None:
|
||||
monkeypatch.setattr(
|
||||
plugin_manager,
|
||||
guard_attr,
|
||||
AsyncMock(side_effect=AssertionError(f"{guard_attr} should not be called")),
|
||||
)
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "plugin bootstrap failed")
|
||||
try:
|
||||
request_kwargs = request_kwargs_factory()
|
||||
response = await getattr(test_client, method)(
|
||||
path,
|
||||
headers=authenticated_header,
|
||||
**request_kwargs,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_failed_response(data, "plugin bootstrap failed")
|
||||
finally:
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_web_route_returns_503_while_runtime_loading(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
route_called = False
|
||||
star_context = core_lifecycle_td.star_context
|
||||
assert star_context is not None
|
||||
|
||||
async def dummy_plugin_route(*args, **kwargs):
|
||||
nonlocal route_called
|
||||
route_called = True
|
||||
return {"status": "ok", "message": None, "data": {"called": True}}
|
||||
|
||||
registered_web_apis = star_context.registered_web_apis
|
||||
original_registered_web_apis = list(registered_web_apis)
|
||||
registered_web_apis[:] = [
|
||||
("/runtime-guard-test", dummy_plugin_route, ["GET"], "runtime guard test"),
|
||||
]
|
||||
|
||||
_set_runtime_loading(core_lifecycle_td)
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/plug/runtime-guard-test",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_loading_response(data)
|
||||
assert route_called is False
|
||||
finally:
|
||||
registered_web_apis[:] = original_registered_web_apis
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_web_route_returns_failed_response_after_runtime_bootstrap_failure(
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
route_called = False
|
||||
star_context = core_lifecycle_td.star_context
|
||||
assert star_context is not None
|
||||
|
||||
async def dummy_plugin_route(*args, **kwargs):
|
||||
nonlocal route_called
|
||||
route_called = True
|
||||
return {"status": "ok", "message": None, "data": {"called": True}}
|
||||
|
||||
registered_web_apis = star_context.registered_web_apis
|
||||
original_registered_web_apis = list(registered_web_apis)
|
||||
registered_web_apis[:] = [
|
||||
(
|
||||
"/runtime-failed-guard-test",
|
||||
dummy_plugin_route,
|
||||
["GET"],
|
||||
"runtime guard test",
|
||||
),
|
||||
]
|
||||
|
||||
_set_runtime_failed(core_lifecycle_td, "plugin web runtime bootstrap failed")
|
||||
try:
|
||||
response = await test_client.get(
|
||||
"/api/plug/runtime-failed-guard-test",
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = await response.get_json()
|
||||
_assert_runtime_failed_response(
|
||||
data,
|
||||
"plugin web runtime bootstrap failed",
|
||||
)
|
||||
assert route_called is False
|
||||
finally:
|
||||
registered_web_apis[:] = original_registered_web_apis
|
||||
_restore_runtime_ready(core_lifecycle_td)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(
|
||||
app: Quart,
|
||||
@@ -550,7 +1478,9 @@ async def test_plugins(
|
||||
assert data["status"] == "ok"
|
||||
|
||||
# 使用 MockPluginBuilder 创建测试插件
|
||||
plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path
|
||||
plugin_manager = core_lifecycle_td.plugin_manager
|
||||
assert plugin_manager is not None
|
||||
plugin_store_path = plugin_manager.plugin_store_path
|
||||
builder = MockPluginBuilder(plugin_store_path)
|
||||
|
||||
# 定义测试插件
|
||||
@@ -565,10 +1495,8 @@ async def test_plugins(
|
||||
mock_update = create_mock_updater_update(builder)
|
||||
|
||||
# 设置 Mock
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.plugin_manager.updator, "install", mock_install
|
||||
)
|
||||
monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update)
|
||||
monkeypatch.setattr(plugin_manager.updator, "install", mock_install)
|
||||
monkeypatch.setattr(plugin_manager.updator, "update", mock_update)
|
||||
|
||||
try:
|
||||
# 插件安装
|
||||
@@ -1103,7 +2031,7 @@ async def test_do_update(
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists(release_path)
|
||||
assert await anyio.Path(release_path).exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
171
tests/unit/test_initial_loader.py
Normal file
171
tests/unit/test_initial_loader.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Tests for InitialLoader."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_loader_start_awaits_initialize_core_and_schedules_runtime_bootstrap():
|
||||
"""Test InitialLoader.start splits core init from background runtime bootstrap."""
|
||||
loader = InitialLoader(MagicMock(), MagicMock())
|
||||
call_order: list[str] = []
|
||||
real_create_task = asyncio.create_task
|
||||
created_tasks: list[asyncio.Task] = []
|
||||
|
||||
lifecycle = MagicMock()
|
||||
lifecycle.dashboard_shutdown_event = asyncio.Event()
|
||||
lifecycle.runtime_bootstrap_task = None
|
||||
|
||||
async def initialize_core() -> None:
|
||||
call_order.append("initialize_core")
|
||||
|
||||
async def bootstrap_runtime() -> None:
|
||||
call_order.append("bootstrap_runtime")
|
||||
|
||||
async def start_core() -> None:
|
||||
call_order.append("core_start")
|
||||
|
||||
async def run_dashboard() -> None:
|
||||
call_order.append("dashboard_run")
|
||||
|
||||
lifecycle.initialize = AsyncMock(
|
||||
side_effect=AssertionError("initialize should not be used")
|
||||
)
|
||||
lifecycle.initialize_core = AsyncMock(side_effect=initialize_core)
|
||||
lifecycle.bootstrap_runtime = AsyncMock(side_effect=bootstrap_runtime)
|
||||
lifecycle.start = AsyncMock(side_effect=start_core)
|
||||
|
||||
dashboard = MagicMock()
|
||||
dashboard.run = AsyncMock(side_effect=run_dashboard)
|
||||
|
||||
def dashboard_factory(*args, **kwargs):
|
||||
del args, kwargs
|
||||
call_order.append("dashboard_init")
|
||||
return dashboard
|
||||
|
||||
def create_task(coro, *args, **kwargs):
|
||||
call_order.append("create_task")
|
||||
task = real_create_task(coro, *args, **kwargs)
|
||||
created_tasks.append(task)
|
||||
return task
|
||||
|
||||
with (
|
||||
patch(
|
||||
"astrbot.core.initial_loader.AstrBotCoreLifecycle", return_value=lifecycle
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.initial_loader.AstrBotDashboard",
|
||||
side_effect=dashboard_factory,
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.initial_loader.asyncio.create_task", side_effect=create_task
|
||||
),
|
||||
):
|
||||
await loader.start()
|
||||
|
||||
lifecycle.initialize.assert_not_called()
|
||||
lifecycle.initialize_core.assert_awaited_once()
|
||||
lifecycle.bootstrap_runtime.assert_awaited_once()
|
||||
lifecycle.start.assert_awaited_once()
|
||||
dashboard.run.assert_awaited_once()
|
||||
assert call_order[:3] == ["initialize_core", "create_task", "dashboard_init"]
|
||||
assert len(created_tasks) == 1
|
||||
assert lifecycle.runtime_bootstrap_task is created_tasks[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_loader_start_returns_without_partial_start_when_initialize_core_fails():
|
||||
"""Test InitialLoader.start aborts cleanly if initialize_core fails."""
|
||||
loader = InitialLoader(MagicMock(), MagicMock())
|
||||
|
||||
lifecycle = MagicMock()
|
||||
lifecycle.runtime_bootstrap_task = None
|
||||
expected_error = RuntimeError("core init failed")
|
||||
lifecycle.initialize_core = AsyncMock(side_effect=expected_error)
|
||||
lifecycle.bootstrap_runtime = AsyncMock()
|
||||
lifecycle.start = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"astrbot.core.initial_loader.AstrBotCoreLifecycle", return_value=lifecycle
|
||||
),
|
||||
patch("astrbot.core.initial_loader.AstrBotDashboard") as dashboard_cls,
|
||||
patch("astrbot.core.initial_loader.asyncio.create_task") as create_task,
|
||||
):
|
||||
await loader.start()
|
||||
|
||||
lifecycle.initialize_core.assert_awaited_once()
|
||||
dashboard_cls.assert_not_called()
|
||||
create_task.assert_not_called()
|
||||
lifecycle.bootstrap_runtime.assert_not_called()
|
||||
lifecycle.start.assert_not_called()
|
||||
assert lifecycle.runtime_bootstrap_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("failing_component", "expected_order"),
|
||||
[
|
||||
("core", ["initialize_core", "bootstrap_runtime", "core_start", "dashboard_run"]),
|
||||
(
|
||||
"dashboard",
|
||||
["initialize_core", "bootstrap_runtime", "core_start", "dashboard_run"],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_initial_loader_start_stops_lifecycle_when_runtime_task_raises(
|
||||
failing_component: str,
|
||||
expected_order: list[str],
|
||||
):
|
||||
"""Test InitialLoader.start stops lifecycle if a runtime task crashes."""
|
||||
loader = InitialLoader(MagicMock(), MagicMock())
|
||||
call_order: list[str] = []
|
||||
runtime_error = RuntimeError(f"{failing_component} failed")
|
||||
|
||||
lifecycle = MagicMock()
|
||||
lifecycle.dashboard_shutdown_event = asyncio.Event()
|
||||
lifecycle.runtime_bootstrap_task = None
|
||||
lifecycle.stop = AsyncMock()
|
||||
|
||||
async def initialize_core() -> None:
|
||||
call_order.append("initialize_core")
|
||||
|
||||
async def bootstrap_runtime() -> None:
|
||||
call_order.append("bootstrap_runtime")
|
||||
|
||||
async def start_core() -> None:
|
||||
call_order.append("core_start")
|
||||
if failing_component == "core":
|
||||
raise runtime_error
|
||||
|
||||
async def run_dashboard() -> None:
|
||||
call_order.append("dashboard_run")
|
||||
if failing_component == "dashboard":
|
||||
raise runtime_error
|
||||
|
||||
lifecycle.initialize_core = AsyncMock(side_effect=initialize_core)
|
||||
lifecycle.bootstrap_runtime = AsyncMock(side_effect=bootstrap_runtime)
|
||||
lifecycle.start = AsyncMock(side_effect=start_core)
|
||||
|
||||
dashboard = MagicMock()
|
||||
dashboard.run = AsyncMock(side_effect=run_dashboard)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"astrbot.core.initial_loader.AstrBotCoreLifecycle",
|
||||
return_value=lifecycle,
|
||||
),
|
||||
patch(
|
||||
"astrbot.core.initial_loader.AstrBotDashboard",
|
||||
return_value=dashboard,
|
||||
),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match=f"{failing_component} failed"):
|
||||
await loader.start()
|
||||
|
||||
lifecycle.stop.assert_awaited_once()
|
||||
assert call_order == expected_order
|
||||
Reference in New Issue
Block a user