mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 01:10:21 +08:00
fix: preserve embedding api version suffixes (#8736)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
@@ -8,6 +10,13 @@ from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
def _normalize_api_base(api_base: str) -> str:
|
||||
api_base = api_base.strip().removesuffix("/").removesuffix("/embeddings")
|
||||
if api_base and not re.search(r"/v\d+$", api_base):
|
||||
api_base = api_base + "/v1"
|
||||
return api_base
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"openai_embedding",
|
||||
"OpenAI API Embedding 提供商适配器",
|
||||
@@ -24,15 +33,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
if proxy:
|
||||
logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}")
|
||||
http_client = httpx.AsyncClient(proxy=proxy)
|
||||
api_base = (
|
||||
api_base = _normalize_api_base(
|
||||
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
|
||||
.strip()
|
||||
.removesuffix("/")
|
||||
.removesuffix("/embeddings")
|
||||
)
|
||||
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
|
||||
# /v4 see #5699
|
||||
api_base = api_base + "/v1"
|
||||
logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
|
||||
18
tests/test_openai_embedding_source.py
Normal file
18
tests/test_openai_embedding_source.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from astrbot.core.provider.sources.openai_embedding_source import _normalize_api_base
|
||||
|
||||
|
||||
def test_openai_embedding_api_base_keeps_version_suffixes():
|
||||
assert (
|
||||
_normalize_api_base("https://ark.cn-beijing.volces.com/api/plan/v3")
|
||||
== "https://ark.cn-beijing.volces.com/api/plan/v3"
|
||||
)
|
||||
assert _normalize_api_base("https://example.test/v4") == "https://example.test/v4"
|
||||
|
||||
|
||||
def test_openai_embedding_api_base_adds_default_version():
|
||||
assert _normalize_api_base("https://example.test/openai") == (
|
||||
"https://example.test/openai/v1"
|
||||
)
|
||||
assert _normalize_api_base("https://example.test/v1/embeddings") == (
|
||||
"https://example.test/v1"
|
||||
)
|
||||
Reference in New Issue
Block a user