mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 11:10:14 +08:00
Compare commits
1 Commits
codex/fix-
...
feat--llm-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af204f4273 |
@@ -24,6 +24,7 @@ from astrbot.core.utils.network_utils import (
|
||||
)
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .request_retry import retry_provider_request, retry_provider_request_context
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -366,8 +367,11 @@ class ProviderAnthropic(Provider):
|
||||
self._apply_thinking_config(payloads)
|
||||
|
||||
try:
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
completion = await retry_provider_request(
|
||||
"Anthropic",
|
||||
lambda: self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
),
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
@@ -459,8 +463,9 @@ class ProviderAnthropic(Provider):
|
||||
payloads["max_tokens"] = 65536
|
||||
self._apply_thinking_config(payloads)
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
async with retry_provider_request_context(
|
||||
"Anthropic",
|
||||
lambda: self.client.messages.stream(**payloads, extra_body=extra_body),
|
||||
) as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
@@ -838,7 +843,10 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
async def get_models(self) -> list[str]:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
models = await retry_provider_request(
|
||||
"Anthropic",
|
||||
lambda: self.client.models.list(),
|
||||
)
|
||||
models = sorted(models.data, key=lambda x: x.id)
|
||||
for model in models:
|
||||
models_str.append(model.id)
|
||||
|
||||
@@ -28,6 +28,7 @@ from astrbot.core.utils.media_utils import ensure_wav
|
||||
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .request_retry import retry_provider_request
|
||||
|
||||
|
||||
class SuppressNonTextPartsWarning(logging.Filter):
|
||||
@@ -630,10 +631,14 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalities,
|
||||
temperature,
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
result = await retry_provider_request(
|
||||
"Gemini",
|
||||
lambda: self.client.models.generate_content(
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
),
|
||||
retry_rate_limits=False,
|
||||
)
|
||||
logger.debug(f"genai result: {result}")
|
||||
|
||||
@@ -710,10 +715,14 @@ class ProviderGoogleGenAI(Provider):
|
||||
payloads.get("tool_choice", "auto"),
|
||||
system_instruction,
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
result = await retry_provider_request(
|
||||
"Gemini",
|
||||
lambda: self.client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
),
|
||||
retry_rate_limits=False,
|
||||
)
|
||||
break
|
||||
except APIError as e:
|
||||
@@ -940,7 +949,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models = await self.client.models.list()
|
||||
models = await retry_provider_request(
|
||||
"Gemini",
|
||||
lambda: self.client.models.list(),
|
||||
)
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
|
||||
@@ -48,6 +48,7 @@ from astrbot.core.utils.network_utils import (
|
||||
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from .request_retry import retry_provider_request
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -560,7 +561,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def get_models(self):
|
||||
try:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
models = await retry_provider_request(
|
||||
"OpenAI",
|
||||
lambda: self.client.models.list(),
|
||||
)
|
||||
models = sorted(models.data, key=lambda x: x.id)
|
||||
for model in models:
|
||||
models_str.append(model.id)
|
||||
@@ -636,10 +640,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
self._sanitize_assistant_messages(payloads)
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False,
|
||||
extra_body=extra_body,
|
||||
completion = await retry_provider_request(
|
||||
"OpenAI",
|
||||
lambda: self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False,
|
||||
extra_body=extra_body,
|
||||
),
|
||||
retry_rate_limits=False,
|
||||
)
|
||||
|
||||
if not isinstance(completion, ChatCompletion):
|
||||
@@ -688,11 +696,15 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
self._sanitize_assistant_messages(payloads)
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=True,
|
||||
extra_body=extra_body,
|
||||
stream_options={"include_usage": True},
|
||||
stream = await retry_provider_request(
|
||||
"OpenAI",
|
||||
lambda: self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=True,
|
||||
extra_body=extra_body,
|
||||
stream_options={"include_usage": True},
|
||||
),
|
||||
retry_rate_limits=False,
|
||||
)
|
||||
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
141
astrbot/core/provider/sources/request_retry.py
Normal file
141
astrbot/core/provider/sources/request_retry.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from typing import TypeVar
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryCallState,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.network_utils import is_connection_error
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
REQUEST_RETRY_ATTEMPTS = 5
|
||||
REQUEST_RETRY_WAIT_MIN_S = 1
|
||||
REQUEST_RETRY_WAIT_MAX_S = 8
|
||||
REQUEST_RETRY_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504, 529}
|
||||
|
||||
|
||||
def _get_status_code(error: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status", "code"):
|
||||
value = getattr(error, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
|
||||
response = getattr(error, "response", None)
|
||||
if response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
if isinstance(status_code, int):
|
||||
return status_code
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_retryable_provider_request_error(
|
||||
error: BaseException,
|
||||
*,
|
||||
retry_rate_limits: bool,
|
||||
) -> bool:
|
||||
if is_connection_error(error):
|
||||
return True
|
||||
|
||||
error_type_name = type(error).__name__
|
||||
if error_type_name in {"APIConnectionError", "APITimeoutError"}:
|
||||
return True
|
||||
|
||||
status_code = _get_status_code(error)
|
||||
if status_code is None:
|
||||
return False
|
||||
|
||||
if status_code == 429 and not retry_rate_limits:
|
||||
return False
|
||||
|
||||
return status_code in REQUEST_RETRY_STATUS_CODES or 500 <= status_code <= 599
|
||||
|
||||
|
||||
def _log_retry(provider_label: str, retry_state: RetryCallState) -> None:
|
||||
error = retry_state.outcome.exception() if retry_state.outcome else None
|
||||
logger.warning(
|
||||
f"[{provider_label}] Request failed with retryable error; "
|
||||
f"retrying ({retry_state.attempt_number + 1}/{REQUEST_RETRY_ATTEMPTS}): "
|
||||
f"{error}"
|
||||
)
|
||||
|
||||
|
||||
def _build_retrying(
|
||||
provider_label: str,
|
||||
*,
|
||||
retry_rate_limits: bool,
|
||||
) -> AsyncRetrying:
|
||||
return AsyncRetrying(
|
||||
retry=retry_if_exception(
|
||||
lambda error: _is_retryable_provider_request_error(
|
||||
error,
|
||||
retry_rate_limits=retry_rate_limits,
|
||||
)
|
||||
),
|
||||
stop=stop_after_attempt(REQUEST_RETRY_ATTEMPTS),
|
||||
wait=wait_exponential(
|
||||
multiplier=1,
|
||||
min=REQUEST_RETRY_WAIT_MIN_S,
|
||||
max=REQUEST_RETRY_WAIT_MAX_S,
|
||||
),
|
||||
before_sleep=lambda retry_state: _log_retry(provider_label, retry_state),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
|
||||
async def retry_provider_request(
|
||||
provider_label: str,
|
||||
request_factory: Callable[[], Awaitable[T]],
|
||||
*,
|
||||
retry_rate_limits: bool = True,
|
||||
) -> T:
|
||||
retrying = _build_retrying(
|
||||
provider_label,
|
||||
retry_rate_limits=retry_rate_limits,
|
||||
)
|
||||
|
||||
async for attempt in retrying:
|
||||
with attempt:
|
||||
return await request_factory()
|
||||
|
||||
raise RuntimeError("Provider request retry loop exited unexpectedly.")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def retry_provider_request_context(
|
||||
provider_label: str,
|
||||
context_manager_factory: Callable[[], AbstractAsyncContextManager[T]],
|
||||
*,
|
||||
retry_rate_limits: bool = True,
|
||||
) -> AsyncIterator[T]:
|
||||
manager: AbstractAsyncContextManager[T] | None = None
|
||||
|
||||
async def _enter_context() -> T:
|
||||
nonlocal manager
|
||||
manager = context_manager_factory()
|
||||
return await manager.__aenter__()
|
||||
|
||||
value = await retry_provider_request(
|
||||
provider_label,
|
||||
_enter_context,
|
||||
retry_rate_limits=retry_rate_limits,
|
||||
)
|
||||
|
||||
if manager is None:
|
||||
raise RuntimeError("Provider request context was not created.")
|
||||
|
||||
try:
|
||||
yield value
|
||||
except BaseException as error:
|
||||
if await manager.__aexit__(type(error), error, error.__traceback__):
|
||||
return
|
||||
raise
|
||||
else:
|
||||
await manager.__aexit__(None, None, None)
|
||||
@@ -1,9 +1,12 @@
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
import astrbot.core.provider.sources.anthropic_source as anthropic_source
|
||||
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
|
||||
import astrbot.core.provider.sources.request_retry as request_retry
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
@@ -171,6 +174,36 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
|
||||
assert captured["httpx_module"] is anthropic_source.httpx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_get_models_retries_transient_request_error(monkeypatch):
|
||||
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
|
||||
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
|
||||
|
||||
class FakeModels:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def list(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise httpx.ConnectError("temporary connection failure")
|
||||
return SimpleNamespace(
|
||||
data=[
|
||||
SimpleNamespace(id="claude-b"),
|
||||
SimpleNamespace(id="claude-a"),
|
||||
]
|
||||
)
|
||||
|
||||
models = FakeModels()
|
||||
provider = anthropic_source.ProviderAnthropic.__new__(
|
||||
anthropic_source.ProviderAnthropic
|
||||
)
|
||||
provider.client = SimpleNamespace(models=models)
|
||||
|
||||
assert await provider.get_models() == ["claude-a", "claude-b"]
|
||||
assert models.calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
|
||||
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
import astrbot.core.provider.sources.request_retry as request_retry
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
|
||||
|
||||
@@ -27,3 +31,35 @@ def test_gemini_reasoning_only_output_is_allowed():
|
||||
response_id="resp_reasoning",
|
||||
finish_reason="STOP",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_get_models_retries_transient_request_error(monkeypatch):
|
||||
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
|
||||
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
|
||||
|
||||
class FakeModels:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def list(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise httpx.ConnectError("temporary connection failure")
|
||||
return [
|
||||
SimpleNamespace(
|
||||
name="models/gemini-a",
|
||||
supported_actions=["generateContent"],
|
||||
),
|
||||
SimpleNamespace(
|
||||
name="models/gemini-b",
|
||||
supported_actions=["embedContent"],
|
||||
),
|
||||
]
|
||||
|
||||
models = FakeModels()
|
||||
provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI)
|
||||
provider.client = SimpleNamespace(models=models)
|
||||
|
||||
assert await provider.get_models() == ["gemini-a"]
|
||||
assert models.calls == 2
|
||||
|
||||
@@ -3,12 +3,14 @@ import builtins
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from PIL import Image as PILImage
|
||||
|
||||
import astrbot.core.provider.sources.openai_source as openai_source_module
|
||||
import astrbot.core.provider.sources.request_retry as request_retry
|
||||
from astrbot.core.exceptions import EmptyModelOutputError
|
||||
from astrbot.core.provider.sources.groq_source import ProviderGroq
|
||||
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
|
||||
@@ -116,6 +118,34 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
|
||||
assert captured["httpx_module"] is openai_source_module.httpx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_models_retries_transient_request_error(monkeypatch):
|
||||
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
|
||||
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
|
||||
|
||||
class FakeModels:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def list(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise httpx.ConnectError("temporary connection failure")
|
||||
return SimpleNamespace(
|
||||
data=[
|
||||
SimpleNamespace(id="gpt-b"),
|
||||
SimpleNamespace(id="gpt-a"),
|
||||
]
|
||||
)
|
||||
|
||||
models = FakeModels()
|
||||
provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial)
|
||||
provider.client = SimpleNamespace(models=models)
|
||||
|
||||
assert await provider.get_models() == ["gpt-a", "gpt-b"]
|
||||
assert models.calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_api_error_content_moderated_removes_images():
|
||||
provider = _make_provider(
|
||||
|
||||
Reference in New Issue
Block a user