Compare commits

...

1 Commits

Author SHA1 Message Date
Soulter
af204f4273 feat: implement request retry mechanism for provider requests 2026-06-09 14:27:46 +08:00
7 changed files with 296 additions and 24 deletions

View File

@@ -24,6 +24,7 @@ from astrbot.core.utils.network_utils import (
)
from ..register import register_provider_adapter
from .request_retry import retry_provider_request, retry_provider_request_context
@register_provider_adapter(
@@ -366,8 +367,11 @@ class ProviderAnthropic(Provider):
self._apply_thinking_config(payloads)
try:
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
completion = await retry_provider_request(
"Anthropic",
lambda: self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
),
)
except httpx.RequestError as e:
proxy = self.provider_config.get("proxy", "")
@@ -459,8 +463,9 @@ class ProviderAnthropic(Provider):
payloads["max_tokens"] = 65536
self._apply_thinking_config(payloads)
async with self.client.messages.stream(
**payloads, extra_body=extra_body
async with retry_provider_request_context(
"Anthropic",
lambda: self.client.messages.stream(**payloads, extra_body=extra_body),
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
@@ -838,7 +843,10 @@ class ProviderAnthropic(Provider):
async def get_models(self) -> list[str]:
models_str = []
models = await self.client.models.list()
models = await retry_provider_request(
"Anthropic",
lambda: self.client.models.list(),
)
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)

View File

@@ -28,6 +28,7 @@ from astrbot.core.utils.media_utils import ensure_wav
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
from ..register import register_provider_adapter
from .request_retry import retry_provider_request
class SuppressNonTextPartsWarning(logging.Filter):
@@ -630,10 +631,14 @@ class ProviderGoogleGenAI(Provider):
modalities,
temperature,
)
result = await self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
retry_rate_limits=False,
)
logger.debug(f"genai result: {result}")
@@ -710,10 +715,14 @@ class ProviderGoogleGenAI(Provider):
payloads.get("tool_choice", "auto"),
system_instruction,
)
result = await self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
result = await retry_provider_request(
"Gemini",
lambda: self.client.models.generate_content_stream(
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
),
retry_rate_limits=False,
)
break
except APIError as e:
@@ -940,7 +949,10 @@ class ProviderGoogleGenAI(Provider):
async def get_models(self):
try:
models = await self.client.models.list()
models = await retry_provider_request(
"Gemini",
lambda: self.client.models.list(),
)
return [
m.name.replace("models/", "")
for m in models

View File

@@ -48,6 +48,7 @@ from astrbot.core.utils.network_utils import (
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
from .request_retry import retry_provider_request
@register_provider_adapter(
@@ -560,7 +561,10 @@ class ProviderOpenAIOfficial(Provider):
async def get_models(self):
try:
models_str = []
models = await self.client.models.list()
models = await retry_provider_request(
"OpenAI",
lambda: self.client.models.list(),
)
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
@@ -636,10 +640,14 @@ class ProviderOpenAIOfficial(Provider):
self._sanitize_assistant_messages(payloads)
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
completion = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
),
retry_rate_limits=False,
)
if not isinstance(completion, ChatCompletion):
@@ -688,11 +696,15 @@ class ProviderOpenAIOfficial(Provider):
self._sanitize_assistant_messages(payloads)
stream = await self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
stream = await retry_provider_request(
"OpenAI",
lambda: self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
stream_options={"include_usage": True},
),
retry_rate_limits=False,
)
llm_response = LLMResponse("assistant", is_chunk=True)

View File

