Compare commits

..

1 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
c46cec8742 Initial plan 2026-03-18 18:10:39 +00:00
664 changed files with 11009 additions and 41178 deletions

View File

@@ -1,184 +1,93 @@
# ==========================================
# AstrBot Instance Configuration: ${INSTANCE_NAME}
# AstrBot 实例配置文件:${INSTANCE_NAME}
# AstrBot Environment Configuration Example
# ==========================================
# 将此文件复制为 .env 并根据需要修改。
# Copy this file to .env and modify as needed.
# 注意:在此处设置的变量将覆盖默认配置。
# Note: Variables set here override application defaults.
# Copy this file to .env and adjust the values as needed.
# Note: Variables set here will override default configurations.
# ------------------------------------------
# 实例标识 / Instance Identity
# Core Configuration (核心配置)
# ------------------------------------------
# 实例名称(用于日志和服务名)
# Instance name (used in logs/service names)
INSTANCE_NAME="${INSTANCE_NAME}"
# AstrBot root directory path. Defaults to current working directory or ~/.astrbot for desktop client.
# ASTRBOT_ROOT=/path/to/astrbot
# ------------------------------------------
# 核心配置 / Core Configuration
# ------------------------------------------
# AstrBot 根目录路径
# AstrBot root directory path
# 默认 Default: 当前工作目录,桌面客户端为 ~/.astrbot服务器为 /var/lib/astrbot/<instance>/
# 示例 Example: /var/lib/astrbot/mybot
ASTRBOT_ROOT="${ASTRBOT_ROOT}"
# 日志等级
# Log level
# 可选值 Values: DEBUG, INFO, WARNING, ERROR, CRITICAL
# 默认 Default: INFO
# Log level. Options: DEBUG, INFO, WARNING, ERROR, CRITICAL. Default: INFO.
# ASTRBOT_LOG_LEVEL=INFO
# 启用插件热重载(开发时有用)
# Enable plugin hot reload (useful for development)
# 可选值 Values: 0 (禁用 disabled), 1 (启用 enabled)
# 默认 Default: 0
# Enable plugin auto-reload. Set to "1" to enable. Useful for development.
# ASTRBOT_RELOAD=0
# 禁用匿名使用统计
# Disable anonymous usage statistics
# 可选值 Values: 0 (启用统计 enabled), 1 (禁用统计 disabled)
# 默认 Default: 0
ASTRBOT_DISABLE_METRICS=0
# Disable metrics upload. Set to "1" to disable anonymous usage statistics.
# ASTRBOT_DISABLE_METRICS=0
# 覆盖 Python 可执行文件路径(用于本地代码执行功能)
# Override Python executable path (for local code execution)
# 示例 Example: /usr/bin/python3, /home/user/.pyenv/shims/python
# Python executable path override (used for local code execution feature).
# PYTHON=/usr/bin/python3
# 启用演示模式(可能限制部分功能)
# Enable demo mode (may restrict certain features)
# 可选值 Values: True, False
# 默认 Default: False
# Enable demo mode (might restrict some features).
# DEMO_MODE=False
# 启用测试模式(影响日志和部分行为)
# Enable testing mode (affects logging and behavior)
# 可选值 Values: True, False
# 默认 Default: False
# Enable testing mode (affects logging and some behaviors).
# TESTING=False
# 标记:是否通过桌面客户端执行(主要用于内部)
# Flag: running via desktop client (internal use)
# 可选值 Values: 0, 1
# Flag indicating execution via desktop client (Internal use mostly).
# ASTRBOT_DESKTOP_CLIENT=0
# 标记:是否通过 systemd 服务执行
# Flag: running via systemd service
# 可选值 Values: 0, 1
ASTRBOT_SYSTEMD=1
# Flag indicating execution via systemd service.
# ASTRBOT_SYSTEMD=0
# ------------------------------------------
# 管理面板配置 / Dashboard Configuration
# Dashboard Configuration (管理面板配置)
# ------------------------------------------
# 启用或禁用 WebUI 管理面板
# Enable or disable WebUI dashboard
# 可选值 Values: True, False
# 默认 Default: True
ASTRBOT_DASHBOARD_ENABLE=True
# Enable or disable the WebUI Dashboard. Default: True.
# ASTRBOT_DASHBOARD_ENABLE=True
# Dashboard bind host. Default: 0.0.0.0 (listen on all interfaces).
# ASTRBOT_DASHBOARD_HOST=0.0.0.0
# Dashboard bind port. Default: 6185.
# ASTRBOT_DASHBOARD_PORT=6185
# Enable SSL (HTTPS) for the dashboard.
# ASTRBOT_DASHBOARD_SSL_ENABLE=False
# SSL Certificate path (required if SSL is enabled).
# ASTRBOT_DASHBOARD_SSL_CERT=/path/to/cert.pem
# SSL Key path (required if SSL is enabled).
# ASTRBOT_DASHBOARD_SSL_KEY=/path/to/key.pem
# SSL CA Certificates path (optional).
# ASTRBOT_DASHBOARD_SSL_CA_CERTS=/path/to/ca.pem
# ------------------------------------------
# 国际化配置 / Internationalization Configuration
# Network Configuration (网络配置)
# ------------------------------------------
# CLI 界面语言
# CLI interface language
# 可选值 Values: zh (中文), en (英文)
# 默认 Default: zh (跟随系统 locale / follows system locale)
# ASTRBOT_CLI_LANG=zh
# TUI 界面语言
# TUI interface language
# 可选值 Values: zh (中文), en (英文)
# 默认 Default: zh
# ASTRBOT_TUI_LANG=zh
# ------------------------------------------
# 网络配置 / Network Configuration
# ------------------------------------------
# API 绑定主机
# API bind host
# 示例 Example: 0.0.0.0 (所有接口 all interfaces), 127.0.0.1 (仅本地 localhost only)
ASTRBOT_HOST="${ASTRBOT_HOST}"
# API 绑定端口
# API bind port
# 示例 Example: 3000, 6185, 8080
ASTRBOT_PORT="${ASTRBOT_PORT}"
# 是否为 API 启用 SSL/TLS
# Enable SSL/TLS for API
# 可选值 Values: true, false
# 默认 Default: false
ASTRBOT_SSL_ENABLE=false
# SSL 证书路径PEM 格式)
# SSL certificate path (PEM format)
# 示例 Example: /etc/astrbot/certs/myinstance/fullchain.pem
ASTRBOT_SSL_CERT=""
# SSL 私钥路径PEM 格式)
# SSL private key path (PEM format)
# 示例 Example: /etc/astrbot/certs/myinstance/privkey.pem
ASTRBOT_SSL_KEY=""
# SSL CA 证书链路径(可选,用于客户端验证)
# SSL CA certificates bundle (optional, for client verification)
# 示例 Example: /etc/ssl/certs/ca-certificates.crt
ASTRBOT_SSL_CA_CERTS=""
# ------------------------------------------
# 代理配置 / Proxy Configuration
# ------------------------------------------
# HTTP 代理地址
# HTTP proxy URL
# 示例 Example: http://127.0.0.1:7890, socks5://127.0.0.1:1080
# HTTP/HTTPS Proxy URL (e.g., http://127.0.0.1:7890).
# http_proxy=
# HTTPS 代理地址
# HTTPS proxy URL
# 示例 Example: http://127.0.0.1:7890, socks5://127.0.0.1:1080
# https_proxy=
# 不走代理的主机列表(逗号分隔)
# Hosts to bypass proxy (comma-separated)
# 示例 Example: localhost,127.0.0.1,192.168.0.0/16,.local
# No proxy list (comma-separated domains/IPs to bypass proxy).
# no_proxy=localhost,127.0.0.1
# ------------------------------------------
# 第三方集成 / Third-party Integrations
# Integrations (第三方集成)
# ------------------------------------------
# 阿里云 DashScope API 密钥(用于 Rerank 服务)
# Alibaba DashScope API Key (for Rerank service)
# 获取地址 Get from: https://dashscope.console.aliyun.com/
# 示例 Example: sk-xxxxxxxxxxxx
# DASHSCOPE_API_KEY=
# Alibaba DashScope API Key (used for Rerank service).
# DASHSCOPE_API_KEY=sk-xxxxxxxxxxxx
# Coze 集成
# Coze integration
# 获取地址 Get from: https://www.coze.com/
# Coze Integration
# COZE_API_KEY=
# COZE_BOT_ID=
# 计算机控制相关的数据目录(用于截图/文件存储)
# Computer control data directory (for screenshots/file storage)
# 示例 Example: /var/lib/astrbot/bay_data
# Computer Use data directory (for screenshot/file storage related to computer control).
# BAY_DATA_DIR=
# ------------------------------------------
# 平台特定配置 / Platform-specific Configuration
# Platform Specific (平台特定配置)
# ------------------------------------------
# QQ 官方机器人测试模式开关
# QQ official bot test mode
# 可选值 Values: on, off
# 默认 Default: off
# Test mode for QQ Official Bot.
# TEST_MODE=off
# End of template / 模板结束

View File

@@ -51,7 +51,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup pnpm
uses: pnpm/action-setup@v5.0.0
uses: pnpm/action-setup@v4.4.0
with:
version: 10.28.2

View File

@@ -1,37 +0,0 @@
name: Unit Tests
on:
push:
branches:
- master
paths-ignore:
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
unit-tests:
name: Run pytest suite
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install uv
run: |
python -m pip install --upgrade pip
python -m pip install uv
- name: Run tests
run: |
chmod +x scripts/run_pytests_ci.sh
bash ./scripts/run_pytests_ci.sh ./tests

15
.gitignore vendored
View File

@@ -59,22 +59,7 @@ CharacterModels/
GenieData/
.agent/
.codex/
.claude/
.opencode/
.kilocode/
.serena
.worktrees/
.astrbot_sdk_testing/
.env
dashboard/warker.js
dashboard/bun.lock
.pua/
# Rust build artifacts
rust/target/
# Build outputs
dist/
*.whl
*.so

View File

@@ -6,20 +6,20 @@ ci:
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.15.7
hooks:
# Run the linter.
- id: ruff-check
types_or: [python, pyi]
args: [--fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.2
hooks:
- id: pyupgrade
args: [--py312-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]

View File

@@ -29,21 +29,7 @@ Runs on `http://localhost:3000` by default.
4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`.
5. Use English for all new comments.
6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory.
7. Use Python 3.12+ type hinting syntax (e.g., `list[str]` over `List[str]`, `int | None` over `Optional[int]`). Avoid using `Any` and `cast()` - use proper TypedDict, dataclass, or Protocol instead. When encountering dict access issues (e.g., `msg.get("key")` where ty infers wrong type), define a TypedDict with `total=False` to explicitly declare allowed keys.
Good example:
```python
class MessageComponent(TypedDict, total=False):
type: str
text: str
path: str
```
Bad example (avoid):
```python
msg: Any = something
msg = cast(dict, msg)
```
7. Use Python 3.12+ type hinting syntax (e.g., `list[str]` over `List[str]`, `int | None` over `Optional[int]`). Avoid using `Any` and ensure comprehensive type annotations are provided.
8. When introducing new environment variables:
- Use the `ASTRBOT_` prefix for naming (e.g., `ASTRBOT_ENABLE_FEATURE`).
- Add the variable and description to `.env.example`.
@@ -51,9 +37,9 @@ Runs on `http://localhost:3000` by default.
- Add to the module docstring under "Environment Variables Used in Project".
- Add to the `keys_to_print` list in the `run` function for debug output.
9. To check all available CLI commands and their usage recursively, run `astrbot help --all`.
10. uv sync --group dev && uv run pytest --cov=astrbot tests/
## PR instructions
1. Title format: use conventional commit messages
2. Use English to write PR title and descriptions.
2. Use English to write PR title and descriptions.

180
CLAUDE.md
View File

@@ -1,180 +0,0 @@
# AstrBot - Claude Code Guidelines
AstrBot is an open-source, all-in-one Agentic personal and group chat assistant supporting multiple IM platforms (QQ, Telegram, Discord, etc.) and LLM providers.
## Project Overview
- **Main entry**: `astrbot/__main__.py` or via CLI `astrbot run`
- **CLI commands**: `astrbot/cli/commands/`
- **Core modules**: `astrbot/core/`
- **Platform adapters**: `astrbot/core/platform/sources/`
- **Star plugins**: `astrbot/builtin_stars/`
- **Dashboard**: `dashboard/` (Vue.js frontend)
## Development Setup
```bash
# Install dependencies
uv tool install -e . --force
# Initialize AstrBot
astrbot init
# Run development
astrbot run
# Backend only (no WebUI)
astrbot run --backend-only
# Dashboard frontend
cd dashboard && bun dev
# Run tests
uv sync --group dev && uv run pytest --cov=astrbot tests/
```
## Code Style
### Python
1. **Type hints required** - Use Python 3.12+ syntax:
- `list[str]` not `List[str]`
- `int | None` not `Optional[int]`
- Avoid `Any` when possible
2. **Path handling** - Always use `pathlib.Path`:
```python
from pathlib import Path
# Use astrbot.core.utils.path_utils for data/temp directories
from astrbot.core.utils.path_utils import get_astrbot_data_path
```
3. **Formatting** - Run before committing:
```bash
ruff format .
ruff check .
```
4. **Comments** - Use English for all comments and docstrings
5. **Imports** - Use absolute imports via `astrbot.` prefix
### Environment Variables
When adding new environment variables:
1. Use `ASTRBOT_` prefix: `ASTRBOT_ENABLE_FEATURE`
2. Add to `.env.example` with description
3. Update `astrbot/cli/commands/cmd_run.py`:
- Add to module docstring under "Environment Variables Used in Project"
- Add to `keys_to_print` list for debug output
## Architecture
### Core Components
- `astrbot/core/` - Core bot functionality
- `astrbot/core/platform/` - Platform adapter system
- `astrbot/core/agent/` - Agent execution logic
- `astrbot/core/star/` - Plugin/Star handler system
- `astrbot/core/pipeline/` - Message processing pipeline
- `astrbot/cli/` - Command-line interface
### Important Utilities
```python
from astrbot.core.utils.astrbot_path import (
get_astrbot_root, # AstrBot root directory
get_astrbot_data_path, # Data directory
get_astrbot_config_path, # Config directory
get_astrbot_plugin_path, # Plugin directory
get_astrbot_temp_path, # Temp directory
get_astrbot_skills_path, # Skills directory
)
```
### Platform Adapters
Platform adapters are in `astrbot/core/platform/sources/`:
- Each adapter extends base platform classes
- Use `@register_platform_adapter` decorator
- Events flow through `commit_event()` to message queue
### Star (Plugin) System
Stars are plugins in `astrbot/builtin_stars/`:
- Extend `Star` base class
- Use decorators for command handlers: `@star.on_command`, `@star.on_message`, etc.
- Access via `context` object
## Testing
1. Tests go in `tests/` directory
2. Use `pytest` with `pytest-asyncio`
3. Coverage target: `uv run pytest --cov=astrbot tests/`
4. Test files: `test_*.py` or `*_test.py`
## Git Conventions
### Commit Messages
Use conventional commits:
```
feat: add new feature
fix: resolve bug
docs: update documentation
refactor: restructure code
test: add tests
chore: maintenance tasks
```
### PR Guidelines
1. Title: conventional commit format
2. Description: English
3. Target branch: `dev`
4. Keep changes focused and atomic
## Project-Specific Guidelines
1. **No report files** - Do not add `xxx_SUMMARY.md` or similar
2. **Componentization** - Maintain clean code, avoid duplication in WebUI
3. **Backward compatibility** - When deprecating, add warnings
4. **CLI help** - Run `astrbot help --all` to see all commands
## File Organization
```
astrbot/
├── __main__.py # Main entry point
├── __init__.py # Package init, exports
├── cli/ # CLI commands
│ └── commands/ # Individual command modules
├── core/ # Core functionality
│ ├── agent/ # Agent execution
│ ├── platform/ # Platform adapters
│ ├── pipeline/ # Message processing
│ ├── star/ # Plugin system
│ └── config/ # Configuration
├── builtin_stars/ # Built-in plugins
├── dashboard/ # Vue.js frontend
└── utils/ # Utilities
```
## Common Tasks
### Adding a new platform adapter
1. Create adapter in `astrbot/core/platform/sources/`
2. Extend `Platform` base class
3. Use `@register_platform_adapter` decorator
4. Implement required methods: `run()`, `convert_message()`, `meta()`
### Adding a new command
1. Add to appropriate module in `cli/commands/`
2. Register with `@click.command()`
3. Update `astrbot/cli/__main__.py` to add command
### Adding a new Star handler
1. Create in `astrbot/builtin_stars/` or as plugin
2. Extend `Star` class
3. Use decorators: `@star.on_command()`, `@star.on_schedule()`, etc.

View File

@@ -31,8 +31,9 @@
<a href="https://astrbot.app/">Docs</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
<a href="mailto:community@astrbot.app">Email Support</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issues</a>
<a href="mailto:community@astrbot.app">Email</a>
</div>
AstrBot is an open-source, all-in-one Agentic personal and group chat assistant that can be deployed on dozens of mainstream instant messaging platforms such as QQ, Telegram, WeCom, Lark, DingTalk, Slack, and more. It also features a built-in lightweight ChatUI similar to OpenWebUI, creating a reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether it's a personal AI companion, smart customer service, automated assistant, or enterprise knowledge base, AstrBot enables you to quickly build AI applications within the workflow of your instant messaging platforms.
@@ -198,13 +199,6 @@ Connect AstrBot to your favorite chat platforms.
| Minimax TTS | Text-to-Speech |
| Volcano Engine TTS | Text-to-Speech |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ Contribution
Welcome any Issues/Pull Requests! Just submit your changes to this project :)
@@ -304,4 +298,4 @@ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
</div>

View File

@@ -199,13 +199,6 @@ Connectez AstrBot à vos plateformes de chat préférées.
| Minimax TTS | Synthèse vocale (Text-to-Speech) |
| Volcengine TTS | Synthèse vocale (Text-to-Speech) |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ Contribution
Les Issues et Pull Requests sont les bienvenus ! Soumettez simplement vos modifications à ce projet :)
@@ -301,4 +294,4 @@ _私は、高性能ですから!_ (Je suis performant !)
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
</div>

View File

@@ -199,13 +199,6 @@ AstrBotを普段使用しているチャットプラットフォームに接続
| Minimax TTS | 音声合成 (TTS) |
| Volcengine TTS (火山エンジン) | 音声合成 (TTS) |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ 貢献
IssueやPull Requestは大歓迎です変更をこのプロジェクトに送信してください :)
@@ -301,4 +294,4 @@ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
</div>

View File

@@ -199,13 +199,6 @@ yay -S astrbot-git
| Minimax TTS | Синтез речи (TTS) |
| Volcengine TTS | Синтез речи (TTS) |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ Вклад в проект
Мы приветствуем любые Issues и Pull Requests! Просто отправьте свои изменения в этот проект :)
@@ -301,4 +294,4 @@ _私は、高性能ですから!_ (Я высокопроизводительны
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
</div>

View File

@@ -199,16 +199,9 @@ yay -S astrbot-git
| Minimax TTS | 文本轉語音 |
| 火山引擎 TTS | 文本轉語音 |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ 貢獻
歡迎任何 Issues/Pull Requests只需要將你的更改提交到此項目 :)
歡迎任何 Issues/Pull Requests只需要將你的更改提交到此項目 )
### 如何貢獻
@@ -301,4 +294,4 @@ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
</div>

View File

@@ -199,16 +199,9 @@ yay -S astrbot-git
| Minimax TTS | 文本转语音 |
| 火山引擎 TTS | 文本转语音 |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 :)
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
### 如何贡献

View File

