Merge commit 'e45bade147ff44b43860ecff12067309e59c151a' into feat/sdk-integration

This commit is contained in:
whatevertogo
2026-03-19 02:20:59 +08:00
parent f8a7e25370
commit 96d1df8584
16 changed files with 1106 additions and 70 deletions

View File

@@ -41,12 +41,14 @@ ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题
如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误:
如果修改了bug或者更改了功能需要添加新的测试
当前仓库已统一使用 `tests/` 目录,`tests_v4/` 不再作为新增测试入口。
仓库当前没有 `run_tests.py`,请直接使用 `pytest`
```bash
python run_tests.py # 运行所有测试
python run_tests.py -v # 详细输出
python run_tests.py -k "test_peer" # 运行匹配模式的测试
python run_tests.py --cov # 运行测试并生成覆盖率报告
python -m pytest tests -q # 运行 tests 目录全部测试
python -m pytest tests -v # 详细输出
python -m pytest tests -k "test_context_register_task" # 运行匹配模式的测试
python -m pytest tests --cov=astrbot_sdk # 运行测试并生成覆盖率报告
```
## 设计原则

View File

@@ -41,12 +41,14 @@ ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题
如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误:
如果修改了bug或者更改了功能需要添加新的测试
当前仓库已统一使用 `tests/` 目录,`tests_v4/` 不再作为新增测试入口。
仓库当前没有 `run_tests.py`,请直接使用 `pytest`
```bash
python run_tests.py # 运行所有测试
python run_tests.py -v # 详细输出
python run_tests.py -k "test_peer" # 运行匹配模式的测试
python run_tests.py --cov # 运行测试并生成覆盖率报告
python -m pytest tests -q # 运行 tests 目录全部测试
python -m pytest tests -v # 详细输出
python -m pytest tests -k "test_context_register_task" # 运行匹配模式的测试
python -m pytest tests --cov=astrbot_sdk # 运行测试并生成覆盖率报告
```
## 设计原则
@@ -57,4 +59,4 @@ python run_tests.py --cov # 运行测试并生成覆盖率报告
---
# currentDate
Today's date is 2026-03-14.
Today's date is 2026-03-19.

View File

@@ -6,6 +6,7 @@ import asyncio
import typing
from collections.abc import Mapping
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, TextIO
from .context import CancelToken
@@ -121,10 +122,77 @@ class InMemoryDB:
class InMemoryMemory:
def __init__(self, store: dict[str, dict[str, Any]]) -> None:
def __init__(
self,
store: dict[str, dict[str, Any]],
*,
expires_at: dict[str, datetime | None] | None = None,
) -> None:
self._store = store
self._expires_at = expires_at if expires_at is not None else {}
@staticmethod
def _is_ttl_entry(value: Any) -> bool:
"""判断测试 memory 值是否使用 TTL 包装结构。
Args:
value: 待检查的存储值。
Returns:
bool: 如果包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
"""
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
@classmethod
def _search_text(cls, value: Any) -> str:
"""提取测试用 memory.search 的匹配文本。
Args:
value: 当前存储的 memory 值。
Returns:
str: 用于本地测试搜索的文本内容。
"""
if cls._is_ttl_entry(value):
value = value.get("value")
if not isinstance(value, dict):
return ""
for field_name in ("embedding_text", "content", "summary", "title", "text"):
item = value.get(field_name)
if isinstance(item, str) and item.strip():
return item.strip()
return str(value)
def _is_expired(self, key: str) -> bool:
"""判断测试 memory 键是否已经过期。
Args:
key: memory 条目的键。
Returns:
bool: 如果当前时间已超过过期时间则返回 ``True``。
"""
expires_at = self._expires_at.get(key)
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
def _purge_if_expired(self, key: str) -> bool:
"""在测试 helper 中清理已过期的 memory 条目。
Args:
key: memory 条目的键。
Returns:
bool: 如果条目已过期并被清理则返回 ``True``。
"""
if not self._is_expired(key):
return False
self._store.pop(key, None)
self._expires_at.pop(key, None)
return True
def get(self, key: str, default: Any = None) -> Any:
if self._purge_if_expired(key):
return default
return self._store.get(key, default)
def save(self, key: str, value: dict[str, Any]) -> None:
@@ -132,11 +200,14 @@ class InMemoryMemory:
def delete(self, key: str) -> None:
self._store.pop(key, None)
self._expires_at.pop(key, None)
def search(self, query: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
for key, value in self._store.items():
if query in key or query in str(value):
for key, value in list(self._store.items()):
if self._purge_if_expired(key):
continue
if query in key or query in self._search_text(value):
results.append({"key": key, "value": value})
return results
@@ -200,7 +271,10 @@ class MockCapabilityRouter(CapabilityRouter):
self._llm_stream_responses: list[str] = []
super().__init__()
self.db = InMemoryDB(self.db_store)
self.memory = InMemoryMemory(self.memory_store)
self.memory = InMemoryMemory(
self.memory_store,
expires_at=self._memory_expires_at,
)
def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]:
return super().list_dynamic_command_routes(plugin_id)

View File

