mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 10:10:15 +08:00
Compare commits
1 Commits
fix/prerel
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8044f4f571 |
@@ -179,6 +179,29 @@ class RespondStage(Stage):
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
event.set_extra("_streaming_finished", True)
|
||||
return
|
||||
sent_plain_texts = event.get_extra(
|
||||
"_send_message_to_user_current_session_plain_texts",
|
||||
[],
|
||||
)
|
||||
result_plain_text = result.get_plain_text().strip()
|
||||
if (
|
||||
result_plain_text
|
||||
and isinstance(sent_plain_texts, list)
|
||||
and result_plain_text in sent_plain_texts
|
||||
and all(
|
||||
comp.type
|
||||
in {
|
||||
ComponentType.Plain,
|
||||
ComponentType.Reply,
|
||||
ComponentType.At,
|
||||
}
|
||||
for comp in result.chain
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
"send_message_to_user already delivered the same text in this session, skip respond stage to avoid duplicate reply.",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}",
|
||||
|
||||
@@ -317,10 +317,23 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
else:
|
||||
return f"error: invalid session: {session}"
|
||||
|
||||
await context.context.context.send_message(
|
||||
target_session,
|
||||
MessageChain(chain=components),
|
||||
)
|
||||
message_chain = MessageChain(chain=components)
|
||||
await context.context.context.send_message(target_session, message_chain)
|
||||
if str(target_session) == current_session:
|
||||
context.context.event._has_send_oper = True
|
||||
sent_plain_text = message_chain.get_plain_text().strip()
|
||||
if sent_plain_text:
|
||||
sent_plain_texts = context.context.event.get_extra(
|
||||
"_send_message_to_user_current_session_plain_texts",
|
||||
[],
|
||||
)
|
||||
if not isinstance(sent_plain_texts, list):
|
||||
sent_plain_texts = []
|
||||
sent_plain_texts.append(sent_plain_text)
|
||||
context.context.event.set_extra(
|
||||
"_send_message_to_user_current_session_plain_texts",
|
||||
sent_plain_texts,
|
||||
)
|
||||
return f"Message sent to session {target_session}"
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.pipeline.respond.stage import RespondStage
|
||||
from astrbot.core.tools.message_tools import SendMessageToUserTool
|
||||
|
||||
|
||||
@@ -21,13 +23,18 @@ def _make_context(
|
||||
"computer_use_runtime": runtime,
|
||||
}
|
||||
}
|
||||
extras = {}
|
||||
event = SimpleNamespace(
|
||||
unified_msg_origin=current_session,
|
||||
role=role,
|
||||
_has_send_oper=False,
|
||||
get_sender_id=lambda: "user-1",
|
||||
)
|
||||
event.set_extra = lambda key, value: extras.__setitem__(key, value)
|
||||
event.get_extra = lambda key, default=None: extras.get(key, default)
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(
|
||||
event=SimpleNamespace(
|
||||
unified_msg_origin=current_session,
|
||||
role=role,
|
||||
get_sender_id=lambda: "user-1",
|
||||
),
|
||||
event=event,
|
||||
context=SimpleNamespace(
|
||||
get_config=lambda umo: cfg,
|
||||
send_message=AsyncMock(),
|
||||
@@ -36,6 +43,65 @@ def _make_context(
|
||||
)
|
||||
|
||||
|
||||
class _DummyRespondEvent:
|
||||
def __init__(self, result_text: str, sent_plain_texts: list[str]) -> None:
|
||||
self._extras = {
|
||||
"_send_message_to_user_current_session_plain_texts": sent_plain_texts,
|
||||
}
|
||||
self._result = MessageEventResult().message(result_text)
|
||||
self.send = AsyncMock()
|
||||
self.plugins_name = []
|
||||
|
||||
def get_result(self):
|
||||
"""Return the current message result."""
|
||||
return self._result
|
||||
|
||||
def set_extra(self, key, value) -> None:
|
||||
"""Set pipeline extra data."""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(self, key, default=None):
|
||||
"""Get pipeline extra data."""
|
||||
return self._extras.get(key, default)
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""Return a sender name for respond-stage logging."""
|
||||
return "tester"
|
||||
|
||||
def get_sender_id(self) -> str:
|
||||
"""Return a sender ID for respond-stage logging."""
|
||||
return "user-1"
|
||||
|
||||
def get_platform_id(self) -> str:
|
||||
"""Return a platform ID for respond-stage logging."""
|
||||
return "test"
|
||||
|
||||
def get_platform_name(self) -> str:
|
||||
"""Return a platform name for segmented-reply checks."""
|
||||
return "test"
|
||||
|
||||
def _outline_chain(self, chain) -> str:
|
||||
"""Return a readable outline for respond-stage logging."""
|
||||
return " ".join(comp.text for comp in chain if hasattr(comp, "text"))
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
"""Return whether this dummy event has stopped."""
|
||||
return False
|
||||
|
||||
def clear_result(self) -> None:
|
||||
"""Clear the current message result."""
|
||||
self._result = None
|
||||
|
||||
|
||||
def _make_respond_stage() -> RespondStage:
|
||||
"""Build a minimally initialized RespondStage for unit tests."""
|
||||
stage = RespondStage()
|
||||
stage.config = {"provider_settings": {}}
|
||||
stage.platform_settings = {"path_mapping": []}
|
||||
stage.enable_seg = False
|
||||
return stage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_full_three_part_session():
|
||||
"""LLM passes a complete three-part session string."""
|
||||
@@ -81,6 +147,61 @@ async def test_send_message_defaults_to_current_session():
|
||||
call_args = ctx.context.context.send_message.call_args
|
||||
target_session = call_args[0][0]
|
||||
assert str(target_session) == "feishu:GroupMessage:oc_xxx"
|
||||
assert ctx.context.event._has_send_oper is True
|
||||
assert ctx.context.event.get_extra(
|
||||
"_send_message_to_user_current_session_plain_texts",
|
||||
) == ["hello"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_other_session_does_not_record_current_text():
|
||||
"""Messages sent to another session do not affect current-session dedupe."""
|
||||
tool = SendMessageToUserTool()
|
||||
ctx = _make_context(current_session="feishu:GroupMessage:oc_xxx")
|
||||
result = await tool.call(
|
||||
ctx,
|
||||
messages=[{"type": "plain", "text": "hello"}],
|
||||
session="feishu:GroupMessage:oc_other",
|
||||
)
|
||||
assert "Message sent to session" in result
|
||||
assert ctx.context.event._has_send_oper is False
|
||||
assert (
|
||||
ctx.context.event.get_extra(
|
||||
"_send_message_to_user_current_session_plain_texts",
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_stage_skips_same_text_after_send_message_to_user():
|
||||
"""RespondStage skips only when the tool already sent the same text."""
|
||||
stage = _make_respond_stage()
|
||||
event = _DummyRespondEvent(
|
||||
result_text="duplicate reply",
|
||||
sent_plain_texts=["duplicate reply"],
|
||||
)
|
||||
|
||||
result = await stage.process(event)
|
||||
|
||||
assert result is None
|
||||
event.send.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respond_stage_sends_different_text_after_send_message_to_user():
|
||||
"""RespondStage still sends a distinct completion after the tool call."""
|
||||
stage = _make_respond_stage()
|
||||
event = _DummyRespondEvent(
|
||||
result_text="I have sent the message with the tool.",
|
||||
sent_plain_texts=["duplicate reply"],
|
||||
)
|
||||
|
||||
result = await stage.process(event)
|
||||
|
||||
assert result is None
|
||||
event.send.assert_awaited_once()
|
||||
assert event.get_result() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user