mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 18:20:16 +08:00
Merge commit 'e45bade147ff44b43860ecff12067309e59c151a' into feat/sdk-integration
This commit is contained in:
10
AGENTS.md
10
AGENTS.md
@@ -41,12 +41,14 @@ ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题
|
||||
|
||||
如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误:
|
||||
如果修改了bug或者更改了功能需要添加新的测试
|
||||
当前仓库已统一使用 `tests/` 目录,`tests_v4/` 不再作为新增测试入口。
|
||||
仓库当前没有 `run_tests.py`,请直接使用 `pytest`。
|
||||
|
||||
```bash
|
||||
python run_tests.py # 运行所有测试
|
||||
python run_tests.py -v # 详细输出
|
||||
python run_tests.py -k "test_peer" # 运行匹配模式的测试
|
||||
python run_tests.py --cov # 运行测试并生成覆盖率报告
|
||||
python -m pytest tests -q # 运行 tests 目录全部测试
|
||||
python -m pytest tests -v # 详细输出
|
||||
python -m pytest tests -k "test_context_register_task" # 运行匹配模式的测试
|
||||
python -m pytest tests --cov=astrbot_sdk # 运行测试并生成覆盖率报告
|
||||
```
|
||||
|
||||
## 设计原则
|
||||
|
||||
12
CLAUDE.md
12
CLAUDE.md
@@ -41,12 +41,14 @@ ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题
|
||||
|
||||
如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误:
|
||||
如果修改了bug或者更改了功能需要添加新的测试
|
||||
当前仓库已统一使用 `tests/` 目录,`tests_v4/` 不再作为新增测试入口。
|
||||
仓库当前没有 `run_tests.py`,请直接使用 `pytest`。
|
||||
|
||||
```bash
|
||||
python run_tests.py # 运行所有测试
|
||||
python run_tests.py -v # 详细输出
|
||||
python run_tests.py -k "test_peer" # 运行匹配模式的测试
|
||||
python run_tests.py --cov # 运行测试并生成覆盖率报告
|
||||
python -m pytest tests -q # 运行 tests 目录全部测试
|
||||
python -m pytest tests -v # 详细输出
|
||||
python -m pytest tests -k "test_context_register_task" # 运行匹配模式的测试
|
||||
python -m pytest tests --cov=astrbot_sdk # 运行测试并生成覆盖率报告
|
||||
```
|
||||
|
||||
## 设计原则
|
||||
@@ -57,4 +59,4 @@ python run_tests.py --cov # 运行测试并生成覆盖率报告
|
||||
---
|
||||
|
||||
# currentDate
|
||||
Today's date is 2026-03-14.
|
||||
Today's date is 2026-03-19.
|
||||
|
||||
@@ -6,6 +6,7 @@ import asyncio
|
||||
import typing
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TextIO
|
||||
|
||||
from .context import CancelToken
|
||||
@@ -121,10 +122,77 @@ class InMemoryDB:
|
||||
|
||||
|
||||
class InMemoryMemory:
|
||||
def __init__(self, store: dict[str, dict[str, Any]]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
store: dict[str, dict[str, Any]],
|
||||
*,
|
||||
expires_at: dict[str, datetime | None] | None = None,
|
||||
) -> None:
|
||||
self._store = store
|
||||
self._expires_at = expires_at if expires_at is not None else {}
|
||||
|
||||
@staticmethod
|
||||
def _is_ttl_entry(value: Any) -> bool:
|
||||
"""判断测试 memory 值是否使用 TTL 包装结构。
|
||||
|
||||
Args:
|
||||
value: 待检查的存储值。
|
||||
|
||||
Returns:
|
||||
bool: 如果包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
|
||||
"""
|
||||
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
|
||||
|
||||
@classmethod
|
||||
def _search_text(cls, value: Any) -> str:
|
||||
"""提取测试用 memory.search 的匹配文本。
|
||||
|
||||
Args:
|
||||
value: 当前存储的 memory 值。
|
||||
|
||||
Returns:
|
||||
str: 用于本地测试搜索的文本内容。
|
||||
"""
|
||||
if cls._is_ttl_entry(value):
|
||||
value = value.get("value")
|
||||
if not isinstance(value, dict):
|
||||
return ""
|
||||
for field_name in ("embedding_text", "content", "summary", "title", "text"):
|
||||
item = value.get(field_name)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
return str(value)
|
||||
|
||||
def _is_expired(self, key: str) -> bool:
|
||||
"""判断测试 memory 键是否已经过期。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 如果当前时间已超过过期时间则返回 ``True``。
|
||||
"""
|
||||
expires_at = self._expires_at.get(key)
|
||||
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
|
||||
|
||||
def _purge_if_expired(self, key: str) -> bool:
|
||||
"""在测试 helper 中清理已过期的 memory 条目。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 如果条目已过期并被清理则返回 ``True``。
|
||||
"""
|
||||
if not self._is_expired(key):
|
||||
return False
|
||||
self._store.pop(key, None)
|
||||
self._expires_at.pop(key, None)
|
||||
return True
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
if self._purge_if_expired(key):
|
||||
return default
|
||||
return self._store.get(key, default)
|
||||
|
||||
def save(self, key: str, value: dict[str, Any]) -> None:
|
||||
@@ -132,11 +200,14 @@ class InMemoryMemory:
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._store.pop(key, None)
|
||||
self._expires_at.pop(key, None)
|
||||
|
||||
def search(self, query: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
for key, value in self._store.items():
|
||||
if query in key or query in str(value):
|
||||
for key, value in list(self._store.items()):
|
||||
if self._purge_if_expired(key):
|
||||
continue
|
||||
if query in key or query in self._search_text(value):
|
||||
results.append({"key": key, "value": value})
|
||||
return results
|
||||
|
||||
@@ -200,7 +271,10 @@ class MockCapabilityRouter(CapabilityRouter):
|
||||
self._llm_stream_responses: list[str] = []
|
||||
super().__init__()
|
||||
self.db = InMemoryDB(self.db_store)
|
||||
self.memory = InMemoryMemory(self.memory_store)
|
||||
self.memory = InMemoryMemory(
|
||||
self.memory_store,
|
||||
expires_at=self._memory_expires_at,
|
||||
)
|
||||
|
||||
def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]:
|
||||
return super().list_dynamic_command_routes(plugin_id)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from ._proxy import CapabilityProxy
|
||||
|
||||
@@ -48,30 +48,47 @@ class MemoryClient:
|
||||
"""
|
||||
self._proxy = proxy
|
||||
|
||||
async def search(self, query: str) -> list[dict[str, Any]]:
|
||||
"""Search memory items with the current bridge behavior.
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
mode: Literal["auto", "keyword", "vector", "hybrid"] = "auto",
|
||||
limit: int | None = None,
|
||||
min_score: float | None = None,
|
||||
provider_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""搜索记忆项。
|
||||
|
||||
The current core bridge matches `query` against the memory key and the
|
||||
serialized memory payload. It does not provide vector or semantic
|
||||
retrieval yet.
|
||||
|
||||
Returned items preserve the original `{"key": ..., "value": {...}}`
|
||||
shape. When `value` is a mapping, its fields are also exposed at the
|
||||
top level for compatibility with existing plugin examples.
|
||||
默认会在有 embedding provider 时执行 hybrid 检索,
|
||||
否则退化为关键词检索。返回结果包含 `score` 与 `match_type` 字段。
|
||||
|
||||
Args:
|
||||
query: 搜索查询文本
|
||||
mode: 搜索模式,支持 auto/keyword/vector/hybrid
|
||||
limit: 最大返回条数
|
||||
min_score: 最低分数阈值
|
||||
provider_id: 指定 embedding provider,默认使用当前激活的 provider
|
||||
|
||||
Returns:
|
||||
匹配的记忆项列表,按相关度排序
|
||||
|
||||
示例:
|
||||
# 搜索用户偏好相关的记忆
|
||||
results = await ctx.memory.search("用户喜欢什么颜色")
|
||||
results = await ctx.memory.search(
|
||||
"用户喜欢什么颜色",
|
||||
mode="hybrid",
|
||||
limit=5,
|
||||
)
|
||||
for item in results:
|
||||
print(item["key"], item["content"])
|
||||
print(item["key"], item["score"], item["match_type"])
|
||||
"""
|
||||
output = await self._proxy.call("memory.search", {"query": query})
|
||||
payload: dict[str, Any] = {"query": query, "mode": mode}
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
if min_score is not None:
|
||||
payload["min_score"] = min_score
|
||||
if provider_id is not None:
|
||||
payload["provider_id"] = provider_id
|
||||
output = await self._proxy.call("memory.search", payload)
|
||||
items = output.get("items")
|
||||
if not isinstance(items, (list, tuple)):
|
||||
return []
|
||||
@@ -96,16 +113,20 @@ class MemoryClient:
|
||||
key: 记忆项的唯一标识键
|
||||
value: 要存储的数据字典
|
||||
**extra: 额外的键值对,会合并到 value 中
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 value 不是 dict 类型
|
||||
|
||||
示例:
|
||||
# 保存用户偏好
|
||||
保存用户偏好
|
||||
await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
|
||||
|
||||
# 使用关键字参数
|
||||
使用关键字参数
|
||||
await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
|
||||
|
||||
使用 embedding_text 显式指定检索文本
|
||||
await ctx.memory.save(
|
||||
"profile",
|
||||
{"name": "alice", "embedding_text": "Alice 喜欢蓝色和海边"},
|
||||
)
|
||||
"""
|
||||
if value is not None and not isinstance(value, dict):
|
||||
raise TypeError("memory.save 的 value 必须是 dict")
|
||||
@@ -230,16 +251,22 @@ class MemoryClient:
|
||||
async def stats(self) -> dict[str, Any]:
|
||||
"""获取记忆系统统计信息。
|
||||
|
||||
返回记忆系统的当前状态,包括总条目数等统计信息。
|
||||
返回记忆系统的当前状态,包括条目数、索引状态和脏索引数量。
|
||||
|
||||
Returns:
|
||||
统计信息字典,包含:
|
||||
- total_items: 总记忆条目数
|
||||
- total_bytes: 总占用字节数(可选)
|
||||
- ttl_entries: 带过期时间的条目数(可选)
|
||||
- indexed_items: 已建立检索索引的条目数(可选)
|
||||
- embedded_items: 已生成向量的条目数(可选)
|
||||
- dirty_items: 等待重建索引的条目数(可选)
|
||||
|
||||
示例:
|
||||
stats = await ctx.memory.stats()
|
||||
print(f"记忆库共有 {stats['total_items']} 条记录")
|
||||
if "embedded_items" in stats:
|
||||
print(f"其中 {stats['embedded_items']} 条已经向量化")
|
||||
"""
|
||||
output = await self._proxy.call("memory.stats", {})
|
||||
stats = {
|
||||
@@ -250,4 +277,10 @@ class MemoryClient:
|
||||
stats["plugin_id"] = output.get("plugin_id")
|
||||
if "ttl_entries" in output:
|
||||
stats["ttl_entries"] = output.get("ttl_entries")
|
||||
if "indexed_items" in output:
|
||||
stats["indexed_items"] = output.get("indexed_items")
|
||||
if "embedded_items" in output:
|
||||
stats["embedded_items"] = output.get("embedded_items")
|
||||
if "dirty_items" in output:
|
||||
stats["dirty_items"] = output.get("dirty_items")
|
||||
return stats
|
||||
|
||||
@@ -159,12 +159,12 @@ async for chunk in ctx.llm.stream_chat("讲一个故事"):
|
||||
|
||||
### search()
|
||||
|
||||
语义搜索记忆项。
|
||||
搜索记忆项。默认在有 embedding provider 时执行 hybrid 检索。
|
||||
|
||||
```python
|
||||
results = await ctx.memory.search("用户喜欢什么颜色")
|
||||
results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5)
|
||||
for item in results:
|
||||
print(item["key"], item["content"])
|
||||
print(item["key"], item["score"], item["match_type"])
|
||||
```
|
||||
|
||||
### save()
|
||||
@@ -177,6 +177,12 @@ await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
|
||||
|
||||
# 使用关键字参数
|
||||
await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
|
||||
|
||||
# 显式指定检索文本
|
||||
await ctx.memory.save(
|
||||
"profile:alice",
|
||||
{"name": "Alice", "embedding_text": "Alice 喜欢蓝色和海边"},
|
||||
)
|
||||
```
|
||||
|
||||
### get()
|
||||
@@ -202,6 +208,15 @@ await ctx.memory.save_with_ttl(
|
||||
)
|
||||
```
|
||||
|
||||
### stats()
|
||||
|
||||
查看记忆索引状态。
|
||||
|
||||
```python
|
||||
stats = await ctx.memory.stats()
|
||||
print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items"))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Database 客户端
|
||||
|
||||
@@ -66,10 +66,12 @@ from astrbot_sdk.clients import MemoryClient
|
||||
|
||||
#### search()
|
||||
|
||||
语义搜索。
|
||||
搜索记忆。默认在有 embedding provider 时执行 hybrid 检索。
|
||||
|
||||
```python
|
||||
results = await ctx.memory.search("用户喜欢什么颜色")
|
||||
results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5)
|
||||
for item in results:
|
||||
print(item["key"], item["score"], item["match_type"])
|
||||
```
|
||||
|
||||
#### save()
|
||||
@@ -78,6 +80,10 @@ results = await ctx.memory.search("用户喜欢什么颜色")
|
||||
|
||||
```python
|
||||
await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
|
||||
await ctx.memory.save(
|
||||
"profile:alice",
|
||||
{"name": "Alice", "embedding_text": "Alice 喜欢蓝色和海边"},
|
||||
)
|
||||
```
|
||||
|
||||
#### get()
|
||||
@@ -108,6 +114,15 @@ await ctx.memory.save_with_ttl(
|
||||
await ctx.memory.delete("old_note")
|
||||
```
|
||||
|
||||
#### stats()
|
||||
|
||||
查看记忆索引状态。
|
||||
|
||||
```python
|
||||
stats = await ctx.memory.stats()
|
||||
print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items"))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. DBClient - KV 数据库客户端
|
||||
|
||||
@@ -142,25 +142,33 @@ from astrbot_sdk.clients import MemoryClient
|
||||
|
||||
### 方法
|
||||
|
||||
#### `search(query)`
|
||||
#### `search(query, *, mode="auto", limit=None, min_score=None, provider_id=None)`
|
||||
|
||||
语义搜索记忆项。
|
||||
搜索记忆项。默认会在存在 embedding provider 时执行 hybrid 检索,
|
||||
否则退化为关键词检索。
|
||||
|
||||
**参数**:
|
||||
- `query` (`str`): 搜索查询文本(自然语言)
|
||||
- `mode` (`Literal["auto", "keyword", "vector", "hybrid"]`): 搜索模式
|
||||
- `limit` (`int | None`): 最大返回条数
|
||||
- `min_score` (`float | None`): 最低分数阈值
|
||||
- `provider_id` (`str | None`): 指定 embedding provider
|
||||
|
||||
**返回**: `list[dict]` - 匹配的记忆项列表,按相关度排序
|
||||
**返回**: `list[dict]` - 匹配的记忆项列表。每项至少包含 `key`、`value`、`score`、`match_type`
|
||||
|
||||
**示例**:
|
||||
|
||||
```python
|
||||
# 搜索用户偏好
|
||||
results = await ctx.memory.search("用户喜欢什么颜色")
|
||||
results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5)
|
||||
for item in results:
|
||||
print(f"Key: {item['key']}, Content: {item['content']}")
|
||||
print(item["key"], item["score"], item["match_type"])
|
||||
|
||||
# 搜索对话摘要
|
||||
summaries = await ctx.memory.search("之前讨论过什么技术话题")
|
||||
# 强制使用关键词检索
|
||||
keyword_hits = await ctx.memory.search("blue", mode="keyword", min_score=0.9)
|
||||
|
||||
# 使用当前激活的 embedding provider 执行向量检索
|
||||
vector_hits = await ctx.memory.search("之前讨论过什么技术话题", mode="vector")
|
||||
```
|
||||
|
||||
---
|
||||
@@ -192,6 +200,16 @@ await ctx.memory.save(
|
||||
tags=["work"],
|
||||
timestamp="2024-01-01"
|
||||
)
|
||||
|
||||
# 显式指定检索文本
|
||||
await ctx.memory.save(
|
||||
"profile:alice",
|
||||
{
|
||||
"name": "Alice",
|
||||
"city": "Shanghai",
|
||||
"embedding_text": "Alice 喜欢蓝色、海边和摄影",
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
@@ -314,6 +332,12 @@ stats = await ctx.memory.stats()
|
||||
print(f"记忆库共有 {stats['total_items']} 条记录")
|
||||
if 'ttl_entries' in stats:
|
||||
print(f"其中 {stats['ttl_entries']} 条有过期时间")
|
||||
if 'indexed_items' in stats:
|
||||
print(f"已建立索引: {stats['indexed_items']}")
|
||||
if 'embedded_items' in stats:
|
||||
print(f"已向量化: {stats['embedded_items']}")
|
||||
if 'dirty_items' in stats:
|
||||
print(f"待重建索引: {stats['dirty_items']}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -210,12 +210,16 @@ async for chunk in ctx.llm.stream_chat("讲一个故事"):
|
||||
|
||||
##### `search()`
|
||||
|
||||
语义搜索。
|
||||
搜索记忆。默认在有 embedding provider 时执行 hybrid 检索。
|
||||
|
||||
```python
|
||||
results = await ctx.memory.search("用户喜欢什么颜色")
|
||||
results = await ctx.memory.search(
|
||||
"用户喜欢什么颜色",
|
||||
mode="hybrid",
|
||||
limit=5,
|
||||
)
|
||||
for item in results:
|
||||
print(item["key"], item["content"])
|
||||
print(item["key"], item["score"], item["match_type"])
|
||||
```
|
||||
|
||||
##### `save()`
|
||||
@@ -228,6 +232,15 @@ await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"})
|
||||
|
||||
# 使用关键字参数
|
||||
await ctx.memory.save("note", None, content="重要笔记", tags=["work"])
|
||||
|
||||
# 显式指定检索文本
|
||||
await ctx.memory.save(
|
||||
"profile:alice",
|
||||
{
|
||||
"name": "Alice",
|
||||
"embedding_text": "Alice 喜欢蓝色和海边",
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
##### `get()`
|
||||
@@ -261,6 +274,15 @@ await ctx.memory.save_with_ttl(
|
||||
await ctx.memory.delete("old_note")
|
||||
```
|
||||
|
||||
##### `stats()`
|
||||
|
||||
查看记忆索引状态。
|
||||
|
||||
```python
|
||||
stats = await ctx.memory.stats()
|
||||
print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items"))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. DB 客户端 (ctx.db)
|
||||
|
||||
@@ -75,11 +75,28 @@ LLM_STREAM_CHAT_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("text",), text={"type": "string"}
|
||||
)
|
||||
MEMORY_SEARCH_INPUT_SCHEMA = _object_schema(
|
||||
required=("query",), query={"type": "string"}
|
||||
required=("query",),
|
||||
query={"type": "string"},
|
||||
mode={"type": "string", "enum": ["auto", "keyword", "vector", "hybrid"]},
|
||||
limit={"type": "integer", "minimum": 1},
|
||||
min_score={"type": "number"},
|
||||
provider_id={"type": "string"},
|
||||
)
|
||||
MEMORY_SEARCH_OUTPUT_SCHEMA = _object_schema(
|
||||
required=("items",),
|
||||
items={"type": "array", "items": {"type": "object"}},
|
||||
items={
|
||||
"type": "array",
|
||||
"items": _object_schema(
|
||||
required=("key", "value", "score", "match_type"),
|
||||
key={"type": "string"},
|
||||
value=_nullable({"type": "object"}),
|
||||
score={"type": "number"},
|
||||
match_type={
|
||||
"type": "string",
|
||||
"enum": ["keyword", "vector", "hybrid"],
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
MEMORY_SAVE_INPUT_SCHEMA = _object_schema(
|
||||
required=("key", "value"),
|
||||
@@ -133,6 +150,9 @@ MEMORY_STATS_OUTPUT_SCHEMA = _object_schema(
|
||||
total_bytes=_nullable({"type": "integer"}),
|
||||
plugin_id=_nullable({"type": "string"}),
|
||||
ttl_entries=_nullable({"type": "integer"}),
|
||||
indexed_items=_nullable({"type": "integer"}),
|
||||
embedded_items=_nullable({"type": "integer"}),
|
||||
dirty_items=_nullable({"type": "integer"}),
|
||||
)
|
||||
SYSTEM_GET_DATA_DIR_INPUT_SCHEMA = _object_schema()
|
||||
SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA = _object_schema(
|
||||
|
||||
@@ -20,10 +20,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -52,8 +55,61 @@ def _clone_chain_payload(value: Any) -> list[dict[str, Any]]:
|
||||
]
|
||||
|
||||
|
||||
_MOCK_EMBEDDING_DIM = 24
|
||||
|
||||
|
||||
def _embedding_terms(text: str) -> list[str]:
|
||||
"""为 mock embedding 构造稳定的分词结果。
|
||||
|
||||
Args:
|
||||
text: 待向量化的原始文本。
|
||||
|
||||
Returns:
|
||||
list[str]: 用于生成 mock 向量的词项列表。
|
||||
"""
|
||||
normalized = re.sub(r"\s+", " ", str(text).strip().casefold())
|
||||
compact = normalized.replace(" ", "")
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word]
|
||||
if compact:
|
||||
if len(compact) == 1:
|
||||
terms.append(compact)
|
||||
else:
|
||||
terms.extend(
|
||||
compact[index : index + 2] for index in range(len(compact) - 1)
|
||||
)
|
||||
terms.append(compact)
|
||||
return terms or [normalized]
|
||||
|
||||
|
||||
def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]:
|
||||
"""生成确定性的 mock embedding 向量。
|
||||
|
||||
Args:
|
||||
text: 待向量化的文本。
|
||||
provider_id: 当前使用的 embedding provider 标识。
|
||||
|
||||
Returns:
|
||||
list[float]: 归一化后的 mock 向量。
|
||||
"""
|
||||
values = [0.0] * _MOCK_EMBEDDING_DIM
|
||||
for term in _embedding_terms(text):
|
||||
digest = hashlib.sha256(f"{provider_id}:{term}".encode("utf-8")).digest()
|
||||
index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM
|
||||
values[index] += 1.0 + min(len(term), 8) * 0.05
|
||||
norm = math.sqrt(sum(value * value for value in values))
|
||||
if norm <= 0:
|
||||
return values
|
||||
return [value / norm for value in values]
|
||||
|
||||
|
||||
class _CapabilityRouterHost:
|
||||
memory_store: dict[str, dict[str, Any]]
|
||||
_memory_index: dict[str, dict[str, Any]]
|
||||
_memory_dirty_keys: set[str]
|
||||
_memory_expires_at: dict[str, datetime | None]
|
||||
db_store: dict[str, Any]
|
||||
sent_messages: list[dict[str, Any]]
|
||||
event_actions: list[dict[str, Any]]
|
||||
@@ -278,15 +334,471 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_ttl_memory_entry(value: Any) -> bool:
|
||||
"""判断存储值是否使用了 TTL 包装结构。
|
||||
|
||||
Args:
|
||||
value: 待检查的存储值。
|
||||
|
||||
Returns:
|
||||
bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。
|
||||
"""
|
||||
return isinstance(value, dict) and "value" in value and "ttl_seconds" in value
|
||||
|
||||
@classmethod
|
||||
def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None:
|
||||
"""提取用于检索的原始 memory payload。
|
||||
|
||||
Args:
|
||||
stored: memory_store 中保存的原始值。
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。
|
||||
"""
|
||||
if not isinstance(stored, dict):
|
||||
return None
|
||||
if cls._is_ttl_memory_entry(stored):
|
||||
value = stored.get("value")
|
||||
return value if isinstance(value, dict) else None
|
||||
return stored
|
||||
|
||||
@classmethod
|
||||
def _extract_memory_text(cls, stored: Any) -> str:
|
||||
"""提取用于检索索引的首选文本。
|
||||
|
||||
Args:
|
||||
stored: memory_store 中保存的原始值。
|
||||
|
||||
Returns:
|
||||
str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。
|
||||
"""
|
||||
value = cls._memory_value_for_search(stored)
|
||||
if not isinstance(value, dict):
|
||||
return ""
|
||||
for field_name in ("embedding_text", "content", "summary", "title", "text"):
|
||||
item = value.get(field_name)
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
@staticmethod
|
||||
def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None:
|
||||
"""将 TTL 秒数转换为 UTC 过期时间。
|
||||
|
||||
Args:
|
||||
ttl_seconds: TTL 秒数。
|
||||
|
||||
Returns:
|
||||
datetime | None: 绝对过期时间;当输入无效时返回 ``None``。
|
||||
"""
|
||||
try:
|
||||
ttl = int(ttl_seconds)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if ttl < 1:
|
||||
return None
|
||||
return datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
|
||||
@staticmethod
|
||||
def _memory_keyword_score(query: str, key: str, text: str) -> float:
|
||||
"""计算关键词匹配分数。
|
||||
|
||||
Args:
|
||||
query: 查询文本。
|
||||
key: memory 条目的键。
|
||||
text: 已索引的检索文本。
|
||||
|
||||
Returns:
|
||||
float: 基于键名和文本命中的粗粒度关键词分数。
|
||||
"""
|
||||
normalized_query = str(query).casefold()
|
||||
if not normalized_query:
|
||||
return 1.0
|
||||
normalized_key = str(key).casefold()
|
||||
normalized_text = str(text).casefold()
|
||||
if normalized_query in normalized_key:
|
||||
return 1.0
|
||||
if normalized_query in normalized_text:
|
||||
return 0.9
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(left: list[float], right: list[float]) -> float:
|
||||
"""计算两个向量之间的余弦相似度。
|
||||
|
||||
Args:
|
||||
left: 左侧向量。
|
||||
right: 右侧向量。
|
||||
|
||||
Returns:
|
||||
float: 余弦相似度;输入不合法时返回 ``0.0``。
|
||||
"""
|
||||
if not left or not right or len(left) != len(right):
|
||||
return 0.0
|
||||
left_norm = math.sqrt(sum(value * value for value in left))
|
||||
right_norm = math.sqrt(sum(value * value for value in right))
|
||||
if left_norm <= 0 or right_norm <= 0:
|
||||
return 0.0
|
||||
return sum(a * b for a, b in zip(left, right, strict=False)) / (
|
||||
left_norm * right_norm
|
||||
)
|
||||
|
||||
def _resolve_memory_embedding_provider_id(
|
||||
self,
|
||||
provider_id: Any,
|
||||
*,
|
||||
required: bool,
|
||||
) -> str | None:
|
||||
"""解析 memory.search 要使用的 embedding provider。
|
||||
|
||||
Args:
|
||||
provider_id: 调用方显式传入的 provider 标识。
|
||||
required: 当前检索模式是否强制要求 embedding provider。
|
||||
|
||||
Returns:
|
||||
str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。
|
||||
"""
|
||||
normalized = str(provider_id).strip() if provider_id is not None else ""
|
||||
if normalized:
|
||||
self._provider_entry(
|
||||
{"provider_id": normalized},
|
||||
"memory.search",
|
||||
"embedding",
|
||||
)
|
||||
return normalized
|
||||
active_id = self._active_provider_ids.get("embedding")
|
||||
if active_id is not None:
|
||||
normalized_active = str(active_id).strip()
|
||||
if normalized_active:
|
||||
self._provider_entry(
|
||||
{"provider_id": normalized_active},
|
||||
"memory.search",
|
||||
"embedding",
|
||||
)
|
||||
return normalized_active
|
||||
if required:
|
||||
raise AstrBotError.invalid_input(
|
||||
"memory.search requires an embedding provider",
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]:
|
||||
"""将原始索引项规范化为内部统一结构。
|
||||
|
||||
Args:
|
||||
entry: 当前索引表中的原始项。
|
||||
text: 当前条目的索引文本。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。
|
||||
"""
|
||||
if isinstance(entry, dict):
|
||||
return {
|
||||
"text": str(entry.get("text", text)),
|
||||
"embedding": (
|
||||
[float(item) for item in entry.get("embedding", [])]
|
||||
if isinstance(entry.get("embedding"), list)
|
||||
else None
|
||||
),
|
||||
"provider_id": (
|
||||
str(entry.get("provider_id")).strip()
|
||||
if entry.get("provider_id") is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
return {"text": text, "embedding": None, "provider_id": None}
|
||||
|
||||
def _clear_memory_sidecars(self, key: str) -> None:
|
||||
"""清理指定 memory 键对应的所有 sidecar 状态。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._memory_index.pop(key, None)
|
||||
self._memory_expires_at.pop(key, None)
|
||||
self._memory_dirty_keys.discard(key)
|
||||
|
||||
def _delete_memory_entry(self, key: str) -> bool:
|
||||
"""删除 memory 条目并同步清理 sidecar 状态。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 条目存在并删除成功时返回 ``True``。
|
||||
"""
|
||||
deleted = self.memory_store.pop(key, None) is not None
|
||||
self._clear_memory_sidecars(key)
|
||||
return deleted
|
||||
|
||||
def _upsert_memory_sidecars(
|
||||
self,
|
||||
key: str,
|
||||
stored: dict[str, Any],
|
||||
*,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""创建或更新单条 memory 的 sidecar 索引状态。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
stored: 需要建立索引的原始存储值。
|
||||
expires_at: 可选的绝对过期时间。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._memory_index[key] = {
|
||||
"text": self._extract_memory_text(stored),
|
||||
"embedding": None,
|
||||
"provider_id": None,
|
||||
}
|
||||
if expires_at is None:
|
||||
self._memory_expires_at.pop(key, None)
|
||||
else:
|
||||
self._memory_expires_at[key] = expires_at
|
||||
self._memory_dirty_keys.add(key)
|
||||
|
||||
def _ensure_memory_sidecars(self, key: str, stored: Any) -> None:
|
||||
"""确保 sidecar 状态与当前存储值保持一致。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
stored: memory_store 中的当前存储值。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not isinstance(stored, dict):
|
||||
return
|
||||
text = self._extract_memory_text(stored)
|
||||
existed = key in self._memory_index
|
||||
entry = self._memory_index_entry(self._memory_index.get(key), text=text)
|
||||
if entry["text"] != text:
|
||||
entry["text"] = text
|
||||
entry["embedding"] = None
|
||||
entry["provider_id"] = None
|
||||
self._memory_dirty_keys.add(key)
|
||||
self._memory_index[key] = entry
|
||||
if not existed:
|
||||
self._memory_dirty_keys.add(key)
|
||||
|
||||
def _is_memory_expired(self, key: str) -> bool:
|
||||
"""判断 memory 条目是否已过期。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 如果当前时间已超过记录的过期时间则返回 ``True``。
|
||||
"""
|
||||
expires_at = self._memory_expires_at.get(key)
|
||||
return expires_at is not None and expires_at <= datetime.now(timezone.utc)
|
||||
|
||||
def _purge_expired_memory_entry(self, key: str) -> bool:
|
||||
"""在单条 memory 已过期时立即清理它。
|
||||
|
||||
Args:
|
||||
key: memory 条目的键。
|
||||
|
||||
Returns:
|
||||
bool: 如果条目已过期并被成功清理则返回 ``True``。
|
||||
"""
|
||||
if not self._is_memory_expired(key):
|
||||
return False
|
||||
self._delete_memory_entry(key)
|
||||
return True
|
||||
|
||||
def _purge_expired_memory_entries(self) -> None:
|
||||
"""批量清理所有已跟踪的过期 TTL 条目。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for key in list(self._memory_expires_at):
|
||||
self._purge_expired_memory_entry(key)
|
||||
|
||||
async def _embedding_for_text(
|
||||
self,
|
||||
*,
|
||||
provider_id: str,
|
||||
text: str,
|
||||
) -> list[float]:
|
||||
"""通过 embedding capability 获取单条文本向量。
|
||||
|
||||
Args:
|
||||
provider_id: 使用的 embedding provider 标识。
|
||||
text: 待向量化的文本。
|
||||
|
||||
Returns:
|
||||
list[float]: provider 返回的向量;异常场景下返回空列表。
|
||||
"""
|
||||
output = await self._provider_embedding_get_embedding(
|
||||
"",
|
||||
{"provider_id": provider_id, "text": text},
|
||||
None,
|
||||
)
|
||||
embedding = output.get("embedding")
|
||||
if not isinstance(embedding, list):
|
||||
return []
|
||||
return [float(item) for item in embedding]
|
||||
|
||||
async def _embeddings_for_texts(
|
||||
self,
|
||||
*,
|
||||
provider_id: str,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""批量获取多条文本的 embedding 向量。
|
||||
|
||||
Args:
|
||||
provider_id: 使用的 embedding provider 标识。
|
||||
texts: 待向量化的文本列表。
|
||||
|
||||
Returns:
|
||||
list[list[float]]: 与输入顺序对应的向量列表。
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
output = await self._provider_embedding_get_embeddings(
|
||||
"",
|
||||
{"provider_id": provider_id, "texts": texts},
|
||||
None,
|
||||
)
|
||||
embeddings = output.get("embeddings")
|
||||
if not isinstance(embeddings, list):
|
||||
return []
|
||||
return [
|
||||
[float(value) for value in item]
|
||||
for item in embeddings
|
||||
if isinstance(item, list)
|
||||
]
|
||||
|
||||
async def _refresh_memory_embeddings(self, *, provider_id: str) -> None:
|
||||
"""刷新当前 provider 下脏或过期的 memory 向量索引。
|
||||
|
||||
Args:
|
||||
provider_id: 当前使用的 embedding provider 标识。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
keys_to_refresh: list[str] = []
|
||||
texts_to_refresh: list[str] = []
|
||||
for key, stored in self.memory_store.items():
|
||||
self._ensure_memory_sidecars(key, stored)
|
||||
entry = self._memory_index_entry(
|
||||
self._memory_index.get(key),
|
||||
text=self._extract_memory_text(stored),
|
||||
)
|
||||
should_refresh = (
|
||||
key in self._memory_dirty_keys
|
||||
or entry["embedding"] is None
|
||||
or entry["provider_id"] != provider_id
|
||||
)
|
||||
self._memory_index[key] = entry
|
||||
if should_refresh:
|
||||
keys_to_refresh.append(key)
|
||||
texts_to_refresh.append(str(entry["text"]))
|
||||
embeddings = await self._embeddings_for_texts(
|
||||
provider_id=provider_id,
|
||||
texts=texts_to_refresh,
|
||||
)
|
||||
for index, key in enumerate(keys_to_refresh):
|
||||
entry = self._memory_index_entry(
|
||||
self._memory_index.get(key),
|
||||
text=str(texts_to_refresh[index]),
|
||||
)
|
||||
entry["embedding"] = embeddings[index] if index < len(embeddings) else []
|
||||
entry["provider_id"] = provider_id
|
||||
self._memory_index[key] = entry
|
||||
self._memory_dirty_keys.discard(key)
|
||||
|
||||
async def _memory_search(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
query = str(payload.get("query", ""))
|
||||
items = [
|
||||
{"key": key, "value": value}
|
||||
for key, value in self.memory_store.items()
|
||||
if query in key or query in json.dumps(value, ensure_ascii=False)
|
||||
]
|
||||
mode = str(payload.get("mode", "auto")).strip().lower() or "auto"
|
||||
limit = self._optional_int(payload.get("limit"))
|
||||
min_score = (
|
||||
float(payload.get("min_score"))
|
||||
if payload.get("min_score") is not None
|
||||
else None
|
||||
)
|
||||
self._purge_expired_memory_entries()
|
||||
provider_id = self._resolve_memory_embedding_provider_id(
|
||||
payload.get("provider_id"),
|
||||
required=mode in {"vector", "hybrid"},
|
||||
)
|
||||
effective_mode = mode
|
||||
if effective_mode == "auto":
|
||||
effective_mode = "hybrid" if provider_id is not None else "keyword"
|
||||
query_embedding: list[float] | None = None
|
||||
if effective_mode in {"vector", "hybrid"}:
|
||||
if provider_id is None:
|
||||
raise AstrBotError.invalid_input(
|
||||
"memory.search requires an embedding provider",
|
||||
)
|
||||
await self._refresh_memory_embeddings(provider_id=provider_id)
|
||||
query_embedding = await self._embedding_for_text(
|
||||
provider_id=provider_id,
|
||||
text=query,
|
||||
)
|
||||
|
||||
items: list[dict[str, Any]] = []
|
||||
for key, value in self.memory_store.items():
|
||||
self._ensure_memory_sidecars(key, value)
|
||||
entry = self._memory_index_entry(
|
||||
self._memory_index.get(key),
|
||||
text=self._extract_memory_text(value),
|
||||
)
|
||||
text = str(entry.get("text", ""))
|
||||
keyword_score = self._memory_keyword_score(query, key, text)
|
||||
vector_score = 0.0
|
||||
if query_embedding is not None:
|
||||
embedding = entry.get("embedding")
|
||||
if isinstance(embedding, list):
|
||||
vector_score = max(
|
||||
0.0,
|
||||
self._cosine_similarity(query_embedding, embedding),
|
||||
)
|
||||
|
||||
if effective_mode == "keyword":
|
||||
score = keyword_score
|
||||
elif effective_mode == "vector":
|
||||
score = vector_score
|
||||
else:
|
||||
score = vector_score
|
||||
if keyword_score > 0:
|
||||
score = max(score, 0.4 + 0.6 * vector_score)
|
||||
if score <= 0:
|
||||
continue
|
||||
if min_score is not None and score < min_score:
|
||||
continue
|
||||
|
||||
if effective_mode == "keyword" or (keyword_score > 0 and vector_score <= 0):
|
||||
match_type = "keyword"
|
||||
elif effective_mode == "vector" or keyword_score <= 0:
|
||||
match_type = "vector"
|
||||
else:
|
||||
match_type = "hybrid"
|
||||
|
||||
items.append(
|
||||
{
|
||||
"key": key,
|
||||
"value": self._memory_value_for_search(value),
|
||||
"score": score,
|
||||
"match_type": match_type,
|
||||
}
|
||||
)
|
||||
items.sort(key=lambda item: (-float(item["score"]), str(item["key"])))
|
||||
if limit is not None and limit >= 0:
|
||||
items = items[:limit]
|
||||
return {"items": items}
|
||||
|
||||
async def _memory_save(
|
||||
@@ -297,17 +809,21 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
if not isinstance(value, dict):
|
||||
raise AstrBotError.invalid_input("memory.save 的 value 必须是 object")
|
||||
self.memory_store[key] = value
|
||||
self._upsert_memory_sidecars(key, value)
|
||||
return {}
|
||||
|
||||
async def _memory_get(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
return {"value": self.memory_store.get(str(payload.get("key", "")))}
|
||||
key = str(payload.get("key", ""))
|
||||
if self._purge_expired_memory_entry(key):
|
||||
return {"value": None}
|
||||
return {"value": self.memory_store.get(key)}
|
||||
|
||||
async def _memory_delete(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self.memory_store.pop(str(payload.get("key", "")), None)
|
||||
self._delete_memory_entry(str(payload.get("key", "")))
|
||||
return {}
|
||||
|
||||
async def _memory_save_with_ttl(
|
||||
@@ -320,7 +836,13 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
raise AstrBotError.invalid_input(
|
||||
"memory.save_with_ttl 的 value 必须是 object"
|
||||
)
|
||||
self.memory_store[key] = {"value": value, "ttl_seconds": ttl_seconds}
|
||||
stored = {"value": value, "ttl_seconds": ttl_seconds}
|
||||
self.memory_store[key] = stored
|
||||
self._upsert_memory_sidecars(
|
||||
key,
|
||||
stored,
|
||||
expires_at=self._memory_expiration_from_ttl(ttl_seconds),
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _memory_get_many(
|
||||
@@ -332,6 +854,9 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
keys = [str(item) for item in keys_payload]
|
||||
items = []
|
||||
for key in keys:
|
||||
if self._purge_expired_memory_entry(key):
|
||||
items.append({"key": key, "value": None})
|
||||
continue
|
||||
stored = self.memory_store.get(key)
|
||||
if (
|
||||
isinstance(stored, dict)
|
||||
@@ -353,28 +878,36 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
keys = [str(item) for item in keys_payload]
|
||||
deleted_count = 0
|
||||
for key in keys:
|
||||
if key in self.memory_store:
|
||||
del self.memory_store[key]
|
||||
if self._delete_memory_entry(key):
|
||||
deleted_count += 1
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
async def _memory_stats(
|
||||
self, _request_id: str, _payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._purge_expired_memory_entries()
|
||||
total_items = len(self.memory_store)
|
||||
total_bytes = sum(
|
||||
len(str(key)) + len(str(value)) for key, value in self.memory_store.items()
|
||||
)
|
||||
ttl_entries = sum(
|
||||
ttl_entries = len(self._memory_expires_at)
|
||||
indexed_items = len(self._memory_index)
|
||||
embedded_items = sum(
|
||||
1
|
||||
for value in self.memory_store.values()
|
||||
if isinstance(value, dict) and "value" in value and "ttl_seconds" in value
|
||||
for entry in self._memory_index.values()
|
||||
if isinstance(entry, dict)
|
||||
and isinstance(entry.get("embedding"), list)
|
||||
and bool(entry.get("embedding"))
|
||||
)
|
||||
dirty_items = len(self._memory_dirty_keys)
|
||||
return {
|
||||
"total_items": total_items,
|
||||
"total_bytes": total_bytes,
|
||||
"plugin_id": self._require_caller_plugin_id("memory.stats"),
|
||||
"ttl_entries": ttl_entries,
|
||||
"indexed_items": indexed_items,
|
||||
"embedded_items": embedded_items,
|
||||
"dirty_items": dirty_items,
|
||||
}
|
||||
|
||||
def _register_memory_capabilities(self) -> None:
|
||||
@@ -1072,17 +1605,22 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
async def _provider_embedding_get_embedding(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._provider_entry(
|
||||
provider = self._provider_entry(
|
||||
payload,
|
||||
"provider.embedding.get_embedding",
|
||||
"embedding",
|
||||
)
|
||||
return {"embedding": [0.0, 0.0, 0.0]}
|
||||
return {
|
||||
"embedding": _mock_embedding_vector(
|
||||
str(payload.get("text", "")),
|
||||
provider_id=str(provider.get("id", "")),
|
||||
)
|
||||
}
|
||||
|
||||
async def _provider_embedding_get_embeddings(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
) -> dict[str, Any]:
|
||||
self._provider_entry(
|
||||
provider = self._provider_entry(
|
||||
payload,
|
||||
"provider.embedding.get_embeddings",
|
||||
"embedding",
|
||||
@@ -1093,7 +1631,13 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
"provider.embedding.get_embeddings requires texts",
|
||||
)
|
||||
return {
|
||||
"embeddings": [[0.0, 0.0, 0.0] for _ in texts],
|
||||
"embeddings": [
|
||||
_mock_embedding_vector(
|
||||
str(text),
|
||||
provider_id=str(provider.get("id", "")),
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
}
|
||||
|
||||
async def _provider_embedding_get_dim(
|
||||
@@ -1104,7 +1648,7 @@ class BuiltinCapabilityRouterMixin(_CapabilityRouterHost):
|
||||
"provider.embedding.get_dim",
|
||||
"embedding",
|
||||
)
|
||||
return {"dim": 3}
|
||||
return {"dim": _MOCK_EMBEDDING_DIM}
|
||||
|
||||
async def _provider_rerank_rerank(
|
||||
self, _request_id: str, payload: dict[str, Any], _token
|
||||
|
||||
@@ -217,6 +217,9 @@ class CapabilityRouter(BuiltinCapabilityRouterMixin):
|
||||
self._registrations: dict[str, _CapabilityRegistration] = {}
|
||||
self.db_store: dict[str, Any] = {}
|
||||
self.memory_store: dict[str, dict[str, Any]] = {}
|
||||
self._memory_index: dict[str, dict[str, Any]] = {}
|
||||
self._memory_dirty_keys: set[str] = set()
|
||||
self._memory_expires_at: dict[str, datetime | None] = {}
|
||||
self.sent_messages: list[dict[str, Any]] = []
|
||||
self.event_actions: list[dict[str, Any]] = []
|
||||
self._event_streams: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@@ -838,7 +838,7 @@ class HandlerDispatcher:
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
return
|
||||
await Star().on_error(exc, event, ctx)
|
||||
await Star.default_on_error(exc, event, ctx)
|
||||
|
||||
|
||||
__all__ = ["CapabilityDispatcher", "HandlerDispatcher"]
|
||||
|
||||
@@ -102,7 +102,9 @@ class Star(PluginKVStoreMixin):
|
||||
options=options,
|
||||
)
|
||||
|
||||
async def on_error(self, error: Exception, event, ctx) -> None:
|
||||
@staticmethod
|
||||
async def default_on_error(error: Exception, event, ctx) -> None:
|
||||
del ctx
|
||||
if isinstance(error, AstrBotError):
|
||||
lines: list[str] = []
|
||||
if error.retryable:
|
||||
@@ -122,6 +124,9 @@ class Star(PluginKVStoreMixin):
|
||||
await event.reply("出了点问题,请联系插件作者")
|
||||
logger.error("handler 执行失败\n{}", traceback.format_exc())
|
||||
|
||||
async def on_error(self, error: Exception, event, ctx) -> None:
|
||||
await Star.default_on_error(error, event, ctx)
|
||||
|
||||
@classmethod
|
||||
def __astrbot_is_new_star__(cls) -> bool:
|
||||
return True
|
||||
|
||||
277
tests/test_memory_runtime.py
Normal file
277
tests/test_memory_runtime.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot_sdk._invocation_context import caller_plugin_scope
|
||||
from astrbot_sdk.runtime.capability_router import CapabilityRouter
|
||||
|
||||
|
||||
async def _call(
|
||||
router: CapabilityRouter,
|
||||
capability: str,
|
||||
payload: dict[str, object],
|
||||
) -> dict[str, object]:
|
||||
result = await router.execute(
|
||||
capability,
|
||||
payload,
|
||||
stream=False,
|
||||
cancel_token=object(),
|
||||
request_id=f"test-{capability}",
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_save_updates_sidecars_and_search() -> None:
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "user-pref", "value": {"content": "user likes blue"}},
|
||||
)
|
||||
|
||||
assert router.memory_store["user-pref"] == {"content": "user likes blue"}
|
||||
assert router._memory_index["user-pref"] == {
|
||||
"text": "user likes blue",
|
||||
"embedding": None,
|
||||
"provider_id": None,
|
||||
}
|
||||
assert "user-pref" in router._memory_dirty_keys
|
||||
assert "user-pref" not in router._memory_expires_at
|
||||
|
||||
result = await _call(router, "memory.search", {"query": "likes blue"})
|
||||
assert len(result["items"]) == 1
|
||||
item = result["items"][0]
|
||||
assert item["key"] == "user-pref"
|
||||
assert item["value"] == {"content": "user likes blue"}
|
||||
assert item["match_type"] == "hybrid"
|
||||
assert float(item["score"]) > 0
|
||||
assert router._memory_index["user-pref"]["provider_id"] == "mock-embedding-provider"
|
||||
assert isinstance(router._memory_index["user-pref"]["embedding"], list)
|
||||
assert "user-pref" not in router._memory_dirty_keys
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_search_keyword_mode_keeps_dirty_embedding_state() -> None:
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "alpha-key", "value": {"content": "blue ocean memory"}},
|
||||
)
|
||||
|
||||
result = await _call(
|
||||
router,
|
||||
"memory.search",
|
||||
{"query": "alpha", "mode": "keyword", "min_score": 0.95},
|
||||
)
|
||||
|
||||
assert [item["key"] for item in result["items"]] == ["alpha-key"]
|
||||
assert result["items"][0]["match_type"] == "keyword"
|
||||
assert router._memory_index["alpha-key"]["embedding"] is None
|
||||
assert "alpha-key" in router._memory_dirty_keys
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_search_vector_mode_supports_ranking_and_limit() -> None:
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "fruit-note", "value": {"content": "banana smoothie with mango"}},
|
||||
)
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "ocean-note", "value": {"content": "waves on the blue ocean"}},
|
||||
)
|
||||
|
||||
result = await _call(
|
||||
router,
|
||||
"memory.search",
|
||||
{"query": "banana smoothie", "mode": "vector", "limit": 1},
|
||||
)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
assert result["items"][0]["key"] == "fruit-note"
|
||||
assert result["items"][0]["match_type"] == "vector"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_search_auto_falls_back_to_keyword_without_embedding_provider() -> (
|
||||
None
|
||||
):
|
||||
router = CapabilityRouter()
|
||||
router._active_provider_ids["embedding"] = None
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "alpha-key", "value": {"content": "blue ocean memory"}},
|
||||
)
|
||||
|
||||
result = await _call(router, "memory.search", {"query": "alpha", "mode": "auto"})
|
||||
|
||||
assert [item["key"] for item in result["items"]] == ["alpha-key"]
|
||||
assert result["items"][0]["match_type"] == "keyword"
|
||||
assert router._memory_index["alpha-key"]["embedding"] is None
|
||||
assert "alpha-key" in router._memory_dirty_keys
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_search_reembeds_when_embedding_provider_changes() -> None:
|
||||
router = CapabilityRouter()
|
||||
router._provider_catalog["embedding"].append(
|
||||
{
|
||||
"id": "mock-embedding-provider-alt",
|
||||
"model": "mock-embedding-model-alt",
|
||||
"type": "mock",
|
||||
"provider_type": "embedding",
|
||||
}
|
||||
)
|
||||
router._provider_configs["mock-embedding-provider-alt"] = {
|
||||
"id": "mock-embedding-provider-alt",
|
||||
"model": "mock-embedding-model-alt",
|
||||
"type": "mock",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
}
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "topic", "value": {"content": "banana smoothie with mango"}},
|
||||
)
|
||||
|
||||
first = await _call(router, "memory.search", {"query": "banana smoothie"})
|
||||
first_embedding = list(router._memory_index["topic"]["embedding"])
|
||||
assert first["items"][0]["match_type"] == "hybrid"
|
||||
assert router._memory_index["topic"]["provider_id"] == "mock-embedding-provider"
|
||||
|
||||
router._active_provider_ids["embedding"] = "mock-embedding-provider-alt"
|
||||
|
||||
second = await _call(router, "memory.search", {"query": "banana smoothie"})
|
||||
second_embedding = list(router._memory_index["topic"]["embedding"])
|
||||
assert second["items"][0]["match_type"] == "hybrid"
|
||||
assert router._memory_index["topic"]["provider_id"] == "mock-embedding-provider-alt"
|
||||
assert first_embedding != second_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_stats_reports_index_embedding_and_dirty_counts() -> None:
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "a", "value": {"content": "alpha"}},
|
||||
)
|
||||
await _call(
|
||||
router,
|
||||
"memory.save_with_ttl",
|
||||
{"key": "b", "value": {"content": "beta"}, "ttl_seconds": 60},
|
||||
)
|
||||
|
||||
with caller_plugin_scope("test-plugin"):
|
||||
before = await _call(router, "memory.stats", {})
|
||||
assert before["total_items"] == 2
|
||||
assert before["ttl_entries"] == 1
|
||||
assert before["indexed_items"] == 2
|
||||
assert before["embedded_items"] == 0
|
||||
assert before["dirty_items"] == 2
|
||||
|
||||
await _call(router, "memory.search", {"query": "alpha"})
|
||||
|
||||
with caller_plugin_scope("test-plugin"):
|
||||
after = await _call(router, "memory.stats", {})
|
||||
assert after["total_items"] == 2
|
||||
assert after["ttl_entries"] == 1
|
||||
assert after["indexed_items"] == 2
|
||||
assert after["embedded_items"] == 2
|
||||
assert after["dirty_items"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_save_with_ttl_registers_expiry_and_purges_on_read() -> None:
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save_with_ttl",
|
||||
{"key": "temp-note", "value": {"content": "temporary note"}, "ttl_seconds": 60},
|
||||
)
|
||||
|
||||
assert "temp-note" in router._memory_index
|
||||
assert "temp-note" in router._memory_dirty_keys
|
||||
assert router._memory_expires_at["temp-note"] is not None
|
||||
|
||||
search_result = await _call(router, "memory.search", {"query": "temporary"})
|
||||
assert search_result["items"][0]["value"] == {"content": "temporary note"}
|
||||
|
||||
router._memory_expires_at["temp-note"] = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=1
|
||||
)
|
||||
|
||||
get_result = await _call(router, "memory.get", {"key": "temp-note"})
|
||||
assert get_result == {"value": None}
|
||||
assert "temp-note" not in router.memory_store
|
||||
assert "temp-note" not in router._memory_index
|
||||
assert "temp-note" not in router._memory_expires_at
|
||||
assert "temp-note" not in router._memory_dirty_keys
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_get_many_unwraps_ttl_value_and_returns_none_after_expiry() -> (
|
||||
None
|
||||
):
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save_with_ttl",
|
||||
{"key": "session", "value": {"content": "active session"}, "ttl_seconds": 60},
|
||||
)
|
||||
|
||||
result = await _call(router, "memory.get_many", {"keys": ["session", "missing"]})
|
||||
assert result == {
|
||||
"items": [
|
||||
{"key": "session", "value": {"content": "active session"}},
|
||||
{"key": "missing", "value": None},
|
||||
]
|
||||
}
|
||||
|
||||
router._memory_expires_at["session"] = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=1
|
||||
)
|
||||
|
||||
expired_result = await _call(router, "memory.get_many", {"keys": ["session"]})
|
||||
assert expired_result == {"items": [{"key": "session", "value": None}]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_delete_many_clears_sidecars() -> None:
|
||||
router = CapabilityRouter()
|
||||
|
||||
await _call(
|
||||
router,
|
||||
"memory.save",
|
||||
{"key": "a", "value": {"content": "alpha"}},
|
||||
)
|
||||
await _call(
|
||||
router,
|
||||
"memory.save_with_ttl",
|
||||
{"key": "b", "value": {"content": "beta"}, "ttl_seconds": 60},
|
||||
)
|
||||
|
||||
result = await _call(router, "memory.delete_many", {"keys": ["a", "b", "c"]})
|
||||
assert result == {"deleted_count": 2}
|
||||
assert router.memory_store == {}
|
||||
assert router._memory_index == {}
|
||||
assert router._memory_expires_at == {}
|
||||
assert router._memory_dirty_keys == set()
|
||||
Reference in New Issue
Block a user