mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 11:10:14 +08:00
Compare commits
2 Commits
fix/llm_to
...
codex/cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3009e96a8 | ||
|
|
184c7cac5b |
@@ -130,6 +130,7 @@ class LLMSummaryCompressor:
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
token_counter: TokenCounter | None = None,
|
||||
max_recent_rounds: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
@@ -139,11 +140,18 @@ class LLMSummaryCompressor:
|
||||
exact context. Clamped to 0-0.3.
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
token_counter: Optional custom token counter.
|
||||
max_recent_rounds: Maximum exact recent rounds to preserve after
|
||||
summarization. If None, only the token ratio limits recent rounds.
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
|
||||
self.compression_threshold = compression_threshold
|
||||
self.token_counter = token_counter or EstimateTokenCounter()
|
||||
self.max_recent_rounds = (
|
||||
None if max_recent_rounds is None else max(1, int(max_recent_rounds))
|
||||
)
|
||||
self.last_call_failed = False
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
@@ -212,6 +220,8 @@ class LLMSummaryCompressor:
|
||||
"""
|
||||
from .round_utils import split_into_rounds
|
||||
|
||||
self.last_call_failed = False
|
||||
|
||||
rounds = split_into_rounds(messages)
|
||||
message_rounds = [
|
||||
[seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds
|
||||
@@ -231,6 +241,14 @@ class LLMSummaryCompressor:
|
||||
old_rounds = old_rounds[:-1]
|
||||
recent_rounds = [latest_old_round, *recent_rounds]
|
||||
|
||||
if (
|
||||
self.max_recent_rounds is not None
|
||||
and len(recent_rounds) > self.max_recent_rounds
|
||||
):
|
||||
excess_count = len(recent_rounds) - self.max_recent_rounds
|
||||
old_rounds = old_rounds + recent_rounds[:excess_count]
|
||||
recent_rounds = recent_rounds[excess_count:]
|
||||
|
||||
if not old_rounds:
|
||||
if recent_rounds and messages and messages[-1].role == "user":
|
||||
return messages
|
||||
@@ -276,13 +294,19 @@ class LLMSummaryCompressor:
|
||||
response = await self.provider.text_chat(
|
||||
contexts=sanitized_summary_contexts,
|
||||
)
|
||||
if response.role == "err":
|
||||
logger.error(f"Failed to generate summary: {response.completion_text}")
|
||||
self.last_call_failed = True
|
||||
return messages
|
||||
summary_content = (response.completion_text or "").strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
self.last_call_failed = True
|
||||
return messages
|
||||
|
||||
if not summary_content:
|
||||
logger.warning("LLM context compression returned an empty summary.")
|
||||
self.last_call_failed = True
|
||||
return messages
|
||||
|
||||
# Build result: system messages + summary pair + recent rounds
|
||||
|
||||
@@ -3,6 +3,7 @@ from astrbot import logger
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .round_utils import count_conversation_rounds
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
@@ -36,6 +37,11 @@ class ContextManager:
|
||||
keep_recent_ratio=config.llm_compress_keep_recent_ratio,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
token_counter=self.token_counter,
|
||||
max_recent_rounds=(
|
||||
max(1, config.enforce_max_turns - 1)
|
||||
if config.enforce_max_turns != -1
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
@@ -56,15 +62,33 @@ class ContextManager:
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
turn_count = count_conversation_rounds(result)
|
||||
if turn_count > self.config.enforce_max_turns:
|
||||
should_truncate_by_turns = True
|
||||
if isinstance(self.compressor, LLMSummaryCompressor):
|
||||
logger.debug(
|
||||
"Turn limit (%s) exceeded (%s turns), "
|
||||
"trying LLM summary compression first.",
|
||||
self.config.enforce_max_turns,
|
||||
turn_count,
|
||||
)
|
||||
compressed = await self.compressor(result)
|
||||
if self.compressor.last_call_failed or compressed == result:
|
||||
logger.warning(
|
||||
"LLM summary compression failed; falling back "
|
||||
"to turn-based truncation.",
|
||||
)
|
||||
else:
|
||||
result = compressed
|
||||
should_truncate_by_turns = False
|
||||
if should_truncate_by_turns:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
@@ -95,7 +119,17 @@ class ContextManager:
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
compressed = await self.compressor(messages)
|
||||
if isinstance(self.compressor, LLMSummaryCompressor):
|
||||
if self.compressor.last_call_failed:
|
||||
logger.warning(
|
||||
"LLM summary compression failed; falling back to hard "
|
||||
"truncation to keep the request within the token limit.",
|
||||
)
|
||||
else:
|
||||
messages = compressed
|
||||
else:
|
||||
messages = compressed
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
@@ -113,9 +147,23 @@ class ContextManager:
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
"Context still exceeds max tokens after compression, applying hard truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
while self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
truncated = self.truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
if truncated == messages:
|
||||
truncated = self.truncator.truncate_by_halving(messages)
|
||||
if truncated == messages:
|
||||
break
|
||||
next_tokens = self.token_counter.count_tokens(truncated)
|
||||
if next_tokens >= tokens_after_summary:
|
||||
break
|
||||
messages = truncated
|
||||
tokens_after_summary = next_tokens
|
||||
|
||||
return messages
|
||||
|
||||
@@ -35,6 +35,22 @@ def split_into_rounds(
|
||||
return rounds
|
||||
|
||||
|
||||
def count_conversation_rounds(contexts: Sequence[RoundSegment]) -> int:
|
||||
"""Count logical user conversation rounds.
|
||||
|
||||
Args:
|
||||
contexts: Flat message contexts.
|
||||
|
||||
Returns:
|
||||
Number of rounds that contain a user message.
|
||||
"""
|
||||
return sum(
|
||||
1
|
||||
for round_segments in split_into_rounds(contexts)
|
||||
if any(_segment_role(seg) == "user" for seg in round_segments)
|
||||
)
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
if isinstance(content, list):
|
||||
normalized = [
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from ..message import Message
|
||||
from .round_utils import split_into_rounds
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
@@ -120,15 +121,25 @@ class ContextTruncator:
|
||||
return messages
|
||||
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
rounds = split_into_rounds(non_system_messages)
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
round_count = sum(
|
||||
1
|
||||
for round_segments in rounds
|
||||
if any(segment.role == "user" for segment in round_segments)
|
||||
)
|
||||
if round_count <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
truncated_contexts = [
|
||||
segment
|
||||
for round_segments in rounds[-num_to_keep:]
|
||||
for segment in round_segments
|
||||
]
|
||||
|
||||
# Find the first user message
|
||||
index = next(
|
||||
@@ -153,11 +164,21 @@ class ContextTruncator:
|
||||
return messages
|
||||
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
rounds = split_into_rounds(non_system_messages)
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
round_count = sum(
|
||||
1
|
||||
for round_segments in rounds
|
||||
if any(segment.role == "user" for segment in round_segments)
|
||||
)
|
||||
if round_count <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
truncated_non_system = [
|
||||
segment
|
||||
for round_segments in rounds[drop_turns:]
|
||||
for segment in round_segments
|
||||
]
|
||||
|
||||
# Find the first user message
|
||||
index = next(
|
||||
|
||||
@@ -12,6 +12,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.context.round_utils import count_conversation_rounds
|
||||
from astrbot.core.agent.message import AudioURLPart, ImageURLPart, Message, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
@@ -42,6 +43,26 @@ class MockProvider:
|
||||
return MagicMock(id="test_provider", type="openai")
|
||||
|
||||
|
||||
class MessageCountTokenCounter:
|
||||
"""Token counter that assigns a fixed cost to each message."""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count tokens by message count for deterministic tests.
|
||||
|
||||
Args:
|
||||
messages: The messages to count.
|
||||
trusted_token_usage: A trusted token count to return when present.
|
||||
|
||||
Returns:
|
||||
The deterministic token count.
|
||||
"""
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
return len(messages) * 100
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test suite for ContextManager."""
|
||||
|
||||
@@ -467,6 +488,74 @@ class TestContextManager:
|
||||
assert len(system_msgs) >= 1
|
||||
assert system_msgs[0].content == "System instruction"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_enforce_max_turns_uses_summary_first(self):
|
||||
"""LLM strategy should summarize before falling back to turn truncation."""
|
||||
provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=1,
|
||||
llm_compress_provider=provider, # type: ignore[arg-type]
|
||||
llm_compress_keep_recent_ratio=0,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
messages = [
|
||||
self.create_message("user", "First"),
|
||||
self.create_message("assistant", "First answer"),
|
||||
self.create_message("user", "Second"),
|
||||
self.create_message("assistant", "Second answer"),
|
||||
self.create_message("user", "Continue"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert provider.last_text_chat_kwargs is not None
|
||||
assert len(result) < len(messages)
|
||||
assert result[-1] is messages[-1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_enforce_max_turns_caps_recent_rounds(self):
|
||||
"""LLM summary should cap exact recent rounds for max-turn enforcement."""
|
||||
provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=2,
|
||||
llm_compress_provider=provider, # type: ignore[arg-type]
|
||||
llm_compress_keep_recent_ratio=0.3,
|
||||
custom_token_counter=MessageCountTokenCounter(),
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
messages = self.create_messages(20)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert provider.last_text_chat_kwargs is not None
|
||||
assert count_conversation_rounds(result) <= config.enforce_max_turns
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_counts_tool_chain_as_one_round(self):
|
||||
"""Tool messages in one round should not inflate turn count."""
|
||||
config = ContextConfig(enforce_max_turns=1, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
messages = [
|
||||
self.create_message("user", "Run a tool"),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Tool result", tool_call_id="call_1"),
|
||||
self.create_message("assistant", "Done"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert result == messages
|
||||
|
||||
# ==================== Token-based Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -966,6 +1055,27 @@ class TestContextManager:
|
||||
# Should have been compressed
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_until_token_threshold(self):
|
||||
"""Failed LLM compression should hard truncate until tokens are acceptable."""
|
||||
mock_provider = MockProvider()
|
||||
mock_provider.text_chat = AsyncMock(
|
||||
return_value=LLMResponse(role="err", completion_text="compress failed")
|
||||
)
|
||||
config = ContextConfig(
|
||||
max_context_tokens=300,
|
||||
truncate_turns=1,
|
||||
llm_compress_provider=mock_provider, # type: ignore[arg-type]
|
||||
custom_token_counter=MessageCountTokenCounter(),
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
messages = self.create_messages(10)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert manager.token_counter.count_tokens(result) <= 246
|
||||
|
||||
# ==================== split_into_rounds Tests ====================
|
||||
|
||||
def test_split_rounds_ensures_user_start(self):
|
||||
|
||||
@@ -134,6 +134,32 @@ class TestContextTruncator:
|
||||
assert len(result) == 6
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_counts_tool_chain_as_one_round(self):
|
||||
"""Tool calls/results inside one round should not count as extra turns."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run a tool"),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Tool result", tool_call_id="call_1"),
|
||||
self.create_message("assistant", "Done"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=1, drop_turns=1
|
||||
)
|
||||
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_ensures_user_first(self):
|
||||
"""Test that truncate_by_turns ensures user message comes first."""
|
||||
truncator = ContextTruncator()
|
||||
|
||||
Reference in New Issue
Block a user