Compare commits

...

6 Commits

Author SHA1 Message Date
邹永赫
6f5541bc7e fix: restore orchestrator logger binding 2026-03-25 21:50:48 +09:00
邹永赫
e7f57ae8ef fix: simplify runtime lifecycle coordination 2026-03-25 21:33:49 +09:00
邹永赫
0100f8d20c fix: streamline runtime guard handling 2026-03-25 20:12:05 +09:00
邹永赫
c1e2040f43 fix: harden deferred startup recovery 2026-03-25 19:42:40 +09:00
邹永赫
ae53b9fc9f fix: harden runtime cleanup review fixes
Continue terminating remaining providers and disable MCP servers even if one provider terminate hook fails.

Also add InitialLoader failure-path coverage and extract guarded plugin routes into a shared constant for easier review and maintenance.
2026-03-25 14:03:06 +09:00
邹永赫
63cbab610a feat: add two-phase startup lifecycle
Allow the dashboard to become available before plugin bootstrap completes and surface runtime readiness and failure states to API callers.

Guard plugin-facing endpoints until runtime is ready and clean up provider and plugin runtime state safely across bootstrap failures, retries, stop, and restart flows.
2026-03-25 14:02:02 +09:00
15 changed files with 3204 additions and 209 deletions

View File

@@ -11,6 +11,7 @@ from typing import Any
import anyio
from astrbot import logger
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
@@ -18,6 +19,8 @@ from astrbot._internal.protocols.lsp.client import AstrbotLspClient
from astrbot._internal.protocols.mcp.client import McpClient
from astrbot._internal.stars import RuntimeStatusStar
log = logger
class AstrbotOrchestrator(BaseAstrbotOrchestrator):
"""

View File

@@ -10,11 +10,14 @@
"""
import asyncio
import inspect
import os
import threading
import time
import traceback
from asyncio import Queue
from enum import Enum
from typing import Any
from astrbot.api import logger, sp
from astrbot.core import LogBroker, LogManager
@@ -43,6 +46,15 @@ from . import astrbot_config, html_renderer
from .event_bus import EventBus
class LifecycleState(str, Enum):
"""Minimal lifecycle contract for split initialization."""
CREATED = "created"
CORE_READY = "core_ready"
RUNTIME_FAILED = "runtime_failed"
RUNTIME_READY = "runtime_ready"
class AstrBotCoreLifecycle:
"""AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
@@ -56,9 +68,36 @@ class AstrBotCoreLifecycle:
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
self.umop_config_router: UmopConfigRouter | None = None
self.astrbot_config_mgr: AstrBotConfigManager | None = None
self.event_queue: Queue | None = None
self.persona_mgr: PersonaManager | None = None
self.provider_manager: ProviderManager | None = None
self.platform_manager: PlatformManager | None = None
self.conversation_manager: ConversationManager | None = None
self.platform_message_history_manager: PlatformMessageHistoryManager | None = (
None
)
self.kb_manager: KnowledgeBaseManager | None = None
self.subagent_orchestrator: SubAgentOrchestrator | None = None
self.cron_manager: CronJobManager | None = None
self.temp_dir_cleaner: TempDirCleaner | None = None
self.star_context: Context | None = None
self.plugin_manager: PluginManager | None = None
self.pipeline_scheduler_mapping: dict[str, PipelineScheduler] = {}
self.astrbot_updator: AstrBotUpdator | None = None
self.event_bus: EventBus | None = None
self.dashboard_shutdown_event: asyncio.Event | None = None
self.curr_tasks: list[asyncio.Task] = []
self.metadata_update_task: asyncio.Task[None] | None = None
self.start_time = 0
self.runtime_bootstrap_task: asyncio.Task[None] | None = None
self.runtime_bootstrap_error: BaseException | None = None
self.runtime_ready_event = asyncio.Event()
self.runtime_failed_event = asyncio.Event()
self.runtime_request_ready = False
self._runtime_wait_interrupted = False
self._set_lifecycle_state(LifecycleState.CREATED)
# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
@@ -79,6 +118,18 @@ class AstrBotCoreLifecycle:
del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared")
@property
def core_initialized(self) -> bool:
return self.lifecycle_state is not LifecycleState.CREATED
@property
def runtime_ready(self) -> bool:
return self.lifecycle_state is LifecycleState.RUNTIME_READY
@property
def runtime_failed(self) -> bool:
return self.lifecycle_state is LifecycleState.RUNTIME_FAILED
async def _init_or_reload_subagent_orchestrator(self) -> None:
"""Create (if needed) and reload the subagent orchestrator from config.
@@ -86,10 +137,14 @@ class AstrBotCoreLifecycle:
to manage enable/disable and tool registration details.
"""
try:
if self.provider_manager is None or self.persona_mgr is None:
raise RuntimeError("core dependencies are not initialized")
provider_manager = self.provider_manager
persona_mgr = self.persona_mgr
if self.subagent_orchestrator is None:
self.subagent_orchestrator = SubAgentOrchestrator(
self.provider_manager.llm_tools,
self.persona_mgr,
provider_manager.llm_tools,
persona_mgr,
)
await self.subagent_orchestrator.reload_from_config(
self.astrbot_config.get("subagent_orchestrator", {}),
@@ -97,11 +152,199 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True)
async def initialize(self) -> None:
"""初始化 AstrBot 核心生命周期管理类.
def _set_lifecycle_state(self, state: LifecycleState) -> None:
"""Update lifecycle state and keep readiness events in sync."""
self.lifecycle_state = state
if state is LifecycleState.RUNTIME_READY:
self.runtime_ready_event.set()
self.runtime_failed_event.clear()
elif state is LifecycleState.RUNTIME_FAILED:
self.runtime_ready_event.clear()
self.runtime_failed_event.set()
else:
self.runtime_ready_event.clear()
self.runtime_failed_event.clear()
def _clear_runtime_failure_for_retry(self) -> None:
if self.lifecycle_state is LifecycleState.RUNTIME_FAILED:
self._set_lifecycle_state(LifecycleState.CORE_READY)
async def _cleanup_partial_runtime_bootstrap(self) -> None:
if self.star_context is not None and hasattr(
self.star_context,
"reset_runtime_registrations",
):
self.star_context.reset_runtime_registrations()
if self.plugin_manager is not None and hasattr(
self.plugin_manager,
"cleanup_loaded_plugins",
):
try:
cleanup_loaded_plugins = getattr(
self.plugin_manager,
"cleanup_loaded_plugins",
)
result = cleanup_loaded_plugins()
if inspect.isawaitable(result):
await result
except Exception as exc:
logger.warning(
f"Failed to clean up loaded plugin state: {exc}",
exc_info=True,
)
for manager in (self.platform_manager, self.kb_manager, self.provider_manager):
if manager is None:
continue
try:
terminate = getattr(manager, "terminate", None)
if not callable(terminate):
continue
result = terminate()
if inspect.isawaitable(result):
await result
except Exception as exc:
logger.warning(
f"Failed to clean up partial runtime bootstrap state: {exc}",
exc_info=True,
)
self._clear_runtime_artifacts()
def _reset_runtime_bootstrap_state(self) -> None:
self.runtime_bootstrap_task = None
self.runtime_bootstrap_error = None
def _interrupt_runtime_bootstrap_waiters(self) -> None:
self._runtime_wait_interrupted = True
self.runtime_bootstrap_error = None
self.runtime_failed_event.set()
async def _consume_completed_bootstrap_task(self) -> None:
task = self.runtime_bootstrap_task
if task is None or not task.done():
return
try:
await task
except asyncio.CancelledError:
pass
except Exception:
pass
async def _wait_for_runtime_ready(self) -> bool:
if self.runtime_ready:
return True
if self._runtime_wait_interrupted:
return False
if self.runtime_failed or self.runtime_bootstrap_error is not None:
await self._consume_completed_bootstrap_task()
return False
runtime_bootstrap_task = self.runtime_bootstrap_task
if runtime_bootstrap_task is None:
raise RuntimeError(
"runtime bootstrap task was not scheduled before start",
)
try:
await runtime_bootstrap_task
except asyncio.CancelledError:
return False
except BaseException as exc:
if self.runtime_bootstrap_error is None:
self.runtime_bootstrap_error = exc
if not self.runtime_failed:
self._set_lifecycle_state(LifecycleState.RUNTIME_FAILED)
return False
if self._runtime_wait_interrupted:
return False
return self.runtime_ready
def _collect_runtime_bootstrap_task(self) -> list[asyncio.Task]:
task = self.runtime_bootstrap_task
self.runtime_bootstrap_task = None
if task is None:
return []
if not task.done():
task.cancel()
return [task]
def _collect_metadata_update_task(self) -> list[asyncio.Task]:
task = self.metadata_update_task
self.metadata_update_task = None
if task is None:
return []
if not task.done():
task.cancel()
return [task]
async def _await_tasks(self, tasks: list[asyncio.Task]) -> None:
for task in tasks:
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
def _require_runtime_bootstrap_components(
self,
) -> tuple[PluginManager, ProviderManager, KnowledgeBaseManager, PlatformManager]:
if (
self.plugin_manager is None
or self.provider_manager is None
or self.kb_manager is None
or self.platform_manager is None
):
raise RuntimeError("initialize_core must complete before runtime bootstrap")
return (
self.plugin_manager,
self.provider_manager,
self.kb_manager,
self.platform_manager,
)
def _require_runtime_started_components(self) -> tuple[EventBus, Context]:
if self.lifecycle_state is not LifecycleState.RUNTIME_READY:
raise RuntimeError("LifecycleState.RUNTIME_READY required before start")
if self.event_bus is None or self.star_context is None:
raise RuntimeError("runtime bootstrap must complete before start")
return self.event_bus, self.star_context
def _cancel_current_tasks(self) -> list[asyncio.Task]:
tasks_to_wait: list[asyncio.Task] = []
for task in self.curr_tasks:
task.cancel()
if isinstance(task, asyncio.Task):
tasks_to_wait.append(task)
self.curr_tasks = []
return tasks_to_wait
def _clear_runtime_artifacts(self) -> None:
self.metadata_update_task = None
self.runtime_request_ready = False
self.event_bus = None
self.pipeline_scheduler_mapping = {}
self.curr_tasks = []
self.start_time = 0
def _require_core_ready(self) -> None:
if not self.core_initialized:
raise RuntimeError("initialize_core must complete before this operation")
def _require_platform_manager(self) -> PlatformManager:
if self.platform_manager is None:
raise RuntimeError("platform manager is not initialized")
return self.platform_manager
async def initialize_core(self) -> None:
"""Initialize the fast core phase without runtime bootstrap."""
if self.core_initialized:
return
self._runtime_wait_interrupted = False
self._reset_runtime_bootstrap_state()
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION)
if os.environ.get("TESTING", ""):
@@ -127,8 +370,11 @@ class AstrBotCoreLifecycle:
ucr=self.umop_config_router,
sp=sp,
)
if self.astrbot_config_mgr is None:
raise RuntimeError("config manager initialization failed")
astrbot_config_mgr = self.astrbot_config_mgr
self.temp_dir_cleaner = TempDirCleaner(
max_size_getter=lambda: self.astrbot_config_mgr.default_conf.get(
max_size_getter=lambda: astrbot_config_mgr.default_conf.get(
TempDirCleaner.CONFIG_KEY,
TempDirCleaner.DEFAULT_MAX_SIZE,
),
@@ -197,53 +443,100 @@ class AstrBotCoreLifecycle:
# 初始化插件管理器
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
# 扫描、注册插件、实例化插件类
await self.plugin_manager.reload()
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
await self.kb_manager.initialize()
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
# 初始化更新器
# 为提前启动 Dashboard 准备核心依赖
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(
self.event_queue,
self.pipeline_scheduler_mapping,
self.astrbot_config_mgr,
)
# 记录启动时间
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: list[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
asyncio.create_task(update_llm_metadata()) # noqa: RUF006
self._set_lifecycle_state(LifecycleState.CORE_READY)
async def bootstrap_runtime(self) -> None:
"""Complete deferred runtime bootstrap after core initialization."""
if not self.core_initialized:
raise RuntimeError(
"initialize_core must be called before bootstrap_runtime",
)
if self.runtime_ready:
return
self._clear_runtime_failure_for_retry()
self.runtime_bootstrap_error = None
self.runtime_ready_event.clear()
self.runtime_failed_event.clear()
try:
plugin_manager, provider_manager, kb_manager, platform_manager = (
self._require_runtime_bootstrap_components()
)
# 扫描、注册插件、实例化插件类
await plugin_manager.reload()
# 根据配置实例化各个 Provider
await provider_manager.initialize()
await kb_manager.initialize()
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
if self.event_queue is None or self.astrbot_config_mgr is None:
raise RuntimeError(
"initialize_core must complete before runtime bootstrap",
)
# 初始化事件总线
self.event_bus = EventBus(
self.event_queue,
self.pipeline_scheduler_mapping,
self.astrbot_config_mgr,
)
# 记录启动时间
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks = []
# 根据配置实例化各个平台适配器
await platform_manager.initialize()
self.metadata_update_task = asyncio.create_task(update_llm_metadata())
self._set_lifecycle_state(LifecycleState.RUNTIME_READY)
except asyncio.CancelledError:
await self._cleanup_partial_runtime_bootstrap()
self._set_lifecycle_state(LifecycleState.CORE_READY)
self.runtime_bootstrap_error = None
raise
except BaseException as exc:
await self._cleanup_partial_runtime_bootstrap()
self._set_lifecycle_state(LifecycleState.RUNTIME_FAILED)
self.runtime_bootstrap_error = exc
raise
async def initialize(self) -> None:
"""初始化 AstrBot 核心生命周期管理类.
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
await self.initialize_core()
await self.bootstrap_runtime()
self.runtime_request_ready = True
def _load(self) -> None:
"""加载事件总线和任务并初始化."""
event_bus, star_context = self._require_runtime_started_components()
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task(
self.event_bus.dispatch(),
event_bus.dispatch(),
name="event_bus",
)
cron_task = None
if self.cron_manager:
cron_task = asyncio.create_task(
self.cron_manager.start(self.star_context),
self.cron_manager.start(star_context),
name="cron_manager",
)
temp_dir_cleaner_task = None
@@ -254,9 +547,9 @@ class AstrBotCoreLifecycle:
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
if self.star_context._register_tasks is not None:
for task in self.star_context._register_tasks:
extra_tasks: list[asyncio.Task[Any]] = []
if star_context._register_tasks is not None:
for task in star_context._register_tasks:
task_name = getattr(task, "__name__", task.__class__.__name__)
extra_tasks.append(asyncio.create_task(task, name=task_name))
@@ -295,6 +588,18 @@ class AstrBotCoreLifecycle:
用load加载事件总线和任务并初始化, 执行启动完成事件钩子
"""
if not await self._wait_for_runtime_ready():
if self._runtime_wait_interrupted:
return
error = self.runtime_bootstrap_error
if error is None:
logger.error("AstrBot runtime bootstrap failed before start completed.")
else:
logger.error(
f"AstrBot runtime bootstrap failed before start completed: {error}",
)
return
self._load()
logger.info("AstrBot 启动完成。")
@@ -311,50 +616,59 @@ class AstrBotCoreLifecycle:
except BaseException:
logger.error(traceback.format_exc())
self.runtime_request_ready = True
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def _shutdown_runtime(self) -> None:
self.runtime_request_ready = False
self._interrupt_runtime_bootstrap_waiters()
tasks_to_wait = self._cancel_current_tasks()
await self._await_tasks(self._collect_metadata_update_task())
runtime_bootstrap_tasks = self._collect_runtime_bootstrap_task()
await self._await_tasks(runtime_bootstrap_tasks)
tasks_to_wait.extend(runtime_bootstrap_tasks)
if self.cron_manager:
await self.cron_manager.shutdown()
if self.plugin_manager and self.plugin_manager.context:
for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
)
if self.provider_manager:
await self.provider_manager.terminate()
if self.platform_manager:
await self.platform_manager.terminate()
if self.kb_manager:
await self.kb_manager.terminate()
if self.dashboard_shutdown_event:
self.dashboard_shutdown_event.set()
self._clear_runtime_artifacts()
self._set_lifecycle_state(LifecycleState.CREATED)
self._reset_runtime_bootstrap_state()
await self._await_tasks(tasks_to_wait)
async def stop(self) -> None:
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
if self.temp_dir_cleaner:
await self.temp_dir_cleaner.stop()
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks:
task.cancel()
if self.cron_manager:
await self.cron_manager.shutdown()
for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
)
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
for task in self.curr_tasks:
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
await self._shutdown_runtime()
async def restart(self) -> None:
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
await self._shutdown_runtime()
if self.astrbot_updator is None:
return
threading.Thread(
target=self.astrbot_updator._reboot,
name="restart",
@@ -364,7 +678,7 @@ class AstrBotCoreLifecycle:
def load_platform(self) -> list[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
platform_insts = self._require_platform_manager().get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(
@@ -382,9 +696,14 @@ class AstrBotCoreLifecycle:
"""
mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
self._require_core_ready()
assert self.astrbot_config_mgr is not None
assert self.plugin_manager is not None
astrbot_config_mgr = self.astrbot_config_mgr
plugin_manager = self.plugin_manager
for conf_id, ab_config in astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id),
PipelineContext(ab_config, plugin_manager, conf_id),
)
await scheduler.initialize()
mapping[conf_id] = scheduler
@@ -397,11 +716,16 @@ class AstrBotCoreLifecycle:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
self._require_core_ready()
assert self.astrbot_config_mgr is not None
astrbot_config_mgr = self.astrbot_config_mgr
ab_config = astrbot_config_mgr.confs.get(conf_id)
if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在")
assert self.plugin_manager is not None
plugin_manager = self.plugin_manager
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id),
PipelineContext(ab_config, plugin_manager, conf_id),
)
await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler

View File

@@ -7,6 +7,7 @@
import asyncio
import traceback
from typing import cast
from astrbot.core import LogBroker, logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -27,20 +28,28 @@ class InitialLoader:
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
try:
await core_lifecycle.initialize()
await core_lifecycle.initialize_core()
except Exception as e:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
return
core_lifecycle.runtime_bootstrap_task = asyncio.create_task(
core_lifecycle.bootstrap_runtime(),
)
core_task = core_lifecycle.start()
shutdown_event = core_lifecycle.dashboard_shutdown_event
if shutdown_event is None:
raise RuntimeError("initialize_core must set dashboard_shutdown_event")
shutdown_event = cast(asyncio.Event, shutdown_event)
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle,
self.db,
core_lifecycle.dashboard_shutdown_event,
shutdown_event,
webui_dir,
)
@@ -55,3 +64,6 @@ class InitialLoader:
except asyncio.CancelledError:
logger.info("🌈 正在关闭 AstrBot...")
await core_lifecycle.stop()
except Exception:
await core_lifecycle.stop()
raise

View File

@@ -742,25 +742,26 @@ class ProviderManager:
logger.info(
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
)
provider_inst = self.inst_map[provider_id]
if self.inst_map[provider_id] in self.provider_insts:
prov_inst = self.inst_map[provider_id]
if provider_inst in self.provider_insts:
prov_inst = provider_inst
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.stt_provider_insts:
prov_inst = self.inst_map[provider_id]
if provider_inst in self.stt_provider_insts:
prov_inst = provider_inst
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.tts_provider_insts:
prov_inst = self.inst_map[provider_id]
if provider_inst in self.tts_provider_insts:
prov_inst = provider_inst
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] == self.curr_provider_inst:
if provider_inst == self.curr_provider_inst:
self.curr_provider_inst = None
if self.inst_map[provider_id] == self.curr_stt_provider_inst:
if provider_inst == self.curr_stt_provider_inst:
self.curr_stt_provider_inst = None
if self.inst_map[provider_id] == self.curr_tts_provider_inst:
if provider_inst == self.curr_tts_provider_inst:
self.curr_tts_provider_inst = None
inst = self.inst_map[provider_id]
@@ -836,6 +837,35 @@ class ProviderManager:
# sync in-memory config for API queries (e.g., embedding provider list)
self.providers_config = astrbot_config["provider"]
def _get_all_provider_instances(self) -> list[Providers]:
seen: set[int] = set()
instances: list[Providers] = []
for provider_inst in [
*self.provider_insts,
*self.stt_provider_insts,
*self.tts_provider_insts,
*self.embedding_provider_insts,
*self.rerank_provider_insts,
*self.inst_map.values(),
]:
marker = id(provider_inst)
if marker in seen:
continue
seen.add(marker)
instances.append(provider_inst)
return instances
def _clear_loaded_instances(self) -> None:
self.provider_insts = []
self.stt_provider_insts = []
self.tts_provider_insts = []
self.embedding_provider_insts = []
self.rerank_provider_insts = []
self.inst_map = {}
self.curr_provider_inst = None
self.curr_stt_provider_inst = None
self.curr_tts_provider_inst = None
async def terminate(self) -> None:
if self._mcp_init_task and not self._mcp_init_task.done():
self._mcp_init_task.cancel()
@@ -844,9 +874,20 @@ class ProviderManager:
except asyncio.CancelledError:
pass
for provider_inst in self.provider_insts:
if isinstance(provider_inst, SupportsTerminate):
self._mcp_init_task = None
provider_instances = self._get_all_provider_instances()
self._clear_loaded_instances()
for provider_inst in provider_instances:
if not isinstance(provider_inst, SupportsTerminate):
continue
try:
await provider_inst.terminate()
except Exception:
logger.error(
"Error while terminating provider instance",
exc_info=True,
)
try:
await self.llm_tools.disable_mcp_server()
except Exception:

View File

@@ -84,6 +84,8 @@ class Context:
cron_manager: CronJobManager,
subagent_orchestrator: SubAgentOrchestrator | None = None,
) -> None:
self.registered_web_apis = []
self._register_tasks = []
self._event_queue = event_queue
"""事件队列。消息平台通过事件队列传递消息事件。"""
self._config = config
@@ -109,10 +111,16 @@ class Context:
self.subagent_orchestrator = subagent_orchestrator
# Register built-in tools so they appear in WebUI and can be
# assigned to subagents. Done here (not at module-import time)
# assigned to subagents. Done here (not at module-import time)
# to avoid circular imports.
self.provider_manager.llm_tools.register_internal_tools()
def reset_runtime_registrations(self) -> None:
if self.registered_web_apis is not None:
self.registered_web_apis.clear()
if self._register_tasks is not None:
self._register_tasks.clear()
async def llm_generate(
self,
*,

View File

@@ -108,9 +108,9 @@ async def _temporary_filtered_requirements_file(
try:
yield filtered_requirements_path
finally:
if filtered_requirements_path and os.path.exists(
if filtered_requirements_path and await anyio.Path(
filtered_requirements_path
):
).exists():
try:
await to_thread.run_sync(os.remove, filtered_requirements_path)
except OSError as exc:
@@ -742,6 +742,24 @@ class PluginManager:
return result
async def cleanup_loaded_plugins(self) -> None:
"""Terminate and unbind all currently loaded plugins without reloading."""
async with self._pm_lock:
for smd in list(star_registry):
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {e!s}, 可能会导致该插件运行不正常。",
)
if smd.name and smd.module_path:
await self._unbind_plugin(smd.name, smd.module_path)
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
async def load(
self,
specified_module_path=None,

View File

@@ -24,7 +24,26 @@ from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queu
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from .route import Route, RouteContext
from .route import (
Route,
RouteContext,
get_runtime_guard_message,
is_runtime_request_ready,
)
class _QueueTimeoutSentinel:
pass
_QUEUE_TIMEOUT = _QueueTimeoutSentinel()
class _ReceiveTimeoutSentinel:
pass
_RECEIVE_TIMEOUT = _ReceiveTimeoutSentinel()
class LiveChatSession:
@@ -137,6 +156,49 @@ class LiveChatRoute(Route):
"""Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
await self._unified_ws_loop(force_ct=None)
async def _ensure_runtime_ready(self) -> bool:
if is_runtime_request_ready(self.core_lifecycle):
return True
await websocket.close(
1013,
get_runtime_guard_message(self.core_lifecycle),
)
return False
async def _recv_ws_json_guarded(
self,
*,
wait_timeout: float = 1.0,
) -> dict[str, Any] | _ReceiveTimeoutSentinel | None:
if not await self._ensure_runtime_ready():
return None
try:
message = await asyncio.wait_for(
websocket.receive_json(),
timeout=wait_timeout,
)
except asyncio.TimeoutError:
return _RECEIVE_TIMEOUT
if not await self._ensure_runtime_ready():
return None
return message
async def _guarded_queue_get(
self,
back_queue: asyncio.Queue,
*,
wait_timeout: float,
) -> dict[str, Any] | _QueueTimeoutSentinel | None:
if not await self._ensure_runtime_ready():
return None
try:
result = await asyncio.wait_for(back_queue.get(), timeout=wait_timeout)
except asyncio.TimeoutError:
return _QUEUE_TIMEOUT
if not await self._ensure_runtime_ready():
return None
return result
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
"""统一 WebSocket 循环"""
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
@@ -157,6 +219,9 @@ class LiveChatRoute(Route):
await websocket.close(1008, "Invalid token")
return
if not await self._ensure_runtime_ready():
return
session_id = f"webchat_live!{username}!{uuid.uuid4()}"
live_session = LiveChatSession(session_id, username)
self.sessions[session_id] = live_session
@@ -165,7 +230,11 @@ class LiveChatRoute(Route):
try:
while True:
message = await websocket.receive_json()
message = await self._recv_ws_json_guarded()
if isinstance(message, _ReceiveTimeoutSentinel):
continue
if message is None:
return
ct = force_ct or message.get("ct", "live")
if ct == "chat":
await self._handle_chat_message(live_session, message)
@@ -289,7 +358,11 @@ class LiveChatRoute(Route):
)
try:
while True:
result = await back_queue.get()
result = await self._guarded_queue_get(back_queue, wait_timeout=1)
if isinstance(result, _QueueTimeoutSentinel):
continue
if result is None:
break
if not result:
continue
await self._send_chat_payload(session, {"ct": "chat", **result})
@@ -486,14 +559,17 @@ class LiveChatRoute(Route):
refs = {}
while True:
if not await self._ensure_runtime_ready():
break
if session.should_interrupt:
session.should_interrupt = False
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
result = await self._guarded_queue_get(back_queue, wait_timeout=1)
if isinstance(result, _QueueTimeoutSentinel):
continue
if result is None:
break
if not result:
continue
@@ -773,6 +849,8 @@ class LiveChatRoute(Route):
try:
while True:
if not await self._ensure_runtime_ready():
break
if session.should_interrupt:
# 用户打断,停止处理
logger.info("[Live Chat] 检测到用户打断")
@@ -789,10 +867,14 @@ class LiveChatRoute(Route):
break
break
try:
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
result = await self._guarded_queue_get(
back_queue,
wait_timeout=0.5,
)
if isinstance(result, _QueueTimeoutSentinel):
continue
if result is None:
break
if not result:
continue

View File

@@ -19,7 +19,13 @@ from astrbot.core.utils.datetime_utils import to_utc_isoformat
from .api_key import ALL_OPEN_API_SCOPES
from .chat import ChatRoute
from .route import Response, Route, RouteContext
from .route import (
Response,
Route,
RouteContext,
get_runtime_guard_message,
is_runtime_request_ready,
)
class OpenApiRoute(Route):
@@ -244,6 +250,14 @@ class OpenApiRoute(Route):
}
)
async def _ensure_runtime_ready(self) -> bool:
if is_runtime_request_ready(self.core_lifecycle):
return True
message = get_runtime_guard_message(self.core_lifecycle)
await self._send_chat_ws_error(message, "RUNTIME_NOT_READY")
await websocket.close(1013, message)
return False
async def _update_session_config_route(
self,
*,
@@ -370,11 +384,16 @@ class OpenApiRoute(Route):
agent_stats = {}
refs = {}
while True:
if not await self._ensure_runtime_ready():
return
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not await self._ensure_runtime_ready():
return
if not result:
continue
@@ -512,9 +531,16 @@ class OpenApiRoute(Route):
await websocket.close(1008, auth_err or "Unauthorized")
return
if not await self._ensure_runtime_ready():
return
try:
while True:
if not await self._ensure_runtime_ready():
return
message = await websocket.receive_json()
if not await self._ensure_runtime_ready():
return
if not isinstance(message, dict):
await self._send_chat_ws_error(
"message must be an object",

View File

@@ -30,12 +30,33 @@ from astrbot.core.utils.astrbot_path import (
get_astrbot_temp_path,
)
from .route import Response, Route, RouteContext
from .route import Response, Route, RouteContext, guard_runtime_ready
PLUGIN_UPDATE_CONCURRENCY = (
3 # limit concurrent updates to avoid overwhelming plugin sources
)
PLUGIN_ROUTE_DEFINITIONS = (
("/plugin/get", "GET", "get_plugins", True),
("/plugin/check-compat", "POST", "check_plugin_compatibility", False),
("/plugin/install", "POST", "install_plugin", True),
("/plugin/install-upload", "POST", "install_plugin_upload", True),
("/plugin/update", "POST", "update_plugin", True),
("/plugin/update-all", "POST", "update_all_plugins", True),
("/plugin/uninstall", "POST", "uninstall_plugin", True),
("/plugin/uninstall-failed", "POST", "uninstall_failed_plugin", False),
("/plugin/market_list", "GET", "get_online_plugins", False),
("/plugin/off", "POST", "off_plugin", True),
("/plugin/on", "POST", "on_plugin", True),
("/plugin/reload-failed", "POST", "reload_failed_plugins", False),
("/plugin/reload", "POST", "reload_plugins", True),
("/plugin/readme", "GET", "get_plugin_readme", True),
("/plugin/changelog", "GET", "get_plugin_changelog", True),
("/plugin/source/get", "GET", "get_custom_source", False),
("/plugin/source/save", "POST", "save_custom_source", False),
("/plugin/source/get-failed-plugins", "GET", "get_failed_plugins", False),
)
@dataclass
class RegistrySource:
@@ -52,28 +73,18 @@ class PluginRoute(Route):
plugin_manager: PluginManager,
) -> None:
super().__init__(context)
self.routes = {
"/plugin/get": ("GET", self.get_plugins),
"/plugin/check-compat": ("POST", self.check_plugin_compatibility),
"/plugin/install": ("POST", self.install_plugin),
"/plugin/install-upload": ("POST", self.install_plugin_upload),
"/plugin/update": ("POST", self.update_plugin),
"/plugin/update-all": ("POST", self.update_all_plugins),
"/plugin/uninstall": ("POST", self.uninstall_plugin),
"/plugin/uninstall-failed": ("POST", self.uninstall_failed_plugin),
"/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload-failed": ("POST", self.reload_failed_plugins),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/changelog": ("GET", self.get_plugin_changelog),
"/plugin/source/get": ("GET", self.get_custom_source),
"/plugin/source/save": ("POST", self.save_custom_source),
"/plugin/source/get-failed-plugins": ("GET", self.get_failed_plugins),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
self._guard_runtime_ready = lambda handler: guard_runtime_ready(
self.core_lifecycle,
handler,
)
self.routes = {}
for path, method, handler_name, requires_runtime in PLUGIN_ROUTE_DEFINITIONS:
handler = getattr(self, handler_name)
if requires_runtime:
handler = self._guard_runtime_ready(handler)
self.routes[path] = (method, handler)
self.register_routes()
self.translated_event_type = {

View File

@@ -1,9 +1,90 @@
from dataclasses import dataclass
from functools import wraps
from typing import TYPE_CHECKING, Any
from quart import Quart
from quart import Quart, jsonify
from astrbot.core.config.astrbot_config import AstrBotConfig
if TYPE_CHECKING:
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
RUNTIME_LOADING_MESSAGE = "Runtime is still loading. Please try again shortly."
RUNTIME_FAILED_MESSAGE = "Runtime bootstrap failed. Please check logs and retry."
def is_runtime_request_ready(core_lifecycle: "AstrBotCoreLifecycle") -> bool:
return getattr(core_lifecycle, "runtime_request_ready", core_lifecycle.runtime_ready)
def get_runtime_guard_message(core_lifecycle: "AstrBotCoreLifecycle") -> str:
failed = (
core_lifecycle.runtime_failed
or core_lifecycle.runtime_bootstrap_error is not None
)
return RUNTIME_FAILED_MESSAGE if failed else RUNTIME_LOADING_MESSAGE
def build_runtime_status_data(
core_lifecycle: "AstrBotCoreLifecycle",
*,
include_failure_details: bool = True,
) -> dict[str, str | bool | None]:
failure_message = None
if include_failure_details and core_lifecycle.runtime_bootstrap_error is not None:
failure_message = str(core_lifecycle.runtime_bootstrap_error)
return {
"state": core_lifecycle.lifecycle_state.value,
"ready": is_runtime_request_ready(core_lifecycle),
"failed": core_lifecycle.runtime_failed,
"failure_message": failure_message,
}
def runtime_status_response(
core_lifecycle: "AstrBotCoreLifecycle",
status_code: int = 503,
*,
include_failure_details: bool = True,
):
message = get_runtime_guard_message(core_lifecycle)
response = jsonify(
Response(
status="error",
message=message,
data=build_runtime_status_data(
core_lifecycle,
include_failure_details=include_failure_details,
),
).__dict__
)
response.status_code = status_code
return response
def runtime_loading_response(
core_lifecycle: "AstrBotCoreLifecycle",
status_code: int = 503,
*,
include_failure_details: bool = True,
):
return runtime_status_response(
core_lifecycle,
status_code=status_code,
include_failure_details=include_failure_details,
)
def guard_runtime_ready(core_lifecycle: "AstrBotCoreLifecycle", handler):
@wraps(handler)
async def wrapped(*args: Any, **kwargs: Any):
if not is_runtime_request_ready(core_lifecycle):
return runtime_status_response(core_lifecycle)
return await handler(*args, **kwargs)
return wrapped
@dataclass
class RouteContext:

View File

@@ -22,7 +22,14 @@ from astrbot.core.utils.io import get_dashboard_version
from astrbot.core.utils.storage_cleaner import StorageCleaner
from astrbot.core.utils.version_comparator import VersionComparator
from .route import Response, Route, RouteContext
from .route import (
Response,
Route,
RouteContext,
build_runtime_status_data,
is_runtime_request_ready,
runtime_loading_response,
)
def _resolve_path(path: str | Path) -> Path:
@@ -40,6 +47,7 @@ class StatRoute(Route):
self.routes = {
"/stat/get": ("GET", self.get_stat),
"/stat/version": ("GET", self.get_version),
"/stat/runtime-status": ("GET", self.get_runtime_status),
"/stat/start-time": ("GET", self.get_start_time),
"/stat/restart-core": ("POST", self.restart_core),
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
@@ -91,8 +99,16 @@ class StatRoute(Route):
)
async def get_start_time(self):
if not is_runtime_request_ready(self.core_lifecycle):
return runtime_loading_response(
self.core_lifecycle,
include_failure_details=False,
)
return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__
async def get_runtime_status(self):
return Response().ok(build_runtime_status_data(self.core_lifecycle)).__dict__
async def get_storage_status(self):
try:
status = await asyncio.to_thread(self.storage_cleaner.get_status)
@@ -100,7 +116,7 @@ class StatRoute(Route):
except Exception:
logger.error("获取存储占用失败", exc_info=True)
return (
Response().error("获取存储占用失败请查看后端日志了解详情。").__dict__
Response().error("获取存储占用失败, 请查看后端日志了解详情。").__dict__
)
async def cleanup_storage(self):
@@ -116,9 +132,11 @@ class StatRoute(Route):
return Response().error(str(e)).__dict__
except Exception:
logger.error("清理存储失败", exc_info=True)
return Response().error("清理存储失败请查看后端日志了解详情。").__dict__
return Response().error("清理存储失败, 请查看后端日志了解详情。").__dict__
async def get_stat(self):
if not is_runtime_request_ready(self.core_lifecycle):
return runtime_loading_response(self.core_lifecycle)
offset_sec = request.args.get("offset_sec", 86400)
offset_sec = int(offset_sec)
try:

View File

@@ -21,6 +21,7 @@ from hypercorn.asyncio import serve
from hypercorn.config import Config as HyperConfig
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
from quart.typing import ResponseReturnValue
from quart_cors import cors
from astrbot.core import logger
@@ -62,10 +63,41 @@ from .routes import (
UpdateRoute,
)
from .routes.api_key import ALL_OPEN_API_SCOPES
from .routes.route import is_runtime_request_ready, runtime_loading_response
# Static assets shipped inside the wheel (built during `hatch build`).
_BUNDLED_DIST = Path(__file__).parent / "dist"
_PUBLIC_ALLOWED_ENDPOINT_PREFIXES = (
"/api/auth/login",
"/api/file",
"/api/platform/webhook",
"/api/stat/start-time",
"/api/backup/download",
)
_RUNTIME_EXTRA_BYPASS_ENDPOINT_PREFIXES = (
"/api/stat/version",
"/api/stat/runtime-status",
"/api/stat/restart-core",
"/api/stat/changelog",
"/api/stat/changelog/list",
"/api/stat/first-notice",
)
_RUNTIME_BYPASS_ENDPOINT_PREFIXES = (
tuple(
prefix
for prefix in _PUBLIC_ALLOWED_ENDPOINT_PREFIXES
if prefix != "/api/platform/webhook"
)
+ _RUNTIME_EXTRA_BYPASS_ENDPOINT_PREFIXES
)
_RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES = (
"/api/config/",
"/api/plugin/reload-failed",
"/api/plugin/uninstall-failed",
"/api/plugin/source/get-failed-plugins",
)
APP: Quart
_ENV_PLACEHOLDER_RE = re.compile(
@@ -122,12 +154,10 @@ class AstrBotJSONProvider(DefaultJSONProvider):
class AstrBotDashboard:
"""AstrBot Web Dashboard"""
ALLOWED_ENDPOINT_PREFIXES = (
"/api/auth/login",
"/api/file",
"/api/platform/webhook",
"/api/stat/start-time",
"/api/backup/download",
ALLOWED_ENDPOINT_PREFIXES = _PUBLIC_ALLOWED_ENDPOINT_PREFIXES
RUNTIME_BYPASS_ENDPOINT_PREFIXES = _RUNTIME_BYPASS_ENDPOINT_PREFIXES
RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES = (
_RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES
)
def __init__(
@@ -182,8 +212,8 @@ class AstrBotDashboard:
if self.enable_webui and not (Path(self.data_path) / "index.html").exists():
logger.warning(
f"前端未内置或未初始化 (index.html missing in {self.data_path})"
"回退到仅启动后端请访问在线面板dash.astrbot.men"
f"前端未内置或未初始化 (index.html missing in {self.data_path}), "
"回退到仅启动后端. 请访问在线面板: dash.astrbot.men"
)
self.enable_webui = False
self._webui_fallback = True
@@ -233,7 +263,7 @@ class AstrBotDashboard:
@self.app.route("/")
async def index():
if not self.enable_webui:
return "前端未启用请访问在线面板dash.astrbot.men"
return "前端未启用, 请访问在线面板: dash.astrbot.men"
try:
return await self.app.send_static_file("index.html")
except werkzeug.exceptions.NotFound:
@@ -243,7 +273,7 @@ class AstrBotDashboard:
@self.app.errorhandler(404)
async def not_found(e):
if not self.enable_webui:
return "前端未启用请访问在线面板dash.astrbot.men"
return "前端未启用, 请访问在线面板: dash.astrbot.men"
if request.path.startswith("/api/"):
return jsonify(Response().error("Not Found").to_json()), 404
try:
@@ -263,13 +293,14 @@ class AstrBotDashboard:
logging.getLogger(self.app.name).removeHandler(default_handler)
def _init_routes(self, db: BaseDatabase):
UpdateRoute(
self.context, self.core_lifecycle.astrbot_updator, self.core_lifecycle
)
astrbot_updator = self.core_lifecycle.astrbot_updator
plugin_manager = self.core_lifecycle.plugin_manager
assert astrbot_updator is not None
assert plugin_manager is not None
UpdateRoute(self.context, astrbot_updator, self.core_lifecycle)
StatRoute(self.context, db, self.core_lifecycle)
PluginRoute(
self.context, self.core_lifecycle, self.core_lifecycle.plugin_manager
)
PluginRoute(self.context, self.core_lifecycle, plugin_manager)
self.command_route = CommandRoute(self.context)
self.cr = ConfigRoute(self.context, self.core_lifecycle)
@@ -308,21 +339,24 @@ class AstrBotDashboard:
self.app.add_url_rule(
"/api/plug/<path:subpath>",
view_func=self.srv_plug_route,
view_func=self.guarded_srv_plug_route,
methods=["GET", "POST"],
)
def _init_plugin_route_index(self):
"""将插件路由索引,避免 O(n) 查找"""
self._plugin_route_map: dict[tuple[str, str], Callable] = {}
if self.core_lifecycle.star_context.registered_web_apis is None:
self.core_lifecycle.star_context.registered_web_apis = []
star_context = self.core_lifecycle.star_context
if star_context is None:
return
if star_context.registered_web_apis is None:
star_context.registered_web_apis = []
for (
route,
handler,
methods,
_,
) in self.core_lifecycle.star_context.registered_web_apis:
) in star_context.registered_web_apis:
for method in methods:
self._plugin_route_map[(route, method)] = handler
@@ -334,6 +368,48 @@ class AstrBotDashboard:
logger.info("Initialized random JWT secret for dashboard.")
self._jwt_secret = dashboard_cfg["jwt_secret"]
async def guarded_srv_plug_route(
self, subpath: str, *args, **kwargs
) -> ResponseReturnValue:
guard_resp = self._maybe_runtime_guard(request.path)
if guard_resp is not None:
return guard_resp
return await self.srv_plug_route(subpath, *args, **kwargs)
def _should_bypass_runtime_guard(self, path: str) -> bool:
return any(
path.startswith(prefix)
for prefix in self.RUNTIME_BYPASS_ENDPOINT_PREFIXES
)
def _should_allow_failed_runtime_recovery(self, path: str) -> bool:
if not (
self.core_lifecycle.runtime_failed
or self.core_lifecycle.runtime_bootstrap_error is not None
):
return False
return any(
path.startswith(prefix)
for prefix in self.RUNTIME_FAILED_RECOVERY_ENDPOINT_PREFIXES
)
def _maybe_runtime_guard(
self,
path: str,
*,
include_failure_details: bool = True,
) -> ResponseReturnValue | None:
if self._should_bypass_runtime_guard(path):
return None
if self._should_allow_failed_runtime_recovery(path):
return None
if not is_runtime_request_ready(self.core_lifecycle):
return runtime_loading_response(
self.core_lifecycle,
include_failure_details=include_failure_details,
)
return None
async def auth_middleware(self):
# 放行CORS预检请求
if request.method == "OPTIONS":
@@ -372,9 +448,21 @@ class AstrBotDashboard:
g.api_key_scopes = scopes
g.username = f"api_key:{api_key.key_id}"
await self.db.touch_api_key(api_key.key_id)
guard_resp = self._maybe_runtime_guard(
request.path,
include_failure_details=False,
)
if guard_resp is not None:
return guard_resp
return None
if any(request.path.startswith(p) for p in self.ALLOWED_ENDPOINT_PREFIXES):
guard_resp = self._maybe_runtime_guard(
request.path,
include_failure_details=False,
)
if guard_resp is not None:
return guard_resp
return None
token = request.headers.get("Authorization")
@@ -394,14 +482,25 @@ class AstrBotDashboard:
except jwt.PyJWTError:
return self._unauthorized("Token 无效")
guard_resp = self._maybe_runtime_guard(request.path)
if guard_resp is not None:
return guard_resp
@staticmethod
def _unauthorized(msg: str):
r = jsonify(Response().error(msg).to_json())
r.status_code = 401
return r
def _get_plugin_handler(self, subpath: str, method: str) -> Callable | None:
handler = self._plugin_route_map.get((f"/{subpath}", method))
if handler is not None:
return handler
self._init_plugin_route_index()
return self._plugin_route_map.get((f"/{subpath}", method))
async def srv_plug_route(self, subpath: str, *args, **kwargs):
handler = self._plugin_route_map.get((f"/{subpath}", request.method))
handler = self._get_plugin_handler(subpath, request.method)
if not handler:
return jsonify(Response().error("未找到该路由").to_json())
@@ -481,10 +580,10 @@ class AstrBotDashboard:
"""Run dashboard server (blocking)"""
if self._webui_fallback:
logger.warning(
"前端未内置或未初始化回退到仅启动后端请访问在线面板dash.astrbot.men"
"前端未内置或未初始化, 回退到仅启动后端. 请访问在线面板: dash.astrbot.men"
)
elif not self.enable_webui:
logger.warning("前端已禁用请访问在线面板dash.astrbot.men")
logger.warning("前端已禁用, 请访问在线面板: dash.astrbot.men")
dashboard_config = self.config.get("dashboard", {})
host_value = os.environ.get("ASTRBOT_HOST") or dashboard_config.get(

View File

@@ -9,8 +9,9 @@ import uuid
import zipfile
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import anyio
import pytest
import pytest_asyncio
from quart import Quart
@@ -18,11 +19,17 @@ from werkzeug.datastructures import FileStorage
from astrbot.cli.commands.cmd_conf import hash_dashboard_password_secure
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle, LifecycleState
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.star import star_registry
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.utils.pip_installer import PipInstallError
from astrbot.dashboard.routes.live_chat import (
LiveChatRoute,
LiveChatSession,
_ReceiveTimeoutSentinel,
)
from astrbot.dashboard.routes.open_api import OpenApiRoute
from astrbot.dashboard.routes.plugin import PluginRoute
from astrbot.dashboard.server import AstrBotDashboard, _expand_env_placeholders
from tests.fixtures.helpers import (
@@ -150,6 +157,86 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc
return {"Authorization": f"Bearer {token}"}
def _set_runtime_loading(core_lifecycle: AstrBotCoreLifecycle) -> None:
core_lifecycle._set_lifecycle_state(LifecycleState.CORE_READY)
core_lifecycle.runtime_ready_event.clear()
core_lifecycle.runtime_failed_event.clear()
core_lifecycle.runtime_bootstrap_error = None
core_lifecycle.runtime_request_ready = False
def _restore_runtime_ready(core_lifecycle: AstrBotCoreLifecycle) -> None:
core_lifecycle._set_lifecycle_state(LifecycleState.RUNTIME_READY)
core_lifecycle.runtime_ready_event.set()
core_lifecycle.runtime_failed_event.clear()
core_lifecycle.runtime_bootstrap_error = None
core_lifecycle.runtime_request_ready = True
def _set_runtime_failed(
core_lifecycle: AstrBotCoreLifecycle,
message: str = "Runtime bootstrap failed.",
) -> None:
core_lifecycle._set_lifecycle_state(LifecycleState.RUNTIME_FAILED)
core_lifecycle.runtime_ready_event.clear()
core_lifecycle.runtime_failed_event.set()
core_lifecycle.runtime_bootstrap_error = RuntimeError(message)
core_lifecycle.runtime_request_ready = False
def _set_runtime_starting(core_lifecycle: AstrBotCoreLifecycle) -> None:
core_lifecycle._set_lifecycle_state(LifecycleState.RUNTIME_READY)
core_lifecycle.runtime_ready_event.set()
core_lifecycle.runtime_failed_event.clear()
core_lifecycle.runtime_bootstrap_error = None
core_lifecycle.runtime_request_ready = False
def _assert_runtime_loading_response(
data: dict,
state: LifecycleState = LifecycleState.CORE_READY,
) -> None:
assert data == {
"status": "error",
"message": "Runtime is still loading. Please try again shortly.",
"data": {
"state": state.value,
"ready": False,
"failed": False,
"failure_message": None,
},
}
def _assert_runtime_failed_response(
data: dict,
message: str | None,
state: SimpleNamespace | LifecycleState = LifecycleState.RUNTIME_FAILED,
) -> None:
assert data == {
"status": "error",
"message": "Runtime bootstrap failed. Please check logs and retry.",
"data": {
"state": state.value,
"ready": False,
"failed": True,
"failure_message": message,
},
}
def _build_plugin_upload_files() -> dict:
return {
"files": {
"file": FileStorage(
stream=io.BytesIO(b"fake-plugin-archive"),
filename="demo-plugin.zip",
content_type="application/zip",
)
}
}
@pytest.mark.asyncio
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Tests the login functionality with both wrong and correct credentials."""
@@ -517,6 +604,847 @@ async def test_batch_delete_sessions_uses_batch_lookup(
assert called["batch_lookup_count"] == 1
@pytest.mark.asyncio
async def test_runtime_status_reports_current_state(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
response = await test_client.get(
"/api/stat/runtime-status",
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["state"] == LifecycleState.RUNTIME_READY.value
assert data["data"]["ready"] is True
assert data["data"]["failed"] is False
assert data["data"]["failure_message"] is None
_set_runtime_loading(core_lifecycle_td)
try:
response = await test_client.get(
"/api/stat/runtime-status",
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["state"] == LifecycleState.CORE_READY.value
assert data["data"]["ready"] is False
assert data["data"]["failed"] is False
assert data["data"]["failure_message"] is None
finally:
_restore_runtime_ready(core_lifecycle_td)
_set_runtime_failed(core_lifecycle_td, "bootstrap exploded during provider init")
try:
response = await test_client.get(
"/api/stat/runtime-status",
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["state"] == "runtime_failed"
assert data["data"]["ready"] is False
assert data["data"]["failed"] is True
assert (
data["data"]["failure_message"] == "bootstrap exploded during provider init"
)
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("path", "headers", "misleading_field"),
[
("/api/stat/get", "auth", "platform"),
("/api/stat/start-time", None, "start_time"),
],
ids=["stat-get", "stat-start-time"],
)
async def test_stat_endpoints_return_503_while_runtime_loading(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
path: str,
headers: str | None,
misleading_field: str,
):
test_client = app.test_client()
request_kwargs = {}
if headers == "auth":
request_kwargs["headers"] = authenticated_header
_set_runtime_loading(core_lifecycle_td)
try:
response = await test_client.get(path, **request_kwargs)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_loading_response(data)
assert misleading_field not in data["data"]
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("path", "headers", "misleading_field"),
[
("/api/stat/get", "auth", "platform"),
("/api/stat/start-time", None, "start_time"),
],
ids=["stat-get", "stat-start-time"],
)
async def test_stat_endpoints_return_failure_aware_503_after_bootstrap_failure(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
path: str,
headers: str | None,
misleading_field: str,
):
test_client = app.test_client()
request_kwargs = {}
if headers == "auth":
request_kwargs["headers"] = authenticated_header
_set_runtime_failed(core_lifecycle_td, "runtime bootstrap failed in plugin reload")
try:
response = await test_client.get(path, **request_kwargs)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_failed_response(
data,
(
None
if path == "/api/stat/start-time"
else "runtime bootstrap failed in plugin reload"
),
)
assert misleading_field not in data["data"]
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_runtime_dependent_dashboard_route_returns_503_while_runtime_starting(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
_set_runtime_starting(core_lifecycle_td)
try:
response = await test_client.get(
"/api/config/get",
headers=authenticated_header,
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_loading_response(data, state=LifecycleState.RUNTIME_READY)
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_openapi_configs_returns_503_while_runtime_loading(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
test_client = app.test_client()
create_res = await test_client.post(
"/api/apikey/create",
json={"name": f"runtime-guard-{uuid.uuid4().hex[:8]}", "scopes": ["config"]},
headers=authenticated_header,
)
assert create_res.status_code == 200
create_data = await create_res.get_json()
api_key = create_data["data"]["api_key"]
assert core_lifecycle_td.astrbot_config_mgr is not None
monkeypatch.setattr(
core_lifecycle_td.astrbot_config_mgr,
"get_conf_list",
MagicMock(side_effect=AssertionError("config list should not be read")),
)
_set_runtime_loading(core_lifecycle_td)
try:
response = await test_client.get(
"/api/v1/configs",
headers={"X-API-Key": api_key},
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_loading_response(data)
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_openapi_failure_response_hides_bootstrap_error_details(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
create_res = await test_client.post(
"/api/apikey/create",
json={"name": f"runtime-failed-{uuid.uuid4().hex[:8]}", "scopes": ["config"]},
headers=authenticated_header,
)
assert create_res.status_code == 200
create_data = await create_res.get_json()
api_key = create_data["data"]["api_key"]
_set_runtime_failed(core_lifecycle_td, "provider secret exploded")
try:
response = await test_client.get(
"/api/v1/configs",
headers={"X-API-Key": api_key},
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_failed_response(data, None)
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_openapi_chat_ws_closes_while_runtime_loading(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
route = OpenApiRoute.__new__(OpenApiRoute)
route.core_lifecycle = core_lifecycle_td
route._authenticate_chat_ws_api_key = AsyncMock(return_value=(True, None))
route._send_chat_ws_error = AsyncMock()
fake_websocket = MagicMock()
fake_websocket.receive_json = AsyncMock(
side_effect=AssertionError("websocket should not consume messages")
)
fake_websocket.close = AsyncMock()
monkeypatch.setattr("astrbot.dashboard.routes.open_api.websocket", fake_websocket)
_set_runtime_loading(core_lifecycle_td)
try:
await route.chat_ws()
route._send_chat_ws_error.assert_awaited_once()
fake_websocket.close.assert_awaited_once()
fake_websocket.receive_json.assert_not_awaited()
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_openapi_chat_ws_closes_when_runtime_stops_mid_session(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
route = OpenApiRoute.__new__(OpenApiRoute)
route.core_lifecycle = core_lifecycle_td
route._authenticate_chat_ws_api_key = AsyncMock(return_value=(True, None))
route._send_chat_ws_error = AsyncMock()
_restore_runtime_ready(core_lifecycle_td)
fake_websocket = MagicMock()
fake_websocket.send_json = AsyncMock()
fake_websocket.close = AsyncMock()
receive_calls = 0
async def receive_json():
nonlocal receive_calls
receive_calls += 1
if receive_calls > 1:
raise AssertionError("websocket should close before next receive")
core_lifecycle_td.runtime_request_ready = False
return {"t": "ping"}
fake_websocket.receive_json = AsyncMock(side_effect=receive_json)
monkeypatch.setattr("astrbot.dashboard.routes.open_api.websocket", fake_websocket)
await route.chat_ws()
route._send_chat_ws_error.assert_awaited_once()
fake_websocket.send_json.assert_not_awaited()
fake_websocket.close.assert_awaited_once()
@pytest.mark.asyncio
async def test_live_chat_ws_closes_while_runtime_loading(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path,
):
route = LiveChatRoute.__new__(LiveChatRoute)
route.core_lifecycle = core_lifecycle_td
route.db = MagicMock()
route.plugin_manager = core_lifecycle_td.plugin_manager
route.platform_history_mgr = core_lifecycle_td.platform_message_history_manager
route.sessions = {}
route.config = core_lifecycle_td.astrbot_config
route.attachments_dir = str(tmp_path / "attachments")
route.legacy_img_dir = str(tmp_path / "legacy")
fake_websocket = MagicMock()
fake_websocket.args = {"token": "test-token"}
fake_websocket.receive_json = AsyncMock(
side_effect=AssertionError("live chat websocket should not consume messages")
)
fake_websocket.close = AsyncMock()
monkeypatch.setattr("astrbot.dashboard.routes.live_chat.websocket", fake_websocket)
_set_runtime_loading(core_lifecycle_td)
try:
with patch(
"astrbot.dashboard.routes.live_chat.jwt.decode",
return_value={"username": "astrbot"},
):
await route._unified_ws_loop(force_ct="live")
fake_websocket.close.assert_awaited_once()
fake_websocket.receive_json.assert_not_awaited()
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_live_chat_ws_closes_when_runtime_stops_mid_session(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path,
):
route = LiveChatRoute.__new__(LiveChatRoute)
route.core_lifecycle = core_lifecycle_td
route.db = MagicMock()
route.plugin_manager = core_lifecycle_td.plugin_manager
route.platform_history_mgr = core_lifecycle_td.platform_message_history_manager
route.sessions = {}
route.config = core_lifecycle_td.astrbot_config
route.attachments_dir = str(tmp_path / "attachments")
route.legacy_img_dir = str(tmp_path / "legacy")
route._handle_message = AsyncMock(
side_effect=lambda *_args, **_kwargs: setattr(
core_lifecycle_td,
"runtime_request_ready",
False,
)
)
_restore_runtime_ready(core_lifecycle_td)
fake_websocket = MagicMock()
fake_websocket.args = {"token": "test-token"}
fake_websocket.close = AsyncMock()
fake_websocket.receive_json = AsyncMock(
side_effect=[{}, AssertionError("live chat websocket should close before next receive")]
)
monkeypatch.setattr("astrbot.dashboard.routes.live_chat.websocket", fake_websocket)
with patch(
"astrbot.dashboard.routes.live_chat.jwt.decode",
return_value={"username": "astrbot"},
):
await route._unified_ws_loop(force_ct="live")
route._handle_message.assert_awaited_once()
fake_websocket.close.assert_awaited_once()
@pytest.mark.asyncio
async def test_live_chat_recv_guard_polls_until_runtime_stops(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
route = LiveChatRoute.__new__(LiveChatRoute)
route.core_lifecycle = core_lifecycle_td
fake_websocket = MagicMock()
fake_websocket.close = AsyncMock()
fake_websocket.receive_json = AsyncMock(
side_effect=AssertionError("receive_json should be wrapped by wait_for")
)
monkeypatch.setattr("astrbot.dashboard.routes.live_chat.websocket", fake_websocket)
_restore_runtime_ready(core_lifecycle_td)
async def fake_wait_for(awaitable, **kwargs):
del kwargs
if hasattr(awaitable, "close"):
awaitable.close()
core_lifecycle_td.runtime_request_ready = False
raise asyncio.TimeoutError
monkeypatch.setattr(
"astrbot.dashboard.routes.live_chat.asyncio.wait_for",
fake_wait_for,
)
first = await route._recv_ws_json_guarded(wait_timeout=0.01)
assert isinstance(first, _ReceiveTimeoutSentinel)
second = await route._recv_ws_json_guarded(wait_timeout=0.01)
assert second is None
fake_websocket.close.assert_awaited_once()
@pytest.mark.asyncio
async def test_live_chat_subscription_stops_forwarding_when_runtime_stops(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
route = LiveChatRoute.__new__(LiveChatRoute)
route.core_lifecycle = core_lifecycle_td
route._send_chat_payload = AsyncMock(
side_effect=AssertionError("subscription should not forward after runtime stop")
)
session = LiveChatSession("session", "astrbot")
session.chat_subscriptions["chat-session"] = "request-id"
session.chat_subscription_tasks["chat-session"] = MagicMock()
fake_queue = MagicMock()
fake_queue.get = AsyncMock(
side_effect=[{"type": "plain", "data": "hello"}, asyncio.CancelledError()]
)
remove_back_queue = MagicMock()
monkeypatch.setattr(
"astrbot.dashboard.routes.live_chat.webchat_queue_mgr.get_or_create_back_queue",
MagicMock(return_value=fake_queue),
)
monkeypatch.setattr(
"astrbot.dashboard.routes.live_chat.webchat_queue_mgr.remove_back_queue",
remove_back_queue,
)
core_lifecycle_td.runtime_request_ready = False
await route._forward_chat_subscription(session, "chat-session", "request-id")
route._send_chat_payload.assert_not_awaited()
remove_back_queue.assert_called_once_with("request-id")
@pytest.mark.asyncio
async def test_public_start_time_failure_response_does_not_leak_bootstrap_error(
app: Quart,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
_set_runtime_failed(core_lifecycle_td, "provider api key leaked in stacktrace")
try:
response = await test_client.get("/api/stat/start-time")
assert response.status_code == 503
data = await response.get_json()
assert data == {
"status": "error",
"message": "Runtime bootstrap failed. Please check logs and retry.",
"data": {
"state": LifecycleState.RUNTIME_FAILED.value,
"ready": False,
"failed": True,
"failure_message": None,
},
}
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_public_webhook_failure_response_does_not_leak_bootstrap_error(
app: Quart,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
_set_runtime_failed(core_lifecycle_td, "provider webhook secret leaked")
try:
response = await test_client.post("/api/platform/webhook/test-webhook")
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_failed_response(data, None)
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_config_routes_remain_available_after_runtime_bootstrap_failure(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
_set_runtime_failed(core_lifecycle_td, "broken provider config")
try:
response = await test_client.get(
"/api/config/get",
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "path", "request_kwargs_factory", "setup_recovery"),
[
(
"get",
"/api/plugin/source/get-failed-plugins",
lambda: {},
lambda pm: pm.failed_plugin_dict.update(
{"broken-plugin": {"error": "boom"}}
),
),
(
"post",
"/api/plugin/reload-failed",
lambda: {"json": {"dir_name": "broken-plugin"}},
lambda pm: setattr(
pm,
"reload_failed_plugin",
AsyncMock(return_value=(True, None)),
),
),
(
"post",
"/api/plugin/uninstall-failed",
lambda: {"json": {"dir_name": "broken-plugin"}},
lambda pm: setattr(pm, "uninstall_failed_plugin", AsyncMock()),
),
],
ids=["get-failed-plugins", "reload-failed", "uninstall-failed"],
)
async def test_failed_plugin_recovery_routes_remain_available_after_bootstrap_failure(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
method: str,
path: str,
request_kwargs_factory,
setup_recovery,
):
test_client = app.test_client()
plugin_manager = core_lifecycle_td.plugin_manager
assert plugin_manager is not None
setup_recovery(plugin_manager)
_set_runtime_failed(core_lifecycle_td, "plugin bootstrap failed")
try:
response = await getattr(test_client, method)(
path,
headers=authenticated_header,
**request_kwargs_factory(),
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "path", "request_kwargs_factory", "guard_attr"),
[
("get", "/api/plugin/get", lambda: {}, "context.get_all_stars"),
(
"get",
"/api/plugin/source/get-failed-plugins",
lambda: {},
None,
),
(
"post",
"/api/plugin/install",
lambda: {"json": {"url": "https://example.com/plugin"}},
"install_plugin",
),
(
"post",
"/api/plugin/install-upload",
_build_plugin_upload_files,
"install_plugin_from_file",
),
(
"post",
"/api/plugin/update",
lambda: {"json": {"name": "demo-plugin"}},
"update_plugin",
),
(
"post",
"/api/plugin/update-all",
lambda: {"json": {"names": ["demo-plugin"]}},
"update_plugin",
),
(
"post",
"/api/plugin/uninstall",
lambda: {"json": {"name": "demo-plugin"}},
"uninstall_plugin",
),
(
"post",
"/api/plugin/uninstall-failed",
lambda: {"json": {"dir_name": "demo-plugin"}},
"uninstall_failed_plugin",
),
(
"post",
"/api/plugin/off",
lambda: {"json": {"name": "demo-plugin"}},
"turn_off_plugin",
),
(
"post",
"/api/plugin/on",
lambda: {"json": {"name": "demo-plugin"}},
"turn_on_plugin",
),
(
"post",
"/api/plugin/reload",
lambda: {"json": {"name": "demo-plugin"}},
"reload",
),
(
"post",
"/api/plugin/reload-failed",
lambda: {"json": {"dir_name": "demo-plugin"}},
"reload_failed_plugin",
),
(
"get",
"/api/plugin/readme?name=demo-plugin",
lambda: {},
"context.get_all_stars",
),
(
"get",
"/api/plugin/changelog?name=demo-plugin",
lambda: {},
"context.get_all_stars",
),
],
ids=[
"get",
"get-failed-plugins",
"install",
"install-upload",
"update",
"update-all",
"uninstall",
"uninstall-failed",
"off",
"on",
"reload",
"reload-failed",
"readme",
"changelog",
],
)
async def test_plugin_api_returns_503_while_runtime_loading(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
method: str,
path: str,
request_kwargs_factory,
guard_attr: str,
):
test_client = app.test_client()
plugin_manager = core_lifecycle_td.plugin_manager
assert plugin_manager is not None
if guard_attr == "context.get_all_stars":
monkeypatch.setattr(
plugin_manager.context,
"get_all_stars",
lambda: (_ for _ in ()).throw(
AssertionError("plugin state should not be read")
),
)
elif guard_attr is not None:
monkeypatch.setattr(
plugin_manager,
guard_attr,
AsyncMock(side_effect=AssertionError(f"{guard_attr} should not be called")),
)
_set_runtime_loading(core_lifecycle_td)
try:
request_kwargs = request_kwargs_factory()
response = await getattr(test_client, method)(
path,
headers=authenticated_header,
**request_kwargs,
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_loading_response(data)
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "path", "request_kwargs_factory", "guard_attr"),
[
("get", "/api/plugin/get", lambda: {}, "context.get_all_stars"),
(
"post",
"/api/plugin/install",
lambda: {"json": {"url": "https://example.com/plugin"}},
"install_plugin",
),
],
ids=["get", "install"],
)
async def test_plugin_api_returns_failed_response_after_runtime_bootstrap_failure(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
method: str,
path: str,
request_kwargs_factory,
guard_attr: str,
):
test_client = app.test_client()
plugin_manager = core_lifecycle_td.plugin_manager
assert plugin_manager is not None
if guard_attr == "context.get_all_stars":
monkeypatch.setattr(
plugin_manager.context,
"get_all_stars",
lambda: (_ for _ in ()).throw(
AssertionError("plugin state should not be read")
),
)
elif guard_attr is not None:
monkeypatch.setattr(
plugin_manager,
guard_attr,
AsyncMock(side_effect=AssertionError(f"{guard_attr} should not be called")),
)
_set_runtime_failed(core_lifecycle_td, "plugin bootstrap failed")
try:
request_kwargs = request_kwargs_factory()
response = await getattr(test_client, method)(
path,
headers=authenticated_header,
**request_kwargs,
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_failed_response(data, "plugin bootstrap failed")
finally:
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_plugin_web_route_returns_503_while_runtime_loading(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
route_called = False
star_context = core_lifecycle_td.star_context
assert star_context is not None
async def dummy_plugin_route(*args, **kwargs):
nonlocal route_called
route_called = True
return {"status": "ok", "message": None, "data": {"called": True}}
registered_web_apis = star_context.registered_web_apis
original_registered_web_apis = list(registered_web_apis)
registered_web_apis[:] = [
("/runtime-guard-test", dummy_plugin_route, ["GET"], "runtime guard test"),
]
_set_runtime_loading(core_lifecycle_td)
try:
response = await test_client.get(
"/api/plug/runtime-guard-test",
headers=authenticated_header,
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_loading_response(data)
assert route_called is False
finally:
registered_web_apis[:] = original_registered_web_apis
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_plugin_web_route_returns_failed_response_after_runtime_bootstrap_failure(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
route_called = False
star_context = core_lifecycle_td.star_context
assert star_context is not None
async def dummy_plugin_route(*args, **kwargs):
nonlocal route_called
route_called = True
return {"status": "ok", "message": None, "data": {"called": True}}
registered_web_apis = star_context.registered_web_apis
original_registered_web_apis = list(registered_web_apis)
registered_web_apis[:] = [
(
"/runtime-failed-guard-test",
dummy_plugin_route,
["GET"],
"runtime guard test",
),
]
_set_runtime_failed(core_lifecycle_td, "plugin web runtime bootstrap failed")
try:
response = await test_client.get(
"/api/plug/runtime-failed-guard-test",
headers=authenticated_header,
)
assert response.status_code == 503
data = await response.get_json()
_assert_runtime_failed_response(
data,
"plugin web runtime bootstrap failed",
)
assert route_called is False
finally:
registered_web_apis[:] = original_registered_web_apis
_restore_runtime_ready(core_lifecycle_td)
@pytest.mark.asyncio
async def test_plugins(
app: Quart,
@@ -550,7 +1478,9 @@ async def test_plugins(
assert data["status"] == "ok"
# 使用 MockPluginBuilder 创建测试插件
plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path
plugin_manager = core_lifecycle_td.plugin_manager
assert plugin_manager is not None
plugin_store_path = plugin_manager.plugin_store_path
builder = MockPluginBuilder(plugin_store_path)
# 定义测试插件
@@ -565,10 +1495,8 @@ async def test_plugins(
mock_update = create_mock_updater_update(builder)
# 设置 Mock
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "install", mock_install
)
monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update)
monkeypatch.setattr(plugin_manager.updator, "install", mock_install)
monkeypatch.setattr(plugin_manager.updator, "update", mock_update)
try:
# 插件安装
@@ -1103,7 +2031,7 @@ async def test_do_update(
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert os.path.exists(release_path)
assert await anyio.Path(release_path).exists()
@pytest.mark.asyncio

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
"""Tests for InitialLoader."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from astrbot.core.initial_loader import InitialLoader
@pytest.mark.asyncio
async def test_initial_loader_start_awaits_initialize_core_and_schedules_runtime_bootstrap():
"""Test InitialLoader.start splits core init from background runtime bootstrap."""
loader = InitialLoader(MagicMock(), MagicMock())
call_order: list[str] = []
real_create_task = asyncio.create_task
created_tasks: list[asyncio.Task] = []
lifecycle = MagicMock()
lifecycle.dashboard_shutdown_event = asyncio.Event()
lifecycle.runtime_bootstrap_task = None
async def initialize_core() -> None:
call_order.append("initialize_core")
async def bootstrap_runtime() -> None:
call_order.append("bootstrap_runtime")
async def start_core() -> None:
call_order.append("core_start")
async def run_dashboard() -> None:
call_order.append("dashboard_run")
lifecycle.initialize = AsyncMock(
side_effect=AssertionError("initialize should not be used")
)
lifecycle.initialize_core = AsyncMock(side_effect=initialize_core)
lifecycle.bootstrap_runtime = AsyncMock(side_effect=bootstrap_runtime)
lifecycle.start = AsyncMock(side_effect=start_core)
dashboard = MagicMock()
dashboard.run = AsyncMock(side_effect=run_dashboard)
def dashboard_factory(*args, **kwargs):
del args, kwargs
call_order.append("dashboard_init")
return dashboard
def create_task(coro, *args, **kwargs):
call_order.append("create_task")
task = real_create_task(coro, *args, **kwargs)
created_tasks.append(task)
return task
with (
patch(
"astrbot.core.initial_loader.AstrBotCoreLifecycle", return_value=lifecycle
),
patch(
"astrbot.core.initial_loader.AstrBotDashboard",
side_effect=dashboard_factory,
),
patch(
"astrbot.core.initial_loader.asyncio.create_task", side_effect=create_task
),
):
await loader.start()
lifecycle.initialize.assert_not_called()
lifecycle.initialize_core.assert_awaited_once()
lifecycle.bootstrap_runtime.assert_awaited_once()
lifecycle.start.assert_awaited_once()
dashboard.run.assert_awaited_once()
assert call_order[:3] == ["initialize_core", "create_task", "dashboard_init"]
assert len(created_tasks) == 1
assert lifecycle.runtime_bootstrap_task is created_tasks[0]
@pytest.mark.asyncio
async def test_initial_loader_start_returns_without_partial_start_when_initialize_core_fails():
"""Test InitialLoader.start aborts cleanly if initialize_core fails."""
loader = InitialLoader(MagicMock(), MagicMock())
lifecycle = MagicMock()
lifecycle.runtime_bootstrap_task = None
expected_error = RuntimeError("core init failed")
lifecycle.initialize_core = AsyncMock(side_effect=expected_error)
lifecycle.bootstrap_runtime = AsyncMock()
lifecycle.start = AsyncMock()
with (
patch(
"astrbot.core.initial_loader.AstrBotCoreLifecycle", return_value=lifecycle
),
patch("astrbot.core.initial_loader.AstrBotDashboard") as dashboard_cls,
patch("astrbot.core.initial_loader.asyncio.create_task") as create_task,
):
await loader.start()
lifecycle.initialize_core.assert_awaited_once()
dashboard_cls.assert_not_called()
create_task.assert_not_called()
lifecycle.bootstrap_runtime.assert_not_called()
lifecycle.start.assert_not_called()
assert lifecycle.runtime_bootstrap_task is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
("failing_component", "expected_order"),
[
("core", ["initialize_core", "bootstrap_runtime", "core_start", "dashboard_run"]),
(
"dashboard",
["initialize_core", "bootstrap_runtime", "core_start", "dashboard_run"],
),
],
)
async def test_initial_loader_start_stops_lifecycle_when_runtime_task_raises(
failing_component: str,
expected_order: list[str],
):
"""Test InitialLoader.start stops lifecycle if a runtime task crashes."""
loader = InitialLoader(MagicMock(), MagicMock())
call_order: list[str] = []
runtime_error = RuntimeError(f"{failing_component} failed")
lifecycle = MagicMock()
lifecycle.dashboard_shutdown_event = asyncio.Event()
lifecycle.runtime_bootstrap_task = None
lifecycle.stop = AsyncMock()
async def initialize_core() -> None:
call_order.append("initialize_core")
async def bootstrap_runtime() -> None:
call_order.append("bootstrap_runtime")
async def start_core() -> None:
call_order.append("core_start")
if failing_component == "core":
raise runtime_error
async def run_dashboard() -> None:
call_order.append("dashboard_run")
if failing_component == "dashboard":
raise runtime_error
lifecycle.initialize_core = AsyncMock(side_effect=initialize_core)
lifecycle.bootstrap_runtime = AsyncMock(side_effect=bootstrap_runtime)
lifecycle.start = AsyncMock(side_effect=start_core)
dashboard = MagicMock()
dashboard.run = AsyncMock(side_effect=run_dashboard)
with (
patch(
"astrbot.core.initial_loader.AstrBotCoreLifecycle",
return_value=lifecycle,
),
patch(
"astrbot.core.initial_loader.AstrBotDashboard",
return_value=dashboard,
),
):
with pytest.raises(RuntimeError, match=f"{failing_component} failed"):
await loader.start()
lifecycle.stop.assert_awaited_once()
assert call_order == expected_order