fix: reconnect MCP client on terminated session (#8694)

* fix: reconnect MCP client on terminated session

* Update astrbot/core/agent/mcp_client.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update mcp_client.py

---------

Co-authored-by: Weilong Liao <37870767+Soulter@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
EterUltimate
2026-06-24 22:43:24 +08:00
committed by GitHub
parent bc117038fb
commit 2bda4e4d96
2 changed files with 124 additions and 4 deletions

View File

@@ -13,7 +13,7 @@ import httpx
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
@@ -93,6 +93,10 @@ _DENIED_DOCKER_ARGS = frozenset(
}
)
_STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS"
_MCP_RECONNECT_ERROR_MESSAGES = (
"session terminated",
"session was terminated",
)
try:
import anyio
@@ -121,6 +125,13 @@ except (ModuleNotFoundError, ImportError):
)
def _is_mcp_reconnect_error(exc: BaseException) -> bool:
if "anyio" in globals() and isinstance(exc, anyio.ClosedResourceError):
return True
message = str(exc).lower()
return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES)
def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format"""
if config.get("mcpServers"):
@@ -635,7 +646,7 @@ class MCPClient:
"""
@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
retry=retry_if_exception(_is_mcp_reconnect_error),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(logger, logging.WARNING),
@@ -651,9 +662,15 @@ class MCPClient:
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
except Exception as exc:
if not _is_mcp_reconnect_error(exc):
raise
logger.warning(
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
"MCP tool %s call failed (%s: %s), attempting to reconnect...",
tool_name,
type(exc).__name__,
exc,
)
# Attempt to reconnect
await self._reconnect()

View File

@@ -0,0 +1,103 @@
from datetime import timedelta
import anyio
import pytest
from tenacity import wait_none
from astrbot.core.agent import mcp_client
class FlakyMcpSession:
def __init__(self, first_error: Exception | None = None) -> None:
self.calls = 0
self.first_error = first_error or RuntimeError("Session terminated")
async def call_tool(
self,
*,
name: str,
arguments: dict,
read_timeout_seconds: timedelta,
) -> dict[str, object]:
self.calls += 1
if self.calls == 1:
raise self.first_error
return {
"name": name,
"arguments": arguments,
"timeout": read_timeout_seconds.total_seconds(),
}
@pytest.mark.parametrize(
("error", "expected"),
[
(RuntimeError("Session terminated"), True),
(RuntimeError("SESSION TERMINATED"), True),
(RuntimeError("session was terminated"), True),
(anyio.ClosedResourceError(), True),
(RuntimeError("business flow terminated normally"), False),
(RuntimeError("terminated"), False),
],
)
def test_mcp_reconnect_error_detection_is_narrow(
error: BaseException, expected: bool
) -> None:
assert mcp_client._is_mcp_reconnect_error(error) is expected
@pytest.mark.asyncio
async def test_call_tool_reconnects_on_session_terminated(monkeypatch) -> None:
monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none())
client = mcp_client.MCPClient()
session = FlakyMcpSession()
reconnects = 0
async def reconnect() -> None:
nonlocal reconnects
reconnects += 1
client.session = session
client.session = session
client._reconnect = reconnect
result = await client.call_tool_with_reconnect(
tool_name="lookup",
arguments={"url": "https://example.com"},
read_timeout_seconds=timedelta(seconds=5),
)
assert result == {
"name": "lookup",
"arguments": {"url": "https://example.com"},
"timeout": 5.0,
}
assert session.calls == 2
assert reconnects == 1
@pytest.mark.asyncio
async def test_call_tool_does_not_reconnect_on_business_error(monkeypatch) -> None:
monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none())
client = mcp_client.MCPClient()
session = FlakyMcpSession(first_error=ValueError("business logic failed"))
reconnects = 0
async def reconnect() -> None:
nonlocal reconnects
reconnects += 1
client.session = session
client._reconnect = reconnect
with pytest.raises(ValueError, match="business logic failed"):
await client.call_tool_with_reconnect(
tool_name="lookup",
arguments={"url": "https://example.com"},
read_timeout_seconds=timedelta(seconds=5),
)
assert session.calls == 1
assert reconnects == 0