@@ -1,16 +1,3 @@
from __future__ import annotations
from .core.log import LogManager
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .core import logger as logger
__all__ = ["logger"]
def __getattr__(name: str) -> Any:
if name == "logger":
from .core import logger
return logger
raise AttributeError(name)
logger = LogManager.GetLogger(log_name="astrbot")

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import anyio
import runtime_bootstrap
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.config.default import VERSION
from astrbot.core.initial_loader import InitialLoader
@@ -24,9 +25,8 @@ from astrbot.core.utils.io import (
download_dashboard,
get_dashboard_version,
)
from astrbot.runtime_bootstrap import initialize_runtime_bootstrap
initialize_runtime_bootstrap()
runtime_bootstrap.initialize_runtime_bootstrap()
# 将父目录添加到 sys.path
@@ -44,9 +44,9 @@ logo_tmpl = r"""
def check_env() -> None:
# Python version check: require 3.12 or 3.13
if not (sys.version_info.major == 3 and sys.version_info.minor in (12, 13)):
sys.exit(1)
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
logger.error("请使用 Python3.10+ 运行本项目。")
exit()
astrbot_root = get_astrbot_root()
if astrbot_root not in sys.path:
@@ -76,7 +76,7 @@ async def check_dashboard_files(webui_dir: str | None = None):
if await anyio.Path(webui_dir).exists():
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
return webui_dir
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑")
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在将使用默认逻辑")
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
if await anyio.Path(data_dist_path).exists():
@@ -84,41 +84,41 @@ async def check_dashboard_files(webui_dir: str | None = None):
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI 版本已是最新")
logger.info("WebUI 版本已是最新")
else:
logger.warning(
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符",
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符",
)
return data_dist_path
logger.info(
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下",
"开始下载管理面板文件...高峰期晚上可能导致较慢的速度如多次下载失败请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip并将其中的 dist 文件夹解压至 data 目录下",
)
try:
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.warning(
f"下载指定版本(v{VERSION})的管理面板文件失败: {e},尝试下载最新版本"
f"下载指定版本(v{VERSION})的管理面板文件失败: {e}尝试下载最新版本"
)
try:
await download_dashboard(latest=True)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成")
logger.info("管理面板下载完成")
return data_dist_path
async def main_async(webui_dir_arg: str | None, log_broker: LogBroker) -> None:
async def main_async(webui_dir_arg: str | None) -> None:
"""主异步入口"""
# 检查仪表板文件
webui_dir = await check_dashboard_files(webui_dir_arg)
if webui_dir is None:
logger.warning(
"管理面板文件检查失败,WebUI 功能将不可用"
"请检查网络连接或手动指定 --webui-dir 参数"
"管理面板文件检查失败WebUI 功能将不可用"
"请检查网络连接或手动指定 --webui-dir 参数"
)
db = db_helper
@@ -148,4 +148,4 @@ if __name__ == "__main__":
LogManager.set_queue_handler(logger, log_broker)
# 只使用一次 asyncio.run()
asyncio.run(main_async(args.webui_dir, log_broker))
asyncio.run(main_async(args.webui_dir))

View File

@@ -1,5 +0,0 @@
"""
Astbot内部实现
外部模块请勿导入
"""

View File

@@ -1,57 +0,0 @@
"""
ABP (AstrBot Protocol) client - in-process star communication.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAstrbotAbpClient(ABC):
"""
ABP client: in-process star (plugin) communication.
Stars register themselves; client delegates calls to registered instances.
Subclass must implement:
- connect() -> None
- register_star(name, instance) -> None
- unregister_star(name) -> None
- call_star_tool(star, tool, args) -> Any
- shutdown() -> None
"""
@property
@abstractmethod
def connected(self) -> bool: ...
@abstractmethod
async def connect(self) -> None:
"""Lightweight: just sets connected=True."""
...
@abstractmethod
def register_star(self, star_name: str, star_instance: Any) -> None:
"""Add star to internal registry."""
...
@abstractmethod
def unregister_star(self, star_name: str) -> None:
"""Remove star from registry (idempotent)."""
...
@abstractmethod
async def call_star_tool(
self,
star_name: str,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""Delegate to star_instance.call_tool(tool_name, arguments)."""
...
@abstractmethod
async def shutdown(self) -> None:
"""Set connected=False, cancel pending requests."""
...

View File

@@ -1,66 +0,0 @@
"""
ACP (AstrBot Communication Protocol) client.
Transport: TCP | Unix Socket
Messages: JSON with Content-Length header
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAstrbotAcpClient(ABC):
"""
ACP client: connects to ACP servers via TCP or Unix socket.
Subclass must implement:
- connect() -> None
- connect_to_server(host, port) -> None
- connect_to_unix_socket(path) -> None
- call_tool(server, tool, args) -> Any
- send_notification(method, params) -> None
- shutdown() -> None
"""
@property
@abstractmethod
def connected(self) -> bool: ...
@abstractmethod
async def connect(self) -> None: ...
@abstractmethod
async def connect_to_server(self, host: str, port: int) -> None:
"""Connect via TCP."""
...
@abstractmethod
async def connect_to_unix_socket(self, socket_path: str) -> None:
"""Connect via Unix domain socket."""
...
@abstractmethod
async def call_tool(
self,
server_name: str,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""Call tool on server, return result."""
...
@abstractmethod
async def send_notification(
self,
method: str,
params: dict[str, Any],
) -> None:
"""Send one-way notification."""
...
@abstractmethod
async def shutdown(self) -> None:
"""Close connection, cancel pending requests."""
...

View File

@@ -1,68 +0,0 @@
"""
ACP (AstrBot Communication Protocol) server.
Transport: TCP listening socket
Messages: JSON with Content-Length header
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
class BaseAstrbotAcpServer(ABC):
"""
ACP server: listens for client connections, exposes tools.
Subclass must implement:
- start(host, port) -> None
- register_tool(name, handler) -> None
- register_notification_handler(name, handler) -> None
- broadcast_notification(method, params) -> None
- shutdown() -> None
"""
@property
@abstractmethod
def running(self) -> bool:
"""True if server is accepting connections."""
...
@abstractmethod
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
"""Bind and listen. Block until shutdown."""
...
@abstractmethod
def register_tool(
self,
name: str,
handler: Callable[..., Any],
) -> None:
"""Register async tool handler (receives params dict, returns result)."""
...
@abstractmethod
def register_notification_handler(
self,
name: str,
handler: Callable[..., Any],
) -> None:
"""Register async notification handler (receives params dict)."""
...
@abstractmethod
async def broadcast_notification(
self,
method: str,
params: dict[str, Any],
) -> None:
"""Send notification to all connected clients."""
...
@abstractmethod
async def shutdown(self) -> None:
"""Stop accepting, close all client connections."""
...

View File

@@ -1,73 +0,0 @@
"""
AstrBot Gateway - HTTP/WebSocket API server.
Built on FastAPI, provides:
- HTTP REST API (stats, inspector, config)
- WebSocket for real-time events
- Static file serving (dashboard)
- Authentication (JWT/API key)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
class BaseAstrbotGateway(ABC):
"""
Gateway: HTTP/WebSocket server built on FastAPI.
┌─────────────────────────────────────────────────────────┐
│ FastAPI App │
├─────────────────────────────────────────────────────────┤
│ REST Endpoints WebSocket │
│ ├─ GET /api/stats ├─ /ws (connection manager)│
│ ├─ GET /api/inspector/* │ │
│ ├─ GET /api/memory/* │ │
│ └─ ... │ │
│ │
│ Middleware: CORS, Auth, Logging │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────┐
│ Orchestrator │
│ (owns protocol clients)│
└─────────────────────────┘
Routes (typical):
GET / → Dashboard static files
GET /api/stats → System statistics
GET /api/inspector/stars → List registered stars
WS /ws → WebSocket for real-time events
serve() Lifecycle:
1. Create FastAPI app
2. Register routes
3. Start WebSocket manager
4. Bind to host:port
5. Run ASGI server (uvicorn/hypercorn)
6. Block until shutdown
7. Close all connections
Subclass must implement:
- serve(): start server, block until shutdown
"""
@abstractmethod
async def serve(self) -> None:
"""
Start gateway server - blocks until shutdown.
Should:
1. Create FastAPI app with routes
2. Configure CORS, auth middleware
3. Start WebSocket connection manager
4. Bind to ASTRBOT_PORT (default 6185)
5. Run ASGI server
6. Handle graceful shutdown on SIGTERM/SIGINT
Raises:
OSError: address already in use
"""
...

View File

@@ -1,352 +0,0 @@
"""
AstrBot Orchestrator - core runtime lifecycle manager.
Architecture
============
┌─────────────────────────────────────────────────────┐
│ Orchestrator │
│ (owns lifecycle of all protocol clients + stars) │
└─────────────────────────────────────────────────────┘
┌──────────────┼──────────────┐
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ LSP │ │ MCP │ │ ACP │
│ Client │ │ Client │ │ Client │
└─────────┘ └─────────┘ └─────────┘
│ │ │
▼ ▼ ▼
LSP Servers MCP Servers ACP Services
┌─────────────────────────────────────────────────────┐
│ ABP Client │
│ (in-process star registry) │
└─────────────────────────────────────────────────────┘
┌─────────┐
│ Stars │
│(Plugins) │
└─────────┘
Lifecycle State Machine
=======================
States:
┌─────────┐
│ INIT │───► orchestrator created, clients not initialized
└────┬────┘
│ start()
┌─────────┐
│ RUNNING │◄─── run_loop() executing
└────┬────┘
│ shutdown()
┌──────────┐
│ SHUTDOWN │─── all clients closed, ready for GC
└──────────┘
Transitions:
INIT + start() ──► RUNNING
RUNNING + shutdown() ──► SHUTDOWN
For each protocol client, the orchestrator:
1. Creates instance in __init__
2. Calls connect() to initialize
3. Calls protocol-specific setup (connect_to_server, etc)
4. Manages via run_loop() heartbeat
5. Calls shutdown() on final cleanup
Star Registration Flow
=====================
orchestrator.register_star("my-star", MyStar())
┌───────────────────┐
│ ABP Client │
│ .register_star() │
└───────────────────┘
┌───────────────────┐
│ Internal dict │
{"my-star": obj} │
└───────────────────┘
Message Routing (conceptual)
===========================
External Tool Call
┌──────────────┐ list_tools() ┌──────────────┐
│ MCP Client │────────────────────►│ MCP Server │
└──────────────┘◄────────────────────└──────────────┘
│ tool result
┌──────────────┐ call_tool() ┌──────────────┐
│ ABP │────────────────────►│ Star │
│ Client │◄────────────────────└──────────────┘
└──────────────┘ tool result
Return to caller
run_loop() Responsibilities
===========================
while running:
│─ check LSP server health (ping/heartbeat)
│─ check MCP session status (reconnect if needed)
│─ check ACP client connections
│─ process any pending star notifications
│─ sleep(SLEEP_INTERVAL)
Shutdown Sequence
==================
shutdown()
├─ set _running = False
├─ LSP.shutdown()
│ └─ send "shutdown" request
│ └─ terminate subprocess
├─ ACP.shutdown()
│ └─ close TCP/Unix connections
├─ ABP.shutdown()
│ └─ cancel pending requests
└─ MCP.cleanup()
└─ close all sessions
└─ cleanup subprocesses
Exception Handling
==================
Each protocol client should:
- Catch connection errors
- Attempt reconnection with exponential backoff
- Log errors but don't crash run_loop
- Raise on irrecoverable failures
The orchestrator run_loop should:
- Catch CancelledError on shutdown
- Catch Exception and log (don't crash)
- Ensure cleanup runs in finally block
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
from astrbot._internal.protocols.mcp.client import McpClient
#: Default heartbeat interval for run_loop()
DEFAULT_SLEEP_INTERVAL: float = 5.0
class BaseAstrbotOrchestrator(ABC):
"""
Core runtime: owns lifecycle of all protocol clients and stars.
┌────────────────────────────────────────────────────────────┐
│ Protocol Clients (always present, never None after init) │
├────────────────────────────────────────────────────────────┤
│ lsp: Language Server Protocol │
│ Purpose: code completion, diagnostics, hover, etc │
│ Transport: stdio subprocess │
│ │
│ mcp: Model Context Protocol │
│ Purpose: external tool access │
│ Transport: stdio | SSE | HTTP │
│ │
│ acp: AstrBot Communication Protocol │
│ Purpose: inter-service communication │
│ Transport: TCP | Unix Socket │
│ │
│ abp: AstrBot Protocol │
│ Purpose: in-process star (plugin) communication │
│ Transport: direct method calls │
└────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
│ Star Registry │
├────────────────────────────────────────────────────────────┤
│ _stars: dict[str, Any] │
│ Stars are plugins registered by name │
│ ABP client delegates calls to registered stars │
└────────────────────────────────────────────────────────────┘
Subclass must implement:
- __init__(): create all protocol client instances
- run_loop(): main event loop (block until shutdown)
- register_star(name, instance): add to registry + ABP
- unregister_star(name): remove from registry + ABP
- shutdown(): clean up all clients
"""
#: LSP client for language intelligence
lsp: AstrbotLspClient
#: MCP client for external tools
mcp: McpClient
#: ACP client for inter-service communication
acp: AstrbotAcpClient
#: ABP client for in-process star communication
abp: AstrbotAbpClient
def __init__(self) -> None:
"""
Initialize orchestrator and all protocol clients.
After __init__, all clients exist but are not connected.
Call start() or run_loop() to begin operation.
Example:
class MyOrchestrator(BaseAstrbotOrchestrator):
def __init__(self):
self.lsp = AstrbotLspClient()
self.mcp = McpClient()
self.acp = AstrbotAcpClient()
self.abp = AstrbotAbpClient()
self._stars: dict[str, Any] = {}
self._running = False
"""
self._stars: dict[str, Any] = {}
self._running: bool = False
@property
def running(self) -> bool:
"""True if run_loop() is executing."""
return self._running
@abstractmethod
async def start(self) -> None:
"""
Initialize all protocol clients.
Called once before run_loop(). Should:
1. Call lsp.connect()
2. Call mcp.connect()
3. Call acp.connect()
4. Call abp.connect()
5. Set _running = True
Raises:
Exception: if any client fails to initialize
"""
...
@abstractmethod
async def run_loop(self) -> None:
"""
Main event loop - blocks until shutdown.
Execution:
self._running = True
try:
while self._running:
await self._heartbeat()
await anyio.sleep(DEFAULT_SLEEP_INTERVAL)
except asyncio.CancelledError:
pass # shutdown requested
finally:
self._running = False
_heartbeat() responsibilities:
- Check LSP server health (optional ping)
- Check MCP session status, reconnect if needed
- Check ACP connections
- Process any pending star notifications
Raises:
asyncio.CancelledError: when shutdown() called
Note:
Subclass defines _heartbeat() for periodic tasks.
This method only handles the loop control.
"""
...
@abstractmethod
async def register_star(self, name: str, star_instance: Any) -> None:
"""
Register a star (plugin) with the orchestrator.
Args:
name: Unique identifier for the star
instance: Star plugin instance (must have .call_tool() method)
Does:
self._stars[name] = star_instance
self.abp.register_star(name, star_instance)
Raises:
ValueError: if name already registered
"""
...
@abstractmethod
async def unregister_star(self, name: str) -> None:
"""
Unregister a star (plugin) from the orchestrator.
Args:
name: Identifier of star to remove
Does:
del self._stars[name]
self.abp.unregister_star(name)
Note:
Idempotent - does nothing if name not found.
"""
...
@abstractmethod
async def get_star(self, name: str) -> Any | None:
"""Get registered star by name. Returns None if not found."""
...
@abstractmethod
async def list_stars(self) -> list[str]:
"""Return list of registered star names."""
...
@abstractmethod
async def shutdown(self) -> None:
"""
Graceful shutdown of orchestrator and all clients.
Execution order:
1. self._running = False (stop run_loop)
2. await lsp.shutdown()
3. await acp.shutdown()
4. await abp.shutdown()
5. await mcp.cleanup()
Does NOT unregister stars - caller should do that first.
After shutdown, orchestrator is ready for garbage collection.
"""
...

View File

@@ -1,114 +0,0 @@
"""
LSP (Language Server Protocol) client.
Transport: stdio subprocess
Messages: JSON-RPC 2.0 with Content-Length header
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
pass
class LspMessage:
"""JSON-RPC 2.0 message."""
jsonrpc: str = "2.0"
id: int | str | None = None
method: str | None = None
params: dict[str, Any] | None = None
result: Any = None
error: dict[str, Any] | None = None
class LspRequest(LspMessage):
"""Outgoing request."""
def __init__(self, method: str, params: dict[str, Any] | None = None) -> None:
self.id = id(self)
self.method = method
self.params = params
class LspResponse(LspMessage):
"""Incoming response."""
class LspNotification(LspMessage):
"""Incoming notification (no id)."""
class BaseAstrbotLspClient(ABC):
"""
LSP client: connects to LSP servers via stdio subprocess.
Subclass must implement:
- connect() -> None
- connect_to_server(command, workspace_uri) -> None
- send_request(method, params) -> dict
- send_notification(method, params) -> None
- shutdown() -> None
"""
@property
@abstractmethod
def connected(self) -> bool:
"""True if connected to an LSP server."""
...
@abstractmethod
async def connect(self) -> None:
self._connected = False
...
@abstractmethod
async def connect_to_server(
self,
command: list[str],
workspace_uri: str,
) -> None:
"""
Start LSP server subprocess and complete handshake.
Steps:
1. Spawn subprocess with stdin/stdout pipes
2. Send initialize request
3. Wait for response
4. Send initialized notification
"""
...
@abstractmethod
async def send_request(
self,
method: str,
params: dict[str, Any] | None = None,
) -> Any:
"""
Send JSON-RPC request and return result.
Raises:
RuntimeError: not connected
Exception: server returned error
"""
...
@abstractmethod
async def send_notification(
self,
method: str,
params: dict[str, Any] | None = None,
) -> None:
"""
Send JSON-RPC notification (no response expected).
"""
...
@abstractmethod
async def shutdown(self) -> None:
"""Send shutdown, terminate subprocess, cleanup."""
...

View File

@@ -1,95 +0,0 @@
"""
MCP (Model Context Protocol) client.
Transport: stdio | SSE | streamable_http
Messages: JSON-RPC 2.0
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal, TypedDict
if TYPE_CHECKING:
pass
class McpServerConfig(TypedDict, total=False):
"""MCP server configuration."""
# Stdio transport
command: str
args: list[str]
env: dict[str, str]
cwd: str
# HTTP transport
url: str
headers: dict[str, str]
transport: Literal["sse", "streamable_http"]
class McpToolInfo(TypedDict):
"""MCP tool descriptor."""
name: str
description: str
inputSchema: dict[str, Any]
class BaseAstrbotMcpClient(ABC):
"""
MCP client: connects to MCP servers for external tools.
Subclass must implement:
- connect() -> None
- connect_to_server(config, name) -> None
- list_tools() -> list[McpToolInfo]
- call_tool(name, args, timeout) -> CallToolResult
- cleanup() -> None
"""
session: Any # mcp.ClientSession
@property
@abstractmethod
def connected(self) -> bool: ...
@abstractmethod
async def connect(self) -> None:
"""Initialize client session."""
...
@abstractmethod
async def connect_to_server(
self,
config: McpServerConfig,
name: str,
) -> None:
"""
Connect to MCP server.
Stdio: {"command": "python", "args": ["server.py"], "env": {...}}
HTTP: {"url": "https://...", "transport": "sse"}
"""
...
@abstractmethod
async def list_tools(self) -> list[McpToolInfo]:
"""Call tools/list and return tools."""
...
@abstractmethod
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: int = 60,
) -> Any:
"""Call tools/call with reconnection support."""
...
@abstractmethod
async def cleanup(self) -> None:
"""Close all server connections."""
...

View File

@@ -1,6 +0,0 @@
"""Gateway module - FastAPI server for the dashboard backend."""
from .server import AstrbotGateway
from .ws_manager import WebSocketManager
__all__ = ["AstrbotGateway", "WebSocketManager"]

View File

@@ -1,4 +0,0 @@
"""
依赖注入
"""

View File

@@ -1,248 +0,0 @@
"""
AstrBot Gateway - FastAPI server for the dashboard backend.
Provides REST API endpoints and WebSocket connections for the frontend dashboard.
The gateway acts as the communication bridge between the dashboard and the orchestrator.
"""
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, cast
from astrbot import logger
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.geteway.ws_manager import WebSocketManager
if TYPE_CHECKING:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
else:
try:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
except ImportError:
logger.warning("FastAPI not installed, gateway unavailable.")
FastAPI = cast(Any, None)
WebSocket = cast(Any, None)
WebSocketDisconnect = cast(Any, None)
CORSMiddleware = cast(Any, None)
log = logger
class AstrbotGateway(BaseAstrbotGateway):
"""
FastAPI-based gateway server for AstrBot.
Handles:
- REST API endpoints for configuration and stats
- WebSocket connections for real-time communication
- CORS middleware for dashboard access
"""
def __init__(self, orchestrator: BaseAstrbotOrchestrator) -> None:
self.orchestrator = orchestrator
self.ws_manager = WebSocketManager()
self._app: FastAPI | None = None
self._host = "0.0.0.0"
self._port = 8765
async def serve(self) -> None:
"""
Start the gateway server.
Creates and runs a FastAPI application with WebSocket support.
"""
if FastAPI is None:
raise RuntimeError("FastAPI is not installed")
log.info(f"Starting AstrBot Gateway on {self._host}:{self._port}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
log.info("Gateway server started.")
yield
# Shutdown
await self.ws_manager.broadcast({"type": "server_shutdown"})
log.info("Gateway server stopped.")
self._app = FastAPI(
title="AstrBot Gateway",
description="Backend API for AstrBot dashboard",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware
self._app.add_middleware(
cast(Any, CORSMiddleware),
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
self._setup_routes()
# Run with uvicorn
import uvicorn
config = uvicorn.Config(
self._app,
host=self._host,
port=self._port,
log_level="info",
)
server = uvicorn.Server(config)
await server.serve()
def _setup_routes(self) -> None:
"""Set up API routes."""
if self._app is None:
return
from fastapi import APIRouter
# Health check
@self._app.get("/health")
async def health():
return {"status": "ok"}
# WebSocket endpoint
@self._app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await self.ws_manager.connect(ws)
try:
while True:
data = await ws.receive_text()
try:
message = json.loads(data)
response = await self._handle_ws_message(message)
if response:
await ws.send_json(response)
except json.JSONDecodeError:
await ws.send_json({"error": "Invalid JSON"})
except WebSocketDisconnect:
self.ws_manager.disconnect(ws)
# Stats router
stats_router = APIRouter(prefix="/api/stats", tags=["stats"])
@stats_router.get("/overview")
async def get_overview():
return await self._get_stats_overview()
self._app.include_router(stats_router)
# Inspector router
inspector_router = APIRouter(prefix="/api/inspector", tags=["inspector"])
@inspector_router.get("/stars")
async def list_stars():
return await self._list_stars()
@inspector_router.get("/stars/{star_name}")
async def get_star(star_name: str):
return await self._get_star_detail(star_name)
self._app.include_router(inspector_router)
# Memory router
memory_router = APIRouter(prefix="/api/memory", tags=["memory"])
@memory_router.get("/")
async def get_memory():
return await self._get_memory_info()
self._app.include_router(memory_router)
async def _handle_ws_message(
self, message: dict[str, Any]
) -> dict[str, Any] | None:
"""
Handle an incoming WebSocket message.
Args:
message: Parsed JSON message from the client
Returns:
Response message to send back, or None for no response
"""
msg_type = message.get("type")
data = message.get("data", {})
if msg_type == "ping":
return {"type": "pong", "data": {}}
if msg_type == "call_tool":
return await self._handle_call_tool(data)
if msg_type == "get_stars":
return {"type": "stars_list", "data": await self._list_stars()}
return {
"type": "error",
"data": {"message": f"Unknown message type: {msg_type}"},
}
async def _handle_call_tool(self, data: dict[str, Any]) -> dict[str, Any]:
"""Handle a tool call request via WebSocket."""
star_name = data.get("star")
tool_name = data.get("tool")
arguments = data.get("arguments", {})
if not star_name or not tool_name:
return {
"type": "tool_result",
"data": {"error": "Missing star or tool name"},
}
try:
result = await self.orchestrator.abp.call_star_tool(
star_name, tool_name, arguments
)
return {"type": "tool_result", "data": {"result": result}}
except Exception as e:
return {"type": "tool_result", "data": {"error": str(e)}}
async def _get_stats_overview(self) -> dict[str, Any]:
"""Get overview statistics."""
return {
"stars_count": len(self.orchestrator.abp._stars),
"lsp_connected": self.orchestrator.lsp._connected,
"mcp_sessions": getattr(self.orchestrator.mcp, "session", None) is not None,
"acp_clients": len(getattr(self.orchestrator.acp, "_clients", [])),
}
async def _list_stars(self) -> list[dict[str, Any]]:
"""List all registered stars."""
stars = []
for name in self.orchestrator.abp._stars:
stars.append({"name": name, "status": "active"})
return stars
async def _get_star_detail(self, star_name: str) -> dict[str, Any]:
"""Get details of a specific star."""
star = self.orchestrator.abp._stars.get(star_name)
if not star:
return {"error": f"Star '{star_name}' not found"}
return {"name": star_name, "status": "active"}
async def _get_memory_info(self) -> dict[str, Any]:
"""Get memory usage information."""
import gc
gc.collect()
return {
"gc_objects": len(gc.get_objects()),
"python_memory": "N/A", # Would need psutil for actual values
}
def set_listen_address(self, host: str, port: int) -> None:
"""Set the listen address for the gateway server."""
self._host = host
self._port = port

View File

@@ -1,103 +0,0 @@
"""
WebSocket connection manager for the AstrBot gateway.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
import anyio
from astrbot import logger
if TYPE_CHECKING:
from fastapi import WebSocket
else:
try:
from fastapi import WebSocket
except ImportError:
logger.warning("FastAPI not installed, WebSocketManager unavailable.")
WebSocket = cast(Any, None)
log = logger
class WebSocketManager:
"""
Manages all active WebSocket connections.
Provides connection/disconnection handling and broadcast capabilities.
"""
def __init__(self) -> None:
self._connections: set[WebSocket] = set()
self._lock = anyio.Lock()
async def connect(self, websocket: WebSocket) -> None:
"""Accept and register a new WebSocket connection."""
await websocket.accept()
async with self._lock:
self._connections.add(websocket)
log.debug(f"WebSocket connected. Total: {len(self._connections)}")
async def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
async with self._lock:
self._connections.discard(websocket)
log.debug(f"WebSocket disconnected. Total: {len(self._connections)}")
async def send_json(self, websocket: WebSocket, data: dict[str, Any]) -> None:
"""
Send JSON data to a specific WebSocket.
Args:
websocket: Target WebSocket connection
data: Data to send (must be JSON-serializable)
"""
try:
await websocket.send_json(data)
except Exception as e:
log.warning(f"Failed to send to WebSocket: {e}")
await self.disconnect(websocket)
async def broadcast(self, data: dict[str, Any]) -> None:
"""
Broadcast JSON data to all connected WebSockets.
Args:
data: Data to broadcast (must be JSON-serializable)
"""
async with self._lock:
connections = list(self._connections)
for conn in connections:
try:
await conn.send_json(data)
except Exception as e:
log.warning(f"Failed to broadcast to WebSocket: {e}")
async with self._lock:
self._connections.discard(conn)
async def send_to(
self, websocket: WebSocket, message: str | dict[str, Any]
) -> None:
"""
Send a message to a specific WebSocket.
Args:
websocket: Target WebSocket connection
message: Message to send (string or dict)
"""
try:
if isinstance(message, str):
await websocket.send_text(message)
else:
await websocket.send_json(message)
except Exception as e:
log.warning(f"Failed to send to WebSocket: {e}")
await self.disconnect(websocket)
@property
def connection_count(self) -> int:
"""Return the number of active connections."""
return len(self._connections)

View File

@@ -1,5 +0,0 @@
"""ABP module - AstrBot Protocol client implementation (built-in plugin protocol)."""
from .client import AstrbotAbpClient
__all__ = ["AstrbotAbpClient"]

View File

@@ -1,93 +0,0 @@
"""
ABP (AstrBot Protocol) client implementation.
ABP is the built-in plugin protocol where the orchestrator acts as client
connecting to internal stars (plugins) embedded in the runtime.
"""
from __future__ import annotations
from typing import Any
from astrbot import logger
from astrbot._internal.abc.abp.base_astrbot_abp_client import BaseAstrbotAbpClient
log = logger
class AstrbotAbpClient(BaseAstrbotAbpClient):
"""
ABP client for communicating with internal stars (built-in plugins).
The orchestrator acts as the client, sending requests to and receiving
notifications from stars running within the same process.
"""
def __init__(self) -> None:
self._connected = False
self._stars: dict[str, Any] = {}
# Use a simple dict for pending requests; we avoid asyncio.Future here.
self._pending_requests: dict[str, Any] = {}
self._request_id = 0
@property
def connected(self) -> bool:
"""True if connected to stars registry."""
return self._connected
async def connect(self) -> None:
"""Connect to internal stars registry."""
log.debug("ABP client connecting to internal stars...")
self._connected = True
log.info("ABP client connected to internal stars registry.")
async def call_star_tool(
self, star_name: str, tool_name: str, arguments: dict[str, Any]
) -> Any:
"""
Call a tool on a registered star.
Args:
star_name: Name of the star (plugin)
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool call result
"""
if not self._connected:
raise RuntimeError("ABP client is not connected")
star = self._stars.get(star_name)
if not star:
raise ValueError(f"Star '{star_name}' not found")
request_id = f"{self._request_id}"
self._request_id += 1
# No asyncio.Future used; store a placeholder entry for tracking if needed.
self._pending_requests[request_id] = None
try:
# Call the star's tool handler
result = await star.call_tool(tool_name, arguments)
return result
finally:
self._pending_requests.pop(request_id, None)
def register_star(self, star_name: str, star_instance: Any) -> None:
"""Register a star (plugin) with the ABP client."""
self._stars[star_name] = star_instance
log.debug(f"Star '{star_name}' registered with ABP client.")
def unregister_star(self, star_name: str) -> None:
"""Unregister a star from the ABP client."""
self._stars.pop(star_name, None)
log.debug(f"Star '{star_name}' unregistered from ABP client.")
async def shutdown(self) -> None:
"""Shutdown the ABP client connection."""
self._connected = False
# Clear any pending requests (no asyncio futures used in this implementation)
self._pending_requests.clear()
log.info("ABP client shut down.")

View File

@@ -1,6 +0,0 @@
"""ACP module - AstrBot Communication Protocol client and server implementations."""
from .client import AstrbotAcpClient
from .server import AstrbotAcpServer
__all__ = ["AstrbotAcpClient", "AstrbotAcpServer"]

View File

@@ -1,220 +0,0 @@
"""
ACP (AstrBot Communication Protocol) client implementation.
ACP is a client-server protocol for inter-service communication,
similar to MCP but designed specifically for AstrBot's architecture.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any
from astrbot import logger
from astrbot._internal.abc.acp.base_astrbot_acp_client import BaseAstrbotAcpClient
log = logger
class AstrbotAcpClient(BaseAstrbotAcpClient):
"""
ACP client for communicating with ACP servers.
The orchestrator acts as an ACP client, connecting to external
ACP-compatible services.
"""
def __init__(self) -> None:
self._connected = False
self._reader: asyncio.StreamReader | None = None
self._writer: asyncio.StreamWriter | None = None
self._server_url: str | None = None
self._pending_requests: dict[str, asyncio.Future[dict[str, Any]]] = {}
self._request_id = 0
self._reader_task: asyncio.Task[None] | None = None
@property
def connected(self) -> bool:
"""True if connected to an ACP server."""
return self._connected
async def connect(self) -> None:
"""
Connect to configured ACP servers.
ACP servers can be accessed via TCP (host:port) or Unix socket.
"""
log.debug("ACP client connecting...")
# TODO: Load ACP server configurations
self._connected = True
log.info("ACP client initialized.")
async def connect_to_server(self, host: str, port: int) -> None:
"""
Connect to an ACP server via TCP.
Args:
host: Server hostname or IP
port: Server port
"""
self._server_url = f"{host}:{port}"
self._reader, self._writer = await asyncio.open_connection(host, port)
self._connected = True
# Start reading responses
self._reader_task = asyncio.create_task(self._read_messages())
log.info(f"ACP client connected to {self._server_url}")
async def connect_to_unix_socket(self, socket_path: str) -> None:
"""
Connect to an ACP server via Unix socket.
Args:
socket_path: Path to the Unix socket
"""
self._server_url = f"unix://{socket_path}"
self._reader, self._writer = await asyncio.open_unix_connection(socket_path)
self._connected = True
self._reader_task = asyncio.create_task(self._read_messages())
log.info(f"ACP client connected to {self._server_url}")
async def _read_messages(self) -> None:
"""Background task to read ACP messages."""
if not self._reader:
return
buffer = b""
while self._connected:
try:
data = await self._reader.read(4096)
if not data:
break
buffer += data
while True:
header_end = buffer.find(b"\n")
if header_end == -1:
break
try:
header = json.loads(buffer[:header_end].decode("utf-8"))
except json.JSONDecodeError:
buffer = buffer[header_end + 1 :]
continue
content_length = header.get("content-length", 0)
if (
content_length == 0
or len(buffer) < header_end + 1 + content_length
):
break
content = buffer[header_end + 1 : header_end + 1 + content_length]
buffer = buffer[header_end + 1 + content_length :]
message = json.loads(content.decode("utf-8"))
if "id" in message:
request_id = str(message["id"])
future = self._pending_requests.pop(request_id, None)
if future and not future.done():
if "error" in message:
future.set_exception(Exception(str(message["error"])))
else:
future.set_result(message.get("result", {}))
else:
await self._handle_notification(message)
except Exception as e:
if self._connected:
log.error(f"ACP read error: {e}")
break
async def _handle_notification(self, notification: dict[str, Any]) -> None:
"""Handle incoming ACP notifications."""
method = notification.get("method", "")
log.debug(f"ACP notification: {method}")
async def call_tool(
self, server_name: str, tool_name: str, arguments: dict[str, Any]
) -> Any:
"""
Call a tool on an ACP server.
Args:
server_name: Name of the ACP server
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool call result
"""
if not self._connected:
raise RuntimeError("ACP client is not connected")
request_id = str(self._request_id)
self._request_id += 1
message = {
"jsonrpc": "2.0",
"id": request_id,
"method": f"{server_name}/{tool_name}",
"params": arguments,
}
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
self._pending_requests[request_id] = future
await self._send_message(message)
return await future
async def _send_message(self, message: dict[str, Any]) -> None:
"""Send an ACP message."""
if not self._writer:
raise RuntimeError("ACP client not connected")
content = json.dumps(message)
header = json.dumps({"content-length": len(content)}) + "\n"
self._writer.write((header + content).encode())
await self._writer.drain()
async def send_notification(
self, method: str, params: dict[str, Any] | None = None
) -> None:
"""Send a one-way notification to the server."""
message = {
"jsonrpc": "2.0",
"method": method,
"params": params or {},
}
await self._send_message(message)
async def shutdown(self) -> None:
"""Shutdown the ACP client connection."""
self._connected = False
if self._reader_task:
self._reader_task.cancel()
try:
await self._reader_task
except asyncio.CancelledError:
pass
if self._writer:
self._writer.close()
try:
await self._writer.wait_closed()
except Exception:
pass
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
log.info("ACP client shut down.")

View File

@@ -1,223 +0,0 @@
"""
ACP (AstrBot Communication Protocol) server implementation.
ACP servers listen for connections from ACP clients and provide
services/tools to the orchestrator.
"""
from __future__ import annotations
import asyncio
import json
from collections.abc import Callable
from typing import Any
from astrbot import logger
from astrbot._internal.abc.acp.base_astrbot_acp_server import BaseAstrbotAcpServer
log = logger
class AstrbotAcpServer(BaseAstrbotAcpServer):
"""
ACP server for accepting connections from ACP clients.
ACP servers expose tools/notifications that can be called by clients.
"""
def __init__(self) -> None:
self._running = False
self._host: str = "127.0.0.1"
self._port: int = 8765
self._server: asyncio.Server | None = None
self._clients: set[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = set()
self._tool_handlers: dict[str, Callable[..., Any]] = {}
self._notification_handlers: dict[str, Callable[..., Any]] = {}
def register_tool(self, name: str, handler: Callable[..., Any]) -> None:
"""
Register a tool handler.
Args:
name: Tool name
handler: Async callable that handles tool calls
"""
self._tool_handlers[name] = handler
log.debug(f"ACP server registered tool: {name}")
def register_notification_handler(
self, name: str, handler: Callable[..., Any]
) -> None:
"""
Register a notification handler.
Args:
name: Notification method name
handler: Async callable that handles notifications
"""
self._notification_handlers[name] = handler
log.debug(f"ACP server registered notification handler: {name}")
async def start(self, host: str = "127.0.0.1", port: int = 8765) -> None:
"""
Start the ACP server.
Args:
host: Host to bind to
port: Port to listen on
"""
self._host = host
self._port = port
self._server = await asyncio.start_server(
self._handle_client,
host=host,
port=port,
)
self._running = True
log.info(f"ACP server listening on {host}:{port}")
async def _handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Handle an incoming ACP client connection."""
addr = writer.get_extra_info("peername")
log.debug(f"ACP client connected: {addr}")
self._clients.add((reader, writer))
buffer = b""
try:
while self._running:
try:
data = await reader.read(4096)
if not data:
break
buffer += data
while True:
header_end = buffer.find(b"\n")
if header_end == -1:
break
try:
header = json.loads(buffer[:header_end].decode("utf-8"))
except json.JSONDecodeError:
buffer = buffer[header_end + 1 :]
continue
content_length = header.get("content-length", 0)
if (
content_length == 0
or len(buffer) < header_end + 1 + content_length
):
break
content = buffer[
header_end + 1 : header_end + 1 + content_length
]
buffer = buffer[header_end + 1 + content_length :]
message = json.loads(content.decode("utf-8"))
response = await self._handle_message(message)
if response:
content = json.dumps(response)
resp_header = (
json.dumps({"content-length": len(content)}) + "\n"
)
writer.write(resp_header.encode() + content.encode())
await writer.drain()
except Exception as e:
log.error(f"ACP client error ({addr}): {e}")
break
finally:
self._clients.discard((reader, writer))
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
log.debug(f"ACP client disconnected: {addr}")
async def _handle_message(self, message: dict[str, Any]) -> dict[str, Any] | None:
"""Handle an incoming ACP message."""
method = message.get("method", "")
msg_id = message.get("id")
params = message.get("params", {})
# Check if it's a notification (no id) or request (has id)
if msg_id is None:
# Notification
handler = self._notification_handlers.get(method)
if handler:
try:
await handler(params)
except Exception as e:
log.error(f"ACP notification handler error ({method}): {e}")
return None
# Request
result = None
error = None
handler = self._tool_handlers.get(method)
if handler:
try:
result = await handler(params)
except Exception as e:
error = str(e)
log.error(f"ACP tool handler error ({method}): {e}")
else:
error = f"Unknown method: {method}"
response: dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id}
if error:
response["error"] = {"code": -32601, "message": error}
else:
response["result"] = result
return response
async def broadcast_notification(self, method: str, params: dict[str, Any]) -> None:
"""
Broadcast a notification to all connected clients.
Args:
method: Notification method name
params: Notification parameters
"""
message = {
"jsonrpc": "2.0",
"method": method,
"params": params,
}
content = json.dumps(message)
header = json.dumps({"content-length": len(content)}) + "\n"
data = header.encode() + content.encode()
for reader, writer in list(self._clients):
try:
writer.write(data)
await writer.drain()
except Exception as e:
log.warning(f"Failed to broadcast to client: {e}")
async def shutdown(self) -> None:
"""Shutdown the ACP server."""
self._running = False
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
for reader, writer in list(self._clients):
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
self._clients.clear()
log.info("ACP server shut down.")

View File

@@ -1,5 +0,0 @@
"""LSP module - Language Server Protocol client implementation."""
from .client import AstrbotLspClient
__all__ = ["AstrbotLspClient"]

View File

@@ -1,243 +0,0 @@
"""
LSP (Language Server Protocol) client implementation.
The orchestrator acts as an LSP client, connecting to LSP servers
that provide language intelligence features (completions, diagnostics, etc.).
"""
from __future__ import annotations
import json
from typing import Any
import anyio
from anyio.abc import ByteReceiveStream, ByteSendStream, Process
from astrbot import logger
from astrbot._internal.abc.lsp.base_astrbot_lsp_client import BaseAstrbotLspClient
log = logger
class AstrbotLspClient(BaseAstrbotLspClient):
"""
LSP client for communicating with LSP servers.
Implements the Microsoft Language Server Protocol for connecting to
external language intelligence services.
"""
def __init__(self) -> None:
self._connected = False
self._reader: ByteReceiveStream | None = None
self._writer: ByteSendStream | None = None
self._server_process: Process | None = None
self._pending_requests: dict[int, Any] = {}
self._request_id = 0
self._server_command: list[str] | None = None
# anyio TaskGroup handle for background readers
self._task_group: Any | None = None
@property
def connected(self) -> bool:
"""True if connected to an LSP server."""
return self._connected
async def connect(self) -> None:
"""
Connect to configured LSP servers.
LSP servers are typically stdio-based subprocesses. This method
establishes the communication channel.
"""
log.debug("LSP client connecting...")
# TODO: Load LSP server configurations and start subprocesses
# For now, mark as connected in idle mode
self._connected = True
log.info("LSP client initialized.")
async def connect_to_server(self, command: list[str], workspace_uri: str) -> None:
"""
Connect to an LSP server subprocess.
Args:
command: Command line to start the LSP server (e.g., ["python", "lsp_server.py"])
workspace_uri: Root URI of the workspace to serve
"""
log.debug(f"Starting LSP server: {' '.join(command)}")
self._server_process = await anyio.open_process(
command,
stdin=-1,
stdout=-1,
stderr=-1,
)
self._reader = self._server_process.stdout
self._writer = self._server_process.stdin
self._server_command = command
self._connected = True
# Start reading responses in background using anyio TaskGroup
# Create and enter a TaskGroup so the reader runs until we close it at shutdown.
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._read_responses)
# Send initialize request
await self.send_request(
"initialize",
{
"processId": None,
"rootUri": workspace_uri,
"capabilities": {},
},
)
# Send initialized notification
await self.send_notification("initialized", {})
log.info(f"LSP client connected to server: {command[0]}")
async def send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send an LSP request and wait for response."""
if not self._writer:
raise RuntimeError("LSP client not connected")
request_id = self._request_id
self._request_id += 1
message = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params or {},
}
# Use anyio.Event for request/response matching
response_event: anyio.Event = anyio.Event()
response_holder: dict[str, Any] = {}
async def set_response(response: dict[str, Any]) -> None:
response_holder["response"] = response
response_event.set()
self._pending_requests[request_id] = set_response
content = json.dumps(message)
headers = f"Content-Length: {len(content)}\r\n\r\n"
await self._writer.send((headers + content).encode())
# Wait for response with timeout
with anyio.move_on_after(30):
await response_event.wait()
if "response" in response_holder:
return response_holder["response"]
raise TimeoutError(f"LSP request {method} timed out")
async def send_notification(
self, method: str, params: dict[str, Any] | None = None
) -> None:
"""Send an LSP notification (no response expected)."""
if not self._writer:
raise RuntimeError("LSP client not connected")
message = {
"jsonrpc": "2.0",
"method": method,
"params": params or {},
}
content = json.dumps(message)
headers = f"Content-Length: {len(content)}\r\n\r\n"
await self._writer.send((headers + content).encode())
async def _read_responses(self) -> None:
"""Background task to read LSP responses."""
if not self._reader:
return
buffer = b""
try:
while self._connected:
try:
data = await self._reader.receive()
if not data:
break
buffer += data
while True:
# Parse Content-Length header
header_end = buffer.find(b"\r\n\r\n")
if header_end == -1:
break
header = buffer[:header_end].decode("utf-8")
content_length = 0
for line in header.split("\r\n"):
if line.startswith("Content-Length:"):
content_length = int(line.split(":")[1].strip())
if content_length == 0:
break
total_length = header_end + 4 + content_length
if len(buffer) < total_length:
break
content = buffer[header_end + 4 : total_length]
buffer = buffer[total_length:]
response = json.loads(content.decode("utf-8"))
# Handle response vs notification
if "id" in response:
request_id = response["id"]
handler = self._pending_requests.pop(request_id, None)
if handler:
await handler(response)
else:
# Notification (e.g., window/logMessage)
await self._handle_notification(response)
except anyio.EndOfStream:
break
except anyio.get_cancelled_exc_class():
# Task was cancelled via the TaskGroup cancel/exit during shutdown
pass
async def _handle_notification(self, notification: dict[str, Any]) -> None:
"""Handle incoming LSP notifications."""
method = notification.get("method", "")
log.debug(f"LSP notification: {method}")
async def shutdown(self) -> None:
"""Shutdown the LSP client."""
self._connected = False
if self._task_group:
try:
# Exit the TaskGroup, which cancels background tasks started within it
await self._task_group.__aexit__(None, None, None)
except anyio.get_cancelled_exc_class():
pass
self._task_group = None
if self._server_process:
try:
await self.send_notification("shutdown", {})
except Exception:
pass
self._server_process.terminate()
try:
with anyio.move_on_after(5):
await self._server_process.wait()
except Exception:
self._server_process.kill()
self._server_process = None
self._pending_requests.clear()
log.info("LSP client shut down.")

View File

@@ -1,63 +0,0 @@
"""MCP module - Model Context Protocol client and tool implementations.
This module provides MCP client functionality and MCP tool wrappers.
"""
import asyncio
from dataclasses import dataclass
from .client import McpClient
from .config import (
DEFAULT_MCP_CONFIG,
get_mcp_config_path,
load_mcp_config,
save_mcp_config,
)
from .tool import MCPTool
# Exceptions
class MCPInitError(Exception):
"""Base exception for MCP initialization failures."""
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
"""Raised when MCP client initialization exceeds the configured timeout."""
class MCPAllServicesFailedError(MCPInitError):
"""Raised when all configured MCP services fail to initialize."""
class MCPShutdownTimeoutError(asyncio.TimeoutError):
"""Raised when MCP shutdown exceeds the configured timeout."""
def __init__(self, names: list[str], timeout: float) -> None:
self.names = names
self.timeout = timeout
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
super().__init__(message)
@dataclass
class MCPInitSummary:
"""Summary of MCP initialization results."""
total: int
success: int
failed: list[str]
__all__ = [
"DEFAULT_MCP_CONFIG",
"MCPAllServicesFailedError",
"MCPInitError",
"MCPInitSummary",
"MCPInitTimeoutError",
"MCPShutdownTimeoutError",
"MCPTool",
"McpClient",
"get_mcp_config_path",
"load_mcp_config",
"save_mcp_config",
]

View File

@@ -1,486 +0,0 @@
"""MCP client implementation."""
import asyncio
import logging
import os
import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, cast
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from astrbot._internal.abc.mcp.base_astrbot_mcp_client import (
BaseAstrbotMcpClient,
McpServerConfig,
McpToolInfo,
)
from astrbot.core.utils.log_pipe import LogPipe
logger = logging.getLogger("astrbot")
try:
import anyio
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
)
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
)
class TenacityLogger:
"""Wraps a logging.Logger to satisfy tenacity's LoggerProtocol."""
__slots__ = ("_logger",)
_logger: logging.Logger
def __init__(self, logger: logging.Logger) -> None:
self._logger = logger
def log(
self,
level: int,
msg: str,
/,
*args: Any,
**kwargs: Any,
) -> None:
self._logger.log(level, msg, *args, **kwargs)
def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format."""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
def _prepare_stdio_env(config: dict) -> dict:
"""Preserve Windows executable resolution for stdio subprocesses."""
if sys.platform != "win32":
return config
pathext = os.environ.get("PATHEXT")
if not pathext:
return config
prepared = config.copy()
env = dict(prepared.get("env") or {})
env.setdefault("PATHEXT", pathext)
prepared["env"] = env
return prepared
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity."""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP connection config missing transport or type field")
async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
except Exception as e:
return False, f"{e!s}"
class McpClient(BaseAstrbotMcpClient):
def __init__(self) -> None:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = anyio.Event()
self.process_pid: int | None = None
# Store connection config for reconnection
self._mcp_server_config: McpServerConfig | None = None
self._server_name: str | None = None
self._reconnect_lock = anyio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
async def connect(self) -> None:
"""Initialize the MCP client connection.
Note: Actual server connections are made via connect_to_server().
This method prepares the client for use.
"""
# MCP client is initialized on-demand via connect_to_server
# This is a no-op stub to satisfy BaseAstrbotMcpClient
logger.debug("MCP client initialized.")
@property
def connected(self) -> bool:
"""True if MCP client has an active session."""
return self.session is not None
async def list_tools(self) -> list[McpToolInfo]:
"""List all tools from connected MCP servers."""
if not self.session:
return []
result = await self.list_tools_and_save()
tools = [
{
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema,
}
for tool in result.tools
]
return cast(list[McpToolInfo], tools)
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: int = 60,
) -> Any:
"""Call a tool on the MCP server with reconnection support."""
return await self.call_tool_with_reconnect(
tool_name=name,
arguments=arguments,
read_timeout_seconds=timedelta(seconds=read_timeout_seconds),
)
@staticmethod
def _extract_stdio_process_pid(streams_context: object) -> int | None:
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
TODO(refactor): replace this async-generator frame introspection with a
stable MCP library hook once the upstream transport exposes process PID.
"""
generator = getattr(streams_context, "gen", None)
frame = getattr(generator, "ag_frame", None)
if frame is None:
return None
process = frame.f_locals.get("process")
pid = getattr(process, "pid", None)
try:
return int(pid) if pid is not None else None
except (TypeError, ValueError):
return None
async def connect_to_server(self, config: McpServerConfig, name: str) -> None:
"""Connect to MCP server
If `url` parameter exists:
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
2. When transport is specified as `sse`, use SSE connection.
3. If not specified, default to SSE connection to MCP service.
Args:
config: Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
# Store config for reconnection
self._mcp_server_config = config
self._server_name = name
self.process_pid = None
cfg = _prepare_config(dict(config))
def logging_callback(
msg: str | mcp.types.LoggingMessageNotificationParams,
) -> None:
# Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
self.server_errlogs.append(log_msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP connection config missing transport or type field")
if transport_type != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context,
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=cast(Any, logging_callback),
),
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5),
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context,
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
),
)
else:
cfg = _prepare_stdio_env(cfg)
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None:
# Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in (
"warning",
"error",
"critical",
"alert",
"emergency",
):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
self.server_errlogs.append(log_msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=cast(
Any,
LogPipe(
level=logging.INFO,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
),
),
)
self.process_pid = self._extract_stdio_process_pid(stdio_transport)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport),
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response
async def _reconnect(self) -> None:
"""Reconnect to the MCP server using the stored configuration.
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
Raises:
Exception: raised when reconnection fails
"""
async with self._reconnect_lock:
# Check if already reconnecting (useful for logging)
if self._reconnecting:
logger.debug(
f"MCP Client {self._server_name} is already reconnecting, skipping"
)
return
if not self._mcp_server_config or not self._server_name:
raise Exception("Cannot reconnect: missing connection configuration")
self._reconnecting = True
try:
logger.info(
f"Attempting to reconnect to MCP server {self._server_name}..."
)
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
if self.exit_stack:
self._old_exit_stacks.append(self.exit_stack)
# Mark old session as invalid
self.session = None
# Create new exit stack for new connection
self.exit_stack = AsyncExitStack()
# Reconnect using stored config
await self.connect_to_server(self._mcp_server_config, self._server_name)
await self.list_tools_and_save()
logger.info(
f"Successfully reconnected to MCP server {self._server_name}"
)
except Exception as e:
logger.error(
f"Failed to reconnect to MCP server {self._server_name}: {e}"
)
raise
finally:
self._reconnecting = False
async def call_tool_with_reconnect(
self,
tool_name: str,
arguments: dict,
read_timeout_seconds: timedelta,
) -> mcp.types.CallToolResult:
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
Args:
tool_name: tool name
arguments: tool arguments
read_timeout_seconds: read timeout
Returns:
MCP tool call result
Raises:
ValueError: MCP session is not available
anyio.ClosedResourceError: raised after reconnection failure
"""
@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(TenacityLogger(logger), logging.WARNING),
reraise=True,
)
async def _call_with_retry():
if not self.session:
raise ValueError("MCP session is not available for MCP function tools.")
try:
return await self.session.call_tool(
name=tool_name,
arguments=arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
logger.warning(
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
)
# Attempt to reconnect
await self._reconnect()
# Reraise the exception to trigger tenacity retry
raise
return await _call_with_retry()
async def cleanup(self) -> None:
"""Clean up resources including old exit stacks from reconnections"""
# Close current exit stack
try:
await self.exit_stack.aclose()
except Exception as e:
logger.debug(f"Error closing current exit stack: {e}")
# Don't close old exit stacks as they may be in different task contexts
# They will be garbage collected naturally
# Just clear the list to release references
self._old_exit_stacks.clear()
# Set running_event first to unblock any waiting tasks
self.running_event.set()
self.process_pid = None

View File

@@ -1,55 +0,0 @@
"""MCP configuration management."""
import json
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
def get_mcp_config_path() -> str:
"""Get the path to the MCP configuration file."""
data_dir = get_astrbot_data_path()
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config() -> dict:
"""Load MCP configuration from file.
Returns:
MCP configuration dict. If file doesn't exist, returns default config.
"""
config_path = get_mcp_config_path()
if not os.path.exists(config_path):
# Create default config if not exists
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
try:
with open(config_path, encoding="utf-8") as f:
return json.load(f)
except Exception:
return DEFAULT_MCP_CONFIG
def save_mcp_config(config: dict) -> bool:
"""Save MCP configuration to file.
Args:
config: MCP configuration dict to save.
Returns:
True if successful, False otherwise.
"""
config_path = get_mcp_config_path()
try:
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception:
return False

View File

@@ -1,45 +0,0 @@
"""MCP tool wrapper."""
from datetime import timedelta
from typing import TYPE_CHECKING, Any
try:
import mcp
except (ModuleNotFoundError, ImportError):
mcp = None # type: ignore
from astrbot._internal.tools.base import FunctionTool
if TYPE_CHECKING:
from astrbot._internal.protocols.mcp.client import McpClient
class MCPTool(FunctionTool):
"""A function tool that calls an MCP service."""
def __init__(
self,
mcp_tool: "mcp.types.Tool",
mcp_client: "McpClient",
mcp_server_name: str,
**kwargs: Any,
) -> None:
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self.mcp_server_name = mcp_server_name
self.source = "mcp"
async def call(self, **kwargs: Any) -> Any:
"""Call the MCP tool with the given arguments."""
# Note: For actual usage, context.tool_call_timeout is needed
# but for simplicity we use a default timeout here
return await self.mcp_client.call_tool_with_reconnect(
tool_name=self.mcp_tool.name,
arguments=kwargs,
read_timeout_seconds=timedelta(seconds=60),
)

View File

@@ -1,3 +0,0 @@
from astrbot._internal.runtime.__main__ import bootstrap
__all__ = ["bootstrap"]

View File

@@ -1,24 +0,0 @@
from __future__ import annotations
import anyio
from astrbot._internal.abc.base_astrbot_gateway import BaseAstrbotGateway
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.geteway.server import AstrbotGateway
from astrbot._internal.runtime.orchestrator import AstrbotOrchestrator
async def bootstrap():
orchestrator: BaseAstrbotOrchestrator = AstrbotOrchestrator()
gw: BaseAstrbotGateway = AstrbotGateway(orchestrator)
# anyio 的结构化并发
async with anyio.create_task_group() as tg:
tg.start_soon(orchestrator.lsp.connect) # 启动 LSP client
tg.start_soon(orchestrator.mcp.connect) # 启动 MCP client
tg.start_soon(orchestrator.acp.connect) # 启动 ACP client
tg.start_soon(orchestrator.abp.connect) # 启动 ABP client
await anyio.sleep(0.5)
tg.start_soon(orchestrator.run_loop) # 启动编排器循环
tg.start_soon(gw.serve) # 面板后端服务

View File

@@ -1,164 +0,0 @@
"""
AstrBot Orchestrator - core runtime that coordinates all protocol clients.
The orchestrator manages the lifecycle of LSP, MCP, ACP, and ABP clients,
and runs the main event loop that dispatches messages between components.
"""
from __future__ import annotations
from typing import Any
import anyio
from astrbot import logger
from astrbot._internal.abc.base_astrbot_orchestrator import BaseAstrbotOrchestrator
from astrbot._internal.protocols.abp.client import AstrbotAbpClient
from astrbot._internal.protocols.acp.client import AstrbotAcpClient
from astrbot._internal.protocols.lsp.client import AstrbotLspClient
from astrbot._internal.protocols.mcp.client import McpClient
from astrbot._internal.stars import RuntimeStatusStar
log = logger
class AstrbotOrchestrator(BaseAstrbotOrchestrator):
"""
Core runtime orchestrator for AstrBot.
Manages:
- LSP client: Language Server Protocol for editor integrations
- MCP client: Model Context Protocol for external tool servers
- ACP client: AstrBot Communication Protocol for inter-service communication
- ABP client: AstrBot Protocol for built-in star (plugin) communication
"""
def __init__(self) -> None:
# Initialize protocol clients (use concrete types for full method access)
self.lsp = AstrbotLspClient()
self.mcp = McpClient()
self.acp = AstrbotAcpClient()
self.abp = AstrbotAbpClient()
self._running = False
self._stars: dict[str, Any] = {}
self._message_count: int = 0
self._last_activity_timestamp: float | None = None
# Auto-register RuntimeStatusStar
self._runtime_status_star = RuntimeStatusStar()
self._runtime_status_star.set_orchestrator(self)
self._stars["runtime-status-star"] = self._runtime_status_star
self.abp.register_star("runtime-status-star", self._runtime_status_star)
log.debug("AstrbotOrchestrator initialized.")
async def start(self) -> None:
"""
Initialize all protocol clients.
Calls connect() on all protocol clients to prepare them for use.
"""
log.info("Starting AstrbotOrchestrator...")
await self.lsp.connect()
await self.mcp.connect()
await self.acp.connect()
await self.abp.connect()
self._running = True
log.info("AstrbotOrchestrator started.")
async def run_loop(self) -> None:
"""
Main orchestrator event loop.
This loop runs continuously, handling:
- Periodic health checks of protocol clients
- Message routing between protocols
- Star (plugin) lifecycle management
"""
self._running = True
log.info("AstrbotOrchestrator run loop started.")
stop_event = anyio.Event()
def set_stop() -> None:
stop_event.set()
# Store the callback for cleanup
self._stop_callback = set_stop
try:
while self._running:
# TODO: Periodic tasks:
# - Check LSP server health
# - Check MCP session status
# - Check ACP client connections
# - Process any pending star notifications
# Wait for 5 seconds or until shutdown is called
with anyio.move_on_after(5):
await stop_event.wait()
except anyio.get_cancelled_exc_class():
log.info("Orchestrator run loop cancelled.")
finally:
self._running = False
self._stop_callback = None
log.info("AstrbotOrchestrator run loop stopped.")
async def register_star(self, name: str, star_instance: Any) -> None:
"""
Register a star (plugin) with the orchestrator.
Args:
name: Unique name for the star
star_instance: Star plugin instance
"""
self._stars[name] = star_instance
self.abp.register_star(name, star_instance)
log.info(f"Star '{name}' registered.")
async def unregister_star(self, name: str) -> None:
"""
Unregister a star (plugin) from the orchestrator.
Args:
name: Name of the star to unregister
"""
self._stars.pop(name, None)
self.abp.unregister_star(name)
log.info(f"Star '{name}' unregistered.")
async def get_star(self, name: str) -> Any | None:
"""Get a registered star by name."""
return self._stars.get(name)
async def list_stars(self) -> list[str]:
"""List all registered star names."""
return list(self._stars.keys())
def record_activity(self) -> None:
"""Record a message activity for stats tracking."""
self._message_count += 1
import time
self._last_activity_timestamp = time.time()
async def shutdown(self) -> None:
"""
Shutdown the orchestrator and all protocol clients.
"""
log.info("Shutting down AstrbotOrchestrator...")
self._running = False
# Shutdown all protocol clients
await self.lsp.shutdown()
await self.acp.shutdown()
await self.abp.shutdown()
# MCP cleanup
await self.mcp.cleanup()
log.info("AstrbotOrchestrator shut down.")

View File

@@ -1,18 +0,0 @@
import sys
try:
from ._core import cli as _cli
def cli():
if len(sys.argv) == 1:
sys.argv.append("--help")
return _cli()
except ImportError:
from click import echo
def cli():
echo("""
AstrBot CLI(rust) is not available.
Developer: maturin dev
User: uv run astrbot-rs
""")

View File

@@ -1,16 +0,0 @@
from typing import Any
class AstrbotOrchestrator:
def start(self) -> None: ...
def stop(self) -> None: ...
def is_running(self) -> bool: ...
def register_star(self, name: str, handler: str) -> None: ...
def unregister_star(self, name: str) -> None: ...
def list_stars(self) -> list[str]: ...
def record_activity(self) -> None: ...
def get_stats(self) -> dict[str, Any]: ...
def set_protocol_connected(self, protocol: str, connected: bool) -> None: ...
def get_protocol_status(self, protocol: str) -> dict[str, Any] | None: ...
def get_orchestrator() -> AstrbotOrchestrator: ...
def cli() -> None: ...

View File

@@ -1,13 +0,0 @@
"""Internal skills module - re-exports from core.skills.skill_manager."""
from astrbot.core.skills.skill_manager import (
SkillInfo,
SkillManager,
build_skills_prompt,
)
__all__ = [
"SkillInfo",
"SkillManager",
"build_skills_prompt",
]

View File

@@ -1,7 +0,0 @@
"""
Stars (built-in plugins) for AstrBot runtime.
"""
from astrbot._internal.stars.runtime_status_star import RuntimeStatusStar
__all__ = ["RuntimeStatusStar"]

View File

@@ -1,127 +0,0 @@
"""
RuntimeStatusStar - ABP plugin that exposes core runtime internal state.
This star provides tools for querying:
- Runtime status (running state, uptime)
- Protocol client status (LSP, MCP, ACP, ABP)
- Registered stars registry
- Message counts and metrics
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
@dataclass
class RuntimeStatusStar:
"""
ABP star that exposes core runtime internal state as callable tools.
Tools provided:
- get_runtime_status: Returns running state and uptime
- get_protocol_status: Returns LSP, MCP, ACP, ABP status
- get_star_registry: Returns registered star names
- get_stats: Returns message counts and metrics
"""
name: str = "runtime-status-star"
description: str = "ABP plugin that exposes core runtime internal state"
_start_time: float = field(default_factory=time.time, init=False)
_orchestrator: Any = field(default=None, init=False)
def set_orchestrator(self, orchestrator: Any) -> None:
"""Set the orchestrator reference for status queries."""
self._orchestrator = orchestrator
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""
Handle tool calls from ABP client.
Args:
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool result
"""
if tool_name == "get_runtime_status":
return self._get_runtime_status()
elif tool_name == "get_protocol_status":
return await self._get_protocol_status()
elif tool_name == "get_star_registry":
return await self._get_star_registry()
elif tool_name == "get_stats":
return self._get_stats()
else:
raise ValueError(f"Unknown tool: {tool_name}")
def _get_runtime_status(self) -> dict[str, Any]:
"""Get overall runtime state."""
running = (
getattr(self._orchestrator, "running", False)
if self._orchestrator
else False
)
uptime_seconds = time.time() - self._start_time
return {
"running": running,
"uptime_seconds": uptime_seconds,
}
async def _get_protocol_status(self) -> dict[str, Any]:
"""Get status of each protocol client."""
if not self._orchestrator:
return {
"lsp": {"connected": False, "name": "lsp-client"},
"mcp": {"connected": False, "name": "mcp-client"},
"acp": {"connected": False, "name": "acp-client"},
"abp": {"connected": False, "name": "abp-client"},
}
return {
"lsp": {
"connected": getattr(self._orchestrator.lsp, "connected", False),
"name": "lsp-client",
},
"mcp": {
"connected": getattr(self._orchestrator.mcp, "connected", False),
"name": "mcp-client",
},
"acp": {
"connected": getattr(self._orchestrator.acp, "connected", False),
"name": "acp-client",
},
"abp": {
"connected": getattr(self._orchestrator.abp, "connected", False),
"name": "abp-client",
},
}
async def _get_star_registry(self) -> dict[str, Any]:
"""Get list of registered stars."""
if not self._orchestrator:
return {"stars": []}
stars = await self._orchestrator.list_stars()
return {"stars": stars}
def _get_stats(self) -> dict[str, Any]:
"""Get message counts and metrics."""
result: dict[str, Any] = {
"uptime_seconds": time.time() - self._start_time,
}
if self._orchestrator:
result["total_messages"] = getattr(self._orchestrator, "_message_count", 0)
last_ts = getattr(self._orchestrator, "_last_activity_timestamp", None)
if last_ts is not None:
result["last_activity"] = datetime.fromtimestamp(
last_ts, tz=timezone.utc
).isoformat()
else:
result["last_activity"] = None
return result

View File

@@ -1,5 +0,0 @@
"""Internal tools module for AstrBot runtime."""
from .base import FunctionTool, ToolSet
__all__ = ["FunctionTool", "ToolSet"]

View File

@@ -1,332 +0,0 @@
"""Base tool classes for AstrBot internal runtime.
This module provides the FunctionTool base class used by MCP tools
in the new internal architecture.
"""
import copy
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator
from dataclasses import dataclass, field
from typing import Any
from pydantic import model_validator
ParametersType = dict[str, Any]
@dataclass
class ToolSchema:
"""A class representing the schema of a tool for function calling."""
name: str
"""The name of the tool."""
description: str
"""The description of the tool."""
parameters: ParametersType = field(default_factory=dict)
"""The parameters of the tool, in JSON Schema format."""
@model_validator(mode="after")
def validate_parameters(self) -> "ToolSchema":
"""Validate the parameters JSON schema."""
import jsonschema
jsonschema.validate(
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
)
return self
@dataclass
class FunctionTool(ToolSchema):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[str | None] | AsyncGenerator[Any, None]] | None = (
None
)
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
"""
The module path of the handler function. This is empty when the origin is mcp.
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
causing the handler's __module__ to be functools
"""
active: bool = True
"""
Whether the tool is active. This field is a special field for AstrBot.
You can ignore it when integrating with other frameworks.
"""
is_background_task: bool = False
"""
Declare this tool as a background task. Background tasks return immediately
with a task identifier while the real work continues asynchronously.
"""
source: str = "mcp"
"""
Origin of this tool: 'plugin' (from star plugins), 'internal' (AstrBot built-in),
or 'mcp' (from MCP servers). Used by WebUI for display grouping.
"""
def __repr__(self) -> str:
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, **kwargs: Any) -> Any:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
class ToolSet:
"""
A collection of FunctionTools grouped under a namespace.
ToolSets allow organizing related tools together. The LLM sees tools
as "namespace/tool_name" when calling.
"""
def __init__(self, namespace: str, tools: list[FunctionTool] | None = None) -> None:
self.namespace = namespace
self._tools: dict[str, FunctionTool] = {}
if tools:
for tool in tools:
self.add(tool)
def add(self, tool: FunctionTool) -> None:
"""Add a tool to the set."""
self._tools[tool.name] = tool
def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set (alias for add())."""
self.add(tool)
def remove(self, name: str) -> FunctionTool | None:
"""Remove and return a tool by name."""
return self._tools.pop(name, None)
def remove_tool(self, name: str) -> None:
"""Remove a tool by its name."""
self._tools.pop(name, None)
def get(self, name: str) -> FunctionTool | None:
"""Get a tool by name."""
return self._tools.get(name)
def get_tool(self, name: str) -> FunctionTool | None:
"""Get a tool by name (alias for get)."""
return self.get(name)
def list_tools(self) -> list[FunctionTool]:
"""List all tools in this set."""
return list(self._tools.values())
def __iter__(self) -> Iterator[FunctionTool]:
return iter(self._tools.values())
def __len__(self) -> int:
return len(self._tools)
def __bool__(self) -> bool:
return bool(self._tools)
def __repr__(self) -> str:
return f"ToolSet(namespace={self.namespace!r}, tools={self.list_tools()!r})"
def __str__(self) -> str:
return f"ToolSet({self.namespace}, {len(self)} tools)"
def names(self) -> list[str]:
"""Get names of all tools in this set."""
return [tool.name for tool in self.tools]
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self) == 0
def merge(self, other: "ToolSet") -> None:
"""Merge another ToolSet into this one."""
for tool in other.tools:
self.add(tool)
def normalize(self) -> None:
"""Sort tools by name for deterministic serialization."""
self._tools = dict(sorted(self._tools.items(), key=lambda x: x[0]))
def get_light_tool_set(self) -> "ToolSet":
"""Return a light tool set with only name/description."""
light_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
light_tools.append(
FunctionTool(
name=tool.name,
description=tool.description,
parameters={"type": "object", "properties": {}},
handler=None,
)
)
return ToolSet("default", light_tools)
def get_param_only_tool_set(self) -> "ToolSet":
"""Return a tool set with name/parameters only (no description)."""
param_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
params = (
copy.deepcopy(tool.parameters)
if tool.parameters
else {"type": "object", "properties": {}}
)
param_tools.append(
FunctionTool(
name=tool.name,
description="",
parameters=params,
handler=None,
)
)
return ToolSet("default", param_tools)
@property
def tools(self) -> list[FunctionTool]:
"""List all tools in this set."""
return list(self._tools.values())
def openai_schema(
self, omit_empty_parameter_field: bool = False
) -> list[dict[str, Any]]:
"""Convert tools to OpenAI API function calling schema format."""
result: list[dict[str, Any]] = []
for tool in self._tools.values():
func_def: dict[str, Any] = {
"type": "function",
"function": {"name": tool.name},
}
if tool.description:
func_def["function"]["description"] = tool.description
if tool.parameters is not None:
if (
tool.parameters.get("properties")
) or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema: dict[str, Any] = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def: dict[str, Any] = {"name": tool.name, "input_schema": input_schema}
if tool.description:
tool_def["description"] = tool.description
result.append(tool_def)
return result
def google_schema(self) -> dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
origin_type = schema.get("type")
target_type = origin_type
if isinstance(origin_type, list):
target_type = next((t for t in origin_type if t != "null"), "string")
if target_type in supported_types:
result["type"] = target_type
if "format" in schema and schema["format"] in supported_formats.get(result["type"], set()):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
if "additionalProperties" in prop_value:
del prop_value["additionalProperties"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if target_type == "array":
items_schema = schema.get("items")
if isinstance(items_schema, dict):
result["items"] = convert_schema(items_schema)
else:
result["items"] = {"type": "string"}
return result
tools_list = []
for tool in self.tools:
d: dict[str, Any] = {"name": tool.name}
if tool.description:
d["description"] = tool.description
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools_list.append(d)
declarations: dict[str, Any] = {}
if tools_list:
declarations["function_declarations"] = tools_list
return declarations
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
"""Get tools in OpenAI function calling style (deprecated)."""
return self.openai_schema(omit_empty_parameter_field)
def get_func_desc_anthropic_style(self):
"""Get tools in Anthropic style (deprecated)."""
return self.anthropic_schema()
def get_func_desc_google_genai_style(self):
"""Get tools in Google GenAI style (deprecated)."""
return self.google_schema()

View File

@@ -1,48 +0,0 @@
"""
Builtin tools for AstrBot - re-exports from core.tools for backward compatibility.
This module re-exports the builtin tools (cron, send_message, kb_query) from
the deprecated core.tools module for backward compatibility.
TODO: These tools should be fully migrated to _internal and core.tools
should be removed once all consumers update their imports.
"""
from __future__ import annotations
# Re-export cron tools
from astrbot.core.tools.cron_tools import (
CREATE_CRON_JOB_TOOL,
DELETE_CRON_JOB_TOOL,
LIST_CRON_JOBS_TOOL,
CreateActiveCronTool,
DeleteCronJobTool,
ListCronJobsTool,
)
# Re-export knowledge_base_query tool
from astrbot.core.tools.kb_query import (
KNOWLEDGE_BASE_QUERY_TOOL,
KnowledgeBaseQueryTool,
)
# Re-export send_message tool
from astrbot.core.tools.send_message import (
SEND_MESSAGE_TO_USER_TOOL,
SendMessageToUserTool,
)
__all__ = [
# Cron tools
"CREATE_CRON_JOB_TOOL",
"DELETE_CRON_JOB_TOOL",
"KNOWLEDGE_BASE_QUERY_TOOL",
"LIST_CRON_JOBS_TOOL",
"SEND_MESSAGE_TO_USER_TOOL",
# Classes
"CreateActiveCronTool",
"DeleteCronJobTool",
"KnowledgeBaseQueryTool",
"ListCronJobsTool",
"SendMessageToUserTool",
]

View File

@@ -1,278 +0,0 @@
"""Tools registry for AstrBot internal runtime."""
from __future__ import annotations
from typing import Any
# Re-export from base
from astrbot._internal.tools.base import FunctionTool, ToolSet
__all__ = [
"DEFAULT_MCP_CONFIG",
"ENABLE_MCP_TIMEOUT_ENV",
"FuncCall",
"FunctionTool",
"FunctionToolManager",
"MCPAllServicesFailedError",
"MCPInitError",
"MCPInitSummary",
"MCPInitTimeoutError",
"MCPShutdownTimeoutError",
"ToolSet",
]
# MCP config constants (re-exported from protocols)
try:
from astrbot._internal.protocols.mcp import (
DEFAULT_MCP_CONFIG,
MCPAllServicesFailedError,
MCPInitError,
MCPInitSummary,
MCPInitTimeoutError,
MCPShutdownTimeoutError,
)
except ImportError:
DEFAULT_MCP_CONFIG: dict[str, Any] = {}
MCPAllServicesFailedError: type[Exception] = Exception
MCPInitError: type[Exception] = Exception
MCPInitSummary: type[dict] = dict
MCPInitTimeoutError: type[TimeoutError] = TimeoutError
MCPShutdownTimeoutError: type[TimeoutError] = TimeoutError
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_TIMEOUT_ENABLED"
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
class FunctionToolManager:
"""Central registry for all function tools."""
def __init__(self) -> None:
self._func_list: list[FunctionTool] = []
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return self._func_list
@func_list.setter
def func_list(self, value: list[FunctionTool]) -> None:
"""Set the list of function tools."""
self._func_list = value
def add(self, tool: FunctionTool) -> None:
"""Add a tool to the registry."""
self._func_list.append(tool)
def remove(self, name: str) -> bool:
"""Remove a tool by name. Returns True if found."""
for i, f in enumerate(self._func_list):
if f.name == name:
self._func_list.pop(i)
return True
return False
def get_func(self, name: str) -> FunctionTool | None:
"""Get a tool by name. Returns the last active tool if multiple match."""
last_match: FunctionTool | None = None
for f in reversed(self._func_list):
if f.name == name:
if getattr(f, "active", True):
return f
if last_match is None:
last_match = f
return last_match
def get_full_tool_set(self) -> ToolSet:
"""Return a ToolSet with all active tools, deduplicated by name."""
seen: dict[str, FunctionTool] = {}
for tool in reversed(self._func_list):
if tool.name not in seen and getattr(tool, "active", True):
seen[tool.name] = tool
return ToolSet("default", list(seen.values()))
def register_internal_tools(self) -> None:
"""Register built-in computer tools (shell, python, browser, neo)."""
# Import here to avoid circular imports
from astrbot.core.computer.computer_tool_provider import get_all_tools
for tool in get_all_tools():
if self.get_func(tool.name) is None:
self.add(tool)
# MCP-related stub methods for base class compatibility
async def enable_mcp_server(
self, name: str, config: dict[str, Any], init_timeout: int = 30
) -> None:
"""Enable an MCP server (stub)."""
pass
async def disable_mcp_server(
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
) -> None:
"""Disable an MCP server (stub)."""
pass
async def init_mcp_clients(self) -> None:
"""Initialize MCP clients (stub)."""
pass
async def test_mcp_server_connection(
self, config: dict[str, Any]
) -> tuple[bool, str]:
"""Test MCP server connection (stub)."""
return False, "Not implemented"
async def sync_modelscope_mcp_servers(self) -> None:
"""Sync ModelScope MCP servers (stub)."""
pass
def load_mcp_config(self) -> dict[str, Any]:
"""Load MCP configuration (stub)."""
return {"mcpServers": {}}
def save_mcp_config(self, config: dict[str, Any]) -> bool:
"""Save MCP configuration (stub)."""
return True
def activate_llm_tool(self, name: str) -> bool:
"""Activate an LLM tool (stub)."""
return True
def deactivate_llm_tool(self, name: str) -> bool:
"""Deactivate an LLM tool (stub)."""
return True
@property
def mcp_client_dict(self) -> dict[str, Any]:
"""Return dict of MCP clients (stub)."""
return {}
@property
def mcp_server_runtime_view(self) -> dict[str, Any]:
"""Return runtime view of MCP servers (stub)."""
return {}
class FuncCall(FunctionToolManager):
"""Alias for FunctionToolManager for backward compatibility."""
def __init__(self) -> None:
super().__init__()
self._mcp_server_runtime_view: dict[str, Any] = {}
self._mcp_client_dict: dict[str, Any] = {}
@property
def mcp_server_runtime_view(self) -> dict[str, Any]:
"""Return runtime view of MCP servers."""
return self._mcp_server_runtime_view
@property
def mcp_client_dict(self) -> dict[str, Any]:
"""Return dict of MCP clients (for backward compatibility)."""
return self._mcp_client_dict
async def init_mcp_clients(self) -> None:
"""Initialize MCP clients (stub implementation)."""
pass
def add_func(
self,
name: str,
func_args: list[dict[str, Any]],
desc: str,
handler: Any,
) -> None:
"""Add a function tool (deprecated, use add() instead)."""
params: dict[str, Any] = {
"type": "object",
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param.get("type", "string"),
"description": param.get("description", ""),
}
func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add(func)
def remove_func(self, name: str) -> None:
"""Remove a function tool by name (deprecated, use remove() instead)."""
self.remove(name)
def get_func(self, name: str) -> FunctionTool | None:
"""Get a function tool by name."""
return super().get_func(name)
def names(self) -> list[str]:
"""Get all tool names."""
return [f.name for f in self.func_list]
def remove_tool(self, name: str) -> None:
"""Remove a tool by its name (alias for remove)."""
self.remove(name)
def get_func_desc_openai_style(
self, omit_empty_parameter_field: bool = False
) -> list[dict[str, Any]]:
"""Get tools in OpenAI style (deprecated, use get_full_tool_set().openai_schema())."""
tool_set = self.get_full_tool_set()
return tool_set.openai_schema(omit_empty_parameter_field)
async def enable_mcp_server(
self, name: str, config: dict[str, Any], init_timeout: int = 30
) -> None:
"""Enable an MCP server (stub implementation)."""
pass
async def disable_mcp_server(
self, name: str = "", timeout: int = 10, shutdown_timeout: int = 10
) -> None:
"""Disable an MCP server (stub implementation)."""
pass
def load_mcp_config(self) -> dict[str, Any]:
"""Load MCP configuration (stub implementation)."""
return {"mcpServers": {}}
def save_mcp_config(self, config: dict[str, Any]) -> bool:
"""Save MCP configuration (stub implementation)."""
return True
def activate_llm_tool(self, name: str) -> bool:
"""Activate an LLM tool (stub implementation)."""
return True
def deactivate_llm_tool(self, name: str) -> bool:
"""Deactivate an LLM tool (stub implementation)."""
return True
async def test_mcp_server_connection(
self, config: dict[str, Any]
) -> tuple[bool, str]:
"""Test MCP server connection (stub implementation)."""
# Import the actual test function if available
try:
from astrbot._internal.protocols.mcp.client import (
_quick_test_mcp_connection,
)
success, message = await _quick_test_mcp_connection(config)
if not success:
raise Exception(message)
return success, message
except Exception as e:
raise Exception(f"MCP connection test failed: {e!s}") from e
async def sync_modelscope_mcp_servers(self) -> None:
"""Sync ModelScope MCP servers (stub implementation)."""
pass
def get_full_tool_set(self) -> ToolSet:
"""Return a ToolSet with all active tools."""
return ToolSet("default", [t for t in self.func_list if t.active])

View File

@@ -1,64 +1,19 @@
"""
AstrBot Public API.
This package exposes the public interface for extending and integrating with
AstrBot. All exports from this module are guaranteed to be stable across
minor version updates.
Modules:
tools: Tool registration and management API
mcp: Model Context Protocol server and tool API
skills: Skill management and conversion API
"""
from astrbot import logger
# Tool API
from astrbot._internal.tools.base import FunctionTool, ToolSet
# MCP API
from astrbot.api.mcp import (
MCPClient,
MCPTool,
get_mcp_servers,
register_mcp_server,
unregister_mcp_server,
)
# Skills API
from astrbot.api.skills import (
SkillInfo,
SkillManager,
get_skill_manager,
skill_to_tool,
)
# Tools API (public interface)
from astrbot.api.tools import ToolRegistry, get_registry, tool
from astrbot.core import html_renderer, sp
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.register import register_agent as agent
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
"AstrBotConfig",
"BaseFunctionToolExecutor",
"FunctionTool",
"MCPClient",
"MCPTool",
"SkillInfo",
"SkillManager",
"ToolRegistry",
"ToolSet",
"agent",
"get_mcp_servers",
"get_registry",
"get_skill_manager",
"html_renderer",
"llm_tool",
"logger",
"register_mcp_server",
"skill_to_tool",
"sp",
"tool",
"unregister_mcp_server",
]

View File

@@ -29,7 +29,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
)
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
register_star as register, # 注册插件Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *

View File

@@ -55,14 +55,14 @@ __all__ = [
"on_decorating_result",
"on_llm_request",
"on_llm_response",
"on_llm_tool_respond",
"on_platform_loaded",
"on_plugin_error",
"on_plugin_loaded",
"on_plugin_unloaded",
"on_using_llm_tool",
"on_platform_loaded",
"on_waiting_llm_request",
"permission_type",
"platform_adapter_type",
"regex",
"on_using_llm_tool",
"on_llm_tool_respond",
]

View File

@@ -1,98 +0,0 @@
"""
MCP (Model Context Protocol) Public API for AstrBot.
This module provides a simple, stable interface for MCP server management,
delegating to the _internal package.
Example:
from astrbot.api.mcp import get_mcp_servers, register_mcp_server
# List connected servers
servers = get_mcp_servers()
# Register stdio MCP server
await register_mcp_server(
name="weather",
command="uv",
args=["tool", "run", "weather-mcp"],
)
# Register SSE server
await register_mcp_server(
name="fileserver",
url="http://localhost:8080/sse",
transport="sse",
)
"""
from __future__ import annotations
from typing import Any
# Import from _internal package (the canonical source)
# TODO: fix path - should be protocols.mcp.client
from astrbot._internal.protocols.mcp.client import McpClient as MCPClient
from astrbot._internal.protocols.mcp.tool import MCPTool
__all__ = [
"MCPClient",
"MCPTool",
"get_mcp_servers",
"register_mcp_server",
"unregister_mcp_server",
]
def get_mcp_servers() -> dict[str, MCPClient]:
"""Get all connected MCP servers."""
from astrbot.core.provider.register import llm_tools as func_tool_manager
manager = func_tool_manager
return dict(manager.mcp_client_dict)
async def register_mcp_server(
name: str,
command: str | None = None,
args: list[str] | None = None,
url: str | None = None,
transport: str | None = None,
**kwargs: Any,
) -> None:
"""Register and connect to an MCP server.
Args:
name: Unique name for this server
command: Command to run (for stdio transport)
args: Command arguments
url: URL (for SSE/Streamable HTTP transports)
transport: "sse", "streamable_http", or None for stdio
Example - Stdio:
await register_mcp_server(name="weather", command="uv",
args=["tool", "run", "weather-mcp"])
"""
from astrbot.core.provider.register import llm_tools as func_tool_manager
manager = func_tool_manager
config: dict[str, Any] = {}
if command is not None:
config["command"] = command
if args is not None:
config["args"] = args
if url is not None:
config["url"] = url
if transport is not None:
config["transport"] = transport
config.update(kwargs)
await manager.enable_mcp_server(name=name, config=config)
async def unregister_mcp_server(name: str) -> None:
"""Disconnect and remove an MCP server."""
from astrbot.core.provider.register import llm_tools as func_tool_manager
manager = func_tool_manager
await manager.disable_mcp_server(name=name)

View File

@@ -1,58 +0,0 @@
"""
Skills Public API for AstrBot.
This module provides a simple, stable interface for skill management,
delegating to the _internal package.
Two skill types:
1. Prompt-based: SKILL.md files injected into system prompt
2. Tool-based: Skills with input_schema converted to FunctionTool
Example:
from astrbot.api.skills import get_skill_manager, skill_to_tool
# List skills
mgr = get_skill_manager()
skills = mgr.list_skills()
# Convert tool-based skill to FunctionTool
tool_skills = [s for s in skills if s.input_schema]
if tool_skills:
func_tool = skill_to_tool(tool_skills[0])
"""
from __future__ import annotations
from astrbot._internal.tools.base import FunctionTool
# Import from _internal package (the canonical source)
# TODO: fix path - should be core.skills.skill_manager
from astrbot.core.skills.skill_manager import SkillInfo, SkillManager
__all__ = ["SkillInfo", "SkillManager", "get_skill_manager", "skill_to_tool"]
def get_skill_manager() -> SkillManager:
"""Get the global SkillManager instance."""
return SkillManager()
def skill_to_tool(skill: SkillInfo) -> FunctionTool | None:
"""Convert a tool-based skill (with input_schema) to a FunctionTool.
Args:
skill: A SkillInfo instance with an input_schema
Returns:
A FunctionTool, or None if the skill has no input_schema
"""
if not skill.input_schema:
return None
return FunctionTool(
name=f"skill_{skill.name}",
description=skill.description or f"Skill: {skill.name}",
parameters=skill.input_schema,
handler=None,
source="skill",
)

View File

@@ -1,7 +1,7 @@
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
register_star as register, # 注册插件Star
)
__all__ = ["Context", "Star", "StarTools", "register"]

View File

@@ -1,120 +0,0 @@
"""
Tools Public API for AstrBot.
This module provides a simple, stable interface for tool registration
and management. All implementations are delegated to the _internal package.
Example:
from astrbot.api.tools import tool, get_registry
@tool(name="weather", description="Get weather", parameters={...})
async def get_weather(city: str) -> str:
return f"Weather in {city} is sunny"
registry = get_registry()
tools = registry.list_tools()
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from functools import wraps
from typing import Any
# Import from _internal package (the canonical source)
from astrbot._internal.tools.base import FunctionTool, ToolSet
from astrbot._internal.tools.registry import FunctionToolManager
__all__ = ["FunctionTool", "ToolRegistry", "ToolSet", "get_registry", "tool"]
class ToolRegistry:
"""Wrapper around FunctionToolManager for simplified tool registration.
This class provides a user-friendly interface for registering and
managing tools, delegating to the internal FunctionToolManager.
"""
_instance: ToolRegistry | None = None
def __init__(self) -> None:
# Import here to avoid circular imports
from astrbot.core.provider.register import llm_tools as func_tool_manager
self._manager: FunctionToolManager = func_tool_manager
@classmethod
def get_instance(cls) -> ToolRegistry:
"""Get the singleton ToolRegistry instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def register(self, tool: FunctionTool) -> None:
"""Register a FunctionTool."""
self._manager.func_list.append(tool)
def unregister(self, name: str) -> bool:
"""Unregister a tool by name. Returns True if found and removed."""
for i, f in enumerate(self._manager.func_list):
if f.name == name:
self._manager.func_list.pop(i)
return True
return False
def list_tools(self) -> list[FunctionTool]:
"""List all registered tools."""
return self._manager.func_list.copy()
def get_tool(self, name: str) -> FunctionTool | None:
"""Get a tool by name."""
return self._manager.get_func(name)
def get_registry() -> ToolRegistry:
"""Get the global ToolRegistry instance."""
return ToolRegistry.get_instance()
def tool(
name: str,
description: str,
parameters: dict[str, Any] | None = None,
) -> Callable[
[Callable[..., Awaitable[str | None]]], Callable[..., Awaitable[str | None]]
]:
"""Decorator to register an async function as a tool.
Args:
name: Tool name (used by LLM to invoke it)
description: What the tool does
parameters: JSON Schema for parameters (optional)
Example:
@tool(name="weather", description="Get weather for a city", parameters={...})
async def get_weather(city: str) -> str:
return f"The weather in {city} is sunny"
"""
if parameters is None:
parameters = {"type": "object", "properties": {}}
def decorator(
func: Callable[..., Awaitable[str | None]],
) -> Callable[..., Awaitable[str | None]]:
func_tool = FunctionTool(
name=name,
description=description,
parameters=parameters,
handler=func,
handler_module_path=getattr(func, "__module__", ""),
source="api",
)
get_registry().register(func_tool)
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> str | None:
return await func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -76,7 +76,7 @@ class LongTermMemory:
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
raise Exception(f"提供商类型错误({type(provider)})无法获取图片描述")
response = await provider.text_chat(
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
@@ -149,7 +149,7 @@ class LongTermMemory:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
"""当触发 LLM 请求前,调用此方法修改 req"""
"""当触发 LLM 请求前调用此方法修改 req"""
if event.unified_msg_origin not in self.session_chats:
return
@@ -164,7 +164,7 @@ class LongTermMemory:
"Please react to it. Only output your response and do not output any other information. "
"You MUST use the SAME language as the chatroom is using."
)
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中
req.contexts = [] # 清空上下文当使用了主动回复所有聊天记录都在一个prompt中
else:
req.system_prompt += (
"You are now in a chatroom. The chat history is as follows: \n"

View File

@@ -50,7 +50,7 @@ class Main(star.Star):
"""主动回复"""
provider = self.context.get_using_provider(event.unified_msg_origin)
if not provider:
logger.error("未找到任何 LLM 提供商请先配置无法主动回复")
logger.error("未找到任何 LLM 提供商请先配置无法主动回复")
return
try:
conv = None
@@ -60,7 +60,7 @@ class Main(star.Star):
if not session_curr_cid:
logger.error(
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话",
"当前未处于对话状态无法主动回复请确保 平台设置->会话隔离(unique_session) 未开启并使用 /switch 序号 切换或者 /new 创建一个会话",
)
return
@@ -72,7 +72,7 @@ class Main(star.Star):
prompt = event.message_str
if not conv:
logger.error("未找到对话,无法主动回复")
logger.error("未找到对话无法主动回复")
return
yield event.request_llm(
@@ -88,7 +88,7 @@ class Main(star.Star):
async def decorate_llm_req(
self, event: AstrMessageEvent, req: ProviderRequest
) -> None:
"""在请求 LLM 前注入人格信息Identifier时间回复内容等 System Prompt"""
"""在请求 LLM 前注入人格信息Identifier时间回复内容等 System Prompt"""
if self.ltm and self.ltm_enabled(event):
try:
await self.ltm.on_req_llm(event, req)

View File

@@ -9,56 +9,56 @@ class AdminCommands:
self.context = context
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员op <admin_id>"""
"""授权管理员op <admin_id>"""
if not admin_id:
event.set_result(
MessageEventResult().message(
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员可通过 /sid 获取 ID",
"使用方法: /op <id> 授权管理员/deop <id> 取消管理员可通过 /sid 获取 ID",
),
)
return
self.context.get_config()["admins_id"].append(str(admin_id))
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功"))
event.set_result(MessageEventResult().message("授权成功"))
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""取消授权管理员deop <admin_id>"""
"""取消授权管理员deop <admin_id>"""
if not admin_id:
event.set_result(
MessageEventResult().message(
"使用方法: /deop <id> 取消管理员可通过 /sid 获取 ID",
"使用方法: /deop <id> 取消管理员可通过 /sid 获取 ID",
),
)
return
try:
self.context.get_config()["admins_id"].remove(str(admin_id))
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("取消授权成功"))
event.set_result(MessageEventResult().message("取消授权成功"))
except ValueError:
event.set_result(
MessageEventResult().message("此用户 ID 不在管理员名单内"),
MessageEventResult().message("此用户 ID 不在管理员名单内"),
)
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单wl <sid>"""
"""添加白名单wl <sid>"""
if not sid:
event.set_result(
MessageEventResult().message(
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单可通过 /sid 获取 ID",
"使用方法: /wl <id> 添加白名单/dwl <id> 删除白名单可通过 /sid 获取 ID",
),
)
return
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg["platform_settings"]["id_whitelist"].append(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("添加白名单成功"))
event.set_result(MessageEventResult().message("添加白名单成功"))
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""删除白名单dwl <sid>"""
"""删除白名单dwl <sid>"""
if not sid:
event.set_result(
MessageEventResult().message(
"使用方法: /dwl <id> 删除白名单可通过 /sid 获取 ID",
"使用方法: /dwl <id> 删除白名单可通过 /sid 获取 ID",
),
)
return
@@ -66,12 +66,12 @@ class AdminCommands:
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("删除白名单成功"))
event.set_result(MessageEventResult().message("删除白名单成功"))
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内"))
event.set_result(MessageEventResult().message("此 SID 不在白名单内"))
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await event.send(MessageChain().message("正在尝试更新管理面板..."))
await download_dashboard(version=f"v{VERSION}", latest=False)
await event.send(MessageChain().message("管理面板更新完成"))
await event.send(MessageChain().message("管理面板更新完成"))

View File

@@ -18,7 +18,7 @@ class AlterCmdCommands(CommandParserMixin):
"""更新reset命令在特定场景下的权限设置"""
from astrbot.api import sp
alter_cmd_cfg = await sp.global_get("alter_cmd", {}) or {}
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_cfg.get("reset", {})
reset_cfg[scene_key] = perm_type
@@ -31,7 +31,7 @@ class AlterCmdCommands(CommandParserMixin):
if token.len < 3:
await event.send(
MessageChain().message(
"该指令用于设置指令或指令组的权限\n"
"该指令用于设置指令或指令组的权限\n"
"格式: /alter_cmd <cmd_name> <admin/member>\n"
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
@@ -47,7 +47,7 @@ class AlterCmdCommands(CommandParserMixin):
if cmd_name == "reset" and cmd_type == "config":
from astrbot.api import sp
alter_cmd_cfg = await sp.global_get("alter_cmd", {}) or {}
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_.get("reset", {})
@@ -56,11 +56,11 @@ class AlterCmdCommands(CommandParserMixin):
private = reset_cfg.get("private", "member")
config_menu = f"""reset命令权限细粒度配置
当前配置:
当前配置
1. 群聊+会话隔离开: {group_unique_on}
2. 群聊+会话隔离关: {group_unique_off}
3. 私聊: {private}
修改指令格式:
修改指令格式
/alter_cmd reset scene <场景编号> <admin/member>
例如: /alter_cmd reset scene 2 member"""
await event.send(MessageChain().message(config_menu))
@@ -82,7 +82,7 @@ class AlterCmdCommands(CommandParserMixin):
if perm_type not in ["admin", "member"]:
await event.send(
MessageChain().message("权限类型错误,只能是 admin 或 member"),
MessageChain().message("权限类型错误只能是 admin 或 member"),
)
return
@@ -101,7 +101,7 @@ class AlterCmdCommands(CommandParserMixin):
if cmd_type not in ["admin", "member"]:
await event.send(
MessageChain().message("指令类型错误,可选类型有 admin, member"),
MessageChain().message("指令类型错误可选类型有 admin, member"),
)
return
@@ -131,7 +131,7 @@ class AlterCmdCommands(CommandParserMixin):
from astrbot.api import sp
alter_cmd_cfg = await sp.global_get("alter_cmd", {}) or {}
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
cfg = plugin_.get(found_command.handler_name, {})
cfg["permission"] = cmd_type
@@ -168,6 +168,6 @@ class AlterCmdCommands(CommandParserMixin):
cmd_group_str = "指令组" if cmd_group else "指令"
await event.send(
MessageChain().message(
f"已将{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}",
f"已将{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}",
),
)

View File

@@ -48,7 +48,7 @@ class ConversationCommands:
scene = RstScene.get_scene(is_group, is_unique_session)
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {}) or {}
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
plugin_config = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_config.get("reset", {})
@@ -60,8 +60,8 @@ class ConversationCommands:
if required_perm == "admin" and message.role != "admin":
message.set_result(
MessageEventResult().message(
f"{scene.name}场景下,reset命令需要管理员权限,"
f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作",
f"{scene.name}场景下reset命令需要管理员权限"
f"您 (ID {message.get_sender_id()}) 不是管理员无法执行此操作",
),
)
return
@@ -74,12 +74,12 @@ class ConversationCommands:
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功"))
message.set_result(MessageEventResult().message("重置对话成功"))
return
if not self.context.get_using_provider(umo):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
)
return
@@ -88,7 +88,7 @@ class ConversationCommands:
if not cid:
message.set_result(
MessageEventResult().message(
"当前未处于对话状态,请 /switch 切换或者 /new 创建",
"当前未处于对话状态请 /switch 切换或者 /new 创建",
),
)
return
@@ -101,7 +101,7 @@ class ConversationCommands:
[],
)
ret = "清除聊天历史成功!"
ret = "清除聊天历史成功"
message.set_extra("_clean_ltm_session", True)
@@ -124,18 +124,18 @@ class ConversationCommands:
if stopped_count > 0:
message.set_result(
MessageEventResult().message(
f"已请求停止 {stopped_count} 个运行中的任务"
f"已请求停止 {stopped_count} 个运行中的任务"
)
)
return
message.set_result(MessageEventResult().message("当前会话没有运行中的任务"))
message.set_result(MessageEventResult().message("当前会话没有运行中的任务"))
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
)
return
@@ -166,7 +166,7 @@ class ConversationCommands:
history = "".join(parts)
ret = (
f"当前对话历史记录:"
f"当前对话历史记录"
f"{history or '无历史记录'}\n\n"
f"{page} 页 | 共 {total_pages}\n"
f"*输入 /history 2 跳转到第 2 页"
@@ -181,7 +181,7 @@ class ConversationCommands:
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
message.set_result(
MessageEventResult().message(
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持",
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持",
),
)
return
@@ -200,7 +200,7 @@ class ConversationCommands:
end_idx = start_idx + size_per_page
conversations_paged = conversations_all[start_idx:end_idx]
parts = ["对话列表:\n---\n"]
parts = ["对话列表\n---\n"]
"""全局序号从当前页的第一个开始"""
global_index = start_idx + 1
@@ -277,7 +277,7 @@ class ConversationCommands:
scope_id=message.unified_msg_origin,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("已创建新对话"))
message.set_result(MessageEventResult().message("已创建新对话"))
return
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
@@ -291,7 +291,7 @@ class ConversationCommands:
message.set_extra("_clean_ltm_session", True)
message.set_result(
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})"),
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})"),
)
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None:
@@ -313,12 +313,12 @@ class ConversationCommands:
)
message.set_result(
MessageEventResult().message(
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})",
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})",
),
)
else:
message.set_result(
MessageEventResult().message("请输入群聊 ID/groupnew 群聊ID"),
MessageEventResult().message("请输入群聊 ID/groupnew 群聊ID"),
)
async def switch_conv(
@@ -329,14 +329,14 @@ class ConversationCommands:
"""通过 /ls 前面的序号切换对话"""
if not isinstance(index, int):
message.set_result(
MessageEventResult().message("类型错误,请输入数字对话序号"),
MessageEventResult().message("类型错误请输入数字对话序号"),
)
return
if index is None:
message.set_result(
MessageEventResult().message(
"请输入对话序号/switch 对话序号/ls 查看对话 /new 新建对话",
"请输入对话序号/switch 对话序号/ls 查看对话 /new 新建对话",
),
)
return
@@ -345,7 +345,7 @@ class ConversationCommands:
)
if index > len(conversations) or index < 1:
message.set_result(
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
MessageEventResult().message("对话序号错误请使用 /ls 查看"),
)
else:
conversation = conversations[index - 1]
@@ -356,20 +356,20 @@ class ConversationCommands:
)
message.set_result(
MessageEventResult().message(
f"切换到对话: {title}({conversation.cid[:4]})",
f"切换到对话: {title}({conversation.cid[:4]})",
),
)
async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None:
"""重命名对话"""
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称"))
message.set_result(MessageEventResult().message("请输入新的对话名称"))
return
await self.context.conversation_manager.update_conversation_title(
message.unified_msg_origin,
new_name,
)
message.set_result(MessageEventResult().message("重命名对话成功"))
message.set_result(MessageEventResult().message("重命名对话成功"))
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
@@ -377,10 +377,10 @@ class ConversationCommands:
cfg = self.context.get_config(umo=umo)
is_unique_session = cfg["platform_settings"]["unique_session"]
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
# 群聊没开独立会话发送人不是管理员
message.set_result(
MessageEventResult().message(
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话",
f"会话处于群聊并且未开启独立会话并且您 (ID {message.get_sender_id()}) 不是管理员因此没有权限删除当前对话",
),
)
return
@@ -393,7 +393,7 @@ class ConversationCommands:
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功"))
message.set_result(MessageEventResult().message("重置对话成功"))
return
session_curr_cid = (
@@ -403,7 +403,7 @@ class ConversationCommands:
if not session_curr_cid:
message.set_result(
MessageEventResult().message(
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建",
"当前未处于对话状态请 /switch 序号 切换或 /new 创建",
),
)
return
@@ -415,6 +415,6 @@ class ConversationCommands:
session_curr_cid,
)
ret = "删除当前对话成功不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建"
ret = "删除当前对话成功不再处于对话状态使用 /switch 序号 切换到其他对话或 /new 创建"
message.set_extra("_clean_ltm_session", True)
message.set_result(MessageEventResult().message(ret))

View File

@@ -24,7 +24,7 @@ class HelpCommand:
async def _build_reserved_command_lines(self) -> list[str]:
"""
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致
使用实时指令配置生成内置指令清单确保重命名/禁用后与实际生效状态保持一致
"""
try:
commands = await command_management.list_commands()

View File

@@ -17,4 +17,4 @@ class LLMCommands:
cfg["provider_settings"]["enable"] = True
status = "开启"
cfg.save_config()
await event.send(MessageChain().message(f"{status} LLM 聊天功能"))
await event.send(MessageChain().message(f"{status} LLM 聊天功能"))

View File

@@ -18,10 +18,10 @@ class PersonaCommands:
all_personas: list["Persona"],
depth: int = 0,
) -> list[str]:
"""递归构建树状输出,使用短线条表示层级"""
"""递归构建树状输出使用短线条表示层级"""
lines: list[str] = []
# 使用短线条作为缩进前缀,每层只用 "" 加一个空格
prefix = " " * depth
# 使用短线条作为缩进前缀每层只用 "" 加一个空格
prefix = " " * depth
for folder in folder_tree:
# 输出文件夹
@@ -31,7 +31,7 @@ class PersonaCommands:
folder_personas = [
p for p in all_personas if p.folder_id == folder["folder_id"]
]
child_prefix = " " * (depth + 1)
child_prefix = " " * (depth + 1)
# 输出该文件夹下的人格
for persona in folder_personas:
@@ -71,7 +71,7 @@ class PersonaCommands:
if conv is None:
message.set_result(
MessageEventResult().message(
"当前对话不存在,请先使用 /new 新建一个对话",
"当前对话不存在请先使用 /new 新建一个对话",
),
)
return
@@ -127,16 +127,16 @@ class PersonaCommands:
folder_tree = await self.context.persona_manager.get_folder_tree()
all_personas = self.context.persona_manager.personas
lines = ["📂 人格列表:\n"]
lines = ["📂 人格列表\n"]
# 构建树状输出
tree_lines = self._build_tree_output(folder_tree, all_personas)
lines.extend(tree_lines)
# 输出根目录下的人格(没有文件夹的)
# 输出根目录下的人格没有文件夹的
root_personas = [p for p in all_personas if p.folder_id is None]
if root_personas:
if tree_lines: # 如果有文件夹内容,加个空行
if tree_lines: # 如果有文件夹内容加个空行
lines.append("")
for persona in root_personas:
lines.append(f"👤 {persona.persona_id}")
@@ -161,7 +161,7 @@ class PersonaCommands:
),
None,
):
msg = f"人格{ps}的详细信息:\n"
msg = f"人格{ps}的详细信息\n"
msg += f"{persona['prompt']}\n"
else:
msg = f"人格{ps}不存在"
@@ -169,20 +169,20 @@ class PersonaCommands:
elif parts[1] == "unset":
if not cid:
message.set_result(
MessageEventResult().message("当前没有对话,无法取消人格"),
MessageEventResult().message("当前没有对话无法取消人格"),
)
return
await self.context.conversation_manager.update_conversation_persona_id(
message.unified_msg_origin,
"[%None]",
)
message.set_result(MessageEventResult().message("取消人格成功"))
message.set_result(MessageEventResult().message("取消人格成功"))
else:
ps = "".join(parts[1:]).strip()
if not cid:
message.set_result(
MessageEventResult().message(
"当前没有对话,请先开始对话或使用 /new 创建一个对话",
"当前没有对话请先开始对话或使用 /new 创建一个对话",
),
)
return
@@ -199,16 +199,18 @@ class PersonaCommands:
)
force_warn_msg = ""
if force_applied_persona_id:
force_warn_msg = "提醒:由于自定义规则,您现在切换的人格将不会生效。"
force_warn_msg = (
"提醒:由于自定义规则,您现在切换的人格将不会生效。"
)
message.set_result(
MessageEventResult().message(
f"设置成功如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格{force_warn_msg}",
f"设置成功如果您正在切换到不同的人格请注意使用 /reset 来清空上下文防止原人格对话影响现人格{force_warn_msg}",
),
)
else:
message.set_result(
MessageEventResult().message(
"不存在该人格情景使用 /persona list 查看所有",
"不存在该人格情景使用 /persona list 查看所有",
),
)

View File

@@ -11,8 +11,8 @@ class PluginCommands:
self.context = context
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表"""
parts = ["已加载的插件:\n"]
"""获取已经安装的插件列表"""
parts = ["已加载的插件\n"]
for plugin in self.context.get_all_stars():
line = f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
if not plugin.activated:
@@ -20,11 +20,11 @@ class PluginCommands:
parts.append(line + "\n")
if len(parts) == 1:
plugin_list_info = "没有加载任何插件"
plugin_list_info = "没有加载任何插件"
else:
plugin_list_info = "".join(parts)
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令\n使用 /plugin on/off <插件名> 启用或者禁用插件"
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令\n使用 /plugin on/off <插件名> 启用或者禁用插件"
event.set_result(
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
)
@@ -32,51 +32,51 @@ class PluginCommands:
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法禁用插件"))
event.set_result(MessageEventResult().message("演示模式下无法禁用插件"))
return
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin off <插件名> 禁用插件"),
MessageEventResult().message("/plugin off <插件名> 禁用插件"),
)
return
if self.context._star_manager is None:
event.set_result(MessageEventResult().message("插件管理器未初始化"))
event.set_result(MessageEventResult().message("插件管理器未初始化"))
return
await self.context._star_manager.turn_off_plugin(plugin_name)
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用"))
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用"))
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法启用插件"))
event.set_result(MessageEventResult().message("演示模式下无法启用插件"))
return
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin on <插件名> 启用插件"),
MessageEventResult().message("/plugin on <插件名> 启用插件"),
)
return
if self.context._star_manager is None:
event.set_result(MessageEventResult().message("插件管理器未初始化"))
event.set_result(MessageEventResult().message("插件管理器未初始化"))
return
await self.context._star_manager.turn_on_plugin(plugin_name)
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用"))
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用"))
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法安装插件"))
event.set_result(MessageEventResult().message("演示模式下无法安装插件"))
return
if not plugin_repo:
event.set_result(
MessageEventResult().message("/plugin get <插件仓库地址> 安装插件"),
)
return
logger.info(f"准备从 {plugin_repo} 安装插件")
logger.info(f"准备从 {plugin_repo} 安装插件")
if self.context._star_manager:
star_mgr = self.context._star_manager
try:
await star_mgr.install_plugin(plugin_repo)
event.set_result(MessageEventResult().message("安装插件成功"))
event.set_result(MessageEventResult().message("安装插件成功"))
except Exception as e:
logger.error(f"安装插件失败: {e}")
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
@@ -86,12 +86,12 @@ class PluginCommands:
"""获取插件帮助"""
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin help <插件名> 查看插件信息"),
MessageEventResult().message("/plugin help <插件名> 查看插件信息"),
)
return
plugin = self.context.get_registered_star(plugin_name)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件"))
event.set_result(MessageEventResult().message("未找到此插件"))
return
help_msg = ""
help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}"
@@ -111,15 +111,15 @@ class PluginCommands:
command_names.append(filter_.group_name)
if len(command_handlers) > 0:
parts = ["\n\n🔧 指令列表:\n"]
parts = ["\n\n🔧 指令列表\n"]
for i in range(len(command_handlers)):
line = f"- {command_names[i]}"
if command_handlers[i].desc:
line += f": {command_handlers[i].desc}"
parts.append(line + "\n")
parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /")
parts.append("\nTip: 指令的触发需要添加唤醒前缀默认为 /")
help_msg += "".join(parts)
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README"
ret = f"🧩 插件 {plugin_name} 帮助信息\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README"
event.set_result(MessageEventResult().message(ret).use_t2i(False))

View File

@@ -127,7 +127,7 @@ class ProviderCommands:
return self.context.get_config(umo).get("provider_settings", {}) or {}
except Exception as e:
logger.debug(
"读取 provider_settings 失败,使用默认值: %s",
"读取 provider_settings 失败使用默认值: %s",
safe_error("", e),
)
return {}
@@ -142,7 +142,7 @@ class ProviderCommands:
return max(float(raw), 0.0)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
"读取 %s 失败回退默认值 %r: %s",
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
safe_error("", e),
@@ -159,7 +159,7 @@ class ProviderCommands:
value = int(raw)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
"读取 %s 失败回退默认值 %r: %s",
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
safe_error("", e),
@@ -209,7 +209,7 @@ class ProviderCommands:
) -> str:
prov.set_model(model_name)
self.invalidate_provider_models_cache(prov.meta().id, umo=umo)
return f"切换模型成功当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
return f"切换模型成功当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
async def _get_provider_models(
self,
@@ -265,7 +265,7 @@ class ProviderCommands:
err_code: str,
err_reason: str,
) -> None:
"""记录不可达原因到日志"""
"""记录不可达原因到日志"""
meta = provider.meta()
logger.warning(
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
@@ -358,7 +358,7 @@ class ProviderCommands:
provider_id for provider_id, _ in failed_provider_errors
)
logger.error(
"跨提供商查找模型 %s,所有 %d 个提供商的 get_models() 均失败: %s请检查配置或网络",
"跨提供商查找模型 %s所有 %d 个提供商的 get_models() 均失败: %s请检查配置或网络",
model_name,
len(all_providers),
failed_ids,
@@ -405,7 +405,7 @@ class ProviderCommands:
if all_providers:
await event.send(
MessageEventResult().message(
"正在进行提供商可达性测试,请稍候..."
"正在进行提供商可达性测试请稍候..."
)
)
check_results = await asyncio.gather(
@@ -426,7 +426,7 @@ class ProviderCommands:
if isinstance(reachable, asyncio.CancelledError):
raise reachable
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
# 异常情况下兜底处理避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
@@ -501,23 +501,23 @@ class ProviderCommands:
line += " (当前使用)"
parts.append(line + "\n")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商")
ret = "".join(parts)
if ttss:
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商"
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商"
if stts:
ret += "\n使用 /provider stt <序号> 切换 STT 提供商"
ret += "\n使用 /provider stt <序号> 切换 STT 提供商"
if not reachability_check_enabled:
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启"
ret += "\n已跳过提供商可达性检测如需检测请在配置文件中开启"
event.set_result(MessageEventResult().message(ret))
elif idx == "tts":
if idx2 is None:
event.set_result(MessageEventResult().message("请输入序号"))
event.set_result(MessageEventResult().message("请输入序号"))
return
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的提供商序号"))
event.set_result(MessageEventResult().message("无效的提供商序号"))
return
provider = self.context.get_all_tts_providers()[idx2 - 1]
id_ = provider.meta().id
@@ -526,13 +526,13 @@ class ProviderCommands:
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif idx == "stt":
if idx2 is None:
event.set_result(MessageEventResult().message("请输入序号"))
event.set_result(MessageEventResult().message("请输入序号"))
return
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的提供商序号"))
event.set_result(MessageEventResult().message("无效的提供商序号"))
return
provider = self.context.get_all_stt_providers()[idx2 - 1]
id_ = provider.meta().id
@@ -541,10 +541,10 @@ class ProviderCommands:
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的提供商序号"))
event.set_result(MessageEventResult().message("无效的提供商序号"))
return
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
@@ -553,16 +553,16 @@ class ProviderCommands:
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
else:
event.set_result(MessageEventResult().message("无效的参数"))
event.set_result(MessageEventResult().message("无效的参数"))
async def _switch_model_by_name(
self, message: AstrMessageEvent, model_name: str, prov: Provider
) -> None:
model_name = model_name.strip()
if not model_name:
message.set_result(MessageEventResult().message("模型名不能为空"))
message.set_result(MessageEventResult().message("模型名不能为空"))
return
umo = message.unified_msg_origin
@@ -574,7 +574,7 @@ class ProviderCommands:
prov,
config,
error_prefix="获取当前提供商模型列表失败: ",
warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
warning_log="获取当前提供商 %s 模型列表失败停止跨提供商查找: %s",
)
if models is None:
return
@@ -597,7 +597,7 @@ class ProviderCommands:
if target_prov is None or matched_target_model_name is None:
message.set_result(
MessageEventResult().message(
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试",
f"模型 [{model_name}] 未在任何已配置的提供商中找到或所有提供商模型列表获取失败请检查配置或网络后重试",
),
)
return
@@ -612,7 +612,7 @@ class ProviderCommands:
self._apply_model(target_prov, matched_target_model_name, umo=umo)
message.set_result(
MessageEventResult().message(
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型",
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}]已自动切换提供商并设置模型",
),
)
except asyncio.CancelledError:
@@ -633,7 +633,7 @@ class ProviderCommands:
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
)
return
config = self._get_model_lookup_config(message.unified_msg_origin)
@@ -655,7 +655,7 @@ class ProviderCommands:
curr_model = prov.get_model() or ""
parts.append(f"\n当前模型: [{curr_model}]")
parts.append(
"\nTips: 使用 /model <模型名/编号> 切换模型输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换"
"\nTips: 使用 /model <模型名/编号> 切换模型输入模型名时可自动跨提供商查找并切换跨提供商也可使用 /provider 切换"
)
ret = "".join(parts)
@@ -670,7 +670,7 @@ class ProviderCommands:
if models is None:
return
if idx_or_name > len(models) or idx_or_name < 1:
message.set_result(MessageEventResult().message("模型序号错误"))
message.set_result(MessageEventResult().message("模型序号错误"))
else:
try:
new_model = models[idx_or_name - 1]
@@ -697,7 +697,7 @@ class ProviderCommands:
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
MessageEventResult().message("未找到任何 LLM 提供商请先配置"),
)
return
@@ -710,14 +710,14 @@ class ProviderCommands:
parts.append(f"\n当前 Key: {curr_key[:8]}")
parts.append("\n当前模型: " + prov.get_model())
parts.append("\n使用 /key <idx> 切换 Key")
parts.append("\n使用 /key <idx> 切换 Key")
ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
keys_data = prov.get_keys()
if index > len(keys_data) or index < 1:
message.set_result(MessageEventResult().message("Key 序号错误"))
message.set_result(MessageEventResult().message("Key 序号错误"))
else:
try:
new_key = keys_data[index - 1]
@@ -726,7 +726,7 @@ class ProviderCommands:
prov.meta().id,
umo=message.unified_msg_origin,
)
message.set_result(MessageEventResult().message("切换 Key 成功"))
message.set_result(MessageEventResult().message("切换 Key 成功"))
except Exception as e:
message.set_result(
MessageEventResult().message(

View File

@@ -9,28 +9,28 @@ class SetUnsetCommands:
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
"""设置会话变量"""
uid = event.unified_msg_origin
session_var = await sp.session_get(uid, "session_variables", {}) or {}
session_var = await sp.session_get(uid, "session_variables", {})
session_var[key] = value
await sp.session_put(uid, "session_variables", session_var)
event.set_result(
MessageEventResult().message(
f"会话 {uid} 变量 {key} 存储成功使用 /unset 移除",
f"会话 {uid} 变量 {key} 存储成功使用 /unset 移除",
),
)
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
"""移除会话变量"""
uid = event.unified_msg_origin
session_var = await sp.session_get(uid, "session_variables", {}) or {}
session_var = await sp.session_get(uid, "session_variables", {})
if key not in session_var:
event.set_result(
MessageEventResult().message("没有那个变量名格式 /unset 变量名"),
MessageEventResult().message("没有那个变量名格式 /unset 变量名"),
)
else:
del session_var[key]
await sp.session_put(uid, "session_variables", session_var)
event.set_result(
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功"),
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功"),
)

View File

@@ -18,19 +18,19 @@ class SIDCommand:
umo_msg_type = event.session.message_type.value
umo_session_id = event.session.session_id
ret = (
f"UMO: {sid} 此值可用于设置白名单\n"
f"UID: {user_id} 此值可用于设置管理员\n"
f"UMO: {sid} 此值可用于设置白名单\n"
f"UID: {user_id} 此值可用于设置管理员\n"
f"消息会话来源信息:\n"
f" 机器人 ID: {umo_platform}\n"
f" 消息类型: {umo_msg_type}\n"
f" 会话 ID: {umo_session_id}\n"
f"消息来源可用于配置机器人的配置文件路由"
f" 机器人 ID: {umo_platform}\n"
f" 消息类型: {umo_msg_type}\n"
f" 会话 ID: {umo_session_id}\n"
f"消息来源可用于配置机器人的配置文件路由"
)
if (
self.context.get_config()["platform_settings"]["unique_session"]
and event.get_group_id()
):
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊"
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊"
event.set_result(MessageEventResult().message(ret).use_t2i(False))

View File

@@ -16,8 +16,8 @@ class T2ICommand:
if config["t2i"]:
config["t2i"] = False
config.save_config()
event.set_result(MessageEventResult().message("已关闭文本转图片模式"))
event.set_result(MessageEventResult().message("已关闭文本转图片模式"))
return
config["t2i"] = True
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转图片模式"))
event.set_result(MessageEventResult().message("已开启文本转图片模式"))

View File

@@ -12,7 +12,7 @@ class TTSCommand:
self.context = context
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
"""开关文本转语音会话级别"""
umo = event.unified_msg_origin
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
@@ -27,10 +27,10 @@ class TTSCommand:
if new_status and not tts_enable:
event.set_result(
MessageEventResult().message(
f"{status_text}当前会话的文本转语音但 TTS 功能在配置中未启用,请前往 WebUI 开启",
f"{status_text}当前会话的文本转语音但 TTS 功能在配置中未启用请前往 WebUI 开启",
),
)
else:
event.set_result(
MessageEventResult().message(f"{status_text}当前会话的文本转语音"),
MessageEventResult().message(f"{status_text}当前会话的文本转语音"),
)

View File

@@ -51,7 +51,7 @@ class Main(star.Star):
@plugin.command("ls")
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表"""
"""获取已经安装的插件列表"""
await self.plugin_c.plugin_ls(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@@ -84,7 +84,7 @@ class Main(star.Star):
@filter.command("tts")
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
"""开关文本转语音会话级别"""
await self.tts_c.tts(event)
@filter.command("sid")
@@ -95,25 +95,25 @@ class Main(star.Star):
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员op <admin_id>"""
"""授权管理员op <admin_id>"""
await self.admin_c.op(event, admin_id)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
"""取消授权管理员deop <admin_id>"""
"""取消授权管理员deop <admin_id>"""
await self.admin_c.deop(event, admin_id)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单wl <sid>"""
"""添加白名单wl <sid>"""
await self.admin_c.wl(event, sid)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
"""删除白名单dwl <sid>"""
"""删除白名单dwl <sid>"""
await self.admin_c.dwl(event, sid)
@filter.permission_type(filter.PermissionType.ADMIN)

View File

@@ -72,9 +72,9 @@ class Main(Star):
# 使用 LLM 生成回复
yield event.request_llm(
prompt=(
"注意,你正在社交媒体上中与用户进行聊天,用户只是通过@来唤醒你,但并未在这条消息中输入内容,他可能会在接下来一条发送他想发送的内容"
"你友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化"
"请注意,你仅需要输出要回复用户的内容,不要输出其他任何东西"
"注意你正在社交媒体上中与用户进行聊天用户只是通过@来唤醒你但并未在这条消息中输入内容他可能会在接下来一条发送他想发送的内容"
"你友好地询问用户想要聊些什么或者需要什么帮助回复要符合人设不要太过机械化"
"请注意你仅需要输出要回复用户的内容不要输出其他任何东西"
),
session_id=curr_cid,
contexts=[],
@@ -83,8 +83,8 @@ class Main(Star):
)
except Exception as e:
logger.error(f"LLM response failed: {e!s}")
# LLM 回复失败,使用原始预设回复
yield event.plain_result("想要问什么呢?😄")
# LLM 回复失败使用原始预设回复
yield event.plain_result("想要问什么呢😄")
@session_waiter(60)
async def empty_mention_waiter(
@@ -106,7 +106,7 @@ class Main(Star):
except TimeoutError as _:
pass
except Exception as e:
yield event.plain_result("发生错误,请联系管理员: " + str(e))
yield event.plain_result("发生错误请联系管理员: " + str(e))
finally:
event.stop_event()
except Exception as e:

View File

@@ -81,7 +81,7 @@ class SearchEngine:
return ret
def tidy_text(self, text: str) -> str:
"""清理文本,去除空格换行符等"""
"""清理文本去除空格换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def _get_url(self, tag: Tag) -> str:

View File

@@ -34,14 +34,14 @@ class Main(star.Star):
self.bocha_key_index = 0
self.bocha_key_lock = asyncio.Lock()
# 将 str 类型的 key 迁移至 list[str],并保存
# 将 str 类型的 key 迁移至 list[str]并保存
cfg = self.context.get_config()
provider_settings = cfg.get("provider_settings")
if provider_settings:
tavily_key = provider_settings.get("websearch_tavily_key")
if isinstance(tavily_key, str):
logger.info(
"检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存",
"检测到旧版 websearch_tavily_key (字符串格式)自动迁移为列表格式并保存",
)
if tavily_key:
provider_settings["websearch_tavily_key"] = [tavily_key]
@@ -62,7 +62,7 @@ class Main(star.Star):
self.baidu_initialized = False
async def _tidy_text(self, text: str) -> str:
"""清理文本,去除空格换行符等"""
"""清理文本去除空格换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def _get_from_url(self, url: str) -> str:
@@ -124,10 +124,10 @@ class Main(star.Star):
return results
async def _get_tavily_key(self, cfg: AstrBotConfig) -> str:
"""并发安全的从列表中获取并轮换Tavily API密钥"""
"""并发安全的从列表中获取并轮换Tavily API密钥"""
tavily_keys = cfg.get("provider_settings", {}).get("websearch_tavily_key", [])
if not tavily_keys:
raise ValueError("错误:Tavily API密钥未在AstrBot中配置")
raise ValueError("错误Tavily API密钥未在AstrBot中配置")
async with self.tavily_key_lock:
key = tavily_keys[self.tavily_key_index]
@@ -203,11 +203,11 @@ class Main(star.Star):
query: str,
max_results: int = 5,
) -> str:
"""搜索网络以回答用户的问题当用户需要搜索网络以获取即时性的信息时调用此工具
"""搜索网络以回答用户的问题当用户需要搜索网络以获取即时性的信息时调用此工具
Args:
query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索
max_results(number): 返回的最大搜索结果数量,默认为 5
query(string): 和用户的问题最相关的搜索关键词用于在 Google 上搜索
max_results(number): 返回的最大搜索结果数量默认为 5
"""
logger.info(f"web_searcher - search_from_search_engine: {query}")
@@ -231,7 +231,7 @@ class Main(star.Star):
ret += processed_result
if websearch_link:
ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
ret += "\n\n针对问题请根据上面的结果分点总结并且在结尾处附上对应内容的参考链接如有)。"
return ret
@@ -384,10 +384,10 @@ class Main(star.Star):
return ret
async def _get_bocha_key(self, cfg: AstrBotConfig) -> str:
"""并发安全的从列表中获取并轮换BoCha API密钥"""
"""并发安全的从列表中获取并轮换BoCha API密钥"""
bocha_keys = cfg.get("provider_settings", {}).get("websearch_bocha_key", [])
if not bocha_keys:
raise ValueError("错误:BoCha API密钥未在AstrBot中配置")
raise ValueError("错误BoCha API密钥未在AstrBot中配置")
async with self.bocha_key_lock:
key = bocha_keys[self.bocha_key_index]
@@ -500,18 +500,18 @@ class Main(star.Star):
"count": count,
}
# freshness:时间范围
# freshness时间范围
if freshness:
payload["freshness"] = freshness
# 是否返回摘要
payload["summary"] = summary
# include:限制搜索域
# include限制搜索域
if include:
payload["include"] = include
# exclude:排除搜索域
# exclude排除搜索域
if exclude:
payload["exclude"] = exclude
@@ -567,9 +567,9 @@ class Main(star.Star):
if provider == "default":
web_search_t = func_tool_mgr.get_func("web_search")
fetch_url_t = func_tool_mgr.get_func("fetch_url")
if web_search_t and web_search_t.active:
if web_search_t:
tool_set.add_tool(web_search_t)
if fetch_url_t and fetch_url_t.active:
if fetch_url_t:
tool_set.add_tool(fetch_url_t)
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
@@ -578,9 +578,9 @@ class Main(star.Star):
elif provider == "tavily":
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
if web_search_tavily and web_search_tavily.active:
if web_search_tavily:
tool_set.add_tool(web_search_tavily)
if tavily_extract_web_page and tavily_extract_web_page.active:
if tavily_extract_web_page:
tool_set.add_tool(tavily_extract_web_page)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
@@ -590,8 +590,9 @@ class Main(star.Star):
try:
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
aisearch_tool = func_tool_mgr.get_func("AIsearch")
if aisearch_tool and aisearch_tool.active:
tool_set.add_tool(aisearch_tool)
if not aisearch_tool:
raise ValueError("Cannot get Baidu AI Search MCP tool.")
tool_set.add_tool(aisearch_tool)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("web_search_tavily")
@@ -601,7 +602,7 @@ class Main(star.Star):
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
elif provider == "bocha":
web_search_bocha = func_tool_mgr.get_func("web_search_bocha")
if web_search_bocha and web_search_bocha.active:
if web_search_bocha:
tool_set.add_tool(web_search_bocha)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")

View File

@@ -7,8 +7,7 @@ import click
from click.shell_completion import get_completion_class
from . import __version__
from .commands import bk, conf, init, plug, run, tui, uninstall
from .i18n import t
from .commands import bk, conf, init, plug, run, uninstall
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
@@ -23,12 +22,10 @@ logo_tmpl = r"""
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""Astrbot
Agentic IM Chatbot infrastructure that integrates lots of IM platforms, LLMs, plugins and AI feature, and can be your openclaw alternative. ✨
"""
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo(t("cli_welcome"))
click.echo(t("cli_version", version=__version__))
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")
@click.command()
@@ -71,7 +68,7 @@ def help(command_name: str | None, all: bool) -> None:
cmd_ctx = click.Context(command, info_name=command.name, parent=parent)
click.echo(command.get_help(cmd_ctx))
else:
click.echo(t("cli_unknown_command", command=command_name))
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# Display general help information
@@ -85,7 +82,6 @@ cli.add_command(plug)
cli.add_command(conf)
cli.add_command(uninstall)
cli.add_command(bk)
cli.add_command(tui)
@click.command()
@@ -108,9 +104,6 @@ def completion(shell: str | None) -> None:
sys.exit(1)
comp_cls = get_completion_class(shell)
if comp_cls is None:
click.echo(f"No completion support for shell: {shell}", err=True)
sys.exit(1)
comp = comp_cls(
cli, ctx_args={}, prog_name="astrbot", complete_var="_ASTRBOT_COMPLETE"
)

View File

@@ -3,7 +3,6 @@ from .cmd_conf import conf
from .cmd_init import init
from .cmd_plug import plug
from .cmd_run import run
from .cmd_tui import tui
from .cmd_uninstall import uninstall
__all__ = ["bk", "conf", "init", "plug", "run", "tui", "uninstall"]
__all__ = ["conf", "init", "plug", "run", "uninstall", "bk"]

View File

@@ -7,41 +7,36 @@ from pathlib import Path
import anyio
import click
from astrbot.core import db_helper
from astrbot.core import astrbot_config, db_helper
from astrbot.core.backup import AstrBotExporter, AstrBotImporter
# Try importing KnowledgeBaseManager to support KB backup
try:
from astrbot.core.knowledge.kb_manager import KnowledgeBaseManager
except ImportError:
try:
from astrbot.core.knowledge_base.kb_manager import KnowledgeBaseManager
except ImportError:
KnowledgeBaseManager = None
async def _get_kb_manager():
"""Initialize and return a KnowledgeBaseManager with full dependency chain."""
from astrbot.core import astrbot_config, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.umop_config_router import UmopConfigRouter
if KnowledgeBaseManager is None:
return None
ucr = UmopConfigRouter(sp=sp)
await ucr.initialize()
acm = AstrBotConfigManager(
default_config=astrbot_config,
ucr=ucr,
sp=sp,
)
persona_mgr = PersonaManager(db_helper, acm)
await persona_mgr.initialize()
provider_manager = ProviderManager(
acm,
db_helper,
persona_mgr,
)
kb_manager = KnowledgeBaseManager(provider_manager)
await kb_manager.initialize()
return kb_manager
try:
# Best effort initialization
kb_mgr = KnowledgeBaseManager(astrbot_config, db_helper)
# If there are async load methods, we might need to call them
if hasattr(kb_mgr, "load_kbs_from_db"):
await kb_mgr.load_kbs_from_db()
elif hasattr(kb_mgr, "load_all"):
await kb_mgr.load_all()
return kb_mgr
except Exception:
# If KB manager fails to load (e.g. missing dependencies), return None
# so we can still backup other data
return None
@click.group(name="bk")
@@ -143,7 +138,8 @@ def export_data(
"GPG tool not found. Please install GnuPG to use encryption/signing features."
)
exporter = AstrBotExporter(db_helper)
kb_mgr = await _get_kb_manager()
exporter = AstrBotExporter(db_helper, kb_mgr)
async def on_progress(stage, current, total, message):
click.echo(f"[{stage}] {message}")

View File

@@ -1,172 +1,95 @@
"""
Configuration CLI for AstrBot.
This module provides:
- secure hashing utilities for the dashboard password (argon2)
- validators for commonly configurable items
- click CLI group with `set`, `get`, and `password` subcommands
"""
from __future__ import annotations
import binascii
import hashlib
import json
import zoneinfo
from collections.abc import Callable
from typing import Any
import argon2.exceptions as argon2_exceptions
import click
from argon2 import PasswordHasher
from astrbot.cli.i18n import t
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.utils.astrbot_path import astrbot_paths
_PASSWORD_HASHER = PasswordHasher()
from ..utils import check_astrbot_root
DEFAULT_DASHBOARD_PASSWORD = "astrbot"
DEFAULT_DASHBOARD_PASSWORD_MD5 = hashlib.md5(
DEFAULT_DASHBOARD_PASSWORD.encode()
).hexdigest()
DEFAULT_DASHBOARD_PASSWORD_SHA256 = hashlib.sha256(
DEFAULT_DASHBOARD_PASSWORD.encode()
).hexdigest()
PBKDF2_SALT = b"astrbot-dashboard"
PBKDF2_ITER = 200_000
def hash_dashboard_password(value: str) -> str:
"""Hash Dashboard password for storage."""
return hashlib.sha256(value.encode()).hexdigest()
# --- Password hashing & validation utilities ---
def hash_dashboard_password_md5(value: str) -> str:
"""Hash Dashboard password with the legacy MD5 algorithm."""
return hashlib.md5(value.encode()).hexdigest()
def hash_dashboard_password_secure(value: str) -> str:
"""
Hash the dashboard password for storage.
Stored format:
$argon2id$... (if Argon2 available) or pbkdf2_sha256 fallback.
"""
if _PASSWORD_HASHER is not None:
try:
return _PASSWORD_HASHER.hash(value)
except Exception as e:
raise click.ClickException(
f"Failed to hash password securely (argon2): {e!s}"
)
dk = hashlib.pbkdf2_hmac("sha256", value.encode("utf-8"), PBKDF2_SALT, PBKDF2_ITER)
return f"pbkdf2_sha256${PBKDF2_ITER}${binascii.hexlify(PBKDF2_SALT).decode()}${dk.hex()}"
def verify_dashboard_password(value: str, stored_hash: str) -> bool:
"""
Verify a plaintext password `value` against a stored hash.
Supported format:
- Argon2 encoded string: $argon2id$...
- PBKDF2 encoded string: pbkdf2_sha256$...
- Legacy SHA-256 (64 hex chars) and MD5 (32 hex chars) for backward compatibility.
"""
if not stored_hash:
return False
if stored_hash.startswith("$argon2"):
try:
return _PASSWORD_HASHER.verify(stored_hash, value)
except argon2_exceptions.VerifyMismatchError:
return False
except Exception as e:
raise click.ClickException(f"Password verification failure (argon2): {e!s}")
if stored_hash.startswith("pbkdf2_sha256$"):
try:
_, iters_s, salt_hex, digest_hex = stored_hash.split("$", 3)
iters = int(iters_s)
salt = binascii.unhexlify(salt_hex)
expected = digest_hex.lower()
dk = hashlib.pbkdf2_hmac("sha256", value.encode("utf-8"), salt, iters)
return dk.hex() == expected
except Exception:
return False
# Legacy plain hex digests: SHA-256 (64 hex chars) and MD5 (32 hex chars).
value_l = value.encode("utf-8")
s = stored_hash.lower()
if len(s) == 64 and all(ch in "0123456789abcdef" for ch in s):
return hashlib.sha256(value_l).hexdigest() == s
if len(s) == 32 and all(ch in "0123456789abcdef" for ch in s):
return hashlib.md5(value_l).hexdigest() == s
return False
def is_dashboard_password_hash(value: str) -> bool:
"""
Heuristic: return True if `value` looks like a supported dashboard password hash.
"""
if not isinstance(value, str) or not value:
return False
return value.startswith("$argon2") or value.startswith("pbkdf2_sha256$")
def is_legacy_dashboard_password_hash(value: str) -> bool:
"""
Heuristic: return True if `value` looks like a legacy password hash format.
Legacy formats are plain SHA-256 (64 hex chars) or MD5 (32 hex chars) digests.
"""
if not isinstance(value, str) or not value:
return False
# Legacy plain hex digests: SHA-256 (64 hex chars) or MD5 (32 hex chars)
if len(value) == 64 and all(ch in "0123456789abcdef" for ch in value.lower()):
return True
if len(value) == 32 and all(ch in "0123456789abcdef" for ch in value.lower()):
return True
return False
# --- Validators for CLI configuration items ---
def is_dashboard_password_hash(value: str, *, algorithm: str) -> bool:
expected_len = 64 if algorithm == "sha256" else 32
return len(value) == expected_len and all(ch in "0123456789abcdef" for ch in value)
def _validate_log_level(value: str) -> str:
value_up = value.upper()
allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
if value_up not in allowed:
raise click.ClickException(t("config_log_level_invalid"))
return value_up
"""Validate log level"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
)
return value
def _validate_dashboard_port(value: str) -> int:
"""Validate Dashboard port"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("Port must be in range 1-65535")
return port
except ValueError:
raise click.ClickException(t("config_port_must_be_number"))
if port < 1 or port > 65535:
raise click.ClickException(t("config_port_range_invalid"))
return port
raise click.ClickException("Port must be a number")
def _validate_dashboard_username(value: str) -> str:
if value is None or value.strip() == "":
raise click.ClickException(t("config_username_empty"))
return value.strip()
"""Validate Dashboard username"""
if not value:
raise click.ClickException("Username cannot be empty")
return value
def _validate_dashboard_password(value: str) -> str:
if value is None or value == "":
raise click.ClickException(t("config_password_empty"))
# Return the canonical stored representation.
return hash_dashboard_password_secure(value)
"""Validate Dashboard password"""
if not value:
raise click.ClickException("Password cannot be empty")
return hash_dashboard_password(value)
def _validate_timezone(value: str) -> str:
"""Validate timezone"""
try:
zoneinfo.ZoneInfo(value)
except Exception:
raise click.ClickException(t("config_timezone_invalid", value=value))
raise click.ClickException(
f"Invalid timezone: {value}. Please use a valid IANA timezone name"
)
return value
def _validate_callback_api_base(value: str) -> str:
if not (value.startswith("http://") or value.startswith("https://")):
raise click.ClickException(t("config_callback_invalid"))
"""Validate callback API base URL"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException(
"Callback API base must start with http:// or https://"
)
return value
# Configuration items settable via CLI, mapping config keys to validator functions
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
@@ -177,23 +100,18 @@ CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
}
# --- Config file helpers ---
def _load_config() -> dict[str, Any]:
"""
Load or initialize the CLI config file (data/cmd_config.json).
Ensures the astrbot root is valid before proceeding.
"""
"""Load or initialize config file"""
root = astrbot_paths.root
if not astrbot_paths.is_root:
if not check_astrbot_root(root):
raise click.ClickException(
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize"
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
config_path = astrbot_paths.data / "cmd_config.json"
if not config_path.exists():
# Write DEFAULT_CONFIG to disk if file missing
from astrbot.core.config.default import DEFAULT_CONFIG
config_path.write_text(
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
@@ -206,43 +124,50 @@ def _load_config() -> dict[str, Any]:
def _save_config(config: dict[str, Any]) -> None:
"""Save config file"""
config_path = astrbot_paths.data / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
json.dumps(config, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
def ensure_config_file() -> dict[str, Any]:
"""Ensure config file exists and return parsed config."""
return _load_config()
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""Set a value in a nested dictionary"""
parts = path.split(".")
cur = obj
for part in parts[:-1]:
if part not in cur:
cur[part] = {}
elif not isinstance(cur[part], dict):
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict"
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict",
)
cur = cur[part]
cur[parts[-1]] = value
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""Get a value from a nested dictionary"""
parts = path.split(".")
cur = obj
for part in parts:
cur = cur[part]
return cur
# --- CLI commands ---
obj = obj[part]
return obj
def prompt_dashboard_password(prompt: str = "Dashboard password") -> str:
password = click.prompt(prompt, hide_input=True, confirmation_prompt=True, type=str)
"""Prompt for dashboard password with confirmation."""
password = click.prompt(
prompt,
hide_input=True,
confirmation_prompt=True,
type=str,
)
return _validate_dashboard_password(password)
@@ -252,69 +177,63 @@ def set_dashboard_credentials(
username: str | None = None,
password_hash: str | None = None,
) -> None:
"""Update dashboard credentials in config."""
if username is not None:
_set_nested_item(
config, "dashboard.username", _validate_dashboard_username(username)
config,
"dashboard.username",
_validate_dashboard_username(username),
)
if password_hash is not None:
if isinstance(password_hash, str) and is_dashboard_password_hash(password_hash):
_set_nested_item(config, "dashboard.password", password_hash)
else:
if is_legacy_dashboard_password_hash(password_hash):
raise click.ClickException(
"Storing legacy dashboard password hashes is no longer supported. "
"Please provide the plaintext password (it will be hashed securely), "
"or provide an Argon2-encoded hash string."
)
_set_nested_item(
config,
"dashboard.password",
_validate_dashboard_password(password_hash),
)
_set_nested_item(config, "dashboard.password", password_hash)
@click.group(name="conf")
def conf() -> None:
"""
Configuration management commands.
"""Configuration management commands
Supported config keys:
- timezone
- log_level
- dashboard.port
- dashboard.username
- dashboard.password
- callback_api_base
- timezone: Timezone setting (e.g. Asia/Shanghai)
- log_level: Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard port
- dashboard.username: Dashboard username
- dashboard.password: Dashboard password
- callback_api_base: Callback API base URL
"""
pass
@conf.command(name="set")
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str) -> None:
"""Set the value of a config item"""
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"Unsupported config key: {key}")
config = _load_config()
try:
# Attempt to get old value (may raise KeyError)
try:
old_value = _get_nested_item(config, key)
except Exception:
old_value = "<not set>"
try:
old_value = _get_nested_item(config, key)
validated_value = CONFIG_VALIDATORS[key](value)
_set_nested_item(config, key, validated_value)
_save_config(config)
click.echo(f"Config updated: {key}")
click.echo(f" Old value: {old_value}")
click.echo(f" New value: {validated_value}")
if key == "dashboard.password":
click.echo(" Old value: ********")
click.echo(" New value: ********")
else:
click.echo(f" Old value: {old_value}")
click.echo(f" New value: {validated_value}")
except KeyError:
raise click.ClickException(f"Unknown config key: {key}")
except click.ClickException:
raise
except Exception as e:
raise click.UsageError(f"Failed to set config: {e!s}")
@@ -322,10 +241,13 @@ def set_config(key: str, value: str) -> None:
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None) -> None:
"""Get the value of a config item. If no key is provided, show all configurable items"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"Unsupported config key: {key}")
try:
value = _get_nested_item(config, key)
if key == "dashboard.password":
@@ -337,51 +259,35 @@ def get_config(key: str | None = None) -> None:
raise click.UsageError(f"Failed to get config: {e!s}")
else:
click.echo("Current config:")
for k in CONFIG_VALIDATORS:
for key in CONFIG_VALIDATORS:
try:
v = (
value = (
"********"
if k == "dashboard.password"
else _get_nested_item(config, k)
if key == "dashboard.password"
else _get_nested_item(config, key)
)
click.echo(f" {k}: {v}")
click.echo(f" {key}: {value}")
except (KeyError, TypeError):
# Missing or non-dict paths are simply skipped in listing
pass
@conf.command(name="admin")
@click.option("-u", "--username", type=str, help="Update admain username as well")
@conf.command(name="password")
@click.option("-u", "--username", type=str, help="Update dashboard username as well")
@click.option(
"-p",
"--password",
type=str,
help="Set admain password directly without interactive prompt",
help="Set dashboard password directly without interactive prompt",
)
def set_dashboard_password(username: str | None, password: str | None) -> None:
"""
Interactively set dashboard password (with confirmation) or set directly with -p.
Acceptable inputs:
- Plaintext password (recommended): it will be hashed securely before storage.
- Argon2 encoded hash (advanced): stored as-is.
"""
"""Interactively manage dashboard password."""
config = _load_config()
if password is not None:
if isinstance(password, str) and is_dashboard_password_hash(password):
password_hash = password
else:
if is_legacy_dashboard_password_hash(password):
raise click.ClickException(
"Providing legacy dashboard password hashes is no longer supported. "
"Please supply the plaintext password (it will be hashed securely), "
"or provide an Argon2-encoded hash string."
)
password_hash = _validate_dashboard_password(password)
else:
password_hash = prompt_dashboard_password()
password_hash = (
_validate_dashboard_password(password)
if password is not None
else prompt_dashboard_password()
)
set_dashboard_credentials(
config,
username=username.strip() if username is not None else None,

View File

@@ -1,19 +1,19 @@
import asyncio
import json
import os
import re
from pathlib import Path
from typing import Any, cast
import click
from filelock import FileLock, Timeout
from astrbot.cli.utils import DashboardManager
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.utils.astrbot_path import astrbot_paths
from ..utils import check_dashboard
from .cmd_conf import (
_validate_dashboard_password,
ensure_config_file,
prompt_dashboard_password,
set_dashboard_credentials,
)
@@ -60,88 +60,50 @@ async def initialize_astrbot(
)
click.echo(f"Created config file: {config_path}")
# Generate an .env for this instance from the bundled config.template (if available).
# The generated file will be written to ASTRBOT_ROOT/.env and will be automatically
# loaded by `astrbot run` (service-config/.env precedence applies).
ASTRBOT_ROOT = astrbot_root
env_file = ASTRBOT_ROOT / ".env"
if not env_file.exists():
tmpl_candidates = [
Path("/opt/astrbot/config.template"),
# project_root may point to the installed package directory; try it as well
getattr(astrbot_paths, "project_root", Path.cwd()) / "config.template",
Path.cwd() / "config.template",
]
tmpl = None
for t in tmpl_candidates:
try:
if t.exists():
tmpl = t
break
except Exception:
continue
if tmpl is not None:
try:
txt = tmpl.read_text(encoding="utf-8")
# Determine instance name for template replacement (fallback to directory name)
instance_name = astrbot_root.name or "astrbot"
# Substitute ${VAR} and ${VAR:-default} for INSTANCE_NAME, PORT, ASTRBOT_ROOT
txt = re.sub(r"\$\{INSTANCE_NAME(:-[^}]*)?\}", instance_name, txt)
port_val = (
os.environ.get("ASTRBOT_PORT") or os.environ.get("PORT") or "8000"
)
txt = re.sub(r"\$\{PORT(:-[^}]*)?\}", str(port_val), txt)
txt = re.sub(r"\$\{ASTRBOT_ROOT(:-[^}]*)?\}", str(ASTRBOT_ROOT), txt)
header = (
f"# Generated from config.template by astrbot init for instance: {instance_name}\n"
"# This file will be auto-loaded by 'astrbot run'\n\n"
)
env_file.write_text(header + txt, encoding="utf-8")
env_file.chmod(0o644)
click.echo(f"Created environment file from template: {env_file}")
except Exception as e:
click.echo(f"Warning: failed to generate .env from template: {e!s}")
else:
click.echo("No config.template found; skipping .env generation")
if admin_password is not None:
if admin_password and not admin_username:
raise click.ClickException(
"--admin-password is no longer supported during init. "
"Run 'astrbot conf admin' after initialization."
"--admin-password requires --admin-username to be provided"
)
effective_admin_username = (
admin_username.strip()
if admin_username
else str(cast(dict[str, Any], DEFAULT_CONFIG)["dashboard"]["username"])
)
if admin_username:
password_hash = (
_validate_dashboard_password(admin_password)
if admin_password is not None
else None
)
if password_hash is None:
if yes or os.environ.get("ASTRBOT_SYSTEMD") == "1":
raise click.ClickException(
"Non-interactive init requires --admin-password when --admin-username is set"
)
password_hash = prompt_dashboard_password("Dashboard admin password")
config = ensure_config_file()
set_dashboard_credentials(
config,
username=effective_admin_username,
password_hash=None,
username=admin_username.strip(),
password_hash=password_hash,
)
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
click.echo(f"Configured dashboard admin username: {effective_admin_username}")
click.echo(
"Dashboard password is not initialized for interactive use. "
"Run 'astrbot conf admin' before the first login."
)
click.echo(f"Configured dashboard admin username: {admin_username.strip()}")
if not backend_only and (
yes
or click.confirm(
"是否需要集成式 WebUI?(个人电脑推荐,服务器不推荐)",
"是否需要集成式 WebUI个人电脑推荐服务器不推荐",
default=True,
)
):
await DashboardManager().ensure_installed(astrbot_root)
# 避免在 systemd 模式下因等待输入而阻塞
if os.environ.get("ASTRBOT_SYSTEMD") == "1":
click.echo("Systemd detected: Skipping dashboard check.")
else:
await check_dashboard(astrbot_root)
else:
click.echo("你可以使用在线面版(需支持配置后端)来控制")
click.echo("你可以使用在线面版需支持配置后端来控制")
@click.command()
@@ -158,12 +120,7 @@ async def initialize_astrbot(
"-p",
"--admin-password",
type=str,
help="Deprecated. Run `astrbot conf admin` after initialization.",
)
@click.option(
"--root",
help="ASTRBOT root directory to initialize (overrides ASTRBOT_ROOT env)",
type=str,
help="Set dashboard admin password during initialization without prompting",
)
def init(
yes: bool,
@@ -171,16 +128,14 @@ def init(
backup: str | None,
admin_username: str | None,
admin_password: str | None,
root: str | None = None,
) -> None:
"""Initialize AstrBot"""
click.echo("Initializing AstrBot...")
if os.environ.get("ASTRBOT_SYSTEMD") == "1":
yes = True
from astrbot.core.utils.astrbot_path import astrbot_paths
astrbot_root = Path(root) if root else astrbot_paths.root
astrbot_root = astrbot_paths.root
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)

View File

@@ -1,12 +1,15 @@
import re
import shutil
from pathlib import Path
import click
from astrbot.cli.i18n import t
from astrbot.cli.utils import (
from astrbot.core.utils.astrbot_path import astrbot_paths
from ..utils import (
PluginStatus,
build_plug_list,
check_astrbot_root,
get_git_repo,
manage_plugin,
)
@@ -17,6 +20,15 @@ def plug() -> None:
"""Plugin management"""
def _get_data_path() -> Path:
base = astrbot_paths.root
if not check_astrbot_root(base):
raise click.ClickException(
f"{base} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
return astrbot_paths.data.resolve()
def display_plugins(plugins, title=None, color=None) -> None:
if title:
click.echo(click.style(title, fg=color, bold=True))
@@ -38,13 +50,11 @@ def display_plugins(plugins, title=None, color=None) -> None:
@click.argument("name")
def new(name: str) -> None:
"""Create a new plugin"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(t("plugin_already_exists", name=name))
raise click.ClickException(f"Plugin {name} already exists")
author = click.prompt("Enter plugin author", type=str)
desc = click.prompt("Enter plugin description", type=str)
@@ -97,9 +107,7 @@ def new(name: str) -> None:
@click.option("--all", "-a", is_flag=True, help="List uninstalled plugins")
def list(all: bool) -> None:
"""List plugins"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# Unpublished plugins
@@ -140,9 +148,7 @@ def list(all: bool) -> None:
@click.option("--proxy", help="Proxy server address")
def install(name: str, proxy: str | None) -> None:
"""Install a plugin"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -156,7 +162,7 @@ def install(name: str, proxy: str | None) -> None:
)
if not plugin:
raise click.ClickException(t("plugin_not_found_or_installed", name=name))
raise click.ClickException(f"Plugin {name} not found or already installed")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@@ -165,26 +171,24 @@ def install(name: str, proxy: str | None) -> None:
@click.argument("name")
def remove(name: str) -> None:
"""Uninstall a plugin"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(t("plugin_not_found_or_installed", name=name))
raise click.ClickException(f"Plugin {name} does not exist or is not installed")
plugin_path = plugin["local_path"]
click.confirm(t("plugin_uninstall_confirm", name=name), default=False, abort=True)
click.confirm(
f"Are you sure you want to uninstall plugin {name}?", default=False, abort=True
)
try:
shutil.rmtree(plugin_path)
click.echo(t("plugin_uninstall_success", name=name))
click.echo(f"Plugin {name} has been uninstalled")
except Exception as e:
raise click.ClickException(
t("plugin_uninstall_failed_ex", name=name, error=str(e))
)
raise click.ClickException(f"Failed to uninstall plugin {name}: {e}")
@plug.command()
@@ -192,9 +196,7 @@ def remove(name: str) -> None:
@click.option("--proxy", help="GitHub proxy address")
def update(name: str, proxy: str | None) -> None:
"""Update plugins"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -220,13 +222,13 @@ def update(name: str, proxy: str | None) -> None:
]
if not need_update_plugins:
click.echo(t("plugin_no_update_needed"))
click.echo("No plugins need updating")
return
click.echo(t("plugin_found_update", count=str(len(need_update_plugins))))
click.echo(f"Found {len(need_update_plugins)} plugin(s) needing update")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(t("plugin_updating", name=plugin_name))
click.echo(f"Updating plugin {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@@ -234,9 +236,7 @@ def update(name: str, proxy: str | None) -> None:
@click.argument("query")
def search(query: str) -> None:
"""Search for plugins"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
matched_plugins = [
@@ -248,7 +248,7 @@ def search(query: str) -> None:
]
if not matched_plugins:
click.echo(t("plugin_search_no_result", query=query))
click.echo(f"No plugins matching '{query}' found")
return
display_plugins(matched_plugins, t("plugin_search_results", query=query), "cyan")
display_plugins(matched_plugins, f"Search results: '{query}'", "cyan")

View File

@@ -1,3 +1,40 @@
"""AstrBot Run
Environment Variables Used in Project:
Core:
- `ASTRBOT_ROOT`: AstrBot root directory path.
- `ASTRBOT_LOG_LEVEL`: Log level (e.g. INFO, DEBUG).
- `ASTRBOT_CLI`: Flag indicating execution via CLI.
- `ASTRBOT_DESKTOP_CLIENT`: Flag indicating execution via desktop client.
- `ASTRBOT_SYSTEMD`: Flag indicating execution via systemd service.
- `ASTRBOT_RELOAD`: Enable plugin auto-reload (set to "1").
- `ASTRBOT_DISABLE_METRICS`: Disable metrics upload (set to "1").
- `TESTING`: Enable testing mode.
- `DEMO_MODE`: Enable demo mode.
- `PYTHON`: Python executable path override (for local code execution).
Dashboard:
- `ASTRBOT_DASHBOARD_ENABLE` / `DASHBOARD_ENABLE`: Enable/Disable Dashboard.
- `ASTRBOT_DASHBOARD_HOST` / `DASHBOARD_HOST`: Dashboard bind host.
- `ASTRBOT_DASHBOARD_PORT` / `DASHBOARD_PORT`: Dashboard bind port.
- `ASTRBOT_DASHBOARD_SSL_ENABLE` / `DASHBOARD_SSL_ENABLE`: Enable SSL.
- `ASTRBOT_DASHBOARD_SSL_CERT` / `DASHBOARD_SSL_CERT`: SSL Certificate path.
- `ASTRBOT_DASHBOARD_SSL_KEY` / `DASHBOARD_SSL_KEY`: SSL Key path.
- `ASTRBOT_DASHBOARD_SSL_CA_CERTS` / `DASHBOARD_SSL_CA_CERTS`: SSL CA Certs path.
Network:
- `http_proxy` / `https_proxy`: Proxy URL.
- `no_proxy`: No proxy list.
Integrations:
- `DASHSCOPE_API_KEY`: Alibaba DashScope API Key (for Rerank).
- `COZE_API_KEY` / `COZE_BOT_ID`: Coze integration.
- `BAY_DATA_DIR`: Computer Use data directory.
Platform Specific:
- `TEST_MODE`: Test mode for QQOfficial.
"""
import asyncio
import os
import sys
@@ -7,8 +44,9 @@ from pathlib import Path
import click
from filelock import FileLock, Timeout
from astrbot.cli.utils import DashboardManager
from astrbot.core.utils.astrbot_path import get_astrbot_root
from astrbot.core.utils.astrbot_path import astrbot_paths
from ..utils import check_astrbot_root, check_dashboard
async def run_astrbot(astrbot_root: Path) -> None:
@@ -16,7 +54,13 @@ async def run_astrbot(astrbot_root: Path) -> None:
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
await DashboardManager().ensure_installed(astrbot_root)
if (
os.environ.get("ASTRBOT_DASHBOARD_ENABLE", os.environ.get("DASHBOARD_ENABLE"))
== "True"
):
# 避免在 systemd 模式下因等待输入而阻塞
if os.environ.get("ASTRBOT_SYSTEMD") != "1":
await check_dashboard(astrbot_root)
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
@@ -28,15 +72,96 @@ async def run_astrbot(astrbot_root: Path) -> None:
@click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins")
@click.option("--host", "-H", help="AstrBot Dashboard Host", required=False, type=str)
@click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str)
@click.option("--root", help="AstrBot root directory", required=False, type=str)
@click.option(
"--service-config",
"-c",
help="Service configuration file path",
required=False,
type=str,
)
@click.option(
"--backend-only",
"-b",
is_flag=True,
default=False,
help="Disable WebUI, run backend only",
)
@click.option(
"--log-level",
"-l",
help="Log level",
required=False,
type=str,
default="INFO",
)
@click.option("--debug", is_flag=True, help="Enable debug mode")
@click.command()
def run(reload: bool, port: str) -> None:
def run(
reload: bool,
host: str,
port: str,
root: str,
service_config: str,
backend_only: bool,
log_level: str,
debug: bool,
) -> None:
"""Run AstrBot"""
try:
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = Path(get_astrbot_root())
if debug:
log_level = "DEBUG"
if not (astrbot_root / "data").exists():
if service_config:
svc_path = Path(service_config)
if svc_path.exists():
content = svc_path.read_text(encoding="utf-8")
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
# Remove quotes
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
value = value[1:-1]
if key == "HOST" and not host:
host = value
elif key == "PORT" and not port:
port = value
elif key == "ASTRBOT_ROOT" and not root:
root = value
# Normalize environment variables for backward compatibility
# If the legacy env var is set but the new one isn't, copy it over.
env_map = {
"DASHBOARD_ENABLE": "ASTRBOT_DASHBOARD_ENABLE",
"DASHBOARD_HOST": "ASTRBOT_DASHBOARD_HOST",
"DASHBOARD_PORT": "ASTRBOT_DASHBOARD_PORT",
"DASHBOARD_SSL_ENABLE": "ASTRBOT_DASHBOARD_SSL_ENABLE",
"DASHBOARD_SSL_CERT": "ASTRBOT_DASHBOARD_SSL_CERT",
"DASHBOARD_SSL_KEY": "ASTRBOT_DASHBOARD_SSL_KEY",
"DASHBOARD_SSL_CA_CERTS": "ASTRBOT_DASHBOARD_SSL_CA_CERTS",
}
for legacy, new in env_map.items():
if legacy in os.environ and new not in os.environ:
os.environ[new] = os.environ[legacy]
os.environ["ASTRBOT_CLI"] = "1"
if root:
os.environ["ASTRBOT_ROOT"] = root
astrbot_root = Path(root)
else:
astrbot_root = astrbot_paths.root
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
@@ -44,13 +169,67 @@ def run(reload: bool, port: str) -> None:
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
sys.path.insert(0, str(astrbot_root))
if port:
os.environ["DASHBOARD_PORT"] = port
if port is not None:
os.environ["ASTRBOT_DASHBOARD_PORT"] = port
os.environ["DASHBOARD_PORT"] = port # 今后应该移除
if host is not None:
os.environ["ASTRBOT_DASHBOARD_HOST"] = host
os.environ["DASHBOARD_HOST"] = host # 今后应该移除
os.environ["ASTRBOT_DASHBOARD_ENABLE"] = str(not backend_only)
os.environ["DASHBOARD_ENABLE"] = str(not backend_only) # 今后应该移除
os.environ["ASTRBOT_LOG_LEVEL"] = log_level
if reload:
click.echo("Plugin auto-reload enabled")
os.environ["ASTRBOT_RELOAD"] = "1"
if debug:
keys_to_print = [
"ASTRBOT_ROOT",
"ASTRBOT_LOG_LEVEL",
"ASTRBOT_CLI",
"ASTRBOT_DESKTOP_CLIENT",
"ASTRBOT_SYSTEMD",
"ASTRBOT_RELOAD",
"ASTRBOT_DISABLE_METRICS",
"TESTING",
"DEMO_MODE",
"PYTHON",
"ASTRBOT_DASHBOARD_ENABLE",
"DASHBOARD_ENABLE",
"ASTRBOT_DASHBOARD_HOST",
"DASHBOARD_HOST",
"ASTRBOT_DASHBOARD_PORT",
"DASHBOARD_PORT",
"ASTRBOT_DASHBOARD_SSL_ENABLE",
"DASHBOARD_SSL_ENABLE",
"ASTRBOT_DASHBOARD_SSL_CERT",
"DASHBOARD_SSL_CERT",
"ASTRBOT_DASHBOARD_SSL_KEY",
"DASHBOARD_SSL_KEY",
"ASTRBOT_DASHBOARD_SSL_CA_CERTS",
"DASHBOARD_SSL_CA_CERTS",
"http_proxy",
"https_proxy",
"no_proxy",
"DASHSCOPE_API_KEY",
"COZE_API_KEY",
"COZE_BOT_ID",
"BAY_DATA_DIR",
"TEST_MODE",
]
click.secho("\n[Debug Mode] Environment Variables:", fg="yellow", bold=True)
for key in keys_to_print:
if key in os.environ:
val = os.environ[key]
if "KEY" in key or "PASSWORD" in key or "SECRET" in key:
if len(val) > 8:
val = val[:4] + "****" + val[-4:]
else:
val = "****"
click.echo(f" {click.style(key, fg='cyan')}: {val}")
click.echo("")
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
with lock.acquire():

View File

@@ -1,307 +0,0 @@
"""AstrBot Run TUI - A beautiful textual interface for running AstrBot.
This module provides a Textual-based TUI for `astrbot run` with:
- Animated ASCII logo
- Live log viewer
- Platform status indicators
- Only activates in interactive TTY environments
"""
from __future__ import annotations
import sys
import typing
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, Vertical
from textual.reactive import reactive
from textual.widgets import Footer, Header, Log, Static
if typing.TYPE_CHECKING:
from rich.console import Console
from rich.style import Style
from rich.text import Text
else:
Console: Any = None
Style: Any = None
Text: Any = None
# AstrBot ASCII Logo
ASTRBOT_LOGO = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
class AstrBotRunTUI(App):
"""Textual TUI for AstrBot run command."""
CSS = """
Screen {
background: $surface;
}
#logo-container {
height: auto;
padding: 1 2;
background: $surface-darken-1;
border: solid $primary;
}
#logo-text {
color: $primary;
text-style: bold;
font-family: "JetBrains Mono", "Fira Code", monospace;
}
#main-container {
height: 1fr;
}
#log-section {
border: solid $accent;
height: 70%;
margin: 1 2;
}
#log-header {
background: $accent-darken-1;
padding: 1 2;
color: $text;
text-style: bold;
}
Log {
background: $surface-darken-2;
color: $text;
border: solid $accent-darken-2;
}
#status-section {
height: auto;
padding: 1 2;
background: $surface-darken-1;
border-top: solid $primary;
}
.status-item {
padding: 0 2;
}
.status-ok {
color: $success;
text-style: bold;
}
.status-pending {
color: $warning;
}
.status-label {
color: $text-muted;
}
.hidden {
display: none;
}
"""
BINDINGS: typing.ClassVar[list[Binding]] = [
Binding("q", "quit", "Quit", show=True),
Binding("ctrl+c", "quit", "Quit", show=False),
Binding("l", "toggle_logs", "Toggle Logs", show=True),
]
log_visible = reactive(True)
def __init__(
self,
startup_coro: Callable[[], Awaitable[Any]],
astrbot_root: Path,
backend_only: bool = False,
host: str | None = None,
port: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.startup_coro = startup_coro
self.astrbot_root = astrbot_root
self.backend_only = backend_only
self.host = host
self.port = port
self._animation_frame = 0
self._startup_done = False
self._log_lines: list[str] = []
self.console: Any = Console() if Console else None
def compose(self) -> ComposeResult:
"""Create child widgets."""
yield Header()
# Animated Logo
with Container(id="logo-container"):
yield Static(self._get_animated_logo(), id="logo-text")
# Main content
with Vertical(id="main-container"):
# Log viewer
with Container(
id="log-section", classes="" if self.log_visible else "hidden"
):
yield Static("📋 Live Logs", id="log-header")
yield Log(id="log-viewer")
# Status bar
with Horizontal(id="status-section"):
yield Static("🌟 AstrBot", classes="status-item status-ok")
yield Static(
f"📁 {self.astrbot_root.name}",
classes="status-item",
id="root-status",
)
if not self.backend_only:
dashboard_url = (
f"http://{self.host or 'localhost'}:{self.port or '6185'}"
)
yield Static(
f"🌐 Dashboard: [link]{dashboard_url}[/link]",
classes="status-item",
id="dashboard-status",
)
yield Static(
"⚡ Running", classes="status-item status-ok", id="run-status"
)
yield Footer()
def on_mount(self) -> None:
"""Called when app is mounted."""
self.title = "AstrBot"
self.sub_title = "AI Chatbot Framework"
# Start the startup coroutine
self.set_timer(0.1, self._run_startup)
# Animate logo
self.set_interval(0.5, self._animate_logo)
# Get the log widget and configure it
log_widget = self.query_one("#log-viewer", Log)
log_widget.write_line("🚀 AstrBot TUI initialized")
log_widget.write_line(f"📁 Running from: {self.astrbot_root}")
if not self.backend_only:
log_widget.write_line(
f"🌐 Dashboard will be available at: {self.host or 'localhost'}:{self.port or '6185'}"
)
log_widget.write_line("")
def _get_animated_logo(self) -> str:
"""Get the logo with optional animation effect."""
lines = ASTRBOT_LOGO.strip().split("\n")
if self.console and hasattr(self, "_animation_frame"):
# Create animated version with color cycling
frame = self._animation_frame % 4
colors = ["#00D9FF", "#00FF87", "#FFD700", "#FF6B6B"]
color = colors[frame]
text = Text()
for i, line in enumerate(lines):
style = Style(color=color, bold=True) if i == 0 else Style(color=color)
text.append(line + "\n", style=style)
return str(text)
return ASTRBOT_LOGO
def _animate_logo(self) -> None:
"""Update the animated logo."""
self._animation_frame = (self._animation_frame + 1) % 4
logo_widget = self.query_one("#logo-text", Static)
logo_widget.update(self._get_animated_logo())
async def _run_startup(self) -> None:
"""Run the AstrBot startup coroutine."""
if self._startup_done:
return
self._startup_done = True
try:
log_widget = self.query_one("#log-viewer", Log)
log_widget.write_line("⏳ Initializing AstrBot...")
await self.startup_coro()
log_widget.write_line("")
log_widget.write_line("✅ AstrBot started successfully!")
except Exception as e:
log_widget = self.query_one("#log-viewer", Log)
log_widget.write_line(f"❌ Error during startup: {e}")
log_widget.write_line("Check logs for details.")
def action_toggle_logs(self) -> None:
"""Toggle log visibility."""
self.log_visible = not self.log_visible
log_section = self.query_one("#log-section", Container)
if self.log_visible:
log_section.remove_class("hidden")
else:
log_section.add_class("hidden")
async def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def write_log(self, message: str) -> None:
"""Write a message to the log viewer (can be called from outside)."""
log_widget = self.query_one("#log-viewer", Log)
log_widget.write_line(message)
def is_interactive_tty() -> bool:
"""Check if we're running in an interactive TTY."""
return sys.stdin.isatty() and sys.stdout.isatty()
async def run_tui(
startup_coro: Callable[[], Awaitable[Any]],
astrbot_root: Path,
backend_only: bool = False,
host: str | None = None,
port: str | None = None,
) -> None:
"""Run the AstrBot TUI.
Args:
startup_coro: Coroutine to run on startup
astrbot_root: AstrBot root directory
backend_only: Whether backend-only mode is enabled
host: Dashboard host
port: Dashboard port
"""
if not is_interactive_tty():
# Not interactive, run without TUI
await startup_coro()
return
app = AstrBotRunTUI(
startup_coro=startup_coro,
astrbot_root=astrbot_root,
backend_only=backend_only,
host=host,
port=port,
)
try:
await app.run_async()
except Exception:
# Fallback to non-TUI mode
await startup_coro()

View File

@@ -1,68 +0,0 @@
"""TUI CLI command for AstrBot."""
from __future__ import annotations
import sys
import click
@click.command(name="tui")
@click.option(
"--debug",
is_flag=True,
help="Enable debug mode with verbose output.",
)
@click.option(
"--host",
default="http://localhost:6185",
help="AstrBot dashboard host URL.",
)
@click.option(
"--api-key",
default=None,
help="API key for authentication (optional, uses login if not provided).",
)
@click.option(
"--username",
default="astrbot",
help="Username for login (if api-key not provided).",
)
@click.option(
"--password",
default="astrbot",
help="Password for login (if api-key not provided).",
)
def tui(
debug: bool,
host: str,
api_key: str | None,
username: str,
password: str,
) -> None:
"""
Launch the AstrBot Terminal User Interface (TUI).
This command starts an interactive terminal-based interface for AstrBot.
The TUI connects to a running AstrBot instance via the dashboard API.
"""
try:
from astrbot.cli.commands.tui_async import run_tui_async
run_tui_async(
debug=debug,
host=host,
api_key=api_key,
username=username,
password=password,
)
except ImportError as e:
click.echo(f"Error: Failed to import TUI module: {e}", err=True)
sys.exit(1)
except Exception as e:
click.echo(f"Error: Failed to start TUI: {e}", err=True)
if debug:
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,511 +0,0 @@
"""Async TUI implementation that connects to a running AstrBot instance via HTTP API.
This module provides a terminal UI that connects to AstrBot via the dashboard API,
supporting streaming responses and all message types.
"""
from __future__ import annotations
import asyncio
import curses
import json
from dataclasses import dataclass, field
from enum import Enum
import httpx
from astrbot.tui.message_handler import (
ChatResponse,
MessageType,
ParsedMessage,
SSEMessageParser,
)
from astrbot.tui.screen import Screen
class MessageSender(Enum):
USER = "user"
BOT = "bot"
SYSTEM = "system"
TOOL = "tool"
REASONING = "reasoning"
@dataclass
class Message:
sender: MessageSender
text: str
timestamp: float | None = None
@dataclass
class TUIState:
messages: list[Message] = field(default_factory=list)
input_buffer: str = ""
cursor_x: int = 0
status: str = "Connecting..."
running: bool = True
connected: bool = False
class TUIClient:
"""TUI client that connects to AstrBot via HTTP API.
Supports full streaming responses including:
- Plain text (streaming)
- Tool calls and results
- Reasoning chains
- Agent stats
- Media (images, audio, files)
"""
def __init__(
self,
screen: Screen,
host: str,
api_key: str | None,
username: str,
password: str,
debug: bool = False,
):
self.screen = screen
self.state = TUIState()
self._input_history: list[str] = []
self._history_index: int = -1
self._max_history: int = 100
self._max_messages: int = 1000
self._pending_tasks: list[asyncio.Task[None]] = []
# Connection settings
self.host = host.rstrip("/")
self.api_key = api_key
self.username = username
self.password = password
self.debug = debug
# Session info
self.session_id: str | None = None
self.conversation_id: str | None = None
# HTTP client
self._client: httpx.AsyncClient | None = None
self._headers: dict[str, str] = {}
# SSE parser
self._parser = SSEMessageParser()
async def connect(self) -> bool:
"""Connect to AstrBot and authenticate."""
self._client = httpx.AsyncClient(base_url=self.host, timeout=30.0)
try:
# Login or use API key
if self.api_key:
self._headers["Authorization"] = f"Bearer {self.api_key}"
else:
login_resp = await self._client.post(
"/api/auth/login",
json={"username": self.username, "password": self.password},
)
if login_resp.status_code != 200:
self.state.status = f"Login failed: {login_resp.status_code}"
return False
data = login_resp.json()
self._headers["Authorization"] = (
f"Bearer {data.get('access_token', '')}"
)
# Create new session for TUI
new_session_resp = await self._client.get(
"/api/tui/new_session",
params={"platform_id": "tui"},
headers=self._headers,
)
if new_session_resp.status_code != 200:
self.state.status = (
f"Failed to create session: {new_session_resp.status_code}"
)
return False
session_data = new_session_resp.json()
if session_data.get("code") != 0:
self.state.status = f"Session error: {session_data.get('msg')}"
return False
self.conversation_id = session_data.get("data", {}).get("session_id")
if not self.conversation_id:
self.state.status = "No session_id in response"
return False
self.session_id = self.conversation_id
self.state.connected = True
self.state.status = "Connected"
return True
except Exception as e:
self.state.status = f"Connection error: {e}"
if self.debug:
import traceback
traceback.print_exc()
return False
async def disconnect(self) -> None:
"""Disconnect from AstrBot."""
if self._client:
await self._client.aclose()
self.state.connected = False
async def load_history(self) -> None:
"""Load message history for the current session."""
if not self._client or not self.conversation_id:
return
try:
resp = await self._client.get(
"/api/tui/get_session",
params={"session_id": self.conversation_id},
headers=self._headers,
)
if resp.status_code != 200:
return
data = resp.json()
history = data.get("data", {}).get("history", [])
for record in reversed(history):
content = record.get("content", {})
msg_type = content.get("type")
message_parts = content.get("message", [])
if msg_type == "user":
for part in message_parts:
if part.get("type") == "plain":
self.add_message(MessageSender.USER, part.get("text", ""))
elif msg_type == "bot":
for part in message_parts:
if part.get("type") == "plain":
self.add_message(MessageSender.BOT, part.get("text", ""))
except Exception:
if self.debug:
import traceback
traceback.print_exc()
def add_message(self, sender: MessageSender, text: str) -> None:
"""Add a message to the chat log."""
if not text:
return
self.state.messages.append(Message(sender=sender, text=text))
if len(self.state.messages) > self._max_messages:
self.state.messages = self.state.messages[-self._max_messages :]
def add_system_message(self, text: str) -> None:
"""Add a system message."""
self.add_message(MessageSender.SYSTEM, text)
def handle_key(self, key: int) -> bool:
"""Handle a keypress. Returns True if the application should continue running."""
if key in (curses.KEY_EXIT, 27): # ESC or ctrl-c
return False
if key == curses.KEY_RESIZE:
self.screen.resize()
return True
# Handle arrow keys for navigation
if key == curses.KEY_LEFT:
if self.state.cursor_x > 0:
self.state.cursor_x -= 1
elif key == curses.KEY_RIGHT:
if self.state.cursor_x < len(self.state.input_buffer):
self.state.cursor_x += 1
elif key == curses.KEY_HOME:
self.state.cursor_x = 0
elif key == curses.KEY_END:
self.state.cursor_x = len(self.state.input_buffer)
# Handle backspace
elif key in (curses.KEY_BACKSPACE, 127, 8):
if self.state.cursor_x > 0:
self.state.input_buffer = (
self.state.input_buffer[: self.state.cursor_x - 1]
+ self.state.input_buffer[self.state.cursor_x :]
)
self.state.cursor_x -= 1
# Handle delete
elif key == curses.KEY_DC:
if self.state.cursor_x < len(self.state.input_buffer):
self.state.input_buffer = (
self.state.input_buffer[: self.state.cursor_x]
+ self.state.input_buffer[self.state.cursor_x + 1 :]
)
# Handle Enter/Return - submit message
elif key in (curses.KEY_ENTER, 10, 13):
if self.state.input_buffer.strip():
task = asyncio.create_task(self._submit_message())
self._pending_tasks.append(task)
return True
# Handle history navigation (up/down arrows)
elif key == curses.KEY_UP:
if (
self._input_history
and self._history_index < len(self._input_history) - 1
):
self._history_index += 1
self.state.input_buffer = self._input_history[self._history_index]
self.state.cursor_x = len(self.state.input_buffer)
elif key == curses.KEY_DOWN:
if self._history_index > 0:
self._history_index -= 1
self.state.input_buffer = self._input_history[self._history_index]
self.state.cursor_x = len(self.state.input_buffer)
elif self._history_index == 0:
self._history_index = -1
self.state.input_buffer = ""
self.state.cursor_x = 0
# Regular character input
elif 32 <= key <= 126:
char = chr(key)
self.state.input_buffer = (
self.state.input_buffer[: self.state.cursor_x]
+ char
+ self.state.input_buffer[self.state.cursor_x :]
)
self.state.cursor_x += 1
# Clear input with Ctrl+L
elif key == 12: # Ctrl+L
self.state.input_buffer = ""
self.state.cursor_x = 0
return True
async def _submit_message(self) -> None:
"""Submit the current input buffer as a user message."""
text = self.state.input_buffer.strip()
if not text:
return
# Add to history
self._input_history.insert(0, text)
if len(self._input_history) > self._max_history:
self._input_history = self._input_history[: self._max_history]
self._history_index = -1
# Add user message to chat
self.add_message(MessageSender.USER, text)
# Clear input
self.state.input_buffer = ""
self.state.cursor_x = 0
# Process the message via API
await self._process_user_message(text)
async def _process_user_message(self, text: str) -> None:
"""Send message to AstrBot and process the streaming response."""
if not self.conversation_id or not self._client:
self.add_system_message("Not connected to AstrBot")
return
self.state.status = "Waiting for response..."
try:
# Format umo for tui
umo = f"tui:FriendMessage:tui!{self.username}!{self.conversation_id}"
# Reset parser for new stream
self._parser.reset()
# Send message and stream response using proper SSE
async with self._client.stream(
"POST",
"/api/tui/chat",
headers=self._headers,
json={
"umo": umo,
"message": text,
"session_id": self.conversation_id,
"streaming": True,
},
timeout=None,
) as response:
if response.status_code != 200:
self._update_last_bot_message(f"Error: HTTP {response.status_code}")
self.state.status = "Error"
return
# Process streaming SSE
async for line in response.aiter_lines():
parsed = self._parser.parse_line(line)
if parsed is None:
continue
update, is_complete = self._process_parsed_message(parsed)
# Update display based on message type
if parsed.type == MessageType.TOOL_CALL:
tool_call = json.loads(parsed.data)
self.add_message(
MessageSender.TOOL,
f"[Tool: {tool_call.get('name', 'unknown')}]",
)
self.state.status = "Running tool..."
elif parsed.type == MessageType.TOOL_CALL_RESULT:
try:
tcr = json.loads(parsed.data)
self.add_message(
MessageSender.TOOL,
f"[Result] {tcr.get('result', '')[:100]}...",
)
except json.JSONDecodeError:
pass
elif parsed.type == MessageType.REASONING:
self._update_last_bot_message(
f"[Thinking] {update.reasoning[-200:]}"
)
self.state.status = "Thinking..."
elif parsed.type == MessageType.AGENT_STATS:
self.state.status = (
f"Tokens: {update.agent_stats.get('total_tokens', 0)}"
)
elif update.text:
self._update_last_bot_message(update.text)
if is_complete:
break
# Final status
if update.reasoning:
self.add_message(
MessageSender.REASONING, f"[Reasoning]\n{update.reasoning}"
)
for tool_display in update.get_tool_calls_display():
self.add_message(MessageSender.TOOL, tool_display)
if update.error:
self.add_message(MessageSender.SYSTEM, f"Error: {update.error}")
self.state.status = "Ready"
except asyncio.CancelledError:
self.state.status = "Cancelled"
except Exception as e:
self.add_system_message(f"Error: {e}")
self.state.status = f"Error: {e}"
if self.debug:
import traceback
traceback.print_exc()
def _process_parsed_message(self, msg: ParsedMessage) -> tuple[ChatResponse, bool]:
"""Process a parsed message and return updated response state."""
return self._parser.process_message(msg)
def _update_last_bot_message(self, text: str) -> None:
"""Update the last bot message with new text (for streaming)."""
for i in range(len(self.state.messages) - 1, -1, -1):
if self.state.messages[i].sender == MessageSender.BOT:
self.state.messages[i] = Message(
sender=MessageSender.BOT,
text=text,
timestamp=self.state.messages[i].timestamp,
)
break
else:
self.add_message(MessageSender.BOT, text)
def render(self) -> None:
"""Render the current state to the screen."""
lines = [(msg.sender.value, msg.text) for msg in self.state.messages]
self.screen.draw_all(
lines=lines,
input_text=self.state.input_buffer,
cursor_x=self.state.cursor_x,
status=self.state.status,
)
async def run_event_loop(self, stdscr: curses.window) -> None:
"""Main event loop for the TUI."""
# Setup
self.screen.setup_colors()
self.screen.layout_windows()
# Connect to AstrBot
connected = await self.connect()
if not connected:
self.add_system_message(f"Failed to connect: {self.state.status}")
else:
self.add_system_message("Connected to AstrBot!")
# Load history
await self.load_history()
# Welcome message
self.add_system_message("Type your message and press Enter to send.")
self.add_system_message("Press ESC or Ctrl+C to exit.")
# Initial render
self.render()
# Input loop
while self.state.running:
# Get input with timeout
self.screen.input_win.nodelay(True)
try:
key = self.screen.input_win.getch()
except curses.error:
key = -1
if key != -1:
if not self.handle_key(key):
self.state.running = False
break
self.render()
# Small sleep to prevent CPU hogging
await asyncio.sleep(0.01)
# Cleanup
await self.disconnect()
def run_tui_async(
debug: bool = False,
host: str = "http://localhost:6185",
api_key: str | None = None,
username: str = "astrbot",
password: str = "astrbot",
) -> None:
"""Entry point to run the TUI application."""
from astrbot.tui.screen import run_curses
def main(stdscr: curses.window) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
scr = Screen(stdscr)
client = TUIClient(
screen=scr,
host=host,
api_key=api_key,
username=username,
password=password,
debug=debug,
)
try:
loop.run_until_complete(client.run_event_loop(stdscr))
finally:
loop.close()
run_curses(main)
if __name__ == "__main__":
run_tui_async()

View File

@@ -1,285 +0,0 @@
"""Internationalization support for AstrBot CLI.
This module provides i18n support with Chinese and English languages.
Language is auto-detected from environment or can be set manually.
"""
from __future__ import annotations
import os
from enum import Enum
from functools import lru_cache
class Language(Enum):
"""Supported languages."""
ZH = "zh"
EN = "en"
# Translation dictionaries
_TRANSLATIONS: dict[Language, dict[str, str]] = {
Language.ZH: {
# CLI welcome and general
"cli_welcome": "欢迎使用 AstrBot CLI!",
"cli_version": "AstrBot CLI 版本: {version}",
"cli_unknown_command": "未知命令: {command}",
"cli_help_available": "使用 astrbot help --all 查看所有命令",
# Dashboard commands
"dashboard_bundled": "Dashboard 已打包在安装包中 - 跳过下载",
"dashboard_not_installed": "Dashboard 未安装",
"dashboard_install_confirm": "是否安装 Dashboard?",
"dashboard_installing": "正在安装 Dashboard...",
"dashboard_install_success": "Dashboard 安装成功",
"dashboard_install_failed": "Dashboard 安装失败: {error}",
"dashboard_not_needed": "Dashboard 不需要安装",
"dashboard_declined": "Dashboard 安装已取消",
"dashboard_already_up_to_date": "Dashboard 已是最新版本",
"dashboard_version": "Dashboard 版本: {version}",
"dashboard_download_failed": "Dashboard 下载失败: {error}",
"dashboard_init_dir": "正在初始化 Dashboard 目录...",
"dashboard_init_success": "Dashboard 初始化成功",
# Plugin commands
"plugin_installing": "正在安装插件: {name}",
"plugin_install_success": "插件安装成功: {name}",
"plugin_install_failed": "插件安装失败: {name}",
"plugin_uninstall_confirm": "确定要卸载插件 {name} 吗?",
"plugin_uninstall_success": "插件卸载成功: {name}",
"plugin_uninstall_failed": "插件卸载失败: {name}",
"plugin_list_empty": "未安装任何插件",
"plugin_already_installed": "插件已安装: {name}",
"plugin_not_found": "插件未找到: {name}",
"plugin_already_exists": "插件已存在: {name}",
"plugin_not_found_or_installed": "插件未找到或已安装: {name}",
"plugin_uninstall_failed_ex": "插件卸载失败 {name}: {error}",
"plugin_no_update_needed": "没有需要更新的插件",
"plugin_found_update": "发现 {count} 个插件需要更新",
"plugin_updating": "正在更新插件 {name}...",
"plugin_search_no_result": "未找到匹配 '{query}' 的插件",
"plugin_search_results": "搜索结果: '{query}'",
# Config commands
"config_show": "显示配置",
"config_set_success": "配置项已更新: {key} = {value}",
"config_set_failed": "配置项更新失败: {key}",
"config_set_failed_ex": "设置配置失败: {error}",
"config_get_success": "{key} = {value}",
"config_get_not_found": "配置项未找到: {key}",
"config_reset_confirm": "确定要重置所有配置吗?",
"config_reset_success": "配置已重置",
# Config validators
"config_log_level_invalid": "日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
"config_port_must_be_number": "端口必须是数字",
"config_port_range_invalid": "端口必须在 1-65535 范围内",
"config_username_empty": "用户名不能为空",
"config_password_empty": "密码不能为空",
"config_timezone_invalid": "无效的时区: {value}。请使用有效的 IANA 时区名称",
"config_callback_invalid": "回调 API 基础路径必须以 http:// 或 https:// 开头",
"config_key_unsupported": "不支持的配置项: {key}",
"config_key_unknown": "未知的配置项: {key}",
"config_updated": "配置已更新: {key}",
# Init command
"init_creating": "正在创建配置目录...",
"init_created": "配置目录已创建: {path}",
"init_copying": "正在复制配置文件...",
"init_copied": "配置文件已复制",
"init_success": "AstrBot 初始化完成!",
"init_failed": "初始化失败: {error}",
# Run command
"run_starting": "正在启动 AstrBot...",
"run_started": "AstrBot 已启动!",
"run_backend_only": "以无界面模式启动",
"run_failed": "启动失败: {error}",
"run_stopped": "AstrBot 已停止",
# TUI command
"tui_starting": "正在启动 TUI...",
"tui_started": "TUI 已启动",
"tui_failed": "TUI 启动失败: {error}",
# Common
"yes": "",
"no": "",
"cancel": "取消",
"confirm": "确认",
"error": "错误",
"success": "成功",
"warning": "警告",
"info": "信息",
"loading": "加载中...",
"done": "完成",
"failed": "失败",
"retry": "重试",
"exit": "退出",
"continue": "继续",
},
Language.EN: {
# CLI welcome and general
"cli_welcome": "Welcome to AstrBot CLI!",
"cli_version": "AstrBot CLI version: {version}",
"cli_unknown_command": "Unknown command: {command}",
"cli_help_available": "Use astrbot help --all to see all commands",
# Dashboard commands
"dashboard_bundled": "Dashboard is bundled with the package - skipping download",
"dashboard_not_installed": "Dashboard is not installed",
"dashboard_install_confirm": "Install Dashboard?",
"dashboard_installing": "Installing Dashboard...",
"dashboard_install_success": "Dashboard installed successfully",
"dashboard_install_failed": "Failed to install dashboard: {error}",
"dashboard_not_needed": "Dashboard not needed",
"dashboard_declined": "Dashboard installation declined.",
"dashboard_already_up_to_date": "Dashboard is already up to date",
"dashboard_version": "Dashboard version: {version}",
"dashboard_download_failed": "Failed to download dashboard: {error}",
"dashboard_init_dir": "Initializing dashboard directory...",
"dashboard_init_success": "Dashboard initialized successfully",
# Plugin commands
"plugin_installing": "Installing plugin: {name}",
"plugin_install_success": "Plugin installed successfully: {name}",
"plugin_install_failed": "Failed to install plugin: {name}",
"plugin_uninstall_confirm": "Uninstall plugin {name}?",
"plugin_uninstall_success": "Plugin uninstalled successfully: {name}",
"plugin_uninstall_failed": "Failed to uninstall plugin: {name}",
"plugin_list_empty": "No plugins installed",
"plugin_already_installed": "Plugin already installed: {name}",
"plugin_not_found": "Plugin not found: {name}",
"plugin_already_exists": "Plugin {name} already exists",
"plugin_not_found_or_installed": "Plugin {name} not found or already installed",
"plugin_uninstall_failed_ex": "Failed to uninstall plugin {name}: {error}",
"plugin_no_update_needed": "No plugins need updating",
"plugin_found_update": "Found {count} plugin(s) needing update",
"plugin_updating": "Updating plugin {name}...",
"plugin_search_no_result": "No plugins matching '{query}' found",
"plugin_search_results": "Search results: '{query}'",
# Config commands
"config_show": "Show configuration",
"config_set_success": "Configuration updated: {key} = {value}",
"config_set_failed": "Failed to update configuration: {key}",
"config_set_failed_ex": "Failed to set config: {error}",
"config_get_success": "{key} = {value}",
"config_get_not_found": "Configuration key not found: {key}",
"config_reset_confirm": "Reset all configuration?",
"config_reset_success": "Configuration reset",
# Config validators
"config_log_level_invalid": "Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
"config_port_must_be_number": "Port must be a number",
"config_port_range_invalid": "Port must be in range 1-65535",
"config_username_empty": "Username cannot be empty",
"config_password_empty": "Password cannot be empty",
"config_timezone_invalid": "Invalid timezone: {value}. Please use a valid IANA timezone name",
"config_callback_invalid": "Callback API base must start with http:// or https://",
"config_key_unsupported": "Unsupported config key: {key}",
"config_key_unknown": "Unknown config key: {key}",
"config_updated": "Config updated: {key}",
# Init command
"init_creating": "Creating config directory...",
"init_created": "Config directory created: {path}",
"init_copying": "Copying config files...",
"init_copied": "Config files copied",
"init_success": "AstrBot initialized successfully!",
"init_failed": "Initialization failed: {error}",
# Run command
"run_starting": "Starting AstrBot...",
"run_started": "AstrBot started!",
"run_backend_only": "Starting in backend-only mode",
"run_failed": "Failed to start: {error}",
"run_stopped": "AstrBot stopped",
# TUI command
"tui_starting": "Starting TUI...",
"tui_started": "TUI started",
"tui_failed": "Failed to start TUI: {error}",
# Common
"yes": "Yes",
"no": "No",
"cancel": "Cancel",
"confirm": "Confirm",
"error": "Error",
"success": "Success",
"warning": "Warning",
"info": "Info",
"loading": "Loading...",
"done": "Done",
"failed": "Failed",
"retry": "Retry",
"exit": "Exit",
"continue": "Continue",
},
}
@lru_cache(maxsize=1)
def get_current_language() -> Language:
"""Get the current language based on environment or default.
Detection order:
1. ASTRBOT_CLI_LANG environment variable (zh/en)
2. LANG environment variable (if contains zh/cn)
3. LC_ALL environment variable (if contains zh/cn)
4. Default to Chinese (most users are Chinese)
"""
# Check explicit override first
explicit = os.environ.get("ASTRBOT_CLI_LANG", "").lower()
if explicit in ("zh", "en"):
return Language.ZH if explicit == "zh" else Language.EN
# Check LANG/LC_ALL for Chinese
for env_var in ("LANG", "LC_ALL"):
lang = os.environ.get(env_var, "").lower()
if "zh" in lang or "cn" in lang:
return Language.ZH
# Default to Chinese for broader appeal
return Language.ZH
def set_language(lang: Language) -> None:
"""Set the current language (clears all translation caches)."""
get_current_language.cache_clear()
_t_cached.cache_clear()
# Set environment variable for persistence
os.environ["ASTRBOT_CLI_LANG"] = lang.value
@lru_cache(maxsize=128)
def _t_cached(key: str, lang: Language) -> str:
"""Cached translation lookup."""
return _TRANSLATIONS.get(lang, {}).get(key, key)
def t(translation_key: str, **kwargs: str) -> str:
"""Get translation for the given key in the current language.
Args:
translation_key: Translation key (e.g., "cli_welcome", "plugin_installing")
**kwargs: Format arguments for the translation string
Returns:
Translated string, or the key itself if not found
"""
result = _t_cached(translation_key, get_current_language())
if kwargs:
result = result.format(**kwargs)
return result
def tr(key: str, **kwargs: str) -> str:
"""Get translation (alias for t())."""
return t(key, **kwargs)
class CLITranslations:
"""Translation accessor class for CLI contexts.
Usage:
translations = CLITranslations()
print(translations.cli_welcome)
print(translations.plugin_installing(name="my_plugin"))
"""
def __getattr__(self, key: str) -> str:
return t(key)
def __call__(self, key: str, **kwargs: str) -> str:
return t(key, **kwargs)
# Convenience instance
translations = CLITranslations()

View File

@@ -1,12 +1,18 @@
from .dashboard import DashboardManager
from .basic import (
check_astrbot_root,
check_dashboard,
get_astrbot_root,
)
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin
from .version_comparator import VersionComparator
__all__ = [
"DashboardManager",
"PluginStatus",
"VersionComparator",
"build_plug_list",
"check_astrbot_root",
"check_dashboard",
"get_astrbot_root",
"get_git_repo",
"manage_plugin",
]

View File

@@ -0,0 +1,91 @@
from importlib import resources
from pathlib import Path
import click
from astrbot.core.utils.astrbot_path import astrbot_paths
# Static assets bundled inside the installed wheel (built by hatch_build.py).
# _BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
_BUNDLED_DIST = resources.files("astrbot") / "dashboard" / "dist"
def check_astrbot_root(path: str | Path) -> bool:
"""Check if the path is an AstrBot root directory"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
return False
if not (path / ".astrbot").exists():
return False
return True
def get_astrbot_root() -> Path:
"""Get the AstrBot root directory path"""
return astrbot_paths.root
async def check_dashboard(astrbot_root: Path) -> None:
"""Check if the dashboard is installed"""
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from .version_comparator import VersionComparator
# If the wheel ships bundled dashboard assets, no network download is needed.
if _BUNDLED_DIST.is_dir():
click.echo("Dashboard is bundled with the package skipping download.")
return
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("Dashboard is not installed")
if click.confirm(
"Install dashboard?",
default=True,
abort=True,
):
click.echo("Installing dashboard...")
try:
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
click.echo("Dashboard installed successfully")
except Exception as e:
click.echo(f"Failed to install dashboard: {e}")
case str():
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("Dashboard is already up to date")
return
try:
version = dashboard_version.split("v")[1]
click.echo(f"Dashboard version: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(f"Failed to download dashboard: {e}")
return
except FileNotFoundError:
click.echo("Initializing dashboard directory...")
try:
await download_dashboard(
path=str(astrbot_root / "data" / "dashboard.zip"),
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
click.echo("Dashboard initialized successfully")
except Exception as e:
click.echo(f"Failed to download dashboard: {e}")
return

View File

@@ -1,79 +0,0 @@
import sys
from importlib import resources
from pathlib import Path
import click
from astrbot.cli.i18n import t
from .version_comparator import VersionComparator
class DashboardManager:
_bundled_dist = resources.files("astrbot") / "dashboard" / "dist"
async def ensure_installed(self, astrbot_root: Path) -> None:
"""Ensure the dashboard assets are installed and up to date."""
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
if self._bundled_dist.is_dir():
click.echo(t("dashboard_bundled"))
return
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo(t("dashboard_not_installed"))
# Skip interactive prompt in non-interactive environments
if not sys.stdin.isatty():
click.echo(t("dashboard_not_needed"))
return
if click.confirm(t("dashboard_install_confirm"), default=True):
click.echo(t("dashboard_installing"))
try:
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
click.echo(t("dashboard_install_success"))
except Exception as e:
click.echo(t("dashboard_install_failed", error=str(e)))
else:
click.echo(t("dashboard_declined"))
case str():
if (
VersionComparator.compare_version(VERSION, dashboard_version)
<= 0
):
click.echo(t("dashboard_already_up_to_date"))
return
try:
version = dashboard_version.split("v")[1]
click.echo(t("dashboard_version", version=version))
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(t("dashboard_download_failed", error=str(e)))
return
except FileNotFoundError:
click.echo(t("dashboard_init_dir"))
try:
await download_dashboard(
path=str(astrbot_root / "data" / "dashboard.zip"),
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
click.echo(t("dashboard_init_success"))
except Exception as e:
click.echo(t("dashboard_download_failed", error=str(e)))
return

View File

@@ -66,16 +66,16 @@ pip_installer = PipInstaller(
astrbot_config.get("pypi_index_url", None),
)
__all__ = [
"DEMO_MODE",
"AstrBotConfig",
"LogBroker",
"LogManager",
"DEMO_MODE",
"astrbot_config",
"db_helper",
"file_token_service",
"t2i_base_url",
"html_renderer",
"logger",
"pip_installer",
"LogBroker",
"LogManager",
"db_helper",
"sp",
"t2i_base_url",
"file_token_service",
"pip_installer",
]

View File

@@ -212,7 +212,7 @@ class LLMSummaryCompressor:
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = [*messages_to_summarize, instruction_message]
llm_payload = messages_to_summarize + [instruction_message]
# generate summary
try:

View File

@@ -1,7 +1,7 @@
import json
from typing import Protocol, runtime_checkable
from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart
from ..message import Message, TextPart
@runtime_checkable
@@ -28,19 +28,9 @@ class TokenCounter(Protocol):
...
# 图片/音频 token 开销估算值,参考 OpenAI vision pricing:
# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千。
# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错。
IMAGE_TOKEN_ESTIMATE = 765
AUDIO_TOKEN_ESTIMATE = 500
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
Supports multimodal content: images, audio, and thinking parts
are all counted so that the context compressor can trigger in time.
"""
def count_tokens(
@@ -55,16 +45,12 @@ class EstimateTokenCounter:
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
elif isinstance(part, ThinkPart):
total += self._estimate_tokens(part.think)
elif isinstance(part, ImageURLPart):
total += IMAGE_TOKEN_ESTIMATE
elif isinstance(part, AudioURLPart):
total += AUDIO_TOKEN_ESTIMATE
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())

View File

@@ -12,74 +12,14 @@ class ContextTruncator:
and len(message.tool_calls) > 0
)
@staticmethod
def _split_system_rest(
messages: list[Message],
) -> tuple[list[Message], list[Message]]:
"""Split messages into system messages and the rest.
Returns:
tuple: (system_messages, non_system_messages)
"""
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
return messages[:first_non_system], messages[first_non_system:]
@staticmethod
def _ensure_user_message(
system_messages: list[Message],
truncated: list[Message],
original_messages: list[Message],
) -> list[Message]:
"""Ensure the result always contains a `user` message immediately after
system messages, as required by some LLM APIs.
Optimization strategy:
- If `truncated` already begins with a `user` message, return it as-is.
- If a `user` message exists later in `truncated`, move that message to
be the first non-system message while preserving the relative order of
the remaining truncated messages (without mutating the original list).
- Otherwise, fall back to the first `user` message from
`original_messages`.
This reduces unnecessary duplication and ensures the required ordering.
"""
if truncated and truncated[0].role == "user":
return system_messages + truncated
# If a user message exists inside the truncated list, promote it to the front.
index_in_truncated = next(
(i for i, m in enumerate(truncated) if m.role == "user"), None
)
if index_in_truncated is not None:
# Build a new truncated list that places the found user message first,
# preserving the order of the other messages and avoiding in-place mutation.
user_msg = truncated[index_in_truncated]
new_truncated = [
user_msg,
*truncated[:index_in_truncated],
*truncated[index_in_truncated + 1 :],
]
return system_messages + new_truncated
# Fallback: find the first user message in the original messages.
first_user = next((m for m in original_messages if m.role == "user"), None)
if first_user is None:
# No user messages at all; return system messages + whatever was truncated.
return system_messages + truncated
return [*system_messages, first_user, *truncated]
def fix_messages(self, messages: list[Message]) -> list[Message]:
"""Fix the message list to ensure the validity of tool call and tool response pairing.
"""修复消息列表,确保 tool call tool response 的配对关系有效。
This method ensures that:
1. Each `tool` message is preceded by an `assistant` message containing `tool_calls`.
2. Each `assistant` message containing `tool_calls` is followed by corresponding `
此方法确保:
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
This is a requirement of the OpenAI Chat Completions API specification (Gemini enforces this strictly).
这是 OpenAI Chat Completions API 规范的要求Gemini 对此执行严格检查)。
"""
if not messages:
return messages
@@ -98,25 +38,24 @@ class ContextTruncator:
for msg in messages:
if msg.role == "tool":
# Only record tool responses when there is a pending assistant(tool_calls)
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
if pending_assistant is not None:
pending_tools.append(msg)
# Isolated tool messages without a preceding assistant(tool_calls) are ignored
# else: 孤立的 tool 消息,直接忽略
continue
if self._has_tool_calls(msg):
# When encountering a new assistant(tool_calls), first process the old pending chain
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending
flush_pending_if_valid()
pending_assistant = msg
continue
# Non-tool messages that do not contain tool_calls will break the pending chain.
# Flush any pending chain first, then append the current message normally.
# 非 tool且不含 tool_calls 的消息
# 先结束任何 pending 链,再正常追加
flush_pending_if_valid()
fixed_messages.append(msg)
# Flush the last pending chain at the end,
# ensuring that any remaining valid assistant(tool_calls) and its tools are included in the final list.
# 结束时处理最后一个 pending 链
flush_pending_if_valid()
return fixed_messages
@@ -127,23 +66,29 @@ class ContextTruncator:
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""
Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
A turn consists of a user message and an assistant message.
This method ensures that the truncated context list conforms to OpenAI's context format.
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: The original list of messages in the context.
keep_most_recent_turns: The number of most recent turns to keep. If set to -1, it means keeping all turns (no truncation).
drop_turns: The number of turns to drop from the beginning.
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
The truncated list of messages.
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
@@ -154,7 +99,7 @@ class ContextTruncator:
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# Find the first user message
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
@@ -162,9 +107,8 @@ class ContextTruncator:
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = self._ensure_user_message(
system_messages, truncated_contexts, messages
)
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
@@ -172,39 +116,53 @@ class ContextTruncator:
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""Drop the oldest N turns, regardless of the number of turns to keep."""
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
# Find the first user message
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
result = self._ensure_user_message(
system_messages, truncated_non_system, messages
)
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""Halve the number of messages, keeping the most recent ones."""
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
@@ -212,7 +170,6 @@ class ContextTruncator:
truncated_non_system = non_system_messages[messages_to_delete:]
# Find the first user message
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
@@ -220,7 +177,6 @@ class ContextTruncator:
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = self._ensure_user_message(
system_messages, truncated_non_system, messages
)
result = system_messages + truncated_non_system
return self.fix_messages(result)

View File

@@ -1,27 +1,8 @@
"""
MCP client - DEPRECATED
.. deprecated::
This module has been moved to :mod:`astrbot._internal.mcp`.
Please update your imports accordingly.
Old import (deprecated):
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
New import:
from astrbot._internal.mcp import MCPClient, MCPTool
This file exists solely for backward compatibility and will be removed in a future version.
"""
import asyncio
import logging
import os
import sys
import warnings
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, Generic
from typing import Generic
from tenacity import (
before_sleep_log,
@@ -31,20 +12,13 @@ from tenacity import (
wait_exponential,
)
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
logger = logging.getLogger("astrbot")
warnings.warn(
"astrbot.core.agent.mcp_client has been moved to astrbot._internal.mcp. "
"Please update your imports.",
DeprecationWarning,
stacklevel=2,
)
try:
import anyio
import mcp
@@ -62,26 +36,6 @@ except (ModuleNotFoundError, ImportError):
)
class TenacityLogger:
"""Wraps a logging.Logger to satisfy tenacity's LoggerProtocol."""
__slots__ = ("_logger",)
_logger: logging.Logger
def __init__(self, logger: logging.Logger) -> None:
self._logger = logger
def log(
self,
level: int,
msg: str,
/,
*args: Any,
**kwargs: Any,
) -> None:
self._logger.log(level, msg, *args, **kwargs)
def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format"""
if config.get("mcpServers"):
@@ -91,22 +45,6 @@ def _prepare_config(config: dict) -> dict:
return config
def _prepare_stdio_env(config: dict) -> dict:
"""Preserve Windows executable resolution for stdio subprocesses."""
if sys.platform != "win32":
return config
pathext = os.environ.get("PATHEXT")
if not pathext:
return config
prepared = config.copy()
env = dict(prepared.get("env") or {})
env.setdefault("PATHEXT", pathext)
prepared["env"] = env
return prepared
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity"""
import aiohttp
@@ -181,7 +119,6 @@ class MCPClient:
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
self.process_pid: int | None = None
# Store connection config for reconnection
self._mcp_server_config: dict | None = None
@@ -189,24 +126,6 @@ class MCPClient:
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
@staticmethod
def _extract_stdio_process_pid(streams_context: object) -> int | None:
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
TODO(refactor): replace this async-generator frame introspection with a
stable MCP library hook once the upstream transport exposes process PID.
"""
generator = getattr(streams_context, "gen", None)
frame = getattr(generator, "ag_frame", None)
if frame is None:
return None
process = frame.f_locals.get("process")
pid = getattr(process, "pid", None)
try:
return int(pid) if pid is not None else None
except (TypeError, ValueError):
return None
async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
"""Connect to MCP server
@@ -222,7 +141,6 @@ class MCPClient:
# Store config for reconnection
self._mcp_server_config = mcp_server_config
self._server_name = name
self.process_pid = None
cfg = _prepare_config(mcp_server_config.copy())
@@ -232,7 +150,7 @@ class MCPClient:
# Handle MCP service error logs
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
if "url" in cfg:
@@ -265,7 +183,7 @@ class MCPClient:
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback,
logging_callback=logging_callback, # type: ignore
),
)
else:
@@ -291,12 +209,11 @@ class MCPClient:
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback,
logging_callback=logging_callback, # type: ignore
),
)
else:
cfg = _prepare_stdio_env(cfg)
server_params = mcp.StdioServerParameters(
**cfg,
)
@@ -311,7 +228,7 @@ class MCPClient:
"alert",
"emergency",
):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
stdio_transport = await self.exit_stack.enter_async_context(
@@ -322,10 +239,9 @@ class MCPClient:
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
), # type: ignore
),
)
self.process_pid = self._extract_stdio_process_pid(self._streams_context)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
@@ -416,7 +332,7 @@ class MCPClient:
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(TenacityLogger(logger), logging.WARNING),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def _call_with_retry():
@@ -455,7 +371,6 @@ class MCPClient:
# Set running_event first to unblock any waiting tasks
self.running_event.set()
self.process_pid = None
class MCPTool(FunctionTool, Generic[TContext]):

View File

@@ -16,7 +16,7 @@ class ContextWrapper(Generic[TContext]):
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 120 # Default tool call timeout in seconds
tool_call_timeout: int = 60 # Default tool call timeout in seconds
NoContext = ContextWrapper[None]

View File

@@ -1,15 +1,13 @@
import abc
from collections.abc import AsyncGenerator
import typing as T
from enum import Enum, auto
from typing import Any, Generic
from astrbot import logger
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.response import AgentResponse
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider.entities import LLMResponse, ProviderRequest
from astrbot.core.provider.provider import Provider
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
class AgentState(Enum):
@@ -21,32 +19,13 @@ class AgentState(Enum):
ERROR = auto() # Error state
class BaseAgentRunner(Generic[TContext]):
def __init__(
self,
):
self.tasks: set = set()
class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
**kwargs: T.Any,
) -> None:
"""Reset the agent to its initial state.
This method should be called before starting a new run.
@@ -54,14 +33,14 @@ class BaseAgentRunner(Generic[TContext]):
...
@abc.abstractmethod
async def step(self) -> AsyncGenerator[AgentResponse, None]:
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""Process a single step of the agent."""
...
@abc.abstractmethod
async def step_until_done(
self, max_step: int
) -> AsyncGenerator[AgentResponse, None]:
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
...

View File

@@ -1,24 +1,21 @@
import base64
import json
import sys
from collections.abc import AsyncGenerator
from typing import Any
import typing as T
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.response import AgentResponse, AgentResponseData
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.runners.base import AgentState, BaseAgentRunner
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.provider import Provider
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .coze_api_client import CozeAPIClient
if sys.version_info >= (3, 12):
@@ -33,45 +30,32 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = streaming
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
provider_config = provider_config or {}
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空")
raise Exception("Coze API Key 不能为空")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空")
raise Exception("Coze Bot ID 不能为空")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头",
"Coze API Base URL 格式不正确必须以 http:// 或 https:// 开头",
)
self.timeout = provider_config.get("timeout", 120)
@@ -86,7 +70,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
self.file_id_cache: dict[str, dict[str, str]] = {}
@override
async def step(self) -> AsyncGenerator[AgentResponse, None]:
async def step(self):
"""
执行 Coze Agent 的一个步骤
"""
@@ -99,7 +83,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
# 开始处理转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
@@ -107,15 +91,15 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
async for response in self._execute_coze_request():
yield response
except Exception as e:
logger.error(f"Coze 请求失败:{e!s}")
logger.error(f"Coze 请求失败{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Coze 请求失败:{e!s}"
role="err", completion_text=f"Coze 请求失败{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Coze 请求失败:{e!s}")
chain=MessageChain().message(f"Coze 请求失败{str(e)}")
),
)
finally:
@@ -123,8 +107,8 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
@override
async def step_until_done(
self, max_step: int
) -> AsyncGenerator[AgentResponse, None]:
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
@@ -168,7 +152,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
# 处理上下文中的图片
content = ctx["content"]
if isinstance(content, list):
# 多模态内容,需要处理图片
# 多模态内容需要处理图片
processed_content = []
for item in content:
if isinstance(item, dict):
@@ -293,7 +277,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
accumulated_content += content
message_started = True
# 如果是流式响应,发送增量数据
# 如果是流式响应发送增量数据
if self.streaming:
yield AgentResponse(
type="streaming_delta",
@@ -344,7 +328,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze,返回 file_id"""
"""下载图片并上传到 Coze返回 file_id"""
import hashlib
# 计算哈希实现缓存
@@ -365,7 +349,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
if session_id:
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id

View File

@@ -66,7 +66,7 @@ class CozeAPIClient:
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
raise Exception("Coze API 认证失败请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
@@ -75,7 +75,7 @@ class CozeAPIClient:
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}",
f"文件上传失败状态码: {response.status}, 响应: {response_text}",
)
try:
@@ -87,7 +87,7 @@ class CozeAPIClient:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
logger.debug(f"[Coze] 图片上传成功file_id: {file_id}")
return file_id
except asyncio.TimeoutError:
@@ -111,7 +111,7 @@ class CozeAPIClient:
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
raise Exception(f"下载图片失败状态码: {response.status}")
image_data = await response.read()
return image_data
@@ -145,7 +145,7 @@ class CozeAPIClient:
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload: dict[str, Any] = {
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
@@ -169,10 +169,10 @@ class CozeAPIClient:
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
raise Exception("Coze API 认证失败请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
raise Exception(f"Coze API 流式请求失败状态码: {response.status}")
# SSE
buffer = ""
@@ -226,10 +226,10 @@ class CozeAPIClient:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
raise Exception("Coze API 认证失败请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
raise Exception(f"Coze API 请求失败状态码: {response.status}")
try:
return json.loads(response_text)
@@ -299,6 +299,7 @@ if __name__ == "__main__":
async with await anyio.open_file("README.md", "rb") as f:
file_data = await f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
@@ -317,7 +318,7 @@ if __name__ == "__main__":
],
stream=True,
):
pass
print(f"Event: {event}")
finally:
await client.close()

Some files were not shown because too many files have changed in this diff Show More