Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
3931c3ca79 feat: implement user agent handling across providers and normalize headers 2026-05-30 21:02:39 +08:00
9 changed files with 197 additions and 23 deletions

View File

@@ -6,6 +6,7 @@ from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.25.2"
ASTRBOT_USER_AGENT = f"astrbot/{VERSION.removeprefix('v')}"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
PERSONAL_WECHAT_CONFIG_METADATA = {
"weixin_oc_base_url": {
@@ -1199,7 +1200,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.kimi.com/coding",
"timeout": 120,
"proxy": "",
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
"custom_headers": {"User-Agent": ASTRBOT_USER_AGENT},
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
},
"Moonshot": {
@@ -1236,7 +1237,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.minimaxi.com/anthropic",
"timeout": 120,
"proxy": "",
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
"custom_headers": {"User-Agent": ASTRBOT_USER_AGENT},
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
},
"Xiaomi": {
@@ -1261,7 +1262,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://token-plan-cn.xiaomimimo.com/anthropic",
"timeout": 120,
"proxy": "",
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
"custom_headers": {"User-Agent": ASTRBOT_USER_AGENT},
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
},
"xAI": {

View File

@@ -13,9 +13,11 @@ from anthropic.types.usage import Usage
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import AudioURLPart, ContentPart, ImageURLPart, TextPart
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.network_utils import (
create_proxy_client,
@@ -50,13 +52,12 @@ class ProviderAnthropic(Provider):
@staticmethod
def _normalize_custom_headers(provider_config: dict) -> dict[str, str] | None:
custom_headers = provider_config.get("custom_headers", {})
if not isinstance(custom_headers, dict) or not custom_headers:
normalized_headers = normalize_headers(
provider_config.get("custom_headers", {})
)
if not normalized_headers:
return None
normalized_headers: dict[str, str] = {}
for key, value in custom_headers.items():
normalized_headers[str(key)] = str(value)
return normalized_headers or None
return normalized_headers
@classmethod
def _resolve_custom_headers(
@@ -67,9 +68,7 @@ class ProviderAnthropic(Provider):
) -> dict[str, str] | None:
merged_headers = cls._normalize_custom_headers(provider_config) or {}
if required_headers:
for header_name, header_value in required_headers.items():
if not merged_headers.get(header_name, "").strip():
merged_headers[header_name] = header_value
merged_headers = apply_default_headers(merged_headers, required_headers)
return merged_headers or None
def __init__(
@@ -89,7 +88,10 @@ class ProviderAnthropic(Provider):
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.thinking_config = provider_config.get("anth_thinking_config", {})
self.custom_headers = self._resolve_custom_headers(provider_config)
self.custom_headers = self._resolve_custom_headers(
provider_config,
required_headers={"User-Agent": ASTRBOT_USER_AGENT},
)
if use_api_key:
self._init_api_key(provider_config)

View File

@@ -18,11 +18,13 @@ import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import AudioURLPart, ContentPart, ImageURLPart, TextPart
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers
from astrbot.core.utils.io import download_file, download_image_by_url
from astrbot.core.utils.media_utils import ensure_wav
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
@@ -76,17 +78,41 @@ class ProviderGoogleGenAI(Provider):
if self.api_base and self.api_base.endswith("/"):
self.api_base = self.api_base[:-1]
self.custom_headers = self._resolve_custom_headers(provider_config)
self._http_client: httpx.AsyncClient | None = None
self._stale_http_clients: list[httpx.AsyncClient] = []
self._init_client()
self.set_model(provider_config.get("model", "unknown"))
self._init_safety_settings()
@staticmethod
def _resolve_custom_headers(provider_config: dict) -> dict[str, str]:
headers = apply_default_headers(
normalize_headers(provider_config.get("custom_headers", {})),
{"user-agent": ASTRBOT_USER_AGENT},
)
return {
"user-agent" if key.lower() == "user-agent" else key: value
for key, value in headers.items()
}
@staticmethod
def _set_gemini_user_agent(client: object, user_agent: str) -> None:
api_client = getattr(client, "_api_client", None)
http_options = getattr(api_client, "_http_options", None)
if http_options is None or http_options.headers is None:
return
for key in list(http_options.headers):
if key.lower() == "user-agent":
http_options.headers.pop(key)
http_options.headers["user-agent"] = user_agent
def _init_client(self) -> None:
"""初始化Gemini客户端"""
proxy = self.provider_config.get("proxy", "")
http_options = types.HttpOptions(
base_url=self.api_base,
headers=dict(self.custom_headers),
timeout=self.timeout * 1000, # 毫秒
)
@@ -94,6 +120,7 @@ class ProviderGoogleGenAI(Provider):
# httpx.AsyncClient 的 timeout 单位为秒(与 HttpOptions 的毫秒不同)
async_client_kwargs: dict = {
"base_url": self.api_base,
"headers": dict(self.custom_headers),
"timeout": self.timeout,
}
if proxy:
@@ -112,10 +139,15 @@ class ProviderGoogleGenAI(Provider):
self._http_client = httpx.AsyncClient(**async_client_kwargs)
http_options.httpx_async_client = self._http_client
self.client = genai.Client(
genai_client = genai.Client(
api_key=self.chosen_api_key,
http_options=http_options,
).aio
)
self._set_gemini_user_agent(
genai_client,
self.custom_headers["user-agent"],
)
self.client = genai_client.aio
def _init_safety_settings(self) -> None:
"""初始化安全设置"""

View File

@@ -1,9 +1,11 @@
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from ..register import register_provider_adapter
from .anthropic_source import ProviderAnthropic
KIMI_CODE_API_BASE = "https://api.kimi.com/coding"
KIMI_CODE_DEFAULT_MODEL = "kimi-for-coding"
KIMI_CODE_USER_AGENT = "claude-code/0.1.0"
KIMI_CODE_USER_AGENT = ASTRBOT_USER_AGENT
@register_provider_adapter(

View File

@@ -34,10 +34,12 @@ from astrbot.core.agent.message import (
TextPart,
)
from astrbot.core.agent.tool import ToolSet
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers
from astrbot.core.utils.io import download_file, download_image_by_url
from astrbot.core.utils.media_utils import ensure_wav
from astrbot.core.utils.network_utils import (
@@ -68,6 +70,13 @@ class ProviderOpenAIOfficial(Provider):
"AVIF": "image/avif",
}
@staticmethod
def _resolve_custom_headers(provider_config: dict) -> dict[str, str]:
return apply_default_headers(
normalize_headers(provider_config.get("custom_headers", {})),
{"User-Agent": ASTRBOT_USER_AGENT},
)
@classmethod
def _truncate_error_text_candidate(cls, text: str) -> str:
if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS:
@@ -498,16 +507,10 @@ class ProviderOpenAIOfficial(Provider):
self.api_keys: list = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
self.custom_headers = provider_config.get("custom_headers", {})
self.custom_headers = self._resolve_custom_headers(provider_config)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
if not isinstance(self.custom_headers, dict) or not self.custom_headers:
self.custom_headers = None
else:
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])
if "api_version" in provider_config:
# Using Azure OpenAI API
self.client = AsyncAzureOpenAI(

View File

@@ -0,0 +1,31 @@
from collections.abc import Mapping
def normalize_headers(headers: object) -> dict[str, str]:
if not isinstance(headers, dict):
return {}
return {str(key): str(value) for key, value in headers.items()}
def apply_default_headers(
headers: dict[str, str],
default_headers: Mapping[str, str],
) -> dict[str, str]:
merged_headers = dict(headers)
for default_name, default_value in default_headers.items():
existing_name = next(
(
header_name
for header_name in merged_headers
if header_name.lower() == default_name.lower()
),
None,
)
if existing_name is None:
merged_headers[default_name] = default_value
continue
if merged_headers[existing_name].strip():
continue
merged_headers.pop(existing_name)
merged_headers[default_name] = default_value
return merged_headers

View File

@@ -4,6 +4,7 @@ import pytest
import astrbot.core.provider.sources.anthropic_source as anthropic_source
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
@@ -16,6 +17,25 @@ class _FakeAsyncAnthropic:
return None
def test_anthropic_provider_uses_astrbot_default_user_agent(monkeypatch):
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
provider = anthropic_source.ProviderAnthropic(
provider_config={
"id": "anthropic-test",
"type": "anthropic_chat_completion",
"model": "claude-test",
"key": ["test-key"],
},
provider_settings={},
)
assert provider.custom_headers == {"User-Agent": ASTRBOT_USER_AGENT}
assert provider.client.kwargs["default_headers"] == {
"User-Agent": ASTRBOT_USER_AGENT,
}
def test_anthropic_provider_passes_custom_headers_via_default_headers(monkeypatch):
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)

View File

@@ -1,10 +1,48 @@
import pytest
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from astrbot.core.exceptions import EmptyModelOutputError
import astrbot.core.provider.sources.gemini_source as gemini_source
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
class _FakeGenAIClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self._api_client = type(
"FakeAPIClient",
(),
{"_http_options": kwargs["http_options"]},
)()
self.aio = type("FakeAioClient", (), {"_api_client": self._api_client})()
@pytest.mark.asyncio
async def test_gemini_provider_uses_astrbot_default_user_agent(monkeypatch):
monkeypatch.setattr(gemini_source.genai, "Client", _FakeGenAIClient)
provider = ProviderGoogleGenAI(
provider_config={
"id": "gemini-test",
"type": "googlegenai_chat_completion",
"model": "gemini-test",
"key": ["test-key"],
"api_base": "https://generativelanguage.googleapis.com/",
},
provider_settings={},
)
try:
assert provider.custom_headers["user-agent"] == ASTRBOT_USER_AGENT
assert provider.client._api_client._http_options.headers["user-agent"] == (
ASTRBOT_USER_AGENT
)
assert provider._http_client.headers["user-agent"] == ASTRBOT_USER_AGENT
finally:
await provider.terminate()
def test_gemini_empty_output_raises_empty_model_output_error():
llm_response = LLMResponse(role="assistant")

View File

@@ -9,6 +9,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from PIL import Image as PILImage
import astrbot.core.provider.sources.openai_source as openai_source_module
from astrbot.core.config.default import ASTRBOT_USER_AGENT
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.sources.groq_source import ProviderGroq
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
@@ -26,6 +27,17 @@ class _ErrorWithResponse(Exception):
self.response = SimpleNamespace(text=response_text)
class _FakeChatCompletions:
def create(self):
return None
class _FakeAsyncOpenAI:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.chat = SimpleNamespace(completions=_FakeChatCompletions())
def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
provider_config = {
"id": "test-openai",
@@ -56,6 +68,39 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq:
)
def test_openai_provider_uses_astrbot_default_user_agent(monkeypatch):
monkeypatch.setattr(openai_source_module, "AsyncOpenAI", _FakeAsyncOpenAI)
provider = _make_provider()
assert provider.custom_headers == {"User-Agent": ASTRBOT_USER_AGENT}
assert provider.client.kwargs["default_headers"] == {
"User-Agent": ASTRBOT_USER_AGENT,
}
def test_openai_provider_preserves_custom_user_agent(monkeypatch):
monkeypatch.setattr(openai_source_module, "AsyncOpenAI", _FakeAsyncOpenAI)
provider = _make_provider(
{
"custom_headers": {
"User-Agent": "custom-agent/1.0",
"X-Test-Header": 123,
},
},
)
assert provider.custom_headers == {
"User-Agent": "custom-agent/1.0",
"X-Test-Header": "123",
}
assert provider.client.kwargs["default_headers"] == {
"User-Agent": "custom-agent/1.0",
"X-Test-Header": "123",
}
def test_create_http_client_uses_openai_httpx_module(monkeypatch):
captured: dict[str, object] = {}