Files
AstrBot/tests/test_gemini_source.py
Weilong Liao dd36979eca feat: implement request retry mechanism for provider requests (#8893)
* feat: implement request retry mechanism for provider requests

* feat: add request max retries configuration and implement retry logic for provider requests

* feat: update fake_query function to accept request_max_retries parameter

* feat: remove retry_rate_limits from provider request calls
2026-06-19 17:13:40 +08:00

66 lines
2.0 KiB
Python

from types import SimpleNamespace
import httpx
import pytest
from astrbot.core.exceptions import EmptyModelOutputError
import astrbot.core.provider.sources.request_retry as request_retry
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI
def test_gemini_empty_output_raises_empty_model_output_error():
llm_response = LLMResponse(role="assistant")
with pytest.raises(EmptyModelOutputError):
ProviderGoogleGenAI._ensure_usable_response(
llm_response,
response_id="resp_empty",
finish_reason="STOP",
)
def test_gemini_reasoning_only_output_is_allowed():
llm_response = LLMResponse(
role="assistant",
reasoning_content="chain of thought placeholder",
)
ProviderGoogleGenAI._ensure_usable_response(
llm_response,
response_id="resp_reasoning",
finish_reason="STOP",
)
@pytest.mark.asyncio
async def test_gemini_get_models_retries_transient_request_error(monkeypatch):
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)
monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MAX_S", 0)
class FakeModels:
def __init__(self):
self.calls = 0
async def list(self):
self.calls += 1
if self.calls == 1:
raise httpx.ConnectError("temporary connection failure")
return [
SimpleNamespace(
name="models/gemini-a",
supported_actions=["generateContent"],
),
SimpleNamespace(
name="models/gemini-b",
supported_actions=["embedContent"],
),
]
models = FakeModels()
provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI)
provider.client = SimpleNamespace(models=models)
assert await provider.get_models() == ["gemini-a"]
assert models.calls == 2