Compare commits

...

2 Commits

2 changed files with 57 additions and 28 deletions

View File

@@ -87,9 +87,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 发送一个空消息, 以便于后续的处理
if (
isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent)
):
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -1,5 +1,6 @@
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
import asyncio
from collections.abc import Awaitable, Callable
from astrbot.api import logger
@@ -14,6 +15,8 @@ from .wecomai_webhook import WecomAIBotWebhookClient
class WecomAIBotMessageEvent(AstrMessageEvent):
"""企业微信智能机器人消息事件"""
STREAM_FLUSH_INTERVAL = 0.5
def __init__(
self,
message_str: str,
@@ -242,6 +245,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return
increment_plain = ""
last_stream_update_time = 0.0
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
@@ -253,17 +257,20 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
now = asyncio.get_running_loop().time()
if now - last_stream_update_time >= self.STREAM_FLUSH_INTERVAL:
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
},
)
)
last_stream_update_time = now
await self.long_connection_sender(
req_id,
@@ -289,22 +296,31 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
await super().send_streaming(generator, use_fallback)
return
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,按间隔推
increment_plain = ""
last_stream_update_time = 0.0
async def enqueue_stream_plain(text: str) -> None:
if not text:
return
await back_queue.put(
{
"type": "plain",
"data": text,
"streaming": True,
"session_id": stream_id,
},
)
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
chain, unsupported_only=True
)
# 累积增量内容,并改写 Plain 段
chain.squash_plain()
for comp in chain.chain:
if isinstance(comp, Plain):
comp.text = increment_plain + comp.text
increment_plain = comp.text
break
if chain.type == "break" and final_data:
if increment_plain:
await enqueue_stream_plain(increment_plain)
# 分割符
await back_queue.put(
{
@@ -315,15 +331,30 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
},
)
final_data = ""
increment_plain = ""
continue
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
queue_mgr=self.queue_mgr,
streaming=True,
suppress_unsupported_log=self.webhook_client is not None,
)
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
final_data += chunk_text
now = asyncio.get_running_loop().time()
if now - last_stream_update_time >= self.STREAM_FLUSH_INTERVAL:
await enqueue_stream_plain(increment_plain)
last_stream_update_time = now
for comp in chain.chain:
if isinstance(comp, (At, Plain)):
continue
await WecomAIBotMessageEvent._send(
MessageChain([comp]),
stream_id=stream_id,
queue_mgr=self.queue_mgr,
streaming=True,
suppress_unsupported_log=self.webhook_client is not None,
)
await enqueue_stream_plain(increment_plain)
await back_queue.put(
{