feat: add conversation.get_current capability and related schemas

- Introduced CONVERSATION_GET_CURRENT_INPUT_SCHEMA and CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA for handling current conversation requests.
- Implemented _conversation_get_current method in BuiltinCapabilityRouterMixin to manage current conversation retrieval and creation.
- Registered the new capability in CoreCapabilityBridge.
- Enhanced HandlerDispatcher to inject provider request, LLM response, and event result payloads into the event handling process.
- Updated tests to validate the new functionality and ensure proper payload handling.
This commit is contained in:
whatevertogo
2026-03-19 06:16:01 +08:00
parent bb361cf9dd
commit ed1b9665dd
8 changed files with 335 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from ..errors import AstrBotError, ErrorCodes
from ..message_session import MessageSession
from ._proxy import CapabilityProxy
@@ -138,7 +139,15 @@ class PersonaManagerClient:
self._proxy = proxy
async def get_persona(self, persona_id: str) -> PersonaRecord:
output = await self._proxy.call("persona.get", {"persona_id": str(persona_id)})
try:
output = await self._proxy.call(
"persona.get",
{"persona_id": str(persona_id)},
)
except AstrBotError as exc:
if exc.code == ErrorCodes.INVALID_INPUT:
raise ValueError(f"persona not found: {persona_id}") from exc
raise
persona = PersonaRecord.from_payload(output.get("persona"))
if persona is None:
raise ValueError(f"persona not found: {persona_id}")
@@ -251,6 +260,21 @@ class ConversationManagerClient:
)
return ConversationRecord.from_payload(output.get("conversation"))
async def get_current_conversation(
self,
session: str | MessageSession,
*,
create_if_not_exists: bool = False,
) -> ConversationRecord | None:
output = await self._proxy.call(
"conversation.get_current",
{
"session": _normalize_session(session),
"create_if_not_exists": bool(create_if_not_exists),
},
)
return ConversationRecord.from_payload(output.get("conversation"))
async def get_conversations(
self,
session: str | MessageSession | None = None,

View File

@@ -113,6 +113,11 @@ class MyPlugin(Star):
- 注册 LLM 工具
- 启动后台任务
**最佳实践:**
- `on_start()` 里只做初始化、能力注册和轻量状态恢复
- 需要长期保存的应是配置值、句柄、任务引用,不要把 `ctx` 实例长期挂到 `self`
- 如果要和 AstrBot 原生 persona / conversation 协作,优先在这里校验或创建所需资源
**示例:**
```python
@@ -150,6 +155,11 @@ class MyPlugin(Star):
- 注销 LLM 工具
- 保存状态数据
**最佳实践:**
-`on_stop()` 中释放 `on_start()` 注册的任务、监听器和外部资源
- 把需要持久化的状态尽量提前落库,不要把关键保存逻辑完全依赖在进程退出瞬间
- 始终把收到的 `ctx` 继续传给 `super().on_stop(ctx)`,不要手动丢掉它
**示例:**
```python

View File

@@ -1101,6 +1101,8 @@ from astrbot_sdk.clients import PersonaManagerClient
获取指定人格。
当人格不存在时会抛出 `ValueError`,而不是返回 `None`
---
#### `get_all_personas()`
@@ -1163,6 +1165,17 @@ from astrbot_sdk.clients import ConversationManagerClient
---
#### `get_current_conversation(session, create_if_not_exists=False)`
获取当前 session 正在使用的对话记录。
这个方法适合“跟随 AstrBot 原生当前会话状态”的插件,例如:
- 给当前会话切换 persona
- 判断当前主聊天是否已经在某个 persona 下
-`waiting_llm_request` / `llm_request` hook 中对当前对话做增强
---
#### `get_conversations(session=None, platform_id=None)`
获取对话列表。

View File

@@ -662,7 +662,7 @@ await ctx.conversations.delete_conversation(
await ctx.conversations.delete_conversation(event.session_id)
```
##### `get_conversation() / get_conversations()`
##### `get_conversation() / get_current_conversation() / get_conversations()`
获取对话。
@@ -674,6 +674,12 @@ conv = await ctx.conversations.get_conversation(
create_if_not_exists=True
)
# 获取当前选中的对话
current = await ctx.conversations.get_current_conversation(
event.session_id,
create_if_not_exists=True,
)
# 获取对话列表
convs = await ctx.conversations.get_conversations(event.session_id)
```

View File

@@ -210,11 +210,110 @@ async def handle_request(self, event, ctx: Context):
await ctx.platform.send(event.user_id, "已自动通过好友请求")
```
#### LLM Pipeline Hooks
`@on_event` 也用于挂接 AstrBot 原生消息处理链路中的系统事件。
常见事件及可注入对象:
| 事件名 | 常见可注入参数 | 是否可修改主链路 |
|------|------|------|
| `waiting_llm_request` | `MessageEvent`, `Context` | 间接可修改,例如切换当前对话 persona |
| `llm_request` | `MessageEvent`, `Context`, `ProviderRequest` | 是,可直接修改 `ProviderRequest` |
| `llm_response` | `MessageEvent`, `Context`, `LLMResponse` | 否,适合观察和提取回复内容 |
| `decorating_result` | `MessageEvent`, `Context`, `MessageEventResult` | 是,可直接修改结果消息链 |
| `after_message_sent` | `MessageEvent`, `Context` | 否,适合落库、记忆、统计 |
最小示例:
```python
from astrbot_sdk import Context, MessageEvent
from astrbot_sdk.decorators import on_event
from astrbot_sdk.llm.entities import ProviderRequest
@on_event("llm_request")
async def add_memory(self, event: MessageEvent, ctx: Context, request: ProviderRequest):
del event, ctx
request.system_prompt = (request.system_prompt or "") + "\n\nmemory: user likes tea"
```
完整示例:
```python
from astrbot_sdk import Context, MessageEvent, Star
from astrbot_sdk.clients.llm import LLMResponse
from astrbot_sdk.clients.managers import ConversationUpdateParams
from astrbot_sdk.decorators import on_event
from astrbot_sdk.llm.entities import ProviderRequest
from astrbot_sdk.message_result import MessageEventResult
from astrbot_sdk.message_components import Plain
class PersonaSample(Star):
@on_event("waiting_llm_request")
async def ensure_persona(self, event: MessageEvent, ctx: Context) -> None:
conversation = await ctx.conversations.get_current_conversation(
event.session_id,
create_if_not_exists=True,
)
if conversation is None or conversation.persona_id == "girlfriend":
return
await ctx.conversations.update_conversation(
event.session_id,
conversation.conversation_id,
ConversationUpdateParams(persona_id="girlfriend"),
)
@on_event("llm_request")
async def inject_context(
self,
event: MessageEvent,
ctx: Context,
request: ProviderRequest,
) -> None:
memories = await ctx.memory.search(event.text, limit=3)
facts = []
for item in memories:
value = item.get("value")
if isinstance(value, dict) and value.get("content"):
facts.append(f"- {value['content']}")
if facts:
request.system_prompt = (request.system_prompt or "") + "\n\n" + "\n".join(facts)
@on_event("llm_response")
async def capture_reply(
self,
event: MessageEvent,
ctx: Context,
response: LLMResponse,
) -> None:
del ctx
if response.text:
event.set_extra("last_reply", response.text)
@on_event("decorating_result")
async def decorate(
self,
event: MessageEvent,
ctx: Context,
result: MessageEventResult,
) -> None:
del event, ctx
result.chain.append(Plain("\n[persona active]", convert=False))
@on_event("after_message_sent")
async def persist(self, event: MessageEvent, ctx: Context) -> None:
reply = str(event.get_extra("last_reply", "") or "").strip()
if reply:
await ctx.db.set("sample:last_reply", reply)
```
#### 注意事项
1. 用于处理非消息类型的事件(如群成员变动、好友请求等
1. 用于处理平台事件,也可用于处理 AstrBot 原生消息链路中的系统事件(如 `llm_request`
2. 不能与 `@rate_limit``@cooldown` 一起使用
3. 不同平台的事件类型可能不同,需要查阅平台文档
4. `llm_request``decorating_result` 注入的是可变对象,修改会回写到 AstrBot 主链路
5. `llm_response` 主要用于观测和提取结果,不应用来替代主回复流程
---

View File

@@ -642,6 +642,15 @@ CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema(
required=("conversation",),
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
)
CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema(
required=("session",),
session={"type": "string"},
create_if_not_exists={"type": "boolean"},
)
CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema(
required=("conversation",),
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
)
CONVERSATION_LIST_INPUT_SCHEMA = _object_schema(
session=_nullable({"type": "string"}),
platform_id=_nullable({"type": "string"}),
@@ -1207,6 +1216,10 @@ BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
"input": CONVERSATION_GET_INPUT_SCHEMA,
"output": CONVERSATION_GET_OUTPUT_SCHEMA,
},
"conversation.get_current": {
"input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA,
"output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA,
},
"conversation.list": {
"input": CONVERSATION_LIST_INPUT_SCHEMA,
"output": CONVERSATION_LIST_OUTPUT_SCHEMA,
@@ -1653,6 +1666,8 @@ __all__ = [
"CONVERSATION_CREATE_SCHEMA",
"CONVERSATION_DELETE_INPUT_SCHEMA",
"CONVERSATION_DELETE_OUTPUT_SCHEMA",
"CONVERSATION_GET_CURRENT_INPUT_SCHEMA",
"CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA",
"CONVERSATION_GET_INPUT_SCHEMA",
"CONVERSATION_GET_OUTPUT_SCHEMA",
"CONVERSATION_LIST_INPUT_SCHEMA",

View File

@@ -96,7 +96,7 @@ def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
"""
values = [0.0] * _MOCK_EMBEDDING_DIM
for term in _embedding_terms(text):
digest = hashlib.sha256(f"{provider_id}:{term}".encode("utf-8")).digest()
digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest()
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
values[index] += 1.0 + min(len(term), 8) * 0.05
norm = math.sqrt(sum(value * value for value in values))
@@ -2672,6 +2672,25 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
return {"conversation": None}
return {"conversation": dict(record)}
async def _conversation_get_current(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = self._session_current_conversation_ids.get(session, "")
if not conversation_id and bool(payload.get("create_if_not_exists", False)):
created = await self._conversation_new(
_request_id,
{"session": session, "conversation": {}},
_token,
)
conversation_id = str(created.get("conversation_id", "")).strip()
if not conversation_id:
return {"conversation": None}
record = self._conversation_store.get(conversation_id)
if record is None or str(record.get("session", "")) != session:
return {"conversation": None}
return {"conversation": dict(record)}
async def _conversation_list(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
@@ -2824,6 +2843,10 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
self._builtin_descriptor("conversation.get", "获取对话"),
call_handler=self._conversation_get,
)
self.register(
self._builtin_descriptor("conversation.get_current", "获取当前对话"),
call_handler=self._conversation_get_current,
)
self.register(
self._builtin_descriptor("conversation.list", "列出对话"),
call_handler=self._conversation_list,

View File

@@ -39,6 +39,7 @@ from .._invocation_context import caller_plugin_scope
from .._plugin_logger import PluginLogger
from .._star_runtime import bind_star_runtime
from .._typing_utils import unwrap_optional
from ..clients.llm import LLMResponse
from ..context import CancelToken, Context
from ..conversation import (
DEFAULT_BUSY_MESSAGE,
@@ -49,8 +50,13 @@ from ..conversation import (
)
from ..events import MessageEvent
from ..filters import LocalFilterBinding
from ..llm.entities import ProviderRequest
from ..message_components import BaseMessageComponent
from ..message_result import MessageChain, MessageEventResult, coerce_message_chain
from ..message_result import (
MessageChain,
MessageEventResult,
coerce_message_chain,
)
from ..protocol.descriptors import (
CommandTrigger,
MessageTrigger,
@@ -76,6 +82,13 @@ class _ActiveConversation:
task: asyncio.Task[Any]
@dataclass(slots=True)
class _InjectedEventPayloads:
provider_request: ProviderRequest | None = None
llm_response: LLMResponse | None = None
event_result: MessageEventResult | None = None
class HandlerDispatcher:
def __init__(
self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler]
@@ -205,6 +218,8 @@ class HandlerDispatcher:
schedule_context: ScheduleContext | None = None,
) -> dict[str, Any]:
summary = {"sent_message": False, "stop": False, "call_llm": False}
injected_payloads = _InjectedEventPayloads()
event_type = self._event_type_name(event)
try:
limiter = loaded.limiter
if limiter is not None:
@@ -254,6 +269,7 @@ class HandlerDispatcher:
plugin_id=self._resolve_plugin_id(loaded),
handler_ref=loaded.descriptor.id,
schedule_context=schedule_context,
injected_payloads=injected_payloads,
)
)
if inspect.isasyncgen(result):
@@ -263,6 +279,11 @@ class HandlerDispatcher:
await self._handle_result_item(item, event, ctx),
)
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
self._append_injected_payloads(
summary,
injected_payloads,
event_type=event_type,
)
return summary
if inspect.isawaitable(result):
result = await result
@@ -272,6 +293,11 @@ class HandlerDispatcher:
await self._handle_result_item(result, event, ctx),
)
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
self._append_injected_payloads(
summary,
injected_payloads,
event_type=event_type,
)
return summary
except Exception as exc:
await self._handle_error(
@@ -339,6 +365,7 @@ class HandlerDispatcher:
handler_ref: str | None = None,
schedule_context: ScheduleContext | None = None,
conversation_session: ConversationSession | None = None,
injected_payloads: _InjectedEventPayloads | None = None,
) -> list[Any]:
"""构建 handler 参数列表。"""
from loguru import logger
@@ -371,6 +398,7 @@ class HandlerDispatcher:
ctx,
schedule_context,
conversation_session,
injected_payloads=injected_payloads,
)
# 2. Fallback 按名字注入
@@ -423,6 +451,8 @@ class HandlerDispatcher:
if not str(key).startswith("__command_")
}
)
if not isinstance(loaded.descriptor.trigger, CommandTrigger):
return parsed_args, None
model_param = resolve_command_model_param(loaded.callable)
if model_param is None:
return parsed_args, None
@@ -584,6 +614,8 @@ class HandlerDispatcher:
ctx: Context,
schedule_context: ScheduleContext | None,
conversation_session: ConversationSession | None,
*,
injected_payloads: _InjectedEventPayloads | None = None,
) -> Any:
"""根据类型注解注入参数。"""
param_type, _is_optional = unwrap_optional(param_type)
@@ -612,9 +644,117 @@ class HandlerDispatcher:
isinstance(param_type, type) and issubclass(param_type, ConversationSession)
):
return conversation_session
if param_type is ProviderRequest or (
isinstance(param_type, type) and issubclass(param_type, ProviderRequest)
):
return self._inject_provider_request(event, injected_payloads)
if param_type is LLMResponse or (
isinstance(param_type, type) and issubclass(param_type, LLMResponse)
):
return self._inject_llm_response(event, injected_payloads)
if param_type is MessageEventResult or (
isinstance(param_type, type) and issubclass(param_type, MessageEventResult)
):
return self._inject_event_result(event, injected_payloads)
return None
@staticmethod
def _event_type_name(event: MessageEvent) -> str:
raw = event.raw if isinstance(event.raw, dict) else {}
value = raw.get("event_type") or raw.get("type")
return str(value or "")
@staticmethod
def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None:
raw = event.raw if isinstance(event.raw, dict) else {}
payload = raw.get(key)
if isinstance(payload, dict):
return payload
nested_raw = raw.get("raw")
if isinstance(nested_raw, dict):
nested_payload = nested_raw.get(key)
if isinstance(nested_payload, dict):
return nested_payload
return None
def _inject_provider_request(
self,
event: MessageEvent,
injected_payloads: _InjectedEventPayloads | None,
) -> ProviderRequest | None:
if injected_payloads is None:
payload = self._payload_from_event(event, "provider_request")
return (
ProviderRequest.from_payload(payload) if payload is not None else None
)
if injected_payloads.provider_request is None:
payload = self._payload_from_event(event, "provider_request")
if payload is None:
return None
injected_payloads.provider_request = ProviderRequest.from_payload(payload)
return injected_payloads.provider_request
def _inject_llm_response(
self,
event: MessageEvent,
injected_payloads: _InjectedEventPayloads | None,
) -> LLMResponse | None:
if injected_payloads is None:
payload = self._payload_from_event(event, "llm_response")
return LLMResponse.model_validate(payload) if payload is not None else None
if injected_payloads.llm_response is None:
payload = self._payload_from_event(event, "llm_response")
if payload is None:
return None
injected_payloads.llm_response = LLMResponse.model_validate(payload)
return injected_payloads.llm_response
def _inject_event_result(
self,
event: MessageEvent,
injected_payloads: _InjectedEventPayloads | None,
) -> MessageEventResult | None:
if injected_payloads is None:
payload = self._payload_from_event(event, "event_result")
return (
MessageEventResult.from_payload(payload)
if payload is not None
else None
)
if injected_payloads.event_result is None:
payload = self._payload_from_event(event, "event_result")
if payload is None:
return None
injected_payloads.event_result = MessageEventResult.from_payload(payload)
return injected_payloads.event_result
@staticmethod
def _append_injected_payloads(
summary: dict[str, Any],
injected_payloads: _InjectedEventPayloads,
*,
event_type: str,
) -> None:
if (
event_type == "llm_request"
and injected_payloads.provider_request is not None
):
summary["provider_request"] = (
injected_payloads.provider_request.to_payload()
)
elif (
event_type == "llm_response" and injected_payloads.llm_response is not None
):
summary["llm_response"] = injected_payloads.llm_response.model_dump(
exclude_none=True
)
elif (
event_type == "decorating_result"
and injected_payloads.event_result is not None
):
summary["event_result"] = injected_payloads.event_result.to_payload()
def _format_handler_injection_error(
self,
*,