mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 18:20:16 +08:00
Compare commits
1 Commits
codex/prep
...
feat/user-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3931c3ca79 |
@@ -6,6 +6,7 @@ from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.25.2"
|
||||
ASTRBOT_USER_AGENT = f"astrbot/{VERSION.removeprefix('v')}"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
PERSONAL_WECHAT_CONFIG_METADATA = {
|
||||
"weixin_oc_base_url": {
|
||||
@@ -1199,7 +1200,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.kimi.com/coding",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
|
||||
"custom_headers": {"User-Agent": ASTRBOT_USER_AGENT},
|
||||
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
|
||||
},
|
||||
"Moonshot": {
|
||||
@@ -1236,7 +1237,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.minimaxi.com/anthropic",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
|
||||
"custom_headers": {"User-Agent": ASTRBOT_USER_AGENT},
|
||||
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
|
||||
},
|
||||
"Xiaomi": {
|
||||
@@ -1261,7 +1262,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://token-plan-cn.xiaomimimo.com/anthropic",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
|
||||
"custom_headers": {"User-Agent": ASTRBOT_USER_AGENT},
|
||||
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
|
||||
},
|
||||
"xAI": {
|
||||
|
||||
@@ -13,9 +13,11 @@ from anthropic.types.usage import Usage
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import AudioURLPart, ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import (
|
||||
create_proxy_client,
|
||||
@@ -50,13 +52,12 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
@staticmethod
|
||||
def _normalize_custom_headers(provider_config: dict) -> dict[str, str] | None:
|
||||
custom_headers = provider_config.get("custom_headers", {})
|
||||
if not isinstance(custom_headers, dict) or not custom_headers:
|
||||
normalized_headers = normalize_headers(
|
||||
provider_config.get("custom_headers", {})
|
||||
)
|
||||
if not normalized_headers:
|
||||
return None
|
||||
normalized_headers: dict[str, str] = {}
|
||||
for key, value in custom_headers.items():
|
||||
normalized_headers[str(key)] = str(value)
|
||||
return normalized_headers or None
|
||||
return normalized_headers
|
||||
|
||||
@classmethod
|
||||
def _resolve_custom_headers(
|
||||
@@ -67,9 +68,7 @@ class ProviderAnthropic(Provider):
|
||||
) -> dict[str, str] | None:
|
||||
merged_headers = cls._normalize_custom_headers(provider_config) or {}
|
||||
if required_headers:
|
||||
for header_name, header_value in required_headers.items():
|
||||
if not merged_headers.get(header_name, "").strip():
|
||||
merged_headers[header_name] = header_value
|
||||
merged_headers = apply_default_headers(merged_headers, required_headers)
|
||||
return merged_headers or None
|
||||
|
||||
def __init__(
|
||||
@@ -89,7 +88,10 @@ class ProviderAnthropic(Provider):
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.thinking_config = provider_config.get("anth_thinking_config", {})
|
||||
self.custom_headers = self._resolve_custom_headers(provider_config)
|
||||
self.custom_headers = self._resolve_custom_headers(
|
||||
provider_config,
|
||||
required_headers={"User-Agent": ASTRBOT_USER_AGENT},
|
||||
)
|
||||
|
||||
if use_api_key:
|
||||
self._init_api_key(provider_config)
|
||||
|
||||
@@ -18,11 +18,13 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import AudioURLPart, ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
from astrbot.core.utils.media_utils import ensure_wav
|
||||
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
||||
@@ -76,17 +78,41 @@ class ProviderGoogleGenAI(Provider):
|
||||
if self.api_base and self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
self.custom_headers = self._resolve_custom_headers(provider_config)
|
||||
self._http_client: httpx.AsyncClient | None = None
|
||||
self._stale_http_clients: list[httpx.AsyncClient] = []
|
||||
self._init_client()
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
self._init_safety_settings()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_custom_headers(provider_config: dict) -> dict[str, str]:
|
||||
headers = apply_default_headers(
|
||||
normalize_headers(provider_config.get("custom_headers", {})),
|
||||
{"user-agent": ASTRBOT_USER_AGENT},
|
||||
)
|
||||
return {
|
||||
"user-agent" if key.lower() == "user-agent" else key: value
|
||||
for key, value in headers.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _set_gemini_user_agent(client: object, user_agent: str) -> None:
|
||||
api_client = getattr(client, "_api_client", None)
|
||||
http_options = getattr(api_client, "_http_options", None)
|
||||
if http_options is None or http_options.headers is None:
|
||||
return
|
||||
for key in list(http_options.headers):
|
||||
if key.lower() == "user-agent":
|
||||
http_options.headers.pop(key)
|
||||
http_options.headers["user-agent"] = user_agent
|
||||
|
||||
def _init_client(self) -> None:
|
||||
"""初始化Gemini客户端"""
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
http_options = types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
headers=dict(self.custom_headers),
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
)
|
||||
|
||||
@@ -94,6 +120,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
# httpx.AsyncClient 的 timeout 单位为秒(与 HttpOptions 的毫秒不同)
|
||||
async_client_kwargs: dict = {
|
||||
"base_url": self.api_base,
|
||||
"headers": dict(self.custom_headers),
|
||||
"timeout": self.timeout,
|
||||
}
|
||||
if proxy:
|
||||
@@ -112,10 +139,15 @@ class ProviderGoogleGenAI(Provider):
|
||||
self._http_client = httpx.AsyncClient(**async_client_kwargs)
|
||||
http_options.httpx_async_client = self._http_client
|
||||
|
||||
self.client = genai.Client(
|
||||
genai_client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=http_options,
|
||||
).aio
|
||||
)
|
||||
self._set_gemini_user_agent(
|
||||
genai_client,
|
||||
self.custom_headers["user-agent"],
|
||||
)
|
||||
self.client = genai_client.aio
|
||||
|
||||
def _init_safety_settings(self) -> None:
|
||||
"""初始化安全设置"""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .anthropic_source import ProviderAnthropic
|
||||
|
||||
KIMI_CODE_API_BASE = "https://api.kimi.com/coding"
|
||||
KIMI_CODE_DEFAULT_MODEL = "kimi-for-coding"
|
||||
KIMI_CODE_USER_AGENT = "claude-code/0.1.0"
|
||||
KIMI_CODE_USER_AGENT = ASTRBOT_USER_AGENT
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
|
||||
@@ -34,10 +34,12 @@ from astrbot.core.agent.message import (
|
||||
TextPart,
|
||||
)
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
from astrbot.core.utils.media_utils import ensure_wav
|
||||
from astrbot.core.utils.network_utils import (
|
||||
@@ -68,6 +70,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"AVIF": "image/avif",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_custom_headers(provider_config: dict) -> dict[str, str]:
|
||||
return apply_default_headers(
|
||||
normalize_headers(provider_config.get("custom_headers", {})),
|
||||
{"User-Agent": ASTRBOT_USER_AGENT},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _truncate_error_text_candidate(cls, text: str) -> str:
|
||||
if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS:
|
||||
@@ -498,16 +507,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
self.custom_headers = provider_config.get("custom_headers", {})
|
||||
self.custom_headers = self._resolve_custom_headers(provider_config)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
if not isinstance(self.custom_headers, dict) or not self.custom_headers:
|
||||
self.custom_headers = None
|
||||
else:
|
||||
for key in self.custom_headers:
|
||||
self.custom_headers[key] = str(self.custom_headers[key])
|
||||
|
||||
if "api_version" in provider_config:
|
||||
# Using Azure OpenAI API
|
||||
self.client = AsyncAzureOpenAI(
|
||||
|
||||
31
astrbot/core/utils/http_headers.py
Normal file
31
astrbot/core/utils/http_headers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
def normalize_headers(headers: object) -> dict[str, str]:
|
||||
if not isinstance(headers, dict):
|
||||
return {}
|
||||
return {str(key): str(value) for key, value in headers.items()}
|
||||
|
||||
|
||||
def apply_default_headers(
|
||||
headers: dict[str, str],
|
||||
default_headers: Mapping[str, str],
|
||||
) -> dict[str, str]:
|
||||
merged_headers = dict(headers)
|
||||
for default_name, default_value in default_headers.items():
|
||||
existing_name = next(
|
||||
(
|
||||
header_name
|
||||
for header_name in merged_headers
|
||||
if header_name.lower() == default_name.lower()
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_name is None:
|
||||
merged_headers[default_name] = default_value
|
||||
continue
|
||||
if merged_headers[existing_name].strip():
|
||||
continue
|
||||
merged_headers.pop(existing_name)
|
||||
merged_headers[default_name] = default_value
|
||||
return merged_headers
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
import astrbot.core.provider.sources.anthropic_source as anthropic_source
|
||||
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
@@ -16,6 +17,25 @@ class _FakeAsyncAnthropic:
|
||||
return None
|
||||
|
||||
|
||||
def test_anthropic_provider_uses_astrbot_default_user_agent(monkeypatch):
|
||||
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
|
||||
|
||||
provider = anthropic_source.ProviderAnthropic(
|
||||
provider_config={
|
||||
"id": "anthropic-test",
|
||||
"type": "anthropic_chat_completion",
|
||||
"model": "claude-test",
|
||||
"key": ["test-key"],
|
||||
},
|
||||
provider_settings={},
|
||||
)
|
||||
|
||||
assert provider.custom_headers == {"User-Agent": ASTRBOT_USER_AGENT}
|
||||
assert provider.client.kwargs["default_headers"] == {
|
||||
"User-Agent": ASTRBOT_USER_AGENT,
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_provider_passes_custom_headers_via_default_headers(monkeypatch):
|
||||
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
|
||||
|
||||
|
||||
@@ -1,10 +1,48 @@
|
||||
import pytest
|
||||
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
import astrbot.core.provider.sources.gemini_source as gemini_source
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
|
||||
|
||||
|
||||
class _FakeGenAIClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self._api_client = type(
|
||||
"FakeAPIClient",
|
||||
(),
|
||||
{"_http_options": kwargs["http_options"]},
|
||||
)()
|
||||
self.aio = type("FakeAioClient", (), {"_api_client": self._api_client})()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_provider_uses_astrbot_default_user_agent(monkeypatch):
|
||||
monkeypatch.setattr(gemini_source.genai, "Client", _FakeGenAIClient)
|
||||
|
||||
provider = ProviderGoogleGenAI(
|
||||
provider_config={
|
||||
"id": "gemini-test",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"model": "gemini-test",
|
||||
"key": ["test-key"],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
},
|
||||
provider_settings={},
|
||||
)
|
||||
|
||||
try:
|
||||
assert provider.custom_headers["user-agent"] == ASTRBOT_USER_AGENT
|
||||
assert provider.client._api_client._http_options.headers["user-agent"] == (
|
||||
ASTRBOT_USER_AGENT
|
||||
)
|
||||
assert provider._http_client.headers["user-agent"] == ASTRBOT_USER_AGENT
|
||||
finally:
|
||||
await provider.terminate()
|
||||
|
||||
|
||||
def test_gemini_empty_output_raises_empty_model_output_error():
|
||||
llm_response = LLMResponse(role="assistant")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from PIL import Image as PILImage
|
||||
|
||||
import astrbot.core.provider.sources.openai_source as openai_source_module
|
||||
from astrbot.core.config.default import ASTRBOT_USER_AGENT
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.provider.sources.groq_source import ProviderGroq
|
||||
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
|
||||
@@ -26,6 +27,17 @@ class _ErrorWithResponse(Exception):
|
||||
self.response = SimpleNamespace(text=response_text)
|
||||
|
||||
|
||||
class _FakeChatCompletions:
|
||||
def create(self):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeAsyncOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.chat = SimpleNamespace(completions=_FakeChatCompletions())
|
||||
|
||||
|
||||
def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial:
|
||||
provider_config = {
|
||||
"id": "test-openai",
|
||||
@@ -56,6 +68,39 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq:
|
||||
)
|
||||
|
||||
|
||||
def test_openai_provider_uses_astrbot_default_user_agent(monkeypatch):
|
||||
monkeypatch.setattr(openai_source_module, "AsyncOpenAI", _FakeAsyncOpenAI)
|
||||
|
||||
provider = _make_provider()
|
||||
|
||||
assert provider.custom_headers == {"User-Agent": ASTRBOT_USER_AGENT}
|
||||
assert provider.client.kwargs["default_headers"] == {
|
||||
"User-Agent": ASTRBOT_USER_AGENT,
|
||||
}
|
||||
|
||||
|
||||
def test_openai_provider_preserves_custom_user_agent(monkeypatch):
|
||||
monkeypatch.setattr(openai_source_module, "AsyncOpenAI", _FakeAsyncOpenAI)
|
||||
|
||||
provider = _make_provider(
|
||||
{
|
||||
"custom_headers": {
|
||||
"User-Agent": "custom-agent/1.0",
|
||||
"X-Test-Header": 123,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert provider.custom_headers == {
|
||||
"User-Agent": "custom-agent/1.0",
|
||||
"X-Test-Header": "123",
|
||||
}
|
||||
assert provider.client.kwargs["default_headers"] == {
|
||||
"User-Agent": "custom-agent/1.0",
|
||||
"X-Test-Header": "123",
|
||||
}
|
||||
|
||||
|
||||
def test_create_http_client_uses_openai_httpx_module(monkeypatch):
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user