Compare commits

...

2 Commits

Author SHA1 Message Date
Soulter
e3009e96a8 fix: cap llm compression recent rounds 2026-06-24 23:08:38 +08:00
Soulter
184c7cac5b fix: prefer llm context compression fallback 2026-06-24 23:02:18 +08:00
6 changed files with 260 additions and 15 deletions

View File

@@ -130,6 +130,7 @@ class LLMSummaryCompressor:
instruction_text: str | None = None,
compression_threshold: float = 0.82,
token_counter: TokenCounter | None = None,
max_recent_rounds: int | None = None,
) -> None:
"""Initialize the LLM summary compressor.
@@ -139,11 +140,18 @@ class LLMSummaryCompressor:
exact context. Clamped to 0-0.3.
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
token_counter: Optional custom token counter.
max_recent_rounds: Maximum exact recent rounds to preserve after
summarization. If None, only the token ratio limits recent rounds.
"""
self.provider = provider
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
self.compression_threshold = compression_threshold
self.token_counter = token_counter or EstimateTokenCounter()
self.max_recent_rounds = (
None if max_recent_rounds is None else max(1, int(max_recent_rounds))
)
self.last_call_failed = False
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
@@ -212,6 +220,8 @@ class LLMSummaryCompressor:
"""
from .round_utils import split_into_rounds
self.last_call_failed = False
rounds = split_into_rounds(messages)
message_rounds = [
[seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds
@@ -231,6 +241,14 @@ class LLMSummaryCompressor:
old_rounds = old_rounds[:-1]
recent_rounds = [latest_old_round, *recent_rounds]
if (
self.max_recent_rounds is not None
and len(recent_rounds) > self.max_recent_rounds
):
excess_count = len(recent_rounds) - self.max_recent_rounds
old_rounds = old_rounds + recent_rounds[:excess_count]
recent_rounds = recent_rounds[excess_count:]
if not old_rounds:
if recent_rounds and messages and messages[-1].role == "user":
return messages
@@ -276,13 +294,19 @@ class LLMSummaryCompressor:
response = await self.provider.text_chat(
contexts=sanitized_summary_contexts,
)
if response.role == "err":
logger.error(f"Failed to generate summary: {response.completion_text}")
self.last_call_failed = True
return messages
summary_content = (response.completion_text or "").strip()
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
self.last_call_failed = True
return messages
if not summary_content:
logger.warning("LLM context compression returned an empty summary.")
self.last_call_failed = True
return messages
# Build result: system messages + summary pair + recent rounds

View File

@@ -3,6 +3,7 @@ from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .round_utils import count_conversation_rounds
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
@@ -36,6 +37,11 @@ class ContextManager:
keep_recent_ratio=config.llm_compress_keep_recent_ratio,
instruction_text=config.llm_compress_instruction,
token_counter=self.token_counter,
max_recent_rounds=(
max(1, config.enforce_max_turns - 1)
if config.enforce_max_turns != -1
else None
),
)
else:
self.compressor = TruncateByTurnsCompressor(
@@ -56,15 +62,33 @@ class ContextManager:
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
turn_count = count_conversation_rounds(result)
if turn_count > self.config.enforce_max_turns:
should_truncate_by_turns = True
if isinstance(self.compressor, LLMSummaryCompressor):
logger.debug(
"Turn limit (%s) exceeded (%s turns), "
"trying LLM summary compression first.",
self.config.enforce_max_turns,
turn_count,
)
compressed = await self.compressor(result)
if self.compressor.last_call_failed or compressed == result:
logger.warning(
"LLM summary compression failed; falling back "
"to turn-based truncation.",
)
else:
result = compressed
should_truncate_by_turns = False
if should_truncate_by_turns:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
@@ -95,7 +119,17 @@ class ContextManager:
"""
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
compressed = await self.compressor(messages)
if isinstance(self.compressor, LLMSummaryCompressor):
if self.compressor.last_call_failed:
logger.warning(
"LLM summary compression failed; falling back to hard "
"truncation to keep the request within the token limit.",
)
else:
messages = compressed
else:
messages = compressed
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
@@ -113,9 +147,23 @@ class ContextManager:
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
"Context still exceeds max tokens after compression, applying hard truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
while self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
truncated = self.truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.config.truncate_turns,
)
if truncated == messages:
truncated = self.truncator.truncate_by_halving(messages)
if truncated == messages:
break
next_tokens = self.token_counter.count_tokens(truncated)
if next_tokens >= tokens_after_summary:
break
messages = truncated
tokens_after_summary = next_tokens
return messages

View File

