mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 01:10:21 +08:00
fix: reconnect MCP client on terminated session (#8694)
* fix: reconnect MCP client on terminated session * Update astrbot/core/agent/mcp_client.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update mcp_client.py --------- Co-authored-by: Weilong Liao <37870767+Soulter@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@ import httpx
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
@@ -93,6 +93,10 @@ _DENIED_DOCKER_ARGS = frozenset(
|
||||
}
|
||||
)
|
||||
_STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS"
|
||||
_MCP_RECONNECT_ERROR_MESSAGES = (
|
||||
"session terminated",
|
||||
"session was terminated",
|
||||
)
|
||||
|
||||
try:
|
||||
import anyio
|
||||
@@ -121,6 +125,13 @@ except (ModuleNotFoundError, ImportError):
|
||||
)
|
||||
|
||||
|
||||
def _is_mcp_reconnect_error(exc: BaseException) -> bool:
|
||||
if "anyio" in globals() and isinstance(exc, anyio.ClosedResourceError):
|
||||
return True
|
||||
message = str(exc).lower()
|
||||
return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""Prepare configuration, handle nested format"""
|
||||
if config.get("mcpServers"):
|
||||
@@ -635,7 +646,7 @@ class MCPClient:
|
||||
"""
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||
retry=retry_if_exception(_is_mcp_reconnect_error),
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
@@ -651,9 +662,15 @@ class MCPClient:
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
except Exception as exc:
|
||||
if not _is_mcp_reconnect_error(exc):
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
||||
"MCP tool %s call failed (%s: %s), attempting to reconnect...",
|
||||
tool_name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
# Attempt to reconnect
|
||||
await self._reconnect()
|
||||
|
||||
103
tests/unit/test_mcp_client_reconnect.py
Normal file
103
tests/unit/test_mcp_client_reconnect.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from datetime import timedelta
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from tenacity import wait_none
|
||||
|
||||
from astrbot.core.agent import mcp_client
|
||||
|
||||
|
||||
class FlakyMcpSession:
|
||||
def __init__(self, first_error: Exception | None = None) -> None:
|
||||
self.calls = 0
|
||||
self.first_error = first_error or RuntimeError("Session terminated")
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
read_timeout_seconds: timedelta,
|
||||
) -> dict[str, object]:
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise self.first_error
|
||||
return {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
"timeout": read_timeout_seconds.total_seconds(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "expected"),
|
||||
[
|
||||
(RuntimeError("Session terminated"), True),
|
||||
(RuntimeError("SESSION TERMINATED"), True),
|
||||
(RuntimeError("session was terminated"), True),
|
||||
(anyio.ClosedResourceError(), True),
|
||||
(RuntimeError("business flow terminated normally"), False),
|
||||
(RuntimeError("terminated"), False),
|
||||
],
|
||||
)
|
||||
def test_mcp_reconnect_error_detection_is_narrow(
|
||||
error: BaseException, expected: bool
|
||||
) -> None:
|
||||
assert mcp_client._is_mcp_reconnect_error(error) is expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_reconnects_on_session_terminated(monkeypatch) -> None:
|
||||
monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none())
|
||||
|
||||
client = mcp_client.MCPClient()
|
||||
session = FlakyMcpSession()
|
||||
reconnects = 0
|
||||
|
||||
async def reconnect() -> None:
|
||||
nonlocal reconnects
|
||||
reconnects += 1
|
||||
client.session = session
|
||||
|
||||
client.session = session
|
||||
client._reconnect = reconnect
|
||||
|
||||
result = await client.call_tool_with_reconnect(
|
||||
tool_name="lookup",
|
||||
arguments={"url": "https://example.com"},
|
||||
read_timeout_seconds=timedelta(seconds=5),
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"name": "lookup",
|
||||
"arguments": {"url": "https://example.com"},
|
||||
"timeout": 5.0,
|
||||
}
|
||||
assert session.calls == 2
|
||||
assert reconnects == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_does_not_reconnect_on_business_error(monkeypatch) -> None:
|
||||
monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none())
|
||||
|
||||
client = mcp_client.MCPClient()
|
||||
session = FlakyMcpSession(first_error=ValueError("business logic failed"))
|
||||
reconnects = 0
|
||||
|
||||
async def reconnect() -> None:
|
||||
nonlocal reconnects
|
||||
reconnects += 1
|
||||
|
||||
client.session = session
|
||||
client._reconnect = reconnect
|
||||
|
||||
with pytest.raises(ValueError, match="business logic failed"):
|
||||
await client.call_tool_with_reconnect(
|
||||
tool_name="lookup",
|
||||
arguments={"url": "https://example.com"},
|
||||
read_timeout_seconds=timedelta(seconds=5),
|
||||
)
|
||||
|
||||
assert session.calls == 1
|
||||
assert reconnects == 0
|
||||
Reference in New Issue
Block a user