merge: pull latest master into dev

Resolved conflicts:
- openai_source.py: keep dev version with abort_signal filtering
- customizer.ts: keep dev version with viewMode functionality
- useSessions.ts: keep dev version with pendingSessionId handling
- platformUtils.js: keep dev version with correct tutorial links
- AddNewPlatform.vue: keep dev version with correct docs link
- FullLayout.vue: keep dev version with viewMode-based logic
- VerticalHeader.vue: keep dev version with viewMode-based logic
This commit is contained in:
LIghtJUNction
2026-03-28 12:14:10 +08:00
57 changed files with 6394 additions and 1948 deletions

View File

@@ -0,0 +1,57 @@
"""
ABP (AstrBot Protocol) client - in-process star communication.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAstrbotAbpClient(ABC):
"""
ABP client: in-process star (plugin) communication.
Stars register themselves; client delegates calls to registered instances.
Subclass must implement:
- connect() -> None
- register_star(name, instance) -> None
- unregister_star(name) -> None
- call_star_tool(star, tool, args) -> Any
- shutdown() -> None
"""
@property
@abstractmethod
def connected(self) -> bool: ...
@abstractmethod
async def connect(self) -> None:
"""Lightweight: just sets connected=True."""
...
@abstractmethod
def register_star(self, star_name: str, star_instance: Any) -> None:
"""Add star to internal registry."""
...
@abstractmethod
def unregister_star(self, star_name: str) -> None:
"""Remove star from registry (idempotent)."""
...
@abstractmethod
async def call_star_tool(
self,
star_name: str,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""Delegate to star_instance.call_tool(tool_name, arguments)."""
...
@abstractmethod
async def shutdown(self) -> None:
"""Set connected=False, cancel pending requests."""
...

View File

@@ -0,0 +1,66 @@
"""
ACP (AstrBot Communication Protocol) client.
Transport: TCP | Unix Socket
Messages: JSON with Content-Length header
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAstrbotAcpClient(ABC):
"""
ACP client: connects to ACP servers via TCP or Unix socket.
Subclass must implement:
- connect() -> None
- connect_to_server(host, port) -> None
- connect_to_unix_socket(path) -> None
- call_tool(server, tool, args) -> Any
- send_notification(method, params) -> None
- shutdown() -> None
"""
@property
@abstractmethod
def connected(self) -> bool: ...
@abstractmethod
async def connect(self) -> None: ...
@abstractmethod
async def connect_to_server(self, host: str, port: int) -> None:
"""Connect via TCP."""
...
@abstractmethod
async def connect_to_unix_socket(self, socket_path: str) -> None:
"""Connect via Unix domain socket."""
...
@abstractmethod
async def call_tool(
self,
server_name: str,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""Call tool on server, return result."""
...
@abstractmethod
async def send_notification(
self,
method: str,
params: dict[str, Any],
) -> None:
"""Send one-way notification."""
...
@abstractmethod
async def shutdown(self) -> None:
"""Close connection, cancel pending requests."""
...

View File

@@ -0,0 +1,68 @@
"""
ACP (AstrBot Communication Protocol) server.
Transport: TCP listening socket
Messages: JSON with Content-Length header
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
class BaseAstrbotAcpServer(ABC):
"""
ACP server: listens for client connections, exposes tools.
Subclass must implement:
- start(host, port) -> None
- register_tool(name, handler) -> None
- register_notification_handler(name, handler) -> None
- broadcast_notification(method, params) -> None
- shutdown() -> None
"""
@property
@abstractmethod
def running(self) -> bool:
"""True if server is accepting connections."""
...
@abstractmethod
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
"""Bind and listen. Block until shutdown."""
...
@abstractmethod
def register_tool(
self,
name: str,
handler: Callable[..., Any],
) -> None:
"""Register async tool handler (receives params dict, returns result)."""
...
@abstractmethod
def register_notification_handler(
self,
name: str,
handler: Callable[..., Any],
) -> None:
"""Register async notification handler (receives params dict)."""
...
@abstractmethod
async def broadcast_notification(
self,
method: str,
params: dict[str, Any],
) -> None:
"""Send notification to all connected clients."""
...
@abstractmethod
async def shutdown(self) -> None:
"""Stop accepting, close all client connections."""
...

View File

@@ -0,0 +1,73 @@
"""
AstrBot Gateway - HTTP/WebSocket API server.
Built on FastAPI, provides:
- HTTP REST API (stats, inspector, config)
- WebSocket for real-time events
- Static file serving (dashboard)
- Authentication (JWT/API key)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
class BaseAstrbotGateway(ABC):
"""
Gateway: HTTP/WebSocket server built on FastAPI.
┌─────────────────────────────────────────────────────────┐
│ FastAPI App │
├─────────────────────────────────────────────────────────┤
│ REST Endpoints WebSocket │
│ ├─ GET /api/stats ├─ /ws (connection manager)│
│ ├─ GET /api/inspector/* │ │
│ ├─ GET /api/memory/* │ │
│ └─ ... │ │
│ │
│ Middleware: CORS, Auth, Logging │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────┐
│ Orchestrator │
│ (owns protocol clients)│
└─────────────────────────┘
Routes (typical):
GET / → Dashboard static files
GET /api/stats → System statistics
GET /api/inspector/stars → List registered stars
WS /ws → WebSocket for real-time events
serve() Lifecycle:
1. Create FastAPI app
2. Register routes
3. Start WebSocket manager
4. Bind to host:port
5. Run ASGI server (uvicorn/hypercorn)
6. Block until shutdown
7. Close all connections
Subclass must implement:
- serve(): start server, block until shutdown
"""
@abstractmethod
async def serve(self) -> None:
"""
Start gateway server - blocks until shutdown.
Should:
1. Create FastAPI app with routes
2. Configure CORS, auth middleware
3. Start WebSocket connection manager
4. Bind to ASTRBOT_PORT (default 6185)
5. Run ASGI server
6. Handle graceful shutdown on SIGTERM/SIGINT
Raises:
OSError: address already in use
"""
...

View File

@@ -0,0 +1,352 @@
"""
AstrBot Orchestrator - core runtime lifecycle manager.
Architecture
============
┌─────────────────────────────────────────────────────┐
│ Orchestrator │
│ (owns lifecycle of all protocol clients + stars) │
└─────────────────────────────────────────────────────┘
┌──────────────┼──────────────┐
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ LSP │ │ MCP │ │ ACP │
│ Client │ │ Client │ │ Client │
└─────────┘ └─────────┘ └─────────┘
│ │ │
▼ ▼ ▼
LSP Servers MCP Servers ACP Services
┌─────────────────────────────────────────────────────┐
│ ABP Client │
│ (in-process star registry) │
└─────────────────────────────────────────────────────┘
┌─────────┐
│ Stars │
│(Plugins) │
└─────────┘
Lifecycle State Machine
=======================
States:
┌─────────┐
│ INIT │───► orchestrator created, clients not initialized
└────┬────┘
│ start()
┌─────────┐
│ RUNNING │◄─── run_loop() executing
└────┬────┘
│ shutdown()
┌──────────┐
│ SHUTDOWN │─── all clients closed, ready for GC
└──────────┘
Transitions:
INIT + start() ──► RUNNING
RUNNING + shutdown() ──► SHUTDOWN
For each protocol client, the orchestrator:
1. Creates instance in __init__
2. Calls connect() to initialize
3. Calls protocol-specific setup (connect_to_server, etc)
4. Manages via run_loop() heartbeat
5. Calls shutdown() on final cleanup
Star Registration Flow
=====================
orchestrator.register_star("my-star", MyStar())
┌───────────────────┐
│ ABP Client │
│ .register_star() │
└───────────────────┘
┌───────────────────┐
│ Internal dict │
{"my-star": obj} │
└───────────────────┘
Message Routing (conceptual)
===========================
External Tool Call
┌──────────────┐ list_tools() ┌──────────────┐
│ MCP Client │────────────────────►│ MCP Server │
└──────────────┘◄────────────────────└──────────────┘
│ tool result
┌──────────────┐ call_tool() ┌──────────────┐
│ ABP │────────────────────►│ Star │
│ Client │◄────────────────────└──────────────┘
└──────────────┘ tool result
Return to caller
run_loop() Responsibilities
===========================
while running:
│─ check LSP server health (ping/heartbeat)
│─ check MCP session status (reconnect if needed)
│─ check ACP client connections
│─ process any pending star notifications
│─ sleep(SLEEP_INTERVAL)
Shutdown Sequence
==================
shutdown()
├─ set _running = False
├─ LSP.shutdown()
│ └─ send "shutdown" request
│ └─ terminate subprocess
├─ ACP.shutdown()
│ └─ close TCP/Unix connections
├─ ABP.shutdown()
│ └─ cancel pending requests
└─ MCP.cleanup()
└─ close all sessions
└─ cleanup subprocesses
Exception Handling
==================
Each protocol client should:
- Catch connection errors
- Attempt reconnection with exponential backoff
- Log errors but don't crash run_loop
- Raise on irrecoverable failures
The orchestrator run_loop should:
- Catch CancelledError on shutdown
- Catch Exception and log (don't crash)
- Ensure cleanup runs in finally block
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
from astrbot._internal.protocols.mcp.client import McpClient
#: Default heartbeat interval for run_loop()
DEFAULT_SLEEP_INTERVAL: float = 5.0
class BaseAstrbotOrchestrator(ABC):
"""
Core runtime: owns lifecycle of all protocol clients and stars.
┌────────────────────────────────────────────────────────────┐
│ Protocol Clients (always present, never None after init) │
├────────────────────────────────────────────────────────────┤
│ lsp: Language Server Protocol │
│ Purpose: code completion, diagnostics, hover, etc │
│ Transport: stdio subprocess │
│ │
│ mcp: Model Context Protocol │
│ Purpose: external tool access │
│ Transport: stdio | SSE | HTTP │
│ │
│ acp: AstrBot Communication Protocol │
│ Purpose: inter-service communication │
│ Transport: TCP | Unix Socket │
│ │
│ abp: AstrBot Protocol │
│ Purpose: in-process star (plugin) communication │
│ Transport: direct method calls │
└────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
│ Star Registry │
├────────────────────────────────────────────────────────────┤
│ _stars: dict[str, Any] │
│ Stars are plugins registered by name │
│ ABP client delegates calls to registered stars │
└────────────────────────────────────────────────────────────┘
Subclass must implement:
- __init__(): create all protocol client instances
- run_loop(): main event loop (block until shutdown)
- register_star(name, instance): add to registry + ABP
- unregister_star(name): remove from registry + ABP
- shutdown(): clean up all clients
"""
#: LSP client for language intelligence
lsp: AstrbotLspClient
#: MCP client for external tools
mcp: McpClient
#: ACP client for inter-service communication
acp: AstrbotAcpClient
#: ABP client for in-process star communication
abp: AstrbotAbpClient
def __init__(self) -> None:
"""
Initialize orchestrator and all protocol clients.
After __init__, all clients exist but are not connected.
Call start() or run_loop() to begin operation.
Example:
class MyOrchestrator(BaseAstrbotOrchestrator):
def __init__(self):
self.lsp = AstrbotLspClient()
self.mcp = McpClient()
self.acp = AstrbotAcpClient()
self.abp = AstrbotAbpClient()
self._stars: dict[str, Any] = {}
self._running = False
"""
self._stars: dict[str, Any] = {}
self._running: bool = False
@property
def running(self) -> bool:
"""True if run_loop() is executing."""
return self._running
@abstractmethod
async def start(self) -> None:
"""
Initialize all protocol clients.
Called once before run_loop(). Should:
1. Call lsp.connect()
2. Call mcp.connect()
3. Call acp.connect()
4. Call abp.connect()
5. Set _running = True
Raises:
Exception: if any client fails to initialize
"""
...
@abstractmethod
async def run_loop(self) -> None:
"""
Main event loop - blocks until shutdown.
Execution:
self._running = True
try:
while self._running:
await self._heartbeat()
await anyio.sleep(DEFAULT_SLEEP_INTERVAL)
except asyncio.CancelledError:
pass # shutdown requested
finally:
self._running = False
_heartbeat() responsibilities:
- Check LSP server health (optional ping)
- Check MCP session status, reconnect if needed
- Check ACP connections
- Process any pending star notifications
Raises:
asyncio.CancelledError: when shutdown() called
Note:
Subclass defines _heartbeat() for periodic tasks.
This method only handles the loop control.
"""
...
@abstractmethod
async def register_star(self, name: str, star_instance: Any) -> None:
"""
Register a star (plugin) with the orchestrator.
Args:
name: Unique identifier for the star
instance: Star plugin instance (must have .call_tool() method)
Does:
self._stars[name] = star_instance
self.abp.register_star(name, star_instance)
Raises:
ValueError: if name already registered
"""
...
@abstractmethod
async def unregister_star(self, name: str) -> None:
"""
Unregister a star (plugin) from the orchestrator.
Args:
name: Identifier of star to remove
Does:
del self._stars[name]
self.abp.unregister_star(name)
Note:
Idempotent - does nothing if name not found.
"""
...
@abstractmethod
async def get_star(self, name: str) -> Any | None:
"""Get registered star by name. Returns None if not found."""
...
@abstractmethod
async def list_stars(self) -> list[str]:
"""Return list of registered star names."""
...
@abstractmethod
async def shutdown(self) -> None:
"""
Graceful shutdown of orchestrator and all clients.
Execution order:
1. self._running = False (stop run_loop)
2. await lsp.shutdown()
3. await acp.shutdown()
4. await abp.shutdown()
5. await mcp.cleanup()
Does NOT unregister stars - caller should do that first.
After shutdown, orchestrator is ready for garbage collection.
"""
...

View File

@@ -0,0 +1,114 @@
"""
LSP (Language Server Protocol) client.
Transport: stdio subprocess
Messages: JSON-RPC 2.0 with Content-Length header
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
pass
class LspMessage:
"""JSON-RPC 2.0 message."""
jsonrpc: str = "2.0"
id: int | str | None = None
method: str | None = None
params: dict[str, Any] | None = None
result: Any = None
error: dict[str, Any] | None = None
class LspRequest(LspMessage):
"""Outgoing request."""
def __init__(self, method: str, params: dict[str, Any] | None = None) -> None:
self.id = id(self)
self.method = method
self.params = params
class LspResponse(LspMessage):
"""Incoming response."""
class LspNotification(LspMessage):
"""Incoming notification (no id)."""
class BaseAstrbotLspClient(ABC):
"""
LSP client: connects to LSP servers via stdio subprocess.
Subclass must implement:
- connect() -> None
- connect_to_server(command, workspace_uri) -> None
- send_request(method, params) -> dict
- send_notification(method, params) -> None
- shutdown() -> None
"""
@property
@abstractmethod
def connected(self) -> bool:
"""True if connected to an LSP server."""
...
@abstractmethod
async def connect(self) -> None:
self._connected = False
...
@abstractmethod
async def connect_to_server(
self,
command: list[str],
workspace_uri: str,
) -> None:
"""
Start LSP server subprocess and complete handshake.
Steps:
1. Spawn subprocess with stdin/stdout pipes
2. Send initialize request
3. Wait for response
4. Send initialized notification
"""
...
@abstractmethod
async def send_request(
self,
method: str,
params: dict[str, Any] | None = None,
) -> Any:
"""
Send JSON-RPC request and return result.
Raises:
RuntimeError: not connected
Exception: server returned error
"""
...
@abstractmethod
async def send_notification(
self,
method: str,
params: dict[str, Any] | None = None,
) -> None:
"""
Send JSON-RPC notification (no response expected).
"""
...
@abstractmethod
async def shutdown(self) -> None:
"""Send shutdown, terminate subprocess, cleanup."""
...

View File