@@ -15,7 +15,7 @@
from __future__ import annotations
from typing import Any
from typing import Any, Literal
from ._proxy import CapabilityProxy
@@ -48,30 +48,47 @@ class MemoryClient:
"""
self._proxy = proxy
async def search(self, query: str) -> list[dict[str, Any]]:
"""Search memory items with the current bridge behavior.
async def search(
self,
query: str,
*,
mode: Literal["auto", "keyword", "vector", "hybrid"] = "auto",
limit: int | None = None,
min_score: float | None = None,
provider_id: str | None = None,
) -> list[dict[str, Any]]:
"""搜索记忆项。
The current core bridge matches `query` against the memory key and the
serialized memory payload. It does not provide vector or semantic
retrieval yet.
Returned items preserve the original `{"key": ..., "value": {...}}`
shape. When `value` is a mapping, its fields are also exposed at the
top level for compatibility with existing plugin examples.
默认会在有 embedding provider 时执行 hybrid 检索,
否则退化为关键词检索。返回结果包含 `score` 与 `match_type` 字段。
Args:
query: 搜索查询文本
mode: 搜索模式,支持 auto/keyword/vector/hybrid
limit: 最大返回条数
min_score: 最低分数阈值
provider_id: 指定 embedding provider默认使用当前激活的 provider
Returns:
匹配的记忆项列表,按相关度排序
示例:
# 搜索用户偏好相关的记忆
results = await ctx.memory.search("用户喜欢什么颜色")
results = await ctx.memory.search(
"用户喜欢什么颜色",
mode="hybrid",
limit=5,
)
for item in results:
print(item["key"], item["content"])
print(item["key"], item["score"], item["match_type"])
"""
output = await self._proxy.call("memory.search", {"query": query})
payload: dict[str, Any] = {"query": query, "mode": mode}
if limit is not None:
payload["limit"] = limit
if min_score is not None:
payload["min_score"] = min_score
if provider_id is not None:
payload["provider_id"] = provider_id
output = await self._proxy.call("memory.search", payload)
items = output.get("items")
if not isinstance(items, (list, tuple)):
return []
@@ -96,16 +113,20 @@ class MemoryClient:
key: 记忆项的唯一标识键
value: 要存储的数据字典
**extra: 额外的键值对,会合并到 value 中
Raises:
TypeError: 如果 value 不是 dict 类型
示例:
# 保存用户偏好
保存用户偏好
await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
# 使用关键字参数
使用关键字参数
await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
使用 embedding_text 显式指定检索文本
await ctx.memory.save(
"profile",
{"name": "alice", "embedding_text": "Alice 喜欢蓝色和海边"},
)
"""
if value is not None and not isinstance(value, dict):
raise TypeError("memory.save 的 value 必须是 dict")
@@ -230,16 +251,22 @@ class MemoryClient:
async def stats(self) -> dict[str, Any]:
"""获取记忆系统统计信息。
返回记忆系统的当前状态,包括条目数等统计信息
返回记忆系统的当前状态,包括条目数、索引状态和脏索引数量
Returns:
统计信息字典,包含:
- total_items: 总记忆条目数
- total_bytes: 总占用字节数(可选)
- ttl_entries: 带过期时间的条目数(可选)
- indexed_items: 已建立检索索引的条目数(可选)
- embedded_items: 已生成向量的条目数(可选)
- dirty_items: 等待重建索引的条目数(可选)
示例:
stats = await ctx.memory.stats()
print(f"记忆库共有 {stats['total_items']} 条记录")
if "embedded_items" in stats:
print(f"其中 {stats['embedded_items']} 条已经向量化")
"""
output = await self._proxy.call("memory.stats", {})
stats = {
@@ -250,4 +277,10 @@ class MemoryClient:
stats["plugin_id"] = output.get("plugin_id")
if "ttl_entries" in output:
stats["ttl_entries"] = output.get("ttl_entries")
if "indexed_items" in output:
stats["indexed_items"] = output.get("indexed_items")
if "embedded_items" in output:
stats["embedded_items"] = output.get("embedded_items")
if "dirty_items" in output:
stats["dirty_items"] = output.get("dirty_items")
return stats

View File

@@ -159,12 +159,12 @@ async for chunk in ctx.llm.stream_chat("讲一个故事"):
### search()
语义搜索记忆项。
搜索记忆项。默认在有 embedding provider 时执行 hybrid 检索。
```python
results = await ctx.memory.search("用户喜欢什么颜色")
results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5)
for item in results:
print(item["key"], item["content"])
print(item["key"], item["score"], item["match_type"])
```
### save()
@@ -177,6 +177,12 @@ await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
# 使用关键字参数
await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
# 显式指定检索文本
await ctx.memory.save(
"profile:alice",
{"name": "Alice", "embedding_text": "Alice 喜欢蓝色和海边"},
)
```
### get()
@@ -202,6 +208,15 @@ await ctx.memory.save_with_ttl(
)
```
### stats()
查看记忆索引状态。
```python
stats = await ctx.memory.stats()
print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items"))
```
---
## Database 客户端

View File

@@ -66,10 +66,12 @@ from astrbot_sdk.clients import MemoryClient
#### search()
语义搜索。
索记忆。默认在有 embedding provider 时执行 hybrid 检索。
```python
results = await ctx.memory.search("用户喜欢什么颜色")
results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5)
for item in results:
print(item["key"], item["score"], item["match_type"])
```
#### save()
@@ -78,6 +80,10 @@ results = await ctx.memory.search("用户喜欢什么颜色")
```python
await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
await ctx.memory.save(
"profile:alice",
{"name": "Alice", "embedding_text": "Alice 喜欢蓝色和海边"},
)
```
#### get()
@@ -108,6 +114,15 @@ await ctx.memory.save_with_ttl(
await ctx.memory.delete("old_note")
```
#### stats()
查看记忆索引状态。
```python
stats = await ctx.memory.stats()
print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items"))
```
---
## 3. DBClient - KV 数据库客户端

View File

@@ -142,25 +142,33 @@ from astrbot_sdk.clients import MemoryClient
### 方法
#### `search(query)`
#### `search(query, *, mode="auto", limit=None, min_score=None, provider_id=None)`
语义搜索记忆项。
搜索记忆项。默认会在存在 embedding provider 时执行 hybrid 检索,
否则退化为关键词检索。
**参数**:
- `query` (`str`): 搜索查询文本(自然语言)
- `mode` (`Literal["auto", "keyword", "vector", "hybrid"]`): 搜索模式
- `limit` (`int | None`): 最大返回条数
- `min_score` (`float | None`): 最低分数阈值
- `provider_id` (`str | None`): 指定 embedding provider
**返回**: `list[dict]` - 匹配的记忆项列表,按相关度排序
**返回**: `list[dict]` - 匹配的记忆项列表。每项至少包含 `key``value``score``match_type`
**示例**:
```python
# 搜索用户偏好
results = await ctx.memory.search("用户喜欢什么颜色")
results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5)
for item in results:
print(f"Key: {item['key']}, Content: {item['content']}")
print(item["key"], item["score"], item["match_type"])
# 搜索对话摘要
summaries = await ctx.memory.search("之前讨论过什么技术话题")
# 强制使用关键词检索
keyword_hits = await ctx.memory.search("blue", mode="keyword", min_score=0.9)
# 使用当前激活的 embedding provider 执行向量检索
vector_hits = await ctx.memory.search("之前讨论过什么技术话题", mode="vector")
```
---
@@ -192,6 +200,16 @@ await ctx.memory.save(
tags=["work"],
timestamp="2024-01-01"
)
# 显式指定检索文本
await ctx.memory.save(
"profile:alice",
{
"name": "Alice",
"city": "Shanghai",
"embedding_text": "Alice 喜欢蓝色、海边和摄影",
},
)
```
---
@@ -314,6 +332,12 @@ stats = await ctx.memory.stats()
print(f"记忆库共有 {stats['total_items']} 条记录")
if 'ttl_entries' in stats:
print(f"其中 {stats['ttl_entries']} 条有过期时间")
if 'indexed_items' in stats:
print(f"已建立索引: {stats['indexed_items']}")
if 'embedded_items' in stats:
print(f"已向量化: {stats['embedded_items']}")
if 'dirty_items' in stats:
print(f"待重建索引: {stats['dirty_items']}")
```
---

View File

@@ -210,12 +210,16 @@ async for chunk in ctx.llm.stream_chat("讲一个故事"):
##### `search()`
语义搜索。
索记忆。默认在有 embedding provider 时执行 hybrid 检索。
```python
results = await ctx.memory.search("用户喜欢什么颜色")
results = await ctx.memory.search(
"用户喜欢什么颜色",
mode="hybrid",
limit=5,
)
for item in results:
print(item["key"], item["content"])
print(item["key"], item["score"], item["match_type"])
```
##### `save()`
@@ -228,6 +232,15 @@ await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
# 使用关键字参数
await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
# 显式指定检索文本
await ctx.memory.save(
"profile:alice",
{
"name": "Alice",
"embedding_text": "Alice 喜欢蓝色和海边",
},
)
```
##### `get()`
@@ -261,6 +274,15 @@ await ctx.memory.save_with_ttl(
await ctx.memory.delete("old_note")
```
##### `stats()`
查看记忆索引状态。
```python
stats = await ctx.memory.stats()
print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items"))
```
---
### 3. DB 客户端 (ctx.db)

View File

@@ -75,11 +75,28 @@ LLM_STREAM_CHAT_OUTPUT_SCHEMA = _object_schema(
required=("text",), text={"type": "string"}
)
MEMORY_SEARCH_INPUT_SCHEMA = _object_schema(
required=("query",), query={"type": "string"}
required=("query",),
query={"type": "string"},
mode={"type": "string", "enum": ["auto", "keyword", "vector", "hybrid"]},
limit={"type": "integer", "minimum": 1},
min_score={"type": "number"},
provider_id={"type": "string"},
)
MEMORY_SEARCH_OUTPUT_SCHEMA = _object_schema(
required=("items",),
items={"type": "array", "items": {"type": "object"}},
items={
"type": "array",
"items": _object_schema(
required=("key", "value", "score", "match_type"),
key={"type": "string"},
value=_nullable({"type": "object"}),
score={"type": "number"},
match_type={
"type": "string",
"enum": ["keyword", "vector", "hybrid"],
},
),
},
)
MEMORY_SAVE_INPUT_SCHEMA = _object_schema(
required=("key", "value"),
@@ -133,6 +150,9 @@ MEMORY_STATS_OUTPUT_SCHEMA = _object_schema(
total_bytes=_nullable({"type": "integer"}),
plugin_id=_nullable({"type": "string"}),
ttl_entries=_nullable({"type": "integer"}),
indexed_items=_nullable({"type": "integer"}),
embedded_items=_nullable({"type": "integer"}),
dirty_items=_nullable({"type": "integer"}),
)
SYSTEM_GET_DATA_DIR_INPUT_SCHEMA = _object_schema()
SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA = _object_schema(

View File

@@ -20,10 +20,13 @@ from __future__ import annotations
import asyncio
import base64
import copy
import hashlib
import json
import math
import re
import uuid
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
@@ -52,8 +55,61 @@ def _clone_chain_payload(value: Any) -> list[dict[str, Any]]:
]
_MOCK_EMBEDDING_DIM = 24
def _embedding_terms(text: str) -> list[str]:
"""为 mock embedding 构造稳定的分词结果。
Args:
text: 待向量化的原始文本。
Returns:
list[str]: 用于生成 mock 向量的词项列表。
"""
normalized = re.sub(r"\s+", " ", str(text).strip().casefold())
compact = normalized.replace(" ", "")
if not normalized:
return []
terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word]
if compact:
if len(compact) == 1:
terms.append(compact)
else:
terms.extend(
compact[index : index + 2] for index in range(len(compact) - 1)
)
terms.append(compact)
return terms or [normalized]
def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
"""生成确定性的 mock embedding 向量。
Args:
text: 待向量化的文本。
provider_id: 当前使用的 embedding provider 标识。
Returns:
list[float]: 归一化后的 mock 向量。
"""
values = [0.0] * _MOCK_EMBEDDING_DIM
for term in _embedding_terms(text):
digest = hashlib.sha256(f"{provider_id}:{term}".encode("utf-8")).digest()
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
values[index] += 1.0 + min(len(term), 8) * 0.05
norm = math.sqrt(sum(value * value for value in values))
if norm <= 0:
return values
return [value / norm for value in values]
class _CapabilityRouterHost:
memory_store: dict[str, dict[str, Any]]
_memory_index: dict[str, dict[str, Any]]
_memory_dirty_keys: set[str]
_memory_expires_at: dict[str, datetime | None]
db_store: dict[str, Any]
sent_messages: list[dict[str, Any]]
event_actions: list[dict[str, Any]]
@@ -278,15 +334,471 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
},
)
@staticmethod
def _is_ttl_memory_entry(value: Any) -> bool:
"""判断存储值是否使用了 TTL 包装结构。
Args:
value: 待检查的存储值。
Returns:
bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
"""
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
@classmethod
def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
"""提取用于检索的原始 memory payload。
Args:
stored: memory_store 中保存的原始值。
Returns:
dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
"""
if not isinstance(stored, dict):
return None
if cls._is_ttl_memory_entry(stored):
value = stored.get("value")
return value if isinstance(value, dict) else None
return stored
@classmethod
def _extract_memory_text(cls, stored: Any) -> str:
"""提取用于检索索引的首选文本。
Args:
stored: memory_store 中保存的原始值。
Returns:
str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
"""
value = cls._memory_value_for_search(stored)
if not isinstance(value, dict):
return ""
for field_name in ("embedding_text", "content", "summary", "title", "text"):
item = value.get(field_name)
if isinstance(item, str) and item.strip():
return item.strip()
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
@staticmethod
def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
"""将 TTL 秒数转换为 UTC 过期时间。
Args:
ttl_seconds: TTL 秒数。
Returns:
datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
"""
try:
ttl = int(ttl_seconds)
except (TypeError, ValueError):
return None
if ttl < 1:
return None
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
@staticmethod
def _memory_keyword_score(query: str, key: str, text: str) -> float:
"""计算关键词匹配分数。
Args:
query: 查询文本。
key: memory 条目的键。
text: 已索引的检索文本。
Returns:
float: 基于键名和文本命中的粗粒度关键词分数。
"""
normalized_query = str(query).casefold()
if not normalized_query:
return 1.0
normalized_key = str(key).casefold()
normalized_text = str(text).casefold()
if normalized_query in normalized_key:
return 1.0
if normalized_query in normalized_text:
return 0.9
return 0.0
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
"""计算两个向量之间的余弦相似度。
Args:
left: 左侧向量。
right: 右侧向量。
Returns:
float: 余弦相似度;输入不合法时返回 ``0.0``。
"""
if not left or not right or len(left) != len(right):
return 0.0
left_norm = math.sqrt(sum(value * value for value in left))
right_norm = math.sqrt(sum(value * value for value in right))
if left_norm <= 0 or right_norm <= 0:
return 0.0
return sum(a * b for a, b in zip(left, right, strict=False)) / (
left_norm * right_norm
)
def _resolve_memory_embedding_provider_id(
self,
provider_id: Any,
*,
required: bool,
) -> str | None:
"""解析 memory.search 要使用的 embedding provider。
Args:
provider_id: 调用方显式传入的 provider 标识。
required: 当前检索模式是否强制要求 embedding provider。
Returns:
str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。
"""
normalized = str(provider_id).strip() if provider_id is not None else ""
if normalized:
self._provider_entry(
{"provider_id": normalized},
"memory.search",
"embedding",
)
return normalized
active_id = self._active_provider_ids.get("embedding")
if active_id is not None:
normalized_active = str(active_id).strip()
if normalized_active:
self._provider_entry(
{"provider_id": normalized_active},
"memory.search",
"embedding",
)
return normalized_active
if required:
raise AstrBotError.invalid_input(
"memory.search requires an embedding provider",
)
return None
@staticmethod
def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
"""将原始索引项规范化为内部统一结构。
Args:
entry: 当前索引表中的原始项。
text: 当前条目的索引文本。
Returns:
dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
"""
if isinstance(entry, dict):
return {
"text": str(entry.get("text", text)),
"embedding": (
[float(item) for item in entry.get("embedding", [])]
if isinstance(entry.get("embedding"), list)
else None
),
"provider_id": (
str(entry.get("provider_id")).strip()
if entry.get("provider_id") is not None
else None
),
}
return {"text": text, "embedding": None, "provider_id": None}
def _clear_memory_sidecars(self, key: str) -> None:
"""清理指定 memory 键对应的所有 sidecar 状态。
Args:
key: memory 条目的键。
Returns:
None
"""
self._memory_index.pop(key, None)
self._memory_expires_at.pop(key, None)
self._memory_dirty_keys.discard(key)
def _delete_memory_entry(self, key: str) -> bool:
"""删除 memory 条目并同步清理 sidecar 状态。
Args:
key: memory 条目的键。
Returns:
bool: 条目存在并删除成功时返回 ``True``。
"""
deleted = self.memory_store.pop(key, None) is not None
self._clear_memory_sidecars(key)
return deleted
def _upsert_memory_sidecars(
self,
key: str,
stored: dict[str, Any],
*,
expires_at: datetime | None = None,
) -> None:
"""创建或更新单条 memory 的 sidecar 索引状态。
Args:
key: memory 条目的键。
stored: 需要建立索引的原始存储值。
expires_at: 可选的绝对过期时间。
Returns:
None
"""
self._memory_index[key] = {
"text": self._extract_memory_text(stored),
"embedding": None,
"provider_id": None,
}
if expires_at is None:
self._memory_expires_at.pop(key, None)
else:
self._memory_expires_at[key] = expires_at
self._memory_dirty_keys.add(key)
def _ensure_memory_sidecars(self, key: str, stored: Any) -> None:
"""确保 sidecar 状态与当前存储值保持一致。
Args:
key: memory 条目的键。
stored: memory_store 中的当前存储值。
Returns:
None
"""
if not isinstance(stored, dict):
return
text = self._extract_memory_text(stored)
existed = key in self._memory_index
entry = self._memory_index_entry(self._memory_index.get(key), text=text)
if entry["text"] != text:
entry["text"] = text
entry["embedding"] = None
entry["provider_id"] = None
self._memory_dirty_keys.add(key)
self._memory_index[key] = entry
if not existed:
self._memory_dirty_keys.add(key)
def _is_memory_expired(self, key: str) -> bool:
"""判断 memory 条目是否已过期。
Args:
key: memory 条目的键。
Returns:
bool: 如果当前时间已超过记录的过期时间则返回 ``True``。
"""
expires_at = self._memory_expires_at.get(key)
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
def _purge_expired_memory_entry(self, key: str) -> bool:
"""在单条 memory 已过期时立即清理它。
Args:
key: memory 条目的键。
Returns:
bool: 如果条目已过期并被成功清理则返回 ``True``。
"""
if not self._is_memory_expired(key):
return False
self._delete_memory_entry(key)
return True
def _purge_expired_memory_entries(self) -> None:
"""批量清理所有已跟踪的过期 TTL 条目。
Returns:
None
"""
for key in list(self._memory_expires_at):
self._purge_expired_memory_entry(key)
async def _embedding_for_text(
self,
*,
provider_id: str,
text: str,
) -> list[float]:
"""通过 embedding capability 获取单条文本向量。
Args:
provider_id: 使用的 embedding provider 标识。
text: 待向量化的文本。
Returns:
list[float]: provider 返回的向量;异常场景下返回空列表。
"""
output = await self._provider_embedding_get_embedding(
"",
{"provider_id": provider_id, "text": text},
None,
)
embedding = output.get("embedding")
if not isinstance(embedding, list):
return []
return [float(item) for item in embedding]
async def _embeddings_for_texts(
self,
*,
provider_id: str,
texts: list[str],
) -> list[list[float]]:
"""批量获取多条文本的 embedding 向量。
Args:
provider_id: 使用的 embedding provider 标识。
texts: 待向量化的文本列表。
Returns:
list[list[float]]: 与输入顺序对应的向量列表。
"""
if not texts:
return []
output = await self._provider_embedding_get_embeddings(
"",
{"provider_id": provider_id, "texts": texts},
None,
)
embeddings = output.get("embeddings")
if not isinstance(embeddings, list):
return []
return [
[float(value) for value in item]
for item in embeddings
if isinstance(item, list)
]
async def _refresh_memory_embeddings(self, *, provider_id: str) -> None:
"""刷新当前 provider 下脏或过期的 memory 向量索引。
Args:
provider_id: 当前使用的 embedding provider 标识。
Returns:
None
"""
keys_to_refresh: list[str] = []
texts_to_refresh: list[str] = []
for key, stored in self.memory_store.items():
self._ensure_memory_sidecars(key, stored)
entry = self._memory_index_entry(
self._memory_index.get(key),
text=self._extract_memory_text(stored),
)
should_refresh = (
key in self._memory_dirty_keys
or entry["embedding"] is None
or entry["provider_id"] != provider_id
)
self._memory_index[key] = entry
if should_refresh:
keys_to_refresh.append(key)
texts_to_refresh.append(str(entry["text"]))
embeddings = await self._embeddings_for_texts(
provider_id=provider_id,
texts=texts_to_refresh,
)
for index, key in enumerate(keys_to_refresh):
entry = self._memory_index_entry(
self._memory_index.get(key),
text=str(texts_to_refresh[index]),
)
entry["embedding"] = embeddings[index] if index < len(embeddings) else []
entry["provider_id"] = provider_id
self._memory_index[key] = entry
self._memory_dirty_keys.discard(key)
async def _memory_search(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
query = str(payload.get("query", ""))
items = [
{"key": key, "value": value}
for key, value in self.memory_store.items()
if query in key or query in json.dumps(value, ensure_ascii=False)
]
mode = str(payload.get("mode", "auto")).strip().lower() or "auto"
limit = self._optional_int(payload.get("limit"))
min_score = (
float(payload.get("min_score"))
if payload.get("min_score") is not None
else None
)
self._purge_expired_memory_entries()
provider_id = self._resolve_memory_embedding_provider_id(
payload.get("provider_id"),
required=mode in {"vector", "hybrid"},
)
effective_mode = mode
if effective_mode == "auto":
effective_mode = "hybrid" if provider_id is not None else "keyword"
query_embedding: list[float] | None = None
if effective_mode in {"vector", "hybrid"}:
if provider_id is None:
raise AstrBotError.invalid_input(
"memory.search requires an embedding provider",
)
await self._refresh_memory_embeddings(provider_id=provider_id)
query_embedding = await self._embedding_for_text(
provider_id=provider_id,
text=query,
)
items: list[dict[str, Any]] = []
for key, value in self.memory_store.items():
self._ensure_memory_sidecars(key, value)
entry = self._memory_index_entry(
self._memory_index.get(key),
text=self._extract_memory_text(value),
)
text = str(entry.get("text", ""))
keyword_score = self._memory_keyword_score(query, key, text)
vector_score = 0.0
if query_embedding is not None:
embedding = entry.get("embedding")
if isinstance(embedding, list):
vector_score = max(
0.0,
self._cosine_similarity(query_embedding, embedding),
)
if effective_mode == "keyword":
score = keyword_score
elif effective_mode == "vector":
score = vector_score
else:
score = vector_score
if keyword_score > 0:
score = max(score, 0.4 + 0.6 * vector_score)
if score <= 0:
continue
if min_score is not None and score < min_score:
continue
if effective_mode == "keyword" or (keyword_score > 0 and vector_score <= 0):
match_type = "keyword"
elif effective_mode == "vector" or keyword_score <= 0:
match_type = "vector"
else:
match_type = "hybrid"
items.append(
{
"key": key,
"value": self._memory_value_for_search(value),
"score": score,
"match_type": match_type,
}
)
items.sort(key=lambda item: (-float(item["score"]), str(item["key"])))
if limit is not None and limit >= 0:
items = items[:limit]
return {"items": items}
async def _memory_save(
@@ -297,17 +809,21 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
if not isinstance(value, dict):
raise AstrBotError.invalid_input("memory.save 的 value 必须是 object")
self.memory_store[key] = value
self._upsert_memory_sidecars(key, value)
return {}
async def _memory_get(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
return {"value": self.memory_store.get(str(payload.get("key", "")))}
key = str(payload.get("key", ""))
if self._purge_expired_memory_entry(key):
return {"value": None}
return {"value": self.memory_store.get(key)}
async def _memory_delete(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self.memory_store.pop(str(payload.get("key", "")), None)
self._delete_memory_entry(str(payload.get("key", "")))
return {}
async def _memory_save_with_ttl(
@@ -320,7 +836,13 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
raise AstrBotError.invalid_input(
"memory.save_with_ttl 的 value 必须是 object"
)
self.memory_store[key] = {"value": value, "ttl_seconds": ttl_seconds}
stored = {"value": value, "ttl_seconds": ttl_seconds}
self.memory_store[key] = stored
self._upsert_memory_sidecars(
key,
stored,
expires_at=self._memory_expiration_from_ttl(ttl_seconds),
)
return {}
async def _memory_get_many(
@@ -332,6 +854,9 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
keys = [str(item) for item in keys_payload]
items = []
for key in keys:
if self._purge_expired_memory_entry(key):
items.append({"key": key, "value": None})
continue
stored = self.memory_store.get(key)
if (
isinstance(stored, dict)
@@ -353,28 +878,36 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
keys = [str(item) for item in keys_payload]
deleted_count = 0
for key in keys:
if key in self.memory_store:
del self.memory_store[key]
if self._delete_memory_entry(key):
deleted_count += 1
return {"deleted_count": deleted_count}
async def _memory_stats(
self, _request_id: str, _payload: dict[str, Any], _token
) -> dict[str, Any]:
self._purge_expired_memory_entries()
total_items = len(self.memory_store)
total_bytes = sum(
len(str(key)) + len(str(value)) for key, value in self.memory_store.items()
)
ttl_entries = sum(
ttl_entries = len(self._memory_expires_at)
indexed_items = len(self._memory_index)
embedded_items = sum(
1
for value in self.memory_store.values()
if isinstance(value, dict) and "value" in value and "ttl_seconds" in value
for entry in self._memory_index.values()
if isinstance(entry, dict)
and isinstance(entry.get("embedding"), list)
and bool(entry.get("embedding"))
)
dirty_items = len(self._memory_dirty_keys)
return {
"total_items": total_items,
"total_bytes": total_bytes,
"plugin_id": self._require_caller_plugin_id("memory.stats"),
"ttl_entries": ttl_entries,
"indexed_items": indexed_items,
"embedded_items": embedded_items,
"dirty_items": dirty_items,
}
def _register_memory_capabilities(self) -> None:
@@ -1072,17 +1605,22 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
async def _provider_embedding_get_embedding(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._provider_entry(
provider = self._provider_entry(
payload,
"provider.embedding.get_embedding",
"embedding",
)
return {"embedding": [0.0, 0.0, 0.0]}
return {
"embedding": _mock_embedding_vector(
str(payload.get("text", "")),
provider_id=str(provider.get("id", "")),
)
}
async def _provider_embedding_get_embeddings(
self, _request_id: str, payload: dict[str, Any], _token
) -> dict[str, Any]:
self._provider_entry(
provider = self._provider_entry(
payload,
"provider.embedding.get_embeddings",
"embedding",
@@ -1093,7 +1631,13 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
"provider.embedding.get_embeddings requires texts",
)
return {
"embeddings": [[0.0, 0.0, 0.0] for _ in texts],
"embeddings": [
_mock_embedding_vector(
str(text),
provider_id=str(provider.get("id", "")),
)
for text in texts
],
}
async def _provider_embedding_get_dim(
@@ -1104,7 +1648,7 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
"provider.embedding.get_dim",
"embedding",
)
return {"dim": 3}
return {"dim": _MOCK_EMBEDDING_DIM}
async def _provider_rerank_rerank(
self, _request_id: str, payload: dict[str, Any], _token

View File

@@ -217,6 +217,9 @@ class CapabilityRouter(BuiltinCapabilityRouterMixin):
self._registrations: dict[str, _CapabilityRegistration] = {}
self.db_store: dict[str, Any] = {}
self.memory_store: dict[str, dict[str, Any]] = {}
self._memory_index: dict[str, dict[str, Any]] = {}
self._memory_dirty_keys: set[str] = set()
self._memory_expires_at: dict[str, datetime | None] = {}
self.sent_messages: list[dict[str, Any]] = []
self.event_actions: list[dict[str, Any]] = []
self._event_streams: dict[str, dict[str, Any]] = {}

View File

@@ -838,7 +838,7 @@ class HandlerDispatcher:
if inspect.isawaitable(result):
await result
return
await Star().on_error(exc, event, ctx)
await Star.default_on_error(exc, event, ctx)
__all__ = ["CapabilityDispatcher", "HandlerDispatcher"]

View File

@@ -102,7 +102,9 @@ class Star(PluginKVStoreMixin):
options=options,
)
async def on_error(self, error: Exception, event, ctx) -> None:
@staticmethod
async def default_on_error(error: Exception, event, ctx) -> None:
del ctx
if isinstance(error, AstrBotError):
lines: list[str] = []
if error.retryable:
@@ -122,6 +124,9 @@ class Star(PluginKVStoreMixin):
await event.reply("出了点问题,请联系插件作者")
logger.error("handler 执行失败\n{}", traceback.format_exc())
async def on_error(self, error: Exception, event, ctx) -> None:
await Star.default_on_error(error, event, ctx)
@classmethod
def __astrbot_is_new_star__(cls) -> bool:
return True

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
import pytest
from astrbot_sdk._invocation_context import caller_plugin_scope
from astrbot_sdk.runtime.capability_router import CapabilityRouter
async def _call(
router: CapabilityRouter,
capability: str,
payload: dict[str, object],
) -> dict[str, object]:
result = await router.execute(
capability,
payload,
stream=False,
cancel_token=object(),
request_id=f"test-{capability}",
)
assert isinstance(result, dict)
return result
@pytest.mark.asyncio
async def test_memory_save_updates_sidecars_and_search() -> None:
router = CapabilityRouter()
await _call(
router,
"memory.save",
{"key": "user-pref", "value": {"content": "user likes blue"}},
)
assert router.memory_store["user-pref"] == {"content": "user likes blue"}
assert router._memory_index["user-pref"] == {
"text": "user likes blue",
"embedding": None,
"provider_id": None,
}
assert "user-pref" in router._memory_dirty_keys
assert "user-pref" not in router._memory_expires_at
result = await _call(router, "memory.search", {"query": "likes blue"})
assert len(result["items"]) == 1
item = result["items"][0]
assert item["key"] == "user-pref"
assert item["value"] == {"content": "user likes blue"}
assert item["match_type"] == "hybrid"
assert float(item["score"]) > 0
assert router._memory_index["user-pref"]["provider_id"] == "mock-embedding-provider"
assert isinstance(router._memory_index["user-pref"]["embedding"], list)
assert "user-pref" not in router._memory_dirty_keys
@pytest.mark.asyncio
async def test_memory_search_keyword_mode_keeps_dirty_embedding_state() -> None:
router = CapabilityRouter()
await _call(
router,
"memory.save",
{"key": "alpha-key", "value": {"content": "blue ocean memory"}},
)
result = await _call(
router,
"memory.search",
{"query": "alpha", "mode": "keyword", "min_score": 0.95},
)
assert [item["key"] for item in result["items"]] == ["alpha-key"]
assert result["items"][0]["match_type"] == "keyword"
assert router._memory_index["alpha-key"]["embedding"] is None
assert "alpha-key" in router._memory_dirty_keys
@pytest.mark.asyncio
async def test_memory_search_vector_mode_supports_ranking_and_limit() -> None:
router = CapabilityRouter()
await _call(
router,
"memory.save",
{"key": "fruit-note", "value": {"content": "banana smoothie with mango"}},
)
await _call(
router,
"memory.save",
{"key": "ocean-note", "value": {"content": "waves on the blue ocean"}},
)
result = await _call(
router,
"memory.search",
{"query": "banana smoothie", "mode": "vector", "limit": 1},
)
assert len(result["items"]) == 1
assert result["items"][0]["key"] == "fruit-note"
assert result["items"][0]["match_type"] == "vector"
@pytest.mark.asyncio
async def test_memory_search_auto_falls_back_to_keyword_without_embedding_provider() -> (
None
):
router = CapabilityRouter()
router._active_provider_ids["embedding"] = None
await _call(
router,
"memory.save",
{"key": "alpha-key", "value": {"content": "blue ocean memory"}},
)
result = await _call(router, "memory.search", {"query": "alpha", "mode": "auto"})
assert [item["key"] for item in result["items"]] == ["alpha-key"]
assert result["items"][0]["match_type"] == "keyword"
assert router._memory_index["alpha-key"]["embedding"] is None
assert "alpha-key" in router._memory_dirty_keys
@pytest.mark.asyncio
async def test_memory_search_reembeds_when_embedding_provider_changes() -> None:
router = CapabilityRouter()
router._provider_catalog["embedding"].append(
{
"id": "mock-embedding-provider-alt",
"model": "mock-embedding-model-alt",
"type": "mock",
"provider_type": "embedding",
}
)
router._provider_configs["mock-embedding-provider-alt"] = {
"id": "mock-embedding-provider-alt",
"model": "mock-embedding-model-alt",
"type": "mock",
"provider_type": "embedding",
"enable": True,
}
await _call(
router,
"memory.save",
{"key": "topic", "value": {"content": "banana smoothie with mango"}},
)
first = await _call(router, "memory.search", {"query": "banana smoothie"})
first_embedding = list(router._memory_index["topic"]["embedding"])
assert first["items"][0]["match_type"] == "hybrid"
assert router._memory_index["topic"]["provider_id"] == "mock-embedding-provider"
router._active_provider_ids["embedding"] = "mock-embedding-provider-alt"
second = await _call(router, "memory.search", {"query": "banana smoothie"})
second_embedding = list(router._memory_index["topic"]["embedding"])
assert second["items"][0]["match_type"] == "hybrid"
assert router._memory_index["topic"]["provider_id"] == "mock-embedding-provider-alt"
assert first_embedding != second_embedding
@pytest.mark.asyncio
async def test_memory_stats_reports_index_embedding_and_dirty_counts() -> None:
router = CapabilityRouter()
await _call(
router,
"memory.save",
{"key": "a", "value": {"content": "alpha"}},
)
await _call(
router,
"memory.save_with_ttl",
{"key": "b", "value": {"content": "beta"}, "ttl_seconds": 60},
)
with caller_plugin_scope("test-plugin"):
before = await _call(router, "memory.stats", {})
assert before["total_items"] == 2
assert before["ttl_entries"] == 1
assert before["indexed_items"] == 2
assert before["embedded_items"] == 0
assert before["dirty_items"] == 2
await _call(router, "memory.search", {"query": "alpha"})
with caller_plugin_scope("test-plugin"):
after = await _call(router, "memory.stats", {})
assert after["total_items"] == 2
assert after["ttl_entries"] == 1
assert after["indexed_items"] == 2
assert after["embedded_items"] == 2
assert after["dirty_items"] == 0
@pytest.mark.asyncio
async def test_memory_save_with_ttl_registers_expiry_and_purges_on_read() -> None:
router = CapabilityRouter()
await _call(
router,
"memory.save_with_ttl",
{"key": "temp-note", "value": {"content": "temporary note"}, "ttl_seconds": 60},
)
assert "temp-note" in router._memory_index
assert "temp-note" in router._memory_dirty_keys
assert router._memory_expires_at["temp-note"] is not None
search_result = await _call(router, "memory.search", {"query": "temporary"})
assert search_result["items"][0]["value"] == {"content": "temporary note"}
router._memory_expires_at["temp-note"] = datetime.now(timezone.utc) - timedelta(
seconds=1
)
get_result = await _call(router, "memory.get", {"key": "temp-note"})
assert get_result == {"value": None}
assert "temp-note" not in router.memory_store
assert "temp-note" not in router._memory_index
assert "temp-note" not in router._memory_expires_at
assert "temp-note" not in router._memory_dirty_keys
@pytest.mark.asyncio
async def test_memory_get_many_unwraps_ttl_value_and_returns_none_after_expiry() -> (
None
):
router = CapabilityRouter()
await _call(
router,
"memory.save_with_ttl",
{"key": "session", "value": {"content": "active session"}, "ttl_seconds": 60},
)
result = await _call(router, "memory.get_many", {"keys": ["session", "missing"]})
assert result == {
"items": [
{"key": "session", "value": {"content": "active session"}},
{"key": "missing", "value": None},
]
}
router._memory_expires_at["session"] = datetime.now(timezone.utc) - timedelta(
seconds=1
)
expired_result = await _call(router, "memory.get_many", {"keys": ["session"]})
assert expired_result == {"items": [{"key": "session", "value": None}]}
@pytest.mark.asyncio
async def test_memory_delete_many_clears_sidecars() -> None:
router = CapabilityRouter()
await _call(
router,
"memory.save",
{"key": "a", "value": {"content": "alpha"}},
)
await _call(
router,
"memory.save_with_ttl",
{"key": "b", "value": {"content": "beta"}, "ttl_seconds": 60},
)
result = await _call(router, "memory.delete_many", {"keys": ["a", "b", "c"]})
assert result == {"deleted_count": 2}
assert router.memory_store == {}
assert router._memory_index == {}
assert router._memory_expires_at == {}
assert router._memory_dirty_keys == set()