mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-03 03:00:15 +08:00
Compare commits
1 Commits
codex/fix-
...
revert-857
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8707aec19 |
@@ -5,8 +5,6 @@ import base64
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import replace
|
||||
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from astrbot.core import db_helper, logger
|
||||
from astrbot.core.agent.message import (
|
||||
CheckpointData,
|
||||
@@ -521,15 +519,6 @@ class InternalAgentSubStage(Stage):
|
||||
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}
|
||||
decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED]
|
||||
|
||||
PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS = 3
|
||||
PROVIDER_STATS_SQLITE_LOCK_RETRY_BASE_DELAY = 0.2
|
||||
|
||||
|
||||
def _is_sqlite_database_locked_error(exc: OperationalError) -> bool:
|
||||
raw = getattr(exc, "orig", exc)
|
||||
message = str(raw).lower()
|
||||
return "database" in message and "locked" in message
|
||||
|
||||
|
||||
async def _record_internal_agent_stats(
|
||||
event: AstrMessageEvent,
|
||||
@@ -560,35 +549,15 @@ async def _record_internal_agent_stats(
|
||||
status = "error"
|
||||
else:
|
||||
status = "completed"
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
await db_helper.insert_provider_stat(
|
||||
umo=event.unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
provider_id=provider_config.get("id", "") or provider.meta().id,
|
||||
provider_model=provider.get_model(),
|
||||
status=status,
|
||||
stats=stats.to_dict(),
|
||||
agent_type="internal",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
|
||||
return
|
||||
|
||||
for attempt in range(PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS):
|
||||
last_attempt = attempt == PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS - 1
|
||||
try:
|
||||
await db_helper.insert_provider_stat(
|
||||
umo=event.unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
provider_id=provider_config.get("id", "") or provider.meta().id,
|
||||
provider_model=provider.get_model(),
|
||||
status=status,
|
||||
stats=stats.to_dict(),
|
||||
agent_type="internal",
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except OperationalError as e:
|
||||
if _is_sqlite_database_locked_error(e) and not last_attempt:
|
||||
await asyncio.sleep(
|
||||
PROVIDER_STATS_SQLITE_LOCK_RETRY_BASE_DELAY * (2**attempt)
|
||||
)
|
||||
continue
|
||||
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Persist provider stats failed: %s", e, exc_info=True)
|
||||
break
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import select
|
||||
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
@@ -65,143 +63,3 @@ async def test_record_internal_agent_stats_persists_provider_stat(
|
||||
assert record.start_time == 100.0
|
||||
assert record.end_time == 108.5
|
||||
assert record.time_to_first_token == 0.6
|
||||
|
||||
|
||||
def _provider_stats_recording_args():
|
||||
event = SimpleNamespace(unified_msg_origin="webchat:FriendMessage:session-42")
|
||||
req = ProviderRequest(conversation=SimpleNamespace(cid="conv-123"))
|
||||
provider = SimpleNamespace(
|
||||
provider_config={"id": "provider-1"},
|
||||
meta=lambda: SimpleNamespace(id="provider-1", type="openai"),
|
||||
get_model=lambda: "gpt-4.1",
|
||||
)
|
||||
agent_runner = SimpleNamespace(
|
||||
provider=provider,
|
||||
stats=AgentStats(),
|
||||
was_aborted=lambda: False,
|
||||
)
|
||||
return event, req, agent_runner, SimpleNamespace(role="assistant")
|
||||
|
||||
|
||||
def _provider_stats_operational_error(message: str) -> OperationalError:
|
||||
return OperationalError("insert into provider_stats", {}, Exception(message))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"lock_message",
|
||||
["database is locked", "database table is locked"],
|
||||
)
|
||||
async def test_record_internal_agent_stats_retries_transient_database_locks(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
lock_message: str,
|
||||
):
|
||||
attempts = 0
|
||||
|
||||
class LockedOnceDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise _provider_stats_operational_error(lock_message)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", LockedOnceDb())
|
||||
|
||||
async def no_sleep(delay: float) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(internal.asyncio, "sleep", no_sleep)
|
||||
|
||||
await internal._record_internal_agent_stats(
|
||||
*_provider_stats_recording_args(),
|
||||
)
|
||||
|
||||
assert attempts == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_internal_agent_stats_logs_after_exhausting_database_lock_retries(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
attempts = 0
|
||||
sleep_delays = []
|
||||
warnings = []
|
||||
|
||||
class AlwaysLockedDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
raise _provider_stats_operational_error("database is locked")
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", AlwaysLockedDb())
|
||||
|
||||
async def record_sleep(delay: float) -> None:
|
||||
sleep_delays.append(delay)
|
||||
|
||||
monkeypatch.setattr(internal.asyncio, "sleep", record_sleep)
|
||||
monkeypatch.setattr(
|
||||
internal.logger,
|
||||
"warning",
|
||||
lambda *args, **kwargs: warnings.append((args, kwargs)),
|
||||
)
|
||||
|
||||
await internal._record_internal_agent_stats(*_provider_stats_recording_args())
|
||||
|
||||
assert attempts == internal.PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS
|
||||
base_delay = internal.PROVIDER_STATS_SQLITE_LOCK_RETRY_BASE_DELAY
|
||||
expected_sleep_delays = [
|
||||
base_delay * (2**attempt)
|
||||
for attempt in range(internal.PROVIDER_STATS_SQLITE_LOCK_RETRY_ATTEMPTS - 1)
|
||||
]
|
||||
assert sleep_delays == expected_sleep_delays
|
||||
assert len(warnings) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_internal_agent_stats_does_not_retry_other_operational_errors(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
attempts = 0
|
||||
warnings = []
|
||||
|
||||
class FailingDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
raise _provider_stats_operational_error("no such table: provider_stats")
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", FailingDb())
|
||||
monkeypatch.setattr(
|
||||
internal.logger,
|
||||
"warning",
|
||||
lambda *args, **kwargs: warnings.append((args, kwargs)),
|
||||
)
|
||||
|
||||
await internal._record_internal_agent_stats(*_provider_stats_recording_args())
|
||||
|
||||
assert attempts == 1
|
||||
assert len(warnings) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_internal_agent_stats_propagates_cancelled_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
warnings = []
|
||||
|
||||
class CancellingDb:
|
||||
async def insert_provider_stat(self, **kwargs):
|
||||
raise asyncio.CancelledError
|
||||
|
||||
monkeypatch.setattr(internal, "db_helper", CancellingDb())
|
||||
monkeypatch.setattr(
|
||||
internal.logger,
|
||||
"warning",
|
||||
lambda *args, **kwargs: warnings.append((args, kwargs)),
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await internal._record_internal_agent_stats(*_provider_stats_recording_args())
|
||||
|
||||
assert warnings == []
|
||||
|
||||
Reference in New Issue
Block a user