mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 01:10:21 +08:00
fix: cap llm compression recent rounds
This commit is contained in:
@@ -130,6 +130,7 @@ class LLMSummaryCompressor:
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
token_counter: TokenCounter | None = None,
|
||||
max_recent_rounds: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
@@ -139,11 +140,17 @@ class LLMSummaryCompressor:
|
||||
exact context. Clamped to 0-0.3.
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
token_counter: Optional custom token counter.
|
||||
max_recent_rounds: Maximum exact recent rounds to preserve after
|
||||
summarization. If None, only the token ratio limits recent rounds.
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
|
||||
self.compression_threshold = compression_threshold
|
||||
self.token_counter = token_counter or EstimateTokenCounter()
|
||||
self.max_recent_rounds = (
|
||||
None if max_recent_rounds is None else max(1, int(max_recent_rounds))
|
||||
)
|
||||
self.last_call_failed = False
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
@@ -234,6 +241,14 @@ class LLMSummaryCompressor:
|
||||
old_rounds = old_rounds[:-1]
|
||||
recent_rounds = [latest_old_round, *recent_rounds]
|
||||
|
||||
if (
|
||||
self.max_recent_rounds is not None
|
||||
and len(recent_rounds) > self.max_recent_rounds
|
||||
):
|
||||
excess_count = len(recent_rounds) - self.max_recent_rounds
|
||||
old_rounds = old_rounds + recent_rounds[:excess_count]
|
||||
recent_rounds = recent_rounds[excess_count:]
|
||||
|
||||
if not old_rounds:
|
||||
if recent_rounds and messages and messages[-1].role == "user":
|
||||
return messages
|
||||
|
||||
@@ -37,6 +37,11 @@ class ContextManager:
|
||||
keep_recent_ratio=config.llm_compress_keep_recent_ratio,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
token_counter=self.token_counter,
|
||||
max_recent_rounds=(
|
||||
max(1, config.enforce_max_turns - 1)
|
||||
if config.enforce_max_turns != -1
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
@@ -60,6 +65,7 @@ class ContextManager:
|
||||
if self.config.enforce_max_turns != -1:
|
||||
turn_count = count_conversation_rounds(result)
|
||||
if turn_count > self.config.enforce_max_turns:
|
||||
should_truncate_by_turns = True
|
||||
if isinstance(self.compressor, LLMSummaryCompressor):
|
||||
logger.debug(
|
||||
"Turn limit (%s) exceeded (%s turns), "
|
||||
@@ -73,14 +79,10 @@ class ContextManager:
|
||||
"LLM summary compression failed; falling back "
|
||||
"to turn-based truncation.",
|
||||
)
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
else:
|
||||
result = compressed
|
||||
else:
|
||||
should_truncate_by_turns = False
|
||||
if should_truncate_by_turns:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ..message import Message
|
||||
from .round_utils import count_conversation_rounds, split_into_rounds
|
||||
from .round_utils import split_into_rounds
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
@@ -123,7 +123,12 @@ class ContextTruncator:
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
rounds = split_into_rounds(non_system_messages)
|
||||
|
||||
if count_conversation_rounds(non_system_messages) <= keep_most_recent_turns:
|
||||
round_count = sum(
|
||||
1
|
||||
for round_segments in rounds
|
||||
if any(segment.role == "user" for segment in round_segments)
|
||||
)
|
||||
if round_count <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
@@ -161,7 +166,12 @@ class ContextTruncator:
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
rounds = split_into_rounds(non_system_messages)
|
||||
|
||||
if count_conversation_rounds(non_system_messages) <= drop_turns:
|
||||
round_count = sum(
|
||||
1
|
||||
for round_segments in rounds
|
||||
if any(segment.role == "user" for segment in round_segments)
|
||||
)
|
||||
if round_count <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = [
|
||||
|
||||
@@ -12,6 +12,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.context.round_utils import count_conversation_rounds
|
||||
from astrbot.core.agent.message import AudioURLPart, ImageURLPart, Message, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
@@ -511,6 +512,24 @@ class TestContextManager:
|
||||
assert len(result) < len(messages)
|
||||
assert result[-1] is messages[-1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_enforce_max_turns_caps_recent_rounds(self):
|
||||
"""LLM summary should cap exact recent rounds for max-turn enforcement."""
|
||||
provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=2,
|
||||
llm_compress_provider=provider, # type: ignore[arg-type]
|
||||
llm_compress_keep_recent_ratio=0.3,
|
||||
custom_token_counter=MessageCountTokenCounter(),
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
messages = self.create_messages(20)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert provider.last_text_chat_kwargs is not None
|
||||
assert count_conversation_rounds(result) <= config.enforce_max_turns
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_counts_tool_chain_as_one_round(self):
|
||||
"""Tool messages in one round should not inflate turn count."""
|
||||
|
||||
Reference in New Issue
Block a user