@@ -0,0 +1,95 @@
"""
MCP (Model Context Protocol) client.
Transport: stdio | SSE | streamable_http
Messages: JSON-RPC 2.0
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal, TypedDict
if TYPE_CHECKING:
pass
class McpServerConfig(TypedDict, total=False):
"""MCP server configuration."""
# Stdio transport
command: str
args: list[str]
env: dict[str, str]
cwd: str
# HTTP transport
url: str
headers: dict[str, str]
transport: Literal["sse", "streamable_http"]
class McpToolInfo(TypedDict):
"""MCP tool descriptor."""
name: str
description: str
inputSchema: dict[str, Any]
class BaseAstrbotMcpClient(ABC):
"""
MCP client: connects to MCP servers for external tools.
Subclass must implement:
- connect() -> None
- connect_to_server(config, name) -> None
- list_tools() -> list[McpToolInfo]
- call_tool(name, args, timeout) -> CallToolResult
- cleanup() -> None
"""
session: Any # mcp.ClientSession
@property
@abstractmethod
def connected(self) -> bool: ...
@abstractmethod
async def connect(self) -> None:
"""Initialize client session."""
...
@abstractmethod
async def connect_to_server(
self,
config: McpServerConfig,
name: str,
) -> None:
"""
Connect to MCP server.
Stdio: {"command": "python", "args": ["server.py"], "env": {...}}
HTTP: {"url": "https://...", "transport": "sse"}
"""
...
@abstractmethod
async def list_tools(self) -> list[McpToolInfo]:
"""Call tools/list and return tools."""
...
@abstractmethod
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: int = 60,
) -> Any:
"""Call tools/call with reconnection support."""
...
@abstractmethod
async def cleanup(self) -> None:
"""Close all server connections."""
...

View File

@@ -0,0 +1,6 @@
"""Gateway module - FastAPI server for the dashboard backend."""
from .server import AstrbotGateway
from .ws_manager import WebSocketManager
__all__ = ["AstrbotGateway", "WebSocketManager"]

View File

@@ -0,0 +1,4 @@
"""
依赖注入
"""

View File

@@ -0,0 +1,248 @@
"""
AstrBot Gateway - FastAPI server for the dashboard backend.
Provides REST API endpoints and WebSocket connections for the frontend dashboard.
The gateway acts as the communication bridge between the dashboard and the orchestrator.
"""
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, cast
from astrbot import logger
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.geteway.ws_manager import WebSocketManager
if TYPE_CHECKING:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
else:
try:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
except ImportError:
logger.warning("FastAPI not installed, gateway unavailable.")
FastAPI = cast(Any, None)
WebSocket = cast(Any, None)
WebSocketDisconnect = cast(Any, None)
CORSMiddleware = cast(Any, None)
log = logger
class AstrbotGateway(BaseAstrbotGateway):
"""
FastAPI-based gateway server for AstrBot.
Handles:
- REST API endpoints for configuration and stats
- WebSocket connections for real-time communication
- CORS middleware for dashboard access
"""
def __init__(self, orchestrator: BaseAstrbotOrchestrator) -> None:
self.orchestrator = orchestrator
self.ws_manager = WebSocketManager()
self._app: FastAPI | None = None
self._host = "0.0.0.0"
self._port = 8765
async def serve(self) -> None:
"""
Start the gateway server.
Creates and runs a FastAPI application with WebSocket support.
"""
if FastAPI is None:
raise RuntimeError("FastAPI is not installed")
log.info(f"Starting AstrBot Gateway on {self._host}:{self._port}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
log.info("Gateway server started.")
yield
# Shutdown
await self.ws_manager.broadcast({"type": "server_shutdown"})
log.info("Gateway server stopped.")
self._app = FastAPI(
title="AstrBot Gateway",
description="Backend API for AstrBot dashboard",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware
self._app.add_middleware(
cast(Any, CORSMiddleware),
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
self._setup_routes()
# Run with uvicorn
import uvicorn
config = uvicorn.Config(
self._app,
host=self._host,
port=self._port,
log_level="info",
)
server = uvicorn.Server(config)
await server.serve()
def _setup_routes(self) -> None:
"""Set up API routes."""
if self._app is None:
return
from fastapi import APIRouter
# Health check
@self._app.get("/health")
async def health():
return {"status": "ok"}
# WebSocket endpoint
@self._app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await self.ws_manager.connect(ws)
try:
while True:
data = await ws.receive_text()
try:
message = json.loads(data)
response = await self._handle_ws_message(message)
if response:
await ws.send_json(response)
except json.JSONDecodeError:
await ws.send_json({"error": "Invalid JSON"})
except WebSocketDisconnect:
self.ws_manager.disconnect(ws)
# Stats router
stats_router = APIRouter(prefix="/api/stats", tags=["stats"])
@stats_router.get("/overview")
async def get_overview():
return await self._get_stats_overview()
self._app.include_router(stats_router)
# Inspector router
inspector_router = APIRouter(prefix="/api/inspector", tags=["inspector"])
@inspector_router.get("/stars")
async def list_stars():
return await self._list_stars()
@inspector_router.get("/stars/{star_name}")
async def get_star(star_name: str):
return await self._get_star_detail(star_name)
self._app.include_router(inspector_router)
# Memory router
memory_router = APIRouter(prefix="/api/memory", tags=["memory"])
@memory_router.get("/")
async def get_memory():
return await self._get_memory_info()
self._app.include_router(memory_router)
async def _handle_ws_message(
self, message: dict[str, Any]
) -> dict[str, Any] | None:
"""
Handle an incoming WebSocket message.
Args:
message: Parsed JSON message from the client
Returns:
Response message to send back, or None for no response
"""
msg_type = message.get("type")
data = message.get("data", {})
if msg_type == "ping":
return {"type": "pong", "data": {}}
if msg_type == "call_tool":
return await self._handle_call_tool(data)
if msg_type == "get_stars":
return {"type": "stars_list", "data": await self._list_stars()}
return {
"type": "error",
"data": {"message": f"Unknown message type: {msg_type}"},
}
async def _handle_call_tool(self, data: dict[str, Any]) -> dict[str, Any]:
"""Handle a tool call request via WebSocket."""
star_name = data.get("star")
tool_name = data.get("tool")
arguments = data.get("arguments", {})
if not star_name or not tool_name:
return {
"type": "tool_result",
"data": {"error": "Missing star or tool name"},
}
try:
result = await self.orchestrator.abp.call_star_tool(
star_name, tool_name, arguments
)
return {"type": "tool_result", "data": {"result": result}}
except Exception as e:
return {"type": "tool_result", "data": {"error": str(e)}}
async def _get_stats_overview(self) -> dict[str, Any]:
"""Get overview statistics."""
return {
"stars_count": len(self.orchestrator.abp._stars),
"lsp_connected": self.orchestrator.lsp._connected,
"mcp_sessions": getattr(self.orchestrator.mcp, "session", None) is not None,
"acp_clients": len(getattr(self.orchestrator.acp, "_clients", [])),
}
async def _list_stars(self) -> list[dict[str, Any]]:
"""List all registered stars."""
stars = []
for name in self.orchestrator.abp._stars:
stars.append({"name": name, "status": "active"})
return stars
async def _get_star_detail(self, star_name: str) -> dict[str, Any]:
"""Get details of a specific star."""
star = self.orchestrator.abp._stars.get(star_name)
if not star:
return {"error": f"Star '{star_name}' not found"}
return {"name": star_name, "status": "active"}
async def _get_memory_info(self) -> dict[str, Any]:
"""Get memory usage information."""
import gc
gc.collect()
return {
"gc_objects": len(gc.get_objects()),
"python_memory": "N/A", # Would need psutil for actual values
}
def set_listen_address(self, host: str, port: int) -> None:
"""Set the listen address for the gateway server."""
self._host = host
self._port = port

View File

@@ -0,0 +1,103 @@
"""
WebSocket connection manager for the AstrBot gateway.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
import anyio
from astrbot import logger
if TYPE_CHECKING:
from fastapi import WebSocket
else:
try:
from fastapi import WebSocket
except ImportError:
logger.warning("FastAPI not installed, WebSocketManager unavailable.")
WebSocket = cast(Any, None)
log = logger
class WebSocketManager:
"""
Manages all active WebSocket connections.
Provides connection/disconnection handling and broadcast capabilities.
"""
def __init__(self) -> None:
self._connections: set[WebSocket] = set()
self._lock = anyio.Lock()
async def connect(self, websocket: WebSocket) -> None:
"""Accept and register a new WebSocket connection."""
await websocket.accept()
async with self._lock:
self._connections.add(websocket)
log.debug(f"WebSocket connected. Total: {len(self._connections)}")
async def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
async with self._lock:
self._connections.discard(websocket)
log.debug(f"WebSocket disconnected. Total: {len(self._connections)}")
async def send_json(self, websocket: WebSocket, data: dict[str, Any]) -> None:
"""
Send JSON data to a specific WebSocket.
Args:
websocket: Target WebSocket connection
data: Data to send (must be JSON-serializable)
"""
try:
await websocket.send_json(data)
except Exception as e:
log.warning(f"Failed to send to WebSocket: {e}")
await self.disconnect(websocket)
async def broadcast(self, data: dict[str, Any]) -> None:
"""
Broadcast JSON data to all connected WebSockets.
Args:
data: Data to broadcast (must be JSON-serializable)
"""
async with self._lock:
connections = list(self._connections)
for conn in connections:
try:
await conn.send_json(data)
except Exception as e:
log.warning(f"Failed to broadcast to WebSocket: {e}")
async with self._lock:
self._connections.discard(conn)
async def send_to(
self, websocket: WebSocket, message: str | dict[str, Any]
) -> None:
"""
Send a message to a specific WebSocket.
Args:
websocket: Target WebSocket connection
message: Message to send (string or dict)
"""
try:
if isinstance(message, str):
await websocket.send_text(message)
else:
await websocket.send_json(message)
except Exception as e:
log.warning(f"Failed to send to WebSocket: {e}")
await self.disconnect(websocket)
@property
def connection_count(self) -> int:
"""Return the number of active connections."""
return len(self._connections)

View File

@@ -0,0 +1,5 @@
"""ABP module - AstrBot Protocol client implementation (built-in plugin protocol)."""
from .client import AstrbotAbpClient
__all__ = ["AstrbotAbpClient"]

View File

@@ -0,0 +1,93 @@
"""
ABP (AstrBot Protocol) client implementation.
ABP is the built-in plugin protocol where the orchestrator acts as client
connecting to internal stars (plugins) embedded in the runtime.
"""
from __future__ import annotations
from typing import Any
from astrbot import logger
from astrbot._internal.abc.abp.base_astrbot_abp_client import BaseAstrbotAbpClient
log = logger
class AstrbotAbpClient(BaseAstrbotAbpClient):
"""
ABP client for communicating with internal stars (built-in plugins).
The orchestrator acts as the client, sending requests to and receiving
notifications from stars running within the same process.
"""
def __init__(self) -> None:
self._connected = False
self._stars: dict[str, Any] = {}
# Use a simple dict for pending requests; we avoid asyncio.Future here.
self._pending_requests: dict[str, Any] = {}
self._request_id = 0
@property
def connected(self) -> bool:
"""True if connected to stars registry."""
return self._connected
async def connect(self) -> None:
"""Connect to internal stars registry."""
log.debug("ABP client connecting to internal stars...")
self._connected = True
log.info("ABP client connected to internal stars registry.")
async def call_star_tool(
self, star_name: str, tool_name: str, arguments: dict[str, Any]
) -> Any:
"""
Call a tool on a registered star.
Args:
star_name: Name of the star (plugin)
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool call result
"""
if not self._connected:
raise RuntimeError("ABP client is not connected")
star = self._stars.get(star_name)
if not star:
raise ValueError(f"Star '{star_name}' not found")
request_id = f"{self._request_id}"
self._request_id += 1
# No asyncio.Future used; store a placeholder entry for tracking if needed.
self._pending_requests[request_id] = None
try:
# Call the star's tool handler
result = await star.call_tool(tool_name, arguments)
return result
finally:
self._pending_requests.pop(request_id, None)
def register_star(self, star_name: str, star_instance: Any) -> None:
"""Register a star (plugin) with the ABP client."""
self._stars[star_name] = star_instance
log.debug(f"Star '{star_name}' registered with ABP client.")
def unregister_star(self, star_name: str) -> None:
"""Unregister a star from the ABP client."""
self._stars.pop(star_name, None)
log.debug(f"Star '{star_name}' unregistered from ABP client.")
async def shutdown(self) -> None:
"""Shutdown the ABP client connection."""
self._connected = False
# Clear any pending requests (no asyncio futures used in this implementation)
self._pending_requests.clear()
log.info("ABP client shut down.")

View File

@@ -0,0 +1,6 @@
"""ACP module - AstrBot Communication Protocol client and server implementations."""
from .client import AstrbotAcpClient
from .server import AstrbotAcpServer
__all__ = ["AstrbotAcpClient", "AstrbotAcpServer"]

View File

