Compare commits

...

3 Commits

Author SHA1 Message Date
Soulter
e740604798 chore: ruff format 2026-03-19 19:01:29 +08:00
Soulter
8a3b7d04cc fix(openai): improve usage extraction for chunk choices and handle fallback for usage details 2026-03-19 18:58:01 +08:00
Soulter
b9d006814e fix(wecom-ai): add 0.5s interval for streaming responses 2026-03-19 17:10:36 +08:00
3 changed files with 69 additions and 36 deletions

View File

@@ -87,9 +87,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 发送一个空消息, 以便于后续的处理
if (
isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent)
):
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -1,5 +1,6 @@
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
import asyncio
from collections.abc import Awaitable, Callable
from astrbot.api import logger
@@ -14,6 +15,8 @@ from .wecomai_webhook import WecomAIBotWebhookClient
class WecomAIBotMessageEvent(AstrMessageEvent):
"""企业微信智能机器人消息事件"""
STREAM_FLUSH_INTERVAL = 0.5
def __init__(
self,
message_str: str,
@@ -242,6 +245,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return
increment_plain = ""
last_stream_update_time = 0.0
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
@@ -253,17 +257,20 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
now = asyncio.get_running_loop().time()
if now - last_stream_update_time >= self.STREAM_FLUSH_INTERVAL:
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
},
)
)
last_stream_update_time = now
await self.long_connection_sender(
req_id,
@@ -289,22 +296,31 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
await super().send_streaming(generator, use_fallback)
return
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,按间隔推
increment_plain = ""
last_stream_update_time = 0.0
async def enqueue_stream_plain(text: str) -> None:
if not text:
return
await back_queue.put(
{
"type": "plain",
"data": text,
"streaming": True,
"session_id": stream_id,
},
)
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
chain, unsupported_only=True
)
# 累积增量内容,并改写 Plain 段
chain.squash_plain()
for comp in chain.chain:
if isinstance(comp, Plain):
comp.text = increment_plain + comp.text
increment_plain = comp.text
break
if chain.type == "break" and final_data:
if increment_plain:
await enqueue_stream_plain(increment_plain)
# 分割符
await back_queue.put(
{
@@ -315,15 +331,30 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
},
)
final_data = ""
increment_plain = ""
continue
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
queue_mgr=self.queue_mgr,
streaming=True,
suppress_unsupported_log=self.webhook_client is not None,
)
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
final_data += chunk_text
now = asyncio.get_running_loop().time()
if now - last_stream_update_time >= self.STREAM_FLUSH_INTERVAL:
await enqueue_stream_plain(increment_plain)
last_stream_update_time = now
for comp in chain.chain:
if isinstance(comp, (At, Plain)):
continue
await WecomAIBotMessageEvent._send(
MessageChain([comp]),
stream_id=stream_id,
queue_mgr=self.queue_mgr,
streaming=True,
suppress_unsupported_log=self.webhook_client is not None,
)
await enqueue_stream_plain(increment_plain)
await back_queue.put(
{

View File

@@ -313,7 +313,8 @@ class ProviderOpenAIOfficial(Provider):
logger.warning("Saving chunk state error: " + str(e))
if not chunk.choices:
continue
delta = chunk.choices[0].delta
choice = chunk.choices[0]
delta = choice.delta
# logger.debug(f"chunk delta: {delta}")
# handle the content delta
reasoning = self._extract_reasoning_content(chunk)
@@ -331,6 +332,11 @@ class ProviderOpenAIOfficial(Provider):
_y = True
if chunk.usage:
llm_response.usage = self._extract_usage(chunk.usage)
elif choice_usage := getattr(choice, "usage", None):
# Workaround for some providers that only return usage in choices[].usage, e.g. MoonshotAI
# See https://github.com/AstrBotDevs/AstrBot/issues/6614
llm_response.usage = self._extract_usage(choice_usage)
state.current_completion_snapshot.usage = choice_usage
if _y:
yield llm_response
@@ -359,13 +365,11 @@ class ProviderOpenAIOfficial(Provider):
reasoning_text = str(reasoning_attr)
return reasoning_text
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
ptd = usage.prompt_tokens_details
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens
completion_tokens = (
0 if usage.completion_tokens is None else usage.completion_tokens
)
def _extract_usage(self, usage: CompletionUsage | dict) -> TokenUsage:
ptd = getattr(usage, "prompt_tokens_details", None)
cached = getattr(ptd, "cached_tokens", 0) if ptd else 0
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
return TokenUsage(
input_other=prompt_tokens - cached,
input_cached=cached,