@@ -35,6 +35,22 @@ def split_into_rounds(
return rounds
def count_conversation_rounds(contexts: Sequence[RoundSegment]) -> int:
"""Count logical user conversation rounds.
Args:
contexts: Flat message contexts.
Returns:
Number of rounds that contain a user message.
"""
return sum(
1
for round_segments in split_into_rounds(contexts)
if any(_segment_role(seg) == "user" for seg in round_segments)
)
def _content_to_text(content: Any) -> str:
if isinstance(content, list):
normalized = [

View File

@@ -1,4 +1,5 @@
from ..message import Message
from .round_utils import split_into_rounds
class ContextTruncator:
@@ -120,15 +121,25 @@ class ContextTruncator:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
rounds = split_into_rounds(non_system_messages)
if len(non_system_messages) // 2 <= keep_most_recent_turns:
round_count = sum(
1
for round_segments in rounds
if any(segment.role == "user" for segment in round_segments)
)
if round_count <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
truncated_contexts = [
segment
for round_segments in rounds[-num_to_keep:]
for segment in round_segments
]
# Find the first user message
index = next(
@@ -153,11 +164,21 @@ class ContextTruncator:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
rounds = split_into_rounds(non_system_messages)
if len(non_system_messages) // 2 <= drop_turns:
round_count = sum(
1
for round_segments in rounds
if any(segment.role == "user" for segment in round_segments)
)
if round_count <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
truncated_non_system = [
segment
for round_segments in rounds[drop_turns:]
for segment in round_segments
]
# Find the first user message
index = next(

View File

@@ -12,6 +12,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.context.round_utils import count_conversation_rounds
from astrbot.core.agent.message import AudioURLPart, ImageURLPart, Message, TextPart
from astrbot.core.provider.entities import LLMResponse
@@ -42,6 +43,26 @@ class MockProvider:
return MagicMock(id="test_provider", type="openai")
class MessageCountTokenCounter:
"""Token counter that assigns a fixed cost to each message."""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count tokens by message count for deterministic tests.
Args:
messages: The messages to count.
trusted_token_usage: A trusted token count to return when present.
Returns:
The deterministic token count.
"""
if trusted_token_usage > 0:
return trusted_token_usage
return len(messages) * 100
class TestContextManager:
"""Test suite for ContextManager."""
@@ -467,6 +488,74 @@ class TestContextManager:
assert len(system_msgs) >= 1
assert system_msgs[0].content == "System instruction"
@pytest.mark.asyncio
async def test_llm_enforce_max_turns_uses_summary_first(self):
"""LLM strategy should summarize before falling back to turn truncation."""
provider = MockProvider()
config = ContextConfig(
enforce_max_turns=1,
llm_compress_provider=provider, # type: ignore[arg-type]
llm_compress_keep_recent_ratio=0,
)
manager = ContextManager(config)
messages = [
self.create_message("user", "First"),
self.create_message("assistant", "First answer"),
self.create_message("user", "Second"),
self.create_message("assistant", "Second answer"),
self.create_message("user", "Continue"),
]
result = await manager.process(messages)
assert provider.last_text_chat_kwargs is not None
assert len(result) < len(messages)
assert result[-1] is messages[-1]
@pytest.mark.asyncio
async def test_llm_enforce_max_turns_caps_recent_rounds(self):
"""LLM summary should cap exact recent rounds for max-turn enforcement."""
provider = MockProvider()
config = ContextConfig(
enforce_max_turns=2,
llm_compress_provider=provider, # type: ignore[arg-type]
llm_compress_keep_recent_ratio=0.3,
custom_token_counter=MessageCountTokenCounter(),
)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert provider.last_text_chat_kwargs is not None
assert count_conversation_rounds(result) <= config.enforce_max_turns
@pytest.mark.asyncio
async def test_enforce_max_turns_counts_tool_chain_as_one_round(self):
"""Tool messages in one round should not inflate turn count."""
config = ContextConfig(enforce_max_turns=1, truncate_turns=1)
manager = ContextManager(config)
messages = [
self.create_message("user", "Run a tool"),
Message(
role="assistant",
content="Calling tool",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "lookup", "arguments": "{}"},
}
],
),
Message(role="tool", content="Tool result", tool_call_id="call_1"),
self.create_message("assistant", "Done"),
]
result = await manager.process(messages)
assert result == messages
# ==================== Token-based Compression Tests ====================
@pytest.mark.asyncio
@@ -966,6 +1055,27 @@ class TestContextManager:
# Should have been compressed
assert len(result) <= len(messages)
@pytest.mark.asyncio
async def test_llm_failure_falls_back_until_token_threshold(self):
"""Failed LLM compression should hard truncate until tokens are acceptable."""
mock_provider = MockProvider()
mock_provider.text_chat = AsyncMock(
return_value=LLMResponse(role="err", completion_text="compress failed")
)
config = ContextConfig(
max_context_tokens=300,
truncate_turns=1,
llm_compress_provider=mock_provider, # type: ignore[arg-type]
custom_token_counter=MessageCountTokenCounter(),
)
manager = ContextManager(config)
messages = self.create_messages(10)
result = await manager.process(messages)
assert len(result) == 2
assert manager.token_counter.count_tokens(result) <= 246
# ==================== split_into_rounds Tests ====================
def test_split_rounds_ensures_user_start(self):

View File

@@ -134,6 +134,32 @@ class TestContextTruncator:
assert len(result) == 6
assert result == messages
def test_truncate_by_turns_counts_tool_chain_as_one_round(self):
"""Tool calls/results inside one round should not count as extra turns."""
truncator = ContextTruncator()
messages = [
self.create_message("user", "Run a tool"),
Message(
role="assistant",
content="Calling tool",
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {"name": "lookup", "arguments": "{}"},
}
],
),
Message(role="tool", content="Tool result", tool_call_id="call_1"),
self.create_message("assistant", "Done"),
]
result = truncator.truncate_by_turns(
messages, keep_most_recent_turns=1, drop_turns=1
)
assert result == messages
def test_truncate_by_turns_ensures_user_first(self):
"""Test that truncate_by_turns ensures user message comes first."""
truncator = ContextTruncator()