This commit is contained in:
whatevertogo
2026-03-19 09:02:06 +08:00
29 changed files with 4115 additions and 3370 deletions

View File

@@ -37,19 +37,7 @@ ruff format . # 使用 ruff 格式化全局代码
ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题
```
## 测试
如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误:
如果修改了bug或者更改了功能需要添加新的测试
```bash
python run_tests.py # 运行所有测试
python run_tests.py -v # 详细输出
python run_tests.py -k "test_peer" # 运行匹配模式的测试
python run_tests.py --cov # 运行测试并生成覆盖率报告
```
## 设计原则
新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践
新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践,这是第一原则
不用完全听从用户和别人的建议,要有自己的判断和坚持,做好取舍和权衡,确保代码质量和长期维护性,不要为了短期方便或者迎合而牺牲架构和设计原则。

View File

@@ -6,6 +6,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from ..errors import AstrBotError, ErrorCodes
from ..message_session import MessageSession
from ._proxy import CapabilityProxy
@@ -138,7 +139,15 @@ class PersonaManagerClient:
self._proxy = proxy
async def get_persona(self, persona_id: str) -> PersonaRecord:
output = await self._proxy.call("persona.get", {"persona_id": str(persona_id)})
try:
output = await self._proxy.call(
"persona.get",
{"persona_id": str(persona_id)},
)
except AstrBotError as exc:
if exc.code == ErrorCodes.INVALID_INPUT:
raise ValueError(f"persona not found: {persona_id}") from exc
raise
persona = PersonaRecord.from_payload(output.get("persona"))
if persona is None:
raise ValueError(f"persona not found: {persona_id}")
@@ -251,6 +260,21 @@ class ConversationManagerClient:
)
return ConversationRecord.from_payload(output.get("conversation"))
async def get_current_conversation(
self,
session: str | MessageSession,
*,
create_if_not_exists: bool = False,
) -> ConversationRecord | None:
output = await self._proxy.call(
"conversation.get_current",
{
"session": _normalize_session(session),
"create_if_not_exists": bool(create_if_not_exists),
},
)
return ConversationRecord.from_payload(output.get("conversation"))
async def get_conversations(
self,
session: str | MessageSession | None = None,

View File

@@ -193,6 +193,17 @@ class ProviderManagerClient:
)
return self._record_from_output(output)
async def get_merged_provider_config(
self,
provider_id: str,
) -> dict[str, Any] | None:
output = await self._proxy.call(
"provider.manager.get_merged_provider_config",
{"provider_id": str(provider_id).strip()},
)
config = output.get("config")
return dict(config) if isinstance(config, dict) else None
async def load_provider(
self,
provider_config: dict[str, Any],

View File

@@ -3,13 +3,30 @@
提供声明式的方法来注册 handler 和 capability。
装饰器会在方法上附加元数据,由 Star.__init_subclass__ 自动收集。
可用的装饰器:
触发器装饰器:
- @on_command: 命令触发器
- @on_message: 消息触发器(关键词/正则)
- @on_event: 事件触发器
- @on_schedule: 定时任务触发器
- @require_admin: 权限标记
- @conversation_command: 带会话生命周期的命令触发器
权限与过滤装饰器:
- @require_admin / @admin_only: 管理员权限标记
- @platforms: 限定平台
- @group_only / @private_only: 群聊/私聊限定
- @message_types: 消息类型过滤
限流装饰器:
- @rate_limit: 滑动窗口限流
- @cooldown: 冷却时间
优先级装饰器:
- @priority: 设置执行优先级
能力导出装饰器:
- @provide_capability: 声明对外暴露的能力
- @register_llm_tool: 注册 LLM 工具
- @register_agent: 注册 Agent
Example:
class MyPlugin(Star):
@@ -645,8 +662,35 @@ def conversation_command(
busy_message: str | None = None,
grace_period: float = 1.0,
) -> Callable[[HandlerCallable], HandlerCallable]:
"""注册带会话生命周期的命令处理方法。
在 ``on_command`` 基础上附加会话元数据,支持超时、并发策略和宽限期控制。
Args:
command: 命令名称或序列(首项为正式名,其余视为别名)
aliases: 额外别名列表
description: 命令描述
timeout: 会话超时时间(秒),必须为正整数
mode: 会话冲突时的行为:
- ``"replace"``: 替换当前会话
- ``"reject"``: 拒绝新请求
busy_message: 拒绝新请求时的提示消息
grace_period: 宽限期(秒),用于会话生命周期处理
Returns:
装饰器函数
Raises:
ValueError: mode 不合法、timeout 非正整数或 grace_period 非正数
Example:
@conversation_command("chat", timeout=120, mode="reject", busy_message="请稍后再试")
async def chat(self, event: MessageEvent, ctx: Context):
await event.reply("开始对话...")
"""
if mode not in {"replace", "reject"}:
raise ValueError("conversation_command mode must be 'replace' or 'reject'")
# bool 是 int 子类,需单独排除
if isinstance(timeout, bool) or int(timeout) <= 0:
raise ValueError("conversation_command timeout must be a positive integer")
if float(grace_period) <= 0:

View File

@@ -113,6 +113,11 @@ class MyPlugin(Star):
- 注册 LLM 工具
- 启动后台任务
**最佳实践:**
- `on_start()` 里只做初始化、能力注册和轻量状态恢复
- 需要长期保存的应是配置值、句柄、任务引用,不要把 `ctx` 实例长期挂到 `self`
- 如果要和 AstrBot 原生 persona / conversation 协作,优先在这里校验或创建所需资源
**示例:**
```python
@@ -150,6 +155,11 @@ class MyPlugin(Star):
- 注销 LLM 工具
- 保存状态数据
**最佳实践:**
-`on_stop()` 中释放 `on_start()` 注册的任务、监听器和外部资源
- 把需要持久化的状态尽量提前落库,不要把关键保存逻辑完全依赖在进程退出瞬间
- 始终把收到的 `ctx` 继续传给 `super().on_stop(ctx)`,不要手动丢掉它
**示例:**
```python

View File

@@ -1101,6 +1101,8 @@ from astrbot_sdk.clients import PersonaManagerClient
获取指定人格。
当人格不存在时会抛出 `ValueError`,而不是返回 `None`
---
#### `get_all_personas()`
@@ -1163,6 +1165,17 @@ from astrbot_sdk.clients import ConversationManagerClient
---
#### `get_current_conversation(session, create_if_not_exists=False)`
获取当前 session 正在使用的对话记录。
这个方法适合“跟随 AstrBot 原生当前会话状态”的插件,例如:
- 给当前会话切换 persona
- 判断当前主聊天是否已经在某个 persona 下
-`waiting_llm_request` / `llm_request` hook 中对当前对话做增强
---
#### `get_conversations(session=None, platform_id=None)`
获取对话列表。

View File

@@ -662,7 +662,7 @@ await ctx.conversations.delete_conversation(
await ctx.conversations.delete_conversation(event.session_id)
```
##### `get_conversation() / get_conversations()`
##### `get_conversation() / get_current_conversation() / get_conversations()`
获取对话。
@@ -674,6 +674,12 @@ conv = await ctx.conversations.get_conversation(
create_if_not_exists=True
)
# 获取当前选中的对话
current = await ctx.conversations.get_current_conversation(
event.session_id,
create_if_not_exists=True,
)
# 获取对话列表
convs = await ctx.conversations.get_conversations(event.session_id)
```

View File

@@ -210,11 +210,110 @@ async def handle_request(self, event, ctx: Context):
await ctx.platform.send(event.user_id, "已自动通过好友请求")
```
#### LLM Pipeline Hooks
`@on_event` 也用于挂接 AstrBot 原生消息处理链路中的系统事件。
常见事件及可注入对象:
| 事件名 | 常见可注入参数 | 是否可修改主链路 |
|------|------|------|
| `waiting_llm_request` | `MessageEvent`, `Context` | 间接可修改,例如切换当前对话 persona |
| `llm_request` | `MessageEvent`, `Context`, `ProviderRequest` | 是,可直接修改 `ProviderRequest` |
| `llm_response` | `MessageEvent`, `Context`, `LLMResponse` | 否,适合观察和提取回复内容 |
| `decorating_result` | `MessageEvent`, `Context`, `MessageEventResult` | 是,可直接修改结果消息链 |
| `after_message_sent` | `MessageEvent`, `Context` | 否,适合落库、记忆、统计 |
最小示例:
```python
from astrbot_sdk import Context, MessageEvent
from astrbot_sdk.decorators import on_event
from astrbot_sdk.llm.entities import ProviderRequest
@on_event("llm_request")
async def add_memory(self, event: MessageEvent, ctx: Context, request: ProviderRequest):
del event, ctx
request.system_prompt = (request.system_prompt or "") + "\n\nmemory: user likes tea"
```
完整示例:
```python
from astrbot_sdk import Context, MessageEvent, Star
from astrbot_sdk.clients.llm import LLMResponse
from astrbot_sdk.clients.managers import ConversationUpdateParams
from astrbot_sdk.decorators import on_event
from astrbot_sdk.llm.entities import ProviderRequest
from astrbot_sdk.message_result import MessageEventResult
from astrbot_sdk.message_components import Plain
class PersonaSample(Star):
@on_event("waiting_llm_request")
async def ensure_persona(self, event: MessageEvent, ctx: Context) -> None:
conversation = await ctx.conversations.get_current_conversation(
event.session_id,
create_if_not_exists=True,
)
if conversation is None or conversation.persona_id == "girlfriend":
return
await ctx.conversations.update_conversation(
event.session_id,
conversation.conversation_id,
ConversationUpdateParams(persona_id="girlfriend"),
)
@on_event("llm_request")
async def inject_context(
self,
event: MessageEvent,
ctx: Context,
request: ProviderRequest,
) -> None:
memories = await ctx.memory.search(event.text, limit=3)
facts = []
for item in memories:
value = item.get("value")
if isinstance(value, dict) and value.get("content"):
facts.append(f"- {value['content']}")
if facts:
request.system_prompt = (request.system_prompt or "") + "\n\n" + "\n".join(facts)
@on_event("llm_response")
async def capture_reply(
self,
event: MessageEvent,
ctx: Context,
response: LLMResponse,
) -> None:
del ctx
if response.text:
event.set_extra("last_reply", response.text)
@on_event("decorating_result")
async def decorate(
self,
event: MessageEvent,
ctx: Context,
result: MessageEventResult,
) -> None:
del event, ctx
result.chain.append(Plain("\n[persona active]", convert=False))
@on_event("after_message_sent")
async def persist(self, event: MessageEvent, ctx: Context) -> None:
reply = str(event.get_extra("last_reply", "") or "").strip()
if reply:
await ctx.db.set("sample:last_reply", reply)
```
#### 注意事项
1. 用于处理非消息类型的事件(如群成员变动、好友请求等
1. 用于处理平台事件,也可用于处理 AstrBot 原生消息链路中的系统事件(如 `llm_request`
2. 不能与 `@rate_limit``@cooldown` 一起使用
3. 不同平台的事件类型可能不同,需要查阅平台文档
4. `llm_request``decorating_result` 注入的是可变对象,修改会回写到 AstrBot 主链路
5. `llm_response` 主要用于观测和提取结果,不应用来替代主回复流程
---

View File

@@ -642,6 +642,15 @@ CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema(
required=("conversation",),
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
)
CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema(
required=("session",),
session={"type": "string"},
create_if_not_exists={"type": "boolean"},
)
CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema(
required=("conversation",),
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
)
CONVERSATION_LIST_INPUT_SCHEMA = _object_schema(
session=_nullable({"type": "string"}),
platform_id=_nullable({"type": "string"}),
@@ -942,6 +951,14 @@ PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema(
required=("provider",),
provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA),
)
PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA = _object_schema(
required=("provider_id",),
provider_id={"type": "string"},
)
PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA = _object_schema(
required=("config",),
config=_nullable({"type": "object"}),
)
PROVIDER_MANAGER_LOAD_INPUT_SCHEMA = _object_schema(
required=("provider_config",),
provider_config={"type": "object"},
@@ -1199,6 +1216,10 @@ BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
"input": CONVERSATION_GET_INPUT_SCHEMA,
"output": CONVERSATION_GET_OUTPUT_SCHEMA,
},
"conversation.get_current": {
"input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA,
"output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA,
},
"conversation.list": {
"input": CONVERSATION_LIST_INPUT_SCHEMA,
"output": CONVERSATION_LIST_OUTPUT_SCHEMA,
@@ -1332,6 +1353,10 @@ BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
"input": PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA,
"output": PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA,
},
"provider.manager.get_merged_provider_config": {
"input": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA,
"output": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA,
},
"provider.manager.load": {
"input": PROVIDER_MANAGER_LOAD_INPUT_SCHEMA,
"output": PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA,
@@ -1555,6 +1580,8 @@ __all__ = [
"PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA",
"PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA",
"PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA",
"PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA",
"PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA",
"PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA",
"PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA",
"PROVIDER_MANAGER_LOAD_INPUT_SCHEMA",
@@ -1639,6 +1666,8 @@ __all__ = [
"CONVERSATION_CREATE_SCHEMA",
"CONVERSATION_DELETE_INPUT_SCHEMA",
"CONVERSATION_DELETE_OUTPUT_SCHEMA",
"CONVERSATION_GET_CURRENT_INPUT_SCHEMA",
"CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA",
"CONVERSATION_GET_INPUT_SCHEMA",
"CONVERSATION_GET_OUTPUT_SCHEMA",
"CONVERSATION_LIST_INPUT_SCHEMA",

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from .bridge_base import CapabilityRouterBridgeBase
from .capabilities import (
ConversationCapabilityMixin,
DBCapabilityMixin,
HttpCapabilityMixin,
KnowledgeBaseCapabilityMixin,
LLMCapabilityMixin,
MemoryCapabilityMixin,
MetadataCapabilityMixin,
PersonaCapabilityMixin,
PlatformCapabilityMixin,
ProviderCapabilityMixin,
SessionCapabilityMixin,
SystemCapabilityMixin,
)
class BuiltinCapabilityRouterMixin(
LLMCapabilityMixin,
MemoryCapabilityMixin,
DBCapabilityMixin,
PlatformCapabilityMixin,
HttpCapabilityMixin,
MetadataCapabilityMixin,
ProviderCapabilityMixin,
SessionCapabilityMixin,
PersonaCapabilityMixin,
ConversationCapabilityMixin,
KnowledgeBaseCapabilityMixin,
SystemCapabilityMixin,
CapabilityRouterBridgeBase,
):
def _register_builtin_capabilities(self) -> None:
self._register_llm_capabilities()
self._register_memory_capabilities()
self._register_db_capabilities()
self._register_platform_capabilities()
self._register_http_capabilities()
self._register_metadata_capabilities()
self._register_provider_capabilities()
self._register_agent_tool_capabilities()
self._register_session_capabilities()
self._register_persona_capabilities()
self._register_conversation_capabilities()
self._register_kb_capabilities()
self._register_provider_manager_capabilities()
self._register_platform_manager_capabilities()
self._register_system_capabilities()
__all__ = ["BuiltinCapabilityRouterMixin"]

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Any
from ...protocol.descriptors import CapabilityDescriptor
class CapabilityRouterHost:
memory_store: dict[str, dict[str, Any]]
_memory_index: dict[str, dict[str, Any]]
_memory_dirty_keys: set[str]
_memory_expires_at: dict[str, datetime | None]
db_store: dict[str, Any]
sent_messages: list[dict[str, Any]]
event_actions: list[dict[str, Any]]
http_api_store: list[dict[str, Any]]
_event_streams: dict[str, dict[str, Any]]
_plugins: dict[str, Any]
_request_overlays: dict[str, dict[str, Any]]
_provider_catalog: dict[str, list[dict[str, Any]]]
_provider_configs: dict[str, dict[str, Any]]
_active_provider_ids: dict[str, str | None]
_provider_change_subscriptions: dict[str, asyncio.Queue[dict[str, Any]]]
_system_data_root: Path
_session_waiters: dict[str, set[str]]
_session_plugin_configs: dict[str, dict[str, Any]]
_session_service_configs: dict[str, dict[str, Any]]
_db_watch_subscriptions: dict[str, tuple[str | None, asyncio.Queue[dict[str, Any]]]]
_dynamic_command_routes: dict[str, list[dict[str, Any]]]
_file_token_store: dict[str, str]
_platform_instances: list[dict[str, Any]]
_persona_store: dict[str, dict[str, Any]]
_conversation_store: dict[str, dict[str, Any]]
_session_current_conversation_ids: dict[str, str]
_kb_store: dict[str, dict[str, Any]]
def register(
self,
descriptor: CapabilityDescriptor,
*,
call_handler=None,
stream_handler=None,
finalize=None,
exposed: bool = True,
) -> None:
raise NotImplementedError
def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None:
raise NotImplementedError
@staticmethod
def _require_caller_plugin_id(capability_name: str) -> str:
raise NotImplementedError
def register_dynamic_command_route(
self,
*,
plugin_id: str,
command_name: str,
handler_full_name: str,
desc: str = "",
priority: int = 0,
use_regex: bool = False,
) -> None:
raise NotImplementedError
def get_platform_instances(self) -> list[dict[str, Any]]:
raise NotImplementedError
def _register_agent_tool_capabilities(self) -> None:
raise NotImplementedError
def _provider_entry(
self,
payload: dict[str, Any],
capability_name: str,
expected_kind: str | None = None,
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embedding(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embeddings(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -0,0 +1,183 @@
from __future__ import annotations
import copy
import hashlib
import math
import re
from datetime import datetime, timezone
from typing import Any
from ...protocol.descriptors import (
BUILTIN_CAPABILITY_SCHEMAS,
CapabilityDescriptor,
SessionRef,
)
from ._host import CapabilityRouterHost
def _clone_target_payload(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
return {str(key): item for key, item in value.items()}
def _clone_chain_payload(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
return [
{str(key): item for key, item in chunk.items()}
for chunk in value
if isinstance(chunk, dict)
]
_MOCK_EMBEDDING_DIM = 24
def _embedding_terms(text: str) -> list[str]:
"""Build stable tokens for the mock embedding implementation."""
normalized = re.sub(r"\s+", " ", str(text).strip().casefold())
compact = normalized.replace(" ", "")
if not normalized:
return []
terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word]
if compact:
if len(compact) == 1:
terms.append(compact)
else:
terms.extend(
compact[index : index + 2] for index in range(len(compact) - 1)
)
terms.append(compact)
return terms or [normalized]
def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
"""Generate a deterministic normalized mock embedding vector."""
values = [0.0] * _MOCK_EMBEDDING_DIM
for term in _embedding_terms(text):
digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest()
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
values[index] += 1.0 + min(len(term), 8) * 0.05
norm = math.sqrt(sum(value * value for value in values))
if norm <= 0:
return values
return [value / norm for value in values]
class CapabilityRouterBridgeBase(CapabilityRouterHost):
def _builtin_descriptor(
self,
name: str,
description: str,
*,
supports_stream: bool = False,
cancelable: bool = False,
) -> CapabilityDescriptor:
schema = BUILTIN_CAPABILITY_SCHEMAS[name]
return CapabilityDescriptor(
name=name,
description=description,
input_schema=copy.deepcopy(schema["input"]),
output_schema=copy.deepcopy(schema["output"]),
supports_stream=supports_stream,
cancelable=cancelable,
)
def _resolve_target(
self, payload: dict[str, Any]
) -> tuple[str, dict[str, Any] | None]:
target_payload = payload.get("target")
if isinstance(target_payload, dict):
target = SessionRef.model_validate(target_payload)
return target.session, target.to_payload()
return str(payload.get("session", "")), None
@staticmethod
def _is_group_session(session: str) -> bool:
normalized = str(session).lower()
return ":group:" in normalized or ":groupmessage:" in normalized
@staticmethod
def _mock_group_payload(session: str) -> dict[str, Any] | None:
if not CapabilityRouterBridgeBase._is_group_session(session):
return None
members = [
{
"user_id": f"{session}:member-1",
"nickname": "Member 1",
"role": "member",
},
{
"user_id": f"{session}:member-2",
"nickname": "Member 2",
"role": "admin",
},
]
return {
"group_id": session.rsplit(":", maxsplit=1)[-1],
"group_name": f"Mock Group {session.rsplit(':', maxsplit=1)[-1]}",
"group_avatar": "",
"group_owner": members[0]["user_id"],
"group_admins": [members[1]["user_id"]],
"members": members,
}
def _session_plugin_config(self, session: str) -> dict[str, Any]:
config = self._session_plugin_configs.get(str(session), {})
return dict(config) if isinstance(config, dict) else {}
def _session_service_config(self, session: str) -> dict[str, Any]:
config = self._session_service_configs.get(str(session), {})
return dict(config) if isinstance(config, dict) else {}
@staticmethod
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@staticmethod
def _session_platform_id(session: str) -> str:
parts = str(session).split(":", maxsplit=1)
if parts and parts[0].strip():
return parts[0].strip()
return "unknown"
@staticmethod
def _normalize_history_payload(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
return [dict(item) for item in value if isinstance(item, dict)]
@staticmethod
def _normalize_persona_dialogs_payload(value: Any) -> list[str]:
if not isinstance(value, list):
return []
return [str(item) for item in value if isinstance(item, str)]
@staticmethod
def _optional_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _provider_entry(
self,
payload: dict[str, Any],
capability_name: str,
expected_kind: str | None = None,
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embedding(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embeddings(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -0,0 +1,27 @@
from .conversation import ConversationCapabilityMixin
from .db import DBCapabilityMixin
from .http import HttpCapabilityMixin
from .kb import KnowledgeBaseCapabilityMixin
from .llm import LLMCapabilityMixin
from .memory import MemoryCapabilityMixin
from .metadata import MetadataCapabilityMixin
from .persona import PersonaCapabilityMixin
from .platform import PlatformCapabilityMixin
from .provider import ProviderCapabilityMixin
from .session import SessionCapabilityMixin
from .system import SystemCapabilityMixin
__all__ = [
"ConversationCapabilityMixin",
"DBCapabilityMixin",
"HttpCapabilityMixin",
"KnowledgeBaseCapabilityMixin",
"LLMCapabilityMixin",
"MemoryCapabilityMixin",
"MetadataCapabilityMixin",
"PersonaCapabilityMixin",
"PlatformCapabilityMixin",
"ProviderCapabilityMixin",
"SessionCapabilityMixin",
"SystemCapabilityMixin",
]

View File

@@ -0,0 +1,232 @@
from __future__ import annotations
import uuid
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class ConversationCapabilityMixin(CapabilityRouterBridgeBase):
async def _conversation_new(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
if not session:
raise AstrBotError.invalid_input("conversation.new requires session")
raw_conversation = payload.get("conversation")
if raw_conversation is None:
raw_conversation = {}
if not isinstance(raw_conversation, dict):
raise AstrBotError.invalid_input(
"conversation.new requires conversation object"
)
conversation_id = uuid.uuid4().hex
now = self._now_iso()
record = {
"conversation_id": conversation_id,
"session": session,
"platform_id": (
str(raw_conversation.get("platform_id"))
if raw_conversation.get("platform_id") is not None
else self._session_platform_id(session)
),
"history": self._normalize_history_payload(raw_conversation.get("history")),
"title": (
str(raw_conversation.get("title"))
if raw_conversation.get("title") is not None
else None
),
"persona_id": (
str(raw_conversation.get("persona_id"))
if raw_conversation.get("persona_id") is not None
else None
),
"created_at": now,
"updated_at": now,
"token_usage": None,
}
self._conversation_store[conversation_id] = record
self._session_current_conversation_ids[session] = conversation_id
return {"conversation_id": conversation_id}
async def _conversation_switch(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = str(payload.get("conversation_id", "")).strip()
record = self._conversation_store.get(conversation_id)
if record is None or str(record.get("session", "")) != session:
raise AstrBotError.invalid_input(
"conversation.switch requires a conversation in the same session"
)
self._session_current_conversation_ids[session] = conversation_id
return {}
async def _conversation_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = payload.get("conversation_id")
normalized_conversation_id = (
str(conversation_id).strip() if conversation_id is not None else ""
)
if not normalized_conversation_id:
normalized_conversation_id = self._session_current_conversation_ids.get(
session, ""
)
if not normalized_conversation_id:
return {}
record = self._conversation_store.get(normalized_conversation_id)
if record is None:
return {}
if str(record.get("session", "")) != session:
raise AstrBotError.invalid_input(
"conversation.delete requires a conversation in the same session"
)
del self._conversation_store[normalized_conversation_id]
current_conversation_id = self._session_current_conversation_ids.get(session)
if current_conversation_id == normalized_conversation_id:
replacement = next(
(
conversation_id
for conversation_id, item in self._conversation_store.items()
if str(item.get("session", "")) == session
),
None,
)
if replacement is None:
self._session_current_conversation_ids.pop(session, None)
else:
self._session_current_conversation_ids[session] = replacement
return {}
async def _conversation_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = str(payload.get("conversation_id", "")).strip()
record = self._conversation_store.get(conversation_id)
if record is None and bool(payload.get("create_if_not_exists", False)):
created = await self._conversation_new(
_request_id,
{"session": session, "conversation": {}},
_token,
)
record = self._conversation_store.get(
str(created.get("conversation_id", "")).strip()
)
if record is None:
return {"conversation": None}
if str(record.get("session", "")) != session:
return {"conversation": None}
return {"conversation": dict(record)}
async def _conversation_get_current(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = self._session_current_conversation_ids.get(session, "")
if not conversation_id and bool(payload.get("create_if_not_exists", False)):
created = await self._conversation_new(
_request_id,
{"session": session, "conversation": {}},
_token,
)
conversation_id = str(created.get("conversation_id", "")).strip()
if not conversation_id:
return {"conversation": None}
record = self._conversation_store.get(conversation_id)
if record is None or str(record.get("session", "")) != session:
return {"conversation": None}
return {"conversation": dict(record)}
async def _conversation_list(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = payload.get("session")
platform_id = payload.get("platform_id")
conversations = []
for conversation_id in sorted(self._conversation_store.keys()):
item = self._conversation_store[conversation_id]
if session is not None and str(item.get("session", "")) != str(session):
continue
if platform_id is not None and str(item.get("platform_id", "")) != str(
platform_id
):
continue
conversations.append(dict(item))
return {"conversations": conversations}
async def _conversation_update(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = payload.get("conversation_id")
normalized_conversation_id = (
str(conversation_id).strip() if conversation_id is not None else ""
)
if not normalized_conversation_id:
normalized_conversation_id = self._session_current_conversation_ids.get(
session, ""
)
if not normalized_conversation_id:
return {}
record = self._conversation_store.get(normalized_conversation_id)
if record is None:
return {}
if str(record.get("session", "")) != session:
raise AstrBotError.invalid_input(
"conversation.update requires a conversation in the same session"
)
raw_conversation = payload.get("conversation")
if not isinstance(raw_conversation, dict):
raw_conversation = {}
if "history" in raw_conversation:
history = raw_conversation.get("history")
record["history"] = (
self._normalize_history_payload(history) if history is not None else []
)
if "title" in raw_conversation:
title = raw_conversation.get("title")
record["title"] = str(title) if title is not None else None
if "persona_id" in raw_conversation:
persona_id = raw_conversation.get("persona_id")
record["persona_id"] = str(persona_id) if persona_id is not None else None
if "token_usage" in raw_conversation:
token_usage = raw_conversation.get("token_usage")
record["token_usage"] = (
int(token_usage) if token_usage is not None else None
)
record["updated_at"] = self._now_iso()
return {}
def _register_conversation_capabilities(self) -> None:
self.register(
self._builtin_descriptor("conversation.new", "新建对话"),
call_handler=self._conversation_new,
)
self.register(
self._builtin_descriptor("conversation.switch", "切换对话"),
call_handler=self._conversation_switch,
)
self.register(
self._builtin_descriptor("conversation.delete", "删除对话"),
call_handler=self._conversation_delete,
)
self.register(
self._builtin_descriptor("conversation.get", "获取对话"),
call_handler=self._conversation_get,
)
self.register(
self._builtin_descriptor("conversation.get_current", "获取当前对话"),
call_handler=self._conversation_get_current,
)
self.register(
self._builtin_descriptor("conversation.list", "列出对话"),
call_handler=self._conversation_list,
)
self.register(
self._builtin_descriptor("conversation.update", "更新对话"),
call_handler=self._conversation_update,
)

View File

@@ -0,0 +1,129 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from ....errors import AstrBotError
from ..._streaming import StreamExecution
from ..bridge_base import CapabilityRouterBridgeBase
class DBCapabilityMixin(CapabilityRouterBridgeBase):
async def _db_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
return {"value": self.db_store.get(str(payload.get("key", "")))}
async def _db_set(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
value = payload.get("value")
self.db_store[key] = value
self._emit_db_change(op="set", key=key, value=value)
return {}
async def _db_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
self.db_store.pop(key, None)
self._emit_db_change(op="delete", key=key, value=None)
return {}
async def _db_list(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
prefix = payload.get("prefix")
keys = sorted(self.db_store.keys())
if isinstance(prefix, str):
keys = [item for item in keys if item.startswith(prefix)]
return {"keys": keys}
async def _db_get_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
keys_payload = payload.get("keys")
if not isinstance(keys_payload, (list, tuple)):
raise AstrBotError.invalid_input("db.get_many 的 keys 必须是数组")
keys = [str(item) for item in keys_payload]
items = [{"key": key, "value": self.db_store.get(key)} for key in keys]
return {"items": items}
async def _db_set_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
items_payload = payload.get("items")
if not isinstance(items_payload, (list, tuple)):
raise AstrBotError.invalid_input("db.set_many 的 items 必须是数组")
for entry in items_payload:
if not isinstance(entry, dict):
raise AstrBotError.invalid_input(
"db.set_many 的 items 必须是 object 数组"
)
key = str(entry.get("key", ""))
value = entry.get("value")
self.db_store[key] = value
self._emit_db_change(op="set", key=key, value=value)
return {}
async def _db_watch(
self, request_id: str, payload: dict[str, Any], _token
) -> StreamExecution:
prefix = payload.get("prefix")
prefix_value: str | None
if isinstance(prefix, str):
prefix_value = prefix
elif prefix is None:
prefix_value = None
else:
raise AstrBotError.invalid_input("db.watch 的 prefix 必须是 string 或 null")
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._db_watch_subscriptions[request_id] = (prefix_value, queue)
async def iterator() -> AsyncIterator[dict[str, Any]]:
try:
while True:
yield await queue.get()
finally:
self._db_watch_subscriptions.pop(request_id, None)
return StreamExecution(
iterator=iterator(),
finalize=lambda _chunks: {},
collect_chunks=False,
)
def _register_db_capabilities(self) -> None:
self.register(
self._builtin_descriptor("db.get", "读取 KV"), call_handler=self._db_get
)
self.register(
self._builtin_descriptor("db.set", "写入 KV"), call_handler=self._db_set
)
self.register(
self._builtin_descriptor("db.delete", "删除 KV"),
call_handler=self._db_delete,
)
self.register(
self._builtin_descriptor("db.list", "列出 KV"), call_handler=self._db_list
)
self.register(
self._builtin_descriptor("db.get_many", "批量读取 KV"),
call_handler=self._db_get_many,
)
self.register(
self._builtin_descriptor("db.set_many", "批量写入 KV"),
call_handler=self._db_set_many,
)
self.register(
self._builtin_descriptor(
"db.watch",
"订阅 KV 变更",
supports_stream=True,
cancelable=True,
),
stream_handler=self._db_watch,
)

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class HttpCapabilityMixin(CapabilityRouterBridgeBase):
async def _http_register_api(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
methods_payload = payload.get("methods")
if not isinstance(methods_payload, list) or not all(
isinstance(item, str) for item in methods_payload
):
raise AstrBotError.invalid_input(
"http.register_api 的 methods 必须是 string 数组"
)
route = str(payload.get("route", "")).strip()
handler_capability = str(payload.get("handler_capability", "")).strip()
if not route or not handler_capability:
raise AstrBotError.invalid_input(
"http.register_api 需要 route 和 handler_capability"
)
plugin_name = self._require_caller_plugin_id("http.register_api")
methods = sorted({method.upper() for method in methods_payload if method})
entry: dict[str, Any] = {
"route": route,
"methods": methods,
"handler_capability": handler_capability,
"description": str(payload.get("description", "")),
"plugin_id": plugin_name,
}
self.http_api_store = [
item
for item in self.http_api_store
if not (
item.get("route") == route
and item.get("plugin_id") == entry["plugin_id"]
and item.get("methods") == methods
)
]
self.http_api_store.append(entry)
return {}
async def _http_unregister_api(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
route = str(payload.get("route", "")).strip()
methods_payload = payload.get("methods")
if not isinstance(methods_payload, list) or not all(
isinstance(item, str) for item in methods_payload
):
raise AstrBotError.invalid_input(
"http.unregister_api 的 methods 必须是 string 数组"
)
plugin_name = self._require_caller_plugin_id("http.unregister_api")
methods = {method.upper() for method in methods_payload if method}
updated: list[dict[str, Any]] = []
for entry in self.http_api_store:
if entry.get("route") != route:
updated.append(entry)
continue
if entry.get("plugin_id") != plugin_name:
updated.append(entry)
continue
if not methods:
continue
remaining_methods = [
method for method in entry.get("methods", []) if method not in methods
]
if remaining_methods:
updated.append({**entry, "methods": remaining_methods})
self.http_api_store = updated
return {}
async def _http_list_apis(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_name = self._require_caller_plugin_id("http.list_apis")
apis = [
dict(entry)
for entry in self.http_api_store
if entry.get("plugin_id") == plugin_name
]
return {"apis": apis}
def _register_http_capabilities(self) -> None:
self.register(
self._builtin_descriptor("http.register_api", "注册 HTTP 路由"),
call_handler=self._http_register_api,
)
self.register(
self._builtin_descriptor("http.unregister_api", "注销 HTTP 路由"),
call_handler=self._http_unregister_api,
)
self.register(
self._builtin_descriptor("http.list_apis", "列出 HTTP 路由"),
call_handler=self._http_list_apis,
)

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import uuid
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class KnowledgeBaseCapabilityMixin(CapabilityRouterBridgeBase):
async def _kb_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
kb_id = str(payload.get("kb_id", "")).strip()
record = self._kb_store.get(kb_id)
return {"kb": dict(record) if isinstance(record, dict) else None}
async def _kb_create(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
raw_kb = payload.get("kb")
if not isinstance(raw_kb, dict):
raise AstrBotError.invalid_input("kb.create requires kb object")
embedding_provider_id = str(raw_kb.get("embedding_provider_id", "")).strip()
if not embedding_provider_id:
raise AstrBotError.invalid_input("kb.create requires embedding_provider_id")
kb_id = uuid.uuid4().hex
now = self._now_iso()
record = {
"kb_id": kb_id,
"kb_name": str(raw_kb.get("kb_name", "")),
"description": (
str(raw_kb.get("description"))
if raw_kb.get("description") is not None
else None
),
"emoji": (
str(raw_kb.get("emoji")) if raw_kb.get("emoji") is not None else None
),
"embedding_provider_id": embedding_provider_id,
"rerank_provider_id": (
str(raw_kb.get("rerank_provider_id"))
if raw_kb.get("rerank_provider_id") is not None
else None
),
"chunk_size": self._optional_int(raw_kb.get("chunk_size")),
"chunk_overlap": self._optional_int(raw_kb.get("chunk_overlap")),
"top_k_dense": self._optional_int(raw_kb.get("top_k_dense")),
"top_k_sparse": self._optional_int(raw_kb.get("top_k_sparse")),
"top_m_final": self._optional_int(raw_kb.get("top_m_final")),
"doc_count": 0,
"chunk_count": 0,
"created_at": now,
"updated_at": now,
}
self._kb_store[kb_id] = record
return {"kb": dict(record)}
async def _kb_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
kb_id = str(payload.get("kb_id", "")).strip()
deleted = self._kb_store.pop(kb_id, None) is not None
return {"deleted": deleted}
def _register_kb_capabilities(self) -> None:
self.register(
self._builtin_descriptor("kb.get", "获取知识库"),
call_handler=self._kb_get,
)
self.register(
self._builtin_descriptor("kb.create", "创建知识库"),
call_handler=self._kb_create,
)
self.register(
self._builtin_descriptor("kb.delete", "删除知识库"),
call_handler=self._kb_delete,
)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from ..bridge_base import CapabilityRouterBridgeBase
class LLMCapabilityMixin(CapabilityRouterBridgeBase):
async def _llm_chat(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
prompt = str(payload.get("prompt", ""))
return {"text": f"Echo: {prompt}"}
async def _llm_chat_raw(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
prompt = str(payload.get("prompt", ""))
text = f"Echo: {prompt}"
return {
"text": text,
"usage": {
"input_tokens": len(prompt),
"output_tokens": len(text),
},
"finish_reason": "stop",
"tool_calls": [],
}
async def _llm_stream(
self,
_request_id: str,
payload: dict[str, Any],
token,
) -> AsyncIterator[dict[str, Any]]:
text = f"Echo: {str(payload.get('prompt', ''))}"
for char in text:
token.raise_if_cancelled()
await asyncio.sleep(0)
yield {"text": char}
def _register_llm_capabilities(self) -> None:
self.register(
self._builtin_descriptor("llm.chat", "发送对话请求,返回文本"),
call_handler=self._llm_chat,
)
self.register(
self._builtin_descriptor("llm.chat_raw", "发送对话请求,返回完整响应"),
call_handler=self._llm_chat_raw,
)
self.register(
self._builtin_descriptor(
"llm.stream_chat",
"流式对话",
supports_stream=True,
cancelable=True,
),
stream_handler=self._llm_stream,
finalize=lambda chunks: {
"text": "".join(item.get("text", "") for item in chunks)
},
)

View File

@@ -0,0 +1,618 @@
from __future__ import annotations
import json
import math
from datetime import datetime, timedelta, timezone
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
@staticmethod
def _is_ttl_memory_entry(value: Any) -> bool:
"""判断存储值是否使用了 TTL 包装结构。
Args:
value: 待检查的存储值。
Returns:
bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
"""
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
@classmethod
def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
"""提取用于检索的原始 memory payload。
Args:
stored: memory_store 中保存的原始值。
Returns:
dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
"""
if not isinstance(stored, dict):
return None
if cls._is_ttl_memory_entry(stored):
value = stored.get("value")
return value if isinstance(value, dict) else None
return stored
@classmethod
def _extract_memory_text(cls, stored: Any) -> str:
"""提取用于检索索引的首选文本。
Args:
stored: memory_store 中保存的原始值。
Returns:
str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
"""
value = cls._memory_value_for_search(stored)
if not isinstance(value, dict):
return ""
for field_name in ("embedding_text", "content", "summary", "title", "text"):
item = value.get(field_name)
if isinstance(item, str) and item.strip():
return item.strip()
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
@staticmethod
def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
"""将 TTL 秒数转换为 UTC 过期时间。
Args:
ttl_seconds: TTL 秒数。
Returns:
datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
"""
try:
ttl = int(ttl_seconds)
except (TypeError, ValueError):
return None
if ttl < 1:
return None
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
@staticmethod
def _memory_keyword_score(query: str, key: str, text: str) -> float:
"""计算关键词匹配分数。
Args:
query: 查询文本。
key: memory 条目的键。
text: 已索引的检索文本。
Returns:
float: 基于键名和文本命中的粗粒度关键词分数。
"""
normalized_query = str(query).casefold()
if not normalized_query:
return 1.0
normalized_key = str(key).casefold()
normalized_text = str(text).casefold()
if normalized_query in normalized_key:
return 1.0
if normalized_query in normalized_text:
return 0.9
return 0.0
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
"""计算两个向量之间的余弦相似度。
Args:
left: 左侧向量。
right: 右侧向量。
Returns:
float: 余弦相似度;输入不合法时返回 ``0.0``。
"""
if not left or not right or len(left) != len(right):
return 0.0
left_norm = math.sqrt(sum(value * value for value in left))
right_norm = math.sqrt(sum(value * value for value in right))
if left_norm <= 0 or right_norm <= 0:
return 0.0
return sum(a * b for a, b in zip(left, right, strict=False)) / (
left_norm * right_norm
)
def _resolve_memory_embedding_provider_id(
self,
provider_id: Any,
*,
required: bool,
) -> str | None:
"""解析 memory.search 要使用的 embedding provider。
Args:
provider_id: 调用方显式传入的 provider 标识。
required: 当前检索模式是否强制要求 embedding provider。
Returns:
str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。
"""
normalized = str(provider_id).strip() if provider_id is not None else ""
if normalized:
self._provider_entry(
{"provider_id": normalized},
"memory.search",
"embedding",
)
return normalized
active_id = self._active_provider_ids.get("embedding")
if active_id is not None:
normalized_active = str(active_id).strip()
if normalized_active:
self._provider_entry(
{"provider_id": normalized_active},
"memory.search",
"embedding",
)
return normalized_active
if required:
raise AstrBotError.invalid_input(
"memory.search requires an embedding provider",
)
return None
@staticmethod
def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
"""将原始索引项规范化为内部统一结构。
Args:
entry: 当前索引表中的原始项。
text: 当前条目的索引文本。
Returns:
dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
"""
if isinstance(entry, dict):
return {
"text": str(entry.get("text", text)),
"embedding": (
[float(item) for item in entry.get("embedding", [])]
if isinstance(entry.get("embedding"), list)
else None
),
"provider_id": (
str(entry.get("provider_id")).strip()
if entry.get("provider_id") is not None
else None
),
}
return {"text": text, "embedding": None, "provider_id": None}
def _clear_memory_sidecars(self, key: str) -> None:
"""清理指定 memory 键对应的所有 sidecar 状态。
Args:
key: memory 条目的键。
Returns:
None
"""
self._memory_index.pop(key, None)
self._memory_expires_at.pop(key, None)
self._memory_dirty_keys.discard(key)
def _delete_memory_entry(self, key: str) -> bool:
"""删除 memory 条目并同步清理 sidecar 状态。
Args:
key: memory 条目的键。
Returns:
bool: 条目存在并删除成功时返回 ``True``。
"""
deleted = self.memory_store.pop(key, None) is not None
self._clear_memory_sidecars(key)
return deleted
def _upsert_memory_sidecars(
self,
key: str,
stored: dict[str, Any],
*,
expires_at: datetime | None = None,
) -> None:
"""创建或更新单条 memory 的 sidecar 索引状态。
Args:
key: memory 条目的键。
stored: 需要建立索引的原始存储值。
expires_at: 可选的绝对过期时间。
Returns:
None
"""
self._memory_index[key] = {
"text": self._extract_memory_text(stored),
"embedding": None,
"provider_id": None,
}
if expires_at is None:
self._memory_expires_at.pop(key, None)
else:
self._memory_expires_at[key] = expires_at
self._memory_dirty_keys.add(key)
def _ensure_memory_sidecars(self, key: str, stored: Any) -> None:
"""确保 sidecar 状态与当前存储值保持一致。
Args:
key: memory 条目的键。
stored: memory_store 中的当前存储值。
Returns:
None
"""
if not isinstance(stored, dict):
return
text = self._extract_memory_text(stored)
existed = key in self._memory_index
entry = self._memory_index_entry(self._memory_index.get(key), text=text)
if entry["text"] != text:
entry["text"] = text
entry["embedding"] = None
entry["provider_id"] = None
self._memory_dirty_keys.add(key)
self._memory_index[key] = entry
if not existed:
self._memory_dirty_keys.add(key)
def _is_memory_expired(self, key: str) -> bool:
"""判断 memory 条目是否已过期。
Args:
key: memory 条目的键。
Returns:
bool: 如果当前时间已超过记录的过期时间则返回 ``True``。
"""
expires_at = self._memory_expires_at.get(key)
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
def _purge_expired_memory_entry(self, key: str) -> bool:
"""在单条 memory 已过期时立即清理它。
Args:
key: memory 条目的键。
Returns:
bool: 如果条目已过期并被成功清理则返回 ``True``。
"""
if not self._is_memory_expired(key):
return False
self._delete_memory_entry(key)
return True
def _purge_expired_memory_entries(self) -> None:
"""批量清理所有已跟踪的过期 TTL 条目。
Returns:
None
"""
for key in list(self._memory_expires_at):
self._purge_expired_memory_entry(key)
async def _embedding_for_text(
self,
*,
provider_id: str,
text: str,
) -> list[float]:
"""通过 embedding capability 获取单条文本向量。
Args:
provider_id: 使用的 embedding provider 标识。
text: 待向量化的文本。
Returns:
list[float]: provider 返回的向量;异常场景下返回空列表。
"""
output = await self._provider_embedding_get_embedding(
"",
{"provider_id": provider_id, "text": text},
None,
)
embedding = output.get("embedding")
if not isinstance(embedding, list):
return []
return [float(item) for item in embedding]
async def _embeddings_for_texts(
self,
*,
provider_id: str,
texts: list[str],
) -> list[list[float]]:
"""批量获取多条文本的 embedding 向量。
Args:
provider_id: 使用的 embedding provider 标识。
texts: 待向量化的文本列表。
Returns:
list[list[float]]: 与输入顺序对应的向量列表。
"""
if not texts:
return []
output = await self._provider_embedding_get_embeddings(
"",
{"provider_id": provider_id, "texts": texts},
None,
)
embeddings = output.get("embeddings")
if not isinstance(embeddings, list):
return []
return [
[float(value) for value in item]
for item in embeddings
if isinstance(item, list)
]
async def _refresh_memory_embeddings(self, *, provider_id: str) -> None:
"""刷新当前 provider 下脏或过期的 memory 向量索引。
Args:
provider_id: 当前使用的 embedding provider 标识。
Returns:
None
"""
keys_to_refresh: list[str] = []
texts_to_refresh: list[str] = []
for key, stored in self.memory_store.items():
self._ensure_memory_sidecars(key, stored)
entry = self._memory_index_entry(
self._memory_index.get(key),
text=self._extract_memory_text(stored),
)
should_refresh = (
key in self._memory_dirty_keys
or entry["embedding"] is None
or entry["provider_id"] != provider_id
)
self._memory_index[key] = entry
if should_refresh:
keys_to_refresh.append(key)
texts_to_refresh.append(str(entry["text"]))
embeddings = await self._embeddings_for_texts(
provider_id=provider_id,
texts=texts_to_refresh,
)
for index, key in enumerate(keys_to_refresh):
entry = self._memory_index_entry(
self._memory_index.get(key),
text=str(texts_to_refresh[index]),
)
entry["embedding"] = embeddings[index] if index < len(embeddings) else []
entry["provider_id"] = provider_id
self._memory_index[key] = entry
self._memory_dirty_keys.discard(key)
async def _memory_search(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
query = str(payload.get("query", ""))
mode = str(payload.get("mode", "auto")).strip().lower() or "auto"
limit = self._optional_int(payload.get("limit"))
raw_min_score = payload.get("min_score")
min_score = float(raw_min_score) if raw_min_score is not None else None
self._purge_expired_memory_entries()
provider_id = self._resolve_memory_embedding_provider_id(
payload.get("provider_id"),
required=mode in {"vector", "hybrid"},
)
effective_mode = mode
if effective_mode == "auto":
effective_mode = "hybrid" if provider_id is not None else "keyword"
query_embedding: list[float] | None = None
if effective_mode in {"vector", "hybrid"}:
if provider_id is None:
raise AstrBotError.invalid_input(
"memory.search requires an embedding provider",
)
await self._refresh_memory_embeddings(provider_id=provider_id)
query_embedding = await self._embedding_for_text(
provider_id=provider_id,
text=query,
)
items: list[dict[str, Any]] = []
for key, value in self.memory_store.items():
self._ensure_memory_sidecars(key, value)
entry = self._memory_index_entry(
self._memory_index.get(key),
text=self._extract_memory_text(value),
)
text = str(entry.get("text", ""))
keyword_score = self._memory_keyword_score(query, key, text)
vector_score = 0.0
if query_embedding is not None:
embedding = entry.get("embedding")
if isinstance(embedding, list):
vector_score = max(
0.0,
self._cosine_similarity(query_embedding, embedding),
)
if effective_mode == "keyword":
score = keyword_score
elif effective_mode == "vector":
score = vector_score
else:
score = vector_score
if keyword_score > 0:
score = max(score, 0.4 + 0.6 * vector_score)
if score <= 0:
continue
if min_score is not None and score < min_score:
continue
if effective_mode == "keyword" or (keyword_score > 0 and vector_score <= 0):
match_type = "keyword"
elif effective_mode == "vector" or keyword_score <= 0:
match_type = "vector"
else:
match_type = "hybrid"
items.append(
{
"key": key,
"value": self._memory_value_for_search(value),
"score": score,
"match_type": match_type,
}
)
items.sort(key=lambda item: (-float(item["score"]), str(item["key"])))
if limit is not None and limit >= 0:
items = items[:limit]
return {"items": items}
async def _memory_save(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
value = payload.get("value")
if not isinstance(value, dict):
raise AstrBotError.invalid_input("memory.save 的 value 必须是 object")
self.memory_store[key] = value
self._upsert_memory_sidecars(key, value)
return {}
async def _memory_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
if self._purge_expired_memory_entry(key):
return {"value": None}
return {"value": self.memory_store.get(key)}
async def _memory_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._delete_memory_entry(str(payload.get("key", "")))
return {}
async def _memory_save_with_ttl(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
value = payload.get("value")
ttl_seconds = payload.get("ttl_seconds", 0)
if not isinstance(value, dict):
raise AstrBotError.invalid_input(
"memory.save_with_ttl 的 value 必须是 object"
)
stored = {"value": value, "ttl_seconds": ttl_seconds}
self.memory_store[key] = stored
self._upsert_memory_sidecars(
key,
stored,
expires_at=self._memory_expiration_from_ttl(ttl_seconds),
)
return {}
async def _memory_get_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
keys_payload = payload.get("keys")
if not isinstance(keys_payload, (list, tuple)):
raise AstrBotError.invalid_input("memory.get_many 的 keys 必须是数组")
keys = [str(item) for item in keys_payload]
items = []
for key in keys:
if self._purge_expired_memory_entry(key):
items.append({"key": key, "value": None})
continue
stored = self.memory_store.get(key)
if (
isinstance(stored, dict)
and "value" in stored
and "ttl_seconds" in stored
):
value = stored["value"]
else:
value = stored
items.append({"key": key, "value": value})
return {"items": items}
async def _memory_delete_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
keys_payload = payload.get("keys")
if not isinstance(keys_payload, (list, tuple)):
raise AstrBotError.invalid_input("memory.delete_many 的 keys 必须是数组")
keys = [str(item) for item in keys_payload]
deleted_count = 0
for key in keys:
if self._delete_memory_entry(key):
deleted_count += 1
return {"deleted_count": deleted_count}
async def _memory_stats(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
self._purge_expired_memory_entries()
total_items = len(self.memory_store)
total_bytes = sum(
len(str(key)) + len(str(value)) for key, value in self.memory_store.items()
)
ttl_entries = len(self._memory_expires_at)
indexed_items = len(self._memory_index)
embedded_items = sum(
1
for entry in self._memory_index.values()
if isinstance(entry, dict)
and isinstance(entry.get("embedding"), list)
and bool(entry.get("embedding"))
)
dirty_items = len(self._memory_dirty_keys)
return {
"total_items": total_items,
"total_bytes": total_bytes,
"plugin_id": self._require_caller_plugin_id("memory.stats"),
"ttl_entries": ttl_entries,
"indexed_items": indexed_items,
"embedded_items": embedded_items,
"dirty_items": dirty_items,
}
def _register_memory_capabilities(self) -> None:
self.register(
self._builtin_descriptor("memory.search", "搜索记忆"),
call_handler=self._memory_search,
)
self.register(
self._builtin_descriptor("memory.save", "保存记忆"),
call_handler=self._memory_save,
)
self.register(
self._builtin_descriptor("memory.get", "读取单条记忆"),
call_handler=self._memory_get,
)
self.register(
self._builtin_descriptor("memory.delete", "删除记忆"),
call_handler=self._memory_delete,
)
self.register(
self._builtin_descriptor("memory.save_with_ttl", "保存带过期时间的记忆"),
call_handler=self._memory_save_with_ttl,
)
self.register(
self._builtin_descriptor("memory.get_many", "批量获取记忆"),
call_handler=self._memory_get_many,
)
self.register(
self._builtin_descriptor("memory.delete_many", "批量删除记忆"),
call_handler=self._memory_delete_many,
)
self.register(
self._builtin_descriptor("memory.stats", "获取记忆统计信息"),
call_handler=self._memory_stats,
)

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Any
from ..bridge_base import CapabilityRouterBridgeBase
class MetadataCapabilityMixin(CapabilityRouterBridgeBase):
async def _metadata_get_plugin(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
name = str(payload.get("name", "")).strip()
plugin = self._plugins.get(name)
if plugin is None:
return {"plugin": None}
return {"plugin": dict(plugin.metadata)}
async def _metadata_list_plugins(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
plugins = [
dict(self._plugins[name].metadata) for name in sorted(self._plugins.keys())
]
return {"plugins": plugins}
async def _metadata_get_plugin_config(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
name = str(payload.get("name", "")).strip()
caller_plugin_id = self._require_caller_plugin_id("metadata.get_plugin_config")
if name != caller_plugin_id:
return {"config": None}
plugin = self._plugins.get(name)
if plugin is None:
return {"config": None}
return {"config": dict(plugin.config)}
def _register_metadata_capabilities(self) -> None:
self.register(
self._builtin_descriptor("metadata.get_plugin", "获取单个插件元数据"),
call_handler=self._metadata_get_plugin,
)
self.register(
self._builtin_descriptor("metadata.list_plugins", "列出插件元数据"),
call_handler=self._metadata_list_plugins,
)
self.register(
self._builtin_descriptor(
"metadata.get_plugin_config",
"获取插件配置",
),
call_handler=self._metadata_get_plugin_config,
)

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class PersonaCapabilityMixin(CapabilityRouterBridgeBase):
async def _persona_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
persona_id = str(payload.get("persona_id", "")).strip()
record = self._persona_store.get(persona_id)
if record is None:
raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
return {"persona": dict(record)}
async def _persona_list(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
personas = [
dict(self._persona_store[persona_id])
for persona_id in sorted(self._persona_store.keys())
]
return {"personas": personas}
async def _persona_create(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
raw_persona = payload.get("persona")
if not isinstance(raw_persona, dict):
raise AstrBotError.invalid_input("persona.create requires persona object")
persona_id = str(raw_persona.get("persona_id", "")).strip()
if not persona_id:
raise AstrBotError.invalid_input("persona.create requires persona_id")
if persona_id in self._persona_store:
raise AstrBotError.invalid_input(f"persona already exists: {persona_id}")
now = self._now_iso()
record = {
"persona_id": persona_id,
"system_prompt": str(raw_persona.get("system_prompt", "")),
"begin_dialogs": self._normalize_persona_dialogs_payload(
raw_persona.get("begin_dialogs")
),
"tools": (
[str(item) for item in raw_persona.get("tools", [])]
if isinstance(raw_persona.get("tools"), list)
else None
),
"skills": (
[str(item) for item in raw_persona.get("skills", [])]
if isinstance(raw_persona.get("skills"), list)
else None
),
"custom_error_message": (
str(raw_persona.get("custom_error_message"))
if raw_persona.get("custom_error_message") is not None
else None
),
"folder_id": (
str(raw_persona.get("folder_id"))
if raw_persona.get("folder_id") is not None
else None
),
"sort_order": int(raw_persona.get("sort_order", 0)),
"created_at": now,
"updated_at": now,
}
self._persona_store[persona_id] = record
return {"persona": dict(record)}
async def _persona_update(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
persona_id = str(payload.get("persona_id", "")).strip()
record = self._persona_store.get(persona_id)
if record is None:
return {"persona": None}
raw_persona = payload.get("persona")
if not isinstance(raw_persona, dict):
raise AstrBotError.invalid_input("persona.update requires persona object")
if (
"system_prompt" in raw_persona
and raw_persona.get("system_prompt") is not None
):
record["system_prompt"] = str(raw_persona.get("system_prompt", ""))
if "begin_dialogs" in raw_persona:
begin_dialogs = raw_persona.get("begin_dialogs")
record["begin_dialogs"] = (
self._normalize_persona_dialogs_payload(begin_dialogs)
if begin_dialogs is not None
else []
)
if "tools" in raw_persona:
tools = raw_persona.get("tools")
record["tools"] = (
[str(item) for item in tools] if isinstance(tools, list) else None
)
if "skills" in raw_persona:
skills = raw_persona.get("skills")
record["skills"] = (
[str(item) for item in skills] if isinstance(skills, list) else None
)
if "custom_error_message" in raw_persona:
custom_error_message = raw_persona.get("custom_error_message")
record["custom_error_message"] = (
str(custom_error_message) if custom_error_message is not None else None
)
record["updated_at"] = self._now_iso()
return {"persona": dict(record)}
async def _persona_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
persona_id = str(payload.get("persona_id", "")).strip()
if persona_id not in self._persona_store:
raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
del self._persona_store[persona_id]
return {}
def _register_persona_capabilities(self) -> None:
self.register(
self._builtin_descriptor("persona.get", "获取人格"),
call_handler=self._persona_get,
)
self.register(
self._builtin_descriptor("persona.list", "列出人格"),
call_handler=self._persona_list,
)
self.register(
self._builtin_descriptor("persona.create", "创建人格"),
call_handler=self._persona_create,
)
self.register(
self._builtin_descriptor("persona.update", "更新人格"),
call_handler=self._persona_update,
)
self.register(
self._builtin_descriptor("persona.delete", "删除人格"),
call_handler=self._persona_delete,
)

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class PlatformCapabilityMixin(CapabilityRouterBridgeBase):
async def _platform_send(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, target = self._resolve_target(payload)
text = str(payload.get("text", ""))
message_id = f"msg_{len(self.sent_messages) + 1}"
sent: dict[str, Any] = {
"message_id": message_id,
"session": session,
"text": text,
}
if target is not None:
sent["target"] = target
self.sent_messages.append(sent)
return {"message_id": message_id}
async def _platform_send_image(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, target = self._resolve_target(payload)
image_url = str(payload.get("image_url", ""))
message_id = f"img_{len(self.sent_messages) + 1}"
sent: dict[str, Any] = {
"message_id": message_id,
"session": session,
"image_url": image_url,
}
if target is not None:
sent["target"] = target
self.sent_messages.append(sent)
return {"message_id": message_id}
async def _platform_send_chain(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, target = self._resolve_target(payload)
chain = payload.get("chain")
if not isinstance(chain, list) or not all(
isinstance(item, dict) for item in chain
):
raise AstrBotError.invalid_input(
"platform.send_chain 的 chain 必须是 object 数组"
)
message_id = f"chain_{len(self.sent_messages) + 1}"
sent: dict[str, Any] = {
"message_id": message_id,
"session": session,
"chain": [dict(item) for item in chain],
}
if target is not None:
sent["target"] = target
self.sent_messages.append(sent)
return {"message_id": message_id}
async def _platform_send_by_session(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
chain = payload.get("chain")
if not isinstance(chain, list) or not all(
isinstance(item, dict) for item in chain
):
raise AstrBotError.invalid_input(
"platform.send_by_session 的 chain 必须是 object 数组"
)
session = str(payload.get("session", ""))
message_id = f"proactive_{len(self.sent_messages) + 1}"
self.sent_messages.append(
{
"message_id": message_id,
"session": session,
"chain": [dict(item) for item in chain],
}
)
return {"message_id": message_id}
async def _platform_get_group(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, _target = self._resolve_target(payload)
return {"group": self._mock_group_payload(session)}
async def _platform_get_members(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, _target = self._resolve_target(payload)
group = self._mock_group_payload(session)
if group is None:
return {"members": []}
return {"members": list(group.get("members", []))}
async def _platform_list_instances(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
return {
"platforms": [
{
"id": str(item.get("id", "")),
"name": str(item.get("name", "")),
"type": str(item.get("type", "")),
"status": str(item.get("status", "unknown")),
}
for item in self.get_platform_instances()
if isinstance(item, dict)
]
}
def _register_platform_capabilities(self) -> None:
self.register(
self._builtin_descriptor("platform.send", "发送消息"),
call_handler=self._platform_send,
)
self.register(
self._builtin_descriptor("platform.send_image", "发送图片"),
call_handler=self._platform_send_image,
)
self.register(
self._builtin_descriptor("platform.send_chain", "发送消息链"),
call_handler=self._platform_send_chain,
)
self.register(
self._builtin_descriptor(
"platform.send_by_session", "按会话主动发送消息链"
),
call_handler=self._platform_send_by_session,
)
self.register(
self._builtin_descriptor("platform.get_group", "获取当前群信息"),
call_handler=self._platform_get_group,
)
self.register(
self._builtin_descriptor("platform.get_members", "获取群成员"),
call_handler=self._platform_get_members,
)
self.register(
self._builtin_descriptor("platform.list_instances", "列出平台实例元信息"),
call_handler=self._platform_list_instances,
)
async def _platform_manager_get_by_id(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._require_reserved_plugin("platform.manager.get_by_id")
platform_id = str(payload.get("platform_id", "")).strip()
platform = next(
(
dict(item)
for item in self._platform_instances
if str(item.get("id", "")) == platform_id
),
None,
)
return {"platform": platform}
async def _platform_manager_clear_errors(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._require_reserved_plugin("platform.manager.clear_errors")
platform_id = str(payload.get("platform_id", "")).strip()
for item in self._platform_instances:
if str(item.get("id", "")) != platform_id:
continue
item["errors"] = []
item["last_error"] = None
if str(item.get("status", "")) == "error":
item["status"] = "running"
break
return {}
async def _platform_manager_get_stats(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._require_reserved_plugin("platform.manager.get_stats")
platform_id = str(payload.get("platform_id", "")).strip()
for item in self._platform_instances:
if str(item.get("id", "")) != platform_id:
continue
stats = item.get("stats")
if isinstance(stats, dict):
return {"stats": dict(stats)}
errors = item.get("errors")
last_error = item.get("last_error")
meta = item.get("meta")
return {
"stats": {
"id": platform_id,
"type": str(item.get("type", "")),
"display_name": str(item.get("name", platform_id)),
"status": str(item.get("status", "pending")),
"started_at": item.get("started_at"),
"error_count": len(errors) if isinstance(errors, list) else 0,
"last_error": dict(last_error)
if isinstance(last_error, dict)
else None,
"unified_webhook": bool(item.get("unified_webhook", False)),
"meta": dict(meta) if isinstance(meta, dict) else {},
}
}
return {"stats": None}
def _register_platform_manager_capabilities(self) -> None:
self.register(
self._builtin_descriptor(
"platform.manager.get_by_id",
"按 ID 获取平台管理快照",
),
call_handler=self._platform_manager_get_by_id,
)
self.register(
self._builtin_descriptor(
"platform.manager.clear_errors",
"清除平台错误",
),
call_handler=self._platform_manager_clear_errors,
)
self.register(
self._builtin_descriptor(
"platform.manager.get_stats",
"获取平台统计信息",
),
call_handler=self._platform_manager_get_stats,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class SessionCapabilityMixin(CapabilityRouterBridgeBase):
async def _session_plugin_is_enabled(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
plugin_name = str(payload.get("plugin_name", ""))
config = self._session_plugin_config(session)
enabled_plugins = {
str(item) for item in config.get("enabled_plugins", []) if str(item).strip()
}
disabled_plugins = {
str(item)
for item in config.get("disabled_plugins", [])
if str(item).strip()
}
if plugin_name in enabled_plugins:
return {"enabled": True}
return {"enabled": plugin_name not in disabled_plugins}
async def _session_plugin_filter_handlers(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
handlers = payload.get("handlers")
if not isinstance(handlers, list):
raise AstrBotError.invalid_input(
"session.plugin.filter_handlers 的 handlers 必须是 object 数组"
)
disabled_plugins = {
str(item)
for item in self._session_plugin_config(session).get("disabled_plugins", [])
if str(item).strip()
}
reserved_plugins = {
str(plugin.metadata.get("name", ""))
for plugin in self._plugins.values()
if bool(plugin.metadata.get("reserved", False))
}
filtered = []
for item in handlers:
if not isinstance(item, dict):
continue
plugin_name = str(item.get("plugin_name", ""))
if (
plugin_name
and plugin_name in disabled_plugins
and plugin_name not in reserved_plugins
):
continue
filtered.append(dict(item))
return {"handlers": filtered}
async def _session_service_is_llm_enabled(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
return {"enabled": bool(config.get("llm_enabled", True))}
async def _session_service_set_llm_status(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
config["llm_enabled"] = bool(payload.get("enabled", False))
self._session_service_configs[session] = config
return {}
async def _session_service_is_tts_enabled(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
return {"enabled": bool(config.get("tts_enabled", True))}
async def _session_service_set_tts_status(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
config["tts_enabled"] = bool(payload.get("enabled", False))
self._session_service_configs[session] = config
return {}
def _register_session_capabilities(self) -> None:
self.register(
self._builtin_descriptor("session.plugin.is_enabled", "获取会话级插件开关"),
call_handler=self._session_plugin_is_enabled,
)
self.register(
self._builtin_descriptor(
"session.plugin.filter_handlers",
"按会话过滤 handler 元数据",
),
call_handler=self._session_plugin_filter_handlers,
)
self.register(
self._builtin_descriptor(
"session.service.is_llm_enabled",
"获取会话级 LLM 开关",
),
call_handler=self._session_service_is_llm_enabled,
)
self.register(
self._builtin_descriptor(
"session.service.set_llm_status",
"写入会话级 LLM 开关",
),
call_handler=self._session_service_set_llm_status,
)
self.register(
self._builtin_descriptor(
"session.service.is_tts_enabled",
"获取会话级 TTS 开关",
),
call_handler=self._session_service_is_tts_enabled,
)
self.register(
self._builtin_descriptor(
"session.service.set_tts_status",
"写入会话级 TTS 开关",
),
call_handler=self._session_service_set_tts_status,
)

View File

@@ -0,0 +1,454 @@
from __future__ import annotations
import json
import uuid
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import (
CapabilityRouterBridgeBase,
_clone_chain_payload,
_clone_target_payload,
)
class SystemCapabilityMixin(CapabilityRouterBridgeBase):
def _register_system_capabilities(self) -> None:
self.register(
self._builtin_descriptor("system.get_data_dir", "获取插件数据目录"),
call_handler=self._system_get_data_dir,
exposed=False,
)
self.register(
self._builtin_descriptor("system.text_to_image", "文本转图片"),
call_handler=self._system_text_to_image,
exposed=False,
)
self.register(
self._builtin_descriptor("system.html_render", "渲染 HTML 模板"),
call_handler=self._system_html_render,
exposed=False,
)
self.register(
self._builtin_descriptor("system.file.register", "注册文件令牌"),
call_handler=self._system_file_register,
exposed=False,
)
self.register(
self._builtin_descriptor("system.file.handle", "解析文件令牌"),
call_handler=self._system_file_handle,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.session_waiter.register",
"注册会话等待器",
),
call_handler=self._system_session_waiter_register,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.session_waiter.unregister",
"注销会话等待器",
),
call_handler=self._system_session_waiter_unregister,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.react", "发送事件表情回应"),
call_handler=self._system_event_react,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.send_typing", "发送输入中状态"),
call_handler=self._system_event_send_typing,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.send_streaming",
"发送事件流式消息",
),
call_handler=self._system_event_send_streaming,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.send_streaming_chunk",
"推送事件流式消息分片",
),
call_handler=self._system_event_send_streaming_chunk,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.send_streaming_close",
"关闭事件流式消息会话",
),
call_handler=self._system_event_send_streaming_close,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.llm.get_state",
"读取当前请求的默认 LLM 状态",
),
call_handler=self._system_event_llm_get_state,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.llm.request",
"请求当前事件继续进入默认 LLM 链路",
),
call_handler=self._system_event_llm_request,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.result.get", "读取当前请求结果"),
call_handler=self._system_event_result_get,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.result.set", "写入当前请求结果"),
call_handler=self._system_event_result_set,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.result.clear", "清理当前请求结果"),
call_handler=self._system_event_result_clear,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.handler_whitelist.get",
"读取当前请求 handler 白名单",
),
call_handler=self._system_event_handler_whitelist_get,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.handler_whitelist.set",
"写入当前请求 handler 白名单",
),
call_handler=self._system_event_handler_whitelist_set,
exposed=False,
)
self.register(
self._builtin_descriptor(
"registry.get_handlers_by_event_type",
"按事件类型列出 handler 元数据",
),
call_handler=self._registry_get_handlers_by_event_type,
)
self.register(
self._builtin_descriptor(
"registry.get_handler_by_full_name",
"按 full name 查询 handler 元数据",
),
call_handler=self._registry_get_handler_by_full_name,
)
self.register(
self._builtin_descriptor(
"registry.command.register",
"注册动态命令路由",
),
call_handler=self._registry_command_register,
)
def _ensure_request_overlay(self, request_id: str) -> dict[str, Any]:
overlay = self._request_overlays.get(request_id)
if overlay is None:
overlay = {
"should_call_llm": False,
"requested_llm": False,
"result": None,
"handler_whitelist": None,
}
self._request_overlays[request_id] = overlay
return overlay
async def _system_get_data_dir(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_id = self._require_caller_plugin_id("system.get_data_dir")
data_dir = self._system_data_root / plugin_id
data_dir.mkdir(parents=True, exist_ok=True)
return {"path": str(data_dir)}
async def _system_text_to_image(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
text = str(payload.get("text", ""))
if bool(payload.get("return_url", True)):
return {"result": f"mock://text_to_image/{text}"}
return {"result": f"<image>{text}</image>"}
async def _system_html_render(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
tmpl = str(payload.get("tmpl", ""))
data = payload.get("data")
if not isinstance(data, dict):
raise AstrBotError.invalid_input("system.html_render requires object data")
if bool(payload.get("return_url", True)):
return {"result": f"mock://html_render/{tmpl}"}
return {"result": json.dumps({"tmpl": tmpl, "data": data}, ensure_ascii=False)}
async def _system_file_register(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
path = str(payload.get("path", "")).strip()
if not path:
raise AstrBotError.invalid_input("system.file.register requires path")
file_token = uuid.uuid4().hex
self._file_token_store[file_token] = path
return {"token": file_token, "url": f"mock://file/{file_token}"}
async def _system_file_handle(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
file_token = str(payload.get("token", "")).strip()
if not file_token:
raise AstrBotError.invalid_input("system.file.handle requires token")
path = self._file_token_store.pop(file_token, None)
if path is None:
raise AstrBotError.invalid_input(f"Unknown file token: {file_token}")
return {"path": path}
async def _system_event_llm_get_state(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
return {
"should_call_llm": bool(overlay["should_call_llm"]),
"requested_llm": bool(overlay["requested_llm"]),
}
async def _system_event_llm_request(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
overlay["requested_llm"] = True
overlay["should_call_llm"] = True
return await self._system_event_llm_get_state(request_id, {}, _token)
async def _system_event_result_get(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
result = overlay.get("result")
return {"result": dict(result) if isinstance(result, dict) else None}
async def _system_event_result_set(
self, request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
result = payload.get("result")
if not isinstance(result, dict):
raise AstrBotError.invalid_input(
"system.event.result.set 的 result 必须是 object"
)
overlay = self._ensure_request_overlay(request_id)
overlay["result"] = dict(result)
return {"result": dict(result)}
async def _system_event_result_clear(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
overlay["result"] = None
return {}
async def _system_event_handler_whitelist_get(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
whitelist = overlay.get("handler_whitelist")
if whitelist is None:
return {"plugin_names": None}
return {"plugin_names": sorted(str(item) for item in whitelist)}
async def _system_event_handler_whitelist_set(
self, request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
plugin_names_payload = payload.get("plugin_names")
if plugin_names_payload is None:
overlay["handler_whitelist"] = None
elif isinstance(plugin_names_payload, list):
overlay["handler_whitelist"] = {
str(item) for item in plugin_names_payload if str(item).strip()
}
else:
raise AstrBotError.invalid_input(
"system.event.handler_whitelist.set 的 plugin_names 必须是数组或 null"
)
return await self._system_event_handler_whitelist_get(request_id, {}, _token)
async def _registry_get_handlers_by_event_type(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
event_type = str(payload.get("event_type", "")).strip()
handlers: list[dict[str, Any]] = []
for plugin in self._plugins.values():
handlers.extend(
[
dict(handler)
for handler in plugin.handlers
if event_type in handler.get("event_types", [])
]
)
if event_type == "message":
for plugin_name, routes in self._dynamic_command_routes.items():
for route in routes:
if not isinstance(route, dict):
continue
handlers.append(
{
"plugin_name": str(route.get("plugin_name", plugin_name)),
"handler_full_name": str(
route.get("handler_full_name", "")
),
"trigger_type": (
"message"
if bool(route.get("use_regex", False))
else "command"
),
"event_types": ["message"],
"enabled": True,
"group_path": [],
}
)
return {"handlers": handlers}
async def _registry_get_handler_by_full_name(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
full_name = str(payload.get("full_name", "")).strip()
for plugin in self._plugins.values():
for handler in plugin.handlers:
if handler.get("handler_full_name") == full_name:
return {"handler": dict(handler)}
return {"handler": None}
async def _registry_command_register(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
source_event_type = str(payload.get("source_event_type", "")).strip()
if source_event_type not in {"astrbot_loaded", "platform_loaded"}:
raise AstrBotError.invalid_input(
"register_commands is only available in astrbot_loaded/platform_loaded events"
)
if bool(payload.get("ignore_prefix", False)):
raise AstrBotError.invalid_input(
"register_commands(ignore_prefix=True) is unsupported in SDK runtime"
)
priority_value = payload.get("priority", 0)
if isinstance(priority_value, bool) or not isinstance(priority_value, int):
raise AstrBotError.invalid_input(
"registry.command.register 的 priority 必须是 integer"
)
plugin_id = self._require_caller_plugin_id("registry.command.register")
self.register_dynamic_command_route(
plugin_id=plugin_id,
command_name=str(payload.get("command_name", "")),
handler_full_name=str(payload.get("handler_full_name", "")),
desc=str(payload.get("desc", "")),
priority=priority_value,
use_regex=bool(payload.get("use_regex", False)),
)
return {}
async def _system_session_waiter_register(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_id = self._require_caller_plugin_id("system.session_waiter.register")
session_key = str(payload.get("session_key", "")).strip()
if not session_key:
raise AstrBotError.invalid_input(
"system.session_waiter.register requires session_key"
)
self._session_waiters.setdefault(plugin_id, set()).add(session_key)
return {}
async def _system_session_waiter_unregister(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_id = self._require_caller_plugin_id("system.session_waiter.unregister")
session_key = str(payload.get("session_key", "")).strip()
plugin_waiters = self._session_waiters.get(plugin_id)
if plugin_waiters is None:
return {}
plugin_waiters.discard(session_key)
if not plugin_waiters:
self._session_waiters.pop(plugin_id, None)
return {}
async def _system_event_react(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self.event_actions.append(
{
"action": "react",
"emoji": str(payload.get("emoji", "")),
"target": _clone_target_payload(payload.get("target")),
}
)
return {"supported": True}
async def _system_event_send_typing(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self.event_actions.append(
{
"action": "send_typing",
"target": _clone_target_payload(payload.get("target")),
}
)
return {"supported": True}
async def _system_event_send_streaming(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
stream_id = f"mock-stream-{len(self._event_streams) + 1}"
stream_state: dict[str, Any] = {
"target": _clone_target_payload(payload.get("target")),
"chunks": [],
"use_fallback": bool(payload.get("use_fallback", False)),
}
self._event_streams[stream_id] = stream_state
return {"supported": True, "stream_id": stream_id}
async def _system_event_send_streaming_chunk(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
stream = self._event_streams.get(str(payload.get("stream_id", "")))
if stream is None:
raise AstrBotError.invalid_input("Unknown sdk event streaming session")
chain = payload.get("chain")
if not isinstance(chain, list):
raise AstrBotError.invalid_input(
"system.event.send_streaming_chunk requires a chain array"
)
stream["chunks"].append({"chain": _clone_chain_payload(chain)})
return {}
async def _system_event_send_streaming_close(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
stream_id = str(payload.get("stream_id", ""))
stream = self._event_streams.pop(stream_id, None)
if stream is None:
raise AstrBotError.invalid_input("Unknown sdk event streaming session")
self.event_actions.append(
{
"action": "send_streaming",
"target": stream["target"],
"chunks": list(stream["chunks"]),
"use_fallback": bool(stream["use_fallback"]),
}
)
return {"supported": True}

View File

@@ -67,6 +67,7 @@
provider.rerank.rerank: 文档重排序
provider.manager.set: 设置当前 Provider
provider.manager.get_by_id: 按 ID 获取 Provider 管理记录
provider.manager.get_merged_provider_config: 获取 Provider 合并配置
provider.manager.load: 运行时加载 Provider
provider.manager.terminate: 终止已加载的 Provider
provider.manager.create: 创建 Provider

View File

@@ -39,6 +39,7 @@ from .._invocation_context import caller_plugin_scope
from .._plugin_logger import PluginLogger
from .._star_runtime import bind_star_runtime
from .._typing_utils import unwrap_optional
from ..clients.llm import LLMResponse
from ..context import CancelToken, Context
from ..conversation import (
DEFAULT_BUSY_MESSAGE,
@@ -49,8 +50,13 @@ from ..conversation import (
)
from ..events import MessageEvent
from ..filters import LocalFilterBinding
from ..llm.entities import ProviderRequest
from ..message_components import BaseMessageComponent
from ..message_result import MessageChain, MessageEventResult, coerce_message_chain
from ..message_result import (
MessageChain,
MessageEventResult,
coerce_message_chain,
)
from ..protocol.descriptors import (
CommandTrigger,
MessageTrigger,
@@ -60,12 +66,12 @@ from ..protocol.descriptors import (
from ..schedule import ScheduleContext
from ..session_waiter import SessionWaiterManager
from ..star import Star
from .capability_dispatcher import CapabilityDispatcher
from ._command_matching import (
build_command_args,
build_regex_args,
match_command_name,
)
from .capability_dispatcher import CapabilityDispatcher
from .limiter import LimiterEngine
from .loader import LoadedHandler
@@ -76,6 +82,13 @@ class _ActiveConversation:
task: asyncio.Task[Any]
@dataclass(slots=True)
class _InjectedEventPayloads:
provider_request: ProviderRequest | None = None
llm_response: LLMResponse | None = None
event_result: MessageEventResult | None = None
class HandlerDispatcher:
def __init__(
self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler]
@@ -205,6 +218,8 @@ class HandlerDispatcher:
schedule_context: ScheduleContext | None = None,
) -> dict[str, Any]:
summary = {"sent_message": False, "stop": False, "call_llm": False}
injected_payloads = _InjectedEventPayloads()
event_type = self._event_type_name(event)
try:
limiter = loaded.limiter
if limiter is not None:
@@ -254,6 +269,7 @@ class HandlerDispatcher:
plugin_id=self._resolve_plugin_id(loaded),
handler_ref=loaded.descriptor.id,
schedule_context=schedule_context,
injected_payloads=injected_payloads,
)
)
if inspect.isasyncgen(result):
@@ -263,6 +279,11 @@ class HandlerDispatcher:
await self._handle_result_item(item, event, ctx),
)
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
self._append_injected_payloads(
summary,
injected_payloads,
event_type=event_type,
)
return summary
if inspect.isawaitable(result):
result = await result
@@ -272,6 +293,11 @@ class HandlerDispatcher:
await self._handle_result_item(result, event, ctx),
)
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
self._append_injected_payloads(
summary,
injected_payloads,
event_type=event_type,
)
return summary
except Exception as exc:
await self._handle_error(
@@ -339,6 +365,7 @@ class HandlerDispatcher:
handler_ref: str | None = None,
schedule_context: ScheduleContext | None = None,
conversation_session: ConversationSession | None = None,
injected_payloads: _InjectedEventPayloads | None = None,
) -> list[Any]:
"""构建 handler 参数列表。"""
from loguru import logger
@@ -371,6 +398,7 @@ class HandlerDispatcher:
ctx,
schedule_context,
conversation_session,
injected_payloads=injected_payloads,
)
# 2. Fallback 按名字注入
@@ -423,6 +451,8 @@ class HandlerDispatcher:
if not str(key).startswith("__command_")
}
)
if not isinstance(loaded.descriptor.trigger, CommandTrigger):
return parsed_args, None
model_param = resolve_command_model_param(loaded.callable)
if model_param is None:
return parsed_args, None
@@ -456,7 +486,7 @@ class HandlerDispatcher:
) -> dict[str, Any]:
assert loaded.conversation is not None
conversation_meta = loaded.conversation
summary = {"sent_message": False, "stop": False, "call_llm": False}
summary = {"sent_message": False, "stop": True, "call_llm": False}
key = f"{self._resolve_plugin_id(loaded)}:{event.session_id}"
active = self._conversations.get(key)
if active is not None and not active.task.done():
@@ -584,6 +614,8 @@ class HandlerDispatcher:
ctx: Context,
schedule_context: ScheduleContext | None,
conversation_session: ConversationSession | None,
*,
injected_payloads: _InjectedEventPayloads | None = None,
) -> Any:
"""根据类型注解注入参数。"""
param_type, _is_optional = unwrap_optional(param_type)
@@ -612,9 +644,117 @@ class HandlerDispatcher:
isinstance(param_type, type) and issubclass(param_type, ConversationSession)
):
return conversation_session
if param_type is ProviderRequest or (
isinstance(param_type, type) and issubclass(param_type, ProviderRequest)
):
return self._inject_provider_request(event, injected_payloads)
if param_type is LLMResponse or (
isinstance(param_type, type) and issubclass(param_type, LLMResponse)
):
return self._inject_llm_response(event, injected_payloads)
if param_type is MessageEventResult or (
isinstance(param_type, type) and issubclass(param_type, MessageEventResult)
):
return self._inject_event_result(event, injected_payloads)
return None
@staticmethod
def _event_type_name(event: MessageEvent) -> str:
raw = event.raw if isinstance(event.raw, dict) else {}
value = raw.get("event_type") or raw.get("type")
return str(value or "")
@staticmethod
def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None:
raw = event.raw if isinstance(event.raw, dict) else {}
payload = raw.get(key)
if isinstance(payload, dict):
return payload
nested_raw = raw.get("raw")
if isinstance(nested_raw, dict):
nested_payload = nested_raw.get(key)
if isinstance(nested_payload, dict):
return nested_payload
return None
def _inject_provider_request(
self,
event: MessageEvent,
injected_payloads: _InjectedEventPayloads | None,
) -> ProviderRequest | None:
if injected_payloads is None:
payload = self._payload_from_event(event, "provider_request")
return (
ProviderRequest.from_payload(payload) if payload is not None else None
)
if injected_payloads.provider_request is None:
payload = self._payload_from_event(event, "provider_request")
if payload is None:
return None
injected_payloads.provider_request = ProviderRequest.from_payload(payload)
return injected_payloads.provider_request
def _inject_llm_response(
self,
event: MessageEvent,
injected_payloads: _InjectedEventPayloads | None,
) -> LLMResponse | None:
if injected_payloads is None:
payload = self._payload_from_event(event, "llm_response")
return LLMResponse.model_validate(payload) if payload is not None else None
if injected_payloads.llm_response is None:
payload = self._payload_from_event(event, "llm_response")
if payload is None:
return None
injected_payloads.llm_response = LLMResponse.model_validate(payload)
return injected_payloads.llm_response
def _inject_event_result(
self,
event: MessageEvent,
injected_payloads: _InjectedEventPayloads | None,
) -> MessageEventResult | None:
if injected_payloads is None:
payload = self._payload_from_event(event, "event_result")
return (
MessageEventResult.from_payload(payload)
if payload is not None
else None
)
if injected_payloads.event_result is None:
payload = self._payload_from_event(event, "event_result")
if payload is None:
return None
injected_payloads.event_result = MessageEventResult.from_payload(payload)
return injected_payloads.event_result
@staticmethod
def _append_injected_payloads(
summary: dict[str, Any],
injected_payloads: _InjectedEventPayloads,
*,
event_type: str,
) -> None:
if (
event_type == "llm_request"
and injected_payloads.provider_request is not None
):
summary["provider_request"] = (
injected_payloads.provider_request.to_payload()
)
elif (
event_type == "llm_response" and injected_payloads.llm_response is not None
):
summary["llm_response"] = injected_payloads.llm_response.model_dump(
exclude_none=True
)
elif (
event_type == "decorating_result"
and injected_payloads.event_result is not None
):
summary["event_result"] = injected_payloads.event_result.to_payload()
def _format_handler_injection_error(
self,
*,

View File

@@ -55,6 +55,7 @@ import copy
import importlib
import inspect
import json
import logging
import os
import re
import shutil
@@ -105,6 +106,7 @@ OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None
HandlerKind: TypeAlias = Literal["handler", "hook", "tool", "session"]
DiscoverySeverity: TypeAlias = Literal["warning", "error"]
DiscoveryPhase: TypeAlias = Literal["discovery", "load", "lifecycle", "reload"]
_LOGGER = logging.getLogger(__name__)
def _default_python_version() -> str:
@@ -502,17 +504,74 @@ def _normalize_config_value(field_schema: dict[str, Any], value: Any) -> Any:
return copy.deepcopy(value) if value is not None else default_value
def load_plugin_config(plugin: PluginSpec) -> dict[str, Any]:
"""加载插件配置,返回普通字典"""
def load_plugin_config_schema(plugin: PluginSpec) -> dict[str, Any]:
"""加载插件配置 schema解析失败时记录日志并返回空对象"""
schema_path = plugin.plugin_dir / CONFIG_SCHEMA_FILE
if not schema_path.exists():
return {}
try:
schema_payload = json.loads(schema_path.read_text(encoding="utf-8"))
except Exception:
schema_payload = {}
schema = schema_payload if isinstance(schema_payload, dict) else {}
except json.JSONDecodeError as exc:
_LOGGER.warning(
"Failed to parse SDK plugin config schema %s: %s",
schema_path,
exc,
)
return {}
except OSError as exc:
_LOGGER.warning(
"Failed to read SDK plugin config schema %s: %s",
schema_path,
exc,
)
return {}
if not isinstance(schema_payload, dict):
_LOGGER.warning(
"SDK plugin config schema %s must be a JSON object, got %s",
schema_path,
type(schema_payload).__name__,
)
return {}
return schema_payload
def save_plugin_config(
plugin: PluginSpec,
payload: dict[str, Any],
*,
schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""按 schema 归一化并写回插件配置。"""
active_schema = (
load_plugin_config_schema(plugin) if schema is None else dict(schema)
)
normalized = {
key: _normalize_config_value(field_schema, payload.get(key))
for key, field_schema in active_schema.items()
if isinstance(field_schema, dict)
}
config_path = _plugin_config_path(plugin.plugin_dir, plugin.name)
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(
json.dumps(normalized, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return normalized
def load_plugin_config(
plugin: PluginSpec,
*,
schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""加载插件配置,返回普通字典。"""
active_schema = (
load_plugin_config_schema(plugin) if schema is None else dict(schema)
)
if not active_schema:
return {}
config_path = _plugin_config_path(plugin.plugin_dir, plugin.name)
try:
@@ -521,21 +580,29 @@ def load_plugin_config(plugin: PluginSpec) -> dict[str, Any]:
if config_path.exists()
else {}
)
except Exception:
except json.JSONDecodeError as exc:
_LOGGER.warning(
"Failed to parse SDK plugin config %s: %s",
config_path,
exc,
)
existing_payload = {}
except OSError as exc:
_LOGGER.warning(
"Failed to read SDK plugin config %s: %s",
config_path,
exc,
)
existing_payload = {}
existing = existing_payload if isinstance(existing_payload, dict) else {}
normalized = {
key: _normalize_config_value(field_schema, existing.get(key))
for key, field_schema in schema.items()
for key, field_schema in active_schema.items()
if isinstance(field_schema, dict)
}
if not config_path.exists() or normalized != existing:
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(
json.dumps(normalized, ensure_ascii=False, indent=2),
encoding="utf-8",
)
save_plugin_config(plugin, normalized, schema=active_schema)
return normalized