@@ -0,0 +1,220 @@
"""
ACP (AstrBot Communication Protocol) client implementation.
ACP is a client-server protocol for inter-service communication,
similar to MCP but designed specifically for AstrBot's architecture.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any
from astrbot import logger
from astrbot._internal.abc.acp.base_astrbot_acp_client import BaseAstrbotAcpClient
log = logger
class AstrbotAcpClient(BaseAstrbotAcpClient):
"""
ACP client for communicating with ACP servers.
The orchestrator acts as an ACP client, connecting to external
ACP-compatible services.
"""
def __init__(self) -> None:
self._connected = False
self._reader: asyncio.StreamReader | None = None
self._writer: asyncio.StreamWriter | None = None
self._server_url: str | None = None
self._pending_requests: dict[str, asyncio.Future[dict[str, Any]]] = {}
self._request_id = 0
self._reader_task: asyncio.Task[None] | None = None
@property
def connected(self) -> bool:
"""True if connected to an ACP server."""
return self._connected
async def connect(self) -> None:
"""
Connect to configured ACP servers.
ACP servers can be accessed via TCP (host:port) or Unix socket.
"""
log.debug("ACP client connecting...")
# TODO: Load ACP server configurations
self._connected = True
log.info("ACP client initialized.")
async def connect_to_server(self, host: str, port: int) -> None:
"""
Connect to an ACP server via TCP.
Args:
host: Server hostname or IP
port: Server port
"""
self._server_url = f"{host}:{port}"
self._reader, self._writer = await asyncio.open_connection(host, port)
self._connected = True
# Start reading responses
self._reader_task = asyncio.create_task(self._read_messages())
log.info(f"ACP client connected to {self._server_url}")
async def connect_to_unix_socket(self, socket_path: str) -> None:
"""
Connect to an ACP server via Unix socket.
Args:
socket_path: Path to the Unix socket
"""
self._server_url = f"unix://{socket_path}"
self._reader, self._writer = await asyncio.open_unix_connection(socket_path)
self._connected = True
self._reader_task = asyncio.create_task(self._read_messages())
log.info(f"ACP client connected to {self._server_url}")
async def _read_messages(self) -> None:
"""Background task to read ACP messages."""
if not self._reader:
return
buffer = b""
while self._connected:
try:
data = await self._reader.read(4096)
if not data:
break
buffer += data
while True:
header_end = buffer.find(b"\n")
if header_end == -1:
break
try:
header = json.loads(buffer[:header_end].decode("utf-8"))
except json.JSONDecodeError:
buffer = buffer[header_end + 1 :]
continue
content_length = header.get("content-length", 0)
if (
content_length == 0
or len(buffer) < header_end + 1 + content_length
):
break
content = buffer[header_end + 1 : header_end + 1 + content_length]
buffer = buffer[header_end + 1 + content_length :]
message = json.loads(content.decode("utf-8"))
if "id" in message:
request_id = str(message["id"])
future = self._pending_requests.pop(request_id, None)
if future and not future.done():
if "error" in message:
future.set_exception(Exception(str(message["error"])))
else:
future.set_result(message.get("result", {}))
else:
await self._handle_notification(message)
except Exception as e:
if self._connected:
log.error(f"ACP read error: {e}")
break
async def _handle_notification(self, notification: dict[str, Any]) -> None:
"""Handle incoming ACP notifications."""
method = notification.get("method", "")
log.debug(f"ACP notification: {method}")
async def call_tool(
self, server_name: str, tool_name: str, arguments: dict[str, Any]
) -> Any:
"""
Call a tool on an ACP server.
Args:
server_name: Name of the ACP server
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool call result
"""
if not self._connected:
raise RuntimeError("ACP client is not connected")
request_id = str(self._request_id)
self._request_id += 1
message = {
"jsonrpc": "2.0",
"id": request_id,
"method": f"{server_name}/{tool_name}",
"params": arguments,
}
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
self._pending_requests[request_id] = future
await self._send_message(message)
return await future
async def _send_message(self, message: dict[str, Any]) -> None:
"""Send an ACP message."""
if not self._writer:
raise RuntimeError("ACP client not connected")
content = json.dumps(message)
header = json.dumps({"content-length": len(content)}) + "\n"
self._writer.write((header + content).encode())
await self._writer.drain()
async def send_notification(
self, method: str, params: dict[str, Any] | None = None
) -> None:
"""Send a one-way notification to the server."""
message = {
"jsonrpc": "2.0",
"method": method,
"params": params or {},
}
await self._send_message(message)
async def shutdown(self) -> None:
"""Shutdown the ACP client connection."""
self._connected = False
if self._reader_task:
self._reader_task.cancel()
try:
await self._reader_task
except asyncio.CancelledError:
pass
if self._writer:
self._writer.close()
try:
await self._writer.wait_closed()
except Exception:
pass
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
log.info("ACP client shut down.")

View File

@@ -0,0 +1,223 @@
"""
ACP (AstrBot Communication Protocol) server implementation.
ACP servers listen for connections from ACP clients and provide
services/tools to the orchestrator.
"""
from __future__ import annotations
import asyncio
import json
from collections.abc import Callable
from typing import Any
from astrbot import logger
from astrbot._internal.abc.acp.base_astrbot_acp_server import BaseAstrbotAcpServer
log = logger
class AstrbotAcpServer(BaseAstrbotAcpServer):
"""
ACP server for accepting connections from ACP clients.
ACP servers expose tools/notifications that can be called by clients.
"""
def __init__(self) -> None:
self._running = False
self._host: str = "127.0.0.1"
self._port: int = 8765
self._server: asyncio.Server | None = None
self._clients: set[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = set()
self._tool_handlers: dict[str, Callable[..., Any]] = {}
self._notification_handlers: dict[str, Callable[..., Any]] = {}
def register_tool(self, name: str, handler: Callable[..., Any]) -> None:
"""
Register a tool handler.
Args:
name: Tool name
handler: Async callable that handles tool calls
"""
self._tool_handlers[name] = handler
log.debug(f"ACP server registered tool: {name}")
def register_notification_handler(
self, name: str, handler: Callable[..., Any]
) -> None:
"""
Register a notification handler.
Args:
name: Notification method name
handler: Async callable that handles notifications
"""
self._notification_handlers[name] = handler
log.debug(f"ACP server registered notification handler: {name}")
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
"""
Start the ACP server.
Args:
host: Host to bind to
port: Port to listen on
"""
self._host = host
self._port = port
self._server = await asyncio.start_server(
self._handle_client,
host=host,
port=port,
)
self._running = True
log.info(f"ACP server listening on {host}:{port}")
async def _handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Handle an incoming ACP client connection."""
addr = writer.get_extra_info("peername")
log.debug(f"ACP client connected: {addr}")
self._clients.add((reader, writer))
buffer = b""
try:
while self._running:
try:
data = await reader.read(4096)
if not data:
break
buffer += data
while True:
header_end = buffer.find(b"\n")
if header_end == -1:
break
try:
header = json.loads(buffer[:header_end].decode("utf-8"))
except json.JSONDecodeError:
buffer = buffer[header_end + 1 :]
continue
content_length = header.get("content-length", 0)
if (
content_length == 0
or len(buffer) < header_end + 1 + content_length
):
break
content = buffer[
header_end + 1 : header_end + 1 + content_length
]
buffer = buffer[header_end + 1 + content_length :]
message = json.loads(content.decode("utf-8"))
response = await self._handle_message(message)
if response:
content = json.dumps(response)
resp_header = (
json.dumps({"content-length": len(content)}) + "\n"
)
writer.write(resp_header.encode() + content.encode())
await writer.drain()
except Exception as e:
log.error(f"ACP client error ({addr}): {e}")
break
finally:
self._clients.discard((reader, writer))
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
log.debug(f"ACP client disconnected: {addr}")
async def _handle_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
"""Handle an incoming ACP message."""
method = message.get("method", "")
msg_id = message.get("id")
params = message.get("params", {})
# Check if it's a notification (no id) or request (has id)
if msg_id is None:
# Notification
handler = self._notification_handlers.get(method)
if handler:
try:
await handler(params)
except Exception as e:
log.error(f"ACP notification handler error ({method}): {e}")
return None
# Request
result = None
error = None
handler = self._tool_handlers.get(method)
if handler:
try:
result = await handler(params)
except Exception as e:
error = str(e)
log.error(f"ACP tool handler error ({method}): {e}")
else:
error = f"Unknown method: {method}"
response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id}
if error:
response["error"] = {"code": -32601, "message": error}
else:
response["result"] = result
return response
async def broadcast_notification(self, method: str, params: dict[str, Any]) -> None:
"""
Broadcast a notification to all connected clients.
Args:
method: Notification method name
params: Notification parameters
"""
message = {
"jsonrpc": "2.0",
"method": method,
"params": params,
}
content = json.dumps(message)
header = json.dumps({"content-length": len(content)}) + "\n"
data = header.encode() + content.encode()
for reader, writer in list(self._clients):
try:
writer.write(data)
await writer.drain()
except Exception as e:
log.warning(f"Failed to broadcast to client: {e}")
async def shutdown(self) -> None:
"""Shutdown the ACP server."""
self._running = False
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
for reader, writer in list(self._clients):
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
self._clients.clear()
log.info("ACP server shut down.")

View File

@@ -0,0 +1,5 @@
"""LSP module - Language Server Protocol client implementation."""
from .client import AstrbotLspClient
__all__ = ["AstrbotLspClient"]

View File

