fix: cap llm compression recent rounds

This commit is contained in:
Soulter
2026-06-24 23:08:38 +08:00
parent 184c7cac5b
commit e3009e96a8
4 changed files with 55 additions and 9 deletions

View File

@@ -130,6 +130,7 @@ class LLMSummaryCompressor:
instruction_text: str | None = None,
compression_threshold: float = 0.82,
token_counter: TokenCounter | None = None,
max_recent_rounds: int | None = None,
) -> None:
"""Initialize the LLM summary compressor.
@@ -139,11 +140,17 @@ class LLMSummaryCompressor:
exact context. Clamped to 0-0.3.
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
token_counter: Optional custom token counter.
max_recent_rounds: Maximum exact recent rounds to preserve after
summarization. If None, only the token ratio limits recent rounds.
"""
self.provider = provider
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
self.compression_threshold = compression_threshold
self.token_counter = token_counter or EstimateTokenCounter()
self.max_recent_rounds = (
None if max_recent_rounds is None else max(1, int(max_recent_rounds))
)
self.last_call_failed = False
self.instruction_text = instruction_text or (
@@ -234,6 +241,14 @@ class LLMSummaryCompressor:
old_rounds = old_rounds[:-1]
recent_rounds = [latest_old_round, *recent_rounds]
if (
self.max_recent_rounds is not None
and len(recent_rounds) > self.max_recent_rounds
):
excess_count = len(recent_rounds) - self.max_recent_rounds
old_rounds = old_rounds + recent_rounds[:excess_count]
recent_rounds = recent_rounds[excess_count:]
if not old_rounds:
if recent_rounds and messages and messages[-1].role == "user":
return messages

View File

@@ -37,6 +37,11 @@ class ContextManager:
keep_recent_ratio=config.llm_compress_keep_recent_ratio,
instruction_text=config.llm_compress_instruction,
token_counter=self.token_counter,
max_recent_rounds=(
max(1, config.enforce_max_turns - 1)
if config.enforce_max_turns != -1
else None
),
)
else:
self.compressor = TruncateByTurnsCompressor(
@@ -60,6 +65,7 @@ class ContextManager:
if self.config.enforce_max_turns != -1:
turn_count = count_conversation_rounds(result)
if turn_count > self.config.enforce_max_turns:
should_truncate_by_turns = True
if isinstance(self.compressor, LLMSummaryCompressor):
logger.debug(
"Turn limit (%s) exceeded (%s turns), "
@@ -73,14 +79,10 @@ class ContextManager:
"LLM summary compression failed; falling back "
"to turn-based truncation.",
)
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
else:
result = compressed
else:
should_truncate_by_turns = False
if should_truncate_by_turns:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,

View File

@@ -1,5 +1,5 @@
from ..message import Message
from .round_utils import count_conversation_rounds, split_into_rounds
from .round_utils import split_into_rounds
class ContextTruncator:
@@ -123,7 +123,12 @@ class ContextTruncator:
system_messages, non_system_messages = self._split_system_rest(messages)
rounds = split_into_rounds(non_system_messages)
if count_conversation_rounds(non_system_messages) <= keep_most_recent_turns:
round_count = sum(
1
for round_segments in rounds
if any(segment.role == "user" for segment in round_segments)
)
if round_count <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
@@ -161,7 +166,12 @@ class ContextTruncator:
system_messages, non_system_messages = self._split_system_rest(messages)
rounds = split_into_rounds(non_system_messages)
if count_conversation_rounds(non_system_messages) <= drop_turns:
round_count = sum(
1
for round_segments in rounds
if any(segment.role == "user" for segment in round_segments)
)
if round_count <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = [

View File

@@ -12,6 +12,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from astrbot.core.agent.context.config import ContextConfig
from astrbot.core.agent.context.manager import ContextManager
from astrbot.core.agent.context.round_utils import count_conversation_rounds
from astrbot.core.agent.message import AudioURLPart, ImageURLPart, Message, TextPart
from astrbot.core.provider.entities import LLMResponse
@@ -511,6 +512,24 @@ class TestContextManager:
assert len(result) < len(messages)
assert result[-1] is messages[-1]
@pytest.mark.asyncio
async def test_llm_enforce_max_turns_caps_recent_rounds(self):
"""LLM summary should cap exact recent rounds for max-turn enforcement."""
provider = MockProvider()
config = ContextConfig(
enforce_max_turns=2,
llm_compress_provider=provider, # type: ignore[arg-type]
llm_compress_keep_recent_ratio=0.3,
custom_token_counter=MessageCountTokenCounter(),
)
manager = ContextManager(config)
messages = self.create_messages(20)
result = await manager.process(messages)
assert provider.last_text_chat_kwargs is not None
assert count_conversation_rounds(result) <= config.enforce_max_turns
@pytest.mark.asyncio
async def test_enforce_max_turns_counts_tool_chain_as_one_round(self):
"""Tool messages in one round should not inflate turn count."""