mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-02 18:50:15 +08:00
Compare commits
3 Commits
codex/fix-
...
fix/moonsh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e740604798 | ||
|
|
8a3b7d04cc | ||
|
|
b9d006814e |
@@ -87,9 +87,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 发送一个空消息, 以便于后续的处理
|
||||
if (
|
||||
isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent)
|
||||
):
|
||||
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -14,6 +15,8 @@ from .wecomai_webhook import WecomAIBotWebhookClient
|
||||
class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
"""企业微信智能机器人消息事件"""
|
||||
|
||||
STREAM_FLUSH_INTERVAL = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
@@ -242,6 +245,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
return
|
||||
|
||||
increment_plain = ""
|
||||
last_stream_update_time = 0.0
|
||||
async for chain in generator:
|
||||
if self.webhook_client:
|
||||
await self.webhook_client.send_message_chain(
|
||||
@@ -253,17 +257,20 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
chunk_text = self._extract_plain_text_from_chain(chain)
|
||||
if chunk_text:
|
||||
increment_plain += chunk_text
|
||||
await self.long_connection_sender(
|
||||
req_id,
|
||||
{
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": stream_id,
|
||||
"finish": False,
|
||||
"content": increment_plain,
|
||||
now = asyncio.get_running_loop().time()
|
||||
if now - last_stream_update_time >= self.STREAM_FLUSH_INTERVAL:
|
||||
await self.long_connection_sender(
|
||||
req_id,
|
||||
{
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": stream_id,
|
||||
"finish": False,
|
||||
"content": increment_plain,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
last_stream_update_time = now
|
||||
|
||||
await self.long_connection_sender(
|
||||
req_id,
|
||||
@@ -289,22 +296,31 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
return
|
||||
|
||||
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
|
||||
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,按间隔推送
|
||||
increment_plain = ""
|
||||
last_stream_update_time = 0.0
|
||||
|
||||
async def enqueue_stream_plain(text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"data": text,
|
||||
"streaming": True,
|
||||
"session_id": stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
async for chain in generator:
|
||||
if self.webhook_client:
|
||||
await self.webhook_client.send_message_chain(
|
||||
chain, unsupported_only=True
|
||||
)
|
||||
# 累积增量内容,并改写 Plain 段
|
||||
chain.squash_plain()
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
comp.text = increment_plain + comp.text
|
||||
increment_plain = comp.text
|
||||
break
|
||||
|
||||
if chain.type == "break" and final_data:
|
||||
if increment_plain:
|
||||
await enqueue_stream_plain(increment_plain)
|
||||
# 分割符
|
||||
await back_queue.put(
|
||||
{
|
||||
@@ -315,15 +331,30 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
},
|
||||
)
|
||||
final_data = ""
|
||||
increment_plain = ""
|
||||
continue
|
||||
|
||||
final_data += await WecomAIBotMessageEvent._send(
|
||||
chain,
|
||||
stream_id=stream_id,
|
||||
queue_mgr=self.queue_mgr,
|
||||
streaming=True,
|
||||
suppress_unsupported_log=self.webhook_client is not None,
|
||||
)
|
||||
chunk_text = self._extract_plain_text_from_chain(chain)
|
||||
if chunk_text:
|
||||
increment_plain += chunk_text
|
||||
final_data += chunk_text
|
||||
now = asyncio.get_running_loop().time()
|
||||
if now - last_stream_update_time >= self.STREAM_FLUSH_INTERVAL:
|
||||
await enqueue_stream_plain(increment_plain)
|
||||
last_stream_update_time = now
|
||||
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, (At, Plain)):
|
||||
continue
|
||||
await WecomAIBotMessageEvent._send(
|
||||
MessageChain([comp]),
|
||||
stream_id=stream_id,
|
||||
queue_mgr=self.queue_mgr,
|
||||
streaming=True,
|
||||
suppress_unsupported_log=self.webhook_client is not None,
|
||||
)
|
||||
|
||||
await enqueue_stream_plain(increment_plain)
|
||||
|
||||
await back_queue.put(
|
||||
{
|
||||
|
||||
@@ -313,7 +313,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.warning("Saving chunk state error: " + str(e))
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
# logger.debug(f"chunk delta: {delta}")
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
@@ -331,6 +332,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
_y = True
|
||||
if chunk.usage:
|
||||
llm_response.usage = self._extract_usage(chunk.usage)
|
||||
elif choice_usage := getattr(choice, "usage", None):
|
||||
# Workaround for some providers that only return usage in choices[].usage, e.g. MoonshotAI
|
||||
# See https://github.com/AstrBotDevs/AstrBot/issues/6614
|
||||
llm_response.usage = self._extract_usage(choice_usage)
|
||||
state.current_completion_snapshot.usage = choice_usage
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
@@ -359,13 +365,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
||||
ptd = usage.prompt_tokens_details
|
||||
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
||||
prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens
|
||||
completion_tokens = (
|
||||
0 if usage.completion_tokens is None else usage.completion_tokens
|
||||
)
|
||||
def _extract_usage(self, usage: CompletionUsage | dict) -> TokenUsage:
|
||||
ptd = getattr(usage, "prompt_tokens_details", None)
|
||||
cached = getattr(ptd, "cached_tokens", 0) if ptd else 0
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
return TokenUsage(
|
||||
input_other=prompt_tokens - cached,
|
||||
input_cached=cached,
|
||||
|
||||
Reference in New Issue
Block a user