@@ -0,0 +1,141 @@
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TypeVar
from tenacity import (
AsyncRetrying,
RetryCallState,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
from astrbot import logger
from astrbot.core.utils.network_utils import is_connection_error
T = TypeVar("T")
REQUEST_RETRY_ATTEMPTS = 5
REQUEST_RETRY_WAIT_MIN_S = 1
REQUEST_RETRY_WAIT_MAX_S = 8
REQUEST_RETRY_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504, 529}
def _get_status_code(error: BaseException) -> int | None:
for attr in ("status_code", "status", "code"):
value = getattr(error, attr, None)
if isinstance(value, int):
return value
response = getattr(error, "response", None)
if response is not None:
status_code = getattr(response, "status_code", None)
if isinstance(status_code, int):
return status_code
return None
def _is_retryable_provider_request_error(
error: BaseException,
*,
retry_rate_limits: bool,
) -> bool:
if is_connection_error(error):
return True
error_type_name = type(error).__name__
if error_type_name in {"APIConnectionError", "APITimeoutError"}:
return True
status_code = _get_status_code(error)
if status_code is None:
return False
if status_code == 429 and not retry_rate_limits:
return False
return status_code in REQUEST_RETRY_STATUS_CODES or 500 <= status_code <= 599
def _log_retry(provider_label: str, retry_state: RetryCallState) -> None:
error = retry_state.outcome.exception() if retry_state.outcome else None
logger.warning(
f"[{provider_label}] Request failed with retryable error; "
f"retrying ({retry_state.attempt_number + 1}/{REQUEST_RETRY_ATTEMPTS}): "
f"{error}"
)
def _build_retrying(
provider_label: str,
*,
retry_rate_limits: bool,
) -> AsyncRetrying:
return AsyncRetrying(
retry=retry_if_exception(
lambda error: _is_retryable_provider_request_error(
error,
retry_rate_limits=retry_rate_limits,
)
),
stop=stop_after_attempt(REQUEST_RETRY_ATTEMPTS),
wait=wait_exponential(
multiplier=1,
min=REQUEST_RETRY_WAIT_MIN_S,
max=REQUEST_RETRY_WAIT_MAX_S,
),
before_sleep=lambda retry_state: _log_retry(provider_label, retry_state),
reraise=True,
)
async def retry_provider_request(
provider_label: str,
request_factory: Callable[[], Awaitable[T]],
*,
retry_rate_limits: bool = True,
) -> T:
retrying = _build_retrying(
provider_label,
retry_rate_limits=retry_rate_limits,
)
async for attempt in retrying:
with attempt:
return await request_factory()
raise RuntimeError("Provider request retry loop exited unexpectedly.")
@asynccontextmanager
async def retry_provider_request_context(
provider_label: str,
context_manager_factory: Callable[[], AbstractAsyncContextManager[T]],
*,
retry_rate_limits: bool = True,
) -> AsyncIterator[T]:
manager: AbstractAsyncContextManager[T] | None = None
async def _enter_context() -> T:
nonlocal manager
manager = context_manager_factory()
return await manager.__aenter__()
value = await retry_provider_request(
provider_label,
_enter_context,
retry_rate_limits=retry_rate_limits,
)
if manager is None:
raise RuntimeError("Provider request context was not created.")
try:
yield value
except BaseException as error:
if await manager.__aexit__(type(error), error, error.__traceback__):
return
raise
else:
await manager.__aexit__(None, None, None)

View File

@@ -1,9 +1,12 @@
import builtins
from types import SimpleNamespace
import httpx
import pytest
import astrbot.core.provider.sources.anthropic_source as anthropic_source
import astrbot.core.provider.sources.kimi_code_source as kimi_code_source
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
@@ -171,6 +174,36 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
assert captured["httpx_module"] is anthropic_source.httpx
@pytest.mark.asyncio
async def test_anthropic_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="claude-b"),
SimpleNamespace(id="claude-a"),
]
)
models = FakeModels()
provider = anthropic_source.ProviderAnthropic.__new__(
anthropic_source.ProviderAnthropic
)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["claude-a", "claude-b"]
assert models.calls == 2
@pytest.mark.asyncio
async def test_text_chat_wraps_string_system_prompt_as_list(monkeypatch):
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)

View File

@@ -1,6 +1,10 @@
from types import SimpleNamespace
import httpx
import pytest
from astrbot.core.exceptions import EmptyModelOutputError
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
@@ -27,3 +31,35 @@ def test_gemini_reasoning_only_output_is_allowed():
response_id="resp_reasoning",
finish_reason="STOP",
)
@pytest.mark.asyncio
async def test_gemini_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return [
SimpleNamespace(
name="models/gemini-a",
supported_actions=["generateContent"],
),
SimpleNamespace(
name="models/gemini-b",
supported_actions=["embedContent"],
),
]
models = FakeModels()
provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gemini-a"]
assert models.calls == 2

View File

@@ -3,12 +3,14 @@ import builtins
from io import BytesIO
from types import SimpleNamespace
import httpx
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from PIL import Image as PILImage
import astrbot.core.provider.sources.openai_source as openai_source_module
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.sources.groq_source import ProviderGroq
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
@@ -116,6 +118,34 @@ def test_create_http_client_falls_back_to_global_httpx_module(monkeypatch):
assert captured["httpx_module"] is openai_source_module.httpx
@pytest.mark.asyncio
async def test_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return SimpleNamespace(
data=[
SimpleNamespace(id="gpt-b"),
SimpleNamespace(id="gpt-a"),
]
)
models = FakeModels()
provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gpt-a", "gpt-b"]
assert models.calls == 2
@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_removes_images():
provider = _make_provider(