Merge pull request #37 from united-pooh/sdk/whatevertogo

Refactor memory utility functions and enhance memory capability mixin
This commit is contained in:
whatevertogo
2026-03-19 20:23:21 +08:00
committed by GitHub
3 changed files with 279 additions and 516 deletions

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import json
import math
from datetime import datetime, timedelta, timezone
from typing import Any
def is_ttl_memory_entry(value: Any) -> bool:
"""Return whether a stored memory payload uses the TTL wrapper shape."""
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
def memory_value_for_search(stored: Any) -> dict[str, Any] | None:
"""Unwrap the search payload from a stored memory record when possible."""
if not isinstance(stored, dict):
return None
if is_ttl_memory_entry(stored):
value = stored.get("value")
return value if isinstance(value, dict) else None
return stored
def extract_memory_text(stored: Any) -> str:
"""Pick the canonical text that keyword/vector search should index."""
value = memory_value_for_search(stored)
if not isinstance(value, dict):
return ""
for field_name in ("embedding_text", "content", "summary", "title", "text"):
item = value.get(field_name)
if isinstance(item, str) and item.strip():
return item.strip()
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
def memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
"""Translate a TTL in seconds into an absolute UTC expiration timestamp."""
try:
ttl = int(ttl_seconds)
except (TypeError, ValueError):
return None
if ttl < 1:
return None
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
def memory_expiration_from_stored_payload(stored: Any) -> datetime | None:
"""Recover an absolute expiration timestamp from a stored TTL payload."""
if not is_ttl_memory_entry(stored) or not isinstance(stored, dict):
return None
raw_expires_at = stored.get("expires_at")
if isinstance(raw_expires_at, (int, float)):
return datetime.fromtimestamp(float(raw_expires_at), tz=timezone.utc)
if not isinstance(raw_expires_at, str):
return None
normalized = raw_expires_at.strip()
if not normalized:
return None
if normalized.endswith("Z"):
normalized = f"{normalized[:-1]}+00:00"
try:
expires_at = datetime.fromisoformat(normalized)
except ValueError:
return None
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return expires_at.astimezone(timezone.utc)
def memory_keyword_score(query: str, key: str, text: str) -> float:
"""Score a keyword hit the same way across runtime and core bridge."""
normalized_query = str(query).casefold()
if not normalized_query:
return 1.0
normalized_key = str(key).casefold()
normalized_text = str(text).casefold()
if normalized_query in normalized_key:
return 1.0
if normalized_query in normalized_text:
return 0.9
return 0.0
def cosine_similarity(left: list[float], right: list[float]) -> float:
"""Compute cosine similarity defensively for embedding vectors."""
if not left or not right or len(left) != len(right):
return 0.0
left_norm = math.sqrt(sum(value * value for value in left))
right_norm = math.sqrt(sum(value * value for value in right))
if left_norm <= 0 or right_norm <= 0:
return 0.0
return sum(a * b for a, b in zip(left, right, strict=False)) / (
left_norm * right_norm
)
def memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
"""Normalize cached sidecar data into a stable memory index record."""
if isinstance(entry, dict):
return {
"text": str(entry.get("text", text)),
"embedding": (
[float(item) for item in entry.get("embedding", [])]
if isinstance(entry.get("embedding"), list)
else None
),
"provider_id": (
str(entry.get("provider_id")).strip()
if entry.get("provider_id") is not None
else None
),
}
return {"text": text, "embedding": None, "provider_id": None}

View File