@@ -0,0 +1,243 @@
"""
LSP (Language Server Protocol) client implementation.
The orchestrator acts as an LSP client, connecting to LSP servers
that provide language intelligence features (completions, diagnostics, etc.).
"""
from __future__ import annotations
import json
from typing import Any
import anyio
from anyio.abc import ByteReceiveStream, ByteSendStream, Process
from astrbot import logger
from astrbot._internal.abc.lsp.base_astrbot_lsp_client import BaseAstrbotLspClient
log = logger
class AstrbotLspClient(BaseAstrbotLspClient):
"""
LSP client for communicating with LSP servers.
Implements the Microsoft Language Server Protocol for connecting to
external language intelligence services.
"""
def __init__(self) -> None:
self._connected = False
self._reader: ByteReceiveStream | None = None
self._writer: ByteSendStream | None = None
self._server_process: Process | None = None
self._pending_requests: dict[int, Any] = {}
self._request_id = 0
self._server_command: list[str] | None = None
# anyio TaskGroup handle for background readers
self._task_group: Any | None = None
@property
def connected(self) -> bool:
"""True if connected to an LSP server."""
return self._connected
async def connect(self) -> None:
"""
Connect to configured LSP servers.
LSP servers are typically stdio-based subprocesses. This method
establishes the communication channel.
"""
log.debug("LSP client connecting...")
# TODO: Load LSP server configurations and start subprocesses
# For now, mark as connected in idle mode
self._connected = True
log.info("LSP client initialized.")
async def connect_to_server(self, command: list[str], workspace_uri: str) -> None:
"""
Connect to an LSP server subprocess.
Args:
command: Command line to start the LSP server (e.g., ["python", "lsp_server.py"])
workspace_uri: Root URI of the workspace to serve
"""
log.debug(f"Starting LSP server: {' '.join(command)}")
self._server_process = await anyio.open_process(
command,
stdin=-1,
stdout=-1,
stderr=-1,
)
self._reader = self._server_process.stdout
self._writer = self._server_process.stdin
self._server_command = command
self._connected = True
# Start reading responses in background using anyio TaskGroup
# Create and enter a TaskGroup so the reader runs until we close it at shutdown.
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._read_responses)
# Send initialize request
await self.send_request(
"initialize",
{
"processId": None,
"rootUri": workspace_uri,
"capabilities": {},
},
)
# Send initialized notification
await self.send_notification("initialized", {})
log.info(f"LSP client connected to server: {command[0]}")
async def send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send an LSP request and wait for response."""
if not self._writer:
raise RuntimeError("LSP client not connected")
request_id = self._request_id
self._request_id += 1
message = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params or {},
}
# Use anyio.Event for request/response matching
response_event: anyio.Event = anyio.Event()
response_holder: dict[str, Any] = {}
async def set_response(response: dict[str, Any]) -> None:
response_holder["response"] = response
response_event.set()
self._pending_requests[request_id] = set_response
content = json.dumps(message)
headers = f"Content-Length: {len(content)}\r\n\r\n"
await self._writer.send((headers + content).encode())
# Wait for response with timeout
with anyio.move_on_after(30):
await response_event.wait()
if "response" in response_holder:
return response_holder["response"]
raise TimeoutError(f"LSP request {method} timed out")
async def send_notification(
self, method: str, params: dict[str, Any] | None = None
) -> None:
"""Send an LSP notification (no response expected)."""
if not self._writer:
raise RuntimeError("LSP client not connected")
message = {
"jsonrpc": "2.0",
"method": method,
"params": params or {},
}
content = json.dumps(message)
headers = f"Content-Length: {len(content)}\r\n\r\n"
await self._writer.send((headers + content).encode())
async def _read_responses(self) -> None:
"""Background task to read LSP responses."""
if not self._reader:
return
buffer = b""
try:
while self._connected:
try:
data = await self._reader.receive()
if not data:
break
buffer += data
while True:
# Parse Content-Length header
header_end = buffer.find(b"\r\n\r\n")
if header_end == -1:
break
header = buffer[:header_end].decode("utf-8")
content_length = 0
for line in header.split("\r\n"):
if line.startswith("Content-Length:"):
content_length = int(line.split(":")[1].strip())
if content_length == 0:
break
total_length = header_end + 4 + content_length
if len(buffer) < total_length:
break
content = buffer[header_end + 4 : total_length]
buffer = buffer[total_length:]
response = json.loads(content.decode("utf-8"))
# Handle response vs notification
if "id" in response:
request_id = response["id"]
handler = self._pending_requests.pop(request_id, None)
if handler:
await handler(response)
else:
# Notification (e.g., window/logMessage)
await self._handle_notification(response)
except anyio.EndOfStream:
break
except anyio.get_cancelled_exc_class():
# Task was cancelled via the TaskGroup cancel/exit during shutdown
pass
async def _handle_notification(self, notification: dict[str, Any]) -> None:
"""Handle incoming LSP notifications."""
method = notification.get("method", "")
log.debug(f"LSP notification: {method}")
async def shutdown(self) -> None:
"""Shutdown the LSP client."""
self._connected = False
if self._task_group:
try:
# Exit the TaskGroup, which cancels background tasks started within it
await self._task_group.__aexit__(None, None, None)
except anyio.get_cancelled_exc_class():
pass
self._task_group = None
if self._server_process:
try:
await self.send_notification("shutdown", {})
except Exception:
pass
self._server_process.terminate()
try:
with anyio.move_on_after(5):
await self._server_process.wait()
except Exception:
self._server_process.kill()
self._server_process = None
self._pending_requests.clear()
log.info("LSP client shut down.")

View File

@@ -0,0 +1,63 @@
"""MCP module - Model Context Protocol client and tool implementations.
This module provides MCP client functionality and MCP tool wrappers.
"""
import asyncio
from dataclasses import dataclass
from .client import McpClient
from .config import (
DEFAULT_MCP_CONFIG,
get_mcp_config_path,
load_mcp_config,
save_mcp_config,
)
from .tool import MCPTool
# Exceptions
class MCPInitError(Exception):
"""Base exception for MCP initialization failures."""
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
"""Raised when MCP client initialization exceeds the configured timeout."""
class MCPAllServicesFailedError(MCPInitError):
"""Raised when all configured MCP services fail to initialize."""
class MCPShutdownTimeoutError(asyncio.TimeoutError):
"""Raised when MCP shutdown exceeds the configured timeout."""
def __init__(self, names: list[str], timeout: float) -> None:
self.names = names
self.timeout = timeout
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
super().__init__(message)
@dataclass
class MCPInitSummary:
"""Summary of MCP initialization results."""
total: int
success: int
failed: list[str]
__all__ = [
"DEFAULT_MCP_CONFIG",
"MCPAllServicesFailedError",
"MCPInitError",
"MCPInitSummary",
"MCPInitTimeoutError",
"MCPShutdownTimeoutError",
"MCPTool",
"McpClient",
"get_mcp_config_path",
"load_mcp_config",
"save_mcp_config",
]

View File

@@ -0,0 +1,466 @@
"""MCP client implementation."""
import asyncio
import logging
import os
import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, cast
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from astrbot._internal.abc.mcp.base_astrbot_mcp_client import (
BaseAstrbotMcpClient,
McpServerConfig,
McpToolInfo,
)
from astrbot.core.utils.log_pipe import LogPipe
logger = logging.getLogger("astrbot")
try:
import anyio
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
)
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
)
def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format."""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
def _prepare_stdio_env(config: dict) -> dict:
"""Preserve Windows executable resolution for stdio subprocesses."""
if sys.platform != "win32":
return config
pathext = os.environ.get("PATHEXT")
if not pathext:
return config
prepared = config.copy()
env = dict(prepared.get("env") or {})
env.setdefault("PATHEXT", pathext)
prepared["env"] = env
return prepared
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity."""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP connection config missing transport or type field")
async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
except Exception as e:
return False, f"{e!s}"
class McpClient(BaseAstrbotMcpClient):
def __init__(self) -> None:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = anyio.Event()
self.process_pid: int | None = None
# Store connection config for reconnection
self._mcp_server_config: McpServerConfig | None = None
self._server_name: str | None = None
self._reconnect_lock = anyio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
async def connect(self) -> None:
"""Initialize the MCP client connection.
Note: Actual server connections are made via connect_to_server().
This method prepares the client for use.
"""
# MCP client is initialized on-demand via connect_to_server
# This is a no-op stub to satisfy BaseAstrbotMcpClient
logger.debug("MCP client initialized.")
@property
def connected(self) -> bool:
"""True if MCP client has an active session."""
return self.session is not None
async def list_tools(self) -> list[McpToolInfo]:
"""List all tools from connected MCP servers."""
if not self.session:
return []
result = await self.list_tools_and_save()
tools = [
{
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema,
}
for tool in result.tools
]
return cast(list[McpToolInfo], tools)
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: int = 60,
) -> Any:
"""Call a tool on the MCP server with reconnection support."""
return await self.call_tool_with_reconnect(
tool_name=name,
arguments=arguments,
read_timeout_seconds=timedelta(seconds=read_timeout_seconds),
)
@staticmethod
def _extract_stdio_process_pid(streams_context: object) -> int | None:
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
TODO(refactor): replace this async-generator frame introspection with a
stable MCP library hook once the upstream transport exposes process PID.
"""
generator = getattr(streams_context, "gen", None)
frame = getattr(generator, "ag_frame", None)
if frame is None:
return None
process = frame.f_locals.get("process")
pid = getattr(process, "pid", None)
try:
return int(pid) if pid is not None else None
except (TypeError, ValueError):
return None
async def connect_to_server(self, config: McpServerConfig, name: str) -> None:
"""Connect to MCP server
If `url` parameter exists:
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
2. When transport is specified as `sse`, use SSE connection.
3. If not specified, default to SSE connection to MCP service.
Args:
config: Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
# Store config for reconnection
self._mcp_server_config = config
self._server_name = name
self.process_pid = None
cfg = _prepare_config(dict(config))
def logging_callback(
msg: str | mcp.types.LoggingMessageNotificationParams,
) -> None:
# Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
self.server_errlogs.append(log_msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP connection config missing transport or type field")
if transport_type != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context,
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=cast(Any, logging_callback),
),
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5),
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context,
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
),
)
else:
cfg = _prepare_stdio_env(cfg)
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None:
# Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in (
"warning",
"error",
"critical",
"alert",
"emergency",
):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
self.server_errlogs.append(log_msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=cast(
Any,
LogPipe(
level=logging.INFO,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
),
),
)
self.process_pid = self._extract_stdio_process_pid(stdio_transport)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport),
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response
async def _reconnect(self) -> None:
"""Reconnect to the MCP server using the stored configuration.
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
Raises:
Exception: raised when reconnection fails
"""
async with self._reconnect_lock:
# Check if already reconnecting (useful for logging)
if self._reconnecting:
logger.debug(
f"MCP Client {self._server_name} is already reconnecting, skipping"
)
return
if not self._mcp_server_config or not self._server_name:
raise Exception("Cannot reconnect: missing connection configuration")
self._reconnecting = True
try:
logger.info(
f"Attempting to reconnect to MCP server {self._server_name}..."
)
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
if self.exit_stack:
self._old_exit_stacks.append(self.exit_stack)
# Mark old session as invalid
self.session = None
# Create new exit stack for new connection
self.exit_stack = AsyncExitStack()
# Reconnect using stored config
await self.connect_to_server(self._mcp_server_config, self._server_name)
await self.list_tools_and_save()
logger.info(
f"Successfully reconnected to MCP server {self._server_name}"
)
except Exception as e:
logger.error(
f"Failed to reconnect to MCP server {self._server_name}: {e}"
)
raise
finally:
self._reconnecting = False
async def call_tool_with_reconnect(
self,
tool_name: str,
arguments: dict,
read_timeout_seconds: timedelta,
) -> mcp.types.CallToolResult:
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
Args:
tool_name: tool name
arguments: tool arguments
read_timeout_seconds: read timeout
Returns:
MCP tool call result
Raises:
ValueError: MCP session is not available
anyio.ClosedResourceError: raised after reconnection failure
"""
@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(logger, logging.WARNING), # type: ignore[arg-type]
reraise=True,
)
async def _call_with_retry():
if not self.session:
raise ValueError("MCP session is not available for MCP function tools.")
try:
return await self.session.call_tool(
name=tool_name,
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
logger.warning(
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
)
# Attempt to reconnect
await self._reconnect()
# Reraise the exception to trigger tenacity retry
raise
return await _call_with_retry()
async def cleanup(self) -> None:
"""Clean up resources including old exit stacks from reconnections"""
# Close current exit stack
try:
await self.exit_stack.aclose()
except Exception as e:
logger.debug(f"Error closing current exit stack: {e}")
# Don't close old exit stacks as they may be in different task contexts
# They will be garbage collected naturally
# Just clear the list to release references
self._old_exit_stacks.clear()
# Set running_event first to unblock any waiting tasks
self.running_event.set()
self.process_pid = None

View File

@@ -0,0 +1,55 @@
"""MCP configuration management."""
import json
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
def get_mcp_config_path() -> str:
"""Get the path to the MCP configuration file."""
data_dir = get_astrbot_data_path()
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config() -> dict:
"""Load MCP configuration from file.
Returns:
MCP configuration dict. If file doesn't exist, returns default config.
"""
config_path = get_mcp_config_path()
if not os.path.exists(config_path):
# Create default config if not exists
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
try:
with open(config_path, encoding="utf-8") as f:
return json.load(f)
except Exception:
return DEFAULT_MCP_CONFIG
def save_mcp_config(config: dict) -> bool:
"""Save MCP configuration to file.
Args:
config: MCP configuration dict to save.
Returns:
True if successful, False otherwise.
"""
config_path = get_mcp_config_path()
try:
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception:
return False

View File

@@ -0,0 +1,45 @@
"""MCP tool wrapper."""
from datetime import timedelta
from typing import TYPE_CHECKING, Any
try:
import mcp
except (ModuleNotFoundError, ImportError):
mcp = None # type: ignore
from astrbot._internal.tools.base import FunctionTool
if TYPE_CHECKING:
from astrbot._internal.protocols.mcp.client import McpClient
class MCPTool(FunctionTool):
"""A function tool that calls an MCP service."""
def __init__(
self,
mcp_tool: "mcp.types.Tool",
mcp_client: "McpClient",
mcp_server_name: str,
**kwargs: Any,
) -> None:
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self.mcp_server_name = mcp_server_name
self.source = "mcp"
async def call(self, **kwargs: Any) -> Any:
"""Call the MCP tool with the given arguments."""
# Note: For actual usage, context.tool_call_timeout is needed
# but for simplicity we use a default timeout here
return await self.mcp_client.call_tool_with_reconnect(
tool_name=self.mcp_tool.name,
arguments=kwargs,
read_timeout_seconds=timedelta(seconds=60),
)

View File

@@ -0,0 +1,3 @@
from astrbot._internal.runtime.__main__ import bootstrap
__all__ = ["bootstrap"]

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import anyio
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.geteway.server import AstrbotGateway
from astrbot._internal.runtime.orchestrator import AstrbotOrchestrator
async def bootstrap():
orchestrator: BaseAstrbotOrchestrator = AstrbotOrchestrator()
gw: BaseAstrbotGateway = AstrbotGateway(orchestrator)
# anyio 的结构化并发
async with anyio.create_task_group() as tg:
tg.start_soon(orchestrator.lsp.connect) # 启动 LSP client
tg.start_soon(orchestrator.mcp.connect) # 启动 MCP client
tg.start_soon(orchestrator.acp.connect) # 启动 ACP client
tg.start_soon(orchestrator.abp.connect) # 启动 ABP client
await anyio.sleep(0.5)
tg.start_soon(orchestrator.run_loop) # 启动编排器循环
tg.start_soon(gw.serve) # 面板后端服务

View File

@@ -0,0 +1,164 @@
"""
AstrBot Orchestrator - core runtime that coordinates all protocol clients.
The orchestrator manages the lifecycle of LSP, MCP, ACP, and ABP clients,
and runs the main event loop that dispatches messages between components.
"""
from __future__ import annotations
from typing import Any
import anyio
from astrbot import logger
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
from astrbot._internal.protocols.mcp.client import McpClient
from astrbot._internal.stars import RuntimeStatusStar
log = logger
class AstrbotOrchestrator(BaseAstrbotOrchestrator):
"""
Core runtime orchestrator for AstrBot.
Manages:
- LSP client: Language Server Protocol for editor integrations
- MCP client: Model Context Protocol for external tool servers
- ACP client: AstrBot Communication Protocol for inter-service communication
- ABP client: AstrBot Protocol for built-in star (plugin) communication
"""
def __init__(self) -> None:
# Initialize protocol clients (use concrete types for full method access)
self.lsp = AstrbotLspClient()
self.mcp = McpClient()
self.acp = AstrbotAcpClient()
self.abp = AstrbotAbpClient()
self._running = False
self._stars: dict[str, Any] = {}
self._message_count: int = 0
self._last_activity_timestamp: float | None = None
# Auto-register RuntimeStatusStar
self._runtime_status_star = RuntimeStatusStar()
self._runtime_status_star.set_orchestrator(self)
self._stars["runtime-status-star"] = self._runtime_status_star
self.abp.register_star("runtime-status-star", self._runtime_status_star)
log.debug("AstrbotOrchestrator initialized.")
async def start(self) -> None:
"""
Initialize all protocol clients.
Calls connect() on all protocol clients to prepare them for use.
"""
log.info("Starting AstrbotOrchestrator...")
await self.lsp.connect()
await self.mcp.connect()
await self.acp.connect()
await self.abp.connect()
self._running = True
log.info("AstrbotOrchestrator started.")
async def run_loop(self) -> None:
"""
Main orchestrator event loop.
This loop runs continuously, handling:
- Periodic health checks of protocol clients
- Message routing between protocols
- Star (plugin) lifecycle management
"""
self._running = True
log.info("AstrbotOrchestrator run loop started.")
stop_event = anyio.Event()
def set_stop() -> None:
stop_event.set()
# Store the callback for cleanup
self._stop_callback = set_stop
try:
while self._running:
# TODO: Periodic tasks:
# - Check LSP server health
# - Check MCP session status
# - Check ACP client connections
# - Process any pending star notifications
# Wait for 5 seconds or until shutdown is called
with anyio.move_on_after(5):
await stop_event.wait()
except anyio.get_cancelled_exc_class():
log.info("Orchestrator run loop cancelled.")
finally:
self._running = False
self._stop_callback = None
log.info("AstrbotOrchestrator run loop stopped.")
async def register_star(self, name: str, star_instance: Any) -> None:
"""
Register a star (plugin) with the orchestrator.
Args:
name: Unique name for the star
star_instance: Star plugin instance
"""
self._stars[name] = star_instance
self.abp.register_star(name, star_instance)
log.info(f"Star '{name}' registered.")
async def unregister_star(self, name: str) -> None:
"""
Unregister a star (plugin) from the orchestrator.
Args:
name: Name of the star to unregister
"""
self._stars.pop(name, None)
self.abp.unregister_star(name)
log.info(f"Star '{name}' unregistered.")
async def get_star(self, name: str) -> Any | None:
"""Get a registered star by name."""
return self._stars.get(name)
async def list_stars(self) -> list[str]:
"""List all registered star names."""
return list(self._stars.keys())
def record_activity(self) -> None:
"""Record a message activity for stats tracking."""
self._message_count += 1
import time
self._last_activity_timestamp = time.time()
async def shutdown(self) -> None:
"""
Shutdown the orchestrator and all protocol clients.
"""
log.info("Shutting down AstrbotOrchestrator...")
self._running = False
# Shutdown all protocol clients
await self.lsp.shutdown()
await self.acp.shutdown()
await self.abp.shutdown()
# MCP cleanup
await self.mcp.cleanup()
log.info("AstrbotOrchestrator shut down.")

View File

@@ -0,0 +1,13 @@
"""Internal skills module - re-exports from core.skills.skill_manager."""
from astrbot.core.skills.skill_manager import (
SkillInfo,
SkillManager,
build_skills_prompt,
)
__all__ = [
"SkillInfo",
"SkillManager",
"build_skills_prompt",
]

View File

@@ -0,0 +1,7 @@
"""
Stars (built-in plugins) for AstrBot runtime.
"""
from astrbot._internal.stars.runtime_status_star import RuntimeStatusStar
__all__ = ["RuntimeStatusStar"]

View File

@@ -0,0 +1,127 @@
"""
RuntimeStatusStar - ABP plugin that exposes core runtime internal state.
This star provides tools for querying:
- Runtime status (running state, uptime)
- Protocol client status (LSP, MCP, ACP, ABP)
- Registered stars registry
- Message counts and metrics
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
@dataclass
class RuntimeStatusStar:
"""
ABP star that exposes core runtime internal state as callable tools.
Tools provided:
- get_runtime_status: Returns running state and uptime
- get_protocol_status: Returns LSP, MCP, ACP, ABP status
- get_star_registry: Returns registered star names
- get_stats: Returns message counts and metrics
"""
name: str = "runtime-status-star"
description: str = "ABP plugin that exposes core runtime internal state"
_start_time: float = field(default_factory=time.time, init=False)
_orchestrator: Any = field(default=None, init=False)
def set_orchestrator(self, orchestrator: Any) -> None:
"""Set the orchestrator reference for status queries."""
self._orchestrator = orchestrator
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""
Handle tool calls from ABP client.
Args:
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool result
"""
if tool_name == "get_runtime_status":
return self._get_runtime_status()
elif tool_name == "get_protocol_status":
return await self._get_protocol_status()
elif tool_name == "get_star_registry":
return await self._get_star_registry()
elif tool_name == "get_stats":
return self._get_stats()
else:
raise ValueError(f"Unknown tool: {tool_name}")
def _get_runtime_status(self) -> dict[str, Any]:
"""Get overall runtime state."""
running = (
getattr(self._orchestrator, "running", False)
if self._orchestrator
else False
)
uptime_seconds = time.time() - self._start_time
return {
"running": running,
"uptime_seconds": uptime_seconds,
}
async def _get_protocol_status(self) -> dict[str, Any]:
"""Get status of each protocol client."""
if not self._orchestrator:
return {
"lsp": {"connected": False, "name": "lsp-client"},
"mcp": {"connected": False, "name": "mcp-client"},
"acp": {"connected": False, "name": "acp-client"},
"abp": {"connected": False, "name": "abp-client"},
}
return {
"lsp": {
"connected": getattr(self._orchestrator.lsp, "connected", False),
"name": "lsp-client",
},
"mcp": {
"connected": getattr(self._orchestrator.mcp, "connected", False),
"name": "mcp-client",
},
"acp": {
"connected": getattr(self._orchestrator.acp, "connected", False),
"name": "acp-client",
},
"abp": {
"connected": getattr(self._orchestrator.abp, "connected", False),
"name": "abp-client",
},
}
async def _get_star_registry(self) -> dict[str, Any]:
"""Get list of registered stars."""
if not self._orchestrator:
return {"stars": []}
stars = await self._orchestrator.list_stars()
return {"stars": stars}
def _get_stats(self) -> dict[str, Any]:
"""Get message counts and metrics."""
result: dict[str, Any] = {
"uptime_seconds": time.time() - self._start_time,
}
if self._orchestrator:
result["total_messages"] = getattr(self._orchestrator, "_message_count", 0)
last_ts = getattr(self._orchestrator, "_last_activity_timestamp", None)
if last_ts is not None:
result["last_activity"] = datetime.fromtimestamp(
last_ts, tz=timezone.utc
).isoformat()
else:
result["last_activity"] = None
return result

View File

@@ -0,0 +1,5 @@
"""Internal tools module for AstrBot runtime."""
from .base import FunctionTool, ToolSet
__all__ = ["FunctionTool", "ToolSet"]

View File

@@ -0,0 +1,332 @@
"""Base tool classes for AstrBot internal runtime.
This module provides the FunctionTool base class used by MCP tools
in the new internal architecture.
"""
import copy
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
from dataclasses import dataclass, field
from typing import Any
from pydantic import model_validator
ParametersType = dict[str, Any]
@dataclass
class ToolSchema:
"""A class representing the schema of a tool for function calling."""
name: str
"""The name of the tool."""
description: str
"""The description of the tool."""
parameters: ParametersType = field(default_factory=dict)
"""The parameters of the tool, in JSON Schema format."""
@model_validator(mode="after")
def validate_parameters(self) -> "ToolSchema":
"""Validate the parameters JSON schema."""
import jsonschema
jsonschema.validate(
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
)
return self
@dataclass
class FunctionTool(ToolSchema):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[str | None] | AsyncGenerator[Any, None]] | None = (
None
)
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
"""
The module path of the handler function. This is empty when the origin is mcp.
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
causing the handler's __module__ to be functools
"""
active: bool = True
"""
Whether the tool is active. This field is a special field for AstrBot.
You can ignore it when integrating with other frameworks.
"""
is_background_task: bool = False
"""
Declare this tool as a background task. Background tasks return immediately
with a task identifier while the real work continues asynchronously.
"""
source: str = "mcp"
"""
Origin of this tool: 'plugin' (from star plugins), 'internal' (AstrBot built-in),
or 'mcp' (from MCP servers). Used by WebUI for display grouping.
"""
def __repr__(self) -> str:
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, **kwargs: Any) -> Any:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
class ToolSet:
"""
A collection of FunctionTools grouped under a namespace.
ToolSets allow organizing related tools together. The LLM sees tools
as "namespace/tool_name" when calling.
"""
def __init__(self, namespace: str, tools: list[FunctionTool] | None = None) -> None:
self.namespace = namespace
self._tools: dict[str, FunctionTool] = {}
if tools:
for tool in tools:
self.add(tool)
def add(self, tool: FunctionTool) -> None:
"""Add a tool to the set."""
self._tools[tool.name] = tool
def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set (alias for add())."""
self.add(tool)
def remove(self, name: str) -> FunctionTool | None:
"""Remove and return a tool by name."""
return self._tools.pop(name, None)
def remove_tool(self, name: str) -> None:
"""Remove a tool by its name."""
self._tools.pop(name, None)
def get(self, name: str) -> FunctionTool | None:
"""Get a tool by name."""
return self._tools.get(name)
def get_tool(self, name: str) -> FunctionTool | None:
"""Get a tool by name (alias for get)."""
return self.get(name)
def list_tools(self) -> list[FunctionTool]:
"""List all tools in this set."""
return list(self._tools.values())
def __iter__(self) -> Iterator[FunctionTool]:
return iter(self._tools.values())
def __len__(self) -> int:
return len(self._tools)
def __bool__(self) -> bool:
return bool(self._tools)
def __repr__(self) -> str:
return f"ToolSet(namespace={self.namespace!r}, tools={self.list_tools()!r})"
def __str__(self) -> str:
return f"ToolSet({self.namespace}, {len(self)} tools)"
def names(self) -> list[str]:
"""Get names of all tools in this set."""
return [tool.name for tool in self.tools]
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self) == 0
def merge(self, other: "ToolSet") -> None:
"""Merge another ToolSet into this one."""
for tool in other.tools:
self.add(tool)
def normalize(self) -> None:
"""Sort tools by name for deterministic serialization."""
self._tools = dict(sorted(self._tools.items(), key=lambda x: x[0]))
def get_light_tool_set(self) -> "ToolSet":
"""Return a light tool set with only name/description."""
light_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
light_tools.append(
FunctionTool(
name=tool.name,
description=tool.description,
parameters={"type": "object", "properties": {}},
handler=None,
)
)
return ToolSet("default", light_tools)
def get_param_only_tool_set(self) -> "ToolSet":
"""Return a tool set with name/parameters only (no description)."""
param_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
params = (
copy.deepcopy(tool.parameters)
if tool.parameters
else {"type": "object", "properties": {}}
)
param_tools.append(
FunctionTool(
name=tool.name,
description="",
parameters=params,
handler=None,
)
)
return ToolSet("default", param_tools)
@property
def tools(self) -> list[FunctionTool]:
"""List all tools in this set."""
return list(self._tools.values())
def openai_schema(
self, omit_empty_parameter_field: bool = False
) -> list[dict[str, Any]]:
"""Convert tools to OpenAI API function calling schema format."""
result: list[dict[str, Any]] = []
for tool in self._tools.values():
func_def: dict[str, Any] = {
"type": "function",
"function": {"name": tool.name},
}
if tool.description:
func_def["function"]["description"] = tool.description
if tool.parameters is not None:
if (
tool.parameters.get("properties")
) or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema: dict[str, Any] = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def: dict[str, Any] = {"name": tool.name, "input_schema": input_schema}
if tool.description:
tool_def["description"] = tool.description
result.append(tool_def)
return result
def google_schema(self) -> dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
origin_type = schema.get("type")
target_type = origin_type
if isinstance(origin_type, list):
target_type = next((t for t in origin_type if t != "null"), "string")
if target_type in supported_types:
result["type"] = target_type
if "format" in schema and schema["format"] in supported_formats.get(result["type"], set()):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
if "additionalProperties" in prop_value:
del prop_value["additionalProperties"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if target_type == "array":
items_schema = schema.get("items")
if isinstance(items_schema, dict):
result["items"] = convert_schema(items_schema)
else:
result["items"] = {"type": "string"}
return result
tools_list = []
for tool in self.tools:
d: dict[str, Any] = {"name": tool.name}
if tool.description:
d["description"] = tool.description
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools_list.append(d)
declarations: dict[str, Any] = {}
if tools_list:
declarations["function_declarations"] = tools_list
return declarations
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
"""Get tools in OpenAI function calling style (deprecated)."""
return self.openai_schema(omit_empty_parameter_field)
def get_func_desc_anthropic_style(self):
"""Get tools in Anthropic style (deprecated)."""
return self.anthropic_schema()
def get_func_desc_google_genai_style(self):
"""Get tools in Google GenAI style (deprecated)."""
return self.google_schema()

View File

@@ -0,0 +1,48 @@
"""
Builtin tools for AstrBot - re-exports from core.tools for backward compatibility.
This module re-exports the builtin tools (cron, send_message, kb_query) from
the deprecated core.tools module for backward compatibility.
TODO: These tools should be fully migrated to _internal and core.tools
should be removed once all consumers update their imports.
"""
from __future__ import annotations
# Re-export cron tools
from astrbot.core.tools.cron_tools import (
CREATE_CRON_JOB_TOOL,
DELETE_CRON_JOB_TOOL,
LIST_CRON_JOBS_TOOL,
CreateActiveCronTool,
DeleteCronJobTool,
ListCronJobsTool,
)
# Re-export knowledge_base_query tool
from astrbot.core.tools.kb_query import (
KNOWLEDGE_BASE_QUERY_TOOL,
KnowledgeBaseQueryTool,
)
# Re-export send_message tool
from astrbot.core.tools.send_message import (
SEND_MESSAGE_TO_USER_TOOL,
SendMessageToUserTool,
)
__all__ = [
# Cron tools
"CREATE_CRON_JOB_TOOL",
"DELETE_CRON_JOB_TOOL",
"KNOWLEDGE_BASE_QUERY_TOOL",
"LIST_CRON_JOBS_TOOL",
"SEND_MESSAGE_TO_USER_TOOL",
# Classes
"CreateActiveCronTool",
"DeleteCronJobTool",
"KnowledgeBaseQueryTool",
"ListCronJobsTool",
"SendMessageToUserTool",
]

View File

@@ -0,0 +1,278 @@
"""Tools registry for AstrBot internal runtime."""
from __future__ import annotations
from typing import Any
# Re-export from base
from astrbot._internal.tools.base import FunctionTool, ToolSet
__all__ = [
"DEFAULT_MCP_CONFIG",
"ENABLE_MCP_TIMEOUT_ENV",
"FuncCall",
"FunctionTool",
"FunctionToolManager",
"MCPAllServicesFailedError",
"MCPInitError",
"MCPInitSummary",
"MCPInitTimeoutError",
"MCPShutdownTimeoutError",
"ToolSet",
]
# MCP config constants (re-exported from protocols)
try:
from astrbot._internal.protocols.mcp import (
DEFAULT_MCP_CONFIG,
MCPAllServicesFailedError,
MCPInitError,
MCPInitSummary,
MCPInitTimeoutError,
MCPShutdownTimeoutError,
)
except ImportError:
DEFAULT_MCP_CONFIG: dict[str, Any] = {}
MCPAllServicesFailedError: type[Exception] = Exception
MCPInitError: type[Exception] = Exception
MCPInitSummary: type[dict] = dict
MCPInitTimeoutError: type[TimeoutError] = TimeoutError
MCPShutdownTimeoutError: type[TimeoutError] = TimeoutError
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_TIMEOUT_ENABLED"
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
class FunctionToolManager:
"""Central registry for all function tools."""
def __init__(self) -> None:
self._func_list: list[FunctionTool] = []
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return self._func_list
@func_list.setter
def func_list(self, value: list[FunctionTool]) -> None:
"""Set the list of function tools."""
self._func_list = value
def add(self, tool: FunctionTool) -> None:
"""Add a tool to the registry."""
self._func_list.append(tool)
def remove(self, name: str) -> bool:
"""Remove a tool by name. Returns True if found."""
for i, f in enumerate(self._func_list):
if f.name == name:
self._func_list.pop(i)
return True
return False
def get_func(self, name: str) -> FunctionTool | None:
"""Get a tool by name. Returns the last active tool if multiple match."""
last_match: FunctionTool | None = None
for f in reversed(self._func_list):
if f.name == name:
if getattr(f, "active", True):
return f
if last_match is None:
last_match = f
return last_match
def get_full_tool_set(self) -> ToolSet:
"""Return a ToolSet with all active tools, deduplicated by name."""
seen: dict[str, FunctionTool] = {}
for tool in reversed(self._func_list):
if tool.name not in seen and getattr(tool, "active", True):
seen[tool.name] = tool
return ToolSet("default", list(seen.values()))
def register_internal_tools(self) -> None:
"""Register built-in computer tools (shell, python, browser, neo)."""
# Import here to avoid circular imports
from astrbot.core.computer.computer_tool_provider import get_all_tools
for tool in get_all_tools():
if self.get_func(tool.name) is None:
self.add(tool)
# MCP-related stub methods for base class compatibility
async def enable_mcp_server(
self, name: str, config: dict[str, Any], init_timeout: int = 30
) -> None:
"""Enable an MCP server (stub)."""
pass
async def disable_mcp_server(
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
) -> None:
"""Disable an MCP server (stub)."""
pass
async def init_mcp_clients(self) -> None:
"""Initialize MCP clients (stub)."""
pass
async def test_mcp_server_connection(
self, config: dict[str, Any]
) -> tuple[bool, str]:
"""Test MCP server connection (stub)."""
return False, "Not implemented"
async def sync_modelscope_mcp_servers(self) -> None:
"""Sync ModelScope MCP servers (stub)."""
pass
def load_mcp_config(self) -> dict[str, Any]:
"""Load MCP configuration (stub)."""
return {"mcpServers": {}}
def save_mcp_config(self, config: dict[str, Any]) -> bool:
"""Save MCP configuration (stub)."""
return True
def activate_llm_tool(self, name: str) -> bool:
"""Activate an LLM tool (stub)."""
return True
def deactivate_llm_tool(self, name: str) -> bool:
"""Deactivate an LLM tool (stub)."""
return True
@property
def mcp_client_dict(self) -> dict[str, Any]:
"""Return dict of MCP clients (stub)."""
return {}
@property
def mcp_server_runtime_view(self) -> dict[str, Any]:
"""Return runtime view of MCP servers (stub)."""
return {}
class FuncCall(FunctionToolManager):
"""Alias for FunctionToolManager for backward compatibility."""
def __init__(self) -> None:
super().__init__()
self._mcp_server_runtime_view: dict[str, Any] = {}
self._mcp_client_dict: dict[str, Any] = {}
@property
def mcp_server_runtime_view(self) -> dict[str, Any]:
"""Return runtime view of MCP servers."""
return self._mcp_server_runtime_view
@property
def mcp_client_dict(self) -> dict[str, Any]:
"""Return dict of MCP clients (for backward compatibility)."""
return self._mcp_client_dict
async def init_mcp_clients(self) -> None:
"""Initialize MCP clients (stub implementation)."""
pass
def add_func(
self,
name: str,
func_args: list[dict[str, Any]],
desc: str,
handler: Any,
) -> None:
"""Add a function tool (deprecated, use add() instead)."""
params: dict[str, Any] = {
"type": "object",
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param.get("type", "string"),
"description": param.get("description", ""),
}
func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add(func)
def remove_func(self, name: str) -> None:
"""Remove a function tool by name (deprecated, use remove() instead)."""
self.remove(name)
def get_func(self, name: str) -> FunctionTool | None:
"""Get a function tool by name."""
return super().get_func(name)
def names(self) -> list[str]:
"""Get all tool names."""
return [f.name for f in self.func_list]
def remove_tool(self, name: str) -> None:
"""Remove a tool by its name (alias for remove)."""
self.remove(name)
def get_func_desc_openai_style(
self, omit_empty_parameter_field: bool = False
) -> list[dict[str, Any]]:
"""Get tools in OpenAI style (deprecated, use get_full_tool_set().openai_schema())."""
tool_set = self.get_full_tool_set()
return tool_set.openai_schema(omit_empty_parameter_field)
async def enable_mcp_server(
self, name: str, config: dict[str, Any], init_timeout: int = 30
) -> None:
"""Enable an MCP server (stub implementation)."""
pass
async def disable_mcp_server(
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
) -> None:
"""Disable an MCP server (stub implementation)."""
pass
def load_mcp_config(self) -> dict[str, Any]:
"""Load MCP configuration (stub implementation)."""
return {"mcpServers": {}}
def save_mcp_config(self, config: dict[str, Any]) -> bool:
"""Save MCP configuration (stub implementation)."""
return True
def activate_llm_tool(self, name: str) -> bool:
"""Activate an LLM tool (stub implementation)."""
return True
def deactivate_llm_tool(self, name: str) -> bool:
"""Deactivate an LLM tool (stub implementation)."""
return True
async def test_mcp_server_connection(
self, config: dict[str, Any]
) -> tuple[bool, str]:
"""Test MCP server connection (stub implementation)."""
# Import the actual test function if available
try:
from astrbot._internal.protocols.mcp.client import (
_quick_test_mcp_connection,
)
success, message = await _quick_test_mcp_connection(config)
if not success:
raise Exception(message)
return success, message
except Exception as e:
raise Exception(f"MCP connection test failed: {e!s}") from e
async def sync_modelscope_mcp_servers(self) -> None:
"""Sync ModelScope MCP servers (stub implementation)."""
pass
def get_full_tool_set(self) -> ToolSet:
"""Return a ToolSet with all active tools."""
return ToolSet("default", [t for t in self.func_list if t.active])

View File

@@ -158,6 +158,7 @@ class InternalAgentSubStage(Stage):
follow_up_capture: FollowUpCapture | None = None
follow_up_consumed_marked = False
follow_up_activated = False
typing_requested = False
try:
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
@@ -192,7 +193,11 @@ class InternalAgentSubStage(Stage):
)
return
await event.send_typing()
try:
typing_requested = True
await event.send_typing()
except Exception:
logger.warning("send_typing failed", exc_info=True)
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
sdk_plugin_bridge = getattr(
self.ctx.plugin_manager.context, "sdk_plugin_bridge", None
@@ -424,6 +429,11 @@ class InternalAgentSubStage(Stage):
)
await event.send(MessageChain().message(error_text))
finally:
if typing_requested:
try:
await event.stop_typing()
except Exception:
logger.warning("stop_typing failed", exc_info=True)
if follow_up_capture:
await finalize_follow_up_capture(
follow_up_capture,

View File

@@ -298,6 +298,12 @@ class AstrMessageEvent(abc.ABC):
默认实现为空,由具体平台按需重写。
"""
async def stop_typing(self) -> None:
"""停止输入中状态。
默认实现为空,由具体平台按需重写。
"""
async def _pre_send(self) -> None:
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""

