mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-02 18:50:15 +08:00
merge: pull latest master into dev
Resolved conflicts: - openai_source.py: keep dev version with abort_signal filtering - customizer.ts: keep dev version with viewMode functionality - useSessions.ts: keep dev version with pendingSessionId handling - platformUtils.js: keep dev version with correct tutorial links - AddNewPlatform.vue: keep dev version with correct docs link - FullLayout.vue: keep dev version with viewMode-based logic - VerticalHeader.vue: keep dev version with viewMode-based logic
This commit is contained in:
57
astrbot/_internal/abc/abp/base_astrbot_abp_client.py
Normal file
57
astrbot/_internal/abc/abp/base_astrbot_abp_client.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
ABP (AstrBot Protocol) client - in-process star communication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAstrbotAbpClient(ABC):
|
||||
"""
|
||||
ABP client: in-process star (plugin) communication.
|
||||
|
||||
Stars register themselves; client delegates calls to registered instances.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- register_star(name, instance) -> None
|
||||
- unregister_star(name) -> None
|
||||
- call_star_tool(star, tool, args) -> Any
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Lightweight: just sets connected=True."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def register_star(self, star_name: str, star_instance: Any) -> None:
|
||||
"""Add star to internal registry."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def unregister_star(self, star_name: str) -> None:
|
||||
"""Remove star from registry (idempotent)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_star_tool(
|
||||
self,
|
||||
star_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Delegate to star_instance.call_tool(tool_name, arguments)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Set connected=False, cancel pending requests."""
|
||||
...
|
||||
66
astrbot/_internal/abc/acp/base_astrbot_acp_client.py
Normal file
66
astrbot/_internal/abc/acp/base_astrbot_acp_client.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) client.
|
||||
|
||||
Transport: TCP | Unix Socket
|
||||
Messages: JSON with Content-Length header
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAstrbotAcpClient(ABC):
|
||||
"""
|
||||
ACP client: connects to ACP servers via TCP or Unix socket.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- connect_to_server(host, port) -> None
|
||||
- connect_to_unix_socket(path) -> None
|
||||
- call_tool(server, tool, args) -> Any
|
||||
- send_notification(method, params) -> None
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_server(self, host: str, port: int) -> None:
|
||||
"""Connect via TCP."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_unix_socket(self, socket_path: str) -> None:
|
||||
"""Connect via Unix domain socket."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Call tool on server, return result."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_notification(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any],
|
||||
) -> None:
|
||||
"""Send one-way notification."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Close connection, cancel pending requests."""
|
||||
...
|
||||
68
astrbot/_internal/abc/acp/base_astrbot_acp_server.py
Normal file
68
astrbot/_internal/abc/acp/base_astrbot_acp_server.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) server.
|
||||
|
||||
Transport: TCP listening socket
|
||||
Messages: JSON with Content-Length header
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAstrbotAcpServer(ABC):
|
||||
"""
|
||||
ACP server: listens for client connections, exposes tools.
|
||||
|
||||
Subclass must implement:
|
||||
- start(host, port) -> None
|
||||
- register_tool(name, handler) -> None
|
||||
- register_notification_handler(name, handler) -> None
|
||||
- broadcast_notification(method, params) -> None
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def running(self) -> bool:
|
||||
"""True if server is accepting connections."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
|
||||
"""Bind and listen. Block until shutdown."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Register async tool handler (receives params dict, returns result)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def register_notification_handler(
|
||||
self,
|
||||
name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Register async notification handler (receives params dict)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def broadcast_notification(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any],
|
||||
) -> None:
|
||||
"""Send notification to all connected clients."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop accepting, close all client connections."""
|
||||
...
|
||||
73
astrbot/_internal/abc/base_astrbot_gateway.py
Normal file
73
astrbot/_internal/abc/base_astrbot_gateway.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
AstrBot Gateway - HTTP/WebSocket API server.
|
||||
|
||||
Built on FastAPI, provides:
|
||||
- HTTP REST API (stats, inspector, config)
|
||||
- WebSocket for real-time events
|
||||
- Static file serving (dashboard)
|
||||
- Authentication (JWT/API key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseAstrbotGateway(ABC):
|
||||
"""
|
||||
Gateway: HTTP/WebSocket server built on FastAPI.
|
||||
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ FastAPI App │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ REST Endpoints WebSocket │
|
||||
│ ├─ GET /api/stats ├─ /ws (connection manager)│
|
||||
│ ├─ GET /api/inspector/* │ │
|
||||
│ ├─ GET /api/memory/* │ │
|
||||
│ └─ ... │ │
|
||||
│ │
|
||||
│ Middleware: CORS, Auth, Logging │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────┐
|
||||
│ Orchestrator │
|
||||
│ (owns protocol clients)│
|
||||
└─────────────────────────┘
|
||||
|
||||
Routes (typical):
|
||||
GET / → Dashboard static files
|
||||
GET /api/stats → System statistics
|
||||
GET /api/inspector/stars → List registered stars
|
||||
WS /ws → WebSocket for real-time events
|
||||
|
||||
serve() Lifecycle:
|
||||
1. Create FastAPI app
|
||||
2. Register routes
|
||||
3. Start WebSocket manager
|
||||
4. Bind to host:port
|
||||
5. Run ASGI server (uvicorn/hypercorn)
|
||||
6. Block until shutdown
|
||||
7. Close all connections
|
||||
|
||||
Subclass must implement:
|
||||
- serve(): start server, block until shutdown
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def serve(self) -> None:
|
||||
"""
|
||||
Start gateway server - blocks until shutdown.
|
||||
|
||||
Should:
|
||||
1. Create FastAPI app with routes
|
||||
2. Configure CORS, auth middleware
|
||||
3. Start WebSocket connection manager
|
||||
4. Bind to ASTRBOT_PORT (default 6185)
|
||||
5. Run ASGI server
|
||||
6. Handle graceful shutdown on SIGTERM/SIGINT
|
||||
|
||||
Raises:
|
||||
OSError: address already in use
|
||||
"""
|
||||
...
|
||||
352
astrbot/_internal/abc/base_astrbot_orchestrator.py
Normal file
352
astrbot/_internal/abc/base_astrbot_orchestrator.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
AstrBot Orchestrator - core runtime lifecycle manager.
|
||||
|
||||
Architecture
|
||||
============
|
||||
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Orchestrator │
|
||||
│ (owns lifecycle of all protocol clients + stars) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌──────────────┼──────────────┐
|
||||
▼ ▼ ▼
|
||||
┌─────────┐ ┌─────────┐ ┌─────────┐
|
||||
│ LSP │ │ MCP │ │ ACP │
|
||||
│ Client │ │ Client │ │ Client │
|
||||
└─────────┘ └─────────┘ └─────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
LSP Servers MCP Servers ACP Services
|
||||
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ ABP Client │
|
||||
│ (in-process star registry) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────┐
|
||||
│ Stars │
|
||||
│(Plugins) │
|
||||
└─────────┘
|
||||
|
||||
|
||||
Lifecycle State Machine
|
||||
=======================
|
||||
|
||||
States:
|
||||
┌─────────┐
|
||||
│ INIT │───► orchestrator created, clients not initialized
|
||||
└────┬────┘
|
||||
│ start()
|
||||
▼
|
||||
┌─────────┐
|
||||
│ RUNNING │◄─── run_loop() executing
|
||||
└────┬────┘
|
||||
│ shutdown()
|
||||
▼
|
||||
┌──────────┐
|
||||
│ SHUTDOWN │─── all clients closed, ready for GC
|
||||
└──────────┘
|
||||
|
||||
Transitions:
|
||||
INIT + start() ──► RUNNING
|
||||
RUNNING + shutdown() ──► SHUTDOWN
|
||||
|
||||
For each protocol client, the orchestrator:
|
||||
1. Creates instance in __init__
|
||||
2. Calls connect() to initialize
|
||||
3. Calls protocol-specific setup (connect_to_server, etc)
|
||||
4. Manages via run_loop() heartbeat
|
||||
5. Calls shutdown() on final cleanup
|
||||
|
||||
|
||||
Star Registration Flow
|
||||
=====================
|
||||
|
||||
orchestrator.register_star("my-star", MyStar())
|
||||
│
|
||||
▼
|
||||
┌───────────────────┐
|
||||
│ ABP Client │
|
||||
│ .register_star() │
|
||||
└───────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────┐
|
||||
│ Internal dict │
|
||||
│ {"my-star": obj} │
|
||||
└───────────────────┘
|
||||
|
||||
|
||||
Message Routing (conceptual)
|
||||
===========================
|
||||
|
||||
External Tool Call
|
||||
│
|
||||
▼
|
||||
┌──────────────┐ list_tools() ┌──────────────┐
|
||||
│ MCP Client │────────────────────►│ MCP Server │
|
||||
└──────────────┘◄────────────────────└──────────────┘
|
||||
│ tool result
|
||||
▼
|
||||
┌──────────────┐ call_tool() ┌──────────────┐
|
||||
│ ABP │────────────────────►│ Star │
|
||||
│ Client │◄────────────────────└──────────────┘
|
||||
└──────────────┘ tool result
|
||||
│
|
||||
▼
|
||||
Return to caller
|
||||
|
||||
|
||||
run_loop() Responsibilities
|
||||
===========================
|
||||
|
||||
while running:
|
||||
│─ check LSP server health (ping/heartbeat)
|
||||
│─ check MCP session status (reconnect if needed)
|
||||
│─ check ACP client connections
|
||||
│─ process any pending star notifications
|
||||
│─ sleep(SLEEP_INTERVAL)
|
||||
|
||||
|
||||
Shutdown Sequence
|
||||
==================
|
||||
|
||||
shutdown()
|
||||
│
|
||||
├─ set _running = False
|
||||
│
|
||||
├─ LSP.shutdown()
|
||||
│ └─ send "shutdown" request
|
||||
│ └─ terminate subprocess
|
||||
│
|
||||
├─ ACP.shutdown()
|
||||
│ └─ close TCP/Unix connections
|
||||
│
|
||||
├─ ABP.shutdown()
|
||||
│ └─ cancel pending requests
|
||||
│
|
||||
└─ MCP.cleanup()
|
||||
└─ close all sessions
|
||||
└─ cleanup subprocesses
|
||||
|
||||
|
||||
Exception Handling
|
||||
==================
|
||||
|
||||
Each protocol client should:
|
||||
- Catch connection errors
|
||||
- Attempt reconnection with exponential backoff
|
||||
- Log errors but don't crash run_loop
|
||||
- Raise on irrecoverable failures
|
||||
|
||||
The orchestrator run_loop should:
|
||||
- Catch CancelledError on shutdown
|
||||
- Catch Exception and log (don't crash)
|
||||
- Ensure cleanup runs in finally block
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
|
||||
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
|
||||
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
|
||||
|
||||
#: Default heartbeat interval for run_loop()
|
||||
DEFAULT_SLEEP_INTERVAL: float = 5.0
|
||||
|
||||
|
||||
class BaseAstrbotOrchestrator(ABC):
|
||||
"""
|
||||
Core runtime: owns lifecycle of all protocol clients and stars.
|
||||
|
||||
┌────────────────────────────────────────────────────────────┐
|
||||
│ Protocol Clients (always present, never None after init) │
|
||||
├────────────────────────────────────────────────────────────┤
|
||||
│ lsp: Language Server Protocol │
|
||||
│ Purpose: code completion, diagnostics, hover, etc │
|
||||
│ Transport: stdio subprocess │
|
||||
│ │
|
||||
│ mcp: Model Context Protocol │
|
||||
│ Purpose: external tool access │
|
||||
│ Transport: stdio | SSE | HTTP │
|
||||
│ │
|
||||
│ acp: AstrBot Communication Protocol │
|
||||
│ Purpose: inter-service communication │
|
||||
│ Transport: TCP | Unix Socket │
|
||||
│ │
|
||||
│ abp: AstrBot Protocol │
|
||||
│ Purpose: in-process star (plugin) communication │
|
||||
│ Transport: direct method calls │
|
||||
└────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌────────────────────────────────────────────────────────────┐
|
||||
│ Star Registry │
|
||||
├────────────────────────────────────────────────────────────┤
|
||||
│ _stars: dict[str, Any] │
|
||||
│ Stars are plugins registered by name │
|
||||
│ ABP client delegates calls to registered stars │
|
||||
└────────────────────────────────────────────────────────────┘
|
||||
|
||||
Subclass must implement:
|
||||
- __init__(): create all protocol client instances
|
||||
- run_loop(): main event loop (block until shutdown)
|
||||
- register_star(name, instance): add to registry + ABP
|
||||
- unregister_star(name): remove from registry + ABP
|
||||
- shutdown(): clean up all clients
|
||||
"""
|
||||
|
||||
#: LSP client for language intelligence
|
||||
lsp: AstrbotLspClient
|
||||
|
||||
#: MCP client for external tools
|
||||
mcp: McpClient
|
||||
|
||||
#: ACP client for inter-service communication
|
||||
acp: AstrbotAcpClient
|
||||
|
||||
#: ABP client for in-process star communication
|
||||
abp: AstrbotAbpClient
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize orchestrator and all protocol clients.
|
||||
|
||||
After __init__, all clients exist but are not connected.
|
||||
Call start() or run_loop() to begin operation.
|
||||
|
||||
Example:
|
||||
class MyOrchestrator(BaseAstrbotOrchestrator):
|
||||
def __init__(self):
|
||||
self.lsp = AstrbotLspClient()
|
||||
self.mcp = McpClient()
|
||||
self.acp = AstrbotAcpClient()
|
||||
self.abp = AstrbotAbpClient()
|
||||
self._stars: dict[str, Any] = {}
|
||||
self._running = False
|
||||
"""
|
||||
self._stars: dict[str, Any] = {}
|
||||
self._running: bool = False
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""True if run_loop() is executing."""
|
||||
return self._running
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Initialize all protocol clients.
|
||||
|
||||
Called once before run_loop(). Should:
|
||||
1. Call lsp.connect()
|
||||
2. Call mcp.connect()
|
||||
3. Call acp.connect()
|
||||
4. Call abp.connect()
|
||||
5. Set _running = True
|
||||
|
||||
Raises:
|
||||
Exception: if any client fails to initialize
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def run_loop(self) -> None:
|
||||
"""
|
||||
Main event loop - blocks until shutdown.
|
||||
|
||||
Execution:
|
||||
self._running = True
|
||||
try:
|
||||
while self._running:
|
||||
await self._heartbeat()
|
||||
await anyio.sleep(DEFAULT_SLEEP_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
pass # shutdown requested
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
_heartbeat() responsibilities:
|
||||
- Check LSP server health (optional ping)
|
||||
- Check MCP session status, reconnect if needed
|
||||
- Check ACP connections
|
||||
- Process any pending star notifications
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: when shutdown() called
|
||||
|
||||
Note:
|
||||
Subclass defines _heartbeat() for periodic tasks.
|
||||
This method only handles the loop control.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def register_star(self, name: str, star_instance: Any) -> None:
|
||||
"""
|
||||
Register a star (plugin) with the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the star
|
||||
instance: Star plugin instance (must have .call_tool() method)
|
||||
|
||||
Does:
|
||||
self._stars[name] = star_instance
|
||||
self.abp.register_star(name, star_instance)
|
||||
|
||||
Raises:
|
||||
ValueError: if name already registered
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def unregister_star(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a star (plugin) from the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Identifier of star to remove
|
||||
|
||||
Does:
|
||||
del self._stars[name]
|
||||
self.abp.unregister_star(name)
|
||||
|
||||
Note:
|
||||
Idempotent - does nothing if name not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_star(self, name: str) -> Any | None:
|
||||
"""Get registered star by name. Returns None if not found."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_stars(self) -> list[str]:
|
||||
"""Return list of registered star names."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Graceful shutdown of orchestrator and all clients.
|
||||
|
||||
Execution order:
|
||||
1. self._running = False (stop run_loop)
|
||||
2. await lsp.shutdown()
|
||||
3. await acp.shutdown()
|
||||
4. await abp.shutdown()
|
||||
5. await mcp.cleanup()
|
||||
|
||||
Does NOT unregister stars - caller should do that first.
|
||||
|
||||
After shutdown, orchestrator is ready for garbage collection.
|
||||
"""
|
||||
...
|
||||
114
astrbot/_internal/abc/lsp/base_astrbot_lsp_client.py
Normal file
114
astrbot/_internal/abc/lsp/base_astrbot_lsp_client.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
LSP (Language Server Protocol) client.
|
||||
|
||||
Transport: stdio subprocess
|
||||
Messages: JSON-RPC 2.0 with Content-Length header
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class LspMessage:
|
||||
"""JSON-RPC 2.0 message."""
|
||||
|
||||
jsonrpc: str = "2.0"
|
||||
id: int | str | None = None
|
||||
method: str | None = None
|
||||
params: dict[str, Any] | None = None
|
||||
result: Any = None
|
||||
error: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LspRequest(LspMessage):
|
||||
"""Outgoing request."""
|
||||
|
||||
def __init__(self, method: str, params: dict[str, Any] | None = None) -> None:
|
||||
self.id = id(self)
|
||||
self.method = method
|
||||
self.params = params
|
||||
|
||||
|
||||
class LspResponse(LspMessage):
|
||||
"""Incoming response."""
|
||||
|
||||
|
||||
class LspNotification(LspMessage):
|
||||
"""Incoming notification (no id)."""
|
||||
|
||||
|
||||
class BaseAstrbotLspClient(ABC):
|
||||
"""
|
||||
LSP client: connects to LSP servers via stdio subprocess.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- connect_to_server(command, workspace_uri) -> None
|
||||
- send_request(method, params) -> dict
|
||||
- send_notification(method, params) -> None
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to an LSP server."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
self._connected = False
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_server(
|
||||
self,
|
||||
command: list[str],
|
||||
workspace_uri: str,
|
||||
) -> None:
|
||||
"""
|
||||
Start LSP server subprocess and complete handshake.
|
||||
|
||||
Steps:
|
||||
1. Spawn subprocess with stdin/stdout pipes
|
||||
2. Send initialize request
|
||||
3. Wait for response
|
||||
4. Send initialized notification
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_request(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send JSON-RPC request and return result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: not connected
|
||||
Exception: server returned error
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_notification(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send JSON-RPC notification (no response expected).
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Send shutdown, terminate subprocess, cleanup."""
|
||||
...
|
||||
95
astrbot/_internal/abc/mcp/base_astrbot_mcp_client.py
Normal file
95
astrbot/_internal/abc/mcp/base_astrbot_mcp_client.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) client.
|
||||
|
||||
Transport: stdio | SSE | streamable_http
|
||||
Messages: JSON-RPC 2.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class McpServerConfig(TypedDict, total=False):
|
||||
"""MCP server configuration."""
|
||||
|
||||
# Stdio transport
|
||||
command: str
|
||||
args: list[str]
|
||||
env: dict[str, str]
|
||||
cwd: str
|
||||
|
||||
# HTTP transport
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
transport: Literal["sse", "streamable_http"]
|
||||
|
||||
|
||||
class McpToolInfo(TypedDict):
|
||||
"""MCP tool descriptor."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
inputSchema: dict[str, Any]
|
||||
|
||||
|
||||
class BaseAstrbotMcpClient(ABC):
|
||||
"""
|
||||
MCP client: connects to MCP servers for external tools.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- connect_to_server(config, name) -> None
|
||||
- list_tools() -> list[McpToolInfo]
|
||||
- call_tool(name, args, timeout) -> CallToolResult
|
||||
- cleanup() -> None
|
||||
"""
|
||||
|
||||
session: Any # mcp.ClientSession
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Initialize client session."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_server(
|
||||
self,
|
||||
config: McpServerConfig,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Connect to MCP server.
|
||||
|
||||
Stdio: {"command": "python", "args": ["server.py"], "env": {...}}
|
||||
HTTP: {"url": "https://...", "transport": "sse"}
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_tools(self) -> list[McpToolInfo]:
|
||||
"""Call tools/list and return tools."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
read_timeout_seconds: int = 60,
|
||||
) -> Any:
|
||||
"""Call tools/call with reconnection support."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Close all server connections."""
|
||||
...
|
||||
6
astrbot/_internal/geteway/__init__.py
Normal file
6
astrbot/_internal/geteway/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Gateway module - FastAPI server for the dashboard backend."""
|
||||
|
||||
from .server import AstrbotGateway
|
||||
from .ws_manager import WebSocketManager
|
||||
|
||||
__all__ = ["AstrbotGateway", "WebSocketManager"]
|
||||
4
astrbot/_internal/geteway/deps.py
Normal file
4
astrbot/_internal/geteway/deps.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
依赖注入
|
||||
|
||||
"""
|
||||
0
astrbot/_internal/geteway/routes/inspector.py
Normal file
0
astrbot/_internal/geteway/routes/inspector.py
Normal file
0
astrbot/_internal/geteway/routes/memory.py
Normal file
0
astrbot/_internal/geteway/routes/memory.py
Normal file
0
astrbot/_internal/geteway/routes/stats.py
Normal file
0
astrbot/_internal/geteway/routes/stats.py
Normal file
248
astrbot/_internal/geteway/server.py
Normal file
248
astrbot/_internal/geteway/server.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
AstrBot Gateway - FastAPI server for the dashboard backend.
|
||||
|
||||
Provides REST API endpoints and WebSocket connections for the frontend dashboard.
|
||||
The gateway acts as the communication bridge between the dashboard and the orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.geteway.ws_manager import WebSocketManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
else:
|
||||
try:
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
except ImportError:
|
||||
logger.warning("FastAPI not installed, gateway unavailable.")
|
||||
FastAPI = cast(Any, None)
|
||||
WebSocket = cast(Any, None)
|
||||
WebSocketDisconnect = cast(Any, None)
|
||||
CORSMiddleware = cast(Any, None)
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotGateway(BaseAstrbotGateway):
|
||||
"""
|
||||
FastAPI-based gateway server for AstrBot.
|
||||
|
||||
Handles:
|
||||
- REST API endpoints for configuration and stats
|
||||
- WebSocket connections for real-time communication
|
||||
- CORS middleware for dashboard access
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BaseAstrbotOrchestrator) -> None:
|
||||
self.orchestrator = orchestrator
|
||||
self.ws_manager = WebSocketManager()
|
||||
self._app: FastAPI | None = None
|
||||
self._host = "0.0.0.0"
|
||||
self._port = 8765
|
||||
|
||||
async def serve(self) -> None:
|
||||
"""
|
||||
Start the gateway server.
|
||||
|
||||
Creates and runs a FastAPI application with WebSocket support.
|
||||
"""
|
||||
if FastAPI is None:
|
||||
raise RuntimeError("FastAPI is not installed")
|
||||
|
||||
log.info(f"Starting AstrBot Gateway on {self._host}:{self._port}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
log.info("Gateway server started.")
|
||||
yield
|
||||
# Shutdown
|
||||
await self.ws_manager.broadcast({"type": "server_shutdown"})
|
||||
log.info("Gateway server stopped.")
|
||||
|
||||
self._app = FastAPI(
|
||||
title="AstrBot Gateway",
|
||||
description="Backend API for AstrBot dashboard",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
self._app.add_middleware(
|
||||
cast(Any, CORSMiddleware),
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
self._setup_routes()
|
||||
|
||||
# Run with uvicorn
|
||||
import uvicorn
|
||||
|
||||
config = uvicorn.Config(
|
||||
self._app,
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
log_level="info",
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
"""Set up API routes."""
|
||||
if self._app is None:
|
||||
return
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
# Health check
|
||||
@self._app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# WebSocket endpoint
|
||||
@self._app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket):
|
||||
await self.ws_manager.connect(ws)
|
||||
try:
|
||||
while True:
|
||||
data = await ws.receive_text()
|
||||
try:
|
||||
message = json.loads(data)
|
||||
response = await self._handle_ws_message(message)
|
||||
if response:
|
||||
await ws.send_json(response)
|
||||
except json.JSONDecodeError:
|
||||
await ws.send_json({"error": "Invalid JSON"})
|
||||
except WebSocketDisconnect:
|
||||
self.ws_manager.disconnect(ws)
|
||||
|
||||
# Stats router
|
||||
stats_router = APIRouter(prefix="/api/stats", tags=["stats"])
|
||||
|
||||
@stats_router.get("/overview")
|
||||
async def get_overview():
|
||||
return await self._get_stats_overview()
|
||||
|
||||
self._app.include_router(stats_router)
|
||||
|
||||
# Inspector router
|
||||
inspector_router = APIRouter(prefix="/api/inspector", tags=["inspector"])
|
||||
|
||||
@inspector_router.get("/stars")
|
||||
async def list_stars():
|
||||
return await self._list_stars()
|
||||
|
||||
@inspector_router.get("/stars/{star_name}")
|
||||
async def get_star(star_name: str):
|
||||
return await self._get_star_detail(star_name)
|
||||
|
||||
self._app.include_router(inspector_router)
|
||||
|
||||
# Memory router
|
||||
memory_router = APIRouter(prefix="/api/memory", tags=["memory"])
|
||||
|
||||
@memory_router.get("/")
|
||||
async def get_memory():
|
||||
return await self._get_memory_info()
|
||||
|
||||
self._app.include_router(memory_router)
|
||||
|
||||
async def _handle_ws_message(
|
||||
self, message: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Handle an incoming WebSocket message.
|
||||
|
||||
Args:
|
||||
message: Parsed JSON message from the client
|
||||
|
||||
Returns:
|
||||
Response message to send back, or None for no response
|
||||
"""
|
||||
msg_type = message.get("type")
|
||||
data = message.get("data", {})
|
||||
|
||||
if msg_type == "ping":
|
||||
return {"type": "pong", "data": {}}
|
||||
|
||||
if msg_type == "call_tool":
|
||||
return await self._handle_call_tool(data)
|
||||
|
||||
if msg_type == "get_stars":
|
||||
return {"type": "stars_list", "data": await self._list_stars()}
|
||||
|
||||
return {
|
||||
"type": "error",
|
||||
"data": {"message": f"Unknown message type: {msg_type}"},
|
||||
}
|
||||
|
||||
async def _handle_call_tool(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a tool call request via WebSocket."""
|
||||
star_name = data.get("star")
|
||||
tool_name = data.get("tool")
|
||||
arguments = data.get("arguments", {})
|
||||
|
||||
if not star_name or not tool_name:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"data": {"error": "Missing star or tool name"},
|
||||
}
|
||||
|
||||
try:
|
||||
result = await self.orchestrator.abp.call_star_tool(
|
||||
star_name, tool_name, arguments
|
||||
)
|
||||
return {"type": "tool_result", "data": {"result": result}}
|
||||
except Exception as e:
|
||||
return {"type": "tool_result", "data": {"error": str(e)}}
|
||||
|
||||
async def _get_stats_overview(self) -> dict[str, Any]:
|
||||
"""Get overview statistics."""
|
||||
return {
|
||||
"stars_count": len(self.orchestrator.abp._stars),
|
||||
"lsp_connected": self.orchestrator.lsp._connected,
|
||||
"mcp_sessions": getattr(self.orchestrator.mcp, "session", None) is not None,
|
||||
"acp_clients": len(getattr(self.orchestrator.acp, "_clients", [])),
|
||||
}
|
||||
|
||||
async def _list_stars(self) -> list[dict[str, Any]]:
|
||||
"""List all registered stars."""
|
||||
stars = []
|
||||
for name in self.orchestrator.abp._stars:
|
||||
stars.append({"name": name, "status": "active"})
|
||||
return stars
|
||||
|
||||
async def _get_star_detail(self, star_name: str) -> dict[str, Any]:
|
||||
"""Get details of a specific star."""
|
||||
star = self.orchestrator.abp._stars.get(star_name)
|
||||
if not star:
|
||||
return {"error": f"Star '{star_name}' not found"}
|
||||
return {"name": star_name, "status": "active"}
|
||||
|
||||
async def _get_memory_info(self) -> dict[str, Any]:
|
||||
"""Get memory usage information."""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return {
|
||||
"gc_objects": len(gc.get_objects()),
|
||||
"python_memory": "N/A", # Would need psutil for actual values
|
||||
}
|
||||
|
||||
def set_listen_address(self, host: str, port: int) -> None:
|
||||
"""Set the listen address for the gateway server."""
|
||||
self._host = host
|
||||
self._port = port
|
||||
103
astrbot/_internal/geteway/ws_manager.py
Normal file
103
astrbot/_internal/geteway/ws_manager.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
WebSocket connection manager for the AstrBot gateway.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
else:
|
||||
try:
|
||||
from fastapi import WebSocket
|
||||
except ImportError:
|
||||
logger.warning("FastAPI not installed, WebSocketManager unavailable.")
|
||||
WebSocket = cast(Any, None)
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""
|
||||
Manages all active WebSocket connections.
|
||||
|
||||
Provides connection/disconnection handling and broadcast capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: set[WebSocket] = set()
|
||||
self._lock = anyio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
"""Accept and register a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
self._connections.add(websocket)
|
||||
log.debug(f"WebSocket connected. Total: {len(self._connections)}")
|
||||
|
||||
async def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection."""
|
||||
async with self._lock:
|
||||
self._connections.discard(websocket)
|
||||
log.debug(f"WebSocket disconnected. Total: {len(self._connections)}")
|
||||
|
||||
async def send_json(self, websocket: WebSocket, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Send JSON data to a specific WebSocket.
|
||||
|
||||
Args:
|
||||
websocket: Target WebSocket connection
|
||||
data: Data to send (must be JSON-serializable)
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(data)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to send to WebSocket: {e}")
|
||||
await self.disconnect(websocket)
|
||||
|
||||
async def broadcast(self, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Broadcast JSON data to all connected WebSockets.
|
||||
|
||||
Args:
|
||||
data: Data to broadcast (must be JSON-serializable)
|
||||
"""
|
||||
async with self._lock:
|
||||
connections = list(self._connections)
|
||||
|
||||
for conn in connections:
|
||||
try:
|
||||
await conn.send_json(data)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to broadcast to WebSocket: {e}")
|
||||
async with self._lock:
|
||||
self._connections.discard(conn)
|
||||
|
||||
async def send_to(
|
||||
self, websocket: WebSocket, message: str | dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Send a message to a specific WebSocket.
|
||||
|
||||
Args:
|
||||
websocket: Target WebSocket connection
|
||||
message: Message to send (string or dict)
|
||||
"""
|
||||
try:
|
||||
if isinstance(message, str):
|
||||
await websocket.send_text(message)
|
||||
else:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to send to WebSocket: {e}")
|
||||
await self.disconnect(websocket)
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
"""Return the number of active connections."""
|
||||
return len(self._connections)
|
||||
5
astrbot/_internal/protocols/abp/__init__.py
Normal file
5
astrbot/_internal/protocols/abp/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""ABP module - AstrBot Protocol client implementation (built-in plugin protocol)."""
|
||||
|
||||
from .client import AstrbotAbpClient
|
||||
|
||||
__all__ = ["AstrbotAbpClient"]
|
||||
93
astrbot/_internal/protocols/abp/client.py
Normal file
93
astrbot/_internal/protocols/abp/client.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
ABP (AstrBot Protocol) client implementation.
|
||||
|
||||
ABP is the built-in plugin protocol where the orchestrator acts as client
|
||||
connecting to internal stars (plugins) embedded in the runtime.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.abp.base_astrbot_abp_client import BaseAstrbotAbpClient
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotAbpClient(BaseAstrbotAbpClient):
|
||||
"""
|
||||
ABP client for communicating with internal stars (built-in plugins).
|
||||
|
||||
The orchestrator acts as the client, sending requests to and receiving
|
||||
notifications from stars running within the same process.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connected = False
|
||||
self._stars: dict[str, Any] = {}
|
||||
# Use a simple dict for pending requests; we avoid asyncio.Future here.
|
||||
self._pending_requests: dict[str, Any] = {}
|
||||
self._request_id = 0
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to stars registry."""
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to internal stars registry."""
|
||||
log.debug("ABP client connecting to internal stars...")
|
||||
self._connected = True
|
||||
log.info("ABP client connected to internal stars registry.")
|
||||
|
||||
async def call_star_tool(
|
||||
self, star_name: str, tool_name: str, arguments: dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call a tool on a registered star.
|
||||
|
||||
Args:
|
||||
star_name: Name of the star (plugin)
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool call result
|
||||
"""
|
||||
if not self._connected:
|
||||
raise RuntimeError("ABP client is not connected")
|
||||
|
||||
star = self._stars.get(star_name)
|
||||
if not star:
|
||||
raise ValueError(f"Star '{star_name}' not found")
|
||||
|
||||
request_id = f"{self._request_id}"
|
||||
self._request_id += 1
|
||||
|
||||
# No asyncio.Future used; store a placeholder entry for tracking if needed.
|
||||
self._pending_requests[request_id] = None
|
||||
|
||||
try:
|
||||
# Call the star's tool handler
|
||||
result = await star.call_tool(tool_name, arguments)
|
||||
return result
|
||||
finally:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
|
||||
def register_star(self, star_name: str, star_instance: Any) -> None:
|
||||
"""Register a star (plugin) with the ABP client."""
|
||||
self._stars[star_name] = star_instance
|
||||
log.debug(f"Star '{star_name}' registered with ABP client.")
|
||||
|
||||
def unregister_star(self, star_name: str) -> None:
|
||||
"""Unregister a star from the ABP client."""
|
||||
self._stars.pop(star_name, None)
|
||||
log.debug(f"Star '{star_name}' unregistered from ABP client.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the ABP client connection."""
|
||||
self._connected = False
|
||||
# Clear any pending requests (no asyncio futures used in this implementation)
|
||||
self._pending_requests.clear()
|
||||
log.info("ABP client shut down.")
|
||||
6
astrbot/_internal/protocols/acp/__init__.py
Normal file
6
astrbot/_internal/protocols/acp/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""ACP module - AstrBot Communication Protocol client and server implementations."""
|
||||
|
||||
from .client import AstrbotAcpClient
|
||||
from .server import AstrbotAcpServer
|
||||
|
||||
__all__ = ["AstrbotAcpClient", "AstrbotAcpServer"]
|
||||
220
astrbot/_internal/protocols/acp/client.py
Normal file
220
astrbot/_internal/protocols/acp/client.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) client implementation.
|
||||
|
||||
ACP is a client-server protocol for inter-service communication,
|
||||
similar to MCP but designed specifically for AstrBot's architecture.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.acp.base_astrbot_acp_client import BaseAstrbotAcpClient
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotAcpClient(BaseAstrbotAcpClient):
|
||||
"""
|
||||
ACP client for communicating with ACP servers.
|
||||
|
||||
The orchestrator acts as an ACP client, connecting to external
|
||||
ACP-compatible services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connected = False
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
self._server_url: str | None = None
|
||||
self._pending_requests: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||
self._request_id = 0
|
||||
self._reader_task: asyncio.Task[None] | None = None
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to an ACP server."""
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to configured ACP servers.
|
||||
|
||||
ACP servers can be accessed via TCP (host:port) or Unix socket.
|
||||
"""
|
||||
log.debug("ACP client connecting...")
|
||||
# TODO: Load ACP server configurations
|
||||
self._connected = True
|
||||
log.info("ACP client initialized.")
|
||||
|
||||
async def connect_to_server(self, host: str, port: int) -> None:
|
||||
"""
|
||||
Connect to an ACP server via TCP.
|
||||
|
||||
Args:
|
||||
host: Server hostname or IP
|
||||
port: Server port
|
||||
"""
|
||||
self._server_url = f"{host}:{port}"
|
||||
self._reader, self._writer = await asyncio.open_connection(host, port)
|
||||
self._connected = True
|
||||
|
||||
# Start reading responses
|
||||
self._reader_task = asyncio.create_task(self._read_messages())
|
||||
|
||||
log.info(f"ACP client connected to {self._server_url}")
|
||||
|
||||
async def connect_to_unix_socket(self, socket_path: str) -> None:
|
||||
"""
|
||||
Connect to an ACP server via Unix socket.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the Unix socket
|
||||
"""
|
||||
self._server_url = f"unix://{socket_path}"
|
||||
self._reader, self._writer = await asyncio.open_unix_connection(socket_path)
|
||||
self._connected = True
|
||||
|
||||
self._reader_task = asyncio.create_task(self._read_messages())
|
||||
|
||||
log.info(f"ACP client connected to {self._server_url}")
|
||||
|
||||
async def _read_messages(self) -> None:
|
||||
"""Background task to read ACP messages."""
|
||||
if not self._reader:
|
||||
return
|
||||
|
||||
buffer = b""
|
||||
while self._connected:
|
||||
try:
|
||||
data = await self._reader.read(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
|
||||
while True:
|
||||
header_end = buffer.find(b"\n")
|
||||
if header_end == -1:
|
||||
break
|
||||
|
||||
try:
|
||||
header = json.loads(buffer[:header_end].decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
buffer = buffer[header_end + 1 :]
|
||||
continue
|
||||
|
||||
content_length = header.get("content-length", 0)
|
||||
if (
|
||||
content_length == 0
|
||||
or len(buffer) < header_end + 1 + content_length
|
||||
):
|
||||
break
|
||||
|
||||
content = buffer[header_end + 1 : header_end + 1 + content_length]
|
||||
buffer = buffer[header_end + 1 + content_length :]
|
||||
|
||||
message = json.loads(content.decode("utf-8"))
|
||||
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
future = self._pending_requests.pop(request_id, None)
|
||||
if future and not future.done():
|
||||
if "error" in message:
|
||||
future.set_exception(Exception(str(message["error"])))
|
||||
else:
|
||||
future.set_result(message.get("result", {}))
|
||||
else:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
if self._connected:
|
||||
log.error(f"ACP read error: {e}")
|
||||
break
|
||||
|
||||
async def _handle_notification(self, notification: dict[str, Any]) -> None:
|
||||
"""Handle incoming ACP notifications."""
|
||||
method = notification.get("method", "")
|
||||
log.debug(f"ACP notification: {method}")
|
||||
|
||||
async def call_tool(
|
||||
self, server_name: str, tool_name: str, arguments: dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call a tool on an ACP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the ACP server
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool call result
|
||||
"""
|
||||
if not self._connected:
|
||||
raise RuntimeError("ACP client is not connected")
|
||||
|
||||
request_id = str(self._request_id)
|
||||
self._request_id += 1
|
||||
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": f"{server_name}/{tool_name}",
|
||||
"params": arguments,
|
||||
}
|
||||
|
||||
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
await self._send_message(message)
|
||||
return await future
|
||||
|
||||
async def _send_message(self, message: dict[str, Any]) -> None:
|
||||
"""Send an ACP message."""
|
||||
if not self._writer:
|
||||
raise RuntimeError("ACP client not connected")
|
||||
|
||||
content = json.dumps(message)
|
||||
header = json.dumps({"content-length": len(content)}) + "\n"
|
||||
|
||||
self._writer.write((header + content).encode())
|
||||
await self._writer.drain()
|
||||
|
||||
async def send_notification(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Send a one-way notification to the server."""
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
await self._send_message(message)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the ACP client connection."""
|
||||
self._connected = False
|
||||
|
||||
if self._reader_task:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
try:
|
||||
await self._writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
log.info("ACP client shut down.")
|
||||
223
astrbot/_internal/protocols/acp/server.py
Normal file
223
astrbot/_internal/protocols/acp/server.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) server implementation.
|
||||
|
||||
ACP servers listen for connections from ACP clients and provide
|
||||
services/tools to the orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.acp.base_astrbot_acp_server import BaseAstrbotAcpServer
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotAcpServer(BaseAstrbotAcpServer):
|
||||
"""
|
||||
ACP server for accepting connections from ACP clients.
|
||||
|
||||
ACP servers expose tools/notifications that can be called by clients.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8765
|
||||
self._server: asyncio.Server | None = None
|
||||
self._clients: set[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = set()
|
||||
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._notification_handlers: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
def register_tool(self, name: str, handler: Callable[..., Any]) -> None:
|
||||
"""
|
||||
Register a tool handler.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
handler: Async callable that handles tool calls
|
||||
"""
|
||||
self._tool_handlers[name] = handler
|
||||
log.debug(f"ACP server registered tool: {name}")
|
||||
|
||||
def register_notification_handler(
|
||||
self, name: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
"""
|
||||
Register a notification handler.
|
||||
|
||||
Args:
|
||||
name: Notification method name
|
||||
handler: Async callable that handles notifications
|
||||
"""
|
||||
self._notification_handlers[name] = handler
|
||||
log.debug(f"ACP server registered notification handler: {name}")
|
||||
|
||||
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
|
||||
"""
|
||||
Start the ACP server.
|
||||
|
||||
Args:
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
"""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle_client,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
self._running = True
|
||||
log.info(f"ACP server listening on {host}:{port}")
|
||||
|
||||
async def _handle_client(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
"""Handle an incoming ACP client connection."""
|
||||
addr = writer.get_extra_info("peername")
|
||||
log.debug(f"ACP client connected: {addr}")
|
||||
self._clients.add((reader, writer))
|
||||
|
||||
buffer = b""
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
data = await reader.read(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
|
||||
while True:
|
||||
header_end = buffer.find(b"\n")
|
||||
if header_end == -1:
|
||||
break
|
||||
|
||||
try:
|
||||
header = json.loads(buffer[:header_end].decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
buffer = buffer[header_end + 1 :]
|
||||
continue
|
||||
|
||||
content_length = header.get("content-length", 0)
|
||||
if (
|
||||
content_length == 0
|
||||
or len(buffer) < header_end + 1 + content_length
|
||||
):
|
||||
break
|
||||
|
||||
content = buffer[
|
||||
header_end + 1 : header_end + 1 + content_length
|
||||
]
|
||||
buffer = buffer[header_end + 1 + content_length :]
|
||||
|
||||
message = json.loads(content.decode("utf-8"))
|
||||
response = await self._handle_message(message)
|
||||
|
||||
if response:
|
||||
content = json.dumps(response)
|
||||
resp_header = (
|
||||
json.dumps({"content-length": len(content)}) + "\n"
|
||||
)
|
||||
writer.write(resp_header.encode() + content.encode())
|
||||
await writer.drain()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"ACP client error ({addr}): {e}")
|
||||
break
|
||||
|
||||
finally:
|
||||
self._clients.discard((reader, writer))
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
log.debug(f"ACP client disconnected: {addr}")
|
||||
|
||||
async def _handle_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Handle an incoming ACP message."""
|
||||
method = message.get("method", "")
|
||||
msg_id = message.get("id")
|
||||
params = message.get("params", {})
|
||||
|
||||
# Check if it's a notification (no id) or request (has id)
|
||||
if msg_id is None:
|
||||
# Notification
|
||||
handler = self._notification_handlers.get(method)
|
||||
if handler:
|
||||
try:
|
||||
await handler(params)
|
||||
except Exception as e:
|
||||
log.error(f"ACP notification handler error ({method}): {e}")
|
||||
return None
|
||||
|
||||
# Request
|
||||
result = None
|
||||
error = None
|
||||
|
||||
handler = self._tool_handlers.get(method)
|
||||
if handler:
|
||||
try:
|
||||
result = await handler(params)
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
log.error(f"ACP tool handler error ({method}): {e}")
|
||||
else:
|
||||
error = f"Unknown method: {method}"
|
||||
|
||||
response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id}
|
||||
if error:
|
||||
response["error"] = {"code": -32601, "message": error}
|
||||
else:
|
||||
response["result"] = result
|
||||
|
||||
return response
|
||||
|
||||
async def broadcast_notification(self, method: str, params: dict[str, Any]) -> None:
|
||||
"""
|
||||
Broadcast a notification to all connected clients.
|
||||
|
||||
Args:
|
||||
method: Notification method name
|
||||
params: Notification parameters
|
||||
"""
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
content = json.dumps(message)
|
||||
header = json.dumps({"content-length": len(content)}) + "\n"
|
||||
data = header.encode() + content.encode()
|
||||
|
||||
for reader, writer in list(self._clients):
|
||||
try:
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to broadcast to client: {e}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the ACP server."""
|
||||
self._running = False
|
||||
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
||||
for reader, writer in list(self._clients):
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self._clients.clear()
|
||||
|
||||
log.info("ACP server shut down.")
|
||||
5
astrbot/_internal/protocols/lsp/__init__.py
Normal file
5
astrbot/_internal/protocols/lsp/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""LSP module - Language Server Protocol client implementation."""
|
||||
|
||||
from .client import AstrbotLspClient
|
||||
|
||||
__all__ = ["AstrbotLspClient"]
|
||||
243
astrbot/_internal/protocols/lsp/client.py
Normal file
243
astrbot/_internal/protocols/lsp/client.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
LSP (Language Server Protocol) client implementation.
|
||||
|
||||
The orchestrator acts as an LSP client, connecting to LSP servers
|
||||
that provide language intelligence features (completions, diagnostics, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ByteReceiveStream, ByteSendStream, Process
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.lsp.base_astrbot_lsp_client import BaseAstrbotLspClient
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotLspClient(BaseAstrbotLspClient):
|
||||
"""
|
||||
LSP client for communicating with LSP servers.
|
||||
|
||||
Implements the Microsoft Language Server Protocol for connecting to
|
||||
external language intelligence services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connected = False
|
||||
self._reader: ByteReceiveStream | None = None
|
||||
self._writer: ByteSendStream | None = None
|
||||
self._server_process: Process | None = None
|
||||
self._pending_requests: dict[int, Any] = {}
|
||||
self._request_id = 0
|
||||
self._server_command: list[str] | None = None
|
||||
# anyio TaskGroup handle for background readers
|
||||
self._task_group: Any | None = None
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to an LSP server."""
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to configured LSP servers.
|
||||
|
||||
LSP servers are typically stdio-based subprocesses. This method
|
||||
establishes the communication channel.
|
||||
"""
|
||||
log.debug("LSP client connecting...")
|
||||
# TODO: Load LSP server configurations and start subprocesses
|
||||
# For now, mark as connected in idle mode
|
||||
self._connected = True
|
||||
log.info("LSP client initialized.")
|
||||
|
||||
async def connect_to_server(self, command: list[str], workspace_uri: str) -> None:
|
||||
"""
|
||||
Connect to an LSP server subprocess.
|
||||
|
||||
Args:
|
||||
command: Command line to start the LSP server (e.g., ["python", "lsp_server.py"])
|
||||
workspace_uri: Root URI of the workspace to serve
|
||||
"""
|
||||
log.debug(f"Starting LSP server: {' '.join(command)}")
|
||||
|
||||
self._server_process = await anyio.open_process(
|
||||
command,
|
||||
stdin=-1,
|
||||
stdout=-1,
|
||||
stderr=-1,
|
||||
)
|
||||
self._reader = self._server_process.stdout
|
||||
self._writer = self._server_process.stdin
|
||||
self._server_command = command
|
||||
self._connected = True
|
||||
|
||||
# Start reading responses in background using anyio TaskGroup
|
||||
# Create and enter a TaskGroup so the reader runs until we close it at shutdown.
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
self._task_group.start_soon(self._read_responses)
|
||||
|
||||
# Send initialize request
|
||||
await self.send_request(
|
||||
"initialize",
|
||||
{
|
||||
"processId": None,
|
||||
"rootUri": workspace_uri,
|
||||
"capabilities": {},
|
||||
},
|
||||
)
|
||||
|
||||
# Send initialized notification
|
||||
await self.send_notification("initialized", {})
|
||||
|
||||
log.info(f"LSP client connected to server: {command[0]}")
|
||||
|
||||
async def send_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Send an LSP request and wait for response."""
|
||||
if not self._writer:
|
||||
raise RuntimeError("LSP client not connected")
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id += 1
|
||||
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
|
||||
# Use anyio.Event for request/response matching
|
||||
response_event: anyio.Event = anyio.Event()
|
||||
response_holder: dict[str, Any] = {}
|
||||
|
||||
async def set_response(response: dict[str, Any]) -> None:
|
||||
response_holder["response"] = response
|
||||
response_event.set()
|
||||
|
||||
self._pending_requests[request_id] = set_response
|
||||
|
||||
content = json.dumps(message)
|
||||
headers = f"Content-Length: {len(content)}\r\n\r\n"
|
||||
await self._writer.send((headers + content).encode())
|
||||
|
||||
# Wait for response with timeout
|
||||
with anyio.move_on_after(30):
|
||||
await response_event.wait()
|
||||
|
||||
if "response" in response_holder:
|
||||
return response_holder["response"]
|
||||
raise TimeoutError(f"LSP request {method} timed out")
|
||||
|
||||
async def send_notification(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Send an LSP notification (no response expected)."""
|
||||
if not self._writer:
|
||||
raise RuntimeError("LSP client not connected")
|
||||
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
|
||||
content = json.dumps(message)
|
||||
headers = f"Content-Length: {len(content)}\r\n\r\n"
|
||||
await self._writer.send((headers + content).encode())
|
||||
|
||||
async def _read_responses(self) -> None:
|
||||
"""Background task to read LSP responses."""
|
||||
if not self._reader:
|
||||
return
|
||||
|
||||
buffer = b""
|
||||
try:
|
||||
while self._connected:
|
||||
try:
|
||||
data = await self._reader.receive()
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
|
||||
while True:
|
||||
# Parse Content-Length header
|
||||
header_end = buffer.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
break
|
||||
|
||||
header = buffer[:header_end].decode("utf-8")
|
||||
content_length = 0
|
||||
for line in header.split("\r\n"):
|
||||
if line.startswith("Content-Length:"):
|
||||
content_length = int(line.split(":")[1].strip())
|
||||
|
||||
if content_length == 0:
|
||||
break
|
||||
|
||||
total_length = header_end + 4 + content_length
|
||||
if len(buffer) < total_length:
|
||||
break
|
||||
|
||||
content = buffer[header_end + 4 : total_length]
|
||||
buffer = buffer[total_length:]
|
||||
|
||||
response = json.loads(content.decode("utf-8"))
|
||||
|
||||
# Handle response vs notification
|
||||
if "id" in response:
|
||||
request_id = response["id"]
|
||||
handler = self._pending_requests.pop(request_id, None)
|
||||
if handler:
|
||||
await handler(response)
|
||||
else:
|
||||
# Notification (e.g., window/logMessage)
|
||||
await self._handle_notification(response)
|
||||
|
||||
except anyio.EndOfStream:
|
||||
break
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# Task was cancelled via the TaskGroup cancel/exit during shutdown
|
||||
pass
|
||||
|
||||
async def _handle_notification(self, notification: dict[str, Any]) -> None:
|
||||
"""Handle incoming LSP notifications."""
|
||||
method = notification.get("method", "")
|
||||
log.debug(f"LSP notification: {method}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the LSP client."""
|
||||
self._connected = False
|
||||
|
||||
if self._task_group:
|
||||
try:
|
||||
# Exit the TaskGroup, which cancels background tasks started within it
|
||||
await self._task_group.__aexit__(None, None, None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
pass
|
||||
self._task_group = None
|
||||
|
||||
if self._server_process:
|
||||
try:
|
||||
await self.send_notification("shutdown", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._server_process.terminate()
|
||||
try:
|
||||
with anyio.move_on_after(5):
|
||||
await self._server_process.wait()
|
||||
except Exception:
|
||||
self._server_process.kill()
|
||||
self._server_process = None
|
||||
|
||||
self._pending_requests.clear()
|
||||
log.info("LSP client shut down.")
|
||||
63
astrbot/_internal/protocols/mcp/__init__.py
Normal file
63
astrbot/_internal/protocols/mcp/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""MCP module - Model Context Protocol client and tool implementations.
|
||||
|
||||
This module provides MCP client functionality and MCP tool wrappers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .client import McpClient
|
||||
from .config import (
|
||||
DEFAULT_MCP_CONFIG,
|
||||
get_mcp_config_path,
|
||||
load_mcp_config,
|
||||
save_mcp_config,
|
||||
)
|
||||
from .tool import MCPTool
|
||||
|
||||
|
||||
# Exceptions
|
||||
class MCPInitError(Exception):
|
||||
"""Base exception for MCP initialization failures."""
|
||||
|
||||
|
||||
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
|
||||
"""Raised when MCP client initialization exceeds the configured timeout."""
|
||||
|
||||
|
||||
class MCPAllServicesFailedError(MCPInitError):
|
||||
"""Raised when all configured MCP services fail to initialize."""
|
||||
|
||||
|
||||
class MCPShutdownTimeoutError(asyncio.TimeoutError):
|
||||
"""Raised when MCP shutdown exceeds the configured timeout."""
|
||||
|
||||
def __init__(self, names: list[str], timeout: float) -> None:
|
||||
self.names = names
|
||||
self.timeout = timeout
|
||||
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPInitSummary:
|
||||
"""Summary of MCP initialization results."""
|
||||
|
||||
total: int
|
||||
success: int
|
||||
failed: list[str]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_MCP_CONFIG",
|
||||
"MCPAllServicesFailedError",
|
||||
"MCPInitError",
|
||||
"MCPInitSummary",
|
||||
"MCPInitTimeoutError",
|
||||
"MCPShutdownTimeoutError",
|
||||
"MCPTool",
|
||||
"McpClient",
|
||||
"get_mcp_config_path",
|
||||
"load_mcp_config",
|
||||
"save_mcp_config",
|
||||
]
|
||||
466
astrbot/_internal/protocols/mcp/client.py
Normal file
466
astrbot/_internal/protocols/mcp/client.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""MCP client implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from astrbot._internal.abc.mcp.base_astrbot_mcp_client import (
|
||||
BaseAstrbotMcpClient,
|
||||
McpServerConfig,
|
||||
McpToolInfo,
|
||||
)
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
try:
|
||||
import anyio
|
||||
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
||||
)
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
||||
)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""Prepare configuration, handle nested format."""
|
||||
if config.get("mcpServers"):
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
def _prepare_stdio_env(config: dict) -> dict:
|
||||
"""Preserve Windows executable resolution for stdio subprocesses."""
|
||||
if sys.platform != "win32":
|
||||
return config
|
||||
|
||||
pathext = os.environ.get("PATHEXT")
|
||||
if not pathext:
|
||||
return config
|
||||
|
||||
prepared = config.copy()
|
||||
env = dict(prepared.get("env") or {})
|
||||
env.setdefault("PATHEXT", pathext)
|
||||
prepared["env"] = env
|
||||
return prepared
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""Quick test MCP server connectivity."""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if transport_type == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"Connection timeout: {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
class McpClient(BaseAstrbotMcpClient):
|
||||
def __init__(self) -> None:
|
||||
# Initialize session and client objects
|
||||
self.session: mcp.ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
||||
|
||||
self.name: str | None = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = anyio.Event()
|
||||
self.process_pid: int | None = None
|
||||
|
||||
# Store connection config for reconnection
|
||||
self._mcp_server_config: McpServerConfig | None = None
|
||||
self._server_name: str | None = None
|
||||
self._reconnect_lock = anyio.Lock() # Lock for thread-safe reconnection
|
||||
self._reconnecting: bool = False # For logging and debugging
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize the MCP client connection.
|
||||
|
||||
Note: Actual server connections are made via connect_to_server().
|
||||
This method prepares the client for use.
|
||||
"""
|
||||
# MCP client is initialized on-demand via connect_to_server
|
||||
# This is a no-op stub to satisfy BaseAstrbotMcpClient
|
||||
logger.debug("MCP client initialized.")
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if MCP client has an active session."""
|
||||
return self.session is not None
|
||||
|
||||
async def list_tools(self) -> list[McpToolInfo]:
|
||||
"""List all tools from connected MCP servers."""
|
||||
if not self.session:
|
||||
return []
|
||||
result = await self.list_tools_and_save()
|
||||
tools = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema,
|
||||
}
|
||||
for tool in result.tools
|
||||
]
|
||||
return cast(list[McpToolInfo], tools)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
read_timeout_seconds: int = 60,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server with reconnection support."""
|
||||
return await self.call_tool_with_reconnect(
|
||||
tool_name=name,
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=timedelta(seconds=read_timeout_seconds),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stdio_process_pid(streams_context: object) -> int | None:
|
||||
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
|
||||
|
||||
TODO(refactor): replace this async-generator frame introspection with a
|
||||
stable MCP library hook once the upstream transport exposes process PID.
|
||||
"""
|
||||
generator = getattr(streams_context, "gen", None)
|
||||
frame = getattr(generator, "ag_frame", None)
|
||||
if frame is None:
|
||||
return None
|
||||
process = frame.f_locals.get("process")
|
||||
pid = getattr(process, "pid", None)
|
||||
try:
|
||||
return int(pid) if pid is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def connect_to_server(self, config: McpServerConfig, name: str) -> None:
|
||||
"""Connect to MCP server
|
||||
|
||||
If `url` parameter exists:
|
||||
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
||||
2. When transport is specified as `sse`, use SSE connection.
|
||||
3. If not specified, default to SSE connection to MCP service.
|
||||
|
||||
Args:
|
||||
config: Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
|
||||
"""
|
||||
# Store config for reconnection
|
||||
self._mcp_server_config = config
|
||||
self._server_name = name
|
||||
self.process_pid = None
|
||||
|
||||
cfg = _prepare_config(dict(config))
|
||||
|
||||
def logging_callback(
|
||||
msg: str | mcp.types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
# Handle MCP service error logs
|
||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
||||
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
|
||||
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
|
||||
self.server_errlogs.append(log_msg)
|
||||
|
||||
if "url" in cfg:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
if transport_type != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context,
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=cast(Any, logging_callback),
|
||||
),
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context,
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
cfg = _prepare_stdio_env(cfg)
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None:
|
||||
# Handle MCP service error logs
|
||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
||||
if msg.level in (
|
||||
"warning",
|
||||
"error",
|
||||
"critical",
|
||||
"alert",
|
||||
"emergency",
|
||||
):
|
||||
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
|
||||
self.server_errlogs.append(log_msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=cast(
|
||||
Any,
|
||||
LogPipe(
|
||||
level=logging.INFO,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
self.process_pid = self._extract_stdio_process_pid(stdio_transport)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport),
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
if not self.session:
|
||||
raise Exception("MCP Client is not initialized")
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server using the stored configuration.
|
||||
|
||||
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
||||
|
||||
Raises:
|
||||
Exception: raised when reconnection fails
|
||||
"""
|
||||
async with self._reconnect_lock:
|
||||
# Check if already reconnecting (useful for logging)
|
||||
if self._reconnecting:
|
||||
logger.debug(
|
||||
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
if not self._mcp_server_config or not self._server_name:
|
||||
raise Exception("Cannot reconnect: missing connection configuration")
|
||||
|
||||
self._reconnecting = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Attempting to reconnect to MCP server {self._server_name}..."
|
||||
)
|
||||
|
||||
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
||||
if self.exit_stack:
|
||||
self._old_exit_stacks.append(self.exit_stack)
|
||||
|
||||
# Mark old session as invalid
|
||||
self.session = None
|
||||
|
||||
# Create new exit stack for new connection
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
# Reconnect using stored config
|
||||
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
||||
await self.list_tools_and_save()
|
||||
|
||||
logger.info(
|
||||
f"Successfully reconnected to MCP server {self._server_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._reconnecting = False
|
||||
|
||||
async def call_tool_with_reconnect(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict,
|
||||
read_timeout_seconds: timedelta,
|
||||
) -> mcp.types.CallToolResult:
|
||||
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
||||
|
||||
Args:
|
||||
tool_name: tool name
|
||||
arguments: tool arguments
|
||||
read_timeout_seconds: read timeout
|
||||
|
||||
Returns:
|
||||
MCP tool call result
|
||||
|
||||
Raises:
|
||||
ValueError: MCP session is not available
|
||||
anyio.ClosedResourceError: raised after reconnection failure
|
||||
"""
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING), # type: ignore[arg-type]
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_with_retry():
|
||||
if not self.session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
|
||||
try:
|
||||
return await self.session.call_tool(
|
||||
name=tool_name,
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
logger.warning(
|
||||
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
||||
)
|
||||
# Attempt to reconnect
|
||||
await self._reconnect()
|
||||
# Reraise the exception to trigger tenacity retry
|
||||
raise
|
||||
|
||||
return await _call_with_retry()
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing current exit stack: {e}")
|
||||
|
||||
# Don't close old exit stacks as they may be in different task contexts
|
||||
# They will be garbage collected naturally
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
self.process_pid = None
|
||||
55
astrbot/_internal/protocols/mcp/config.py
Normal file
55
astrbot/_internal/protocols/mcp/config.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""MCP configuration management."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
|
||||
def get_mcp_config_path() -> str:
|
||||
"""Get the path to the MCP configuration file."""
|
||||
data_dir = get_astrbot_data_path()
|
||||
return os.path.join(data_dir, "mcp_server.json")
|
||||
|
||||
|
||||
def load_mcp_config() -> dict:
|
||||
"""Load MCP configuration from file.
|
||||
|
||||
Returns:
|
||||
MCP configuration dict. If file doesn't exist, returns default config.
|
||||
|
||||
"""
|
||||
config_path = get_mcp_config_path()
|
||||
if not os.path.exists(config_path):
|
||||
# Create default config if not exists
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
|
||||
def save_mcp_config(config: dict) -> bool:
|
||||
"""Save MCP configuration to file.
|
||||
|
||||
Args:
|
||||
config: MCP configuration dict to save.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
|
||||
"""
|
||||
config_path = get_mcp_config_path()
|
||||
try:
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
45
astrbot/_internal/protocols/mcp/tool.py
Normal file
45
astrbot/_internal/protocols/mcp/tool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""MCP tool wrapper."""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
try:
|
||||
import mcp
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
mcp = None # type: ignore
|
||||
|
||||
from astrbot._internal.tools.base import FunctionTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
|
||||
|
||||
class MCPTool(FunctionTool):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_tool: "mcp.types.Tool",
|
||||
mcp_client: "McpClient",
|
||||
mcp_server_name: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=mcp_tool.inputSchema,
|
||||
)
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_client = mcp_client
|
||||
self.mcp_server_name = mcp_server_name
|
||||
self.source = "mcp"
|
||||
|
||||
async def call(self, **kwargs: Any) -> Any:
|
||||
"""Call the MCP tool with the given arguments."""
|
||||
# Note: For actual usage, context.tool_call_timeout is needed
|
||||
# but for simplicity we use a default timeout here
|
||||
return await self.mcp_client.call_tool_with_reconnect(
|
||||
tool_name=self.mcp_tool.name,
|
||||
arguments=kwargs,
|
||||
read_timeout_seconds=timedelta(seconds=60),
|
||||
)
|
||||
3
astrbot/_internal/runtime/__init__.py
Normal file
3
astrbot/_internal/runtime/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from astrbot._internal.runtime.__main__ import bootstrap
|
||||
|
||||
__all__ = ["bootstrap"]
|
||||
24
astrbot/_internal/runtime/__main__.py
Normal file
24
astrbot/_internal/runtime/__main__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.geteway.server import AstrbotGateway
|
||||
from astrbot._internal.runtime.orchestrator import AstrbotOrchestrator
|
||||
|
||||
|
||||
async def bootstrap():
|
||||
orchestrator: BaseAstrbotOrchestrator = AstrbotOrchestrator()
|
||||
gw: BaseAstrbotGateway = AstrbotGateway(orchestrator)
|
||||
|
||||
# anyio 的结构化并发
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(orchestrator.lsp.connect) # 启动 LSP client
|
||||
tg.start_soon(orchestrator.mcp.connect) # 启动 MCP client
|
||||
tg.start_soon(orchestrator.acp.connect) # 启动 ACP client
|
||||
tg.start_soon(orchestrator.abp.connect) # 启动 ABP client
|
||||
await anyio.sleep(0.5)
|
||||
tg.start_soon(orchestrator.run_loop) # 启动编排器循环
|
||||
|
||||
tg.start_soon(gw.serve) # 面板后端服务
|
||||
164
astrbot/_internal/runtime/orchestrator.py
Normal file
164
astrbot/_internal/runtime/orchestrator.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
AstrBot Orchestrator - core runtime that coordinates all protocol clients.
|
||||
|
||||
The orchestrator manages the lifecycle of LSP, MCP, ACP, and ABP clients,
|
||||
and runs the main event loop that dispatches messages between components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
|
||||
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
|
||||
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
from astrbot._internal.stars import RuntimeStatusStar
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotOrchestrator(BaseAstrbotOrchestrator):
|
||||
"""
|
||||
Core runtime orchestrator for AstrBot.
|
||||
|
||||
Manages:
|
||||
- LSP client: Language Server Protocol for editor integrations
|
||||
- MCP client: Model Context Protocol for external tool servers
|
||||
- ACP client: AstrBot Communication Protocol for inter-service communication
|
||||
- ABP client: AstrBot Protocol for built-in star (plugin) communication
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Initialize protocol clients (use concrete types for full method access)
|
||||
self.lsp = AstrbotLspClient()
|
||||
self.mcp = McpClient()
|
||||
self.acp = AstrbotAcpClient()
|
||||
self.abp = AstrbotAbpClient()
|
||||
|
||||
self._running = False
|
||||
self._stars: dict[str, Any] = {}
|
||||
self._message_count: int = 0
|
||||
self._last_activity_timestamp: float | None = None
|
||||
|
||||
# Auto-register RuntimeStatusStar
|
||||
self._runtime_status_star = RuntimeStatusStar()
|
||||
self._runtime_status_star.set_orchestrator(self)
|
||||
self._stars["runtime-status-star"] = self._runtime_status_star
|
||||
self.abp.register_star("runtime-status-star", self._runtime_status_star)
|
||||
|
||||
log.debug("AstrbotOrchestrator initialized.")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Initialize all protocol clients.
|
||||
|
||||
Calls connect() on all protocol clients to prepare them for use.
|
||||
"""
|
||||
log.info("Starting AstrbotOrchestrator...")
|
||||
|
||||
await self.lsp.connect()
|
||||
await self.mcp.connect()
|
||||
await self.acp.connect()
|
||||
await self.abp.connect()
|
||||
|
||||
self._running = True
|
||||
log.info("AstrbotOrchestrator started.")
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""
|
||||
Main orchestrator event loop.
|
||||
|
||||
This loop runs continuously, handling:
|
||||
- Periodic health checks of protocol clients
|
||||
- Message routing between protocols
|
||||
- Star (plugin) lifecycle management
|
||||
"""
|
||||
self._running = True
|
||||
log.info("AstrbotOrchestrator run loop started.")
|
||||
|
||||
stop_event = anyio.Event()
|
||||
|
||||
def set_stop() -> None:
|
||||
stop_event.set()
|
||||
|
||||
# Store the callback for cleanup
|
||||
self._stop_callback = set_stop
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
# TODO: Periodic tasks:
|
||||
# - Check LSP server health
|
||||
# - Check MCP session status
|
||||
# - Check ACP client connections
|
||||
# - Process any pending star notifications
|
||||
|
||||
# Wait for 5 seconds or until shutdown is called
|
||||
with anyio.move_on_after(5):
|
||||
await stop_event.wait()
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
log.info("Orchestrator run loop cancelled.")
|
||||
finally:
|
||||
self._running = False
|
||||
self._stop_callback = None
|
||||
log.info("AstrbotOrchestrator run loop stopped.")
|
||||
|
||||
async def register_star(self, name: str, star_instance: Any) -> None:
|
||||
"""
|
||||
Register a star (plugin) with the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Unique name for the star
|
||||
star_instance: Star plugin instance
|
||||
"""
|
||||
self._stars[name] = star_instance
|
||||
self.abp.register_star(name, star_instance)
|
||||
log.info(f"Star '{name}' registered.")
|
||||
|
||||
async def unregister_star(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a star (plugin) from the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Name of the star to unregister
|
||||
"""
|
||||
self._stars.pop(name, None)
|
||||
self.abp.unregister_star(name)
|
||||
log.info(f"Star '{name}' unregistered.")
|
||||
|
||||
async def get_star(self, name: str) -> Any | None:
|
||||
"""Get a registered star by name."""
|
||||
return self._stars.get(name)
|
||||
|
||||
async def list_stars(self) -> list[str]:
|
||||
"""List all registered star names."""
|
||||
return list(self._stars.keys())
|
||||
|
||||
def record_activity(self) -> None:
|
||||
"""Record a message activity for stats tracking."""
|
||||
self._message_count += 1
|
||||
import time
|
||||
|
||||
self._last_activity_timestamp = time.time()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the orchestrator and all protocol clients.
|
||||
"""
|
||||
log.info("Shutting down AstrbotOrchestrator...")
|
||||
self._running = False
|
||||
|
||||
# Shutdown all protocol clients
|
||||
await self.lsp.shutdown()
|
||||
await self.acp.shutdown()
|
||||
await self.abp.shutdown()
|
||||
|
||||
# MCP cleanup
|
||||
await self.mcp.cleanup()
|
||||
|
||||
log.info("AstrbotOrchestrator shut down.")
|
||||
13
astrbot/_internal/skills/__init__.py
Normal file
13
astrbot/_internal/skills/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Internal skills module - re-exports from core.skills.skill_manager."""
|
||||
|
||||
from astrbot.core.skills.skill_manager import (
|
||||
SkillInfo,
|
||||
SkillManager,
|
||||
build_skills_prompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SkillInfo",
|
||||
"SkillManager",
|
||||
"build_skills_prompt",
|
||||
]
|
||||
7
astrbot/_internal/stars/__init__.py
Normal file
7
astrbot/_internal/stars/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Stars (built-in plugins) for AstrBot runtime.
|
||||
"""
|
||||
|
||||
from astrbot._internal.stars.runtime_status_star import RuntimeStatusStar
|
||||
|
||||
__all__ = ["RuntimeStatusStar"]
|
||||
127
astrbot/_internal/stars/runtime_status_star.py
Normal file
127
astrbot/_internal/stars/runtime_status_star.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
RuntimeStatusStar - ABP plugin that exposes core runtime internal state.
|
||||
|
||||
This star provides tools for querying:
|
||||
- Runtime status (running state, uptime)
|
||||
- Protocol client status (LSP, MCP, ACP, ABP)
|
||||
- Registered stars registry
|
||||
- Message counts and metrics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeStatusStar:
|
||||
"""
|
||||
ABP star that exposes core runtime internal state as callable tools.
|
||||
|
||||
Tools provided:
|
||||
- get_runtime_status: Returns running state and uptime
|
||||
- get_protocol_status: Returns LSP, MCP, ACP, ABP status
|
||||
- get_star_registry: Returns registered star names
|
||||
- get_stats: Returns message counts and metrics
|
||||
"""
|
||||
|
||||
name: str = "runtime-status-star"
|
||||
description: str = "ABP plugin that exposes core runtime internal state"
|
||||
|
||||
_start_time: float = field(default_factory=time.time, init=False)
|
||||
_orchestrator: Any = field(default=None, init=False)
|
||||
|
||||
def set_orchestrator(self, orchestrator: Any) -> None:
|
||||
"""Set the orchestrator reference for status queries."""
|
||||
self._orchestrator = orchestrator
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Handle tool calls from ABP client.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool result
|
||||
"""
|
||||
if tool_name == "get_runtime_status":
|
||||
return self._get_runtime_status()
|
||||
elif tool_name == "get_protocol_status":
|
||||
return await self._get_protocol_status()
|
||||
elif tool_name == "get_star_registry":
|
||||
return await self._get_star_registry()
|
||||
elif tool_name == "get_stats":
|
||||
return self._get_stats()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {tool_name}")
|
||||
|
||||
def _get_runtime_status(self) -> dict[str, Any]:
|
||||
"""Get overall runtime state."""
|
||||
running = (
|
||||
getattr(self._orchestrator, "running", False)
|
||||
if self._orchestrator
|
||||
else False
|
||||
)
|
||||
uptime_seconds = time.time() - self._start_time
|
||||
return {
|
||||
"running": running,
|
||||
"uptime_seconds": uptime_seconds,
|
||||
}
|
||||
|
||||
async def _get_protocol_status(self) -> dict[str, Any]:
|
||||
"""Get status of each protocol client."""
|
||||
if not self._orchestrator:
|
||||
return {
|
||||
"lsp": {"connected": False, "name": "lsp-client"},
|
||||
"mcp": {"connected": False, "name": "mcp-client"},
|
||||
"acp": {"connected": False, "name": "acp-client"},
|
||||
"abp": {"connected": False, "name": "abp-client"},
|
||||
}
|
||||
|
||||
return {
|
||||
"lsp": {
|
||||
"connected": getattr(self._orchestrator.lsp, "connected", False),
|
||||
"name": "lsp-client",
|
||||
},
|
||||
"mcp": {
|
||||
"connected": getattr(self._orchestrator.mcp, "connected", False),
|
||||
"name": "mcp-client",
|
||||
},
|
||||
"acp": {
|
||||
"connected": getattr(self._orchestrator.acp, "connected", False),
|
||||
"name": "acp-client",
|
||||
},
|
||||
"abp": {
|
||||
"connected": getattr(self._orchestrator.abp, "connected", False),
|
||||
"name": "abp-client",
|
||||
},
|
||||
}
|
||||
|
||||
async def _get_star_registry(self) -> dict[str, Any]:
|
||||
"""Get list of registered stars."""
|
||||
if not self._orchestrator:
|
||||
return {"stars": []}
|
||||
|
||||
stars = await self._orchestrator.list_stars()
|
||||
return {"stars": stars}
|
||||
|
||||
def _get_stats(self) -> dict[str, Any]:
|
||||
"""Get message counts and metrics."""
|
||||
result: dict[str, Any] = {
|
||||
"uptime_seconds": time.time() - self._start_time,
|
||||
}
|
||||
if self._orchestrator:
|
||||
result["total_messages"] = getattr(self._orchestrator, "_message_count", 0)
|
||||
last_ts = getattr(self._orchestrator, "_last_activity_timestamp", None)
|
||||
if last_ts is not None:
|
||||
result["last_activity"] = datetime.fromtimestamp(
|
||||
last_ts, tz=timezone.utc
|
||||
).isoformat()
|
||||
else:
|
||||
result["last_activity"] = None
|
||||
return result
|
||||
5
astrbot/_internal/tools/__init__.py
Normal file
5
astrbot/_internal/tools/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Internal tools module for AstrBot runtime."""
|
||||
|
||||
from .base import FunctionTool, ToolSet
|
||||
|
||||
__all__ = ["FunctionTool", "ToolSet"]
|
||||
332
astrbot/_internal/tools/base.py
Normal file
332
astrbot/_internal/tools/base.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Base tool classes for AstrBot internal runtime.
|
||||
|
||||
This module provides the FunctionTool base class used by MCP tools
|
||||
in the new internal architecture.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""A class representing the schema of a tool for function calling."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
|
||||
description: str
|
||||
"""The description of the tool."""
|
||||
|
||||
parameters: ParametersType = field(default_factory=dict)
|
||||
"""The parameters of the tool, in JSON Schema format."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parameters(self) -> "ToolSchema":
|
||||
"""Validate the parameters JSON schema."""
|
||||
import jsonschema
|
||||
|
||||
jsonschema.validate(
|
||||
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool(ToolSchema):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[str | None] | AsyncGenerator[Any, None]] | None = (
|
||||
None
|
||||
)
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
"""
|
||||
The module path of the handler function. This is empty when the origin is mcp.
|
||||
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
|
||||
causing the handler's __module__ to be functools
|
||||
"""
|
||||
|
||||
active: bool = True
|
||||
"""
|
||||
Whether the tool is active. This field is a special field for AstrBot.
|
||||
You can ignore it when integrating with other frameworks.
|
||||
"""
|
||||
|
||||
is_background_task: bool = False
|
||||
"""
|
||||
Declare this tool as a background task. Background tasks return immediately
|
||||
with a task identifier while the real work continues asynchronously.
|
||||
"""
|
||||
|
||||
source: str = "mcp"
|
||||
"""
|
||||
Origin of this tool: 'plugin' (from star plugins), 'internal' (AstrBot built-in),
|
||||
or 'mcp' (from MCP servers). Used by WebUI for display grouping.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
async def call(self, **kwargs: Any) -> Any:
|
||||
"""Run the tool with the given arguments. The handler field has priority."""
|
||||
raise NotImplementedError(
|
||||
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||
)
|
||||
|
||||
|
||||
class ToolSet:
|
||||
"""
|
||||
A collection of FunctionTools grouped under a namespace.
|
||||
|
||||
ToolSets allow organizing related tools together. The LLM sees tools
|
||||
as "namespace/tool_name" when calling.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, tools: list[FunctionTool] | None = None) -> None:
|
||||
self.namespace = namespace
|
||||
self._tools: dict[str, FunctionTool] = {}
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add(tool)
|
||||
|
||||
def add(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the set."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def add_tool(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the set (alias for add())."""
|
||||
self.add(tool)
|
||||
|
||||
def remove(self, name: str) -> FunctionTool | None:
|
||||
"""Remove and return a tool by name."""
|
||||
return self._tools.pop(name, None)
|
||||
|
||||
def remove_tool(self, name: str) -> None:
|
||||
"""Remove a tool by its name."""
|
||||
self._tools.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_tool(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name (alias for get)."""
|
||||
return self.get(name)
|
||||
|
||||
def list_tools(self) -> list[FunctionTool]:
|
||||
"""List all tools in this set."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def __iter__(self) -> Iterator[FunctionTool]:
|
||||
return iter(self._tools.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._tools)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ToolSet(namespace={self.namespace!r}, tools={self.list_tools()!r})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ToolSet({self.namespace}, {len(self)} tools)"
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""Get names of all tools in this set."""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self) == 0
|
||||
|
||||
def merge(self, other: "ToolSet") -> None:
|
||||
"""Merge another ToolSet into this one."""
|
||||
for tool in other.tools:
|
||||
self.add(tool)
|
||||
|
||||
def normalize(self) -> None:
|
||||
"""Sort tools by name for deterministic serialization."""
|
||||
self._tools = dict(sorted(self._tools.items(), key=lambda x: x[0]))
|
||||
|
||||
def get_light_tool_set(self) -> "ToolSet":
|
||||
"""Return a light tool set with only name/description."""
|
||||
light_tools = []
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, "active") and not tool.active:
|
||||
continue
|
||||
light_tools.append(
|
||||
FunctionTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters={"type": "object", "properties": {}},
|
||||
handler=None,
|
||||
)
|
||||
)
|
||||
return ToolSet("default", light_tools)
|
||||
|
||||
def get_param_only_tool_set(self) -> "ToolSet":
|
||||
"""Return a tool set with name/parameters only (no description)."""
|
||||
param_tools = []
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, "active") and not tool.active:
|
||||
continue
|
||||
params = (
|
||||
copy.deepcopy(tool.parameters)
|
||||
if tool.parameters
|
||||
else {"type": "object", "properties": {}}
|
||||
)
|
||||
param_tools.append(
|
||||
FunctionTool(
|
||||
name=tool.name,
|
||||
description="",
|
||||
parameters=params,
|
||||
handler=None,
|
||||
)
|
||||
)
|
||||
return ToolSet("default", param_tools)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[FunctionTool]:
|
||||
"""List all tools in this set."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def openai_schema(
|
||||
self, omit_empty_parameter_field: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tools to OpenAI API function calling schema format."""
|
||||
result: list[dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
func_def: dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {"name": tool.name},
|
||||
}
|
||||
if tool.description:
|
||||
func_def["function"]["description"] = tool.description
|
||||
|
||||
if tool.parameters is not None:
|
||||
if (
|
||||
tool.parameters.get("properties")
|
||||
) or not omit_empty_parameter_field:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
return result
|
||||
|
||||
def anthropic_schema(self) -> list[dict]:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
input_schema: dict[str, Any] = {"type": "object"}
|
||||
if tool.parameters:
|
||||
input_schema["properties"] = tool.parameters.get("properties", {})
|
||||
input_schema["required"] = tool.parameters.get("required", [])
|
||||
tool_def: dict[str, Any] = {"name": tool.name, "input_schema": input_schema}
|
||||
if tool.description:
|
||||
tool_def["description"] = tool.description
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
def google_schema(self) -> dict:
|
||||
"""Convert tools to Google GenAI API format."""
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
origin_type = schema.get("type")
|
||||
target_type = origin_type
|
||||
|
||||
if isinstance(origin_type, list):
|
||||
target_type = next((t for t in origin_type if t != "null"), "string")
|
||||
|
||||
if target_type in supported_types:
|
||||
result["type"] = target_type
|
||||
if "format" in schema and schema["format"] in supported_formats.get(result["type"], set()):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
if "additionalProperties" in prop_value:
|
||||
del prop_value["additionalProperties"]
|
||||
properties[key] = prop_value
|
||||
if properties:
|
||||
result["properties"] = properties
|
||||
|
||||
if target_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if isinstance(items_schema, dict):
|
||||
result["items"] = convert_schema(items_schema)
|
||||
else:
|
||||
result["items"] = {"type": "string"}
|
||||
|
||||
return result
|
||||
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
d: dict[str, Any] = {"name": tool.name}
|
||||
if tool.description:
|
||||
d["description"] = tool.description
|
||||
if tool.parameters:
|
||||
d["parameters"] = convert_schema(tool.parameters)
|
||||
tools_list.append(d)
|
||||
|
||||
declarations: dict[str, Any] = {}
|
||||
if tools_list:
|
||||
declarations["function_declarations"] = tools_list
|
||||
return declarations
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
|
||||
"""Get tools in OpenAI function calling style (deprecated)."""
|
||||
return self.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
def get_func_desc_anthropic_style(self):
|
||||
"""Get tools in Anthropic style (deprecated)."""
|
||||
return self.anthropic_schema()
|
||||
|
||||
def get_func_desc_google_genai_style(self):
|
||||
"""Get tools in Google GenAI style (deprecated)."""
|
||||
return self.google_schema()
|
||||
48
astrbot/_internal/tools/builtin.py
Normal file
48
astrbot/_internal/tools/builtin.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Builtin tools for AstrBot - re-exports from core.tools for backward compatibility.
|
||||
|
||||
This module re-exports the builtin tools (cron, send_message, kb_query) from
|
||||
the deprecated core.tools module for backward compatibility.
|
||||
|
||||
TODO: These tools should be fully migrated to _internal and core.tools
|
||||
should be removed once all consumers update their imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Re-export cron tools
|
||||
from astrbot.core.tools.cron_tools import (
|
||||
CREATE_CRON_JOB_TOOL,
|
||||
DELETE_CRON_JOB_TOOL,
|
||||
LIST_CRON_JOBS_TOOL,
|
||||
CreateActiveCronTool,
|
||||
DeleteCronJobTool,
|
||||
ListCronJobsTool,
|
||||
)
|
||||
|
||||
# Re-export knowledge_base_query tool
|
||||
from astrbot.core.tools.kb_query import (
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
KnowledgeBaseQueryTool,
|
||||
)
|
||||
|
||||
# Re-export send_message tool
|
||||
from astrbot.core.tools.send_message import (
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
SendMessageToUserTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cron tools
|
||||
"CREATE_CRON_JOB_TOOL",
|
||||
"DELETE_CRON_JOB_TOOL",
|
||||
"KNOWLEDGE_BASE_QUERY_TOOL",
|
||||
"LIST_CRON_JOBS_TOOL",
|
||||
"SEND_MESSAGE_TO_USER_TOOL",
|
||||
# Classes
|
||||
"CreateActiveCronTool",
|
||||
"DeleteCronJobTool",
|
||||
"KnowledgeBaseQueryTool",
|
||||
"ListCronJobsTool",
|
||||
"SendMessageToUserTool",
|
||||
]
|
||||
278
astrbot/_internal/tools/registry.py
Normal file
278
astrbot/_internal/tools/registry.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tools registry for AstrBot internal runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Re-export from base
|
||||
from astrbot._internal.tools.base import FunctionTool, ToolSet
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_MCP_CONFIG",
|
||||
"ENABLE_MCP_TIMEOUT_ENV",
|
||||
"FuncCall",
|
||||
"FunctionTool",
|
||||
"FunctionToolManager",
|
||||
"MCPAllServicesFailedError",
|
||||
"MCPInitError",
|
||||
"MCPInitSummary",
|
||||
"MCPInitTimeoutError",
|
||||
"MCPShutdownTimeoutError",
|
||||
"ToolSet",
|
||||
]
|
||||
|
||||
|
||||
# MCP config constants (re-exported from protocols)
|
||||
try:
|
||||
from astrbot._internal.protocols.mcp import (
|
||||
DEFAULT_MCP_CONFIG,
|
||||
MCPAllServicesFailedError,
|
||||
MCPInitError,
|
||||
MCPInitSummary,
|
||||
MCPInitTimeoutError,
|
||||
MCPShutdownTimeoutError,
|
||||
)
|
||||
except ImportError:
|
||||
DEFAULT_MCP_CONFIG: dict[str, Any] = {}
|
||||
MCPAllServicesFailedError: type[Exception] = Exception
|
||||
MCPInitError: type[Exception] = Exception
|
||||
MCPInitSummary: type[dict] = dict
|
||||
MCPInitTimeoutError: type[TimeoutError] = TimeoutError
|
||||
MCPShutdownTimeoutError: type[TimeoutError] = TimeoutError
|
||||
|
||||
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_TIMEOUT_ENABLED"
|
||||
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
||||
|
||||
|
||||
class FunctionToolManager:
|
||||
"""Central registry for all function tools."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._func_list: list[FunctionTool] = []
|
||||
|
||||
@property
|
||||
def func_list(self) -> list[FunctionTool]:
|
||||
"""Get the list of function tools."""
|
||||
return self._func_list
|
||||
|
||||
@func_list.setter
|
||||
def func_list(self, value: list[FunctionTool]) -> None:
|
||||
"""Set the list of function tools."""
|
||||
self._func_list = value
|
||||
|
||||
def add(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the registry."""
|
||||
self._func_list.append(tool)
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
"""Remove a tool by name. Returns True if found."""
|
||||
for i, f in enumerate(self._func_list):
|
||||
if f.name == name:
|
||||
self._func_list.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_func(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name. Returns the last active tool if multiple match."""
|
||||
last_match: FunctionTool | None = None
|
||||
for f in reversed(self._func_list):
|
||||
if f.name == name:
|
||||
if getattr(f, "active", True):
|
||||
return f
|
||||
if last_match is None:
|
||||
last_match = f
|
||||
return last_match
|
||||
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""Return a ToolSet with all active tools, deduplicated by name."""
|
||||
seen: dict[str, FunctionTool] = {}
|
||||
for tool in reversed(self._func_list):
|
||||
if tool.name not in seen and getattr(tool, "active", True):
|
||||
seen[tool.name] = tool
|
||||
return ToolSet("default", list(seen.values()))
|
||||
|
||||
def register_internal_tools(self) -> None:
|
||||
"""Register built-in computer tools (shell, python, browser, neo)."""
|
||||
# Import here to avoid circular imports
|
||||
from astrbot.core.computer.computer_tool_provider import get_all_tools
|
||||
|
||||
for tool in get_all_tools():
|
||||
if self.get_func(tool.name) is None:
|
||||
self.add(tool)
|
||||
|
||||
# MCP-related stub methods for base class compatibility
|
||||
async def enable_mcp_server(
|
||||
self, name: str, config: dict[str, Any], init_timeout: int = 30
|
||||
) -> None:
|
||||
"""Enable an MCP server (stub)."""
|
||||
pass
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server (stub)."""
|
||||
pass
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""Initialize MCP clients (stub)."""
|
||||
pass
|
||||
|
||||
async def test_mcp_server_connection(
|
||||
self, config: dict[str, Any]
|
||||
) -> tuple[bool, str]:
|
||||
"""Test MCP server connection (stub)."""
|
||||
return False, "Not implemented"
|
||||
|
||||
async def sync_modelscope_mcp_servers(self) -> None:
|
||||
"""Sync ModelScope MCP servers (stub)."""
|
||||
pass
|
||||
|
||||
def load_mcp_config(self) -> dict[str, Any]:
|
||||
"""Load MCP configuration (stub)."""
|
||||
return {"mcpServers": {}}
|
||||
|
||||
def save_mcp_config(self, config: dict[str, Any]) -> bool:
|
||||
"""Save MCP configuration (stub)."""
|
||||
return True
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
"""Activate an LLM tool (stub)."""
|
||||
return True
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""Deactivate an LLM tool (stub)."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def mcp_client_dict(self) -> dict[str, Any]:
|
||||
"""Return dict of MCP clients (stub)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mcp_server_runtime_view(self) -> dict[str, Any]:
|
||||
"""Return runtime view of MCP servers (stub)."""
|
||||
return {}
|
||||
|
||||
|
||||
class FuncCall(FunctionToolManager):
|
||||
"""Alias for FunctionToolManager for backward compatibility."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._mcp_server_runtime_view: dict[str, Any] = {}
|
||||
self._mcp_client_dict: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def mcp_server_runtime_view(self) -> dict[str, Any]:
|
||||
"""Return runtime view of MCP servers."""
|
||||
return self._mcp_server_runtime_view
|
||||
|
||||
@property
|
||||
def mcp_client_dict(self) -> dict[str, Any]:
|
||||
"""Return dict of MCP clients (for backward compatibility)."""
|
||||
return self._mcp_client_dict
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""Initialize MCP clients (stub implementation)."""
|
||||
pass
|
||||
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list[dict[str, Any]],
|
||||
desc: str,
|
||||
handler: Any,
|
||||
) -> None:
|
||||
"""Add a function tool (deprecated, use add() instead)."""
|
||||
params: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param.get("type", "string"),
|
||||
"description": param.get("description", ""),
|
||||
}
|
||||
func = FunctionTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
self.add(func)
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
"""Remove a function tool by name (deprecated, use remove() instead)."""
|
||||
self.remove(name)
|
||||
|
||||
def get_func(self, name: str) -> FunctionTool | None:
|
||||
"""Get a function tool by name."""
|
||||
return super().get_func(name)
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""Get all tool names."""
|
||||
return [f.name for f in self.func_list]
|
||||
|
||||
def remove_tool(self, name: str) -> None:
|
||||
"""Remove a tool by its name (alias for remove)."""
|
||||
self.remove(name)
|
||||
|
||||
def get_func_desc_openai_style(
|
||||
self, omit_empty_parameter_field: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get tools in OpenAI style (deprecated, use get_full_tool_set().openai_schema())."""
|
||||
tool_set = self.get_full_tool_set()
|
||||
return tool_set.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
async def enable_mcp_server(
|
||||
self, name: str, config: dict[str, Any], init_timeout: int = 30
|
||||
) -> None:
|
||||
"""Enable an MCP server (stub implementation)."""
|
||||
pass
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server (stub implementation)."""
|
||||
pass
|
||||
|
||||
def load_mcp_config(self) -> dict[str, Any]:
|
||||
"""Load MCP configuration (stub implementation)."""
|
||||
return {"mcpServers": {}}
|
||||
|
||||
def save_mcp_config(self, config: dict[str, Any]) -> bool:
|
||||
"""Save MCP configuration (stub implementation)."""
|
||||
return True
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
"""Activate an LLM tool (stub implementation)."""
|
||||
return True
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""Deactivate an LLM tool (stub implementation)."""
|
||||
return True
|
||||
|
||||
async def test_mcp_server_connection(
|
||||
self, config: dict[str, Any]
|
||||
) -> tuple[bool, str]:
|
||||
"""Test MCP server connection (stub implementation)."""
|
||||
# Import the actual test function if available
|
||||
try:
|
||||
from astrbot._internal.protocols.mcp.client import (
|
||||
_quick_test_mcp_connection,
|
||||
)
|
||||
|
||||
success, message = await _quick_test_mcp_connection(config)
|
||||
if not success:
|
||||
raise Exception(message)
|
||||
return success, message
|
||||
except Exception as e:
|
||||
raise Exception(f"MCP connection test failed: {e!s}") from e
|
||||
|
||||
async def sync_modelscope_mcp_servers(self) -> None:
|
||||
"""Sync ModelScope MCP servers (stub implementation)."""
|
||||
pass
|
||||
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""Return a ToolSet with all active tools."""
|
||||
return ToolSet("default", [t for t in self.func_list if t.active])
|
||||
@@ -158,6 +158,7 @@ class InternalAgentSubStage(Stage):
|
||||
follow_up_capture: FollowUpCapture | None = None
|
||||
follow_up_consumed_marked = False
|
||||
follow_up_activated = False
|
||||
typing_requested = False
|
||||
try:
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
@@ -192,7 +193,11 @@ class InternalAgentSubStage(Stage):
|
||||
)
|
||||
return
|
||||
|
||||
await event.send_typing()
|
||||
try:
|
||||
typing_requested = True
|
||||
await event.send_typing()
|
||||
except Exception:
|
||||
logger.warning("send_typing failed", exc_info=True)
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
sdk_plugin_bridge = getattr(
|
||||
self.ctx.plugin_manager.context, "sdk_plugin_bridge", None
|
||||
@@ -424,6 +429,11 @@ class InternalAgentSubStage(Stage):
|
||||
)
|
||||
await event.send(MessageChain().message(error_text))
|
||||
finally:
|
||||
if typing_requested:
|
||||
try:
|
||||
await event.stop_typing()
|
||||
except Exception:
|
||||
logger.warning("stop_typing failed", exc_info=True)
|
||||
if follow_up_capture:
|
||||
await finalize_follow_up_capture(
|
||||
follow_up_capture,
|
||||
|
||||
@@ -298,6 +298,12 @@ class AstrMessageEvent(abc.ABC):
|
||||
默认实现为空,由具体平台按需重写。
|
||||
"""
|
||||
|
||||
async def stop_typing(self) -> None:
|
||||
"""停止输入中状态。
|
||||
|
||||
默认实现为空,由具体平台按需重写。
|
||||
"""
|
||||
|
||||
async def _pre_send(self) -> None:
|
||||
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
|
||||
@@ -361,6 +361,18 @@ class TelegramPlatformAdapter(Platform):
|
||||
logger.warning("Received an update without a message.")
|
||||
return None
|
||||
|
||||
def _apply_caption() -> None:
|
||||
if update.message.caption:
|
||||
message.message_str = update.message.caption
|
||||
message.message.append(Comp.Plain(message.message_str))
|
||||
if update.message.caption and update.message.caption_entities:
|
||||
for entity in update.message.caption_entities:
|
||||
if entity.type == "mention":
|
||||
name = update.message.caption[
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
|
||||
message = AstrBotMessage()
|
||||
message.session_id = str(update.message.chat.id)
|
||||
|
||||
@@ -480,16 +492,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
photo = update.message.photo[-1] # get the largest photo
|
||||
file = await photo.get_file()
|
||||
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
|
||||
if update.message.caption:
|
||||
message.message_str = update.message.caption
|
||||
message.message.append(Comp.Plain(message.message_str))
|
||||
if update.message.caption_entities:
|
||||
for entity in update.message.caption_entities:
|
||||
if entity.type == "mention":
|
||||
name = message.message_str[
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
_apply_caption()
|
||||
|
||||
elif update.message.sticker:
|
||||
# 将sticker当作图片处理
|
||||
@@ -512,6 +515,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
message.message.append(
|
||||
Comp.File(file=file_path, name=file_name, url=file_path)
|
||||
)
|
||||
_apply_caption()
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
@@ -523,6 +527,7 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.Video(file=file_path, path=file.file_path))
|
||||
_apply_caption()
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
import aiofiles
|
||||
import anyio
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from wechatpy.exceptions import WeChatClientException
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -96,7 +97,19 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
try:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
except WeChatClientException as e:
|
||||
if getattr(e, "errcode", None) == 40096:
|
||||
# 40096: invalid external userid, fallback to regular message API
|
||||
logger.warning(
|
||||
f"kf API error 40096 for user {user_id}, falling back to regular message API"
|
||||
)
|
||||
self.client.message.send_text(
|
||||
self.get_self_id(), user_id, chunk
|
||||
)
|
||||
else:
|
||||
raise
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
@@ -6,7 +6,7 @@ import hashlib
|
||||
import io
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from urllib.parse import quote
|
||||
@@ -49,6 +49,17 @@ class OpenClawLoginSession:
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypingSessionState:
|
||||
ticket: str | None = None
|
||||
ticket_context_token: str | None = None
|
||||
refresh_after: float = 0.0
|
||||
keepalive_task: asyncio.Task | None = None
|
||||
cancel_task: asyncio.Task | None = None
|
||||
owners: set[str] = field(default_factory=set)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"weixin_oc",
|
||||
"个人微信",
|
||||
@@ -105,7 +116,16 @@ class WeixinOCAdapter(Platform):
|
||||
self._sync_buf = ""
|
||||
self._qr_expired_count = 0
|
||||
self._context_tokens: dict[str, str] = {}
|
||||
self._typing_states: dict[str, TypingSessionState] = {}
|
||||
self._last_inbound_error = ""
|
||||
self._typing_keepalive_interval_s = max(
|
||||
1,
|
||||
int(platform_config.get("weixin_oc_typing_keepalive_interval", 5)),
|
||||
)
|
||||
self._typing_ticket_ttl_s = max(
|
||||
5,
|
||||
int(platform_config.get("weixin_oc_typing_ticket_ttl", 60)),
|
||||
)
|
||||
|
||||
self.token = str(platform_config.get("weixin_oc_token", "")).strip() or None
|
||||
self.account_id = (
|
||||
@@ -132,6 +152,316 @@ class WeixinOCAdapter(Platform):
|
||||
self.client.api_timeout_ms = self.api_timeout_ms
|
||||
self.client.token = self.token
|
||||
|
||||
def _get_typing_state(self, user_id: str) -> TypingSessionState:
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is None:
|
||||
state = TypingSessionState()
|
||||
self._typing_states[user_id] = state
|
||||
return state
|
||||
|
||||
def _typing_supported_for(self, user_id: str) -> bool:
|
||||
if not self.token:
|
||||
return False
|
||||
return bool(self._context_tokens.get(user_id))
|
||||
|
||||
async def _cancel_task_safely(
|
||||
self,
|
||||
task: asyncio.Task | None,
|
||||
*,
|
||||
log_message: str | None = None,
|
||||
log_args: tuple[Any, ...] = (),
|
||||
) -> None:
|
||||
if task is None or task.done():
|
||||
return
|
||||
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if log_message is not None:
|
||||
logger.warning(log_message, *log_args, exc_info=True)
|
||||
|
||||
async def _ensure_typing_ticket(
|
||||
self,
|
||||
user_id: str,
|
||||
state: TypingSessionState,
|
||||
) -> str | None:
|
||||
now = time.monotonic()
|
||||
context_token = self._context_tokens.get(user_id)
|
||||
if not context_token:
|
||||
return None
|
||||
|
||||
if (
|
||||
state.ticket
|
||||
and state.ticket_context_token == context_token
|
||||
and state.refresh_after > now
|
||||
):
|
||||
return state.ticket
|
||||
|
||||
payload = await self.client.get_typing_config(user_id, context_token)
|
||||
if int(payload.get("ret") or 0) != 0:
|
||||
logger.warning(
|
||||
"weixin_oc(%s): getconfig failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
payload.get("errmsg", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
ticket = str(payload.get("typing_ticket", "")).strip()
|
||||
if not ticket:
|
||||
return None
|
||||
|
||||
state.ticket = ticket
|
||||
state.ticket_context_token = context_token
|
||||
state.refresh_after = time.monotonic() + self._typing_ticket_ttl_s
|
||||
return ticket
|
||||
|
||||
async def _send_typing_state(
|
||||
self,
|
||||
user_id: str,
|
||||
ticket: str,
|
||||
*,
|
||||
cancel: bool,
|
||||
) -> None:
|
||||
payload = await self.client.send_typing_state(user_id, ticket, cancel=cancel)
|
||||
if int(payload.get("ret") or 0) != 0:
|
||||
raise RuntimeError(
|
||||
f"sendtyping failed for {user_id}: {payload.get('errmsg', '')}"
|
||||
)
|
||||
|
||||
async def _run_typing_keepalive(self, user_id: str) -> None:
|
||||
restart_needed = False
|
||||
try:
|
||||
await self._typing_keepalive_loop(user_id)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is not None:
|
||||
async with state.lock:
|
||||
state.refresh_after = 0.0
|
||||
restart_needed = (
|
||||
bool(state.owners) and not self._shutdown_event.is_set()
|
||||
)
|
||||
logger.warning(
|
||||
"weixin_oc(%s): typing keepalive failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
state = self._typing_states.get(user_id)
|
||||
current_task = asyncio.current_task()
|
||||
if state is not None and state.keepalive_task is current_task:
|
||||
state.keepalive_task = None
|
||||
|
||||
if not restart_needed:
|
||||
return
|
||||
|
||||
await asyncio.sleep(self._typing_keepalive_interval_s)
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is None or self._shutdown_event.is_set():
|
||||
return
|
||||
|
||||
async with state.lock:
|
||||
if not state.owners or state.keepalive_task is not None:
|
||||
return
|
||||
state.keepalive_task = asyncio.create_task(
|
||||
self._run_typing_keepalive(user_id)
|
||||
)
|
||||
|
||||
async def _typing_keepalive_loop(self, user_id: str) -> None:
|
||||
while not self._shutdown_event.is_set():
|
||||
await asyncio.sleep(self._typing_keepalive_interval_s)
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
async with state.lock:
|
||||
if not state.owners:
|
||||
return
|
||||
try:
|
||||
ticket = await self._ensure_typing_ticket(user_id, state)
|
||||
except Exception as e:
|
||||
state.refresh_after = 0.0
|
||||
logger.warning(
|
||||
"weixin_oc(%s): refresh typing ticket failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
if not ticket:
|
||||
continue
|
||||
try:
|
||||
await self._send_typing_state(user_id, ticket, cancel=False)
|
||||
except Exception as e:
|
||||
state.refresh_after = 0.0
|
||||
logger.warning(
|
||||
"weixin_oc(%s): typing keepalive send failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
|
||||
async def _delayed_cancel_typing(self, user_id: str, ticket: str) -> None:
|
||||
await asyncio.sleep(0)
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
current_task = asyncio.current_task()
|
||||
async with state.lock:
|
||||
if state.cancel_task is not current_task:
|
||||
return
|
||||
if state.owners or state.keepalive_task is not None:
|
||||
state.cancel_task = None
|
||||
return
|
||||
|
||||
try:
|
||||
await self._send_typing_state(user_id, ticket, cancel=True)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"weixin_oc(%s): cancel typing failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is None:
|
||||
return
|
||||
async with state.lock:
|
||||
if state.cancel_task is current_task:
|
||||
state.cancel_task = None
|
||||
|
||||
async def start_typing(self, user_id: str, owner_id: str) -> None:
|
||||
state = self._get_typing_state(user_id)
|
||||
cancel_task: asyncio.Task | None = None
|
||||
async with state.lock:
|
||||
if owner_id in state.owners:
|
||||
return
|
||||
if not self._typing_supported_for(user_id):
|
||||
return
|
||||
if state.cancel_task is not None and not state.cancel_task.done():
|
||||
cancel_task = state.cancel_task
|
||||
cancel_task.cancel()
|
||||
state.cancel_task = None
|
||||
try:
|
||||
ticket = await self._ensure_typing_ticket(user_id, state)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"weixin_oc(%s): ensure typing ticket failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
return
|
||||
if not ticket:
|
||||
return
|
||||
|
||||
state.ticket = ticket
|
||||
state.owners.add(owner_id)
|
||||
if state.keepalive_task is not None and not state.keepalive_task.done():
|
||||
return
|
||||
|
||||
try:
|
||||
await self._send_typing_state(user_id, ticket, cancel=False)
|
||||
except Exception as e:
|
||||
state.refresh_after = 0.0
|
||||
logger.warning(
|
||||
"weixin_oc(%s): send typing failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(self._run_typing_keepalive(user_id))
|
||||
state.keepalive_task = task
|
||||
|
||||
if cancel_task is not None:
|
||||
await self._cancel_task_safely(
|
||||
cancel_task,
|
||||
log_message="weixin_oc(%s): ignored error from cancelled typing task",
|
||||
log_args=(self.meta().id,),
|
||||
)
|
||||
|
||||
async def stop_typing(self, user_id: str, owner_id: str) -> None:
|
||||
state = self._typing_states.get(user_id)
|
||||
if state is None:
|
||||
return
|
||||
|
||||
task: asyncio.Task | None = None
|
||||
async with state.lock:
|
||||
if owner_id not in state.owners:
|
||||
return
|
||||
state.owners.remove(owner_id)
|
||||
|
||||
if state.owners:
|
||||
return
|
||||
|
||||
task = state.keepalive_task
|
||||
state.keepalive_task = None
|
||||
|
||||
await self._cancel_task_safely(
|
||||
task,
|
||||
log_message="weixin_oc(%s): typing keepalive stop failed for %s",
|
||||
log_args=(self.meta().id, user_id),
|
||||
)
|
||||
|
||||
async with state.lock:
|
||||
if state.owners:
|
||||
return
|
||||
ticket = state.ticket
|
||||
if ticket:
|
||||
if state.cancel_task is None or state.cancel_task.done():
|
||||
state.cancel_task = asyncio.create_task(
|
||||
self._delayed_cancel_typing(user_id, ticket)
|
||||
)
|
||||
|
||||
async def _cleanup_typing_tasks(self) -> None:
|
||||
tasks: list[asyncio.Task] = []
|
||||
cancels: list[tuple[str, str]] = []
|
||||
for user_id, state in list(self._typing_states.items()):
|
||||
if state.ticket and (
|
||||
state.owners
|
||||
or state.keepalive_task is not None
|
||||
or state.cancel_task is not None
|
||||
):
|
||||
cancels.append((user_id, state.ticket))
|
||||
state.owners.clear()
|
||||
if state.keepalive_task is not None and not state.keepalive_task.done():
|
||||
tasks.append(state.keepalive_task)
|
||||
state.keepalive_task.cancel()
|
||||
state.keepalive_task = None
|
||||
if state.cancel_task is not None and not state.cancel_task.done():
|
||||
tasks.append(state.cancel_task)
|
||||
state.cancel_task.cancel()
|
||||
state.cancel_task = None
|
||||
|
||||
for task in tasks:
|
||||
await self._cancel_task_safely(
|
||||
task,
|
||||
log_message="weixin_oc(%s): typing cleanup failed",
|
||||
log_args=(self.meta().id,),
|
||||
)
|
||||
|
||||
for user_id, ticket in cancels:
|
||||
try:
|
||||
await self._send_typing_state(user_id, ticket, cancel=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"weixin_oc(%s): typing cleanup cancel failed for %s: %s",
|
||||
self.meta().id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
|
||||
def _load_account_state(self) -> None:
|
||||
if not self.token:
|
||||
token = str(self.config.get("weixin_oc_token", "")).strip()
|
||||
@@ -902,15 +1232,24 @@ class WeixinOCAdapter(Platform):
|
||||
"weixin_oc(%s): inbound long-poll timeout",
|
||||
self.meta().id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"weixin_oc(%s): poll inbound updates failed, will retry after 5 seconds: %s",
|
||||
self.meta().id,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("weixin_oc(%s): run failed: %s", self.meta().id, e)
|
||||
finally:
|
||||
await self._cleanup_typing_tasks()
|
||||
await self.client.close()
|
||||
|
||||
async def terminate(self) -> None:
|
||||
self._shutdown_event.set()
|
||||
await self._cleanup_typing_tasks()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
stat = super().get_stats()
|
||||
|
||||
@@ -226,3 +226,44 @@ class WeixinOCClient:
|
||||
if not text:
|
||||
return {}
|
||||
return cast(dict[str, Any], json.loads(text))
|
||||
|
||||
async def get_typing_config(
|
||||
self,
|
||||
user_id: str,
|
||||
context_token: str,
|
||||
) -> dict[str, Any]:
|
||||
return await self.request_json(
|
||||
"POST",
|
||||
"ilink/bot/getconfig",
|
||||
payload={
|
||||
"ilink_user_id": user_id,
|
||||
"context_token": context_token,
|
||||
"base_info": {
|
||||
"channel_version": "astrbot",
|
||||
},
|
||||
},
|
||||
token_required=True,
|
||||
timeout_ms=self.api_timeout_ms,
|
||||
)
|
||||
|
||||
async def send_typing_state(
|
||||
self,
|
||||
user_id: str,
|
||||
typing_ticket: str,
|
||||
*,
|
||||
cancel: bool,
|
||||
) -> dict[str, Any]:
|
||||
return await self.request_json(
|
||||
"POST",
|
||||
"ilink/bot/sendtyping",
|
||||
payload={
|
||||
"ilink_user_id": user_id,
|
||||
"typing_ticket": typing_ticket,
|
||||
"status": 2 if cancel else 1,
|
||||
"base_info": {
|
||||
"channel_version": "astrbot",
|
||||
},
|
||||
},
|
||||
token_required=True,
|
||||
timeout_ms=self.api_timeout_ms,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -29,6 +30,12 @@ class WeixinOCMessageEvent(AstrMessageEvent):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.platform = platform
|
||||
self._typing_owner_id: str | None = None
|
||||
|
||||
def _get_typing_owner_id(self) -> str:
|
||||
if not self._typing_owner_id:
|
||||
self._typing_owner_id = uuid.uuid4().hex
|
||||
return self._typing_owner_id
|
||||
|
||||
@staticmethod
|
||||
def _segment_to_text(segment: BaseMessageComponent) -> str:
|
||||
@@ -58,6 +65,18 @@ class WeixinOCMessageEvent(AstrMessageEvent):
|
||||
await self.platform.send_by_session(self.session, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_typing(self) -> None:
|
||||
await self.platform.start_typing(
|
||||
self.session.session_id,
|
||||
self._get_typing_owner_id(),
|
||||
)
|
||||
|
||||
async def stop_typing(self) -> None:
|
||||
await self.platform.stop_typing(
|
||||
self.session.session_id,
|
||||
self._get_typing_owner_id(),
|
||||
)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
|
||||
@@ -516,7 +516,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": new_messages, "model": model}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -572,7 +572,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": new_messages, "model": model}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
|
||||
@@ -758,7 +758,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": context_query, "model": model}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -813,7 +813,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {**kwargs, "messages": context_query, "model": model}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
|
||||
@@ -27,8 +27,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
api_base = (
|
||||
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
|
||||
.strip()
|
||||
.rstrip("/")
|
||||
.rstrip("/embeddings")
|
||||
.removesuffix("/")
|
||||
.removesuffix("/embeddings")
|
||||
)
|
||||
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
|
||||
# /v4 see #5699
|
||||
|
||||
20
astrbot/rust/__init__.py
Normal file
20
astrbot/rust/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""AstrBot Rust Core module.
|
||||
|
||||
This module exposes the Rust core functionality via PyO3 bindings.
|
||||
"""
|
||||
|
||||
from ._core import (
|
||||
PyAbpClient,
|
||||
PyOrchestrator,
|
||||
cli,
|
||||
get_abp_client,
|
||||
get_orchestrator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PyAbpClient",
|
||||
"PyOrchestrator",
|
||||
"cli",
|
||||
"get_abp_client",
|
||||
"get_orchestrator",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -173,8 +173,8 @@ export function useSessions(chatboxMode = false) {
|
||||
);
|
||||
const currentSessionDeleted = Boolean(
|
||||
currentSessionId &&
|
||||
sessionIds.includes(currentSessionId) &&
|
||||
!failedSessionIds.has(currentSessionId),
|
||||
sessionIds.includes(currentSessionId) &&
|
||||
!failedSessionIds.has(currentSessionId),
|
||||
);
|
||||
|
||||
if (currentSessionDeleted) {
|
||||
|
||||
@@ -17,6 +17,7 @@ const customizer = useCustomizerStore();
|
||||
const { locale } = useI18n();
|
||||
const route = useRoute();
|
||||
const routerLoadingStore = useRouterLoadingStore();
|
||||
const isCurrentChatRoute = computed(() => route.path === '/chat' || route.path.startsWith('/chat/'));
|
||||
|
||||
const isChatPage = computed(() => {
|
||||
return route.path.startsWith("/chat");
|
||||
@@ -127,7 +128,7 @@ onMounted(() => {
|
||||
<v-container
|
||||
fluid
|
||||
class="page-wrapper"
|
||||
:class="{ 'chat-mode-container': showChatPage }"
|
||||
:class="{ 'chat-mode-container': isCurrentChatRoute }"
|
||||
:style="{
|
||||
height: showChatPage ? '100%' : 'calc(100% - 8px)',
|
||||
padding: isChatPage || showChatPage ? '0' : undefined,
|
||||
|
||||
@@ -525,7 +525,7 @@ watch(
|
||||
},
|
||||
);
|
||||
|
||||
// Merry Christmas! 🎄
|
||||
// Merry Christmas!
|
||||
const isChristmas = computed(() => {
|
||||
const today = new Date();
|
||||
const month = today.getMonth() + 1; // getMonth() 返回 0-11
|
||||
@@ -1250,14 +1250,12 @@ const isChristmas = computed(() => {
|
||||
|
||||
.markdown-content ol {
|
||||
padding-left: 24px;
|
||||
/* Adds indentation to ordered lists */
|
||||
margin-top: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.markdown-content ul {
|
||||
padding-left: 24px;
|
||||
/* Adds indentation to unordered lists */
|
||||
margin-top: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/**
|
||||
* Persona 文件夹管理 Store
|
||||
*/
|
||||
import { defineStore } from 'pinia';
|
||||
import axios from '@/utils/request';
|
||||
import { defineStore } from "pinia";
|
||||
import axios from "@/utils/request";
|
||||
|
||||
// 类型定义
|
||||
export interface PersonaFolder {
|
||||
@@ -39,11 +39,11 @@ export interface FolderTreeNode {
|
||||
|
||||
export interface ReorderItem {
|
||||
id: string;
|
||||
type: 'persona' | 'folder';
|
||||
type: "persona" | "folder";
|
||||
sort_order: number;
|
||||
}
|
||||
|
||||
export const usePersonaStore = defineStore('persona', {
|
||||
export const usePersonaStore = defineStore("persona", {
|
||||
state: () => ({
|
||||
folderTree: [] as FolderTreeNode[],
|
||||
currentFolderId: null as string | null,
|
||||
@@ -59,9 +59,11 @@ export const usePersonaStore = defineStore('persona', {
|
||||
// 当前文件夹名称
|
||||
currentFolderName(): string {
|
||||
if (this.breadcrumbPath.length === 0) {
|
||||
return '根目录';
|
||||
return "根目录";
|
||||
}
|
||||
return this.breadcrumbPath[this.breadcrumbPath.length - 1]?.name || '根目录';
|
||||
return (
|
||||
this.breadcrumbPath[this.breadcrumbPath.length - 1]?.name || "根目录"
|
||||
);
|
||||
},
|
||||
},
|
||||
|
||||
@@ -96,11 +98,11 @@ export const usePersonaStore = defineStore('persona', {
|
||||
async loadFolderTree(): Promise<void> {
|
||||
this.treeLoading = true;
|
||||
try {
|
||||
const response = await axios.get('/api/persona/folder/tree');
|
||||
if (response.data.status === 'ok') {
|
||||
const response = await axios.get("/api/persona/folder/tree");
|
||||
if (response.data.status === "ok") {
|
||||
this.folderTree = response.data.data || [];
|
||||
} else {
|
||||
throw new Error(response.data.message || '获取文件夹树失败');
|
||||
throw new Error(response.data.message || "获取文件夹树失败");
|
||||
}
|
||||
} finally {
|
||||
this.treeLoading = false;
|
||||
@@ -117,19 +119,19 @@ export const usePersonaStore = defineStore('persona', {
|
||||
|
||||
// 并行加载子文件夹和 Persona
|
||||
const [foldersRes, personasRes] = await Promise.all([
|
||||
axios.get('/api/persona/folder/list', {
|
||||
params: { parent_id: folderId ?? '' }
|
||||
axios.get("/api/persona/folder/list", {
|
||||
params: { parent_id: folderId ?? "" },
|
||||
}),
|
||||
axios.get('/api/persona/list', {
|
||||
params: { folder_id: folderId ?? '' }
|
||||
axios.get("/api/persona/list", {
|
||||
params: { folder_id: folderId ?? "" },
|
||||
}),
|
||||
]);
|
||||
|
||||
if (foldersRes.data.status === 'ok') {
|
||||
if (foldersRes.data.status === "ok") {
|
||||
this.currentFolders = foldersRes.data.data || [];
|
||||
}
|
||||
|
||||
if (personasRes.data.status === 'ok') {
|
||||
if (personasRes.data.status === "ok") {
|
||||
this.currentPersonas = personasRes.data.data || [];
|
||||
}
|
||||
|
||||
@@ -179,41 +181,41 @@ export const usePersonaStore = defineStore('persona', {
|
||||
/**
|
||||
* 移动 Persona 到文件夹
|
||||
*/
|
||||
async movePersonaToFolder(personaId: string, targetFolderId: string | null): Promise<void> {
|
||||
const response = await axios.post('/api/persona/move', {
|
||||
async movePersonaToFolder(
|
||||
personaId: string,
|
||||
targetFolderId: string | null,
|
||||
): Promise<void> {
|
||||
const response = await axios.post("/api/persona/move", {
|
||||
persona_id: personaId,
|
||||
folder_id: targetFolderId
|
||||
folder_id: targetFolderId,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '移动人格失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "移动人格失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容和文件夹树
|
||||
await Promise.all([
|
||||
this.refreshCurrentFolder(),
|
||||
this.loadFolderTree(),
|
||||
]);
|
||||
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
|
||||
},
|
||||
|
||||
/**
|
||||
* 移动文件夹到另一个文件夹
|
||||
*/
|
||||
async moveFolderToFolder(folderId: string, targetParentId: string | null): Promise<void> {
|
||||
const response = await axios.post('/api/persona/folder/update', {
|
||||
async moveFolderToFolder(
|
||||
folderId: string,
|
||||
targetParentId: string | null,
|
||||
): Promise<void> {
|
||||
const response = await axios.post("/api/persona/folder/update", {
|
||||
folder_id: folderId,
|
||||
parent_id: targetParentId
|
||||
parent_id: targetParentId,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '移动文件夹失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "移动文件夹失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容和文件夹树
|
||||
await Promise.all([
|
||||
this.refreshCurrentFolder(),
|
||||
this.loadFolderTree(),
|
||||
]);
|
||||
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
|
||||
},
|
||||
|
||||
/**
|
||||
@@ -224,20 +226,17 @@ export const usePersonaStore = defineStore('persona', {
|
||||
parent_id?: string | null;
|
||||
description?: string;
|
||||
}): Promise<PersonaFolder> {
|
||||
const response = await axios.post('/api/persona/folder/create', {
|
||||
const response = await axios.post("/api/persona/folder/create", {
|
||||
...data,
|
||||
parent_id: data.parent_id ?? this.currentFolderId,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '创建文件夹失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "创建文件夹失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容和文件夹树
|
||||
await Promise.all([
|
||||
this.refreshCurrentFolder(),
|
||||
this.loadFolderTree(),
|
||||
]);
|
||||
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
|
||||
|
||||
return response.data.data.folder;
|
||||
},
|
||||
@@ -250,48 +249,42 @@ export const usePersonaStore = defineStore('persona', {
|
||||
name?: string;
|
||||
description?: string;
|
||||
}): Promise<void> {
|
||||
const response = await axios.post('/api/persona/folder/update', data);
|
||||
const response = await axios.post("/api/persona/folder/update", data);
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '更新文件夹失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "更新文件夹失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容和文件夹树
|
||||
await Promise.all([
|
||||
this.refreshCurrentFolder(),
|
||||
this.loadFolderTree(),
|
||||
]);
|
||||
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除文件夹
|
||||
*/
|
||||
async deleteFolder(folderId: string): Promise<void> {
|
||||
const response = await axios.post('/api/persona/folder/delete', {
|
||||
folder_id: folderId
|
||||
const response = await axios.post("/api/persona/folder/delete", {
|
||||
folder_id: folderId,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '删除文件夹失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "删除文件夹失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容和文件夹树
|
||||
await Promise.all([
|
||||
this.refreshCurrentFolder(),
|
||||
this.loadFolderTree(),
|
||||
]);
|
||||
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除 Persona
|
||||
*/
|
||||
async deletePersona(personaId: string): Promise<void> {
|
||||
const response = await axios.post('/api/persona/delete', {
|
||||
persona_id: personaId
|
||||
const response = await axios.post("/api/persona/delete", {
|
||||
persona_id: personaId,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '删除人格失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "删除人格失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容
|
||||
@@ -301,14 +294,17 @@ export const usePersonaStore = defineStore('persona', {
|
||||
/**
|
||||
* 克隆 Persona
|
||||
*/
|
||||
async clonePersona(sourcePersonaId: string, newPersonaId: string): Promise<Persona> {
|
||||
const response = await axios.post('/api/persona/clone', {
|
||||
async clonePersona(
|
||||
sourcePersonaId: string,
|
||||
newPersonaId: string,
|
||||
): Promise<Persona> {
|
||||
const response = await axios.post("/api/persona/clone", {
|
||||
source_persona_id: sourcePersonaId,
|
||||
new_persona_id: newPersonaId
|
||||
new_persona_id: newPersonaId,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '克隆人格失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "克隆人格失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容
|
||||
@@ -321,10 +317,10 @@ export const usePersonaStore = defineStore('persona', {
|
||||
* 批量更新排序
|
||||
*/
|
||||
async reorderItems(items: ReorderItem[]): Promise<void> {
|
||||
const response = await axios.post('/api/persona/reorder', { items });
|
||||
const response = await axios.post("/api/persona/reorder", { items });
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '更新排序失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "更新排序失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容
|
||||
@@ -354,7 +350,7 @@ export const usePersonaStore = defineStore('persona', {
|
||||
* 导入人格数据
|
||||
*/
|
||||
async importPersona(data: Partial<Persona>): Promise<Persona> {
|
||||
const response = await axios.post('/api/persona/create', {
|
||||
const response = await axios.post("/api/persona/create", {
|
||||
persona_id: data.persona_id,
|
||||
system_prompt: data.system_prompt,
|
||||
begin_dialogs: data.begin_dialogs || [],
|
||||
@@ -362,8 +358,8 @@ export const usePersonaStore = defineStore('persona', {
|
||||
skills: data.skills,
|
||||
});
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message || '导入人格失败');
|
||||
if (response.data.status !== "ok") {
|
||||
throw new Error(response.data.message || "导入人格失败");
|
||||
}
|
||||
|
||||
// 刷新当前文件夹内容
|
||||
@@ -371,5 +367,5 @@ export const usePersonaStore = defineStore('persona', {
|
||||
|
||||
return response.data.data.persona;
|
||||
},
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<template>
|
||||
<v-card
|
||||
class="persona-card"
|
||||
:class="{ 'dragging': isDragging }"
|
||||
:class="{ dragging: isDragging }"
|
||||
rounded="lg"
|
||||
elevation="1"
|
||||
hover
|
||||
@@ -27,50 +27,38 @@
|
||||
<v-list density="compact">
|
||||
<v-list-item @click.stop="$emit('edit')">
|
||||
<template #prepend>
|
||||
<v-icon size="small">
|
||||
mdi-pencil
|
||||
</v-icon>
|
||||
<v-icon size="small"> mdi-pencil </v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ tm('buttons.edit') }}</v-list-item-title>
|
||||
<v-list-item-title>{{ tm("buttons.edit") }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click.stop="$emit('clone')">
|
||||
<template #prepend>
|
||||
<v-icon size="small">
|
||||
mdi-content-copy
|
||||
</v-icon>
|
||||
<v-icon size="small"> mdi-content-copy </v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ tm('buttons.clone') }}</v-list-item-title>
|
||||
<v-list-item-title>{{ tm("buttons.clone") }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click.stop="$emit('move')">
|
||||
<template #prepend>
|
||||
<v-icon size="small">
|
||||
mdi-folder-move
|
||||
</v-icon>
|
||||
<v-icon size="small"> mdi-folder-move </v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ tm('persona.contextMenu.moveTo') }}</v-list-item-title>
|
||||
<v-list-item-title>{{
|
||||
tm("persona.contextMenu.moveTo")
|
||||
}}</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click.stop="$emit('export')">
|
||||
<template #prepend>
|
||||
<v-icon size="small">
|
||||
mdi-download
|
||||
</v-icon>
|
||||
<v-icon size="small"> mdi-download </v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ tm('persona.contextMenu.export') }}</v-list-item-title>
|
||||
<v-list-item-title>{{
|
||||
tm("persona.contextMenu.export")
|
||||
}}</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-divider class="my-1" />
|
||||
<v-list-item
|
||||
class="text-error"
|
||||
@click.stop="$emit('delete')"
|
||||
>
|
||||
<v-list-item class="text-error" @click.stop="$emit('delete')">
|
||||
<template #prepend>
|
||||
<v-icon
|
||||
size="small"
|
||||
color="error"
|
||||
>
|
||||
mdi-delete
|
||||
</v-icon>
|
||||
<v-icon size="small" color="error"> mdi-delete </v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ tm('buttons.delete') }}</v-list-item-title>
|
||||
<v-list-item-title>{{ tm("buttons.delete") }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
@@ -89,7 +77,11 @@
|
||||
variant="tonal"
|
||||
prepend-icon="mdi-chat"
|
||||
>
|
||||
{{ tm('labels.presetDialogs', { count: persona.begin_dialogs.length / 2 }) }}
|
||||
{{
|
||||
tm("labels.presetDialogs", {
|
||||
count: persona.begin_dialogs.length / 2,
|
||||
})
|
||||
}}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-if="persona.tools === null"
|
||||
@@ -98,7 +90,7 @@
|
||||
variant="tonal"
|
||||
prepend-icon="mdi-tools"
|
||||
>
|
||||
{{ tm('form.allToolsAvailable') }}
|
||||
{{ tm("form.allToolsAvailable") }}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-else-if="persona.tools && persona.tools.length > 0"
|
||||
@@ -107,7 +99,7 @@
|
||||
variant="tonal"
|
||||
prepend-icon="mdi-tools"
|
||||
>
|
||||
{{ persona.tools.length }} {{ tm('persona.toolsCount') }}
|
||||
{{ persona.tools.length }} {{ tm("persona.toolsCount") }}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-if="persona.skills === null"
|
||||
@@ -116,7 +108,7 @@
|
||||
variant="tonal"
|
||||
prepend-icon="mdi-lightning-bolt"
|
||||
>
|
||||
{{ tm('form.allSkillsAvailable') }}
|
||||
{{ tm("form.allSkillsAvailable") }}
|
||||
</v-chip>
|
||||
<v-chip
|
||||
v-else-if="persona.skills && persona.skills.length > 0"
|
||||
@@ -125,142 +117,139 @@
|
||||
variant="tonal"
|
||||
prepend-icon="mdi-lightning-bolt"
|
||||
>
|
||||
{{ persona.skills.length }} {{ tm('persona.skillsCount') }}
|
||||
{{ persona.skills.length }} {{ tm("persona.skillsCount") }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<div class="mt-3 text-caption text-medium-emphasis">
|
||||
{{ tm('labels.createdAt') }}: {{ formatDate(persona.created_at) }}
|
||||
{{ tm("labels.createdAt") }}: {{ formatDate(persona.created_at) }}
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
<!-- Custom Drag Preview -->
|
||||
<div
|
||||
ref="dragPreview"
|
||||
class="drag-preview"
|
||||
>
|
||||
<v-icon
|
||||
size="small"
|
||||
class="mr-2"
|
||||
>
|
||||
mdi-account
|
||||
</v-icon>
|
||||
<div ref="dragPreview" class="drag-preview">
|
||||
<v-icon size="small" class="mr-2"> mdi-account </v-icon>
|
||||
<span class="text-subtitle-2">{{ persona.persona_id }}</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script lang="ts">
|
||||
import { defineComponent, type PropType } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { defineComponent, type PropType } from "vue";
|
||||
import { useModuleI18n } from "@/i18n/composables";
|
||||
|
||||
interface Persona {
|
||||
persona_id: string;
|
||||
system_prompt: string;
|
||||
custom_error_message?: string | null;
|
||||
begin_dialogs?: string[] | null;
|
||||
tools?: string[] | null;
|
||||
skills?: string[] | null;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
folder_id?: string | null;
|
||||
[key: string]: any;
|
||||
persona_id: string;
|
||||
system_prompt: string;
|
||||
custom_error_message?: string | null;
|
||||
begin_dialogs?: string[] | null;
|
||||
tools?: string[] | null;
|
||||
skills?: string[] | null;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
folder_id?: string | null;
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
export default defineComponent({
|
||||
name: 'PersonaCard',
|
||||
props: {
|
||||
persona: {
|
||||
type: Object as PropType<Persona>,
|
||||
required: true
|
||||
}
|
||||
name: "PersonaCard",
|
||||
props: {
|
||||
persona: {
|
||||
type: Object as PropType<Persona>,
|
||||
required: true,
|
||||
},
|
||||
emits: ['view', 'edit', 'clone', 'move', 'export', 'delete'],
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/persona');
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
isDragging: false
|
||||
};
|
||||
},
|
||||
methods: {
|
||||
handleDragStart(event: DragEvent) {
|
||||
this.isDragging = true;
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = 'move';
|
||||
event.dataTransfer.setData('application/json', JSON.stringify({
|
||||
type: 'persona',
|
||||
persona_id: this.persona.persona_id,
|
||||
persona: this.persona
|
||||
}));
|
||||
},
|
||||
emits: ["view", "edit", "clone", "move", "export", "delete"],
|
||||
setup() {
|
||||
const { tm } = useModuleI18n("features/persona");
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
isDragging: false,
|
||||
};
|
||||
},
|
||||
methods: {
|
||||
handleDragStart(event: DragEvent) {
|
||||
this.isDragging = true;
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = "move";
|
||||
event.dataTransfer.setData(
|
||||
"application/json",
|
||||
JSON.stringify({
|
||||
type: "persona",
|
||||
persona_id: this.persona.persona_id,
|
||||
persona: this.persona,
|
||||
}),
|
||||
);
|
||||
|
||||
// Set custom drag image
|
||||
const dragPreview = this.$refs.dragPreview as HTMLElement;
|
||||
if (dragPreview) {
|
||||
event.dataTransfer.setDragImage(dragPreview, 15, 15);
|
||||
}
|
||||
}
|
||||
},
|
||||
handleDragEnd() {
|
||||
this.isDragging = false;
|
||||
},
|
||||
truncateText(text: string | undefined | null, maxLength: number): string {
|
||||
if (!text) return '';
|
||||
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
|
||||
},
|
||||
formatDate(dateString: string | undefined | null): string {
|
||||
if (!dateString) return '';
|
||||
return new Date(dateString).toLocaleString();
|
||||
// Set custom drag image
|
||||
const dragPreview = this.$refs.dragPreview as HTMLElement;
|
||||
if (dragPreview) {
|
||||
event.dataTransfer.setDragImage(dragPreview, 15, 15);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
handleDragEnd() {
|
||||
this.isDragging = false;
|
||||
},
|
||||
truncateText(text: string | undefined | null, maxLength: number): string {
|
||||
if (!text) return "";
|
||||
return text.length > maxLength
|
||||
? text.substring(0, maxLength) + "..."
|
||||
: text;
|
||||
},
|
||||
formatDate(dateString: string | undefined | null): string {
|
||||
if (!dateString) return "";
|
||||
return new Date(dateString).toLocaleString();
|
||||
},
|
||||
},
|
||||
});
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.persona-card {
|
||||
height: 100%;
|
||||
cursor: grab;
|
||||
transition: all 0.2s ease;
|
||||
height: 100%;
|
||||
cursor: grab;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.persona-card:active {
|
||||
cursor: grabbing;
|
||||
cursor: grabbing;
|
||||
}
|
||||
|
||||
.persona-card.dragging {
|
||||
opacity: 0.5;
|
||||
transform: scale(0.95);
|
||||
opacity: 0.5;
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.persona-card:hover {
|
||||
transform: translateY(-2px);
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.system-prompt-preview {
|
||||
font-size: 14px;
|
||||
line-height: 1.4;
|
||||
color: rgba(var(--v-theme-on-surface), 0.7);
|
||||
overflow: hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-line-clamp: 3;
|
||||
line-clamp: 3;
|
||||
-webkit-box-orient: vertical;
|
||||
font-size: 14px;
|
||||
line-height: 1.4;
|
||||
color: rgba(var(--v-theme-on-surface), 0.7);
|
||||
overflow: hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-line-clamp: 3;
|
||||
line-clamp: 3;
|
||||
-webkit-box-orient: vertical;
|
||||
}
|
||||
|
||||
.drag-preview {
|
||||
position: fixed;
|
||||
top: -1000px;
|
||||
left: -1000px;
|
||||
background: rgb(var(--v-theme-surface));
|
||||
padding: 12px 20px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
border: 1px solid rgba(var(--v-border-color), var(--v-border-opacity));
|
||||
z-index: 9999;
|
||||
pointer-events: none;
|
||||
position: fixed;
|
||||
top: -1000px;
|
||||
left: -1000px;
|
||||
background: rgb(var(--v-theme-surface));
|
||||
padding: 12px 20px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
border: 1px solid rgba(var(--v-border-color), var(--v-border-opacity));
|
||||
z-index: 9999;
|
||||
pointer-events: none;
|
||||
}
|
||||
</style>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
3
tests/fixtures/mocks/telegram.py
vendored
3
tests/fixtures/mocks/telegram.py
vendored
@@ -33,7 +33,8 @@ def create_mock_telegram_modules():
|
||||
|
||||
mock_telegram_ext = MagicMock()
|
||||
mock_telegram_ext.ApplicationBuilder = MagicMock
|
||||
mock_telegram_ext.ContextTypes = MagicMock
|
||||
mock_telegram_ext.ContextTypes = MagicMock()
|
||||
mock_telegram_ext.ContextTypes.DEFAULT_TYPE = MagicMock
|
||||
mock_telegram_ext.ExtBot = MagicMock
|
||||
mock_telegram_ext.filters = MagicMock()
|
||||
mock_telegram_ext.filters.ALL = MagicMock()
|
||||
|
||||
108
tests/test_telegram_adapter.py
Normal file
108
tests/test_telegram_adapter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from tests.fixtures.helpers import (
|
||||
create_mock_file,
|
||||
create_mock_update,
|
||||
make_platform_config,
|
||||
)
|
||||
from tests.fixtures.mocks.telegram import create_mock_telegram_modules
|
||||
|
||||
_TELEGRAM_PLATFORM_ADAPTER = None
|
||||
|
||||
|
||||
def _load_telegram_adapter():
|
||||
global _TELEGRAM_PLATFORM_ADAPTER
|
||||
if _TELEGRAM_PLATFORM_ADAPTER is not None:
|
||||
return _TELEGRAM_PLATFORM_ADAPTER
|
||||
|
||||
mocks = create_mock_telegram_modules()
|
||||
patched_modules = {
|
||||
"telegram": mocks["telegram"],
|
||||
"telegram.constants": mocks["telegram"].constants,
|
||||
"telegram.error": mocks["telegram"].error,
|
||||
"telegram.ext": mocks["telegram.ext"],
|
||||
"telegramify_markdown": mocks["telegramify_markdown"],
|
||||
"apscheduler": mocks["apscheduler"],
|
||||
"apscheduler.schedulers": mocks["apscheduler"].schedulers,
|
||||
"apscheduler.schedulers.asyncio": mocks["apscheduler"].schedulers.asyncio,
|
||||
"apscheduler.schedulers.background": mocks["apscheduler"].schedulers.background,
|
||||
}
|
||||
with patch.dict(sys.modules, patched_modules):
|
||||
sys.modules.pop("astrbot.core.platform.sources.telegram.tg_adapter", None)
|
||||
module = importlib.import_module("astrbot.core.platform.sources.telegram.tg_adapter")
|
||||
_TELEGRAM_PLATFORM_ADAPTER = module.TelegramPlatformAdapter
|
||||
return _TELEGRAM_PLATFORM_ADAPTER
|
||||
|
||||
|
||||
def _build_context() -> MagicMock:
|
||||
context = MagicMock()
|
||||
context.bot.username = "test_bot"
|
||||
context.bot.id = 12345678
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telegram_document_caption_populates_message_text_and_plain():
|
||||
TelegramPlatformAdapter = _load_telegram_adapter()
|
||||
adapter = TelegramPlatformAdapter(
|
||||
make_platform_config("telegram"),
|
||||
{},
|
||||
asyncio.Queue(),
|
||||
)
|
||||
document = create_mock_file("https://api.telegram.org/file/test/report.md")
|
||||
document.file_name = "report.md"
|
||||
mention = MagicMock(type="mention", offset=0, length=6)
|
||||
update = create_mock_update(
|
||||
message_text=None,
|
||||
document=document,
|
||||
caption="@alice 请总结这份文档",
|
||||
caption_entities=[mention],
|
||||
)
|
||||
|
||||
result = await adapter.convert_message(update, _build_context())
|
||||
|
||||
assert result is not None
|
||||
assert result.message_str == "@alice 请总结这份文档"
|
||||
assert any(isinstance(component, Comp.File) for component in result.message)
|
||||
assert any(
|
||||
isinstance(component, Comp.Plain)
|
||||
and component.text == "@alice 请总结这份文档"
|
||||
for component in result.message
|
||||
)
|
||||
assert any(
|
||||
isinstance(component, Comp.At) and component.qq == "alice"
|
||||
for component in result.message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_telegram_video_caption_populates_message_text_and_plain():
|
||||
TelegramPlatformAdapter = _load_telegram_adapter()
|
||||
adapter = TelegramPlatformAdapter(
|
||||
make_platform_config("telegram"),
|
||||
{},
|
||||
asyncio.Queue(),
|
||||
)
|
||||
video = create_mock_file("https://api.telegram.org/file/test/lesson.mp4")
|
||||
video.file_name = "lesson.mp4"
|
||||
update = create_mock_update(
|
||||
message_text=None,
|
||||
video=video,
|
||||
caption="这段视频讲了什么",
|
||||
)
|
||||
|
||||
result = await adapter.convert_message(update, _build_context())
|
||||
|
||||
assert result is not None
|
||||
assert result.message_str == "这段视频讲了什么"
|
||||
assert any(isinstance(component, Comp.Video) for component in result.message)
|
||||
assert any(
|
||||
isinstance(component, Comp.Plain) and component.text == "这段视频讲了什么"
|
||||
for component in result.message
|
||||
)
|
||||
@@ -651,6 +651,15 @@ class TestSendTyping:
|
||||
await astr_message_event.send_typing()
|
||||
|
||||
|
||||
class TestStopTyping:
|
||||
"""Tests for stop_typing method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_typing_default_empty(self, astr_message_event):
|
||||
"""Test stop_typing default implementation is empty."""
|
||||
await astr_message_event.stop_typing()
|
||||
|
||||
|
||||
class TestReact:
|
||||
"""Tests for react method."""
|
||||
|
||||
@@ -772,10 +781,12 @@ class TestDefensiveGetattr:
|
||||
|
||||
def test_get_message_type_with_non_enum_type(self, astr_message_event):
|
||||
"""get_message_type should handle message_obj.type that is not a MessageType."""
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self):
|
||||
self.type = "not_an_enum"
|
||||
self.message = []
|
||||
|
||||
astr_message_event.message_obj = DummyMessage()
|
||||
message_type = astr_message_event.get_message_type()
|
||||
assert isinstance(message_type, MessageType)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -22,113 +22,184 @@ class FakeReader(ByteReceiveStream):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lsp_read_responses_failure_disconnects_and_logs():
|
||||
"""Test reader failures are handled inside _read_responses."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
client._reader = FakeReader(AsyncMock(side_effect=RuntimeError("reader crashed")))
|
||||
class TestAstrbotLspClientInitialState:
|
||||
"""Test LSP client initial state."""
|
||||
|
||||
with patch("astrbot._internal.protocols.lsp.client.log") as mock_log:
|
||||
await client._read_responses()
|
||||
|
||||
assert client.connected is False
|
||||
mock_log.error.assert_called_once()
|
||||
def test_client_initial_state(self) -> None:
|
||||
"""Test client starts disconnected."""
|
||||
client = AstrbotLspClient()
|
||||
assert client.connected is False
|
||||
assert client._reader is None
|
||||
assert client._writer is None
|
||||
assert client._task_group is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lsp_read_responses_unexpected_exit_disconnects_and_warns():
|
||||
"""Test non-cancelled reader exit updates connection state."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
client._reader = FakeReader(AsyncMock(return_value=b""))
|
||||
class TestAstrbotLspClientConnect:
|
||||
"""Test LSP client connect method."""
|
||||
|
||||
with patch("astrbot._internal.protocols.lsp.client.log") as mock_log:
|
||||
await client._read_responses()
|
||||
|
||||
assert client.connected is False
|
||||
mock_log.warning.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lsp_read_responses_clears_reader_task_reference_on_exit():
|
||||
"""Test _read_responses clears the stored task reference when it exits."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
client._reader = FakeReader(AsyncMock(return_value=b""))
|
||||
|
||||
task = asyncio.create_task(client._read_responses())
|
||||
client._reader_task = task
|
||||
|
||||
await task
|
||||
|
||||
assert client._reader_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lsp_stop_reader_task_swallows_failed_reader_exceptions():
|
||||
"""Test reader teardown does not re-raise prior reader failures."""
|
||||
client = AstrbotLspClient()
|
||||
|
||||
async def fail_reader() -> None:
|
||||
raise RuntimeError("reader crashed")
|
||||
|
||||
client._reader_task = asyncio.create_task(fail_reader())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await client._stop_reader_task()
|
||||
assert client._reader_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lsp_connect_to_server_cancels_previous_reader_task_before_restart():
|
||||
"""Test reconnect tears down an existing reader task before replacing it."""
|
||||
client = AstrbotLspClient()
|
||||
fake_process = SimpleNamespace(stdout=MagicMock(), stdin=MagicMock())
|
||||
first_reader_cancelled = asyncio.Event()
|
||||
first_reader_started = asyncio.Event()
|
||||
|
||||
async def first_reader() -> None:
|
||||
first_reader_started.set()
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
first_reader_cancelled.set()
|
||||
raise
|
||||
|
||||
with (
|
||||
patch(
|
||||
"astrbot._internal.protocols.lsp.client.anyio.open_process",
|
||||
AsyncMock(return_value=fake_process),
|
||||
),
|
||||
patch.object(client, "send_request", AsyncMock(return_value={})),
|
||||
patch.object(client, "send_notification", AsyncMock()),
|
||||
):
|
||||
client._read_responses = first_reader # type: ignore[method-assign]
|
||||
await client.connect_to_server(["python", "first_lsp.py"], "file:///tmp")
|
||||
await asyncio.wait_for(first_reader_started.wait(), timeout=1)
|
||||
assert client.connected is True
|
||||
|
||||
second_reader = AsyncMock(return_value=None)
|
||||
client._read_responses = second_reader # type: ignore[method-assign]
|
||||
await client.connect_to_server(["python", "second_lsp.py"], "file:///tmp")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert first_reader_cancelled.is_set() is True
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_sets_connected_true(self) -> None:
|
||||
"""Test connect() sets connected state."""
|
||||
client = AstrbotLspClient()
|
||||
await client.connect()
|
||||
assert client.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lsp_stop_reader_task_does_not_await_current_task():
|
||||
"""Test stopping the reader from within itself does not self-await."""
|
||||
client = AstrbotLspClient()
|
||||
done = asyncio.Event()
|
||||
class TestAstrbotLspClientSendRequest:
|
||||
"""Test LSP client send_request method."""
|
||||
|
||||
async def stop_self() -> None:
|
||||
client._reader_task = asyncio.current_task()
|
||||
await client._stop_reader_task()
|
||||
done.set()
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_request_requires_connection(self) -> None:
|
||||
"""Test send_request raises when not connected."""
|
||||
client = AstrbotLspClient()
|
||||
with pytest.raises(RuntimeError, match="not connected"):
|
||||
await client.send_request("initialize", {})
|
||||
|
||||
task = asyncio.create_task(stop_self())
|
||||
await asyncio.wait_for(done.wait(), timeout=1)
|
||||
await task
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_request_formats_jsonrpc_correctly(self) -> None:
|
||||
"""Test send_request formats message as JSON-RPC 2.0."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
mock_writer = AsyncMock()
|
||||
mock_event = MagicMock()
|
||||
mock_event.wait = AsyncMock()
|
||||
client._writer = mock_writer
|
||||
client._pending_requests[0] = AsyncMock()
|
||||
|
||||
with patch("astrbot._internal.protocols.lsp.client.anyio.Event", return_value=mock_event):
|
||||
# Timeout immediately to avoid hanging
|
||||
with pytest.raises(TimeoutError, match="timed out"):
|
||||
await client.send_request("initialize", {"processId": None})
|
||||
|
||||
|
||||
class TestAstrbotLspClientSendNotification:
|
||||
"""Test LSP client send_notification method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_notification_requires_connection(self) -> None:
|
||||
"""Test send_notification raises when not connected."""
|
||||
client = AstrbotLspClient()
|
||||
with pytest.raises(RuntimeError, match="not connected"):
|
||||
await client.send_notification("initialized", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_notification_formats_jsonrpc_correctly(self) -> None:
|
||||
"""Test send_notification formats message as JSON-RPC 2.0."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
mock_writer = AsyncMock()
|
||||
client._writer = mock_writer
|
||||
|
||||
await client.send_notification("initialized", {})
|
||||
|
||||
mock_writer.send.assert_called_once()
|
||||
data = mock_writer.send.call_args[0][0]
|
||||
decoded = data.decode("utf-8")
|
||||
assert "Content-Length:" in decoded
|
||||
assert '"jsonrpc": "2.0"' in decoded
|
||||
assert '"method": "initialized"' in decoded
|
||||
|
||||
|
||||
class TestAstrbotLspClientShutdown:
|
||||
"""Test LSP client shutdown method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_sets_connected_false(self) -> None:
|
||||
"""Test shutdown disconnects the client."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
client._task_group = MagicMock()
|
||||
client._task_group.__aexit__ = AsyncMock()
|
||||
client._server_process = MagicMock()
|
||||
client._server_process.terminate = MagicMock()
|
||||
client._server_process.wait = AsyncMock()
|
||||
client._server_process.kill = MagicMock()
|
||||
client.send_notification = AsyncMock()
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
assert client.connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_clears_pending_requests(self) -> None:
|
||||
"""Test shutdown clears pending requests."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
client._pending_requests[1] = AsyncMock()
|
||||
client._task_group = MagicMock()
|
||||
client._task_group.__aexit__ = AsyncMock()
|
||||
client._server_process = None
|
||||
|
||||
await client.shutdown()
|
||||
|
||||
assert len(client._pending_requests) == 0
|
||||
|
||||
|
||||
class TestAstrbotLspClientReadResponses:
|
||||
"""Test LSP client _read_responses method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_responses_returns_immediately_if_no_reader(self) -> None:
|
||||
"""Test _read_responses exits early when _reader is None."""
|
||||
client = AstrbotLspClient()
|
||||
client._reader = None
|
||||
client._connected = True
|
||||
|
||||
await client._read_responses()
|
||||
|
||||
# Should return without error
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_responses_handles_empty_data_as_eof(self) -> None:
|
||||
"""Test _read_responses breaks on empty data (EOF)."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
client._reader = FakeReader(AsyncMock(return_value=b""))
|
||||
client._pending_requests = {}
|
||||
|
||||
# Should exit cleanly without raising
|
||||
await client._read_responses()
|
||||
|
||||
assert client._connected is True # Note: current impl doesn't auto-disconnect on EOF
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_responses_parses_jsonrpc_response(self) -> None:
|
||||
"""Test _read_responses parses and dispatches JSON-RPC responses."""
|
||||
client = AstrbotLspClient()
|
||||
client._connected = True
|
||||
|
||||
response = {"jsonrpc": "2.0", "id": 0, "result": {}}
|
||||
content = json.dumps(response).encode()
|
||||
header = f"Content-Length: {len(content)}\r\n\r\n".encode()
|
||||
|
||||
# First call returns the message, second call returns empty (EOF)
|
||||
fake_reader = FakeReader(AsyncMock(side_effect=[header + content, b""]))
|
||||
client._reader = fake_reader
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def handler(resp: dict) -> None:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
|
||||
client._pending_requests[0] = handler
|
||||
|
||||
await client._read_responses()
|
||||
|
||||
assert handler_called is True
|
||||
|
||||
|
||||
class TestAstrbotLspClientHandleNotification:
|
||||
"""Test LSP client _handle_notification method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_notification_logs_method_name(self) -> None:
|
||||
"""Test _handle_notification logs the notification method."""
|
||||
client = AstrbotLspClient()
|
||||
notification = {"jsonrpc": "2.0", "method": "window/showMessage", "params": {}}
|
||||
|
||||
with patch("astrbot._internal.protocols.lsp.client.log") as mock_log:
|
||||
await client._handle_notification(notification)
|
||||
|
||||
mock_log.debug.assert_called()
|
||||
|
||||
Reference in New Issue
Block a user