@@ -1,18 +1,20 @@
# AstrBot SDK 项目完整架构分析文档
# AstrBot SDK 架构概述文档
> 作者whatevertogo
> 生成日期2026-03-19
---
## 目录
1. [项目概述](#项目概述)
2. [目录结构](#目录结构)
3. [核心架构层次](#核心架构层次)
4. [协议层设计](#协议层设计)
5. [运行时架构](#运行时架构)
6. [客户端层设计](#客户端层设计)
7. [新旧架构对比](#新旧架构对比)
8. [插件开发指南](#插件开发指南)
9. [关键设计模式](#关键设计模式)
2. [核心架构层次](#核心架构层次)
3. [协议层设计](#协议层设计)
4. [运行时架构](#运行时架构)
5. [客户端层设计](#客户端层设计)
6. [插件开发指南](#插件开发指南)
7. [关键设计模式](#关键设计模式)
8. [文档与资源](#文档与资源)
---
@@ -24,12 +26,12 @@ AstrBot SDK 是一个基于 Python 3.12+ 的机器人插件开发框架,采用
| 特性 | 描述 |
|------|------|
| 进程隔离 | 每个插件运行在独立 Worker 进程,崩溃不影响其他插件 |
| 环境分组 | 多插件可共享同一 Python 虚拟环境,节省资源 |
| 能力路由 | 显式声明的 Capability 系统,支持 JSON Schema 验证 |
| 流式支持 | 原生支持流式 LLM 调用和增量结果返回 |
| 向后兼容 | 完整的旧版 API 兼容层,支持无修改迁移 |
| 协议优先 | 基于 v4 协议的统一通信模型,支持多种传输方式 |
| **进程隔离** | 每个插件运行在独立 Worker 进程,崩溃不影响其他插件 |
| **环境分组** | 多插件可共享同一 Python 虚拟环境,节省资源 |
| **能力路由** | 显式声明的 Capability 系统,支持 JSON Schema 验证 |
| **流式支持** | 原生支持流式 LLM 调用和增量结果返回 |
| **向后兼容** | 完整的旧版 API 兼容层,支持无修改迁移 |
| **协议优先** | 基于 v4 协议的统一通信模型,支持多种传输方式 |
### 技术栈
@@ -44,66 +46,6 @@ AstrBot SDK 是一个基于 Python 3.12+ 的机器人插件开发框架,采用
---
## 目录结构
```
astrbot_sdk/ # v4 SDK 主包
├── __init__.py # 顶层公共 API 导出
├── __main__.py # CLI 入口点 (python -m astrbot_sdk)
├── star.py # v4 原生插件基类
├── context.py # 运行时上下文 (Context, CancelToken)
├── decorators.py # v4 原生装饰器 (on_command, on_message, etc.)
├── events.py # v4 原生事件对象 (MessageEvent)
├── errors.py # 统一错误模型 (AstrBotError)
├── cli.py # 命令行工具 (init/validate/build/dev/run)
├── testing.py # 测试辅助模块 (PluginHarness)
├── _invocation_context.py # 调用上下文管理 (caller_plugin_scope)
├── _testing_support.py # 测试支持工具
├── commands.py # 命令分组工具 (CommandGroup)
├── filters.py # 事件过滤器 (PlatformFilter, CustomFilter)
├── message_components.py # 消息组件 (Plain, Image, At, etc.)
├── message_result.py # 消息结果对象 (MessageChain)
├── message_session.py # 会话标识符 (MessageSession)
├── schedule.py # 定时任务上下文 (ScheduleContext)
├── session_waiter.py # 会话等待器 (SessionController)
├── types.py # 参数类型助手 (GreedyStr)
├── clients/ # 能力客户端层
│ ├── __init__.py # 客户端公共导出
│ ├── _proxy.py # CapabilityProxy 能力代理
│ ├── llm.py # LLM 客户端 (chat, chat_raw, stream_chat)
│ ├── memory.py # 记忆存储客户端 (search, save, get)
│ ├── db.py # KV 存储客户端 (get, set, watch)
│ ├── platform.py # 平台消息客户端 (send, send_image)
│ ├── http.py # HTTP 注册客户端 (register_api)
│ └── metadata.py # 插件元数据客户端 (get_plugin)
├── protocol/ # 协议层
│ ├── __init__.py # 协议公共导出
│ ├── messages.py # v4 协议消息模型
│ ├── descriptors.py # Handler/Capability 描述符
│ └── _builtin_schemas.py # 内置能力 JSON Schema
└── runtime/ # 运行时层
├── __init__.py # 运行时公共导出 (延迟加载)
├── peer.py # 协议对等端 (Peer)
├── transport.py # 传输抽象 (Stdio, WebSocket)
├── handler_dispatcher.py # Handler 执行分发
├── capability_dispatcher.py # Capability 调用分发
├── capability_router.py # Capability 路由
├── _capability_router_builtins.py # 内置能力处理器
├── _loader_support.py # 加载器反射工具
├── _streaming.py # 流式执行原语 (StreamExecution)
├── loader.py # 插件加载器
├── bootstrap.py # 启动引导
├── worker.py # Worker 运行时
├── supervisor.py # Supervisor 运行时
└── environment_groups.py # 环境分组管理
```
---
## 核心架构层次
```
@@ -139,14 +81,8 @@ astrbot_sdk/ # v4 SDK 主包
│ - handler_dispatcher.py (Handler 执行分发、参数注入) │
│ - capability_dispatcher.py (Capability 调用分发) │
│ - capability_router.py (Capability 路由、Schema 验证) │
│ - _capability_router_builtins.py (内置能力实现) │
│ - _loader_support.py (反射工具、签名验证) │
│ - _streaming.py (流式执行原语) │
│ - peer.py (协议对等端) │
│ - transport.py (传输抽象) │
│ - supervisor.py (Supervisor 运行时) │
│ - worker.py (Worker 运行时) │
│ - environment_groups.py (环境分组规划) │
└────────────────────┬────────────────────────────────────────────┘
┌──────────────────▼─────────────────────────────────────────────┐
@@ -155,7 +91,6 @@ astrbot_sdk/ # v4 SDK 主包
│ protocol/ │
│ - messages.py (协议消息模型) │
│ - descriptors.py (Handler/Capability 描述符) │
│ - _builtin_schemas.py (内置能力 JSON Schema) │
│ transport 实现: │
│ - StdioTransport (标准输入输出) │
│ - WebSocketServerTransport (WebSocket 服务端) │
@@ -167,15 +102,15 @@ astrbot_sdk/ # v4 SDK 主包
| 层次 | 职责 | 主要模块 |
|------|------|---------|
| 用户层 | 插件开发者 API | `Star`, `Context`, `MessageEvent`, 装饰器, 过滤器, 命令组 |
| 高层 API | 类型化的能力客户端 | `clients/{llm, memory, db, platform, http, metadata}` |
| 执行边界 | 插件加载、路由、分发、参数注入 | `runtime/loader.py`, `runtime/*_dispatcher.py` |
| 协议层 | 消息模型、描述符、JSON Schema | `protocol/` |
| 传输层 | 底层通信抽象 | `runtime/transport.py` |
| **用户层** | 插件开发者 API | `Star`, `Context`, `MessageEvent`, 装饰器, 过滤器 |
| **高层 API** | 类型化的能力客户端 | `clients/{llm, memory, db, platform, http, metadata}` |
| **执行边界** | 插件加载、路由、分发 | `runtime/loader.py`, `runtime/*_dispatcher.py` |
| **协议层** | 消息模型、描述符、JSON Schema | `protocol/` |
| **传输层** | 底层通信抽象 | `runtime/transport.py` |
### 核心设计原则
1. **延迟加载**`runtime/__init__.py` 使用 `__getattr__` 避免导入时加载 websocket/aiohttp 等重型依赖
1. **延迟加载**`runtime/__init__.py` 使用 `__getattr__` 避免导入时加载重型依赖
2. **插件身份透传**:通过 `caller_plugin_scope()` 上下文管理器将 plugin_id 注入协议层
3. **声明式优先**:所有配置都是数据结构(描述符),便于序列化和跨进程传递
4. **类型安全**:使用 Pydantic 模型和类型注解提供验证和 IDE 支持
@@ -212,17 +147,12 @@ Worker (Plugin) Supervisor (Core)
| InitializeMessage |
| (handlers, capabilities) |
|----------------------------->|
| | 创建 CapabilityRouter
| | 注册 handler.invoke
| |
| ResultMessage(kind="init") |
|<-----------------------------|
| | 等待 handler.invoke 调用
| | 执行 CapabilityRouter.execute()
| |
| InvokeMessage(handler.invoke) |
|<-----------------------------|
| HandlerDispatcher.invoke() |
| 执行用户 handler |
| |
| ResultMessage(output) |
@@ -246,9 +176,8 @@ Worker (Plugin) Supervisor (Core)
"contract": "message_event", # message_event | schedule
"priority": 0,
"permissions": {"require_admin": False, "level": 0},
"filters": [], # 高级过滤器列表
"param_specs": [], # 参数规范
"command_route": {...} # 命令路由元信息
"filters": [],
"param_specs": []
}
```
@@ -259,46 +188,7 @@ Worker (Plugin) Supervisor (Core)
| `CommandTrigger` | command, aliases, platforms | 命令触发 |
| `MessageTrigger` | regex, keywords, platforms | 消息触发(正则/关键词) |
| `EventTrigger` | event_type | 事件触发 |
| `ScheduleTrigger` | cron, interval_seconds | 定时触发(二选一) |
#### FilterSpec 类型
| 类型 | 说明 |
|------|------|
| `PlatformFilterSpec` | 按平台名称过滤 |
| `MessageTypeFilterSpec` | 按消息类型过滤 |
| `LocalFilterRefSpec` | 引用本地自定义过滤器 |
| `CompositeFilterSpec` | 组合过滤器AND/OR |
#### CapabilityDescriptor
```python
{
"name": "llm.chat",
"description": "发送对话请求,返回文本",
"input_schema": {
"type": "object",
"properties": {"prompt": {"type": "string"}},
"required": ["prompt"]
},
"output_schema": {
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"]
},
"supports_stream": False,
"cancelable": False
}
```
### 命名空间治理
**保留前缀**
- `handler.` - 内部 handler.invoke
- `system.` - 系统内置能力
- `internal.` - 内部使用
**内置能力命名空间**`llm`, `memory`, `db`, `platform`, `http`, `metadata`
| `ScheduleTrigger` | cron, interval_seconds | 定时触发 |
### 内置 Capabilities (38个)
@@ -307,7 +197,7 @@ Worker (Plugin) Supervisor (Core)
| 能力 | 说明 |
|------|------|
| `llm.chat` | 同步对话,返回文本 |
| `llm.chat_raw` | 同步对话,返回完整响应(含 usage、tool_calls |
| `llm.chat_raw` | 同步对话,返回完整响应 |
| `llm.stream_chat` | 流式对话 |
#### Memory 命名空间
@@ -317,23 +207,19 @@ Worker (Plugin) Supervisor (Core)
| `memory.search` | 语义搜索记忆 |
| `memory.save` | 保存记忆 |
| `memory.save_with_ttl` | 保存带过期时间的记忆 |
| `memory.get` | 读取单条记忆 |
| `memory.get_many` | 批量获取记忆 |
| `memory.delete` | 删除记忆 |
| `memory.delete_many` | 批量删除记忆 |
| `memory.stats` | 获取记忆统计信息 |
| `memory.get` / `get_many` | 读取记忆 |
| `memory.delete` / `delete_many` | 删除记忆 |
| `memory.stats` | 获取统计信息 |
#### DB 命名空间
| 能力 | 说明 |
|------|------|
| `db.get` | 读取 KV |
| `db.set` | 写入 KV |
| `db.get` / `get_many` | 读取 KV |
| `db.set` / `set_many` | 写入 KV |
| `db.delete` | 删除 KV |
| `db.list` | 列出 KV 键(支持前缀过滤) |
| `db.get_many` | 批量读取 KV |
| `db.set_many` | 批量写入 KV |
| `db.watch` | 订阅 KV 变更(流式) |
| `db.list` | 列出键(支持前缀过滤) |
| `db.watch` | 订阅变更(流式) |
#### Platform 命名空间
@@ -367,13 +253,8 @@ Worker (Plugin) Supervisor (Core)
| `system.get_data_dir` | 获取插件数据目录 |
| `system.text_to_image` | 文本转图片 |
| `system.html_render` | 渲染 HTML 模板 |
| `system.session_waiter.register` | 注册会话等待器 |
| `system.session_waiter.unregister` | 注销会话等待器 |
| `system.event.react` | 发送表情回应 |
| `system.event.send_typing` | 发送输入中状态 |
| `system.event.send_streaming` | 开始流式消息会话 |
| `system.event.send_streaming_chunk` | 推送流式消息分片 |
| `system.event.send_streaming_close` | 关闭流式消息会话 |
| `system.session_waiter.*` | 会话等待器管理 |
| `system.event.*` | 表情回应、输入状态、流式消息 |
---
@@ -406,199 +287,26 @@ Worker (Plugin) Supervisor (Core)
│ │ │
┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐
│ Plugin A │ │ Plugin B │ │ Plugin C │
│ (v4/old) │ │ (v4/old) │ │ (v4/old) │
└───────────┘ └───────────┘ └───────────┘
```
### SupervisorRuntime
### 核心运行时组件
职责:管理多个 Worker 进程,聚合所有 handler
| 组件 | 职责 |
|------|------|
| **SupervisorRuntime** | 管理多个 Worker 进程,聚合所有 handler |
| **WorkerSession** | 管理单个 Worker 进程的生命周期 |
| **PluginWorkerRuntime** | Worker 进程内的插件加载与执行 |
| **HandlerDispatcher** | 将 handler.invoke 请求转成真实 Python 调用 |
| **CapabilityRouter** | 能力注册、发现和执行路由 |
```python
class SupervisorRuntime:
def __init__(self, *, transport, plugins_dir, env_manager):
self.transport = transport # 与 Core 的传输层
self.plugins_dir = plugins_dir # 插件目录
self.capability_router = CapabilityRouter() # 能力路由器
self.peer = Peer(...) # 与 Core 的对等端
self.worker_sessions = {} # Worker 会话映射
self.handler_to_worker = {} # Handler → Worker 映射
### 参数注入优先级
async def start(self):
# 1. 发现所有插件
discovery = discover_plugins(self.plugins_dir)
HandlerDispatcher 支持参数注入,优先级为:
# 2. 规划环境分组
plan_result = self.env_manager.plan(discovery.plugins)
# 3. 为每个分组启动 Worker
for group in plan_result.groups:
session = WorkerSession(group=group, ...)
await session.start()
self.worker_sessions[group.id] = session
# 4. 聚合所有 handler 和 capability
await self.peer.initialize(
handlers=[...],
provided_capabilities=self.capability_router.descriptors()
)
```
### WorkerSession
职责:管理单个 Worker 进程的生命周期
```python
class WorkerSession:
def __init__(self, *, group, env_manager, capability_router):
self.group = group # 环境分组
self.peer = Peer(...) # 与 Worker 的对等端
self.capability_router = capability_router
self.handlers = [] # Worker 注册的 handlers
self.provided_capabilities = [] # Worker 提供的 capabilities
async def start(self):
# 启动 Worker 子进程
python_path = self.env_manager.prepare_group_environment(self.group)
transport = StdioTransport(
command=[python_path, "-m", "astrbot_sdk", "worker", "--group-metadata", ...]
)
self.peer = Peer(transport=transport, ...)
# 等待 Worker 初始化完成
await self.peer.start()
await self.peer.wait_until_remote_initialized()
# 获取 Worker 的注册信息
self.handlers = list(self.peer.remote_handlers)
self.provided_capabilities = list(self.peer.remote_provided_capabilities)
async def invoke_capability(self, capability_name, payload, *, request_id):
# 转发能力调用到 Worker
return await self.peer.invoke(capability_name, payload, request_id=request_id)
```
### PluginWorkerRuntime
职责Worker 进程内的插件加载与执行
```python
class PluginWorkerRuntime:
def __init__(self, *, plugin_dir, transport):
self.plugin = load_plugin_spec(plugin_dir)
self.loaded_plugin = load_plugin(self.plugin)
self.peer = Peer(transport=transport, ...)
self.dispatcher = HandlerDispatcher(...)
self.capability_dispatcher = CapabilityDispatcher(...)
async def start(self):
# 1. 向 Supervisor 注册 handlers 和 capabilities
await self.peer.initialize(
handlers=[h.descriptor for h in self.loaded_plugin.handlers],
provided_capabilities=[c.descriptor for c in self.loaded_plugin.capabilities]
)
# 2. 执行 on_start 生命周期
await self._run_lifecycle("on_start")
# 3. 设置消息处理器
self.peer.set_invoke_handler(self._handle_invoke)
self.peer.set_cancel_handler(self._handle_cancel)
async def _handle_invoke(self, message, cancel_token):
if message.capability == "handler.invoke":
return await self.dispatcher.invoke(message, cancel_token)
return await self.capability_dispatcher.invoke(message, cancel_token)
```
### HandlerDispatcher
职责:将 handler.invoke 请求转成真实 Python 调用
```python
class HandlerDispatcher:
def __init__(self, *, plugin_id, peer, handlers):
self._handlers = {item.descriptor.id: item for item in handlers}
self._peer = peer
self._active = {} # request_id → (task, cancel_token)
async def invoke(self, message, cancel_token):
# 1. 查找 handler
loaded = self._handlers[message.input["handler_id"]]
# 2. 创建上下文
ctx = Context(peer=self._peer, plugin_id=plugin_id, cancel_token=cancel_token)
event = MessageEvent.from_payload(message.input["event"], context=ctx)
# 3. 构建参数 (支持类型注解注入)
args = self._build_args(loaded.callable, event, ctx)
# 4. 执行 handler
result = loaded.callable(*args)
# 5. 处理返回值
await self._consume_result(result, event, ctx)
```
**参数注入优先级**:
1. 按类型注解注入(`MessageEvent`, `Context`
2. 按参数名注入(`event`, `ctx`, `context`
3. 从 legacy_args 注入(命令参数等)
### CapabilityRouter
职责:能力注册、发现和执行路由
```python
class CapabilityRouter:
def __init__(self):
self._registrations = {} # capability_name → registration
self.db_store = {} # 内置 KV 存储
self.memory_store = {} # 内置记忆存储
self._register_builtin_capabilities()
def register(self, descriptor, *, call_handler, stream_handler, finalize):
"""注册能力"""
self._registrations[descriptor.name] = _CapabilityRegistration(
descriptor=descriptor,
call_handler=call_handler,
stream_handler=stream_handler,
finalize=finalize
)
async def execute(self, capability, payload, *, stream, cancel_token, request_id):
"""执行能力调用"""
registration = self._registrations[capability]
if stream:
# 流式调用
raw_execution = registration.stream_handler(request_id, payload, cancel_token)
return StreamExecution(iterator=raw_execution, finalize=finalize)
else:
# 同步调用
output = await registration.call_handler(request_id, payload, cancel_token)
return output
```
### 环境分组管理
```python
class EnvironmentPlanner:
def plan(self, plugins):
"""根据 Python 版本和依赖兼容性分组"""
# 1. 按版本分组
# 2. 按依赖兼容性合并
# 3. 生成分组元数据
return EnvironmentPlanResult(groups=[...])
class GroupEnvironmentManager:
def prepare(self, group):
"""准备分组虚拟环境"""
# 1. 生成 lock/source/metadata 工件
# 2. 必要时重建虚拟环境
# 3. 返回 Python 解释器路径
return venv_python_path
```
1. **按类型注解注入**`MessageEvent`, `Context`
2. **按参数名注入**`event`, `ctx`, `context`
3. **从 legacy_args 注入**(命令参数等)
---
@@ -609,99 +317,35 @@ class GroupEnvironmentManager:
```
┌─────────────────────────────────────────────────────────────┐
│ User Plugin │
├─────────────────────────────────────────────────────────────┤
│ ctx.llm.chat() │
│ ctx.memory.save() │
│ ctx.db.set() │
│ ctx.platform.send() │
│ ctx.llm.chat() / ctx.memory.save() / ctx.db.set() │
└────────────┬──────────────────────────────────────────────┘
┌────────────▼──────────────────────────────────────────────┐
│ CapabilityProxy │
│ - call(name, payload)
│ - stream(name, payload)
│ - call(name, payload) 普通调用
│ - stream(name, payload) 流式调用
└────────────┬──────────────────────────────────────────────┘
┌────────────▼──────────────────────────────────────────────┐
│ Peer │
│ - invoke(capability, payload, stream=False)
│ Peer
│ - invoke(capability, payload)
│ - invoke_stream(capability, payload) │
└────────────┬──────────────────────────────────────────────┘
┌────────────▼──────────────────────────────────────────────┐
│ Transport │
│ - send(json_string) │
│ Transport
│ - send(json_string)
└─────────────────────────────────────────────────────────────┘
```
### CapabilityProxy
职责:封装 Peer 的能力调用接口
```python
class CapabilityProxy:
def __init__(self, peer):
self._peer = peer
async def call(self, name, payload):
"""普通能力调用"""
# 1. 检查能力是否可用
descriptor = self._peer.remote_capability_map.get(name)
if descriptor is None:
raise AstrBotError.capability_not_found(name)
# 2. 调用 Peer.invoke
return await self._peer.invoke(name, payload, stream=False)
async def stream(self, name, payload):
"""流式能力调用"""
# 1. 检查流式支持
descriptor = self._peer.remote_capability_map.get(name)
if not descriptor.supports_stream:
raise AstrBotError.invalid_input(f"{name} 不支持 stream")
# 2. 调用 Peer.invoke_stream
event_stream = await self._peer.invoke_stream(name, payload)
async for event in event_stream:
if event.phase == "delta":
yield event.data
```
### LLMClient
```python
class LLMClient:
def __init__(self, proxy: CapabilityProxy):
self._proxy = proxy
async def chat(self, prompt, *, system=None, history=None, **kwargs) -> str:
"""发送聊天请求,返回文本"""
output = await self._proxy.call("llm.chat", {
"prompt": prompt,
"system": system,
"history": self._serialize_history(history),
**kwargs
})
return output["text"]
async def chat_raw(self, prompt, **kwargs) -> LLMResponse:
"""发送聊天请求,返回完整响应"""
output = await self._proxy.call("llm.chat_raw", {"prompt": prompt, **kwargs})
return LLMResponse.model_validate(output)
async def stream_chat(self, prompt, **kwargs) -> AsyncGenerator[str]:
"""流式聊天"""
async for delta in self._proxy.stream("llm.stream_chat", {"prompt": prompt, **kwargs}):
yield delta["text"]
```
### 其他客户端
### 客户端一览
| 客户端 | 主要方法 | 对应 Capability |
|--------|---------|-----------------|
| `LLMClient` | `chat()`, `chat_raw()`, `stream_chat()` | `llm.*` |
| `MemoryClient` | `search()`, `save()`, `save_with_ttl()`, `get()`, `get_many()`, `delete()`, `delete_many()`, `stats()` | `memory.*` |
| `DBClient` | `get()`, `set()`, `delete()`, `list()`, `get_many()`, `set_many()`, `watch()` | `db.*` |
| `PlatformClient` | `send()`, `send_image()`, `send_chain()`, `send_by_session()`, `send_by_id()`, `get_members()` | `platform.*` |
| `DBClient` | `get()`, `set()`, `get_many()`, `set_many()`, `delete()`, `list()`, `watch()` | `db.*` |
| `PlatformClient` | `send()`, `send_image()`, `send_chain()`, `get_members()` | `platform.*` |
| `HTTPClient` | `register_api()`, `unregister_api()`, `list_apis()` | `http.*` |
| `MetadataClient` | `get_plugin()`, `list_plugins()`, `get_current_plugin()`, `get_plugin_config()` | `metadata.*` |
@@ -709,7 +353,7 @@ class LLMClient:
## 插件开发指南
### v4 原生插件
### v4 原生插件示例
#### plugin.yaml
@@ -756,11 +400,7 @@ class MyPlugin(Star):
"required": ["result"]
}
)
async def calculate_capability(
self,
payload: dict,
ctx: Context
) -> dict:
async def calculate_capability(self, payload: dict, ctx: Context) -> dict:
x = payload.get("x", 0)
return {"result": x * 2}
```
@@ -773,6 +413,51 @@ class MyPlugin(Star):
| `on_stop()` | 插件停止时调用 |
| `on_error(exc, event, ctx)` | Handler 执行出错时调用 |
### 常用功能速查
#### 1. LLM 对话
```python
# 简单对话
reply = await ctx.llm.chat("你好")
# 带历史对话
from astrbot_sdk.clients.llm import ChatMessage
history = [ChatMessage(role="user", content="我叫小明")]
reply = await ctx.llm.chat("你记得我吗?", history=history)
# 流式对话
async for chunk in ctx.llm.stream_chat("讲个故事"):
print(chunk, end="")
```
#### 2. 数据持久化
```python
# DB 客户端(精确匹配)
await ctx.db.set("user:123", {"name": "Alice"})
data = await ctx.db.get("user:123")
# Memory 客户端(语义搜索)
await ctx.memory.save("user_pref", {"theme": "dark"})
results = await ctx.memory.search("用户喜欢什么颜色")
```
#### 3. 消息发送
```python
# 简单文本
await ctx.platform.send(event.session_id, "消息内容")
# 图片
await ctx.platform.send_image(event.session_id, "https://example.com/img.jpg")
# 消息链
from astrbot_sdk.message_components import Plain, Image
chain = [Plain("文字"), Image(url="https://example.com/img.jpg")]
await ctx.platform.send_chain(event.session_id, chain)
```
---
## 关键设计模式
@@ -822,51 +507,49 @@ class MyPlugin(Star):
---
## 附录:关键文件速查
## 文档与资源
### 完整文档目录
SDK 文档按学习路径组织,位于 `src/astrbot_sdk/docs/`
| 级别 | 文档 | 内容 |
|------|------|------|
| **初级** | README.md | 快速开始、核心概念 |
| | 01_context_api.md | Context API 完整参考 |
| | 02_event_and_components.md | MessageEvent 和消息组件 |
| | 03_decorators.md | 装饰器详细说明 |
| | 04_star_lifecycle.md | 插件基类和生命周期 |
| | 05_clients.md | 客户端 API 文档 |
| **中级** | 06_error_handling.md | 错误处理与调试 |
| | 07_advanced_topics.md | 并发、性能优化、安全 |
| | 08_testing_guide.md | 测试指南 |
| **高级** | 09_api_reference.md | 完整 API 索引 |
| | 10_migration_guide.md | 迁移指南 |
| | 11_security_checklist.md | 安全检查清单 |
| | PROJECT_ARCHITECTURE.md | 架构设计文档 |
### 关键文件速查
| 文件 | 核心类/函数 | 说明 |
|------|------------|------|
| `astrbot_sdk/__init__.py` | `Star`, `Context`, `MessageEvent` | 顶层入口 |
| `astrbot_sdk/star.py` | `Star` | v4 原生插件基类 |
| `astrbot_sdk/context.py` | `Context` | 运行时上下文 |
| `astrbot_sdk/decorators.py` | `on_command`, `on_message`, `provide_capability` | v4 装饰器 |
| `astrbot_sdk/decorators.py` | `on_command`, `on_message` | v4 装饰器 |
| `astrbot_sdk/errors.py` | `AstrBotError` | 统一错误模型 |
| `astrbot_sdk/cli.py` | CLI 命令 | 命令行工具init/validate/build/dev/run/worker/websocket |
| `astrbot_sdk/testing.py` | `PluginHarness`, `MockContext` | 测试辅助 |
| `astrbot_sdk/commands.py` | `CommandGroup`, `command_group` | 命令分组工具 |
| `astrbot_sdk/filters.py` | `PlatformFilter`, `CustomFilter`, `all_of`, `any_of` | 事件过滤器 |
| `astrbot_sdk/message_result.py` | `MessageChain`, `MessageEventResult` | 消息结果对象 |
| `astrbot_sdk/message_session.py` | `MessageSession` | 会话标识符 |
| `astrbot_sdk/schedule.py` | `ScheduleContext` | 定时任务上下文 |
| `astrbot_sdk/session_waiter.py` | `SessionController`, `SessionWaiterManager` | 会话等待器 |
| `astrbot_sdk/types.py` | `GreedyStr` | 参数类型助手 |
| `astrbot_sdk/runtime/__init__.py` | 延迟导出 | 运行时公共 API延迟加载 |
| `astrbot_sdk/runtime/peer.py` | `Peer` | 协议对等端 |
| `astrbot_sdk/runtime/supervisor.py` | `SupervisorRuntime` | Supervisor 运行时 |
| `astrbot_sdk/runtime/worker.py` | `PluginWorkerRuntime` | Worker 运行时 |
| `astrbot_sdk/runtime/loader.py` | `load_plugin()`, `_ResolvedComponent` | 插件加载 |
| `astrbot_sdk/runtime/_loader_support.py` | `build_param_specs`, `is_injected_parameter` | 加载器反射工具 |
| `astrbot_sdk/runtime/_streaming.py` | `StreamExecution` | 流式执行原语 |
| `astrbot_sdk/runtime/handler_dispatcher.py` | `HandlerDispatcher` | Handler 执行分发 |
| `astrbot_sdk/runtime/capability_dispatcher.py` | `CapabilityDispatcher` | Capability 调用分发 |
| `astrbot_sdk/runtime/capability_router.py` | `CapabilityRouter` | Capability 路由 |
| `astrbot_sdk/runtime/_capability_router_builtins.py` | `BuiltinCapabilityRouterMixin` | 内置能力处理器 |
| `astrbot_sdk/runtime/environment_groups.py` | `EnvironmentGroup` | 环境分组 |
| `astrbot_sdk/protocol/messages.py` | `InitializeMessage`, `InvokeMessage` | 协议消息 |
| `astrbot_sdk/protocol/descriptors.py` | `HandlerDescriptor`, `CapabilityDescriptor` | 描述符 |
| `astrbot_sdk/protocol/_builtin_schemas.py` | `BUILTIN_CAPABILITY_SCHEMAS` | 内置能力 JSON Schema |
| `astrbot_sdk/clients/_proxy.py` | `CapabilityProxy` | 能力代理 |
| `astrbot_sdk/clients/llm.py` | `LLMClient` | LLM 客户端 |
| `astrbot_sdk/clients/memory.py` | `MemoryClient` | 记忆客户端 |
| `astrbot_sdk/clients/db.py` | `DBClient` | 数据库客户端 |
| `astrbot_sdk/clients/platform.py` | `PlatformClient` | 平台客户端 |
| `astrbot_sdk/clients/http.py` | `HTTPClient` | HTTP 客户端 |
| `astrbot_sdk/clients/metadata.py` | `MetadataClient`, `PluginMetadata` | 元数据客户端 |
| `astrbot_sdk/message_components.py` | `Plain`, `Image`, `At`, `Reply` | 消息组件 |
| `astrbot_sdk/events.py` | `MessageEvent` | 事件对象 |
| `astrbot_sdk/_testing_support.py` | 测试工具 | 测试支持 |
### 版本信息
- **SDK 版本**: v4.0
- **协议版本**: P0.6
- **Python 要求**: >=3.12
- **推荐版本**: 3.12+
---
> 本文档描述 AstrBot SDK v4 的设计与实现思想
> 如有疑问请查阅源代码或提交 Issue
> 本文档基于 AstrBot SDK v4 架构文档整理
> 详细内容请查阅 `src/astrbot_sdk/docs/` 目录下的完整文档

View File

@@ -1,10 +1,17 @@
from __future__ import annotations
import json
import math
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
from typing import Any
from ...._memory_utils import (
cosine_similarity,
extract_memory_text,
is_ttl_memory_entry,
memory_expiration_from_ttl,
memory_index_entry,
memory_keyword_score,
memory_value_for_search,
)
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
@@ -20,7 +27,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
"""
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
return is_ttl_memory_entry(value)
@classmethod
def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
@@ -32,12 +39,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
"""
if not isinstance(stored, dict):
return None
if cls._is_ttl_memory_entry(stored):
value = stored.get("value")
return value if isinstance(value, dict) else None
return stored
return memory_value_for_search(stored)
@classmethod
def _extract_memory_text(cls, stored: Any) -> str:
@@ -49,14 +51,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
"""
value = cls._memory_value_for_search(stored)
if not isinstance(value, dict):
return ""
for field_name in ("embedding_text", "content", "summary", "title", "text"):
item = value.get(field_name)
if isinstance(item, str) and item.strip():
return item.strip()
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
return extract_memory_text(stored)
@staticmethod
def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
@@ -68,13 +63,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
"""
try:
ttl = int(ttl_seconds)
except (TypeError, ValueError):
return None
if ttl < 1:
return None
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
return memory_expiration_from_ttl(ttl_seconds)
@staticmethod
def _memory_keyword_score(query: str, key: str, text: str) -> float:
@@ -88,16 +77,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
float: 基于键名和文本命中的粗粒度关键词分数。
"""
normalized_query = str(query).casefold()
if not normalized_query:
return 1.0
normalized_key = str(key).casefold()
normalized_text = str(text).casefold()
if normalized_query in normalized_key:
return 1.0
if normalized_query in normalized_text:
return 0.9
return 0.0
return memory_keyword_score(query, key, text)
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
@@ -110,15 +90,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
float: 余弦相似度;输入不合法时返回 ``0.0``。
"""
if not left or not right or len(left) != len(right):
return 0.0
left_norm = math.sqrt(sum(value * value for value in left))
right_norm = math.sqrt(sum(value * value for value in right))
if left_norm <= 0 or right_norm <= 0:
return 0.0
return sum(a * b for a, b in zip(left, right, strict=False)) / (
left_norm * right_norm
)
return cosine_similarity(left, right)
def _resolve_memory_embedding_provider_id(
self,
@@ -170,21 +142,7 @@ class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
Returns:
dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
"""
if isinstance(entry, dict):
return {
"text": str(entry.get("text", text)),
"embedding": (
[float(item) for item in entry.get("embedding", [])]
if isinstance(entry.get("embedding"), list)
else None
),
"provider_id": (
str(entry.get("provider_id")).strip()
if entry.get("provider_id") is not None
else None
),
}
return {"text": text, "embedding": None, "provider_id": None}
return memory_index_entry(entry, text=text)
def _clear_memory_sidecars(self, key: str) -> None:
"""清理指定 memory 键对应的所有 sidecar 状态。