mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-04 19:50:16 +08:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd53e0e751 | ||
|
|
383df74e34 | ||
|
|
26627887d1 | ||
|
|
a5e86c8b94 | ||
|
|
af6f9cfc5e | ||
|
|
8986d05309 | ||
|
|
045be7943d | ||
|
|
cd4e999526 | ||
|
|
6db9aef3ea |
@@ -335,6 +335,18 @@ class TelegramPlatformAdapter(Platform):
|
||||
logger.warning("Received an update without a message.")
|
||||
return None
|
||||
|
||||
def _apply_caption() -> None:
|
||||
if update.message.caption:
|
||||
message.message_str = update.message.caption
|
||||
message.message.append(Comp.Plain(message.message_str))
|
||||
if update.message.caption and update.message.caption_entities:
|
||||
for entity in update.message.caption_entities:
|
||||
if entity.type == "mention":
|
||||
name = update.message.caption[
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
|
||||
message = AstrBotMessage()
|
||||
message.session_id = str(update.message.chat.id)
|
||||
|
||||
@@ -454,16 +466,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
photo = update.message.photo[-1] # get the largest photo
|
||||
file = await photo.get_file()
|
||||
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||
if update.message.caption:
|
||||
message.message_str = update.message.caption
|
||||
message.message.append(Comp.Plain(message.message_str))
|
||||
if update.message.caption_entities:
|
||||
for entity in update.message.caption_entities:
|
||||
if entity.type == "mention":
|
||||
name = message.message_str[
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
_apply_caption()
|
||||
|
||||
elif update.message.sticker:
|
||||
# 将sticker当作图片处理
|
||||
@@ -486,6 +489,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
message.message.append(
|
||||
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||
)
|
||||
_apply_caption()
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
@@ -497,6 +501,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.Video(file=file_path, path=file.file_path))
|
||||
_apply_caption()
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from wechatpy.exceptions import WeChatClientException
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -95,7 +96,19 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
try:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
except WeChatClientException as e:
|
||||
if getattr(e, "errcode", None) == 40096:
|
||||
# 40096: invalid external userid, fallback to regular message API
|
||||
logger.warning(
|
||||
f"kf API error 40096 for user {user_id}, falling back to regular message API"
|
||||
)
|
||||
self.client.message.send_text(
|
||||
self.get_self_id(), user_id, chunk
|
||||
)
|
||||
else:
|
||||
raise
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
@@ -902,6 +902,13 @@ class WeixinOCAdapter(Platform):
|
||||
"weixin_oc(%s): inbound long-poll timeout",
|
||||
self.meta().id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"weixin_oc(%s): poll inbound updates failed, will retry after 5 seconds: %s",
|
||||
self.meta().id,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -515,7 +515,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": new_messages, "model": model}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -571,7 +571,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": new_messages, "model": model}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
|
||||
@@ -757,7 +757,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": context_query, "model": model}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -812,7 +812,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": context_query, "model": model}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
|
||||
@@ -27,8 +27,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
api_base = (
|
||||
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
|
||||
.strip()
|
||||
.rstrip("/")
|
||||
.rstrip("/embeddings")
|
||||
.removesuffix("/")
|
||||
.removesuffix("/embeddings")
|
||||
)
|
||||
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
|
||||
# /v4 see #5699
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
@@ -14,6 +18,8 @@ from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from PIL import Image as PILImage
|
||||
from PIL import UnidentifiedImageError
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
@@ -133,6 +139,186 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_invalid_attachment_error(self, error: Exception) -> bool:
|
||||
body = getattr(error, "body", None)
|
||||
code: str | None = None
|
||||
message: str | None = None
|
||||
if isinstance(body, dict):
|
||||
err_obj = body.get("error")
|
||||
if isinstance(err_obj, dict):
|
||||
raw_code = err_obj.get("code")
|
||||
raw_message = err_obj.get("message")
|
||||
code = raw_code.lower() if isinstance(raw_code, str) else None
|
||||
message = raw_message.lower() if isinstance(raw_message, str) else None
|
||||
|
||||
if code == "invalid_attachment":
|
||||
return True
|
||||
|
||||
text_sources: list[str] = []
|
||||
if message:
|
||||
text_sources.append(message)
|
||||
if code:
|
||||
text_sources.append(code)
|
||||
text_sources.extend(map(str, self._extract_error_text_candidates(error)))
|
||||
|
||||
error_text = " ".join(text.lower() for text in text_sources if text)
|
||||
if "invalid_attachment" in error_text:
|
||||
return True
|
||||
if "download attachment" in error_text and "404" in error_text:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _encode_image_file_to_data_url(
|
||||
cls,
|
||||
image_path: str,
|
||||
*,
|
||||
mode: Literal["safe", "strict"],
|
||||
) -> str | None:
|
||||
try:
|
||||
image_bytes = Path(image_path).read_bytes()
|
||||
except OSError:
|
||||
if mode == "strict":
|
||||
raise
|
||||
return None
|
||||
|
||||
try:
|
||||
with PILImage.open(BytesIO(image_bytes)) as image:
|
||||
image.verify()
|
||||
image_format = str(image.format or "").upper()
|
||||
except (OSError, UnidentifiedImageError):
|
||||
if mode == "strict":
|
||||
raise ValueError(f"Invalid image file: {image_path}")
|
||||
return None
|
||||
|
||||
mime_type = {
|
||||
"JPEG": "image/jpeg",
|
||||
"PNG": "image/png",
|
||||
"GIF": "image/gif",
|
||||
"WEBP": "image/webp",
|
||||
"BMP": "image/bmp",
|
||||
}.get(image_format, "image/jpeg")
|
||||
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{image_bs64}"
|
||||
|
||||
@staticmethod
|
||||
def _file_uri_to_path(file_uri: str) -> str:
|
||||
"""Normalize file URIs to paths.
|
||||
|
||||
`file://localhost/...` and drive-letter forms are treated as local paths.
|
||||
Other non-empty hosts are preserved as UNC-style paths.
|
||||
"""
|
||||
parsed = urlparse(file_uri)
|
||||
if parsed.scheme != "file":
|
||||
return file_uri
|
||||
|
||||
netloc = unquote(parsed.netloc or "")
|
||||
path = unquote(parsed.path or "")
|
||||
if re.fullmatch(r"[A-Za-z]:", netloc):
|
||||
return str(Path(f"{netloc}{path}"))
|
||||
if re.match(r"^/[A-Za-z]:/", path):
|
||||
path = path[1:]
|
||||
if netloc and netloc != "localhost":
|
||||
path = f"//{netloc}{path}"
|
||||
return str(Path(path))
|
||||
|
||||
async def _image_ref_to_data_url(
|
||||
self,
|
||||
image_ref: str,
|
||||
*,
|
||||
mode: Literal["safe", "strict"] = "safe",
|
||||
) -> str | None:
|
||||
if image_ref.startswith("base64://"):
|
||||
return image_ref.replace("base64://", "data:image/jpeg;base64,")
|
||||
|
||||
if image_ref.startswith("http"):
|
||||
image_path = await download_image_by_url(image_ref)
|
||||
elif image_ref.startswith("file://"):
|
||||
image_path = self._file_uri_to_path(image_ref)
|
||||
else:
|
||||
image_path = image_ref
|
||||
|
||||
return self._encode_image_file_to_data_url(
|
||||
image_path,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
async def _resolve_image_part(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
image_detail: str | None = None,
|
||||
) -> dict | None:
|
||||
if image_url.startswith("data:"):
|
||||
image_payload = {"url": image_url}
|
||||
else:
|
||||
image_data = await self._image_ref_to_data_url(image_url, mode="safe")
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
image_payload = {"url": image_data}
|
||||
|
||||
if image_detail:
|
||||
image_payload["detail"] = image_detail
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": image_payload,
|
||||
}
|
||||
|
||||
def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]:
|
||||
if not isinstance(part, dict) or part.get("type") != "image_url":
|
||||
return None, None
|
||||
|
||||
image_url_data = part.get("image_url")
|
||||
if not isinstance(image_url_data, dict):
|
||||
logger.warning("图片内容块格式无效,将保留原始内容。")
|
||||
return None, None
|
||||
|
||||
url = image_url_data.get("url")
|
||||
if not isinstance(url, str) or not url:
|
||||
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
|
||||
return None, None
|
||||
|
||||
image_detail = image_url_data.get("detail")
|
||||
if not isinstance(image_detail, str):
|
||||
image_detail = None
|
||||
return url, image_detail
|
||||
|
||||
async def _transform_content_part(self, part: dict) -> dict:
|
||||
url, image_detail = self._extract_image_part_info(part)
|
||||
if not url:
|
||||
return part
|
||||
|
||||
try:
|
||||
resolved_part = await self._resolve_image_part(
|
||||
url, image_detail=image_detail
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"图片 %s 预处理失败,将保留原始内容。错误: %s",
|
||||
url,
|
||||
exc,
|
||||
)
|
||||
return part
|
||||
|
||||
return resolved_part or part
|
||||
|
||||
async def _materialize_message_image_parts(self, message: dict) -> dict:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
return {**message}
|
||||
|
||||
new_content = [await self._transform_content_part(part) for part in content]
|
||||
return {**message, "content": new_content}
|
||||
|
||||
async def _materialize_context_image_parts(
|
||||
self, context_query: list[dict]
|
||||
) -> list[dict]:
|
||||
return [
|
||||
await self._materialize_message_image_parts(message)
|
||||
for message in context_query
|
||||
]
|
||||
|
||||
async def _fallback_to_text_only_and_retry(
|
||||
self,
|
||||
payloads: dict,
|
||||
@@ -604,7 +790,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts))
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
if system_prompt:
|
||||
@@ -622,8 +808,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
if self._context_contains_image(context_query):
|
||||
context_query = await self._materialize_context_image_parts(context_query)
|
||||
|
||||
model = model or self.get_model()
|
||||
payloads = {**kwargs, "messages": context_query, "model": model}
|
||||
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
self._finally_convert_payload(payloads)
|
||||
|
||||
@@ -721,6 +911,18 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"image_content_moderated",
|
||||
image_fallback_used=True,
|
||||
)
|
||||
if self._is_invalid_attachment_error(e):
|
||||
if image_fallback_used or not self._context_contains_image(context_query):
|
||||
raise e
|
||||
return await self._fallback_to_text_only_and_retry(
|
||||
payloads,
|
||||
context_query,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
func_tool,
|
||||
"invalid_attachment",
|
||||
image_fallback_used=True,
|
||||
)
|
||||
|
||||
if (
|
||||
"Function calling is not enabled" in str(e)
|
||||
@@ -922,23 +1124,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
@@ -958,7 +1143,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
image_part = await self._resolve_image_part(
|
||||
part.image_url.url,
|
||||
)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
@@ -967,7 +1154,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
image_part = await self._resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
@@ -986,11 +1173,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
image_data = await self._image_ref_to_data_url(image_url, mode="strict")
|
||||
if image_data is None:
|
||||
raise RuntimeError(f"Failed to encode image data: {image_url}")
|
||||
return image_data
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
|
||||
@@ -259,7 +259,7 @@
|
||||
<v-card-text class="py-4">
|
||||
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
|
||||
<span><a
|
||||
href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7"
|
||||
href="https://docs.astrbot.app/platform/aiocqhttp.html"
|
||||
target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
|
||||
</v-card-text>
|
||||
<v-card-actions class="px-4 pb-4">
|
||||
|
||||
@@ -19,7 +19,6 @@ export function useSessions(chatboxMode: boolean = false) {
|
||||
const selectedSessions = ref<string[]>([]);
|
||||
const currSessionId = ref('');
|
||||
const pendingSessionId = ref<string | null>(null);
|
||||
|
||||
// 编辑标题相关
|
||||
const editTitleDialog = ref(false);
|
||||
const editingTitle = ref('');
|
||||
@@ -30,29 +29,16 @@ export function useSessions(chatboxMode: boolean = false) {
|
||||
return sessions.value.find(s => s.session_id === currSessionId.value);
|
||||
});
|
||||
|
||||
|
||||
|
||||
async function getSessions() {
|
||||
try {
|
||||
const response = await axios.get('/api/chat/sessions');
|
||||
sessions.value = response.data.data;
|
||||
|
||||
// 处理待加载的会话
|
||||
if (pendingSessionId.value) {
|
||||
const session = sessions.value.find(s => s.session_id === pendingSessionId.value);
|
||||
if (session) {
|
||||
selectedSessions.value = [pendingSessionId.value];
|
||||
pendingSessionId.value = null;
|
||||
}
|
||||
} else if (currSessionId.value) {
|
||||
// 如果当前有选中的会话,确保它在列表中并被选中
|
||||
const session = sessions.value.find(s => s.session_id === currSessionId.value);
|
||||
if (session) {
|
||||
selectedSessions.value = [currSessionId.value];
|
||||
}
|
||||
} else if (sessions.value.length > 0) {
|
||||
// 默认选择第一个会话
|
||||
const firstSession = sessions.value[0];
|
||||
selectedSessions.value = [firstSession.session_id];
|
||||
}
|
||||
|
||||
|
||||
|
||||
} catch (err: any) {
|
||||
if (err.response?.status === 401) {
|
||||
router.push('/auth/login?redirect=/chatbox');
|
||||
|
||||
@@ -17,18 +17,10 @@ const customizer = useCustomizerStore();
|
||||
const { locale } = useI18n();
|
||||
const route = useRoute();
|
||||
const routerLoadingStore = useRouterLoadingStore();
|
||||
const isCurrentChatRoute = computed(() => route.path === '/chat' || route.path.startsWith('/chat/'));
|
||||
|
||||
const isChatPage = computed(() => {
|
||||
return route.path.startsWith('/chat');
|
||||
});
|
||||
|
||||
const showSidebar = computed(() => {
|
||||
return customizer.viewMode === 'bot';
|
||||
});
|
||||
|
||||
const showChatPage = computed(() => {
|
||||
return customizer.viewMode === 'chat';
|
||||
});
|
||||
const showSidebar = computed(() => !isCurrentChatRoute.value)
|
||||
|
||||
const migrationDialog = ref<InstanceType<typeof MigrationDialog> | null>(null);
|
||||
const showFirstNoticeDialog = ref(false);
|
||||
@@ -111,20 +103,20 @@ onMounted(() => {
|
||||
<VerticalHeaderVue />
|
||||
<VerticalSidebarVue v-if="showSidebar" />
|
||||
<v-main :style="{
|
||||
height: showChatPage ? 'calc(100vh - 55px)' : undefined,
|
||||
overflow: showChatPage ? 'hidden' : undefined
|
||||
height: isCurrentChatRoute ? 'calc(100vh - 55px)' : undefined,
|
||||
overflow: isCurrentChatRoute ? 'hidden' : undefined
|
||||
}">
|
||||
<v-container
|
||||
fluid
|
||||
class="page-wrapper"
|
||||
:class="{ 'chat-mode-container': showChatPage }"
|
||||
:class="{ 'chat-mode-container': isCurrentChatRoute }"
|
||||
:style="{
|
||||
height: showChatPage ? '100%' : 'calc(100% - 8px)',
|
||||
padding: (isChatPage || showChatPage) ? '0' : undefined,
|
||||
minHeight: showChatPage ? 'unset' : undefined
|
||||
height: isCurrentChatRoute ? '100%' : 'calc(100% - 8px)',
|
||||
padding: isCurrentChatRoute ? '0' : undefined,
|
||||
minHeight: isCurrentChatRoute ? 'unset' : undefined
|
||||
}">
|
||||
<div :style="{ height: '100%', width: '100%', overflow: showChatPage ? 'hidden' : undefined }">
|
||||
<div v-if="showChatPage" style="height: 100%; width: 100%; overflow: hidden;">
|
||||
<div :style="{ height: '100%', width: '100%', overflow: isCurrentChatRoute ? 'hidden' : undefined }">
|
||||
<div v-if="isCurrentChatRoute" style="height: 100%; width: 100%; overflow: hidden;">
|
||||
<Chat />
|
||||
</div>
|
||||
<RouterView v-else />
|
||||
|
||||
@@ -28,6 +28,7 @@ const theme = useTheme();
|
||||
const { t } = useI18n();
|
||||
const route = useRoute();
|
||||
const LAST_BOT_ROUTE_KEY = 'astrbot:last_bot_route';
|
||||
const LAST_CHAT_ROUTE_KEY = 'astrbot:last_chat_route';
|
||||
let dialog = ref(false);
|
||||
let accountWarning = ref(false)
|
||||
let updateStatusDialog = ref(false);
|
||||
@@ -58,7 +59,9 @@ const desktopUpdateHasNewVersion = ref(false);
|
||||
const desktopUpdateCurrentVersion = ref('-');
|
||||
const desktopUpdateLatestVersion = ref('-');
|
||||
const desktopUpdateStatus = ref('');
|
||||
|
||||
const isChatPath = computed(() =>
|
||||
route.path === '/chat' || route.path.startsWith('/chat/')
|
||||
);
|
||||
const getAppUpdaterBridge = (): AstrBotAppUpdaterBridge | null => {
|
||||
if (typeof window === 'undefined') {
|
||||
return null;
|
||||
@@ -380,7 +383,7 @@ function openReleaseNotesDialog(body: string, tag: string) {
|
||||
}
|
||||
|
||||
function handleLogoClick() {
|
||||
if (customizer.viewMode === 'chat') {
|
||||
if (isChatPath.value) {
|
||||
aboutDialog.value = true;
|
||||
} else {
|
||||
router.push('/about');
|
||||
@@ -395,10 +398,22 @@ commonStore.createEventSource(); // log
|
||||
commonStore.getStartTime();
|
||||
|
||||
// 视图模式切换
|
||||
const viewMode = computed({
|
||||
get: () => customizer.viewMode,
|
||||
set: (value: 'bot' | 'chat') => {
|
||||
customizer.SET_VIEW_MODE(value);
|
||||
onMounted(() => {
|
||||
// 初次加載時保存當前路由
|
||||
if (typeof window !== 'undefined') {
|
||||
if (isChatPath.value) {
|
||||
// 保存 chat ID
|
||||
const parts = route.fullPath.split('/');
|
||||
const sessionId = parts[2];
|
||||
if (sessionId) {
|
||||
sessionStorage.setItem(LAST_CHAT_ROUTE_KEY, sessionId);
|
||||
console.log('Initial save chat ID:', sessionId);
|
||||
}
|
||||
} else {
|
||||
// 保存 bot 路由(非 chat 頁面)
|
||||
sessionStorage.setItem(LAST_BOT_ROUTE_KEY, route.fullPath);
|
||||
console.log('Initial save bot route:', route.fullPath);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -406,26 +421,61 @@ const viewMode = computed({
|
||||
// 保存 bot 模式的最後路由
|
||||
// 監聽 route 變化,保存最後一次 bot 路由
|
||||
watch(() => route.fullPath, (newPath) => {
|
||||
if (customizer.viewMode === 'bot' && typeof window !== 'undefined') {
|
||||
try {
|
||||
localStorage.setItem(LAST_BOT_ROUTE_KEY, newPath);
|
||||
} catch (e) {
|
||||
console.error('Failed to save last bot route to localStorage:', e);
|
||||
if (typeof window === 'undefined') return;
|
||||
console.log('Route changed:', {
|
||||
newPath,
|
||||
isChat: isChatPath.value,
|
||||
currentChatId: route.params.id
|
||||
});
|
||||
try {
|
||||
// 使用現有的 isChatPath 計算屬性來避免名稱衝突
|
||||
const isChat = isChatPath.value; // 這裡使用已經計算好的 isChatPath
|
||||
|
||||
// ✅ bot:只存「非 chat 頁」
|
||||
if (!isChat) {
|
||||
sessionStorage.setItem(LAST_BOT_ROUTE_KEY, newPath);
|
||||
}
|
||||
|
||||
// ✅ chat:只存 sessionId
|
||||
if (isChat) {
|
||||
const parts = newPath.split('/');
|
||||
const sessionId = parts[2];
|
||||
|
||||
if (sessionId) {
|
||||
sessionStorage.setItem(LAST_CHAT_ROUTE_KEY, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (e) {
|
||||
console.error('Failed to save route:', e);
|
||||
}
|
||||
});
|
||||
|
||||
// 監聽 viewMode 切換
|
||||
watch(() => customizer.viewMode, (newMode, oldMode) => {
|
||||
if (newMode === 'bot' && oldMode === 'chat' && typeof window !== 'undefined') {
|
||||
// 從 chat 切換回 bot,跳轉到最後一次的 bot 路由
|
||||
let lastBotRoute = '/';
|
||||
const currentMode = computed({
|
||||
get: () => (isChatPath.value ? 'chat' : 'bot'),
|
||||
set: (val: 'chat' | 'bot') => {
|
||||
try {
|
||||
lastBotRoute = localStorage.getItem(LAST_BOT_ROUTE_KEY) || '/';
|
||||
// 檢查 window 和 sessionStorage 是否存在
|
||||
if (typeof window === 'undefined' || typeof sessionStorage === 'undefined') {
|
||||
// 如果在非瀏覽器環境中,不做任何 sessionStorage 操作
|
||||
console.warn('sessionStorage is not available in this environment');
|
||||
return;
|
||||
}
|
||||
|
||||
if (val === 'chat') {
|
||||
const lastSessionId = sessionStorage.getItem(LAST_CHAT_ROUTE_KEY);
|
||||
router.push(lastSessionId ? `/chat/${lastSessionId}` : '/chat');
|
||||
} else {
|
||||
let lastBotRoute = sessionStorage.getItem(LAST_BOT_ROUTE_KEY) || '/';
|
||||
if (lastBotRoute.startsWith('/chat')) {
|
||||
lastBotRoute = '/';
|
||||
}
|
||||
router.push(lastBotRoute);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to read last bot route from localStorage:', e);
|
||||
// 在受限隱私模式等環境中,sessionStorage 操作可能會拋出 SecurityError
|
||||
console.warn('Failed to access sessionStorage in currentMode setter:', e);
|
||||
}
|
||||
router.push(lastBotRoute);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -465,29 +515,46 @@ onMounted(async () => {
|
||||
<v-app-bar elevation="0" height="50" class="top-header">
|
||||
|
||||
<!-- 桌面端 menu 按钮 - 仅在 bot 模式下显示 -->
|
||||
<v-btn v-if="customizer.viewMode === 'bot'"
|
||||
style="margin-left: 16px;"
|
||||
class="hidden-md-and-down" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<!-- 移动端 menu 按钮 - 仅在 bot 模式下显示 -->
|
||||
<v-btn v-if="customizer.viewMode === 'bot'" class="hidden-lg-and-up ms-3" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_SIDEBAR_DRAWER">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<v-btn
|
||||
v-if="!isChatPath"
|
||||
style="margin-left: 16px;"
|
||||
class="hidden-md-and-down"
|
||||
icon
|
||||
rounded="sm"
|
||||
variant="flat"
|
||||
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)"
|
||||
>
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<!-- 移动端 chat sidebar 展开按钮 - 仅在 chat 模式下的小屏幕显示 -->
|
||||
<v-btn v-if="customizer.viewMode === 'chat'" class="hidden-lg-and-up ms-1" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.TOGGLE_CHAT_SIDEBAR()">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<!-- 移动端 menu 按钮 -->
|
||||
<v-btn
|
||||
v-if="!isChatPath"
|
||||
class="hidden-lg-and-up ms-3"
|
||||
icon
|
||||
rounded="sm"
|
||||
variant="flat"
|
||||
@click.stop="customizer.SET_SIDEBAR_DRAWER"
|
||||
>
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs, 'chat-mode-logo': customizer.viewMode === 'chat' }" @click="handleLogoClick">
|
||||
<v-btn
|
||||
v-if="isChatPath"
|
||||
class="hidden-lg-and-up ms-1"
|
||||
icon
|
||||
rounded="sm"
|
||||
variant="flat"
|
||||
@click.stop="customizer.TOGGLE_CHAT_SIDEBAR()"
|
||||
>
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs, 'chat-mode-logo': isChatPath }" @click="handleLogoClick">
|
||||
<span class="logo-text Outfit">Astr<span class="logo-text bot-text-wrapper">Bot
|
||||
<img v-if="isChristmas" src="@/assets/images/xmas-hat.png" alt="Christmas hat" class="xmas-hat" />
|
||||
</span></span>
|
||||
<span class="logo-text logo-text-light Outfit" style="color: grey;" v-if="customizer.viewMode === 'chat'">ChatUI</span>
|
||||
<span class="logo-text logo-text-light Outfit" style="color: grey;" v-if="isChatPath">ChatUI</span>
|
||||
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
|
||||
</div>
|
||||
|
||||
@@ -504,23 +571,23 @@ onMounted(async () => {
|
||||
</div>
|
||||
|
||||
<!-- Bot/Chat 模式切换按钮 - 手机端隐藏,移入 ... 菜单 -->
|
||||
<v-btn-toggle
|
||||
v-model="viewMode"
|
||||
mandatory
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
class="mr-4 hidden-xs"
|
||||
color="primary"
|
||||
>
|
||||
<v-btn value="bot" size="small">
|
||||
<v-icon start>mdi-robot</v-icon>
|
||||
Bot
|
||||
</v-btn>
|
||||
<v-btn value="chat" size="small">
|
||||
<v-icon start>mdi-chat</v-icon>
|
||||
Chat
|
||||
</v-btn>
|
||||
</v-btn-toggle>
|
||||
<v-btn-toggle
|
||||
v-model="currentMode"
|
||||
mandatory
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
class="mr-4 hidden-xs"
|
||||
color="primary"
|
||||
>
|
||||
<v-btn value="bot" size="small">
|
||||
<v-icon start>mdi-robot</v-icon>
|
||||
Bot
|
||||
</v-btn>
|
||||
<v-btn value="chat" size="small">
|
||||
<v-icon start>mdi-chat</v-icon>
|
||||
Chat
|
||||
</v-btn>
|
||||
</v-btn-toggle>
|
||||
|
||||
|
||||
<!-- 功能菜单 -->
|
||||
@@ -542,14 +609,14 @@ onMounted(async () => {
|
||||
<!-- Bot/Chat 模式切换 - 仅在手机端显示 -->
|
||||
<template v-if="$vuetify.display.xs">
|
||||
<div class="mobile-mode-toggle-wrapper">
|
||||
<v-btn-toggle
|
||||
v-model="viewMode"
|
||||
mandatory
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
color="primary"
|
||||
class="mobile-mode-toggle"
|
||||
>
|
||||
<v-btn-toggle
|
||||
v-model="currentMode"
|
||||
mandatory
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
class="mobile-mode-toggle"
|
||||
color="primary"
|
||||
>
|
||||
<v-btn value="bot" size="small">
|
||||
<v-icon start>mdi-robot</v-icon>
|
||||
Bot
|
||||
|
||||
@@ -39,3 +39,10 @@ html {
|
||||
transform: rotate(270deg);
|
||||
}
|
||||
}
|
||||
|
||||
pre, code, .markdown pre, .markdown code, .release-notes pre, .release-notes code {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, "Roboto Mono", "Helvetica Neue", monospace;
|
||||
color: var(--astrbot-code-color);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,10 +11,12 @@ $font-size-root: 1rem;
|
||||
$border-radius-root: 8px;
|
||||
$cjk-sans-fallback: 'PingFang SC', 'Hiragino Sans GB', 'Noto Sans CJK SC', 'Microsoft YaHei' !default;
|
||||
$cjk-mono-fallback: 'PingFang SC', 'PingFang TC', 'Hiragino Sans GB', 'Noto Sans CJK SC', 'Microsoft YaHei' !default;
|
||||
$code-text-color: #111827 !default;
|
||||
|
||||
:root {
|
||||
--astrbot-font-cjk-sans: #{$cjk-sans-fallback};
|
||||
--astrbot-font-cjk-mono: #{$cjk-mono-fallback};
|
||||
--astrbot-code-color: #{$code-text-color};
|
||||
}
|
||||
|
||||
$body-font-family: 'Roboto', $cjk-sans-fallback, sans-serif !default;
|
||||
|
||||
@@ -10,7 +10,6 @@ export const useCustomizerStore = defineStore({
|
||||
fontTheme: "Poppins",
|
||||
uiTheme: config.uiTheme,
|
||||
inputBg: config.inputBg,
|
||||
viewMode: (localStorage.getItem('viewMode') as 'bot' | 'chat') || 'bot', // 'bot' 或 'chat'
|
||||
chatSidebarOpen: false // chat mode mobile sidebar state
|
||||
}),
|
||||
|
||||
@@ -29,10 +28,7 @@ export const useCustomizerStore = defineStore({
|
||||
this.uiTheme = payload;
|
||||
localStorage.setItem("uiTheme", payload);
|
||||
},
|
||||
SET_VIEW_MODE(payload: 'bot' | 'chat') {
|
||||
this.viewMode = payload;
|
||||
localStorage.setItem('viewMode', payload);
|
||||
},
|
||||
|
||||
TOGGLE_CHAT_SIDEBAR() {
|
||||
this.chatSidebarOpen = !this.chatSidebarOpen;
|
||||
},
|
||||
|
||||
@@ -50,7 +50,7 @@ export function getTutorialLink(platformType) {
|
||||
const tutorialMap = {
|
||||
"qq_official_webhook": "https://docs.astrbot.app/platform/qqofficial/webhook.html",
|
||||
"qq_official": "https://docs.astrbot.app/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/platform/aiocqhttp/napcat.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/platform/aiocqhttp.html",
|
||||
"wecom": "https://docs.astrbot.app/platform/wecom.html",
|
||||
"weixin_oc": "https://docs.astrbot.app/platform/weixin_oc.html",
|
||||
"wecom_ai_bot": "https://docs.astrbot.app/platform/wecom_ai_bot.html",
|
||||
|
||||
3
tests/fixtures/mocks/telegram.py
vendored
3
tests/fixtures/mocks/telegram.py
vendored
@@ -33,7 +33,8 @@ def create_mock_telegram_modules():
|
||||
|
||||
mock_telegram_ext = MagicMock()
|
||||
mock_telegram_ext.ApplicationBuilder = MagicMock
|
||||
mock_telegram_ext.ContextTypes = MagicMock
|
||||
mock_telegram_ext.ContextTypes = MagicMock()
|
||||
mock_telegram_ext.ContextTypes.DEFAULT_TYPE = MagicMock
|
||||
mock_telegram_ext.ExtBot = MagicMock
|
||||
mock_telegram_ext.filters = MagicMock()
|
||||
mock_telegram_ext.filters.ALL = MagicMock()
|
||||
|
||||
@@ -2,6 +2,7 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from astrbot.core.provider.sources.groq_source import ProviderGroq
|
||||
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
|
||||
@@ -234,7 +235,9 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history():
|
||||
provider._finally_convert_payload(payloads)
|
||||
|
||||
assistant_message = payloads["messages"][0]
|
||||
assert assistant_message["content"] == [{"type": "text", "text": "final answer"}]
|
||||
assert assistant_message["content"] == [
|
||||
{"type": "text", "text": "final answer"}
|
||||
]
|
||||
assert assistant_message["reasoning_content"] == "step 1"
|
||||
finally:
|
||||
await provider.terminate()
|
||||
@@ -259,7 +262,9 @@ async def test_groq_payload_drops_reasoning_content_from_assistant_history():
|
||||
provider._finally_convert_payload(payloads)
|
||||
|
||||
assistant_message = payloads["messages"][0]
|
||||
assert assistant_message["content"] == [{"type": "text", "text": "final answer"}]
|
||||
assert assistant_message["content"] == [
|
||||
{"type": "text", "text": "final answer"}
|
||||
]
|
||||
assert "reasoning_content" not in assistant_message
|
||||
assert "reasoning" not in assistant_message
|
||||
finally:
|
||||
@@ -450,6 +455,604 @@ async def test_handle_api_error_unknown_image_error_raises():
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_invalid_attachment_removes_images_and_retries_text_only():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = _ErrorWithBody(
|
||||
"upstream error",
|
||||
{
|
||||
"error": {
|
||||
"code": "INVALID_ATTACHMENT",
|
||||
"message": "download attachment: unexpected status 404",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
success, *_rest = await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert payloads["messages"][0]["content"] == [{"type": "text", "text": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_invalid_attachment_without_images_raises():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = _ErrorWithBody(
|
||||
"upstream error",
|
||||
{
|
||||
"error": {
|
||||
"code": "INVALID_ATTACHMENT",
|
||||
"message": "download attachment: unexpected status 404",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(_ErrorWithBody, match="upstream error"):
|
||||
await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=0,
|
||||
max_retries=10,
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_invalid_attachment_after_fallback_raises():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
payloads = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
context_query = payloads["messages"]
|
||||
err = _ErrorWithBody(
|
||||
"upstream error",
|
||||
{
|
||||
"error": {
|
||||
"code": "INVALID_ATTACHMENT",
|
||||
"message": "download attachment: unexpected status 404",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(_ErrorWithBody, match="upstream error"):
|
||||
await provider._handle_api_error(
|
||||
err,
|
||||
payloads=payloads,
|
||||
context_query=context_query,
|
||||
func_tool=None,
|
||||
chosen_key="test-key",
|
||||
available_api_keys=["test-key"],
|
||||
retry_cnt=1,
|
||||
max_retries=10,
|
||||
image_fallback_used=True,
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_materializes_context_http_image_urls(monkeypatch):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
|
||||
async def fake_download(url: str) -> str:
|
||||
assert url == "https://example.com/quoted.png"
|
||||
return "/tmp/quoted.png"
|
||||
|
||||
def fake_encode(image_path: str, **_kwargs) -> str:
|
||||
assert image_path == "/tmp/quoted.png"
|
||||
return "data:image/png;base64,abcd"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.provider.sources.openai_source.download_image_by_url",
|
||||
fake_download,
|
||||
)
|
||||
monkeypatch.setattr(provider, "_encode_image_file_to_data_url", fake_encode)
|
||||
|
||||
contexts = [
|
||||
{
|
||||
"role": "user",
|
||||
"metadata": {"source": "quoted"},
|
||||
"content": [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/quoted.png",
|
||||
"id": "ctx-img",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=contexts,
|
||||
)
|
||||
|
||||
assert payloads["messages"][0]["content"] == [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,abcd",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
]
|
||||
assert payloads["messages"][0]["content"][1]["image_url"].get("id") is None
|
||||
assert contexts[0]["content"][1]["image_url"] == {
|
||||
"url": "https://example.com/quoted.png",
|
||||
"id": "ctx-img",
|
||||
"detail": "high",
|
||||
}
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_skips_materialization_for_text_only_context(
|
||||
monkeypatch,
|
||||
):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
|
||||
async def fail_if_called(_context_query):
|
||||
raise AssertionError("materialization should be skipped")
|
||||
|
||||
monkeypatch.setattr(
|
||||
provider, "_materialize_context_image_parts", fail_if_called
|
||||
)
|
||||
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert payloads["messages"] == [{"role": "user", "content": "hello"}]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_skips_materialization_for_text_only_parts(
|
||||
monkeypatch,
|
||||
):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
|
||||
async def fail_if_called(_context_query):
|
||||
raise AssertionError("materialization should be skipped")
|
||||
|
||||
monkeypatch.setattr(
|
||||
provider, "_materialize_context_image_parts", fail_if_called
|
||||
)
|
||||
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert payloads["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
}
|
||||
]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_materializes_context_http_image_urls_with_detected_mime(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
image_path = tmp_path / "quoted-image.png"
|
||||
PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path)
|
||||
|
||||
async def fake_download(url: str) -> str:
|
||||
assert url == "https://example.com/quoted.png"
|
||||
return str(image_path)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.provider.sources.openai_source.download_image_by_url",
|
||||
fake_download,
|
||||
)
|
||||
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/quoted.png",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
image_payload = payloads["messages"][0]["content"][1]["image_url"]
|
||||
assert image_payload["url"].startswith("data:image/png;base64,")
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
image_path = tmp_path / "quoted-image.png"
|
||||
PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path)
|
||||
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_path.as_uri(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
image_payload = payloads["messages"][0]["content"][1]["image_url"]
|
||||
assert image_payload["url"].startswith("data:image/png;base64,")
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_uri_to_path_preserves_windows_drive_letter():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == (
|
||||
"C:/tmp/quoted-image.png"
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_uri_to_path_preserves_windows_netloc_drive_letter():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == (
|
||||
"C:/tmp/quoted-image.png"
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
assert provider._file_uri_to_path("file://server/share/quoted-image.png") == (
|
||||
"//server/share/quoted-image.png"
|
||||
)
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_image_part_rejects_invalid_local_file(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
invalid_file = tmp_path / "not-image.txt"
|
||||
invalid_file.write_text("not an image")
|
||||
|
||||
assert await provider._resolve_image_part(str(invalid_file)) is None
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_image_part_rejects_invalid_file_uri(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
invalid_file = tmp_path / "not-image.txt"
|
||||
invalid_file.write_text("not an image")
|
||||
|
||||
assert await provider._resolve_image_part(invalid_file.as_uri()) is None
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_ref_to_data_url_mode_controls_invalid_file_behavior(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
invalid_file = tmp_path / "not-image.txt"
|
||||
invalid_file.write_text("not an image")
|
||||
|
||||
assert (
|
||||
await provider._image_ref_to_data_url(str(invalid_file), mode="safe")
|
||||
is None
|
||||
)
|
||||
with pytest.raises(ValueError, match="Invalid image file"):
|
||||
await provider._image_ref_to_data_url(str(invalid_file), mode="strict")
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_materialize_context_image_parts_returns_new_messages(monkeypatch):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
context_query = [
|
||||
{
|
||||
"role": "user",
|
||||
"metadata": {"source": "quoted"},
|
||||
"content": [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/quoted.png",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "plain text"},
|
||||
]
|
||||
|
||||
async def fake_resolve(image_url: str, *, image_detail: str | None = None):
|
||||
assert image_url == "https://example.com/quoted.png"
|
||||
assert image_detail == "high"
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,abcd",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(provider, "_resolve_image_part", fake_resolve)
|
||||
|
||||
materialized = await provider._materialize_context_image_parts(context_query)
|
||||
|
||||
assert materialized is not context_query
|
||||
assert materialized[0] is not context_query[0]
|
||||
assert materialized[0]["metadata"] is context_query[0]["metadata"]
|
||||
assert materialized[0]["content"][0] is context_query[0]["content"][0]
|
||||
assert (
|
||||
materialized[0]["content"][1]["image_url"]["url"]
|
||||
== "data:image/png;base64,abcd"
|
||||
)
|
||||
assert (
|
||||
context_query[0]["content"][1]["image_url"]["url"]
|
||||
== "https://example.com/quoted.png"
|
||||
)
|
||||
assert materialized[1] is not context_query[1]
|
||||
assert materialized[1]["content"] == "plain text"
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_image_bs64_missing_file_raises(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
missing_path = tmp_path / "missing-image.png"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await provider.encode_image_bs64(str(missing_path))
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_image_bs64_invalid_file_raises(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
invalid_file = tmp_path / "not-image.txt"
|
||||
invalid_file.write_text("not an image")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid image file"):
|
||||
await provider.encode_image_bs64(str(invalid_file))
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_image_bs64_supports_base64_scheme():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
image_data = await provider.encode_image_bs64("base64://abcd")
|
||||
|
||||
assert image_data == "data:image/jpeg;base64,abcd"
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_image_bs64_supports_file_uri(tmp_path):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
image_path = tmp_path / "quoted-image.png"
|
||||
PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path)
|
||||
|
||||
image_data = await provider.encode_image_bs64(image_path.as_uri())
|
||||
|
||||
assert image_data.startswith("data:image/png;base64,")
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_image_part_supports_base64_scheme():
|
||||
provider = _make_provider()
|
||||
try:
|
||||
assert await provider._resolve_image_part("base64://abcd") == {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,abcd"},
|
||||
}
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_materializes_context_localhost_file_uri_image_urls(
|
||||
tmp_path,
|
||||
):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
image_path = tmp_path / "quoted-image.png"
|
||||
PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path)
|
||||
|
||||
localhost_uri = f"file://localhost{image_path.as_posix()}"
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": localhost_uri,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
image_payload = payloads["messages"][0]["content"][1]["image_url"]
|
||||
assert image_payload["url"].startswith("data:image/png;base64,")
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_chat_payload_keeps_original_context_image_when_materialization_fails(
|
||||
monkeypatch,
|
||||
):
|
||||
provider = _make_provider()
|
||||
try:
|
||||
|
||||
async def fake_download(url: str) -> str:
|
||||
assert url == "https://example.com/expired.png"
|
||||
return "/tmp/not-an-image"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"astrbot.core.provider.sources.openai_source.download_image_by_url",
|
||||
fake_download,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
provider,
|
||||
"_encode_image_file_to_data_url",
|
||||
lambda _image_path, **_kwargs: None,
|
||||
)
|
||||
|
||||
payloads, _ = await provider._prepare_chat_payload(
|
||||
prompt=None,
|
||||
contexts=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/expired.png",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert payloads["messages"][0]["content"] == [
|
||||
{"type": "text", "text": "look"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/expired.png",
|
||||
},
|
||||
},
|
||||
]
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_provider_specific_extra_body_overrides_disables_ollama_thinking():
|
||||
provider = _make_provider(
|
||||
|
||||
108
tests/test_telegram_adapter.py
Normal file
108
tests/test_telegram_adapter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from tests.fixtures.helpers import (
|
||||
create_mock_file,
|
||||
create_mock_update,
|
||||
make_platform_config,
|
||||
)
|
||||
from tests.fixtures.mocks.telegram import create_mock_telegram_modules
|
||||
|
||||
_TELEGRAM_PLATFORM_ADAPTER = None
|
||||
|
||||
|
||||
def _load_telegram_adapter():
|
||||
global _TELEGRAM_PLATFORM_ADAPTER
|
||||
if _TELEGRAM_PLATFORM_ADAPTER is not None:
|
||||
return _TELEGRAM_PLATFORM_ADAPTER
|
||||
|
||||
mocks = create_mock_telegram_modules()
|
||||
patched_modules = {
|
||||
"telegram": mocks["telegram"],
|
||||
"telegram.constants": mocks["telegram"].constants,
|
||||
"telegram.error": mocks["telegram"].error,
|
||||
"telegram.ext": mocks["telegram.ext"],
|
||||
"telegramify_markdown": mocks["telegramify_markdown"],
|
||||
"apscheduler": mocks["apscheduler"],
|
||||
"apscheduler.schedulers": mocks["apscheduler"].schedulers,
|
||||
"apscheduler.schedulers.asyncio": mocks["apscheduler"].schedulers.asyncio,
|
||||
"apscheduler.schedulers.background": mocks["apscheduler"].schedulers.background,
|
||||
}
|
||||
with patch.dict(sys.modules, patched_modules):
|
||||
sys.modules.pop("astrbot.core.platform.sources.telegram.tg_adapter", None)
|
||||
module = importlib.import_module("astrbot.core.platform.sources.telegram.tg_adapter")
|
||||
_TELEGRAM_PLATFORM_ADAPTER = module.TelegramPlatformAdapter
|
||||
return _TELEGRAM_PLATFORM_ADAPTER
|
||||
|
||||
|
||||
def _build_context() -> MagicMock:
|
||||
context = MagicMock()
|
||||
context.bot.username = "test_bot"
|
||||
context.bot.id = 12345678
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telegram_document_caption_populates_message_text_and_plain():
|
||||
TelegramPlatformAdapter = _load_telegram_adapter()
|
||||
adapter = TelegramPlatformAdapter(
|
||||
make_platform_config("telegram"),
|
||||
{},
|
||||
asyncio.Queue(),
|
||||
)
|
||||
document = create_mock_file("https://api.telegram.org/file/test/report.md")
|
||||
document.file_name = "report.md"
|
||||
mention = MagicMock(type="mention", offset=0, length=6)
|
||||
update = create_mock_update(
|
||||
message_text=None,
|
||||
document=document,
|
||||
caption="@alice 请总结这份文档",
|
||||
caption_entities=[mention],
|
||||
)
|
||||
|
||||
result = await adapter.convert_message(update, _build_context())
|
||||
|
||||
assert result is not None
|
||||
assert result.message_str == "@alice 请总结这份文档"
|
||||
assert any(isinstance(component, Comp.File) for component in result.message)
|
||||
assert any(
|
||||
isinstance(component, Comp.Plain)
|
||||
and component.text == "@alice 请总结这份文档"
|
||||
for component in result.message
|
||||
)
|
||||
assert any(
|
||||
isinstance(component, Comp.At) and component.qq == "alice"
|
||||
for component in result.message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telegram_video_caption_populates_message_text_and_plain():
|
||||
TelegramPlatformAdapter = _load_telegram_adapter()
|
||||
adapter = TelegramPlatformAdapter(
|
||||
make_platform_config("telegram"),
|
||||
{},
|
||||
asyncio.Queue(),
|
||||
)
|
||||
video = create_mock_file("https://api.telegram.org/file/test/lesson.mp4")
|
||||
video.file_name = "lesson.mp4"
|
||||
update = create_mock_update(
|
||||
message_text=None,
|
||||
video=video,
|
||||
caption="这段视频讲了什么",
|
||||
)
|
||||
|
||||
result = await adapter.convert_message(update, _build_context())
|
||||
|
||||
assert result is not None
|
||||
assert result.message_str == "这段视频讲了什么"
|
||||
assert any(isinstance(component, Comp.Video) for component in result.message)
|
||||
assert any(
|
||||
isinstance(component, Comp.Plain) and component.text == "这段视频讲了什么"
|
||||
for component in result.message
|
||||
)
|
||||
Reference in New Issue
Block a user