feat: add session and system capabilities for plugin management and event handling

- Implemented SessionCapabilityMixin with methods to manage session-level plugin states and handlers.
- Added SystemCapabilityMixin to handle system-level functionalities including file management, event handling, and dynamic command registration.
- Introduced methods for enabling/disabling plugins, filtering handlers, and managing LLM and TTS service states.
- Registered various system capabilities for data directory access, HTML rendering, and event streaming.
This commit is contained in:
whatevertogo
2026-03-19 07:35:01 +08:00
parent 5a46321a04
commit e961e36170
17 changed files with 3650 additions and 3398 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from .bridge_base import CapabilityRouterBridgeBase
from .capabilities import (
ConversationCapabilityMixin,
DBCapabilityMixin,
HttpCapabilityMixin,
KnowledgeBaseCapabilityMixin,
LLMCapabilityMixin,
MemoryCapabilityMixin,
MetadataCapabilityMixin,
PersonaCapabilityMixin,
PlatformCapabilityMixin,
ProviderCapabilityMixin,
SessionCapabilityMixin,
SystemCapabilityMixin,
)
class BuiltinCapabilityRouterMixin(
LLMCapabilityMixin,
MemoryCapabilityMixin,
DBCapabilityMixin,
PlatformCapabilityMixin,
HttpCapabilityMixin,
MetadataCapabilityMixin,
ProviderCapabilityMixin,
SessionCapabilityMixin,
PersonaCapabilityMixin,
ConversationCapabilityMixin,
KnowledgeBaseCapabilityMixin,
SystemCapabilityMixin,
CapabilityRouterBridgeBase,
):
def _register_builtin_capabilities(self) -> None:
self._register_llm_capabilities()
self._register_memory_capabilities()
self._register_db_capabilities()
self._register_platform_capabilities()
self._register_http_capabilities()
self._register_metadata_capabilities()
self._register_provider_capabilities()
self._register_agent_tool_capabilities()
self._register_session_capabilities()
self._register_persona_capabilities()
self._register_conversation_capabilities()
self._register_kb_capabilities()
self._register_provider_manager_capabilities()
self._register_platform_manager_capabilities()
self._register_system_capabilities()
__all__ = ["BuiltinCapabilityRouterMixin"]

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Any
from ...protocol.descriptors import CapabilityDescriptor
class CapabilityRouterHost:
memory_store: dict[str, dict[str, Any]]
_memory_index: dict[str, dict[str, Any]]
_memory_dirty_keys: set[str]
_memory_expires_at: dict[str, datetime | None]
db_store: dict[str, Any]
sent_messages: list[dict[str, Any]]
event_actions: list[dict[str, Any]]
http_api_store: list[dict[str, Any]]
_event_streams: dict[str, dict[str, Any]]
_plugins: dict[str, Any]
_request_overlays: dict[str, dict[str, Any]]
_provider_catalog: dict[str, list[dict[str, Any]]]
_provider_configs: dict[str, dict[str, Any]]
_active_provider_ids: dict[str, str | None]
_provider_change_subscriptions: dict[str, asyncio.Queue[dict[str, Any]]]
_system_data_root: Path
_session_waiters: dict[str, set[str]]
_session_plugin_configs: dict[str, dict[str, Any]]
_session_service_configs: dict[str, dict[str, Any]]
_db_watch_subscriptions: dict[str, tuple[str | None, asyncio.Queue[dict[str, Any]]]]
_dynamic_command_routes: dict[str, list[dict[str, Any]]]
_file_token_store: dict[str, str]
_platform_instances: list[dict[str, Any]]
_persona_store: dict[str, dict[str, Any]]
_conversation_store: dict[str, dict[str, Any]]
_session_current_conversation_ids: dict[str, str]
_kb_store: dict[str, dict[str, Any]]
def register(
self,
descriptor: CapabilityDescriptor,
*,
call_handler=None,
stream_handler=None,
finalize=None,
exposed: bool = True,
) -> None:
raise NotImplementedError
def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None:
raise NotImplementedError
@staticmethod
def _require_caller_plugin_id(capability_name: str) -> str:
raise NotImplementedError
def register_dynamic_command_route(
self,
*,
plugin_id: str,
command_name: str,
handler_full_name: str,
desc: str = "",
priority: int = 0,
use_regex: bool = False,
) -> None:
raise NotImplementedError
def get_platform_instances(self) -> list[dict[str, Any]]:
raise NotImplementedError
def _register_agent_tool_capabilities(self) -> None:
raise NotImplementedError
def _provider_entry(
self,
payload: dict[str, Any],
capability_name: str,
expected_kind: str | None = None,
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embedding(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embeddings(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -0,0 +1,183 @@
from __future__ import annotations
import copy
import hashlib
import math
import re
from datetime import datetime, timezone
from typing import Any
from ...protocol.descriptors import (
BUILTIN_CAPABILITY_SCHEMAS,
CapabilityDescriptor,
SessionRef,
)
from ._host import CapabilityRouterHost
def _clone_target_payload(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
return {str(key): item for key, item in value.items()}
def _clone_chain_payload(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
return [
{str(key): item for key, item in chunk.items()}
for chunk in value
if isinstance(chunk, dict)
]
_MOCK_EMBEDDING_DIM = 24
def _embedding_terms(text: str) -> list[str]:
"""Build stable tokens for the mock embedding implementation."""
normalized = re.sub(r"\s+", " ", str(text).strip().casefold())
compact = normalized.replace(" ", "")
if not normalized:
return []
terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word]
if compact:
if len(compact) == 1:
terms.append(compact)
else:
terms.extend(
compact[index : index + 2] for index in range(len(compact) - 1)
)
terms.append(compact)
return terms or [normalized]
def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
"""Generate a deterministic normalized mock embedding vector."""
values = [0.0] * _MOCK_EMBEDDING_DIM
for term in _embedding_terms(text):
digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest()
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
values[index] += 1.0 + min(len(term), 8) * 0.05
norm = math.sqrt(sum(value * value for value in values))
if norm <= 0:
return values
return [value / norm for value in values]
class CapabilityRouterBridgeBase(CapabilityRouterHost):
def _builtin_descriptor(
self,
name: str,
description: str,
*,
supports_stream: bool = False,
cancelable: bool = False,
) -> CapabilityDescriptor:
schema = BUILTIN_CAPABILITY_SCHEMAS[name]
return CapabilityDescriptor(
name=name,
description=description,
input_schema=copy.deepcopy(schema["input"]),
output_schema=copy.deepcopy(schema["output"]),
supports_stream=supports_stream,
cancelable=cancelable,
)
def _resolve_target(
self, payload: dict[str, Any]
) -> tuple[str, dict[str, Any] | None]:
target_payload = payload.get("target")
if isinstance(target_payload, dict):
target = SessionRef.model_validate(target_payload)
return target.session, target.to_payload()
return str(payload.get("session", "")), None
@staticmethod
def _is_group_session(session: str) -> bool:
normalized = str(session).lower()
return ":group:" in normalized or ":groupmessage:" in normalized
@staticmethod
def _mock_group_payload(session: str) -> dict[str, Any] | None:
if not CapabilityRouterBridgeBase._is_group_session(session):
return None
members = [
{
"user_id": f"{session}:member-1",
"nickname": "Member 1",
"role": "member",
},
{
"user_id": f"{session}:member-2",
"nickname": "Member 2",
"role": "admin",
},
]
return {
"group_id": session.rsplit(":", maxsplit=1)[-1],
"group_name": f"Mock Group {session.rsplit(':', maxsplit=1)[-1]}",
"group_avatar": "",
"group_owner": members[0]["user_id"],
"group_admins": [members[1]["user_id"]],
"members": members,
}
def _session_plugin_config(self, session: str) -> dict[str, Any]:
config = self._session_plugin_configs.get(str(session), {})
return dict(config) if isinstance(config, dict) else {}
def _session_service_config(self, session: str) -> dict[str, Any]:
config = self._session_service_configs.get(str(session), {})
return dict(config) if isinstance(config, dict) else {}
@staticmethod
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@staticmethod
def _session_platform_id(session: str) -> str:
parts = str(session).split(":", maxsplit=1)
if parts and parts[0].strip():
return parts[0].strip()
return "unknown"
@staticmethod
def _normalize_history_payload(value: Any) -> list[dict[str, Any]]:
if not isinstance(value, list):
return []
return [dict(item) for item in value if isinstance(item, dict)]
@staticmethod
def _normalize_persona_dialogs_payload(value: Any) -> list[str]:
if not isinstance(value, list):
return []
return [str(item) for item in value if isinstance(item, str)]
@staticmethod
def _optional_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _provider_entry(
self,
payload: dict[str, Any],
capability_name: str,
expected_kind: str | None = None,
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embedding(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError
async def _provider_embedding_get_embeddings(
self, request_id: str, payload: dict[str, Any], token
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -0,0 +1,27 @@
from .conversation import ConversationCapabilityMixin
from .db import DBCapabilityMixin
from .http import HttpCapabilityMixin
from .kb import KnowledgeBaseCapabilityMixin
from .llm import LLMCapabilityMixin
from .memory import MemoryCapabilityMixin
from .metadata import MetadataCapabilityMixin
from .persona import PersonaCapabilityMixin
from .platform import PlatformCapabilityMixin
from .provider import ProviderCapabilityMixin
from .session import SessionCapabilityMixin
from .system import SystemCapabilityMixin
__all__ = [
"ConversationCapabilityMixin",
"DBCapabilityMixin",
"HttpCapabilityMixin",
"KnowledgeBaseCapabilityMixin",
"LLMCapabilityMixin",
"MemoryCapabilityMixin",
"MetadataCapabilityMixin",
"PersonaCapabilityMixin",
"PlatformCapabilityMixin",
"ProviderCapabilityMixin",
"SessionCapabilityMixin",
"SystemCapabilityMixin",
]

View File

@@ -0,0 +1,232 @@
from __future__ import annotations
import uuid
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class ConversationCapabilityMixin(CapabilityRouterBridgeBase):
async def _conversation_new(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
if not session:
raise AstrBotError.invalid_input("conversation.new requires session")
raw_conversation = payload.get("conversation")
if raw_conversation is None:
raw_conversation = {}
if not isinstance(raw_conversation, dict):
raise AstrBotError.invalid_input(
"conversation.new requires conversation object"
)
conversation_id = uuid.uuid4().hex
now = self._now_iso()
record = {
"conversation_id": conversation_id,
"session": session,
"platform_id": (
str(raw_conversation.get("platform_id"))
if raw_conversation.get("platform_id") is not None
else self._session_platform_id(session)
),
"history": self._normalize_history_payload(raw_conversation.get("history")),
"title": (
str(raw_conversation.get("title"))
if raw_conversation.get("title") is not None
else None
),
"persona_id": (
str(raw_conversation.get("persona_id"))
if raw_conversation.get("persona_id") is not None
else None
),
"created_at": now,
"updated_at": now,
"token_usage": None,
}
self._conversation_store[conversation_id] = record
self._session_current_conversation_ids[session] = conversation_id
return {"conversation_id": conversation_id}
async def _conversation_switch(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = str(payload.get("conversation_id", "")).strip()
record = self._conversation_store.get(conversation_id)
if record is None or str(record.get("session", "")) != session:
raise AstrBotError.invalid_input(
"conversation.switch requires a conversation in the same session"
)
self._session_current_conversation_ids[session] = conversation_id
return {}
async def _conversation_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = payload.get("conversation_id")
normalized_conversation_id = (
str(conversation_id).strip() if conversation_id is not None else ""
)
if not normalized_conversation_id:
normalized_conversation_id = self._session_current_conversation_ids.get(
session, ""
)
if not normalized_conversation_id:
return {}
record = self._conversation_store.get(normalized_conversation_id)
if record is None:
return {}
if str(record.get("session", "")) != session:
raise AstrBotError.invalid_input(
"conversation.delete requires a conversation in the same session"
)
del self._conversation_store[normalized_conversation_id]
current_conversation_id = self._session_current_conversation_ids.get(session)
if current_conversation_id == normalized_conversation_id:
replacement = next(
(
conversation_id
for conversation_id, item in self._conversation_store.items()
if str(item.get("session", "")) == session
),
None,
)
if replacement is None:
self._session_current_conversation_ids.pop(session, None)
else:
self._session_current_conversation_ids[session] = replacement
return {}
async def _conversation_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = str(payload.get("conversation_id", "")).strip()
record = self._conversation_store.get(conversation_id)
if record is None and bool(payload.get("create_if_not_exists", False)):
created = await self._conversation_new(
_request_id,
{"session": session, "conversation": {}},
_token,
)
record = self._conversation_store.get(
str(created.get("conversation_id", "")).strip()
)
if record is None:
return {"conversation": None}
if str(record.get("session", "")) != session:
return {"conversation": None}
return {"conversation": dict(record)}
async def _conversation_get_current(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = self._session_current_conversation_ids.get(session, "")
if not conversation_id and bool(payload.get("create_if_not_exists", False)):
created = await self._conversation_new(
_request_id,
{"session": session, "conversation": {}},
_token,
)
conversation_id = str(created.get("conversation_id", "")).strip()
if not conversation_id:
return {"conversation": None}
record = self._conversation_store.get(conversation_id)
if record is None or str(record.get("session", "")) != session:
return {"conversation": None}
return {"conversation": dict(record)}
async def _conversation_list(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = payload.get("session")
platform_id = payload.get("platform_id")
conversations = []
for conversation_id in sorted(self._conversation_store.keys()):
item = self._conversation_store[conversation_id]
if session is not None and str(item.get("session", "")) != str(session):
continue
if platform_id is not None and str(item.get("platform_id", "")) != str(
platform_id
):
continue
conversations.append(dict(item))
return {"conversations": conversations}
async def _conversation_update(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", "")).strip()
conversation_id = payload.get("conversation_id")
normalized_conversation_id = (
str(conversation_id).strip() if conversation_id is not None else ""
)
if not normalized_conversation_id:
normalized_conversation_id = self._session_current_conversation_ids.get(
session, ""
)
if not normalized_conversation_id:
return {}
record = self._conversation_store.get(normalized_conversation_id)
if record is None:
return {}
if str(record.get("session", "")) != session:
raise AstrBotError.invalid_input(
"conversation.update requires a conversation in the same session"
)
raw_conversation = payload.get("conversation")
if not isinstance(raw_conversation, dict):
raw_conversation = {}
if "history" in raw_conversation:
history = raw_conversation.get("history")
record["history"] = (
self._normalize_history_payload(history) if history is not None else []
)
if "title" in raw_conversation:
title = raw_conversation.get("title")
record["title"] = str(title) if title is not None else None
if "persona_id" in raw_conversation:
persona_id = raw_conversation.get("persona_id")
record["persona_id"] = str(persona_id) if persona_id is not None else None
if "token_usage" in raw_conversation:
token_usage = raw_conversation.get("token_usage")
record["token_usage"] = (
int(token_usage) if token_usage is not None else None
)
record["updated_at"] = self._now_iso()
return {}
def _register_conversation_capabilities(self) -> None:
self.register(
self._builtin_descriptor("conversation.new", "新建对话"),
call_handler=self._conversation_new,
)
self.register(
self._builtin_descriptor("conversation.switch", "切换对话"),
call_handler=self._conversation_switch,
)
self.register(
self._builtin_descriptor("conversation.delete", "删除对话"),
call_handler=self._conversation_delete,
)
self.register(
self._builtin_descriptor("conversation.get", "获取对话"),
call_handler=self._conversation_get,
)
self.register(
self._builtin_descriptor("conversation.get_current", "获取当前对话"),
call_handler=self._conversation_get_current,
)
self.register(
self._builtin_descriptor("conversation.list", "列出对话"),
call_handler=self._conversation_list,
)
self.register(
self._builtin_descriptor("conversation.update", "更新对话"),
call_handler=self._conversation_update,
)

View File

@@ -0,0 +1,129 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from ....errors import AstrBotError
from ..._streaming import StreamExecution
from ..bridge_base import CapabilityRouterBridgeBase
class DBCapabilityMixin(CapabilityRouterBridgeBase):
async def _db_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
return {"value": self.db_store.get(str(payload.get("key", "")))}
async def _db_set(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
value = payload.get("value")
self.db_store[key] = value
self._emit_db_change(op="set", key=key, value=value)
return {}
async def _db_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
self.db_store.pop(key, None)
self._emit_db_change(op="delete", key=key, value=None)
return {}
async def _db_list(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
prefix = payload.get("prefix")
keys = sorted(self.db_store.keys())
if isinstance(prefix, str):
keys = [item for item in keys if item.startswith(prefix)]
return {"keys": keys}
async def _db_get_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
keys_payload = payload.get("keys")
if not isinstance(keys_payload, (list, tuple)):
raise AstrBotError.invalid_input("db.get_many 的 keys 必须是数组")
keys = [str(item) for item in keys_payload]
items = [{"key": key, "value": self.db_store.get(key)} for key in keys]
return {"items": items}
async def _db_set_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
items_payload = payload.get("items")
if not isinstance(items_payload, (list, tuple)):
raise AstrBotError.invalid_input("db.set_many 的 items 必须是数组")
for entry in items_payload:
if not isinstance(entry, dict):
raise AstrBotError.invalid_input(
"db.set_many 的 items 必须是 object 数组"
)
key = str(entry.get("key", ""))
value = entry.get("value")
self.db_store[key] = value
self._emit_db_change(op="set", key=key, value=value)
return {}
async def _db_watch(
self, request_id: str, payload: dict[str, Any], _token
) -> StreamExecution:
prefix = payload.get("prefix")
prefix_value: str | None
if isinstance(prefix, str):
prefix_value = prefix
elif prefix is None:
prefix_value = None
else:
raise AstrBotError.invalid_input("db.watch 的 prefix 必须是 string 或 null")
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._db_watch_subscriptions[request_id] = (prefix_value, queue)
async def iterator() -> AsyncIterator[dict[str, Any]]:
try:
while True:
yield await queue.get()
finally:
self._db_watch_subscriptions.pop(request_id, None)
return StreamExecution(
iterator=iterator(),
finalize=lambda _chunks: {},
collect_chunks=False,
)
def _register_db_capabilities(self) -> None:
self.register(
self._builtin_descriptor("db.get", "读取 KV"), call_handler=self._db_get
)
self.register(
self._builtin_descriptor("db.set", "写入 KV"), call_handler=self._db_set
)
self.register(
self._builtin_descriptor("db.delete", "删除 KV"),
call_handler=self._db_delete,
)
self.register(
self._builtin_descriptor("db.list", "列出 KV"), call_handler=self._db_list
)
self.register(
self._builtin_descriptor("db.get_many", "批量读取 KV"),
call_handler=self._db_get_many,
)
self.register(
self._builtin_descriptor("db.set_many", "批量写入 KV"),
call_handler=self._db_set_many,
)
self.register(
self._builtin_descriptor(
"db.watch",
"订阅 KV 变更",
supports_stream=True,
cancelable=True,
),
stream_handler=self._db_watch,
)

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class HttpCapabilityMixin(CapabilityRouterBridgeBase):
async def _http_register_api(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
methods_payload = payload.get("methods")
if not isinstance(methods_payload, list) or not all(
isinstance(item, str) for item in methods_payload
):
raise AstrBotError.invalid_input(
"http.register_api 的 methods 必须是 string 数组"
)
route = str(payload.get("route", "")).strip()
handler_capability = str(payload.get("handler_capability", "")).strip()
if not route or not handler_capability:
raise AstrBotError.invalid_input(
"http.register_api 需要 route 和 handler_capability"
)
plugin_name = self._require_caller_plugin_id("http.register_api")
methods = sorted({method.upper() for method in methods_payload if method})
entry: dict[str, Any] = {
"route": route,
"methods": methods,
"handler_capability": handler_capability,
"description": str(payload.get("description", "")),
"plugin_id": plugin_name,
}
self.http_api_store = [
item
for item in self.http_api_store
if not (
item.get("route") == route
and item.get("plugin_id") == entry["plugin_id"]
and item.get("methods") == methods
)
]
self.http_api_store.append(entry)
return {}
async def _http_unregister_api(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
route = str(payload.get("route", "")).strip()
methods_payload = payload.get("methods")
if not isinstance(methods_payload, list) or not all(
isinstance(item, str) for item in methods_payload
):
raise AstrBotError.invalid_input(
"http.unregister_api 的 methods 必须是 string 数组"
)
plugin_name = self._require_caller_plugin_id("http.unregister_api")
methods = {method.upper() for method in methods_payload if method}
updated: list[dict[str, Any]] = []
for entry in self.http_api_store:
if entry.get("route") != route:
updated.append(entry)
continue
if entry.get("plugin_id") != plugin_name:
updated.append(entry)
continue
if not methods:
continue
remaining_methods = [
method for method in entry.get("methods", []) if method not in methods
]
if remaining_methods:
updated.append({**entry, "methods": remaining_methods})
self.http_api_store = updated
return {}
async def _http_list_apis(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_name = self._require_caller_plugin_id("http.list_apis")
apis = [
dict(entry)
for entry in self.http_api_store
if entry.get("plugin_id") == plugin_name
]
return {"apis": apis}
def _register_http_capabilities(self) -> None:
self.register(
self._builtin_descriptor("http.register_api", "注册 HTTP 路由"),
call_handler=self._http_register_api,
)
self.register(
self._builtin_descriptor("http.unregister_api", "注销 HTTP 路由"),
call_handler=self._http_unregister_api,
)
self.register(
self._builtin_descriptor("http.list_apis", "列出 HTTP 路由"),
call_handler=self._http_list_apis,
)

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import uuid
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class KnowledgeBaseCapabilityMixin(CapabilityRouterBridgeBase):
async def _kb_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
kb_id = str(payload.get("kb_id", "")).strip()
record = self._kb_store.get(kb_id)
return {"kb": dict(record) if isinstance(record, dict) else None}
async def _kb_create(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
raw_kb = payload.get("kb")
if not isinstance(raw_kb, dict):
raise AstrBotError.invalid_input("kb.create requires kb object")
embedding_provider_id = str(raw_kb.get("embedding_provider_id", "")).strip()
if not embedding_provider_id:
raise AstrBotError.invalid_input("kb.create requires embedding_provider_id")
kb_id = uuid.uuid4().hex
now = self._now_iso()
record = {
"kb_id": kb_id,
"kb_name": str(raw_kb.get("kb_name", "")),
"description": (
str(raw_kb.get("description"))
if raw_kb.get("description") is not None
else None
),
"emoji": (
str(raw_kb.get("emoji")) if raw_kb.get("emoji") is not None else None
),
"embedding_provider_id": embedding_provider_id,
"rerank_provider_id": (
str(raw_kb.get("rerank_provider_id"))
if raw_kb.get("rerank_provider_id") is not None
else None
),
"chunk_size": self._optional_int(raw_kb.get("chunk_size")),
"chunk_overlap": self._optional_int(raw_kb.get("chunk_overlap")),
"top_k_dense": self._optional_int(raw_kb.get("top_k_dense")),
"top_k_sparse": self._optional_int(raw_kb.get("top_k_sparse")),
"top_m_final": self._optional_int(raw_kb.get("top_m_final")),
"doc_count": 0,
"chunk_count": 0,
"created_at": now,
"updated_at": now,
}
self._kb_store[kb_id] = record
return {"kb": dict(record)}
async def _kb_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
kb_id = str(payload.get("kb_id", "")).strip()
deleted = self._kb_store.pop(kb_id, None) is not None
return {"deleted": deleted}
def _register_kb_capabilities(self) -> None:
self.register(
self._builtin_descriptor("kb.get", "获取知识库"),
call_handler=self._kb_get,
)
self.register(
self._builtin_descriptor("kb.create", "创建知识库"),
call_handler=self._kb_create,
)
self.register(
self._builtin_descriptor("kb.delete", "删除知识库"),
call_handler=self._kb_delete,
)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from ..bridge_base import CapabilityRouterBridgeBase
class LLMCapabilityMixin(CapabilityRouterBridgeBase):
async def _llm_chat(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
prompt = str(payload.get("prompt", ""))
return {"text": f"Echo: {prompt}"}
async def _llm_chat_raw(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
prompt = str(payload.get("prompt", ""))
text = f"Echo: {prompt}"
return {
"text": text,
"usage": {
"input_tokens": len(prompt),
"output_tokens": len(text),
},
"finish_reason": "stop",
"tool_calls": [],
}
async def _llm_stream(
self,
_request_id: str,
payload: dict[str, Any],
token,
) -> AsyncIterator[dict[str, Any]]:
text = f"Echo: {str(payload.get('prompt', ''))}"
for char in text:
token.raise_if_cancelled()
await asyncio.sleep(0)
yield {"text": char}
def _register_llm_capabilities(self) -> None:
self.register(
self._builtin_descriptor("llm.chat", "发送对话请求,返回文本"),
call_handler=self._llm_chat,
)
self.register(
self._builtin_descriptor("llm.chat_raw", "发送对话请求,返回完整响应"),
call_handler=self._llm_chat_raw,
)
self.register(
self._builtin_descriptor(
"llm.stream_chat",
"流式对话",
supports_stream=True,
cancelable=True,
),
stream_handler=self._llm_stream,
finalize=lambda chunks: {
"text": "".join(item.get("text", "") for item in chunks)
},
)

View File

@@ -0,0 +1,618 @@
from __future__ import annotations
import json
import math
from datetime import datetime, timedelta, timezone
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class MemoryCapabilityMixin(CapabilityRouterBridgeBase):
@staticmethod
def _is_ttl_memory_entry(value: Any) -> bool:
"""判断存储值是否使用了 TTL 包装结构。
Args:
value: 待检查的存储值。
Returns:
bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
"""
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
@classmethod
def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
"""提取用于检索的原始 memory payload。
Args:
stored: memory_store 中保存的原始值。
Returns:
dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
"""
if not isinstance(stored, dict):
return None
if cls._is_ttl_memory_entry(stored):
value = stored.get("value")
return value if isinstance(value, dict) else None
return stored
@classmethod
def _extract_memory_text(cls, stored: Any) -> str:
"""提取用于检索索引的首选文本。
Args:
stored: memory_store 中保存的原始值。
Returns:
str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
"""
value = cls._memory_value_for_search(stored)
if not isinstance(value, dict):
return ""
for field_name in ("embedding_text", "content", "summary", "title", "text"):
item = value.get(field_name)
if isinstance(item, str) and item.strip():
return item.strip()
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
@staticmethod
def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
"""将 TTL 秒数转换为 UTC 过期时间。
Args:
ttl_seconds: TTL 秒数。
Returns:
datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
"""
try:
ttl = int(ttl_seconds)
except (TypeError, ValueError):
return None
if ttl < 1:
return None
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
@staticmethod
def _memory_keyword_score(query: str, key: str, text: str) -> float:
"""计算关键词匹配分数。
Args:
query: 查询文本。
key: memory 条目的键。
text: 已索引的检索文本。
Returns:
float: 基于键名和文本命中的粗粒度关键词分数。
"""
normalized_query = str(query).casefold()
if not normalized_query:
return 1.0
normalized_key = str(key).casefold()
normalized_text = str(text).casefold()
if normalized_query in normalized_key:
return 1.0
if normalized_query in normalized_text:
return 0.9
return 0.0
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
"""计算两个向量之间的余弦相似度。
Args:
left: 左侧向量。
right: 右侧向量。
Returns:
float: 余弦相似度;输入不合法时返回 ``0.0``。
"""
if not left or not right or len(left) != len(right):
return 0.0
left_norm = math.sqrt(sum(value * value for value in left))
right_norm = math.sqrt(sum(value * value for value in right))
if left_norm <= 0 or right_norm <= 0:
return 0.0
return sum(a * b for a, b in zip(left, right, strict=False)) / (
left_norm * right_norm
)
def _resolve_memory_embedding_provider_id(
self,
provider_id: Any,
*,
required: bool,
) -> str | None:
"""解析 memory.search 要使用的 embedding provider。
Args:
provider_id: 调用方显式传入的 provider 标识。
required: 当前检索模式是否强制要求 embedding provider。
Returns:
str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。
"""
normalized = str(provider_id).strip() if provider_id is not None else ""
if normalized:
self._provider_entry(
{"provider_id": normalized},
"memory.search",
"embedding",
)
return normalized
active_id = self._active_provider_ids.get("embedding")
if active_id is not None:
normalized_active = str(active_id).strip()
if normalized_active:
self._provider_entry(
{"provider_id": normalized_active},
"memory.search",
"embedding",
)
return normalized_active
if required:
raise AstrBotError.invalid_input(
"memory.search requires an embedding provider",
)
return None
@staticmethod
def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
"""将原始索引项规范化为内部统一结构。
Args:
entry: 当前索引表中的原始项。
text: 当前条目的索引文本。
Returns:
dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
"""
if isinstance(entry, dict):
return {
"text": str(entry.get("text", text)),
"embedding": (
[float(item) for item in entry.get("embedding", [])]
if isinstance(entry.get("embedding"), list)
else None
),
"provider_id": (
str(entry.get("provider_id")).strip()
if entry.get("provider_id") is not None
else None
),
}
return {"text": text, "embedding": None, "provider_id": None}
def _clear_memory_sidecars(self, key: str) -> None:
"""清理指定 memory 键对应的所有 sidecar 状态。
Args:
key: memory 条目的键。
Returns:
None
"""
self._memory_index.pop(key, None)
self._memory_expires_at.pop(key, None)
self._memory_dirty_keys.discard(key)
def _delete_memory_entry(self, key: str) -> bool:
"""删除 memory 条目并同步清理 sidecar 状态。
Args:
key: memory 条目的键。
Returns:
bool: 条目存在并删除成功时返回 ``True``。
"""
deleted = self.memory_store.pop(key, None) is not None
self._clear_memory_sidecars(key)
return deleted
def _upsert_memory_sidecars(
self,
key: str,
stored: dict[str, Any],
*,
expires_at: datetime | None = None,
) -> None:
"""创建或更新单条 memory 的 sidecar 索引状态。
Args:
key: memory 条目的键。
stored: 需要建立索引的原始存储值。
expires_at: 可选的绝对过期时间。
Returns:
None
"""
self._memory_index[key] = {
"text": self._extract_memory_text(stored),
"embedding": None,
"provider_id": None,
}
if expires_at is None:
self._memory_expires_at.pop(key, None)
else:
self._memory_expires_at[key] = expires_at
self._memory_dirty_keys.add(key)
def _ensure_memory_sidecars(self, key: str, stored: Any) -> None:
"""确保 sidecar 状态与当前存储值保持一致。
Args:
key: memory 条目的键。
stored: memory_store 中的当前存储值。
Returns:
None
"""
if not isinstance(stored, dict):
return
text = self._extract_memory_text(stored)
existed = key in self._memory_index
entry = self._memory_index_entry(self._memory_index.get(key), text=text)
if entry["text"] != text:
entry["text"] = text
entry["embedding"] = None
entry["provider_id"] = None
self._memory_dirty_keys.add(key)
self._memory_index[key] = entry
if not existed:
self._memory_dirty_keys.add(key)
def _is_memory_expired(self, key: str) -> bool:
"""判断 memory 条目是否已过期。
Args:
key: memory 条目的键。
Returns:
bool: 如果当前时间已超过记录的过期时间则返回 ``True``。
"""
expires_at = self._memory_expires_at.get(key)
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
def _purge_expired_memory_entry(self, key: str) -> bool:
"""在单条 memory 已过期时立即清理它。
Args:
key: memory 条目的键。
Returns:
bool: 如果条目已过期并被成功清理则返回 ``True``。
"""
if not self._is_memory_expired(key):
return False
self._delete_memory_entry(key)
return True
def _purge_expired_memory_entries(self) -> None:
"""批量清理所有已跟踪的过期 TTL 条目。
Returns:
None
"""
for key in list(self._memory_expires_at):
self._purge_expired_memory_entry(key)
async def _embedding_for_text(
self,
*,
provider_id: str,
text: str,
) -> list[float]:
"""通过 embedding capability 获取单条文本向量。
Args:
provider_id: 使用的 embedding provider 标识。
text: 待向量化的文本。
Returns:
list[float]: provider 返回的向量;异常场景下返回空列表。
"""
output = await self._provider_embedding_get_embedding(
"",
{"provider_id": provider_id, "text": text},
None,
)
embedding = output.get("embedding")
if not isinstance(embedding, list):
return []
return [float(item) for item in embedding]
async def _embeddings_for_texts(
self,
*,
provider_id: str,
texts: list[str],
) -> list[list[float]]:
"""批量获取多条文本的 embedding 向量。
Args:
provider_id: 使用的 embedding provider 标识。
texts: 待向量化的文本列表。
Returns:
list[list[float]]: 与输入顺序对应的向量列表。
"""
if not texts:
return []
output = await self._provider_embedding_get_embeddings(
"",
{"provider_id": provider_id, "texts": texts},
None,
)
embeddings = output.get("embeddings")
if not isinstance(embeddings, list):
return []
return [
[float(value) for value in item]
for item in embeddings
if isinstance(item, list)
]
async def _refresh_memory_embeddings(self, *, provider_id: str) -> None:
"""刷新当前 provider 下脏或过期的 memory 向量索引。
Args:
provider_id: 当前使用的 embedding provider 标识。
Returns:
None
"""
keys_to_refresh: list[str] = []
texts_to_refresh: list[str] = []
for key, stored in self.memory_store.items():
self._ensure_memory_sidecars(key, stored)
entry = self._memory_index_entry(
self._memory_index.get(key),
text=self._extract_memory_text(stored),
)
should_refresh = (
key in self._memory_dirty_keys
or entry["embedding"] is None
or entry["provider_id"] != provider_id
)
self._memory_index[key] = entry
if should_refresh:
keys_to_refresh.append(key)
texts_to_refresh.append(str(entry["text"]))
embeddings = await self._embeddings_for_texts(
provider_id=provider_id,
texts=texts_to_refresh,
)
for index, key in enumerate(keys_to_refresh):
entry = self._memory_index_entry(
self._memory_index.get(key),
text=str(texts_to_refresh[index]),
)
entry["embedding"] = embeddings[index] if index < len(embeddings) else []
entry["provider_id"] = provider_id
self._memory_index[key] = entry
self._memory_dirty_keys.discard(key)
async def _memory_search(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
query = str(payload.get("query", ""))
mode = str(payload.get("mode", "auto")).strip().lower() or "auto"
limit = self._optional_int(payload.get("limit"))
raw_min_score = payload.get("min_score")
min_score = float(raw_min_score) if raw_min_score is not None else None
self._purge_expired_memory_entries()
provider_id = self._resolve_memory_embedding_provider_id(
payload.get("provider_id"),
required=mode in {"vector", "hybrid"},
)
effective_mode = mode
if effective_mode == "auto":
effective_mode = "hybrid" if provider_id is not None else "keyword"
query_embedding: list[float] | None = None
if effective_mode in {"vector", "hybrid"}:
if provider_id is None:
raise AstrBotError.invalid_input(
"memory.search requires an embedding provider",
)
await self._refresh_memory_embeddings(provider_id=provider_id)
query_embedding = await self._embedding_for_text(
provider_id=provider_id,
text=query,
)
items: list[dict[str, Any]] = []
for key, value in self.memory_store.items():
self._ensure_memory_sidecars(key, value)
entry = self._memory_index_entry(
self._memory_index.get(key),
text=self._extract_memory_text(value),
)
text = str(entry.get("text", ""))
keyword_score = self._memory_keyword_score(query, key, text)
vector_score = 0.0
if query_embedding is not None:
embedding = entry.get("embedding")
if isinstance(embedding, list):
vector_score = max(
0.0,
self._cosine_similarity(query_embedding, embedding),
)
if effective_mode == "keyword":
score = keyword_score
elif effective_mode == "vector":
score = vector_score
else:
score = vector_score
if keyword_score > 0:
score = max(score, 0.4 + 0.6 * vector_score)
if score <= 0:
continue
if min_score is not None and score < min_score:
continue
if effective_mode == "keyword" or (keyword_score > 0 and vector_score <= 0):
match_type = "keyword"
elif effective_mode == "vector" or keyword_score <= 0:
match_type = "vector"
else:
match_type = "hybrid"
items.append(
{
"key": key,
"value": self._memory_value_for_search(value),
"score": score,
"match_type": match_type,
}
)
items.sort(key=lambda item: (-float(item["score"]), str(item["key"])))
if limit is not None and limit >= 0:
items = items[:limit]
return {"items": items}
async def _memory_save(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
value = payload.get("value")
if not isinstance(value, dict):
raise AstrBotError.invalid_input("memory.save 的 value 必须是 object")
self.memory_store[key] = value
self._upsert_memory_sidecars(key, value)
return {}
async def _memory_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
if self._purge_expired_memory_entry(key):
return {"value": None}
return {"value": self.memory_store.get(key)}
async def _memory_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._delete_memory_entry(str(payload.get("key", "")))
return {}
async def _memory_save_with_ttl(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
key = str(payload.get("key", ""))
value = payload.get("value")
ttl_seconds = payload.get("ttl_seconds", 0)
if not isinstance(value, dict):
raise AstrBotError.invalid_input(
"memory.save_with_ttl 的 value 必须是 object"
)
stored = {"value": value, "ttl_seconds": ttl_seconds}
self.memory_store[key] = stored
self._upsert_memory_sidecars(
key,
stored,
expires_at=self._memory_expiration_from_ttl(ttl_seconds),
)
return {}
async def _memory_get_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
keys_payload = payload.get("keys")
if not isinstance(keys_payload, (list, tuple)):
raise AstrBotError.invalid_input("memory.get_many 的 keys 必须是数组")
keys = [str(item) for item in keys_payload]
items = []
for key in keys:
if self._purge_expired_memory_entry(key):
items.append({"key": key, "value": None})
continue
stored = self.memory_store.get(key)
if (
isinstance(stored, dict)
and "value" in stored
and "ttl_seconds" in stored
):
value = stored["value"]
else:
value = stored
items.append({"key": key, "value": value})
return {"items": items}
async def _memory_delete_many(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
keys_payload = payload.get("keys")
if not isinstance(keys_payload, (list, tuple)):
raise AstrBotError.invalid_input("memory.delete_many 的 keys 必须是数组")
keys = [str(item) for item in keys_payload]
deleted_count = 0
for key in keys:
if self._delete_memory_entry(key):
deleted_count += 1
return {"deleted_count": deleted_count}
async def _memory_stats(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
self._purge_expired_memory_entries()
total_items = len(self.memory_store)
total_bytes = sum(
len(str(key)) + len(str(value)) for key, value in self.memory_store.items()
)
ttl_entries = len(self._memory_expires_at)
indexed_items = len(self._memory_index)
embedded_items = sum(
1
for entry in self._memory_index.values()
if isinstance(entry, dict)
and isinstance(entry.get("embedding"), list)
and bool(entry.get("embedding"))
)
dirty_items = len(self._memory_dirty_keys)
return {
"total_items": total_items,
"total_bytes": total_bytes,
"plugin_id": self._require_caller_plugin_id("memory.stats"),
"ttl_entries": ttl_entries,
"indexed_items": indexed_items,
"embedded_items": embedded_items,
"dirty_items": dirty_items,
}
def _register_memory_capabilities(self) -> None:
self.register(
self._builtin_descriptor("memory.search", "搜索记忆"),
call_handler=self._memory_search,
)
self.register(
self._builtin_descriptor("memory.save", "保存记忆"),
call_handler=self._memory_save,
)
self.register(
self._builtin_descriptor("memory.get", "读取单条记忆"),
call_handler=self._memory_get,
)
self.register(
self._builtin_descriptor("memory.delete", "删除记忆"),
call_handler=self._memory_delete,
)
self.register(
self._builtin_descriptor("memory.save_with_ttl", "保存带过期时间的记忆"),
call_handler=self._memory_save_with_ttl,
)
self.register(
self._builtin_descriptor("memory.get_many", "批量获取记忆"),
call_handler=self._memory_get_many,
)
self.register(
self._builtin_descriptor("memory.delete_many", "批量删除记忆"),
call_handler=self._memory_delete_many,
)
self.register(
self._builtin_descriptor("memory.stats", "获取记忆统计信息"),
call_handler=self._memory_stats,
)

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Any
from ..bridge_base import CapabilityRouterBridgeBase
class MetadataCapabilityMixin(CapabilityRouterBridgeBase):
async def _metadata_get_plugin(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
name = str(payload.get("name", "")).strip()
plugin = self._plugins.get(name)
if plugin is None:
return {"plugin": None}
return {"plugin": dict(plugin.metadata)}
async def _metadata_list_plugins(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
plugins = [
dict(self._plugins[name].metadata) for name in sorted(self._plugins.keys())
]
return {"plugins": plugins}
async def _metadata_get_plugin_config(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
name = str(payload.get("name", "")).strip()
caller_plugin_id = self._require_caller_plugin_id("metadata.get_plugin_config")
if name != caller_plugin_id:
return {"config": None}
plugin = self._plugins.get(name)
if plugin is None:
return {"config": None}
return {"config": dict(plugin.config)}
def _register_metadata_capabilities(self) -> None:
self.register(
self._builtin_descriptor("metadata.get_plugin", "获取单个插件元数据"),
call_handler=self._metadata_get_plugin,
)
self.register(
self._builtin_descriptor("metadata.list_plugins", "列出插件元数据"),
call_handler=self._metadata_list_plugins,
)
self.register(
self._builtin_descriptor(
"metadata.get_plugin_config",
"获取插件配置",
),
call_handler=self._metadata_get_plugin_config,
)

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class PersonaCapabilityMixin(CapabilityRouterBridgeBase):
async def _persona_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
persona_id = str(payload.get("persona_id", "")).strip()
record = self._persona_store.get(persona_id)
if record is None:
raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
return {"persona": dict(record)}
async def _persona_list(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
personas = [
dict(self._persona_store[persona_id])
for persona_id in sorted(self._persona_store.keys())
]
return {"personas": personas}
async def _persona_create(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
raw_persona = payload.get("persona")
if not isinstance(raw_persona, dict):
raise AstrBotError.invalid_input("persona.create requires persona object")
persona_id = str(raw_persona.get("persona_id", "")).strip()
if not persona_id:
raise AstrBotError.invalid_input("persona.create requires persona_id")
if persona_id in self._persona_store:
raise AstrBotError.invalid_input(f"persona already exists: {persona_id}")
now = self._now_iso()
record = {
"persona_id": persona_id,
"system_prompt": str(raw_persona.get("system_prompt", "")),
"begin_dialogs": self._normalize_persona_dialogs_payload(
raw_persona.get("begin_dialogs")
),
"tools": (
[str(item) for item in raw_persona.get("tools", [])]
if isinstance(raw_persona.get("tools"), list)
else None
),
"skills": (
[str(item) for item in raw_persona.get("skills", [])]
if isinstance(raw_persona.get("skills"), list)
else None
),
"custom_error_message": (
str(raw_persona.get("custom_error_message"))
if raw_persona.get("custom_error_message") is not None
else None
),
"folder_id": (
str(raw_persona.get("folder_id"))
if raw_persona.get("folder_id") is not None
else None
),
"sort_order": int(raw_persona.get("sort_order", 0)),
"created_at": now,
"updated_at": now,
}
self._persona_store[persona_id] = record
return {"persona": dict(record)}
async def _persona_update(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
persona_id = str(payload.get("persona_id", "")).strip()
record = self._persona_store.get(persona_id)
if record is None:
return {"persona": None}
raw_persona = payload.get("persona")
if not isinstance(raw_persona, dict):
raise AstrBotError.invalid_input("persona.update requires persona object")
if (
"system_prompt" in raw_persona
and raw_persona.get("system_prompt") is not None
):
record["system_prompt"] = str(raw_persona.get("system_prompt", ""))
if "begin_dialogs" in raw_persona:
begin_dialogs = raw_persona.get("begin_dialogs")
record["begin_dialogs"] = (
self._normalize_persona_dialogs_payload(begin_dialogs)
if begin_dialogs is not None
else []
)
if "tools" in raw_persona:
tools = raw_persona.get("tools")
record["tools"] = (
[str(item) for item in tools] if isinstance(tools, list) else None
)
if "skills" in raw_persona:
skills = raw_persona.get("skills")
record["skills"] = (
[str(item) for item in skills] if isinstance(skills, list) else None
)
if "custom_error_message" in raw_persona:
custom_error_message = raw_persona.get("custom_error_message")
record["custom_error_message"] = (
str(custom_error_message) if custom_error_message is not None else None
)
record["updated_at"] = self._now_iso()
return {"persona": dict(record)}
async def _persona_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
persona_id = str(payload.get("persona_id", "")).strip()
if persona_id not in self._persona_store:
raise AstrBotError.invalid_input(f"persona not found: {persona_id}")
del self._persona_store[persona_id]
return {}
def _register_persona_capabilities(self) -> None:
self.register(
self._builtin_descriptor("persona.get", "获取人格"),
call_handler=self._persona_get,
)
self.register(
self._builtin_descriptor("persona.list", "列出人格"),
call_handler=self._persona_list,
)
self.register(
self._builtin_descriptor("persona.create", "创建人格"),
call_handler=self._persona_create,
)
self.register(
self._builtin_descriptor("persona.update", "更新人格"),
call_handler=self._persona_update,
)
self.register(
self._builtin_descriptor("persona.delete", "删除人格"),
call_handler=self._persona_delete,
)

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class PlatformCapabilityMixin(CapabilityRouterBridgeBase):
async def _platform_send(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, target = self._resolve_target(payload)
text = str(payload.get("text", ""))
message_id = f"msg_{len(self.sent_messages) + 1}"
sent: dict[str, Any] = {
"message_id": message_id,
"session": session,
"text": text,
}
if target is not None:
sent["target"] = target
self.sent_messages.append(sent)
return {"message_id": message_id}
async def _platform_send_image(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, target = self._resolve_target(payload)
image_url = str(payload.get("image_url", ""))
message_id = f"img_{len(self.sent_messages) + 1}"
sent: dict[str, Any] = {
"message_id": message_id,
"session": session,
"image_url": image_url,
}
if target is not None:
sent["target"] = target
self.sent_messages.append(sent)
return {"message_id": message_id}
async def _platform_send_chain(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, target = self._resolve_target(payload)
chain = payload.get("chain")
if not isinstance(chain, list) or not all(
isinstance(item, dict) for item in chain
):
raise AstrBotError.invalid_input(
"platform.send_chain 的 chain 必须是 object 数组"
)
message_id = f"chain_{len(self.sent_messages) + 1}"
sent: dict[str, Any] = {
"message_id": message_id,
"session": session,
"chain": [dict(item) for item in chain],
}
if target is not None:
sent["target"] = target
self.sent_messages.append(sent)
return {"message_id": message_id}
async def _platform_send_by_session(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
chain = payload.get("chain")
if not isinstance(chain, list) or not all(
isinstance(item, dict) for item in chain
):
raise AstrBotError.invalid_input(
"platform.send_by_session 的 chain 必须是 object 数组"
)
session = str(payload.get("session", ""))
message_id = f"proactive_{len(self.sent_messages) + 1}"
self.sent_messages.append(
{
"message_id": message_id,
"session": session,
"chain": [dict(item) for item in chain],
}
)
return {"message_id": message_id}
async def _platform_get_group(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, _target = self._resolve_target(payload)
return {"group": self._mock_group_payload(session)}
async def _platform_get_members(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session, _target = self._resolve_target(payload)
group = self._mock_group_payload(session)
if group is None:
return {"members": []}
return {"members": list(group.get("members", []))}
async def _platform_list_instances(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
return {
"platforms": [
{
"id": str(item.get("id", "")),
"name": str(item.get("name", "")),
"type": str(item.get("type", "")),
"status": str(item.get("status", "unknown")),
}
for item in self.get_platform_instances()
if isinstance(item, dict)
]
}
def _register_platform_capabilities(self) -> None:
self.register(
self._builtin_descriptor("platform.send", "发送消息"),
call_handler=self._platform_send,
)
self.register(
self._builtin_descriptor("platform.send_image", "发送图片"),
call_handler=self._platform_send_image,
)
self.register(
self._builtin_descriptor("platform.send_chain", "发送消息链"),
call_handler=self._platform_send_chain,
)
self.register(
self._builtin_descriptor(
"platform.send_by_session", "按会话主动发送消息链"
),
call_handler=self._platform_send_by_session,
)
self.register(
self._builtin_descriptor("platform.get_group", "获取当前群信息"),
call_handler=self._platform_get_group,
)
self.register(
self._builtin_descriptor("platform.get_members", "获取群成员"),
call_handler=self._platform_get_members,
)
self.register(
self._builtin_descriptor("platform.list_instances", "列出平台实例元信息"),
call_handler=self._platform_list_instances,
)
async def _platform_manager_get_by_id(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._require_reserved_plugin("platform.manager.get_by_id")
platform_id = str(payload.get("platform_id", "")).strip()
platform = next(
(
dict(item)
for item in self._platform_instances
if str(item.get("id", "")) == platform_id
),
None,
)
return {"platform": platform}
async def _platform_manager_clear_errors(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._require_reserved_plugin("platform.manager.clear_errors")
platform_id = str(payload.get("platform_id", "")).strip()
for item in self._platform_instances:
if str(item.get("id", "")) != platform_id:
continue
item["errors"] = []
item["last_error"] = None
if str(item.get("status", "")) == "error":
item["status"] = "running"
break
return {}
async def _platform_manager_get_stats(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._require_reserved_plugin("platform.manager.get_stats")
platform_id = str(payload.get("platform_id", "")).strip()
for item in self._platform_instances:
if str(item.get("id", "")) != platform_id:
continue
stats = item.get("stats")
if isinstance(stats, dict):
return {"stats": dict(stats)}
errors = item.get("errors")
last_error = item.get("last_error")
meta = item.get("meta")
return {
"stats": {
"id": platform_id,
"type": str(item.get("type", "")),
"display_name": str(item.get("name", platform_id)),
"status": str(item.get("status", "pending")),
"started_at": item.get("started_at"),
"error_count": len(errors) if isinstance(errors, list) else 0,
"last_error": dict(last_error)
if isinstance(last_error, dict)
else None,
"unified_webhook": bool(item.get("unified_webhook", False)),
"meta": dict(meta) if isinstance(meta, dict) else {},
}
}
return {"stats": None}
def _register_platform_manager_capabilities(self) -> None:
self.register(
self._builtin_descriptor(
"platform.manager.get_by_id",
"按 ID 获取平台管理快照",
),
call_handler=self._platform_manager_get_by_id,
)
self.register(
self._builtin_descriptor(
"platform.manager.clear_errors",
"清除平台错误",
),
call_handler=self._platform_manager_clear_errors,
)
self.register(
self._builtin_descriptor(
"platform.manager.get_stats",
"获取平台统计信息",
),
call_handler=self._platform_manager_get_stats,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import CapabilityRouterBridgeBase
class SessionCapabilityMixin(CapabilityRouterBridgeBase):
async def _session_plugin_is_enabled(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
plugin_name = str(payload.get("plugin_name", ""))
config = self._session_plugin_config(session)
enabled_plugins = {
str(item) for item in config.get("enabled_plugins", []) if str(item).strip()
}
disabled_plugins = {
str(item)
for item in config.get("disabled_plugins", [])
if str(item).strip()
}
if plugin_name in enabled_plugins:
return {"enabled": True}
return {"enabled": plugin_name not in disabled_plugins}
async def _session_plugin_filter_handlers(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
handlers = payload.get("handlers")
if not isinstance(handlers, list):
raise AstrBotError.invalid_input(
"session.plugin.filter_handlers 的 handlers 必须是 object 数组"
)
disabled_plugins = {
str(item)
for item in self._session_plugin_config(session).get("disabled_plugins", [])
if str(item).strip()
}
reserved_plugins = {
str(plugin.metadata.get("name", ""))
for plugin in self._plugins.values()
if bool(plugin.metadata.get("reserved", False))
}
filtered = []
for item in handlers:
if not isinstance(item, dict):
continue
plugin_name = str(item.get("plugin_name", ""))
if (
plugin_name
and plugin_name in disabled_plugins
and plugin_name not in reserved_plugins
):
continue
filtered.append(dict(item))
return {"handlers": filtered}
async def _session_service_is_llm_enabled(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
return {"enabled": bool(config.get("llm_enabled", True))}
async def _session_service_set_llm_status(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
config["llm_enabled"] = bool(payload.get("enabled", False))
self._session_service_configs[session] = config
return {}
async def _session_service_is_tts_enabled(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
return {"enabled": bool(config.get("tts_enabled", True))}
async def _session_service_set_tts_status(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
session = str(payload.get("session", ""))
config = self._session_service_config(session)
config["tts_enabled"] = bool(payload.get("enabled", False))
self._session_service_configs[session] = config
return {}
def _register_session_capabilities(self) -> None:
self.register(
self._builtin_descriptor("session.plugin.is_enabled", "获取会话级插件开关"),
call_handler=self._session_plugin_is_enabled,
)
self.register(
self._builtin_descriptor(
"session.plugin.filter_handlers",
"按会话过滤 handler 元数据",
),
call_handler=self._session_plugin_filter_handlers,
)
self.register(
self._builtin_descriptor(
"session.service.is_llm_enabled",
"获取会话级 LLM 开关",
),
call_handler=self._session_service_is_llm_enabled,
)
self.register(
self._builtin_descriptor(
"session.service.set_llm_status",
"写入会话级 LLM 开关",
),
call_handler=self._session_service_set_llm_status,
)
self.register(
self._builtin_descriptor(
"session.service.is_tts_enabled",
"获取会话级 TTS 开关",
),
call_handler=self._session_service_is_tts_enabled,
)
self.register(
self._builtin_descriptor(
"session.service.set_tts_status",
"写入会话级 TTS 开关",
),
call_handler=self._session_service_set_tts_status,
)

View File

@@ -0,0 +1,454 @@
from __future__ import annotations
import json
import uuid
from typing import Any
from ....errors import AstrBotError
from ..bridge_base import (
CapabilityRouterBridgeBase,
_clone_chain_payload,
_clone_target_payload,
)
class SystemCapabilityMixin(CapabilityRouterBridgeBase):
def _register_system_capabilities(self) -> None:
self.register(
self._builtin_descriptor("system.get_data_dir", "获取插件数据目录"),
call_handler=self._system_get_data_dir,
exposed=False,
)
self.register(
self._builtin_descriptor("system.text_to_image", "文本转图片"),
call_handler=self._system_text_to_image,
exposed=False,
)
self.register(
self._builtin_descriptor("system.html_render", "渲染 HTML 模板"),
call_handler=self._system_html_render,
exposed=False,
)
self.register(
self._builtin_descriptor("system.file.register", "注册文件令牌"),
call_handler=self._system_file_register,
exposed=False,
)
self.register(
self._builtin_descriptor("system.file.handle", "解析文件令牌"),
call_handler=self._system_file_handle,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.session_waiter.register",
"注册会话等待器",
),
call_handler=self._system_session_waiter_register,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.session_waiter.unregister",
"注销会话等待器",
),
call_handler=self._system_session_waiter_unregister,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.react", "发送事件表情回应"),
call_handler=self._system_event_react,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.send_typing", "发送输入中状态"),
call_handler=self._system_event_send_typing,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.send_streaming",
"发送事件流式消息",
),
call_handler=self._system_event_send_streaming,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.send_streaming_chunk",
"推送事件流式消息分片",
),
call_handler=self._system_event_send_streaming_chunk,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.send_streaming_close",
"关闭事件流式消息会话",
),
call_handler=self._system_event_send_streaming_close,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.llm.get_state",
"读取当前请求的默认 LLM 状态",
),
call_handler=self._system_event_llm_get_state,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.llm.request",
"请求当前事件继续进入默认 LLM 链路",
),
call_handler=self._system_event_llm_request,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.result.get", "读取当前请求结果"),
call_handler=self._system_event_result_get,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.result.set", "写入当前请求结果"),
call_handler=self._system_event_result_set,
exposed=False,
)
self.register(
self._builtin_descriptor("system.event.result.clear", "清理当前请求结果"),
call_handler=self._system_event_result_clear,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.handler_whitelist.get",
"读取当前请求 handler 白名单",
),
call_handler=self._system_event_handler_whitelist_get,
exposed=False,
)
self.register(
self._builtin_descriptor(
"system.event.handler_whitelist.set",
"写入当前请求 handler 白名单",
),
call_handler=self._system_event_handler_whitelist_set,
exposed=False,
)
self.register(
self._builtin_descriptor(
"registry.get_handlers_by_event_type",
"按事件类型列出 handler 元数据",
),
call_handler=self._registry_get_handlers_by_event_type,
)
self.register(
self._builtin_descriptor(
"registry.get_handler_by_full_name",
"按 full name 查询 handler 元数据",
),
call_handler=self._registry_get_handler_by_full_name,
)
self.register(
self._builtin_descriptor(
"registry.command.register",
"注册动态命令路由",
),
call_handler=self._registry_command_register,
)
def _ensure_request_overlay(self, request_id: str) -> dict[str, Any]:
overlay = self._request_overlays.get(request_id)
if overlay is None:
overlay = {
"should_call_llm": False,
"requested_llm": False,
"result": None,
"handler_whitelist": None,
}
self._request_overlays[request_id] = overlay
return overlay
async def _system_get_data_dir(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_id = self._require_caller_plugin_id("system.get_data_dir")
data_dir = self._system_data_root / plugin_id
data_dir.mkdir(parents=True, exist_ok=True)
return {"path": str(data_dir)}
async def _system_text_to_image(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
text = str(payload.get("text", ""))
if bool(payload.get("return_url", True)):
return {"result": f"mock://text_to_image/{text}"}
return {"result": f"<image>{text}</image>"}
async def _system_html_render(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
tmpl = str(payload.get("tmpl", ""))
data = payload.get("data")
if not isinstance(data, dict):
raise AstrBotError.invalid_input("system.html_render requires object data")
if bool(payload.get("return_url", True)):
return {"result": f"mock://html_render/{tmpl}"}
return {"result": json.dumps({"tmpl": tmpl, "data": data}, ensure_ascii=False)}
async def _system_file_register(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
path = str(payload.get("path", "")).strip()
if not path:
raise AstrBotError.invalid_input("system.file.register requires path")
file_token = uuid.uuid4().hex
self._file_token_store[file_token] = path
return {"token": file_token, "url": f"mock://file/{file_token}"}
async def _system_file_handle(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
file_token = str(payload.get("token", "")).strip()
if not file_token:
raise AstrBotError.invalid_input("system.file.handle requires token")
path = self._file_token_store.pop(file_token, None)
if path is None:
raise AstrBotError.invalid_input(f"Unknown file token: {file_token}")
return {"path": path}
async def _system_event_llm_get_state(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
return {
"should_call_llm": bool(overlay["should_call_llm"]),
"requested_llm": bool(overlay["requested_llm"]),
}
async def _system_event_llm_request(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
overlay["requested_llm"] = True
overlay["should_call_llm"] = True
return await self._system_event_llm_get_state(request_id, {}, _token)
async def _system_event_result_get(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
result = overlay.get("result")
return {"result": dict(result) if isinstance(result, dict) else None}
async def _system_event_result_set(
self, request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
result = payload.get("result")
if not isinstance(result, dict):
raise AstrBotError.invalid_input(
"system.event.result.set 的 result 必须是 object"
)
overlay = self._ensure_request_overlay(request_id)
overlay["result"] = dict(result)
return {"result": dict(result)}
async def _system_event_result_clear(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
overlay["result"] = None
return {}
async def _system_event_handler_whitelist_get(
self, request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
whitelist = overlay.get("handler_whitelist")
if whitelist is None:
return {"plugin_names": None}
return {"plugin_names": sorted(str(item) for item in whitelist)}
async def _system_event_handler_whitelist_set(
self, request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
overlay = self._ensure_request_overlay(request_id)
plugin_names_payload = payload.get("plugin_names")
if plugin_names_payload is None:
overlay["handler_whitelist"] = None
elif isinstance(plugin_names_payload, list):
overlay["handler_whitelist"] = {
str(item) for item in plugin_names_payload if str(item).strip()
}
else:
raise AstrBotError.invalid_input(
"system.event.handler_whitelist.set 的 plugin_names 必须是数组或 null"
)
return await self._system_event_handler_whitelist_get(request_id, {}, _token)
async def _registry_get_handlers_by_event_type(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
event_type = str(payload.get("event_type", "")).strip()
handlers: list[dict[str, Any]] = []
for plugin in self._plugins.values():
handlers.extend(
[
dict(handler)
for handler in plugin.handlers
if event_type in handler.get("event_types", [])
]
)
if event_type == "message":
for plugin_name, routes in self._dynamic_command_routes.items():
for route in routes:
if not isinstance(route, dict):
continue
handlers.append(
{
"plugin_name": str(route.get("plugin_name", plugin_name)),
"handler_full_name": str(
route.get("handler_full_name", "")
),
"trigger_type": (
"message"
if bool(route.get("use_regex", False))
else "command"
),
"event_types": ["message"],
"enabled": True,
"group_path": [],
}
)
return {"handlers": handlers}
async def _registry_get_handler_by_full_name(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
full_name = str(payload.get("full_name", "")).strip()
for plugin in self._plugins.values():
for handler in plugin.handlers:
if handler.get("handler_full_name") == full_name:
return {"handler": dict(handler)}
return {"handler": None}
async def _registry_command_register(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
source_event_type = str(payload.get("source_event_type", "")).strip()
if source_event_type not in {"astrbot_loaded", "platform_loaded"}:
raise AstrBotError.invalid_input(
"register_commands is only available in astrbot_loaded/platform_loaded events"
)
if bool(payload.get("ignore_prefix", False)):
raise AstrBotError.invalid_input(
"register_commands(ignore_prefix=True) is unsupported in SDK runtime"
)
priority_value = payload.get("priority", 0)
if isinstance(priority_value, bool) or not isinstance(priority_value, int):
raise AstrBotError.invalid_input(
"registry.command.register 的 priority 必须是 integer"
)
plugin_id = self._require_caller_plugin_id("registry.command.register")
self.register_dynamic_command_route(
plugin_id=plugin_id,
command_name=str(payload.get("command_name", "")),
handler_full_name=str(payload.get("handler_full_name", "")),
desc=str(payload.get("desc", "")),
priority=priority_value,
use_regex=bool(payload.get("use_regex", False)),
)
return {}
async def _system_session_waiter_register(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_id = self._require_caller_plugin_id("system.session_waiter.register")
session_key = str(payload.get("session_key", "")).strip()
if not session_key:
raise AstrBotError.invalid_input(
"system.session_waiter.register requires session_key"
)
self._session_waiters.setdefault(plugin_id, set()).add(session_key)
return {}
async def _system_session_waiter_unregister(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
plugin_id = self._require_caller_plugin_id("system.session_waiter.unregister")
session_key = str(payload.get("session_key", "")).strip()
plugin_waiters = self._session_waiters.get(plugin_id)
if plugin_waiters is None:
return {}
plugin_waiters.discard(session_key)
if not plugin_waiters:
self._session_waiters.pop(plugin_id, None)
return {}
async def _system_event_react(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self.event_actions.append(
{
"action": "react",
"emoji": str(payload.get("emoji", "")),
"target": _clone_target_payload(payload.get("target")),
}
)
return {"supported": True}
async def _system_event_send_typing(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self.event_actions.append(
{
"action": "send_typing",
"target": _clone_target_payload(payload.get("target")),
}
)
return {"supported": True}
async def _system_event_send_streaming(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
stream_id = f"mock-stream-{len(self._event_streams) + 1}"
stream_state: dict[str, Any] = {
"target": _clone_target_payload(payload.get("target")),
"chunks": [],
"use_fallback": bool(payload.get("use_fallback", False)),
}
self._event_streams[stream_id] = stream_state
return {"supported": True, "stream_id": stream_id}
async def _system_event_send_streaming_chunk(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
stream = self._event_streams.get(str(payload.get("stream_id", "")))
if stream is None:
raise AstrBotError.invalid_input("Unknown sdk event streaming session")
chain = payload.get("chain")
if not isinstance(chain, list):
raise AstrBotError.invalid_input(
"system.event.send_streaming_chunk requires a chain array"
)
stream["chunks"].append({"chain": _clone_chain_payload(chain)})
return {}
async def _system_event_send_streaming_close(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
stream_id = str(payload.get("stream_id", ""))
stream = self._event_streams.pop(stream_id, None)
if stream is None:
raise AstrBotError.invalid_input("Unknown sdk event streaming session")
self.event_actions.append(
{
"action": "send_streaming",
"target": stream["target"],
"chunks": list(stream["chunks"]),
"use_fallback": bool(stream["use_fallback"]),
}
)
return {"supported": True}