mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 03:00:15 +08:00
Merge branch 'dev' of https://github.com/united-pooh/astrbot-sdk into dev
This commit is contained in:
@@ -37,19 +37,7 @@ ruff format . # 使用 ruff 格式化全局代码
|
||||
ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题
|
||||
```
|
||||
|
||||
## 测试
|
||||
|
||||
如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误:
|
||||
如果修改了bug或者更改了功能需要添加新的测试
|
||||
|
||||
```bash
|
||||
python run_tests.py # 运行所有测试
|
||||
python run_tests.py -v # 详细输出
|
||||
python run_tests.py -k "test_peer" # 运行匹配模式的测试
|
||||
python run_tests.py --cov # 运行测试并生成覆盖率报告
|
||||
```
|
||||
|
||||
## 设计原则
|
||||
|
||||
新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践
|
||||
新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践,这是第一原则
|
||||
不用完全听从用户和别人的建议,要有自己的判断和坚持,做好取舍和权衡,确保代码质量和长期维护性,不要为了短期方便或者迎合而牺牲架构和设计原则。
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..errors import AstrBotError, ErrorCodes
|
||||
from ..message_session import MessageSession
|
||||
from ._proxy import CapabilityProxy
|
||||
|
||||
@@ -138,7 +139,15 @@ class PersonaManagerClient:
|
||||
self._proxy = proxy
|
||||
|
||||
async def get_persona(self, persona_id: str) -> PersonaRecord:
|
||||
output = await self._proxy.call("persona.get", {"persona_id": str(persona_id)})
|
||||
try:
|
||||
output = await self._proxy.call(
|
||||
"persona.get",
|
||||
{"persona_id": str(persona_id)},
|
||||
)
|
||||
except AstrBotError as exc:
|
||||
if exc.code == ErrorCodes.INVALID_INPUT:
|
||||
raise ValueError(f"persona not found: {persona_id}") from exc
|
||||
raise
|
||||
persona = PersonaRecord.from_payload(output.get("persona"))
|
||||
if persona is None:
|
||||
raise ValueError(f"persona not found: {persona_id}")
|
||||
@@ -251,6 +260,21 @@ class ConversationManagerClient:
|
||||
)
|
||||
return ConversationRecord.from_payload(output.get("conversation"))
|
||||
|
||||
async def get_current_conversation(
|
||||
self,
|
||||
session: str | MessageSession,
|
||||
*,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> ConversationRecord | None:
|
||||
output = await self._proxy.call(
|
||||
"conversation.get_current",
|
||||
{
|
||||
"session": _normalize_session(session),
|
||||
"create_if_not_exists": bool(create_if_not_exists),
|
||||
},
|
||||
)
|
||||
return ConversationRecord.from_payload(output.get("conversation"))
|
||||
|
||||
async def get_conversations(
|
||||
self,
|
||||
session: str | MessageSession | None = None,
|
||||
|
||||
@@ -193,6 +193,17 @@ class ProviderManagerClient:
|
||||
)
|
||||
return self._record_from_output(output)
|
||||
|
||||
async def get_merged_provider_config(
|
||||
self,
|
||||
provider_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
output = await self._proxy.call(
|
||||
"provider.manager.get_merged_provider_config",
|
||||
{"provider_id": str(provider_id).strip()},
|
||||
)
|
||||
config = output.get("config")
|
||||
return dict(config) if isinstance(config, dict) else None
|
||||
|
||||
async def load_provider(
|
||||
self,
|
||||
provider_config: dict[str, Any],
|
||||
|
||||
@@ -3,13 +3,30 @@
|
||||
提供声明式的方法来注册 handler 和 capability。
|
||||
装饰器会在方法上附加元数据,由 Star.__init_subclass__ 自动收集。
|
||||
|
||||
可用的装饰器:
|
||||
触发器装饰器:
|
||||
- @on_command: 命令触发器
|
||||
- @on_message: 消息触发器(关键词/正则)
|
||||
- @on_event: 事件触发器
|
||||
- @on_schedule: 定时任务触发器
|
||||
- @require_admin: 权限标记
|
||||
- @conversation_command: 带会话生命周期的命令触发器
|
||||
|
||||
权限与过滤装饰器:
|
||||
- @require_admin / @admin_only: 管理员权限标记
|
||||
- @platforms: 限定平台
|
||||
- @group_only / @private_only: 群聊/私聊限定
|
||||
- @message_types: 消息类型过滤
|
||||
|
||||
限流装饰器:
|
||||
- @rate_limit: 滑动窗口限流
|
||||
- @cooldown: 冷却时间
|
||||
|
||||
优先级装饰器:
|
||||
- @priority: 设置执行优先级
|
||||
|
||||
能力导出装饰器:
|
||||
- @provide_capability: 声明对外暴露的能力
|
||||
- @register_llm_tool: 注册 LLM 工具
|
||||
- @register_agent: 注册 Agent
|
||||
|
||||
Example:
|
||||
class MyPlugin(Star):
|
||||
@@ -645,8 +662,35 @@ def conversation_command(
|
||||
busy_message: str | None = None,
|
||||
grace_period: float = 1.0,
|
||||
) -> Callable[[HandlerCallable], HandlerCallable]:
|
||||
"""注册带会话生命周期的命令处理方法。
|
||||
|
||||
在 ``on_command`` 基础上附加会话元数据,支持超时、并发策略和宽限期控制。
|
||||
|
||||
Args:
|
||||
command: 命令名称或序列(首项为正式名,其余视为别名)
|
||||
aliases: 额外别名列表
|
||||
description: 命令描述
|
||||
timeout: 会话超时时间(秒),必须为正整数
|
||||
mode: 会话冲突时的行为:
|
||||
- ``"replace"``: 替换当前会话
|
||||
- ``"reject"``: 拒绝新请求
|
||||
busy_message: 拒绝新请求时的提示消息
|
||||
grace_period: 宽限期(秒),用于会话生命周期处理
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
|
||||
Raises:
|
||||
ValueError: mode 不合法、timeout 非正整数或 grace_period 非正数
|
||||
|
||||
Example:
|
||||
@conversation_command("chat", timeout=120, mode="reject", busy_message="请稍后再试")
|
||||
async def chat(self, event: MessageEvent, ctx: Context):
|
||||
await event.reply("开始对话...")
|
||||
"""
|
||||
if mode not in {"replace", "reject"}:
|
||||
raise ValueError("conversation_command mode must be 'replace' or 'reject'")
|
||||
# bool 是 int 子类,需单独排除
|
||||
if isinstance(timeout, bool) or int(timeout) <= 0:
|
||||
raise ValueError("conversation_command timeout must be a positive integer")
|
||||
if float(grace_period) <= 0:
|
||||
|
||||
@@ -113,6 +113,11 @@ class MyPlugin(Star):
|
||||
- 注册 LLM 工具
|
||||
- 启动后台任务
|
||||
|
||||
**最佳实践:**
|
||||
- `on_start()` 里只做初始化、能力注册和轻量状态恢复
|
||||
- 需要长期保存的应是配置值、句柄、任务引用,不要把 `ctx` 实例长期挂到 `self`
|
||||
- 如果要和 AstrBot 原生 persona / conversation 协作,优先在这里校验或创建所需资源
|
||||
|
||||
**示例:**
|
||||
|
||||
```python
|
||||
@@ -150,6 +155,11 @@ class MyPlugin(Star):
|
||||
- 注销 LLM 工具
|
||||
- 保存状态数据
|
||||
|
||||
**最佳实践:**
|
||||
- 在 `on_stop()` 中释放 `on_start()` 注册的任务、监听器和外部资源
|
||||
- 把需要持久化的状态尽量提前落库,不要把关键保存逻辑完全依赖在进程退出瞬间
|
||||
- 始终把收到的 `ctx` 继续传给 `super().on_stop(ctx)`,不要手动丢掉它
|
||||
|
||||
**示例:**
|
||||
|
||||
```python
|
||||
|
||||
@@ -1101,6 +1101,8 @@ from astrbot_sdk.clients import PersonaManagerClient
|
||||
|
||||
获取指定人格。
|
||||
|
||||
当人格不存在时会抛出 `ValueError`,而不是返回 `None`。
|
||||
|
||||
---
|
||||
|
||||
#### `get_all_personas()`
|
||||
@@ -1163,6 +1165,17 @@ from astrbot_sdk.clients import ConversationManagerClient
|
||||
|
||||
---
|
||||
|
||||
#### `get_current_conversation(session, create_if_not_exists=False)`
|
||||
|
||||
获取当前 session 正在使用的对话记录。
|
||||
|
||||
这个方法适合“跟随 AstrBot 原生当前会话状态”的插件,例如:
|
||||
- 给当前会话切换 persona
|
||||
- 判断当前主聊天是否已经在某个 persona 下
|
||||
- 在 `waiting_llm_request` / `llm_request` hook 中对当前对话做增强
|
||||
|
||||
---
|
||||
|
||||
#### `get_conversations(session=None, platform_id=None)`
|
||||
|
||||
获取对话列表。
|
||||
|
||||
@@ -662,7 +662,7 @@ await ctx.conversations.delete_conversation(
|
||||
await ctx.conversations.delete_conversation(event.session_id)
|
||||
```
|
||||
|
||||
##### `get_conversation() / get_conversations()`
|
||||
##### `get_conversation() / get_current_conversation() / get_conversations()`
|
||||
|
||||
获取对话。
|
||||
|
||||
@@ -674,6 +674,12 @@ conv = await ctx.conversations.get_conversation(
|
||||
create_if_not_exists=True
|
||||
)
|
||||
|
||||
# 获取当前选中的对话
|
||||
current = await ctx.conversations.get_current_conversation(
|
||||
event.session_id,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
|
||||
# 获取对话列表
|
||||
convs = await ctx.conversations.get_conversations(event.session_id)
|
||||
```
|
||||
|
||||
@@ -210,11 +210,110 @@ async def handle_request(self, event, ctx: Context):
|
||||
await ctx.platform.send(event.user_id, "已自动通过好友请求")
|
||||
```
|
||||
|
||||
#### LLM Pipeline Hooks
|
||||
|
||||
`@on_event` 也用于挂接 AstrBot 原生消息处理链路中的系统事件。
|
||||
|
||||
常见事件及可注入对象:
|
||||
|
||||
| 事件名 | 常见可注入参数 | 是否可修改主链路 |
|
||||
|------|------|------|
|
||||
| `waiting_llm_request` | `MessageEvent`, `Context` | 间接可修改,例如切换当前对话 persona |
|
||||
| `llm_request` | `MessageEvent`, `Context`, `ProviderRequest` | 是,可直接修改 `ProviderRequest` |
|
||||
| `llm_response` | `MessageEvent`, `Context`, `LLMResponse` | 否,适合观察和提取回复内容 |
|
||||
| `decorating_result` | `MessageEvent`, `Context`, `MessageEventResult` | 是,可直接修改结果消息链 |
|
||||
| `after_message_sent` | `MessageEvent`, `Context` | 否,适合落库、记忆、统计 |
|
||||
|
||||
最小示例:
|
||||
|
||||
```python
|
||||
from astrbot_sdk import Context, MessageEvent
|
||||
from astrbot_sdk.decorators import on_event
|
||||
from astrbot_sdk.llm.entities import ProviderRequest
|
||||
|
||||
@on_event("llm_request")
|
||||
async def add_memory(self, event: MessageEvent, ctx: Context, request: ProviderRequest):
|
||||
del event, ctx
|
||||
request.system_prompt = (request.system_prompt or "") + "\n\nmemory: user likes tea"
|
||||
```
|
||||
|
||||
完整示例:
|
||||
|
||||
```python
|
||||
from astrbot_sdk import Context, MessageEvent, Star
|
||||
from astrbot_sdk.clients.llm import LLMResponse
|
||||
from astrbot_sdk.clients.managers import ConversationUpdateParams
|
||||
from astrbot_sdk.decorators import on_event
|
||||
from astrbot_sdk.llm.entities import ProviderRequest
|
||||
from astrbot_sdk.message_result import MessageEventResult
|
||||
from astrbot_sdk.message_components import Plain
|
||||
|
||||
class PersonaSample(Star):
|
||||
@on_event("waiting_llm_request")
|
||||
async def ensure_persona(self, event: MessageEvent, ctx: Context) -> None:
|
||||
conversation = await ctx.conversations.get_current_conversation(
|
||||
event.session_id,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
if conversation is None or conversation.persona_id == "girlfriend":
|
||||
return
|
||||
await ctx.conversations.update_conversation(
|
||||
event.session_id,
|
||||
conversation.conversation_id,
|
||||
ConversationUpdateParams(persona_id="girlfriend"),
|
||||
)
|
||||
|
||||
@on_event("llm_request")
|
||||
async def inject_context(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
ctx: Context,
|
||||
request: ProviderRequest,
|
||||
) -> None:
|
||||
memories = await ctx.memory.search(event.text, limit=3)
|
||||
facts = []
|
||||
for item in memories:
|
||||
value = item.get("value")
|
||||
if isinstance(value, dict) and value.get("content"):
|
||||
facts.append(f"- {value['content']}")
|
||||
if facts:
|
||||
request.system_prompt = (request.system_prompt or "") + "\n\n" + "\n".join(facts)
|
||||
|
||||
@on_event("llm_response")
|
||||
async def capture_reply(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
ctx: Context,
|
||||
response: LLMResponse,
|
||||
) -> None:
|
||||
del ctx
|
||||
if response.text:
|
||||
event.set_extra("last_reply", response.text)
|
||||
|
||||
@on_event("decorating_result")
|
||||
async def decorate(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
ctx: Context,
|
||||
result: MessageEventResult,
|
||||
) -> None:
|
||||
del event, ctx
|
||||
result.chain.append(Plain("\n[persona active]", convert=False))
|
||||
|
||||
@on_event("after_message_sent")
|
||||
async def persist(self, event: MessageEvent, ctx: Context) -> None:
|
||||
reply = str(event.get_extra("last_reply", "") or "").strip()
|
||||
if reply:
|
||||
await ctx.db.set("sample:last_reply", reply)
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. 用于处理非消息类型的事件(如群成员变动、好友请求等)
|
||||
1. 可用于处理平台事件,也可用于处理 AstrBot 原生消息链路中的系统事件(如 `llm_request`)
|
||||
2. 不能与 `@rate_limit` 或 `@cooldown` 一起使用
|
||||
3. 不同平台的事件类型可能不同,需要查阅平台文档
|
||||
4. `llm_request` 和 `decorating_result` 注入的是可变对象,修改会回写到 AstrBot 主链路
|
||||
5. `llm_response` 主要用于观测和提取结果,不应用来替代主回复流程
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -642,6 +642,15 @@ CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("conversation",),
|
||||
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
|
||||
)
|
||||
CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema(
|
||||
required=("session",),
|
||||
session={"type": "string"},
|
||||
create_if_not_exists={"type": "boolean"},
|
||||
)
|
||||
CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("conversation",),
|
||||
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
|
||||
)
|
||||
CONVERSATION_LIST_INPUT_SCHEMA = _object_schema(
|
||||
session=_nullable({"type": "string"}),
|
||||
platform_id=_nullable({"type": "string"}),
|
||||
@@ -942,6 +951,14 @@ PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("provider",),
|
||||
provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA),
|
||||
)
|
||||
PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA = _object_schema(
|
||||
required=("provider_id",),
|
||||
provider_id={"type": "string"},
|
||||
)
|
||||
PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("config",),
|
||||
config=_nullable({"type": "object"}),
|
||||
)
|
||||
PROVIDER_MANAGER_LOAD_INPUT_SCHEMA = _object_schema(
|
||||
required=("provider_config",),
|
||||
provider_config={"type": "object"},
|
||||
@@ -1199,6 +1216,10 @@ BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
|
||||
"input": CONVERSATION_GET_INPUT_SCHEMA,
|
||||
"output": CONVERSATION_GET_OUTPUT_SCHEMA,
|
||||
},
|
||||
"conversation.get_current": {
|
||||
"input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA,
|
||||
"output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA,
|
||||
},
|
||||
"conversation.list": {
|
||||
"input": CONVERSATION_LIST_INPUT_SCHEMA,
|
||||
"output": CONVERSATION_LIST_OUTPUT_SCHEMA,
|
||||
@@ -1332,6 +1353,10 @@ BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
|
||||
"input": PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA,
|
||||
"output": PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA,
|
||||
},
|
||||
"provider.manager.get_merged_provider_config": {
|
||||
"input": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA,
|
||||
"output": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA,
|
||||
},
|
||||
"provider.manager.load": {
|
||||
"input": PROVIDER_MANAGER_LOAD_INPUT_SCHEMA,
|
||||
"output": PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA,
|
||||
@@ -1555,6 +1580,8 @@ __all__ = [
|
||||
"PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA",
|
||||
"PROVIDER_MANAGER_LOAD_INPUT_SCHEMA",
|
||||
@@ -1639,6 +1666,8 @@ __all__ = [
|
||||
"CONVERSATION_CREATE_SCHEMA",
|
||||
"CONVERSATION_DELETE_INPUT_SCHEMA",
|
||||
"CONVERSATION_DELETE_OUTPUT_SCHEMA",
|
||||
"CONVERSATION_GET_CURRENT_INPUT_SCHEMA",
|
||||
"CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA",
|
||||
"CONVERSATION_GET_INPUT_SCHEMA",
|
||||
"CONVERSATION_GET_OUTPUT_SCHEMA",
|
||||
"CONVERSATION_LIST_INPUT_SCHEMA",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .bridge_base import CapabilityRouterBridgeBase
|
||||
from .capabilities import (
|
||||
ConversationCapabilityMixin,
|
||||
DBCapabilityMixin,
|
||||
HttpCapabilityMixin,
|
||||
KnowledgeBaseCapabilityMixin,
|
||||
LLMCapabilityMixin,
|
||||
MemoryCapabilityMixin,
|
||||
MetadataCapabilityMixin,
|
||||
PersonaCapabilityMixin,
|
||||
PlatformCapabilityMixin,
|
||||
ProviderCapabilityMixin,
|
||||
SessionCapabilityMixin,
|
||||
SystemCapabilityMixin,
|
||||
)
|
||||
|
||||
|
||||
class BuiltinCapabilityRouterMixin(
|
||||
LLMCapabilityMixin,
|
||||
MemoryCapabilityMixin,
|
||||
DBCapabilityMixin,
|
||||
PlatformCapabilityMixin,
|
||||
HttpCapabilityMixin,
|
||||
MetadataCapabilityMixin,
|
||||
ProviderCapabilityMixin,
|
||||
SessionCapabilityMixin,
|
||||
PersonaCapabilityMixin,
|
||||
ConversationCapabilityMixin,
|
||||
KnowledgeBaseCapabilityMixin,
|
||||
SystemCapabilityMixin,
|
||||
CapabilityRouterBridgeBase,
|
||||
):
|
||||
def _register_builtin_capabilities(self) -> None:
|
||||
self._register_llm_capabilities()
|
||||
self._register_memory_capabilities()
|
||||
self._register_db_capabilities()
|
||||
self._register_platform_capabilities()
|
||||
self._register_http_capabilities()
|
||||
self._register_metadata_capabilities()
|
||||
self._register_provider_capabilities()
|
||||
self._register_agent_tool_capabilities()
|
||||
self._register_session_capabilities()
|
||||
self._register_persona_capabilities()
|
||||
self._register_conversation_capabilities()
|
||||
self._register_kb_capabilities()
|
||||
self._register_provider_manager_capabilities()
|
||||
self._register_platform_manager_capabilities()
|
||||
self._register_system_capabilities()
|
||||
|
||||
|
||||
__all__ = ["BuiltinCapabilityRouterMixin"]
|
||||
92
src/astrbot_sdk/runtime/_capability_router_builtins/_host.py
Normal file
92
src/astrbot_sdk/runtime/_capability_router_builtins/_host.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ...protocol.descriptors import CapabilityDescriptor
|
||||
|
||||
|
||||
class CapabilityRouterHost:
|
||||
memory_store: dict[str, dict[str, Any]]
|
||||
_memory_index: dict[str, dict[str, Any]]
|
||||
_memory_dirty_keys: set[str]
|
||||
_memory_expires_at: dict[str, datetime | None]
|
||||
db_store: dict[str, Any]
|
||||
sent_messages: list[dict[str, Any]]
|
||||
event_actions: list[dict[str, Any]]
|
||||
http_api_store: list[dict[str, Any]]
|
||||
_event_streams: dict[str, dict[str, Any]]
|
||||
_plugins: dict[str, Any]
|
||||
_request_overlays: dict[str, dict[str, Any]]
|
||||
_provider_catalog: dict[str, list[dict[str, Any]]]
|
||||
_provider_configs: dict[str, dict[str, Any]]
|
||||
_active_provider_ids: dict[str, str | None]
|
||||
_provider_change_subscriptions: dict[str, asyncio.Queue[dict[str, Any]]]
|
||||
_system_data_root: Path
|
||||
_session_waiters: dict[str, set[str]]
|
||||
_session_plugin_configs: dict[str, dict[str, Any]]
|
||||
_session_service_configs: dict[str, dict[str, Any]]
|
||||
_db_watch_subscriptions: dict[str, tuple[str | None, asyncio.Queue[dict[str, Any]]]]
|
||||
_dynamic_command_routes: dict[str, list[dict[str, Any]]]
|
||||
_file_token_store: dict[str, str]
|
||||
_platform_instances: list[dict[str, Any]]
|
||||
_persona_store: dict[str, dict[str, Any]]
|
||||
_conversation_store: dict[str, dict[str, Any]]
|
||||
_session_current_conversation_ids: dict[str, str]
|
||||
_kb_store: dict[str, dict[str, Any]]
|
||||
|
||||
def register(
|
||||
self,
|
||||
descriptor: CapabilityDescriptor,
|
||||
*,
|
||||
call_handler=None,
|
||||
stream_handler=None,
|
||||
finalize=None,
|
||||
exposed: bool = True,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _require_caller_plugin_id(capability_name: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def register_dynamic_command_route(
|
||||
self,
|
||||
*,
|
||||
plugin_id: str,
|
||||
command_name: str,
|
||||
handler_full_name: str,
|
||||
desc: str = "",
|
||||
priority: int = 0,
|
||||
use_regex: bool = False,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_platform_instances(self) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _register_agent_tool_capabilities(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _provider_entry(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
capability_name: str,
|
||||
expected_kind: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _provider_embedding_get_embedding(
|
||||
self, request_id: str, payload: dict[str, Any], token
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _provider_embedding_get_embeddings(
|
||||
self, request_id: str, payload: dict[str, Any], token
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ...protocol.descriptors import (
|
||||
BUILTIN_CAPABILITY_SCHEMAS,
|
||||
CapabilityDescriptor,
|
||||
SessionRef,
|
||||
)
|
||||
from ._host import CapabilityRouterHost
|
||||
|
||||
|
||||
def _clone_target_payload(value: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
return {str(key): item for key, item in value.items()}
|
||||
|
||||
|
||||
def _clone_chain_payload(value: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
return [
|
||||
{str(key): item for key, item in chunk.items()}
|
||||
for chunk in value
|
||||
if isinstance(chunk, dict)
|
||||
]
|
||||
|
||||
|
||||
_MOCK_EMBEDDING_DIM = 24
|
||||
|
||||
|
||||
def _embedding_terms(text: str) -> list[str]:
|
||||
"""Build stable tokens for the mock embedding implementation."""
|
||||
normalized = re.sub(r"\s+", " ", str(text).strip().casefold())
|
||||
compact = normalized.replace(" ", "")
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word]
|
||||
if compact:
|
||||
if len(compact) == 1:
|
||||
terms.append(compact)
|
||||
else:
|
||||
terms.extend(
|
||||
compact[index : index + 2] for index in range(len(compact) - 1)
|
||||
)
|
||||
terms.append(compact)
|
||||
return terms or [normalized]
|
||||
|
||||
|
||||
def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
|
||||
"""Generate a deterministic normalized mock embedding vector."""
|
||||
values = [0.0] * _MOCK_EMBEDDING_DIM
|
||||
for term in _embedding_terms(text):
|
||||
digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest()
|
||||
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
|
||||
values[index] += 1.0 + min(len(term), 8) * 0.05
|
||||
norm = math.sqrt(sum(value * value for value in values))
|
||||
if norm <= 0:
|
||||
return values
|
||||
return [value / norm for value in values]
|
||||
|
||||
|
||||
class CapabilityRouterBridgeBase(CapabilityRouterHost):
|
||||
def _builtin_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
*,
|
||||
supports_stream: bool = False,
|
||||
cancelable: bool = False,
|
||||
) -> CapabilityDescriptor:
|
||||
schema = BUILTIN_CAPABILITY_SCHEMAS[name]
|
||||
return CapabilityDescriptor(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=copy.deepcopy(schema["input"]),
|
||||
output_schema=copy.deepcopy(schema["output"]),
|
||||
supports_stream=supports_stream,
|
||||
cancelable=cancelable,
|
||||
)
|
||||
|
||||
def _resolve_target(
|
||||
self, payload: dict[str, Any]
|
||||
) -> tuple[str, dict[str, Any] | None]:
|
||||
target_payload = payload.get("target")
|
||||
if isinstance(target_payload, dict):
|
||||
target = SessionRef.model_validate(target_payload)
|
||||
return target.session, target.to_payload()
|
||||
return str(payload.get("session", "")), None
|
||||
|
||||
@staticmethod
|
||||
def _is_group_session(session: str) -> bool:
|
||||
normalized = str(session).lower()
|
||||
return ":group:" in normalized or ":groupmessage:" in normalized
|
||||
|
||||
@staticmethod
|
||||
def _mock_group_payload(session: str) -> dict[str, Any] | None:
|
||||
if not CapabilityRouterBridgeBase._is_group_session(session):
|
||||
return None
|
||||
members = [
|
||||
{
|
||||
"user_id": f"{session}:member-1",
|
||||
"nickname": "Member 1",
|
||||
"role": "member",
|
||||
},
|
||||
{
|
||||
"user_id": f"{session}:member-2",
|
||||
"nickname": "Member 2",
|
||||
"role": "admin",
|
||||
},
|
||||
]
|
||||
return {
|
||||
"group_id": session.rsplit(":", maxsplit=1)[-1],
|
||||
"group_name": f"Mock Group {session.rsplit(':', maxsplit=1)[-1]}",
|
||||
"group_avatar": "",
|
||||
"group_owner": members[0]["user_id"],
|
||||
"group_admins": [members[1]["user_id"]],
|
||||
"members": members,
|
||||
}
|
||||
|
||||
def _session_plugin_config(self, session: str) -> dict[str, Any]:
|
||||
config = self._session_plugin_configs.get(str(session), {})
|
||||
return dict(config) if isinstance(config, dict) else {}
|
||||
|
||||
def _session_service_config(self, session: str) -> dict[str, Any]:
|
||||
config = self._session_service_configs.get(str(session), {})
|
||||
return dict(config) if isinstance(config, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def _session_platform_id(session: str) -> str:
|
||||
parts = str(session).split(":", maxsplit=1)
|
||||
if parts and parts[0].strip():
|
||||
return parts[0].strip()
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_history_payload(value: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
return [dict(item) for item in value if isinstance(item, dict)]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_persona_dialogs_payload(value: Any) -> list[str]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
return [str(item) for item in value if isinstance(item, str)]
|
||||
|
||||
@staticmethod
|
||||
def _optional_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _provider_entry(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
capability_name: str,
|
||||
expected_kind: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _provider_embedding_get_embedding(
|
||||
self, request_id: str, payload: dict[str, Any], token
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _provider_embedding_get_embeddings(
|
||||
self, request_id: str, payload: dict[str, Any], token
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,27 @@
|
||||
from .conversation import ConversationCapabilityMixin
|
||||
from .db import DBCapabilityMixin
|
||||
from .http import HttpCapabilityMixin
|
||||
from .kb import KnowledgeBaseCapabilityMixin
|
||||
from .llm import LLMCapabilityMixin
|
||||
from .memory import MemoryCapabilityMixin
|
||||
from .metadata import MetadataCapabilityMixin
|
||||
from .persona import PersonaCapabilityMixin
|
||||
from .platform import PlatformCapabilityMixin
|
||||
from .provider import ProviderCapabilityMixin
|
||||
from .session import SessionCapabilityMixin
|
||||
from .system import SystemCapabilityMixin
|
||||
|
||||
__all__ = [
|
||||
"ConversationCapabilityMixin",
|
||||
"DBCapabilityMixin",
|
||||
"HttpCapabilityMixin",
|
||||
"KnowledgeBaseCapabilityMixin",
|
||||
"LLMCapabilityMixin",
|
||||
"MemoryCapabilityMixin",
|
||||
"MetadataCapabilityMixin",
|
||||
"PersonaCapabilityMixin",
|
||||
"PlatformCapabilityMixin",
|
||||
"ProviderCapabilityMixin",
|
||||
"SessionCapabilityMixin",
|
||||
"SystemCapabilityMixin",
|
||||
]
|
||||
@@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class ConversationCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _conversation_new(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
if not session:
|
||||
raise AstrBotError.invalid_input("conversation.new requires session")
|
||||
raw_conversation = payload.get("conversation")
|
||||
if raw_conversation is None:
|
||||
raw_conversation = {}
|
||||
if not isinstance(raw_conversation, dict):
|
||||
raise AstrBotError.invalid_input(
|
||||
"conversation.new requires conversation object"
|
||||
)
|
||||
conversation_id = uuid.uuid4().hex
|
||||
now = self._now_iso()
|
||||
record = {
|
||||
"conversation_id": conversation_id,
|
||||
"session": session,
|
||||
"platform_id": (
|
||||
str(raw_conversation.get("platform_id"))
|
||||
if raw_conversation.get("platform_id") is not None
|
||||
else self._session_platform_id(session)
|
||||
),
|
||||
"history": self._normalize_history_payload(raw_conversation.get("history")),
|
||||
"title": (
|
||||
str(raw_conversation.get("title"))
|
||||
if raw_conversation.get("title") is not None
|
||||
else None
|
||||
),
|
||||
"persona_id": (
|
||||
str(raw_conversation.get("persona_id"))
|
||||
if raw_conversation.get("persona_id") is not None
|
||||
else None
|
||||
),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"token_usage": None,
|
||||
}
|
||||
self._conversation_store[conversation_id] = record
|
||||
self._session_current_conversation_ids[session] = conversation_id
|
||||
return {"conversation_id": conversation_id}
|
||||
|
||||
async def _conversation_switch(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
conversation_id = str(payload.get("conversation_id", "")).strip()
|
||||
record = self._conversation_store.get(conversation_id)
|
||||
if record is None or str(record.get("session", "")) != session:
|
||||
raise AstrBotError.invalid_input(
|
||||
"conversation.switch requires a conversation in the same session"
|
||||
)
|
||||
self._session_current_conversation_ids[session] = conversation_id
|
||||
return {}
|
||||
|
||||
async def _conversation_delete(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
conversation_id = payload.get("conversation_id")
|
||||
normalized_conversation_id = (
|
||||
str(conversation_id).strip() if conversation_id is not None else ""
|
||||
)
|
||||
if not normalized_conversation_id:
|
||||
normalized_conversation_id = self._session_current_conversation_ids.get(
|
||||
session, ""
|
||||
)
|
||||
if not normalized_conversation_id:
|
||||
return {}
|
||||
record = self._conversation_store.get(normalized_conversation_id)
|
||||
if record is None:
|
||||
return {}
|
||||
if str(record.get("session", "")) != session:
|
||||
raise AstrBotError.invalid_input(
|
||||
"conversation.delete requires a conversation in the same session"
|
||||
)
|
||||
del self._conversation_store[normalized_conversation_id]
|
||||
current_conversation_id = self._session_current_conversation_ids.get(session)
|
||||
if current_conversation_id == normalized_conversation_id:
|
||||
replacement = next(
|
||||
(
|
||||
conversation_id
|
||||
for conversation_id, item in self._conversation_store.items()
|
||||
if str(item.get("session", "")) == session
|
||||
),
|
||||
None,
|
||||
)
|
||||
if replacement is None:
|
||||
self._session_current_conversation_ids.pop(session, None)
|
||||
else:
|
||||
self._session_current_conversation_ids[session] = replacement
|
||||
return {}
|
||||
|
||||
async def _conversation_get(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
conversation_id = str(payload.get("conversation_id", "")).strip()
|
||||
record = self._conversation_store.get(conversation_id)
|
||||
if record is None and bool(payload.get("create_if_not_exists", False)):
|
||||
created = await self._conversation_new(
|
||||
_request_id,
|
||||
{"session": session, "conversation": {}},
|
||||
_token,
|
||||
)
|
||||
record = self._conversation_store.get(
|
||||
str(created.get("conversation_id", "")).strip()
|
||||
)
|
||||
if record is None:
|
||||
return {"conversation": None}
|
||||
if str(record.get("session", "")) != session:
|
||||
return {"conversation": None}
|
||||
return {"conversation": dict(record)}
|
||||
|
||||
async def _conversation_get_current(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
conversation_id = self._session_current_conversation_ids.get(session, "")
|
||||
if not conversation_id and bool(payload.get("create_if_not_exists", False)):
|
||||
created = await self._conversation_new(
|
||||
_request_id,
|
||||
{"session": session, "conversation": {}},
|
||||
_token,
|
||||
)
|
||||
conversation_id = str(created.get("conversation_id", "")).strip()
|
||||
if not conversation_id:
|
||||
return {"conversation": None}
|
||||
record = self._conversation_store.get(conversation_id)
|
||||
if record is None or str(record.get("session", "")) != session:
|
||||
return {"conversation": None}
|
||||
return {"conversation": dict(record)}
|
||||
|
||||
async def _conversation_list(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = payload.get("session")
|
||||
platform_id = payload.get("platform_id")
|
||||
conversations = []
|
||||
for conversation_id in sorted(self._conversation_store.keys()):
|
||||
item = self._conversation_store[conversation_id]
|
||||
if session is not None and str(item.get("session", "")) != str(session):
|
||||
continue
|
||||
if platform_id is not None and str(item.get("platform_id", "")) != str(
|
||||
platform_id
|
||||
):
|
||||
continue
|
||||
conversations.append(dict(item))
|
||||
return {"conversations": conversations}
|
||||
|
||||
async def _conversation_update(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
conversation_id = payload.get("conversation_id")
|
||||
normalized_conversation_id = (
|
||||
str(conversation_id).strip() if conversation_id is not None else ""
|
||||
)
|
||||
if not normalized_conversation_id:
|
||||
normalized_conversation_id = self._session_current_conversation_ids.get(
|
||||
session, ""
|
||||
)
|
||||
if not normalized_conversation_id:
|
||||
return {}
|
||||
record = self._conversation_store.get(normalized_conversation_id)
|
||||
if record is None:
|
||||
return {}
|
||||
if str(record.get("session", "")) != session:
|
||||
raise AstrBotError.invalid_input(
|
||||
"conversation.update requires a conversation in the same session"
|
||||
)
|
||||
raw_conversation = payload.get("conversation")
|
||||
if not isinstance(raw_conversation, dict):
|
||||
raw_conversation = {}
|
||||
if "history" in raw_conversation:
|
||||
history = raw_conversation.get("history")
|
||||
record["history"] = (
|
||||
self._normalize_history_payload(history) if history is not None else []
|
||||
)
|
||||
if "title" in raw_conversation:
|
||||
title = raw_conversation.get("title")
|
||||
record["title"] = str(title) if title is not None else None
|
||||
if "persona_id" in raw_conversation:
|
||||
persona_id = raw_conversation.get("persona_id")
|
||||
record["persona_id"] = str(persona_id) if persona_id is not None else None
|
||||
if "token_usage" in raw_conversation:
|
||||
token_usage = raw_conversation.get("token_usage")
|
||||
record["token_usage"] = (
|
||||
int(token_usage) if token_usage is not None else None
|
||||
)
|
||||
record["updated_at"] = self._now_iso()
|
||||
return {}
|
||||
|
||||
def _register_conversation_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.new", "新建对话"),
|
||||
call_handler=self._conversation_new,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.switch", "切换对话"),
|
||||
call_handler=self._conversation_switch,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.delete", "删除对话"),
|
||||
call_handler=self._conversation_delete,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.get", "获取对话"),
|
||||
call_handler=self._conversation_get,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.get_current", "获取当前对话"),
|
||||
call_handler=self._conversation_get_current,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.list", "列出对话"),
|
||||
call_handler=self._conversation_list,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.update", "更新对话"),
|
||||
call_handler=self._conversation_update,
|
||||
)
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..._streaming import StreamExecution
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class DBCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _db_get(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
return {"value": self.db_store.get(str(payload.get("key", "")))}
|
||||
|
||||
async def _db_set(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
key = str(payload.get("key", ""))
|
||||
value = payload.get("value")
|
||||
self.db_store[key] = value
|
||||
self._emit_db_change(op="set", key=key, value=value)
|
||||
return {}
|
||||
|
||||
async def _db_delete(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
key = str(payload.get("key", ""))
|
||||
self.db_store.pop(key, None)
|
||||
self._emit_db_change(op="delete", key=key, value=None)
|
||||
return {}
|
||||
|
||||
async def _db_list(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
prefix = payload.get("prefix")
|
||||
keys = sorted(self.db_store.keys())
|
||||
if isinstance(prefix, str):
|
||||
keys = [item for item in keys if item.startswith(prefix)]
|
||||
return {"keys": keys}
|
||||
|
||||
async def _db_get_many(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
keys_payload = payload.get("keys")
|
||||
if not isinstance(keys_payload, (list, tuple)):
|
||||
raise AstrBotError.invalid_input("db.get_many 的 keys 必须是数组")
|
||||
keys = [str(item) for item in keys_payload]
|
||||
items = [{"key": key, "value": self.db_store.get(key)} for key in keys]
|
||||
return {"items": items}
|
||||
|
||||
async def _db_set_many(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
items_payload = payload.get("items")
|
||||
if not isinstance(items_payload, (list, tuple)):
|
||||
raise AstrBotError.invalid_input("db.set_many 的 items 必须是数组")
|
||||
for entry in items_payload:
|
||||
if not isinstance(entry, dict):
|
||||
raise AstrBotError.invalid_input(
|
||||
"db.set_many 的 items 必须是 object 数组"
|
||||
)
|
||||
key = str(entry.get("key", ""))
|
||||
value = entry.get("value")
|
||||
self.db_store[key] = value
|
||||
self._emit_db_change(op="set", key=key, value=value)
|
||||
return {}
|
||||
|
||||
async def _db_watch(
|
||||
self, request_id: str, payload: dict[str, Any], _token
|
||||
) -> StreamExecution:
|
||||
prefix = payload.get("prefix")
|
||||
prefix_value: str | None
|
||||
if isinstance(prefix, str):
|
||||
prefix_value = prefix
|
||||
elif prefix is None:
|
||||
prefix_value = None
|
||||
else:
|
||||
raise AstrBotError.invalid_input("db.watch 的 prefix 必须是 string 或 null")
|
||||
|
||||
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
self._db_watch_subscriptions[request_id] = (prefix_value, queue)
|
||||
|
||||
async def iterator() -> AsyncIterator[dict[str, Any]]:
|
||||
try:
|
||||
while True:
|
||||
yield await queue.get()
|
||||
finally:
|
||||
self._db_watch_subscriptions.pop(request_id, None)
|
||||
|
||||
return StreamExecution(
|
||||
iterator=iterator(),
|
||||
finalize=lambda _chunks: {},
|
||||
collect_chunks=False,
|
||||
)
|
||||
|
||||
def _register_db_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("db.get", "读取 KV"), call_handler=self._db_get
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("db.set", "写入 KV"), call_handler=self._db_set
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("db.delete", "删除 KV"),
|
||||
call_handler=self._db_delete,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("db.list", "列出 KV"), call_handler=self._db_list
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("db.get_many", "批量读取 KV"),
|
||||
call_handler=self._db_get_many,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("db.set_many", "批量写入 KV"),
|
||||
call_handler=self._db_set_many,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"db.watch",
|
||||
"订阅 KV 变更",
|
||||
supports_stream=True,
|
||||
cancelable=True,
|
||||
),
|
||||
stream_handler=self._db_watch,
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class HttpCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _http_register_api(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
methods_payload = payload.get("methods")
|
||||
if not isinstance(methods_payload, list) or not all(
|
||||
isinstance(item, str) for item in methods_payload
|
||||
):
|
||||
raise AstrBotError.invalid_input(
|
||||
"http.register_api 的 methods 必须是 string 数组"
|
||||
)
|
||||
route = str(payload.get("route", "")).strip()
|
||||
handler_capability = str(payload.get("handler_capability", "")).strip()
|
||||
if not route or not handler_capability:
|
||||
raise AstrBotError.invalid_input(
|
||||
"http.register_api 需要 route 和 handler_capability"
|
||||
)
|
||||
plugin_name = self._require_caller_plugin_id("http.register_api")
|
||||
methods = sorted({method.upper() for method in methods_payload if method})
|
||||
entry: dict[str, Any] = {
|
||||
"route": route,
|
||||
"methods": methods,
|
||||
"handler_capability": handler_capability,
|
||||
"description": str(payload.get("description", "")),
|
||||
"plugin_id": plugin_name,
|
||||
}
|
||||
self.http_api_store = [
|
||||
item
|
||||
for item in self.http_api_store
|
||||
if not (
|
||||
item.get("route") == route
|
||||
and item.get("plugin_id") == entry["plugin_id"]
|
||||
and item.get("methods") == methods
|
||||
)
|
||||
]
|
||||
self.http_api_store.append(entry)
|
||||
return {}
|
||||
|
||||
async def _http_unregister_api(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
route = str(payload.get("route", "")).strip()
|
||||
methods_payload = payload.get("methods")
|
||||
if not isinstance(methods_payload, list) or not all(
|
||||
isinstance(item, str) for item in methods_payload
|
||||
):
|
||||
raise AstrBotError.invalid_input(
|
||||
"http.unregister_api 的 methods 必须是 string 数组"
|
||||
)
|
||||
plugin_name = self._require_caller_plugin_id("http.unregister_api")
|
||||
methods = {method.upper() for method in methods_payload if method}
|
||||
updated: list[dict[str, Any]] = []
|
||||
for entry in self.http_api_store:
|
||||
if entry.get("route") != route:
|
||||
updated.append(entry)
|
||||
continue
|
||||
if entry.get("plugin_id") != plugin_name:
|
||||
updated.append(entry)
|
||||
continue
|
||||
if not methods:
|
||||
continue
|
||||
remaining_methods = [
|
||||
method for method in entry.get("methods", []) if method not in methods
|
||||
]
|
||||
if remaining_methods:
|
||||
updated.append({**entry, "methods": remaining_methods})
|
||||
self.http_api_store = updated
|
||||
return {}
|
||||
|
||||
async def _http_list_apis(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
plugin_name = self._require_caller_plugin_id("http.list_apis")
|
||||
apis = [
|
||||
dict(entry)
|
||||
for entry in self.http_api_store
|
||||
if entry.get("plugin_id") == plugin_name
|
||||
]
|
||||
return {"apis": apis}
|
||||
|
||||
def _register_http_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("http.register_api", "注册 HTTP 路由"),
|
||||
call_handler=self._http_register_api,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("http.unregister_api", "注销 HTTP 路由"),
|
||||
call_handler=self._http_unregister_api,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("http.list_apis", "列出 HTTP 路由"),
|
||||
call_handler=self._http_list_apis,
|
||||
)
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class KnowledgeBaseCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _kb_get(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
kb_id = str(payload.get("kb_id", "")).strip()
|
||||
record = self._kb_store.get(kb_id)
|
||||
return {"kb": dict(record) if isinstance(record, dict) else None}
|
||||
|
||||
async def _kb_create(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
raw_kb = payload.get("kb")
|
||||
if not isinstance(raw_kb, dict):
|
||||
raise AstrBotError.invalid_input("kb.create requires kb object")
|
||||
embedding_provider_id = str(raw_kb.get("embedding_provider_id", "")).strip()
|
||||
if not embedding_provider_id:
|
||||
raise AstrBotError.invalid_input("kb.create requires embedding_provider_id")
|
||||
kb_id = uuid.uuid4().hex
|
||||
now = self._now_iso()
|
||||
record = {
|
||||
"kb_id": kb_id,
|
||||
"kb_name": str(raw_kb.get("kb_name", "")),
|
||||
"description": (
|
||||
str(raw_kb.get("description"))
|
||||
if raw_kb.get("description") is not None
|
||||
else None
|
||||
),
|
||||
"emoji": (
|
||||
str(raw_kb.get("emoji")) if raw_kb.get("emoji") is not None else None
|
||||
),
|
||||
"embedding_provider_id": embedding_provider_id,
|
||||
"rerank_provider_id": (
|
||||
str(raw_kb.get("rerank_provider_id"))
|
||||
if raw_kb.get("rerank_provider_id") is not None
|
||||
else None
|
||||
),
|
||||
"chunk_size": self._optional_int(raw_kb.get("chunk_size")),
|
||||
"chunk_overlap": self._optional_int(raw_kb.get("chunk_overlap")),
|
||||
"top_k_dense": self._optional_int(raw_kb.get("top_k_dense")),
|
||||
"top_k_sparse": self._optional_int(raw_kb.get("top_k_sparse")),
|
||||
"top_m_final": self._optional_int(raw_kb.get("top_m_final")),
|
||||
"doc_count": 0,
|
||||
"chunk_count": 0,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
self._kb_store[kb_id] = record
|
||||
return {"kb": dict(record)}
|
||||
|
||||
async def _kb_delete(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
kb_id = str(payload.get("kb_id", "")).strip()
|
||||
deleted = self._kb_store.pop(kb_id, None) is not None
|
||||
return {"deleted": deleted}
|
||||
|
||||
def _register_kb_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("kb.get", "获取知识库"),
|
||||
call_handler=self._kb_get,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("kb.create", "创建知识库"),
|
||||
call_handler=self._kb_create,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("kb.delete", "删除知识库"),
|
||||
call_handler=self._kb_delete,
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class LLMCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _llm_chat(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
prompt = str(payload.get("prompt", ""))
|
||||
return {"text": f"Echo: {prompt}"}
|
||||
|
||||
async def _llm_chat_raw(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
prompt = str(payload.get("prompt", ""))
|
||||
text = f"Echo: {prompt}"
|
||||
return {
|
||||
"text": text,
|
||||
"usage": {
|
||||
"input_tokens": len(prompt),
|
||||
"output_tokens": len(text),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
async def _llm_stream(
|
||||
self,
|
||||
_request_id: str,
|
||||
payload: dict[str, Any],
|
||||
token,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
text = f"Echo: {str(payload.get('prompt', ''))}"
|
||||
for char in text:
|
||||
token.raise_if_cancelled()
|
||||
await asyncio.sleep(0)
|
||||
yield {"text": char}
|
||||
|
||||
def _register_llm_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("llm.chat", "发送对话请求,返回文本"),
|
||||
call_handler=self._llm_chat,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("llm.chat_raw", "发送对话请求,返回完整响应"),
|
||||
call_handler=self._llm_chat_raw,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"llm.stream_chat",
|
||||
"流式对话",
|
||||
supports_stream=True,
|
||||
cancelable=True,
|
||||
),
|
||||
stream_handler=self._llm_stream,
|
||||
finalize=lambda chunks: {
|
||||
"text": "".join(item.get("text", "") for item in chunks)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,618 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
@staticmethod
|
||||
def _is_ttl_memory_entry(value: Any) -> bool:
|
||||
"""判断存储值是否使用了 TTL 包装结构。
|
||||
|
||||
Args:
|
||||
value: 待检查的存储值。
|
||||
|
||||
Returns:
|
||||
bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
|
||||
"""
|
||||
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
|
||||
|
||||
@classmethod
|
||||
def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
|
||||
"""提取用于检索的原始 memory payload。
|
||||
|
||||
Args:
|
||||
stored: memory_store 中保存的原始值。
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
|
||||
"""
|
||||
if not isinstance(stored, dict):
|
||||
return None
|
||||
if cls._is_ttl_memory_entry(stored):
|
||||
value = stored.get("value")
|
||||
return value if isinstance(value, dict) else None
|
||||
return stored
|
||||
|
||||
@classmethod
|
||||
def _extract_memory_text(cls, stored: Any) -> str:
|
||||
"""提取用于检索索引的首选文本。
|
||||
|
||||
Args:
|
||||
stored: memory_store 中保存的原始值。
|
||||
|
||||
Returns:
|
||||
str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
|
||||
"""
|
||||
value = cls._memory_value_for_search(stored)
|
||||
if not isinstance(value, dict):
|
||||
return ""
|
||||
for field_name in ("embedding_text", "content", "summary", "title", "text"):
|
||||
item = value.get(field_name)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
@staticmethod
|
||||
def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
|
||||
"""将 TTL 秒数转换为 UTC 过期时间。
|
||||
|
||||
Args:
|
||||
ttl_seconds: TTL 秒数。
|
||||
|
||||
Returns:
|
||||
datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
|
||||
"""
|
||||
try:
|
||||
ttl = int(ttl_seconds)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if ttl < 1:
|
||||
return None
|
||||
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
|
||||
@staticmethod
|
||||
def _memory_keyword_score(query: str, key: str, text: str) -> float:
|
||||
"""计算关键词匹配分数。
|
||||
|
||||
Args:
|
||||
query: 查询文本。
|
||||
key: memory 条目的键。
|
||||
text: 已索引的检索文本。
|
||||
|
||||
Returns:
|
||||
float: 基于键名和文本命中的粗粒度关键词分数。
|
||||
"""
|
||||
normalized_query = str(query).casefold()
|
||||
if not normalized_query:
|
||||
return 1.0
|
||||
normalized_key = str(key).casefold()
|
||||
normalized_text = str(text).casefold()
|
||||
if normalized_query in normalized_key:
|
||||
return 1.0
|
||||
if normalized_query in normalized_text:
|
||||
return 0.9
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
||||
"""计算两个向量之间的余弦相似度。
|
||||
|
||||
Args:
|
||||
left: 左侧向量。
|
||||
right: 右侧向量。
|
||||
|
||||
Returns:
|
||||
float: 余弦相似度;输入不合法时返回 ``0.0``。
|
||||
"""
|
||||
if not left or not right or len(left) != len(right):
|
||||
return 0.0
|
||||
left_norm = math.sqrt(sum(value * value for value in left))
|
||||
right_norm = math.sqrt(sum(value * value for value in right))
|
||||
if left_norm <= 0 or right_norm <= 0:
|
||||
return 0.0
|
||||
return sum(a * b for a, b in zip(left, right, strict=False)) / (
|
||||
left_norm * right_norm
|
||||
)
|
||||
|
||||
def _resolve_memory_embedding_provider_id(
|
||||
self,
|
||||
provider_id: Any,
|
||||
*,
|
||||
required: bool,
|
||||
) -> str | None:
|
||||
"""解析 memory.search 要使用的 embedding provider。
|
||||
|
||||
Args:
|
||||
provider_id: 调用方显式传入的 provider 标识。
|
||||
required: 当前检索模式是否强制要求 embedding provider。
|
||||
|
||||
Returns:
|
||||
str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。
|
||||
"""
|
||||
normalized = str(provider_id).strip() if provider_id is not None else ""
|
||||
if normalized:
|
||||
self._provider_entry(
|
||||
{"provider_id": normalized},
|
||||
"memory.search",
|
||||
"embedding",
|
||||
)
|
||||
return normalized
|
||||
active_id = self._active_provider_ids.get("embedding")
|
||||
if active_id is not None:
|
||||
normalized_active = str(active_id).strip()
|
||||
if normalized_active:
|
||||
self._provider_entry(
|
||||
{"provider_id": normalized_active},
|
||||
"memory.search",
|
||||
"embedding",
|
||||
)
|
||||
return normalized_active
|
||||
if required:
|
||||
raise AstrBotError.invalid_input(
|
||||
"memory.search requires an embedding provider",
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
|
||||
"""将原始索引项规范化为内部统一结构。
|
||||
|
||||
Args:
|
||||
entry: 当前索引表中的原始项。
|
||||
text: 当前条目的索引文本。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
|
||||
"""
|
||||
if isinstance(entry, dict):
|
||||
return {
|
||||
"text": str(entry.get("text", text)),
|
||||
"embedding": (
|
||||
[float(item) for item in entry.get("embedding", [])]
|
||||
if isinstance(entry.get("embedding"), list)
|
||||
else None
|
||||
),
|
||||
"provider_id": (
|
||||
str(entry.get("provider_id")).strip()
|
||||
if entry.get("provider_id") is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
return {"text": text, "embedding": None, "provider_id": None}
|
||||
|
||||
def _clear_memory_sidecars(self, key: str) -> None:
|
||||
"""清理指定 memory 键对应的所有 sidecar 状态。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._memory_index.pop(key, None)
|
||||
self._memory_expires_at.pop(key, None)
|
||||
self._memory_dirty_keys.discard(key)
|
||||
|
||||
def _delete_memory_entry(self, key: str) -> bool:
|
||||
"""删除 memory 条目并同步清理 sidecar 状态。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 条目存在并删除成功时返回 ``True``。
|
||||
"""
|
||||
deleted = self.memory_store.pop(key, None) is not None
|
||||
self._clear_memory_sidecars(key)
|
||||
return deleted
|
||||
|
||||
def _upsert_memory_sidecars(
|
||||
self,
|
||||
key: str,
|
||||
stored: dict[str, Any],
|
||||
*,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""创建或更新单条 memory 的 sidecar 索引状态。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
stored: 需要建立索引的原始存储值。
|
||||
expires_at: 可选的绝对过期时间。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._memory_index[key] = {
|
||||
"text": self._extract_memory_text(stored),
|
||||
"embedding": None,
|
||||
"provider_id": None,
|
||||
}
|
||||
if expires_at is None:
|
||||
self._memory_expires_at.pop(key, None)
|
||||
else:
|
||||
self._memory_expires_at[key] = expires_at
|
||||
self._memory_dirty_keys.add(key)
|
||||
|
||||
def _ensure_memory_sidecars(self, key: str, stored: Any) -> None:
|
||||
"""确保 sidecar 状态与当前存储值保持一致。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
stored: memory_store 中的当前存储值。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not isinstance(stored, dict):
|
||||
return
|
||||
text = self._extract_memory_text(stored)
|
||||
existed = key in self._memory_index
|
||||
entry = self._memory_index_entry(self._memory_index.get(key), text=text)
|
||||
if entry["text"] != text:
|
||||
entry["text"] = text
|
||||
entry["embedding"] = None
|
||||
entry["provider_id"] = None
|
||||
self._memory_dirty_keys.add(key)
|
||||
self._memory_index[key] = entry
|
||||
if not existed:
|
||||
self._memory_dirty_keys.add(key)
|
||||
|
||||
def _is_memory_expired(self, key: str) -> bool:
|
||||
"""判断 memory 条目是否已过期。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 如果当前时间已超过记录的过期时间则返回 ``True``。
|
||||
"""
|
||||
expires_at = self._memory_expires_at.get(key)
|
||||
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
|
||||
|
||||
def _purge_expired_memory_entry(self, key: str) -> bool:
|
||||
"""在单条 memory 已过期时立即清理它。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 如果条目已过期并被成功清理则返回 ``True``。
|
||||
"""
|
||||
if not self._is_memory_expired(key):
|
||||
return False
|
||||
self._delete_memory_entry(key)
|
||||
return True
|
||||
|
||||
def _purge_expired_memory_entries(self) -> None:
|
||||
"""批量清理所有已跟踪的过期 TTL 条目。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for key in list(self._memory_expires_at):
|
||||
self._purge_expired_memory_entry(key)
|
||||
|
||||
async def _embedding_for_text(
|
||||
self,
|
||||
*,
|
||||
provider_id: str,
|
||||
text: str,
|
||||
) -> list[float]:
|
||||
"""通过 embedding capability 获取单条文本向量。
|
||||
|
||||
Args:
|
||||
provider_id: 使用的 embedding provider 标识。
|
||||
text: 待向量化的文本。
|
||||
|
||||
Returns:
|
||||
list[float]: provider 返回的向量;异常场景下返回空列表。
|
||||
"""
|
||||
output = await self._provider_embedding_get_embedding(
|
||||
"",
|
||||
{"provider_id": provider_id, "text": text},
|
||||
None,
|
||||
)
|
||||
embedding = output.get("embedding")
|
||||
if not isinstance(embedding, list):
|
||||
return []
|
||||
return [float(item) for item in embedding]
|
||||
|
||||
async def _embeddings_for_texts(
|
||||
self,
|
||||
*,
|
||||
provider_id: str,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""批量获取多条文本的 embedding 向量。
|
||||
|
||||
Args:
|
||||
provider_id: 使用的 embedding provider 标识。
|
||||
texts: 待向量化的文本列表。
|
||||
|
||||
Returns:
|
||||
list[list[float]]: 与输入顺序对应的向量列表。
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
output = await self._provider_embedding_get_embeddings(
|
||||
"",
|
||||
{"provider_id": provider_id, "texts": texts},
|
||||
None,
|
||||
)
|
||||
embeddings = output.get("embeddings")
|
||||
if not isinstance(embeddings, list):
|
||||
return []
|
||||
return [
|
||||
[float(value) for value in item]
|
||||
for item in embeddings
|
||||
if isinstance(item, list)
|
||||
]
|
||||
|
||||
async def _refresh_memory_embeddings(self, *, provider_id: str) -> None:
|
||||
"""刷新当前 provider 下脏或过期的 memory 向量索引。
|
||||
|
||||
Args:
|
||||
provider_id: 当前使用的 embedding provider 标识。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
keys_to_refresh: list[str] = []
|
||||
texts_to_refresh: list[str] = []
|
||||
for key, stored in self.memory_store.items():
|
||||
self._ensure_memory_sidecars(key, stored)
|
||||
entry = self._memory_index_entry(
|
||||
self._memory_index.get(key),
|
||||
text=self._extract_memory_text(stored),
|
||||
)
|
||||
should_refresh = (
|
||||
key in self._memory_dirty_keys
|
||||
or entry["embedding"] is None
|
||||
or entry["provider_id"] != provider_id
|
||||
)
|
||||
self._memory_index[key] = entry
|
||||
if should_refresh:
|
||||
keys_to_refresh.append(key)
|
||||
texts_to_refresh.append(str(entry["text"]))
|
||||
embeddings = await self._embeddings_for_texts(
|
||||
provider_id=provider_id,
|
||||
texts=texts_to_refresh,
|
||||
)
|
||||
for index, key in enumerate(keys_to_refresh):
|
||||
entry = self._memory_index_entry(
|
||||
self._memory_index.get(key),
|
||||
text=str(texts_to_refresh[index]),
|
||||
)
|
||||
entry["embedding"] = embeddings[index] if index < len(embeddings) else []
|
||||
entry["provider_id"] = provider_id
|
||||
self._memory_index[key] = entry
|
||||
self._memory_dirty_keys.discard(key)
|
||||
|
||||
async def _memory_search(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
query = str(payload.get("query", ""))
|
||||
mode = str(payload.get("mode", "auto")).strip().lower() or "auto"
|
||||
limit = self._optional_int(payload.get("limit"))
|
||||
raw_min_score = payload.get("min_score")
|
||||
min_score = float(raw_min_score) if raw_min_score is not None else None
|
||||
self._purge_expired_memory_entries()
|
||||
provider_id = self._resolve_memory_embedding_provider_id(
|
||||
payload.get("provider_id"),
|
||||
required=mode in {"vector", "hybrid"},
|
||||
)
|
||||
effective_mode = mode
|
||||
if effective_mode == "auto":
|
||||
effective_mode = "hybrid" if provider_id is not None else "keyword"
|
||||
query_embedding: list[float] | None = None
|
||||
if effective_mode in {"vector", "hybrid"}:
|
||||
if provider_id is None:
|
||||
raise AstrBotError.invalid_input(
|
||||
"memory.search requires an embedding provider",
|
||||
)
|
||||
await self._refresh_memory_embeddings(provider_id=provider_id)
|
||||
query_embedding = await self._embedding_for_text(
|
||||
provider_id=provider_id,
|
||||
text=query,
|
||||
)
|
||||
|
||||
items: list[dict[str, Any]] = []
|
||||
for key, value in self.memory_store.items():
|
||||
self._ensure_memory_sidecars(key, value)
|
||||
entry = self._memory_index_entry(
|
||||
self._memory_index.get(key),
|
||||
text=self._extract_memory_text(value),
|
||||
)
|
||||
text = str(entry.get("text", ""))
|
||||
keyword_score = self._memory_keyword_score(query, key, text)
|
||||
vector_score = 0.0
|
||||
if query_embedding is not None:
|
||||
embedding = entry.get("embedding")
|
||||
if isinstance(embedding, list):
|
||||
vector_score = max(
|
||||
0.0,
|
||||
self._cosine_similarity(query_embedding, embedding),
|
||||
)
|
||||
|
||||
if effective_mode == "keyword":
|
||||
score = keyword_score
|
||||
elif effective_mode == "vector":
|
||||
score = vector_score
|
||||
else:
|
||||
score = vector_score
|
||||
if keyword_score > 0:
|
||||
score = max(score, 0.4 + 0.6 * vector_score)
|
||||
if score <= 0:
|
||||
continue
|
||||
if min_score is not None and score < min_score:
|
||||
continue
|
||||
|
||||
if effective_mode == "keyword" or (keyword_score > 0 and vector_score <= 0):
|
||||
match_type = "keyword"
|
||||
elif effective_mode == "vector" or keyword_score <= 0:
|
||||
match_type = "vector"
|
||||
else:
|
||||
match_type = "hybrid"
|
||||
|
||||
items.append(
|
||||
{
|
||||
"key": key,
|
||||
"value": self._memory_value_for_search(value),
|
||||
"score": score,
|
||||
"match_type": match_type,
|
||||
}
|
||||
)
|
||||
items.sort(key=lambda item: (-float(item["score"]), str(item["key"])))
|
||||
if limit is not None and limit >= 0:
|
||||
items = items[:limit]
|
||||
return {"items": items}
|
||||
|
||||
async def _memory_save(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
key = str(payload.get("key", ""))
|
||||
value = payload.get("value")
|
||||
if not isinstance(value, dict):
|
||||
raise AstrBotError.invalid_input("memory.save 的 value 必须是 object")
|
||||
self.memory_store[key] = value
|
||||
self._upsert_memory_sidecars(key, value)
|
||||
return {}
|
||||
|
||||
async def _memory_get(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
key = str(payload.get("key", ""))
|
||||
if self._purge_expired_memory_entry(key):
|
||||
return {"value": None}
|
||||
return {"value": self.memory_store.get(key)}
|
||||
|
||||
async def _memory_delete(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._delete_memory_entry(str(payload.get("key", "")))
|
||||
return {}
|
||||
|
||||
async def _memory_save_with_ttl(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
key = str(payload.get("key", ""))
|
||||
value = payload.get("value")
|
||||
ttl_seconds = payload.get("ttl_seconds", 0)
|
||||
if not isinstance(value, dict):
|
||||
raise AstrBotError.invalid_input(
|
||||
"memory.save_with_ttl 的 value 必须是 object"
|
||||
)
|
||||
stored = {"value": value, "ttl_seconds": ttl_seconds}
|
||||
self.memory_store[key] = stored
|
||||
self._upsert_memory_sidecars(
|
||||
key,
|
||||
stored,
|
||||
expires_at=self._memory_expiration_from_ttl(ttl_seconds),
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _memory_get_many(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
keys_payload = payload.get("keys")
|
||||
if not isinstance(keys_payload, (list, tuple)):
|
||||
raise AstrBotError.invalid_input("memory.get_many 的 keys 必须是数组")
|
||||
keys = [str(item) for item in keys_payload]
|
||||
items = []
|
||||
for key in keys:
|
||||
if self._purge_expired_memory_entry(key):
|
||||
items.append({"key": key, "value": None})
|
||||
continue
|
||||
stored = self.memory_store.get(key)
|
||||
if (
|
||||
isinstance(stored, dict)
|
||||
and "value" in stored
|
||||
and "ttl_seconds" in stored
|
||||
):
|
||||
value = stored["value"]
|
||||
else:
|
||||
value = stored
|
||||
items.append({"key": key, "value": value})
|
||||
return {"items": items}
|
||||
|
||||
async def _memory_delete_many(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
keys_payload = payload.get("keys")
|
||||
if not isinstance(keys_payload, (list, tuple)):
|
||||
raise AstrBotError.invalid_input("memory.delete_many 的 keys 必须是数组")
|
||||
keys = [str(item) for item in keys_payload]
|
||||
deleted_count = 0
|
||||
for key in keys:
|
||||
if self._delete_memory_entry(key):
|
||||
deleted_count += 1
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
async def _memory_stats(
|
||||
self, _request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._purge_expired_memory_entries()
|
||||
total_items = len(self.memory_store)
|
||||
total_bytes = sum(
|
||||
len(str(key)) + len(str(value)) for key, value in self.memory_store.items()
|
||||
)
|
||||
ttl_entries = len(self._memory_expires_at)
|
||||
indexed_items = len(self._memory_index)
|
||||
embedded_items = sum(
|
||||
1
|
||||
for entry in self._memory_index.values()
|
||||
if isinstance(entry, dict)
|
||||
and isinstance(entry.get("embedding"), list)
|
||||
and bool(entry.get("embedding"))
|
||||
)
|
||||
dirty_items = len(self._memory_dirty_keys)
|
||||
return {
|
||||
"total_items": total_items,
|
||||
"total_bytes": total_bytes,
|
||||
"plugin_id": self._require_caller_plugin_id("memory.stats"),
|
||||
"ttl_entries": ttl_entries,
|
||||
"indexed_items": indexed_items,
|
||||
"embedded_items": embedded_items,
|
||||
"dirty_items": dirty_items,
|
||||
}
|
||||
|
||||
def _register_memory_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.search", "搜索记忆"),
|
||||
call_handler=self._memory_search,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.save", "保存记忆"),
|
||||
call_handler=self._memory_save,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.get", "读取单条记忆"),
|
||||
call_handler=self._memory_get,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.delete", "删除记忆"),
|
||||
call_handler=self._memory_delete,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.save_with_ttl", "保存带过期时间的记忆"),
|
||||
call_handler=self._memory_save_with_ttl,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.get_many", "批量获取记忆"),
|
||||
call_handler=self._memory_get_many,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.delete_many", "批量删除记忆"),
|
||||
call_handler=self._memory_delete_many,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("memory.stats", "获取记忆统计信息"),
|
||||
call_handler=self._memory_stats,
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class MetadataCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _metadata_get_plugin(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
name = str(payload.get("name", "")).strip()
|
||||
plugin = self._plugins.get(name)
|
||||
if plugin is None:
|
||||
return {"plugin": None}
|
||||
return {"plugin": dict(plugin.metadata)}
|
||||
|
||||
async def _metadata_list_plugins(
|
||||
self, _request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
plugins = [
|
||||
dict(self._plugins[name].metadata) for name in sorted(self._plugins.keys())
|
||||
]
|
||||
return {"plugins": plugins}
|
||||
|
||||
async def _metadata_get_plugin_config(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
name = str(payload.get("name", "")).strip()
|
||||
caller_plugin_id = self._require_caller_plugin_id("metadata.get_plugin_config")
|
||||
if name != caller_plugin_id:
|
||||
return {"config": None}
|
||||
plugin = self._plugins.get(name)
|
||||
if plugin is None:
|
||||
return {"config": None}
|
||||
return {"config": dict(plugin.config)}
|
||||
|
||||
def _register_metadata_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("metadata.get_plugin", "获取单个插件元数据"),
|
||||
call_handler=self._metadata_get_plugin,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("metadata.list_plugins", "列出插件元数据"),
|
||||
call_handler=self._metadata_list_plugins,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"metadata.get_plugin_config",
|
||||
"获取插件配置",
|
||||
),
|
||||
call_handler=self._metadata_get_plugin_config,
|
||||
)
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class PersonaCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _persona_get(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
persona_id = str(payload.get("persona_id", "")).strip()
|
||||
record = self._persona_store.get(persona_id)
|
||||
if record is None:
|
||||
raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
|
||||
return {"persona": dict(record)}
|
||||
|
||||
async def _persona_list(
|
||||
self, _request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
personas = [
|
||||
dict(self._persona_store[persona_id])
|
||||
for persona_id in sorted(self._persona_store.keys())
|
||||
]
|
||||
return {"personas": personas}
|
||||
|
||||
async def _persona_create(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
raw_persona = payload.get("persona")
|
||||
if not isinstance(raw_persona, dict):
|
||||
raise AstrBotError.invalid_input("persona.create requires persona object")
|
||||
persona_id = str(raw_persona.get("persona_id", "")).strip()
|
||||
if not persona_id:
|
||||
raise AstrBotError.invalid_input("persona.create requires persona_id")
|
||||
if persona_id in self._persona_store:
|
||||
raise AstrBotError.invalid_input(f"persona already exists: {persona_id}")
|
||||
now = self._now_iso()
|
||||
record = {
|
||||
"persona_id": persona_id,
|
||||
"system_prompt": str(raw_persona.get("system_prompt", "")),
|
||||
"begin_dialogs": self._normalize_persona_dialogs_payload(
|
||||
raw_persona.get("begin_dialogs")
|
||||
),
|
||||
"tools": (
|
||||
[str(item) for item in raw_persona.get("tools", [])]
|
||||
if isinstance(raw_persona.get("tools"), list)
|
||||
else None
|
||||
),
|
||||
"skills": (
|
||||
[str(item) for item in raw_persona.get("skills", [])]
|
||||
if isinstance(raw_persona.get("skills"), list)
|
||||
else None
|
||||
),
|
||||
"custom_error_message": (
|
||||
str(raw_persona.get("custom_error_message"))
|
||||
if raw_persona.get("custom_error_message") is not None
|
||||
else None
|
||||
),
|
||||
"folder_id": (
|
||||
str(raw_persona.get("folder_id"))
|
||||
if raw_persona.get("folder_id") is not None
|
||||
else None
|
||||
),
|
||||
"sort_order": int(raw_persona.get("sort_order", 0)),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
self._persona_store[persona_id] = record
|
||||
return {"persona": dict(record)}
|
||||
|
||||
async def _persona_update(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
persona_id = str(payload.get("persona_id", "")).strip()
|
||||
record = self._persona_store.get(persona_id)
|
||||
if record is None:
|
||||
return {"persona": None}
|
||||
raw_persona = payload.get("persona")
|
||||
if not isinstance(raw_persona, dict):
|
||||
raise AstrBotError.invalid_input("persona.update requires persona object")
|
||||
if (
|
||||
"system_prompt" in raw_persona
|
||||
and raw_persona.get("system_prompt") is not None
|
||||
):
|
||||
record["system_prompt"] = str(raw_persona.get("system_prompt", ""))
|
||||
if "begin_dialogs" in raw_persona:
|
||||
begin_dialogs = raw_persona.get("begin_dialogs")
|
||||
record["begin_dialogs"] = (
|
||||
self._normalize_persona_dialogs_payload(begin_dialogs)
|
||||
if begin_dialogs is not None
|
||||
else []
|
||||
)
|
||||
if "tools" in raw_persona:
|
||||
tools = raw_persona.get("tools")
|
||||
record["tools"] = (
|
||||
[str(item) for item in tools] if isinstance(tools, list) else None
|
||||
)
|
||||
if "skills" in raw_persona:
|
||||
skills = raw_persona.get("skills")
|
||||
record["skills"] = (
|
||||
[str(item) for item in skills] if isinstance(skills, list) else None
|
||||
)
|
||||
if "custom_error_message" in raw_persona:
|
||||
custom_error_message = raw_persona.get("custom_error_message")
|
||||
record["custom_error_message"] = (
|
||||
str(custom_error_message) if custom_error_message is not None else None
|
||||
)
|
||||
record["updated_at"] = self._now_iso()
|
||||
return {"persona": dict(record)}
|
||||
|
||||
async def _persona_delete(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
persona_id = str(payload.get("persona_id", "")).strip()
|
||||
if persona_id not in self._persona_store:
|
||||
raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
|
||||
del self._persona_store[persona_id]
|
||||
return {}
|
||||
|
||||
def _register_persona_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("persona.get", "获取人格"),
|
||||
call_handler=self._persona_get,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("persona.list", "列出人格"),
|
||||
call_handler=self._persona_list,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("persona.create", "创建人格"),
|
||||
call_handler=self._persona_create,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("persona.update", "更新人格"),
|
||||
call_handler=self._persona_update,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("persona.delete", "删除人格"),
|
||||
call_handler=self._persona_delete,
|
||||
)
|
||||
@@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class PlatformCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _platform_send(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session, target = self._resolve_target(payload)
|
||||
text = str(payload.get("text", ""))
|
||||
message_id = f"msg_{len(self.sent_messages) + 1}"
|
||||
sent: dict[str, Any] = {
|
||||
"message_id": message_id,
|
||||
"session": session,
|
||||
"text": text,
|
||||
}
|
||||
if target is not None:
|
||||
sent["target"] = target
|
||||
self.sent_messages.append(sent)
|
||||
return {"message_id": message_id}
|
||||
|
||||
async def _platform_send_image(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session, target = self._resolve_target(payload)
|
||||
image_url = str(payload.get("image_url", ""))
|
||||
message_id = f"img_{len(self.sent_messages) + 1}"
|
||||
sent: dict[str, Any] = {
|
||||
"message_id": message_id,
|
||||
"session": session,
|
||||
"image_url": image_url,
|
||||
}
|
||||
if target is not None:
|
||||
sent["target"] = target
|
||||
self.sent_messages.append(sent)
|
||||
return {"message_id": message_id}
|
||||
|
||||
async def _platform_send_chain(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session, target = self._resolve_target(payload)
|
||||
chain = payload.get("chain")
|
||||
if not isinstance(chain, list) or not all(
|
||||
isinstance(item, dict) for item in chain
|
||||
):
|
||||
raise AstrBotError.invalid_input(
|
||||
"platform.send_chain 的 chain 必须是 object 数组"
|
||||
)
|
||||
message_id = f"chain_{len(self.sent_messages) + 1}"
|
||||
sent: dict[str, Any] = {
|
||||
"message_id": message_id,
|
||||
"session": session,
|
||||
"chain": [dict(item) for item in chain],
|
||||
}
|
||||
if target is not None:
|
||||
sent["target"] = target
|
||||
self.sent_messages.append(sent)
|
||||
return {"message_id": message_id}
|
||||
|
||||
async def _platform_send_by_session(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
chain = payload.get("chain")
|
||||
if not isinstance(chain, list) or not all(
|
||||
isinstance(item, dict) for item in chain
|
||||
):
|
||||
raise AstrBotError.invalid_input(
|
||||
"platform.send_by_session 的 chain 必须是 object 数组"
|
||||
)
|
||||
session = str(payload.get("session", ""))
|
||||
message_id = f"proactive_{len(self.sent_messages) + 1}"
|
||||
self.sent_messages.append(
|
||||
{
|
||||
"message_id": message_id,
|
||||
"session": session,
|
||||
"chain": [dict(item) for item in chain],
|
||||
}
|
||||
)
|
||||
return {"message_id": message_id}
|
||||
|
||||
async def _platform_get_group(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session, _target = self._resolve_target(payload)
|
||||
return {"group": self._mock_group_payload(session)}
|
||||
|
||||
async def _platform_get_members(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session, _target = self._resolve_target(payload)
|
||||
group = self._mock_group_payload(session)
|
||||
if group is None:
|
||||
return {"members": []}
|
||||
return {"members": list(group.get("members", []))}
|
||||
|
||||
async def _platform_list_instances(
|
||||
self, _request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"platforms": [
|
||||
{
|
||||
"id": str(item.get("id", "")),
|
||||
"name": str(item.get("name", "")),
|
||||
"type": str(item.get("type", "")),
|
||||
"status": str(item.get("status", "unknown")),
|
||||
}
|
||||
for item in self.get_platform_instances()
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
}
|
||||
|
||||
def _register_platform_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("platform.send", "发送消息"),
|
||||
call_handler=self._platform_send,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("platform.send_image", "发送图片"),
|
||||
call_handler=self._platform_send_image,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("platform.send_chain", "发送消息链"),
|
||||
call_handler=self._platform_send_chain,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"platform.send_by_session", "按会话主动发送消息链"
|
||||
),
|
||||
call_handler=self._platform_send_by_session,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("platform.get_group", "获取当前群信息"),
|
||||
call_handler=self._platform_get_group,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("platform.get_members", "获取群成员"),
|
||||
call_handler=self._platform_get_members,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("platform.list_instances", "列出平台实例元信息"),
|
||||
call_handler=self._platform_list_instances,
|
||||
)
|
||||
|
||||
async def _platform_manager_get_by_id(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._require_reserved_plugin("platform.manager.get_by_id")
|
||||
platform_id = str(payload.get("platform_id", "")).strip()
|
||||
platform = next(
|
||||
(
|
||||
dict(item)
|
||||
for item in self._platform_instances
|
||||
if str(item.get("id", "")) == platform_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
return {"platform": platform}
|
||||
|
||||
async def _platform_manager_clear_errors(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._require_reserved_plugin("platform.manager.clear_errors")
|
||||
platform_id = str(payload.get("platform_id", "")).strip()
|
||||
for item in self._platform_instances:
|
||||
if str(item.get("id", "")) != platform_id:
|
||||
continue
|
||||
item["errors"] = []
|
||||
item["last_error"] = None
|
||||
if str(item.get("status", "")) == "error":
|
||||
item["status"] = "running"
|
||||
break
|
||||
return {}
|
||||
|
||||
async def _platform_manager_get_stats(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._require_reserved_plugin("platform.manager.get_stats")
|
||||
platform_id = str(payload.get("platform_id", "")).strip()
|
||||
for item in self._platform_instances:
|
||||
if str(item.get("id", "")) != platform_id:
|
||||
continue
|
||||
stats = item.get("stats")
|
||||
if isinstance(stats, dict):
|
||||
return {"stats": dict(stats)}
|
||||
errors = item.get("errors")
|
||||
last_error = item.get("last_error")
|
||||
meta = item.get("meta")
|
||||
return {
|
||||
"stats": {
|
||||
"id": platform_id,
|
||||
"type": str(item.get("type", "")),
|
||||
"display_name": str(item.get("name", platform_id)),
|
||||
"status": str(item.get("status", "pending")),
|
||||
"started_at": item.get("started_at"),
|
||||
"error_count": len(errors) if isinstance(errors, list) else 0,
|
||||
"last_error": dict(last_error)
|
||||
if isinstance(last_error, dict)
|
||||
else None,
|
||||
"unified_webhook": bool(item.get("unified_webhook", False)),
|
||||
"meta": dict(meta) if isinstance(meta, dict) else {},
|
||||
}
|
||||
}
|
||||
return {"stats": None}
|
||||
|
||||
def _register_platform_manager_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"platform.manager.get_by_id",
|
||||
"按 ID 获取平台管理快照",
|
||||
),
|
||||
call_handler=self._platform_manager_get_by_id,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"platform.manager.clear_errors",
|
||||
"清除平台错误",
|
||||
),
|
||||
call_handler=self._platform_manager_clear_errors,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"platform.manager.get_stats",
|
||||
"获取平台统计信息",
|
||||
),
|
||||
call_handler=self._platform_manager_get_stats,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import CapabilityRouterBridgeBase
|
||||
|
||||
|
||||
class SessionCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
async def _session_plugin_is_enabled(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", ""))
|
||||
plugin_name = str(payload.get("plugin_name", ""))
|
||||
config = self._session_plugin_config(session)
|
||||
enabled_plugins = {
|
||||
str(item) for item in config.get("enabled_plugins", []) if str(item).strip()
|
||||
}
|
||||
disabled_plugins = {
|
||||
str(item)
|
||||
for item in config.get("disabled_plugins", [])
|
||||
if str(item).strip()
|
||||
}
|
||||
if plugin_name in enabled_plugins:
|
||||
return {"enabled": True}
|
||||
return {"enabled": plugin_name not in disabled_plugins}
|
||||
|
||||
async def _session_plugin_filter_handlers(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", ""))
|
||||
handlers = payload.get("handlers")
|
||||
if not isinstance(handlers, list):
|
||||
raise AstrBotError.invalid_input(
|
||||
"session.plugin.filter_handlers 的 handlers 必须是 object 数组"
|
||||
)
|
||||
disabled_plugins = {
|
||||
str(item)
|
||||
for item in self._session_plugin_config(session).get("disabled_plugins", [])
|
||||
if str(item).strip()
|
||||
}
|
||||
reserved_plugins = {
|
||||
str(plugin.metadata.get("name", ""))
|
||||
for plugin in self._plugins.values()
|
||||
if bool(plugin.metadata.get("reserved", False))
|
||||
}
|
||||
filtered = []
|
||||
for item in handlers:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
plugin_name = str(item.get("plugin_name", ""))
|
||||
if (
|
||||
plugin_name
|
||||
and plugin_name in disabled_plugins
|
||||
and plugin_name not in reserved_plugins
|
||||
):
|
||||
continue
|
||||
filtered.append(dict(item))
|
||||
return {"handlers": filtered}
|
||||
|
||||
async def _session_service_is_llm_enabled(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", ""))
|
||||
config = self._session_service_config(session)
|
||||
return {"enabled": bool(config.get("llm_enabled", True))}
|
||||
|
||||
async def _session_service_set_llm_status(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", ""))
|
||||
config = self._session_service_config(session)
|
||||
config["llm_enabled"] = bool(payload.get("enabled", False))
|
||||
self._session_service_configs[session] = config
|
||||
return {}
|
||||
|
||||
async def _session_service_is_tts_enabled(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", ""))
|
||||
config = self._session_service_config(session)
|
||||
return {"enabled": bool(config.get("tts_enabled", True))}
|
||||
|
||||
async def _session_service_set_tts_status(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", ""))
|
||||
config = self._session_service_config(session)
|
||||
config["tts_enabled"] = bool(payload.get("enabled", False))
|
||||
self._session_service_configs[session] = config
|
||||
return {}
|
||||
|
||||
def _register_session_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("session.plugin.is_enabled", "获取会话级插件开关"),
|
||||
call_handler=self._session_plugin_is_enabled,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"session.plugin.filter_handlers",
|
||||
"按会话过滤 handler 元数据",
|
||||
),
|
||||
call_handler=self._session_plugin_filter_handlers,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"session.service.is_llm_enabled",
|
||||
"获取会话级 LLM 开关",
|
||||
),
|
||||
call_handler=self._session_service_is_llm_enabled,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"session.service.set_llm_status",
|
||||
"写入会话级 LLM 开关",
|
||||
),
|
||||
call_handler=self._session_service_set_llm_status,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"session.service.is_tts_enabled",
|
||||
"获取会话级 TTS 开关",
|
||||
),
|
||||
call_handler=self._session_service_is_tts_enabled,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"session.service.set_tts_status",
|
||||
"写入会话级 TTS 开关",
|
||||
),
|
||||
call_handler=self._session_service_set_tts_status,
|
||||
)
|
||||
@@ -0,0 +1,454 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from ....errors import AstrBotError
|
||||
from ..bridge_base import (
|
||||
CapabilityRouterBridgeBase,
|
||||
_clone_chain_payload,
|
||||
_clone_target_payload,
|
||||
)
|
||||
|
||||
|
||||
class SystemCapabilityMixin(CapabilityRouterBridgeBase):
|
||||
def _register_system_capabilities(self) -> None:
|
||||
self.register(
|
||||
self._builtin_descriptor("system.get_data_dir", "获取插件数据目录"),
|
||||
call_handler=self._system_get_data_dir,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.text_to_image", "文本转图片"),
|
||||
call_handler=self._system_text_to_image,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.html_render", "渲染 HTML 模板"),
|
||||
call_handler=self._system_html_render,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.file.register", "注册文件令牌"),
|
||||
call_handler=self._system_file_register,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.file.handle", "解析文件令牌"),
|
||||
call_handler=self._system_file_handle,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.session_waiter.register",
|
||||
"注册会话等待器",
|
||||
),
|
||||
call_handler=self._system_session_waiter_register,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.session_waiter.unregister",
|
||||
"注销会话等待器",
|
||||
),
|
||||
call_handler=self._system_session_waiter_unregister,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.event.react", "发送事件表情回应"),
|
||||
call_handler=self._system_event_react,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.event.send_typing", "发送输入中状态"),
|
||||
call_handler=self._system_event_send_typing,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.send_streaming",
|
||||
"发送事件流式消息",
|
||||
),
|
||||
call_handler=self._system_event_send_streaming,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.send_streaming_chunk",
|
||||
"推送事件流式消息分片",
|
||||
),
|
||||
call_handler=self._system_event_send_streaming_chunk,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.send_streaming_close",
|
||||
"关闭事件流式消息会话",
|
||||
),
|
||||
call_handler=self._system_event_send_streaming_close,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.llm.get_state",
|
||||
"读取当前请求的默认 LLM 状态",
|
||||
),
|
||||
call_handler=self._system_event_llm_get_state,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.llm.request",
|
||||
"请求当前事件继续进入默认 LLM 链路",
|
||||
),
|
||||
call_handler=self._system_event_llm_request,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.event.result.get", "读取当前请求结果"),
|
||||
call_handler=self._system_event_result_get,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.event.result.set", "写入当前请求结果"),
|
||||
call_handler=self._system_event_result_set,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("system.event.result.clear", "清理当前请求结果"),
|
||||
call_handler=self._system_event_result_clear,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.handler_whitelist.get",
|
||||
"读取当前请求 handler 白名单",
|
||||
),
|
||||
call_handler=self._system_event_handler_whitelist_get,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"system.event.handler_whitelist.set",
|
||||
"写入当前请求 handler 白名单",
|
||||
),
|
||||
call_handler=self._system_event_handler_whitelist_set,
|
||||
exposed=False,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"registry.get_handlers_by_event_type",
|
||||
"按事件类型列出 handler 元数据",
|
||||
),
|
||||
call_handler=self._registry_get_handlers_by_event_type,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"registry.get_handler_by_full_name",
|
||||
"按 full name 查询 handler 元数据",
|
||||
),
|
||||
call_handler=self._registry_get_handler_by_full_name,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor(
|
||||
"registry.command.register",
|
||||
"注册动态命令路由",
|
||||
),
|
||||
call_handler=self._registry_command_register,
|
||||
)
|
||||
|
||||
def _ensure_request_overlay(self, request_id: str) -> dict[str, Any]:
|
||||
overlay = self._request_overlays.get(request_id)
|
||||
if overlay is None:
|
||||
overlay = {
|
||||
"should_call_llm": False,
|
||||
"requested_llm": False,
|
||||
"result": None,
|
||||
"handler_whitelist": None,
|
||||
}
|
||||
self._request_overlays[request_id] = overlay
|
||||
return overlay
|
||||
|
||||
async def _system_get_data_dir(
|
||||
self, _request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
plugin_id = self._require_caller_plugin_id("system.get_data_dir")
|
||||
data_dir = self._system_data_root / plugin_id
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return {"path": str(data_dir)}
|
||||
|
||||
async def _system_text_to_image(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
text = str(payload.get("text", ""))
|
||||
if bool(payload.get("return_url", True)):
|
||||
return {"result": f"mock://text_to_image/{text}"}
|
||||
return {"result": f"<image>{text}</image>"}
|
||||
|
||||
async def _system_html_render(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
tmpl = str(payload.get("tmpl", ""))
|
||||
data = payload.get("data")
|
||||
if not isinstance(data, dict):
|
||||
raise AstrBotError.invalid_input("system.html_render requires object data")
|
||||
if bool(payload.get("return_url", True)):
|
||||
return {"result": f"mock://html_render/{tmpl}"}
|
||||
return {"result": json.dumps({"tmpl": tmpl, "data": data}, ensure_ascii=False)}
|
||||
|
||||
async def _system_file_register(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
path = str(payload.get("path", "")).strip()
|
||||
if not path:
|
||||
raise AstrBotError.invalid_input("system.file.register requires path")
|
||||
file_token = uuid.uuid4().hex
|
||||
self._file_token_store[file_token] = path
|
||||
return {"token": file_token, "url": f"mock://file/{file_token}"}
|
||||
|
||||
async def _system_file_handle(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
file_token = str(payload.get("token", "")).strip()
|
||||
if not file_token:
|
||||
raise AstrBotError.invalid_input("system.file.handle requires token")
|
||||
path = self._file_token_store.pop(file_token, None)
|
||||
if path is None:
|
||||
raise AstrBotError.invalid_input(f"Unknown file token: {file_token}")
|
||||
return {"path": path}
|
||||
|
||||
async def _system_event_llm_get_state(
|
||||
self, request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
return {
|
||||
"should_call_llm": bool(overlay["should_call_llm"]),
|
||||
"requested_llm": bool(overlay["requested_llm"]),
|
||||
}
|
||||
|
||||
async def _system_event_llm_request(
|
||||
self, request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
overlay["requested_llm"] = True
|
||||
overlay["should_call_llm"] = True
|
||||
return await self._system_event_llm_get_state(request_id, {}, _token)
|
||||
|
||||
async def _system_event_result_get(
|
||||
self, request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
result = overlay.get("result")
|
||||
return {"result": dict(result) if isinstance(result, dict) else None}
|
||||
|
||||
async def _system_event_result_set(
|
||||
self, request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
result = payload.get("result")
|
||||
if not isinstance(result, dict):
|
||||
raise AstrBotError.invalid_input(
|
||||
"system.event.result.set 的 result 必须是 object"
|
||||
)
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
overlay["result"] = dict(result)
|
||||
return {"result": dict(result)}
|
||||
|
||||
async def _system_event_result_clear(
|
||||
self, request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
overlay["result"] = None
|
||||
return {}
|
||||
|
||||
async def _system_event_handler_whitelist_get(
|
||||
self, request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
whitelist = overlay.get("handler_whitelist")
|
||||
if whitelist is None:
|
||||
return {"plugin_names": None}
|
||||
return {"plugin_names": sorted(str(item) for item in whitelist)}
|
||||
|
||||
async def _system_event_handler_whitelist_set(
|
||||
self, request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
overlay = self._ensure_request_overlay(request_id)
|
||||
plugin_names_payload = payload.get("plugin_names")
|
||||
if plugin_names_payload is None:
|
||||
overlay["handler_whitelist"] = None
|
||||
elif isinstance(plugin_names_payload, list):
|
||||
overlay["handler_whitelist"] = {
|
||||
str(item) for item in plugin_names_payload if str(item).strip()
|
||||
}
|
||||
else:
|
||||
raise AstrBotError.invalid_input(
|
||||
"system.event.handler_whitelist.set 的 plugin_names 必须是数组或 null"
|
||||
)
|
||||
return await self._system_event_handler_whitelist_get(request_id, {}, _token)
|
||||
|
||||
async def _registry_get_handlers_by_event_type(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
event_type = str(payload.get("event_type", "")).strip()
|
||||
handlers: list[dict[str, Any]] = []
|
||||
for plugin in self._plugins.values():
|
||||
handlers.extend(
|
||||
[
|
||||
dict(handler)
|
||||
for handler in plugin.handlers
|
||||
if event_type in handler.get("event_types", [])
|
||||
]
|
||||
)
|
||||
if event_type == "message":
|
||||
for plugin_name, routes in self._dynamic_command_routes.items():
|
||||
for route in routes:
|
||||
if not isinstance(route, dict):
|
||||
continue
|
||||
handlers.append(
|
||||
{
|
||||
"plugin_name": str(route.get("plugin_name", plugin_name)),
|
||||
"handler_full_name": str(
|
||||
route.get("handler_full_name", "")
|
||||
),
|
||||
"trigger_type": (
|
||||
"message"
|
||||
if bool(route.get("use_regex", False))
|
||||
else "command"
|
||||
),
|
||||
"event_types": ["message"],
|
||||
"enabled": True,
|
||||
"group_path": [],
|
||||
}
|
||||
)
|
||||
return {"handlers": handlers}
|
||||
|
||||
async def _registry_get_handler_by_full_name(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
full_name = str(payload.get("full_name", "")).strip()
|
||||
for plugin in self._plugins.values():
|
||||
for handler in plugin.handlers:
|
||||
if handler.get("handler_full_name") == full_name:
|
||||
return {"handler": dict(handler)}
|
||||
return {"handler": None}
|
||||
|
||||
async def _registry_command_register(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
source_event_type = str(payload.get("source_event_type", "")).strip()
|
||||
if source_event_type not in {"astrbot_loaded", "platform_loaded"}:
|
||||
raise AstrBotError.invalid_input(
|
||||
"register_commands is only available in astrbot_loaded/platform_loaded events"
|
||||
)
|
||||
if bool(payload.get("ignore_prefix", False)):
|
||||
raise AstrBotError.invalid_input(
|
||||
"register_commands(ignore_prefix=True) is unsupported in SDK runtime"
|
||||
)
|
||||
priority_value = payload.get("priority", 0)
|
||||
if isinstance(priority_value, bool) or not isinstance(priority_value, int):
|
||||
raise AstrBotError.invalid_input(
|
||||
"registry.command.register 的 priority 必须是 integer"
|
||||
)
|
||||
plugin_id = self._require_caller_plugin_id("registry.command.register")
|
||||
self.register_dynamic_command_route(
|
||||
plugin_id=plugin_id,
|
||||
command_name=str(payload.get("command_name", "")),
|
||||
handler_full_name=str(payload.get("handler_full_name", "")),
|
||||
desc=str(payload.get("desc", "")),
|
||||
priority=priority_value,
|
||||
use_regex=bool(payload.get("use_regex", False)),
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _system_session_waiter_register(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
plugin_id = self._require_caller_plugin_id("system.session_waiter.register")
|
||||
session_key = str(payload.get("session_key", "")).strip()
|
||||
if not session_key:
|
||||
raise AstrBotError.invalid_input(
|
||||
"system.session_waiter.register requires session_key"
|
||||
)
|
||||
self._session_waiters.setdefault(plugin_id, set()).add(session_key)
|
||||
return {}
|
||||
|
||||
async def _system_session_waiter_unregister(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
plugin_id = self._require_caller_plugin_id("system.session_waiter.unregister")
|
||||
session_key = str(payload.get("session_key", "")).strip()
|
||||
plugin_waiters = self._session_waiters.get(plugin_id)
|
||||
if plugin_waiters is None:
|
||||
return {}
|
||||
plugin_waiters.discard(session_key)
|
||||
if not plugin_waiters:
|
||||
self._session_waiters.pop(plugin_id, None)
|
||||
return {}
|
||||
|
||||
async def _system_event_react(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self.event_actions.append(
|
||||
{
|
||||
"action": "react",
|
||||
"emoji": str(payload.get("emoji", "")),
|
||||
"target": _clone_target_payload(payload.get("target")),
|
||||
}
|
||||
)
|
||||
return {"supported": True}
|
||||
|
||||
async def _system_event_send_typing(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self.event_actions.append(
|
||||
{
|
||||
"action": "send_typing",
|
||||
"target": _clone_target_payload(payload.get("target")),
|
||||
}
|
||||
)
|
||||
return {"supported": True}
|
||||
|
||||
async def _system_event_send_streaming(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
stream_id = f"mock-stream-{len(self._event_streams) + 1}"
|
||||
stream_state: dict[str, Any] = {
|
||||
"target": _clone_target_payload(payload.get("target")),
|
||||
"chunks": [],
|
||||
"use_fallback": bool(payload.get("use_fallback", False)),
|
||||
}
|
||||
self._event_streams[stream_id] = stream_state
|
||||
return {"supported": True, "stream_id": stream_id}
|
||||
|
||||
async def _system_event_send_streaming_chunk(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
stream = self._event_streams.get(str(payload.get("stream_id", "")))
|
||||
if stream is None:
|
||||
raise AstrBotError.invalid_input("Unknown sdk event streaming session")
|
||||
chain = payload.get("chain")
|
||||
if not isinstance(chain, list):
|
||||
raise AstrBotError.invalid_input(
|
||||
"system.event.send_streaming_chunk requires a chain array"
|
||||
)
|
||||
stream["chunks"].append({"chain": _clone_chain_payload(chain)})
|
||||
return {}
|
||||
|
||||
async def _system_event_send_streaming_close(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
stream_id = str(payload.get("stream_id", ""))
|
||||
stream = self._event_streams.pop(stream_id, None)
|
||||
if stream is None:
|
||||
raise AstrBotError.invalid_input("Unknown sdk event streaming session")
|
||||
self.event_actions.append(
|
||||
{
|
||||
"action": "send_streaming",
|
||||
"target": stream["target"],
|
||||
"chunks": list(stream["chunks"]),
|
||||
"use_fallback": bool(stream["use_fallback"]),
|
||||
}
|
||||
)
|
||||
return {"supported": True}
|
||||
|
||||
@@ -67,6 +67,7 @@
|
||||
provider.rerank.rerank: 文档重排序
|
||||
provider.manager.set: 设置当前 Provider
|
||||
provider.manager.get_by_id: 按 ID 获取 Provider 管理记录
|
||||
provider.manager.get_merged_provider_config: 获取 Provider 合并配置
|
||||
provider.manager.load: 运行时加载 Provider
|
||||
provider.manager.terminate: 终止已加载的 Provider
|
||||
provider.manager.create: 创建 Provider
|
||||
|
||||
@@ -39,6 +39,7 @@ from .._invocation_context import caller_plugin_scope
|
||||
from .._plugin_logger import PluginLogger
|
||||
from .._star_runtime import bind_star_runtime
|
||||
from .._typing_utils import unwrap_optional
|
||||
from ..clients.llm import LLMResponse
|
||||
from ..context import CancelToken, Context
|
||||
from ..conversation import (
|
||||
DEFAULT_BUSY_MESSAGE,
|
||||
@@ -49,8 +50,13 @@ from ..conversation import (
|
||||
)
|
||||
from ..events import MessageEvent
|
||||
from ..filters import LocalFilterBinding
|
||||
from ..llm.entities import ProviderRequest
|
||||
from ..message_components import BaseMessageComponent
|
||||
from ..message_result import MessageChain, MessageEventResult, coerce_message_chain
|
||||
from ..message_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
coerce_message_chain,
|
||||
)
|
||||
from ..protocol.descriptors import (
|
||||
CommandTrigger,
|
||||
MessageTrigger,
|
||||
@@ -60,12 +66,12 @@ from ..protocol.descriptors import (
|
||||
from ..schedule import ScheduleContext
|
||||
from ..session_waiter import SessionWaiterManager
|
||||
from ..star import Star
|
||||
from .capability_dispatcher import CapabilityDispatcher
|
||||
from ._command_matching import (
|
||||
build_command_args,
|
||||
build_regex_args,
|
||||
match_command_name,
|
||||
)
|
||||
from .capability_dispatcher import CapabilityDispatcher
|
||||
from .limiter import LimiterEngine
|
||||
from .loader import LoadedHandler
|
||||
|
||||
@@ -76,6 +82,13 @@ class _ActiveConversation:
|
||||
task: asyncio.Task[Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _InjectedEventPayloads:
|
||||
provider_request: ProviderRequest | None = None
|
||||
llm_response: LLMResponse | None = None
|
||||
event_result: MessageEventResult | None = None
|
||||
|
||||
|
||||
class HandlerDispatcher:
|
||||
def __init__(
|
||||
self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler]
|
||||
@@ -205,6 +218,8 @@ class HandlerDispatcher:
|
||||
schedule_context: ScheduleContext | None = None,
|
||||
) -> dict[str, Any]:
|
||||
summary = {"sent_message": False, "stop": False, "call_llm": False}
|
||||
injected_payloads = _InjectedEventPayloads()
|
||||
event_type = self._event_type_name(event)
|
||||
try:
|
||||
limiter = loaded.limiter
|
||||
if limiter is not None:
|
||||
@@ -254,6 +269,7 @@ class HandlerDispatcher:
|
||||
plugin_id=self._resolve_plugin_id(loaded),
|
||||
handler_ref=loaded.descriptor.id,
|
||||
schedule_context=schedule_context,
|
||||
injected_payloads=injected_payloads,
|
||||
)
|
||||
)
|
||||
if inspect.isasyncgen(result):
|
||||
@@ -263,6 +279,11 @@ class HandlerDispatcher:
|
||||
await self._handle_result_item(item, event, ctx),
|
||||
)
|
||||
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
|
||||
self._append_injected_payloads(
|
||||
summary,
|
||||
injected_payloads,
|
||||
event_type=event_type,
|
||||
)
|
||||
return summary
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
@@ -272,6 +293,11 @@ class HandlerDispatcher:
|
||||
await self._handle_result_item(result, event, ctx),
|
||||
)
|
||||
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
|
||||
self._append_injected_payloads(
|
||||
summary,
|
||||
injected_payloads,
|
||||
event_type=event_type,
|
||||
)
|
||||
return summary
|
||||
except Exception as exc:
|
||||
await self._handle_error(
|
||||
@@ -339,6 +365,7 @@ class HandlerDispatcher:
|
||||
handler_ref: str | None = None,
|
||||
schedule_context: ScheduleContext | None = None,
|
||||
conversation_session: ConversationSession | None = None,
|
||||
injected_payloads: _InjectedEventPayloads | None = None,
|
||||
) -> list[Any]:
|
||||
"""构建 handler 参数列表。"""
|
||||
from loguru import logger
|
||||
@@ -371,6 +398,7 @@ class HandlerDispatcher:
|
||||
ctx,
|
||||
schedule_context,
|
||||
conversation_session,
|
||||
injected_payloads=injected_payloads,
|
||||
)
|
||||
|
||||
# 2. Fallback 按名字注入
|
||||
@@ -423,6 +451,8 @@ class HandlerDispatcher:
|
||||
if not str(key).startswith("__command_")
|
||||
}
|
||||
)
|
||||
if not isinstance(loaded.descriptor.trigger, CommandTrigger):
|
||||
return parsed_args, None
|
||||
model_param = resolve_command_model_param(loaded.callable)
|
||||
if model_param is None:
|
||||
return parsed_args, None
|
||||
@@ -456,7 +486,7 @@ class HandlerDispatcher:
|
||||
) -> dict[str, Any]:
|
||||
assert loaded.conversation is not None
|
||||
conversation_meta = loaded.conversation
|
||||
summary = {"sent_message": False, "stop": False, "call_llm": False}
|
||||
summary = {"sent_message": False, "stop": True, "call_llm": False}
|
||||
key = f"{self._resolve_plugin_id(loaded)}:{event.session_id}"
|
||||
active = self._conversations.get(key)
|
||||
if active is not None and not active.task.done():
|
||||
@@ -584,6 +614,8 @@ class HandlerDispatcher:
|
||||
ctx: Context,
|
||||
schedule_context: ScheduleContext | None,
|
||||
conversation_session: ConversationSession | None,
|
||||
*,
|
||||
injected_payloads: _InjectedEventPayloads | None = None,
|
||||
) -> Any:
|
||||
"""根据类型注解注入参数。"""
|
||||
param_type, _is_optional = unwrap_optional(param_type)
|
||||
@@ -612,9 +644,117 @@ class HandlerDispatcher:
|
||||
isinstance(param_type, type) and issubclass(param_type, ConversationSession)
|
||||
):
|
||||
return conversation_session
|
||||
if param_type is ProviderRequest or (
|
||||
isinstance(param_type, type) and issubclass(param_type, ProviderRequest)
|
||||
):
|
||||
return self._inject_provider_request(event, injected_payloads)
|
||||
if param_type is LLMResponse or (
|
||||
isinstance(param_type, type) and issubclass(param_type, LLMResponse)
|
||||
):
|
||||
return self._inject_llm_response(event, injected_payloads)
|
||||
if param_type is MessageEventResult or (
|
||||
isinstance(param_type, type) and issubclass(param_type, MessageEventResult)
|
||||
):
|
||||
return self._inject_event_result(event, injected_payloads)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _event_type_name(event: MessageEvent) -> str:
|
||||
raw = event.raw if isinstance(event.raw, dict) else {}
|
||||
value = raw.get("event_type") or raw.get("type")
|
||||
return str(value or "")
|
||||
|
||||
@staticmethod
|
||||
def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None:
|
||||
raw = event.raw if isinstance(event.raw, dict) else {}
|
||||
payload = raw.get(key)
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
nested_raw = raw.get("raw")
|
||||
if isinstance(nested_raw, dict):
|
||||
nested_payload = nested_raw.get(key)
|
||||
if isinstance(nested_payload, dict):
|
||||
return nested_payload
|
||||
return None
|
||||
|
||||
def _inject_provider_request(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
injected_payloads: _InjectedEventPayloads | None,
|
||||
) -> ProviderRequest | None:
|
||||
if injected_payloads is None:
|
||||
payload = self._payload_from_event(event, "provider_request")
|
||||
return (
|
||||
ProviderRequest.from_payload(payload) if payload is not None else None
|
||||
)
|
||||
if injected_payloads.provider_request is None:
|
||||
payload = self._payload_from_event(event, "provider_request")
|
||||
if payload is None:
|
||||
return None
|
||||
injected_payloads.provider_request = ProviderRequest.from_payload(payload)
|
||||
return injected_payloads.provider_request
|
||||
|
||||
def _inject_llm_response(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
injected_payloads: _InjectedEventPayloads | None,
|
||||
) -> LLMResponse | None:
|
||||
if injected_payloads is None:
|
||||
payload = self._payload_from_event(event, "llm_response")
|
||||
return LLMResponse.model_validate(payload) if payload is not None else None
|
||||
if injected_payloads.llm_response is None:
|
||||
payload = self._payload_from_event(event, "llm_response")
|
||||
if payload is None:
|
||||
return None
|
||||
injected_payloads.llm_response = LLMResponse.model_validate(payload)
|
||||
return injected_payloads.llm_response
|
||||
|
||||
def _inject_event_result(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
injected_payloads: _InjectedEventPayloads | None,
|
||||
) -> MessageEventResult | None:
|
||||
if injected_payloads is None:
|
||||
payload = self._payload_from_event(event, "event_result")
|
||||
return (
|
||||
MessageEventResult.from_payload(payload)
|
||||
if payload is not None
|
||||
else None
|
||||
)
|
||||
if injected_payloads.event_result is None:
|
||||
payload = self._payload_from_event(event, "event_result")
|
||||
if payload is None:
|
||||
return None
|
||||
injected_payloads.event_result = MessageEventResult.from_payload(payload)
|
||||
return injected_payloads.event_result
|
||||
|
||||
@staticmethod
|
||||
def _append_injected_payloads(
|
||||
summary: dict[str, Any],
|
||||
injected_payloads: _InjectedEventPayloads,
|
||||
*,
|
||||
event_type: str,
|
||||
) -> None:
|
||||
if (
|
||||
event_type == "llm_request"
|
||||
and injected_payloads.provider_request is not None
|
||||
):
|
||||
summary["provider_request"] = (
|
||||
injected_payloads.provider_request.to_payload()
|
||||
)
|
||||
elif (
|
||||
event_type == "llm_response" and injected_payloads.llm_response is not None
|
||||
):
|
||||
summary["llm_response"] = injected_payloads.llm_response.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
elif (
|
||||
event_type == "decorating_result"
|
||||
and injected_payloads.event_result is not None
|
||||
):
|
||||
summary["event_result"] = injected_payloads.event_result.to_payload()
|
||||
|
||||
def _format_handler_injection_error(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -55,6 +55,7 @@ import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -105,6 +106,7 @@ OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None
|
||||
HandlerKind: TypeAlias = Literal["handler", "hook", "tool", "session"]
|
||||
DiscoverySeverity: TypeAlias = Literal["warning", "error"]
|
||||
DiscoveryPhase: TypeAlias = Literal["discovery", "load", "lifecycle", "reload"]
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_python_version() -> str:
|
||||
@@ -502,17 +504,74 @@ def _normalize_config_value(field_schema: dict[str, Any], value: Any) -> Any:
|
||||
return copy.deepcopy(value) if value is not None else default_value
|
||||
|
||||
|
||||
def load_plugin_config(plugin: PluginSpec) -> dict[str, Any]:
|
||||
"""加载插件配置,返回普通字典。"""
|
||||
def load_plugin_config_schema(plugin: PluginSpec) -> dict[str, Any]:
|
||||
"""加载插件配置 schema,解析失败时记录日志并返回空对象。"""
|
||||
schema_path = plugin.plugin_dir / CONFIG_SCHEMA_FILE
|
||||
if not schema_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
schema_payload = json.loads(schema_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
schema_payload = {}
|
||||
schema = schema_payload if isinstance(schema_payload, dict) else {}
|
||||
except json.JSONDecodeError as exc:
|
||||
_LOGGER.warning(
|
||||
"Failed to parse SDK plugin config schema %s: %s",
|
||||
schema_path,
|
||||
exc,
|
||||
)
|
||||
return {}
|
||||
except OSError as exc:
|
||||
_LOGGER.warning(
|
||||
"Failed to read SDK plugin config schema %s: %s",
|
||||
schema_path,
|
||||
exc,
|
||||
)
|
||||
return {}
|
||||
if not isinstance(schema_payload, dict):
|
||||
_LOGGER.warning(
|
||||
"SDK plugin config schema %s must be a JSON object, got %s",
|
||||
schema_path,
|
||||
type(schema_payload).__name__,
|
||||
)
|
||||
return {}
|
||||
return schema_payload
|
||||
|
||||
|
||||
def save_plugin_config(
|
||||
plugin: PluginSpec,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
schema: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""按 schema 归一化并写回插件配置。"""
|
||||
active_schema = (
|
||||
load_plugin_config_schema(plugin) if schema is None else dict(schema)
|
||||
)
|
||||
normalized = {
|
||||
key: _normalize_config_value(field_schema, payload.get(key))
|
||||
for key, field_schema in active_schema.items()
|
||||
if isinstance(field_schema, dict)
|
||||
}
|
||||
|
||||
config_path = _plugin_config_path(plugin.plugin_dir, plugin.name)
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_path.write_text(
|
||||
json.dumps(normalized, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def load_plugin_config(
|
||||
plugin: PluginSpec,
|
||||
*,
|
||||
schema: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""加载插件配置,返回普通字典。"""
|
||||
active_schema = (
|
||||
load_plugin_config_schema(plugin) if schema is None else dict(schema)
|
||||
)
|
||||
if not active_schema:
|
||||
return {}
|
||||
|
||||
config_path = _plugin_config_path(plugin.plugin_dir, plugin.name)
|
||||
try:
|
||||
@@ -521,21 +580,29 @@ def load_plugin_config(plugin: PluginSpec) -> dict[str, Any]:
|
||||
if config_path.exists()
|
||||
else {}
|
||||
)
|
||||
except Exception:
|
||||
except json.JSONDecodeError as exc:
|
||||
_LOGGER.warning(
|
||||
"Failed to parse SDK plugin config %s: %s",
|
||||
config_path,
|
||||
exc,
|
||||
)
|
||||
existing_payload = {}
|
||||
except OSError as exc:
|
||||
_LOGGER.warning(
|
||||
"Failed to read SDK plugin config %s: %s",
|
||||
config_path,
|
||||
exc,
|
||||
)
|
||||
existing_payload = {}
|
||||
existing = existing_payload if isinstance(existing_payload, dict) else {}
|
||||
normalized = {
|
||||
key: _normalize_config_value(field_schema, existing.get(key))
|
||||
for key, field_schema in schema.items()
|
||||
for key, field_schema in active_schema.items()
|
||||
if isinstance(field_schema, dict)
|
||||
}
|
||||
|
||||
if not config_path.exists() or normalized != existing:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_path.write_text(
|
||||
json.dumps(normalized, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
save_plugin_config(plugin, normalized, schema=active_schema)
|
||||
return normalized
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user