mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-04 03:30:15 +08:00
Compare commits
1 Commits
pr-5943-de
...
copilot/cr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c46cec8742 |
185
.env.example
185
.env.example
@@ -1,184 +1,93 @@
|
||||
# ==========================================
|
||||
# AstrBot Instance Configuration: ${INSTANCE_NAME}
|
||||
# AstrBot 实例配置文件:${INSTANCE_NAME}
|
||||
# AstrBot Environment Configuration Example
|
||||
# ==========================================
|
||||
# 将此文件复制为 .env 并根据需要修改。
|
||||
# Copy this file to .env and modify as needed.
|
||||
# 注意:在此处设置的变量将覆盖默认配置。
|
||||
# Note: Variables set here override application defaults.
|
||||
# Copy this file to .env and adjust the values as needed.
|
||||
# Note: Variables set here will override default configurations.
|
||||
|
||||
# ------------------------------------------
|
||||
# 实例标识 / Instance Identity
|
||||
# Core Configuration (核心配置)
|
||||
# ------------------------------------------
|
||||
|
||||
# 实例名称(用于日志和服务名)
|
||||
# Instance name (used in logs/service names)
|
||||
INSTANCE_NAME="${INSTANCE_NAME}"
|
||||
# AstrBot root directory path. Defaults to current working directory or ~/.astrbot for desktop client.
|
||||
# ASTRBOT_ROOT=/path/to/astrbot
|
||||
|
||||
# ------------------------------------------
|
||||
# 核心配置 / Core Configuration
|
||||
# ------------------------------------------
|
||||
|
||||
# AstrBot 根目录路径
|
||||
# AstrBot root directory path
|
||||
# 默认 Default: 当前工作目录,桌面客户端为 ~/.astrbot,服务器为 /var/lib/astrbot/<instance>/
|
||||
# 示例 Example: /var/lib/astrbot/mybot
|
||||
ASTRBOT_ROOT="${ASTRBOT_ROOT}"
|
||||
|
||||
# 日志等级
|
||||
# Log level
|
||||
# 可选值 Values: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
# 默认 Default: INFO
|
||||
# Log level. Options: DEBUG, INFO, WARNING, ERROR, CRITICAL. Default: INFO.
|
||||
# ASTRBOT_LOG_LEVEL=INFO
|
||||
|
||||
# 启用插件热重载(开发时有用)
|
||||
# Enable plugin hot reload (useful for development)
|
||||
# 可选值 Values: 0 (禁用 disabled), 1 (启用 enabled)
|
||||
# 默认 Default: 0
|
||||
# Enable plugin auto-reload. Set to "1" to enable. Useful for development.
|
||||
# ASTRBOT_RELOAD=0
|
||||
|
||||
# 禁用匿名使用统计
|
||||
# Disable anonymous usage statistics
|
||||
# 可选值 Values: 0 (启用统计 enabled), 1 (禁用统计 disabled)
|
||||
# 默认 Default: 0
|
||||
ASTRBOT_DISABLE_METRICS=0
|
||||
# Disable metrics upload. Set to "1" to disable anonymous usage statistics.
|
||||
# ASTRBOT_DISABLE_METRICS=0
|
||||
|
||||
# 覆盖 Python 可执行文件路径(用于本地代码执行功能)
|
||||
# Override Python executable path (for local code execution)
|
||||
# 示例 Example: /usr/bin/python3, /home/user/.pyenv/shims/python
|
||||
# Python executable path override (used for local code execution feature).
|
||||
# PYTHON=/usr/bin/python3
|
||||
|
||||
# 启用演示模式(可能限制部分功能)
|
||||
# Enable demo mode (may restrict certain features)
|
||||
# 可选值 Values: True, False
|
||||
# 默认 Default: False
|
||||
# Enable demo mode (might restrict some features).
|
||||
# DEMO_MODE=False
|
||||
|
||||
# 启用测试模式(影响日志和部分行为)
|
||||
# Enable testing mode (affects logging and behavior)
|
||||
# 可选值 Values: True, False
|
||||
# 默认 Default: False
|
||||
# Enable testing mode (affects logging and some behaviors).
|
||||
# TESTING=False
|
||||
|
||||
# 标记:是否通过桌面客户端执行(主要用于内部)
|
||||
# Flag: running via desktop client (internal use)
|
||||
# 可选值 Values: 0, 1
|
||||
# Flag indicating execution via desktop client (Internal use mostly).
|
||||
# ASTRBOT_DESKTOP_CLIENT=0
|
||||
|
||||
# 标记:是否通过 systemd 服务执行
|
||||
# Flag: running via systemd service
|
||||
# 可选值 Values: 0, 1
|
||||
ASTRBOT_SYSTEMD=1
|
||||
# Flag indicating execution via systemd service.
|
||||
# ASTRBOT_SYSTEMD=0
|
||||
|
||||
# ------------------------------------------
|
||||
# 管理面板配置 / Dashboard Configuration
|
||||
# Dashboard Configuration (管理面板配置)
|
||||
# ------------------------------------------
|
||||
|
||||
# 启用或禁用 WebUI 管理面板
|
||||
# Enable or disable WebUI dashboard
|
||||
# 可选值 Values: True, False
|
||||
# 默认 Default: True
|
||||
ASTRBOT_DASHBOARD_ENABLE=True
|
||||
# Enable or disable the WebUI Dashboard. Default: True.
|
||||
# ASTRBOT_DASHBOARD_ENABLE=True
|
||||
|
||||
# Dashboard bind host. Default: 0.0.0.0 (listen on all interfaces).
|
||||
# ASTRBOT_DASHBOARD_HOST=0.0.0.0
|
||||
|
||||
# Dashboard bind port. Default: 6185.
|
||||
# ASTRBOT_DASHBOARD_PORT=6185
|
||||
|
||||
# Enable SSL (HTTPS) for the dashboard.
|
||||
# ASTRBOT_DASHBOARD_SSL_ENABLE=False
|
||||
|
||||
# SSL Certificate path (required if SSL is enabled).
|
||||
# ASTRBOT_DASHBOARD_SSL_CERT=/path/to/cert.pem
|
||||
|
||||
# SSL Key path (required if SSL is enabled).
|
||||
# ASTRBOT_DASHBOARD_SSL_KEY=/path/to/key.pem
|
||||
|
||||
# SSL CA Certificates path (optional).
|
||||
# ASTRBOT_DASHBOARD_SSL_CA_CERTS=/path/to/ca.pem
|
||||
|
||||
# ------------------------------------------
|
||||
# 国际化配置 / Internationalization Configuration
|
||||
# Network Configuration (网络配置)
|
||||
# ------------------------------------------
|
||||
|
||||
# CLI 界面语言
|
||||
# CLI interface language
|
||||
# 可选值 Values: zh (中文), en (英文)
|
||||
# 默认 Default: zh (跟随系统 locale / follows system locale)
|
||||
# ASTRBOT_CLI_LANG=zh
|
||||
|
||||
# TUI 界面语言
|
||||
# TUI interface language
|
||||
# 可选值 Values: zh (中文), en (英文)
|
||||
# 默认 Default: zh
|
||||
# ASTRBOT_TUI_LANG=zh
|
||||
|
||||
# ------------------------------------------
|
||||
# 网络配置 / Network Configuration
|
||||
# ------------------------------------------
|
||||
|
||||
# API 绑定主机
|
||||
# API bind host
|
||||
# 示例 Example: 0.0.0.0 (所有接口 all interfaces), 127.0.0.1 (仅本地 localhost only)
|
||||
ASTRBOT_HOST="${ASTRBOT_HOST}"
|
||||
|
||||
# API 绑定端口
|
||||
# API bind port
|
||||
# 示例 Example: 3000, 6185, 8080
|
||||
ASTRBOT_PORT="${ASTRBOT_PORT}"
|
||||
|
||||
# 是否为 API 启用 SSL/TLS
|
||||
# Enable SSL/TLS for API
|
||||
# 可选值 Values: true, false
|
||||
# 默认 Default: false
|
||||
ASTRBOT_SSL_ENABLE=false
|
||||
|
||||
# SSL 证书路径(PEM 格式)
|
||||
# SSL certificate path (PEM format)
|
||||
# 示例 Example: /etc/astrbot/certs/myinstance/fullchain.pem
|
||||
ASTRBOT_SSL_CERT=""
|
||||
|
||||
# SSL 私钥路径(PEM 格式)
|
||||
# SSL private key path (PEM format)
|
||||
# 示例 Example: /etc/astrbot/certs/myinstance/privkey.pem
|
||||
ASTRBOT_SSL_KEY=""
|
||||
|
||||
# SSL CA 证书链路径(可选,用于客户端验证)
|
||||
# SSL CA certificates bundle (optional, for client verification)
|
||||
# 示例 Example: /etc/ssl/certs/ca-certificates.crt
|
||||
ASTRBOT_SSL_CA_CERTS=""
|
||||
|
||||
# ------------------------------------------
|
||||
# 代理配置 / Proxy Configuration
|
||||
# ------------------------------------------
|
||||
|
||||
# HTTP 代理地址
|
||||
# HTTP proxy URL
|
||||
# 示例 Example: http://127.0.0.1:7890, socks5://127.0.0.1:1080
|
||||
# HTTP/HTTPS Proxy URL (e.g., http://127.0.0.1:7890).
|
||||
# http_proxy=
|
||||
|
||||
# HTTPS 代理地址
|
||||
# HTTPS proxy URL
|
||||
# 示例 Example: http://127.0.0.1:7890, socks5://127.0.0.1:1080
|
||||
# https_proxy=
|
||||
|
||||
# 不走代理的主机列表(逗号分隔)
|
||||
# Hosts to bypass proxy (comma-separated)
|
||||
# 示例 Example: localhost,127.0.0.1,192.168.0.0/16,.local
|
||||
# No proxy list (comma-separated domains/IPs to bypass proxy).
|
||||
# no_proxy=localhost,127.0.0.1
|
||||
|
||||
# ------------------------------------------
|
||||
# 第三方集成 / Third-party Integrations
|
||||
# Integrations (第三方集成)
|
||||
# ------------------------------------------
|
||||
|
||||
# 阿里云 DashScope API 密钥(用于 Rerank 服务)
|
||||
# Alibaba DashScope API Key (for Rerank service)
|
||||
# 获取地址 Get from: https://dashscope.console.aliyun.com/
|
||||
# 示例 Example: sk-xxxxxxxxxxxx
|
||||
# DASHSCOPE_API_KEY=
|
||||
# Alibaba DashScope API Key (used for Rerank service).
|
||||
# DASHSCOPE_API_KEY=sk-xxxxxxxxxxxx
|
||||
|
||||
# Coze 集成
|
||||
# Coze integration
|
||||
# 获取地址 Get from: https://www.coze.com/
|
||||
# Coze Integration
|
||||
# COZE_API_KEY=
|
||||
# COZE_BOT_ID=
|
||||
|
||||
# 计算机控制相关的数据目录(用于截图/文件存储)
|
||||
# Computer control data directory (for screenshots/file storage)
|
||||
# 示例 Example: /var/lib/astrbot/bay_data
|
||||
# Computer Use data directory (for screenshot/file storage related to computer control).
|
||||
# BAY_DATA_DIR=
|
||||
|
||||
# ------------------------------------------
|
||||
# 平台特定配置 / Platform-specific Configuration
|
||||
# Platform Specific (平台特定配置)
|
||||
# ------------------------------------------
|
||||
|
||||
# QQ 官方机器人测试模式开关
|
||||
# QQ official bot test mode
|
||||
# 可选值 Values: on, off
|
||||
# 默认 Default: off
|
||||
# Test mode for QQ Official Bot.
|
||||
# TEST_MODE=off
|
||||
|
||||
# End of template / 模板结束
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v5.0.0
|
||||
uses: pnpm/action-setup@v4.4.0
|
||||
with:
|
||||
version: 10.28.2
|
||||
|
||||
|
||||
37
.github/workflows/unit_tests.yml
vendored
37
.github/workflows/unit_tests.yml
vendored
@@ -1,37 +0,0 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README*.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
name: Run pytest suite
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install uv
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
chmod +x scripts/run_pytests_ci.sh
|
||||
bash ./scripts/run_pytests_ci.sh ./tests
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -59,22 +59,7 @@ CharacterModels/
|
||||
GenieData/
|
||||
.agent/
|
||||
.codex/
|
||||
.claude/
|
||||
.opencode/
|
||||
.kilocode/
|
||||
.serena
|
||||
.worktrees/
|
||||
|
||||
.astrbot_sdk_testing/
|
||||
.env
|
||||
dashboard/warker.js
|
||||
dashboard/bun.lock
|
||||
.pua/
|
||||
|
||||
# Rust build artifacts
|
||||
rust/target/
|
||||
|
||||
# Build outputs
|
||||
dist/
|
||||
*.whl
|
||||
*.so
|
||||
|
||||
@@ -6,20 +6,20 @@ ci:
|
||||
autoupdate_schedule: weekly
|
||||
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.15.7
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
types_or: [python, pyi]
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [python, pyi]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.21.2
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py312-plus]
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [ python, pyi ]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
|
||||
20
AGENTS.md
20
AGENTS.md
@@ -29,21 +29,7 @@ Runs on `http://localhost:3000` by default.
|
||||
4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`.
|
||||
5. Use English for all new comments.
|
||||
6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory.
|
||||
7. Use Python 3.12+ type hinting syntax (e.g., `list[str]` over `List[str]`, `int | None` over `Optional[int]`). Avoid using `Any` and `cast()` - use proper TypedDict, dataclass, or Protocol instead. When encountering dict access issues (e.g., `msg.get("key")` where ty infers wrong type), define a TypedDict with `total=False` to explicitly declare allowed keys.
|
||||
|
||||
Good example:
|
||||
```python
|
||||
class MessageComponent(TypedDict, total=False):
|
||||
type: str
|
||||
text: str
|
||||
path: str
|
||||
```
|
||||
|
||||
Bad example (avoid):
|
||||
```python
|
||||
msg: Any = something
|
||||
msg = cast(dict, msg)
|
||||
```
|
||||
7. Use Python 3.12+ type hinting syntax (e.g., `list[str]` over `List[str]`, `int | None` over `Optional[int]`). Avoid using `Any` and ensure comprehensive type annotations are provided.
|
||||
8. When introducing new environment variables:
|
||||
- Use the `ASTRBOT_` prefix for naming (e.g., `ASTRBOT_ENABLE_FEATURE`).
|
||||
- Add the variable and description to `.env.example`.
|
||||
@@ -51,9 +37,9 @@ Runs on `http://localhost:3000` by default.
|
||||
- Add to the module docstring under "Environment Variables Used in Project".
|
||||
- Add to the `keys_to_print` list in the `run` function for debug output.
|
||||
9. To check all available CLI commands and their usage recursively, run `astrbot help --all`.
|
||||
10. uv sync --group dev && uv run pytest --cov=astrbot tests/
|
||||
|
||||
|
||||
## PR instructions
|
||||
|
||||
1. Title format: use conventional commit messages
|
||||
2. Use English to write PR title and descriptions.
|
||||
2. Use English to write PR title and descriptions.
|
||||
|
||||
180
CLAUDE.md
180
CLAUDE.md
@@ -1,180 +0,0 @@
|
||||
# AstrBot - Claude Code Guidelines
|
||||
|
||||
AstrBot is an open-source, all-in-one Agentic personal and group chat assistant supporting multiple IM platforms (QQ, Telegram, Discord, etc.) and LLM providers.
|
||||
|
||||
## Project Overview
|
||||
|
||||
- **Main entry**: `astrbot/__main__.py` or via CLI `astrbot run`
|
||||
- **CLI commands**: `astrbot/cli/commands/`
|
||||
- **Core modules**: `astrbot/core/`
|
||||
- **Platform adapters**: `astrbot/core/platform/sources/`
|
||||
- **Star plugins**: `astrbot/builtin_stars/`
|
||||
- **Dashboard**: `dashboard/` (Vue.js frontend)
|
||||
|
||||
## Development Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv tool install -e . --force
|
||||
|
||||
# Initialize AstrBot
|
||||
astrbot init
|
||||
|
||||
# Run development
|
||||
astrbot run
|
||||
|
||||
# Backend only (no WebUI)
|
||||
astrbot run --backend-only
|
||||
|
||||
# Dashboard frontend
|
||||
cd dashboard && bun dev
|
||||
|
||||
# Run tests
|
||||
uv sync --group dev && uv run pytest --cov=astrbot tests/
|
||||
```
|
||||
|
||||
## Code Style
|
||||
|
||||
### Python
|
||||
|
||||
1. **Type hints required** - Use Python 3.12+ syntax:
|
||||
- `list[str]` not `List[str]`
|
||||
- `int | None` not `Optional[int]`
|
||||
- Avoid `Any` when possible
|
||||
|
||||
2. **Path handling** - Always use `pathlib.Path`:
|
||||
```python
|
||||
from pathlib import Path
|
||||
# Use astrbot.core.utils.path_utils for data/temp directories
|
||||
from astrbot.core.utils.path_utils import get_astrbot_data_path
|
||||
```
|
||||
|
||||
3. **Formatting** - Run before committing:
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
4. **Comments** - Use English for all comments and docstrings
|
||||
|
||||
5. **Imports** - Use absolute imports via `astrbot.` prefix
|
||||
|
||||
### Environment Variables
|
||||
|
||||
When adding new environment variables:
|
||||
|
||||
1. Use `ASTRBOT_` prefix: `ASTRBOT_ENABLE_FEATURE`
|
||||
2. Add to `.env.example` with description
|
||||
3. Update `astrbot/cli/commands/cmd_run.py`:
|
||||
- Add to module docstring under "Environment Variables Used in Project"
|
||||
- Add to `keys_to_print` list for debug output
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
- `astrbot/core/` - Core bot functionality
|
||||
- `astrbot/core/platform/` - Platform adapter system
|
||||
- `astrbot/core/agent/` - Agent execution logic
|
||||
- `astrbot/core/star/` - Plugin/Star handler system
|
||||
- `astrbot/core/pipeline/` - Message processing pipeline
|
||||
- `astrbot/cli/` - Command-line interface
|
||||
|
||||
### Important Utilities
|
||||
|
||||
```python
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_root, # AstrBot root directory
|
||||
get_astrbot_data_path, # Data directory
|
||||
get_astrbot_config_path, # Config directory
|
||||
get_astrbot_plugin_path, # Plugin directory
|
||||
get_astrbot_temp_path, # Temp directory
|
||||
get_astrbot_skills_path, # Skills directory
|
||||
)
|
||||
```
|
||||
|
||||
### Platform Adapters
|
||||
|
||||
Platform adapters are in `astrbot/core/platform/sources/`:
|
||||
- Each adapter extends base platform classes
|
||||
- Use `@register_platform_adapter` decorator
|
||||
- Events flow through `commit_event()` to message queue
|
||||
|
||||
### Star (Plugin) System
|
||||
|
||||
Stars are plugins in `astrbot/builtin_stars/`:
|
||||
- Extend `Star` base class
|
||||
- Use decorators for command handlers: `@star.on_command`, `@star.on_message`, etc.
|
||||
- Access via `context` object
|
||||
|
||||
## Testing
|
||||
|
||||
1. Tests go in `tests/` directory
|
||||
2. Use `pytest` with `pytest-asyncio`
|
||||
3. Coverage target: `uv run pytest --cov=astrbot tests/`
|
||||
4. Test files: `test_*.py` or `*_test.py`
|
||||
|
||||
## Git Conventions
|
||||
|
||||
### Commit Messages
|
||||
|
||||
Use conventional commits:
|
||||
```
|
||||
feat: add new feature
|
||||
fix: resolve bug
|
||||
docs: update documentation
|
||||
refactor: restructure code
|
||||
test: add tests
|
||||
chore: maintenance tasks
|
||||
```
|
||||
|
||||
### PR Guidelines
|
||||
|
||||
1. Title: conventional commit format
|
||||
2. Description: English
|
||||
3. Target branch: `dev`
|
||||
4. Keep changes focused and atomic
|
||||
|
||||
## Project-Specific Guidelines
|
||||
|
||||
1. **No report files** - Do not add `xxx_SUMMARY.md` or similar
|
||||
2. **Componentization** - Maintain clean code, avoid duplication in WebUI
|
||||
3. **Backward compatibility** - When deprecating, add warnings
|
||||
4. **CLI help** - Run `astrbot help --all` to see all commands
|
||||
|
||||
## File Organization
|
||||
|
||||
```
|
||||
astrbot/
|
||||
├── __main__.py # Main entry point
|
||||
├── __init__.py # Package init, exports
|
||||
├── cli/ # CLI commands
|
||||
│ └── commands/ # Individual command modules
|
||||
├── core/ # Core functionality
|
||||
│ ├── agent/ # Agent execution
|
||||
│ ├── platform/ # Platform adapters
|
||||
│ ├── pipeline/ # Message processing
|
||||
│ ├── star/ # Plugin system
|
||||
│ └── config/ # Configuration
|
||||
├── builtin_stars/ # Built-in plugins
|
||||
├── dashboard/ # Vue.js frontend
|
||||
└── utils/ # Utilities
|
||||
```
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a new platform adapter
|
||||
1. Create adapter in `astrbot/core/platform/sources/`
|
||||
2. Extend `Platform` base class
|
||||
3. Use `@register_platform_adapter` decorator
|
||||
4. Implement required methods: `run()`, `convert_message()`, `meta()`
|
||||
|
||||
### Adding a new command
|
||||
1. Add to appropriate module in `cli/commands/`
|
||||
2. Register with `@click.command()`
|
||||
3. Update `astrbot/cli/__main__.py` to add command
|
||||
|
||||
### Adding a new Star handler
|
||||
1. Create in `astrbot/builtin_stars/` or as plugin
|
||||
2. Extend `Star` class
|
||||
3. Use decorators: `@star.on_command()`, `@star.on_schedule()`, etc.
|
||||
14
README.md
14
README.md
@@ -31,8 +31,9 @@
|
||||
<a href="https://astrbot.app/">Docs</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a> |
|
||||
<a href="mailto:community@astrbot.app">Email Support</a>
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issues</a>
|
||||
<a href="mailto:community@astrbot.app">Email</a>
|
||||
|
||||
</div>
|
||||
|
||||
AstrBot is an open-source, all-in-one Agentic personal and group chat assistant that can be deployed on dozens of mainstream instant messaging platforms such as QQ, Telegram, WeCom, Lark, DingTalk, Slack, and more. It also features a built-in lightweight ChatUI similar to OpenWebUI, creating a reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether it's a personal AI companion, smart customer service, automated assistant, or enterprise knowledge base, AstrBot enables you to quickly build AI applications within the workflow of your instant messaging platforms.
|
||||
@@ -198,13 +199,6 @@ Connect AstrBot to your favorite chat platforms.
|
||||
| Minimax TTS | Text-to-Speech |
|
||||
| Volcano Engine TTS | Text-to-Speech |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
|
||||
## ❤️ Contribution
|
||||
|
||||
Welcome any Issues/Pull Requests! Just submit your changes to this project :)
|
||||
@@ -304,4 +298,4 @@ _私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@@ -199,13 +199,6 @@ Connectez AstrBot à vos plateformes de chat préférées.
|
||||
| Minimax TTS | Synthèse vocale (Text-to-Speech) |
|
||||
| Volcengine TTS | Synthèse vocale (Text-to-Speech) |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
|
||||
## ❤️ Contribution
|
||||
|
||||
Les Issues et Pull Requests sont les bienvenus ! Soumettez simplement vos modifications à ce projet :)
|
||||
@@ -301,4 +294,4 @@ _私は、高性能ですから!_ (Je suis performant !)
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@@ -199,13 +199,6 @@ AstrBotを普段使用しているチャットプラットフォームに接続
|
||||
| Minimax TTS | 音声合成 (TTS) |
|
||||
| Volcengine TTS (火山エンジン) | 音声合成 (TTS) |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
|
||||
## ❤️ 貢献
|
||||
|
||||
IssueやPull Requestは大歓迎です!変更をこのプロジェクトに送信してください :)
|
||||
@@ -301,4 +294,4 @@ _私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@@ -199,13 +199,6 @@ yay -S astrbot-git
|
||||
| Minimax TTS | Синтез речи (TTS) |
|
||||
| Volcengine TTS | Синтез речи (TTS) |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
|
||||
## ❤️ Вклад в проект
|
||||
|
||||
Мы приветствуем любые Issues и Pull Requests! Просто отправьте свои изменения в этот проект :)
|
||||
@@ -301,4 +294,4 @@ _私は、高性能ですから!_ (Я высокопроизводительны
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@@ -199,16 +199,9 @@ yay -S astrbot-git
|
||||
| Minimax TTS | 文本轉語音 |
|
||||
| 火山引擎 TTS | 文本轉語音 |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
|
||||
## ❤️ 貢獻
|
||||
|
||||
歡迎任何 Issues/Pull Requests!只需要將你的更改提交到此項目 :)
|
||||
歡迎任何 Issues/Pull Requests!只需要將你的更改提交到此項目 :)
|
||||
|
||||
### 如何貢獻
|
||||
|
||||
@@ -301,4 +294,4 @@ _私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@@ -199,16 +199,9 @@ yay -S astrbot-git
|
||||
| Minimax TTS | 文本转语音 |
|
||||
| 火山引擎 TTS | 文本转语音 |
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
<p align="center">
|
||||
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
|
||||
</p>
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
### 如何贡献
|
||||
|
||||
|
||||
@@ -1,16 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from .core.log import LogManager
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import logger as logger
|
||||
|
||||
__all__ = ["logger"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "logger":
|
||||
from .core import logger
|
||||
|
||||
return logger
|
||||
raise AttributeError(name)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
import anyio
|
||||
|
||||
import runtime_bootstrap
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
@@ -24,9 +25,8 @@ from astrbot.core.utils.io import (
|
||||
download_dashboard,
|
||||
get_dashboard_version,
|
||||
)
|
||||
from astrbot.runtime_bootstrap import initialize_runtime_bootstrap
|
||||
|
||||
initialize_runtime_bootstrap()
|
||||
runtime_bootstrap.initialize_runtime_bootstrap()
|
||||
|
||||
|
||||
# 将父目录添加到 sys.path
|
||||
@@ -44,9 +44,9 @@ logo_tmpl = r"""
|
||||
|
||||
|
||||
def check_env() -> None:
|
||||
# Python version check: require 3.12 or 3.13
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor in (12, 13)):
|
||||
sys.exit(1)
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
|
||||
logger.error("请使用 Python3.10+ 运行本项目。")
|
||||
exit()
|
||||
|
||||
astrbot_root = get_astrbot_root()
|
||||
if astrbot_root not in sys.path:
|
||||
@@ -76,7 +76,7 @@ async def check_dashboard_files(webui_dir: str | None = None):
|
||||
if await anyio.Path(webui_dir).exists():
|
||||
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
|
||||
return webui_dir
|
||||
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
|
||||
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
|
||||
|
||||
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if await anyio.Path(data_dist_path).exists():
|
||||
@@ -84,41 +84,41 @@ async def check_dashboard_files(webui_dir: str | None = None):
|
||||
if v is not None:
|
||||
# 存在文件
|
||||
if v == f"v{VERSION}":
|
||||
logger.info("WebUI 版本已是最新。")
|
||||
logger.info("WebUI 版本已是最新。")
|
||||
else:
|
||||
logger.warning(
|
||||
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
|
||||
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
|
||||
)
|
||||
return data_dist_path
|
||||
|
||||
logger.info(
|
||||
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。",
|
||||
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。",
|
||||
)
|
||||
|
||||
try:
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"下载指定版本(v{VERSION})的管理面板文件失败: {e},尝试下载最新版本。"
|
||||
f"下载指定版本(v{VERSION})的管理面板文件失败: {e},尝试下载最新版本。"
|
||||
)
|
||||
try:
|
||||
await download_dashboard(latest=True)
|
||||
except Exception as e:
|
||||
logger.critical(f"下载管理面板文件失败: {e}。")
|
||||
logger.critical(f"下载管理面板文件失败: {e}。")
|
||||
return None
|
||||
|
||||
logger.info("管理面板下载完成。")
|
||||
logger.info("管理面板下载完成。")
|
||||
return data_dist_path
|
||||
|
||||
|
||||
async def main_async(webui_dir_arg: str | None, log_broker: LogBroker) -> None:
|
||||
async def main_async(webui_dir_arg: str | None) -> None:
|
||||
"""主异步入口"""
|
||||
# 检查仪表板文件
|
||||
webui_dir = await check_dashboard_files(webui_dir_arg)
|
||||
if webui_dir is None:
|
||||
logger.warning(
|
||||
"管理面板文件检查失败,WebUI 功能将不可用。"
|
||||
"请检查网络连接或手动指定 --webui-dir 参数。"
|
||||
"管理面板文件检查失败,WebUI 功能将不可用。"
|
||||
"请检查网络连接或手动指定 --webui-dir 参数。"
|
||||
)
|
||||
|
||||
db = db_helper
|
||||
@@ -148,4 +148,4 @@ if __name__ == "__main__":
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
|
||||
# 只使用一次 asyncio.run()
|
||||
asyncio.run(main_async(args.webui_dir, log_broker))
|
||||
asyncio.run(main_async(args.webui_dir))
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Astbot内部实现
|
||||
外部模块请勿导入
|
||||
|
||||
"""
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
ABP (AstrBot Protocol) client - in-process star communication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAstrbotAbpClient(ABC):
|
||||
"""
|
||||
ABP client: in-process star (plugin) communication.
|
||||
|
||||
Stars register themselves; client delegates calls to registered instances.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- register_star(name, instance) -> None
|
||||
- unregister_star(name) -> None
|
||||
- call_star_tool(star, tool, args) -> Any
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Lightweight: just sets connected=True."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def register_star(self, star_name: str, star_instance: Any) -> None:
|
||||
"""Add star to internal registry."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def unregister_star(self, star_name: str) -> None:
|
||||
"""Remove star from registry (idempotent)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_star_tool(
|
||||
self,
|
||||
star_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Delegate to star_instance.call_tool(tool_name, arguments)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Set connected=False, cancel pending requests."""
|
||||
...
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) client.
|
||||
|
||||
Transport: TCP | Unix Socket
|
||||
Messages: JSON with Content-Length header
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAstrbotAcpClient(ABC):
|
||||
"""
|
||||
ACP client: connects to ACP servers via TCP or Unix socket.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- connect_to_server(host, port) -> None
|
||||
- connect_to_unix_socket(path) -> None
|
||||
- call_tool(server, tool, args) -> Any
|
||||
- send_notification(method, params) -> None
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_server(self, host: str, port: int) -> None:
|
||||
"""Connect via TCP."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_unix_socket(self, socket_path: str) -> None:
|
||||
"""Connect via Unix domain socket."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Call tool on server, return result."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_notification(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any],
|
||||
) -> None:
|
||||
"""Send one-way notification."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Close connection, cancel pending requests."""
|
||||
...
|
||||
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) server.
|
||||
|
||||
Transport: TCP listening socket
|
||||
Messages: JSON with Content-Length header
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAstrbotAcpServer(ABC):
|
||||
"""
|
||||
ACP server: listens for client connections, exposes tools.
|
||||
|
||||
Subclass must implement:
|
||||
- start(host, port) -> None
|
||||
- register_tool(name, handler) -> None
|
||||
- register_notification_handler(name, handler) -> None
|
||||
- broadcast_notification(method, params) -> None
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def running(self) -> bool:
|
||||
"""True if server is accepting connections."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
|
||||
"""Bind and listen. Block until shutdown."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Register async tool handler (receives params dict, returns result)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def register_notification_handler(
|
||||
self,
|
||||
name: str,
|
||||
handler: Callable[..., Any],
|
||||
) -> None:
|
||||
"""Register async notification handler (receives params dict)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def broadcast_notification(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any],
|
||||
) -> None:
|
||||
"""Send notification to all connected clients."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Stop accepting, close all client connections."""
|
||||
...
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
AstrBot Gateway - HTTP/WebSocket API server.
|
||||
|
||||
Built on FastAPI, provides:
|
||||
- HTTP REST API (stats, inspector, config)
|
||||
- WebSocket for real-time events
|
||||
- Static file serving (dashboard)
|
||||
- Authentication (JWT/API key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseAstrbotGateway(ABC):
|
||||
"""
|
||||
Gateway: HTTP/WebSocket server built on FastAPI.
|
||||
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ FastAPI App │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ REST Endpoints WebSocket │
|
||||
│ ├─ GET /api/stats ├─ /ws (connection manager)│
|
||||
│ ├─ GET /api/inspector/* │ │
|
||||
│ ├─ GET /api/memory/* │ │
|
||||
│ └─ ... │ │
|
||||
│ │
|
||||
│ Middleware: CORS, Auth, Logging │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────┐
|
||||
│ Orchestrator │
|
||||
│ (owns protocol clients)│
|
||||
└─────────────────────────┘
|
||||
|
||||
Routes (typical):
|
||||
GET / → Dashboard static files
|
||||
GET /api/stats → System statistics
|
||||
GET /api/inspector/stars → List registered stars
|
||||
WS /ws → WebSocket for real-time events
|
||||
|
||||
serve() Lifecycle:
|
||||
1. Create FastAPI app
|
||||
2. Register routes
|
||||
3. Start WebSocket manager
|
||||
4. Bind to host:port
|
||||
5. Run ASGI server (uvicorn/hypercorn)
|
||||
6. Block until shutdown
|
||||
7. Close all connections
|
||||
|
||||
Subclass must implement:
|
||||
- serve(): start server, block until shutdown
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def serve(self) -> None:
|
||||
"""
|
||||
Start gateway server - blocks until shutdown.
|
||||
|
||||
Should:
|
||||
1. Create FastAPI app with routes
|
||||
2. Configure CORS, auth middleware
|
||||
3. Start WebSocket connection manager
|
||||
4. Bind to ASTRBOT_PORT (default 6185)
|
||||
5. Run ASGI server
|
||||
6. Handle graceful shutdown on SIGTERM/SIGINT
|
||||
|
||||
Raises:
|
||||
OSError: address already in use
|
||||
"""
|
||||
...
|
||||
@@ -1,352 +0,0 @@
|
||||
"""
|
||||
AstrBot Orchestrator - core runtime lifecycle manager.
|
||||
|
||||
Architecture
|
||||
============
|
||||
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Orchestrator │
|
||||
│ (owns lifecycle of all protocol clients + stars) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌──────────────┼──────────────┐
|
||||
▼ ▼ ▼
|
||||
┌─────────┐ ┌─────────┐ ┌─────────┐
|
||||
│ LSP │ │ MCP │ │ ACP │
|
||||
│ Client │ │ Client │ │ Client │
|
||||
└─────────┘ └─────────┘ └─────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
LSP Servers MCP Servers ACP Services
|
||||
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ ABP Client │
|
||||
│ (in-process star registry) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────┐
|
||||
│ Stars │
|
||||
│(Plugins) │
|
||||
└─────────┘
|
||||
|
||||
|
||||
Lifecycle State Machine
|
||||
=======================
|
||||
|
||||
States:
|
||||
┌─────────┐
|
||||
│ INIT │───► orchestrator created, clients not initialized
|
||||
└────┬────┘
|
||||
│ start()
|
||||
▼
|
||||
┌─────────┐
|
||||
│ RUNNING │◄─── run_loop() executing
|
||||
└────┬────┘
|
||||
│ shutdown()
|
||||
▼
|
||||
┌──────────┐
|
||||
│ SHUTDOWN │─── all clients closed, ready for GC
|
||||
└──────────┘
|
||||
|
||||
Transitions:
|
||||
INIT + start() ──► RUNNING
|
||||
RUNNING + shutdown() ──► SHUTDOWN
|
||||
|
||||
For each protocol client, the orchestrator:
|
||||
1. Creates instance in __init__
|
||||
2. Calls connect() to initialize
|
||||
3. Calls protocol-specific setup (connect_to_server, etc)
|
||||
4. Manages via run_loop() heartbeat
|
||||
5. Calls shutdown() on final cleanup
|
||||
|
||||
|
||||
Star Registration Flow
|
||||
=====================
|
||||
|
||||
orchestrator.register_star("my-star", MyStar())
|
||||
│
|
||||
▼
|
||||
┌───────────────────┐
|
||||
│ ABP Client │
|
||||
│ .register_star() │
|
||||
└───────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────┐
|
||||
│ Internal dict │
|
||||
│ {"my-star": obj} │
|
||||
└───────────────────┘
|
||||
|
||||
|
||||
Message Routing (conceptual)
|
||||
===========================
|
||||
|
||||
External Tool Call
|
||||
│
|
||||
▼
|
||||
┌──────────────┐ list_tools() ┌──────────────┐
|
||||
│ MCP Client │────────────────────►│ MCP Server │
|
||||
└──────────────┘◄────────────────────└──────────────┘
|
||||
│ tool result
|
||||
▼
|
||||
┌──────────────┐ call_tool() ┌──────────────┐
|
||||
│ ABP │────────────────────►│ Star │
|
||||
│ Client │◄────────────────────└──────────────┘
|
||||
└──────────────┘ tool result
|
||||
│
|
||||
▼
|
||||
Return to caller
|
||||
|
||||
|
||||
run_loop() Responsibilities
|
||||
===========================
|
||||
|
||||
while running:
|
||||
│─ check LSP server health (ping/heartbeat)
|
||||
│─ check MCP session status (reconnect if needed)
|
||||
│─ check ACP client connections
|
||||
│─ process any pending star notifications
|
||||
│─ sleep(SLEEP_INTERVAL)
|
||||
|
||||
|
||||
Shutdown Sequence
|
||||
==================
|
||||
|
||||
shutdown()
|
||||
│
|
||||
├─ set _running = False
|
||||
│
|
||||
├─ LSP.shutdown()
|
||||
│ └─ send "shutdown" request
|
||||
│ └─ terminate subprocess
|
||||
│
|
||||
├─ ACP.shutdown()
|
||||
│ └─ close TCP/Unix connections
|
||||
│
|
||||
├─ ABP.shutdown()
|
||||
│ └─ cancel pending requests
|
||||
│
|
||||
└─ MCP.cleanup()
|
||||
└─ close all sessions
|
||||
└─ cleanup subprocesses
|
||||
|
||||
|
||||
Exception Handling
|
||||
==================
|
||||
|
||||
Each protocol client should:
|
||||
- Catch connection errors
|
||||
- Attempt reconnection with exponential backoff
|
||||
- Log errors but don't crash run_loop
|
||||
- Raise on irrecoverable failures
|
||||
|
||||
The orchestrator run_loop should:
|
||||
- Catch CancelledError on shutdown
|
||||
- Catch Exception and log (don't crash)
|
||||
- Ensure cleanup runs in finally block
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
|
||||
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
|
||||
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
|
||||
|
||||
#: Default heartbeat interval for run_loop()
|
||||
DEFAULT_SLEEP_INTERVAL: float = 5.0
|
||||
|
||||
|
||||
class BaseAstrbotOrchestrator(ABC):
|
||||
"""
|
||||
Core runtime: owns lifecycle of all protocol clients and stars.
|
||||
|
||||
┌────────────────────────────────────────────────────────────┐
|
||||
│ Protocol Clients (always present, never None after init) │
|
||||
├────────────────────────────────────────────────────────────┤
|
||||
│ lsp: Language Server Protocol │
|
||||
│ Purpose: code completion, diagnostics, hover, etc │
|
||||
│ Transport: stdio subprocess │
|
||||
│ │
|
||||
│ mcp: Model Context Protocol │
|
||||
│ Purpose: external tool access │
|
||||
│ Transport: stdio | SSE | HTTP │
|
||||
│ │
|
||||
│ acp: AstrBot Communication Protocol │
|
||||
│ Purpose: inter-service communication │
|
||||
│ Transport: TCP | Unix Socket │
|
||||
│ │
|
||||
│ abp: AstrBot Protocol │
|
||||
│ Purpose: in-process star (plugin) communication │
|
||||
│ Transport: direct method calls │
|
||||
└────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌────────────────────────────────────────────────────────────┐
|
||||
│ Star Registry │
|
||||
├────────────────────────────────────────────────────────────┤
|
||||
│ _stars: dict[str, Any] │
|
||||
│ Stars are plugins registered by name │
|
||||
│ ABP client delegates calls to registered stars │
|
||||
└────────────────────────────────────────────────────────────┘
|
||||
|
||||
Subclass must implement:
|
||||
- __init__(): create all protocol client instances
|
||||
- run_loop(): main event loop (block until shutdown)
|
||||
- register_star(name, instance): add to registry + ABP
|
||||
- unregister_star(name): remove from registry + ABP
|
||||
- shutdown(): clean up all clients
|
||||
"""
|
||||
|
||||
#: LSP client for language intelligence
|
||||
lsp: AstrbotLspClient
|
||||
|
||||
#: MCP client for external tools
|
||||
mcp: McpClient
|
||||
|
||||
#: ACP client for inter-service communication
|
||||
acp: AstrbotAcpClient
|
||||
|
||||
#: ABP client for in-process star communication
|
||||
abp: AstrbotAbpClient
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize orchestrator and all protocol clients.
|
||||
|
||||
After __init__, all clients exist but are not connected.
|
||||
Call start() or run_loop() to begin operation.
|
||||
|
||||
Example:
|
||||
class MyOrchestrator(BaseAstrbotOrchestrator):
|
||||
def __init__(self):
|
||||
self.lsp = AstrbotLspClient()
|
||||
self.mcp = McpClient()
|
||||
self.acp = AstrbotAcpClient()
|
||||
self.abp = AstrbotAbpClient()
|
||||
self._stars: dict[str, Any] = {}
|
||||
self._running = False
|
||||
"""
|
||||
self._stars: dict[str, Any] = {}
|
||||
self._running: bool = False
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""True if run_loop() is executing."""
|
||||
return self._running
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Initialize all protocol clients.
|
||||
|
||||
Called once before run_loop(). Should:
|
||||
1. Call lsp.connect()
|
||||
2. Call mcp.connect()
|
||||
3. Call acp.connect()
|
||||
4. Call abp.connect()
|
||||
5. Set _running = True
|
||||
|
||||
Raises:
|
||||
Exception: if any client fails to initialize
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def run_loop(self) -> None:
|
||||
"""
|
||||
Main event loop - blocks until shutdown.
|
||||
|
||||
Execution:
|
||||
self._running = True
|
||||
try:
|
||||
while self._running:
|
||||
await self._heartbeat()
|
||||
await anyio.sleep(DEFAULT_SLEEP_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
pass # shutdown requested
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
_heartbeat() responsibilities:
|
||||
- Check LSP server health (optional ping)
|
||||
- Check MCP session status, reconnect if needed
|
||||
- Check ACP connections
|
||||
- Process any pending star notifications
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: when shutdown() called
|
||||
|
||||
Note:
|
||||
Subclass defines _heartbeat() for periodic tasks.
|
||||
This method only handles the loop control.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def register_star(self, name: str, star_instance: Any) -> None:
|
||||
"""
|
||||
Register a star (plugin) with the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the star
|
||||
instance: Star plugin instance (must have .call_tool() method)
|
||||
|
||||
Does:
|
||||
self._stars[name] = star_instance
|
||||
self.abp.register_star(name, star_instance)
|
||||
|
||||
Raises:
|
||||
ValueError: if name already registered
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def unregister_star(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a star (plugin) from the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Identifier of star to remove
|
||||
|
||||
Does:
|
||||
del self._stars[name]
|
||||
self.abp.unregister_star(name)
|
||||
|
||||
Note:
|
||||
Idempotent - does nothing if name not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_star(self, name: str) -> Any | None:
|
||||
"""Get registered star by name. Returns None if not found."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_stars(self) -> list[str]:
|
||||
"""Return list of registered star names."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Graceful shutdown of orchestrator and all clients.
|
||||
|
||||
Execution order:
|
||||
1. self._running = False (stop run_loop)
|
||||
2. await lsp.shutdown()
|
||||
3. await acp.shutdown()
|
||||
4. await abp.shutdown()
|
||||
5. await mcp.cleanup()
|
||||
|
||||
Does NOT unregister stars - caller should do that first.
|
||||
|
||||
After shutdown, orchestrator is ready for garbage collection.
|
||||
"""
|
||||
...
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
LSP (Language Server Protocol) client.
|
||||
|
||||
Transport: stdio subprocess
|
||||
Messages: JSON-RPC 2.0 with Content-Length header
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class LspMessage:
|
||||
"""JSON-RPC 2.0 message."""
|
||||
|
||||
jsonrpc: str = "2.0"
|
||||
id: int | str | None = None
|
||||
method: str | None = None
|
||||
params: dict[str, Any] | None = None
|
||||
result: Any = None
|
||||
error: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LspRequest(LspMessage):
|
||||
"""Outgoing request."""
|
||||
|
||||
def __init__(self, method: str, params: dict[str, Any] | None = None) -> None:
|
||||
self.id = id(self)
|
||||
self.method = method
|
||||
self.params = params
|
||||
|
||||
|
||||
class LspResponse(LspMessage):
|
||||
"""Incoming response."""
|
||||
|
||||
|
||||
class LspNotification(LspMessage):
|
||||
"""Incoming notification (no id)."""
|
||||
|
||||
|
||||
class BaseAstrbotLspClient(ABC):
|
||||
"""
|
||||
LSP client: connects to LSP servers via stdio subprocess.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- connect_to_server(command, workspace_uri) -> None
|
||||
- send_request(method, params) -> dict
|
||||
- send_notification(method, params) -> None
|
||||
- shutdown() -> None
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to an LSP server."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
self._connected = False
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_server(
|
||||
self,
|
||||
command: list[str],
|
||||
workspace_uri: str,
|
||||
) -> None:
|
||||
"""
|
||||
Start LSP server subprocess and complete handshake.
|
||||
|
||||
Steps:
|
||||
1. Spawn subprocess with stdin/stdout pipes
|
||||
2. Send initialize request
|
||||
3. Wait for response
|
||||
4. Send initialized notification
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_request(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send JSON-RPC request and return result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: not connected
|
||||
Exception: server returned error
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send_notification(
|
||||
self,
|
||||
method: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send JSON-RPC notification (no response expected).
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""Send shutdown, terminate subprocess, cleanup."""
|
||||
...
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) client.
|
||||
|
||||
Transport: stdio | SSE | streamable_http
|
||||
Messages: JSON-RPC 2.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class McpServerConfig(TypedDict, total=False):
|
||||
"""MCP server configuration."""
|
||||
|
||||
# Stdio transport
|
||||
command: str
|
||||
args: list[str]
|
||||
env: dict[str, str]
|
||||
cwd: str
|
||||
|
||||
# HTTP transport
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
transport: Literal["sse", "streamable_http"]
|
||||
|
||||
|
||||
class McpToolInfo(TypedDict):
|
||||
"""MCP tool descriptor."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
inputSchema: dict[str, Any]
|
||||
|
||||
|
||||
class BaseAstrbotMcpClient(ABC):
|
||||
"""
|
||||
MCP client: connects to MCP servers for external tools.
|
||||
|
||||
Subclass must implement:
|
||||
- connect() -> None
|
||||
- connect_to_server(config, name) -> None
|
||||
- list_tools() -> list[McpToolInfo]
|
||||
- call_tool(name, args, timeout) -> CallToolResult
|
||||
- cleanup() -> None
|
||||
"""
|
||||
|
||||
session: Any # mcp.ClientSession
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def connected(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Initialize client session."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def connect_to_server(
|
||||
self,
|
||||
config: McpServerConfig,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Connect to MCP server.
|
||||
|
||||
Stdio: {"command": "python", "args": ["server.py"], "env": {...}}
|
||||
HTTP: {"url": "https://...", "transport": "sse"}
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_tools(self) -> list[McpToolInfo]:
|
||||
"""Call tools/list and return tools."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
read_timeout_seconds: int = 60,
|
||||
) -> Any:
|
||||
"""Call tools/call with reconnection support."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Close all server connections."""
|
||||
...
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Gateway module - FastAPI server for the dashboard backend."""
|
||||
|
||||
from .server import AstrbotGateway
|
||||
from .ws_manager import WebSocketManager
|
||||
|
||||
__all__ = ["AstrbotGateway", "WebSocketManager"]
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
依赖注入
|
||||
|
||||
"""
|
||||
@@ -1,248 +0,0 @@
|
||||
"""
|
||||
AstrBot Gateway - FastAPI server for the dashboard backend.
|
||||
|
||||
Provides REST API endpoints and WebSocket connections for the frontend dashboard.
|
||||
The gateway acts as the communication bridge between the dashboard and the orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.geteway.ws_manager import WebSocketManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
else:
|
||||
try:
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
except ImportError:
|
||||
logger.warning("FastAPI not installed, gateway unavailable.")
|
||||
FastAPI = cast(Any, None)
|
||||
WebSocket = cast(Any, None)
|
||||
WebSocketDisconnect = cast(Any, None)
|
||||
CORSMiddleware = cast(Any, None)
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotGateway(BaseAstrbotGateway):
|
||||
"""
|
||||
FastAPI-based gateway server for AstrBot.
|
||||
|
||||
Handles:
|
||||
- REST API endpoints for configuration and stats
|
||||
- WebSocket connections for real-time communication
|
||||
- CORS middleware for dashboard access
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BaseAstrbotOrchestrator) -> None:
|
||||
self.orchestrator = orchestrator
|
||||
self.ws_manager = WebSocketManager()
|
||||
self._app: FastAPI | None = None
|
||||
self._host = "0.0.0.0"
|
||||
self._port = 8765
|
||||
|
||||
async def serve(self) -> None:
|
||||
"""
|
||||
Start the gateway server.
|
||||
|
||||
Creates and runs a FastAPI application with WebSocket support.
|
||||
"""
|
||||
if FastAPI is None:
|
||||
raise RuntimeError("FastAPI is not installed")
|
||||
|
||||
log.info(f"Starting AstrBot Gateway on {self._host}:{self._port}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
log.info("Gateway server started.")
|
||||
yield
|
||||
# Shutdown
|
||||
await self.ws_manager.broadcast({"type": "server_shutdown"})
|
||||
log.info("Gateway server stopped.")
|
||||
|
||||
self._app = FastAPI(
|
||||
title="AstrBot Gateway",
|
||||
description="Backend API for AstrBot dashboard",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
self._app.add_middleware(
|
||||
cast(Any, CORSMiddleware),
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
self._setup_routes()
|
||||
|
||||
# Run with uvicorn
|
||||
import uvicorn
|
||||
|
||||
config = uvicorn.Config(
|
||||
self._app,
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
log_level="info",
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def _setup_routes(self) -> None:
|
||||
"""Set up API routes."""
|
||||
if self._app is None:
|
||||
return
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
# Health check
|
||||
@self._app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# WebSocket endpoint
|
||||
@self._app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket):
|
||||
await self.ws_manager.connect(ws)
|
||||
try:
|
||||
while True:
|
||||
data = await ws.receive_text()
|
||||
try:
|
||||
message = json.loads(data)
|
||||
response = await self._handle_ws_message(message)
|
||||
if response:
|
||||
await ws.send_json(response)
|
||||
except json.JSONDecodeError:
|
||||
await ws.send_json({"error": "Invalid JSON"})
|
||||
except WebSocketDisconnect:
|
||||
self.ws_manager.disconnect(ws)
|
||||
|
||||
# Stats router
|
||||
stats_router = APIRouter(prefix="/api/stats", tags=["stats"])
|
||||
|
||||
@stats_router.get("/overview")
|
||||
async def get_overview():
|
||||
return await self._get_stats_overview()
|
||||
|
||||
self._app.include_router(stats_router)
|
||||
|
||||
# Inspector router
|
||||
inspector_router = APIRouter(prefix="/api/inspector", tags=["inspector"])
|
||||
|
||||
@inspector_router.get("/stars")
|
||||
async def list_stars():
|
||||
return await self._list_stars()
|
||||
|
||||
@inspector_router.get("/stars/{star_name}")
|
||||
async def get_star(star_name: str):
|
||||
return await self._get_star_detail(star_name)
|
||||
|
||||
self._app.include_router(inspector_router)
|
||||
|
||||
# Memory router
|
||||
memory_router = APIRouter(prefix="/api/memory", tags=["memory"])
|
||||
|
||||
@memory_router.get("/")
|
||||
async def get_memory():
|
||||
return await self._get_memory_info()
|
||||
|
||||
self._app.include_router(memory_router)
|
||||
|
||||
async def _handle_ws_message(
|
||||
self, message: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Handle an incoming WebSocket message.
|
||||
|
||||
Args:
|
||||
message: Parsed JSON message from the client
|
||||
|
||||
Returns:
|
||||
Response message to send back, or None for no response
|
||||
"""
|
||||
msg_type = message.get("type")
|
||||
data = message.get("data", {})
|
||||
|
||||
if msg_type == "ping":
|
||||
return {"type": "pong", "data": {}}
|
||||
|
||||
if msg_type == "call_tool":
|
||||
return await self._handle_call_tool(data)
|
||||
|
||||
if msg_type == "get_stars":
|
||||
return {"type": "stars_list", "data": await self._list_stars()}
|
||||
|
||||
return {
|
||||
"type": "error",
|
||||
"data": {"message": f"Unknown message type: {msg_type}"},
|
||||
}
|
||||
|
||||
async def _handle_call_tool(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Handle a tool call request via WebSocket."""
|
||||
star_name = data.get("star")
|
||||
tool_name = data.get("tool")
|
||||
arguments = data.get("arguments", {})
|
||||
|
||||
if not star_name or not tool_name:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"data": {"error": "Missing star or tool name"},
|
||||
}
|
||||
|
||||
try:
|
||||
result = await self.orchestrator.abp.call_star_tool(
|
||||
star_name, tool_name, arguments
|
||||
)
|
||||
return {"type": "tool_result", "data": {"result": result}}
|
||||
except Exception as e:
|
||||
return {"type": "tool_result", "data": {"error": str(e)}}
|
||||
|
||||
async def _get_stats_overview(self) -> dict[str, Any]:
|
||||
"""Get overview statistics."""
|
||||
return {
|
||||
"stars_count": len(self.orchestrator.abp._stars),
|
||||
"lsp_connected": self.orchestrator.lsp._connected,
|
||||
"mcp_sessions": getattr(self.orchestrator.mcp, "session", None) is not None,
|
||||
"acp_clients": len(getattr(self.orchestrator.acp, "_clients", [])),
|
||||
}
|
||||
|
||||
async def _list_stars(self) -> list[dict[str, Any]]:
|
||||
"""List all registered stars."""
|
||||
stars = []
|
||||
for name in self.orchestrator.abp._stars:
|
||||
stars.append({"name": name, "status": "active"})
|
||||
return stars
|
||||
|
||||
async def _get_star_detail(self, star_name: str) -> dict[str, Any]:
|
||||
"""Get details of a specific star."""
|
||||
star = self.orchestrator.abp._stars.get(star_name)
|
||||
if not star:
|
||||
return {"error": f"Star '{star_name}' not found"}
|
||||
return {"name": star_name, "status": "active"}
|
||||
|
||||
async def _get_memory_info(self) -> dict[str, Any]:
|
||||
"""Get memory usage information."""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return {
|
||||
"gc_objects": len(gc.get_objects()),
|
||||
"python_memory": "N/A", # Would need psutil for actual values
|
||||
}
|
||||
|
||||
def set_listen_address(self, host: str, port: int) -> None:
|
||||
"""Set the listen address for the gateway server."""
|
||||
self._host = host
|
||||
self._port = port
|
||||
@@ -1,103 +0,0 @@
|
||||
"""
|
||||
WebSocket connection manager for the AstrBot gateway.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
else:
|
||||
try:
|
||||
from fastapi import WebSocket
|
||||
except ImportError:
|
||||
logger.warning("FastAPI not installed, WebSocketManager unavailable.")
|
||||
WebSocket = cast(Any, None)
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""
|
||||
Manages all active WebSocket connections.
|
||||
|
||||
Provides connection/disconnection handling and broadcast capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: set[WebSocket] = set()
|
||||
self._lock = anyio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
"""Accept and register a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
self._connections.add(websocket)
|
||||
log.debug(f"WebSocket connected. Total: {len(self._connections)}")
|
||||
|
||||
async def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection."""
|
||||
async with self._lock:
|
||||
self._connections.discard(websocket)
|
||||
log.debug(f"WebSocket disconnected. Total: {len(self._connections)}")
|
||||
|
||||
async def send_json(self, websocket: WebSocket, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Send JSON data to a specific WebSocket.
|
||||
|
||||
Args:
|
||||
websocket: Target WebSocket connection
|
||||
data: Data to send (must be JSON-serializable)
|
||||
"""
|
||||
try:
|
||||
await websocket.send_json(data)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to send to WebSocket: {e}")
|
||||
await self.disconnect(websocket)
|
||||
|
||||
async def broadcast(self, data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Broadcast JSON data to all connected WebSockets.
|
||||
|
||||
Args:
|
||||
data: Data to broadcast (must be JSON-serializable)
|
||||
"""
|
||||
async with self._lock:
|
||||
connections = list(self._connections)
|
||||
|
||||
for conn in connections:
|
||||
try:
|
||||
await conn.send_json(data)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to broadcast to WebSocket: {e}")
|
||||
async with self._lock:
|
||||
self._connections.discard(conn)
|
||||
|
||||
async def send_to(
|
||||
self, websocket: WebSocket, message: str | dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Send a message to a specific WebSocket.
|
||||
|
||||
Args:
|
||||
websocket: Target WebSocket connection
|
||||
message: Message to send (string or dict)
|
||||
"""
|
||||
try:
|
||||
if isinstance(message, str):
|
||||
await websocket.send_text(message)
|
||||
else:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to send to WebSocket: {e}")
|
||||
await self.disconnect(websocket)
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
"""Return the number of active connections."""
|
||||
return len(self._connections)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""ABP module - AstrBot Protocol client implementation (built-in plugin protocol)."""
|
||||
|
||||
from .client import AstrbotAbpClient
|
||||
|
||||
__all__ = ["AstrbotAbpClient"]
|
||||
@@ -1,93 +0,0 @@
|
||||
"""
|
||||
ABP (AstrBot Protocol) client implementation.
|
||||
|
||||
ABP is the built-in plugin protocol where the orchestrator acts as client
|
||||
connecting to internal stars (plugins) embedded in the runtime.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.abp.base_astrbot_abp_client import BaseAstrbotAbpClient
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotAbpClient(BaseAstrbotAbpClient):
|
||||
"""
|
||||
ABP client for communicating with internal stars (built-in plugins).
|
||||
|
||||
The orchestrator acts as the client, sending requests to and receiving
|
||||
notifications from stars running within the same process.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connected = False
|
||||
self._stars: dict[str, Any] = {}
|
||||
# Use a simple dict for pending requests; we avoid asyncio.Future here.
|
||||
self._pending_requests: dict[str, Any] = {}
|
||||
self._request_id = 0
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to stars registry."""
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to internal stars registry."""
|
||||
log.debug("ABP client connecting to internal stars...")
|
||||
self._connected = True
|
||||
log.info("ABP client connected to internal stars registry.")
|
||||
|
||||
async def call_star_tool(
|
||||
self, star_name: str, tool_name: str, arguments: dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call a tool on a registered star.
|
||||
|
||||
Args:
|
||||
star_name: Name of the star (plugin)
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool call result
|
||||
"""
|
||||
if not self._connected:
|
||||
raise RuntimeError("ABP client is not connected")
|
||||
|
||||
star = self._stars.get(star_name)
|
||||
if not star:
|
||||
raise ValueError(f"Star '{star_name}' not found")
|
||||
|
||||
request_id = f"{self._request_id}"
|
||||
self._request_id += 1
|
||||
|
||||
# No asyncio.Future used; store a placeholder entry for tracking if needed.
|
||||
self._pending_requests[request_id] = None
|
||||
|
||||
try:
|
||||
# Call the star's tool handler
|
||||
result = await star.call_tool(tool_name, arguments)
|
||||
return result
|
||||
finally:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
|
||||
def register_star(self, star_name: str, star_instance: Any) -> None:
|
||||
"""Register a star (plugin) with the ABP client."""
|
||||
self._stars[star_name] = star_instance
|
||||
log.debug(f"Star '{star_name}' registered with ABP client.")
|
||||
|
||||
def unregister_star(self, star_name: str) -> None:
|
||||
"""Unregister a star from the ABP client."""
|
||||
self._stars.pop(star_name, None)
|
||||
log.debug(f"Star '{star_name}' unregistered from ABP client.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the ABP client connection."""
|
||||
self._connected = False
|
||||
# Clear any pending requests (no asyncio futures used in this implementation)
|
||||
self._pending_requests.clear()
|
||||
log.info("ABP client shut down.")
|
||||
@@ -1,6 +0,0 @@
|
||||
"""ACP module - AstrBot Communication Protocol client and server implementations."""
|
||||
|
||||
from .client import AstrbotAcpClient
|
||||
from .server import AstrbotAcpServer
|
||||
|
||||
__all__ = ["AstrbotAcpClient", "AstrbotAcpServer"]
|
||||
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) client implementation.
|
||||
|
||||
ACP is a client-server protocol for inter-service communication,
|
||||
similar to MCP but designed specifically for AstrBot's architecture.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.acp.base_astrbot_acp_client import BaseAstrbotAcpClient
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotAcpClient(BaseAstrbotAcpClient):
|
||||
"""
|
||||
ACP client for communicating with ACP servers.
|
||||
|
||||
The orchestrator acts as an ACP client, connecting to external
|
||||
ACP-compatible services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connected = False
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
self._server_url: str | None = None
|
||||
self._pending_requests: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||
self._request_id = 0
|
||||
self._reader_task: asyncio.Task[None] | None = None
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to an ACP server."""
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to configured ACP servers.
|
||||
|
||||
ACP servers can be accessed via TCP (host:port) or Unix socket.
|
||||
"""
|
||||
log.debug("ACP client connecting...")
|
||||
# TODO: Load ACP server configurations
|
||||
self._connected = True
|
||||
log.info("ACP client initialized.")
|
||||
|
||||
async def connect_to_server(self, host: str, port: int) -> None:
|
||||
"""
|
||||
Connect to an ACP server via TCP.
|
||||
|
||||
Args:
|
||||
host: Server hostname or IP
|
||||
port: Server port
|
||||
"""
|
||||
self._server_url = f"{host}:{port}"
|
||||
self._reader, self._writer = await asyncio.open_connection(host, port)
|
||||
self._connected = True
|
||||
|
||||
# Start reading responses
|
||||
self._reader_task = asyncio.create_task(self._read_messages())
|
||||
|
||||
log.info(f"ACP client connected to {self._server_url}")
|
||||
|
||||
async def connect_to_unix_socket(self, socket_path: str) -> None:
|
||||
"""
|
||||
Connect to an ACP server via Unix socket.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the Unix socket
|
||||
"""
|
||||
self._server_url = f"unix://{socket_path}"
|
||||
self._reader, self._writer = await asyncio.open_unix_connection(socket_path)
|
||||
self._connected = True
|
||||
|
||||
self._reader_task = asyncio.create_task(self._read_messages())
|
||||
|
||||
log.info(f"ACP client connected to {self._server_url}")
|
||||
|
||||
async def _read_messages(self) -> None:
|
||||
"""Background task to read ACP messages."""
|
||||
if not self._reader:
|
||||
return
|
||||
|
||||
buffer = b""
|
||||
while self._connected:
|
||||
try:
|
||||
data = await self._reader.read(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
|
||||
while True:
|
||||
header_end = buffer.find(b"\n")
|
||||
if header_end == -1:
|
||||
break
|
||||
|
||||
try:
|
||||
header = json.loads(buffer[:header_end].decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
buffer = buffer[header_end + 1 :]
|
||||
continue
|
||||
|
||||
content_length = header.get("content-length", 0)
|
||||
if (
|
||||
content_length == 0
|
||||
or len(buffer) < header_end + 1 + content_length
|
||||
):
|
||||
break
|
||||
|
||||
content = buffer[header_end + 1 : header_end + 1 + content_length]
|
||||
buffer = buffer[header_end + 1 + content_length :]
|
||||
|
||||
message = json.loads(content.decode("utf-8"))
|
||||
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
future = self._pending_requests.pop(request_id, None)
|
||||
if future and not future.done():
|
||||
if "error" in message:
|
||||
future.set_exception(Exception(str(message["error"])))
|
||||
else:
|
||||
future.set_result(message.get("result", {}))
|
||||
else:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
if self._connected:
|
||||
log.error(f"ACP read error: {e}")
|
||||
break
|
||||
|
||||
async def _handle_notification(self, notification: dict[str, Any]) -> None:
|
||||
"""Handle incoming ACP notifications."""
|
||||
method = notification.get("method", "")
|
||||
log.debug(f"ACP notification: {method}")
|
||||
|
||||
async def call_tool(
|
||||
self, server_name: str, tool_name: str, arguments: dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call a tool on an ACP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the ACP server
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool call result
|
||||
"""
|
||||
if not self._connected:
|
||||
raise RuntimeError("ACP client is not connected")
|
||||
|
||||
request_id = str(self._request_id)
|
||||
self._request_id += 1
|
||||
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": f"{server_name}/{tool_name}",
|
||||
"params": arguments,
|
||||
}
|
||||
|
||||
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
await self._send_message(message)
|
||||
return await future
|
||||
|
||||
async def _send_message(self, message: dict[str, Any]) -> None:
|
||||
"""Send an ACP message."""
|
||||
if not self._writer:
|
||||
raise RuntimeError("ACP client not connected")
|
||||
|
||||
content = json.dumps(message)
|
||||
header = json.dumps({"content-length": len(content)}) + "\n"
|
||||
|
||||
self._writer.write((header + content).encode())
|
||||
await self._writer.drain()
|
||||
|
||||
async def send_notification(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Send a one-way notification to the server."""
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
await self._send_message(message)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the ACP client connection."""
|
||||
self._connected = False
|
||||
|
||||
if self._reader_task:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
try:
|
||||
await self._writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
log.info("ACP client shut down.")
|
||||
@@ -1,223 +0,0 @@
|
||||
"""
|
||||
ACP (AstrBot Communication Protocol) server implementation.
|
||||
|
||||
ACP servers listen for connections from ACP clients and provide
|
||||
services/tools to the orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.acp.base_astrbot_acp_server import BaseAstrbotAcpServer
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotAcpServer(BaseAstrbotAcpServer):
|
||||
"""
|
||||
ACP server for accepting connections from ACP clients.
|
||||
|
||||
ACP servers expose tools/notifications that can be called by clients.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8765
|
||||
self._server: asyncio.Server | None = None
|
||||
self._clients: set[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = set()
|
||||
self._tool_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._notification_handlers: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
def register_tool(self, name: str, handler: Callable[..., Any]) -> None:
|
||||
"""
|
||||
Register a tool handler.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
handler: Async callable that handles tool calls
|
||||
"""
|
||||
self._tool_handlers[name] = handler
|
||||
log.debug(f"ACP server registered tool: {name}")
|
||||
|
||||
def register_notification_handler(
|
||||
self, name: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
"""
|
||||
Register a notification handler.
|
||||
|
||||
Args:
|
||||
name: Notification method name
|
||||
handler: Async callable that handles notifications
|
||||
"""
|
||||
self._notification_handlers[name] = handler
|
||||
log.debug(f"ACP server registered notification handler: {name}")
|
||||
|
||||
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
|
||||
"""
|
||||
Start the ACP server.
|
||||
|
||||
Args:
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
"""
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle_client,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
self._running = True
|
||||
log.info(f"ACP server listening on {host}:{port}")
|
||||
|
||||
async def _handle_client(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
"""Handle an incoming ACP client connection."""
|
||||
addr = writer.get_extra_info("peername")
|
||||
log.debug(f"ACP client connected: {addr}")
|
||||
self._clients.add((reader, writer))
|
||||
|
||||
buffer = b""
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
data = await reader.read(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
|
||||
while True:
|
||||
header_end = buffer.find(b"\n")
|
||||
if header_end == -1:
|
||||
break
|
||||
|
||||
try:
|
||||
header = json.loads(buffer[:header_end].decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
buffer = buffer[header_end + 1 :]
|
||||
continue
|
||||
|
||||
content_length = header.get("content-length", 0)
|
||||
if (
|
||||
content_length == 0
|
||||
or len(buffer) < header_end + 1 + content_length
|
||||
):
|
||||
break
|
||||
|
||||
content = buffer[
|
||||
header_end + 1 : header_end + 1 + content_length
|
||||
]
|
||||
buffer = buffer[header_end + 1 + content_length :]
|
||||
|
||||
message = json.loads(content.decode("utf-8"))
|
||||
response = await self._handle_message(message)
|
||||
|
||||
if response:
|
||||
content = json.dumps(response)
|
||||
resp_header = (
|
||||
json.dumps({"content-length": len(content)}) + "\n"
|
||||
)
|
||||
writer.write(resp_header.encode() + content.encode())
|
||||
await writer.drain()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"ACP client error ({addr}): {e}")
|
||||
break
|
||||
|
||||
finally:
|
||||
self._clients.discard((reader, writer))
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
log.debug(f"ACP client disconnected: {addr}")
|
||||
|
||||
async def _handle_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Handle an incoming ACP message."""
|
||||
method = message.get("method", "")
|
||||
msg_id = message.get("id")
|
||||
params = message.get("params", {})
|
||||
|
||||
# Check if it's a notification (no id) or request (has id)
|
||||
if msg_id is None:
|
||||
# Notification
|
||||
handler = self._notification_handlers.get(method)
|
||||
if handler:
|
||||
try:
|
||||
await handler(params)
|
||||
except Exception as e:
|
||||
log.error(f"ACP notification handler error ({method}): {e}")
|
||||
return None
|
||||
|
||||
# Request
|
||||
result = None
|
||||
error = None
|
||||
|
||||
handler = self._tool_handlers.get(method)
|
||||
if handler:
|
||||
try:
|
||||
result = await handler(params)
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
log.error(f"ACP tool handler error ({method}): {e}")
|
||||
else:
|
||||
error = f"Unknown method: {method}"
|
||||
|
||||
response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id}
|
||||
if error:
|
||||
response["error"] = {"code": -32601, "message": error}
|
||||
else:
|
||||
response["result"] = result
|
||||
|
||||
return response
|
||||
|
||||
async def broadcast_notification(self, method: str, params: dict[str, Any]) -> None:
|
||||
"""
|
||||
Broadcast a notification to all connected clients.
|
||||
|
||||
Args:
|
||||
method: Notification method name
|
||||
params: Notification parameters
|
||||
"""
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
content = json.dumps(message)
|
||||
header = json.dumps({"content-length": len(content)}) + "\n"
|
||||
data = header.encode() + content.encode()
|
||||
|
||||
for reader, writer in list(self._clients):
|
||||
try:
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to broadcast to client: {e}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the ACP server."""
|
||||
self._running = False
|
||||
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
||||
for reader, writer in list(self._clients):
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self._clients.clear()
|
||||
|
||||
log.info("ACP server shut down.")
|
||||
@@ -1,5 +0,0 @@
|
||||
"""LSP module - Language Server Protocol client implementation."""
|
||||
|
||||
from .client import AstrbotLspClient
|
||||
|
||||
__all__ = ["AstrbotLspClient"]
|
||||
@@ -1,243 +0,0 @@
|
||||
"""
|
||||
LSP (Language Server Protocol) client implementation.
|
||||
|
||||
The orchestrator acts as an LSP client, connecting to LSP servers
|
||||
that provide language intelligence features (completions, diagnostics, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ByteReceiveStream, ByteSendStream, Process
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.lsp.base_astrbot_lsp_client import BaseAstrbotLspClient
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotLspClient(BaseAstrbotLspClient):
|
||||
"""
|
||||
LSP client for communicating with LSP servers.
|
||||
|
||||
Implements the Microsoft Language Server Protocol for connecting to
|
||||
external language intelligence services.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connected = False
|
||||
self._reader: ByteReceiveStream | None = None
|
||||
self._writer: ByteSendStream | None = None
|
||||
self._server_process: Process | None = None
|
||||
self._pending_requests: dict[int, Any] = {}
|
||||
self._request_id = 0
|
||||
self._server_command: list[str] | None = None
|
||||
# anyio TaskGroup handle for background readers
|
||||
self._task_group: Any | None = None
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if connected to an LSP server."""
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to configured LSP servers.
|
||||
|
||||
LSP servers are typically stdio-based subprocesses. This method
|
||||
establishes the communication channel.
|
||||
"""
|
||||
log.debug("LSP client connecting...")
|
||||
# TODO: Load LSP server configurations and start subprocesses
|
||||
# For now, mark as connected in idle mode
|
||||
self._connected = True
|
||||
log.info("LSP client initialized.")
|
||||
|
||||
async def connect_to_server(self, command: list[str], workspace_uri: str) -> None:
|
||||
"""
|
||||
Connect to an LSP server subprocess.
|
||||
|
||||
Args:
|
||||
command: Command line to start the LSP server (e.g., ["python", "lsp_server.py"])
|
||||
workspace_uri: Root URI of the workspace to serve
|
||||
"""
|
||||
log.debug(f"Starting LSP server: {' '.join(command)}")
|
||||
|
||||
self._server_process = await anyio.open_process(
|
||||
command,
|
||||
stdin=-1,
|
||||
stdout=-1,
|
||||
stderr=-1,
|
||||
)
|
||||
self._reader = self._server_process.stdout
|
||||
self._writer = self._server_process.stdin
|
||||
self._server_command = command
|
||||
self._connected = True
|
||||
|
||||
# Start reading responses in background using anyio TaskGroup
|
||||
# Create and enter a TaskGroup so the reader runs until we close it at shutdown.
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
self._task_group.start_soon(self._read_responses)
|
||||
|
||||
# Send initialize request
|
||||
await self.send_request(
|
||||
"initialize",
|
||||
{
|
||||
"processId": None,
|
||||
"rootUri": workspace_uri,
|
||||
"capabilities": {},
|
||||
},
|
||||
)
|
||||
|
||||
# Send initialized notification
|
||||
await self.send_notification("initialized", {})
|
||||
|
||||
log.info(f"LSP client connected to server: {command[0]}")
|
||||
|
||||
async def send_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Send an LSP request and wait for response."""
|
||||
if not self._writer:
|
||||
raise RuntimeError("LSP client not connected")
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id += 1
|
||||
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
|
||||
# Use anyio.Event for request/response matching
|
||||
response_event: anyio.Event = anyio.Event()
|
||||
response_holder: dict[str, Any] = {}
|
||||
|
||||
async def set_response(response: dict[str, Any]) -> None:
|
||||
response_holder["response"] = response
|
||||
response_event.set()
|
||||
|
||||
self._pending_requests[request_id] = set_response
|
||||
|
||||
content = json.dumps(message)
|
||||
headers = f"Content-Length: {len(content)}\r\n\r\n"
|
||||
await self._writer.send((headers + content).encode())
|
||||
|
||||
# Wait for response with timeout
|
||||
with anyio.move_on_after(30):
|
||||
await response_event.wait()
|
||||
|
||||
if "response" in response_holder:
|
||||
return response_holder["response"]
|
||||
raise TimeoutError(f"LSP request {method} timed out")
|
||||
|
||||
async def send_notification(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Send an LSP notification (no response expected)."""
|
||||
if not self._writer:
|
||||
raise RuntimeError("LSP client not connected")
|
||||
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params or {},
|
||||
}
|
||||
|
||||
content = json.dumps(message)
|
||||
headers = f"Content-Length: {len(content)}\r\n\r\n"
|
||||
await self._writer.send((headers + content).encode())
|
||||
|
||||
async def _read_responses(self) -> None:
|
||||
"""Background task to read LSP responses."""
|
||||
if not self._reader:
|
||||
return
|
||||
|
||||
buffer = b""
|
||||
try:
|
||||
while self._connected:
|
||||
try:
|
||||
data = await self._reader.receive()
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
|
||||
while True:
|
||||
# Parse Content-Length header
|
||||
header_end = buffer.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
break
|
||||
|
||||
header = buffer[:header_end].decode("utf-8")
|
||||
content_length = 0
|
||||
for line in header.split("\r\n"):
|
||||
if line.startswith("Content-Length:"):
|
||||
content_length = int(line.split(":")[1].strip())
|
||||
|
||||
if content_length == 0:
|
||||
break
|
||||
|
||||
total_length = header_end + 4 + content_length
|
||||
if len(buffer) < total_length:
|
||||
break
|
||||
|
||||
content = buffer[header_end + 4 : total_length]
|
||||
buffer = buffer[total_length:]
|
||||
|
||||
response = json.loads(content.decode("utf-8"))
|
||||
|
||||
# Handle response vs notification
|
||||
if "id" in response:
|
||||
request_id = response["id"]
|
||||
handler = self._pending_requests.pop(request_id, None)
|
||||
if handler:
|
||||
await handler(response)
|
||||
else:
|
||||
# Notification (e.g., window/logMessage)
|
||||
await self._handle_notification(response)
|
||||
|
||||
except anyio.EndOfStream:
|
||||
break
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# Task was cancelled via the TaskGroup cancel/exit during shutdown
|
||||
pass
|
||||
|
||||
async def _handle_notification(self, notification: dict[str, Any]) -> None:
|
||||
"""Handle incoming LSP notifications."""
|
||||
method = notification.get("method", "")
|
||||
log.debug(f"LSP notification: {method}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the LSP client."""
|
||||
self._connected = False
|
||||
|
||||
if self._task_group:
|
||||
try:
|
||||
# Exit the TaskGroup, which cancels background tasks started within it
|
||||
await self._task_group.__aexit__(None, None, None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
pass
|
||||
self._task_group = None
|
||||
|
||||
if self._server_process:
|
||||
try:
|
||||
await self.send_notification("shutdown", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._server_process.terminate()
|
||||
try:
|
||||
with anyio.move_on_after(5):
|
||||
await self._server_process.wait()
|
||||
except Exception:
|
||||
self._server_process.kill()
|
||||
self._server_process = None
|
||||
|
||||
self._pending_requests.clear()
|
||||
log.info("LSP client shut down.")
|
||||
@@ -1,63 +0,0 @@
|
||||
"""MCP module - Model Context Protocol client and tool implementations.
|
||||
|
||||
This module provides MCP client functionality and MCP tool wrappers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .client import McpClient
|
||||
from .config import (
|
||||
DEFAULT_MCP_CONFIG,
|
||||
get_mcp_config_path,
|
||||
load_mcp_config,
|
||||
save_mcp_config,
|
||||
)
|
||||
from .tool import MCPTool
|
||||
|
||||
|
||||
# Exceptions
|
||||
class MCPInitError(Exception):
|
||||
"""Base exception for MCP initialization failures."""
|
||||
|
||||
|
||||
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
|
||||
"""Raised when MCP client initialization exceeds the configured timeout."""
|
||||
|
||||
|
||||
class MCPAllServicesFailedError(MCPInitError):
|
||||
"""Raised when all configured MCP services fail to initialize."""
|
||||
|
||||
|
||||
class MCPShutdownTimeoutError(asyncio.TimeoutError):
|
||||
"""Raised when MCP shutdown exceeds the configured timeout."""
|
||||
|
||||
def __init__(self, names: list[str], timeout: float) -> None:
|
||||
self.names = names
|
||||
self.timeout = timeout
|
||||
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPInitSummary:
|
||||
"""Summary of MCP initialization results."""
|
||||
|
||||
total: int
|
||||
success: int
|
||||
failed: list[str]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_MCP_CONFIG",
|
||||
"MCPAllServicesFailedError",
|
||||
"MCPInitError",
|
||||
"MCPInitSummary",
|
||||
"MCPInitTimeoutError",
|
||||
"MCPShutdownTimeoutError",
|
||||
"MCPTool",
|
||||
"McpClient",
|
||||
"get_mcp_config_path",
|
||||
"load_mcp_config",
|
||||
"save_mcp_config",
|
||||
]
|
||||
@@ -1,486 +0,0 @@
|
||||
"""MCP client implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from astrbot._internal.abc.mcp.base_astrbot_mcp_client import (
|
||||
BaseAstrbotMcpClient,
|
||||
McpServerConfig,
|
||||
McpToolInfo,
|
||||
)
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
try:
|
||||
import anyio
|
||||
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
||||
)
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
||||
)
|
||||
|
||||
|
||||
class TenacityLogger:
|
||||
"""Wraps a logging.Logger to satisfy tenacity's LoggerProtocol."""
|
||||
|
||||
__slots__ = ("_logger",)
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(self, logger: logging.Logger) -> None:
|
||||
self._logger = logger
|
||||
|
||||
def log(
|
||||
self,
|
||||
level: int,
|
||||
msg: str,
|
||||
/,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""Prepare configuration, handle nested format."""
|
||||
if config.get("mcpServers"):
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
def _prepare_stdio_env(config: dict) -> dict:
|
||||
"""Preserve Windows executable resolution for stdio subprocesses."""
|
||||
if sys.platform != "win32":
|
||||
return config
|
||||
|
||||
pathext = os.environ.get("PATHEXT")
|
||||
if not pathext:
|
||||
return config
|
||||
|
||||
prepared = config.copy()
|
||||
env = dict(prepared.get("env") or {})
|
||||
env.setdefault("PATHEXT", pathext)
|
||||
prepared["env"] = env
|
||||
return prepared
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""Quick test MCP server connectivity."""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if transport_type == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"Connection timeout: {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
class McpClient(BaseAstrbotMcpClient):
|
||||
def __init__(self) -> None:
|
||||
# Initialize session and client objects
|
||||
self.session: mcp.ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
||||
|
||||
self.name: str | None = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = anyio.Event()
|
||||
self.process_pid: int | None = None
|
||||
|
||||
# Store connection config for reconnection
|
||||
self._mcp_server_config: McpServerConfig | None = None
|
||||
self._server_name: str | None = None
|
||||
self._reconnect_lock = anyio.Lock() # Lock for thread-safe reconnection
|
||||
self._reconnecting: bool = False # For logging and debugging
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize the MCP client connection.
|
||||
|
||||
Note: Actual server connections are made via connect_to_server().
|
||||
This method prepares the client for use.
|
||||
"""
|
||||
# MCP client is initialized on-demand via connect_to_server
|
||||
# This is a no-op stub to satisfy BaseAstrbotMcpClient
|
||||
logger.debug("MCP client initialized.")
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""True if MCP client has an active session."""
|
||||
return self.session is not None
|
||||
|
||||
async def list_tools(self) -> list[McpToolInfo]:
|
||||
"""List all tools from connected MCP servers."""
|
||||
if not self.session:
|
||||
return []
|
||||
result = await self.list_tools_and_save()
|
||||
tools = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema,
|
||||
}
|
||||
for tool in result.tools
|
||||
]
|
||||
return cast(list[McpToolInfo], tools)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
read_timeout_seconds: int = 60,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server with reconnection support."""
|
||||
return await self.call_tool_with_reconnect(
|
||||
tool_name=name,
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=timedelta(seconds=read_timeout_seconds),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stdio_process_pid(streams_context: object) -> int | None:
|
||||
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
|
||||
|
||||
TODO(refactor): replace this async-generator frame introspection with a
|
||||
stable MCP library hook once the upstream transport exposes process PID.
|
||||
"""
|
||||
generator = getattr(streams_context, "gen", None)
|
||||
frame = getattr(generator, "ag_frame", None)
|
||||
if frame is None:
|
||||
return None
|
||||
process = frame.f_locals.get("process")
|
||||
pid = getattr(process, "pid", None)
|
||||
try:
|
||||
return int(pid) if pid is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def connect_to_server(self, config: McpServerConfig, name: str) -> None:
|
||||
"""Connect to MCP server
|
||||
|
||||
If `url` parameter exists:
|
||||
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
||||
2. When transport is specified as `sse`, use SSE connection.
|
||||
3. If not specified, default to SSE connection to MCP service.
|
||||
|
||||
Args:
|
||||
config: Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
|
||||
"""
|
||||
# Store config for reconnection
|
||||
self._mcp_server_config = config
|
||||
self._server_name = name
|
||||
self.process_pid = None
|
||||
|
||||
cfg = _prepare_config(dict(config))
|
||||
|
||||
def logging_callback(
|
||||
msg: str | mcp.types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
# Handle MCP service error logs
|
||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
||||
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
|
||||
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
|
||||
self.server_errlogs.append(log_msg)
|
||||
|
||||
if "url" in cfg:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
if transport_type != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context,
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=cast(Any, logging_callback),
|
||||
),
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context,
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
cfg = _prepare_stdio_env(cfg)
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None:
|
||||
# Handle MCP service error logs
|
||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
||||
if msg.level in (
|
||||
"warning",
|
||||
"error",
|
||||
"critical",
|
||||
"alert",
|
||||
"emergency",
|
||||
):
|
||||
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
|
||||
self.server_errlogs.append(log_msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=cast(
|
||||
Any,
|
||||
LogPipe(
|
||||
level=logging.INFO,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
self.process_pid = self._extract_stdio_process_pid(stdio_transport)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport),
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
if not self.session:
|
||||
raise Exception("MCP Client is not initialized")
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server using the stored configuration.
|
||||
|
||||
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
||||
|
||||
Raises:
|
||||
Exception: raised when reconnection fails
|
||||
"""
|
||||
async with self._reconnect_lock:
|
||||
# Check if already reconnecting (useful for logging)
|
||||
if self._reconnecting:
|
||||
logger.debug(
|
||||
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
if not self._mcp_server_config or not self._server_name:
|
||||
raise Exception("Cannot reconnect: missing connection configuration")
|
||||
|
||||
self._reconnecting = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Attempting to reconnect to MCP server {self._server_name}..."
|
||||
)
|
||||
|
||||
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
||||
if self.exit_stack:
|
||||
self._old_exit_stacks.append(self.exit_stack)
|
||||
|
||||
# Mark old session as invalid
|
||||
self.session = None
|
||||
|
||||
# Create new exit stack for new connection
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
# Reconnect using stored config
|
||||
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
||||
await self.list_tools_and_save()
|
||||
|
||||
logger.info(
|
||||
f"Successfully reconnected to MCP server {self._server_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._reconnecting = False
|
||||
|
||||
async def call_tool_with_reconnect(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict,
|
||||
read_timeout_seconds: timedelta,
|
||||
) -> mcp.types.CallToolResult:
|
||||
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
||||
|
||||
Args:
|
||||
tool_name: tool name
|
||||
arguments: tool arguments
|
||||
read_timeout_seconds: read timeout
|
||||
|
||||
Returns:
|
||||
MCP tool call result
|
||||
|
||||
Raises:
|
||||
ValueError: MCP session is not available
|
||||
anyio.ClosedResourceError: raised after reconnection failure
|
||||
"""
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
before_sleep=before_sleep_log(TenacityLogger(logger), logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_with_retry():
|
||||
if not self.session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
|
||||
try:
|
||||
return await self.session.call_tool(
|
||||
name=tool_name,
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
logger.warning(
|
||||
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
||||
)
|
||||
# Attempt to reconnect
|
||||
await self._reconnect()
|
||||
# Reraise the exception to trigger tenacity retry
|
||||
raise
|
||||
|
||||
return await _call_with_retry()
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing current exit stack: {e}")
|
||||
|
||||
# Don't close old exit stacks as they may be in different task contexts
|
||||
# They will be garbage collected naturally
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
self.process_pid = None
|
||||
@@ -1,55 +0,0 @@
|
||||
"""MCP configuration management."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
|
||||
def get_mcp_config_path() -> str:
|
||||
"""Get the path to the MCP configuration file."""
|
||||
data_dir = get_astrbot_data_path()
|
||||
return os.path.join(data_dir, "mcp_server.json")
|
||||
|
||||
|
||||
def load_mcp_config() -> dict:
|
||||
"""Load MCP configuration from file.
|
||||
|
||||
Returns:
|
||||
MCP configuration dict. If file doesn't exist, returns default config.
|
||||
|
||||
"""
|
||||
config_path = get_mcp_config_path()
|
||||
if not os.path.exists(config_path):
|
||||
# Create default config if not exists
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
|
||||
def save_mcp_config(config: dict) -> bool:
|
||||
"""Save MCP configuration to file.
|
||||
|
||||
Args:
|
||||
config: MCP configuration dict to save.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
|
||||
"""
|
||||
config_path = get_mcp_config_path()
|
||||
try:
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,45 +0,0 @@
|
||||
"""MCP tool wrapper."""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
try:
|
||||
import mcp
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
mcp = None # type: ignore
|
||||
|
||||
from astrbot._internal.tools.base import FunctionTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
|
||||
|
||||
class MCPTool(FunctionTool):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_tool: "mcp.types.Tool",
|
||||
mcp_client: "McpClient",
|
||||
mcp_server_name: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=mcp_tool.inputSchema,
|
||||
)
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_client = mcp_client
|
||||
self.mcp_server_name = mcp_server_name
|
||||
self.source = "mcp"
|
||||
|
||||
async def call(self, **kwargs: Any) -> Any:
|
||||
"""Call the MCP tool with the given arguments."""
|
||||
# Note: For actual usage, context.tool_call_timeout is needed
|
||||
# but for simplicity we use a default timeout here
|
||||
return await self.mcp_client.call_tool_with_reconnect(
|
||||
tool_name=self.mcp_tool.name,
|
||||
arguments=kwargs,
|
||||
read_timeout_seconds=timedelta(seconds=60),
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
from astrbot._internal.runtime.__main__ import bootstrap
|
||||
|
||||
__all__ = ["bootstrap"]
|
||||
@@ -1,24 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.geteway.server import AstrbotGateway
|
||||
from astrbot._internal.runtime.orchestrator import AstrbotOrchestrator
|
||||
|
||||
|
||||
async def bootstrap():
|
||||
orchestrator: BaseAstrbotOrchestrator = AstrbotOrchestrator()
|
||||
gw: BaseAstrbotGateway = AstrbotGateway(orchestrator)
|
||||
|
||||
# anyio 的结构化并发
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(orchestrator.lsp.connect) # 启动 LSP client
|
||||
tg.start_soon(orchestrator.mcp.connect) # 启动 MCP client
|
||||
tg.start_soon(orchestrator.acp.connect) # 启动 ACP client
|
||||
tg.start_soon(orchestrator.abp.connect) # 启动 ABP client
|
||||
await anyio.sleep(0.5)
|
||||
tg.start_soon(orchestrator.run_loop) # 启动编排器循环
|
||||
|
||||
tg.start_soon(gw.serve) # 面板后端服务
|
||||
@@ -1,164 +0,0 @@
|
||||
"""
|
||||
AstrBot Orchestrator - core runtime that coordinates all protocol clients.
|
||||
|
||||
The orchestrator manages the lifecycle of LSP, MCP, ACP, and ABP clients,
|
||||
and runs the main event loop that dispatches messages between components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
|
||||
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
|
||||
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
|
||||
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
|
||||
from astrbot._internal.protocols.mcp.client import McpClient
|
||||
from astrbot._internal.stars import RuntimeStatusStar
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
class AstrbotOrchestrator(BaseAstrbotOrchestrator):
|
||||
"""
|
||||
Core runtime orchestrator for AstrBot.
|
||||
|
||||
Manages:
|
||||
- LSP client: Language Server Protocol for editor integrations
|
||||
- MCP client: Model Context Protocol for external tool servers
|
||||
- ACP client: AstrBot Communication Protocol for inter-service communication
|
||||
- ABP client: AstrBot Protocol for built-in star (plugin) communication
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Initialize protocol clients (use concrete types for full method access)
|
||||
self.lsp = AstrbotLspClient()
|
||||
self.mcp = McpClient()
|
||||
self.acp = AstrbotAcpClient()
|
||||
self.abp = AstrbotAbpClient()
|
||||
|
||||
self._running = False
|
||||
self._stars: dict[str, Any] = {}
|
||||
self._message_count: int = 0
|
||||
self._last_activity_timestamp: float | None = None
|
||||
|
||||
# Auto-register RuntimeStatusStar
|
||||
self._runtime_status_star = RuntimeStatusStar()
|
||||
self._runtime_status_star.set_orchestrator(self)
|
||||
self._stars["runtime-status-star"] = self._runtime_status_star
|
||||
self.abp.register_star("runtime-status-star", self._runtime_status_star)
|
||||
|
||||
log.debug("AstrbotOrchestrator initialized.")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Initialize all protocol clients.
|
||||
|
||||
Calls connect() on all protocol clients to prepare them for use.
|
||||
"""
|
||||
log.info("Starting AstrbotOrchestrator...")
|
||||
|
||||
await self.lsp.connect()
|
||||
await self.mcp.connect()
|
||||
await self.acp.connect()
|
||||
await self.abp.connect()
|
||||
|
||||
self._running = True
|
||||
log.info("AstrbotOrchestrator started.")
|
||||
|
||||
async def run_loop(self) -> None:
|
||||
"""
|
||||
Main orchestrator event loop.
|
||||
|
||||
This loop runs continuously, handling:
|
||||
- Periodic health checks of protocol clients
|
||||
- Message routing between protocols
|
||||
- Star (plugin) lifecycle management
|
||||
"""
|
||||
self._running = True
|
||||
log.info("AstrbotOrchestrator run loop started.")
|
||||
|
||||
stop_event = anyio.Event()
|
||||
|
||||
def set_stop() -> None:
|
||||
stop_event.set()
|
||||
|
||||
# Store the callback for cleanup
|
||||
self._stop_callback = set_stop
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
# TODO: Periodic tasks:
|
||||
# - Check LSP server health
|
||||
# - Check MCP session status
|
||||
# - Check ACP client connections
|
||||
# - Process any pending star notifications
|
||||
|
||||
# Wait for 5 seconds or until shutdown is called
|
||||
with anyio.move_on_after(5):
|
||||
await stop_event.wait()
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
log.info("Orchestrator run loop cancelled.")
|
||||
finally:
|
||||
self._running = False
|
||||
self._stop_callback = None
|
||||
log.info("AstrbotOrchestrator run loop stopped.")
|
||||
|
||||
async def register_star(self, name: str, star_instance: Any) -> None:
|
||||
"""
|
||||
Register a star (plugin) with the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Unique name for the star
|
||||
star_instance: Star plugin instance
|
||||
"""
|
||||
self._stars[name] = star_instance
|
||||
self.abp.register_star(name, star_instance)
|
||||
log.info(f"Star '{name}' registered.")
|
||||
|
||||
async def unregister_star(self, name: str) -> None:
|
||||
"""
|
||||
Unregister a star (plugin) from the orchestrator.
|
||||
|
||||
Args:
|
||||
name: Name of the star to unregister
|
||||
"""
|
||||
self._stars.pop(name, None)
|
||||
self.abp.unregister_star(name)
|
||||
log.info(f"Star '{name}' unregistered.")
|
||||
|
||||
async def get_star(self, name: str) -> Any | None:
|
||||
"""Get a registered star by name."""
|
||||
return self._stars.get(name)
|
||||
|
||||
async def list_stars(self) -> list[str]:
|
||||
"""List all registered star names."""
|
||||
return list(self._stars.keys())
|
||||
|
||||
def record_activity(self) -> None:
|
||||
"""Record a message activity for stats tracking."""
|
||||
self._message_count += 1
|
||||
import time
|
||||
|
||||
self._last_activity_timestamp = time.time()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the orchestrator and all protocol clients.
|
||||
"""
|
||||
log.info("Shutting down AstrbotOrchestrator...")
|
||||
self._running = False
|
||||
|
||||
# Shutdown all protocol clients
|
||||
await self.lsp.shutdown()
|
||||
await self.acp.shutdown()
|
||||
await self.abp.shutdown()
|
||||
|
||||
# MCP cleanup
|
||||
await self.mcp.cleanup()
|
||||
|
||||
log.info("AstrbotOrchestrator shut down.")
|
||||
@@ -1,18 +0,0 @@
|
||||
import sys
|
||||
|
||||
try:
|
||||
from ._core import cli as _cli
|
||||
|
||||
def cli():
|
||||
if len(sys.argv) == 1:
|
||||
sys.argv.append("--help")
|
||||
return _cli()
|
||||
except ImportError:
|
||||
from click import echo
|
||||
|
||||
def cli():
|
||||
echo("""
|
||||
AstrBot CLI(rust) is not available.
|
||||
Developer: maturin dev
|
||||
User: uv run astrbot-rs
|
||||
""")
|
||||
@@ -1,16 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
class AstrbotOrchestrator:
|
||||
def start(self) -> None: ...
|
||||
def stop(self) -> None: ...
|
||||
def is_running(self) -> bool: ...
|
||||
def register_star(self, name: str, handler: str) -> None: ...
|
||||
def unregister_star(self, name: str) -> None: ...
|
||||
def list_stars(self) -> list[str]: ...
|
||||
def record_activity(self) -> None: ...
|
||||
def get_stats(self) -> dict[str, Any]: ...
|
||||
def set_protocol_connected(self, protocol: str, connected: bool) -> None: ...
|
||||
def get_protocol_status(self, protocol: str) -> dict[str, Any] | None: ...
|
||||
|
||||
def get_orchestrator() -> AstrbotOrchestrator: ...
|
||||
def cli() -> None: ...
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Internal skills module - re-exports from core.skills.skill_manager."""
|
||||
|
||||
from astrbot.core.skills.skill_manager import (
|
||||
SkillInfo,
|
||||
SkillManager,
|
||||
build_skills_prompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SkillInfo",
|
||||
"SkillManager",
|
||||
"build_skills_prompt",
|
||||
]
|
||||
@@ -1,7 +0,0 @@
|
||||
"""
|
||||
Stars (built-in plugins) for AstrBot runtime.
|
||||
"""
|
||||
|
||||
from astrbot._internal.stars.runtime_status_star import RuntimeStatusStar
|
||||
|
||||
__all__ = ["RuntimeStatusStar"]
|
||||
@@ -1,127 +0,0 @@
|
||||
"""
|
||||
RuntimeStatusStar - ABP plugin that exposes core runtime internal state.
|
||||
|
||||
This star provides tools for querying:
|
||||
- Runtime status (running state, uptime)
|
||||
- Protocol client status (LSP, MCP, ACP, ABP)
|
||||
- Registered stars registry
|
||||
- Message counts and metrics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeStatusStar:
|
||||
"""
|
||||
ABP star that exposes core runtime internal state as callable tools.
|
||||
|
||||
Tools provided:
|
||||
- get_runtime_status: Returns running state and uptime
|
||||
- get_protocol_status: Returns LSP, MCP, ACP, ABP status
|
||||
- get_star_registry: Returns registered star names
|
||||
- get_stats: Returns message counts and metrics
|
||||
"""
|
||||
|
||||
name: str = "runtime-status-star"
|
||||
description: str = "ABP plugin that exposes core runtime internal state"
|
||||
|
||||
_start_time: float = field(default_factory=time.time, init=False)
|
||||
_orchestrator: Any = field(default=None, init=False)
|
||||
|
||||
def set_orchestrator(self, orchestrator: Any) -> None:
|
||||
"""Set the orchestrator reference for status queries."""
|
||||
self._orchestrator = orchestrator
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Handle tool calls from ABP client.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool result
|
||||
"""
|
||||
if tool_name == "get_runtime_status":
|
||||
return self._get_runtime_status()
|
||||
elif tool_name == "get_protocol_status":
|
||||
return await self._get_protocol_status()
|
||||
elif tool_name == "get_star_registry":
|
||||
return await self._get_star_registry()
|
||||
elif tool_name == "get_stats":
|
||||
return self._get_stats()
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {tool_name}")
|
||||
|
||||
def _get_runtime_status(self) -> dict[str, Any]:
|
||||
"""Get overall runtime state."""
|
||||
running = (
|
||||
getattr(self._orchestrator, "running", False)
|
||||
if self._orchestrator
|
||||
else False
|
||||
)
|
||||
uptime_seconds = time.time() - self._start_time
|
||||
return {
|
||||
"running": running,
|
||||
"uptime_seconds": uptime_seconds,
|
||||
}
|
||||
|
||||
async def _get_protocol_status(self) -> dict[str, Any]:
|
||||
"""Get status of each protocol client."""
|
||||
if not self._orchestrator:
|
||||
return {
|
||||
"lsp": {"connected": False, "name": "lsp-client"},
|
||||
"mcp": {"connected": False, "name": "mcp-client"},
|
||||
"acp": {"connected": False, "name": "acp-client"},
|
||||
"abp": {"connected": False, "name": "abp-client"},
|
||||
}
|
||||
|
||||
return {
|
||||
"lsp": {
|
||||
"connected": getattr(self._orchestrator.lsp, "connected", False),
|
||||
"name": "lsp-client",
|
||||
},
|
||||
"mcp": {
|
||||
"connected": getattr(self._orchestrator.mcp, "connected", False),
|
||||
"name": "mcp-client",
|
||||
},
|
||||
"acp": {
|
||||
"connected": getattr(self._orchestrator.acp, "connected", False),
|
||||
"name": "acp-client",
|
||||
},
|
||||
"abp": {
|
||||
"connected": getattr(self._orchestrator.abp, "connected", False),
|
||||
"name": "abp-client",
|
||||
},
|
||||
}
|
||||
|
||||
async def _get_star_registry(self) -> dict[str, Any]:
|
||||
"""Get list of registered stars."""
|
||||
if not self._orchestrator:
|
||||
return {"stars": []}
|
||||
|
||||
stars = await self._orchestrator.list_stars()
|
||||
return {"stars": stars}
|
||||
|
||||
def _get_stats(self) -> dict[str, Any]:
|
||||
"""Get message counts and metrics."""
|
||||
result: dict[str, Any] = {
|
||||
"uptime_seconds": time.time() - self._start_time,
|
||||
}
|
||||
if self._orchestrator:
|
||||
result["total_messages"] = getattr(self._orchestrator, "_message_count", 0)
|
||||
last_ts = getattr(self._orchestrator, "_last_activity_timestamp", None)
|
||||
if last_ts is not None:
|
||||
result["last_activity"] = datetime.fromtimestamp(
|
||||
last_ts, tz=timezone.utc
|
||||
).isoformat()
|
||||
else:
|
||||
result["last_activity"] = None
|
||||
return result
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Internal tools module for AstrBot runtime."""
|
||||
|
||||
from .base import FunctionTool, ToolSet
|
||||
|
||||
__all__ = ["FunctionTool", "ToolSet"]
|
||||
@@ -1,332 +0,0 @@
|
||||
"""Base tool classes for AstrBot internal runtime.
|
||||
|
||||
This module provides the FunctionTool base class used by MCP tools
|
||||
in the new internal architecture.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""A class representing the schema of a tool for function calling."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
|
||||
description: str
|
||||
"""The description of the tool."""
|
||||
|
||||
parameters: ParametersType = field(default_factory=dict)
|
||||
"""The parameters of the tool, in JSON Schema format."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parameters(self) -> "ToolSchema":
|
||||
"""Validate the parameters JSON schema."""
|
||||
import jsonschema
|
||||
|
||||
jsonschema.validate(
|
||||
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool(ToolSchema):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[str | None] | AsyncGenerator[Any, None]] | None = (
|
||||
None
|
||||
)
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
"""
|
||||
The module path of the handler function. This is empty when the origin is mcp.
|
||||
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
|
||||
causing the handler's __module__ to be functools
|
||||
"""
|
||||
|
||||
active: bool = True
|
||||
"""
|
||||
Whether the tool is active. This field is a special field for AstrBot.
|
||||
You can ignore it when integrating with other frameworks.
|
||||
"""
|
||||
|
||||
is_background_task: bool = False
|
||||
"""
|
||||
Declare this tool as a background task. Background tasks return immediately
|
||||
with a task identifier while the real work continues asynchronously.
|
||||
"""
|
||||
|
||||
source: str = "mcp"
|
||||
"""
|
||||
Origin of this tool: 'plugin' (from star plugins), 'internal' (AstrBot built-in),
|
||||
or 'mcp' (from MCP servers). Used by WebUI for display grouping.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
async def call(self, **kwargs: Any) -> Any:
|
||||
"""Run the tool with the given arguments. The handler field has priority."""
|
||||
raise NotImplementedError(
|
||||
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||
)
|
||||
|
||||
|
||||
class ToolSet:
|
||||
"""
|
||||
A collection of FunctionTools grouped under a namespace.
|
||||
|
||||
ToolSets allow organizing related tools together. The LLM sees tools
|
||||
as "namespace/tool_name" when calling.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, tools: list[FunctionTool] | None = None) -> None:
|
||||
self.namespace = namespace
|
||||
self._tools: dict[str, FunctionTool] = {}
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add(tool)
|
||||
|
||||
def add(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the set."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def add_tool(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the set (alias for add())."""
|
||||
self.add(tool)
|
||||
|
||||
def remove(self, name: str) -> FunctionTool | None:
|
||||
"""Remove and return a tool by name."""
|
||||
return self._tools.pop(name, None)
|
||||
|
||||
def remove_tool(self, name: str) -> None:
|
||||
"""Remove a tool by its name."""
|
||||
self._tools.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_tool(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name (alias for get)."""
|
||||
return self.get(name)
|
||||
|
||||
def list_tools(self) -> list[FunctionTool]:
|
||||
"""List all tools in this set."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def __iter__(self) -> Iterator[FunctionTool]:
|
||||
return iter(self._tools.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._tools)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ToolSet(namespace={self.namespace!r}, tools={self.list_tools()!r})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ToolSet({self.namespace}, {len(self)} tools)"
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""Get names of all tools in this set."""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self) == 0
|
||||
|
||||
def merge(self, other: "ToolSet") -> None:
|
||||
"""Merge another ToolSet into this one."""
|
||||
for tool in other.tools:
|
||||
self.add(tool)
|
||||
|
||||
def normalize(self) -> None:
|
||||
"""Sort tools by name for deterministic serialization."""
|
||||
self._tools = dict(sorted(self._tools.items(), key=lambda x: x[0]))
|
||||
|
||||
def get_light_tool_set(self) -> "ToolSet":
|
||||
"""Return a light tool set with only name/description."""
|
||||
light_tools = []
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, "active") and not tool.active:
|
||||
continue
|
||||
light_tools.append(
|
||||
FunctionTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters={"type": "object", "properties": {}},
|
||||
handler=None,
|
||||
)
|
||||
)
|
||||
return ToolSet("default", light_tools)
|
||||
|
||||
def get_param_only_tool_set(self) -> "ToolSet":
|
||||
"""Return a tool set with name/parameters only (no description)."""
|
||||
param_tools = []
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, "active") and not tool.active:
|
||||
continue
|
||||
params = (
|
||||
copy.deepcopy(tool.parameters)
|
||||
if tool.parameters
|
||||
else {"type": "object", "properties": {}}
|
||||
)
|
||||
param_tools.append(
|
||||
FunctionTool(
|
||||
name=tool.name,
|
||||
description="",
|
||||
parameters=params,
|
||||
handler=None,
|
||||
)
|
||||
)
|
||||
return ToolSet("default", param_tools)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[FunctionTool]:
|
||||
"""List all tools in this set."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def openai_schema(
|
||||
self, omit_empty_parameter_field: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert tools to OpenAI API function calling schema format."""
|
||||
result: list[dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
func_def: dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {"name": tool.name},
|
||||
}
|
||||
if tool.description:
|
||||
func_def["function"]["description"] = tool.description
|
||||
|
||||
if tool.parameters is not None:
|
||||
if (
|
||||
tool.parameters.get("properties")
|
||||
) or not omit_empty_parameter_field:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
return result
|
||||
|
||||
def anthropic_schema(self) -> list[dict]:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
input_schema: dict[str, Any] = {"type": "object"}
|
||||
if tool.parameters:
|
||||
input_schema["properties"] = tool.parameters.get("properties", {})
|
||||
input_schema["required"] = tool.parameters.get("required", [])
|
||||
tool_def: dict[str, Any] = {"name": tool.name, "input_schema": input_schema}
|
||||
if tool.description:
|
||||
tool_def["description"] = tool.description
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
def google_schema(self) -> dict:
|
||||
"""Convert tools to Google GenAI API format."""
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
origin_type = schema.get("type")
|
||||
target_type = origin_type
|
||||
|
||||
if isinstance(origin_type, list):
|
||||
target_type = next((t for t in origin_type if t != "null"), "string")
|
||||
|
||||
if target_type in supported_types:
|
||||
result["type"] = target_type
|
||||
if "format" in schema and schema["format"] in supported_formats.get(result["type"], set()):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
if "additionalProperties" in prop_value:
|
||||
del prop_value["additionalProperties"]
|
||||
properties[key] = prop_value
|
||||
if properties:
|
||||
result["properties"] = properties
|
||||
|
||||
if target_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if isinstance(items_schema, dict):
|
||||
result["items"] = convert_schema(items_schema)
|
||||
else:
|
||||
result["items"] = {"type": "string"}
|
||||
|
||||
return result
|
||||
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
d: dict[str, Any] = {"name": tool.name}
|
||||
if tool.description:
|
||||
d["description"] = tool.description
|
||||
if tool.parameters:
|
||||
d["parameters"] = convert_schema(tool.parameters)
|
||||
tools_list.append(d)
|
||||
|
||||
declarations: dict[str, Any] = {}
|
||||
if tools_list:
|
||||
declarations["function_declarations"] = tools_list
|
||||
return declarations
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
|
||||
"""Get tools in OpenAI function calling style (deprecated)."""
|
||||
return self.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
def get_func_desc_anthropic_style(self):
|
||||
"""Get tools in Anthropic style (deprecated)."""
|
||||
return self.anthropic_schema()
|
||||
|
||||
def get_func_desc_google_genai_style(self):
|
||||
"""Get tools in Google GenAI style (deprecated)."""
|
||||
return self.google_schema()
|
||||
@@ -1,48 +0,0 @@
|
||||
"""
|
||||
Builtin tools for AstrBot - re-exports from core.tools for backward compatibility.
|
||||
|
||||
This module re-exports the builtin tools (cron, send_message, kb_query) from
|
||||
the deprecated core.tools module for backward compatibility.
|
||||
|
||||
TODO: These tools should be fully migrated to _internal and core.tools
|
||||
should be removed once all consumers update their imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Re-export cron tools
|
||||
from astrbot.core.tools.cron_tools import (
|
||||
CREATE_CRON_JOB_TOOL,
|
||||
DELETE_CRON_JOB_TOOL,
|
||||
LIST_CRON_JOBS_TOOL,
|
||||
CreateActiveCronTool,
|
||||
DeleteCronJobTool,
|
||||
ListCronJobsTool,
|
||||
)
|
||||
|
||||
# Re-export knowledge_base_query tool
|
||||
from astrbot.core.tools.kb_query import (
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
KnowledgeBaseQueryTool,
|
||||
)
|
||||
|
||||
# Re-export send_message tool
|
||||
from astrbot.core.tools.send_message import (
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
SendMessageToUserTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Cron tools
|
||||
"CREATE_CRON_JOB_TOOL",
|
||||
"DELETE_CRON_JOB_TOOL",
|
||||
"KNOWLEDGE_BASE_QUERY_TOOL",
|
||||
"LIST_CRON_JOBS_TOOL",
|
||||
"SEND_MESSAGE_TO_USER_TOOL",
|
||||
# Classes
|
||||
"CreateActiveCronTool",
|
||||
"DeleteCronJobTool",
|
||||
"KnowledgeBaseQueryTool",
|
||||
"ListCronJobsTool",
|
||||
"SendMessageToUserTool",
|
||||
]
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Tools registry for AstrBot internal runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Re-export from base
|
||||
from astrbot._internal.tools.base import FunctionTool, ToolSet
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_MCP_CONFIG",
|
||||
"ENABLE_MCP_TIMEOUT_ENV",
|
||||
"FuncCall",
|
||||
"FunctionTool",
|
||||
"FunctionToolManager",
|
||||
"MCPAllServicesFailedError",
|
||||
"MCPInitError",
|
||||
"MCPInitSummary",
|
||||
"MCPInitTimeoutError",
|
||||
"MCPShutdownTimeoutError",
|
||||
"ToolSet",
|
||||
]
|
||||
|
||||
|
||||
# MCP config constants (re-exported from protocols)
|
||||
try:
|
||||
from astrbot._internal.protocols.mcp import (
|
||||
DEFAULT_MCP_CONFIG,
|
||||
MCPAllServicesFailedError,
|
||||
MCPInitError,
|
||||
MCPInitSummary,
|
||||
MCPInitTimeoutError,
|
||||
MCPShutdownTimeoutError,
|
||||
)
|
||||
except ImportError:
|
||||
DEFAULT_MCP_CONFIG: dict[str, Any] = {}
|
||||
MCPAllServicesFailedError: type[Exception] = Exception
|
||||
MCPInitError: type[Exception] = Exception
|
||||
MCPInitSummary: type[dict] = dict
|
||||
MCPInitTimeoutError: type[TimeoutError] = TimeoutError
|
||||
MCPShutdownTimeoutError: type[TimeoutError] = TimeoutError
|
||||
|
||||
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_TIMEOUT_ENABLED"
|
||||
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
|
||||
|
||||
|
||||
class FunctionToolManager:
|
||||
"""Central registry for all function tools."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._func_list: list[FunctionTool] = []
|
||||
|
||||
@property
|
||||
def func_list(self) -> list[FunctionTool]:
|
||||
"""Get the list of function tools."""
|
||||
return self._func_list
|
||||
|
||||
@func_list.setter
|
||||
def func_list(self, value: list[FunctionTool]) -> None:
|
||||
"""Set the list of function tools."""
|
||||
self._func_list = value
|
||||
|
||||
def add(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the registry."""
|
||||
self._func_list.append(tool)
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
"""Remove a tool by name. Returns True if found."""
|
||||
for i, f in enumerate(self._func_list):
|
||||
if f.name == name:
|
||||
self._func_list.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_func(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name. Returns the last active tool if multiple match."""
|
||||
last_match: FunctionTool | None = None
|
||||
for f in reversed(self._func_list):
|
||||
if f.name == name:
|
||||
if getattr(f, "active", True):
|
||||
return f
|
||||
if last_match is None:
|
||||
last_match = f
|
||||
return last_match
|
||||
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""Return a ToolSet with all active tools, deduplicated by name."""
|
||||
seen: dict[str, FunctionTool] = {}
|
||||
for tool in reversed(self._func_list):
|
||||
if tool.name not in seen and getattr(tool, "active", True):
|
||||
seen[tool.name] = tool
|
||||
return ToolSet("default", list(seen.values()))
|
||||
|
||||
def register_internal_tools(self) -> None:
|
||||
"""Register built-in computer tools (shell, python, browser, neo)."""
|
||||
# Import here to avoid circular imports
|
||||
from astrbot.core.computer.computer_tool_provider import get_all_tools
|
||||
|
||||
for tool in get_all_tools():
|
||||
if self.get_func(tool.name) is None:
|
||||
self.add(tool)
|
||||
|
||||
# MCP-related stub methods for base class compatibility
|
||||
async def enable_mcp_server(
|
||||
self, name: str, config: dict[str, Any], init_timeout: int = 30
|
||||
) -> None:
|
||||
"""Enable an MCP server (stub)."""
|
||||
pass
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server (stub)."""
|
||||
pass
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""Initialize MCP clients (stub)."""
|
||||
pass
|
||||
|
||||
async def test_mcp_server_connection(
|
||||
self, config: dict[str, Any]
|
||||
) -> tuple[bool, str]:
|
||||
"""Test MCP server connection (stub)."""
|
||||
return False, "Not implemented"
|
||||
|
||||
async def sync_modelscope_mcp_servers(self) -> None:
|
||||
"""Sync ModelScope MCP servers (stub)."""
|
||||
pass
|
||||
|
||||
def load_mcp_config(self) -> dict[str, Any]:
|
||||
"""Load MCP configuration (stub)."""
|
||||
return {"mcpServers": {}}
|
||||
|
||||
def save_mcp_config(self, config: dict[str, Any]) -> bool:
|
||||
"""Save MCP configuration (stub)."""
|
||||
return True
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
"""Activate an LLM tool (stub)."""
|
||||
return True
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""Deactivate an LLM tool (stub)."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def mcp_client_dict(self) -> dict[str, Any]:
|
||||
"""Return dict of MCP clients (stub)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mcp_server_runtime_view(self) -> dict[str, Any]:
|
||||
"""Return runtime view of MCP servers (stub)."""
|
||||
return {}
|
||||
|
||||
|
||||
class FuncCall(FunctionToolManager):
|
||||
"""Alias for FunctionToolManager for backward compatibility."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._mcp_server_runtime_view: dict[str, Any] = {}
|
||||
self._mcp_client_dict: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def mcp_server_runtime_view(self) -> dict[str, Any]:
|
||||
"""Return runtime view of MCP servers."""
|
||||
return self._mcp_server_runtime_view
|
||||
|
||||
@property
|
||||
def mcp_client_dict(self) -> dict[str, Any]:
|
||||
"""Return dict of MCP clients (for backward compatibility)."""
|
||||
return self._mcp_client_dict
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""Initialize MCP clients (stub implementation)."""
|
||||
pass
|
||||
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list[dict[str, Any]],
|
||||
desc: str,
|
||||
handler: Any,
|
||||
) -> None:
|
||||
"""Add a function tool (deprecated, use add() instead)."""
|
||||
params: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param.get("type", "string"),
|
||||
"description": param.get("description", ""),
|
||||
}
|
||||
func = FunctionTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
self.add(func)
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
"""Remove a function tool by name (deprecated, use remove() instead)."""
|
||||
self.remove(name)
|
||||
|
||||
def get_func(self, name: str) -> FunctionTool | None:
|
||||
"""Get a function tool by name."""
|
||||
return super().get_func(name)
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""Get all tool names."""
|
||||
return [f.name for f in self.func_list]
|
||||
|
||||
def remove_tool(self, name: str) -> None:
|
||||
"""Remove a tool by its name (alias for remove)."""
|
||||
self.remove(name)
|
||||
|
||||
def get_func_desc_openai_style(
|
||||
self, omit_empty_parameter_field: bool = False
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get tools in OpenAI style (deprecated, use get_full_tool_set().openai_schema())."""
|
||||
tool_set = self.get_full_tool_set()
|
||||
return tool_set.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
async def enable_mcp_server(
|
||||
self, name: str, config: dict[str, Any], init_timeout: int = 30
|
||||
) -> None:
|
||||
"""Enable an MCP server (stub implementation)."""
|
||||
pass
|
||||
|
||||
async def disable_mcp_server(
|
||||
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
|
||||
) -> None:
|
||||
"""Disable an MCP server (stub implementation)."""
|
||||
pass
|
||||
|
||||
def load_mcp_config(self) -> dict[str, Any]:
|
||||
"""Load MCP configuration (stub implementation)."""
|
||||
return {"mcpServers": {}}
|
||||
|
||||
def save_mcp_config(self, config: dict[str, Any]) -> bool:
|
||||
"""Save MCP configuration (stub implementation)."""
|
||||
return True
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
"""Activate an LLM tool (stub implementation)."""
|
||||
return True
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""Deactivate an LLM tool (stub implementation)."""
|
||||
return True
|
||||
|
||||
async def test_mcp_server_connection(
|
||||
self, config: dict[str, Any]
|
||||
) -> tuple[bool, str]:
|
||||
"""Test MCP server connection (stub implementation)."""
|
||||
# Import the actual test function if available
|
||||
try:
|
||||
from astrbot._internal.protocols.mcp.client import (
|
||||
_quick_test_mcp_connection,
|
||||
)
|
||||
|
||||
success, message = await _quick_test_mcp_connection(config)
|
||||
if not success:
|
||||
raise Exception(message)
|
||||
return success, message
|
||||
except Exception as e:
|
||||
raise Exception(f"MCP connection test failed: {e!s}") from e
|
||||
|
||||
async def sync_modelscope_mcp_servers(self) -> None:
|
||||
"""Sync ModelScope MCP servers (stub implementation)."""
|
||||
pass
|
||||
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""Return a ToolSet with all active tools."""
|
||||
return ToolSet("default", [t for t in self.func_list if t.active])
|
||||
@@ -1,64 +1,19 @@
|
||||
"""
|
||||
AstrBot Public API.
|
||||
|
||||
This package exposes the public interface for extending and integrating with
|
||||
AstrBot. All exports from this module are guaranteed to be stable across
|
||||
minor version updates.
|
||||
|
||||
Modules:
|
||||
tools: Tool registration and management API
|
||||
mcp: Model Context Protocol server and tool API
|
||||
skills: Skill management and conversion API
|
||||
"""
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
# Tool API
|
||||
from astrbot._internal.tools.base import FunctionTool, ToolSet
|
||||
|
||||
# MCP API
|
||||
from astrbot.api.mcp import (
|
||||
MCPClient,
|
||||
MCPTool,
|
||||
get_mcp_servers,
|
||||
register_mcp_server,
|
||||
unregister_mcp_server,
|
||||
)
|
||||
|
||||
# Skills API
|
||||
from astrbot.api.skills import (
|
||||
SkillInfo,
|
||||
SkillManager,
|
||||
get_skill_manager,
|
||||
skill_to_tool,
|
||||
)
|
||||
|
||||
# Tools API (public interface)
|
||||
from astrbot.api.tools import ToolRegistry, get_registry, tool
|
||||
from astrbot.core import html_renderer, sp
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.register import register_agent as agent
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"BaseFunctionToolExecutor",
|
||||
"FunctionTool",
|
||||
"MCPClient",
|
||||
"MCPTool",
|
||||
"SkillInfo",
|
||||
"SkillManager",
|
||||
"ToolRegistry",
|
||||
"ToolSet",
|
||||
"agent",
|
||||
"get_mcp_servers",
|
||||
"get_registry",
|
||||
"get_skill_manager",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"logger",
|
||||
"register_mcp_server",
|
||||
"skill_to_tool",
|
||||
"sp",
|
||||
"tool",
|
||||
"unregister_mcp_server",
|
||||
]
|
||||
|
||||
@@ -29,7 +29,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
)
|
||||
from astrbot.core.star.register import (
|
||||
register_star as register, # 注册插件(Star)
|
||||
register_star as register, # 注册插件(Star)
|
||||
)
|
||||
from astrbot.core.star import Context, Star
|
||||
from astrbot.core.star.config import *
|
||||
|
||||
@@ -55,14 +55,14 @@ __all__ = [
|
||||
"on_decorating_result",
|
||||
"on_llm_request",
|
||||
"on_llm_response",
|
||||
"on_llm_tool_respond",
|
||||
"on_platform_loaded",
|
||||
"on_plugin_error",
|
||||
"on_plugin_loaded",
|
||||
"on_plugin_unloaded",
|
||||
"on_using_llm_tool",
|
||||
"on_platform_loaded",
|
||||
"on_waiting_llm_request",
|
||||
"permission_type",
|
||||
"platform_adapter_type",
|
||||
"regex",
|
||||
"on_using_llm_tool",
|
||||
"on_llm_tool_respond",
|
||||
]
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) Public API for AstrBot.
|
||||
|
||||
This module provides a simple, stable interface for MCP server management,
|
||||
delegating to the _internal package.
|
||||
|
||||
Example:
|
||||
from astrbot.api.mcp import get_mcp_servers, register_mcp_server
|
||||
|
||||
# List connected servers
|
||||
servers = get_mcp_servers()
|
||||
|
||||
# Register stdio MCP server
|
||||
await register_mcp_server(
|
||||
name="weather",
|
||||
command="uv",
|
||||
args=["tool", "run", "weather-mcp"],
|
||||
)
|
||||
|
||||
# Register SSE server
|
||||
await register_mcp_server(
|
||||
name="fileserver",
|
||||
url="http://localhost:8080/sse",
|
||||
transport="sse",
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Import from _internal package (the canonical source)
|
||||
# TODO: fix path - should be protocols.mcp.client
|
||||
from astrbot._internal.protocols.mcp.client import McpClient as MCPClient
|
||||
from astrbot._internal.protocols.mcp.tool import MCPTool
|
||||
|
||||
__all__ = [
|
||||
"MCPClient",
|
||||
"MCPTool",
|
||||
"get_mcp_servers",
|
||||
"register_mcp_server",
|
||||
"unregister_mcp_server",
|
||||
]
|
||||
|
||||
|
||||
def get_mcp_servers() -> dict[str, MCPClient]:
|
||||
"""Get all connected MCP servers."""
|
||||
from astrbot.core.provider.register import llm_tools as func_tool_manager
|
||||
|
||||
manager = func_tool_manager
|
||||
return dict(manager.mcp_client_dict)
|
||||
|
||||
|
||||
async def register_mcp_server(
|
||||
name: str,
|
||||
command: str | None = None,
|
||||
args: list[str] | None = None,
|
||||
url: str | None = None,
|
||||
transport: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Register and connect to an MCP server.
|
||||
|
||||
Args:
|
||||
name: Unique name for this server
|
||||
command: Command to run (for stdio transport)
|
||||
args: Command arguments
|
||||
url: URL (for SSE/Streamable HTTP transports)
|
||||
transport: "sse", "streamable_http", or None for stdio
|
||||
|
||||
Example - Stdio:
|
||||
await register_mcp_server(name="weather", command="uv",
|
||||
args=["tool", "run", "weather-mcp"])
|
||||
"""
|
||||
from astrbot.core.provider.register import llm_tools as func_tool_manager
|
||||
|
||||
manager = func_tool_manager
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
if command is not None:
|
||||
config["command"] = command
|
||||
if args is not None:
|
||||
config["args"] = args
|
||||
if url is not None:
|
||||
config["url"] = url
|
||||
if transport is not None:
|
||||
config["transport"] = transport
|
||||
config.update(kwargs)
|
||||
|
||||
await manager.enable_mcp_server(name=name, config=config)
|
||||
|
||||
|
||||
async def unregister_mcp_server(name: str) -> None:
|
||||
"""Disconnect and remove an MCP server."""
|
||||
from astrbot.core.provider.register import llm_tools as func_tool_manager
|
||||
|
||||
manager = func_tool_manager
|
||||
await manager.disable_mcp_server(name=name)
|
||||
@@ -1,58 +0,0 @@
|
||||
"""
|
||||
Skills Public API for AstrBot.
|
||||
|
||||
This module provides a simple, stable interface for skill management,
|
||||
delegating to the _internal package.
|
||||
|
||||
Two skill types:
|
||||
1. Prompt-based: SKILL.md files injected into system prompt
|
||||
2. Tool-based: Skills with input_schema converted to FunctionTool
|
||||
|
||||
Example:
|
||||
from astrbot.api.skills import get_skill_manager, skill_to_tool
|
||||
|
||||
# List skills
|
||||
mgr = get_skill_manager()
|
||||
skills = mgr.list_skills()
|
||||
|
||||
# Convert tool-based skill to FunctionTool
|
||||
tool_skills = [s for s in skills if s.input_schema]
|
||||
if tool_skills:
|
||||
func_tool = skill_to_tool(tool_skills[0])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from astrbot._internal.tools.base import FunctionTool
|
||||
|
||||
# Import from _internal package (the canonical source)
|
||||
# TODO: fix path - should be core.skills.skill_manager
|
||||
from astrbot.core.skills.skill_manager import SkillInfo, SkillManager
|
||||
|
||||
__all__ = ["SkillInfo", "SkillManager", "get_skill_manager", "skill_to_tool"]
|
||||
|
||||
|
||||
def get_skill_manager() -> SkillManager:
|
||||
"""Get the global SkillManager instance."""
|
||||
return SkillManager()
|
||||
|
||||
|
||||
def skill_to_tool(skill: SkillInfo) -> FunctionTool | None:
|
||||
"""Convert a tool-based skill (with input_schema) to a FunctionTool.
|
||||
|
||||
Args:
|
||||
skill: A SkillInfo instance with an input_schema
|
||||
|
||||
Returns:
|
||||
A FunctionTool, or None if the skill has no input_schema
|
||||
"""
|
||||
if not skill.input_schema:
|
||||
return None
|
||||
|
||||
return FunctionTool(
|
||||
name=f"skill_{skill.name}",
|
||||
description=skill.description or f"Skill: {skill.name}",
|
||||
parameters=skill.input_schema,
|
||||
handler=None,
|
||||
source="skill",
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from astrbot.core.star import Context, Star, StarTools
|
||||
from astrbot.core.star.config import *
|
||||
from astrbot.core.star.register import (
|
||||
register_star as register, # 注册插件(Star)
|
||||
register_star as register, # 注册插件(Star)
|
||||
)
|
||||
|
||||
__all__ = ["Context", "Star", "StarTools", "register"]
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
"""
|
||||
Tools Public API for AstrBot.
|
||||
|
||||
This module provides a simple, stable interface for tool registration
|
||||
and management. All implementations are delegated to the _internal package.
|
||||
|
||||
Example:
|
||||
from astrbot.api.tools import tool, get_registry
|
||||
|
||||
@tool(name="weather", description="Get weather", parameters={...})
|
||||
async def get_weather(city: str) -> str:
|
||||
return f"Weather in {city} is sunny"
|
||||
|
||||
registry = get_registry()
|
||||
tools = registry.list_tools()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
# Import from _internal package (the canonical source)
|
||||
from astrbot._internal.tools.base import FunctionTool, ToolSet
|
||||
from astrbot._internal.tools.registry import FunctionToolManager
|
||||
|
||||
__all__ = ["FunctionTool", "ToolRegistry", "ToolSet", "get_registry", "tool"]
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Wrapper around FunctionToolManager for simplified tool registration.
|
||||
|
||||
This class provides a user-friendly interface for registering and
|
||||
managing tools, delegating to the internal FunctionToolManager.
|
||||
"""
|
||||
|
||||
_instance: ToolRegistry | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Import here to avoid circular imports
|
||||
from astrbot.core.provider.register import llm_tools as func_tool_manager
|
||||
|
||||
self._manager: FunctionToolManager = func_tool_manager
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> ToolRegistry:
|
||||
"""Get the singleton ToolRegistry instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def register(self, tool: FunctionTool) -> None:
|
||||
"""Register a FunctionTool."""
|
||||
self._manager.func_list.append(tool)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Unregister a tool by name. Returns True if found and removed."""
|
||||
for i, f in enumerate(self._manager.func_list):
|
||||
if f.name == name:
|
||||
self._manager.func_list.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_tools(self) -> list[FunctionTool]:
|
||||
"""List all registered tools."""
|
||||
return self._manager.func_list.copy()
|
||||
|
||||
def get_tool(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._manager.get_func(name)
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global ToolRegistry instance."""
|
||||
return ToolRegistry.get_instance()
|
||||
|
||||
|
||||
def tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> Callable[
|
||||
[Callable[..., Awaitable[str | None]]], Callable[..., Awaitable[str | None]]
|
||||
]:
|
||||
"""Decorator to register an async function as a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name (used by LLM to invoke it)
|
||||
description: What the tool does
|
||||
parameters: JSON Schema for parameters (optional)
|
||||
|
||||
Example:
|
||||
@tool(name="weather", description="Get weather for a city", parameters={...})
|
||||
async def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is sunny"
|
||||
"""
|
||||
if parameters is None:
|
||||
parameters = {"type": "object", "properties": {}}
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Awaitable[str | None]],
|
||||
) -> Callable[..., Awaitable[str | None]]:
|
||||
func_tool = FunctionTool(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
handler=func,
|
||||
handler_module_path=getattr(func, "__module__", ""),
|
||||
source="api",
|
||||
)
|
||||
get_registry().register(func_tool)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> str | None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -76,7 +76,7 @@ class LongTermMemory:
|
||||
if not provider:
|
||||
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
|
||||
if not isinstance(provider, Provider):
|
||||
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
|
||||
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
|
||||
response = await provider.text_chat(
|
||||
prompt=image_caption_prompt,
|
||||
session_id=uuid.uuid4().hex,
|
||||
@@ -149,7 +149,7 @@ class LongTermMemory:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
"""当触发 LLM 请求前,调用此方法修改 req"""
|
||||
"""当触发 LLM 请求前,调用此方法修改 req"""
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
@@ -164,7 +164,7 @@ class LongTermMemory:
|
||||
"Please react to it. Only output your response and do not output any other information. "
|
||||
"You MUST use the SAME language as the chatroom is using."
|
||||
)
|
||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += (
|
||||
"You are now in a chatroom. The chat history is as follows: \n"
|
||||
|
||||
@@ -50,7 +50,7 @@ class Main(star.Star):
|
||||
"""主动回复"""
|
||||
provider = self.context.get_using_provider(event.unified_msg_origin)
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
conv = None
|
||||
@@ -60,7 +60,7 @@ class Main(star.Star):
|
||||
|
||||
if not session_curr_cid:
|
||||
logger.error(
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -72,7 +72,7 @@ class Main(star.Star):
|
||||
prompt = event.message_str
|
||||
|
||||
if not conv:
|
||||
logger.error("未找到对话,无法主动回复")
|
||||
logger.error("未找到对话,无法主动回复")
|
||||
return
|
||||
|
||||
yield event.request_llm(
|
||||
@@ -88,7 +88,7 @@ class Main(star.Star):
|
||||
async def decorate_llm_req(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest
|
||||
) -> None:
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
|
||||
@@ -9,56 +9,56 @@ class AdminCommands:
|
||||
self.context = context
|
||||
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
|
||||
"""授权管理员。op <admin_id>"""
|
||||
"""授权管理员。op <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员。可通过 /sid 获取 ID。",
|
||||
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员。可通过 /sid 获取 ID。",
|
||||
),
|
||||
)
|
||||
return
|
||||
self.context.get_config()["admins_id"].append(str(admin_id))
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /deop <id> 取消管理员。可通过 /sid 获取 ID。",
|
||||
"使用方法: /deop <id> 取消管理员。可通过 /sid 获取 ID。",
|
||||
),
|
||||
)
|
||||
return
|
||||
try:
|
||||
self.context.get_config()["admins_id"].remove(str(admin_id))
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("取消授权成功。"))
|
||||
event.set_result(MessageEventResult().message("取消授权成功。"))
|
||||
except ValueError:
|
||||
event.set_result(
|
||||
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
|
||||
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
|
||||
)
|
||||
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""添加白名单。wl <sid>"""
|
||||
"""添加白名单。wl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单。可通过 /sid 获取 ID。",
|
||||
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单。可通过 /sid 获取 ID。",
|
||||
),
|
||||
)
|
||||
return
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
cfg["platform_settings"]["id_whitelist"].append(str(sid))
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""删除白名单。dwl <sid>"""
|
||||
"""删除白名单。dwl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /dwl <id> 删除白名单。可通过 /sid 获取 ID。",
|
||||
"使用方法: /dwl <id> 删除白名单。可通过 /sid 获取 ID。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -66,12 +66,12 @@ class AdminCommands:
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("删除白名单成功。"))
|
||||
event.set_result(MessageEventResult().message("删除白名单成功。"))
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent) -> None:
|
||||
"""更新管理面板"""
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
|
||||
@@ -18,7 +18,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
"""更新reset命令在特定场景下的权限设置"""
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {}) or {}
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_cfg.get("reset", {})
|
||||
reset_cfg[scene_key] = perm_type
|
||||
@@ -31,7 +31,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
if token.len < 3:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
"该指令用于设置指令或指令组的权限。\n"
|
||||
"该指令用于设置指令或指令组的权限。\n"
|
||||
"格式: /alter_cmd <cmd_name> <admin/member>\n"
|
||||
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
|
||||
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
|
||||
@@ -47,7 +47,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
if cmd_name == "reset" and cmd_type == "config":
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {}) or {}
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_.get("reset", {})
|
||||
|
||||
@@ -56,11 +56,11 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
private = reset_cfg.get("private", "member")
|
||||
|
||||
config_menu = f"""reset命令权限细粒度配置
|
||||
当前配置:
|
||||
当前配置:
|
||||
1. 群聊+会话隔离开: {group_unique_on}
|
||||
2. 群聊+会话隔离关: {group_unique_off}
|
||||
3. 私聊: {private}
|
||||
修改指令格式:
|
||||
修改指令格式:
|
||||
/alter_cmd reset scene <场景编号> <admin/member>
|
||||
例如: /alter_cmd reset scene 2 member"""
|
||||
await event.send(MessageChain().message(config_menu))
|
||||
@@ -82,7 +82,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
|
||||
if perm_type not in ["admin", "member"]:
|
||||
await event.send(
|
||||
MessageChain().message("权限类型错误,只能是 admin 或 member"),
|
||||
MessageChain().message("权限类型错误,只能是 admin 或 member"),
|
||||
)
|
||||
return
|
||||
|
||||
@@ -101,7 +101,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
|
||||
if cmd_type not in ["admin", "member"]:
|
||||
await event.send(
|
||||
MessageChain().message("指令类型错误,可选类型有 admin, member"),
|
||||
MessageChain().message("指令类型错误,可选类型有 admin, member"),
|
||||
)
|
||||
return
|
||||
|
||||
@@ -131,7 +131,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {}) or {}
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
|
||||
cfg = plugin_.get(found_command.handler_name, {})
|
||||
cfg["permission"] = cmd_type
|
||||
@@ -168,6 +168,6 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
cmd_group_str = "指令组" if cmd_group else "指令"
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。",
|
||||
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ class ConversationCommands:
|
||||
|
||||
scene = RstScene.get_scene(is_group, is_unique_session)
|
||||
|
||||
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {}) or {}
|
||||
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
|
||||
plugin_config = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_config.get("reset", {})
|
||||
|
||||
@@ -60,8 +60,8 @@ class ConversationCommands:
|
||||
if required_perm == "admin" and message.role != "admin":
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"在{scene.name}场景下,reset命令需要管理员权限,"
|
||||
f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。",
|
||||
f"在{scene.name}场景下,reset命令需要管理员权限,"
|
||||
f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -74,12 +74,12 @@ class ConversationCommands:
|
||||
scope_id=umo,
|
||||
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重置对话成功。"))
|
||||
message.set_result(MessageEventResult().message("重置对话成功。"))
|
||||
return
|
||||
|
||||
if not self.context.get_using_provider(umo):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
)
|
||||
return
|
||||
|
||||
@@ -88,7 +88,7 @@ class ConversationCommands:
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前未处于对话状态,请 /switch 切换或者 /new 创建。",
|
||||
"当前未处于对话状态,请 /switch 切换或者 /new 创建。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -101,7 +101,7 @@ class ConversationCommands:
|
||||
[],
|
||||
)
|
||||
|
||||
ret = "清除聊天历史成功!"
|
||||
ret = "清除聊天历史成功!"
|
||||
|
||||
message.set_extra("_clean_ltm_session", True)
|
||||
|
||||
@@ -124,18 +124,18 @@ class ConversationCommands:
|
||||
if stopped_count > 0:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"已请求停止 {stopped_count} 个运行中的任务。"
|
||||
f"已请求停止 {stopped_count} 个运行中的任务。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
|
||||
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
|
||||
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||
"""查看对话记录"""
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
)
|
||||
return
|
||||
|
||||
@@ -166,7 +166,7 @@ class ConversationCommands:
|
||||
|
||||
history = "".join(parts)
|
||||
ret = (
|
||||
f"当前对话历史记录:"
|
||||
f"当前对话历史记录:"
|
||||
f"{history or '无历史记录'}\n\n"
|
||||
f"第 {page} 页 | 共 {total_pages} 页\n"
|
||||
f"*输入 /history 2 跳转到第 2 页"
|
||||
@@ -181,7 +181,7 @@ class ConversationCommands:
|
||||
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
|
||||
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -200,7 +200,7 @@ class ConversationCommands:
|
||||
end_idx = start_idx + size_per_page
|
||||
conversations_paged = conversations_all[start_idx:end_idx]
|
||||
|
||||
parts = ["对话列表:\n---\n"]
|
||||
parts = ["对话列表:\n---\n"]
|
||||
"""全局序号从当前页的第一个开始"""
|
||||
global_index = start_idx + 1
|
||||
|
||||
@@ -277,7 +277,7 @@ class ConversationCommands:
|
||||
scope_id=message.unified_msg_origin,
|
||||
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
|
||||
)
|
||||
message.set_result(MessageEventResult().message("已创建新对话。"))
|
||||
message.set_result(MessageEventResult().message("已创建新对话。"))
|
||||
return
|
||||
|
||||
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
|
||||
@@ -291,7 +291,7 @@ class ConversationCommands:
|
||||
message.set_extra("_clean_ltm_session", True)
|
||||
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
|
||||
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
|
||||
)
|
||||
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None:
|
||||
@@ -313,12 +313,12 @@ class ConversationCommands:
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。",
|
||||
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。",
|
||||
),
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。"),
|
||||
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。"),
|
||||
)
|
||||
|
||||
async def switch_conv(
|
||||
@@ -329,14 +329,14 @@ class ConversationCommands:
|
||||
"""通过 /ls 前面的序号切换对话"""
|
||||
if not isinstance(index, int):
|
||||
message.set_result(
|
||||
MessageEventResult().message("类型错误,请输入数字对话序号。"),
|
||||
MessageEventResult().message("类型错误,请输入数字对话序号。"),
|
||||
)
|
||||
return
|
||||
|
||||
if index is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话",
|
||||
"请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -345,7 +345,7 @@ class ConversationCommands:
|
||||
)
|
||||
if index > len(conversations) or index < 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
|
||||
)
|
||||
else:
|
||||
conversation = conversations[index - 1]
|
||||
@@ -356,20 +356,20 @@ class ConversationCommands:
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"切换到对话: {title}({conversation.cid[:4]})。",
|
||||
f"切换到对话: {title}({conversation.cid[:4]})。",
|
||||
),
|
||||
)
|
||||
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None:
|
||||
"""重命名对话"""
|
||||
if not new_name:
|
||||
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
|
||||
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
|
||||
return
|
||||
await self.context.conversation_manager.update_conversation_title(
|
||||
message.unified_msg_origin,
|
||||
new_name,
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
|
||||
async def del_conv(self, message: AstrMessageEvent) -> None:
|
||||
"""删除当前对话"""
|
||||
@@ -377,10 +377,10 @@ class ConversationCommands:
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
is_unique_session = cfg["platform_settings"]["unique_session"]
|
||||
if message.get_group_id() and not is_unique_session and message.role != "admin":
|
||||
# 群聊,没开独立会话,发送人不是管理员
|
||||
# 群聊,没开独立会话,发送人不是管理员
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。",
|
||||
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -393,7 +393,7 @@ class ConversationCommands:
|
||||
scope_id=umo,
|
||||
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重置对话成功。"))
|
||||
message.set_result(MessageEventResult().message("重置对话成功。"))
|
||||
return
|
||||
|
||||
session_curr_cid = (
|
||||
@@ -403,7 +403,7 @@ class ConversationCommands:
|
||||
if not session_curr_cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建。",
|
||||
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -415,6 +415,6 @@ class ConversationCommands:
|
||||
session_curr_cid,
|
||||
)
|
||||
|
||||
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
message.set_extra("_clean_ltm_session", True)
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
@@ -24,7 +24,7 @@ class HelpCommand:
|
||||
|
||||
async def _build_reserved_command_lines(self) -> list[str]:
|
||||
"""
|
||||
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
|
||||
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
|
||||
"""
|
||||
try:
|
||||
commands = await command_management.list_commands()
|
||||
|
||||
@@ -17,4 +17,4 @@ class LLMCommands:
|
||||
cfg["provider_settings"]["enable"] = True
|
||||
status = "开启"
|
||||
cfg.save_config()
|
||||
await event.send(MessageChain().message(f"{status} LLM 聊天功能。"))
|
||||
await event.send(MessageChain().message(f"{status} LLM 聊天功能。"))
|
||||
|
||||
@@ -18,10 +18,10 @@ class PersonaCommands:
|
||||
all_personas: list["Persona"],
|
||||
depth: int = 0,
|
||||
) -> list[str]:
|
||||
"""递归构建树状输出,使用短线条表示层级"""
|
||||
"""递归构建树状输出,使用短线条表示层级"""
|
||||
lines: list[str] = []
|
||||
# 使用短线条作为缩进前缀,每层只用 "│" 加一个空格
|
||||
prefix = "│ " * depth
|
||||
# 使用短线条作为缩进前缀,每层只用 "│" 加一个空格
|
||||
prefix = "│ " * depth
|
||||
|
||||
for folder in folder_tree:
|
||||
# 输出文件夹
|
||||
@@ -31,7 +31,7 @@ class PersonaCommands:
|
||||
folder_personas = [
|
||||
p for p in all_personas if p.folder_id == folder["folder_id"]
|
||||
]
|
||||
child_prefix = "│ " * (depth + 1)
|
||||
child_prefix = "│ " * (depth + 1)
|
||||
|
||||
# 输出该文件夹下的人格
|
||||
for persona in folder_personas:
|
||||
@@ -71,7 +71,7 @@ class PersonaCommands:
|
||||
if conv is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前对话不存在,请先使用 /new 新建一个对话。",
|
||||
"当前对话不存在,请先使用 /new 新建一个对话。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -127,16 +127,16 @@ class PersonaCommands:
|
||||
folder_tree = await self.context.persona_manager.get_folder_tree()
|
||||
all_personas = self.context.persona_manager.personas
|
||||
|
||||
lines = ["📂 人格列表:\n"]
|
||||
lines = ["📂 人格列表:\n"]
|
||||
|
||||
# 构建树状输出
|
||||
tree_lines = self._build_tree_output(folder_tree, all_personas)
|
||||
lines.extend(tree_lines)
|
||||
|
||||
# 输出根目录下的人格(没有文件夹的)
|
||||
# 输出根目录下的人格(没有文件夹的)
|
||||
root_personas = [p for p in all_personas if p.folder_id is None]
|
||||
if root_personas:
|
||||
if tree_lines: # 如果有文件夹内容,加个空行
|
||||
if tree_lines: # 如果有文件夹内容,加个空行
|
||||
lines.append("")
|
||||
for persona in root_personas:
|
||||
lines.append(f"👤 {persona.persona_id}")
|
||||
@@ -161,7 +161,7 @@ class PersonaCommands:
|
||||
),
|
||||
None,
|
||||
):
|
||||
msg = f"人格{ps}的详细信息:\n"
|
||||
msg = f"人格{ps}的详细信息:\n"
|
||||
msg += f"{persona['prompt']}\n"
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
@@ -169,20 +169,20 @@ class PersonaCommands:
|
||||
elif parts[1] == "unset":
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message("当前没有对话,无法取消人格。"),
|
||||
MessageEventResult().message("当前没有对话,无法取消人格。"),
|
||||
)
|
||||
return
|
||||
await self.context.conversation_manager.update_conversation_persona_id(
|
||||
message.unified_msg_origin,
|
||||
"[%None]",
|
||||
)
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(parts[1:]).strip()
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前没有对话,请先开始对话或使用 /new 创建一个对话。",
|
||||
"当前没有对话,请先开始对话或使用 /new 创建一个对话。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -199,16 +199,18 @@ class PersonaCommands:
|
||||
)
|
||||
force_warn_msg = ""
|
||||
if force_applied_persona_id:
|
||||
force_warn_msg = "提醒:由于自定义规则,您现在切换的人格将不会生效。"
|
||||
force_warn_msg = (
|
||||
"提醒:由于自定义规则,您现在切换的人格将不会生效。"
|
||||
)
|
||||
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。{force_warn_msg}",
|
||||
f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。{force_warn_msg}",
|
||||
),
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"不存在该人格情景。使用 /persona list 查看所有。",
|
||||
"不存在该人格情景。使用 /persona list 查看所有。",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -11,8 +11,8 @@ class PluginCommands:
|
||||
self.context = context
|
||||
|
||||
async def plugin_ls(self, event: AstrMessageEvent) -> None:
|
||||
"""获取已经安装的插件列表。"""
|
||||
parts = ["已加载的插件:\n"]
|
||||
"""获取已经安装的插件列表。"""
|
||||
parts = ["已加载的插件:\n"]
|
||||
for plugin in self.context.get_all_stars():
|
||||
line = f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
|
||||
if not plugin.activated:
|
||||
@@ -20,11 +20,11 @@ class PluginCommands:
|
||||
parts.append(line + "\n")
|
||||
|
||||
if len(parts) == 1:
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
else:
|
||||
plugin_list_info = "".join(parts)
|
||||
|
||||
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
|
||||
)
|
||||
@@ -32,51 +32,51 @@ class PluginCommands:
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""禁用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
|
||||
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
|
||||
return
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin off <插件名> 禁用插件。"),
|
||||
MessageEventResult().message("/plugin off <插件名> 禁用插件。"),
|
||||
)
|
||||
return
|
||||
if self.context._star_manager is None:
|
||||
event.set_result(MessageEventResult().message("插件管理器未初始化。"))
|
||||
event.set_result(MessageEventResult().message("插件管理器未初始化。"))
|
||||
return
|
||||
await self.context._star_manager.turn_off_plugin(plugin_name)
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
|
||||
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""启用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
|
||||
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
|
||||
return
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin on <插件名> 启用插件。"),
|
||||
MessageEventResult().message("/plugin on <插件名> 启用插件。"),
|
||||
)
|
||||
return
|
||||
if self.context._star_manager is None:
|
||||
event.set_result(MessageEventResult().message("插件管理器未初始化。"))
|
||||
event.set_result(MessageEventResult().message("插件管理器未初始化。"))
|
||||
return
|
||||
await self.context._star_manager.turn_on_plugin(plugin_name)
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
|
||||
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
|
||||
"""安装插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
|
||||
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
|
||||
return
|
||||
if not plugin_repo:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin get <插件仓库地址> 安装插件"),
|
||||
)
|
||||
return
|
||||
logger.info(f"准备从 {plugin_repo} 安装插件。")
|
||||
logger.info(f"准备从 {plugin_repo} 安装插件。")
|
||||
if self.context._star_manager:
|
||||
star_mgr = self.context._star_manager
|
||||
try:
|
||||
await star_mgr.install_plugin(plugin_repo)
|
||||
event.set_result(MessageEventResult().message("安装插件成功。"))
|
||||
event.set_result(MessageEventResult().message("安装插件成功。"))
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}")
|
||||
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
|
||||
@@ -86,12 +86,12 @@ class PluginCommands:
|
||||
"""获取插件帮助"""
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin help <插件名> 查看插件信息。"),
|
||||
MessageEventResult().message("/plugin help <插件名> 查看插件信息。"),
|
||||
)
|
||||
return
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if plugin is None:
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
return
|
||||
help_msg = ""
|
||||
help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}"
|
||||
@@ -111,15 +111,15 @@ class PluginCommands:
|
||||
command_names.append(filter_.group_name)
|
||||
|
||||
if len(command_handlers) > 0:
|
||||
parts = ["\n\n🔧 指令列表:\n"]
|
||||
parts = ["\n\n🔧 指令列表:\n"]
|
||||
for i in range(len(command_handlers)):
|
||||
line = f"- {command_names[i]}"
|
||||
if command_handlers[i].desc:
|
||||
line += f": {command_handlers[i].desc}"
|
||||
parts.append(line + "\n")
|
||||
parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /。")
|
||||
parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /。")
|
||||
help_msg += "".join(parts)
|
||||
|
||||
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
|
||||
ret += "更多帮助信息请查看插件仓库 README。"
|
||||
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
|
||||
ret += "更多帮助信息请查看插件仓库 README。"
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@@ -127,7 +127,7 @@ class ProviderCommands:
|
||||
return self.context.get_config(umo).get("provider_settings", {}) or {}
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"读取 provider_settings 失败,使用默认值: %s",
|
||||
"读取 provider_settings 失败,使用默认值: %s",
|
||||
safe_error("", e),
|
||||
)
|
||||
return {}
|
||||
@@ -142,7 +142,7 @@ class ProviderCommands:
|
||||
return max(float(raw), 0.0)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"读取 %s 失败,回退默认值 %r: %s",
|
||||
"读取 %s 失败,回退默认值 %r: %s",
|
||||
MODEL_LIST_CACHE_TTL_KEY,
|
||||
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
|
||||
safe_error("", e),
|
||||
@@ -159,7 +159,7 @@ class ProviderCommands:
|
||||
value = int(raw)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"读取 %s 失败,回退默认值 %r: %s",
|
||||
"读取 %s 失败,回退默认值 %r: %s",
|
||||
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
|
||||
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
|
||||
safe_error("", e),
|
||||
@@ -209,7 +209,7 @@ class ProviderCommands:
|
||||
) -> str:
|
||||
prov.set_model(model_name)
|
||||
self.invalidate_provider_models_cache(prov.meta().id, umo=umo)
|
||||
return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
|
||||
return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
|
||||
|
||||
async def _get_provider_models(
|
||||
self,
|
||||
@@ -265,7 +265,7 @@ class ProviderCommands:
|
||||
err_code: str,
|
||||
err_reason: str,
|
||||
) -> None:
|
||||
"""记录不可达原因到日志。"""
|
||||
"""记录不可达原因到日志。"""
|
||||
meta = provider.meta()
|
||||
logger.warning(
|
||||
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
|
||||
@@ -358,7 +358,7 @@ class ProviderCommands:
|
||||
provider_id for provider_id, _ in failed_provider_errors
|
||||
)
|
||||
logger.error(
|
||||
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络",
|
||||
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络",
|
||||
model_name,
|
||||
len(all_providers),
|
||||
failed_ids,
|
||||
@@ -405,7 +405,7 @@ class ProviderCommands:
|
||||
if all_providers:
|
||||
await event.send(
|
||||
MessageEventResult().message(
|
||||
"正在进行提供商可达性测试,请稍候..."
|
||||
"正在进行提供商可达性测试,请稍候..."
|
||||
)
|
||||
)
|
||||
check_results = await asyncio.gather(
|
||||
@@ -426,7 +426,7 @@ class ProviderCommands:
|
||||
if isinstance(reachable, asyncio.CancelledError):
|
||||
raise reachable
|
||||
if isinstance(reachable, Exception):
|
||||
# 异常情况下兜底处理,避免单个 provider 导致列表失败
|
||||
# 异常情况下兜底处理,避免单个 provider 导致列表失败
|
||||
self._log_reachability_failure(
|
||||
p,
|
||||
None,
|
||||
@@ -501,23 +501,23 @@ class ProviderCommands:
|
||||
line += " (当前使用)"
|
||||
parts.append(line + "\n")
|
||||
|
||||
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
|
||||
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
|
||||
ret = "".join(parts)
|
||||
|
||||
if ttss:
|
||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||
if stts:
|
||||
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
|
||||
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
|
||||
if not reachability_check_enabled:
|
||||
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
|
||||
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
elif idx == "tts":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
@@ -526,13 +526,13 @@ class ProviderCommands:
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif idx == "stt":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
@@ -541,10 +541,10 @@ class ProviderCommands:
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
@@ -553,16 +553,16 @@ class ProviderCommands:
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
|
||||
async def _switch_model_by_name(
|
||||
self, message: AstrMessageEvent, model_name: str, prov: Provider
|
||||
) -> None:
|
||||
model_name = model_name.strip()
|
||||
if not model_name:
|
||||
message.set_result(MessageEventResult().message("模型名不能为空。"))
|
||||
message.set_result(MessageEventResult().message("模型名不能为空。"))
|
||||
return
|
||||
|
||||
umo = message.unified_msg_origin
|
||||
@@ -574,7 +574,7 @@ class ProviderCommands:
|
||||
prov,
|
||||
config,
|
||||
error_prefix="获取当前提供商模型列表失败: ",
|
||||
warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
|
||||
warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
|
||||
)
|
||||
if models is None:
|
||||
return
|
||||
@@ -597,7 +597,7 @@ class ProviderCommands:
|
||||
if target_prov is None or matched_target_model_name is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。",
|
||||
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。",
|
||||
),
|
||||
)
|
||||
return
|
||||
@@ -612,7 +612,7 @@ class ProviderCommands:
|
||||
self._apply_model(target_prov, matched_target_model_name, umo=umo)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
|
||||
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
|
||||
),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
@@ -633,7 +633,7 @@ class ProviderCommands:
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
)
|
||||
return
|
||||
config = self._get_model_lookup_config(message.unified_msg_origin)
|
||||
@@ -655,7 +655,7 @@ class ProviderCommands:
|
||||
curr_model = prov.get_model() or "无"
|
||||
parts.append(f"\n当前模型: [{curr_model}]")
|
||||
parts.append(
|
||||
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。"
|
||||
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。"
|
||||
)
|
||||
|
||||
ret = "".join(parts)
|
||||
@@ -670,7 +670,7 @@ class ProviderCommands:
|
||||
if models is None:
|
||||
return
|
||||
if idx_or_name > len(models) or idx_or_name < 1:
|
||||
message.set_result(MessageEventResult().message("模型序号错误。"))
|
||||
message.set_result(MessageEventResult().message("模型序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_model = models[idx_or_name - 1]
|
||||
@@ -697,7 +697,7 @@ class ProviderCommands:
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
|
||||
)
|
||||
return
|
||||
|
||||
@@ -710,14 +710,14 @@ class ProviderCommands:
|
||||
|
||||
parts.append(f"\n当前 Key: {curr_key[:8]}")
|
||||
parts.append("\n当前模型: " + prov.get_model())
|
||||
parts.append("\n使用 /key <idx> 切换 Key。")
|
||||
parts.append("\n使用 /key <idx> 切换 Key。")
|
||||
|
||||
ret = "".join(parts)
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
keys_data = prov.get_keys()
|
||||
if index > len(keys_data) or index < 1:
|
||||
message.set_result(MessageEventResult().message("Key 序号错误。"))
|
||||
message.set_result(MessageEventResult().message("Key 序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_key = keys_data[index - 1]
|
||||
@@ -726,7 +726,7 @@ class ProviderCommands:
|
||||
prov.meta().id,
|
||||
umo=message.unified_msg_origin,
|
||||
)
|
||||
message.set_result(MessageEventResult().message("切换 Key 成功。"))
|
||||
message.set_result(MessageEventResult().message("切换 Key 成功。"))
|
||||
except Exception as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
|
||||
@@ -9,28 +9,28 @@ class SetUnsetCommands:
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
|
||||
"""设置会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {}) or {}
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
session_var[key] = value
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。",
|
||||
f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。",
|
||||
),
|
||||
)
|
||||
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
|
||||
"""移除会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {}) or {}
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
|
||||
if key not in session_var:
|
||||
event.set_result(
|
||||
MessageEventResult().message("没有那个变量名。格式 /unset 变量名。"),
|
||||
MessageEventResult().message("没有那个变量名。格式 /unset 变量名。"),
|
||||
)
|
||||
else:
|
||||
del session_var[key]
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。"),
|
||||
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。"),
|
||||
)
|
||||
|
||||
@@ -18,19 +18,19 @@ class SIDCommand:
|
||||
umo_msg_type = event.session.message_type.value
|
||||
umo_session_id = event.session.session_id
|
||||
ret = (
|
||||
f"UMO: 「{sid}」 此值可用于设置白名单。\n"
|
||||
f"UID: 「{user_id}」 此值可用于设置管理员。\n"
|
||||
f"UMO: 「{sid}」 此值可用于设置白名单。\n"
|
||||
f"UID: 「{user_id}」 此值可用于设置管理员。\n"
|
||||
f"消息会话来源信息:\n"
|
||||
f" 机器人 ID: 「{umo_platform}」\n"
|
||||
f" 消息类型: 「{umo_msg_type}」\n"
|
||||
f" 会话 ID: 「{umo_session_id}」\n"
|
||||
f"消息来源可用于配置机器人的配置文件路由。"
|
||||
f" 机器人 ID: 「{umo_platform}」\n"
|
||||
f" 消息类型: 「{umo_msg_type}」\n"
|
||||
f" 会话 ID: 「{umo_session_id}」\n"
|
||||
f"消息来源可用于配置机器人的配置文件路由。"
|
||||
)
|
||||
|
||||
if (
|
||||
self.context.get_config()["platform_settings"]["unique_session"]
|
||||
and event.get_group_id()
|
||||
):
|
||||
ret += f"\n\n当前处于独立会话模式, 此群 ID: 「{event.get_group_id()}」, 也可将此 ID 加入白名单来放行整个群聊。"
|
||||
ret += f"\n\n当前处于独立会话模式, 此群 ID: 「{event.get_group_id()}」, 也可将此 ID 加入白名单来放行整个群聊。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@@ -16,8 +16,8 @@ class T2ICommand:
|
||||
if config["t2i"]:
|
||||
config["t2i"] = False
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
|
||||
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
|
||||
return
|
||||
config["t2i"] = True
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
|
||||
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
|
||||
|
||||
@@ -12,7 +12,7 @@ class TTSCommand:
|
||||
self.context = context
|
||||
|
||||
async def tts(self, event: AstrMessageEvent) -> None:
|
||||
"""开关文本转语音(会话级别)"""
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
@@ -27,10 +27,10 @@ class TTSCommand:
|
||||
if new_status and not tts_enable:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。",
|
||||
f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。",
|
||||
),
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{status_text}当前会话的文本转语音。"),
|
||||
MessageEventResult().message(f"{status_text}当前会话的文本转语音。"),
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ class Main(star.Star):
|
||||
|
||||
@plugin.command("ls")
|
||||
async def plugin_ls(self, event: AstrMessageEvent) -> None:
|
||||
"""获取已经安装的插件列表。"""
|
||||
"""获取已经安装的插件列表。"""
|
||||
await self.plugin_c.plugin_ls(event)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@@ -84,7 +84,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command("tts")
|
||||
async def tts(self, event: AstrMessageEvent) -> None:
|
||||
"""开关文本转语音(会话级别)"""
|
||||
"""开关文本转语音(会话级别)"""
|
||||
await self.tts_c.tts(event)
|
||||
|
||||
@filter.command("sid")
|
||||
@@ -95,25 +95,25 @@ class Main(star.Star):
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("op")
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
|
||||
"""授权管理员。op <admin_id>"""
|
||||
"""授权管理员。op <admin_id>"""
|
||||
await self.admin_c.op(event, admin_id)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("deop")
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
await self.admin_c.deop(event, admin_id)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("wl")
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""添加白名单。wl <sid>"""
|
||||
"""添加白名单。wl <sid>"""
|
||||
await self.admin_c.wl(event, sid)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dwl")
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
|
||||
"""删除白名单。dwl <sid>"""
|
||||
"""删除白名单。dwl <sid>"""
|
||||
await self.admin_c.dwl(event, sid)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
|
||||
@@ -72,9 +72,9 @@ class Main(Star):
|
||||
# 使用 LLM 生成回复
|
||||
yield event.request_llm(
|
||||
prompt=(
|
||||
"注意,你正在社交媒体上中与用户进行聊天,用户只是通过@来唤醒你,但并未在这条消息中输入内容,他可能会在接下来一条发送他想发送的内容。"
|
||||
"你友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化。"
|
||||
"请注意,你仅需要输出要回复用户的内容,不要输出其他任何东西"
|
||||
"注意,你正在社交媒体上中与用户进行聊天,用户只是通过@来唤醒你,但并未在这条消息中输入内容,他可能会在接下来一条发送他想发送的内容。"
|
||||
"你友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化。"
|
||||
"请注意,你仅需要输出要回复用户的内容,不要输出其他任何东西"
|
||||
),
|
||||
session_id=curr_cid,
|
||||
contexts=[],
|
||||
@@ -83,8 +83,8 @@ class Main(Star):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM response failed: {e!s}")
|
||||
# LLM 回复失败,使用原始预设回复
|
||||
yield event.plain_result("想要问什么呢?😄")
|
||||
# LLM 回复失败,使用原始预设回复
|
||||
yield event.plain_result("想要问什么呢?😄")
|
||||
|
||||
@session_waiter(60)
|
||||
async def empty_mention_waiter(
|
||||
@@ -106,7 +106,7 @@ class Main(Star):
|
||||
except TimeoutError as _:
|
||||
pass
|
||||
except Exception as e:
|
||||
yield event.plain_result("发生错误,请联系管理员: " + str(e))
|
||||
yield event.plain_result("发生错误,请联系管理员: " + str(e))
|
||||
finally:
|
||||
event.stop_event()
|
||||
except Exception as e:
|
||||
|
||||
@@ -81,7 +81,7 @@ class SearchEngine:
|
||||
return ret
|
||||
|
||||
def tidy_text(self, text: str) -> str:
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
def _get_url(self, tag: Tag) -> str:
|
||||
|
||||
@@ -34,14 +34,14 @@ class Main(star.Star):
|
||||
self.bocha_key_index = 0
|
||||
self.bocha_key_lock = asyncio.Lock()
|
||||
|
||||
# 将 str 类型的 key 迁移至 list[str],并保存
|
||||
# 将 str 类型的 key 迁移至 list[str],并保存
|
||||
cfg = self.context.get_config()
|
||||
provider_settings = cfg.get("provider_settings")
|
||||
if provider_settings:
|
||||
tavily_key = provider_settings.get("websearch_tavily_key")
|
||||
if isinstance(tavily_key, str):
|
||||
logger.info(
|
||||
"检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。",
|
||||
"检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。",
|
||||
)
|
||||
if tavily_key:
|
||||
provider_settings["websearch_tavily_key"] = [tavily_key]
|
||||
@@ -62,7 +62,7 @@ class Main(star.Star):
|
||||
self.baidu_initialized = False
|
||||
|
||||
async def _tidy_text(self, text: str) -> str:
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
async def _get_from_url(self, url: str) -> str:
|
||||
@@ -124,10 +124,10 @@ class Main(star.Star):
|
||||
return results
|
||||
|
||||
async def _get_tavily_key(self, cfg: AstrBotConfig) -> str:
|
||||
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
|
||||
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
|
||||
tavily_keys = cfg.get("provider_settings", {}).get("websearch_tavily_key", [])
|
||||
if not tavily_keys:
|
||||
raise ValueError("错误:Tavily API密钥未在AstrBot中配置。")
|
||||
raise ValueError("错误:Tavily API密钥未在AstrBot中配置。")
|
||||
|
||||
async with self.tavily_key_lock:
|
||||
key = tavily_keys[self.tavily_key_index]
|
||||
@@ -203,11 +203,11 @@ class Main(star.Star):
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
) -> str:
|
||||
"""搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。
|
||||
"""搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。
|
||||
|
||||
Args:
|
||||
query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索。
|
||||
max_results(number): 返回的最大搜索结果数量,默认为 5。
|
||||
query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索。
|
||||
max_results(number): 返回的最大搜索结果数量,默认为 5。
|
||||
|
||||
"""
|
||||
logger.info(f"web_searcher - search_from_search_engine: {query}")
|
||||
@@ -231,7 +231,7 @@ class Main(star.Star):
|
||||
ret += processed_result
|
||||
|
||||
if websearch_link:
|
||||
ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
|
||||
ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
|
||||
|
||||
return ret
|
||||
|
||||
@@ -384,10 +384,10 @@ class Main(star.Star):
|
||||
return ret
|
||||
|
||||
async def _get_bocha_key(self, cfg: AstrBotConfig) -> str:
|
||||
"""并发安全的从列表中获取并轮换BoCha API密钥。"""
|
||||
"""并发安全的从列表中获取并轮换BoCha API密钥。"""
|
||||
bocha_keys = cfg.get("provider_settings", {}).get("websearch_bocha_key", [])
|
||||
if not bocha_keys:
|
||||
raise ValueError("错误:BoCha API密钥未在AstrBot中配置。")
|
||||
raise ValueError("错误:BoCha API密钥未在AstrBot中配置。")
|
||||
|
||||
async with self.bocha_key_lock:
|
||||
key = bocha_keys[self.bocha_key_index]
|
||||
@@ -500,18 +500,18 @@ class Main(star.Star):
|
||||
"count": count,
|
||||
}
|
||||
|
||||
# freshness:时间范围
|
||||
# freshness:时间范围
|
||||
if freshness:
|
||||
payload["freshness"] = freshness
|
||||
|
||||
# 是否返回摘要
|
||||
payload["summary"] = summary
|
||||
|
||||
# include:限制搜索域
|
||||
# include:限制搜索域
|
||||
if include:
|
||||
payload["include"] = include
|
||||
|
||||
# exclude:排除搜索域
|
||||
# exclude:排除搜索域
|
||||
if exclude:
|
||||
payload["exclude"] = exclude
|
||||
|
||||
@@ -567,9 +567,9 @@ class Main(star.Star):
|
||||
if provider == "default":
|
||||
web_search_t = func_tool_mgr.get_func("web_search")
|
||||
fetch_url_t = func_tool_mgr.get_func("fetch_url")
|
||||
if web_search_t and web_search_t.active:
|
||||
if web_search_t:
|
||||
tool_set.add_tool(web_search_t)
|
||||
if fetch_url_t and fetch_url_t.active:
|
||||
if fetch_url_t:
|
||||
tool_set.add_tool(fetch_url_t)
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
@@ -578,9 +578,9 @@ class Main(star.Star):
|
||||
elif provider == "tavily":
|
||||
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
|
||||
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
|
||||
if web_search_tavily and web_search_tavily.active:
|
||||
if web_search_tavily:
|
||||
tool_set.add_tool(web_search_tavily)
|
||||
if tavily_extract_web_page and tavily_extract_web_page.active:
|
||||
if tavily_extract_web_page:
|
||||
tool_set.add_tool(tavily_extract_web_page)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
@@ -590,8 +590,9 @@ class Main(star.Star):
|
||||
try:
|
||||
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
|
||||
aisearch_tool = func_tool_mgr.get_func("AIsearch")
|
||||
if aisearch_tool and aisearch_tool.active:
|
||||
tool_set.add_tool(aisearch_tool)
|
||||
if not aisearch_tool:
|
||||
raise ValueError("Cannot get Baidu AI Search MCP tool.")
|
||||
tool_set.add_tool(aisearch_tool)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
@@ -601,7 +602,7 @@ class Main(star.Star):
|
||||
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
|
||||
elif provider == "bocha":
|
||||
web_search_bocha = func_tool_mgr.get_func("web_search_bocha")
|
||||
if web_search_bocha and web_search_bocha.active:
|
||||
if web_search_bocha:
|
||||
tool_set.add_tool(web_search_bocha)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
|
||||
@@ -7,8 +7,7 @@ import click
|
||||
from click.shell_completion import get_completion_class
|
||||
|
||||
from . import __version__
|
||||
from .commands import bk, conf, init, plug, run, tui, uninstall
|
||||
from .i18n import t
|
||||
from .commands import bk, conf, init, plug, run, uninstall
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
@@ -23,12 +22,10 @@ logo_tmpl = r"""
|
||||
@click.group()
|
||||
@click.version_option(__version__, prog_name="AstrBot")
|
||||
def cli() -> None:
|
||||
"""Astrbot
|
||||
Agentic IM Chatbot infrastructure that integrates lots of IM platforms, LLMs, plugins and AI feature, and can be your openclaw alternative. ✨
|
||||
"""
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(logo_tmpl)
|
||||
click.echo(t("cli_welcome"))
|
||||
click.echo(t("cli_version", version=__version__))
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot CLI version: {__version__}")
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -71,7 +68,7 @@ def help(command_name: str | None, all: bool) -> None:
|
||||
cmd_ctx = click.Context(command, info_name=command.name, parent=parent)
|
||||
click.echo(command.get_help(cmd_ctx))
|
||||
else:
|
||||
click.echo(t("cli_unknown_command", command=command_name))
|
||||
click.echo(f"Unknown command: {command_name}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Display general help information
|
||||
@@ -85,7 +82,6 @@ cli.add_command(plug)
|
||||
cli.add_command(conf)
|
||||
cli.add_command(uninstall)
|
||||
cli.add_command(bk)
|
||||
cli.add_command(tui)
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -108,9 +104,6 @@ def completion(shell: str | None) -> None:
|
||||
sys.exit(1)
|
||||
|
||||
comp_cls = get_completion_class(shell)
|
||||
if comp_cls is None:
|
||||
click.echo(f"No completion support for shell: {shell}", err=True)
|
||||
sys.exit(1)
|
||||
comp = comp_cls(
|
||||
cli, ctx_args={}, prog_name="astrbot", complete_var="_ASTRBOT_COMPLETE"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from .cmd_conf import conf
|
||||
from .cmd_init import init
|
||||
from .cmd_plug import plug
|
||||
from .cmd_run import run
|
||||
from .cmd_tui import tui
|
||||
from .cmd_uninstall import uninstall
|
||||
|
||||
__all__ = ["bk", "conf", "init", "plug", "run", "tui", "uninstall"]
|
||||
__all__ = ["conf", "init", "plug", "run", "uninstall", "bk"]
|
||||
|
||||
@@ -7,41 +7,36 @@ from pathlib import Path
|
||||
import anyio
|
||||
import click
|
||||
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core import astrbot_config, db_helper
|
||||
from astrbot.core.backup import AstrBotExporter, AstrBotImporter
|
||||
|
||||
# Try importing KnowledgeBaseManager to support KB backup
|
||||
try:
|
||||
from astrbot.core.knowledge.kb_manager import KnowledgeBaseManager
|
||||
except ImportError:
|
||||
try:
|
||||
from astrbot.core.knowledge_base.kb_manager import KnowledgeBaseManager
|
||||
except ImportError:
|
||||
KnowledgeBaseManager = None
|
||||
|
||||
|
||||
async def _get_kb_manager():
|
||||
"""Initialize and return a KnowledgeBaseManager with full dependency chain."""
|
||||
from astrbot.core import astrbot_config, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
if KnowledgeBaseManager is None:
|
||||
return None
|
||||
|
||||
ucr = UmopConfigRouter(sp=sp)
|
||||
await ucr.initialize()
|
||||
|
||||
acm = AstrBotConfigManager(
|
||||
default_config=astrbot_config,
|
||||
ucr=ucr,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
persona_mgr = PersonaManager(db_helper, acm)
|
||||
await persona_mgr.initialize()
|
||||
|
||||
provider_manager = ProviderManager(
|
||||
acm,
|
||||
db_helper,
|
||||
persona_mgr,
|
||||
)
|
||||
|
||||
kb_manager = KnowledgeBaseManager(provider_manager)
|
||||
await kb_manager.initialize()
|
||||
|
||||
return kb_manager
|
||||
try:
|
||||
# Best effort initialization
|
||||
kb_mgr = KnowledgeBaseManager(astrbot_config, db_helper)
|
||||
# If there are async load methods, we might need to call them
|
||||
if hasattr(kb_mgr, "load_kbs_from_db"):
|
||||
await kb_mgr.load_kbs_from_db()
|
||||
elif hasattr(kb_mgr, "load_all"):
|
||||
await kb_mgr.load_all()
|
||||
return kb_mgr
|
||||
except Exception:
|
||||
# If KB manager fails to load (e.g. missing dependencies), return None
|
||||
# so we can still backup other data
|
||||
return None
|
||||
|
||||
|
||||
@click.group(name="bk")
|
||||
@@ -143,7 +138,8 @@ def export_data(
|
||||
"GPG tool not found. Please install GnuPG to use encryption/signing features."
|
||||
)
|
||||
|
||||
exporter = AstrBotExporter(db_helper)
|
||||
kb_mgr = await _get_kb_manager()
|
||||
exporter = AstrBotExporter(db_helper, kb_mgr)
|
||||
|
||||
async def on_progress(stage, current, total, message):
|
||||
click.echo(f"[{stage}] {message}")
|
||||
|
||||
@@ -1,172 +1,95 @@
|
||||
"""
|
||||
Configuration CLI for AstrBot.
|
||||
|
||||
This module provides:
|
||||
- secure hashing utilities for the dashboard password (argon2)
|
||||
- validators for commonly configurable items
|
||||
- click CLI group with `set`, `get`, and `password` subcommands
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
import zoneinfo
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import argon2.exceptions as argon2_exceptions
|
||||
import click
|
||||
from argon2 import PasswordHasher
|
||||
|
||||
from astrbot.cli.i18n import t
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
_PASSWORD_HASHER = PasswordHasher()
|
||||
from ..utils import check_astrbot_root
|
||||
|
||||
DEFAULT_DASHBOARD_PASSWORD = "astrbot"
|
||||
DEFAULT_DASHBOARD_PASSWORD_MD5 = hashlib.md5(
|
||||
DEFAULT_DASHBOARD_PASSWORD.encode()
|
||||
).hexdigest()
|
||||
DEFAULT_DASHBOARD_PASSWORD_SHA256 = hashlib.sha256(
|
||||
DEFAULT_DASHBOARD_PASSWORD.encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
PBKDF2_SALT = b"astrbot-dashboard"
|
||||
PBKDF2_ITER = 200_000
|
||||
def hash_dashboard_password(value: str) -> str:
|
||||
"""Hash Dashboard password for storage."""
|
||||
return hashlib.sha256(value.encode()).hexdigest()
|
||||
|
||||
|
||||
# --- Password hashing & validation utilities ---
|
||||
def hash_dashboard_password_md5(value: str) -> str:
|
||||
"""Hash Dashboard password with the legacy MD5 algorithm."""
|
||||
return hashlib.md5(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def hash_dashboard_password_secure(value: str) -> str:
|
||||
"""
|
||||
Hash the dashboard password for storage.
|
||||
|
||||
Stored format:
|
||||
$argon2id$... (if Argon2 available) or pbkdf2_sha256 fallback.
|
||||
"""
|
||||
if _PASSWORD_HASHER is not None:
|
||||
try:
|
||||
return _PASSWORD_HASHER.hash(value)
|
||||
except Exception as e:
|
||||
raise click.ClickException(
|
||||
f"Failed to hash password securely (argon2): {e!s}"
|
||||
)
|
||||
|
||||
dk = hashlib.pbkdf2_hmac("sha256", value.encode("utf-8"), PBKDF2_SALT, PBKDF2_ITER)
|
||||
return f"pbkdf2_sha256${PBKDF2_ITER}${binascii.hexlify(PBKDF2_SALT).decode()}${dk.hex()}"
|
||||
|
||||
|
||||
def verify_dashboard_password(value: str, stored_hash: str) -> bool:
|
||||
"""
|
||||
Verify a plaintext password `value` against a stored hash.
|
||||
|
||||
Supported format:
|
||||
- Argon2 encoded string: $argon2id$...
|
||||
- PBKDF2 encoded string: pbkdf2_sha256$...
|
||||
- Legacy SHA-256 (64 hex chars) and MD5 (32 hex chars) for backward compatibility.
|
||||
"""
|
||||
if not stored_hash:
|
||||
return False
|
||||
|
||||
if stored_hash.startswith("$argon2"):
|
||||
try:
|
||||
return _PASSWORD_HASHER.verify(stored_hash, value)
|
||||
except argon2_exceptions.VerifyMismatchError:
|
||||
return False
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"Password verification failure (argon2): {e!s}")
|
||||
|
||||
if stored_hash.startswith("pbkdf2_sha256$"):
|
||||
try:
|
||||
_, iters_s, salt_hex, digest_hex = stored_hash.split("$", 3)
|
||||
iters = int(iters_s)
|
||||
salt = binascii.unhexlify(salt_hex)
|
||||
expected = digest_hex.lower()
|
||||
dk = hashlib.pbkdf2_hmac("sha256", value.encode("utf-8"), salt, iters)
|
||||
return dk.hex() == expected
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Legacy plain hex digests: SHA-256 (64 hex chars) and MD5 (32 hex chars).
|
||||
value_l = value.encode("utf-8")
|
||||
s = stored_hash.lower()
|
||||
if len(s) == 64 and all(ch in "0123456789abcdef" for ch in s):
|
||||
return hashlib.sha256(value_l).hexdigest() == s
|
||||
if len(s) == 32 and all(ch in "0123456789abcdef" for ch in s):
|
||||
return hashlib.md5(value_l).hexdigest() == s
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_dashboard_password_hash(value: str) -> bool:
|
||||
"""
|
||||
Heuristic: return True if `value` looks like a supported dashboard password hash.
|
||||
"""
|
||||
if not isinstance(value, str) or not value:
|
||||
return False
|
||||
return value.startswith("$argon2") or value.startswith("pbkdf2_sha256$")
|
||||
|
||||
|
||||
def is_legacy_dashboard_password_hash(value: str) -> bool:
|
||||
"""
|
||||
Heuristic: return True if `value` looks like a legacy password hash format.
|
||||
Legacy formats are plain SHA-256 (64 hex chars) or MD5 (32 hex chars) digests.
|
||||
"""
|
||||
if not isinstance(value, str) or not value:
|
||||
return False
|
||||
# Legacy plain hex digests: SHA-256 (64 hex chars) or MD5 (32 hex chars)
|
||||
if len(value) == 64 and all(ch in "0123456789abcdef" for ch in value.lower()):
|
||||
return True
|
||||
if len(value) == 32 and all(ch in "0123456789abcdef" for ch in value.lower()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# --- Validators for CLI configuration items ---
|
||||
def is_dashboard_password_hash(value: str, *, algorithm: str) -> bool:
|
||||
expected_len = 64 if algorithm == "sha256" else 32
|
||||
return len(value) == expected_len and all(ch in "0123456789abcdef" for ch in value)
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
value_up = value.upper()
|
||||
allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
if value_up not in allowed:
|
||||
raise click.ClickException(t("config_log_level_invalid"))
|
||||
return value_up
|
||||
"""Validate log level"""
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_port(value: str) -> int:
|
||||
"""Validate Dashboard port"""
|
||||
try:
|
||||
port = int(value)
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException("Port must be in range 1-65535")
|
||||
return port
|
||||
except ValueError:
|
||||
raise click.ClickException(t("config_port_must_be_number"))
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException(t("config_port_range_invalid"))
|
||||
return port
|
||||
raise click.ClickException("Port must be a number")
|
||||
|
||||
|
||||
def _validate_dashboard_username(value: str) -> str:
|
||||
if value is None or value.strip() == "":
|
||||
raise click.ClickException(t("config_username_empty"))
|
||||
return value.strip()
|
||||
"""Validate Dashboard username"""
|
||||
if not value:
|
||||
raise click.ClickException("Username cannot be empty")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_password(value: str) -> str:
|
||||
if value is None or value == "":
|
||||
raise click.ClickException(t("config_password_empty"))
|
||||
# Return the canonical stored representation.
|
||||
return hash_dashboard_password_secure(value)
|
||||
"""Validate Dashboard password"""
|
||||
if not value:
|
||||
raise click.ClickException("Password cannot be empty")
|
||||
return hash_dashboard_password(value)
|
||||
|
||||
|
||||
def _validate_timezone(value: str) -> str:
|
||||
"""Validate timezone"""
|
||||
try:
|
||||
zoneinfo.ZoneInfo(value)
|
||||
except Exception:
|
||||
raise click.ClickException(t("config_timezone_invalid", value=value))
|
||||
raise click.ClickException(
|
||||
f"Invalid timezone: {value}. Please use a valid IANA timezone name"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_callback_api_base(value: str) -> str:
|
||||
if not (value.startswith("http://") or value.startswith("https://")):
|
||||
raise click.ClickException(t("config_callback_invalid"))
|
||||
"""Validate callback API base URL"""
|
||||
if not value.startswith("http://") and not value.startswith("https://"):
|
||||
raise click.ClickException(
|
||||
"Callback API base must start with http:// or https://"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
# Configuration items settable via CLI, mapping config keys to validator functions
|
||||
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
"timezone": _validate_timezone,
|
||||
"log_level": _validate_log_level,
|
||||
@@ -177,23 +100,18 @@ CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
}
|
||||
|
||||
|
||||
# --- Config file helpers ---
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
"""
|
||||
Load or initialize the CLI config file (data/cmd_config.json).
|
||||
Ensures the astrbot root is valid before proceeding.
|
||||
"""
|
||||
"""Load or initialize config file"""
|
||||
root = astrbot_paths.root
|
||||
if not astrbot_paths.is_root:
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize"
|
||||
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
|
||||
)
|
||||
|
||||
config_path = astrbot_paths.data / "cmd_config.json"
|
||||
if not config_path.exists():
|
||||
# Write DEFAULT_CONFIG to disk if file missing
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
@@ -206,43 +124,50 @@ def _load_config() -> dict[str, Any]:
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
"""Save config file"""
|
||||
config_path = astrbot_paths.data / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||
json.dumps(config, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
|
||||
|
||||
def ensure_config_file() -> dict[str, Any]:
|
||||
"""Ensure config file exists and return parsed config."""
|
||||
return _load_config()
|
||||
|
||||
|
||||
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""Set a value in a nested dictionary"""
|
||||
parts = path.split(".")
|
||||
cur = obj
|
||||
for part in parts[:-1]:
|
||||
if part not in cur:
|
||||
cur[part] = {}
|
||||
elif not isinstance(cur[part], dict):
|
||||
if part not in obj:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict"
|
||||
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict",
|
||||
)
|
||||
cur = cur[part]
|
||||
cur[parts[-1]] = value
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
|
||||
|
||||
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
"""Get a value from a nested dictionary"""
|
||||
parts = path.split(".")
|
||||
cur = obj
|
||||
for part in parts:
|
||||
cur = cur[part]
|
||||
return cur
|
||||
|
||||
|
||||
# --- CLI commands ---
|
||||
obj = obj[part]
|
||||
return obj
|
||||
|
||||
|
||||
def prompt_dashboard_password(prompt: str = "Dashboard password") -> str:
|
||||
password = click.prompt(prompt, hide_input=True, confirmation_prompt=True, type=str)
|
||||
"""Prompt for dashboard password with confirmation."""
|
||||
password = click.prompt(
|
||||
prompt,
|
||||
hide_input=True,
|
||||
confirmation_prompt=True,
|
||||
type=str,
|
||||
)
|
||||
return _validate_dashboard_password(password)
|
||||
|
||||
|
||||
@@ -252,69 +177,63 @@ def set_dashboard_credentials(
|
||||
username: str | None = None,
|
||||
password_hash: str | None = None,
|
||||
) -> None:
|
||||
"""Update dashboard credentials in config."""
|
||||
if username is not None:
|
||||
_set_nested_item(
|
||||
config, "dashboard.username", _validate_dashboard_username(username)
|
||||
config,
|
||||
"dashboard.username",
|
||||
_validate_dashboard_username(username),
|
||||
)
|
||||
if password_hash is not None:
|
||||
if isinstance(password_hash, str) and is_dashboard_password_hash(password_hash):
|
||||
_set_nested_item(config, "dashboard.password", password_hash)
|
||||
else:
|
||||
if is_legacy_dashboard_password_hash(password_hash):
|
||||
raise click.ClickException(
|
||||
"Storing legacy dashboard password hashes is no longer supported. "
|
||||
"Please provide the plaintext password (it will be hashed securely), "
|
||||
"or provide an Argon2-encoded hash string."
|
||||
)
|
||||
_set_nested_item(
|
||||
config,
|
||||
"dashboard.password",
|
||||
_validate_dashboard_password(password_hash),
|
||||
)
|
||||
_set_nested_item(config, "dashboard.password", password_hash)
|
||||
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf() -> None:
|
||||
"""
|
||||
Configuration management commands.
|
||||
"""Configuration management commands
|
||||
|
||||
Supported config keys:
|
||||
- timezone
|
||||
- log_level
|
||||
- dashboard.port
|
||||
- dashboard.username
|
||||
- dashboard.password
|
||||
- callback_api_base
|
||||
|
||||
- timezone: Timezone setting (e.g. Asia/Shanghai)
|
||||
|
||||
- log_level: Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
|
||||
- dashboard.port: Dashboard port
|
||||
|
||||
- dashboard.username: Dashboard username
|
||||
|
||||
- dashboard.password: Dashboard password
|
||||
|
||||
- callback_api_base: Callback API base URL
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str) -> None:
|
||||
"""Set the value of a config item"""
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"Unsupported config key: {key}")
|
||||
|
||||
config = _load_config()
|
||||
try:
|
||||
# Attempt to get old value (may raise KeyError)
|
||||
try:
|
||||
old_value = _get_nested_item(config, key)
|
||||
except Exception:
|
||||
old_value = "<not set>"
|
||||
|
||||
try:
|
||||
old_value = _get_nested_item(config, key)
|
||||
validated_value = CONFIG_VALIDATORS[key](value)
|
||||
_set_nested_item(config, key, validated_value)
|
||||
_save_config(config)
|
||||
|
||||
click.echo(f"Config updated: {key}")
|
||||
click.echo(f" Old value: {old_value}")
|
||||
click.echo(f" New value: {validated_value}")
|
||||
if key == "dashboard.password":
|
||||
click.echo(" Old value: ********")
|
||||
click.echo(" New value: ********")
|
||||
else:
|
||||
click.echo(f" Old value: {old_value}")
|
||||
click.echo(f" New value: {validated_value}")
|
||||
|
||||
except KeyError:
|
||||
raise click.ClickException(f"Unknown config key: {key}")
|
||||
except click.ClickException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"Failed to set config: {e!s}")
|
||||
|
||||
@@ -322,10 +241,13 @@ def set_config(key: str, value: str) -> None:
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str | None = None) -> None:
|
||||
"""Get the value of a config item. If no key is provided, show all configurable items"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"Unsupported config key: {key}")
|
||||
|
||||
try:
|
||||
value = _get_nested_item(config, key)
|
||||
if key == "dashboard.password":
|
||||
@@ -337,51 +259,35 @@ def get_config(key: str | None = None) -> None:
|
||||
raise click.UsageError(f"Failed to get config: {e!s}")
|
||||
else:
|
||||
click.echo("Current config:")
|
||||
for k in CONFIG_VALIDATORS:
|
||||
for key in CONFIG_VALIDATORS:
|
||||
try:
|
||||
v = (
|
||||
value = (
|
||||
"********"
|
||||
if k == "dashboard.password"
|
||||
else _get_nested_item(config, k)
|
||||
if key == "dashboard.password"
|
||||
else _get_nested_item(config, key)
|
||||
)
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo(f" {key}: {value}")
|
||||
except (KeyError, TypeError):
|
||||
# Missing or non-dict paths are simply skipped in listing
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="admin")
|
||||
@click.option("-u", "--username", type=str, help="Update admain username as well")
|
||||
@conf.command(name="password")
|
||||
@click.option("-u", "--username", type=str, help="Update dashboard username as well")
|
||||
@click.option(
|
||||
"-p",
|
||||
"--password",
|
||||
type=str,
|
||||
help="Set admain password directly without interactive prompt",
|
||||
help="Set dashboard password directly without interactive prompt",
|
||||
)
|
||||
def set_dashboard_password(username: str | None, password: str | None) -> None:
|
||||
"""
|
||||
Interactively set dashboard password (with confirmation) or set directly with -p.
|
||||
|
||||
Acceptable inputs:
|
||||
- Plaintext password (recommended): it will be hashed securely before storage.
|
||||
- Argon2 encoded hash (advanced): stored as-is.
|
||||
"""
|
||||
"""Interactively manage dashboard password."""
|
||||
config = _load_config()
|
||||
|
||||
if password is not None:
|
||||
if isinstance(password, str) and is_dashboard_password_hash(password):
|
||||
password_hash = password
|
||||
else:
|
||||
if is_legacy_dashboard_password_hash(password):
|
||||
raise click.ClickException(
|
||||
"Providing legacy dashboard password hashes is no longer supported. "
|
||||
"Please supply the plaintext password (it will be hashed securely), "
|
||||
"or provide an Argon2-encoded hash string."
|
||||
)
|
||||
password_hash = _validate_dashboard_password(password)
|
||||
else:
|
||||
password_hash = prompt_dashboard_password()
|
||||
|
||||
password_hash = (
|
||||
_validate_dashboard_password(password)
|
||||
if password is not None
|
||||
else prompt_dashboard_password()
|
||||
)
|
||||
set_dashboard_credentials(
|
||||
config,
|
||||
username=username.strip() if username is not None else None,
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from astrbot.cli.utils import DashboardManager
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
from ..utils import check_dashboard
|
||||
from .cmd_conf import (
|
||||
_validate_dashboard_password,
|
||||
ensure_config_file,
|
||||
prompt_dashboard_password,
|
||||
set_dashboard_credentials,
|
||||
)
|
||||
|
||||
@@ -60,88 +60,50 @@ async def initialize_astrbot(
|
||||
)
|
||||
click.echo(f"Created config file: {config_path}")
|
||||
|
||||
# Generate an .env for this instance from the bundled config.template (if available).
|
||||
# The generated file will be written to ASTRBOT_ROOT/.env and will be automatically
|
||||
# loaded by `astrbot run` (service-config/.env precedence applies).
|
||||
ASTRBOT_ROOT = astrbot_root
|
||||
env_file = ASTRBOT_ROOT / ".env"
|
||||
if not env_file.exists():
|
||||
tmpl_candidates = [
|
||||
Path("/opt/astrbot/config.template"),
|
||||
# project_root may point to the installed package directory; try it as well
|
||||
getattr(astrbot_paths, "project_root", Path.cwd()) / "config.template",
|
||||
Path.cwd() / "config.template",
|
||||
]
|
||||
tmpl = None
|
||||
for t in tmpl_candidates:
|
||||
try:
|
||||
if t.exists():
|
||||
tmpl = t
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if tmpl is not None:
|
||||
try:
|
||||
txt = tmpl.read_text(encoding="utf-8")
|
||||
# Determine instance name for template replacement (fallback to directory name)
|
||||
instance_name = astrbot_root.name or "astrbot"
|
||||
# Substitute ${VAR} and ${VAR:-default} for INSTANCE_NAME, PORT, ASTRBOT_ROOT
|
||||
txt = re.sub(r"\$\{INSTANCE_NAME(:-[^}]*)?\}", instance_name, txt)
|
||||
port_val = (
|
||||
os.environ.get("ASTRBOT_PORT") or os.environ.get("PORT") or "8000"
|
||||
)
|
||||
txt = re.sub(r"\$\{PORT(:-[^}]*)?\}", str(port_val), txt)
|
||||
txt = re.sub(r"\$\{ASTRBOT_ROOT(:-[^}]*)?\}", str(ASTRBOT_ROOT), txt)
|
||||
header = (
|
||||
f"# Generated from config.template by astrbot init for instance: {instance_name}\n"
|
||||
"# This file will be auto-loaded by 'astrbot run'\n\n"
|
||||
)
|
||||
env_file.write_text(header + txt, encoding="utf-8")
|
||||
env_file.chmod(0o644)
|
||||
click.echo(f"Created environment file from template: {env_file}")
|
||||
except Exception as e:
|
||||
click.echo(f"Warning: failed to generate .env from template: {e!s}")
|
||||
else:
|
||||
click.echo("No config.template found; skipping .env generation")
|
||||
|
||||
if admin_password is not None:
|
||||
if admin_password and not admin_username:
|
||||
raise click.ClickException(
|
||||
"--admin-password is no longer supported during init. "
|
||||
"Run 'astrbot conf admin' after initialization."
|
||||
"--admin-password requires --admin-username to be provided"
|
||||
)
|
||||
|
||||
effective_admin_username = (
|
||||
admin_username.strip()
|
||||
if admin_username
|
||||
else str(cast(dict[str, Any], DEFAULT_CONFIG)["dashboard"]["username"])
|
||||
)
|
||||
if admin_username:
|
||||
password_hash = (
|
||||
_validate_dashboard_password(admin_password)
|
||||
if admin_password is not None
|
||||
else None
|
||||
)
|
||||
if password_hash is None:
|
||||
if yes or os.environ.get("ASTRBOT_SYSTEMD") == "1":
|
||||
raise click.ClickException(
|
||||
"Non-interactive init requires --admin-password when --admin-username is set"
|
||||
)
|
||||
password_hash = prompt_dashboard_password("Dashboard admin password")
|
||||
|
||||
config = ensure_config_file()
|
||||
set_dashboard_credentials(
|
||||
config,
|
||||
username=effective_admin_username,
|
||||
password_hash=None,
|
||||
username=admin_username.strip(),
|
||||
password_hash=password_hash,
|
||||
)
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
click.echo(f"Configured dashboard admin username: {effective_admin_username}")
|
||||
click.echo(
|
||||
"Dashboard password is not initialized for interactive use. "
|
||||
"Run 'astrbot conf admin' before the first login."
|
||||
)
|
||||
click.echo(f"Configured dashboard admin username: {admin_username.strip()}")
|
||||
|
||||
if not backend_only and (
|
||||
yes
|
||||
or click.confirm(
|
||||
"是否需要集成式 WebUI?(个人电脑推荐,服务器不推荐)",
|
||||
"是否需要集成式 WebUI?(个人电脑推荐,服务器不推荐)",
|
||||
default=True,
|
||||
)
|
||||
):
|
||||
await DashboardManager().ensure_installed(astrbot_root)
|
||||
# 避免在 systemd 模式下因等待输入而阻塞
|
||||
if os.environ.get("ASTRBOT_SYSTEMD") == "1":
|
||||
click.echo("Systemd detected: Skipping dashboard check.")
|
||||
else:
|
||||
await check_dashboard(astrbot_root)
|
||||
else:
|
||||
click.echo("你可以使用在线面版(需支持配置后端)来控制。")
|
||||
click.echo("你可以使用在线面版(需支持配置后端)来控制。")
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -158,12 +120,7 @@ async def initialize_astrbot(
|
||||
"-p",
|
||||
"--admin-password",
|
||||
type=str,
|
||||
help="Deprecated. Run `astrbot conf admin` after initialization.",
|
||||
)
|
||||
@click.option(
|
||||
"--root",
|
||||
help="ASTRBOT root directory to initialize (overrides ASTRBOT_ROOT env)",
|
||||
type=str,
|
||||
help="Set dashboard admin password during initialization without prompting",
|
||||
)
|
||||
def init(
|
||||
yes: bool,
|
||||
@@ -171,16 +128,14 @@ def init(
|
||||
backup: str | None,
|
||||
admin_username: str | None,
|
||||
admin_password: str | None,
|
||||
root: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
|
||||
if os.environ.get("ASTRBOT_SYSTEMD") == "1":
|
||||
yes = True
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
astrbot_root = Path(root) if root else astrbot_paths.root
|
||||
astrbot_root = astrbot_paths.root
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from astrbot.cli.i18n import t
|
||||
from astrbot.cli.utils import (
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
from ..utils import (
|
||||
PluginStatus,
|
||||
build_plug_list,
|
||||
check_astrbot_root,
|
||||
get_git_repo,
|
||||
manage_plugin,
|
||||
)
|
||||
@@ -17,6 +20,15 @@ def plug() -> None:
|
||||
"""Plugin management"""
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = astrbot_paths.root
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
|
||||
)
|
||||
return astrbot_paths.data.resolve()
|
||||
|
||||
|
||||
def display_plugins(plugins, title=None, color=None) -> None:
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
@@ -38,13 +50,11 @@ def display_plugins(plugins, title=None, color=None) -> None:
|
||||
@click.argument("name")
|
||||
def new(name: str) -> None:
|
||||
"""Create a new plugin"""
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
base_path = astrbot_paths.data
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(t("plugin_already_exists", name=name))
|
||||
raise click.ClickException(f"Plugin {name} already exists")
|
||||
|
||||
author = click.prompt("Enter plugin author", type=str)
|
||||
desc = click.prompt("Enter plugin description", type=str)
|
||||
@@ -97,9 +107,7 @@ def new(name: str) -> None:
|
||||
@click.option("--all", "-a", is_flag=True, help="List uninstalled plugins")
|
||||
def list(all: bool) -> None:
|
||||
"""List plugins"""
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
base_path = astrbot_paths.data
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
# Unpublished plugins
|
||||
@@ -140,9 +148,7 @@ def list(all: bool) -> None:
|
||||
@click.option("--proxy", help="Proxy server address")
|
||||
def install(name: str, proxy: str | None) -> None:
|
||||
"""Install a plugin"""
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
base_path = astrbot_paths.data
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
@@ -156,7 +162,7 @@ def install(name: str, proxy: str | None) -> None:
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(t("plugin_not_found_or_installed", name=name))
|
||||
raise click.ClickException(f"Plugin {name} not found or already installed")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||
|
||||
@@ -165,26 +171,24 @@ def install(name: str, proxy: str | None) -> None:
|
||||
@click.argument("name")
|
||||
def remove(name: str) -> None:
|
||||
"""Uninstall a plugin"""
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
base_path = astrbot_paths.data
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||
|
||||
if not plugin or not plugin.get("local_path"):
|
||||
raise click.ClickException(t("plugin_not_found_or_installed", name=name))
|
||||
raise click.ClickException(f"Plugin {name} does not exist or is not installed")
|
||||
|
||||
plugin_path = plugin["local_path"]
|
||||
|
||||
click.confirm(t("plugin_uninstall_confirm", name=name), default=False, abort=True)
|
||||
click.confirm(
|
||||
f"Are you sure you want to uninstall plugin {name}?", default=False, abort=True
|
||||
)
|
||||
|
||||
try:
|
||||
shutil.rmtree(plugin_path)
|
||||
click.echo(t("plugin_uninstall_success", name=name))
|
||||
click.echo(f"Plugin {name} has been uninstalled")
|
||||
except Exception as e:
|
||||
raise click.ClickException(
|
||||
t("plugin_uninstall_failed_ex", name=name, error=str(e))
|
||||
)
|
||||
raise click.ClickException(f"Failed to uninstall plugin {name}: {e}")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@@ -192,9 +196,7 @@ def remove(name: str) -> None:
|
||||
@click.option("--proxy", help="GitHub proxy address")
|
||||
def update(name: str, proxy: str | None) -> None:
|
||||
"""Update plugins"""
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
base_path = astrbot_paths.data
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
@@ -220,13 +222,13 @@ def update(name: str, proxy: str | None) -> None:
|
||||
]
|
||||
|
||||
if not need_update_plugins:
|
||||
click.echo(t("plugin_no_update_needed"))
|
||||
click.echo("No plugins need updating")
|
||||
return
|
||||
|
||||
click.echo(t("plugin_found_update", count=str(len(need_update_plugins))))
|
||||
click.echo(f"Found {len(need_update_plugins)} plugin(s) needing update")
|
||||
for plugin in need_update_plugins:
|
||||
plugin_name = plugin["name"]
|
||||
click.echo(t("plugin_updating", name=plugin_name))
|
||||
click.echo(f"Updating plugin {plugin_name}...")
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
|
||||
|
||||
@@ -234,9 +236,7 @@ def update(name: str, proxy: str | None) -> None:
|
||||
@click.argument("query")
|
||||
def search(query: str) -> None:
|
||||
"""Search for plugins"""
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
base_path = astrbot_paths.data
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
matched_plugins = [
|
||||
@@ -248,7 +248,7 @@ def search(query: str) -> None:
|
||||
]
|
||||
|
||||
if not matched_plugins:
|
||||
click.echo(t("plugin_search_no_result", query=query))
|
||||
click.echo(f"No plugins matching '{query}' found")
|
||||
return
|
||||
|
||||
display_plugins(matched_plugins, t("plugin_search_results", query=query), "cyan")
|
||||
display_plugins(matched_plugins, f"Search results: '{query}'", "cyan")
|
||||
|
||||
@@ -1,3 +1,40 @@
|
||||
"""AstrBot Run
|
||||
Environment Variables Used in Project:
|
||||
|
||||
Core:
|
||||
- `ASTRBOT_ROOT`: AstrBot root directory path.
|
||||
- `ASTRBOT_LOG_LEVEL`: Log level (e.g. INFO, DEBUG).
|
||||
- `ASTRBOT_CLI`: Flag indicating execution via CLI.
|
||||
- `ASTRBOT_DESKTOP_CLIENT`: Flag indicating execution via desktop client.
|
||||
- `ASTRBOT_SYSTEMD`: Flag indicating execution via systemd service.
|
||||
- `ASTRBOT_RELOAD`: Enable plugin auto-reload (set to "1").
|
||||
- `ASTRBOT_DISABLE_METRICS`: Disable metrics upload (set to "1").
|
||||
- `TESTING`: Enable testing mode.
|
||||
- `DEMO_MODE`: Enable demo mode.
|
||||
- `PYTHON`: Python executable path override (for local code execution).
|
||||
|
||||
Dashboard:
|
||||
- `ASTRBOT_DASHBOARD_ENABLE` / `DASHBOARD_ENABLE`: Enable/Disable Dashboard.
|
||||
- `ASTRBOT_DASHBOARD_HOST` / `DASHBOARD_HOST`: Dashboard bind host.
|
||||
- `ASTRBOT_DASHBOARD_PORT` / `DASHBOARD_PORT`: Dashboard bind port.
|
||||
- `ASTRBOT_DASHBOARD_SSL_ENABLE` / `DASHBOARD_SSL_ENABLE`: Enable SSL.
|
||||
- `ASTRBOT_DASHBOARD_SSL_CERT` / `DASHBOARD_SSL_CERT`: SSL Certificate path.
|
||||
- `ASTRBOT_DASHBOARD_SSL_KEY` / `DASHBOARD_SSL_KEY`: SSL Key path.
|
||||
- `ASTRBOT_DASHBOARD_SSL_CA_CERTS` / `DASHBOARD_SSL_CA_CERTS`: SSL CA Certs path.
|
||||
|
||||
Network:
|
||||
- `http_proxy` / `https_proxy`: Proxy URL.
|
||||
- `no_proxy`: No proxy list.
|
||||
|
||||
Integrations:
|
||||
- `DASHSCOPE_API_KEY`: Alibaba DashScope API Key (for Rerank).
|
||||
- `COZE_API_KEY` / `COZE_BOT_ID`: Coze integration.
|
||||
- `BAY_DATA_DIR`: Computer Use data directory.
|
||||
|
||||
Platform Specific:
|
||||
- `TEST_MODE`: Test mode for QQOfficial.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
@@ -7,8 +44,9 @@ from pathlib import Path
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from astrbot.cli.utils import DashboardManager
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_root
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
from ..utils import check_astrbot_root, check_dashboard
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path) -> None:
|
||||
@@ -16,7 +54,13 @@ async def run_astrbot(astrbot_root: Path) -> None:
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
await DashboardManager().ensure_installed(astrbot_root)
|
||||
if (
|
||||
os.environ.get("ASTRBOT_DASHBOARD_ENABLE", os.environ.get("DASHBOARD_ENABLE"))
|
||||
== "True"
|
||||
):
|
||||
# 避免在 systemd 模式下因等待输入而阻塞
|
||||
if os.environ.get("ASTRBOT_SYSTEMD") != "1":
|
||||
await check_dashboard(astrbot_root)
|
||||
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
@@ -28,15 +72,96 @@ async def run_astrbot(astrbot_root: Path) -> None:
|
||||
|
||||
|
||||
@click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins")
|
||||
@click.option("--host", "-H", help="AstrBot Dashboard Host", required=False, type=str)
|
||||
@click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str)
|
||||
@click.option("--root", help="AstrBot root directory", required=False, type=str)
|
||||
@click.option(
|
||||
"--service-config",
|
||||
"-c",
|
||||
help="Service configuration file path",
|
||||
required=False,
|
||||
type=str,
|
||||
)
|
||||
@click.option(
|
||||
"--backend-only",
|
||||
"-b",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Disable WebUI, run backend only",
|
||||
)
|
||||
@click.option(
|
||||
"--log-level",
|
||||
"-l",
|
||||
help="Log level",
|
||||
required=False,
|
||||
type=str,
|
||||
default="INFO",
|
||||
)
|
||||
@click.option("--debug", is_flag=True, help="Enable debug mode")
|
||||
@click.command()
|
||||
def run(reload: bool, port: str) -> None:
|
||||
def run(
|
||||
reload: bool,
|
||||
host: str,
|
||||
port: str,
|
||||
root: str,
|
||||
service_config: str,
|
||||
backend_only: bool,
|
||||
log_level: str,
|
||||
debug: bool,
|
||||
) -> None:
|
||||
"""Run AstrBot"""
|
||||
try:
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
astrbot_root = Path(get_astrbot_root())
|
||||
if debug:
|
||||
log_level = "DEBUG"
|
||||
|
||||
if not (astrbot_root / "data").exists():
|
||||
if service_config:
|
||||
svc_path = Path(service_config)
|
||||
if svc_path.exists():
|
||||
content = svc_path.read_text(encoding="utf-8")
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
# Remove quotes
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
value = value[1:-1]
|
||||
|
||||
if key == "HOST" and not host:
|
||||
host = value
|
||||
elif key == "PORT" and not port:
|
||||
port = value
|
||||
elif key == "ASTRBOT_ROOT" and not root:
|
||||
root = value
|
||||
|
||||
# Normalize environment variables for backward compatibility
|
||||
# If the legacy env var is set but the new one isn't, copy it over.
|
||||
env_map = {
|
||||
"DASHBOARD_ENABLE": "ASTRBOT_DASHBOARD_ENABLE",
|
||||
"DASHBOARD_HOST": "ASTRBOT_DASHBOARD_HOST",
|
||||
"DASHBOARD_PORT": "ASTRBOT_DASHBOARD_PORT",
|
||||
"DASHBOARD_SSL_ENABLE": "ASTRBOT_DASHBOARD_SSL_ENABLE",
|
||||
"DASHBOARD_SSL_CERT": "ASTRBOT_DASHBOARD_SSL_CERT",
|
||||
"DASHBOARD_SSL_KEY": "ASTRBOT_DASHBOARD_SSL_KEY",
|
||||
"DASHBOARD_SSL_CA_CERTS": "ASTRBOT_DASHBOARD_SSL_CA_CERTS",
|
||||
}
|
||||
for legacy, new in env_map.items():
|
||||
if legacy in os.environ and new not in os.environ:
|
||||
os.environ[new] = os.environ[legacy]
|
||||
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
if root:
|
||||
os.environ["ASTRBOT_ROOT"] = root
|
||||
astrbot_root = Path(root)
|
||||
else:
|
||||
astrbot_root = astrbot_paths.root
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
|
||||
)
|
||||
@@ -44,13 +169,67 @@ def run(reload: bool, port: str) -> None:
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
sys.path.insert(0, str(astrbot_root))
|
||||
|
||||
if port:
|
||||
os.environ["DASHBOARD_PORT"] = port
|
||||
if port is not None:
|
||||
os.environ["ASTRBOT_DASHBOARD_PORT"] = port
|
||||
os.environ["DASHBOARD_PORT"] = port # 今后应该移除
|
||||
if host is not None:
|
||||
os.environ["ASTRBOT_DASHBOARD_HOST"] = host
|
||||
os.environ["DASHBOARD_HOST"] = host # 今后应该移除
|
||||
os.environ["ASTRBOT_DASHBOARD_ENABLE"] = str(not backend_only)
|
||||
os.environ["DASHBOARD_ENABLE"] = str(not backend_only) # 今后应该移除
|
||||
os.environ["ASTRBOT_LOG_LEVEL"] = log_level
|
||||
|
||||
if reload:
|
||||
click.echo("Plugin auto-reload enabled")
|
||||
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||
|
||||
if debug:
|
||||
keys_to_print = [
|
||||
"ASTRBOT_ROOT",
|
||||
"ASTRBOT_LOG_LEVEL",
|
||||
"ASTRBOT_CLI",
|
||||
"ASTRBOT_DESKTOP_CLIENT",
|
||||
"ASTRBOT_SYSTEMD",
|
||||
"ASTRBOT_RELOAD",
|
||||
"ASTRBOT_DISABLE_METRICS",
|
||||
"TESTING",
|
||||
"DEMO_MODE",
|
||||
"PYTHON",
|
||||
"ASTRBOT_DASHBOARD_ENABLE",
|
||||
"DASHBOARD_ENABLE",
|
||||
"ASTRBOT_DASHBOARD_HOST",
|
||||
"DASHBOARD_HOST",
|
||||
"ASTRBOT_DASHBOARD_PORT",
|
||||
"DASHBOARD_PORT",
|
||||
"ASTRBOT_DASHBOARD_SSL_ENABLE",
|
||||
"DASHBOARD_SSL_ENABLE",
|
||||
"ASTRBOT_DASHBOARD_SSL_CERT",
|
||||
"DASHBOARD_SSL_CERT",
|
||||
"ASTRBOT_DASHBOARD_SSL_KEY",
|
||||
"DASHBOARD_SSL_KEY",
|
||||
"ASTRBOT_DASHBOARD_SSL_CA_CERTS",
|
||||
"DASHBOARD_SSL_CA_CERTS",
|
||||
"http_proxy",
|
||||
"https_proxy",
|
||||
"no_proxy",
|
||||
"DASHSCOPE_API_KEY",
|
||||
"COZE_API_KEY",
|
||||
"COZE_BOT_ID",
|
||||
"BAY_DATA_DIR",
|
||||
"TEST_MODE",
|
||||
]
|
||||
click.secho("\n[Debug Mode] Environment Variables:", fg="yellow", bold=True)
|
||||
for key in keys_to_print:
|
||||
if key in os.environ:
|
||||
val = os.environ[key]
|
||||
if "KEY" in key or "PASSWORD" in key or "SECRET" in key:
|
||||
if len(val) > 8:
|
||||
val = val[:4] + "****" + val[-4:]
|
||||
else:
|
||||
val = "****"
|
||||
click.echo(f" {click.style(key, fg='cyan')}: {val}")
|
||||
click.echo("")
|
||||
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
"""AstrBot Run TUI - A beautiful textual interface for running AstrBot.
|
||||
|
||||
This module provides a Textual-based TUI for `astrbot run` with:
|
||||
- Animated ASCII logo
|
||||
- Live log viewer
|
||||
- Platform status indicators
|
||||
- Only activates in interactive TTY environments
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, Vertical
|
||||
from textual.reactive import reactive
|
||||
from textual.widgets import Footer, Header, Log, Static
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
else:
|
||||
Console: Any = None
|
||||
Style: Any = None
|
||||
Text: Any = None
|
||||
|
||||
|
||||
# AstrBot ASCII Logo
|
||||
ASTRBOT_LOGO = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
"""
|
||||
|
||||
|
||||
class AstrBotRunTUI(App):
|
||||
"""Textual TUI for AstrBot run command."""
|
||||
|
||||
CSS = """
|
||||
Screen {
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#logo-container {
|
||||
height: auto;
|
||||
padding: 1 2;
|
||||
background: $surface-darken-1;
|
||||
border: solid $primary;
|
||||
}
|
||||
|
||||
#logo-text {
|
||||
color: $primary;
|
||||
text-style: bold;
|
||||
font-family: "JetBrains Mono", "Fira Code", monospace;
|
||||
}
|
||||
|
||||
#main-container {
|
||||
height: 1fr;
|
||||
}
|
||||
|
||||
#log-section {
|
||||
border: solid $accent;
|
||||
height: 70%;
|
||||
margin: 1 2;
|
||||
}
|
||||
|
||||
#log-header {
|
||||
background: $accent-darken-1;
|
||||
padding: 1 2;
|
||||
color: $text;
|
||||
text-style: bold;
|
||||
}
|
||||
|
||||
Log {
|
||||
background: $surface-darken-2;
|
||||
color: $text;
|
||||
border: solid $accent-darken-2;
|
||||
}
|
||||
|
||||
#status-section {
|
||||
height: auto;
|
||||
padding: 1 2;
|
||||
background: $surface-darken-1;
|
||||
border-top: solid $primary;
|
||||
}
|
||||
|
||||
.status-item {
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
.status-ok {
|
||||
color: $success;
|
||||
text-style: bold;
|
||||
}
|
||||
|
||||
.status-pending {
|
||||
color: $warning;
|
||||
}
|
||||
|
||||
.status-label {
|
||||
color: $text-muted;
|
||||
}
|
||||
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
|
||||
BINDINGS: typing.ClassVar[list[Binding]] = [
|
||||
Binding("q", "quit", "Quit", show=True),
|
||||
Binding("ctrl+c", "quit", "Quit", show=False),
|
||||
Binding("l", "toggle_logs", "Toggle Logs", show=True),
|
||||
]
|
||||
|
||||
log_visible = reactive(True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
startup_coro: Callable[[], Awaitable[Any]],
|
||||
astrbot_root: Path,
|
||||
backend_only: bool = False,
|
||||
host: str | None = None,
|
||||
port: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.startup_coro = startup_coro
|
||||
self.astrbot_root = astrbot_root
|
||||
self.backend_only = backend_only
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._animation_frame = 0
|
||||
self._startup_done = False
|
||||
self._log_lines: list[str] = []
|
||||
self.console: Any = Console() if Console else None
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create child widgets."""
|
||||
yield Header()
|
||||
|
||||
# Animated Logo
|
||||
with Container(id="logo-container"):
|
||||
yield Static(self._get_animated_logo(), id="logo-text")
|
||||
|
||||
# Main content
|
||||
with Vertical(id="main-container"):
|
||||
# Log viewer
|
||||
with Container(
|
||||
id="log-section", classes="" if self.log_visible else "hidden"
|
||||
):
|
||||
yield Static("📋 Live Logs", id="log-header")
|
||||
yield Log(id="log-viewer")
|
||||
|
||||
# Status bar
|
||||
with Horizontal(id="status-section"):
|
||||
yield Static("🌟 AstrBot", classes="status-item status-ok")
|
||||
yield Static(
|
||||
f"📁 {self.astrbot_root.name}",
|
||||
classes="status-item",
|
||||
id="root-status",
|
||||
)
|
||||
if not self.backend_only:
|
||||
dashboard_url = (
|
||||
f"http://{self.host or 'localhost'}:{self.port or '6185'}"
|
||||
)
|
||||
yield Static(
|
||||
f"🌐 Dashboard: [link]{dashboard_url}[/link]",
|
||||
classes="status-item",
|
||||
id="dashboard-status",
|
||||
)
|
||||
yield Static(
|
||||
"⚡ Running", classes="status-item status-ok", id="run-status"
|
||||
)
|
||||
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Called when app is mounted."""
|
||||
self.title = "AstrBot"
|
||||
self.sub_title = "AI Chatbot Framework"
|
||||
|
||||
# Start the startup coroutine
|
||||
self.set_timer(0.1, self._run_startup)
|
||||
|
||||
# Animate logo
|
||||
self.set_interval(0.5, self._animate_logo)
|
||||
|
||||
# Get the log widget and configure it
|
||||
log_widget = self.query_one("#log-viewer", Log)
|
||||
log_widget.write_line("🚀 AstrBot TUI initialized")
|
||||
log_widget.write_line(f"📁 Running from: {self.astrbot_root}")
|
||||
if not self.backend_only:
|
||||
log_widget.write_line(
|
||||
f"🌐 Dashboard will be available at: {self.host or 'localhost'}:{self.port or '6185'}"
|
||||
)
|
||||
log_widget.write_line("")
|
||||
|
||||
def _get_animated_logo(self) -> str:
|
||||
"""Get the logo with optional animation effect."""
|
||||
lines = ASTRBOT_LOGO.strip().split("\n")
|
||||
|
||||
if self.console and hasattr(self, "_animation_frame"):
|
||||
# Create animated version with color cycling
|
||||
frame = self._animation_frame % 4
|
||||
colors = ["#00D9FF", "#00FF87", "#FFD700", "#FF6B6B"]
|
||||
color = colors[frame]
|
||||
|
||||
text = Text()
|
||||
for i, line in enumerate(lines):
|
||||
style = Style(color=color, bold=True) if i == 0 else Style(color=color)
|
||||
text.append(line + "\n", style=style)
|
||||
return str(text)
|
||||
|
||||
return ASTRBOT_LOGO
|
||||
|
||||
def _animate_logo(self) -> None:
|
||||
"""Update the animated logo."""
|
||||
self._animation_frame = (self._animation_frame + 1) % 4
|
||||
logo_widget = self.query_one("#logo-text", Static)
|
||||
logo_widget.update(self._get_animated_logo())
|
||||
|
||||
async def _run_startup(self) -> None:
|
||||
"""Run the AstrBot startup coroutine."""
|
||||
if self._startup_done:
|
||||
return
|
||||
self._startup_done = True
|
||||
|
||||
try:
|
||||
log_widget = self.query_one("#log-viewer", Log)
|
||||
log_widget.write_line("⏳ Initializing AstrBot...")
|
||||
|
||||
await self.startup_coro()
|
||||
|
||||
log_widget.write_line("")
|
||||
log_widget.write_line("✅ AstrBot started successfully!")
|
||||
except Exception as e:
|
||||
log_widget = self.query_one("#log-viewer", Log)
|
||||
log_widget.write_line(f"❌ Error during startup: {e}")
|
||||
log_widget.write_line("Check logs for details.")
|
||||
|
||||
def action_toggle_logs(self) -> None:
|
||||
"""Toggle log visibility."""
|
||||
self.log_visible = not self.log_visible
|
||||
log_section = self.query_one("#log-section", Container)
|
||||
if self.log_visible:
|
||||
log_section.remove_class("hidden")
|
||||
else:
|
||||
log_section.add_class("hidden")
|
||||
|
||||
async def action_quit(self) -> None:
|
||||
"""Quit the application."""
|
||||
self.exit()
|
||||
|
||||
def write_log(self, message: str) -> None:
|
||||
"""Write a message to the log viewer (can be called from outside)."""
|
||||
log_widget = self.query_one("#log-viewer", Log)
|
||||
log_widget.write_line(message)
|
||||
|
||||
|
||||
def is_interactive_tty() -> bool:
|
||||
"""Check if we're running in an interactive TTY."""
|
||||
return sys.stdin.isatty() and sys.stdout.isatty()
|
||||
|
||||
|
||||
async def run_tui(
|
||||
startup_coro: Callable[[], Awaitable[Any]],
|
||||
astrbot_root: Path,
|
||||
backend_only: bool = False,
|
||||
host: str | None = None,
|
||||
port: str | None = None,
|
||||
) -> None:
|
||||
"""Run the AstrBot TUI.
|
||||
|
||||
Args:
|
||||
startup_coro: Coroutine to run on startup
|
||||
astrbot_root: AstrBot root directory
|
||||
backend_only: Whether backend-only mode is enabled
|
||||
host: Dashboard host
|
||||
port: Dashboard port
|
||||
"""
|
||||
if not is_interactive_tty():
|
||||
# Not interactive, run without TUI
|
||||
await startup_coro()
|
||||
return
|
||||
|
||||
app = AstrBotRunTUI(
|
||||
startup_coro=startup_coro,
|
||||
astrbot_root=astrbot_root,
|
||||
backend_only=backend_only,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
try:
|
||||
await app.run_async()
|
||||
except Exception:
|
||||
# Fallback to non-TUI mode
|
||||
await startup_coro()
|
||||
@@ -1,68 +0,0 @@
|
||||
"""TUI CLI command for AstrBot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@click.command(name="tui")
|
||||
@click.option(
|
||||
"--debug",
|
||||
is_flag=True,
|
||||
help="Enable debug mode with verbose output.",
|
||||
)
|
||||
@click.option(
|
||||
"--host",
|
||||
default="http://localhost:6185",
|
||||
help="AstrBot dashboard host URL.",
|
||||
)
|
||||
@click.option(
|
||||
"--api-key",
|
||||
default=None,
|
||||
help="API key for authentication (optional, uses login if not provided).",
|
||||
)
|
||||
@click.option(
|
||||
"--username",
|
||||
default="astrbot",
|
||||
help="Username for login (if api-key not provided).",
|
||||
)
|
||||
@click.option(
|
||||
"--password",
|
||||
default="astrbot",
|
||||
help="Password for login (if api-key not provided).",
|
||||
)
|
||||
def tui(
|
||||
debug: bool,
|
||||
host: str,
|
||||
api_key: str | None,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> None:
|
||||
"""
|
||||
Launch the AstrBot Terminal User Interface (TUI).
|
||||
|
||||
This command starts an interactive terminal-based interface for AstrBot.
|
||||
The TUI connects to a running AstrBot instance via the dashboard API.
|
||||
"""
|
||||
try:
|
||||
from astrbot.cli.commands.tui_async import run_tui_async
|
||||
|
||||
run_tui_async(
|
||||
debug=debug,
|
||||
host=host,
|
||||
api_key=api_key,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
except ImportError as e:
|
||||
click.echo(f"Error: Failed to import TUI module: {e}", err=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
click.echo(f"Error: Failed to start TUI: {e}", err=True)
|
||||
if debug:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -1,511 +0,0 @@
|
||||
"""Async TUI implementation that connects to a running AstrBot instance via HTTP API.
|
||||
|
||||
This module provides a terminal UI that connects to AstrBot via the dashboard API,
|
||||
supporting streaming responses and all message types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import curses
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
from astrbot.tui.message_handler import (
|
||||
ChatResponse,
|
||||
MessageType,
|
||||
ParsedMessage,
|
||||
SSEMessageParser,
|
||||
)
|
||||
from astrbot.tui.screen import Screen
|
||||
|
||||
|
||||
class MessageSender(Enum):
|
||||
USER = "user"
|
||||
BOT = "bot"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
REASONING = "reasoning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
sender: MessageSender
|
||||
text: str
|
||||
timestamp: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TUIState:
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
input_buffer: str = ""
|
||||
cursor_x: int = 0
|
||||
status: str = "Connecting..."
|
||||
running: bool = True
|
||||
connected: bool = False
|
||||
|
||||
|
||||
class TUIClient:
|
||||
"""TUI client that connects to AstrBot via HTTP API.
|
||||
|
||||
Supports full streaming responses including:
|
||||
- Plain text (streaming)
|
||||
- Tool calls and results
|
||||
- Reasoning chains
|
||||
- Agent stats
|
||||
- Media (images, audio, files)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
screen: Screen,
|
||||
host: str,
|
||||
api_key: str | None,
|
||||
username: str,
|
||||
password: str,
|
||||
debug: bool = False,
|
||||
):
|
||||
self.screen = screen
|
||||
self.state = TUIState()
|
||||
self._input_history: list[str] = []
|
||||
self._history_index: int = -1
|
||||
self._max_history: int = 100
|
||||
self._max_messages: int = 1000
|
||||
self._pending_tasks: list[asyncio.Task[None]] = []
|
||||
|
||||
# Connection settings
|
||||
self.host = host.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.debug = debug
|
||||
|
||||
# Session info
|
||||
self.session_id: str | None = None
|
||||
self.conversation_id: str | None = None
|
||||
|
||||
# HTTP client
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._headers: dict[str, str] = {}
|
||||
|
||||
# SSE parser
|
||||
self._parser = SSEMessageParser()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to AstrBot and authenticate."""
|
||||
self._client = httpx.AsyncClient(base_url=self.host, timeout=30.0)
|
||||
|
||||
try:
|
||||
# Login or use API key
|
||||
if self.api_key:
|
||||
self._headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
else:
|
||||
login_resp = await self._client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": self.username, "password": self.password},
|
||||
)
|
||||
if login_resp.status_code != 200:
|
||||
self.state.status = f"Login failed: {login_resp.status_code}"
|
||||
return False
|
||||
data = login_resp.json()
|
||||
self._headers["Authorization"] = (
|
||||
f"Bearer {data.get('access_token', '')}"
|
||||
)
|
||||
|
||||
# Create new session for TUI
|
||||
new_session_resp = await self._client.get(
|
||||
"/api/tui/new_session",
|
||||
params={"platform_id": "tui"},
|
||||
headers=self._headers,
|
||||
)
|
||||
if new_session_resp.status_code != 200:
|
||||
self.state.status = (
|
||||
f"Failed to create session: {new_session_resp.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
session_data = new_session_resp.json()
|
||||
if session_data.get("code") != 0:
|
||||
self.state.status = f"Session error: {session_data.get('msg')}"
|
||||
return False
|
||||
|
||||
self.conversation_id = session_data.get("data", {}).get("session_id")
|
||||
if not self.conversation_id:
|
||||
self.state.status = "No session_id in response"
|
||||
return False
|
||||
|
||||
self.session_id = self.conversation_id
|
||||
self.state.connected = True
|
||||
self.state.status = "Connected"
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.state.status = f"Connection error: {e}"
|
||||
if self.debug:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from AstrBot."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self.state.connected = False
|
||||
|
||||
async def load_history(self) -> None:
|
||||
"""Load message history for the current session."""
|
||||
if not self._client or not self.conversation_id:
|
||||
return
|
||||
|
||||
try:
|
||||
resp = await self._client.get(
|
||||
"/api/tui/get_session",
|
||||
params={"session_id": self.conversation_id},
|
||||
headers=self._headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return
|
||||
|
||||
data = resp.json()
|
||||
history = data.get("data", {}).get("history", [])
|
||||
|
||||
for record in reversed(history):
|
||||
content = record.get("content", {})
|
||||
msg_type = content.get("type")
|
||||
message_parts = content.get("message", [])
|
||||
|
||||
if msg_type == "user":
|
||||
for part in message_parts:
|
||||
if part.get("type") == "plain":
|
||||
self.add_message(MessageSender.USER, part.get("text", ""))
|
||||
elif msg_type == "bot":
|
||||
for part in message_parts:
|
||||
if part.get("type") == "plain":
|
||||
self.add_message(MessageSender.BOT, part.get("text", ""))
|
||||
|
||||
except Exception:
|
||||
if self.debug:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def add_message(self, sender: MessageSender, text: str) -> None:
|
||||
"""Add a message to the chat log."""
|
||||
if not text:
|
||||
return
|
||||
self.state.messages.append(Message(sender=sender, text=text))
|
||||
if len(self.state.messages) > self._max_messages:
|
||||
self.state.messages = self.state.messages[-self._max_messages :]
|
||||
|
||||
def add_system_message(self, text: str) -> None:
|
||||
"""Add a system message."""
|
||||
self.add_message(MessageSender.SYSTEM, text)
|
||||
|
||||
def handle_key(self, key: int) -> bool:
|
||||
"""Handle a keypress. Returns True if the application should continue running."""
|
||||
if key in (curses.KEY_EXIT, 27): # ESC or ctrl-c
|
||||
return False
|
||||
|
||||
if key == curses.KEY_RESIZE:
|
||||
self.screen.resize()
|
||||
return True
|
||||
|
||||
# Handle arrow keys for navigation
|
||||
if key == curses.KEY_LEFT:
|
||||
if self.state.cursor_x > 0:
|
||||
self.state.cursor_x -= 1
|
||||
elif key == curses.KEY_RIGHT:
|
||||
if self.state.cursor_x < len(self.state.input_buffer):
|
||||
self.state.cursor_x += 1
|
||||
elif key == curses.KEY_HOME:
|
||||
self.state.cursor_x = 0
|
||||
elif key == curses.KEY_END:
|
||||
self.state.cursor_x = len(self.state.input_buffer)
|
||||
|
||||
# Handle backspace
|
||||
elif key in (curses.KEY_BACKSPACE, 127, 8):
|
||||
if self.state.cursor_x > 0:
|
||||
self.state.input_buffer = (
|
||||
self.state.input_buffer[: self.state.cursor_x - 1]
|
||||
+ self.state.input_buffer[self.state.cursor_x :]
|
||||
)
|
||||
self.state.cursor_x -= 1
|
||||
|
||||
# Handle delete
|
||||
elif key == curses.KEY_DC:
|
||||
if self.state.cursor_x < len(self.state.input_buffer):
|
||||
self.state.input_buffer = (
|
||||
self.state.input_buffer[: self.state.cursor_x]
|
||||
+ self.state.input_buffer[self.state.cursor_x + 1 :]
|
||||
)
|
||||
|
||||
# Handle Enter/Return - submit message
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
if self.state.input_buffer.strip():
|
||||
task = asyncio.create_task(self._submit_message())
|
||||
self._pending_tasks.append(task)
|
||||
return True
|
||||
|
||||
# Handle history navigation (up/down arrows)
|
||||
elif key == curses.KEY_UP:
|
||||
if (
|
||||
self._input_history
|
||||
and self._history_index < len(self._input_history) - 1
|
||||
):
|
||||
self._history_index += 1
|
||||
self.state.input_buffer = self._input_history[self._history_index]
|
||||
self.state.cursor_x = len(self.state.input_buffer)
|
||||
elif key == curses.KEY_DOWN:
|
||||
if self._history_index > 0:
|
||||
self._history_index -= 1
|
||||
self.state.input_buffer = self._input_history[self._history_index]
|
||||
self.state.cursor_x = len(self.state.input_buffer)
|
||||
elif self._history_index == 0:
|
||||
self._history_index = -1
|
||||
self.state.input_buffer = ""
|
||||
self.state.cursor_x = 0
|
||||
|
||||
# Regular character input
|
||||
elif 32 <= key <= 126:
|
||||
char = chr(key)
|
||||
self.state.input_buffer = (
|
||||
self.state.input_buffer[: self.state.cursor_x]
|
||||
+ char
|
||||
+ self.state.input_buffer[self.state.cursor_x :]
|
||||
)
|
||||
self.state.cursor_x += 1
|
||||
|
||||
# Clear input with Ctrl+L
|
||||
elif key == 12: # Ctrl+L
|
||||
self.state.input_buffer = ""
|
||||
self.state.cursor_x = 0
|
||||
|
||||
return True
|
||||
|
||||
async def _submit_message(self) -> None:
|
||||
"""Submit the current input buffer as a user message."""
|
||||
text = self.state.input_buffer.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Add to history
|
||||
self._input_history.insert(0, text)
|
||||
if len(self._input_history) > self._max_history:
|
||||
self._input_history = self._input_history[: self._max_history]
|
||||
self._history_index = -1
|
||||
|
||||
# Add user message to chat
|
||||
self.add_message(MessageSender.USER, text)
|
||||
|
||||
# Clear input
|
||||
self.state.input_buffer = ""
|
||||
self.state.cursor_x = 0
|
||||
|
||||
# Process the message via API
|
||||
await self._process_user_message(text)
|
||||
|
||||
async def _process_user_message(self, text: str) -> None:
|
||||
"""Send message to AstrBot and process the streaming response."""
|
||||
if not self.conversation_id or not self._client:
|
||||
self.add_system_message("Not connected to AstrBot")
|
||||
return
|
||||
|
||||
self.state.status = "Waiting for response..."
|
||||
|
||||
try:
|
||||
# Format umo for tui
|
||||
umo = f"tui:FriendMessage:tui!{self.username}!{self.conversation_id}"
|
||||
|
||||
# Reset parser for new stream
|
||||
self._parser.reset()
|
||||
|
||||
# Send message and stream response using proper SSE
|
||||
async with self._client.stream(
|
||||
"POST",
|
||||
"/api/tui/chat",
|
||||
headers=self._headers,
|
||||
json={
|
||||
"umo": umo,
|
||||
"message": text,
|
||||
"session_id": self.conversation_id,
|
||||
"streaming": True,
|
||||
},
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
self._update_last_bot_message(f"Error: HTTP {response.status_code}")
|
||||
self.state.status = "Error"
|
||||
return
|
||||
|
||||
# Process streaming SSE
|
||||
async for line in response.aiter_lines():
|
||||
parsed = self._parser.parse_line(line)
|
||||
if parsed is None:
|
||||
continue
|
||||
|
||||
update, is_complete = self._process_parsed_message(parsed)
|
||||
|
||||
# Update display based on message type
|
||||
if parsed.type == MessageType.TOOL_CALL:
|
||||
tool_call = json.loads(parsed.data)
|
||||
self.add_message(
|
||||
MessageSender.TOOL,
|
||||
f"[Tool: {tool_call.get('name', 'unknown')}]",
|
||||
)
|
||||
self.state.status = "Running tool..."
|
||||
elif parsed.type == MessageType.TOOL_CALL_RESULT:
|
||||
try:
|
||||
tcr = json.loads(parsed.data)
|
||||
self.add_message(
|
||||
MessageSender.TOOL,
|
||||
f"[Result] {tcr.get('result', '')[:100]}...",
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif parsed.type == MessageType.REASONING:
|
||||
self._update_last_bot_message(
|
||||
f"[Thinking] {update.reasoning[-200:]}"
|
||||
)
|
||||
self.state.status = "Thinking..."
|
||||
elif parsed.type == MessageType.AGENT_STATS:
|
||||
self.state.status = (
|
||||
f"Tokens: {update.agent_stats.get('total_tokens', 0)}"
|
||||
)
|
||||
elif update.text:
|
||||
self._update_last_bot_message(update.text)
|
||||
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
# Final status
|
||||
if update.reasoning:
|
||||
self.add_message(
|
||||
MessageSender.REASONING, f"[Reasoning]\n{update.reasoning}"
|
||||
)
|
||||
|
||||
for tool_display in update.get_tool_calls_display():
|
||||
self.add_message(MessageSender.TOOL, tool_display)
|
||||
|
||||
if update.error:
|
||||
self.add_message(MessageSender.SYSTEM, f"Error: {update.error}")
|
||||
|
||||
self.state.status = "Ready"
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.state.status = "Cancelled"
|
||||
except Exception as e:
|
||||
self.add_system_message(f"Error: {e}")
|
||||
self.state.status = f"Error: {e}"
|
||||
if self.debug:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def _process_parsed_message(self, msg: ParsedMessage) -> tuple[ChatResponse, bool]:
|
||||
"""Process a parsed message and return updated response state."""
|
||||
return self._parser.process_message(msg)
|
||||
|
||||
def _update_last_bot_message(self, text: str) -> None:
|
||||
"""Update the last bot message with new text (for streaming)."""
|
||||
for i in range(len(self.state.messages) - 1, -1, -1):
|
||||
if self.state.messages[i].sender == MessageSender.BOT:
|
||||
self.state.messages[i] = Message(
|
||||
sender=MessageSender.BOT,
|
||||
text=text,
|
||||
timestamp=self.state.messages[i].timestamp,
|
||||
)
|
||||
break
|
||||
else:
|
||||
self.add_message(MessageSender.BOT, text)
|
||||
|
||||
def render(self) -> None:
|
||||
"""Render the current state to the screen."""
|
||||
lines = [(msg.sender.value, msg.text) for msg in self.state.messages]
|
||||
|
||||
self.screen.draw_all(
|
||||
lines=lines,
|
||||
input_text=self.state.input_buffer,
|
||||
cursor_x=self.state.cursor_x,
|
||||
status=self.state.status,
|
||||
)
|
||||
|
||||
async def run_event_loop(self, stdscr: curses.window) -> None:
|
||||
"""Main event loop for the TUI."""
|
||||
# Setup
|
||||
self.screen.setup_colors()
|
||||
self.screen.layout_windows()
|
||||
|
||||
# Connect to AstrBot
|
||||
connected = await self.connect()
|
||||
if not connected:
|
||||
self.add_system_message(f"Failed to connect: {self.state.status}")
|
||||
else:
|
||||
self.add_system_message("Connected to AstrBot!")
|
||||
# Load history
|
||||
await self.load_history()
|
||||
|
||||
# Welcome message
|
||||
self.add_system_message("Type your message and press Enter to send.")
|
||||
self.add_system_message("Press ESC or Ctrl+C to exit.")
|
||||
|
||||
# Initial render
|
||||
self.render()
|
||||
|
||||
# Input loop
|
||||
while self.state.running:
|
||||
# Get input with timeout
|
||||
self.screen.input_win.nodelay(True)
|
||||
try:
|
||||
key = self.screen.input_win.getch()
|
||||
except curses.error:
|
||||
key = -1
|
||||
|
||||
if key != -1:
|
||||
if not self.handle_key(key):
|
||||
self.state.running = False
|
||||
break
|
||||
self.render()
|
||||
|
||||
# Small sleep to prevent CPU hogging
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Cleanup
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
def run_tui_async(
|
||||
debug: bool = False,
|
||||
host: str = "http://localhost:6185",
|
||||
api_key: str | None = None,
|
||||
username: str = "astrbot",
|
||||
password: str = "astrbot",
|
||||
) -> None:
|
||||
"""Entry point to run the TUI application."""
|
||||
from astrbot.tui.screen import run_curses
|
||||
|
||||
def main(stdscr: curses.window) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
scr = Screen(stdscr)
|
||||
client = TUIClient(
|
||||
screen=scr,
|
||||
host=host,
|
||||
api_key=api_key,
|
||||
username=username,
|
||||
password=password,
|
||||
debug=debug,
|
||||
)
|
||||
try:
|
||||
loop.run_until_complete(client.run_event_loop(stdscr))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
run_curses(main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tui_async()
|
||||
@@ -1,285 +0,0 @@
|
||||
"""Internationalization support for AstrBot CLI.
|
||||
|
||||
This module provides i18n support with Chinese and English languages.
|
||||
Language is auto-detected from environment or can be set manually.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Language(Enum):
|
||||
"""Supported languages."""
|
||||
|
||||
ZH = "zh"
|
||||
EN = "en"
|
||||
|
||||
|
||||
# Translation dictionaries
|
||||
_TRANSLATIONS: dict[Language, dict[str, str]] = {
|
||||
Language.ZH: {
|
||||
# CLI welcome and general
|
||||
"cli_welcome": "欢迎使用 AstrBot CLI!",
|
||||
"cli_version": "AstrBot CLI 版本: {version}",
|
||||
"cli_unknown_command": "未知命令: {command}",
|
||||
"cli_help_available": "使用 astrbot help --all 查看所有命令",
|
||||
# Dashboard commands
|
||||
"dashboard_bundled": "Dashboard 已打包在安装包中 - 跳过下载",
|
||||
"dashboard_not_installed": "Dashboard 未安装",
|
||||
"dashboard_install_confirm": "是否安装 Dashboard?",
|
||||
"dashboard_installing": "正在安装 Dashboard...",
|
||||
"dashboard_install_success": "Dashboard 安装成功",
|
||||
"dashboard_install_failed": "Dashboard 安装失败: {error}",
|
||||
"dashboard_not_needed": "Dashboard 不需要安装",
|
||||
"dashboard_declined": "Dashboard 安装已取消",
|
||||
"dashboard_already_up_to_date": "Dashboard 已是最新版本",
|
||||
"dashboard_version": "Dashboard 版本: {version}",
|
||||
"dashboard_download_failed": "Dashboard 下载失败: {error}",
|
||||
"dashboard_init_dir": "正在初始化 Dashboard 目录...",
|
||||
"dashboard_init_success": "Dashboard 初始化成功",
|
||||
# Plugin commands
|
||||
"plugin_installing": "正在安装插件: {name}",
|
||||
"plugin_install_success": "插件安装成功: {name}",
|
||||
"plugin_install_failed": "插件安装失败: {name}",
|
||||
"plugin_uninstall_confirm": "确定要卸载插件 {name} 吗?",
|
||||
"plugin_uninstall_success": "插件卸载成功: {name}",
|
||||
"plugin_uninstall_failed": "插件卸载失败: {name}",
|
||||
"plugin_list_empty": "未安装任何插件",
|
||||
"plugin_already_installed": "插件已安装: {name}",
|
||||
"plugin_not_found": "插件未找到: {name}",
|
||||
"plugin_already_exists": "插件已存在: {name}",
|
||||
"plugin_not_found_or_installed": "插件未找到或已安装: {name}",
|
||||
"plugin_uninstall_failed_ex": "插件卸载失败 {name}: {error}",
|
||||
"plugin_no_update_needed": "没有需要更新的插件",
|
||||
"plugin_found_update": "发现 {count} 个插件需要更新",
|
||||
"plugin_updating": "正在更新插件 {name}...",
|
||||
"plugin_search_no_result": "未找到匹配 '{query}' 的插件",
|
||||
"plugin_search_results": "搜索结果: '{query}'",
|
||||
# Config commands
|
||||
"config_show": "显示配置",
|
||||
"config_set_success": "配置项已更新: {key} = {value}",
|
||||
"config_set_failed": "配置项更新失败: {key}",
|
||||
"config_set_failed_ex": "设置配置失败: {error}",
|
||||
"config_get_success": "{key} = {value}",
|
||||
"config_get_not_found": "配置项未找到: {key}",
|
||||
"config_reset_confirm": "确定要重置所有配置吗?",
|
||||
"config_reset_success": "配置已重置",
|
||||
# Config validators
|
||||
"config_log_level_invalid": "日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
|
||||
"config_port_must_be_number": "端口必须是数字",
|
||||
"config_port_range_invalid": "端口必须在 1-65535 范围内",
|
||||
"config_username_empty": "用户名不能为空",
|
||||
"config_password_empty": "密码不能为空",
|
||||
"config_timezone_invalid": "无效的时区: {value}。请使用有效的 IANA 时区名称",
|
||||
"config_callback_invalid": "回调 API 基础路径必须以 http:// 或 https:// 开头",
|
||||
"config_key_unsupported": "不支持的配置项: {key}",
|
||||
"config_key_unknown": "未知的配置项: {key}",
|
||||
"config_updated": "配置已更新: {key}",
|
||||
# Init command
|
||||
"init_creating": "正在创建配置目录...",
|
||||
"init_created": "配置目录已创建: {path}",
|
||||
"init_copying": "正在复制配置文件...",
|
||||
"init_copied": "配置文件已复制",
|
||||
"init_success": "AstrBot 初始化完成!",
|
||||
"init_failed": "初始化失败: {error}",
|
||||
# Run command
|
||||
"run_starting": "正在启动 AstrBot...",
|
||||
"run_started": "AstrBot 已启动!",
|
||||
"run_backend_only": "以无界面模式启动",
|
||||
"run_failed": "启动失败: {error}",
|
||||
"run_stopped": "AstrBot 已停止",
|
||||
# TUI command
|
||||
"tui_starting": "正在启动 TUI...",
|
||||
"tui_started": "TUI 已启动",
|
||||
"tui_failed": "TUI 启动失败: {error}",
|
||||
# Common
|
||||
"yes": "是",
|
||||
"no": "否",
|
||||
"cancel": "取消",
|
||||
"confirm": "确认",
|
||||
"error": "错误",
|
||||
"success": "成功",
|
||||
"warning": "警告",
|
||||
"info": "信息",
|
||||
"loading": "加载中...",
|
||||
"done": "完成",
|
||||
"failed": "失败",
|
||||
"retry": "重试",
|
||||
"exit": "退出",
|
||||
"continue": "继续",
|
||||
},
|
||||
Language.EN: {
|
||||
# CLI welcome and general
|
||||
"cli_welcome": "Welcome to AstrBot CLI!",
|
||||
"cli_version": "AstrBot CLI version: {version}",
|
||||
"cli_unknown_command": "Unknown command: {command}",
|
||||
"cli_help_available": "Use astrbot help --all to see all commands",
|
||||
# Dashboard commands
|
||||
"dashboard_bundled": "Dashboard is bundled with the package - skipping download",
|
||||
"dashboard_not_installed": "Dashboard is not installed",
|
||||
"dashboard_install_confirm": "Install Dashboard?",
|
||||
"dashboard_installing": "Installing Dashboard...",
|
||||
"dashboard_install_success": "Dashboard installed successfully",
|
||||
"dashboard_install_failed": "Failed to install dashboard: {error}",
|
||||
"dashboard_not_needed": "Dashboard not needed",
|
||||
"dashboard_declined": "Dashboard installation declined.",
|
||||
"dashboard_already_up_to_date": "Dashboard is already up to date",
|
||||
"dashboard_version": "Dashboard version: {version}",
|
||||
"dashboard_download_failed": "Failed to download dashboard: {error}",
|
||||
"dashboard_init_dir": "Initializing dashboard directory...",
|
||||
"dashboard_init_success": "Dashboard initialized successfully",
|
||||
# Plugin commands
|
||||
"plugin_installing": "Installing plugin: {name}",
|
||||
"plugin_install_success": "Plugin installed successfully: {name}",
|
||||
"plugin_install_failed": "Failed to install plugin: {name}",
|
||||
"plugin_uninstall_confirm": "Uninstall plugin {name}?",
|
||||
"plugin_uninstall_success": "Plugin uninstalled successfully: {name}",
|
||||
"plugin_uninstall_failed": "Failed to uninstall plugin: {name}",
|
||||
"plugin_list_empty": "No plugins installed",
|
||||
"plugin_already_installed": "Plugin already installed: {name}",
|
||||
"plugin_not_found": "Plugin not found: {name}",
|
||||
"plugin_already_exists": "Plugin {name} already exists",
|
||||
"plugin_not_found_or_installed": "Plugin {name} not found or already installed",
|
||||
"plugin_uninstall_failed_ex": "Failed to uninstall plugin {name}: {error}",
|
||||
"plugin_no_update_needed": "No plugins need updating",
|
||||
"plugin_found_update": "Found {count} plugin(s) needing update",
|
||||
"plugin_updating": "Updating plugin {name}...",
|
||||
"plugin_search_no_result": "No plugins matching '{query}' found",
|
||||
"plugin_search_results": "Search results: '{query}'",
|
||||
# Config commands
|
||||
"config_show": "Show configuration",
|
||||
"config_set_success": "Configuration updated: {key} = {value}",
|
||||
"config_set_failed": "Failed to update configuration: {key}",
|
||||
"config_set_failed_ex": "Failed to set config: {error}",
|
||||
"config_get_success": "{key} = {value}",
|
||||
"config_get_not_found": "Configuration key not found: {key}",
|
||||
"config_reset_confirm": "Reset all configuration?",
|
||||
"config_reset_success": "Configuration reset",
|
||||
# Config validators
|
||||
"config_log_level_invalid": "Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
|
||||
"config_port_must_be_number": "Port must be a number",
|
||||
"config_port_range_invalid": "Port must be in range 1-65535",
|
||||
"config_username_empty": "Username cannot be empty",
|
||||
"config_password_empty": "Password cannot be empty",
|
||||
"config_timezone_invalid": "Invalid timezone: {value}. Please use a valid IANA timezone name",
|
||||
"config_callback_invalid": "Callback API base must start with http:// or https://",
|
||||
"config_key_unsupported": "Unsupported config key: {key}",
|
||||
"config_key_unknown": "Unknown config key: {key}",
|
||||
"config_updated": "Config updated: {key}",
|
||||
# Init command
|
||||
"init_creating": "Creating config directory...",
|
||||
"init_created": "Config directory created: {path}",
|
||||
"init_copying": "Copying config files...",
|
||||
"init_copied": "Config files copied",
|
||||
"init_success": "AstrBot initialized successfully!",
|
||||
"init_failed": "Initialization failed: {error}",
|
||||
# Run command
|
||||
"run_starting": "Starting AstrBot...",
|
||||
"run_started": "AstrBot started!",
|
||||
"run_backend_only": "Starting in backend-only mode",
|
||||
"run_failed": "Failed to start: {error}",
|
||||
"run_stopped": "AstrBot stopped",
|
||||
# TUI command
|
||||
"tui_starting": "Starting TUI...",
|
||||
"tui_started": "TUI started",
|
||||
"tui_failed": "Failed to start TUI: {error}",
|
||||
# Common
|
||||
"yes": "Yes",
|
||||
"no": "No",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Confirm",
|
||||
"error": "Error",
|
||||
"success": "Success",
|
||||
"warning": "Warning",
|
||||
"info": "Info",
|
||||
"loading": "Loading...",
|
||||
"done": "Done",
|
||||
"failed": "Failed",
|
||||
"retry": "Retry",
|
||||
"exit": "Exit",
|
||||
"continue": "Continue",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_current_language() -> Language:
|
||||
"""Get the current language based on environment or default.
|
||||
|
||||
Detection order:
|
||||
1. ASTRBOT_CLI_LANG environment variable (zh/en)
|
||||
2. LANG environment variable (if contains zh/cn)
|
||||
3. LC_ALL environment variable (if contains zh/cn)
|
||||
4. Default to Chinese (most users are Chinese)
|
||||
"""
|
||||
# Check explicit override first
|
||||
explicit = os.environ.get("ASTRBOT_CLI_LANG", "").lower()
|
||||
if explicit in ("zh", "en"):
|
||||
return Language.ZH if explicit == "zh" else Language.EN
|
||||
|
||||
# Check LANG/LC_ALL for Chinese
|
||||
for env_var in ("LANG", "LC_ALL"):
|
||||
lang = os.environ.get(env_var, "").lower()
|
||||
if "zh" in lang or "cn" in lang:
|
||||
return Language.ZH
|
||||
|
||||
# Default to Chinese for broader appeal
|
||||
return Language.ZH
|
||||
|
||||
|
||||
def set_language(lang: Language) -> None:
|
||||
"""Set the current language (clears all translation caches)."""
|
||||
get_current_language.cache_clear()
|
||||
_t_cached.cache_clear()
|
||||
# Set environment variable for persistence
|
||||
os.environ["ASTRBOT_CLI_LANG"] = lang.value
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _t_cached(key: str, lang: Language) -> str:
|
||||
"""Cached translation lookup."""
|
||||
return _TRANSLATIONS.get(lang, {}).get(key, key)
|
||||
|
||||
|
||||
def t(translation_key: str, **kwargs: str) -> str:
|
||||
"""Get translation for the given key in the current language.
|
||||
|
||||
Args:
|
||||
translation_key: Translation key (e.g., "cli_welcome", "plugin_installing")
|
||||
**kwargs: Format arguments for the translation string
|
||||
|
||||
Returns:
|
||||
Translated string, or the key itself if not found
|
||||
"""
|
||||
result = _t_cached(translation_key, get_current_language())
|
||||
if kwargs:
|
||||
result = result.format(**kwargs)
|
||||
return result
|
||||
|
||||
|
||||
def tr(key: str, **kwargs: str) -> str:
|
||||
"""Get translation (alias for t())."""
|
||||
return t(key, **kwargs)
|
||||
|
||||
|
||||
class CLITranslations:
|
||||
"""Translation accessor class for CLI contexts.
|
||||
|
||||
Usage:
|
||||
translations = CLITranslations()
|
||||
print(translations.cli_welcome)
|
||||
print(translations.plugin_installing(name="my_plugin"))
|
||||
"""
|
||||
|
||||
def __getattr__(self, key: str) -> str:
|
||||
return t(key)
|
||||
|
||||
def __call__(self, key: str, **kwargs: str) -> str:
|
||||
return t(key, **kwargs)
|
||||
|
||||
|
||||
# Convenience instance
|
||||
translations = CLITranslations()
|
||||
@@ -1,12 +1,18 @@
|
||||
from .dashboard import DashboardManager
|
||||
from .basic import (
|
||||
check_astrbot_root,
|
||||
check_dashboard,
|
||||
get_astrbot_root,
|
||||
)
|
||||
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
__all__ = [
|
||||
"DashboardManager",
|
||||
"PluginStatus",
|
||||
"VersionComparator",
|
||||
"build_plug_list",
|
||||
"check_astrbot_root",
|
||||
"check_dashboard",
|
||||
"get_astrbot_root",
|
||||
"get_git_repo",
|
||||
"manage_plugin",
|
||||
]
|
||||
|
||||
91
astrbot/cli/utils/basic.py
Normal file
91
astrbot/cli/utils/basic.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from importlib import resources
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from astrbot.core.utils.astrbot_path import astrbot_paths
|
||||
|
||||
# Static assets bundled inside the installed wheel (built by hatch_build.py).
|
||||
# _BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
|
||||
_BUNDLED_DIST = resources.files("astrbot") / "dashboard" / "dist"
|
||||
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""Check if the path is an AstrBot root directory"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
if not (path / ".astrbot").exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""Get the AstrBot root directory path"""
|
||||
return astrbot_paths.root
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""Check if the dashboard is installed"""
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
# If the wheel ships bundled dashboard assets, no network download is needed.
|
||||
if _BUNDLED_DIST.is_dir():
|
||||
click.echo("Dashboard is bundled with the package – skipping download.")
|
||||
return
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("Dashboard is not installed")
|
||||
if click.confirm(
|
||||
"Install dashboard?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("Installing dashboard...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root / "data"),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo("Dashboard installed successfully")
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to install dashboard: {e}")
|
||||
|
||||
case str():
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("Dashboard is already up to date")
|
||||
return
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"Dashboard version: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root / "data"),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to download dashboard: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("Initializing dashboard directory...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "data" / "dashboard.zip"),
|
||||
extract_path=str(astrbot_root / "data"),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo("Dashboard initialized successfully")
|
||||
except Exception as e:
|
||||
click.echo(f"Failed to download dashboard: {e}")
|
||||
return
|
||||
@@ -1,79 +0,0 @@
|
||||
import sys
|
||||
from importlib import resources
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from astrbot.cli.i18n import t
|
||||
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class DashboardManager:
|
||||
_bundled_dist = resources.files("astrbot") / "dashboard" / "dist"
|
||||
|
||||
async def ensure_installed(self, astrbot_root: Path) -> None:
|
||||
"""Ensure the dashboard assets are installed and up to date."""
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
if self._bundled_dist.is_dir():
|
||||
click.echo(t("dashboard_bundled"))
|
||||
return
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo(t("dashboard_not_installed"))
|
||||
# Skip interactive prompt in non-interactive environments
|
||||
if not sys.stdin.isatty():
|
||||
click.echo(t("dashboard_not_needed"))
|
||||
return
|
||||
if click.confirm(t("dashboard_install_confirm"), default=True):
|
||||
click.echo(t("dashboard_installing"))
|
||||
try:
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root / "data"),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo(t("dashboard_install_success"))
|
||||
except Exception as e:
|
||||
click.echo(t("dashboard_install_failed", error=str(e)))
|
||||
else:
|
||||
click.echo(t("dashboard_declined"))
|
||||
|
||||
case str():
|
||||
if (
|
||||
VersionComparator.compare_version(VERSION, dashboard_version)
|
||||
<= 0
|
||||
):
|
||||
click.echo(t("dashboard_already_up_to_date"))
|
||||
return
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(t("dashboard_version", version=version))
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root / "data"),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(t("dashboard_download_failed", error=str(e)))
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo(t("dashboard_init_dir"))
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "data" / "dashboard.zip"),
|
||||
extract_path=str(astrbot_root / "data"),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo(t("dashboard_init_success"))
|
||||
except Exception as e:
|
||||
click.echo(t("dashboard_download_failed", error=str(e)))
|
||||
return
|
||||
@@ -66,16 +66,16 @@ pip_installer = PipInstaller(
|
||||
astrbot_config.get("pypi_index_url", None),
|
||||
)
|
||||
__all__ = [
|
||||
"DEMO_MODE",
|
||||
"AstrBotConfig",
|
||||
"LogBroker",
|
||||
"LogManager",
|
||||
"DEMO_MODE",
|
||||
"astrbot_config",
|
||||
"db_helper",
|
||||
"file_token_service",
|
||||
"t2i_base_url",
|
||||
"html_renderer",
|
||||
"logger",
|
||||
"pip_installer",
|
||||
"LogBroker",
|
||||
"LogManager",
|
||||
"db_helper",
|
||||
"sp",
|
||||
"t2i_base_url",
|
||||
"file_token_service",
|
||||
"pip_installer",
|
||||
]
|
||||
|
||||
@@ -212,7 +212,7 @@ class LLMSummaryCompressor:
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = [*messages_to_summarize, instruction_message]
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@@ -28,19 +28,9 @@ class TokenCounter(Protocol):
|
||||
...
|
||||
|
||||
|
||||
# 图片/音频 token 开销估算值,参考 OpenAI vision pricing:
|
||||
# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千。
|
||||
# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错。
|
||||
IMAGE_TOKEN_ESTIMATE = 765
|
||||
AUDIO_TOKEN_ESTIMATE = 500
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
|
||||
Supports multimodal content: images, audio, and thinking parts
|
||||
are all counted so that the context compressor can trigger in time.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
@@ -55,16 +45,12 @@ class EstimateTokenCounter:
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
elif isinstance(part, ThinkPart):
|
||||
total += self._estimate_tokens(part.think)
|
||||
elif isinstance(part, ImageURLPart):
|
||||
total += IMAGE_TOKEN_ESTIMATE
|
||||
elif isinstance(part, AudioURLPart):
|
||||
total += AUDIO_TOKEN_ESTIMATE
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
|
||||
@@ -12,74 +12,14 @@ class ContextTruncator:
|
||||
and len(message.tool_calls) > 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _split_system_rest(
|
||||
messages: list[Message],
|
||||
) -> tuple[list[Message], list[Message]]:
|
||||
"""Split messages into system messages and the rest.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, non_system_messages)
|
||||
"""
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
return messages[:first_non_system], messages[first_non_system:]
|
||||
|
||||
@staticmethod
|
||||
def _ensure_user_message(
|
||||
system_messages: list[Message],
|
||||
truncated: list[Message],
|
||||
original_messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""Ensure the result always contains a `user` message immediately after
|
||||
system messages, as required by some LLM APIs.
|
||||
|
||||
Optimization strategy:
|
||||
- If `truncated` already begins with a `user` message, return it as-is.
|
||||
- If a `user` message exists later in `truncated`, move that message to
|
||||
be the first non-system message while preserving the relative order of
|
||||
the remaining truncated messages (without mutating the original list).
|
||||
- Otherwise, fall back to the first `user` message from
|
||||
`original_messages`.
|
||||
This reduces unnecessary duplication and ensures the required ordering.
|
||||
"""
|
||||
if truncated and truncated[0].role == "user":
|
||||
return system_messages + truncated
|
||||
|
||||
# If a user message exists inside the truncated list, promote it to the front.
|
||||
index_in_truncated = next(
|
||||
(i for i, m in enumerate(truncated) if m.role == "user"), None
|
||||
)
|
||||
if index_in_truncated is not None:
|
||||
# Build a new truncated list that places the found user message first,
|
||||
# preserving the order of the other messages and avoiding in-place mutation.
|
||||
user_msg = truncated[index_in_truncated]
|
||||
new_truncated = [
|
||||
user_msg,
|
||||
*truncated[:index_in_truncated],
|
||||
*truncated[index_in_truncated + 1 :],
|
||||
]
|
||||
return system_messages + new_truncated
|
||||
|
||||
# Fallback: find the first user message in the original messages.
|
||||
first_user = next((m for m in original_messages if m.role == "user"), None)
|
||||
if first_user is None:
|
||||
# No user messages at all; return system messages + whatever was truncated.
|
||||
return system_messages + truncated
|
||||
|
||||
return [*system_messages, first_user, *truncated]
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
"""Fix the message list to ensure the validity of tool call and tool response pairing.
|
||||
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。
|
||||
|
||||
This method ensures that:
|
||||
1. Each `tool` message is preceded by an `assistant` message containing `tool_calls`.
|
||||
2. Each `assistant` message containing `tool_calls` is followed by corresponding `
|
||||
此方法确保:
|
||||
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
|
||||
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
|
||||
|
||||
This is a requirement of the OpenAI Chat Completions API specification (Gemini enforces this strictly).
|
||||
这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
@@ -98,25 +38,24 @@ class ContextTruncator:
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "tool":
|
||||
# Only record tool responses when there is a pending assistant(tool_calls)
|
||||
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
|
||||
if pending_assistant is not None:
|
||||
pending_tools.append(msg)
|
||||
# Isolated tool messages without a preceding assistant(tool_calls) are ignored
|
||||
# else: 孤立的 tool 消息,直接忽略
|
||||
continue
|
||||
|
||||
if self._has_tool_calls(msg):
|
||||
# When encountering a new assistant(tool_calls), first process the old pending chain
|
||||
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链
|
||||
flush_pending_if_valid()
|
||||
pending_assistant = msg
|
||||
continue
|
||||
|
||||
# Non-tool messages that do not contain tool_calls will break the pending chain.
|
||||
# Flush any pending chain first, then append the current message normally.
|
||||
# 非 tool,且不含 tool_calls 的消息
|
||||
# 先结束任何 pending 链,再正常追加
|
||||
flush_pending_if_valid()
|
||||
fixed_messages.append(msg)
|
||||
|
||||
# Flush the last pending chain at the end,
|
||||
# ensuring that any remaining valid assistant(tool_calls) and its tools are included in the final list.
|
||||
# 结束时处理最后一个 pending 链
|
||||
flush_pending_if_valid()
|
||||
|
||||
return fixed_messages
|
||||
@@ -127,23 +66,29 @@ class ContextTruncator:
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
|
||||
A turn consists of a user message and an assistant message.
|
||||
This method ensures that the truncated context list conforms to OpenAI's context format.
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: The original list of messages in the context.
|
||||
keep_most_recent_turns: The number of most recent turns to keep. If set to -1, it means keeping all turns (no truncation).
|
||||
drop_turns: The number of turns to drop from the beginning.
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
The truncated list of messages.
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
@@ -154,7 +99,7 @@ class ContextTruncator:
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# Find the first user message
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
@@ -162,9 +107,8 @@ class ContextTruncator:
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = self._ensure_user_message(
|
||||
system_messages, truncated_contexts, messages
|
||||
)
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
@@ -172,39 +116,53 @@ class ContextTruncator:
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""Drop the oldest N turns, regardless of the number of turns to keep."""
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
# Find the first user message
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
result = self._ensure_user_message(
|
||||
system_messages, truncated_non_system, messages
|
||||
)
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""Halve the number of messages, keeping the most recent ones."""
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
system_messages, non_system_messages = self._split_system_rest(messages)
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
@@ -212,7 +170,6 @@ class ContextTruncator:
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
# Find the first user message
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
@@ -220,7 +177,6 @@ class ContextTruncator:
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = self._ensure_user_message(
|
||||
system_messages, truncated_non_system, messages
|
||||
)
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
@@ -1,27 +1,8 @@
|
||||
"""
|
||||
MCP client - DEPRECATED
|
||||
|
||||
.. deprecated::
|
||||
This module has been moved to :mod:`astrbot._internal.mcp`.
|
||||
Please update your imports accordingly.
|
||||
|
||||
Old import (deprecated):
|
||||
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
|
||||
|
||||
New import:
|
||||
from astrbot._internal.mcp import MCPClient, MCPTool
|
||||
|
||||
This file exists solely for backward compatibility and will be removed in a future version.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from typing import Any, Generic
|
||||
from typing import Generic
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@@ -31,20 +12,13 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
warnings.warn(
|
||||
"astrbot.core.agent.mcp_client has been moved to astrbot._internal.mcp. "
|
||||
"Please update your imports.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
import anyio
|
||||
import mcp
|
||||
@@ -62,26 +36,6 @@ except (ModuleNotFoundError, ImportError):
|
||||
)
|
||||
|
||||
|
||||
class TenacityLogger:
|
||||
"""Wraps a logging.Logger to satisfy tenacity's LoggerProtocol."""
|
||||
|
||||
__slots__ = ("_logger",)
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(self, logger: logging.Logger) -> None:
|
||||
self._logger = logger
|
||||
|
||||
def log(
|
||||
self,
|
||||
level: int,
|
||||
msg: str,
|
||||
/,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""Prepare configuration, handle nested format"""
|
||||
if config.get("mcpServers"):
|
||||
@@ -91,22 +45,6 @@ def _prepare_config(config: dict) -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def _prepare_stdio_env(config: dict) -> dict:
|
||||
"""Preserve Windows executable resolution for stdio subprocesses."""
|
||||
if sys.platform != "win32":
|
||||
return config
|
||||
|
||||
pathext = os.environ.get("PATHEXT")
|
||||
if not pathext:
|
||||
return config
|
||||
|
||||
prepared = config.copy()
|
||||
env = dict(prepared.get("env") or {})
|
||||
env.setdefault("PATHEXT", pathext)
|
||||
prepared["env"] = env
|
||||
return prepared
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""Quick test MCP server connectivity"""
|
||||
import aiohttp
|
||||
@@ -181,7 +119,6 @@ class MCPClient:
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
self.process_pid: int | None = None
|
||||
|
||||
# Store connection config for reconnection
|
||||
self._mcp_server_config: dict | None = None
|
||||
@@ -189,24 +126,6 @@ class MCPClient:
|
||||
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
||||
self._reconnecting: bool = False # For logging and debugging
|
||||
|
||||
@staticmethod
|
||||
def _extract_stdio_process_pid(streams_context: object) -> int | None:
|
||||
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
|
||||
|
||||
TODO(refactor): replace this async-generator frame introspection with a
|
||||
stable MCP library hook once the upstream transport exposes process PID.
|
||||
"""
|
||||
generator = getattr(streams_context, "gen", None)
|
||||
frame = getattr(generator, "ag_frame", None)
|
||||
if frame is None:
|
||||
return None
|
||||
process = frame.f_locals.get("process")
|
||||
pid = getattr(process, "pid", None)
|
||||
try:
|
||||
return int(pid) if pid is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
|
||||
"""Connect to MCP server
|
||||
|
||||
@@ -222,7 +141,6 @@ class MCPClient:
|
||||
# Store config for reconnection
|
||||
self._mcp_server_config = mcp_server_config
|
||||
self._server_name = name
|
||||
self.process_pid = None
|
||||
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
@@ -232,7 +150,7 @@ class MCPClient:
|
||||
# Handle MCP service error logs
|
||||
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
|
||||
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
|
||||
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
|
||||
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
|
||||
self.server_errlogs.append(log_msg)
|
||||
|
||||
if "url" in cfg:
|
||||
@@ -265,7 +183,7 @@ class MCPClient:
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -291,12 +209,11 @@ class MCPClient:
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
cfg = _prepare_stdio_env(cfg)
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
@@ -311,7 +228,7 @@ class MCPClient:
|
||||
"alert",
|
||||
"emergency",
|
||||
):
|
||||
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
|
||||
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
|
||||
self.server_errlogs.append(log_msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
@@ -322,10 +239,9 @@ class MCPClient:
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
self.process_pid = self._extract_stdio_process_pid(self._streams_context)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
@@ -416,7 +332,7 @@ class MCPClient:
|
||||
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
before_sleep=before_sleep_log(TenacityLogger(logger), logging.WARNING),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_with_retry():
|
||||
@@ -455,7 +371,6 @@ class MCPClient:
|
||||
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
self.process_pid = None
|
||||
|
||||
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
|
||||
@@ -16,7 +16,7 @@ class ContextWrapper(Generic[TContext]):
|
||||
context: TContext
|
||||
messages: list[Message] = Field(default_factory=list)
|
||||
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
||||
tool_call_timeout: int = 120 # Default tool call timeout in seconds
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import abc
|
||||
from collections.abc import AsyncGenerator
|
||||
import typing as T
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Generic
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.response import AgentResponse
|
||||
from astrbot.core.agent.run_context import ContextWrapper, TContext
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..response import AgentResponse
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
@@ -21,32 +19,13 @@ class AgentState(Enum):
|
||||
ERROR = auto() # Error state
|
||||
|
||||
|
||||
class BaseAgentRunner(Generic[TContext]):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
self.tasks: set = set()
|
||||
|
||||
class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
enforce_max_turns: int = -1,
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
truncate_turns: int = 1,
|
||||
custom_token_counter: Any = None,
|
||||
custom_compressor: Any = None,
|
||||
tool_schema_mode: str | None = "full",
|
||||
fallback_providers: list[Provider] | None = None,
|
||||
provider_config: dict | None = None,
|
||||
**kwargs: Any,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
"""Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
@@ -54,14 +33,14 @@ class BaseAgentRunner(Generic[TContext]):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step(self) -> AsyncGenerator[AgentResponse, None]:
|
||||
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process a single step of the agent."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> AsyncGenerator[AgentResponse, None]:
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
...
|
||||
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
import typing as T
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.response import AgentResponse, AgentResponseData
|
||||
from astrbot.core.agent.run_context import ContextWrapper, TContext
|
||||
from astrbot.core.agent.runners.base import AgentState, BaseAgentRunner
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ...hooks import BaseAgentRunHooks
|
||||
from ...response import AgentResponseData
|
||||
from ...run_context import ContextWrapper, TContext
|
||||
from ..base import AgentResponse, AgentState, BaseAgentRunner
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
@@ -33,45 +30,32 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
enforce_max_turns: int = -1,
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
truncate_turns: int = 1,
|
||||
custom_token_counter: Any = None,
|
||||
custom_compressor: Any = None,
|
||||
tool_schema_mode: str | None = "full",
|
||||
fallback_providers: list[Provider] | None = None,
|
||||
provider_config: dict | None = None,
|
||||
**kwargs: Any,
|
||||
provider_config: dict,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = streaming
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
provider_config = provider_config or {}
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
@@ -86,7 +70,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.file_id_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
@override
|
||||
async def step(self) -> AsyncGenerator[AgentResponse, None]:
|
||||
async def step(self):
|
||||
"""
|
||||
执行 Coze Agent 的一个步骤
|
||||
"""
|
||||
@@ -99,7 +83,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
|
||||
try:
|
||||
@@ -107,15 +91,15 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for response in self._execute_coze_request():
|
||||
yield response
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 请求失败:{e!s}")
|
||||
logger.error(f"Coze 请求失败:{str(e)}")
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self.final_llm_resp = LLMResponse(
|
||||
role="err", completion_text=f"Coze 请求失败:{e!s}"
|
||||
role="err", completion_text=f"Coze 请求失败:{str(e)}"
|
||||
)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"Coze 请求失败:{e!s}")
|
||||
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
|
||||
),
|
||||
)
|
||||
finally:
|
||||
@@ -123,8 +107,8 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
@override
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> AsyncGenerator[AgentResponse, None]:
|
||||
self, max_step: int = 30
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
while not self.done():
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
@@ -168,7 +152,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 处理上下文中的图片
|
||||
content = ctx["content"]
|
||||
if isinstance(content, list):
|
||||
# 多模态内容,需要处理图片
|
||||
# 多模态内容,需要处理图片
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
@@ -293,7 +277,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
accumulated_content += content
|
||||
message_started = True
|
||||
|
||||
# 如果是流式响应,发送增量数据
|
||||
# 如果是流式响应,发送增量数据
|
||||
if self.streaming:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
@@ -344,7 +328,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
image_url: str,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
import hashlib
|
||||
|
||||
# 计算哈希实现缓存
|
||||
@@ -365,7 +349,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
if session_id:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class CozeAPIClient:
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(
|
||||
@@ -75,7 +75,7 @@ class CozeAPIClient:
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}",
|
||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -87,7 +87,7 @@ class CozeAPIClient:
|
||||
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
||||
|
||||
file_id = result["data"]["id"]
|
||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@@ -111,7 +111,7 @@ class CozeAPIClient:
|
||||
try:
|
||||
async with session.get(image_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"下载图片失败,状态码: {response.status}")
|
||||
raise Exception(f"下载图片失败,状态码: {response.status}")
|
||||
|
||||
image_data = await response.read()
|
||||
return image_data
|
||||
@@ -145,7 +145,7 @@ class CozeAPIClient:
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/chat"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
payload = {
|
||||
"bot_id": bot_id,
|
||||
"user_id": user_id,
|
||||
"stream": stream,
|
||||
@@ -169,10 +169,10 @@ class CozeAPIClient:
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||
|
||||
# SSE
|
||||
buffer = ""
|
||||
@@ -226,10 +226,10 @@ class CozeAPIClient:
|
||||
response_text = await response.text()
|
||||
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
||||
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
||||
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
@@ -299,6 +299,7 @@ if __name__ == "__main__":
|
||||
async with await anyio.open_file("README.md", "rb") as f:
|
||||
file_data = await f.read()
|
||||
file_id = await client.upload_file(file_data)
|
||||
print(f"Uploaded file_id: {file_id}")
|
||||
async for event in client.chat_messages(
|
||||
bot_id=bot_id,
|
||||
user_id="test_user",
|
||||
@@ -317,7 +318,7 @@ if __name__ == "__main__":
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
pass
|
||||
print(f"Event: {event}")
|
||||
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user