mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-02 02:30:16 +08:00
feat: add conversation.get_current capability and related schemas
- Introduced CONVERSATION_GET_CURRENT_INPUT_SCHEMA and CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA for handling current conversation requests. - Implemented _conversation_get_current method in BuiltinCapabilityRouterMixin to manage current conversation retrieval and creation. - Registered the new capability in CoreCapabilityBridge. - Enhanced HandlerDispatcher to inject provider request, LLM response, and event result payloads into the event handling process. - Updated tests to validate the new functionality and ensure proper payload handling.
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..errors import AstrBotError, ErrorCodes
|
||||
from ..message_session import MessageSession
|
||||
from ._proxy import CapabilityProxy
|
||||
|
||||
@@ -138,7 +139,15 @@ class PersonaManagerClient:
|
||||
self._proxy = proxy
|
||||
|
||||
async def get_persona(self, persona_id: str) -> PersonaRecord:
|
||||
output = await self._proxy.call("persona.get", {"persona_id": str(persona_id)})
|
||||
try:
|
||||
output = await self._proxy.call(
|
||||
"persona.get",
|
||||
{"persona_id": str(persona_id)},
|
||||
)
|
||||
except AstrBotError as exc:
|
||||
if exc.code == ErrorCodes.INVALID_INPUT:
|
||||
raise ValueError(f"persona not found: {persona_id}") from exc
|
||||
raise
|
||||
persona = PersonaRecord.from_payload(output.get("persona"))
|
||||
if persona is None:
|
||||
raise ValueError(f"persona not found: {persona_id}")
|
||||
@@ -251,6 +260,21 @@ class ConversationManagerClient:
|
||||
)
|
||||
return ConversationRecord.from_payload(output.get("conversation"))
|
||||
|
||||
async def get_current_conversation(
|
||||
self,
|
||||
session: str | MessageSession,
|
||||
*,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> ConversationRecord | None:
|
||||
output = await self._proxy.call(
|
||||
"conversation.get_current",
|
||||
{
|
||||
"session": _normalize_session(session),
|
||||
"create_if_not_exists": bool(create_if_not_exists),
|
||||
},
|
||||
)
|
||||
return ConversationRecord.from_payload(output.get("conversation"))
|
||||
|
||||
async def get_conversations(
|
||||
self,
|
||||
session: str | MessageSession | None = None,
|
||||
|
||||
@@ -113,6 +113,11 @@ class MyPlugin(Star):
|
||||
- 注册 LLM 工具
|
||||
- 启动后台任务
|
||||
|
||||
**最佳实践:**
|
||||
- `on_start()` 里只做初始化、能力注册和轻量状态恢复
|
||||
- 需要长期保存的应是配置值、句柄、任务引用,不要把 `ctx` 实例长期挂到 `self`
|
||||
- 如果要和 AstrBot 原生 persona / conversation 协作,优先在这里校验或创建所需资源
|
||||
|
||||
**示例:**
|
||||
|
||||
```python
|
||||
@@ -150,6 +155,11 @@ class MyPlugin(Star):
|
||||
- 注销 LLM 工具
|
||||
- 保存状态数据
|
||||
|
||||
**最佳实践:**
|
||||
- 在 `on_stop()` 中释放 `on_start()` 注册的任务、监听器和外部资源
|
||||
- 把需要持久化的状态尽量提前落库,不要把关键保存逻辑完全依赖在进程退出瞬间
|
||||
- 始终把收到的 `ctx` 继续传给 `super().on_stop(ctx)`,不要手动丢掉它
|
||||
|
||||
**示例:**
|
||||
|
||||
```python
|
||||
|
||||
@@ -1101,6 +1101,8 @@ from astrbot_sdk.clients import PersonaManagerClient
|
||||
|
||||
获取指定人格。
|
||||
|
||||
当人格不存在时会抛出 `ValueError`,而不是返回 `None`。
|
||||
|
||||
---
|
||||
|
||||
#### `get_all_personas()`
|
||||
@@ -1163,6 +1165,17 @@ from astrbot_sdk.clients import ConversationManagerClient
|
||||
|
||||
---
|
||||
|
||||
#### `get_current_conversation(session, create_if_not_exists=False)`
|
||||
|
||||
获取当前 session 正在使用的对话记录。
|
||||
|
||||
这个方法适合“跟随 AstrBot 原生当前会话状态”的插件,例如:
|
||||
- 给当前会话切换 persona
|
||||
- 判断当前主聊天是否已经在某个 persona 下
|
||||
- 在 `waiting_llm_request` / `llm_request` hook 中对当前对话做增强
|
||||
|
||||
---
|
||||
|
||||
#### `get_conversations(session=None, platform_id=None)`
|
||||
|
||||
获取对话列表。
|
||||
|
||||
@@ -662,7 +662,7 @@ await ctx.conversations.delete_conversation(
|
||||
await ctx.conversations.delete_conversation(event.session_id)
|
||||
```
|
||||
|
||||
##### `get_conversation() / get_conversations()`
|
||||
##### `get_conversation() / get_current_conversation() / get_conversations()`
|
||||
|
||||
获取对话。
|
||||
|
||||
@@ -674,6 +674,12 @@ conv = await ctx.conversations.get_conversation(
|
||||
create_if_not_exists=True
|
||||
)
|
||||
|
||||
# 获取当前选中的对话
|
||||
current = await ctx.conversations.get_current_conversation(
|
||||
event.session_id,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
|
||||
# 获取对话列表
|
||||
convs = await ctx.conversations.get_conversations(event.session_id)
|
||||
```
|
||||
|
||||
@@ -210,11 +210,110 @@ async def handle_request(self, event, ctx: Context):
|
||||
await ctx.platform.send(event.user_id, "已自动通过好友请求")
|
||||
```
|
||||
|
||||
#### LLM Pipeline Hooks
|
||||
|
||||
`@on_event` 也用于挂接 AstrBot 原生消息处理链路中的系统事件。
|
||||
|
||||
常见事件及可注入对象:
|
||||
|
||||
| 事件名 | 常见可注入参数 | 是否可修改主链路 |
|
||||
|------|------|------|
|
||||
| `waiting_llm_request` | `MessageEvent`, `Context` | 间接可修改,例如切换当前对话 persona |
|
||||
| `llm_request` | `MessageEvent`, `Context`, `ProviderRequest` | 是,可直接修改 `ProviderRequest` |
|
||||
| `llm_response` | `MessageEvent`, `Context`, `LLMResponse` | 否,适合观察和提取回复内容 |
|
||||
| `decorating_result` | `MessageEvent`, `Context`, `MessageEventResult` | 是,可直接修改结果消息链 |
|
||||
| `after_message_sent` | `MessageEvent`, `Context` | 否,适合落库、记忆、统计 |
|
||||
|
||||
最小示例:
|
||||
|
||||
```python
|
||||
from astrbot_sdk import Context, MessageEvent
|
||||
from astrbot_sdk.decorators import on_event
|
||||
from astrbot_sdk.llm.entities import ProviderRequest
|
||||
|
||||
@on_event("llm_request")
|
||||
async def add_memory(self, event: MessageEvent, ctx: Context, request: ProviderRequest):
|
||||
del event, ctx
|
||||
request.system_prompt = (request.system_prompt or "") + "\n\nmemory: user likes tea"
|
||||
```
|
||||
|
||||
完整示例:
|
||||
|
||||
```python
|
||||
from astrbot_sdk import Context, MessageEvent, Star
|
||||
from astrbot_sdk.clients.llm import LLMResponse
|
||||
from astrbot_sdk.clients.managers import ConversationUpdateParams
|
||||
from astrbot_sdk.decorators import on_event
|
||||
from astrbot_sdk.llm.entities import ProviderRequest
|
||||
from astrbot_sdk.message_result import MessageEventResult
|
||||
from astrbot_sdk.message_components import Plain
|
||||
|
||||
class PersonaSample(Star):
|
||||
@on_event("waiting_llm_request")
|
||||
async def ensure_persona(self, event: MessageEvent, ctx: Context) -> None:
|
||||
conversation = await ctx.conversations.get_current_conversation(
|
||||
event.session_id,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
if conversation is None or conversation.persona_id == "girlfriend":
|
||||
return
|
||||
await ctx.conversations.update_conversation(
|
||||
event.session_id,
|
||||
conversation.conversation_id,
|
||||
ConversationUpdateParams(persona_id="girlfriend"),
|
||||
)
|
||||
|
||||
@on_event("llm_request")
|
||||
async def inject_context(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
ctx: Context,
|
||||
request: ProviderRequest,
|
||||
) -> None:
|
||||
memories = await ctx.memory.search(event.text, limit=3)
|
||||
facts = []
|
||||
for item in memories:
|
||||
value = item.get("value")
|
||||
if isinstance(value, dict) and value.get("content"):
|
||||
facts.append(f"- {value['content']}")
|
||||
if facts:
|
||||
request.system_prompt = (request.system_prompt or "") + "\n\n" + "\n".join(facts)
|
||||
|
||||
@on_event("llm_response")
|
||||
async def capture_reply(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
ctx: Context,
|
||||
response: LLMResponse,
|
||||
) -> None:
|
||||
del ctx
|
||||
if response.text:
|
||||
event.set_extra("last_reply", response.text)
|
||||
|
||||
@on_event("decorating_result")
|
||||
async def decorate(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
ctx: Context,
|
||||
result: MessageEventResult,
|
||||
) -> None:
|
||||
del event, ctx
|
||||
result.chain.append(Plain("\n[persona active]", convert=False))
|
||||
|
||||
@on_event("after_message_sent")
|
||||
async def persist(self, event: MessageEvent, ctx: Context) -> None:
|
||||
reply = str(event.get_extra("last_reply", "") or "").strip()
|
||||
if reply:
|
||||
await ctx.db.set("sample:last_reply", reply)
|
||||
```
|
||||
|
||||
#### 注意事项
|
||||
|
||||
1. 用于处理非消息类型的事件(如群成员变动、好友请求等)
|
||||
1. 可用于处理平台事件,也可用于处理 AstrBot 原生消息链路中的系统事件(如 `llm_request`)
|
||||
2. 不能与 `@rate_limit` 或 `@cooldown` 一起使用
|
||||
3. 不同平台的事件类型可能不同,需要查阅平台文档
|
||||
4. `llm_request` 和 `decorating_result` 注入的是可变对象,修改会回写到 AstrBot 主链路
|
||||
5. `llm_response` 主要用于观测和提取结果,不应用来替代主回复流程
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -642,6 +642,15 @@ CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("conversation",),
|
||||
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
|
||||
)
|
||||
CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema(
|
||||
required=("session",),
|
||||
session={"type": "string"},
|
||||
create_if_not_exists={"type": "boolean"},
|
||||
)
|
||||
CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("conversation",),
|
||||
conversation=_nullable(CONVERSATION_RECORD_SCHEMA),
|
||||
)
|
||||
CONVERSATION_LIST_INPUT_SCHEMA = _object_schema(
|
||||
session=_nullable({"type": "string"}),
|
||||
platform_id=_nullable({"type": "string"}),
|
||||
@@ -1207,6 +1216,10 @@ BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = {
|
||||
"input": CONVERSATION_GET_INPUT_SCHEMA,
|
||||
"output": CONVERSATION_GET_OUTPUT_SCHEMA,
|
||||
},
|
||||
"conversation.get_current": {
|
||||
"input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA,
|
||||
"output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA,
|
||||
},
|
||||
"conversation.list": {
|
||||
"input": CONVERSATION_LIST_INPUT_SCHEMA,
|
||||
"output": CONVERSATION_LIST_OUTPUT_SCHEMA,
|
||||
@@ -1653,6 +1666,8 @@ __all__ = [
|
||||
"CONVERSATION_CREATE_SCHEMA",
|
||||
"CONVERSATION_DELETE_INPUT_SCHEMA",
|
||||
"CONVERSATION_DELETE_OUTPUT_SCHEMA",
|
||||
"CONVERSATION_GET_CURRENT_INPUT_SCHEMA",
|
||||
"CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA",
|
||||
"CONVERSATION_GET_INPUT_SCHEMA",
|
||||
"CONVERSATION_GET_OUTPUT_SCHEMA",
|
||||
"CONVERSATION_LIST_INPUT_SCHEMA",
|
||||
|
||||
@@ -96,7 +96,7 @@ def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
|
||||
"""
|
||||
values = [0.0] * _MOCK_EMBEDDING_DIM
|
||||
for term in _embedding_terms(text):
|
||||
digest = hashlib.sha256(f"{provider_id}:{term}".encode("utf-8")).digest()
|
||||
digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest()
|
||||
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
|
||||
values[index] += 1.0 + min(len(term), 8) * 0.05
|
||||
norm = math.sqrt(sum(value * value for value in values))
|
||||
@@ -2672,6 +2672,25 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
return {"conversation": None}
|
||||
return {"conversation": dict(record)}
|
||||
|
||||
async def _conversation_get_current(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
session = str(payload.get("session", "")).strip()
|
||||
conversation_id = self._session_current_conversation_ids.get(session, "")
|
||||
if not conversation_id and bool(payload.get("create_if_not_exists", False)):
|
||||
created = await self._conversation_new(
|
||||
_request_id,
|
||||
{"session": session, "conversation": {}},
|
||||
_token,
|
||||
)
|
||||
conversation_id = str(created.get("conversation_id", "")).strip()
|
||||
if not conversation_id:
|
||||
return {"conversation": None}
|
||||
record = self._conversation_store.get(conversation_id)
|
||||
if record is None or str(record.get("session", "")) != session:
|
||||
return {"conversation": None}
|
||||
return {"conversation": dict(record)}
|
||||
|
||||
async def _conversation_list(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
@@ -2824,6 +2843,10 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
self._builtin_descriptor("conversation.get", "获取对话"),
|
||||
call_handler=self._conversation_get,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.get_current", "获取当前对话"),
|
||||
call_handler=self._conversation_get_current,
|
||||
)
|
||||
self.register(
|
||||
self._builtin_descriptor("conversation.list", "列出对话"),
|
||||
call_handler=self._conversation_list,
|
||||
|
||||
@@ -39,6 +39,7 @@ from .._invocation_context import caller_plugin_scope
|
||||
from .._plugin_logger import PluginLogger
|
||||
from .._star_runtime import bind_star_runtime
|
||||
from .._typing_utils import unwrap_optional
|
||||
from ..clients.llm import LLMResponse
|
||||
from ..context import CancelToken, Context
|
||||
from ..conversation import (
|
||||
DEFAULT_BUSY_MESSAGE,
|
||||
@@ -49,8 +50,13 @@ from ..conversation import (
|
||||
)
|
||||
from ..events import MessageEvent
|
||||
from ..filters import LocalFilterBinding
|
||||
from ..llm.entities import ProviderRequest
|
||||
from ..message_components import BaseMessageComponent
|
||||
from ..message_result import MessageChain, MessageEventResult, coerce_message_chain
|
||||
from ..message_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
coerce_message_chain,
|
||||
)
|
||||
from ..protocol.descriptors import (
|
||||
CommandTrigger,
|
||||
MessageTrigger,
|
||||
@@ -76,6 +82,13 @@ class _ActiveConversation:
|
||||
task: asyncio.Task[Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _InjectedEventPayloads:
|
||||
provider_request: ProviderRequest | None = None
|
||||
llm_response: LLMResponse | None = None
|
||||
event_result: MessageEventResult | None = None
|
||||
|
||||
|
||||
class HandlerDispatcher:
|
||||
def __init__(
|
||||
self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler]
|
||||
@@ -205,6 +218,8 @@ class HandlerDispatcher:
|
||||
schedule_context: ScheduleContext | None = None,
|
||||
) -> dict[str, Any]:
|
||||
summary = {"sent_message": False, "stop": False, "call_llm": False}
|
||||
injected_payloads = _InjectedEventPayloads()
|
||||
event_type = self._event_type_name(event)
|
||||
try:
|
||||
limiter = loaded.limiter
|
||||
if limiter is not None:
|
||||
@@ -254,6 +269,7 @@ class HandlerDispatcher:
|
||||
plugin_id=self._resolve_plugin_id(loaded),
|
||||
handler_ref=loaded.descriptor.id,
|
||||
schedule_context=schedule_context,
|
||||
injected_payloads=injected_payloads,
|
||||
)
|
||||
)
|
||||
if inspect.isasyncgen(result):
|
||||
@@ -263,6 +279,11 @@ class HandlerDispatcher:
|
||||
await self._handle_result_item(item, event, ctx),
|
||||
)
|
||||
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
|
||||
self._append_injected_payloads(
|
||||
summary,
|
||||
injected_payloads,
|
||||
event_type=event_type,
|
||||
)
|
||||
return summary
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
@@ -272,6 +293,11 @@ class HandlerDispatcher:
|
||||
await self._handle_result_item(result, event, ctx),
|
||||
)
|
||||
summary["stop"] = bool(summary.get("stop")) or event.is_stopped()
|
||||
self._append_injected_payloads(
|
||||
summary,
|
||||
injected_payloads,
|
||||
event_type=event_type,
|
||||
)
|
||||
return summary
|
||||
except Exception as exc:
|
||||
await self._handle_error(
|
||||
@@ -339,6 +365,7 @@ class HandlerDispatcher:
|
||||
handler_ref: str | None = None,
|
||||
schedule_context: ScheduleContext | None = None,
|
||||
conversation_session: ConversationSession | None = None,
|
||||
injected_payloads: _InjectedEventPayloads | None = None,
|
||||
) -> list[Any]:
|
||||
"""构建 handler 参数列表。"""
|
||||
from loguru import logger
|
||||
@@ -371,6 +398,7 @@ class HandlerDispatcher:
|
||||
ctx,
|
||||
schedule_context,
|
||||
conversation_session,
|
||||
injected_payloads=injected_payloads,
|
||||
)
|
||||
|
||||
# 2. Fallback 按名字注入
|
||||
@@ -423,6 +451,8 @@ class HandlerDispatcher:
|
||||
if not str(key).startswith("__command_")
|
||||
}
|
||||
)
|
||||
if not isinstance(loaded.descriptor.trigger, CommandTrigger):
|
||||
return parsed_args, None
|
||||
model_param = resolve_command_model_param(loaded.callable)
|
||||
if model_param is None:
|
||||
return parsed_args, None
|
||||
@@ -584,6 +614,8 @@ class HandlerDispatcher:
|
||||
ctx: Context,
|
||||
schedule_context: ScheduleContext | None,
|
||||
conversation_session: ConversationSession | None,
|
||||
*,
|
||||
injected_payloads: _InjectedEventPayloads | None = None,
|
||||
) -> Any:
|
||||
"""根据类型注解注入参数。"""
|
||||
param_type, _is_optional = unwrap_optional(param_type)
|
||||
@@ -612,9 +644,117 @@ class HandlerDispatcher:
|
||||
isinstance(param_type, type) and issubclass(param_type, ConversationSession)
|
||||
):
|
||||
return conversation_session
|
||||
if param_type is ProviderRequest or (
|
||||
isinstance(param_type, type) and issubclass(param_type, ProviderRequest)
|
||||
):
|
||||
return self._inject_provider_request(event, injected_payloads)
|
||||
if param_type is LLMResponse or (
|
||||
isinstance(param_type, type) and issubclass(param_type, LLMResponse)
|
||||
):
|
||||
return self._inject_llm_response(event, injected_payloads)
|
||||
if param_type is MessageEventResult or (
|
||||
isinstance(param_type, type) and issubclass(param_type, MessageEventResult)
|
||||
):
|
||||
return self._inject_event_result(event, injected_payloads)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _event_type_name(event: MessageEvent) -> str:
|
||||
raw = event.raw if isinstance(event.raw, dict) else {}
|
||||
value = raw.get("event_type") or raw.get("type")
|
||||
return str(value or "")
|
||||
|
||||
@staticmethod
|
||||
def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None:
|
||||
raw = event.raw if isinstance(event.raw, dict) else {}
|
||||
payload = raw.get(key)
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
nested_raw = raw.get("raw")
|
||||
if isinstance(nested_raw, dict):
|
||||
nested_payload = nested_raw.get(key)
|
||||
if isinstance(nested_payload, dict):
|
||||
return nested_payload
|
||||
return None
|
||||
|
||||
def _inject_provider_request(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
injected_payloads: _InjectedEventPayloads | None,
|
||||
) -> ProviderRequest | None:
|
||||
if injected_payloads is None:
|
||||
payload = self._payload_from_event(event, "provider_request")
|
||||
return (
|
||||
ProviderRequest.from_payload(payload) if payload is not None else None
|
||||
)
|
||||
if injected_payloads.provider_request is None:
|
||||
payload = self._payload_from_event(event, "provider_request")
|
||||
if payload is None:
|
||||
return None
|
||||
injected_payloads.provider_request = ProviderRequest.from_payload(payload)
|
||||
return injected_payloads.provider_request
|
||||
|
||||
def _inject_llm_response(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
injected_payloads: _InjectedEventPayloads | None,
|
||||
) -> LLMResponse | None:
|
||||
if injected_payloads is None:
|
||||
payload = self._payload_from_event(event, "llm_response")
|
||||
return LLMResponse.model_validate(payload) if payload is not None else None
|
||||
if injected_payloads.llm_response is None:
|
||||
payload = self._payload_from_event(event, "llm_response")
|
||||
if payload is None:
|
||||
return None
|
||||
injected_payloads.llm_response = LLMResponse.model_validate(payload)
|
||||
return injected_payloads.llm_response
|
||||
|
||||
def _inject_event_result(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
injected_payloads: _InjectedEventPayloads | None,
|
||||
) -> MessageEventResult | None:
|
||||
if injected_payloads is None:
|
||||
payload = self._payload_from_event(event, "event_result")
|
||||
return (
|
||||
MessageEventResult.from_payload(payload)
|
||||
if payload is not None
|
||||
else None
|
||||
)
|
||||
if injected_payloads.event_result is None:
|
||||
payload = self._payload_from_event(event, "event_result")
|
||||
if payload is None:
|
||||
return None
|
||||
injected_payloads.event_result = MessageEventResult.from_payload(payload)
|
||||
return injected_payloads.event_result
|
||||
|
||||
@staticmethod
|
||||
def _append_injected_payloads(
|
||||
summary: dict[str, Any],
|
||||
injected_payloads: _InjectedEventPayloads,
|
||||
*,
|
||||
event_type: str,
|
||||
) -> None:
|
||||
if (
|
||||
event_type == "llm_request"
|
||||
and injected_payloads.provider_request is not None
|
||||
):
|
||||
summary["provider_request"] = (
|
||||
injected_payloads.provider_request.to_payload()
|
||||
)
|
||||
elif (
|
||||
event_type == "llm_response" and injected_payloads.llm_response is not None
|
||||
):
|
||||
summary["llm_response"] = injected_payloads.llm_response.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
elif (
|
||||
event_type == "decorating_result"
|
||||
and injected_payloads.event_result is not None
|
||||
):
|
||||
summary["event_result"] = injected_payloads.event_result.to_payload()
|
||||
|
||||
def _format_handler_injection_error(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user