Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
8044f4f571 fix: avoid duplicate send_message_to_user replies 2026-06-27 16:52:31 +08:00
3 changed files with 166 additions and 9 deletions

View File

@@ -179,6 +179,29 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_FINISH:
event.set_extra("_streaming_finished", True)
return
sent_plain_texts = event.get_extra(
"_send_message_to_user_current_session_plain_texts",
[],
)
result_plain_text = result.get_plain_text().strip()
if (
result_plain_text
and isinstance(sent_plain_texts, list)
and result_plain_text in sent_plain_texts
and all(
comp.type
in {
ComponentType.Plain,
ComponentType.Reply,
ComponentType.At,
}
for comp in result.chain
)
):
logger.info(
"send_message_to_user already delivered the same text in this session, skip respond stage to avoid duplicate reply.",
)
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}",

View File

@@ -317,10 +317,23 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
else:
return f"error: invalid session: {session}"
await context.context.context.send_message(
target_session,
MessageChain(chain=components),
)
message_chain = MessageChain(chain=components)
await context.context.context.send_message(target_session, message_chain)
if str(target_session) == current_session:
context.context.event._has_send_oper = True
sent_plain_text = message_chain.get_plain_text().strip()
if sent_plain_text:
sent_plain_texts = context.context.event.get_extra(
"_send_message_to_user_current_session_plain_texts",
[],
)
if not isinstance(sent_plain_texts, list):
sent_plain_texts = []
sent_plain_texts.append(sent_plain_text)
context.context.event.set_extra(
"_send_message_to_user_current_session_plain_texts",
sent_plain_texts,
)
return f"Message sent to session {target_session}"

View File

@@ -5,6 +5,8 @@ from unittest.mock import AsyncMock
import pytest
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.pipeline.respond.stage import RespondStage
from astrbot.core.tools.message_tools import SendMessageToUserTool
@@ -21,13 +23,18 @@ def _make_context(
"computer_use_runtime": runtime,
}
}
extras = {}
event = SimpleNamespace(
unified_msg_origin=current_session,
role=role,
_has_send_oper=False,
get_sender_id=lambda: "user-1",
)
event.set_extra = lambda key, value: extras.__setitem__(key, value)
event.get_extra = lambda key, default=None: extras.get(key, default)
return SimpleNamespace(
context=SimpleNamespace(
event=SimpleNamespace(
unified_msg_origin=current_session,
role=role,
get_sender_id=lambda: "user-1",
),
event=event,
context=SimpleNamespace(
get_config=lambda umo: cfg,
send_message=AsyncMock(),
@@ -36,6 +43,65 @@ def _make_context(
)
class _DummyRespondEvent:
def __init__(self, result_text: str, sent_plain_texts: list[str]) -> None:
self._extras = {
"_send_message_to_user_current_session_plain_texts": sent_plain_texts,
}
self._result = MessageEventResult().message(result_text)
self.send = AsyncMock()
self.plugins_name = []
def get_result(self):
"""Return the current message result."""
return self._result
def set_extra(self, key, value) -> None:
"""Set pipeline extra data."""
self._extras[key] = value
def get_extra(self, key, default=None):
"""Get pipeline extra data."""
return self._extras.get(key, default)
def get_sender_name(self) -> str:
"""Return a sender name for respond-stage logging."""
return "tester"
def get_sender_id(self) -> str:
"""Return a sender ID for respond-stage logging."""
return "user-1"
def get_platform_id(self) -> str:
"""Return a platform ID for respond-stage logging."""
return "test"
def get_platform_name(self) -> str:
"""Return a platform name for segmented-reply checks."""
return "test"
def _outline_chain(self, chain) -> str:
"""Return a readable outline for respond-stage logging."""
return " ".join(comp.text for comp in chain if hasattr(comp, "text"))
def is_stopped(self) -> bool:
"""Return whether this dummy event has stopped."""
return False
def clear_result(self) -> None:
"""Clear the current message result."""
self._result = None
def _make_respond_stage() -> RespondStage:
"""Build a minimally initialized RespondStage for unit tests."""
stage = RespondStage()
stage.config = {"provider_settings": {}}
stage.platform_settings = {"path_mapping": []}
stage.enable_seg = False
return stage
@pytest.mark.asyncio
async def test_send_message_with_full_three_part_session():
"""LLM passes a complete three-part session string."""
@@ -81,6 +147,61 @@ async def test_send_message_defaults_to_current_session():
call_args = ctx.context.context.send_message.call_args
target_session = call_args[0][0]
assert str(target_session) == "feishu:GroupMessage:oc_xxx"
assert ctx.context.event._has_send_oper is True
assert ctx.context.event.get_extra(
"_send_message_to_user_current_session_plain_texts",
) == ["hello"]
@pytest.mark.asyncio
async def test_send_message_other_session_does_not_record_current_text():
"""Messages sent to another session do not affect current-session dedupe."""
tool = SendMessageToUserTool()
ctx = _make_context(current_session="feishu:GroupMessage:oc_xxx")
result = await tool.call(
ctx,
messages=[{"type": "plain", "text": "hello"}],
session="feishu:GroupMessage:oc_other",
)
assert "Message sent to session" in result
assert ctx.context.event._has_send_oper is False
assert (
ctx.context.event.get_extra(
"_send_message_to_user_current_session_plain_texts",
)
is None
)
@pytest.mark.asyncio
async def test_respond_stage_skips_same_text_after_send_message_to_user():
"""RespondStage skips only when the tool already sent the same text."""
stage = _make_respond_stage()
event = _DummyRespondEvent(
result_text="duplicate reply",
sent_plain_texts=["duplicate reply"],
)
result = await stage.process(event)
assert result is None
event.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_respond_stage_sends_different_text_after_send_message_to_user():
"""RespondStage still sends a distinct completion after the tool call."""
stage = _make_respond_stage()
event = _DummyRespondEvent(
result_text="I have sent the message with the tool.",
sent_plain_texts=["duplicate reply"],
)
result = await stage.process(event)
assert result is None
event.send.assert_awaited_once()
assert event.get_result() is None
@pytest.mark.asyncio