View File

@@ -361,6 +361,18 @@ class TelegramPlatformAdapter(Platform):
logger.warning("Received an update without a message.")
return None
def _apply_caption() -> None:
if update.message.caption:
message.message_str = update.message.caption
message.message.append(Comp.Plain(message.message_str))
if update.message.caption and update.message.caption_entities:
for entity in update.message.caption_entities:
if entity.type == "mention":
name = update.message.caption[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(Comp.At(qq=name, name=name))
message = AstrBotMessage()
message.session_id = str(update.message.chat.id)
@@ -480,16 +492,7 @@ class TelegramPlatformAdapter(Platform):
photo = update.message.photo[-1] # get the largest photo
file = await photo.get_file()
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.caption:
message.message_str = update.message.caption
message.message.append(Comp.Plain(message.message_str))
if update.message.caption_entities:
for entity in update.message.caption_entities:
if entity.type == "mention":
name = message.message_str[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(Comp.At(qq=name, name=name))
_apply_caption()
elif update.message.sticker:
# 将sticker当作图片处理
@@ -512,6 +515,7 @@ class TelegramPlatformAdapter(Platform):
message.message.append(
Comp.File(file=file_path, name=file_name, url=file_path)
)
_apply_caption()
elif update.message.video:
file = await update.message.video.get_file()
@@ -523,6 +527,7 @@ class TelegramPlatformAdapter(Platform):
)
else:
message.message.append(Comp.Video(file=file_path, path=file.file_path))
_apply_caption()
return message

View File

@@ -3,6 +3,7 @@ import asyncio
import aiofiles
import anyio
from wechatpy.enterprise import WeChatClient
from wechatpy.exceptions import WeChatClientException
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -96,7 +97,19 @@ class WecomPlatformEvent(AstrMessageEvent):
# Split long text messages if needed
plain_chunks = await self.split_plain(comp.text)
for chunk in plain_chunks:
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
try:
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
except WeChatClientException as e:
if getattr(e, "errcode", None) == 40096:
# 40096: invalid external userid, fallback to regular message API
logger.warning(
f"kf API error 40096 for user {user_id}, falling back to regular message API"
)
self.client.message.send_text(
self.get_self_id(), user_id, chunk
)
else:
raise
await asyncio.sleep(0.5) # Avoid sending too fast
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()

View File

@@ -6,7 +6,7 @@ import hashlib
import io
import time
import uuid
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import quote
@@ -49,6 +49,17 @@ class OpenClawLoginSession:
error: str | None = None
@dataclass
class TypingSessionState:
ticket: str | None = None
ticket_context_token: str | None = None
refresh_after: float = 0.0
keepalive_task: asyncio.Task | None = None
cancel_task: asyncio.Task | None = None
owners: set[str] = field(default_factory=set)
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
@register_platform_adapter(
"weixin_oc",
"个人微信",
@@ -105,7 +116,16 @@ class WeixinOCAdapter(Platform):
self._sync_buf = ""
self._qr_expired_count = 0
self._context_tokens: dict[str, str] = {}
self._typing_states: dict[str, TypingSessionState] = {}
self._last_inbound_error = ""
self._typing_keepalive_interval_s = max(
1,
int(platform_config.get("weixin_oc_typing_keepalive_interval", 5)),
)
self._typing_ticket_ttl_s = max(
5,
int(platform_config.get("weixin_oc_typing_ticket_ttl", 60)),
)
self.token = str(platform_config.get("weixin_oc_token", "")).strip() or None
self.account_id = (
@@ -132,6 +152,316 @@ class WeixinOCAdapter(Platform):
self.client.api_timeout_ms = self.api_timeout_ms
self.client.token = self.token
def _get_typing_state(self, user_id: str) -> TypingSessionState:
state = self._typing_states.get(user_id)
if state is None:
state = TypingSessionState()
self._typing_states[user_id] = state
return state
def _typing_supported_for(self, user_id: str) -> bool:
if not self.token:
return False
return bool(self._context_tokens.get(user_id))
async def _cancel_task_safely(
self,
task: asyncio.Task | None,
*,
log_message: str | None = None,
log_args: tuple[Any, ...] = (),
) -> None:
if task is None or task.done():
return
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception:
if log_message is not None:
logger.warning(log_message, *log_args, exc_info=True)
async def _ensure_typing_ticket(
self,
user_id: str,
state: TypingSessionState,
) -> str | None:
now = time.monotonic()
context_token = self._context_tokens.get(user_id)
if not context_token:
return None
if (
state.ticket
and state.ticket_context_token == context_token
and state.refresh_after > now
):
return state.ticket
payload = await self.client.get_typing_config(user_id, context_token)
if int(payload.get("ret") or 0) != 0:
logger.warning(
"weixin_oc(%s): getconfig failed for %s: %s",
self.meta().id,
user_id,
payload.get("errmsg", ""),
)
return None
ticket = str(payload.get("typing_ticket", "")).strip()
if not ticket:
return None
state.ticket = ticket
state.ticket_context_token = context_token
state.refresh_after = time.monotonic() + self._typing_ticket_ttl_s
return ticket
async def _send_typing_state(
self,
user_id: str,
ticket: str,
*,
cancel: bool,
) -> None:
payload = await self.client.send_typing_state(user_id, ticket, cancel=cancel)
if int(payload.get("ret") or 0) != 0:
raise RuntimeError(
f"sendtyping failed for {user_id}: {payload.get('errmsg', '')}"
)
async def _run_typing_keepalive(self, user_id: str) -> None:
restart_needed = False
try:
await self._typing_keepalive_loop(user_id)
except asyncio.CancelledError:
raise
except Exception as e:
state = self._typing_states.get(user_id)
if state is not None:
async with state.lock:
state.refresh_after = 0.0
restart_needed = (
bool(state.owners) and not self._shutdown_event.is_set()
)
logger.warning(
"weixin_oc(%s): typing keepalive failed for %s: %s",
self.meta().id,
user_id,
e,
)
finally:
state = self._typing_states.get(user_id)
current_task = asyncio.current_task()
if state is not None and state.keepalive_task is current_task:
state.keepalive_task = None
if not restart_needed:
return
await asyncio.sleep(self._typing_keepalive_interval_s)
state = self._typing_states.get(user_id)
if state is None or self._shutdown_event.is_set():
return
async with state.lock:
if not state.owners or state.keepalive_task is not None:
return
state.keepalive_task = asyncio.create_task(
self._run_typing_keepalive(user_id)
)
async def _typing_keepalive_loop(self, user_id: str) -> None:
while not self._shutdown_event.is_set():
await asyncio.sleep(self._typing_keepalive_interval_s)
state = self._typing_states.get(user_id)
if state is None:
return
async with state.lock:
if not state.owners:
return
try:
ticket = await self._ensure_typing_ticket(user_id, state)
except Exception as e:
state.refresh_after = 0.0
logger.warning(
"weixin_oc(%s): refresh typing ticket failed for %s: %s",
self.meta().id,
user_id,
e,
)
continue
if not ticket:
continue
try:
await self._send_typing_state(user_id, ticket, cancel=False)
except Exception as e:
state.refresh_after = 0.0
logger.warning(
"weixin_oc(%s): typing keepalive send failed for %s: %s",
self.meta().id,
user_id,
e,
)
async def _delayed_cancel_typing(self, user_id: str, ticket: str) -> None:
await asyncio.sleep(0)
state = self._typing_states.get(user_id)
if state is None:
return
current_task = asyncio.current_task()
async with state.lock:
if state.cancel_task is not current_task:
return
if state.owners or state.keepalive_task is not None:
state.cancel_task = None
return
try:
await self._send_typing_state(user_id, ticket, cancel=True)
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(
"weixin_oc(%s): cancel typing failed for %s: %s",
self.meta().id,
user_id,
e,
)
finally:
state = self._typing_states.get(user_id)
if state is None:
return
async with state.lock:
if state.cancel_task is current_task:
state.cancel_task = None
async def start_typing(self, user_id: str, owner_id: str) -> None:
state = self._get_typing_state(user_id)
cancel_task: asyncio.Task | None = None
async with state.lock:
if owner_id in state.owners:
return
if not self._typing_supported_for(user_id):
return
if state.cancel_task is not None and not state.cancel_task.done():
cancel_task = state.cancel_task
cancel_task.cancel()
state.cancel_task = None
try:
ticket = await self._ensure_typing_ticket(user_id, state)
except Exception as e:
logger.warning(
"weixin_oc(%s): ensure typing ticket failed for %s: %s",
self.meta().id,
user_id,
e,
)
return
if not ticket:
return
state.ticket = ticket
state.owners.add(owner_id)
if state.keepalive_task is not None and not state.keepalive_task.done():
return
try:
await self._send_typing_state(user_id, ticket, cancel=False)
except Exception as e:
state.refresh_after = 0.0
logger.warning(
"weixin_oc(%s): send typing failed for %s: %s",
self.meta().id,
user_id,
e,
)
task = asyncio.create_task(self._run_typing_keepalive(user_id))
state.keepalive_task = task
if cancel_task is not None:
await self._cancel_task_safely(
cancel_task,
log_message="weixin_oc(%s): ignored error from cancelled typing task",
log_args=(self.meta().id,),
)
async def stop_typing(self, user_id: str, owner_id: str) -> None:
state = self._typing_states.get(user_id)
if state is None:
return
task: asyncio.Task | None = None
async with state.lock:
if owner_id not in state.owners:
return
state.owners.remove(owner_id)
if state.owners:
return
task = state.keepalive_task
state.keepalive_task = None
await self._cancel_task_safely(
task,
log_message="weixin_oc(%s): typing keepalive stop failed for %s",
log_args=(self.meta().id, user_id),
)
async with state.lock:
if state.owners:
return
ticket = state.ticket
if ticket:
if state.cancel_task is None or state.cancel_task.done():
state.cancel_task = asyncio.create_task(
self._delayed_cancel_typing(user_id, ticket)
)
async def _cleanup_typing_tasks(self) -> None:
tasks: list[asyncio.Task] = []
cancels: list[tuple[str, str]] = []
for user_id, state in list(self._typing_states.items()):
if state.ticket and (
state.owners
or state.keepalive_task is not None
or state.cancel_task is not None
):
cancels.append((user_id, state.ticket))
state.owners.clear()
if state.keepalive_task is not None and not state.keepalive_task.done():
tasks.append(state.keepalive_task)
state.keepalive_task.cancel()
state.keepalive_task = None
if state.cancel_task is not None and not state.cancel_task.done():
tasks.append(state.cancel_task)
state.cancel_task.cancel()
state.cancel_task = None
for task in tasks:
await self._cancel_task_safely(
task,
log_message="weixin_oc(%s): typing cleanup failed",
log_args=(self.meta().id,),
)
for user_id, ticket in cancels:
try:
await self._send_typing_state(user_id, ticket, cancel=True)
except Exception as e:
logger.warning(
"weixin_oc(%s): typing cleanup cancel failed for %s: %s",
self.meta().id,
user_id,
e,
)
def _load_account_state(self) -> None:
if not self.token:
token = str(self.config.get("weixin_oc_token", "")).strip()
@@ -902,15 +1232,24 @@ class WeixinOCAdapter(Platform):
"weixin_oc(%s): inbound long-poll timeout",
self.meta().id,
)
except Exception as e:
logger.error(
"weixin_oc(%s): poll inbound updates failed, will retry after 5 seconds: %s",
self.meta().id,
e,
)
await asyncio.sleep(5)
except asyncio.CancelledError:
raise
except Exception as e:
logger.exception("weixin_oc(%s): run failed: %s", self.meta().id, e)
finally:
await self._cleanup_typing_tasks()
await self.client.close()
async def terminate(self) -> None:
self._shutdown_event.set()
await self._cleanup_typing_tasks()
def get_stats(self) -> dict:
stat = super().get_stats()

View File

@@ -226,3 +226,44 @@ class WeixinOCClient:
if not text:
return {}
return cast(dict[str, Any], json.loads(text))
async def get_typing_config(
self,
user_id: str,
context_token: str,
) -> dict[str, Any]:
return await self.request_json(
"POST",
"ilink/bot/getconfig",
payload={
"ilink_user_id": user_id,
"context_token": context_token,
"base_info": {
"channel_version": "astrbot",
},
},
token_required=True,
timeout_ms=self.api_timeout_ms,
)
async def send_typing_state(
self,
user_id: str,
typing_ticket: str,
*,
cancel: bool,
) -> dict[str, Any]:
return await self.request_json(
"POST",
"ilink/bot/sendtyping",
payload={
"ilink_user_id": user_id,
"typing_ticket": typing_ticket,
"status": 2 if cancel else 1,
"base_info": {
"channel_version": "astrbot",
},
},
token_required=True,
timeout_ms=self.api_timeout_ms,
)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import uuid
from typing import TYPE_CHECKING
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -29,6 +30,12 @@ class WeixinOCMessageEvent(AstrMessageEvent):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.platform = platform
self._typing_owner_id: str | None = None
def _get_typing_owner_id(self) -> str:
if not self._typing_owner_id:
self._typing_owner_id = uuid.uuid4().hex
return self._typing_owner_id
@staticmethod
def _segment_to_text(segment: BaseMessageComponent) -> str:
@@ -58,6 +65,18 @@ class WeixinOCMessageEvent(AstrMessageEvent):
await self.platform.send_by_session(self.session, message)
await super().send(message)
async def send_typing(self) -> None:
await self.platform.start_typing(
self.session.session_id,
self._get_typing_owner_id(),
)
async def stop_typing(self) -> None:
await self.platform.stop_typing(
self.session.session_id,
self._get_typing_owner_id(),
)
async def send_streaming(self, generator, use_fallback: bool = False):
if not use_fallback:
buffer = None

View File

@@ -516,7 +516,7 @@ class ProviderAnthropic(Provider):
model = model or self.get_model()
payloads = {**kwargs, "messages": new_messages, "model": model}
payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts
if system_prompt:
@@ -572,7 +572,7 @@ class ProviderAnthropic(Provider):
model = model or self.get_model()
payloads = {**kwargs, "messages": new_messages, "model": model}
payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts
if system_prompt:

View File

@@ -758,7 +758,7 @@ class ProviderGoogleGenAI(Provider):
model = model or self.get_model()
payloads = {**kwargs, "messages": context_query, "model": model}
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
@@ -813,7 +813,7 @@ class ProviderGoogleGenAI(Provider):
model = model or self.get_model()
payloads = {**kwargs, "messages": context_query, "model": model}
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()

View File

@@ -27,8 +27,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
api_base = (
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
.strip()
.rstrip("/")
.rstrip("/embeddings")
.removesuffix("/")
.removesuffix("/embeddings")
)
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
# /v4 see #5699

20
astrbot/rust/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""AstrBot Rust Core module.
This module exposes the Rust core functionality via PyO3 bindings.
"""
from ._core import (
PyAbpClient,
PyOrchestrator,
cli,
get_abp_client,
get_orchestrator,
)
__all__ = [
"PyAbpClient",
"PyOrchestrator",
"cli",
"get_abp_client",
"get_orchestrator",
]

File diff suppressed because it is too large Load Diff

View File

@@ -173,8 +173,8 @@ export function useSessions(chatboxMode = false) {
);
const currentSessionDeleted = Boolean(
currentSessionId &&
sessionIds.includes(currentSessionId) &&
!failedSessionIds.has(currentSessionId),
sessionIds.includes(currentSessionId) &&
!failedSessionIds.has(currentSessionId),
);
if (currentSessionDeleted) {

View File

@@ -17,6 +17,7 @@ const customizer = useCustomizerStore();
const { locale } = useI18n();
const route = useRoute();
const routerLoadingStore = useRouterLoadingStore();
const isCurrentChatRoute = computed(() => route.path === '/chat' || route.path.startsWith('/chat/'));
const isChatPage = computed(() => {
return route.path.startsWith("/chat");
@@ -127,7 +128,7 @@ onMounted(() => {
<v-container
fluid
class="page-wrapper"
:class="{ 'chat-mode-container': showChatPage }"
:class="{ 'chat-mode-container': isCurrentChatRoute }"
:style="{
height: showChatPage ? '100%' : 'calc(100% - 8px)',
padding: isChatPage || showChatPage ? '0' : undefined,

View File

@@ -525,7 +525,7 @@ watch(
},
);
// Merry Christmas! 🎄
// Merry Christmas!
const isChristmas = computed(() => {
const today = new Date();
const month = today.getMonth() + 1; // getMonth() 返回 0-11
@@ -1250,14 +1250,12 @@ const isChristmas = computed(() => {
.markdown-content ol {
padding-left: 24px;
/* Adds indentation to ordered lists */
margin-top: 8px;
margin-bottom: 8px;
}
.markdown-content ul {
padding-left: 24px;
/* Adds indentation to unordered lists */
margin-top: 8px;
margin-bottom: 8px;
}

View File

@@ -1,8 +1,8 @@
/**
* Persona 文件夹管理 Store
*/
import { defineStore } from 'pinia';
import axios from '@/utils/request';
import { defineStore } from "pinia";
import axios from "@/utils/request";
// 类型定义
export interface PersonaFolder {
@@ -39,11 +39,11 @@ export interface FolderTreeNode {
export interface ReorderItem {
id: string;
type: 'persona' | 'folder';
type: "persona" | "folder";
sort_order: number;
}
export const usePersonaStore = defineStore('persona', {
export const usePersonaStore = defineStore("persona", {
state: () => ({
folderTree: [] as FolderTreeNode[],
currentFolderId: null as string | null,
@@ -59,9 +59,11 @@ export const usePersonaStore = defineStore('persona', {
// 当前文件夹名称
currentFolderName(): string {
if (this.breadcrumbPath.length === 0) {
return '根目录';
return "根目录";
}
return this.breadcrumbPath[this.breadcrumbPath.length - 1]?.name || '根目录';
return (
this.breadcrumbPath[this.breadcrumbPath.length - 1]?.name || "根目录"
);
},
},
@@ -96,11 +98,11 @@ export const usePersonaStore = defineStore('persona', {
async loadFolderTree(): Promise<void> {
this.treeLoading = true;
try {
const response = await axios.get('/api/persona/folder/tree');
if (response.data.status === 'ok') {
const response = await axios.get("/api/persona/folder/tree");
if (response.data.status === "ok") {
this.folderTree = response.data.data || [];
} else {
throw new Error(response.data.message || '获取文件夹树失败');
throw new Error(response.data.message || "获取文件夹树失败");
}
} finally {
this.treeLoading = false;
@@ -117,19 +119,19 @@ export const usePersonaStore = defineStore('persona', {
// 并行加载子文件夹和 Persona
const [foldersRes, personasRes] = await Promise.all([
axios.get('/api/persona/folder/list', {
params: { parent_id: folderId ?? '' }
axios.get("/api/persona/folder/list", {
params: { parent_id: folderId ?? "" },
}),
axios.get('/api/persona/list', {
params: { folder_id: folderId ?? '' }
axios.get("/api/persona/list", {
params: { folder_id: folderId ?? "" },
}),
]);
if (foldersRes.data.status === 'ok') {
if (foldersRes.data.status === "ok") {
this.currentFolders = foldersRes.data.data || [];
}
if (personasRes.data.status === 'ok') {
if (personasRes.data.status === "ok") {
this.currentPersonas = personasRes.data.data || [];
}
@@ -179,41 +181,41 @@ export const usePersonaStore = defineStore('persona', {
/**
* 移动 Persona 到文件夹
*/
async movePersonaToFolder(personaId: string, targetFolderId: string | null): Promise<void> {
const response = await axios.post('/api/persona/move', {
async movePersonaToFolder(
personaId: string,
targetFolderId: string | null,
): Promise<void> {
const response = await axios.post("/api/persona/move", {
persona_id: personaId,
folder_id: targetFolderId
folder_id: targetFolderId,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '移动人格失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "移动人格失败");
}
// 刷新当前文件夹内容和文件夹树
await Promise.all([
this.refreshCurrentFolder(),
this.loadFolderTree(),
]);
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
},
/**
* 移动文件夹到另一个文件夹
*/
async moveFolderToFolder(folderId: string, targetParentId: string | null): Promise<void> {
const response = await axios.post('/api/persona/folder/update', {
async moveFolderToFolder(
folderId: string,
targetParentId: string | null,
): Promise<void> {
const response = await axios.post("/api/persona/folder/update", {
folder_id: folderId,
parent_id: targetParentId
parent_id: targetParentId,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '移动文件夹失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "移动文件夹失败");
}
// 刷新当前文件夹内容和文件夹树
await Promise.all([
this.refreshCurrentFolder(),
this.loadFolderTree(),
]);
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
},
/**
@@ -224,20 +226,17 @@ export const usePersonaStore = defineStore('persona', {
parent_id?: string | null;
description?: string;
}): Promise<PersonaFolder> {
const response = await axios.post('/api/persona/folder/create', {
const response = await axios.post("/api/persona/folder/create", {
...data,
parent_id: data.parent_id ?? this.currentFolderId,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '创建文件夹失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "创建文件夹失败");
}
// 刷新当前文件夹内容和文件夹树
await Promise.all([
this.refreshCurrentFolder(),
this.loadFolderTree(),
]);
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
return response.data.data.folder;
},
@@ -250,48 +249,42 @@ export const usePersonaStore = defineStore('persona', {
name?: string;
description?: string;
}): Promise<void> {
const response = await axios.post('/api/persona/folder/update', data);
const response = await axios.post("/api/persona/folder/update", data);
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '更新文件夹失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "更新文件夹失败");
}
// 刷新当前文件夹内容和文件夹树
await Promise.all([
this.refreshCurrentFolder(),
this.loadFolderTree(),
]);
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
},
/**
* 删除文件夹
*/
async deleteFolder(folderId: string): Promise<void> {
const response = await axios.post('/api/persona/folder/delete', {
folder_id: folderId
const response = await axios.post("/api/persona/folder/delete", {
folder_id: folderId,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '删除文件夹失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "删除文件夹失败");
}
// 刷新当前文件夹内容和文件夹树
await Promise.all([
this.refreshCurrentFolder(),
this.loadFolderTree(),
]);
await Promise.all([this.refreshCurrentFolder(), this.loadFolderTree()]);
},
/**
* 删除 Persona
*/
async deletePersona(personaId: string): Promise<void> {
const response = await axios.post('/api/persona/delete', {
persona_id: personaId
const response = await axios.post("/api/persona/delete", {
persona_id: personaId,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '删除人格失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "删除人格失败");
}
// 刷新当前文件夹内容
@@ -301,14 +294,17 @@ export const usePersonaStore = defineStore('persona', {
/**
* 克隆 Persona
*/
async clonePersona(sourcePersonaId: string, newPersonaId: string): Promise<Persona> {
const response = await axios.post('/api/persona/clone', {
async clonePersona(
sourcePersonaId: string,
newPersonaId: string,
): Promise<Persona> {
const response = await axios.post("/api/persona/clone", {
source_persona_id: sourcePersonaId,
new_persona_id: newPersonaId
new_persona_id: newPersonaId,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '克隆人格失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "克隆人格失败");
}
// 刷新当前文件夹内容
@@ -321,10 +317,10 @@ export const usePersonaStore = defineStore('persona', {
* 批量更新排序
*/
async reorderItems(items: ReorderItem[]): Promise<void> {
const response = await axios.post('/api/persona/reorder', { items });
const response = await axios.post("/api/persona/reorder", { items });
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '更新排序失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "更新排序失败");
}
// 刷新当前文件夹内容
@@ -354,7 +350,7 @@ export const usePersonaStore = defineStore('persona', {
* 导入人格数据
*/
async importPersona(data: Partial<Persona>): Promise<Persona> {
const response = await axios.post('/api/persona/create', {
const response = await axios.post("/api/persona/create", {
persona_id: data.persona_id,
system_prompt: data.system_prompt,
begin_dialogs: data.begin_dialogs || [],
@@ -362,8 +358,8 @@ export const usePersonaStore = defineStore('persona', {
skills: data.skills,
});
if (response.data.status !== 'ok') {
throw new Error(response.data.message || '导入人格失败');
if (response.data.status !== "ok") {
throw new Error(response.data.message || "导入人格失败");
}
// 刷新当前文件夹内容
@@ -371,5 +367,5 @@ export const usePersonaStore = defineStore('persona', {
return response.data.data.persona;
},
}
},
});

View File

@@ -1,7 +1,7 @@
<template>
<v-card
class="persona-card"
:class="{ 'dragging': isDragging }"
:class="{ dragging: isDragging }"
rounded="lg"
elevation="1"
hover
@@ -27,50 +27,38 @@
<v-list density="compact">
<v-list-item @click.stop="$emit('edit')">
<template #prepend>
<v-icon size="small">
mdi-pencil
</v-icon>
<v-icon size="small"> mdi-pencil </v-icon>
</template>
<v-list-item-title>{{ tm('buttons.edit') }}</v-list-item-title>
<v-list-item-title>{{ tm("buttons.edit") }}</v-list-item-title>
</v-list-item>
<v-list-item @click.stop="$emit('clone')">
<template #prepend>
<v-icon size="small">
mdi-content-copy
</v-icon>
<v-icon size="small"> mdi-content-copy </v-icon>
</template>
<v-list-item-title>{{ tm('buttons.clone') }}</v-list-item-title>
<v-list-item-title>{{ tm("buttons.clone") }}</v-list-item-title>
</v-list-item>
<v-list-item @click.stop="$emit('move')">
<template #prepend>
<v-icon size="small">
mdi-folder-move
</v-icon>
<v-icon size="small"> mdi-folder-move </v-icon>
</template>
<v-list-item-title>{{ tm('persona.contextMenu.moveTo') }}</v-list-item-title>
<v-list-item-title>{{
tm("persona.contextMenu.moveTo")
}}</v-list-item-title>
</v-list-item>
<v-list-item @click.stop="$emit('export')">
<template #prepend>
<v-icon size="small">
mdi-download
</v-icon>
<v-icon size="small"> mdi-download </v-icon>
</template>
<v-list-item-title>{{ tm('persona.contextMenu.export') }}</v-list-item-title>
<v-list-item-title>{{
tm("persona.contextMenu.export")
}}</v-list-item-title>
</v-list-item>
<v-divider class="my-1" />
<v-list-item
class="text-error"
@click.stop="$emit('delete')"
>
<v-list-item class="text-error" @click.stop="$emit('delete')">
<template #prepend>
<v-icon
size="small"
color="error"
>
mdi-delete
</v-icon>
<v-icon size="small" color="error"> mdi-delete </v-icon>
</template>
<v-list-item-title>{{ tm('buttons.delete') }}</v-list-item-title>
<v-list-item-title>{{ tm("buttons.delete") }}</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
@@ -89,7 +77,11 @@
variant="tonal"
prepend-icon="mdi-chat"
>
{{ tm('labels.presetDialogs', { count: persona.begin_dialogs.length / 2 }) }}
{{
tm("labels.presetDialogs", {
count: persona.begin_dialogs.length / 2,
})
}}
</v-chip>
<v-chip
v-if="persona.tools === null"
@@ -98,7 +90,7 @@
variant="tonal"
prepend-icon="mdi-tools"
>
{{ tm('form.allToolsAvailable') }}
{{ tm("form.allToolsAvailable") }}
</v-chip>
<v-chip
v-else-if="persona.tools && persona.tools.length > 0"
@@ -107,7 +99,7 @@
variant="tonal"
prepend-icon="mdi-tools"
>
{{ persona.tools.length }} {{ tm('persona.toolsCount') }}
{{ persona.tools.length }} {{ tm("persona.toolsCount") }}
</v-chip>
<v-chip
v-if="persona.skills === null"
@@ -116,7 +108,7 @@
variant="tonal"
prepend-icon="mdi-lightning-bolt"
>
{{ tm('form.allSkillsAvailable') }}
{{ tm("form.allSkillsAvailable") }}
</v-chip>
<v-chip
v-else-if="persona.skills && persona.skills.length > 0"
@@ -125,142 +117,139 @@
variant="tonal"
prepend-icon="mdi-lightning-bolt"
>
{{ persona.skills.length }} {{ tm('persona.skillsCount') }}
{{ persona.skills.length }} {{ tm("persona.skillsCount") }}
</v-chip>
</div>
<div class="mt-3 text-caption text-medium-emphasis">
{{ tm('labels.createdAt') }}: {{ formatDate(persona.created_at) }}
{{ tm("labels.createdAt") }}: {{ formatDate(persona.created_at) }}
</div>
</v-card-text>
</v-card>
<!-- Custom Drag Preview -->
<div
ref="dragPreview"
class="drag-preview"
>
<v-icon
size="small"
class="mr-2"
>
mdi-account
</v-icon>
<div ref="dragPreview" class="drag-preview">
<v-icon size="small" class="mr-2"> mdi-account </v-icon>
<span class="text-subtitle-2">{{ persona.persona_id }}</span>
</div>
</template>
<script lang="ts">
import { defineComponent, type PropType } from 'vue';
import { useModuleI18n } from '@/i18n/composables';
import { defineComponent, type PropType } from "vue";
import { useModuleI18n } from "@/i18n/composables";
interface Persona {
persona_id: string;
system_prompt: string;
custom_error_message?: string | null;
begin_dialogs?: string[] | null;
tools?: string[] | null;
skills?: string[] | null;
created_at?: string;
updated_at?: string;
folder_id?: string | null;
[key: string]: any;
persona_id: string;
system_prompt: string;
custom_error_message?: string | null;
begin_dialogs?: string[] | null;
tools?: string[] | null;
skills?: string[] | null;
created_at?: string;
updated_at?: string;
folder_id?: string | null;
[key: string]: any;
}
export default defineComponent({
name: 'PersonaCard',
props: {
persona: {
type: Object as PropType<Persona>,
required: true
}
name: "PersonaCard",
props: {
persona: {
type: Object as PropType<Persona>,
required: true,
},
emits: ['view', 'edit', 'clone', 'move', 'export', 'delete'],
setup() {
const { tm } = useModuleI18n('features/persona');
return { tm };
},
data() {
return {
isDragging: false
};
},
methods: {
handleDragStart(event: DragEvent) {
this.isDragging = true;
if (event.dataTransfer) {
event.dataTransfer.effectAllowed = 'move';
event.dataTransfer.setData('application/json', JSON.stringify({
type: 'persona',
persona_id: this.persona.persona_id,
persona: this.persona
}));
},
emits: ["view", "edit", "clone", "move", "export", "delete"],
setup() {
const { tm } = useModuleI18n("features/persona");
return { tm };
},
data() {
return {
isDragging: false,
};
},
methods: {
handleDragStart(event: DragEvent) {
this.isDragging = true;
if (event.dataTransfer) {
event.dataTransfer.effectAllowed = "move";
event.dataTransfer.setData(
"application/json",
JSON.stringify({
type: "persona",
persona_id: this.persona.persona_id,
persona: this.persona,
}),
);
// Set custom drag image
const dragPreview = this.$refs.dragPreview as HTMLElement;
if (dragPreview) {
event.dataTransfer.setDragImage(dragPreview, 15, 15);
}
}
},
handleDragEnd() {
this.isDragging = false;
},
truncateText(text: string | undefined | null, maxLength: number): string {
if (!text) return '';
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
},
formatDate(dateString: string | undefined | null): string {
if (!dateString) return '';
return new Date(dateString).toLocaleString();
// Set custom drag image
const dragPreview = this.$refs.dragPreview as HTMLElement;
if (dragPreview) {
event.dataTransfer.setDragImage(dragPreview, 15, 15);
}
}
}
},
handleDragEnd() {
this.isDragging = false;
},
truncateText(text: string | undefined | null, maxLength: number): string {
if (!text) return "";
return text.length > maxLength
? text.substring(0, maxLength) + "..."
: text;
},
formatDate(dateString: string | undefined | null): string {
if (!dateString) return "";
return new Date(dateString).toLocaleString();
},
},
});
</script>
<style scoped>
.persona-card {
height: 100%;
cursor: grab;
transition: all 0.2s ease;
height: 100%;
cursor: grab;
transition: all 0.2s ease;
}
.persona-card:active {
cursor: grabbing;
cursor: grabbing;
}
.persona-card.dragging {
opacity: 0.5;
transform: scale(0.95);
opacity: 0.5;
transform: scale(0.95);
}
.persona-card:hover {
transform: translateY(-2px);
transform: translateY(-2px);
}
.system-prompt-preview {
font-size: 14px;
line-height: 1.4;
color: rgba(var(--v-theme-on-surface), 0.7);
overflow: hidden;
display: -webkit-box;
-webkit-line-clamp: 3;
line-clamp: 3;
-webkit-box-orient: vertical;
font-size: 14px;
line-height: 1.4;
color: rgba(var(--v-theme-on-surface), 0.7);
overflow: hidden;
display: -webkit-box;
-webkit-line-clamp: 3;
line-clamp: 3;
-webkit-box-orient: vertical;
}
.drag-preview {
position: fixed;
top: -1000px;
left: -1000px;
background: rgb(var(--v-theme-surface));
padding: 12px 20px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
display: flex;
align-items: center;
border: 1px solid rgba(var(--v-border-color), var(--v-border-opacity));
z-index: 9999;
pointer-events: none;
position: fixed;
top: -1000px;
left: -1000px;
background: rgb(var(--v-theme-surface));
padding: 12px 20px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
display: flex;
align-items: center;
border: 1px solid rgba(var(--v-border-color), var(--v-border-opacity));
z-index: 9999;
pointer-events: none;
}
</style>

File diff suppressed because it is too large Load Diff

View File

@@ -33,7 +33,8 @@ def create_mock_telegram_modules():
mock_telegram_ext = MagicMock()
mock_telegram_ext.ApplicationBuilder = MagicMock
mock_telegram_ext.ContextTypes = MagicMock
mock_telegram_ext.ContextTypes = MagicMock()
mock_telegram_ext.ContextTypes.DEFAULT_TYPE = MagicMock
mock_telegram_ext.ExtBot = MagicMock
mock_telegram_ext.filters = MagicMock()
mock_telegram_ext.filters.ALL = MagicMock()

View File

@@ -0,0 +1,108 @@
import asyncio
import importlib
import sys
from unittest.mock import MagicMock, patch
import pytest
import astrbot.api.message_components as Comp
from tests.fixtures.helpers import (
create_mock_file,
create_mock_update,
make_platform_config,
)
from tests.fixtures.mocks.telegram import create_mock_telegram_modules
_TELEGRAM_PLATFORM_ADAPTER = None
def _load_telegram_adapter():
global _TELEGRAM_PLATFORM_ADAPTER
if _TELEGRAM_PLATFORM_ADAPTER is not None:
return _TELEGRAM_PLATFORM_ADAPTER
mocks = create_mock_telegram_modules()
patched_modules = {
"telegram": mocks["telegram"],
"telegram.constants": mocks["telegram"].constants,
"telegram.error": mocks["telegram"].error,
"telegram.ext": mocks["telegram.ext"],
"telegramify_markdown": mocks["telegramify_markdown"],
"apscheduler": mocks["apscheduler"],
"apscheduler.schedulers": mocks["apscheduler"].schedulers,
"apscheduler.schedulers.asyncio": mocks["apscheduler"].schedulers.asyncio,
"apscheduler.schedulers.background": mocks["apscheduler"].schedulers.background,
}
with patch.dict(sys.modules, patched_modules):
sys.modules.pop("astrbot.core.platform.sources.telegram.tg_adapter", None)
module = importlib.import_module("astrbot.core.platform.sources.telegram.tg_adapter")
_TELEGRAM_PLATFORM_ADAPTER = module.TelegramPlatformAdapter
return _TELEGRAM_PLATFORM_ADAPTER
def _build_context() -> MagicMock:
context = MagicMock()
context.bot.username = "test_bot"
context.bot.id = 12345678
return context
@pytest.mark.asyncio
async def test_telegram_document_caption_populates_message_text_and_plain():
TelegramPlatformAdapter = _load_telegram_adapter()
adapter = TelegramPlatformAdapter(
make_platform_config("telegram"),
{},
asyncio.Queue(),
)
document = create_mock_file("https://api.telegram.org/file/test/report.md")
document.file_name = "report.md"
mention = MagicMock(type="mention", offset=0, length=6)
update = create_mock_update(
message_text=None,
document=document,
caption="@alice 请总结这份文档",
caption_entities=[mention],
)
result = await adapter.convert_message(update, _build_context())
assert result is not None
assert result.message_str == "@alice 请总结这份文档"
assert any(isinstance(component, Comp.File) for component in result.message)
assert any(
isinstance(component, Comp.Plain)
and component.text == "@alice 请总结这份文档"
for component in result.message
)
assert any(
isinstance(component, Comp.At) and component.qq == "alice"
for component in result.message
)
@pytest.mark.asyncio
async def test_telegram_video_caption_populates_message_text_and_plain():
TelegramPlatformAdapter = _load_telegram_adapter()
adapter = TelegramPlatformAdapter(
make_platform_config("telegram"),
{},
asyncio.Queue(),
)
video = create_mock_file("https://api.telegram.org/file/test/lesson.mp4")
video.file_name = "lesson.mp4"
update = create_mock_update(
message_text=None,
video=video,
caption="这段视频讲了什么",
)
result = await adapter.convert_message(update, _build_context())
assert result is not None
assert result.message_str == "这段视频讲了什么"
assert any(isinstance(component, Comp.Video) for component in result.message)
assert any(
isinstance(component, Comp.Plain) and component.text == "这段视频讲了什么"
for component in result.message
)

View File

@@ -651,6 +651,15 @@ class TestSendTyping:
await astr_message_event.send_typing()
class TestStopTyping:
"""Tests for stop_typing method."""
@pytest.mark.asyncio
async def test_stop_typing_default_empty(self, astr_message_event):
"""Test stop_typing default implementation is empty."""
await astr_message_event.stop_typing()
class TestReact:
"""Tests for react method."""
@@ -772,10 +781,12 @@ class TestDefensiveGetattr:
def test_get_message_type_with_non_enum_type(self, astr_message_event):
"""get_message_type should handle message_obj.type that is not a MessageType."""
class DummyMessage:
def __init__(self):
self.type = "not_an_enum"
self.message = []
astr_message_event.message_obj = DummyMessage()
message_type = astr_message_event.get_message_type()
assert isinstance(message_type, MessageType)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -22,113 +22,184 @@ class FakeReader(ByteReceiveStream):
return None
@pytest.mark.asyncio
async def test_lsp_read_responses_failure_disconnects_and_logs():
"""Test reader failures are handled inside _read_responses."""
client = AstrbotLspClient()
client._connected = True
client._reader = FakeReader(AsyncMock(side_effect=RuntimeError("reader crashed")))
class TestAstrbotLspClientInitialState:
"""Test LSP client initial state."""
with patch("astrbot._internal.protocols.lsp.client.log") as mock_log:
await client._read_responses()
assert client.connected is False
mock_log.error.assert_called_once()
def test_client_initial_state(self) -> None:
"""Test client starts disconnected."""
client = AstrbotLspClient()
assert client.connected is False
assert client._reader is None
assert client._writer is None
assert client._task_group is None
@pytest.mark.asyncio
async def test_lsp_read_responses_unexpected_exit_disconnects_and_warns():
"""Test non-cancelled reader exit updates connection state."""
client = AstrbotLspClient()
client._connected = True
client._reader = FakeReader(AsyncMock(return_value=b""))
class TestAstrbotLspClientConnect:
"""Test LSP client connect method."""
with patch("astrbot._internal.protocols.lsp.client.log") as mock_log:
await client._read_responses()
assert client.connected is False
mock_log.warning.assert_called_once()
@pytest.mark.asyncio
async def test_lsp_read_responses_clears_reader_task_reference_on_exit():
"""Test _read_responses clears the stored task reference when it exits."""
client = AstrbotLspClient()
client._connected = True
client._reader = FakeReader(AsyncMock(return_value=b""))
task = asyncio.create_task(client._read_responses())
client._reader_task = task
await task
assert client._reader_task is None
@pytest.mark.asyncio
async def test_lsp_stop_reader_task_swallows_failed_reader_exceptions():
"""Test reader teardown does not re-raise prior reader failures."""
client = AstrbotLspClient()
async def fail_reader() -> None:
raise RuntimeError("reader crashed")
client._reader_task = asyncio.create_task(fail_reader())
await asyncio.sleep(0)
await client._stop_reader_task()
assert client._reader_task is None
@pytest.mark.asyncio
async def test_lsp_connect_to_server_cancels_previous_reader_task_before_restart():
"""Test reconnect tears down an existing reader task before replacing it."""
client = AstrbotLspClient()
fake_process = SimpleNamespace(stdout=MagicMock(), stdin=MagicMock())
first_reader_cancelled = asyncio.Event()
first_reader_started = asyncio.Event()
async def first_reader() -> None:
first_reader_started.set()
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
first_reader_cancelled.set()
raise
with (
patch(
"astrbot._internal.protocols.lsp.client.anyio.open_process",
AsyncMock(return_value=fake_process),
),
patch.object(client, "send_request", AsyncMock(return_value={})),
patch.object(client, "send_notification", AsyncMock()),
):
client._read_responses = first_reader # type: ignore[method-assign]
await client.connect_to_server(["python", "first_lsp.py"], "file:///tmp")
await asyncio.wait_for(first_reader_started.wait(), timeout=1)
assert client.connected is True
second_reader = AsyncMock(return_value=None)
client._read_responses = second_reader # type: ignore[method-assign]
await client.connect_to_server(["python", "second_lsp.py"], "file:///tmp")
await asyncio.sleep(0)
assert first_reader_cancelled.is_set() is True
@pytest.mark.asyncio
async def test_connect_sets_connected_true(self) -> None:
"""Test connect() sets connected state."""
client = AstrbotLspClient()
await client.connect()
assert client.connected is True
@pytest.mark.asyncio
async def test_lsp_stop_reader_task_does_not_await_current_task():
"""Test stopping the reader from within itself does not self-await."""
client = AstrbotLspClient()
done = asyncio.Event()
class TestAstrbotLspClientSendRequest:
"""Test LSP client send_request method."""
async def stop_self() -> None:
client._reader_task = asyncio.current_task()
await client._stop_reader_task()
done.set()
@pytest.mark.asyncio
async def test_send_request_requires_connection(self) -> None:
"""Test send_request raises when not connected."""
client = AstrbotLspClient()
with pytest.raises(RuntimeError, match="not connected"):
await client.send_request("initialize", {})
task = asyncio.create_task(stop_self())
await asyncio.wait_for(done.wait(), timeout=1)
await task
@pytest.mark.asyncio
async def test_send_request_formats_jsonrpc_correctly(self) -> None:
"""Test send_request formats message as JSON-RPC 2.0."""
client = AstrbotLspClient()
client._connected = True
mock_writer = AsyncMock()
mock_event = MagicMock()
mock_event.wait = AsyncMock()
client._writer = mock_writer
client._pending_requests[0] = AsyncMock()
with patch("astrbot._internal.protocols.lsp.client.anyio.Event", return_value=mock_event):
# Timeout immediately to avoid hanging
with pytest.raises(TimeoutError, match="timed out"):
await client.send_request("initialize", {"processId": None})
class TestAstrbotLspClientSendNotification:
"""Test LSP client send_notification method."""
@pytest.mark.asyncio
async def test_send_notification_requires_connection(self) -> None:
"""Test send_notification raises when not connected."""
client = AstrbotLspClient()
with pytest.raises(RuntimeError, match="not connected"):
await client.send_notification("initialized", {})
@pytest.mark.asyncio
async def test_send_notification_formats_jsonrpc_correctly(self) -> None:
"""Test send_notification formats message as JSON-RPC 2.0."""
client = AstrbotLspClient()
client._connected = True
mock_writer = AsyncMock()
client._writer = mock_writer
await client.send_notification("initialized", {})
mock_writer.send.assert_called_once()
data = mock_writer.send.call_args[0][0]
decoded = data.decode("utf-8")
assert "Content-Length:" in decoded
assert '"jsonrpc": "2.0"' in decoded
assert '"method": "initialized"' in decoded
class TestAstrbotLspClientShutdown:
"""Test LSP client shutdown method."""
@pytest.mark.asyncio
async def test_shutdown_sets_connected_false(self) -> None:
"""Test shutdown disconnects the client."""
client = AstrbotLspClient()
client._connected = True
client._task_group = MagicMock()
client._task_group.__aexit__ = AsyncMock()
client._server_process = MagicMock()
client._server_process.terminate = MagicMock()
client._server_process.wait = AsyncMock()
client._server_process.kill = MagicMock()
client.send_notification = AsyncMock()
await client.shutdown()
assert client.connected is False
@pytest.mark.asyncio
async def test_shutdown_clears_pending_requests(self) -> None:
"""Test shutdown clears pending requests."""
client = AstrbotLspClient()
client._connected = True
client._pending_requests[1] = AsyncMock()
client._task_group = MagicMock()
client._task_group.__aexit__ = AsyncMock()
client._server_process = None
await client.shutdown()
assert len(client._pending_requests) == 0
class TestAstrbotLspClientReadResponses:
"""Test LSP client _read_responses method."""
@pytest.mark.asyncio
async def test_read_responses_returns_immediately_if_no_reader(self) -> None:
"""Test _read_responses exits early when _reader is None."""
client = AstrbotLspClient()
client._reader = None
client._connected = True
await client._read_responses()
# Should return without error
assert True
@pytest.mark.asyncio
async def test_read_responses_handles_empty_data_as_eof(self) -> None:
"""Test _read_responses breaks on empty data (EOF)."""
client = AstrbotLspClient()
client._connected = True
client._reader = FakeReader(AsyncMock(return_value=b""))
client._pending_requests = {}
# Should exit cleanly without raising
await client._read_responses()
assert client._connected is True # Note: current impl doesn't auto-disconnect on EOF
@pytest.mark.asyncio
async def test_read_responses_parses_jsonrpc_response(self) -> None:
"""Test _read_responses parses and dispatches JSON-RPC responses."""
client = AstrbotLspClient()
client._connected = True
response = {"jsonrpc": "2.0", "id": 0, "result": {}}
content = json.dumps(response).encode()
header = f"Content-Length: {len(content)}\r\n\r\n".encode()
# First call returns the message, second call returns empty (EOF)
fake_reader = FakeReader(AsyncMock(side_effect=[header + content, b""]))
client._reader = fake_reader
handler_called = False
async def handler(resp: dict) -> None:
nonlocal handler_called
handler_called = True
client._pending_requests[0] = handler
await client._read_responses()
assert handler_called is True
class TestAstrbotLspClientHandleNotification:
"""Test LSP client _handle_notification method."""
@pytest.mark.asyncio
async def test_handle_notification_logs_method_name(self) -> None:
"""Test _handle_notification logs the notification method."""
client = AstrbotLspClient()
notification = {"jsonrpc": "2.0", "method": "window/showMessage", "params": {}}
with patch("astrbot._internal.protocols.lsp.client.log") as mock_log:
await client._handle_notification(notification)
mock_log.debug.assert_called()