Compare commits

..

27 Commits
dev ... v4.26.3

Author SHA1 Message Date
Soulter
7831c68660 chore: bump version to 4.26.3 2026-06-30 20:10:58 +08:00
LIghtJUNction
6067a70803 feat: support local plugin install (#8448)
* feat: support local plugin install

* fix: make editable plugin install symlink

* fix: harden local plugin install

* Update tests/test_cli_plugin.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update astrbot/cli/commands/cmd_plug.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-06-29 01:44:06 +08:00
Weilong Liao
89ec07a92b style: standardize dashboard dialog styling (#9062) 2026-06-28 14:24:19 +08:00
lxfight
758e43273d fix: paginate knowledge base dashboard lists (#9055)
* fix: paginate knowledge base dashboard lists

* fix: preserve knowledge document search pagination
2026-06-28 14:00:45 +08:00
Weilong Liao
3d4c4ed01b fix: validate plugin install sources (#9061) 2026-06-28 13:45:59 +08:00
Gargantua
c1cc74b6bc fix: preserve fallback models for future tasks (#9054)
* fix: preserve fallback models for future tasks

* chore: remove local test markdown ignore

---------

Co-authored-by: Gargantua <22532097@zju.edu.cn>
2026-06-28 00:03:31 +08:00
Weilong Liao
a619988d2d chore: bump version to 4.26.2 (#9052) 2026-06-27 17:50:04 +08:00
Weilong Liao
de572e3fe0 fix: avoid duplicate send_message_to_user replies (#9051) 2026-06-27 16:57:17 +08:00
F. Abyssalis
d4fa9d3d5d fix: align OpenAI tool message sanitizer (#8350)
Co-authored-by: Soulter <905617992@qq.com>
2026-06-27 16:27:32 +08:00
Copilot
771911a893 fix: keep Tab navigation within reset-password inputs in account dialog (#9049)
* Initial plan

* fix: keep tab navigation on account dialog inputs

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-06-27 16:21:24 +08:00
Weilong Liao
8f30978c8d fix: separate plugin and tool activation state (#9048)
* fix: separate plugin and tool activation state

* style: apply ruff formatting

* fix: preserve manual tool deactivation on load

* fix: harden plugin tool state migration
2026-06-27 16:15:31 +08:00
Weilong Liao
dd2865d28c chore: remove plugin publish issue template (#9050) 2026-06-27 16:11:13 +08:00
Fiber
298078b536 Fix: KV storage not cleared on plugin uninstall. Improved cleanup logic and updated i18n strings to indicate database KV data removal. (#8291)
* fix: 完善插件卸载时的清理逻辑,新增KV数据清理,更新了多语言文案以说明会清理数据库KV数据

* fix: 修复插件关闭时不清理KV的问题,更新单元测试

* refactor: 统一插件ID生成逻辑

将插件ID生成逻辑抽离到StarMetadata类中,移除重复的代码实现,
同时在__post_init__中自动补全plugin_id字段。

* refactor: 将plugin_id属性从方法转换为属性,确保在属性赋值后正确计算
2026-06-27 16:06:14 +08:00
伊尔弥亚 - Irmia
3667487dd7 fix: DeepSeek V4 proxy model recognition — substring match instead of exact set match for reasoning_content injection (#9015)
* fix: DeepSeek V4 proxy model recognition — substring match for reasoning_content

* fix: remove deepseek-chat/reasoner exclusion per review feedback
2026-06-27 15:46:11 +08:00
renchonghan
3db778ff09 fix: prevent API 400 errors by ensuring assistant messages with reasoning_content but no content or tool_calls are preserved with a placeholder content value (#8483)
* fix: preserve assistant messages with reasoning_content in sanitize pass

When _sanitize_assistant_messages encounters an assistant message with empty content and no tool_calls but with reasoning_content, keep it with content set to empty string instead of dropping it. Reasoning models (DeepSeek V4, MiMo, etc.) require this history for subsequent turn validation.

* Update astrbot/core/provider/sources/openai_source.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fix: default to empty TokenUsage when completion.usage is None

When completion.usage is None (e.g. certain proxy/streaming edge cases), llm_response.usage stayed unset (None). Plugins accessing .input_tokens on it would crash with AttributeError.

Always assign llm_response.usage — extract from completion.usage if present, otherwise fall back to a zeroed TokenUsage().

Closes #8605

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: renchonghan <renchonghan@users.noreply.github.com>
2026-06-27 15:26:09 +08:00
F. Abyssalis
8213c14cc6 fix: sanitize orphaned tool_result blocks in Anthropic provider (#8952)
* sanitize orphaned tool_result blocks in Anthropic provider

Sanitize tool_result blocks and merge consecutive messages with the same role to comply with Anthropic API requirements.

* Fix content handling in message merging logic

* fix: sanitize anthropic assistant messages

* fix: validate anthropic tool result ordering

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-06-27 15:18:33 +08:00
Florance
b5784abc55 fix: only show plugin update for newer market versions
Only mark installed plugins as updatable when the marketplace version is a known newer version, including pre-release-to-stable transitions. Conflict resolution also keeps per-plugin marketplace source lookup from current master.
2026-06-27 15:12:53 +08:00
FuShang114
b8f4c7d515 fix: handle MiMo STT audio and reasoning output (#8938)
* fix: handle MiMo STT audio and reasoning output

* fix: 移除 MiMo STT 的系统和用户提示词配置

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-06-27 15:04:41 +08:00
Weilong Liao
0ef790d289 fix: track plugin install source for updates (#9037)
* fix: track plugin install source for updates

* feat: refine plugin source management flow

* test: align plugin update expectation
2026-06-27 15:02:04 +08:00
fan
534ad0ccc7 fix: _KeyRotator index bounds check (#9040) 2026-06-27 13:56:26 +08:00
MUHAMED FAZAL PS
b5e29511ac fix: reliably kill shell process tree on Windows timeout (#8822)
* fix: reliably kill shell process tree on Windows timeout

Fixes #8809

* fix: remove redundant import and wrap taskkill in try/except

- Remove 'import subprocess as _sp' (subprocess already imported at top)
- Use subprocess.run directly with DEVNULL for stdout/stderr
- Wrap taskkill in try/except to avoid masking original TimeoutExpired
- If taskkill fails, cleanup failures don't prevent proc.wait() or re-raise

https://buymeacoffee.com/muhamedfazalps

* style: apply ruff formatting to local.py

* test: fix shell component tests to match Popen-based implementation

The tests were monkeypatching subprocess.run but the implementation
now uses subprocess.Popen + communicate() for timeout handling.
Updated tests to mock Popen instead.

Fixes CI Unit Tests failure

* fix: harden windows shell timeout cleanup

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-06-27 00:10:59 +08:00
Weilong Liao
d6738a03f3 fix: preserve jpeg quality during conversion (#9031) 2026-06-26 22:43:15 +08:00
FuShang114
6dd5e1e080 fix: prevent plugin detail marketplace mismatches (#9028) 2026-06-26 22:12:29 +08:00
NayukiChiba
c93bedf04d fix: unify handling of whitespace in streamed message segments and ensure trailing buffers are stripped before sending. (#9029)
* perf(streaming): 将流式消息分段发送的空白过滤逻辑前移至核心方法

- 在 AstrMessageEvent 的 process_buffer 中统一进行 strip 并跳过空段,移除各平台子类的重复处理
- aiocqhttp 平台回退尾部 buffer 直接使用核心逻辑过滤,避免自身再次判断空白
- 简化流式消息发送链路,提升代码可维护性与执行效率

* test(strip-stream): 添加流式消息分段空白过滤的单元测试

- 新增 CollectingMessageEvent 与 CollectingAiocqhttpMessageEvent 辅助类,模拟消息发送与断言
- 覆盖 AstrMessageEvent.process_buffer 对段落前导空白行的 strip 和空段跳过逻辑
- 覆盖 AiocqhttpMessageEvent.send_streaming 回退缓冲区的空白过滤,确保跨平台行为一致

* fix(streaming): 调整流式消息分段发送限速等待位置

- 将限速等待(sleep)从循环末尾移动至消息发送成功后立即执行
- 仅在成功发送文本片段时才进行等待,避免无实际发送时的无谓延迟
- 修复因等待时机不当可能导致的流式消息发送频率异常问题

* test(strip-stream): 添加空白分段时跳过睡眠的单元测试
- 新增测试用例,验证当分段为空时不调用睡眠函数
- 确保处理缓冲区时正确剥离空白字符
- 测试通过验证发送消息的行为

* Delete tests/test_streaming_segment_strip.py

---------

Co-authored-by: Weilong Liao <37870767+Soulter@users.noreply.github.com>
2026-06-26 22:11:31 +08:00
Lovely Moe Moli
110cc8736c fix: 'DashboardRequest' object has no attribute 'get_data' (#9021) (#9023) 2026-06-26 22:04:29 +08:00
Light
ce05ac0db6 docs: add spanish readme (#9020)
* docs(i18n): add Spanish (es) translation for README

* docs(i18n): add Spanish (es) translation for README

* fix(i18n): 修复了导航栏的分隔符与多余字符
2026-06-26 21:52:46 +08:00
Jia
9ae33e9344 docs: update Python requirement to 3.12 (#9022) 2026-06-26 19:43:50 +08:00
1087 changed files with 56449 additions and 111510 deletions

View File

@@ -1,184 +0,0 @@
# ==========================================
# AstrBot Instance Configuration: ${INSTANCE_NAME}
# AstrBot 实例配置文件:${INSTANCE_NAME}
# ==========================================
# 将此文件复制为 .env 并根据需要修改。
# Copy this file to .env and modify as needed.
# 注意:在此处设置的变量将覆盖默认配置。
# Note: Variables set here override application defaults.
# ------------------------------------------
# 实例标识 / Instance Identity
# ------------------------------------------
# 实例名称(用于日志和服务名)
# Instance name (used in logs/service names)
INSTANCE_NAME="${INSTANCE_NAME}"
# ------------------------------------------
# 核心配置 / Core Configuration
# ------------------------------------------
# AstrBot 根目录路径
# AstrBot root directory path
# 默认 Default: 当前工作目录,桌面客户端为 ~/.astrbot服务器为 /var/lib/astrbot/<instance>/
# 示例 Example: /var/lib/astrbot/mybot
ASTRBOT_ROOT="${ASTRBOT_ROOT}"
# 日志等级
# Log level
# 可选值 Values: DEBUG, INFO, WARNING, ERROR, CRITICAL
# 默认 Default: INFO
# ASTRBOT_LOG_LEVEL=INFO
# 启用插件热重载(开发时有用)
# Enable plugin hot reload (useful for development)
# 可选值 Values: 0 (禁用 disabled), 1 (启用 enabled)
# 默认 Default: 0
# ASTRBOT_RELOAD=0
# 禁用匿名使用统计
# Disable anonymous usage statistics
# 可选值 Values: 0 (启用统计 enabled), 1 (禁用统计 disabled)
# 默认 Default: 0
ASTRBOT_DISABLE_METRICS=0
# 覆盖 Python 可执行文件路径(用于本地代码执行功能)
# Override Python executable path (for local code execution)
# 示例 Example: /usr/bin/python3, /home/user/.pyenv/shims/python
# PYTHON=/usr/bin/python3
# 启用演示模式(可能限制部分功能)
# Enable demo mode (may restrict certain features)
# 可选值 Values: True, False
# 默认 Default: False
# DEMO_MODE=False
# 启用测试模式(影响日志和部分行为)
# Enable testing mode (affects logging and behavior)
# 可选值 Values: True, False
# 默认 Default: False
# TESTING=False
# 标记:是否通过桌面客户端执行(主要用于内部)
# Flag: running via desktop client (internal use)
# 可选值 Values: 0, 1
# ASTRBOT_DESKTOP_CLIENT=0
# 标记:是否通过 systemd 服务执行
# Flag: running via systemd service
# 可选值 Values: 0, 1
ASTRBOT_SYSTEMD=1
# ------------------------------------------
# 管理面板配置 / Dashboard Configuration
# ------------------------------------------
# 启用或禁用 WebUI 管理面板
# Enable or disable WebUI dashboard
# 可选值 Values: True, False
# 默认 Default: True
ASTRBOT_DASHBOARD_ENABLE=True
# 允许跨域请求的来源域名(多个用逗号分隔,允许所有则用 *
# Allowed CORS origins for WebUI dashboard (comma-separated, or * for all)
# 示例 Example: https://dash.astrbot.men
# 默认 Default: *
# ASTRBOT_CORS_ALLOW_ORIGIN="*"
# ------------------------------------------
# 国际化配置 / Internationalization Configuration
# ------------------------------------------
# CLI 界面语言
# CLI interface language
# 可选值 Values: zh (中文), en (英文)
# 默认 Default: zh (跟随系统 locale / follows system locale)
# ASTRBOT_CLI_LANG=zh
# ------------------------------------------
# 网络配置 / Network Configuration
# ------------------------------------------
# API 绑定主机
# API bind host
# 示例 Example: 0.0.0.0 (所有接口 all interfaces), 127.0.0.1 (仅本地 localhost only)
ASTRBOT_HOST="${ASTRBOT_HOST}"
# API 绑定端口
# API bind port
# 示例 Example: 3000, 6185, 8080
ASTRBOT_PORT="${ASTRBOT_PORT}"
# 是否为 API 启用 SSL/TLS
# Enable SSL/TLS for API
# 可选值 Values: true, false
# 默认 Default: false
ASTRBOT_SSL_ENABLE=false
# SSL 证书路径PEM 格式)
# SSL certificate path (PEM format)
# 示例 Example: /etc/astrbot/certs/myinstance/fullchain.pem
ASTRBOT_SSL_CERT=""
# SSL 私钥路径PEM 格式)
# SSL private key path (PEM format)
# 示例 Example: /etc/astrbot/certs/myinstance/privkey.pem
ASTRBOT_SSL_KEY=""
# SSL CA 证书链路径(可选,用于客户端验证)
# SSL CA certificates bundle (optional, for client verification)
# 示例 Example: /etc/ssl/certs/ca-certificates.crt
ASTRBOT_SSL_CA_CERTS=""
# ------------------------------------------
# 代理配置 / Proxy Configuration
# ------------------------------------------
# HTTP 代理地址
# HTTP proxy URL
# 示例 Example: http://127.0.0.1:7890, socks5://127.0.0.1:1080
# http_proxy=
# HTTPS 代理地址
# HTTPS proxy URL
# 示例 Example: http://127.0.0.1:7890, socks5://127.0.0.1:1080
# https_proxy=
# 不走代理的主机列表(逗号分隔)
# Hosts to bypass proxy (comma-separated)
# 示例 Example: localhost,127.0.0.1,192.168.0.0/16,.local
# no_proxy=localhost,127.0.0.1
# ------------------------------------------
# 第三方集成 / Third-party Integrations
# ------------------------------------------
# 阿里云 DashScope API 密钥(用于 Rerank 服务)
# Alibaba DashScope API Key (for Rerank service)
# 获取地址 Get from: https://dashscope.console.aliyun.com/
# 示例 Example: sk-xxxxxxxxxxxx
# DASHSCOPE_API_KEY=
# Coze 集成
# Coze integration
# 获取地址 Get from: https://www.coze.com/
# COZE_API_KEY=
# COZE_BOT_ID=
# 计算机控制相关的数据目录(用于截图/文件存储)
# Computer control data directory (for screenshots/file storage)
# 示例 Example: /var/lib/astrbot/bay_data
# BAY_DATA_DIR=
# ------------------------------------------
# 平台特定配置 / Platform-specific Configuration
# ------------------------------------------
# QQ 官方机器人测试模式开关
# QQ official bot test mode
# 可选值 Values: on, off
# 默认 Default: off
# TEST_MODE=off
# End of template / 模板结束

2
.envrc
View File

@@ -1,2 +0,0 @@
git pull
git status

View File

@@ -1,57 +0,0 @@
name: 🥳 发布插件
description: 提交插件到插件市场
title: "[Plugin] 插件名"
labels: ["plugin-publish"]
assignees: []
body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!
- type: markdown
attributes:
value: |
## 插件基本信息
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?可以从 [此站](https://plugins.astrbot.app) 右下角提交。
- type: textarea
id: plugin-info
attributes:
label: 插件信息
description: 请在下方代码块中填写您的插件信息确保反引号包裹了JSON
value: |
```json
{
"name": "插件名,请以 astrbot_plugin_ 开头",
"display_name": "用于展示的插件名,方便人类阅读",
"desc": "插件的简短介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": "",
}
```
validations:
required: true
- type: markdown
attributes:
value: |
## 检查
- type: checkboxes
id: checks
attributes:
label: 插件检查清单
description: 请确认以下所有项目
options:
- label: 我的插件经过完整的测试
required: true
- label: 我的插件不包含恶意代码
required: true
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true

View File

@@ -30,7 +30,6 @@ jobs:
working-directory: dashboard
run: |
pnpm install --frozen-lockfile
pnpm lint:check
pnpm run build
- name: Inject Commit SHA

View File

@@ -58,7 +58,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: "24.13.0"
node-version: '24.13.0'
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
@@ -156,6 +156,7 @@ jobs:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: release-assets
- name: Resolve release notes
id: notes
shell: bash

View File

@@ -5,17 +5,30 @@ on:
branches:
- master
paths-ignore:
- "README*.md"
- "changelogs/**"
- "dashboard/**"
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
smoke-test:
name: Run smoke tests
runs-on: ubuntu-latest
name: Smoke test (${{ matrix.os }}, Python ${{ matrix.python-version }})
runs-on: ${{ matrix.os }}
timeout-minutes: 10
strategy:
fail-fast: false
matrix:
os:
- ubuntu-latest
- macos-latest
- windows-latest
python-version:
- '3.10'
- '3.11'
- '3.12'
- '3.13'
- '3.14'
steps:
- name: Checkout
@@ -26,36 +39,21 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: requirements.txt
- name: Install UV package manager
- name: Install uv
run: |
pip install uv
python -m pip install --upgrade pip
python -m pip install uv
- name: Install dependencies
run: |
uv sync
uv pip install --system -r requirements.txt
timeout-minutes: 15
- name: Run smoke tests
run: |
uv run main.py &
# uv tool install -e . --force
# astrbot init -y
# astrbot run --backend-only &
APP_PID=$!
echo "Waiting for application to start..."
for i in {1..60}; do
if curl -f http://localhost:6185 > /dev/null 2>&1; then
echo "Application started successfully!"
kill $APP_PID
exit 0
fi
sleep 1
done
echo "Application failed to start within 30 seconds"
kill $APP_PID 2>/dev/null || true
exit 1
python scripts/smoke_startup_check.py
timeout-minutes: 2

34
.gitignore vendored
View File

@@ -1,5 +1,6 @@
# Python related
__pycache__
.mypy_cache
.venv*
.conda/
uv.lock
@@ -8,6 +9,7 @@ uv.lock
# IDE and editors
.vscode
.idea
.zed/
# Logs and temporary files
botpy.log
@@ -50,46 +52,16 @@ astrbot.lock
chroma
venv/*
pytest.ini
AGENTS.md
IFLOW.md
CLAUDE.md
# genie_tts data
CharacterModels/
GenieData/
.agent/
.codex/
.claude/
.opencode/
.kilocode/
.serena
.worktrees/
.astrbot_sdk_testing/
.env
dashboard/warker.js
dashboard/bun.lock
.pua/
# Rust build artifacts
rust/target/
# Build outputs
dist/
*.whl
# 拓展模块
*.so
*.dll
# MDI font subset (generated by dashboard/scripts/subset-mdi-font.mjs)
dashboard/src/assets/mdi-subset/*.woff
dashboard/src/assets/mdi-subset/*.woff2
.planning
*cache
node_modules
*pinokio*
dashboard/pnpm-lock.yaml
.obsidian
dashboard/.codex
.codex
.zed/settings.json

View File

@@ -6,20 +6,20 @@ ci:
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.15.7
hooks:
# Run the linter.
- id: ruff-check
types_or: [python, pyi]
args: [--fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.2
hooks:
- id: pyupgrade
args: [--py312-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]

View File

@@ -1 +1 @@
3.12
3.12

296
AGENTS.md
View File

@@ -3,10 +3,8 @@
### Core
```
uv tool install -e . --force
astrbot init
astrbot run # start the bot
astrbot run --backend-only # start the backend only
uv sync
uv run main.py
```
Exposed an API server on `http://localhost:6185` by default.
@@ -15,8 +13,8 @@ Exposed an API server on `http://localhost:6185` by default.
```
cd dashboard
bun install # First time only.
bun dev
pnpm install # First time only. Use npm install -g pnpm if pnpm is not installed.
pnpm dev
```
Runs on `http://localhost:3000` by default.
@@ -43,215 +41,95 @@ ruff check .
## Dev environment tips
- **Main entry**: `astrbot/__main__.py` or via CLI `astrbot run`
- **CLI commands**: `astrbot/cli/commands/`
- **Core modules**: `astrbot/core/`
- **Platform adapters**: `astrbot/core/platform/sources/`
- **Star plugins**: `astrbot/builtin_stars/`
- **Dashboard**: `dashboard/` (Vue.js frontend)
### Basic
## File Organization
1. When modifying the WebUI, be sure to maintain componentization and clean code. Avoid duplicate code.
2. Do not add any report files such as xxx_SUMMARY.md.
3. After finishing, use `ruff format .` and `ruff check .` to format and check the code.
4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`.
5. Use English for all new comments.
6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory.
7. When backend API routes, request/response schemas, or OpenAPI definitions change, regenerate the frontend API client by running `cd dashboard && pnpm generate:api`.
8. When updating the project version, keep `[project].version` in `pyproject.toml` and `__version__` in `astrbot/__init__.py` in sync. `VERSION` in `astrbot/core/config/default.py` should derive from `astrbot.__version__` instead of hardcoding a separate version string.
9. When designing WebUI dialogs, use `text-h3 pa-4 pb-0 pl-6` as the base class for dialog titles, and use `variant="text"` or `variant="tonal"` for dialog buttons.
```
astrbot/
├── __main__.py # Main entry point
├── __init__.py # Package init, exports
├── cli/ # CLI commands
│ └── commands/ # Individual command modules
├── core/ # Core functionality
│ ├── agent/ # Agent execution
│ ├── platform/ # Platform adapters
│ ├── pipeline/ # Message processing
│ ├── star/ # Plugin system
│ └── config/ # Configuration
├── builtin_stars/ # Built-in plugins
├── dashboard/ # Vue.js frontend
└── utils/ # Utilities
### KISS and First Principles
Follow the KISS principle and reason from first principles during development. Start by identifying the real problem, required behavior, and smallest useful change before adding code. Do not pile on features, configuration switches, abstractions, dependencies, or compatibility layers unless they directly solve the current problem and have clear evidence of need.
Prefer the simplest implementation that is correct, maintainable, and consistent with the existing codebase. If a broader design seems attractive, reduce it to the essential behavior needed now and leave optional expansion for a later, explicit requirement.
### No Unnecessary Helpers
Prioritize inline implementation over abstraction. Avoid over-engineering and do not create helper functions unless absolutely necessary.
1. **Inline-First Rule**: If a logic block can be implemented directly within the main function without breaking overall readability, **do not** extract it into a new helper function.
2. **Strict Justification for Helpers**: You may only create a separate helper function if it meets at least one of these criteria:
- **High Reuse**: The exact same logic is repeated across **3 or more** different locations.
- **Extreme Complexity**: Inlining the logic makes the main function too long (e.g., >50 lines) or severely derails the main execution flow.
3. **No Fragmentation**: Do not split continuous linear logic (e.g., a single API call, simple form validation, or one-time data formatting) into tiny functions just for the sake of "clean code."
4. **Keep Context Compact**: Handle edge cases, error catching, and logging directly inside the main function block instead of offloading them.
5. **Refactoring Constraint**: When modifying existing code, do not alter the current function structure or extract code into new helpers unless the existing code already violates the complexity or reuse rules above.
### Mandatory Google-Style Docstrings
* **Comment the complex**: Add clear comments to any non-obvious function, method, or parameter.
* **Google Format**: All docstrings must strictly use the Google format (`Args:`, `Returns:`, `Raises:`).
#### Example:
```py
def calculate_metrics(user_id: int, force_refresh: bool = False) -> dict:
"""Brief description of the function.
Args:
user_id: Description of the ID.
force_refresh: Description of the flag.
Returns:
Description of the returned dict.
Raises:
ValueError: Description of when this occurs.
"""
# Inline implementation here...
```
## Architecture
### Core Components
## PR instructions
- `astrbot/core/` - Core bot functionality
- `astrbot/core/platform/` - Platform adapter system
- `astrbot/core/agent/` - Agent execution logic
- `astrbot/core/star/` - Plugin/Star handler system
- `astrbot/core/pipeline/` - Message processing pipeline
- `astrbot/cli/` - Command-line interface
### Important Utilities
```python
from astrbot.core.utils.astrbot_path import (
get_astrbot_root, # AstrBot root directory
get_astrbot_data_path, # Data directory
get_astrbot_config_path, # Config directory
get_astrbot_plugin_path, # Plugin directory
get_astrbot_temp_path, # Temp directory
get_astrbot_skills_path, # Skills directory
)
```
### Platform Adapters
Platform adapters are in `astrbot/core/platform/sources/`:
- Each adapter extends base platform classes
- Use `@register_platform_adapter` decorator
- Events flow through `commit_event()` to message queue
### Star (Plugin) System
Stars are plugins in `astrbot/builtin_stars/`:
- Extend `Star` base class
- Use decorators for command handlers: `@star.on_command`, `@star.on_message`, etc.
- Access via `context` object
## Code Style
1. **Type hints required** - Use Python 3.12+ syntax:
- `list[str]` not `List[str]`
- `int | None` not `Optional[int]`
- Avoid `Any` when possible. Use proper `TypedDict`, `dataclass`, or `Protocol` instead.
- When encountering dict access issues (e.g., `msg.get("key")` where type inference is wrong), define a `TypedDict` with `total=False` to explicitly declare allowed keys.
Good example:
```python
class MessageComponent(TypedDict, total=False):
type: str
text: str
path: str
```
Bad example (avoid):
```python
msg: Any = something
msg = cast(dict, msg)
```
2. **Path handling** - Always use `pathlib.Path`:
```python
from pathlib import Path
# Use astrbot.core.utils.path_utils for data/temp directories
from astrbot.core.utils.path_utils import get_astrbot_data_path
```
3. **Formatting** - Run before committing:
```bash
ruff format .
ruff check .
```
4. **Comments** - Use English for all comments and docstrings
5. **Imports** - Use absolute imports via `astrbot.` prefix
### Environment Variables
When adding new environment variables:
1. Use `ASTRBOT_` prefix: `ASTRBOT_ENABLE_FEATURE`
2. Add to `.env.example` with description
3. Update `astrbot/cli/commands/cmd_run.py`:
- Add to module docstring under "Environment Variables Used in Project"
- Add to `keys_to_print` list for debug output
## Testing
1. Tests go in `tests/` directory
2. Use `pytest` with `pytest-asyncio`
3. Run: `uv sync --group dev && uv run pytest --cov=astrbot tests/`
4. Test files: `test_*.py` or `*_test.py`
### Code Quality Scoring Test
The project enforces a **code quality score** via `tests/test_code_quality_typing.py`. All agents must treat this as a hard constraint when modifying code.
**Run the test:**
```bash
uv run pytest tests/test_code_quality_typing.py -v
```
**Scoring rules (target: 100/100, threshold for PASS: 80/100):**
| Pattern | Cost |
|---------|------|
| `cast(Any, ...)` | -1 pt each |
| `# type: ignore` | -0.5 pt each |
| **BAD** `# type: ignore[...]` (unresolved-import, class-alias, no-name-module, attr-defined, etc.) | **-3 pt each** |
| `bare except:` (no exception type) | -0.5 pt each |
| Duplicate code block (5+ identical lines, ≥2 occurrences) | -2 pt each |
**Why bad type: ignore is heavily penalized:**
- `# type: ignore[unresolved-import]` — hides missing module/stub issues
- `# type: ignore[class-alias]` — hides improper type alias patterns
- `# type: ignore[attr-defined]` — hides missing attribute errors
- These are **workarounds, not fixes** — they paper over real type errors
**Scoring formula:**
```
score = max(0, 100 - cast_any - type_ignore*0.5 - bad_type_ignore*3 - bare_except*0.5 - dup_blocks*2)
```
**Agent rules when modifying code:**
1. **Do not add** `# type: ignore[unresolved-import]` or `# type: ignore[class-alias]` — fix the underlying issue instead
2. **Do not use** `cast(Any, ...)` to suppress type errors — use proper type annotations
3. **Do not add** bare `except:` clauses — use `except SomeSpecificException:`
4. **Do not copy-paste** 5+ line blocks — extract to a shared helper function
5. Before committing, run the scoring test and ensure score ≥ 80
## Git Conventions
### Commit Messages
Use conventional commits:
```
feat: add new feature
fix: resolve bug
docs: update documentation
refactor: restructure code
test: add tests
chore: maintenance tasks
```
### PR Guidelines
1. Title: conventional commit format
2. Description: English
3. Target branch: `dev`
4. Keep changes focused and atomic
## Project-Specific Guidelines
1. **No report files** - Do not add `xxx_SUMMARY.md` or similar
2. **Componentization** - Maintain clean code, avoid duplication in WebUI
3. **Backward compatibility** - When deprecating, add warnings
4. **CLI help** - Run `astrbot help --all` to see all commands
5. When modifying frontend/dashboard code, use the project's custom request module `@/utils/request` for HTTP calls
6. For fetch or SSE URLs, use `resolveApiUrl('/api/your-path')` so the configured `VITE_API_BASE` and dev proxy rules are respected
7. Do not import the plain `axios` package directly in dashboard source files
## Common Tasks
### Adding a new platform adapter
1. Create adapter in `astrbot/core/platform/sources/`
2. Extend `Platform` base class
3. Use `@register_platform_adapter` decorator
4. Implement required methods: `run()`, `convert_message()`, `meta()`
### Adding a new command
1. Add to appropriate module in `cli/commands/`
2. Register with `@click.command()`
3. Update `astrbot/cli/__main__.py` to add command
### Adding a new Star handler
1. Create in `astrbot/builtin_stars/` or as plugin
2. Extend `Star` class
3. Use decorators: `@star.on_command()`, `@star.on_schedule()`, etc.
1. Title format: use conventional commit messages
2. Use English to write PR title and descriptions.
## Release versions
1. Replace current version name to specific version name.
2. Write changelog in `changelogs/`, you can refer to the full commit messages between the latest tag to the latest commit.
3. Make and push a commit into master branch with message format like: `chore: bump version to 4.25.0`
4. Create a tag and push the tag. For example: `git tag v4.25.0 && git push origin v4.25.0`
Use a short-lived `release/*` branch for each release. The release branch is the stabilization area for version bumps, changelog updates, release-blocking fixes, and final validation only. Do not add unrelated features or broad refactors to a release branch.
Prepare a release from a clean worktree with:
```bash
uv run python scripts/prepare_release.py 4.25.0
```
The script updates `pyproject.toml` and `astrbot/__init__.py`, creates `changelogs/v4.25.0.md`, runs the required Python checks, and prints the remaining steps. Use these flags when needed:
```bash
uv run python scripts/prepare_release.py 4.25.0 --generate-api-client
uv run python scripts/prepare_release.py 4.25.0 --dashboard-build
uv run python scripts/prepare_release.py 4.25.0 --commit --push
```
Open a PR from `release/4.25.0` to `master`. The PR title must use the conventional commit format, for example `chore: bump version to 4.25.0`. After the release PR is merged, create and push the tag from the updated `master` branch so the tag points to the exact code that was merged:
```bash
git checkout master
git pull --ff-only origin master
git tag v4.25.0
git push origin v4.25.0
```
For one-off release candidate branches, delete the release branch after the tag is pushed and verified. For maintained release lines, use a branch such as `release/4.25` and keep it until that line reaches EOL.
```bash
git branch -d release/4.25.0
git push origin --delete release/4.25.0
```

View File

@@ -1,4 +1,4 @@
![astrbot-github-banner-v2-light-0405_副本](https://github.com/user-attachments/assets/36fb04e4-cc75-4454-bd8b-049d11aa86f9)
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
@@ -7,7 +7,8 @@
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_es.md">Español</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<br>
@@ -20,7 +21,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>

289
README_es.md Normal file
View File

@@ -0,0 +1,289 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh.md">简体中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<br>
<div>
<a href="https://trendshift.io/repositories/21369" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21369" alt="AstrBotDevs%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&label=Marketplace&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
</div>
<br>
<a href="https://astrbot.app/">Documentación</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Hoja de ruta</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Registro de incidencias</a>
<a href="mailto:community@astrbot.app">Soporte por correo</a>
</div>
AstrBot es una plataforma de chatbot Agent todo en uno de código abierto que se integra con las principales aplicaciones de mensajería instantánea. Proporciona una infraestructura de IA conversacional confiable y escalable para individuos, desarrolladores y equipos. Ya sea que estés construyendo un compañero de IA personal, un servicio de atención al cliente inteligente, un asistente de automatización o una base de conocimiento empresarial, AstrBot te permite crear rápidamente aplicaciones de IA listas para producción dentro de los flujos de trabajo de tu plataforma de mensajería instantánea.
![screenshot_1 5x_postspark_2026-02-27_22-37-45](https://github.com/user-attachments/assets/f17cdb90-52d7-4773-be2e-ff64b566af6b)
## Características principales
1. 💯 Gratis y de código abierto.
2. ✨ Conversaciones con LLM de IA, multimodal, Agent, MCP, habilidades, base de conocimiento, configuración de personalidad, compresión automática de contexto.
3. 🤖 Soporta integración con Dify, Alibaba Cloud Bailian, Coze y otras plataformas de Agent.
4. 🌐 Multiplataforma: QQ, WeChat Work, Feishu, DingTalk, cuentas oficiales de WeChat, Telegram, Slack y [más](#plataformas-de-mensajería-soportadas).
5. 📦 Extensiones mediante plugins con más de 1000 plugins disponibles para instalación en un clic.
6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) para ejecución aislada y segura de código, llamadas a shell y reutilización de recursos a nivel de sesión.
7. 💻 Soporte de WebUI.
8. 🌈 Soporte de Web ChatUI con Agent Sandbox integrado y búsqueda web.
9. 🌐 Soporte de internacionalización (i18n).
<br>
<table align="center">
<tr align="center">
<th>💙 Juego de roles y compañía emocional</th>
<th>✨ Agent proactivo</th>
<th>🚀 Capacidades Agentic generales</th>
<th>🧩 Más de 1000 plugins de la comunidad</th>
</tr>
<tr>
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
</tr>
</table>
## Inicio rápido
### Despliegue en un clic
Para los usuarios que quieran experimentar AstrBot rápidamente, estén familiarizados con el uso de la línea de comandos y puedan instalar un entorno `uv` por su cuenta, recomendamos el método de despliegue en un clic con `uv` ⚡️:
```bash
uv tool install astrbot --python 3.12
astrbot init # Ejecuta este comando solo la primera vez para inicializar el entorno
astrbot run
```
> Requiere tener [uv](https://docs.astral.sh/uv/) instalado.
> AstrBot requiere Python 3.12 o superior. La opción `--python 3.12` asegura que `uv` cree el entorno de la herramienta con Python 3.12.
> [!NOTE]
> Para usuarios de macOS: debido a las comprobaciones de seguridad de macOS, la primera ejecución del comando `astrbot` puede tardar más (aproximadamente 10-20s).
Actualizar `astrbot`:
```bash
uv tool upgrade astrbot --python 3.12
```
> [!WARNING]
> AstrBot desplegado mediante `uv` **no soporta la actualización a través de la WebUI**. Para actualizar, ejecuta el comando anterior desde la línea de comandos.
### Despliegue con Docker
Para usuarios familiarizados con contenedores y que buscan un método de despliegue más estable y listo para producción, recomendamos desplegar AstrBot con Docker / Docker Compose.
Consulta la documentación oficial: [Desplegar AstrBot con Docker](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Desplegar en RainYun
Para usuarios que desean un despliegue en un clic y no quieren administrar servidores por sí mismos, recomendamos el servicio de despliegue en la nube en un clic de RainYun ☁️:
[![Desplegar en RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Despliegue como aplicación de escritorio
Para usuarios que quieran usar AstrBot en el escritorio y principalmente usen ChatUI, recomendamos AstrBot App.
Visita [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) para descargar e instalar; este método está diseñado para uso en escritorio y no se recomienda para escenarios de servidor.
### Despliegue con Launcher
Para usuarios de escritorio que también desean un despliegue rápido y uso aislado de múltiples instancias, recomendamos AstrBot Launcher.
Visita [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) para descargar e instalar.
### Desplegar en Replit
El despliegue en Replit es mantenido por la comunidad y es adecuado para demostraciones en línea y pruebas ligeras.
[![Ejecutar en Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
El despliegue mediante AUR está dirigido a usuarios de Arch Linux que prefieren instalar AstrBot a través del flujo de trabajo de paquetes del sistema.
Ejecuta el siguiente comando para instalar `astrbot-git`, luego inicia AstrBot en tu entorno local.
```bash
yay -S astrbot-git
```
**Más métodos de despliegue**
Si necesitas gestión basada en panel o una personalización más profunda, consulta [Despliegue con BT-Panel](https://docs.astrbot.app/deploy/astrbot/btpanel.html) para la configuración desde la tienda de aplicaciones de BT Panel, [Despliegue con 1Panel](https://docs.astrbot.app/deploy/astrbot/1panel.html) para el despliegue desde el mercado de aplicaciones de 1Panel, [Despliegue con CasaOS](https://docs.astrbot.app/deploy/astrbot/casaos.html) para despliegue visual en NAS/servidor doméstico, y [Despliegue manual](https://docs.astrbot.app/deploy/astrbot/cli.html) para una instalación completamente personalizada desde el código fuente con `uv`.
## Plataformas de mensajería soportadas
Conecta AstrBot a tu plataforma de chat favorita.
| Plataforma | Mantenedor |
|---------|---------------|
| QQ | Oficial |
| Implementación del protocolo OneBot v11 | Oficial |
| Telegram | Oficial |
| Wecom y Wecom AI Bot | Oficial |
| Cuentas oficiales de WeChat | Oficial |
| Feishu (Lark) | Oficial |
| DingTalk | Oficial |
| Slack | Oficial |
| Discord | Oficial |
| LINE | Oficial |
| Satori | Oficial |
| KOOK | Oficial |
| Misskey | Oficial |
| Mattermost | Oficial |
| WhatsApp (Próximamente) | Oficial |
| [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | Comunidad |
| [Rocket.Chat](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | Comunidad |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | Comunidad |
## Servicios de modelo soportados
| Servicio | Tipo |
|---------|---------------|
| OpenAI y servicios compatibles | Servicios LLM |
| Anthropic | Servicios LLM |
| Google Gemini | Servicios LLM |
| Moonshot AI | Servicios LLM |
| Zhipu AI | Servicios LLM |
| DeepSeek | Servicios LLM |
| Ollama (Autoalojado) | Servicios LLM |
| LM Studio (Autoalojado) | Servicios LLM |
| [AIHubMix](https://aihubmix.com/?aff=4bfH) | Servicios LLM (API Gateway, soporta todos los modelos) |
| [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | Servicios LLM |
| [302.AI](https://share.302.ai/rr1M3l) | Servicios LLM |
| [TokenPony](https://www.tokenpony.cn/3YPyf) | Servicios LLM |
| [SiliconFlow](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot) | Servicios LLM |
| [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE) | Servicios LLM |
| ModelScope | Servicios LLM |
| OneAPI | Servicios LLM |
| Dify | Plataformas LLMOps |
| Aplicaciones de Alibaba Cloud Bailian | Plataformas LLMOps |
| Coze | Plataformas LLMOps |
| OpenAI Whisper | Servicios de voz a texto |
| SenseVoice | Servicios de voz a texto |
| Xiaomi MiMo Omni | Servicios de voz a texto |
| OpenAI TTS | Servicios de texto a voz |
| Gemini TTS | Servicios de texto a voz |
| GPT-Sovits-Inference | Servicios de texto a voz |
| GPT-Sovits | Servicios de texto a voz |
| FishAudio | Servicios de texto a voz |
| Edge TTS | Servicios de texto a voz |
| Alibaba Cloud Bailian TTS | Servicios de texto a voz |
| Azure TTS | Servicios de texto a voz |
| Minimax TTS | Servicios de texto a voz |
| Xiaomi MiMo TTS | Servicios de texto a voz |
| Volcano Engine TTS | Servicios de texto a voz |
## ❤️ Patrocinadores
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ Contribuir
¡Issues y Pull Requests son siempre bienvenidos! No dudes en enviar tus cambios a este proyecto :)
### Cómo contribuir
Puedes contribuir revisando issues o ayudando con la revisión de pull requests. Cualquier issue o PR es bienvenido para fomentar la participación de la comunidad. Por supuesto, estas son solo sugerencias: puedes contribuir de la manera que prefieras. Para agregar nuevas funcionalidades, por favor discútelo primero a través de un Issue.
### Entorno de desarrollo
AstrBot usa `ruff` para el formateo y linting de código.
```bash
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## 🌍 Comunidad
### Grupos de QQ
- Grupo 1: 322154837 (Lleno)
- Grupo 3: 630166526 (Lleno)
- Grupo 4: 1077826412 (Lleno)
- Grupo 5: 822130018 (Lleno)
- Grupo 6: 753075035 (Lleno)
- Grupo 7: 743746109 (Lleno)
- Grupo 8: 1030353265 (Lleno)
- Grupo 9: 1076659624 (Lleno)
- Grupo 10: 1078079676 (Lleno)
- Grupo 11: 704659519 (Lleno)
- Grupo 12: 916228568 (Lleno)
- Grupo 13: 1092185289
- Grupo 14: 1103419483
- Grupo de desarrolladores (Charla): 975206796
- Grupo de desarrolladores (Formal): 1039761811
### Servidor de Discord
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Agradecimientos especiales
Un agradecimiento especial a todos los contribuidores y desarrolladores de plugins por sus contribuciones a AstrBot ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot&max=300&columns=15" />
</a>
Además, el nacimiento de este proyecto no habría sido posible sin la ayuda de los siguientes proyectos de código abierto:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - El increíble framework felino
## ⭐ Historial de estrellas
> [!TIP]
> Si este proyecto te ha ayudado en tu vida o trabajo, o si estás interesado en su desarrollo futuro, por favor dale una estrella al proyecto. Es la fuerza impulsora detrás del mantenimiento de este proyecto de código abierto <3
<div align="center">
[![Gráfico de historial de estrellas](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
</div>
<div align="center">
_La compañía y la capacidad nunca deberían estar en conflicto. Lo que aspiramos a crear es un robot que pueda entender emociones, proporcionar compañía genuina y realizar tareas de manera confiable._
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>

View File

@@ -6,6 +6,7 @@
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_es.md">Español</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<br>
@@ -19,7 +20,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>

View File

@@ -6,6 +6,7 @@
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_es.md">Español</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<br>
@@ -19,7 +20,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0LjYxNTZDNS4zMTUwMiAxNC4zOTk5IDUuNjAxNTYgMTQuMTEzNCA1LjYwMTU2IDEzLjc1OTlWMTEuMDM5OUM1LjYwMTU2IDEwLjY4NjQgNS4zMTUwMiAxMC4zOTk5IDQuOTYxNTYgMTAuMzk5OVoiIGZpbGw9IiNmZmYiLz4KPHBhdGggZD0iTTEzLjc1ODQgMS42MDAxSDExLjAzODRDMTAuNjg1IDEuNjAwMSAxMC4zOTg0IDEuODg2NjQgMTAuMzk4NCAyLjI0MDFWNC45NjAxQzEwLjM5ODQgNS4zMTM1NiAxMC42ODUgNS42MDAxIDExLjAzODQgNS42MDAxSDEzLjc1ODRDMTQuMTExOSA1LjYwMDEgMTQuMzk4NCA1LjMxMzU2IDE0LjM5ODQgNC45NjAxVjIuMjQwMUMxNC4zOTg0IDEuODg2NjQgMTQuMTExOSAxLjYwMDEgMTMuNzU4NCAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDRMNCAxMlpFIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>

View File

@@ -6,7 +6,8 @@
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_es.md">Español</a>
<br>
@@ -19,7 +20,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFZIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjczODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>

View File

@@ -6,6 +6,7 @@
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_es.md">Español</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<br>
@@ -19,7 +20,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>

View File

@@ -6,6 +6,7 @@
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_es.md">Español</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<div>
@@ -17,7 +18,7 @@
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
@@ -78,10 +79,7 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
```bash
uv tool install astrbot --python 3.12
astrbot init # 仅首次执行此命令以初始化环境
astrbot run # astrbot run --backend-only 仅启动后端服务
# 安装开发版本(更多修复,新功能,但不够稳定,适合开发者)
uv tool install git+https://github.com/AstrBotDevs/AstrBot@dev
astrbot run
```
> 需要安装 [uv](https://docs.astral.sh/uv/)。
@@ -207,25 +205,13 @@ yay -S astrbot-git
| Xiaomi MiMo TTS | 文本转语音 |
| 火山引擎 TTS | 文本转语音 |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 :)
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
### 如何贡献
你可以通过查看问题或帮助审核 PR拉取请求来贡献。任何问题或 PR 都欢迎参与,以促进社区贡献。当然,这些只是建议,你可以以任何方式进行贡献。对于新功能的添加,请先通过 Issue 讨论。
建议将功能性PR合并至dev分支将在测试修改后合并到主分支并发布新版本。
为了减少冲突,建议如下:
1. 工作分支最好基于 `dev` 分支创建,避免直接在 `main` 分支上工作。
2. 提交 PR 时,选择 `dev` 分支作为目标分支。
3. 定期同步 `dev` 分支到本地多使用git pull。
### 开发环境
@@ -233,23 +219,11 @@ AstrBot 使用 `ruff` 进行代码格式化和检查。
```bash
git clone https://github.com/AstrBotDevs/AstrBot
git switch dev # 切换到开发分支
pip install pre-commit # 或者uv tool install pre-commit
pip install pre-commit
pre-commit install
```
推荐使用uv本地安装进行测试
```bash
uv tool install -e . --force
astrbot init
astrbot run
```
调试前端
```bash
astrbot run --backend-only
cd dashboard
bun install # 或者pnpm 等
bun dev
```
## 🌍 社区
### QQ 群组

View File

@@ -1,30 +1,4 @@
from __future__ import annotations
import logging
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _pkg_version
from typing import TYPE_CHECKING, Any
try:
__version__ = _pkg_version("astrbot")
except PackageNotFoundError:
__version__ = "4.26.1"
if TYPE_CHECKING:
from .core import logger as logger
__all__ = ["logger"]
def __getattr__(name: str) -> Any:
if name == "cli":
from astrbot.cli.__main__ import cli
return cli()
if name == "logger":
from .core import logger
return logger
raise AttributeError(name)
__version__ = "4.26.3"
logger = logging.getLogger("astrbot")

View File

@@ -1,147 +0,0 @@
import argparse
import asyncio
import mimetypes
import os
import sys
from pathlib import Path
import anyio
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.config.default import VERSION
from astrbot.core.initial_loader import InitialLoader
from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_data_path,
get_astrbot_knowledge_base_path,
get_astrbot_plugin_path,
get_astrbot_root,
get_astrbot_site_packages_path,
get_astrbot_skills_path,
get_astrbot_temp_path,
)
from astrbot.core.utils.io import (
download_dashboard,
get_dashboard_version,
)
# 将父目录添加到 sys.path
sys.path.append(Path(__file__).parent.as_posix())
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
def check_env() -> None:
# Python version check: require 3.12 or 3.13
if not (sys.version_info.major == 3 and sys.version_info.minor in (12, 13)):
sys.exit(1)
astrbot_root = get_astrbot_root()
if astrbot_root not in sys.path:
sys.path.insert(0, astrbot_root)
site_packages_path = get_astrbot_site_packages_path()
if site_packages_path not in sys.path:
sys.path.insert(0, site_packages_path)
os.makedirs(get_astrbot_config_path(), exist_ok=True)
os.makedirs(get_astrbot_plugin_path(), exist_ok=True)
os.makedirs(get_astrbot_temp_path(), exist_ok=True)
os.makedirs(get_astrbot_knowledge_base_path(), exist_ok=True)
os.makedirs(get_astrbot_skills_path(), exist_ok=True)
os.makedirs(site_packages_path, exist_ok=True)
# 针对问题 #181 的临时解决方案
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("text/javascript", ".mjs")
mimetypes.add_type("application/json", ".json")
async def check_dashboard_files(webui_dir: str | None = None):
"""下载管理面板文件"""
# 指定webui目录
if webui_dir:
if await anyio.Path(webui_dir).exists():
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
return webui_dir
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
if await anyio.Path(data_dist_path).exists():
v = await get_dashboard_version()
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI 版本已是最新。")
else:
logger.warning(
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
)
return data_dist_path
logger.info(
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。",
)
try:
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.warning(
f"下载指定版本(v{VERSION})的管理面板文件失败: {e},尝试下载最新版本。",
)
try:
await download_dashboard(latest=True)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return data_dist_path
async def main_async(webui_dir_arg: str | None, log_broker: LogBroker) -> None:
"""主异步入口"""
# 检查仪表板文件
webui_dir = await check_dashboard_files(webui_dir_arg)
if webui_dir is None:
logger.warning(
"管理面板文件检查失败,WebUI 功能将不可用。"
"请检查网络连接或手动指定 --webui-dir 参数。",
)
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
await core_lifecycle.start()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AstrBot")
parser.add_argument(
"--webui-dir",
type=str,
help="指定 WebUI 静态文件目录路径",
default=None,
)
args = parser.parse_args()
check_env()
# 启动日志代理
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# 只使用一次 asyncio.run()
asyncio.run(main_async(args.webui_dir, log_broker))

View File

@@ -30,7 +30,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
)
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
register_star as register, # 注册插件Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *

View File

@@ -59,14 +59,14 @@ __all__ = [
"on_decorating_result",
"on_llm_request",
"on_llm_response",
"on_llm_tool_respond",
"on_platform_loaded",
"on_plugin_error",
"on_plugin_loaded",
"on_plugin_unloaded",
"on_using_llm_tool",
"on_platform_loaded",
"on_waiting_llm_request",
"permission_type",
"platform_adapter_type",
"regex",
"on_using_llm_tool",
"on_llm_tool_respond",
]

View File

@@ -1,7 +1,7 @@
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
register_star as register, # 注册插件Star
)
__all__ = ["Context", "Star", "StarTools", "register"]

View File

@@ -1,29 +1,19 @@
# Commands module
from .admin import AdminCommands
from .alter_cmd import AlterCmdCommands
from .conversation import ConversationCommands
from .help import HelpCommand
from .llm import LLMCommands
from .name import NameCommand
from .plugin import PluginCommands
from .provider import ProviderCommands
from .setunset import SetUnsetCommands
from .sid import SIDCommand
from .t2i import T2ICommand
from .tts import TTSCommand
__all__ = [
"AdminCommands",
"AlterCmdCommands",
"ConversationCommands",
"HelpCommand",
"LLMCommands",
"NameCommand",
"PluginCommands",
"ProviderCommands",
"SIDCommand",
"SetUnsetCommands",
"T2ICommand",
"TTSCommand",
"SIDCommand",
]

View File

@@ -1,5 +1,5 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageChain, MessageEventResult
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard
@@ -8,70 +8,8 @@ class AdminCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
if not admin_id:
event.set_result(
MessageEventResult().message(
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员。可通过 /sid 获取 ID。",
),
)
return
self.context.get_config()["admins_id"].append(str(admin_id))
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""取消授权管理员。deop <admin_id>"""
if not admin_id:
event.set_result(
MessageEventResult().message(
"使用方法: /deop <id> 取消管理员。可通过 /sid 获取 ID。",
),
)
return
try:
self.context.get_config()["admins_id"].remove(str(admin_id))
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("取消授权成功。"))
except ValueError:
event.set_result(
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
)
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
if not sid:
event.set_result(
MessageEventResult().message(
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单。可通过 /sid 获取 ID。",
),
)
return
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg["platform_settings"]["id_whitelist"].append(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""删除白名单。dwl <sid>"""
if not sid:
event.set_result(
MessageEventResult().message(
"使用方法: /dwl <id> 删除白名单。可通过 /sid 获取 ID。",
),
)
return
try:
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("删除白名单成功。"))
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await event.send(MessageChain().message("正在尝试更新管理面板..."))
await event.send(MessageChain().message("⏳ Updating dashboard..."))
await download_dashboard(version=f"v{VERSION}", latest=False)
await event.send(MessageChain().message("管理面板更新完成。"))
await event.send(MessageChain().message("✅ Dashboard updated successfully."))

View File

@@ -1,187 +0,0 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
from astrbot.core.utils.command_parser import CommandParserMixin
from .utils.rst_scene import RstScene
class AlterCmdCommands(CommandParserMixin):
def __init__(self, context: star.Context) -> None:
self.context = context
async def update_reset_permission(self, scene_key: str, perm_type: str) -> None:
"""更新reset命令在特定场景下的权限设置"""
from astrbot.api import sp
alter_cmd_cfg: dict[str, dict[str, dict[str, str]]] = (
await sp.global_get("alter_cmd", {}) or {}
)
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_cfg.get("reset", {})
reset_cfg[scene_key] = perm_type
plugin_cfg["reset"] = reset_cfg
alter_cmd_cfg["astrbot"] = plugin_cfg
await sp.global_put("alter_cmd", alter_cmd_cfg)
async def alter_cmd(self, event: AstrMessageEvent) -> None:
token = self.parse_commands(event.message_str)
if token.len < 3:
await event.send(
MessageChain().message(
"该指令用于设置指令或指令组的权限。\n"
"格式: /alter_cmd <cmd_name> <admin/member>\n"
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
"/alter_cmd reset config 打开 reset 权限配置",
),
)
return
# 兼容 reset scene 的专门配置
cmd_name = token.get(1)
cmd_type = token.get(2)
if cmd_name == "reset" and cmd_type == "config":
from astrbot.api import sp
alter_cmd_cfg: dict[str, dict[str, dict[str, str]]] = (
await sp.global_get("alter_cmd", {}) or {}
)
plugin_ = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_.get("reset", {})
group_unique_on = reset_cfg.get("group_unique_on", "admin")
group_unique_off = reset_cfg.get("group_unique_off", "admin")
private = reset_cfg.get("private", "member")
config_menu = f"""reset命令权限细粒度配置
当前配置:
1. 群聊+会话隔离开: {group_unique_on}
2. 群聊+会话隔离关: {group_unique_off}
3. 私聊: {private}
修改指令格式:
/alter_cmd reset scene <场景编号> <admin/member>
例如: /alter_cmd reset scene 2 member"""
await event.send(MessageChain().message(config_menu))
return
if cmd_name == "reset" and cmd_type == "scene" and token.len >= 4:
scene_num = token.get(3)
perm_type = token.get(4)
if scene_num is None or perm_type is None:
await event.send(MessageChain().message("场景编号和权限类型不能为空"))
return
if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3:
await event.send(
MessageChain().message("场景编号必须是 1-3 之间的数字"),
)
return
if perm_type not in ["admin", "member"]:
await event.send(
MessageChain().message("权限类型错误,只能是 admin 或 member"),
)
return
scene_index = int(scene_num)
scene = RstScene.from_index(scene_index)
scene_key = scene.key
await self.update_reset_permission(scene_key, perm_type)
await event.send(
MessageChain().message(
f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}",
),
)
return
if cmd_type not in ["admin", "member"]:
await event.send(
MessageChain().message("指令类型错误,可选类型有 admin, member"),
)
return
# 查找指令
cmd_name = " ".join(token.tokens[1:-1])
permission_type = token.get(-1)
if permission_type not in ["admin", "member"]:
await event.send(
MessageChain().message("指令类型错误,可选类型有 admin, member"),
)
return
found_command = None
cmd_group = False
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.equals(cmd_name):
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if filter_.equals(cmd_name):
found_command = handler
cmd_group = True
break
if not found_command:
await event.send(MessageChain().message("未找到该指令"))
return
found_plugin = star_map[found_command.handler_module_path]
from astrbot.api import sp
stored_alter_cmd_cfg: dict[str, dict[str, dict[str, str]]] = (
await sp.global_get("alter_cmd", {}) or {}
)
if found_plugin.name is None:
await event.send(MessageChain().message("未找到指令对应的插件名称"))
return
plugin_ = stored_alter_cmd_cfg.get(found_plugin.name, {})
cfg = plugin_.get(found_command.handler_name, {})
cfg["permission"] = permission_type
plugin_[found_command.handler_name] = cfg
stored_alter_cmd_cfg[found_plugin.name] = plugin_
await sp.global_put("alter_cmd", stored_alter_cmd_cfg)
# 注入权限过滤器
found_permission_filter = False
for filter_ in found_command.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if permission_type == "admin":
from astrbot.api.event import filter
filter_.permission_type = filter.PermissionType.ADMIN
else:
from astrbot.api.event import filter
filter_.permission_type = filter.PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
from astrbot.api.event import filter
found_command.event_filters.insert(
0,
PermissionTypeFilter(
filter.PermissionType.ADMIN
if permission_type == "admin"
else filter.PermissionType.MEMBER,
),
)
cmd_group_str = "指令组" if cmd_group else "指令"
await event.send(
MessageChain().message(
f"已将「{cmd_name}{cmd_group_str} 的权限级别调整为 {permission_type}",
),
)

View File

@@ -23,7 +23,9 @@ class HelpCommand:
return ""
async def _build_reserved_command_lines(self) -> list[str]:
"""使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。"""
"""
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
"""
try:
commands = await command_management.list_commands()
except BaseException:

View File

@@ -1,20 +0,0 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageChain
class LLMCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
cfg = self.context.get_config(umo=event.unified_msg_origin)
enable = cfg["provider_settings"].get("enable", True)
if enable:
cfg["provider_settings"]["enable"] = False
status = "关闭"
else:
cfg["provider_settings"]["enable"] = True
status = "开启"
cfg.save_config()
await event.send(MessageChain().message(f"{status} LLM 聊天功能。"))

View File

@@ -1,214 +0,0 @@
import builtins
from typing import TYPE_CHECKING
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
if TYPE_CHECKING:
from astrbot.core.db.po import Persona
class PersonaCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
def _build_tree_output(
self,
folder_tree: list[dict],
all_personas: list["Persona"],
depth: int = 0,
) -> list[str]:
"""递归构建树状输出,使用短线条表示层级"""
lines: list[str] = []
# 使用短线条作为缩进前缀,每层只用 "" 加一个空格
prefix = " " * depth
for folder in folder_tree:
# 输出文件夹
lines.append(f"{prefix}├ 📁 {folder['name']}/")
# 获取该文件夹下的人格
folder_personas = [
p for p in all_personas if p.folder_id == folder["folder_id"]
]
child_prefix = " " * (depth + 1)
# 输出该文件夹下的人格
for persona in folder_personas:
lines.append(f"{child_prefix}├ 👤 {persona.persona_id}")
# 递归处理子文件夹
children = folder.get("children", [])
if children:
lines.extend(
self._build_tree_output(
children,
all_personas,
depth + 1,
),
)
return lines
async def persona(self, message: AstrMessageEvent) -> None:
parts = message.message_str.split(" ")
umo = message.unified_msg_origin
curr_persona_name = ""
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
default_persona = await self.context.persona_manager.get_default_persona_v3(
umo=umo,
)
force_applied_persona_id = None
curr_cid_title = ""
if cid:
conv = await self.context.conversation_manager.get_conversation(
unified_msg_origin=umo,
conversation_id=cid,
create_if_not_exists=True,
)
if conv is None:
message.set_result(
MessageEventResult().message(
"当前对话不存在,请先使用 /new 新建一个对话。",
),
)
return
provider_settings = self.context.get_config(umo=umo).get(
"provider_settings",
{},
)
(
persona_id,
_,
force_applied_persona_id,
_,
) = await self.context.persona_manager.resolve_selected_persona(
umo=umo,
conversation_persona_id=conv.persona_id,
platform_name=message.get_platform_name(),
provider_settings=provider_settings,
)
if persona_id == "[%None]":
curr_persona_name = ""
elif persona_id:
curr_persona_name = persona_id
if force_applied_persona_id:
curr_persona_name = f"{curr_persona_name} (自定义规则)"
curr_cid_title = conv.title or "新对话"
curr_cid_title += f"({cid[:4]})"
if len(parts) == 1:
message.set_result(
MessageEventResult()
.message(
f"""[Persona]
- 人格情景列表: `/persona list`
- 设置人格情景: `/persona 人格`
- 人格情景详细信息: `/persona view 人格`
- 取消人格: `/persona unset`
默认人格情景: {default_persona["name"]}
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
配置人格情景请前往管理面板-配置页
""",
)
.use_t2i(False),
)
elif parts[1] == "list":
# 获取文件夹树和所有人格
folder_tree = await self.context.persona_manager.get_folder_tree()
all_personas = self.context.persona_manager.personas
lines = ["📂 人格列表:\n"]
# 构建树状输出
tree_lines = self._build_tree_output(folder_tree, all_personas)
lines.extend(tree_lines)
# 输出根目录下的人格(没有文件夹的)
root_personas = [p for p in all_personas if p.folder_id is None]
if root_personas:
if tree_lines: # 如果有文件夹内容,加个空行
lines.append("")
for persona in root_personas:
lines.append(f"👤 {persona.persona_id}")
# 统计信息
total_count = len(all_personas)
lines.append(f"\n{total_count} 个人格")
lines.append("\n*使用 `/persona <人格名>` 设置人格")
lines.append("*使用 `/persona view <人格名>` 查看详细信息")
msg = "\n".join(lines)
message.set_result(MessageEventResult().message(msg).use_t2i(False))
elif parts[1] == "view":
if len(parts) == 2:
message.set_result(MessageEventResult().message("请输入人格情景名"))
return
ps = parts[2].strip()
if persona_info := next(
builtins.filter(
lambda persona: persona["name"] == ps,
self.context.provider_manager.personas,
),
None,
):
msg = f"人格{ps}的详细信息:\n"
msg += f"{persona_info['prompt']}\n"
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
elif parts[1] == "unset":
if not cid:
message.set_result(
MessageEventResult().message("当前没有对话,无法取消人格。"),
)
return
await self.context.conversation_manager.update_conversation_persona_id(
message.unified_msg_origin,
"[%None]",
)
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
ps = "".join(parts[1:]).strip()
if not cid:
message.set_result(
MessageEventResult().message(
"当前没有对话,请先开始对话或使用 /new 创建一个对话。",
),
)
return
if persona_info := next(
builtins.filter(
lambda persona: persona["name"] == ps,
self.context.provider_manager.personas,
),
None,
):
await self.context.conversation_manager.update_conversation_persona_id(
message.unified_msg_origin,
ps,
)
force_warn_msg = ""
if force_applied_persona_id:
force_warn_msg = "提醒:由于自定义规则,您现在切换的人格将不会生效。"
message.set_result(
MessageEventResult().message(
f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。{force_warn_msg}",
),
)
else:
message.set_result(
MessageEventResult().message(
"不存在该人格情景。使用 /persona list 查看所有。",
),
)

View File

@@ -1,125 +0,0 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core import DEMO_MODE, logger
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
class PluginCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
parts = ["已加载的插件:\n"]
for plugin in self.context.get_all_stars():
line = f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
if not plugin.activated:
line += " (未启用)"
parts.append(line + "\n")
if len(parts) == 1:
plugin_list_info = "没有加载任何插件。"
else:
plugin_list_info = "".join(parts)
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
)
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
return
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin off <插件名> 禁用插件。"),
)
return
if self.context._star_manager is None:
event.set_result(MessageEventResult().message("插件管理器未初始化。"))
return
await self.context._star_manager.turn_off_plugin(plugin_name)
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
return
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin on <插件名> 启用插件。"),
)
return
if self.context._star_manager is None:
event.set_result(MessageEventResult().message("插件管理器未初始化。"))
return
await self.context._star_manager.turn_on_plugin(plugin_name)
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
return
if not plugin_repo:
event.set_result(
MessageEventResult().message("/plugin get <插件仓库地址> 安装插件"),
)
return
logger.info(f"准备从 {plugin_repo} 安装插件。")
if self.context._star_manager:
star_mgr = self.context._star_manager
try:
await star_mgr.install_plugin(plugin_repo)
event.set_result(MessageEventResult().message("安装插件成功。"))
except Exception as e:
logger.error(f"安装插件失败: {e}")
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
return
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin help <插件名> 查看插件信息。"),
)
return
plugin = self.context.get_registered_star(plugin_name)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
return
help_msg = ""
help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}"
command_handlers = []
command_names = []
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
if handler.handler_module_path != plugin.module_path:
continue
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
command_handlers.append(handler)
command_names.append(filter_.command_name)
break
if isinstance(filter_, CommandGroupFilter):
command_handlers.append(handler)
command_names.append(filter_.group_name)
if len(command_handlers) > 0:
parts = ["\n\n🔧 指令列表:\n"]
for i in range(len(command_handlers)):
line = f"- {command_names[i]}"
if command_handlers[i].desc:
line += f": {command_handlers[i].desc}"
parts.append(line + "\n")
parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /。")
help_msg += "".join(parts)
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))

View File

@@ -40,10 +40,7 @@ class ProviderCommands:
err_code = "TEST_FAILED"
err_reason = safe_error("", e)
self._log_reachability_failure(
provider,
provider_capability_type,
err_code,
err_reason,
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
@@ -65,7 +62,7 @@ class ProviderCommands:
check_results = [None for _ in providers]
display_data = []
for provider, reachable in zip(providers, check_results, strict=False):
for provider, reachable in zip(providers, check_results):
meta = provider.meta()
id_ = meta.id
error_code = None
@@ -106,7 +103,7 @@ class ProviderCommands:
"info": info,
"mark": mark,
"provider": provider,
},
}
)
return display_data
@@ -131,7 +128,7 @@ class ProviderCommands:
if reachability_check_enabled and (llms or ttss or stts):
await event.send(
MessageEventResult().message("👀 Testing provider reachability..."),
MessageEventResult().message("👀 Testing provider reachability...")
)
llm_data, tts_data, stt_data = await asyncio.gather(
@@ -192,12 +189,12 @@ class ProviderCommands:
elif idx == "tts":
if idx2 is None:
event.set_result(
MessageEventResult().message("Please enter the index."),
MessageEventResult().message("Please enter the index.")
)
return
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index."),
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = self.context.get_all_tts_providers()[idx2 - 1]
@@ -208,17 +205,17 @@ class ProviderCommands:
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}."),
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
elif idx == "stt":
if idx2 is None:
event.set_result(
MessageEventResult().message("Please enter the index."),
MessageEventResult().message("Please enter the index.")
)
return
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index."),
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = self.context.get_all_stt_providers()[idx2 - 1]
@@ -229,12 +226,12 @@ class ProviderCommands:
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}."),
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index."),
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = self.context.get_all_providers()[idx - 1]
@@ -245,86 +242,7 @@ class ProviderCommands:
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}."),
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
else:
event.set_result(MessageEventResult().message("❌ Invalid parameter."))
async def model_ls(
self,
event: AstrMessageEvent,
idx_or_name: int | str | None = None,
) -> None:
"""查看或者切换当前 Provider 的模型。"""
umo = event.unified_msg_origin
provider = self.context.get_using_provider(umo=umo)
if provider is None:
event.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
try:
models = await provider.get_models()
except Exception as e:
event.set_result(
MessageEventResult().message(
f"获取模型列表失败: {safe_error('', e)}",
),
)
return
current_model = provider.get_model()
if idx_or_name is None:
if not models:
event.set_result(
MessageEventResult().message(
f"当前模型: {current_model}\n此提供商未返回可切换模型列表。",
),
)
return
parts = [f"当前模型: {current_model}\n\n可用模型:\n"]
for index, model_name in enumerate(models, start=1):
suffix = " 👈" if model_name == current_model else ""
parts.append(f"{index}. {model_name}{suffix}\n")
parts.append("\n使用 /model <序号> 或 /model <模型名> 切换模型。")
event.set_result(MessageEventResult().message("".join(parts)))
return
selected_model: str | None = None
if isinstance(idx_or_name, int):
if 1 <= idx_or_name <= len(models):
selected_model = models[idx_or_name - 1]
else:
text = idx_or_name.strip()
if text.isdigit():
model_index = int(text)
if 1 <= model_index <= len(models):
selected_model = models[model_index - 1]
elif text:
selected_model = text
if not selected_model:
event.set_result(MessageEventResult().message("❌ Invalid model index."))
return
provider.set_model(selected_model)
provider.provider_config["model"] = selected_model
cfg = self.context.get_config(umo)
providers_config = cfg.get("provider", [])
if isinstance(providers_config, list):
for provider_config in providers_config:
if not isinstance(provider_config, dict):
continue
if provider_config.get("id") == provider.meta().id:
provider_config["model"] = selected_model
break
cfg.save_config()
event.set_result(
MessageEventResult().message(
f"✅ Successfully switched model to {selected_model}.",
),
)

View File

@@ -2,16 +2,6 @@ from astrbot.api import sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
def _normalize_session_variables(value: object) -> dict[str, str]:
if not isinstance(value, dict):
return {}
return {
key: value
for key, value in value.items()
if isinstance(key, str) and isinstance(value, str)
}
class SetUnsetCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
@@ -19,32 +9,28 @@ class SetUnsetCommands:
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
"""设置会话变量"""
uid = event.unified_msg_origin
session_var = _normalize_session_variables(
await sp.session_get(uid, "session_variables", {}),
)
session_var = await sp.session_get(uid, "session_variables", {})
session_var[key] = value
await sp.session_put(uid, "session_variables", session_var)
event.set_result(
MessageEventResult().message(
f"会话 {uid} 变量 {key} 存储成功使用 /unset 移除",
f"会话 {uid} 变量 {key} 存储成功使用 /unset 移除",
),
)
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
"""移除会话变量"""
uid = event.unified_msg_origin
session_var = _normalize_session_variables(
await sp.session_get(uid, "session_variables", {}),
)
session_var = await sp.session_get(uid, "session_variables", {})
if key not in session_var:
event.set_result(
MessageEventResult().message("没有那个变量名格式 /unset 变量名"),
MessageEventResult().message("没有那个变量名格式 /unset 变量名"),
)
else:
del session_var[key]
await sp.session_put(uid, "session_variables", session_var)
event.set_result(
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功"),
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功"),
)

View File

@@ -18,19 +18,19 @@ class SIDCommand:
umo_msg_type = event.session.message_type.value
umo_session_id = event.session.session_id
ret = (
f"UMO: {sid}」 此值可用于设置白名单。\n"
f"UID: {user_id}」 此值可用于设置管理员。\n"
f"消息会话来源信息:\n"
f" 机器人 ID: 「{umo_platform}\n"
f" 消息类型: {umo_msg_type}\n"
f" 会话 ID: {umo_session_id}\n"
f"消息来源可用于配置机器人的配置文件路由。"
f"UMO: {sid}\n"
f"UID: {user_id}\n"
"*Use UMO to set whitelist and configure routing, use UID to set admin list(UMO 可用于设置白名单和配置文件路由UID 可用于设置管理员列表)\n\n"
f"Your session information:\n"
f"Bot ID: {umo_platform}\n"
f"Message Type: {umo_msg_type}\n"
f"Session ID: 「{umo_session_id}\n\n"
)
if (
self.context.get_config()["platform_settings"]["unique_session"]
and event.get_group_id()
):
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}」, 也可将此 ID 加入白名单来放行整个群聊。"
ret += f"\n\nThe group's ID: {event.get_group_id()}」. Set this ID to whitelist to allow the entire group."
event.set_result(MessageEventResult().message(ret).use_t2i(False))

View File

@@ -1,23 +0,0 @@
"""文本转图片命令"""
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
class T2ICommand:
"""文本转图片命令类"""
def __init__(self, context: star.Context) -> None:
self.context = context
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
config = self.context.get_config(umo=event.unified_msg_origin)
if config["t2i"]:
config["t2i"] = False
config.save_config()
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
return
config["t2i"] = True
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))

View File

@@ -1,36 +0,0 @@
"""文本转语音命令"""
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.star.session_llm_manager import SessionServiceManager
class TTSCommand:
"""文本转语音命令类"""
def __init__(self, context: star.Context) -> None:
self.context = context
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
tts_enable = cfg["provider_tts_settings"]["enable"]
# 切换状态
new_status = not ses_tts
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
status_text = "已开启" if new_status else "已关闭"
if new_status and not tts_enable:
event.set_result(
MessageEventResult().message(
f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。",
),
)
else:
event.set_result(
MessageEventResult().message(f"{status_text}当前会话的文本转语音。"),
)

View File

@@ -4,17 +4,12 @@ from astrbot.core.star.filter.command import GreedyStr
from .commands import (
AdminCommands,
AlterCmdCommands,
ConversationCommands,
HelpCommand,
LLMCommands,
NameCommand,
PluginCommands,
ProviderCommands,
SetUnsetCommands,
SIDCommand,
T2ICommand,
TTSCommand,
)
@@ -22,75 +17,22 @@ class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
self.help_c = HelpCommand(self.context)
self.llm_c = LLMCommands(self.context)
self.plugin_c = PluginCommands(self.context)
self.admin_c = AdminCommands(self.context)
self.conversation_c = ConversationCommands(self.context)
self.help_c = HelpCommand(self.context)
self.name_c = NameCommand(self.context)
self.provider_c = ProviderCommands(self.context)
self.setunset_c = SetUnsetCommands(self.context)
self.t2i_c = T2ICommand(self.context)
self.tts_c = TTSCommand(self.context)
self.sid_c = SIDCommand(self.context)
self.alter_cmd_c = AlterCmdCommands(self.context)
@filter.command("help")
async def help(self, event: AstrMessageEvent) -> None:
"""查看帮助"""
"""Show help message"""
await self.help_c.help(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("llm")
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
await self.llm_c.llm(event)
@filter.command_group("plugin")
def plugin(self) -> None:
"""插件管理"""
@plugin.command("ls")
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
await self.plugin_c.plugin_ls(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("off")
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
await self.plugin_c.plugin_off(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("on")
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
await self.plugin_c.plugin_on(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("get")
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
await self.plugin_c.plugin_get(event, plugin_repo)
@plugin.command("help")
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
await self.plugin_c.plugin_help(event, plugin_name)
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
await self.t2i_c.t2i(event)
@filter.command("tts")
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
await self.tts_c.tts(event)
@filter.command("sid")
async def sid(self, event: AstrMessageEvent) -> None:
"""获取会话 ID 和 管理员 ID"""
"""Get session ID and other related information"""
await self.sid_c.sid(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@@ -99,29 +41,20 @@ class Main(star.Star):
"""Set display name for current UMO"""
await self.name_c.name(event, alias)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
await self.admin_c.op(event, admin_id)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent) -> None:
"""Reset conversation history"""
await self.conversation_c.reset(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
"""取消授权管理员。deop <admin_id>"""
await self.admin_c.deop(event, admin_id)
@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""Stop agent execution"""
await self.conversation_c.stop(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
await self.admin_c.wl(event, sid)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
"""删除白名单。dwl <sid>"""
await self.admin_c.dwl(event, sid)
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent) -> None:
"""Create new conversation"""
await self.conversation_c.new_conv(message)
@filter.command("stats")
async def stats(self, message: AstrMessageEvent) -> None:
@@ -136,60 +69,21 @@ class Main(star.Star):
idx: str | int | None = None,
idx2: int | None = None,
) -> None:
"""查看或者切换 LLM Provider"""
"""View or switch LLM Provider"""
await self.provider_c.provider(event, idx, idx2)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
await self.conversation_c.reset(message)
@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话中正在运行的 Agent"""
await self.conversation_c.stop(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("model")
async def model_ls(
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
) -> None:
"""查看或者切换模型"""
await self.provider_c.model_ls(message, idx_or_name)
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
await self.conversation_c.his(message, page)
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
await self.conversation_c.convs(message, page)
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
await self.conversation_c.new_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
"""Update AstrBot WebUI"""
await self.admin_c.update_dashboard(event)
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
"""Set session variable"""
await self.setunset_c.set_variable(event, key, value)
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
"""Unset session variable"""
await self.setunset_c.unset_variable(event, key)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent) -> None:
"""修改命令权限"""
await self.alter_cmd_c.alter_cmd(event)

View File

@@ -1,115 +0,0 @@
import copy
from sys import maxsize
import astrbot.api.message_components as Comp
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.star import Context, Star
from astrbot.core.utils.session_waiter import (
FILTERS,
USER_SESSIONS,
SessionController,
SessionWaiter,
session_waiter,
)
class Main(Star):
"""会话控制"""
def __init__(self, context: Context) -> None:
super().__init__(context)
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
async def handle_session_control_agent(self, event: AstrMessageEvent) -> None:
"""会话控制代理"""
for session_filter in FILTERS:
session_id = session_filter.filter(event)
if session_id in USER_SESSIONS:
await SessionWaiter.trigger(session_id, event)
event.stop_event()
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1)
async def handle_empty_mention(self, event: AstrMessageEvent):
"""实现了对只有一个 @ 的消息内容的处理"""
try:
messages = event.get_messages()
cfg = self.context.get_config(umo=event.unified_msg_origin)
p_settings = cfg["platform_settings"]
wake_prefix = cfg.get("wake_prefix", [])
if len(messages) == 1:
if (
isinstance(messages[0], Comp.At)
and str(messages[0].qq) == str(event.get_self_id())
and p_settings.get("empty_mention_waiting", True)
) or (
isinstance(messages[0], Comp.Plain)
and messages[0].text.strip() in wake_prefix
):
if p_settings.get("empty_mention_waiting_need_reply", True):
try:
# 尝试使用 LLM 生成更生动的回复
# func_tools_mgr = self.context.get_llm_tool_manager()
# 获取用户当前的对话信息
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
event.unified_msg_origin,
)
conversation = None
if curr_cid:
conversation = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
curr_cid,
)
else:
# 创建新对话
curr_cid = await self.context.conversation_manager.new_conversation(
event.unified_msg_origin,
platform_id=event.get_platform_id(),
)
# 使用 LLM 生成回复
yield event.request_llm(
prompt=(
"注意,你正在社交媒体上中与用户进行聊天,用户只是通过@来唤醒你,但并未在这条消息中输入内容,他可能会在接下来一条发送他想发送的内容。"
"你友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化。"
"请注意,你仅需要输出要回复用户的内容,不要输出其他任何东西"
),
session_id=curr_cid,
contexts=[],
system_prompt="",
conversation=conversation,
)
except Exception as e:
logger.error(f"LLM response failed: {e!s}")
# LLM 回复失败,使用原始预设回复
yield event.plain_result("想要问什么呢?😄")
@session_waiter(60)
async def empty_mention_waiter(
controller: SessionController,
event: AstrMessageEvent,
) -> None:
if not event.message_str or not event.message_str.strip():
return
event.message_obj.message.insert(
0,
Comp.At(qq=event.get_self_id(), name=event.get_self_id()),
)
new_event = copy.copy(event)
# 重新推入事件队列
self.context.get_event_queue().put_nowait(new_event)
event.stop_event()
controller.stop()
try:
await empty_mention_waiter(event)
except TimeoutError as _:
pass
except Exception as e:
yield event.plain_result("发生错误,请联系管理员: " + str(e))
finally:
event.stop_event()
except Exception as e:
logger.error("handle_empty_mention error: " + str(e))

View File

@@ -1,146 +0,0 @@
import random
import urllib.parse
from collections.abc import Callable
from dataclasses import dataclass
from aiohttp import ClientSession, ClientTimeout
from bs4 import BeautifulSoup, Tag
HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
"Accept": "*/*",
"Connection": "keep-alive",
"Accept-Language": "en-GB,en;q=0.5",
}
USER_AGENT_BING = "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0"
USER_AGENTS = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0",
]
@dataclass
class SearchResult:
title: str
url: str
snippet: str
favicon: str | None = None
def __str__(self) -> str:
return f"{self.title} - {self.url}\n{self.snippet}"
class SearchEngine:
"""搜索引擎爬虫基类"""
def __init__(self) -> None:
self.TIMEOUT = ClientTimeout(total=10)
self.page = 1
self.headers = HEADERS
def _set_selector(self, selector: str) -> str:
raise NotImplementedError
async def _get_next_page(self, query: str) -> str:
raise NotImplementedError
async def _get_html(self, url: str, data: dict | None = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
if data:
async with (
ClientSession() as session,
session.post(
url,
headers=headers,
data=data,
timeout=self.TIMEOUT,
) as resp,
):
ret = await resp.text(encoding="utf-8")
return ret
else:
async with (
ClientSession() as session,
session.get(
url,
headers=headers,
timeout=self.TIMEOUT,
) as resp,
):
ret = await resp.text(encoding="utf-8")
return ret
def tidy_text(self, text: str) -> str:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def _get_url(self, tag: Tag) -> str:
return self.tidy_text(tag.get_text())
async def search(self, query: str, num_results: int) -> list[SearchResult]:
query = urllib.parse.quote(query)
try:
resp = await self._get_next_page(query)
soup = BeautifulSoup(resp, "html.parser")
links = soup.select(self._set_selector("links"))
results = []
try:
text_selector = self._set_selector("text")
except (KeyError, NotImplementedError):
# Keep backward compatibility with engines that only expose
# title/url/link selectors and do not provide snippets.
text_selector = ""
for link in links:
# Safely get the title text (select_one may return None)
title_elem = link.select_one(self._set_selector("title"))
title = ""
if title_elem is not None:
title = self.tidy_text(title_elem.get_text())
url_tag = link.select_one(self._set_selector("url"))
snippet = ""
if text_selector:
text_elem = link.select_one(text_selector)
if text_elem is not None:
snippet = self.tidy_text(text_elem.get_text())
if title and url_tag:
url = self._get_url(url_tag)
if not url:
continue
if url.startswith("//"):
url = f"https:{url}"
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
raise e
async def _search_with_result_filter(
self,
query: str,
num_results: int,
predicate: Callable[[SearchResult], bool],
) -> list[SearchResult]:
if num_results <= 0:
return []
rough_results = await SearchEngine.search(self, query, max(num_results * 2, 10))
final_results: list[SearchResult] = []
for result in rough_results:
if not predicate(result):
continue
final_results.append(result)
if len(final_results) >= num_results:
break
return final_results

View File

@@ -1,33 +0,0 @@
from . import USER_AGENT_BING, SearchEngine
class Bing(SearchEngine):
NAME = "bing"
def __init__(self) -> None:
super().__init__()
# Prefer international Bing first, keep cn endpoint as compatibility fallback.
self.base_urls = ["https://www.bing.com", "https://cn.bing.com"]
self.headers.update({"User-Agent": USER_AGENT_BING})
def _set_selector(self, selector: str):
selectors = {
"url": "div.b_attribution cite",
"title": "h2",
"text": "p",
"links": "ol#b_results > li.b_algo",
"next": 'div#b_content nav[role="navigation"] a.sb_pagN',
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
# if self.page == 1:
# await self._get_html(self.base_url)
for base_url in self.base_urls:
try:
url = f"{base_url}/search?q={query}"
return await self._get_html(url, None)
except Exception as _:
self.base_url = base_url
continue
raise Exception("Bing search failed")

View File

@@ -1,64 +0,0 @@
from urllib.parse import unquote, urlencode, urlparse
from bs4 import Tag
from . import SearchEngine, SearchResult
class Comet(SearchEngine):
"""Best-effort search via public Perplexity/Comet page.
Note:
- This endpoint is often protected by anti-bot challenges.
- We intentionally treat failures as non-fatal and rely on fallback engines.
"""
NAME = "comet"
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.perplexity.ai"
def _set_selector(self, selector: str):
selectors = {
"url": "a[href^='http'], a[href^='//']",
"title": "main h1, main h2, main h3, h3, h2",
"text": "main article, main div[role='article'], main section, main p, p",
"links": "main article, main div[role='article'], main li, main div.result, article, div[role='article'], li, div.result",
"next": "",
}
return selectors[selector]
async def _get_next_page(self, query: str) -> str:
url = f"{self.base_url}/search?{urlencode({'q': unquote(query)})}"
return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
href = str(tag.get("href") or "")
if href.startswith("//"):
return f"https:{href}"
return href
@staticmethod
def _is_valid_result_url(url: str) -> bool:
lowered = (url or "").strip().lower()
if not lowered:
return False
if lowered.startswith(("#", "javascript:", "mailto:")):
return False
if not lowered.startswith(("http://", "https://")):
return False
netloc = urlparse(lowered).netloc
if not netloc:
return False
if netloc.endswith("perplexity.ai"):
return False
return True
async def search(self, query: str, num_results: int) -> list[SearchResult]:
return await self._search_with_result_filter(
query=query,
num_results=num_results,
predicate=lambda result: self._is_valid_result_url(result.url),
)

View File

@@ -1,43 +0,0 @@
import urllib.parse
from bs4 import Tag
from . import SearchEngine, SearchResult
class DuckDuckGo(SearchEngine):
NAME = "duckduckgo"
def __init__(self) -> None:
super().__init__()
self.base_url = "https://html.duckduckgo.com/html"
def _set_selector(self, selector: str):
selectors = {
"url": "a.result__a, h2 a",
"title": "a.result__a, h2",
"text": "a.result__snippet, div.result__snippet",
"links": "div.result, div.web-result",
"next": "a.result--more__btn",
}
return selectors[selector]
async def _get_next_page(self, query: str) -> str:
params = {"q": urllib.parse.unquote(query), "kl": "us-en"}
url = f"{self.base_url}/?{urllib.parse.urlencode(params)}"
return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
href = str(tag.get("href") or "")
if "duckduckgo.com/l/?" in href:
parsed = urllib.parse.urlparse(href)
target = urllib.parse.parse_qs(parsed.query).get("uddg", [""])[0]
return urllib.parse.unquote(target)
return href
async def search(self, query: str, num_results: int) -> list[SearchResult]:
return await self._search_with_result_filter(
query=query,
num_results=num_results,
predicate=lambda result: result.url.startswith("http"),
)

View File

@@ -1,51 +0,0 @@
import urllib.parse
from bs4 import Tag
from . import SearchEngine, SearchResult
class Google(SearchEngine):
NAME = "google"
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.google.com"
def _set_selector(self, selector: str):
selectors = {
"url": "a[href]",
"title": "h3",
"text": "div.VwiC3b, span.aCOpRe",
"links": "div#search div.g, div#search div.MjjYud",
"next": "a#pnnext",
}
return selectors[selector]
async def _get_next_page(self, query: str) -> str:
params = {
"q": urllib.parse.unquote(query),
"hl": "en",
"gl": "us",
"pws": "0",
"num": "10",
}
url = f"{self.base_url}/search?{urllib.parse.urlencode(params)}"
return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
href = str(tag.get("href") or "")
if href.startswith("/url?"):
parsed = urllib.parse.urlparse(href)
q = urllib.parse.parse_qs(parsed.query).get("q", [""])[0]
return urllib.parse.unquote(q)
return href
async def search(self, query: str, num_results: int) -> list[SearchResult]:
return await self._search_with_result_filter(
query=query,
num_results=num_results,
predicate=lambda result: (
result.url.startswith("http") and "google.com/search?" not in result.url
),
)

View File

@@ -1,53 +0,0 @@
import random
import re
from bs4 import BeautifulSoup, Tag
from . import USER_AGENTS, SearchEngine, SearchResult
class Sogo(SearchEngine):
NAME = "sogo"
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.sogou.com"
self.headers["User-Agent"] = random.choice(USER_AGENTS)
def _set_selector(self, selector: str):
selectors = {
"url": "h3 > a",
"title": "h3",
"text": "",
"links": "div.results > div.vrwrap:not(.middle-better-hintBox)",
"next": "",
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
url = f"{self.base_url}/web?query={query}"
return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
return str(tag.get("href") or "")
async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
return results
async def _parse_url(self, url) -> str:
html = await self._get_html(url)
soup = BeautifulSoup(html, "html.parser")
script = soup.find("script")
if script:
script_text = (
script.string if script.string is not None else script.get_text()
)
match = re.search('window.location.replace\\("(.+?)"\\)', script_text)
if match:
url = match.group(1)
return url

View File

@@ -1,663 +0,0 @@
import asyncio
import json
import random
import uuid
from typing import ClassVar
import aiohttp
from bs4 import BeautifulSoup
from readability import Document
from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star
from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.provider import ProviderRequest
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from .engines import HEADERS, USER_AGENTS, SearchResult
from .engines.bing import Bing
from .engines.comet import Comet
from .engines.duckduckgo import DuckDuckGo
from .engines.google import Google
from .engines.sogo import Sogo
from .provider_routing import (
DEFAULT_WEB_SEARCH_PROVIDER,
build_default_engine_order,
normalize_websearch_provider,
normalize_websearch_provider_for_tools,
validate_default_engine_registry,
)
class Main(star.Star):
TOOLS: ClassVar[list[str]] = [
"web_search",
"fetch_url",
"web_search_tavily",
"tavily_extract_web_page",
"web_search_bocha",
]
def __init__(self, context: star.Context) -> None:
self.context = context
self.tavily_key_index = 0
self.tavily_key_lock = asyncio.Lock()
self.bocha_key_index = 0
self.bocha_key_lock = asyncio.Lock()
# 将 str 类型的 key 迁移至 list[str],并保存
cfg = self.context.get_config()
provider_settings = cfg.get("provider_settings")
if provider_settings:
tavily_key = provider_settings.get("websearch_tavily_key")
if isinstance(tavily_key, str):
logger.info(
"检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。",
)
if tavily_key:
provider_settings["websearch_tavily_key"] = [tavily_key]
else:
provider_settings["websearch_tavily_key"] = []
cfg.save_config()
bocha_key = provider_settings.get("websearch_bocha_key")
if isinstance(bocha_key, str):
if bocha_key:
provider_settings["websearch_bocha_key"] = [bocha_key]
else:
provider_settings["websearch_bocha_key"] = []
cfg.save_config()
self.google_search = Google()
self.bing_search = Bing()
self.ddg_search = DuckDuckGo()
self.comet_search = Comet()
self.sogo_search = Sogo()
self.default_search_engines = {
engine.NAME: engine
for engine in (
self.google_search,
self.bing_search,
self.ddg_search,
self.comet_search,
self.sogo_search,
)
}
validate_default_engine_registry(self.default_search_engines)
self.baidu_initialized = False
async def _tidy_text(self, text: str) -> str:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def _get_from_url(self, url: str) -> str:
"""获取网页内容"""
header = HEADERS
header.update({"User-Agent": random.choice(USER_AGENTS)})
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, headers=header) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, "html.parser")
ret = await self._tidy_text(soup.get_text())
return ret
async def _process_search_result(
self,
result: SearchResult,
idx: int,
websearch_link: bool,
) -> str:
"""处理单个搜索结果"""
logger.info(f"web_searcher - scraping web: {result.title} - {result.url}")
try:
site_result = await self._get_from_url(result.url)
except BaseException:
site_result = ""
site_result = (
f"{site_result[:700]}..." if len(site_result) > 700 else site_result
)
header = f"{idx}. {result.title} "
if websearch_link and result.url:
header += result.url
return f"{header}\n{result.snippet}\n{site_result}\n\n"
async def _web_search_default(
self,
query,
num_results: int = 5,
preferred_provider: str = DEFAULT_WEB_SEARCH_PROVIDER,
) -> list[SearchResult]:
for engine_name in build_default_engine_order(preferred_provider):
engine = self.default_search_engines.get(engine_name)
if not engine:
continue
try:
results = await engine.search(query, num_results)
except Exception as e:
logger.error(
f"{engine_name} search error: {e}, try the next one...",
)
continue
if results:
logger.info(
f"web_searcher - provider `{engine_name}` success: {len(results)} results",
)
return results
logger.debug(f"search {engine_name} returned no results")
return []
async def _get_tavily_key(self, cfg: AstrBotConfig) -> str:
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
tavily_keys = cfg.get("provider_settings", {}).get("websearch_tavily_key", [])
if not tavily_keys:
raise ValueError("错误:Tavily API密钥未在AstrBot中配置。")
async with self.tavily_key_lock:
key = tavily_keys[self.tavily_key_index]
self.tavily_key_index = (self.tavily_key_index + 1) % len(tavily_keys)
return key
async def _web_search_tavily(
self,
cfg: AstrBotConfig,
payload: dict,
) -> list[SearchResult]:
"""使用 Tavily 搜索引擎进行搜索"""
tavily_key = await self._get_tavily_key(cfg)
url = "https://api.tavily.com/search"
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
url,
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Tavily web search failed: {reason}, status: {response.status}",
)
data = await response.json()
results = []
for item in data.get("results", []):
result = SearchResult(
title=item.get("title"),
url=item.get("url"),
snippet=item.get("content"),
favicon=item.get("favicon"),
)
results.append(result)
return results
async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict]:
"""使用 Tavily 提取网页内容"""
tavily_key = await self._get_tavily_key(cfg)
url = "https://api.tavily.com/extract"
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
url,
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Tavily web search failed: {reason}, status: {response.status}",
)
data = await response.json()
results: list[dict] = data.get("results", [])
if not results:
raise ValueError(
"Error: Tavily web searcher does not return any results.",
)
return results
@llm_tool(name="web_search")
async def search_from_search_engine(
self,
event: AstrMessageEvent,
query: str,
max_results: int = 5,
) -> str:
"""搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。
Args:
query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索。
max_results(number): 返回的最大搜索结果数量,默认为 5。
"""
logger.info(f"web_searcher - search_from_search_engine: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
websearch_link = cfg["provider_settings"].get("web_search_link", False)
preferred_provider = normalize_websearch_provider(
cfg.get("provider_settings", {}).get(
"websearch_provider",
DEFAULT_WEB_SEARCH_PROVIDER,
),
)
results = await self._web_search_default(
query,
max_results,
preferred_provider=preferred_provider,
)
if not results:
return "Error: web searcher does not return any results."
tasks = []
for idx, result in enumerate(results, 1):
task = self._process_search_result(result, idx, websearch_link)
tasks.append(task)
processed_results = await asyncio.gather(*tasks, return_exceptions=True)
ret = ""
for processed_result in processed_results:
if isinstance(processed_result, BaseException):
logger.error(f"Error processing search result: {processed_result}")
continue
ret += processed_result
if websearch_link:
ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
return ret
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None:
if self.baidu_initialized:
return
cfg = self.context.get_config(umo=umo)
key = cfg.get("provider_settings", {}).get(
"websearch_baidu_app_builder_key",
"",
)
if not key:
raise ValueError(
"Error: Baidu AI Search API key is not configured in AstrBot.",
)
func_tool_mgr = self.context.get_llm_tool_manager()
await func_tool_mgr.enable_mcp_server(
"baidu_ai_search",
config={
"transport": "sse",
"url": f"http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key={key}",
"headers": {},
"timeout": 600,
},
)
self.baidu_initialized = True
logger.info("Successfully initialized Baidu AI Search MCP server.")
@llm_tool(name="fetch_url")
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
"""Fetch the content of a website with the given web url
Args:
url(string): The url of the website to fetch content from
"""
resp = await self._get_from_url(url)
return resp
@llm_tool("web_search_tavily")
async def search_from_tavily(
self,
event: AstrMessageEvent,
query: str,
max_results: int = 7,
search_depth: str = "basic",
topic: str = "general",
days: int = 3,
time_range: str = "",
start_date: str = "",
end_date: str = "",
) -> str:
"""A web search tool that uses Tavily to search the web for relevant content.
Ideal for gathering current information, news, and detailed web content analysis.
Args:
query(string): Required. Search query.
max_results(number): Optional. The maximum number of results to return. Default is 7. Range is 5-20.
search_depth(string): Optional. The depth of the search, must be one of 'basic', 'advanced'. Default is "basic".
topic(string): Optional. The topic of the search, must be one of 'general', 'news'. Default is "general".
days(number): Optional. The number of days back from the current date to include in the search results. Please note that this feature is only available when using the 'news' search topic.
time_range(string): Optional. The time range back from the current date to include in the search results. This feature is available for both 'general' and 'news' search topics. Must be one of 'day', 'week', 'month', 'year'.
start_date(string): Optional. The start date for the search results in the format 'YYYY-MM-DD'.
end_date(string): Optional. The end date for the search results in the format 'YYYY-MM-DD'.
"""
logger.info(f"web_searcher - search_from_tavily: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []):
raise ValueError("Error: Tavily API key is not configured in AstrBot.")
# build payload
payload = {"query": query, "max_results": max_results, "include_favicon": True}
if search_depth not in ["basic", "advanced"]:
search_depth = "basic"
payload["search_depth"] = search_depth
if topic not in ["general", "news"]:
topic = "general"
payload["topic"] = topic
if topic == "news":
payload["days"] = days
if time_range in ["day", "week", "month", "year"]:
payload["time_range"] = time_range
if start_date:
payload["start_date"] = start_date
if end_date:
payload["end_date"] = end_date
results = await self._web_search_tavily(cfg, payload)
if not results:
return "Error: Tavily web searcher does not return any results."
ret_ls = []
ref_uuid = str(uuid.uuid4())[:4]
for idx, result in enumerate(results, 1):
index = f"{ref_uuid}.{idx}"
ret_ls.append(
{
"title": f"{result.title}",
"url": f"{result.url}",
"snippet": f"{result.snippet}",
# TODO: do not need ref for non-webchat platform adapter
"index": index,
},
)
if result.favicon:
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
# ret = "\n".join(ret_ls)
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
return ret
@llm_tool("tavily_extract_web_page")
async def tavily_extract_web_page(
self,
event: AstrMessageEvent,
url: str = "",
extract_depth: str = "basic",
) -> str:
"""Extract the content of a web page using Tavily.
Args:
url(string): Required. An URl to extract content from.
extract_depth(string): Optional. The depth of the extraction, must be one of 'basic', 'advanced'. Default is "basic".
"""
cfg = self.context.get_config(umo=event.unified_msg_origin)
if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []):
raise ValueError("Error: Tavily API key is not configured in AstrBot.")
if not url:
raise ValueError("Error: url must be a non-empty string.")
if extract_depth not in ["basic", "advanced"]:
extract_depth = "basic"
payload = {
"urls": [url],
"extract_depth": extract_depth,
}
results = await self._extract_tavily(cfg, payload)
ret_ls = []
for result in results:
ret_ls.append(f"URL: {result.get('url', 'No URL')}")
ret_ls.append(f"Content: {result.get('raw_content', 'No content')}")
ret = "\n".join(ret_ls)
if not ret:
return "Error: Tavily web searcher does not return any results."
return ret
async def _get_bocha_key(self, cfg: AstrBotConfig) -> str:
"""并发安全的从列表中获取并轮换BoCha API密钥。"""
bocha_keys = cfg.get("provider_settings", {}).get("websearch_bocha_key", [])
if not bocha_keys:
raise ValueError("错误:BoCha API密钥未在AstrBot中配置。")
async with self.bocha_key_lock:
key = bocha_keys[self.bocha_key_index]
self.bocha_key_index = (self.bocha_key_index + 1) % len(bocha_keys)
return key
async def _web_search_bocha(
self,
cfg: AstrBotConfig,
payload: dict,
) -> list[SearchResult]:
"""使用 BoCha 搜索引擎进行搜索"""
bocha_key = await self._get_bocha_key(cfg)
url = "https://api.bochaai.com/v1/web-search"
header = {
"Authorization": f"Bearer {bocha_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
url,
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"BoCha web search failed: {reason}, status: {response.status}",
)
data = await response.json()
data = data["data"]["webPages"]["value"]
results = []
for item in data:
result = SearchResult(
title=item.get("name"),
url=item.get("url"),
snippet=item.get("snippet"),
favicon=item.get("siteIcon"),
)
results.append(result)
return results
@llm_tool("web_search_bocha")
async def search_from_bocha(
self,
event: AstrMessageEvent,
query: str,
freshness: str = "noLimit",
summary: bool = False,
include: str = "",
exclude: str = "",
count: int = 10,
) -> str:
"""A web search tool based on Bocha Search API, used to retrieve web pages
related to the user's query.
Args:
query (string): Required. User's search query.
freshness (string): Optional. Specifies the time range of the search.
Supported values:
- "noLimit": No time limit (default, recommended).
- "oneDay": Within one day.
- "oneWeek": Within one week.
- "oneMonth": Within one month.
- "oneYear": Within one year.
- "YYYY-MM-DD..YYYY-MM-DD": Search within a specific date range.
Example: "2025-01-01..2025-04-06".
- "YYYY-MM-DD": Search on a specific date.
Example: "2025-04-06".
It is recommended to use "noLimit", as the search algorithm will
automatically optimize time relevance. Manually restricting the
time range may result in no search results.
summary (boolean): Optional. Whether to include a text summary
for each search result.
- True: Include summary.
- False: Do not include summary (default).
include (string): Optional. Specifies the domains to include in
the search. Multiple domains can be separated by "|" or ",".
A maximum of 100 domains is allowed.
Examples:
- "qq.com"
- "qq.com|m.163.com"
exclude (string): Optional. Specifies the domains to exclude from
the search. Multiple domains can be separated by "|" or ",".
A maximum of 100 domains is allowed.
Examples:
- "qq.com"
- "qq.com|m.163.com"
count (number): Optional. Number of search results to return.
- Range: 150
- Default: 10
The actual number of returned results may be less than the
specified count.
"""
logger.info(f"web_searcher - search_from_bocha: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
if not cfg.get("provider_settings", {}).get("websearch_bocha_key", []):
raise ValueError("Error: BoCha API key is not configured in AstrBot.")
# build payload
payload = {
"query": query,
"count": count,
}
# freshness:时间范围
if freshness:
payload["freshness"] = freshness
# 是否返回摘要
payload["summary"] = summary
# include:限制搜索域
if include:
payload["include"] = include
# exclude:排除搜索域
if exclude:
payload["exclude"] = exclude
results = await self._web_search_bocha(cfg, payload)
if not results:
return "Error: BoCha web searcher does not return any results."
ret_ls = []
ref_uuid = str(uuid.uuid4())[:4]
for idx, result in enumerate(results, 1):
index = f"{ref_uuid}.{idx}"
ret_ls.append(
{
"title": f"{result.title}",
"url": f"{result.url}",
"snippet": f"{result.snippet}",
"index": index,
},
)
if result.favicon:
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
# ret = "\n".join(ret_ls)
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
return ret
@filter.on_llm_request(priority=-10000)
async def edit_web_search_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
) -> None:
"""Get the session conversation for the given event."""
cfg = self.context.get_config(umo=event.unified_msg_origin)
prov_settings = cfg.get("provider_settings", {})
websearch_enable = prov_settings.get("web_search", False)
raw_provider = prov_settings.get(
"websearch_provider",
DEFAULT_WEB_SEARCH_PROVIDER,
)
branch_provider, is_known_provider = normalize_websearch_provider_for_tools(
raw_provider,
)
tool_set = req.func_tool
if isinstance(tool_set, FunctionToolManager):
req.func_tool = tool_set.get_full_tool_set()
tool_set = req.func_tool
if not tool_set:
return
if not websearch_enable:
# pop tools
for tool_name in self.TOOLS:
tool_set.remove_tool(tool_name)
return
func_tool_mgr = self.context.get_llm_tool_manager()
if branch_provider == "default":
if not is_known_provider:
logger.warning(
"Unsupported websearch_provider `%s`, fallback to default search tool branch.",
raw_provider,
)
web_search_t = func_tool_mgr.get_func("web_search")
fetch_url_t = func_tool_mgr.get_func("fetch_url")
if web_search_t and web_search_t.active:
tool_set.add_tool(web_search_t)
if fetch_url_t and fetch_url_t.active:
tool_set.add_tool(fetch_url_t)
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
tool_set.remove_tool("AIsearch")
tool_set.remove_tool("web_search_bocha")
elif branch_provider == "tavily":
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
if web_search_tavily and web_search_tavily.active:
tool_set.add_tool(web_search_tavily)
if tavily_extract_web_page and tavily_extract_web_page.active:
tool_set.add_tool(tavily_extract_web_page)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("AIsearch")
tool_set.remove_tool("web_search_bocha")
elif branch_provider == "baidu_ai_search":
try:
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
aisearch_tool = func_tool_mgr.get_func("AIsearch")
if aisearch_tool and aisearch_tool.active:
tool_set.add_tool(aisearch_tool)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
tool_set.remove_tool("web_search_bocha")
except Exception as e:
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
elif branch_provider == "bocha":
web_search_bocha = func_tool_mgr.get_func("web_search_bocha")
if web_search_bocha and web_search_bocha.active:
tool_set.add_tool(web_search_bocha)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("AIsearch")
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")

View File

@@ -1,24 +0,0 @@
from __future__ import annotations
DEFAULT_WEB_SEARCH_PROVIDER = "default"
# Canonical provider ids shown in config UI options.
WEB_SEARCH_PROVIDER_OPTIONS: tuple[str, ...] = (
DEFAULT_WEB_SEARCH_PROVIDER,
"duckduckgo",
"google",
"bing",
"comet",
"sogo",
"tavily",
"baidu_ai_search",
"bocha",
)
# Provider ids that select non-default tool branches directly.
WEB_SEARCH_TOOL_BRANCH_PROVIDERS: tuple[str, ...] = (
DEFAULT_WEB_SEARCH_PROVIDER,
"tavily",
"baidu_ai_search",
"bocha",
)

View File

@@ -1,132 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from .engines.bing import Bing
from .engines.comet import Comet
from .engines.duckduckgo import DuckDuckGo
from .engines.google import Google
from .engines.sogo import Sogo
from .provider_constants import (
DEFAULT_WEB_SEARCH_PROVIDER,
WEB_SEARCH_PROVIDER_OPTIONS,
WEB_SEARCH_TOOL_BRANCH_PROVIDERS,
)
ENGINE_REGISTRY: tuple[tuple[str, type[object], bool], ...] = (
(Bing.NAME, Bing, True),
(Sogo.NAME, Sogo, True),
# Compatibility first: DDG should stay as fallback and cannot become primary.
(DuckDuckGo.NAME, DuckDuckGo, False),
(Google.NAME, Google, True),
(Comet.NAME, Comet, True),
)
DEFAULT_ENGINE_ORDER: tuple[str, ...] = tuple(name for name, _, _ in ENGINE_REGISTRY)
_ENGINE_PROVIDER_SET = {name for name, _, _ in ENGINE_REGISTRY}
_ENGINE_CAN_BE_PRIMARY = {
name: can_be_primary for name, _, can_be_primary in ENGINE_REGISTRY
}
_TOOL_BRANCH_PROVIDER_SET = set(WEB_SEARCH_TOOL_BRANCH_PROVIDERS)
_CANONICAL_PROVIDER_SET = _ENGINE_PROVIDER_SET | _TOOL_BRANCH_PROVIDER_SET
if not _CANONICAL_PROVIDER_SET.issubset(set(WEB_SEARCH_PROVIDER_OPTIONS)):
raise RuntimeError(
"web search provider options and routing providers are out of sync: "
f"canonical={sorted(_CANONICAL_PROVIDER_SET)} options={list(WEB_SEARCH_PROVIDER_OPTIONS)}",
)
_WEB_SEARCH_PROVIDER_ALIASES: dict[str, str] = {
"": DEFAULT_WEB_SEARCH_PROVIDER,
"default": DEFAULT_WEB_SEARCH_PROVIDER,
"native": DEFAULT_WEB_SEARCH_PROVIDER,
}
_WEB_SEARCH_PROVIDER_ALIASES.update({name: name for name in _CANONICAL_PROVIDER_SET})
_WEB_SEARCH_PROVIDER_ALIASES.update(
{
"duckduck_go": DuckDuckGo.NAME,
"duckduck-go": DuckDuckGo.NAME,
"ddg": DuckDuckGo.NAME,
"baidu_ai": "baidu_ai_search",
"baidu": "baidu_ai_search",
"bochaai": "bocha",
# ZeroClaw compatibility: AstrBot has no Brave provider yet, so downgrade to default.
"brave": DEFAULT_WEB_SEARCH_PROVIDER,
},
)
@dataclass(frozen=True)
class NormalizedProvider:
canonical: str
tool_branch: str
is_known: bool
def _normalize_raw_provider(provider: object) -> str:
return str(provider or "").strip().lower().replace(" ", "")
def normalize_websearch(provider: object) -> NormalizedProvider:
raw = _normalize_raw_provider(provider)
alias = _WEB_SEARCH_PROVIDER_ALIASES.get(raw, raw)
canonical = alias or DEFAULT_WEB_SEARCH_PROVIDER
is_engine = canonical in _ENGINE_PROVIDER_SET
is_tool_branch = canonical in _TOOL_BRANCH_PROVIDER_SET
is_known = is_engine or is_tool_branch
tool_branch = canonical if is_tool_branch else DEFAULT_WEB_SEARCH_PROVIDER
return NormalizedProvider(
canonical=canonical,
tool_branch=tool_branch,
is_known=is_known,
)
def normalize_websearch_provider(provider: object) -> str:
return normalize_websearch(provider).canonical
def normalize_websearch_provider_for_tools(provider: object) -> tuple[str, bool]:
normalized = normalize_websearch(provider)
return normalized.tool_branch, normalized.is_known
def resolve_tool_branch_provider(provider: object) -> str:
return normalize_websearch(provider).tool_branch
def build_default_engine_order(provider: object) -> tuple[str, ...]:
normalized = normalize_websearch(provider)
engine_name = normalized.canonical
if engine_name not in _ENGINE_PROVIDER_SET:
return DEFAULT_ENGINE_ORDER
if not _ENGINE_CAN_BE_PRIMARY.get(engine_name, False):
return DEFAULT_ENGINE_ORDER
return (
engine_name,
*tuple(name for name in DEFAULT_ENGINE_ORDER if name != engine_name),
)
def is_known_websearch_provider(provider: object) -> bool:
return normalize_websearch(provider).is_known
def validate_default_engine_registry(engines_by_name: Mapping[str, object]) -> None:
expected_names = {name for name, _, _ in ENGINE_REGISTRY}
missing = [name for name in DEFAULT_ENGINE_ORDER if name not in engines_by_name]
extra = [name for name in engines_by_name if name not in expected_names]
if not missing and not extra:
return
raise ValueError(
"default search engine registry mismatch. "
f"missing={missing}, extra={extra}, expected_order={list(DEFAULT_ENGINE_ORDER)}",
)

View File

@@ -1,128 +1,48 @@
"""AstrBot CLI entry point"""
import os
import platform
import sys
from pathlib import Path
import click
from click.shell_completion import get_completion_class
from . import __version__
from .commands import bk, config, init, password, plugin, run, service, uninstall
from .i18n import t
from .commands import conf, init, password, plug, run
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
def print_version_detail() -> None:
"""Print detailed version info (same for --version and version command)"""
from astrbot.core.utils.astrbot_path import astrbot_paths
click.echo(f"AstrBot: {__version__}")
click.echo(f"Python: {sys.version.split()[0]}")
click.echo(f"System: {platform.system()} {platform.release()}")
click.echo(f"Machine: {platform.machine()}")
git_root = Path(astrbot_paths.root) / ".git"
if git_root.exists():
import subprocess
try:
git_hash = subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"],
cwd=str(astrbot_paths.root),
text=True,
).strip()
git_branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
cwd=str(astrbot_paths.root),
text=True,
).strip()
click.echo(f"Git Branch: {git_branch}")
click.echo(f"Git Commit: {git_hash}")
except Exception:
pass
click.echo(f"AstrBot Root: {astrbot_paths.root}")
click.echo(f"Platform: {platform.platform()}")
def version_callback(ctx: click.Context, param: click.Parameter, value: bool) -> bool:
"""Callback for --version to show detailed version and exit."""
if not value:
return value
print_version_detail()
ctx.exit()
return value
class AstrBotCLIGroup(click.Group):
COMMAND_ALIASES = {
"conf": "config",
"plug": "plugin",
}
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
command = super().get_command(ctx, cmd_name)
if command is not None:
return command
alias_target = self.COMMAND_ALIASES.get(cmd_name)
if alias_target is None:
return None
return super().get_command(ctx, alias_target)
@click.group(cls=AstrBotCLIGroup)
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""Astrbot
Agentic IM Chatbot infrastructure that integrates lots of IM platforms, LLMs, plugins and AI feature, and can be your openclaw alternative. ✨
"""
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")
@click.command()
@click.argument("command_name", required=False, type=str)
@click.option(
"--all",
"-a",
is_flag=True,
help="Show help for all commands recursively.",
)
def help(command_name: str | None, all: bool) -> None:
def help(command_name: str | None) -> None:
"""Display help information for commands
If COMMAND_NAME is provided, display detailed help for that command.
Otherwise, display general help information.
"""
ctx = click.get_current_context()
if all:
def print_recursive_help(command, parent_ctx):
name = command.name
if parent_ctx is None:
name = "astrbot"
cmd_ctx = click.Context(command, info_name=name, parent=parent_ctx)
click.echo(command.get_help(cmd_ctx))
click.echo("\n" + "-" * 50 + "\n")
if isinstance(command, click.Group):
for subcommand in command.commands.values():
print_recursive_help(subcommand, cmd_ctx)
print_recursive_help(cli, None)
return
if command_name:
# Find the specified command
command = cli.get_command(ctx, command_name)
if command:
# Display help for the specific command
parent = ctx.parent or ctx
cmd_ctx = click.Context(command, info_name=command.name, parent=parent)
click.echo(command.get_help(cmd_ctx))
click.echo(command.get_help(ctx))
else:
click.echo(t("cli_unknown_command", command=command_name))
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# Display general help information
@@ -132,56 +52,9 @@ def help(command_name: str | None, all: bool) -> None:
cli.add_command(init)
cli.add_command(run)
cli.add_command(help)
cli.add_command(plugin)
cli.add_command(config)
cli.add_command(uninstall)
cli.add_command(bk)
cli.add_command(plug)
cli.add_command(conf)
cli.add_command(password)
cli.add_command(service)
@click.command()
@click.argument("shell", required=False, type=click.Choice(["bash", "zsh", "fish"]))
def completion(shell: str | None) -> None:
"""Generate shell completion script"""
if shell is None:
shell_path = os.environ.get("SHELL", "")
if "zsh" in shell_path:
shell = "zsh"
elif "bash" in shell_path:
shell = "bash"
elif "fish" in shell_path:
shell = "fish"
else:
click.echo(
"Could not detect shell. Please specify one of: bash, zsh, fish",
err=True,
)
sys.exit(1)
comp_cls = get_completion_class(shell)
if comp_cls is None:
click.echo(f"No completion support for shell: {shell}", err=True)
sys.exit(1)
comp = comp_cls(
cli,
ctx_args={},
prog_name="astrbot",
complete_var="_ASTRBOT_COMPLETE",
)
click.echo(comp.source())
cli.add_command(completion)
@click.command(name="version")
def version_cmd() -> None:
"""Display detailed version information"""
print_version_detail()
cli.add_command(version_cmd)
if __name__ == "__main__":
cli()

View File

@@ -1,28 +0,0 @@
"""ASCII logo and interactive mode utilities for CLI"""
import sys
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
def is_interactive() -> bool:
"""Check if stdout is connected to a TTY (interactive terminal)"""
try:
return sys.stdout.isatty()
except Exception:
return False
def print_logo() -> None:
"""Print ASCII logo if in interactive mode"""
import click
if is_interactive():
click.echo(logo_tmpl)

View File

@@ -1,24 +1,7 @@
from .cmd_bk import bk
from .cmd_conf import conf as config
from .cmd_conf import conf
from .cmd_init import init
from .cmd_password import password
from .cmd_plug import plug as plugin
from .cmd_plug import plug
from .cmd_run import run
from .cmd_service import service
from .cmd_uninstall import uninstall
conf = config
plug = plugin
__all__ = [
"bk",
"conf",
"config",
"init",
"password",
"plug",
"plugin",
"run",
"service",
"uninstall",
]
__all__ = ["conf", "init", "password", "plug", "run"]

View File

@@ -1,392 +0,0 @@
import asyncio
import hashlib
import shutil
import subprocess
from pathlib import Path
import anyio
import click
from astrbot.core import db_helper
from astrbot.core.backup import AstrBotExporter, AstrBotImporter
async def _get_kb_manager():
"""Initialize and return a KnowledgeBaseManager with full dependency chain."""
from astrbot.core import astrbot_config, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.umop_config_router import UmopConfigRouter
ucr = UmopConfigRouter(sp=sp)
await ucr.initialize()
acm = AstrBotConfigManager(
default_config=astrbot_config,
ucr=ucr,
sp=sp,
)
persona_mgr = PersonaManager(db_helper, acm)
await persona_mgr.initialize()
provider_manager = ProviderManager(
acm,
db_helper,
persona_mgr,
)
kb_manager = KnowledgeBaseManager(provider_manager)
await kb_manager.initialize()
return kb_manager
@click.group(name="bk")
def bk():
"""Backup management (Export/Import)"""
@bk.command(name="export")
@click.option("--output", "-o", help="Output directory", default=None)
@click.option(
"--gpg-sign",
"-S",
is_flag=True,
help="Sign backup with GPG default private key",
)
@click.option(
"--gpg-encrypt",
"-E",
help="Encrypt for GPG recipient (Asymmetric)",
metavar="RECIPIENT",
)
@click.option(
"--gpg-symmetric",
"-C",
is_flag=True,
help="Encrypt with symmetric cipher (GPG)",
)
@click.option(
"--digest",
"-d",
type=click.Choice(["md5", "sha1", "sha256", "sha512"]),
help="Generate digital digest",
)
def export_data(
output: str | None,
gpg_sign: bool,
gpg_encrypt: str | None,
gpg_symmetric: bool,
digest: str | None,
):
"""Export all AstrBot data to a backup archive.
If any GPG option (-S, -E, -C) is used, the output file will be processed by GPG
and saved with a .gpg extension.
Examples:
\b
1. Standard Export:
astrbot bk export
-> Generates a plain .zip file.
\b
2. Signed Backup (Integrity Check):
astrbot bk export -S
-> Generates a .zip.gpg file containing the backup and your signature.
-> NOT ENCRYPTED, but packaged in OpenPGP format.
-> Use 'astrbot bk import' or 'gpg --verify' to check integrity.
\b
3. Password Protected (Symmetric Encryption):
astrbot bk export -C
-> Generates an encrypted .zip.gpg file.
-> Prompts for a passphrase.
-> Only accessible with the passphrase.
\b
4. Encrypted for Recipient (Asymmetric Encryption):
astrbot bk export -E "alice@example.com"
-> Generates an encrypted .zip.gpg file for Alice.
-> Only Alice's private key can decrypt it.
\b
5. Signed and Encrypted with Digest:
astrbot bk export -S -E "bob@example.com" -d sha256
-> Signs, encrypts for Bob, and generates a SHA256 checksum file.
"""
# Handle case where -E consumes the next flag (e.g. -E -S)
if gpg_encrypt and gpg_encrypt.startswith("-"):
consumed_flag = gpg_encrypt
click.echo(
click.style(
f"Warning: Flag '{consumed_flag}' was interpreted as the recipient for -E.",
fg="yellow",
),
)
# Recover flags
if consumed_flag == "-S":
gpg_sign = True
click.echo("Recovered flag -S (Sign).")
elif consumed_flag == "-C":
gpg_symmetric = True
click.echo("Recovered flag -C (Symmetric).")
# Prompt for the actual recipient
gpg_encrypt = click.prompt("Please enter the GPG recipient (email or key ID)")
async def _run():
if gpg_sign or gpg_encrypt or gpg_symmetric:
if not shutil.which("gpg"):
raise click.ClickException(
"GPG tool not found. Please install GnuPG to use encryption/signing features.",
)
exporter = AstrBotExporter(db_helper)
async def on_progress(stage, current, total, message):
click.echo(f"[{stage}] {message}")
try:
path_str = await exporter.export_all(output, progress_callback=on_progress)
final_path = Path(path_str)
click.echo(
click.style(f"\nRaw backup exported to: {final_path}", fg="green"),
)
# GPG Operations
if gpg_sign or gpg_encrypt or gpg_symmetric:
# Construct GPG command
# output file usually ends with .gpg
gpg_output = final_path.with_name(final_path.name + ".gpg")
cmd = ["gpg", "--output", str(gpg_output), "--yes"]
if gpg_symmetric:
if gpg_encrypt:
click.echo(
click.style(
"Warning: Symmetric encryption selected, ignoring asymmetric recipient.",
fg="yellow",
),
)
cmd.append("--symmetric")
# No --batch to allow interactive passphrase entry on TTY
else:
# Asymmetric or just Sign
# Note: If encrypting, -s adds signature to the encrypted packet.
if gpg_encrypt:
cmd.extend(["--encrypt", "--recipient", gpg_encrypt])
if gpg_sign:
cmd.append("--sign")
cmd.append(str(final_path))
click.echo(f"Running GPG: {' '.join(cmd)}")
# Replace subprocess.run with asyncio.create_subprocess_exec to avoid blocking the event loop
process = await asyncio.create_subprocess_exec(*cmd)
await process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode or 1, cmd)
# Clean up original file
await anyio.Path(final_path).unlink()
final_path = gpg_output
click.echo(
click.style(f"Processed backup created: {final_path}", fg="green"),
)
# Digest Generation
if digest:
click.echo(f"Calculating {digest} digest...")
hash_func = getattr(hashlib, digest)()
# Read file in chunks
async with await anyio.open_file(final_path, "rb") as f:
while chunk := await f.read(8192):
hash_func.update(chunk)
digest_val = hash_func.hexdigest()
digest_file = final_path.with_name(final_path.name + f".{digest}")
await anyio.Path(digest_file).write_text(
f"{digest_val} *{final_path.name}\n",
encoding="utf-8",
)
click.echo(click.style(f"Digest generated: {digest_file}", fg="green"))
except subprocess.CalledProcessError as e:
click.echo(click.style(f"\nGPG process failed: {e}", fg="red"), err=True)
except Exception as e:
click.echo(click.style(f"\nExport failed: {e}", fg="red"), err=True)
asyncio.run(_run())
@bk.command(name="import")
@click.argument("backup_file")
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompts")
def import_data_command(backup_file: str, yes: bool):
"""Import AstrBot data from a backup archive.
Automatically handles .zip files and .gpg files (signed or encrypted).
If the file is encrypted, you will be prompted for the passphrase.
If a digest file (.sha256, .md5, etc.) exists, it will be verified automatically.
"""
backup_path = Path(backup_file)
if not backup_path.exists():
raise click.ClickException(f"Backup file not found: {backup_file}")
# 1. Verify Digest if exists
def _verify_digest(file_path: Path) -> bool:
supported_digests = ["sha256", "sha512", "md5", "sha1"]
digest_verified = True # Default true if no digest file found
for algo in supported_digests:
digest_file = file_path.with_name(f"{file_path.name}.{algo}")
if digest_file.exists():
click.echo(f"Found digest file: {digest_file.name}")
try:
# Parse digest file
content = digest_file.read_text(encoding="utf-8").strip()
# Format: "digest *filename" or "digest filename"
# We expect the hash to be the first part
if " " in content:
expected_digest = content.split()[0].lower()
else:
expected_digest = content.lower()
click.echo(f"Verifying {algo} digest...")
hash_func = getattr(hashlib, algo)()
with open(file_path, "rb") as f:
while chunk := f.read(8192):
hash_func.update(chunk)
calculated_digest = hash_func.hexdigest().lower()
if calculated_digest == expected_digest:
click.echo(
click.style("Digest verification PASSED.", fg="green"),
)
else:
click.echo(
click.style(
"Digest verification FAILED!",
fg="red",
bold=True,
),
)
click.echo(f" Expected: {expected_digest}")
click.echo(f" Actual: {calculated_digest}")
digest_verified = False
except Exception as e:
click.echo(click.style(f"Error checking digest: {e}", fg="red"))
digest_verified = False
return digest_verified
if not _verify_digest(backup_path):
if not yes:
if not click.confirm(
"Digest verification failed. Abort import?",
default=True,
abort=True,
):
pass
else:
click.echo(
click.style(
"Warning: Digest verification failed. Continuing due to --yes.",
fg="yellow",
),
)
if not yes:
click.confirm(
"This will OVERWRITE all current data (DB, Config, Plugins). Continue?",
abort=True,
default=False,
)
async def _run():
zip_path = backup_path
is_temp_file = False
# Handle GPG encrypted files
if backup_path.suffix == ".gpg":
if not shutil.which("gpg"):
raise click.ClickException(
"GPG tool not found. Cannot decrypt .gpg file.",
)
# Remove .gpg extension for output
decrypted_path = backup_path.with_suffix("")
# If it doesn't look like a zip after stripping .gpg, maybe append .zip?
# But the exporter creates .zip.gpg, so stripping .gpg gives .zip.
click.echo(f"Processing GPG file {backup_path}...")
try:
cmd = [
"gpg",
"--output",
str(decrypted_path),
"--decrypt", # This handles both decryption and signature verification/extraction
str(backup_path),
]
# Allow interactive passphrase
process = await asyncio.create_subprocess_exec(*cmd)
await process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode or 1, cmd)
zip_path = decrypted_path
is_temp_file = True
except subprocess.CalledProcessError:
click.echo(
click.style(
"GPG processing failed. Verify signature or decryption key.",
fg="red",
),
err=True,
)
return
kb_mgr = await _get_kb_manager()
importer = AstrBotImporter(db_helper, kb_mgr)
async def on_progress(stage, current, total, message):
click.echo(f"[{stage}] {message}")
try:
result = await importer.import_all(
str(zip_path),
progress_callback=on_progress,
)
if result.errors:
click.echo(
click.style("\nImport failed with errors:", fg="red"),
err=True,
)
for err in result.errors:
click.echo(f" - {err}", err=True)
else:
click.echo(click.style("\nImport completed successfully!", fg="green"))
if result.warnings:
click.echo(click.style("\nWarnings:", fg="yellow"))
for warn in result.warnings:
click.echo(f" - {warn}")
finally:
if is_temp_file and await anyio.Path(zip_path).exists():
await anyio.Path(zip_path).unlink()
click.echo(f"Cleaned up temporary file: {zip_path}")
asyncio.run(_run())

View File

@@ -1,95 +1,73 @@
"""Configuration CLI for AstrBot.
This module provides:
- secure hashing utilities for the dashboard password (argon2)
- validators for commonly configurable items
- click CLI group with `set`, `get`, and `password` subcommands
"""
from __future__ import annotations
import json
import zoneinfo
from collections.abc import Callable
from typing import Any
import click
from filelock import FileLock, Timeout
from astrbot.cli.i18n import t
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.utils.astrbot_path import astrbot_paths
from astrbot.core.utils.auth_password import (
_is_argon2_hash,
_is_pbkdf2_hash,
hash_dashboard_password,
hash_legacy_dashboard_password,
is_legacy_dashboard_password,
validate_dashboard_password,
)
# --- Password hashing & validation utilities ---
def is_dashboard_password_hash(value: str) -> bool:
"""Heuristic: return True if `value` looks like a supported dashboard password hash."""
if not isinstance(value, str) or not value:
return False
return _is_argon2_hash(value) or _is_pbkdf2_hash(value)
# --- Validators for CLI configuration items ---
from ..utils import check_astrbot_root, get_astrbot_root
def _validate_log_level(value: str) -> str:
value_up = value.upper()
allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
if value_up not in allowed:
raise click.ClickException(t("config_log_level_invalid"))
return value_up
"""Validate log level"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
)
return value
def _validate_dashboard_port(value: str) -> int:
"""Validate Dashboard port"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("Port must be in range 1-65535")
return port
except ValueError:
raise click.ClickException(t("config_port_must_be_number")) from None
if port < 1 or port > 65535:
raise click.ClickException(t("config_port_range_invalid"))
return port
raise click.ClickException("Port must be a number")
def _validate_dashboard_username(value: str) -> str:
if value is None or value.strip() == "":
raise click.ClickException(t("config_username_empty"))
return value.strip()
"""Validate Dashboard username"""
if not value:
raise click.ClickException("Username cannot be empty")
return value
def _validate_dashboard_password(value: str) -> str:
if value is None or value == "":
raise click.ClickException(t("config_password_empty"))
"""Validate Dashboard password"""
from astrbot.core.utils.auth_password import validate_dashboard_password
try:
validate_dashboard_password(value)
except ValueError as e:
raise click.ClickException(str(e)) from e
# Return the plaintext value; callers hash it before storage.
raise click.ClickException(str(e))
return value
def _validate_timezone(value: str) -> str:
"""Validate timezone"""
try:
zoneinfo.ZoneInfo(value)
except Exception as e:
raise click.ClickException(t("config_timezone_invalid", value=value)) from e
except Exception:
raise click.ClickException(
f"Invalid timezone: {value}. Please use a valid IANA timezone name"
)
return value
def _validate_callback_api_base(value: str) -> str:
if not (value.startswith("http://") or value.startswith("https://")):
raise click.ClickException(t("config_callback_invalid"))
"""Validate callback API base URL"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException(
"Callback API base must start with http:// or https://"
)
return value
# Configuration items settable via CLI, mapping config keys to validator functions
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
@@ -100,22 +78,18 @@ CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
}
# --- Config file helpers ---
def _load_config() -> dict[str, Any]:
"""Load or initialize the CLI config file (data/cmd_config.json).
Ensures the astrbot root is valid before proceeding.
"""
root = astrbot_paths.root
if not astrbot_paths.is_root:
"""Load or initialize config file"""
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
config_path = astrbot_paths.data / "cmd_config.json"
config_path = root / "data" / "cmd_config.json"
if not config_path.exists():
# Write DEFAULT_CONFIG to disk if file missing
from astrbot.core.config.default import DEFAULT_CONFIG
config_path.write_text(
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
@@ -124,98 +98,48 @@ def _load_config() -> dict[str, Any]:
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"Failed to parse config file: {e!s}") from e
raise click.ClickException(f"Failed to parse config file: {e!s}")
def _save_config(config: dict[str, Any]) -> None:
config_path = astrbot_paths.data / "cmd_config.json"
"""Save config file"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
def ensure_config_file() -> dict[str, Any]:
return _load_config()
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""Set a value in a nested dictionary"""
parts = path.split(".")
cur = obj
for part in parts[:-1]:
if part not in cur:
cur[part] = {}
elif not isinstance(cur[part], dict):
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict",
)
cur = cur[part]
cur[parts[-1]] = value
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""Get a value from a nested dictionary"""
parts = path.split(".")
cur = obj
for part in parts:
cur = cur[part]
return cur
# --- CLI commands ---
def prompt_dashboard_password(prompt: str = "Dashboard password") -> str:
# 显示密码规则提示
click.echo()
click.echo("密码规则:")
click.echo(" - 至少 12 个字符")
click.echo(" - 必须包含至少一个大写字母")
click.echo(" - 必须包含至少一个小写字母")
click.echo(" - 必须包含至少一个数字")
click.echo()
password = click.prompt(prompt, hide_input=True, confirmation_prompt=True, type=str)
click.echo(f"密码长度: {len(password)} 字符")
return _validate_dashboard_password(password)
def set_dashboard_credentials(
config: dict[str, Any],
*,
username: str | None = None,
password_hash: str | None = None,
) -> None:
if username is not None:
_set_nested_item(
config,
"dashboard.username",
_validate_dashboard_username(username),
)
if password_hash is not None:
if isinstance(password_hash, str) and is_dashboard_password_hash(password_hash):
_set_nested_item(config, "dashboard.password", password_hash)
else:
if is_legacy_dashboard_password(password_hash):
raise click.ClickException(
"Storing legacy dashboard password hashes is no longer supported. "
"Please provide the plaintext password (it will be hashed securely), "
"or provide an Argon2-encoded hash string.",
)
validated = _validate_dashboard_password(password_hash)
_set_nested_item(
config,
"dashboard.pbkdf2_password",
hash_dashboard_password(validated),
)
_set_nested_item(
config,
"dashboard.password",
hash_legacy_dashboard_password(validated),
)
obj = obj[part]
return obj
def _set_dashboard_password(config: dict[str, Any], raw_password: str) -> None:
"""Set dashboard password hashes and clear password migration flags."""
from astrbot.core.utils.auth_password import (
hash_dashboard_password,
hash_md5_dashboard_password,
)
_set_nested_item(
config,
"dashboard.pbkdf2_password",
@@ -224,23 +148,29 @@ def _set_dashboard_password(config: dict[str, Any], raw_password: str) -> None:
_set_nested_item(
config,
"dashboard.password",
hash_legacy_dashboard_password(raw_password),
hash_md5_dashboard_password(raw_password),
)
_set_nested_item(config, "dashboard.password_storage_upgraded", True)
_set_nested_item(config, "dashboard.password_change_required", False)
@click.group(name="config")
@click.group(name="conf")
def conf() -> None:
"""Configuration management commands.
"""Configuration management commands
Supported config keys:
- timezone
- log_level
- dashboard.port
- dashboard.username
- dashboard.password
- callback_api_base
- timezone: Timezone setting (e.g. Asia/Shanghai)
- log_level: Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard port
- dashboard.username: Dashboard username
- dashboard.password: Dashboard password
- callback_api_base: Callback API base URL
"""
@@ -248,17 +178,14 @@ def conf() -> None:
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str) -> None:
"""Set the value of a config item"""
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"Unsupported config key: {key}")
config = _load_config()
try:
# Attempt to get old value (may raise KeyError)
try:
old_value = _get_nested_item(config, key)
except Exception:
old_value = "<not set>"
try:
old_value = _get_nested_item(config, key)
validated_value = CONFIG_VALIDATORS[key](value)
if key == "dashboard.password":
_set_dashboard_password(config, validated_value)
@@ -267,103 +194,47 @@ def set_config(key: str, value: str) -> None:
_save_config(config)
click.echo(f"Config updated: {key}")
click.echo(f" Old value: {old_value}")
click.echo(f" New value: {validated_value}")
except KeyError as e:
raise click.ClickException(f"Unknown config key: {key}") from e
except click.ClickException:
raise
if key == "dashboard.password":
click.echo(" Old value: ********")
click.echo(" New value: ********")
else:
click.echo(f" Old value: {old_value}")
click.echo(f" New value: {validated_value}")
except KeyError:
raise click.ClickException(f"Unknown config key: {key}")
except Exception as e:
raise click.UsageError(f"Failed to set config: {e!s}") from e
raise click.UsageError(f"Failed to set config: {e!s}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None) -> None:
"""Get the value of a config item. If no key is provided, show all configurable items"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"Unsupported config key: {key}")
try:
value = _get_nested_item(config, key)
if key == "dashboard.password":
value = "********"
click.echo(f"{key}: {value}")
except KeyError as e:
raise click.ClickException(f"Unknown config key: {key}") from e
except KeyError:
raise click.ClickException(f"Unknown config key: {key}")
except Exception as e:
raise click.UsageError(f"Failed to get config: {e!s}") from e
raise click.UsageError(f"Failed to get config: {e!s}")
else:
click.echo("Current config:")
for k in CONFIG_VALIDATORS:
for key in CONFIG_VALIDATORS:
try:
v = (
value = (
"********"
if k == "dashboard.password"
else _get_nested_item(config, k)
if key == "dashboard.password"
else _get_nested_item(config, key)
)
click.echo(f" {k}: {v}")
click.echo(f" {key}: {value}")
except (KeyError, TypeError):
# Missing or non-dict paths are simply skipped in listing
pass
def _check_astrbot_not_running() -> None:
"""Refuse to proceed if astrbot is currently running (lock file held)."""
lock_file = astrbot_paths.root / "astrbot.lock"
if not lock_file.exists():
return
lock = FileLock(lock_file, timeout=1)
try:
lock.acquire()
except Timeout:
raise click.ClickException(
"AstrBot is currently running. "
"Please stop it first before changing the password via CLI.",
) from None
else:
lock.release()
@conf.command(name="admin")
@click.option("-u", "--username", type=str, help="Update admain username as well")
@click.option(
"-p",
"--password",
type=str,
help="Set admain password directly without interactive prompt",
)
def set_dashboard_password(username: str | None, password: str | None) -> None:
"""Interactively set dashboard password (with confirmation) or set directly with -p.
Acceptable inputs:
- Plaintext password (recommended): it will be hashed securely before storage.
- Argon2 encoded hash (advanced): stored as-is.
"""
_check_astrbot_not_running()
config = _load_config()
if password is not None:
if isinstance(password, str) and is_dashboard_password_hash(password):
password_hash = password
else:
if is_legacy_dashboard_password(password):
raise click.ClickException(
"Providing legacy dashboard password hashes is no longer supported. "
"Please supply the plaintext password (it will be hashed securely), "
"or provide an Argon2-encoded hash string.",
)
password_hash = _validate_dashboard_password(password)
else:
password_hash = prompt_dashboard_password()
set_dashboard_credentials(
config,
username=username.strip() if username is not None else None,
password_hash=password_hash,
)
_save_config(config)
if username is not None:
click.echo(f"Dashboard username updated: {username.strip()}")
click.echo("Dashboard password updated.")

View File

@@ -1,21 +1,24 @@
import asyncio
import json
import os
import re
from pathlib import Path
import click
from filelock import FileLock, Timeout
from astrbot.cli.utils import DashboardManager
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.utils.astrbot_path import astrbot_paths
from .cmd_conf import ensure_config_file, set_dashboard_credentials
DASHBOARD_INITIAL_PASSWORD_ENV = "ASTRBOT_DASHBOARD_INITIAL_PASSWORD"
async def check_dashboard(astrbot_root: Path) -> None:
"""Check whether dashboard assets are available.
Args:
astrbot_root: AstrBot data directory path.
"""
from ..utils import check_dashboard as _check_dashboard
await _check_dashboard(astrbot_root)
def _initialize_config_from_env(astrbot_root: Path) -> None:
if DASHBOARD_INITIAL_PASSWORD_ENV not in os.environ:
return
@@ -26,212 +29,54 @@ def _initialize_config_from_env(astrbot_root: Path) -> None:
click.echo("Initialized data/cmd_config.json with dashboard initial password.")
async def initialize_astrbot(
astrbot_root: Path,
*,
yes: bool,
backend_only: bool,
admin_username: str | None,
admin_password: str | None,
) -> None:
async def initialize_astrbot(astrbot_root: Path) -> None:
"""Execute AstrBot initialization logic"""
from astrbot.cli.banner import print_logo
click.echo("=" * 60)
click.echo("AstrBot 初始化向导")
click.echo("=" * 60)
print_logo()
click.echo()
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
if yes or click.confirm(
f"确定要将 AstrBot 安装到以下目录吗?\n {astrbot_root}",
if click.confirm(
f"Install AstrBot to this directory? {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"[OK] 已创建: {dot_astrbot}")
click.echo(f"Created {dot_astrbot}")
paths = {
"data": astrbot_root / "data",
"config": astrbot_root / "data" / "config",
"plugins": astrbot_root / "data" / "plugins",
"temp": astrbot_root / "data" / "temp",
"skills": astrbot_root / "data" / "skills",
}
for name, path in paths.items():
path.mkdir(parents=True, exist_ok=True)
status = "Created" if not path.exists() else "Exists"
click.echo(f" [{status}] {name.title()}: {path}")
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
_initialize_config_from_env(astrbot_root)
config_path = astrbot_root / "data" / "cmd_config.json"
if not config_path.exists():
config_path.write_text(
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
click.echo(f"[OK] 配置文件已创建: {config_path}")
ASTRBOT_ROOT = astrbot_root
env_file = ASTRBOT_ROOT / ".env"
if not env_file.exists():
tmpl_candidates = [
Path("/opt/astrbot/config.template"),
getattr(astrbot_paths, "project_root", Path.cwd()) / "config.template",
Path.cwd() / "config.template",
]
tmpl = None
for t in tmpl_candidates:
try:
if t.exists():
tmpl = t
break
except Exception:
continue
if tmpl is not None:
try:
txt = tmpl.read_text(encoding="utf-8")
instance_name = astrbot_root.name or "astrbot"
txt = re.sub("\\$\\{INSTANCE_NAME(:-[^}]*)?\\}", instance_name, txt)
port_val = (
os.environ.get("ASTRBOT_PORT") or os.environ.get("PORT") or "8000"
)
txt = re.sub("\\$\\{PORT(:-[^}]*)?\\}", str(port_val), txt)
txt = re.sub("\\$\\{ASTRBOT_ROOT(:-[^}]*)?\\}", str(ASTRBOT_ROOT), txt)
header = f"# Generated from config.template by astrbot init for instance: {instance_name}\n# This file will be auto-loaded by 'astrbot run'\n\n"
env_file.write_text(header + txt, encoding="utf-8")
env_file.chmod(420)
click.echo(f"[OK] 环境变量文件已创建: {env_file}")
except Exception as e:
click.echo(f"[警告] 无法从模板生成 .env 文件: {e!s}")
else:
click.echo("[提示] 未找到 config.template 文件,跳过 .env 生成")
if admin_password is not None:
raise click.ClickException(
"--admin-password is no longer supported during init. Run 'astrbot conf admin' after initialization.",
)
effective_admin_username = (
admin_username.strip()
if admin_username
else str(DEFAULT_CONFIG["dashboard"]["username"])
)
if admin_username:
config = ensure_config_file()
set_dashboard_credentials(
config,
username=effective_admin_username,
password_hash=None,
)
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
click.echo(f"[OK] Dashboard admin 用户名已设置为: {effective_admin_username}")
click.echo()
click.echo("!" * 60)
click.echo("重要提示:")
click.echo(" 1. Dashboard 密码尚未设置!首次登录前必须先设置密码")
click.echo(" 2. 设置命令: astrbot conf admin")
click.echo(" 3. 登录地址: http://localhost:6185 或 http://服务器IP:6185")
click.echo("!" * 60)
click.echo()
if not backend_only and (
yes
or click.confirm(
"是否需要集成式 WebUI个人电脑推荐服务器推荐使用后端模式",
default=True,
)
):
await DashboardManager().ensure_installed(astrbot_root)
else:
click.echo()
click.echo("[提示] 你选择了后端模式,可以使用以下方式管理 AstrBot")
click.echo(" - 使用在线 Dashboard: 在浏览器中访问远程服务器的 WebUI")
click.echo(" - 使用 CLI 命令: astrbot conf / astrbot plug 等")
click.echo()
click.echo("!" * 60)
click.echo("安全提示:")
click.echo(" HTTPS 前端只能安全连接 localhost 的 HTTP 后端")
click.echo(" 不支持远程 + HTTP 后端(不安全)")
click.echo(" 如需远程访问,请使用 HTTPS 后端或通过反向代理")
click.echo("!" * 60)
click.echo()
await check_dashboard(astrbot_root / "data")
@click.command()
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompts")
@click.option("--backend-only", "-b", is_flag=True, help="Only initialize the backend")
@click.option("--backup", "-f", help="Initialize from backup file", type=str)
@click.option(
"-u",
"--admin-username",
type=str,
help="Set dashboard admin username during initialization",
)
@click.option(
"-p",
"--admin-password",
type=str,
help="Deprecated. Run `astrbot conf admin` after initialization.",
)
@click.option(
"--root",
help="ASTRBOT root directory to initialize (overrides ASTRBOT_ROOT env)",
type=str,
)
def init(
yes: bool,
backend_only: bool,
backup: str | None,
admin_username: str | None,
admin_password: str | None,
root: str | None = None,
) -> None:
def init() -> None:
"""Initialize AstrBot"""
click.echo("Initializing AstrBot...")
if os.environ.get("ASTRBOT_SYSTEMD") == "1":
yes = True
from astrbot.core.utils.astrbot_path import astrbot_paths
from ..utils import get_astrbot_root
astrbot_root = Path(root) if root else astrbot_paths.root
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
try:
with lock.acquire():
asyncio.run(
initialize_astrbot(
astrbot_root,
yes=yes,
backend_only=backend_only,
admin_username=admin_username,
admin_password=admin_password,
),
)
if backup:
from .cmd_bk import import_data_command
click.echo(f"Restoring from backup: {backup}")
click.get_current_context().invoke(
import_data_command,
backup_file=backup,
yes=True,
)
click.echo()
click.echo("=" * 60)
click.echo("初始化完成!")
click.echo("=" * 60)
click.echo()
click.echo("启动 AstrBot")
click.echo(" 完整模式(含 Dashboard: astrbot run")
click.echo(" 仅后端模式: astrbot run --backend-only")
click.echo()
click.echo("首次使用前请先设置管理员密码:")
click.echo(" astrbot conf admin")
click.echo()
except Timeout as err:
asyncio.run(initialize_astrbot(astrbot_root))
click.echo("Done! You can now run 'astrbot run' to start AstrBot")
except Timeout:
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running",
) from err
"Cannot acquire lock file. Please check if another instance is running"
)
except Exception as e:
raise click.ClickException(f"Initialization failed: {e!s}") from e
raise click.ClickException(f"Initialization failed: {e!s}")

View File

@@ -1,28 +1,40 @@
import re
import shutil
from pathlib import Path
import click
from astrbot.cli.i18n import t
from astrbot.cli.utils import (
from ..utils import (
PluginStatus,
build_plug_list,
check_astrbot_root,
get_astrbot_root,
get_git_repo,
install_local_plugin,
manage_plugin,
)
@click.group(name="plugin")
@click.group()
def plug() -> None:
"""Plugin management"""
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
return (base / "data").resolve()
def display_plugins(plugins, title=None, color=None) -> None:
if title:
click.echo(click.style(title, fg=color, bold=True))
click.echo(
f"{'Name':<20} {'Version':<10} {'Status':<10} {'Author':<15} {'Description':<30}",
f"{'Name':<20} {'Version':<10} {'Status':<10} {'Author':<15} {'Description':<30}"
)
click.echo("-" * 85)
@@ -38,13 +50,11 @@ def display_plugins(plugins, title=None, color=None) -> None:
@click.argument("name")
def new(name: str) -> None:
"""Create a new plugin"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(t("plugin_already_exists", name=name))
raise click.ClickException(f"Plugin {name} already exists")
author = click.prompt("Enter plugin author", type=str)
desc = click.prompt("Enter plugin description", type=str)
@@ -75,7 +85,7 @@ def new(name: str) -> None:
# Rewrite README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://docs.astrbot.app)\n",
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://docs.astrbot.app)\n"
)
# Rewrite main.py
@@ -97,9 +107,7 @@ def new(name: str) -> None:
@click.option("--all", "-a", is_flag=True, help="List uninstalled plugins")
def list(all: bool) -> None:
"""List plugins"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# Unpublished plugins
@@ -136,14 +144,32 @@ def list(all: bool) -> None:
@plug.command()
@click.argument("name")
@click.argument("name", required=False)
@click.option(
"--editable",
"-e",
"local_path",
type=click.Path(exists=True, file_okay=False, path_type=Path),
help="Install a plugin from a local directory as a symlink",
)
@click.option("--proxy", help="Proxy server address")
def install(name: str, proxy: str | None) -> None:
def install(name: str | None, local_path: Path | None, proxy: str | None) -> None:
"""Install a plugin"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plug_path = base_path / "plugins"
if local_path is not None:
install_local_plugin(local_path, plug_path, editable=True)
return
if name is None:
raise click.ClickException("Missing plugin name or local plugin path")
local_name_path = Path(name).expanduser()
if local_name_path.exists() and local_name_path.is_dir():
install_local_plugin(local_name_path, plug_path, editable=False)
return
plugins = build_plug_list(base_path / "plugins")
plugin = next(
@@ -156,7 +182,7 @@ def install(name: str, proxy: str | None) -> None:
)
if not plugin:
raise click.ClickException(t("plugin_not_found_or_installed", name=name))
raise click.ClickException(f"Plugin {name} not found or already installed")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@@ -165,26 +191,24 @@ def install(name: str, proxy: str | None) -> None:
@click.argument("name")
def remove(name: str) -> None:
"""Uninstall a plugin"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(t("plugin_not_found_or_installed", name=name))
raise click.ClickException(f"Plugin {name} does not exist or is not installed")
plugin_path = plugin["local_path"]
click.confirm(t("plugin_uninstall_confirm", name=name), default=False, abort=True)
click.confirm(
f"Are you sure you want to uninstall plugin {name}?", default=False, abort=True
)
try:
shutil.rmtree(plugin_path)
click.echo(t("plugin_uninstall_success", name=name))
click.echo(f"Plugin {name} has been uninstalled")
except Exception as e:
raise click.ClickException(
t("plugin_uninstall_failed_ex", name=name, error=str(e)),
) from e
raise click.ClickException(f"Failed to uninstall plugin {name}: {e}")
@plug.command()
@@ -192,9 +216,7 @@ def remove(name: str) -> None:
@click.option("--proxy", help="GitHub proxy address")
def update(name: str, proxy: str | None) -> None:
"""Update plugins"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -210,7 +232,7 @@ def update(name: str, proxy: str | None) -> None:
if not plugin:
raise click.ClickException(
f"Plugin {name} does not need updating or cannot be updated",
f"Plugin {name} does not need updating or cannot be updated"
)
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@@ -220,13 +242,13 @@ def update(name: str, proxy: str | None) -> None:
]
if not need_update_plugins:
click.echo(t("plugin_no_update_needed"))
click.echo("No plugins need updating")
return
click.echo(t("plugin_found_update", count=str(len(need_update_plugins))))
click.echo(f"Found {len(need_update_plugins)} plugin(s) needing update")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(t("plugin_updating", name=plugin_name))
click.echo(f"Updating plugin {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@@ -234,9 +256,7 @@ def update(name: str, proxy: str | None) -> None:
@click.argument("query")
def search(query: str) -> None:
"""Search for plugins"""
from astrbot.core.utils.astrbot_path import astrbot_paths
base_path = astrbot_paths.data
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
matched_plugins = [
@@ -248,7 +268,7 @@ def search(query: str) -> None:
]
if not matched_plugins:
click.echo(t("plugin_search_no_result", query=query))
click.echo(f"No plugins matching '{query}' found")
return
display_plugins(matched_plugins, t("plugin_search_results", query=query), "cyan")
display_plugins(matched_plugins, f"Search results: '{query}'", "cyan")

View File

@@ -1,92 +1,15 @@
"""AstrBot Run
Environment Variables Used in Project:
Core:
- `ASTRBOT_ROOT`: AstrBot root directory path.
- `ASTRBOT_LOG_LEVEL`: Log level (e.g. INFO, DEBUG).
- `ASTRBOT_CLI`: Flag indicating execution via CLI.
- `ASTRBOT_DESKTOP_CLIENT`: Flag indicating execution via desktop client.
- `ASTRBOT_SYSTEMD`: Flag indicating execution via systemd service.
- `ASTRBOT_RELOAD`: Enable plugin auto-reload (set to "1").
- `ASTRBOT_DISABLE_METRICS`: Disable metrics upload (set to "1").
- `TESTING`: Enable testing mode.
- `DEMO_MODE`: Enable demo mode.
- `PYTHON`: Python executable path override (for local code execution).
Dashboard / Backend:
- `ASTRBOT_DASHBOARD_ENABLE`: Enable/Disable Dashboard.
- `ASTRBOT_HOST`: Dashboard bind host.
- `ASTRBOT_PORT`: Dashboard bind port.
SSL (AstrBot-standard names):
- `ASTRBOT_SSL_ENABLE`: Enable SSL for API.
- `ASTRBOT_SSL_CERT`: SSL Certificate path for backend.
- `ASTRBOT_SSL_KEY`: SSL Key path for backend.
- `ASTRBOT_SSL_CA_CERTS`: SSL CA Certs path for backend.
Network:
- `http_proxy` / `https_proxy`: Proxy URL.
- `no_proxy`: No proxy list.
Internationalization:
- `ASTRBOT_CLI_LANG`: CLI interface language (zh/en).
Integrations:
- `DASHSCOPE_API_KEY`: Alibaba DashScope API Key (for Rerank).
- `COZE_API_KEY` / `COZE_BOT_ID`: Coze integration.
- `BAY_DATA_DIR`: Computer Use data directory.
Platform Specific:
- `TEST_MODE`: Test mode for QQOfficial.
"""
from __future__ import annotations
import asyncio
import os
import re
import sys
import traceback
from pathlib import Path
import click
from dotenv import load_dotenv
from filelock import FileLock, Timeout
from astrbot.cli.utils import DashboardManager
from astrbot.runtime_bootstrap import initialize_runtime_bootstrap
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
# Python version check: require 3.12 or 3.13
if not (sys.version_info.major == 3 and sys.version_info.minor in (12, 13)):
sys.exit(1)
# Regular expression to find bash-like parameter expansions:
# ${VAR:-default} or ${VAR}
_PARAM_EXPAND_RE = re.compile(r"\$\{([^}:]+?)(:-([^}]*))?\}")
def _expand_parameter(
match: re.Match,
env: dict[str, str],
local: dict[str, str],
) -> str:
"""Helper to expand a single ${VAR:-default} or ${VAR} occurrence.
Precedence:
1. local dict (parsed from the same file, earlier entries)
2. environment variables
3. default provided in the expansion (if any)
4. empty string
"""
var = match.group(1)
default = match.group(3) if match.group(3) is not None else ""
# Prefer 'local' parsed values first
if var in local and local[var] != "":
return local[var]
val = env.get(var, "")
if val != "":
return val
return default
DASHBOARD_RESET_PASSWORD_ENV = "ASTRBOT_RESET_DASHBOARD_PASSWORD"
async def run_astrbot(astrbot_root: Path) -> None:
@@ -94,11 +17,7 @@ async def run_astrbot(astrbot_root: Path) -> None:
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
if (
os.environ.get("ASTRBOT_DASHBOARD_ENABLE", os.environ.get("DASHBOARD_ENABLE"))
== "True"
):
await DashboardManager().ensure_installed(astrbot_root)
await check_dashboard(astrbot_root / "data")
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
@@ -110,316 +29,46 @@ async def run_astrbot(astrbot_root: Path) -> None:
@click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins")
@click.option("--host", "-H", help="AstrBot Dashboard Host", required=False, type=str)
@click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str)
@click.option("--root", help="AstrBot root directory", required=False, type=str)
@click.option(
"--service-config",
"-c",
help="Service configuration file path (supports ${VAR:-default} style expansion)",
required=False,
type=str,
)
@click.option(
"--backend-only",
"-b",
"--reset-password",
is_flag=True,
default=False,
help="Disable WebUI, run backend only",
help="Reset dashboard initial password on startup",
)
@click.option(
"--log-level",
"-l",
help="Log level",
required=False,
type=str,
default="INFO",
)
@click.option(
"--ssl-cert",
help="SSL certificate file path for backend (preferred env name: ASTRBOT_SSL_CERT)",
required=False,
type=str,
)
@click.option(
"--ssl-key",
help="SSL private key file path for backend (preferred env name: ASTRBOT_SSL_KEY)",
required=False,
type=str,
)
@click.option(
"--ssl-ca",
help="SSL CA certificates file path for backend (preferred env name: ASTRBOT_SSL_CA_CERTS)",
required=False,
type=str,
)
@click.option("--debug", is_flag=True, help="Enable debug mode")
@click.command()
def run(
reload: bool,
host: str,
port: str,
root: str,
service_config: str,
backend_only: bool,
log_level: str,
ssl_cert: str,
ssl_key: str,
ssl_ca: str,
debug: bool,
) -> None:
def run(reload: bool, port: str | None, reset_password: bool) -> None:
"""Run AstrBot"""
initialize_runtime_bootstrap()
try:
if debug:
log_level = "DEBUG"
# --- Step 1: Resolve service-config path (if provided). We'll treat it as a .env file later. ---
svc_path: Path | None = None
if service_config:
candidate = Path(service_config)
if not candidate.exists():
# Try to expand user and resolve
candidate = Path(os.path.expanduser(service_config))
if candidate.exists():
svc_path = candidate
# NOTE:
# Loading of common .env files (CWD/.env, packaged project .env, ASTRBOT_ROOT/.env)
# has been moved to astrbot.core.utils.astrbot_path during import-time to avoid
# early-initialization ordering issues. Those files are loaded there using
# `override=False` so they do not clobber environment variables provided by the
# systemd unit or the caller.
#
# Here we only load an explicit service-config file (if given). Service-config
# should be able to override the common .env files, but CLI-provided values must
# still win; the CLI will set/overwrite corresponding environment variables
# below after this load.
if svc_path and svc_path.exists():
# Load service-config as an env file and allow it to override previously-loaded
# .env values (those were loaded by astrbot_path). CLI variables are applied
# after this point and will take precedence.
load_dotenv(dotenv_path=str(svc_path), override=True)
# Mark CLI execution
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = get_astrbot_root()
from astrbot.core.utils.astrbot_path import astrbot_paths
# Resolve astrbot_root with the following precedence:
# 1. CLI --root parameter (local variable `root`)
# 2. ASTRBOT_ROOT environment variable (possibly from .env or parsed service config)
# 3. packaged default astrbot_paths.root
if root:
os.environ["ASTRBOT_ROOT"] = root
astrbot_root = Path(root)
elif os.environ.get("ASTRBOT_ROOT"):
astrbot_root = Path(os.environ["ASTRBOT_ROOT"])
else:
astrbot_root = astrbot_paths.root
if not astrbot_paths.is_root:
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
# Ensure ASTRBOT_ROOT env var is set to the resolved root (without overriding a CLI-provided root value above)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
sys.path.insert(0, str(astrbot_root))
# Host/Port precedence: CLI args > parsed service config/env/.env > defaults.
if port is not None:
os.environ["ASTRBOT_PORT"] = port
if host is not None:
os.environ["ASTRBOT_HOST"] = host
# CLI-provided SSL paths should set backend-standard env names.
if ssl_cert is not None:
os.environ["ASTRBOT_SSL_CERT"] = ssl_cert
if ssl_key is not None:
os.environ["ASTRBOT_SSL_KEY"] = ssl_key
if ssl_ca is not None:
os.environ["ASTRBOT_SSL_CA_CERTS"] = ssl_ca
# Dashboard enable is derived from CLI flag (--backend-only). CLI decision should win.
os.environ["ASTRBOT_DASHBOARD_ENABLE"] = str(not backend_only)
os.environ["ASTRBOT_LOG_LEVEL"] = log_level
if port:
os.environ["DASHBOARD_PORT"] = port
if reload:
click.echo("Plugin auto-reload enabled")
os.environ["ASTRBOT_RELOAD"] = "1"
if debug:
keys_to_print = [
"ASTRBOT_ROOT",
"ASTRBOT_LOG_LEVEL",
"ASTRBOT_CLI",
"ASTRBOT_DESKTOP_CLIENT",
"ASTRBOT_SYSTEMD",
"ASTRBOT_RELOAD",
"ASTRBOT_DISABLE_METRICS",
"TESTING",
"DEMO_MODE",
"PYTHON",
"ASTRBOT_DASHBOARD_ENABLE",
"DASHBOARD_ENABLE",
"ASTRBOT_HOST",
"DASHBOARD_HOST",
"ASTRBOT_PORT",
"DASHBOARD_PORT",
# Dashboard SSL (legacy)
"ASTRBOT_SSL_ENABLE",
"DASHBOARD_SSL_ENABLE",
"ASTRBOT_SSL_CERT",
"DASHBOARD_SSL_CERT",
"ASTRBOT_SSL_KEY",
"DASHBOARD_SSL_KEY",
"ASTRBOT_SSL_CA_CERTS",
"DASHBOARD_SSL_CA_CERTS",
# Backend-standard SSL (preferred)
"ASTRBOT_SSL_ENABLE",
"ASTRBOT_SSL_CERT",
"ASTRBOT_SSL_KEY",
"ASTRBOT_SSL_CA_CERTS",
"http_proxy",
"https_proxy",
"no_proxy",
"DASHSCOPE_API_KEY",
"COZE_API_KEY",
"COZE_BOT_ID",
"BAY_DATA_DIR",
"TEST_MODE",
]
click.secho("\n[Debug Mode] Environment Variables:", fg="yellow", bold=True)
for key in keys_to_print:
if key in os.environ:
val = os.environ[key]
if "KEY" in key or "PASSWORD" in key or "SECRET" in key:
if len(val) > 8:
val = val[:4] + "****" + val[-4:]
else:
val = "****"
click.echo(f" {click.style(key, fg='cyan')}: {val}")
if svc_path:
click.echo(
f" {click.style('SERVICE_CONFIG', fg='cyan')}: {svc_path!s}",
)
click.echo("")
if reset_password:
os.environ[DASHBOARD_RESET_PASSWORD_ENV] = "1"
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
with lock.acquire():
async def run_with_logging() -> None:
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
if (
os.environ.get(
"ASTRBOT_DASHBOARD_ENABLE",
os.environ.get("DASHBOARD_ENABLE"),
)
== "True"
):
await DashboardManager().ensure_installed(astrbot_root)
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# Register a stdout subscriber for real-time log streaming
log_queue = log_broker.register()
db = db_helper
initial_loader = InitialLoader(db, log_broker)
# Start a task to stream logs to stdout
async def stream_logs() -> None:
"""Stream logs from LogBroker to stdout."""
while True:
try:
log_entry = await asyncio.wait_for(
log_queue.get(),
timeout=0.5,
)
# Format: [LEVEL] message
level = log_entry.get("level_name", "INFO")
message = log_entry.get("message", "")
if message:
level_color = {
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red",
}.get(level, "white")
click.secho(
f"[{level}]",
fg=level_color,
bold=False,
nl=False,
)
click.echo(f" {message}")
except TimeoutError:
continue
except asyncio.CancelledError:
break
# Start streaming task
stream_task = asyncio.create_task(stream_logs())
try:
await initial_loader.start()
finally:
stream_task.cancel()
try:
await stream_task
except asyncio.CancelledError:
pass
click.echo()
click.echo("=" * 60)
click.echo("AstrBot 启动中...")
click.echo("=" * 60)
from astrbot.cli.banner import print_logo
print_logo()
click.echo()
if backend_only:
click.echo("[模式] 仅后端模式(无本地 Dashboard")
click.echo()
click.echo("[提示] 可以通过以下方式访问 WebUI")
click.echo(" - 使用远程服务器的在线 Dashboard")
click.echo(" - 地址: http://服务器IP:6185")
click.echo()
else:
dashboard_url = f"http://{host or 'localhost'}:{port or '6185'}"
click.echo("[模式] 完整模式(含本地 Dashboard")
click.echo()
click.echo(f"[Dashboard] 请访问: {dashboard_url}")
click.echo()
click.echo("!" * 60)
click.echo("安全提示:")
click.echo(" HTTPS 前端只能安全连接 localhost 的 HTTP 后端")
click.echo(" 不支持远程 + HTTP 后端(不安全)")
click.echo("!" * 60)
click.echo()
click.echo("正在启动服务...(日志输出中)")
click.echo()
asyncio.run(run_with_logging())
asyncio.run(run_astrbot(astrbot_root))
except KeyboardInterrupt:
click.echo("AstrBot has been shut down.")
except Timeout:
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running",
) from None
"Cannot acquire lock file. Please check if another instance is running"
)
except Exception as e:
# Keep original traceback visible for diagnostics
raise click.ClickException(
f"Runtime error: {e}\n{traceback.format_exc()}",
) from e
raise click.ClickException(f"Runtime error: {e}\n{traceback.format_exc()}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,69 +0,0 @@
import os
import shutil
from pathlib import Path
import click
from astrbot.core.utils.astrbot_path import astrbot_paths
@click.command()
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompts")
@click.option(
"--keep-data",
is_flag=True,
help="Keep data directory (config, plugins, etc.)",
)
def uninstall(yes: bool, keep_data: bool) -> None:
"""Remove AstrBot files from the current root directory."""
if os.environ.get("ASTRBOT_SYSTEMD") == "1":
yes = True
dot_astrbot = astrbot_paths.root / ".astrbot"
lock_file = astrbot_paths.root / "astrbot.lock"
data_dir = astrbot_paths.data
removable_paths: list[Path] = [dot_astrbot, lock_file]
if not keep_data:
removable_paths.insert(0, data_dir)
# Check if this looks like an AstrBot root before blowing things up
if not dot_astrbot.exists() and not data_dir.exists():
click.echo("No AstrBot initialization found in current directory.")
return
if keep_data:
click.echo("Keeping data directory as requested.")
if yes or click.confirm(
f"Are you sure you want to remove AstrBot data at {astrbot_paths.root}? \n"
f"This will delete:\n"
f" - {data_dir} (Config, Plugins, Database)\n"
f" - {dot_astrbot}\n"
f" - {lock_file}",
default=False,
abort=True,
):
removed_any = False
for path in removable_paths:
if not path.exists():
continue
removed_any = True
if path.is_dir():
click.echo(f"Removing directory: {path}")
shutil.rmtree(path)
else:
click.echo(f"Removing file: {path}")
path.unlink()
if removed_any:
click.echo("AstrBot files removed successfully.")
else:
click.echo("No removable AstrBot files were found.")
# TODO: Consider adding an explicit `--service` cleanup mode instead of
# touching systemd or other service managers during normal uninstall.
# TODO: Consider adding package-manager-specific uninstall helpers once
# the CLI can reliably detect the installation source.
click.echo("uv: uv tool uninstall astrbot")
click.echo("paru/yay: paru -R astrbot")

View File

@@ -1,278 +0,0 @@
"""Internationalization support for AstrBot CLI.
This module provides i18n support with Chinese and English languages.
Language is auto-detected from environment or can be set manually.
"""
from __future__ import annotations
import os
from enum import Enum
from functools import lru_cache
class Language(Enum):
"""Supported languages."""
ZH = "zh"
EN = "en"
# Translation dictionaries
_TRANSLATIONS: dict[Language, dict[str, str]] = {
Language.ZH: {
# CLI welcome and general
"cli_welcome": "欢迎使用 AstrBot CLI!",
"cli_version": "AstrBot CLI 版本: {version}",
"cli_unknown_command": "未知命令: {command}",
"cli_help_available": "使用 astrbot help --all 查看所有命令",
# Dashboard commands
"dashboard_bundled": "Dashboard 已打包在安装包中 - 跳过下载",
"dashboard_not_installed": "Dashboard 未安装",
"dashboard_install_confirm": "是否安装 Dashboard?",
"dashboard_installing": "正在安装 Dashboard...",
"dashboard_install_success": "Dashboard 安装成功",
"dashboard_install_failed": "Dashboard 安装失败: {error}",
"dashboard_not_needed": "Dashboard 不需要安装",
"dashboard_declined": "Dashboard 安装已取消",
"dashboard_already_up_to_date": "Dashboard 已是最新版本",
"dashboard_version": "Dashboard 版本: {version}",
"dashboard_download_failed": "Dashboard 下载失败: {error}",
"dashboard_init_dir": "正在初始化 Dashboard 目录...",
"dashboard_init_success": "Dashboard 初始化成功",
# Plugin commands
"plugin_installing": "正在安装插件: {name}",
"plugin_install_success": "插件安装成功: {name}",
"plugin_install_failed": "插件安装失败: {name}",
"plugin_uninstall_confirm": "确定要卸载插件 {name} 吗?",
"plugin_uninstall_success": "插件卸载成功: {name}",
"plugin_uninstall_failed": "插件卸载失败: {name}",
"plugin_list_empty": "未安装任何插件",
"plugin_already_installed": "插件已安装: {name}",
"plugin_not_found": "插件未找到: {name}",
"plugin_already_exists": "插件已存在: {name}",
"plugin_not_found_or_installed": "插件未找到或已安装: {name}",
"plugin_uninstall_failed_ex": "插件卸载失败 {name}: {error}",
"plugin_no_update_needed": "没有需要更新的插件",
"plugin_found_update": "发现 {count} 个插件需要更新",
"plugin_updating": "正在更新插件 {name}...",
"plugin_search_no_result": "未找到匹配 '{query}' 的插件",
"plugin_search_results": "搜索结果: '{query}'",
# Config commands
"config_show": "显示配置",
"config_set_success": "配置项已更新: {key} = {value}",
"config_set_failed": "配置项更新失败: {key}",
"config_set_failed_ex": "设置配置失败: {error}",
"config_get_success": "{key} = {value}",
"config_get_not_found": "配置项未找到: {key}",
"config_reset_confirm": "确定要重置所有配置吗?",
"config_reset_success": "配置已重置",
# Config validators
"config_log_level_invalid": "日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
"config_port_must_be_number": "端口必须是数字",
"config_port_range_invalid": "端口必须在 1-65535 范围内",
"config_username_empty": "用户名不能为空",
"config_password_empty": "密码不能为空",
"config_timezone_invalid": "无效的时区: {value}。请使用有效的 IANA 时区名称",
"config_callback_invalid": "回调 API 基础路径必须以 http:// 或 https:// 开头",
"config_key_unsupported": "不支持的配置项: {key}",
"config_key_unknown": "未知的配置项: {key}",
"config_updated": "配置已更新: {key}",
# Init command
"init_creating": "正在创建配置目录...",
"init_created": "配置目录已创建: {path}",
"init_copying": "正在复制配置文件...",
"init_copied": "配置文件已复制",
"init_success": "AstrBot 初始化完成!",
"init_failed": "初始化失败: {error}",
# Run command
"run_starting": "正在启动 AstrBot...",
"run_started": "AstrBot 已启动!",
"run_backend_only": "以无界面模式启动",
"run_failed": "启动失败: {error}",
"run_stopped": "AstrBot 已停止",
# Common
"yes": "",
"no": "",
"cancel": "取消",
"confirm": "确认",
"error": "错误",
"success": "成功",
"warning": "警告",
"info": "信息",
"loading": "加载中...",
"done": "完成",
"failed": "失败",
"retry": "重试",
"exit": "退出",
"continue": "继续",
},
Language.EN: {
# CLI welcome and general
"cli_welcome": "Welcome to AstrBot CLI!",
"cli_version": "AstrBot CLI version: {version}",
"cli_unknown_command": "Unknown command: {command}",
"cli_help_available": "Use astrbot help --all to see all commands",
# Dashboard commands
"dashboard_bundled": "Dashboard is bundled with the package - skipping download",
"dashboard_not_installed": "Dashboard is not installed",
"dashboard_install_confirm": "Install Dashboard?",
"dashboard_installing": "Installing Dashboard...",
"dashboard_install_success": "Dashboard installed successfully",
"dashboard_install_failed": "Failed to install dashboard: {error}",
"dashboard_not_needed": "Dashboard not needed",
"dashboard_declined": "Dashboard installation declined.",
"dashboard_already_up_to_date": "Dashboard is already up to date",
"dashboard_version": "Dashboard version: {version}",
"dashboard_download_failed": "Failed to download dashboard: {error}",
"dashboard_init_dir": "Initializing dashboard directory...",
"dashboard_init_success": "Dashboard initialized successfully",
# Plugin commands
"plugin_installing": "Installing plugin: {name}",
"plugin_install_success": "Plugin installed successfully: {name}",
"plugin_install_failed": "Failed to install plugin: {name}",
"plugin_uninstall_confirm": "Uninstall plugin {name}?",
"plugin_uninstall_success": "Plugin uninstalled successfully: {name}",
"plugin_uninstall_failed": "Failed to uninstall plugin: {name}",
"plugin_list_empty": "No plugins installed",
"plugin_already_installed": "Plugin already installed: {name}",
"plugin_not_found": "Plugin not found: {name}",
"plugin_already_exists": "Plugin {name} already exists",
"plugin_not_found_or_installed": "Plugin {name} not found or already installed",
"plugin_uninstall_failed_ex": "Failed to uninstall plugin {name}: {error}",
"plugin_no_update_needed": "No plugins need updating",
"plugin_found_update": "Found {count} plugin(s) needing update",
"plugin_updating": "Updating plugin {name}...",
"plugin_search_no_result": "No plugins matching '{query}' found",
"plugin_search_results": "Search results: '{query}'",
# Config commands
"config_show": "Show configuration",
"config_set_success": "Configuration updated: {key} = {value}",
"config_set_failed": "Failed to update configuration: {key}",
"config_set_failed_ex": "Failed to set config: {error}",
"config_get_success": "{key} = {value}",
"config_get_not_found": "Configuration key not found: {key}",
"config_reset_confirm": "Reset all configuration?",
"config_reset_success": "Configuration reset",
# Config validators
"config_log_level_invalid": "Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
"config_port_must_be_number": "Port must be a number",
"config_port_range_invalid": "Port must be in range 1-65535",
"config_username_empty": "Username cannot be empty",
"config_password_empty": "Password cannot be empty",
"config_timezone_invalid": "Invalid timezone: {value}. Please use a valid IANA timezone name",
"config_callback_invalid": "Callback API base must start with http:// or https://",
"config_key_unsupported": "Unsupported config key: {key}",
"config_key_unknown": "Unknown config key: {key}",
"config_updated": "Config updated: {key}",
# Init command
"init_creating": "Creating config directory...",
"init_created": "Config directory created: {path}",
"init_copying": "Copying config files...",
"init_copied": "Config files copied",
"init_success": "AstrBot initialized successfully!",
"init_failed": "Initialization failed: {error}",
# Run command
"run_starting": "Starting AstrBot...",
"run_started": "AstrBot started!",
"run_backend_only": "Starting in backend-only mode",
"run_failed": "Failed to start: {error}",
"run_stopped": "AstrBot stopped",
# Common
"yes": "Yes",
"no": "No",
"cancel": "Cancel",
"confirm": "Confirm",
"error": "Error",
"success": "Success",
"warning": "Warning",
"info": "Info",
"loading": "Loading...",
"done": "Done",
"failed": "Failed",
"retry": "Retry",
"exit": "Exit",
"continue": "Continue",
},
}
@lru_cache(maxsize=1)
def get_current_language() -> Language:
"""Get the current language based on environment or default.
Detection order:
1. ASTRBOT_CLI_LANG environment variable (zh/en)
2. LANG environment variable (if contains zh/cn)
3. LC_ALL environment variable (if contains zh/cn)
4. Default to Chinese (most users are Chinese)
"""
# Check explicit override first
explicit = os.environ.get("ASTRBOT_CLI_LANG", "").lower()
if explicit in ("zh", "en"):
return Language.ZH if explicit == "zh" else Language.EN
# Check LANG/LC_ALL for Chinese
for env_var in ("LANG", "LC_ALL"):
lang = os.environ.get(env_var, "").lower()
if "zh" in lang or "cn" in lang:
return Language.ZH
# Default to Chinese for broader appeal
return Language.ZH
def set_language(lang: Language) -> None:
"""Set the current language (clears all translation caches)."""
get_current_language.cache_clear()
_t_cached.cache_clear()
# Set environment variable for persistence
os.environ["ASTRBOT_CLI_LANG"] = lang.value
@lru_cache(maxsize=128)
def _t_cached(key: str, lang: Language) -> str:
"""Cached translation lookup."""
return _TRANSLATIONS.get(lang, {}).get(key, key)
def t(translation_key: str, **kwargs: str) -> str:
"""Get translation for the given key in the current language.
Args:
translation_key: Translation key (e.g., "cli_welcome", "plugin_installing")
**kwargs: Format arguments for the translation string
Returns:
Translated string, or the key itself if not found
"""
result = _t_cached(translation_key, get_current_language())
if kwargs:
result = result.format(**kwargs)
return result
def tr(key: str, **kwargs: str) -> str:
"""Get translation (alias for t())."""
return t(key, **kwargs)
class CLITranslations:
"""Translation accessor class for CLI contexts.
Usage:
translations = CLITranslations()
print(translations.cli_welcome)
print(translations.plugin_installing(name="my_plugin"))
"""
def __getattr__(self, key: str) -> str:
return t(key)
def __call__(self, key: str, **kwargs: str) -> str:
return t(key, **kwargs)
# Convenience instance
translations = CLITranslations()

View File

@@ -1,12 +1,25 @@
from .dashboard import DashboardManager
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin
from .basic import (
check_astrbot_root,
check_dashboard,
get_astrbot_root,
)
from .plugin import (
PluginStatus,
build_plug_list,
get_git_repo,
install_local_plugin,
manage_plugin,
)
from .version_comparator import VersionComparator
__all__ = [
"DashboardManager",
"PluginStatus",
"VersionComparator",
"build_plug_list",
"check_astrbot_root",
"check_dashboard",
"get_astrbot_root",
"get_git_repo",
"install_local_plugin",
"manage_plugin",
]

View File

@@ -1,79 +0,0 @@
import sys
from importlib import resources
from pathlib import Path
import click
from astrbot.cli.i18n import t
from .version_comparator import VersionComparator
class DashboardManager:
_bundled_dist = resources.files("astrbot") / "dashboard" / "dist"
async def ensure_installed(self, astrbot_root: Path) -> None:
"""Ensure the dashboard assets are installed and up to date."""
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
if self._bundled_dist.is_dir():
click.echo(t("dashboard_bundled"))
return
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo(t("dashboard_not_installed"))
# Skip interactive prompt in non-interactive environments
if not sys.stdin.isatty():
click.echo(t("dashboard_not_needed"))
return
if click.confirm(t("dashboard_install_confirm"), default=True):
click.echo(t("dashboard_installing"))
try:
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
click.echo(t("dashboard_install_success"))
except Exception as e:
click.echo(t("dashboard_install_failed", error=str(e)))
else:
click.echo(t("dashboard_declined"))
case str():
if (
VersionComparator.compare_version(VERSION, dashboard_version)
<= 0
):
click.echo(t("dashboard_already_up_to_date"))
return
try:
version = dashboard_version.split("v")[1]
click.echo(t("dashboard_version", version=version))
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(t("dashboard_download_failed", error=str(e)))
return
except FileNotFoundError:
click.echo(t("dashboard_init_dir"))
try:
await download_dashboard(
path=str(astrbot_root / "data" / "dashboard.zip"),
extract_path=str(astrbot_root / "data"),
version=f"v{VERSION}",
latest=False,
)
click.echo(t("dashboard_init_success"))
except Exception as e:
click.echo(t("dashboard_download_failed", error=str(e)))
return

View File

@@ -1,9 +1,9 @@
import shutil
import tempfile
import uuid
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Any
from zipfile import ZipFile
import click
@@ -20,6 +20,35 @@ class PluginStatus(str, Enum):
NOT_PUBLISHED = "unpublished"
LOCAL_PLUGIN_COPY_IGNORE = shutil.ignore_patterns(
".git",
"__pycache__",
"*.pyc",
".venv",
"venv",
".idea",
".vscode",
".zed",
)
def _validate_plugin_dir_name(plugin_name: str, source_path: Path) -> str:
plugin_name = plugin_name.strip()
plugin_path = Path(plugin_name)
has_separator = "/" in plugin_name or "\\" in plugin_name
if (
not plugin_name
or plugin_name in {".", ".."}
or plugin_path.is_absolute()
or has_separator
or plugin_path.name != plugin_name
):
raise click.ClickException(
f"Local plugin {source_path} metadata.yaml has invalid name: {plugin_name}"
)
return plugin_name
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
"""Download code from a Git repository and extract to the specified path"""
temp_dir = Path(tempfile.mkdtemp())
@@ -33,7 +62,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
proxy=proxy or None,
proxy=proxy if proxy else None,
follow_redirects=True,
) as client:
resp = client.get(release_url)
@@ -57,7 +86,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
# Download and extract
with httpx.Client(
proxy=proxy or None,
proxy=proxy if proxy else None,
follow_redirects=True,
) as client:
resp = client.get(download_url)
@@ -84,7 +113,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
shutil.rmtree(temp_dir, ignore_errors=True)
def load_yaml_metadata(plugin_dir: Path) -> dict[str, Any]:
def load_yaml_metadata(plugin_dir: Path) -> dict:
"""Load plugin metadata from metadata.yaml file
Args:
@@ -97,10 +126,7 @@ def load_yaml_metadata(plugin_dir: Path) -> dict[str, Any]:
yaml_path = plugin_dir / "metadata.yaml"
if yaml_path.exists():
try:
data = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if isinstance(data, dict):
return dict[str, Any](data)
return {}
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
except Exception as e:
click.echo(f"Failed to read {yaml_path}: {e}", err=True)
return {}
@@ -118,9 +144,10 @@ def build_plug_list(plugins_dir: Path) -> list:
"""
# Get local plugin info
result = []
if plugins_dir.exists():
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
if plugins_dir.is_dir():
for plugin_dir in plugins_dir.iterdir():
if not plugin_dir.is_dir():
continue
# Load metadata from metadata.yaml
metadata = load_yaml_metadata(plugin_dir)
@@ -145,58 +172,120 @@ def build_plug_list(plugins_dir: Path) -> list:
)
# Get online plugin list
online_plugins = []
online_plugins_dict = {}
try:
with httpx.Client() as client:
resp = client.get("https://api.soulter.top/astrbot/plugins")
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append(
{
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
},
)
online_plugins_dict[str(plugin_id)] = {
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
}
except Exception as e:
click.echo(f"Failed to get online plugin list: {e}", err=True)
# Compare with online plugins and update status
online_plugin_names = {plugin["name"] for plugin in online_plugins}
for local_plugin in result:
if local_plugin["name"] in online_plugin_names:
# Find the corresponding online plugin
online_plugin = next(
p for p in online_plugins if p["name"] == local_plugin["name"]
)
if (
VersionComparator.compare_version(
local_plugin["version"] or "",
online_plugin["version"] or "",
)
< 0
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
else:
online_plugin = online_plugins_dict.pop(local_plugin["name"], None)
if online_plugin is None:
# Local plugin is not published online
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
continue
if (
VersionComparator.compare_version(
local_plugin["version"],
online_plugin["version"],
)
< 0
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
# Add uninstalled online plugins
for online_plugin in online_plugins:
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
clean: dict[str, str] = {
k: v for k, v in online_plugin.items() if v is not None
}
result.append(clean)
result.extend(online_plugins_dict.values())
return result
def _cleanup_local_plugin_target(target_path: Path) -> None:
if target_path.is_symlink() or target_path.is_file():
target_path.unlink(missing_ok=True)
elif target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
def _copy_local_plugin(source_path: Path, plugins_dir: Path, target_path: Path) -> None:
temp_target = plugins_dir / f".{target_path.name}.tmp-{uuid.uuid4().hex}"
try:
shutil.copytree(source_path, temp_target, ignore=LOCAL_PLUGIN_COPY_IGNORE)
temp_target.rename(target_path)
except FileExistsError:
raise click.ClickException(
f"Plugin {target_path.name} already exists"
) from None
except Exception:
raise
finally:
if temp_target.exists() or temp_target.is_symlink():
_cleanup_local_plugin_target(temp_target)
def install_local_plugin(
source_path: Path,
plugins_dir: Path,
editable: bool = False,
) -> None:
"""Install a plugin from a local directory."""
source_path = source_path.expanduser().resolve()
plugins_dir = plugins_dir.resolve()
if not source_path.exists() or not source_path.is_dir():
raise click.ClickException(f"Local plugin path does not exist: {source_path}")
metadata = load_yaml_metadata(source_path)
plugin_name = metadata.get("name")
if not isinstance(plugin_name, str) or not plugin_name.strip():
raise click.ClickException(
f"Local plugin {source_path} must contain metadata.yaml with a valid name"
)
plugin_name = _validate_plugin_dir_name(plugin_name, source_path)
target_path = plugins_dir / plugin_name
if target_path.exists():
raise click.ClickException(f"Plugin {plugin_name} already exists")
try:
plugins_dir.mkdir(parents=True, exist_ok=True)
if editable:
try:
target_path.symlink_to(source_path, target_is_directory=True)
except OSError as e:
raise click.ClickException(
f"Failed to create symlink for editable install: {e}. "
"On Windows, you may need to run as Administrator or enable Developer Mode."
) from e
else:
_copy_local_plugin(source_path, plugins_dir, target_path)
click.echo(f"Plugin {plugin_name} installed successfully from {source_path}")
except FileExistsError:
raise click.ClickException(f"Plugin {plugin_name} already exists") from None
except click.ClickException:
raise
except Exception as e:
if editable and target_path.is_symlink():
_cleanup_local_plugin_target(target_path)
raise click.ClickException(
f"Error installing local plugin {plugin_name}: {e}"
) from e
def manage_plugin(
plugin: dict,
plugins_dir: Path,
@@ -226,7 +315,7 @@ def manage_plugin(
# Check if plugin exists
if is_update and not target_path.exists():
raise click.ClickException(
f"Plugin {plugin_name} is not installed and cannot be updated",
f"Plugin {plugin_name} is not installed and cannot be updated"
)
# Backup existing plugin
@@ -245,7 +334,7 @@ def manage_plugin(
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(
f"Plugin {plugin_name} {'updated' if is_update else 'installed'} successfully",
f"Plugin {plugin_name} {'updated' if is_update else 'installed'} successfully"
)
except Exception as e:
if target_path.exists():
@@ -254,4 +343,4 @@ def manage_plugin(
shutil.move(backup_path, target_path)
raise click.ClickException(
f"Error {'updating' if is_update else 'installing'} plugin {plugin_name}: {e}",
) from e
)

View File

@@ -62,9 +62,12 @@ class VersionComparator:
return -1
if isinstance(p1, str) and isinstance(p2, int):
return 1
if (isinstance(p1, int) and isinstance(p2, int)) or (
isinstance(p1, str) and isinstance(p2, str)
):
if isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
if p1 < p2:
return -1
elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
if p1 < p2:

View File

@@ -22,29 +22,11 @@ from astrbot.core.utils.requirements_utils import (
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from .log import LogBroker, LogManager
from .utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_data_path,
get_astrbot_knowledge_base_path,
get_astrbot_plugin_path,
get_astrbot_site_packages_path,
get_astrbot_skills_path,
get_astrbot_temp_path,
)
from .log import LogBroker, LogManager # noqa
from .utils.astrbot_path import get_astrbot_data_path
# Initialize required data directories eagerly so later agent/tool flows do not
# fail on missing paths when the runtime root resolves to a fresh location.
for required_dir in (
get_astrbot_data_path(),
get_astrbot_config_path(),
get_astrbot_plugin_path(),
get_astrbot_temp_path(),
get_astrbot_knowledge_base_path(),
get_astrbot_skills_path(),
get_astrbot_site_packages_path(),
):
os.makedirs(required_dir, exist_ok=True)
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", "False").strip().lower() in ("true", "1", "t")
@@ -52,11 +34,7 @@ astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
LogManager.configure_logger(
logger,
astrbot_config,
override_level=os.getenv("ASTRBOT_LOG_LEVEL"),
)
LogManager.configure_logger(logger, astrbot_config)
LogManager.configure_trace_logger(astrbot_config)
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
@@ -67,17 +45,3 @@ pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
__all__ = [
"DEMO_MODE",
"AstrBotConfig",
"LogBroker",
"LogManager",
"astrbot_config",
"db_helper",
"file_token_service",
"html_renderer",
"logger",
"pip_installer",
"sp",
"t2i_base_url",
]

View File

@@ -1,6 +1,6 @@
from astrbot import logger
from astrbot.core.agent.message import Message
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import EstimateTokenCounter
@@ -22,7 +22,6 @@ class ContextManager:
Args:
config: The context configuration.
"""
self.config = config
@@ -40,13 +39,11 @@ class ContextManager:
)
else:
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns,
truncate_turns=config.truncate_turns
)
async def process(
self,
messages: list[Message],
trusted_token_usage: int = 0,
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.
@@ -55,7 +52,6 @@ class ContextManager:
Returns:
The processed message list.
"""
try:
result = messages
@@ -71,14 +67,11 @@ class ContextManager:
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result,
trusted_token_usage,
result, trusted_token_usage
)
if self.compressor.should_compress(
result,
total_tokens,
self.config.max_context_tokens,
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
@@ -88,11 +81,10 @@ class ContextManager:
return messages
async def _run_compression(
self,
messages: list[Message],
prev_tokens: int,
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""Compress/truncate the messages.
"""
Compress/truncate the messages.
Args:
messages: The original message list.
@@ -100,7 +92,6 @@ class ContextManager:
Returns:
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")
@@ -119,12 +110,10 @@ class ContextManager:
# last check
if self.compressor.should_compress(
messages,
tokens_after_summary,
self.config.max_context_tokens,
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation...",
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)

View File

@@ -1,25 +1,18 @@
import json
from typing import Protocol, runtime_checkable
from astrbot.core.agent.message import (
AudioURLPart,
ImageURLPart,
Message,
TextPart,
ThinkPart,
)
from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart
@runtime_checkable
class TokenCounter(Protocol):
"""Protocol for token counters.
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(
self,
messages: list[Message],
trusted_token_usage: int = 0,
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count the total tokens in the message list.
@@ -31,14 +24,13 @@ class TokenCounter(Protocol):
Returns:
The total token count.
"""
...
# 图片/音频 token 开销估算值,参考 OpenAI vision pricing:
# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千
# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错
# 图片/音频 token 开销估算值参考 OpenAI vision pricing:
# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千
# 这里取一个保守中位数宁可偏高触发压缩也不要偏低导致 API 报错
IMAGE_TOKEN_ESTIMATE = 765
AUDIO_TOKEN_ESTIMATE = 500
@@ -52,9 +44,7 @@ class EstimateTokenCounter:
"""
def count_tokens(
self,
messages: list[Message],
trusted_token_usage: int = 0,
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage

View File

@@ -1,4 +1,4 @@
from astrbot.core.agent.message import Message
from ..message import Message
class ContextTruncator:
@@ -20,7 +20,6 @@ class ContextTruncator:
Returns:
tuple: (system_messages, non_system_messages)
"""
first_non_system = 0
for i, msg in enumerate(messages):
@@ -35,44 +34,19 @@ class ContextTruncator:
truncated: list[Message],
original_messages: list[Message],
) -> list[Message]:
"""Ensure the result always contains a `user` message immediately after
system messages, as required by some LLM APIs.
Optimization strategy:
- If `truncated` already begins with a `user` message, return it as-is.
- If a `user` message exists later in `truncated`, move that message to
be the first non-system message while preserving the relative order of
the remaining truncated messages (without mutating the original list).
- Otherwise, fall back to the first `user` message from
`original_messages`.
This reduces unnecessary duplication and ensures the required ordering.
"""Ensure the result always contains the first user message right after
system messages. This is required by many LLM APIs (e.g. Zhipu) that
mandate a ``user`` message immediately following the ``system`` message.
"""
if truncated and truncated[0].role == "user":
return system_messages + truncated
# If a user message exists inside the truncated list, promote it to the front.
index_in_truncated = next(
(i for i, m in enumerate(truncated) if m.role == "user"),
None,
)
if index_in_truncated is not None:
# Build a new truncated list that places the found user message first,
# preserving the order of the other messages and avoiding in-place mutation.
user_msg = truncated[index_in_truncated]
new_truncated = [
user_msg,
*truncated[:index_in_truncated],
*truncated[index_in_truncated + 1 :],
]
return system_messages + new_truncated
# Fallback: find the first user message in the original messages.
# Locate the first user message from the *original* list.
first_user = next((m for m in original_messages if m.role == "user"), None)
if first_user is None:
# No user messages at all; return system messages + whatever was truncated.
return system_messages + truncated
return [*system_messages, first_user, *truncated]
return system_messages + [first_user] + truncated
def fix_messages(self, messages: list[Message]) -> list[Message]:
"""Fix the message list to ensure the validity of tool call and tool response pairing.
@@ -129,7 +103,8 @@ class ContextTruncator:
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
"""
Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
A turn consists of a user message and an assistant message.
This method ensures that the truncated context list conforms to OpenAI's context format.
@@ -140,7 +115,6 @@ class ContextTruncator:
Returns:
The truncated list of messages.
"""
if keep_most_recent_turns == -1:
return messages
@@ -165,9 +139,7 @@ class ContextTruncator:
truncated_contexts = truncated_contexts[index:]
result = self._ensure_user_message(
system_messages,
truncated_contexts,
messages,
system_messages, truncated_contexts, messages
)
return self.fix_messages(result)
@@ -196,9 +168,7 @@ class ContextTruncator:
truncated_non_system = truncated_non_system[index:]
result = self._ensure_user_message(
system_messages,
truncated_non_system,
messages,
system_messages, truncated_non_system, messages
)
return self.fix_messages(result)
@@ -227,8 +197,6 @@ class ContextTruncator:
truncated_non_system = truncated_non_system[index:]
result = self._ensure_user_message(
system_messages,
truncated_non_system,
messages,
system_messages, truncated_non_system, messages
)
return self.fix_messages(result)

View File

@@ -1,7 +1,3 @@
"""MCP client
This file exists solely for backward compatibility and will be removed in a future version.
"""
import asyncio
import copy
import logging
@@ -11,7 +7,7 @@ import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from pathlib import Path, PureWindowsPath
from typing import Any, Generic, TextIO
from typing import Any, Generic
import httpx
from tenacity import (
@@ -22,14 +18,13 @@ from tenacity import (
wait_exponential,
)
from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
logger = logging.getLogger("astrbot")
_DEFAULT_STDIO_COMMAND_ALLOWLIST = frozenset(
{
"python",
@@ -105,7 +100,7 @@ try:
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency, MCP services will be unavailable.",
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
)
streamable_http_client_legacy = None
@@ -126,31 +121,13 @@ except (ModuleNotFoundError, ImportError):
)
class TenacityLogger:
"""Wraps a logging.Logger to satisfy tenacity's LoggerProtocol."""
__slots__ = ("_logger",)
_logger: logging.Logger
def __init__(self, logger: logging.Logger) -> None:
self._logger = logger
def log(
self,
level: int,
msg: str,
/,
*args: Any,
**kwargs: Any,
) -> None:
self._logger.log(level, msg, *args, **kwargs)
def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format"""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config = dict(config["mcpServers"][first_key])
else:
config = dict(config)
config.pop("active", None)
return config
@@ -180,6 +157,11 @@ def _get_stdio_command_allowlist() -> set[str]:
return allowed
def _is_stdio_config(config: dict) -> bool:
cfg = _prepare_config(config.copy())
return "url" not in cfg
def _validate_stdio_args(command_name: str, args: object) -> None:
if args is None:
return
@@ -231,7 +213,7 @@ def _validate_stdio_args(command_name: str, args: object) -> None:
def validate_mcp_stdio_config(config: dict) -> None:
"""Validate MCP stdio configuration in a backward-compatible way."""
"""Validate stdio MCP config before any subprocess can be spawned."""
cfg = _prepare_config(config.copy())
if "url" in cfg:
return
@@ -278,10 +260,10 @@ def _prepare_stdio_env(config: dict) -> dict:
def _merge_environment_variables(env: dict) -> dict:
"""Merge environment variables in case-insensitive systems."""
"""合并环境变量处理Windows不区分大小写的情况"""
merged = env.copy()
# Use lower-case keys for case-insensitive matching on Windows.
# 将用户环境变量转换为统一的大小写形式便于比较
user_keys_lower = {k.lower(): k for k in merged.keys()}
for sys_key, sys_value in os.environ.items():
@@ -349,7 +331,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
except TimeoutError:
except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
except Exception as e:
return False, f"{e!s}"
@@ -386,7 +368,7 @@ def _normalize_mcp_input_schema(schema: dict[str, Any]) -> dict[str, Any]:
if not isinstance(prop_schema, dict):
continue
original_prop_schema = (original_properties or {}).get(prop_name, {})
original_prop_schema = original_properties.get(prop_name, {})
prop_required = (
original_prop_schema.get("required")
if isinstance(original_prop_schema, dict)
@@ -420,7 +402,6 @@ class MCPClient:
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
self.process_pid: int | None = None
# Store connection config for reconnection
self._mcp_server_config: dict | None = None
@@ -428,24 +409,6 @@ class MCPClient:
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
@staticmethod
def _extract_stdio_process_pid(streams_context: object) -> int | None:
"""Best-effort extraction for stdio subprocess PID used by lease cleanup.
TODO(refactor): replace this async-generator frame introspection with a
stable MCP library hook once the upstream transport exposes process PID.
"""
generator = getattr(streams_context, "gen", None)
frame = getattr(generator, "ag_frame", None)
if frame is None:
return None
process = frame.f_locals.get("process")
pid = getattr(process, "pid", None)
try:
return int(pid) if pid is not None else None
except (TypeError, ValueError):
return None
async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
"""Connect to MCP server
@@ -461,17 +424,17 @@ class MCPClient:
# Store config for reconnection
self._mcp_server_config = mcp_server_config
self._server_name = name
self.process_pid = None
cfg = _prepare_config(mcp_server_config.copy())
async def logging_callback(
params: mcp.types.LoggingMessageNotificationParams,
def logging_callback(
msg: str | mcp.types.LoggingMessageNotificationParams,
) -> None:
# Handle MCP service error logs
if params.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{params.level.upper()}] {params.data!s}"
self.server_errlogs.append(log_msg)
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
@@ -493,21 +456,19 @@ class MCPClient:
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
read_stream, write_stream = await self.exit_stack.enter_async_context(
streams = await self.exit_stack.enter_async_context(
self._streams_context,
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
session = await self.exit_stack.enter_async_context(
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_stream,
write_stream=write_stream,
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback,
logging_callback=logging_callback, # type: ignore
),
)
self.session = session
else:
timeout_seconds = cfg.get("timeout", 30)
sse_read_timeout_seconds = cfg.get("sse_read_timeout", 60 * 5)
@@ -547,17 +508,17 @@ class MCPClient:
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
session = await self.exit_stack.enter_async_context(
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback,
logging_callback=logging_callback, # type: ignore
),
)
self.session = session
else:
validate_mcp_stdio_config(cfg)
cfg = _prepare_stdio_env(cfg)
server_params = mcp.StdioServerParameters(
**cfg,
@@ -573,35 +534,25 @@ class MCPClient:
"alert",
"emergency",
):
log_msg = f"[{msg.level.upper()}] {msg.data!s}"
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
log_pipe = self.exit_stack.enter_context(
LogPipe(
level=logging.INFO,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
)
errlog_stream: TextIO = self.exit_stack.enter_context(
os.fdopen(os.dup(log_pipe.fileno()), "w"),
)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=errlog_stream,
errlog=LogPipe(
level=logging.INFO,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
), # type: ignore
),
)
self.process_pid = self._extract_stdio_process_pid(stdio_transport)
# Create a new client session
session = await self.exit_stack.enter_async_context(
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport),
)
self.session = session
assert self.session is not None
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
@@ -619,13 +570,12 @@ class MCPClient:
Raises:
Exception: raised when reconnection fails
"""
async with self._reconnect_lock:
# Check if already reconnecting (useful for logging)
if self._reconnecting:
logger.debug(
f"MCP Client {self._server_name} is already reconnecting, skipping",
f"MCP Client {self._server_name} is already reconnecting, skipping"
)
return
@@ -635,7 +585,7 @@ class MCPClient:
self._reconnecting = True
try:
logger.info(
f"Attempting to reconnect to MCP server {self._server_name}...",
f"Attempting to reconnect to MCP server {self._server_name}..."
)
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
@@ -653,11 +603,11 @@ class MCPClient:
await self.list_tools_and_save()
logger.info(
f"Successfully reconnected to MCP server {self._server_name}",
f"Successfully reconnected to MCP server {self._server_name}"
)
except Exception as e:
logger.error(
f"Failed to reconnect to MCP server {self._server_name}: {e}",
f"Failed to reconnect to MCP server {self._server_name}: {e}"
)
raise
finally:
@@ -682,14 +632,13 @@ class MCPClient:
Raises:
ValueError: MCP session is not available
anyio.ClosedResourceError: raised after reconnection failure
"""
@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=3),
before_sleep=before_sleep_log(TenacityLogger(logger), logging.WARNING),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def _call_with_retry():
@@ -704,7 +653,7 @@ class MCPClient:
)
except anyio.ClosedResourceError:
logger.warning(
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect...",
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
)
# Attempt to reconnect
await self._reconnect()
@@ -728,18 +677,13 @@ class MCPClient:
# Set running_event first to unblock any waiting tasks
self.running_event.set()
self.process_pid = None
class MCPTool(FunctionTool, Generic[TContext]):
"""A function tool that calls an MCP service."""
def __init__(
self,
mcp_tool: mcp.Tool,
mcp_client: MCPClient,
mcp_server_name: str,
**kwargs,
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
) -> None:
super().__init__(
name=mcp_tool.name,
@@ -749,12 +693,9 @@ class MCPTool(FunctionTool, Generic[TContext]):
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self.mcp_server_name = mcp_server_name
self.source = "mcp"
async def call(
self,
context: ContextWrapper[TContext],
**kwargs,
self, context: ContextWrapper[TContext], **kwargs
) -> mcp.types.CallToolResult:
return await self.mcp_client.call_tool_with_reconnect(
tool_name=self.mcp_tool.name,

View File

@@ -1,7 +1,7 @@
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
# License: Apache License 2.0
from typing import Any, ClassVar, Literal, Self, TypeVar, cast
from typing import Any, ClassVar, Literal, TypeVar, cast
from pydantic import (
BaseModel,
@@ -37,9 +37,7 @@ class ContentPart(BaseModel):
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: Any,
handler: GetCoreSchemaHandler,
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# If we're dealing with the base ContentPart class, use custom validation
if cls.__name__ == "ContentPart":
@@ -51,12 +49,12 @@ class ContentPart(BaseModel):
# if it's a dict with a type field, dispatch to the appropriate subclass
if isinstance(value, dict) and "type" in value:
type_value: Any | None = cast("dict[str, Any]", value).get("type")
type_value: Any | None = cast(dict[str, Any], value).get("type")
if not isinstance(type_value, str):
raise ValueError(f"Cannot validate {value} as ContentPart")
target_class = cls.__content_part_registry[type_value]
part = target_class.model_validate(value)
if cast("dict[str, Any]", value).get("_no_save"):
if cast(dict[str, Any], value).get("_no_save"):
part._no_save = True
return part
@@ -67,7 +65,7 @@ class ContentPart(BaseModel):
# for subclasses, use the default schema
return handler(source_type)
def mark_as_temp(self) -> Self:
def mark_as_temp(self: ContentPartT) -> ContentPartT:
"""Mark this content part as provider-facing only, not persisted."""
self._no_save = True
return self
@@ -80,7 +78,8 @@ class ContentPart(BaseModel):
class TextPart(ContentPart):
"""TextPart(text="Hello, world!").model_dump()
"""
>>> TextPart(text="Hello, world!").model_dump()
{'type': 'text', 'text': 'Hello, world!'}
"""
@@ -89,7 +88,8 @@ class TextPart(ContentPart):
class ThinkPart(ContentPart):
"""ThinkPart(think="I think I need to think about this.").model_dump()
"""
>>> ThinkPart(think="I think I need to think about this.").model_dump()
{'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None}
"""
@@ -110,7 +110,8 @@ class ThinkPart(ContentPart):
class ImageURLPart(ContentPart):
"""ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
"""
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
"""
@@ -125,7 +126,8 @@ class ImageURLPart(ContentPart):
class AudioURLPart(ContentPart):
"""AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
"""
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
"""
@@ -140,9 +142,10 @@ class AudioURLPart(ContentPart):
class ToolCall(BaseModel):
"""A tool call requested by the assistant.
"""
A tool call requested by the assistant.
ToolCall(
>>> ToolCall(
... id="123",
... function=ToolCall.FunctionBody(
... name="function",
@@ -229,7 +232,7 @@ class Message(BaseModel):
# other all cases: content is required
if self.content is None:
raise ValueError(
"content is required unless role='assistant' and tool_calls is not None",
"content is required unless role='assistant' and tool_calls is not None"
)
return self
@@ -353,8 +356,6 @@ def dump_messages_with_checkpoints(messages: list[Message]) -> list[dict]:
dumped.append(message_data)
if message._checkpoint_after is not None:
dumped.append(
CheckpointMessageSegment(
content=message._checkpoint_after,
).model_dump(),
CheckpointMessageSegment(content=message._checkpoint_after).model_dump()
)
return dumped

View File

@@ -1,4 +1,4 @@
from typing import Any, Generic, cast
from typing import Any, Generic
from pydantic import Field
from pydantic.dataclasses import dataclass
@@ -13,7 +13,7 @@ TContext = TypeVar("TContext", default=Any)
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext = cast("TContext", None)
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 120 # Default tool call timeout in seconds

View File

@@ -1,16 +1,13 @@
import abc
import asyncio
from collections.abc import AsyncGenerator
import typing as T
from enum import Enum, auto
from typing import Any, Generic
from astrbot import logger
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.response import AgentResponse
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider.entities import LLMResponse, ProviderRequest
from astrbot.core.provider.provider import Provider
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
class AgentState(Enum):
@@ -22,33 +19,13 @@ class AgentState(Enum):
ERROR = auto() # Error state
class BaseAgentRunner(Generic[TContext]):
def __init__(
self,
):
self.tasks: set[asyncio.Task[object]] = set()
self._state = AgentState.IDLE
class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
**kwargs: T.Any,
) -> None:
"""Reset the agent to its initial state.
This method should be called before starting a new run.
@@ -56,12 +33,14 @@ class BaseAgentRunner(Generic[TContext]):
...
@abc.abstractmethod
def step(self) -> AsyncGenerator[AgentResponse, None]:
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""Process a single step of the agent."""
...
@abc.abstractmethod
def step_until_done(self, max_step: int) -> AsyncGenerator[AgentResponse, None]:
async def step_until_done(
self, max_step: int
) -> T.AsyncGenerator[AgentResponse, None]:
"""Process steps until the agent is done."""
...

View File

@@ -1,25 +1,29 @@
import base64
import json
from typing import Any, override
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.message import is_checkpoint_message
from astrbot.core.agent.response import AgentResponse, AgentResponseData
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.runners.base import AgentState, BaseAgentRunner
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.provider import Provider
from astrbot.core.utils.media_utils import MediaResolver, describe_media_ref
from ...hooks import BaseAgentRunHooks
from ...message import is_checkpoint_message
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .coze_api_client import CozeAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class CozeAgentRunner(BaseAgentRunner[TContext]):
"""Coze Agent Runner"""
@@ -27,45 +31,32 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = streaming
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
provider_config = provider_config or {}
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空")
raise Exception("Coze API Key 不能为空")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空")
raise Exception("Coze Bot ID 不能为空")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头",
"Coze API Base URL 格式不正确必须以 http:// 或 https:// 开头",
)
self.timeout = provider_config.get("timeout", 120)
@@ -81,7 +72,9 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
@override
async def step(self):
"""执行 Coze Agent 的一个步骤"""
"""
执行 Coze Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
@@ -91,7 +84,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
# 开始处理转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
@@ -99,23 +92,24 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
async for response in self._execute_coze_request():
yield response
except Exception as e:
logger.error(f"Coze 请求失败:{e!s}")
logger.error(f"Coze 请求失败{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err",
completion_text=f"Coze 请求失败:{e!s}",
role="err", completion_text=f"Coze 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Coze 请求失败:{e!s}"),
chain=MessageChain().message(f"Coze 请求失败{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(self, max_step: int):
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
@@ -161,7 +155,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
# 处理上下文中的图片
content = ctx["content"]
if isinstance(content, list):
# 多模态内容,需要处理图片
# 多模态内容需要处理图片
processed_content = []
for item in content:
if isinstance(item, dict):
@@ -175,8 +169,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
if url:
file_id = (
await self._download_and_upload_image(
url,
session_id,
url, session_id
)
)
processed_content.append(
@@ -184,7 +177,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
"type": "file",
"file_id": file_id,
"file_url": url,
},
}
)
except Exception as e:
logger.warning(f"处理上下文图片失败: {e}")
@@ -196,7 +189,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
},
}
)
else:
# 纯文本内容
@@ -205,7 +198,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
"role": ctx["role"],
"content": content,
"content_type": "text",
},
}
)
# 构建当前消息
@@ -217,18 +210,23 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
# the url is a base64 string
try:
image_data = base64.b64decode(url)
file_id = await self.api_client.upload_file(image_data)
file_id = await self._download_and_upload_image(
url,
session_id,
)
object_string_content.append(
{
"type": "image",
"file_id": file_id,
},
}
)
except Exception as e:
logger.warning(f"处理图片失败 {url}: {e}")
logger.warning(
"处理图片失败 %s: %s",
describe_media_ref(url),
e,
)
continue
if object_string_content:
@@ -238,7 +236,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
"role": "user",
"content": content,
"content_type": "object_string",
},
}
)
elif prompt:
# 纯文本
@@ -287,12 +285,12 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
accumulated_content += content
message_started = True
# 如果是流式响应,发送增量数据
# 如果是流式响应发送增量数据
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(content),
chain=MessageChain().message(content)
),
)
@@ -338,7 +336,7 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze,返回 file_id"""
"""下载图片并上传到 Coze返回 file_id"""
import hashlib
# 计算哈希实现缓存
@@ -354,17 +352,20 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self.api_client.upload_file(image_data)
image_bytes = await MediaResolver(
image_url,
media_type="image",
).to_bytes()
file_id = await self.api_client.upload_file(image_bytes)
if session_id:
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
logger.error("处理图片失败 %s: %s", describe_media_ref(image_url), e)
raise Exception(f"处理图片失败: {e!s}") from e
@override

View File

@@ -66,7 +66,7 @@ class CozeAPIClient:
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
raise Exception("Coze API 认证失败请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
@@ -75,27 +75,27 @@ class CozeAPIClient:
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}",
f"文件上传失败状态码: {response.status}, 响应: {response_text}",
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}") from None
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
logger.debug(f"[Coze] 图片上传成功file_id: {file_id}")
return file_id
except TimeoutError:
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时") from None
raise Exception("文件上传超时")
except Exception as e:
logger.error(f"文件上传失败: {e!s}")
raise Exception(f"文件上传失败: {e!s}") from e
raise Exception(f"文件上传失败: {e!s}")
async def download_image(self, image_url: str) -> bytes:
"""下载图片并返回字节数据
@@ -111,14 +111,14 @@ class CozeAPIClient:
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
raise Exception(f"下载图片失败状态码: {response.status}")
image_data = await response.read()
return image_data
except Exception as e:
logger.error(f"下载图片失败 {image_url}: {e!s}")
raise Exception(f"下载图片失败: {e!s}") from e
raise Exception(f"下载图片失败: {e!s}")
async def chat_messages(
self,
@@ -145,7 +145,7 @@ class CozeAPIClient:
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload: dict[str, Any] = {
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
@@ -169,10 +169,10 @@ class CozeAPIClient:
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
raise Exception("Coze API 认证失败请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
raise Exception(f"Coze API 流式请求失败状态码: {response.status}")
# SSE
buffer = ""
@@ -203,10 +203,10 @@ class CozeAPIClient:
except json.JSONDecodeError:
event_data = {"content": data_str}
except TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)") from None
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {e!s}") from e
raise Exception(f"Coze API 流式请求失败: {e!s}")
async def clear_context(self, conversation_id: str):
"""清空会话上下文
@@ -226,20 +226,20 @@ class CozeAPIClient:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
raise Exception("Coze API 认证失败请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
raise Exception(f"Coze API 请求失败状态码: {response.status}")
try:
return json.loads(response_text)
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式") from None
raise Exception("Coze API 返回非JSON格式")
except TimeoutError:
raise Exception("Coze API 请求超时") from None
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {e!s}") from e
raise Exception(f"Coze API 请求失败: {e!s}")
async def get_message_list(
self,
@@ -275,7 +275,7 @@ class CozeAPIClient:
except Exception as e:
logger.error(f"获取Coze消息列表失败: {e!s}")
raise Exception(f"获取Coze消息列表失败: {e!s}") from e
raise Exception(f"获取Coze消息列表失败: {e!s}")
async def close(self) -> None:
"""关闭会话"""
@@ -288,18 +288,17 @@ if __name__ == "__main__":
import asyncio
import os
import anyio
async def test_coze_api_client() -> None:
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
try:
async with await anyio.open_file("README.md", "rb") as f:
file_data = await f.read()
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
async for _event in client.chat_messages(
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
additional_messages=[
@@ -317,7 +316,7 @@ if __name__ == "__main__":
],
stream=True,
):
pass
print(f"Event: {event}")
finally:
await client.close()

View File

@@ -2,26 +2,30 @@ import asyncio
import functools
import queue
import re
import sys
import threading
from collections.abc import AsyncGenerator
from typing import Any, override
import typing as T
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.response import AgentResponseData
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.runners.base import AgentResponse, AgentState, BaseAgentRunner
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.provider import Provider
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
@@ -30,41 +34,28 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = streaming
self.final_llm_resp: LLMResponse | None = None
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
provider_config = provider_config or {}
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空")
raise Exception("阿里云百炼 API Key 不能为空")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空")
raise Exception("阿里云百炼 APP ID 不能为空")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空")
raise Exception("阿里云百炼 APP 类型不能为空")
self.variables: dict = provider_config.get("variables", {}) or {}
self.rag_options: dict = provider_config.get("rag_options", {})
@@ -92,7 +83,9 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
@override
async def step(self):
"""执行 Dashscope Agent 的一个步骤"""
"""
执行 Dashscope Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
@@ -102,7 +95,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
# 开始处理转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
@@ -110,29 +103,28 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
async for response in self._execute_dashscope_request():
yield response
except Exception as e:
logger.error(f"阿里云百炼请求失败:{e!s}")
logger.error(f"阿里云百炼请求失败{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败:{e!s}",
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"阿里云百炼请求失败:{e!s}"),
chain=MessageChain().message(f"阿里云百炼请求失败{str(e)}")
),
)
@override
async def step_until_done(self, max_step: int):
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
def _consume_sync_generator(
self,
response: Any,
response_queue: queue.Queue,
self, response: T.Any, response_queue: queue.Queue
) -> None:
"""在线程中消费同步generator,将结果放入队列
@@ -153,9 +145,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
response_queue.put(("done", None))
async def _process_stream_chunk(
self,
chunk: ApplicationResponse,
output_text: str,
self, chunk: ApplicationResponse, output_text: str
) -> tuple[str, list | None, AgentResponse | None]:
"""处理流式响应的单个chunk
@@ -171,7 +161,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
if chunk.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
)
self._transition_state(AgentState.ERROR)
error_msg = (
@@ -190,8 +180,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
),
)
chunk_text_value = chunk.output.get("text", "")
chunk_text = chunk_text_value if isinstance(chunk_text_value, str) else ""
chunk_text = chunk.output.get("text", "") or ""
# RAG 引用脚标格式化
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
@@ -204,10 +193,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
)
# 获取文档引用
raw_doc_references = chunk.output.get("doc_references")
doc_references = (
raw_doc_references if isinstance(raw_doc_references, list) else None
)
doc_references = chunk.output.get("doc_references", None)
return output_text, doc_references, response
@@ -231,11 +217,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
return f"\n\n回答来源:\n{ref_str}"
async def _build_request_payload(
self,
prompt: str,
session_id: str,
contexts: list,
system_prompt: str,
self, prompt: str, session_id: str, contexts: list, system_prompt: str
) -> dict:
"""构建请求payload
@@ -256,17 +238,15 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
default="",
)
# 获得会话变量
payload_vars: dict = self.variables.copy()
session_var: dict = (
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
or {}
payload_vars = self.variables.copy()
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
@@ -283,24 +263,23 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
if conversation_id:
p["session_id"] = conversation_id
return p
# 不支持多轮对话的
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
return payload
else:
# 不支持多轮对话的
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
return payload
async def _handle_streaming_response(
self,
response: Any,
session_id: str,
) -> AsyncGenerator[AgentResponse, None]:
self, response: T.Any, session_id: str
) -> T.AsyncGenerator[AgentResponse, None]:
"""处理流式响应
Args:
@@ -310,7 +289,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
AgentResponse 对象
"""
response_queue: queue.Queue[tuple[str, Any]] = queue.Queue()
response_queue = queue.Queue()
consumer_thread = threading.Thread(
target=self._consume_sync_generator,
args=(response, response_queue),
@@ -324,10 +303,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
while True:
try:
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
None,
response_queue.get,
True,
1,
None, response_queue.get, True, 1
)
except queue.Empty:
continue
@@ -335,10 +311,6 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
if item_type == "done":
break
elif item_type == "error":
if not isinstance(item_data, BaseException):
raise RuntimeError(
f"Unexpected Dashscope error payload: {item_data!r}",
)
raise item_data
elif item_type == "data":
chunk = item_data
@@ -347,14 +319,14 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
(
output_text,
chunk_doc_refs,
agent_response,
response,
) = await self._process_stream_chunk(chunk, output_text)
if agent_response:
if agent_response.type == "err":
yield agent_response
if response:
if response.type == "err":
yield response
return
yield agent_response
yield response
if chunk_doc_refs:
doc_references = chunk_doc_refs
@@ -380,12 +352,11 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(output_text)])
final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self.final_llm_resp = final_llm_resp
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, final_llm_resp)
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
@@ -405,14 +376,11 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 检查图片输入
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容")
logger.warning("阿里云百炼暂不支持图片输入将自动忽略图片内容")
# 构建请求payload
payload = await self._build_request_payload(
prompt,
session_id,
contexts,
system_prompt,
prompt, session_id, contexts, system_prompt
)
if not self.streaming:

View File

@@ -1,28 +1,26 @@
import asyncio
import hashlib
import json
import sys
import typing as T
from collections import deque
from dataclasses import dataclass, field
from typing import Any, override
from uuid import uuid4
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.response import AgentResponse, AgentResponseData
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.runners.base import AgentState, BaseAgentRunner
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.provider import Provider
from astrbot.core.utils.config_number import coerce_int_config
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .constants import DEERFLOW_SESSION_PREFIX, DEERFLOW_THREAD_ID_KEY
from .deerflow_api_client import DeerFlowAPIClient
from .deerflow_content_mapper import (
@@ -43,12 +41,16 @@ from .deerflow_stream_utils import (
get_message_id,
)
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
"""DeerFlow Agent Runner via LangGraph HTTP API."""
_MAX_VALUES_HISTORY = 200
final_llm_resp: LLMResponse | None
@dataclass(frozen=True)
class _RunnerConfig:
@@ -129,9 +131,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
async def _finish_with_result(
self,
chain: MessageChain,
role: str,
self, chain: MessageChain, role: str
) -> AgentResponse:
self.final_llm_resp = LLMResponse(
role=role,
@@ -248,7 +248,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
await old_client.close()
except Exception as e:
logger.warning(
f"Failed to close previous DeerFlow API client cleanly: {e}",
f"Failed to close previous DeerFlow API client cleanly: {e}"
)
self.api_client = DeerFlowAPIClient(
@@ -262,32 +262,20 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = streaming
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
await self._load_config_and_client(provider_config or {})
await self._load_config_and_client(provider_config)
@override
async def step(self):
@@ -316,7 +304,9 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
yield await self._finish_with_error(err_msg)
@override
async def step_until_done(self, max_step: int):
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
if max_step <= 0:
raise ValueError("max_step must be greater than 0")
@@ -328,7 +318,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
if not self.done():
raise RuntimeError(
f"DeerFlow agent reached max_step ({max_step}) without completion.",
f"DeerFlow agent reached max_step ({max_step}) without completion."
)
def _extract_new_messages_from_values(
@@ -393,7 +383,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
thread_id = thread.get("thread_id", "")
if not thread_id:
raise Exception(
f"DeerFlow create thread returned invalid payload: {thread}",
f"DeerFlow create thread returned invalid payload: {thread}"
)
await sp.put_async(
@@ -549,7 +539,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(delta)),
),
)
]
if delta_text:
@@ -559,9 +549,9 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(delta_text),
chain=MessageChain().message(delta_text)
),
),
)
]
return []
@@ -613,7 +603,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
self._update_text_and_maybe_stream(
state=state,
new_full_text=latest_text or None,
),
)
)
return responses
@@ -630,7 +620,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
self._update_text_and_maybe_stream(
state=state,
delta_text=delta,
),
)
)
maybe_clarification = extract_clarification_from_event_data(data)
@@ -747,7 +737,7 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
if event_type == "end":
break
except TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
logger.warning(
"DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.",
self.timeout,

View File

@@ -1,8 +1,7 @@
import codecs
import json
import types
from collections.abc import AsyncGenerator
from typing import Any, Self
from typing import Any
from aiohttp import ClientResponse, ClientSession, ClientTimeout
@@ -156,26 +155,26 @@ class DeerFlowAPIClient:
self._session = ClientSession(trust_env=True)
return self._session
async def __aenter__(self) -> Self:
async def __aenter__(self) -> "DeerFlowAPIClient":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: types.TracebackType | None,
tb: object | None,
) -> None:
await self.close()
async def create_thread(self, timeout: float = 20) -> dict[str, Any]:
session = self._get_session()
url = f"{self.api_base}/api/langgraph/threads"
payload: dict[str, dict[str, object]] = {"metadata": {}}
payload = {"metadata": {}}
async with session.post(
url,
json=payload,
headers=self.headers,
timeout=ClientTimeout(total=timeout),
timeout=timeout,
proxy=self.proxy,
) as resp:
if resp.status not in (200, 201):
@@ -218,8 +217,7 @@ class DeerFlowAPIClient:
input_payload = payload.get("input")
message_count = 0
if isinstance(input_payload, dict) and isinstance(
input_payload.get("messages"),
list,
input_payload.get("messages"), list
):
message_count = len(input_payload["messages"])
# Log only a minimal summary to avoid exposing sensitive user content.
@@ -292,7 +290,7 @@ class DeerFlowAPIClient:
return
logger.warning(
"DeerFlowAPIClient garbage collected with unclosed session; "
"explicit close() should be called by runner lifecycle (or `async with`).",
"explicit close() should be called by runner lifecycle (or `async with`)."
)
@property

View File

@@ -62,7 +62,7 @@ def build_user_content(prompt: str, image_urls: list[str]) -> Any:
if not is_likely_base64_image(url):
skipped_invalid_images += 1
logger.debug(
"Skipped DeerFlow image input because it is neither URL/data URI nor valid base64.",
"Skipped DeerFlow image input because it is neither URL/data URI nor valid base64."
)
continue
compact_base64 = url.replace("\n", "").replace("\r", "")
@@ -250,18 +250,14 @@ def append_components_from_content(
if "content" in content:
append_components_from_content(
content.get("content"),
components,
image_resolver,
content.get("content"), components, image_resolver
)
return
kwargs = content.get("kwargs")
if isinstance(kwargs, dict) and "content" in kwargs:
append_components_from_content(
kwargs.get("content"),
components,
image_resolver,
kwargs.get("content"), components, image_resolver
)

View File

@@ -1,23 +1,25 @@
import base64
import os
from typing import Any, override
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.response import AgentResponseData
from astrbot.core.agent.run_context import ContextWrapper, TContext
from astrbot.core.agent.runners.base import AgentResponse, AgentState, BaseAgentRunner
from astrbot.core.agent.runners.dify.dify_api_client import DifyAPIClient
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.provider.provider import Provider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import MediaResolver
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .dify_api_client import DifyAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DifyAgentRunner(BaseAgentRunner[TContext]):
@@ -26,32 +28,19 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
enforce_max_turns: int = -1,
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
truncate_turns: int = 1,
custom_token_counter: Any = None,
custom_compressor: Any = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
provider_config: dict | None = None,
**kwargs: Any,
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = streaming
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
provider_config = provider_config or {}
self.api_key = provider_config.get("dify_api_key", "")
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_type = provider_config.get("dify_api_type", "chat")
@@ -72,7 +61,9 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
@override
async def step(self):
"""执行 Dify Agent 的一个步骤"""
"""
执行 Dify Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
@@ -82,7 +73,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
# 开始处理转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
@@ -90,27 +81,64 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
async for response in self._execute_dify_request():
yield response
except Exception as e:
logger.error(f"Dify 请求失败:{e!s}")
logger.error(f"Dify 请求失败{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err",
completion_text=f"Dify 请求失败:{e!s}",
role="err", completion_text=f"Dify 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Dify 请求失败:{e!s}"),
chain=MessageChain().message(f"Dify 请求失败{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(self, max_step: int):
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _upload_image_for_dify(
self,
image_url: str,
session_id: str,
) -> dict[str, str] | None:
image_data = await MediaResolver(
image_url,
media_type="image",
).to_base64_data(strict=True)
if image_data is None:
logger.warning("Dify 图片预处理结果为空,将忽略。")
return None
image_extension = image_data.mime_type.split("/", 1)[-1] or "png"
if image_extension == "jpeg":
image_extension = "jpg"
file_response = await self.api_client.file_upload(
file_data=image_data.to_bytes(),
user=session_id,
mime_type=image_data.mime_type,
file_name=f"image.{image_extension}",
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
return None
return {
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
async def _execute_dify_request(self):
"""执行 Dify 请求的核心逻辑"""
prompt = self.req.prompt or ""
@@ -129,43 +157,22 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
# 处理图片上传
files_payload = []
for image_url in image_urls:
# image_url is a base64 string
try:
image_data = base64.b64decode(image_url)
file_response = await self.api_client.file_upload(
file_data=image_data,
user=session_id,
mime_type="image/png",
file_name="image.png",
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。",
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
},
)
image_payload = await self._upload_image_for_dify(image_url, session_id)
except Exception as e:
logger.warning(f"上传图片失败:{e}")
logger.warning(f"上传图片失败{e}")
continue
if image_payload:
files_payload.append(image_payload)
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var: dict = (
await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
or {}
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
@@ -174,7 +181,7 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片"
prompt = "请描述这张图片"
async for chunk in self.api_client.chat_messages(
inputs={
@@ -182,9 +189,9 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
},
query=prompt,
user=session_id,
conversation_id=conversation_id or "",
conversation_id=conversation_id,
files=files_payload,
request_timeout=self.timeout,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk["event"] == "message" or chunk["event"] == "agent_message":
@@ -198,21 +205,21 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
)
conversation_id = chunk["conversation_id"]
# 如果是流式响应,发送增量数据
# 如果是流式响应发送增量数据
if self.streaming and chunk["answer"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(chunk["answer"]),
chain=MessageChain().message(chunk["answer"])
),
)
elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
logger.error(f"Dify 出现错误{chunk}")
raise Exception(
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}",
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
)
case "workflow":
@@ -224,17 +231,17 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
},
user=session_id,
files=files_payload,
request_timeout=self.timeout,
timeout=self.timeout,
):
logger.debug(f"dify workflow resp chunk: {chunk}")
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行",
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行"
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束",
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束"
)
case "text_chunk":
if self.streaming and chunk["data"]["text"]:
@@ -242,32 +249,32 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(
chunk["data"]["text"],
),
chunk["data"]["text"]
)
),
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束",
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
)
logger.debug(f"Dify 工作流结果:{chunk}")
logger.debug(f"Dify 工作流结果{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}",
f"Dify 工作流出现错误{chunk['data']['error']}"
)
raise Exception(
f"Dify 工作流出现错误:{chunk['data']['error']}",
f"Dify 工作流出现错误{chunk['data']['error']}"
)
if self.workflow_output_key not in chunk["data"]["outputs"]:
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}",
f"Dify 工作流的输出不包含指定的键名{self.workflow_output_key}"
)
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
raise Exception(f"未知的 Dify API 类型{self.api_type}")
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志")
logger.warning("Dify 请求结果为空请查看 Debug 日志")
# 解析结果
chain = await self.parse_dify_result(result)
@@ -293,23 +300,24 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict) -> Comp.BaseMessageComponent:
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"dify_{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
audio_path = await MediaResolver(
item["url"],
media_type="audio",
default_suffix=".wav",
).to_path(target_format="wav")
return Comp.Record(file=audio_path, url=audio_path)
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains: list[Comp.BaseMessageComponent] = []
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))

View File

@@ -3,8 +3,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
import anyio
from aiohttp import ClientResponse, ClientSession, ClientTimeout, FormData
from aiohttp import ClientResponse, ClientSession, FormData
from astrbot.core import logger
@@ -36,74 +35,66 @@ class DifyAPIClient:
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession(trust_env=True)
self.headers: dict[str, str] = {
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
async def chat_messages(
self,
inputs: dict[str, object],
inputs: dict,
query: str,
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
files: list[dict[str, object]] | None = None,
request_timeout: float = 60,
files: list[dict[str, Any]] | None = None,
timeout: float = 60,
) -> AsyncGenerator[dict[str, Any], None]:
if files is None:
files = []
url = f"{self.api_base}/chat-messages"
payload: dict[str, object] = {
"inputs": inputs,
"query": query,
"user": user,
"response_mode": response_mode,
"conversation_id": conversation_id,
"files": files,
}
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"chat_messages payload: {payload}")
async with self.session.post(
url,
json=payload,
headers=self.headers,
timeout=ClientTimeout(total=request_timeout),
timeout=timeout,
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}",
f"Dify /chat-messages 接口请求失败{resp.status}. {text}",
)
async for event in _stream_sse(resp):
yield event
async def workflow_run(
self,
inputs: dict[str, object],
inputs: dict,
user: str,
response_mode: str = "streaming",
files: list[dict[str, object]] | None = None,
request_timeout: float = 60,
files: list[dict[str, Any]] | None = None,
timeout: float = 60,
):
if files is None:
files = []
url = f"{self.api_base}/workflows/run"
payload: dict[str, object] = {
"inputs": inputs,
"user": user,
"response_mode": response_mode,
"files": files,
}
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"workflow_run payload: {payload}")
async with self.session.post(
url,
json=payload,
headers=self.headers,
timeout=ClientTimeout(total=request_timeout),
timeout=timeout,
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}",
f"Dify /workflows/run 接口请求失败{resp.status}. {text}",
)
async for event in _stream_sse(resp):
yield event
@@ -123,10 +114,8 @@ class DifyAPIClient:
file_path: The path to the file to upload.
file_data: The file data in bytes.
file_name: Optional file name when using file_data.
Returns:
A dictionary containing the uploaded file information.
"""
url = f"{self.api_base}/files/upload"
@@ -145,8 +134,8 @@ class DifyAPIClient:
# 使用文件路径
import os
async with await anyio.open_file(file_path, "rb") as f:
file_content = await f.read()
with open(file_path, "rb") as f:
file_content = f.read()
form.add_field(
"file",
file_content,
@@ -159,11 +148,11 @@ class DifyAPIClient:
async with self.session.post(
url,
data=form,
headers=self.headers, # 不包含 Content-Type,让 aiohttp 自动设置
headers=self.headers, # 不包含 Content-Type让 aiohttp 自动设置
) as resp:
if resp.status != 200 and resp.status != 201:
text = await resp.text()
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
raise Exception(f"Dify 文件上传失败{resp.status}. {text}")
return await resp.json() # {"id": "xxx", ...}
async def close(self) -> None:
@@ -172,11 +161,11 @@ class DifyAPIClient:
async def get_chat_convs(self, user: str, limit: int = 20):
# conversations. GET
url = f"{self.api_base}/conversations"
params: dict[str, str | int] = {
payload = {
"user": user,
"limit": limit,
}
async with self.session.get(url, params=params, headers=self.headers) as resp:
async with self.session.get(url, params=payload, headers=self.headers) as resp:
return await resp.json()
async def delete_chat_conv(self, user: str, conversation_id: str):

View File

@@ -1,5 +1,6 @@
import asyncio
import copy
import sys
import time
import traceback
import typing as T
@@ -7,7 +8,6 @@ import uuid
from contextlib import suppress
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import override
from mcp.types import (
BlobResourceContents,
@@ -63,6 +63,11 @@ from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@dataclass(slots=True)
class _HandleFunctionToolsResult:

View File

@@ -1,6 +1,6 @@
import copy
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic, TypedDict
from typing import Any, Generic
import jsonschema
import mcp
@@ -16,12 +16,6 @@ ParametersType = dict[str, Any]
ToolExecResult = str | mcp.types.CallToolResult
class ToolArgumentSpec(TypedDict):
name: str
type: str
description: str
@dataclass
class ToolSchema:
"""A class representing the schema of a tool for function calling."""
@@ -32,20 +26,14 @@ class ToolSchema:
description: str
"""The description of the tool."""
parameters: ParametersType | None = None
"""The parameters of the tool, in JSON Schema format."""
active: bool = True
"""Whether the tool is active."""
parameters: ParametersType
"""The parameters of the tool, in JSON Schema format."""
@model_validator(mode="after")
def validate_parameters(self) -> "ToolSchema":
if self.parameters is not None:
jsonschema.validate(
self.parameters,
jsonschema.Draft202012Validator.META_SCHEMA,
)
jsonschema.validate(
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
)
return self
@@ -75,23 +63,14 @@ class FunctionTool(ToolSchema, Generic[TContext]):
Declare this tool as a background task. Background tasks return immediately
with a task identifier while the real work continues asynchronously.
"""
source: str = "plugin"
"""
Origin of this tool: 'plugin' (from star plugins), 'internal' (AstrBot built-in),
or 'mcp' (from MCP servers). Used by WebUI for display grouping.
"""
def __repr__(self) -> str:
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(
self,
context: ContextWrapper[TContext],
**kwargs: Any,
) -> ToolExecResult:
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
"""Run the tool with the given arguments. The handler field has priority."""
raise NotImplementedError(
"FunctionTool.call() must be implemented by subclasses or set a handler.",
"FunctionTool.call() must be implemented by subclasses or set a handler."
)
@@ -103,13 +82,13 @@ class ToolSet:
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
tools: list[ToolSchema] = Field(default_factory=list)
tools: list[FunctionTool] = Field(default_factory=list)
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: ToolSchema) -> None:
def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set.
If a tool with the same name already exists:
@@ -132,26 +111,16 @@ class ToolSet:
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def normalize(self) -> None:
"""Sort tools by name for deterministic serialization.
This ensures the serialized tool schema sent to the LLM is
identical across requests regardless of registration/injection
order, enabling LLM provider prefix cache hits.
"""
self.tools.sort(key=lambda t: t.name)
def get_tool(self, name: str) -> FunctionTool | None:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
if isinstance(tool, FunctionTool):
return tool
return tool
return None
def get_light_tool_set(self) -> "ToolSet":
"""Return a light tool set with only name/description."""
light_tools: list[ToolSchema] = []
light_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
@@ -162,16 +131,16 @@ class ToolSet:
light_tools.append(
FunctionTool(
name=tool.name,
description=tool.description,
parameters=light_params,
description=tool.description,
handler=None,
),
)
)
return ToolSet(light_tools)
def get_param_only_tool_set(self) -> "ToolSet":
"""Return a tool set with name/parameters only (no description)."""
param_tools: list[ToolSchema] = []
param_tools = []
for tool in self.tools:
if hasattr(tool, "active") and not tool.active:
continue
@@ -183,10 +152,10 @@ class ToolSet:
param_tools.append(
FunctionTool(
name=tool.name,
description="",
parameters=params,
description="",
handler=None,
),
)
)
return ToolSet(param_tools)
@@ -194,18 +163,17 @@ class ToolSet:
def add_func(
self,
name: str,
func_args: list[ToolArgumentSpec],
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""Add a function tool to the set."""
properties: dict[str, dict[str, str]] = {}
params = {
"type": "object", # hard-coded here
"properties": properties,
"properties": {},
}
for param in func_args:
properties[param["name"]] = {
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
@@ -230,28 +198,22 @@ class ToolSet:
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return [t for t in self.tools if isinstance(t, FunctionTool)]
def list_tools(self) -> list[FunctionTool]:
"""Get the list of function tools (alias for func_list)."""
return [t for t in self.tools if isinstance(t, FunctionTool)]
return self.tools
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
function_dict: dict[str, Any] = {"name": tool.name}
func_def = {"type": "function", "function": {"name": tool.name}}
if tool.description:
function_dict["description"] = tool.description
func_def["function"]["description"] = tool.description
if tool.parameters is not None:
if (
tool.parameters and tool.parameters.get("properties")
) or not omit_empty_parameter_field:
function_dict["parameters"] = tool.parameters
func_def: dict[str, Any] = {
"type": "function",
"function": function_dict,
}
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result

View File

@@ -1,4 +1,3 @@
import abc
from collections.abc import AsyncGenerator
from typing import Any, Generic
@@ -8,9 +7,8 @@ from .run_context import ContextWrapper, TContext
from .tool import FunctionTool
class BaseFunctionToolExecutor(abc.ABC, Generic[TContext]):
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
@abc.abstractmethod
async def execute(
cls,
tool: FunctionTool,

View File

@@ -7,7 +7,7 @@ import base64
import os
import time
from dataclasses import dataclass, field
from typing import ClassVar, Self
from typing import ClassVar
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
@@ -35,20 +35,16 @@ class ToolImageCache:
Images are stored in data/temp/tool_images/ and can be retrieved by file path.
"""
_instance: ClassVar[Self | None] = None
_instance: ClassVar["ToolImageCache | None"] = None
CACHE_DIR_NAME: ClassVar[str] = "tool_images"
# Cache expiry time in seconds (1 hour)
CACHE_EXPIRY: ClassVar[int] = 3600
_initialized: bool
_cache_dir: str
def __new__(cls) -> Self:
instance = cls._instance
if instance is None:
instance = super().__new__(cls)
instance._initialized = False
cls._instance = instance
return instance
def __new__(cls) -> "ToolImageCache":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
@@ -89,7 +85,6 @@ class ToolImageCache:
Returns:
CachedImage object with file path.
"""
ext = self._get_file_extension(mime_type)
file_name = f"{tool_call_id}_{index}{ext}"
@@ -113,9 +108,7 @@ class ToolImageCache:
)
def get_image_base64_by_path(
self,
file_path: str,
mime_type: str = "image/png",
self, file_path: str, mime_type: str = "image/png"
) -> tuple[str, str] | None:
"""Read an image file and return its base64 encoded data.
@@ -125,7 +118,6 @@ class ToolImageCache:
Returns:
Tuple of (base64_data, mime_type) if found, None otherwise.
"""
if not os.path.exists(file_path):
return None
@@ -144,7 +136,6 @@ class ToolImageCache:
Returns:
Number of images cleaned up.
"""
now = time.time()
cleaned = 0

View File

@@ -1,5 +1,3 @@
from typing import ClassVar
from pydantic import Field
from pydantic.dataclasses import dataclass
@@ -10,7 +8,7 @@ from astrbot.core.star.context import Context
@dataclass
class AstrAgentContext:
__pydantic_config__: ClassVar[dict[str, bool]] = {"arbitrary_types_allowed": True}
__pydantic_config__ = {"arbitrary_types_allowed": True}
context: Context
"""The star context instance"""

View File

@@ -10,31 +10,22 @@ from astrbot.core.pipeline.context_utils import call_event_hook
from astrbot.core.star.star_handler import EventType
def _sdk_safe_payload(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, list):
return [_sdk_safe_payload(item) for item in value]
if isinstance(value, dict):
return {str(key): _sdk_safe_payload(item) for key, item in value.items()}
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
try:
dumped = model_dump()
except Exception:
return str(value)
return _sdk_safe_payload(dumped)
return str(value)
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_begin(
self, run_context: ContextWrapper[AstrAgentContext]
) -> None:
await call_event_hook(
run_context.context.event,
EventType.OnAgentBeginEvent,
run_context,
)
async def on_agent_done(self, run_context, llm_response) -> None:
# 执行事件钩子
if llm_response and llm_response.reasoning_content:
# we will use this in result_decorate stage to inject reasoning content to chain
run_context.context.event.set_extra(
"_llm_reasoning_content",
llm_response.reasoning_content,
"_llm_reasoning_content", llm_response.reasoning_content
)
await call_event_hook(
@@ -42,32 +33,12 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
EventType.OnLLMResponseEvent,
llm_response,
)
sdk_plugin_bridge = getattr(
run_context.context.context,
"sdk_plugin_bridge",
None,
await call_event_hook(
run_context.context.event,
EventType.OnAgentDoneEvent,
run_context,
llm_response,
)
if sdk_plugin_bridge is not None:
try:
await sdk_plugin_bridge.dispatch_message_event(
"llm_response",
run_context.context.event,
{
"completion_text": (
llm_response.completion_text if llm_response else ""
),
"tool_call_names": (
list(llm_response.tools_call_name)
if llm_response and llm_response.tools_call_name
else []
),
},
llm_response=llm_response,
)
except Exception as exc:
from astrbot.core import logger
logger.warning("SDK llm_response dispatch failed: %s", exc)
async def on_tool_start(
self,
@@ -81,25 +52,6 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
tool,
tool_args,
)
sdk_plugin_bridge = getattr(
run_context.context.context,
"sdk_plugin_bridge",
None,
)
if sdk_plugin_bridge is not None:
try:
await sdk_plugin_bridge.dispatch_message_event(
"using_llm_tool",
run_context.context.event,
{
"tool_name": tool.name,
"tool_args": _sdk_safe_payload(tool_args),
},
)
except Exception as exc:
from astrbot.core import logger
logger.warning("SDK using_llm_tool dispatch failed: %s", exc)
async def on_tool_end(
self,
@@ -116,26 +68,6 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
tool_args,
tool_result,
)
sdk_plugin_bridge = getattr(
run_context.context.context,
"sdk_plugin_bridge",
None,
)
if sdk_plugin_bridge is not None:
try:
await sdk_plugin_bridge.dispatch_message_event(
"llm_tool_respond",
run_context.context.event,
{
"tool_name": tool.name,
"tool_args": _sdk_safe_payload(tool_args),
"tool_result": _sdk_safe_payload(tool_result),
},
)
except Exception as exc:
from astrbot.core import logger
logger.warning("SDK llm_tool_respond dispatch failed: %s", exc)
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):

View File

@@ -3,8 +3,7 @@ import re
import time
import traceback
from collections.abc import AsyncGenerator
import anyio
from typing import Any
from astrbot.core import logger
from astrbot.core.agent.message import Message
@@ -49,8 +48,7 @@ def _extract_chain_json_data(msg_chain: MessageChain) -> dict | None:
def _record_tool_call_name(
tool_info: dict | None,
tool_name_by_call_id: dict[str, str],
tool_info: dict | None, tool_name_by_call_id: dict[str, str]
) -> None:
if not isinstance(tool_info, dict):
return
@@ -68,8 +66,7 @@ def _build_tool_call_status_message(tool_info: dict | None) -> str:
def _build_tool_result_status_message(
msg_chain: MessageChain,
tool_name_by_call_id: dict[str, str],
msg_chain: MessageChain, tool_name_by_call_id: dict[str, str]
) -> str:
tool_name = "unknown"
tool_result = ""
@@ -118,7 +115,7 @@ def _merge_buffered_llm_chains(
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 3,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
stream_to_general: bool = False,
@@ -139,7 +136,7 @@ async def run_agent(
if step_idx == max_step + 1:
logger.warning(
f"Agent reached max steps ({max_step}), forcing a final response.",
f"Agent reached max steps ({max_step}), forcing a final response."
)
if not agent_runner.done():
# 拔掉所有工具
@@ -149,8 +146,8 @@ async def run_agent(
agent_runner.run_context.messages.append(
Message(
role="user",
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户",
),
content="工具调用次数已达到上限请停止使用工具并根据已经收集到的信息对你的任务和发现进行总结然后直接回复用户",
)
)
stop_watcher = asyncio.create_task(
@@ -192,7 +189,7 @@ async def run_agent(
astr_event.trace.record(
"agent_tool_result",
tool_result=msg_chain.get_plain_text(
with_other_comps_mark=True,
with_other_comps_mark=True
),
)
@@ -204,13 +201,12 @@ async def run_agent(
await astr_event.send(msg_chain)
elif show_tool_use and show_tool_call_result:
status_msg = _build_tool_result_status_message(
msg_chain,
tool_name_by_call_id,
msg_chain, tool_name_by_call_id
)
await astr_event.send(
MessageChain(type="tool_call").message(status_msg),
MessageChain(type="tool_call").message(status_msg)
)
# 对于其他情况,暂时先不处理
# 对于其他情况暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming and show_tool_use:
@@ -225,7 +221,7 @@ async def run_agent(
tool_info = _extract_chain_json_data(resp.data["chain"])
astr_event.trace.record(
"agent_tool_call",
tool_name=tool_info or "unknown",
tool_name=tool_info if tool_info else "unknown",
)
_record_tool_call_name(tool_info, tool_name_by_call_id)
@@ -236,7 +232,7 @@ async def run_agent(
# Delay tool status notification until tool_call_result.
continue
chain = MessageChain(type="tool_call").message(
_build_tool_call_status_message(tool_info),
_build_tool_call_status_message(tool_info)
)
await astr_event.send(chain)
continue
@@ -300,7 +296,7 @@ async def run_agent(
MessageChain(
type="agent_stats",
chain=[Json(data=agent_runner.stats.to_dict())],
),
)
)
break
@@ -315,7 +311,7 @@ async def run_agent(
logger.error(traceback.format_exc())
custom_error_message = extract_persona_custom_error_message_from_event(
astr_event,
astr_event
)
if custom_error_message:
err_msg = custom_error_message
@@ -323,7 +319,7 @@ async def run_agent(
err_msg = (
f"Error occurred during AI execution.\n"
f"Error Type: {type(e).__name__}\n"
f"Error Message: {e!s}"
f"Error Message: {str(e)}"
)
error_llm_response = LLMResponse(
@@ -332,8 +328,7 @@ async def run_agent(
)
try:
await agent_runner.agent_hooks.on_agent_done(
agent_runner.run_context,
error_llm_response,
agent_runner.run_context, error_llm_response
)
except Exception:
logger.exception("Error in on_agent_done hook")
@@ -356,13 +351,13 @@ async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> Non
async def run_live_agent(
agent_runner: AgentRunner,
tts_provider: TTSProvider | None = None,
max_step: int = 3,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
buffer_intermediate_messages: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS
"""Live Mode 的 Agent 运行器支持流式 TTS
Args:
agent_runner: Agent 运行器
@@ -374,9 +369,8 @@ async def run_live_agent(
Yields:
MessageChain: 包含文本或音频数据的消息链
"""
# 如果没有 TTS Provider,直接发送文本
# 如果没有 TTS Provider直接发送文本
if not tts_provider:
async for chain in run_agent(
agent_runner,
@@ -392,11 +386,11 @@ async def run_live_agent(
support_stream = tts_provider.support_stream()
if support_stream:
logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)")
logger.info("[Live Agent] 使用流式 TTS原生支持 get_audio_stream")
else:
logger.info(
f"[Live Agent] 使用 TTS({tts_provider.meta().type} "
"使用 get_audio,将按句子分块生成音频)",
f"[Live Agent] 使用 TTS{tts_provider.meta().type} "
"使用 get_audio将按句子分块生成音频"
)
# 统计数据初始化
@@ -409,7 +403,7 @@ async def run_live_agent(
# audio_queue stored bytes or (text, bytes)
audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue()
# 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue
# 1. 启动 Agent Feeder 任务负责运行 Agent 并将文本分句喂给 text_queue
feeder_task = asyncio.create_task(
_run_agent_feeder(
agent_runner,
@@ -419,20 +413,25 @@ async def run_live_agent(
show_tool_call_result,
show_reasoning,
buffer_intermediate_messages,
),
)
)
# 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue
# 2. 启动 TTS 任务负责从 text_queue 读取文本并生成音频到 audio_queue
if support_stream:
tts_task = asyncio.create_task(
_safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue),
_safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue)
)
else:
tts_task = asyncio.create_task(
_simulated_stream_tts(tts_provider, text_queue, audio_queue),
_simulated_stream_tts(
tts_provider,
text_queue,
audio_queue,
agent_runner.run_context.context.event,
)
)
# 3. 主循环:从 audio_queue 读取音频并 yield
# 3. 主循环从 audio_queue 读取音频并 yield
try:
while True:
queue_item = await audio_queue.get()
@@ -447,7 +446,7 @@ async def run_live_agent(
audio_data = queue_item
if not first_chunk_received:
# 记录首帧延迟(从开始处理到收到第一个音频块)
# 记录首帧延迟从开始处理到收到第一个音频块
tts_first_frame_time = time.time() - tts_start_time
first_chunk_received = True
@@ -471,6 +470,7 @@ async def run_live_agent(
tts_task.cancel()
# 确保队列被消费
pass
tts_end_time = time.time()
@@ -489,10 +489,10 @@ async def run_live_agent(
"tts_first_frame_time": tts_first_frame_time,
"tts": tts_provider.meta().type,
"chat_model": agent_runner.provider.get_model(),
},
),
}
)
],
),
)
)
except Exception as e:
logger.error(f"发送 TTS 统计信息失败: {e}")
@@ -527,9 +527,9 @@ async def _run_agent_feeder(
if text:
buffer += text
# 分句逻辑:匹配标点符号
# r"([.。!!??\n]+)" 会保留分隔符
parts = re.split(r"([.。!!??\n]+)", buffer)
# 分句逻辑匹配标点符号
# r"([.。!?\n]+)" 会保留分隔符
parts = re.split(r"([.。!?\n]+)", buffer)
if len(parts) > 1:
# 处理完整的句子
@@ -579,8 +579,18 @@ async def _simulated_stream_tts(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
astr_event: Any,
) -> None:
"""模拟流式 TTS 分句生成音频"""
"""模拟流式 TTS 分句生成音频.
Args:
tts_provider: Provider used to synthesize audio files.
text_queue: Text chunks to synthesize. ``None`` ends the worker.
audio_queue: Synthesized audio bytes output queue.
astr_event: Current event used to cleanup generated TTS files after the
event finishes.
"""
try:
while True:
text = await text_queue.get()
@@ -591,12 +601,13 @@ async def _simulated_stream_tts(
audio_path = await tts_provider.get_audio(text)
if audio_path:
async with await anyio.open_file(audio_path, "rb") as f:
audio_data = await f.read()
with open(audio_path, "rb") as f:
audio_data = f.read()
astr_event.track_temporary_local_file(audio_path)
await audio_queue.put((text, audio_data))
except Exception as e:
logger.error(
f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}",
f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}"
)
# 继续处理下一句

View File

@@ -2,10 +2,10 @@ import asyncio
import inspect
import json
import traceback
import typing as T
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
from collections.abc import Sequence
from collections.abc import Set as AbstractSet
from typing import Any
import mcp
@@ -14,13 +14,11 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.message import Message
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolSchema, ToolSet
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.astr_main_agent_resources import (
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
BACKGROUND_TASK_WOKE_USER_PROMPT,
CONVERSATION_HISTORY_INJECT_PREFIX,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
@@ -32,7 +30,21 @@ from astrbot.core.message.message_event_result import (
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.tools.send_message import SEND_MESSAGE_TO_USER_TOOL
from astrbot.core.tools.computer_tools import (
CuaKeyboardTypeTool,
CuaMouseClickTool,
CuaScreenshotTool,
ExecuteShellTool,
FileDownloadTool,
FileEditTool,
FileReadTool,
FileUploadTool,
FileWriteTool,
GrepTool,
LocalPythonTool,
PythonTool,
)
from astrbot.core.tools.message_tools import SendMessageToUserTool
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.image_ref_utils import is_supported_image_ref
@@ -41,15 +53,18 @@ from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: Any) -> list[str]:
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
if image_urls_raw is None:
return []
if isinstance(image_urls_raw, str):
return [image_urls_raw]
if isinstance(image_urls_raw, (Sequence, AbstractSet)) and (
not isinstance(image_urls_raw, (str, bytes, bytearray))
if isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance(
image_urls_raw, (str, bytes, bytearray)
):
return [item for item in image_urls_raw if isinstance(item, str)]
logger.debug(
"Unsupported image_urls type in handoff tool args: %s",
type(image_urls_raw).__name__,
@@ -58,8 +73,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def _collect_image_urls_from_message(
cls,
run_context: ContextWrapper[AstrAgentContext],
cls, run_context: ContextWrapper[AstrAgentContext]
) -> list[str]:
urls: list[str] = []
event = getattr(run_context.context, "event", None)
@@ -86,11 +100,12 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def _collect_handoff_image_urls(
cls,
run_context: ContextWrapper[AstrAgentContext],
image_urls_raw: Any,
image_urls_raw: T.Any,
) -> list[str]:
candidates: list[str] = []
candidates.extend(cls._collect_image_urls_from_args(image_urls_raw))
candidates.extend(await cls._collect_image_urls_from_message(run_context))
normalized = normalize_and_dedupe_strings(candidates)
extensionless_local_roots = (get_astrbot_temp_path(),)
sanitized = [
@@ -112,13 +127,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用
"""执行函数调用
Args:
tool: The tool to execute.
run_context: The run context.
**tool_args: Tool-specific arguments.
**kwargs: 函数调用的参数。
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
@@ -128,19 +141,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
is_bg = tool_args.pop("background_task", False)
if is_bg:
async for r in cls._execute_handoff_background(
tool,
run_context,
**tool_args,
tool, run_context, **tool_args
):
yield r
return
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
elif isinstance(tool, MCPTool):
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
elif tool.is_background_task:
task_id = uuid.uuid4().hex
@@ -152,7 +165,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
task_id=task_id,
**tool_args,
)
except Exception as e:
except Exception as e: # noqa: BLE001
logger.error(
f"Background task {task_id} failed: {e!s}",
exc_info=True,
@@ -164,99 +177,68 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
text=f"Background task submitted. task_id={task_id}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
else:
rejection = cls._check_sandbox_capability(tool, run_context)
if rejection is not None:
yield rejection
return
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
_BROWSER_TOOL_NAMES: frozenset[str] = frozenset(
{
"astrbot_execute_browser",
"astrbot_execute_browser_batch",
"astrbot_run_browser_skill",
},
)
@classmethod
def _check_sandbox_capability(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
) -> mcp.types.CallToolResult | None:
"""Return a rejection result if the tool requires a sandbox capability
that is not available, or None if the tool may proceed.
"""
if tool.name not in cls._BROWSER_TOOL_NAMES:
return None
from astrbot.core.computer.computer_client import get_sandbox_capabilities
session_id = run_context.context.event.unified_msg_origin
caps = get_sandbox_capabilities(session_id)
if caps is None:
return None
if "browser" not in caps:
msg = f"Tool '{tool.name}' requires browser capability, but the current sandbox profile does not include it (capabilities: {list(caps)}). Please ask the administrator to switch to a sandbox profile with browser support, or use shell/python tools instead."
logger.warning(
"[ToolExec] capability_rejected tool=%s caps=%s",
tool.name,
list(caps),
)
return mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=msg)],
isError=True,
)
return None
@classmethod
def _get_runtime_computer_tools(
cls,
runtime: str,
tool_mgr: Any = None,
tool_mgr,
booter: str | None = None,
session_id: str = "",
sandbox_cfg: dict | None = None,
) -> dict[str, ToolSchema]:
"""Get computer runtime tools via ComputerToolProvider.
Delegates tool discovery to ComputerToolProvider for decoupled
sandbox / local tool injection. The *tool_mgr* parameter is kept
for backward compatibility but is no longer used.
Args:
runtime: ``'sandbox'``, ``'local'``, or ``'none'``.
tool_mgr: Kept for backward compatibility (unused).
booter: Short-form booter type (e.g. ``'shipyard_neo'``).
session_id: Session identifier.
sandbox_cfg: Full sandbox configuration dict (preferred over
*booter* when both are provided).
Returns:
Dict mapping tool name to FunctionTool instance.
"""
from astrbot.core.computer.computer_tool_provider import (
ComputerToolProvider,
)
from astrbot.core.tool_provider import ToolProviderContext
cfg: dict = {}
if sandbox_cfg is not None:
cfg = sandbox_cfg
elif booter:
cfg["booter"] = booter
ctx = ToolProviderContext(
computer_use_runtime=runtime,
sandbox_cfg=cfg,
session_id=session_id,
)
tools = ComputerToolProvider().get_tools(ctx)
return {t.name: t for t in tools}
) -> dict[str, FunctionTool]:
booter = "" if booter is None else str(booter).lower()
if runtime == "sandbox":
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
python_tool = tool_mgr.get_builtin_tool(PythonTool)
upload_tool = tool_mgr.get_builtin_tool(FileUploadTool)
download_tool = tool_mgr.get_builtin_tool(FileDownloadTool)
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
tools = {
shell_tool.name: shell_tool,
python_tool.name: python_tool,
upload_tool.name: upload_tool,
download_tool.name: download_tool,
read_tool.name: read_tool,
write_tool.name: write_tool,
edit_tool.name: edit_tool,
grep_tool.name: grep_tool,
}
if booter == "cua":
screenshot_tool = tool_mgr.get_builtin_tool(CuaScreenshotTool)
mouse_click_tool = tool_mgr.get_builtin_tool(CuaMouseClickTool)
keyboard_type_tool = tool_mgr.get_builtin_tool(CuaKeyboardTypeTool)
tools.update(
{
screenshot_tool.name: screenshot_tool,
mouse_click_tool.name: mouse_click_tool,
keyboard_type_tool.name: keyboard_type_tool,
}
)
return tools
if runtime == "local":
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
python_tool = tool_mgr.get_builtin_tool(LocalPythonTool)
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
return {
shell_tool.name: shell_tool,
python_tool.name: python_tool,
read_tool.name: read_tool,
write_tool.name: write_tool,
edit_tool.name: edit_tool,
grep_tool.name: grep_tool,
}
return {}
@classmethod
def _build_handoff_toolset(
@@ -269,12 +251,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
cfg = ctx.get_config(umo=event.unified_msg_origin)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
tool_mgr = ctx.get_llm_tool_manager()
tool_mgr = (
ctx.get_llm_tool_manager()
if hasattr(ctx, "get_llm_tool_manager")
else llm_tools
)
runtime_computer_tools = cls._get_runtime_computer_tools(
runtime,
tool_mgr,
provider_settings.get("sandbox", {}).get("booter"),
)
# Keep persona semantics aligned with the main agent: tools=None means
# "all tools", including runtime computer-use tools.
if tools is None:
toolset = ToolSet()
handoff_names = {
@@ -290,8 +279,10 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
for runtime_tool in runtime_computer_tools.values():
toolset.add_tool(runtime_tool)
return None if toolset.empty() else toolset
if not tools:
return None
toolset = ToolSet()
for tool_name_or_obj in tools:
if isinstance(tool_name_or_obj, str):
@@ -309,11 +300,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool[Any],
run_context: ContextWrapper[Any],
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
*,
image_urls_prepared: bool = False,
**tool_args: Any,
**tool_args: T.Any,
):
tool_args = dict(tool_args)
input_ = tool_args.get("input")
@@ -333,15 +324,21 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
tool_args.get("image_urls"),
)
tool_args["image_urls"] = image_urls
# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
ctx = run_context.context.context
event = run_context.context.event
umo = event.unified_msg_origin
# Use per-subagent provider override if configured; otherwise fall back
# to the current/default provider resolution.
prov_id = getattr(
tool,
"provider_id",
None,
tool, "provider_id", None
) or await ctx.get_current_chat_provider_id(umo)
# prepare begin dialogs
contexts = None
dialogs = tool.agent.begin_dialogs
if dialogs:
@@ -351,12 +348,13 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
contexts.append(
dialog
if isinstance(dialog, Message)
else Message.model_validate(dialog),
else Message.model_validate(dialog)
)
except Exception:
continue
prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {})
agent_max_step = int(prov_settings.get("max_agent_step", 3))
agent_max_step = int(prov_settings.get("max_agent_step", 30))
stream = prov_settings.get("streaming_response", False)
llm_resp = await ctx.tool_loop_agent(
event=event,
@@ -371,7 +369,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
stream=stream,
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)],
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
)
@classmethod
@@ -399,16 +397,21 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
task_id=task_id,
**tool_args,
)
except Exception as e:
except Exception as e: # noqa: BLE001
logger.error(
f"Background handoff {task_id} ({tool.name}) failed: {e!s}",
exc_info=True,
)
asyncio.create_task(_run_handoff_in_background())
text_content = mcp.types.TextContent(
type="text",
text=f"Background task dedicated to subagent '{tool.agent.name}' submitted. task_id={task_id}. The subagent '{tool.agent.name}' is working on the task on behalf of you. You will be notified when it finishes.",
text=(
f"Background task dedicated to subagent '{tool.agent.name}' submitted. task_id={task_id}. "
f"The subagent '{tool.agent.name}' is working on the task on hehalf you. "
f"You will be notified when it finishes."
),
)
yield mcp.types.CallToolResult(content=[text_content])
@@ -442,15 +445,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
result_text = (
f"error: Background task execution failed, internal error: {e!s}"
)
event = run_context.context.event
await cls._wake_main_agent_for_background_result(
run_context=run_context,
task_id=task_id,
tool_name=tool.name,
result_text=result_text,
tool_args=tool_args,
note=event.get_extra("background_note")
or f"Background task for subagent '{tool.agent.name}' finished.",
note=(
event.get_extra("background_note")
or f"Background task for subagent '{tool.agent.name}' finished."
),
summary_name=f"Dedicated to subagent `{tool.agent.name}`",
extra_result_fields={"subagent_name": tool.agent.name},
)
@@ -463,14 +470,13 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
task_id: str,
**tool_args,
) -> None:
# run the tool
result_text = ""
try:
async for r in cls._execute_local(
tool,
run_context,
tool_call_timeout=3600,
**tool_args,
tool, run_context, tool_call_timeout=3600, **tool_args
):
# collect results, currently we just collect the text results
if isinstance(r, mcp.types.CallToolResult):
result_text = ""
for content in r.content:
@@ -480,15 +486,19 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
result_text = (
f"error: Background task execution failed, internal error: {e!s}"
)
event = run_context.context.event
await cls._wake_main_agent_for_background_result(
run_context=run_context,
task_id=task_id,
tool_name=tool.name,
result_text=result_text,
tool_args=tool_args,
note=event.get_extra("background_note")
or f"Background task {tool.name} finished.",
note=(
event.get_extra("background_note")
or f"Background task {tool.name} finished."
),
summary_name=tool.name,
)
@@ -500,10 +510,10 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
task_id: str,
tool_name: str,
result_text: str,
tool_args: dict[str, Any],
tool_args: dict[str, T.Any],
note: str,
summary_name: str,
extra_result_fields: dict[str, Any] | None = None,
extra_result_fields: dict[str, T.Any] | None = None,
) -> None:
from astrbot.core.astr_main_agent import (
MainAgentBuildConfig,
@@ -513,6 +523,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
event = run_context.context.event
ctx = run_context.context.context
task_result = {
"task_id": task_id,
"tool_name": tool_name,
@@ -522,6 +533,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
if extra_result_fields:
task_result.update(extra_result_fields)
extras = {"background_task_result": task_result}
session = MessageSession.from_str(event.unified_msg_origin)
cron_event = CronMessageEvent(
context=ctx,
@@ -531,15 +543,14 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
message_type=session.message_type,
)
cron_event.role = event.role
config = MainAgentBuildConfig(
tool_call_timeout=run_context.tool_call_timeout,
streaming_response=ctx.get_config()
.get("provider_settings", {})
.get("stream", False),
)
req = ProviderRequest()
req.system_prompt = ""
conv = await _get_session_conv(event=cron_event, plugin_context=ctx)
req.conversation = conv
context = json.loads(conv.history)
@@ -547,30 +558,47 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
req.contexts = context
context_dump = req._print_friendly_context()
req.contexts = []
req.system_prompt += CONVERSATION_HISTORY_INJECT_PREFIX + context_dump
req.system_prompt += (
"\n\nBellow is you and user previous conversation history:\n"
f"{context_dump}"
)
bg = json.dumps(extras["background_task_result"], ensure_ascii=False)
req.system_prompt += BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT.format(
background_task_result=bg,
background_task_result=bg
)
req.prompt = (
"Proceed according to your system instructions. "
"Output using same language as previous conversation. "
"If you need to deliver the result to the user immediately, "
"you MUST use `send_message_to_user` tool to send the message directly to the user, "
"otherwise the user will not see the result. "
"After completing your task, summarize and output your actions and results. "
)
req.prompt = BACKGROUND_TASK_WOKE_USER_PROMPT
if not req.func_tool:
req.func_tool = ToolSet()
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
req.func_tool.add_tool(
ctx.get_llm_tool_manager().get_builtin_tool(SendMessageToUserTool)
)
result = await build_main_agent(
event=cron_event,
plugin_context=ctx,
config=config,
req=req,
event=cron_event, plugin_context=ctx, config=config, req=req
)
if not result:
logger.error(f"Failed to build main agent for background task {tool_name}.")
return
runner = result.agent_runner
async for _ in runner.step_until_done(3):
async for _ in runner.step_until_done(30):
# agent will send message to user via using tools
pass
llm_resp = runner.get_final_llm_resp()
task_meta = extras.get("background_task_result", {})
summary_note = f"[BackgroundTask] {summary_name} (task_id={task_meta.get('task_id', task_id)}) finished. Result: {task_meta.get('result') or result_text or 'no content'}"
summary_note = (
f"[BackgroundTask] {summary_name} "
f"(task_id={task_meta.get('task_id', task_id)}) finished. "
f"Result: {task_meta.get('result') or result_text or 'no content'}"
)
if llm_resp and llm_resp.completion_text:
summary_note += (
f"I finished the task, here is the result: {llm_resp.completion_text}"
@@ -597,13 +625,17 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
event = run_context.context.event
if not event:
raise ValueError("Event must be provided for local function tools.")
is_override_call = False
for ty in type(tool).mro():
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
is_override_call = True
break
if not tool.handler and (not hasattr(tool, "run")) and (not is_override_call):
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
raise ValueError("Tool must have a valid handler or override 'run' method.")
awaitable = None
method_name = ""
if tool.handler:
@@ -612,36 +644,12 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
elif is_override_call:
awaitable = tool.call
method_name = "call"
else:
awaitable = getattr(tool, "run", None)
if awaitable is not None:
method_name = "run"
elif hasattr(tool, "run"):
awaitable = getattr(tool, "run")
method_name = "run"
if awaitable is None:
raise ValueError("Tool must have a valid handler or override 'run' method.")
sdk_plugin_bridge = getattr(
run_context.context.context,
"sdk_plugin_bridge",
None,
)
if sdk_plugin_bridge is not None:
try:
await sdk_plugin_bridge.dispatch_message_event(
"calling_func_tool",
event,
{
"tool_name": tool.name,
"tool_args": json.loads(
json.dumps(tool_args, ensure_ascii=False, default=str),
),
},
)
except Exception as exc:
logger.warning("SDK calling_func_tool dispatch failed: %s", exc)
_HandlerType = Callable[
...,
Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
]
wrapper = call_local_llm_tool(
context=run_context,
handler=awaitable,
@@ -664,31 +672,28 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
yield mcp.types.CallToolResult(content=[text_content])
else:
res = run_context.context.event.get_result()
if res and res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
),
)
except Exception as e:
logger.error(f"Tool 直接发送消息失败: {e}", exc_info=True)
yield None
else:
yield mcp.types.CallToolResult(
content=[
mcp.types.TextContent(
type="text",
text="Tool executed successfully with no output.",
),
],
)
except TimeoutError:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
if res := run_context.context.event.get_result():
if res.chain:
try:
await event.send(
MessageChain(
chain=res.chain,
type="tool_direct_result",
)
)
except Exception as e:
logger.error(
f"Tool 直接发送消息失败: {e}",
exc_info=True,
)
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.",
) from None
)
except StopAsyncIteration:
break
@@ -707,19 +712,22 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: Callable[
handler: T.Callable[
...,
Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str,
*args,
**kwargs,
) -> AsyncGenerator[Any, None]:
) -> T.AsyncGenerator[T.Any, None]:
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
ready_to_call = None
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
event = context.context.event
try:
if method_name == "run" or method_name == "decorator_handler":
ready_to_call = handler(event, *args, **kwargs)
@@ -730,15 +738,19 @@ async def call_local_llm_tool(
except ValueError as e:
raise Exception(f"Tool execution ValueError: {e}") from e
except TypeError as e:
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
try:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# 跳过第一个参数event 或 context
if params:
params = params[1:]
param_strs = []
for param in params:
param_str = param.name
if param.annotation != inspect.Parameter.empty:
# 获取类型注解的字符串表示
if isinstance(param.annotation, type):
type_str = param.annotation.__name__
else:
@@ -747,35 +759,46 @@ async def call_local_llm_tool(
if param.default != inspect.Parameter.empty:
param_str += f" = {param.default!r}"
param_strs.append(param_str)
handler_param_str = (
", ".join(param_strs) if param_strs else "(no additional parameters)"
)
except Exception:
handler_param_str = "(unable to inspect signature)"
raise Exception(
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}",
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
) from e
except Exception as e:
trace_ = traceback.format_exc()
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, MessageEventResult | CommandResult):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, MessageEventResult | CommandResult):
event.set_result(ret)

View File

@@ -99,26 +99,6 @@ BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT = (
"{background_task_result}"
)
CONVERSATION_HISTORY_INJECT_PREFIX = (
"\n\nBelow is your and the user's previous conversation history:\n"
)
BACKGROUND_TASK_WOKE_USER_PROMPT = (
"Proceed according to your system instructions. "
"Output using same language as previous conversation. "
"If you need to deliver the result to the user immediately, "
"you MUST use `send_message_to_user` tool to send the message directly to the user, "
"otherwise the user will not see the result. "
"After completing your task, summarize and output your actions and results. "
)
CRON_TASK_WOKE_USER_PROMPT = (
"You are now responding to a scheduled task. "
"Proceed according to your system instructions. "
"Output using same language as previous conversation. "
"After completing your task, summarize and output your actions and results."
)
# we prevent astrbot from connecting to known malicious hosts
# these hosts are base64 encoded
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}

View File

@@ -1,6 +1,6 @@
import os
import uuid
from typing import Any, TypedDict, TypeVar
from typing import TypedDict, TypeVar
from astrbot.core import AstrBotConfig, logger
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
@@ -13,7 +13,7 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
_VT = TypeVar("_VT")
class ConfInfo(TypedDict, total=False):
class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
@@ -42,7 +42,7 @@ class AstrBotConfigManager:
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
self.abconf_data: dict | None = None
self.abconf_data = None
self._load_all_configs()
def _get_abconf_data(self) -> dict:
@@ -54,7 +54,7 @@ class AstrBotConfigManager:
scope="global",
scope_id="global",
)
return self.abconf_data # type: ignore[return-value]
return self.abconf_data
def _load_all_configs(self) -> None:
"""Load all configurations from the shared preferences."""
@@ -107,13 +107,12 @@ class AstrBotConfigManager:
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
raw_abconf: dict[str, Any] | None = self.sp.get(
abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
)
abconf_data: dict[str, dict[str, str]] = raw_abconf or {}
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"path": abconf_path,
@@ -123,7 +122,7 @@ class AstrBotConfigManager:
self.abconf_data = abconf_data
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件如果不存在,则 fallback 到默认配置文件"""
"""获取指定 umo 的配置文件如果不存在则 fallback 到默认配置文件"""
if not umo:
return self.confs["default"]
if isinstance(umo, MessageSession):
@@ -192,14 +191,11 @@ class AstrBotConfigManager:
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data: dict[str, dict[str, str]] = (
self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
)
or {}
abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
@@ -246,14 +242,11 @@ class AstrBotConfigManager:
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data: dict[str, dict[str, str]] = (
self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
)
or {}
abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
@@ -273,9 +266,9 @@ class AstrBotConfigManager:
self,
umo: str | None = None,
key: str | None = None,
default: _VT | None = None,
) -> _VT | None:
"""获取配置项umo 为 None 时使用默认配置"""
default: _VT = None,
) -> _VT:
"""获取配置项umo 为 None 时使用默认配置"""
if umo is None:
return self.confs["default"].get(key, default)
conf = self.get_conf(umo)

View File

@@ -1,6 +1,6 @@
"""AstrBot 备份与恢复模块
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据
提供数据导出和导入功能支持用户在服务器迁移时一键备份和恢复所有数据
"""
# 从 constants 模块导入共享常量
@@ -16,11 +16,11 @@ from .exporter import AstrBotExporter
from .importer import AstrBotImporter, ImportPreCheckResult
__all__ = [
"BACKUP_MANIFEST_VERSION",
"KB_METADATA_MODELS",
"MAIN_DB_MODELS",
"AstrBotExporter",
"AstrBotImporter",
"ImportPreCheckResult",
"MAIN_DB_MODELS",
"KB_METADATA_MODELS",
"get_backup_directories",
"BACKUP_MANIFEST_VERSION",
]

View File

@@ -1,6 +1,6 @@
"""AstrBot 备份模块共享常量
此文件定义了导出器和导入器共享的常量,确保两端配置一致
此文件定义了导出器和导入器共享的常量确保两端配置一致
"""
from sqlmodel import SQLModel
@@ -29,6 +29,7 @@ from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_plugin_data_path,
get_astrbot_plugin_path,
get_astrbot_skills_path,
get_astrbot_t2i_templates_path,
get_astrbot_temp_path,
get_astrbot_webchat_path,
@@ -66,11 +67,10 @@ KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
def get_backup_directories() -> dict[str, str]:
"""获取需要备份的目录列表
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录
使用 astrbot_path 模块动态获取路径支持通过环境变量 ASTRBOT_ROOT 自定义根目录
Returns:
dict: 键为备份文件中的目录名称,值为目录的绝对路径
dict: 键为备份文件中的目录名称值为目录的绝对路径
"""
return {
"plugins": get_astrbot_plugin_path(), # 插件本体
@@ -79,6 +79,7 @@ def get_backup_directories() -> dict[str, str]:
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
"webchat": get_astrbot_webchat_path(), # WebChat 数据
"temp": get_astrbot_temp_path(), # 临时文件
"skills": get_astrbot_skills_path(), # Skills
}

View File

@@ -1,7 +1,7 @@
"""AstrBot 数据导出器
负责将所有数据导出为 ZIP 备份文件
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移
负责将所有数据导出为 ZIP 备份文件
导出格式为 JSON这是数据库无关的方案支持未来向 MySQL/PostgreSQL 迁移
"""
import hashlib
@@ -12,7 +12,6 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
import anyio
from sqlalchemy import select
from astrbot.core import logger
@@ -40,19 +39,19 @@ CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
class AstrBotExporter:
"""AstrBot 数据导出器
导出内容:
- 主数据库所有表(data/data_v4.db)
- 知识库元数据(data/knowledge_base/kb.db)
导出内容
- 主数据库所有表data/data_v4.db
- 知识库元数据data/knowledge_base/kb.db
- 每个知识库的向量文档数据
- 配置文件(data/cmd_config.json)
- 配置文件data/cmd_config.json
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins)
- 插件数据目录(data/plugin_data)
- 配置目录(data/config)
- T2I 模板目录(data/t2i_templates)
- WebChat 数据目录(data/webchat)
- 临时文件目录(data/temp)
- 插件目录data/plugins
- 插件数据目录data/plugin_data
- 配置目录data/config
- T2I 模板目录data/t2i_templates
- WebChat 数据目录data/webchat
- 临时文件目录data/temp
"""
def __init__(
@@ -75,17 +74,16 @@ class AstrBotExporter:
Args:
output_dir: 输出目录
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
progress_callback: 进度回调函数接收参数 (stage, current, total, message)
Returns:
str: 生成的 ZIP 文件路径
"""
if output_dir is None:
output_dir = get_astrbot_backups_path()
# 确保输出目录存在
await anyio.Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(output_dir).mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"astrbot_backup_{timestamp}.zip"
@@ -100,10 +98,7 @@ class AstrBotExporter:
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
main_data = await self._export_main_database()
main_db_json = json.dumps(
main_data,
ensure_ascii=False,
indent=2,
default=str,
main_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/main_db.json", main_db_json)
self._add_checksum("databases/main_db.json", main_db_json)
@@ -119,26 +114,17 @@ class AstrBotExporter:
if self.kb_manager:
if progress_callback:
await progress_callback(
"kb_metadata",
0,
100,
"正在导出知识库元数据...",
"kb_metadata", 0, 100, "正在导出知识库元数据..."
)
kb_meta_data = await self._export_kb_metadata()
kb_meta_json = json.dumps(
kb_meta_data,
ensure_ascii=False,
indent=2,
default=str,
kb_meta_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/kb_metadata.json", kb_meta_json)
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
if progress_callback:
await progress_callback(
"kb_metadata",
100,
100,
"知识库元数据导出完成",
"kb_metadata", 100, 100, "知识库元数据导出完成"
)
# 导出每个知识库的文档数据
@@ -154,10 +140,7 @@ class AstrBotExporter:
)
doc_data = await self._export_kb_documents(kb_helper)
doc_json = json.dumps(
doc_data,
ensure_ascii=False,
indent=2,
default=str,
doc_data, ensure_ascii=False, indent=2, default=str
)
doc_path = f"databases/kb_{kb_id}/documents.json"
zf.writestr(doc_path, doc_json)
@@ -171,21 +154,15 @@ class AstrBotExporter:
if progress_callback:
await progress_callback(
"kb_documents",
total_kbs,
total_kbs,
"知识库文档导出完成",
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
)
# 3. 导出配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导出配置文件...")
if await anyio.Path(self.config_path).exists():
async with await anyio.open_file(
self.config_path,
encoding="utf-8",
) as f:
config_content = await f.read()
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config_content = f.read()
zf.writestr("config/cmd_config.json", config_content)
self._add_checksum("config/cmd_config.json", config_content)
if progress_callback:
@@ -201,10 +178,7 @@ class AstrBotExporter:
# 5. 导出插件和其他目录
if progress_callback:
await progress_callback(
"directories",
0,
100,
"正在导出插件和数据目录...",
"directories", 0, 100, "正在导出插件和数据目录..."
)
dir_stats = await self._export_directories(zf)
if progress_callback:
@@ -225,8 +199,8 @@ class AstrBotExporter:
except Exception as e:
logger.error(f"备份导出失败: {e}")
# 清理失败的文件
if await anyio.Path(zip_path).exists():
await anyio.Path(zip_path).unlink()
if os.path.exists(zip_path):
os.remove(zip_path)
raise
async def _export_main_database(self) -> dict[str, list[dict]]:
@@ -242,7 +216,7 @@ class AstrBotExporter:
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出表 {table_name}: {len(export_data[table_name])} 条记录",
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出表 {table_name} 失败: {e}")
@@ -266,7 +240,7 @@ class AstrBotExporter:
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录",
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
@@ -312,10 +286,7 @@ class AstrBotExporter:
logger.warning(f"导出 FAISS 索引失败: {e}")
async def _export_kb_media_files(
self,
zf: zipfile.ZipFile,
kb_helper: Any,
kb_id: str,
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
) -> None:
"""导出知识库的多媒体文件"""
try:
@@ -334,22 +305,20 @@ class AstrBotExporter:
logger.warning(f"导出知识库媒体文件失败: {e}")
async def _export_directories(
self,
zf: zipfile.ZipFile,
self, zf: zipfile.ZipFile
) -> dict[str, dict[str, int]]:
"""导出插件和其他数据目录
Returns:
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
"""
stats: dict[str, dict[str, int]] = {}
backup_directories = get_backup_directories()
for dir_name, dir_path in backup_directories.items():
full_path = Path(dir_path)
if not await anyio.Path(full_path).exists():
logger.debug(f"目录不存在,跳过: {full_path}")
if not full_path.exists():
logger.debug(f"目录不存在跳过: {full_path}")
continue
file_count = 0
@@ -378,7 +347,7 @@ class AstrBotExporter:
stats[dir_name] = {"files": file_count, "size": total_size}
logger.debug(
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节",
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
)
except Exception as e:
logger.warning(f"导出目录 {dir_path} 失败: {e}")
@@ -387,15 +356,13 @@ class AstrBotExporter:
return stats
async def _export_attachments(
self,
zf: zipfile.ZipFile,
attachments: list[dict],
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""导出附件文件"""
for attachment in attachments:
try:
file_path = attachment.get("path", "")
if file_path and await anyio.Path(file_path).exists():
if file_path and os.path.exists(file_path):
# 使用 attachment_id 作为文件名
attachment_id = attachment.get("attachment_id", "")
ext = os.path.splitext(file_path)[1]
@@ -407,9 +374,9 @@ class AstrBotExporter:
def _model_to_dict(self, record: Any) -> dict:
"""将 SQLModel 实例转换为字典
这是数据库无关的序列化方式,支持未来迁移到其他数据库
这是数据库无关的序列化方式支持未来迁移到其他数据库
"""
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
# 使用 SQLModel 内置的 model_dump 方法如果可用
if hasattr(record, "model_dump"):
data = record.model_dump(mode="python")
# 处理 datetime 类型
@@ -470,7 +437,7 @@ class AstrBotExporter:
media_files: list[str] = []
media_dir = kb_helper.kb_medias_dir
if media_dir.exists():
for _root, _, files in os.walk(media_dir):
for root, _, files in os.walk(media_dir):
for file in files:
media_files.append(file)
if media_files:
@@ -480,7 +447,7 @@ class AstrBotExporter:
"version": BACKUP_MANIFEST_VERSION,
"astrbot_version": VERSION,
"exported_at": datetime.now(timezone.utc).isoformat(),
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
"origin": "exported", # 标记备份来源exported=本实例导出, uploaded=用户上传
"schema_version": {
"main_db": "v4",
"kb_db": "v1",

View File

@@ -1,9 +1,9 @@
"""AstrBot 数据导入器
负责从 ZIP 备份文件恢复所有数据
导入时进行版本校验:
- 主版本(前两位)不同时直接拒绝导入
- 小版本(第三位)不同时提示警告,用户可选择强制导入
负责从 ZIP 备份文件恢复所有数据
导入时进行版本校验
- 主版本前两位不同时直接拒绝导入
- 小版本第三位不同时提示警告用户可选择强制导入
- 版本匹配时也需要用户确认
"""
@@ -16,7 +16,6 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
import anyio
from sqlalchemy import delete
from astrbot.core import logger
@@ -41,14 +40,13 @@ if TYPE_CHECKING:
def _get_major_version(version_str: str) -> str:
"""提取版本的主版本部分(前两位)
"""提取版本的主版本部分前两位
Args:
version_str: 版本字符串,"4.9.1", "4.10.0-beta"
version_str: 版本字符串"4.9.1", "4.10.0-beta"
Returns:
主版本字符串,"4.9", "4.10"
主版本字符串"4.9", "4.10"
"""
if not version_str:
return "0.0"
@@ -57,7 +55,7 @@ def _get_major_version(version_str: str) -> str:
parts = [p for p in version.split(".") if p] # 过滤空字符串
if len(parts) >= 2:
return f"{parts[0]}.{parts[1]}"
if len(parts) == 1 and parts[0]:
elif len(parts) == 1 and parts[0]:
return f"{parts[0]}.0"
return "0.0"
@@ -121,14 +119,14 @@ class _InvalidCountWarnLimiter:
if self.limit > 0:
if self._count < self.limit:
logger.warning(
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
"platform_stats count 非法已按 0 处理: value=%r, key=%s",
value,
key_for_log,
)
self._count += 1
if self._count == self.limit and not self._suppression_logged:
logger.warning(
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
"platform_stats 非法 count 告警已达到上限 (%d)后续将抑制",
self.limit,
)
self._suppression_logged = True
@@ -137,7 +135,7 @@ class _InvalidCountWarnLimiter:
if not self._suppression_logged:
# limit <= 0: emit only one suppression warning.
logger.warning(
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
"platform_stats 非法 count 告警已达到上限 (%d)后续将抑制",
self.limit,
)
self._suppression_logged = True
@@ -147,15 +145,15 @@ class _InvalidCountWarnLimiter:
class ImportPreCheckResult:
"""导入预检查结果
用于在实际导入前检查备份文件的版本兼容性,
并返回确认信息让用户决定是否继续导入
用于在实际导入前检查备份文件的版本兼容性
并返回确认信息让用户决定是否继续导入
"""
# 检查是否通过(文件有效且版本可导入)
# 检查是否通过文件有效且版本可导入
valid: bool = False
# 是否可以导入(版本兼容)
# 是否可以导入版本兼容
can_import: bool = False
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
# 版本状态: match完全匹配, minor_diff小版本差异, major_diff主版本不同拒绝
version_status: str = ""
# 备份文件中的 AstrBot 版本
backup_version: str = ""
@@ -163,11 +161,11 @@ class ImportPreCheckResult:
current_version: str = VERSION
# 备份创建时间
backup_time: str = ""
# 确认消息(显示给用户)
# 确认消息显示给用户
confirm_message: str = ""
# 警告消息列表
warnings: list[str] = field(default_factory=list)
# 错误消息(如果检查失败)
# 错误消息如果检查失败
error: str = ""
# 备份包含的内容摘要
backup_summary: dict = field(default_factory=dict)
@@ -225,18 +223,18 @@ class DatabaseClearError(RuntimeError):
class AstrBotImporter:
"""AstrBot 数据导入器
导入备份文件中的所有数据,包括:
导入备份文件中的所有数据包括
- 主数据库所有表
- 知识库元数据和文档
- 配置文件
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins)
- 插件数据目录(data/plugin_data)
- 配置目录(data/config)
- T2I 模板目录(data/t2i_templates)
- WebChat 数据目录(data/webchat)
- 临时文件目录(data/temp)
- 插件目录data/plugins
- 插件数据目录data/plugin_data
- 配置目录data/config
- T2I 模板目录data/t2i_templates
- WebChat 数据目录data/webchat
- 临时文件目录data/temp
"""
def __init__(
@@ -254,15 +252,14 @@ class AstrBotImporter:
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
"""预检查备份文件
在实际导入前检查备份文件的有效性和版本兼容性
返回检查结果供前端显示确认对话框
在实际导入前检查备份文件的有效性和版本兼容性
返回检查结果供前端显示确认对话框
Args:
zip_path: ZIP 备份文件路径
Returns:
ImportPreCheckResult: 预检查结果
"""
result = ImportPreCheckResult()
result.current_version = VERSION
@@ -278,7 +275,7 @@ class AstrBotImporter:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
result.error = "备份文件缺少 manifest.json不是有效的 AstrBot 备份"
return result
except json.JSONDecodeError as e:
result.error = f"manifest.json 格式错误: {e}"
@@ -303,7 +300,7 @@ class AstrBotImporter:
result.can_import = version_check["can_import"]
# 版本信息由前端根据 version_status 和 i18n 生成显示
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
# 不再将版本消息添加到 warnings 列表中避免中文硬编码
# warnings 列表保留用于其他非版本相关的警告
return result
@@ -318,13 +315,12 @@ class AstrBotImporter:
def _check_version_compatibility(self, backup_version: str) -> dict:
"""检查版本兼容性
规则:
- 主版本(前两位,如 4.9)必须一致,否则拒绝
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
规则
- 主版本前两位如 4.9必须一致否则拒绝
- 小版本第三位如 4.9.1 vs 4.9.2不同时警告但允许导入
Returns:
dict: {status, can_import, message}
"""
if not backup_version:
return {
@@ -333,7 +329,7 @@ class AstrBotImporter:
"message": "备份文件缺少版本信息",
}
# 提取主版本(前两位)进行比较
# 提取主版本前两位进行比较
backup_major = _get_major_version(backup_version)
current_major = _get_major_version(VERSION)
@@ -343,8 +339,8 @@ class AstrBotImporter:
"status": "major_diff",
"can_import": False,
"message": (
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}"
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot"
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}"
f"跨主版本导入可能导致数据损坏请使用相同主版本的 AstrBot"
),
}
@@ -355,7 +351,7 @@ class AstrBotImporter:
"status": "minor_diff",
"can_import": True,
"message": (
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}"
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}"
),
}
@@ -375,16 +371,15 @@ class AstrBotImporter:
Args:
zip_path: ZIP 备份文件路径
mode: 导入模式,目前仅支持 "replace"(清空后导入)
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
mode: 导入模式目前仅支持 "replace"清空后导入
progress_callback: 进度回调函数接收参数 (stage, current, total, message)
Returns:
ImportResult: 导入结果
"""
result = ImportResult()
if not await anyio.Path(zip_path).exists():
if not os.path.exists(zip_path):
result.add_error(f"备份文件不存在: {zip_path}")
return result
@@ -466,12 +461,12 @@ class AstrBotImporter:
try:
config_content = zf.read("config/cmd_config.json")
# 备份现有配置
if await anyio.Path(self.config_path).exists():
if os.path.exists(self.config_path):
backup_path = f"{self.config_path}.bak"
shutil.copy2(self.config_path, backup_path)
async with await anyio.open_file(self.config_path, "wb") as f:
await f.write(config_content)
with open(self.config_path, "wb") as f:
f.write(config_content)
result.imported_files["config"] = 1
except Exception as e:
result.add_warning(f"导入配置文件失败: {e}")
@@ -484,8 +479,7 @@ class AstrBotImporter:
await progress_callback("attachments", 0, 100, "正在导入附件...")
attachment_count = await self._import_attachments(
zf,
main_data.get("attachments", []),
zf, main_data.get("attachments", [])
)
result.imported_files["attachments"] = attachment_count
@@ -495,10 +489,7 @@ class AstrBotImporter:
# 6. 导入插件和其他目录
if progress_callback:
await progress_callback(
"directories",
0,
100,
"正在导入插件和数据目录...",
"directories", 0, 100, "正在导入插件和数据目录..."
)
dir_stats = await self._import_directories(zf, manifest, result)
@@ -520,8 +511,8 @@ class AstrBotImporter:
def _validate_version(self, manifest: dict) -> None:
"""验证版本兼容性 - 仅允许相同主版本导入
注意:此方法仅在 import_all 中调用,用于双重校验
前端应先调用 pre_check 获取详细的版本信息并让用户确认
注意此方法仅在 import_all 中调用用于双重校验
前端应先调用 pre_check 获取详细的版本信息并让用户确认
"""
backup_version = manifest.get("astrbot_version")
if not backup_version:
@@ -539,15 +530,16 @@ class AstrBotImporter:
async def _clear_main_db(self) -> None:
"""清空主数据库所有表"""
async with self.main_db.get_db() as session, session.begin():
for table_name, model_class in MAIN_DB_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
raise DatabaseClearError(
f"清空表 {table_name} 失败: {e}",
) from e
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, model_class in MAIN_DB_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
raise DatabaseClearError(
f"清空表 {table_name} 失败: {e}"
) from e
async def _clear_kb_data(self) -> None:
"""清空知识库数据"""
@@ -555,13 +547,14 @@ class AstrBotImporter:
return
# 清空知识库元数据表
async with self.kb_manager.kb_db.get_db() as session, session.begin():
for table_name, model_class in KB_METADATA_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空知识库表 {table_name}")
except Exception as e:
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, model_class in KB_METADATA_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空知识库表 {table_name}")
except Exception as e:
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
# 删除知识库文件目录
for kb_id in list(self.kb_manager.kb_insts.keys()):
@@ -576,47 +569,45 @@ class AstrBotImporter:
self.kb_manager.kb_insts.clear()
async def _import_main_database(
self,
data: dict[str, list[dict]],
self, data: dict[str, list[dict]]
) -> dict[str, int]:
"""导入主数据库数据"""
imported: dict[str, int] = {}
async with self.main_db.get_db() as session, session.begin():
for table_name, rows in data.items():
model_class = MAIN_DB_MODELS.get(table_name)
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, rows in data.items():
model_class = MAIN_DB_MODELS.get(table_name)
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
count = 0
for row in normalized_rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入记录到 {table_name} 失败: {e}")
count = 0
for row in normalized_rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入记录到 {table_name} 失败: {e}")
imported[table_name] = count
logger.debug(f"导入表 {table_name}: {count} 条记录")
imported[table_name] = count
logger.debug(f"导入表 {table_name}: {count} 条记录")
return imported
def _preprocess_main_table_rows(
self,
table_name: str,
rows: list[dict[str, Any]],
self, table_name: str, rows: list[dict[str, Any]]
) -> list[dict[str, Any]]:
if table_name == "platform_stats":
normalized_rows = self._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(normalized_rows)
if duplicate_count > 0:
logger.warning(
"检测到 %s 重复键 %d,已在导入前聚合",
"检测到 %s 重复键 %d已在导入前聚合",
table_name,
duplicate_count,
)
@@ -624,8 +615,7 @@ class AstrBotImporter:
return rows
def _merge_platform_stats_rows(
self,
rows: list[dict[str, Any]],
self, rows: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Merge duplicate platform_stats rows by normalized timestamp/platform key.
@@ -633,7 +623,6 @@ class AstrBotImporter:
- Invalid/empty timestamps are kept as distinct rows to avoid accidental merging.
- Non-string platform_id/platform_type are kept as distinct rows.
- Invalid count warnings are rate-limited per function invocation.
"""
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
result: list[dict[str, Any]] = []
@@ -733,23 +722,24 @@ class AstrBotImporter:
return
# 1. 导入知识库元数据
async with self.kb_manager.kb_db.get_db() as session, session.begin():
for table_name, rows in kb_meta_data.items():
model_class = KB_METADATA_MODELS.get(table_name)
if not model_class:
continue
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, rows in kb_meta_data.items():
model_class = KB_METADATA_MODELS.get(table_name)
if not model_class:
continue
count = 0
for row in rows:
try:
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
count = 0
for row in rows:
try:
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
result.imported_tables[f"kb_{table_name}"] = count
result.imported_tables[f"kb_{table_name}"] = count
# 2. 导入每个知识库的文档和文件
for kb_data in kb_meta_data.get("knowledge_bases", []):
@@ -778,10 +768,8 @@ class AstrBotImporter:
if faiss_path in zf.namelist():
try:
target_path = kb_dir / "index.faiss"
with zf.open(faiss_path) as src:
content = src.read()
async with await anyio.open_file(target_path, "wb") as dst:
await dst.write(content)
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
@@ -797,10 +785,8 @@ class AstrBotImporter:
logger.warning(f"媒体文件路径越界,已跳过: {target_path}")
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src:
content = src.read()
async with await anyio.open_file(target_path, "wb") as dst:
await dst.write(content)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
@@ -866,10 +852,8 @@ class AstrBotImporter:
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src:
content = src.read()
async with await anyio.open_file(target_path, "wb") as dst:
await dst.write(content)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
count += 1
except Exception as e:
logger.warning(f"导入附件 {name} 失败: {e}")
@@ -891,14 +875,13 @@ class AstrBotImporter:
Returns:
dict: 每个目录导入的文件数量
"""
dir_stats: dict[str, int] = {}
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
# 检查备份版本是否支持目录备份需要版本 >= 1.1
backup_version = manifest.get("version", "1.0")
if VersionComparator.compare_version(backup_version, "1.1") < 0:
logger.info("备份版本不支持目录备份,跳过目录导入")
logger.info("备份版本不支持目录备份跳过目录导入")
return dir_stats
backed_up_dirs = manifest.get("directories", [])
@@ -925,16 +908,16 @@ class AstrBotImporter:
if not dir_files:
continue
# 备份现有目录(如果存在)
if await anyio.Path(target_dir).exists():
# 备份现有目录如果存在
if target_dir.exists():
backup_path = Path(f"{target_dir}.bak")
if await anyio.Path(backup_path).exists():
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(str(target_dir), str(backup_path))
logger.debug(f"已备份现有目录 {target_dir}{backup_path}")
# 创建目标目录
await anyio.Path(target_dir).mkdir(parents=True, exist_ok=True)
target_dir.mkdir(parents=True, exist_ok=True)
# 解压文件
for name in dir_files:
@@ -956,10 +939,8 @@ class AstrBotImporter:
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src:
content = src.read()
async with await anyio.open_file(target_path, "wb") as dst:
await dst.write(content)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
file_count += 1
except Exception as e:
result.add_warning(f"导入文件 {name} 失败: {e}")
@@ -979,10 +960,9 @@ class AstrBotImporter:
# 获取模型的 datetime 字段
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm import Mapper
try:
mapper: Mapper[Any] = sa_inspect(model_class)
mapper = sa_inspect(model_class)
for column in mapper.columns:
if column.name in result and result[column.name] is not None:
# 检查是否是 datetime 类型的列

View File

@@ -1,9 +1,4 @@
from __future__ import annotations
import abc
from typing import TYPE_CHECKING
from astrbot.core.computer.olayer import (
from ..olayer import (
BrowserComponent,
FileSystemComponent,
GUIComponent,
@@ -11,25 +6,16 @@ from astrbot.core.computer.olayer import (
ShellComponent,
)
if TYPE_CHECKING:
from astrbot.core.agent.tool import ToolSchema
class ComputerBooter(abc.ABC):
class ComputerBooter:
@property
@abc.abstractmethod
def fs(self) -> FileSystemComponent:
raise NotImplementedError("Subclass must implement fs property")
def fs(self) -> FileSystemComponent: ...
@property
@abc.abstractmethod
def python(self) -> PythonComponent:
raise NotImplementedError("Subclass must implement python property")
def python(self) -> PythonComponent: ...
@property
@abc.abstractmethod
def shell(self) -> ShellComponent:
raise NotImplementedError("Subclass must implement shell property")
def shell(self) -> ShellComponent: ...
@property
def capabilities(self) -> tuple[str, ...] | None:
@@ -48,44 +34,29 @@ class ComputerBooter(abc.ABC):
def gui(self) -> GUIComponent | None:
return None
@abc.abstractmethod
async def boot(self, session_id: str) -> None: ...
@abc.abstractmethod
async def shutdown(self, **kwargs) -> None:
"""Shut down the computer sandbox.
Subclasses may accept type-specific keyword arguments.
Subclasses may accept extra keyword arguments for
type-specific cleanup (e.g. ``delete_sandbox`` for
ShipyardNeoBooter). The default implementation ignores
them.
"""
...
async def upload_file(self, path: str, file_name: str) -> dict:
"""Upload file to the computer.
Should return a dict with `success` (bool) and `file_path` (str) keys.
"""
raise NotImplementedError("Subclass must implement upload_file method")
...
async def download_file(self, remote_path: str, local_path: str) -> None:
"""Download file from the computer."""
raise NotImplementedError("Subclass must implement download_file method")
...
@abc.abstractmethod
async def available(self) -> bool:
"""Check if the computer is available."""
raise NotImplementedError("Subclass must implement available method")
@classmethod
def get_default_tools(cls) -> list[ToolSchema]:
"""Conservative full tool list (no instance needed, pre-boot)."""
return []
def get_tools(self) -> list[ToolSchema]:
"""Capability-filtered tool list (post-boot).
Defaults to get_default_tools().
"""
return self.__class__.get_default_tools()
@classmethod
def get_system_prompt_parts(cls) -> list[str]:
"""Booter-specific system prompt fragments (static text, no instance needed)."""
return []
...

View File

@@ -60,7 +60,7 @@ class BayContainerManager:
raise RuntimeError(
"Failed to connect to Docker daemon. "
"Ensure Docker is installed and running, or configure "
"an explicit Bay endpoint instead of auto-start mode.",
"an explicit Bay endpoint instead of auto-start mode."
) from exc
# 1. Look for an existing managed container
@@ -72,12 +72,13 @@ class BayContainerManager:
logger.info("[BayManager] Reusing existing Bay container: %s", cid)
self._container = await self._docker.containers.get(existing["Id"])
return f"http://127.0.0.1:{self._host_port}"
# Container exists but stopped — restart it
logger.info("[BayManager] Restarting stopped Bay container")
container = await self._docker.containers.get(existing["Id"])
await container.start()
self._container = container
return f"http://127.0.0.1:{self._host_port}"
else:
# Container exists but stopped — restart it
logger.info("[BayManager] Restarting stopped Bay container")
container = await self._docker.containers.get(existing["Id"])
await container.start()
self._container = container
return f"http://127.0.0.1:{self._host_port}"
# 2. Pull image if needed
await self._pull_image_if_needed()
@@ -95,7 +96,7 @@ class BayContainerManager:
"BAY_SERVER__HOST=0.0.0.0",
f"BAY_SERVER__PORT={BAY_PORT}",
"BAY_DATA_DIR=/app/data",
# allow_anonymous=false auto-provisions API key
# allow_anonymous=false auto-provisions API key
"BAY_SECURITY__ALLOW_ANONYMOUS=false",
],
"HostConfig": {
@@ -110,8 +111,7 @@ class BayContainerManager:
},
}
self._container = await self._docker.containers.create_or_replace(
BAY_CONTAINER_NAME,
config,
BAY_CONTAINER_NAME, config
)
await self._container.start()
logger.info("[BayManager] Bay container started: %s", BAY_CONTAINER_NAME)
@@ -129,8 +129,7 @@ class BayContainerManager:
while loop.time() < deadline:
try:
async with session.get(
url,
timeout=aiohttp.ClientTimeout(total=3),
url, timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status == 200:
logger.info("[BayManager] Bay is healthy")
@@ -142,7 +141,7 @@ class BayContainerManager:
await asyncio.sleep(HEALTH_POLL_INTERVAL_S)
raise TimeoutError(
f"Bay did not become healthy within {timeout}s (last error: {last_error})",
f"Bay did not become healthy within {timeout}s (last error: {last_error})"
)
async def read_credentials(self) -> str:
@@ -203,8 +202,7 @@ class BayContainerManager:
return api_key
except Exception as exc:
logger.debug(
"[BayManager] Failed to read credentials from container: %s",
exc,
"[BayManager] Failed to read credentials from container: %s", exc
)
return ""

View File

@@ -1,30 +1,18 @@
from __future__ import annotations
import asyncio
import functools
import random
from typing import TYPE_CHECKING, Any, cast
from typing import Any
import aiohttp
import anyio
import boxlite
from shipyard.filesystem import FileSystemComponent as ShipyardFileSystemComponent
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
from shipyard.python import PythonComponent as ShipyardPythonComponent
from shipyard.shell import ShellComponent as ShipyardShellComponent
from astrbot.api import logger
if TYPE_CHECKING:
from astrbot.core.agent.tool import ToolSchema
from astrbot.core.computer.olayer import (
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard import ShipyardFileSystemWrapper, ShipyardShellWrapper
from .shipyard import ShipyardFileSystemWrapper
class MockShipyardSandboxClient:
@@ -47,10 +35,11 @@ class MockShipyardSandboxClient:
) as response:
if response.status == 200:
return await response.json()
error_text = await response.text()
raise Exception(
f"Failed to exec operation: {response.status} {error_text}",
)
else:
error_text = await response.text()
raise Exception(
f"Failed to exec operation: {response.status} {error_text}"
)
async def upload_file(self, path: str, remote_path: str) -> dict:
"""Upload a file to the sandbox"""
@@ -58,15 +47,15 @@ class MockShipyardSandboxClient:
try:
# Read file content
async with await anyio.open_file(path, "rb") as f:
file_content = await f.read()
with open(path, "rb") as f:
file_content = f.read()
# Create multipart form data
data = aiohttp.FormData()
data.add_field(
"file",
file_content,
filename=remote_path.rsplit("/", maxsplit=1)[-1],
filename=remote_path.split("/")[-1],
content_type="application/octet-stream",
)
data.add_field("file_path", remote_path)
@@ -77,7 +66,7 @@ class MockShipyardSandboxClient:
async with session.post(url, data=data) as response:
if response.status == 200:
logger.info(
"[Computer] file_upload booter=boxlite remote_path=%s",
"[Computer] File uploaded to Boxlite sandbox: %s",
remote_path,
)
return {
@@ -85,52 +74,39 @@ class MockShipyardSandboxClient:
"message": "File uploaded successfully",
"file_path": remote_path,
}
error_text = await response.text()
logger.warning(
"[Computer] file_upload_failed booter=boxlite error=http_status status=%s remote_path=%s",
response.status,
remote_path,
)
return {
"success": False,
"error": f"Server returned {response.status}: {error_text}",
"message": "File upload failed",
}
else:
error_text = await response.text()
return {
"success": False,
"error": f"Server returned {response.status}: {error_text}",
"message": "File upload failed",
}
except aiohttp.ClientError as e:
logger.error("[Computer] file_upload_failed booter=boxlite error=%s", e)
logger.error(f"Failed to upload file: {e}")
return {
"success": False,
"error": f"Connection error: {e!s}",
"error": f"Connection error: {str(e)}",
"message": "File upload failed",
}
except TimeoutError:
logger.warning(
"[Computer] file_upload_failed booter=boxlite error=timeout remote_path=%s",
remote_path,
)
except asyncio.TimeoutError:
return {
"success": False,
"error": "File upload timeout",
"message": "File upload failed",
}
except FileNotFoundError:
logger.error(
"[Computer] file_upload_failed booter=boxlite error=file_not_found path=%s",
path,
)
logger.error(f"File not found: {path}")
return {
"success": False,
"error": f"File not found: {path}",
"message": "File upload failed",
}
except Exception as exc:
logger.exception(
"[Computer] file_upload_failed booter=boxlite error=unexpected",
)
except Exception as e:
logger.error(f"Unexpected error uploading file: {e}")
return {
"success": False,
"error": f"Internal error: {exc!s}",
"error": f"Internal error: {str(e)}",
"message": "File upload failed",
}
@@ -139,46 +115,27 @@ class MockShipyardSandboxClient:
loop = 60
while loop > 0:
try:
logger.debug(
"[Computer] health_check booter=boxlite ship_id=%s session=%s endpoint=%s attempt=%s healthy=pending",
ship_id,
session_id,
self.sb_url,
61 - loop,
logger.info(
f"Checking health for sandbox {ship_id} on {self.sb_url}..."
)
url = f"{self.sb_url}/health"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
logger.debug(
"[Computer] health_check booter=boxlite ship_id=%s session=%s endpoint=%s healthy=true",
ship_id,
session_id,
self.sb_url,
)
return
await asyncio.sleep(1)
loop -= 1
logger.info(f"Sandbox {ship_id} is healthy")
return
except Exception:
await asyncio.sleep(1)
loop -= 1
logger.warning(
"[Computer] health_check_timeout booter=boxlite ship_id=%s session=%s endpoint=%s",
ship_id,
session_id,
self.sb_url,
)
class BoxliteBooter(ComputerBooter):
async def boot(self, session_id: str) -> None:
logger.info(
"[Computer] booter_boot booter=boxlite session=%s status=starting",
session_id,
f"Booting(Boxlite) for session: {session_id}, this may take a while..."
)
random_port = random.randint(20000, 30000)
SimpleBox = vars(boxlite)["SimpleBox"]
self.box = SimpleBox(
self.box = boxlite.SimpleBox(
image="soulter/shipyard-ship",
memory_mib=512,
cpus=1,
@@ -186,51 +143,39 @@ class BoxliteBooter(ComputerBooter):
{
"host_port": random_port,
"guest_port": 8123,
},
}
],
)
await self.box.start()
logger.info(
"[Computer] booter_boot booter=boxlite session=%s status=ready ship_id=%s",
session_id,
self.box.id,
)
logger.info(f"Boxlite booter started for session: {session_id}")
self.mocked = MockShipyardSandboxClient(
sb_url=f"http://127.0.0.1:{random_port}",
)
raw_fs = ShipyardFileSystemComponent(
client=cast("Any", self.mocked),
ship_id=self.box.id,
session_id=session_id,
sb_url=f"http://127.0.0.1:{random_port}"
)
self._python = ShipyardPythonComponent(
client=cast("Any", self.mocked),
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
raw_shell = ShipyardShellComponent(
client=cast("Any", self.mocked),
self._shell = ShipyardShellComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
self._shell = ShipyardShellWrapper(cast("Any", raw_shell))
self._fs = ShipyardFileSystemWrapper(cast("Any", raw_fs), self._shell)
self._ship_fs = ShipyardFileSystemComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
self._fs = ShipyardFileSystemWrapper(
_shipyard_fs=self._ship_fs, _shipyard_shell=self._shell
)
await self.mocked.wait_healthy(self.box.id, session_id)
async def shutdown(self, **kwargs) -> None:
logger.info(
"[Computer] booter_shutdown booter=boxlite ship_id=%s status=starting",
self.box.id,
)
async def shutdown(self) -> None:
logger.info(f"Shutting down Boxlite booter for ship: {self.box.id}")
self.box.shutdown()
logger.info(
"[Computer] booter_shutdown booter=boxlite ship_id=%s status=done",
self.box.id,
)
async def available(self) -> bool:
return hasattr(self, "box")
logger.info(f"Boxlite booter for ship: {self.box.id} stopped")
@property
def fs(self) -> FileSystemComponent:
@@ -247,24 +192,3 @@ class BoxliteBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
"""Upload file to sandbox"""
return await self.mocked.upload_file(path, file_name)
@classmethod
@functools.cache
def _default_tools(cls) -> tuple[ToolSchema, ...]:
from astrbot.core.computer.tools import (
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
PythonTool,
)
return (
ExecuteShellTool(),
PythonTool(),
FileUploadTool(),
FileDownloadTool(),
)
@classmethod
def get_default_tools(cls) -> list[ToolSchema]:
return list(cls._default_tools())

View File

@@ -1,446 +0,0 @@
from __future__ import annotations
import asyncio
import locale
import os
import shlex
import shutil
import subprocess
import sys
from dataclasses import dataclass, field
from typing import Any
from astrbot.core.computer.olayer import (
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from astrbot.core.utils.astrbot_path import (
get_astrbot_temp_path,
)
from .base import ComputerBooter
def _decode_shell_output(output: bytes | None) -> str:
if output is None:
return ""
preferred = locale.getpreferredencoding(False) or "utf-8"
try:
return output.decode("utf-8")
except (LookupError, UnicodeDecodeError):
pass
try:
return output.decode(preferred)
except (LookupError, UnicodeDecodeError):
pass
return output.decode("utf-8", errors="replace")
def _write_file_sync(path: str, content: str, mode: str, encoding: str) -> None:
with open(path, mode, encoding=encoding) as f:
f.write(content)
def _read_file_sync(path: str, encoding: str) -> str:
with open(path, encoding=encoding) as f:
return f.read()
@dataclass
class BwrapConfig:
workspace_dir: str
ro_binds: list[str] = field(default_factory=list)
rw_binds: list[str] = field(default_factory=list)
share_net: bool = True
def __post_init__(self):
# Merge default required system binds with any additional ro_binds passed
default_ro = ["/usr", "/lib", "/lib64", "/bin", "/etc", "/opt"]
for p in default_ro:
if p not in self.ro_binds:
self.ro_binds.append(p)
def build_bwrap_cmd(config: BwrapConfig, script_cmd: list[str]) -> list[str]:
"""Helper to build a bubblewrap command."""
cmd = ["bwrap"]
if not config.share_net:
cmd.append("--unshare-net")
# Bind paths to itself so paths match
for path in config.ro_binds:
if os.path.exists(path):
cmd.extend(["--ro-bind", path, path])
for path in config.rw_binds:
# Avoid bind mounting dangerous host paths
if path == "/" or path.startswith("/root"):
continue
if os.path.exists(path):
cmd.extend(["--bind", path, path])
# Make system binds the last to avoid issues about ro `/`
cmd.extend(
[
"--unshare-pid",
"--unshare-ipc",
"--unshare-uts",
"--die-with-parent",
"--dir",
"/tmp",
"--dir",
"/var/tmp",
"--proc",
"/proc",
"--dev",
"/dev",
"--bind",
config.workspace_dir,
config.workspace_dir,
],
)
cmd.extend(["--"])
cmd.extend(script_cmd)
return cmd
@dataclass
class BwrapShellComponent(ShellComponent):
config: BwrapConfig
async def exec(
self,
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
_ = session_id
def _run() -> dict[str, Any]:
run_env = os.environ.copy()
if env:
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = cwd or self.config.workspace_dir
# Use /bin/sh -c to run the evaluated command
# The command must be run inside bwrap
script_cmd = ["/bin/sh", "-c", command] if shell else shlex.split(command)
bwrap_cmd = build_bwrap_cmd(self.config, script_cmd)
if background:
proc = subprocess.Popen(
bwrap_cmd,
cwd=working_dir,
env=run_env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
result = subprocess.run(
bwrap_cmd,
check=False,
cwd=working_dir,
env=run_env,
timeout=timeout,
capture_output=True,
)
return {
"stdout": _decode_shell_output(result.stdout),
"stderr": _decode_shell_output(result.stderr),
"exit_code": result.returncode,
}
return await asyncio.to_thread(_run)
@dataclass
class BwrapPythonComponent(PythonComponent):
config: BwrapConfig
async def exec(
self,
code: str,
kernel_id: str | None = None,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
bwrap_cmd = build_bwrap_cmd(
self.config,
[os.environ.get("PYTHON", "python3"), "-c", code],
)
try:
result = subprocess.run(
bwrap_cmd,
check=False,
timeout=timeout,
capture_output=True,
text=True,
)
stdout = "" if silent else result.stdout
return {
"stdout": stdout,
"stderr": result.stderr,
"exit_code": result.returncode,
}
except subprocess.TimeoutExpired as e:
return {
"stdout": e.stdout.decode()
if isinstance(e.stdout, bytes)
else str(e.stdout or ""),
"stderr": f"Execution timed out after {timeout} seconds.",
"exit_code": 1,
}
except Exception as e:
return {
"stdout": "",
"stderr": str(e),
"exit_code": 1,
}
return await asyncio.to_thread(_run)
@dataclass
class HostBackedFileSystemComponent(FileSystemComponent):
"""File operations happen safely on host mapping to workspace, making I/O extremely fast."""
workspace_dir: str
def _safe_path(self, path: str) -> str:
# Simply maps it. In a stricter implementation, we could verify it's inside workspace_dir.
# But for this implementation, we trust the agent or restrict to workspace_dir.
if not path.startswith("/"):
path = os.path.join(self.workspace_dir, path)
return path
async def create_file(
self,
path: str,
content: str = "",
mode: int = 0o644,
) -> dict[str, Any]:
p = self._safe_path(path)
await asyncio.to_thread(os.makedirs, os.path.dirname(p), exist_ok=True)
await asyncio.to_thread(_write_file_sync, p, content, "w", "utf-8")
await asyncio.to_thread(os.chmod, p, mode)
return {"success": True, "path": p}
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
_ = offset, limit
p = self._safe_path(path)
try:
content = await asyncio.to_thread(_read_file_sync, p, encoding)
return {"success": True, "content": content}
except Exception as e:
return {"success": False, "error": str(e)}
async def write_file(
self,
path: str,
content: str,
mode: str = "w",
encoding: str = "utf-8",
) -> dict[str, Any]:
p = self._safe_path(path)
await asyncio.to_thread(os.makedirs, os.path.dirname(p), exist_ok=True)
try:
await asyncio.to_thread(_write_file_sync, p, content, mode, encoding)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def delete_file(self, path: str) -> dict[str, Any]:
p = self._safe_path(path)
try:
if await asyncio.to_thread(os.path.isdir, p):
await asyncio.to_thread(shutil.rmtree, p)
else:
await asyncio.to_thread(os.remove, p)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def list_dir(
self,
path: str = ".",
show_hidden: bool = False,
) -> dict[str, Any]:
p = self._safe_path(path)
try:
items = os.listdir(p)
if not show_hidden:
items = [item for item in items if not item.startswith(".")]
return {"success": True, "items": items}
except Exception as e:
return {"success": False, "error": str(e), "items": []}
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
p = path or self.workspace_dir
try:
import subprocess
cmd = ["grep", "-r", pattern, p]
result = await asyncio.to_thread(
subprocess.run,
cmd,
capture_output=True,
text=True,
)
return {
"success": True,
"matches": result.stdout.splitlines() if result.stdout else [],
}
except Exception as e:
return {"success": False, "error": str(e), "matches": []}
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
p = self._safe_path(path)
try:
content = await asyncio.to_thread(_read_file_sync, p, encoding)
if replace_all:
new_content = content.replace(old_string, new_string)
else:
parts = content.split(old_string, 1)
if len(parts) == 1:
return {"success": False, "error": "Pattern not found"}
new_content = new_string.join(parts)
await asyncio.to_thread(_write_file_sync, p, new_content, "w", encoding)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
class BwrapBooter(ComputerBooter):
def __init__(
self,
rw_binds: list[str] | None = None,
ro_binds: list[str] | None = None,
):
self._rw_binds = rw_binds or []
self._ro_binds = ro_binds or []
self._fs: HostBackedFileSystemComponent | None = None
self._python: BwrapPythonComponent | None = None
self._shell: BwrapShellComponent | None = None
self.config: BwrapConfig | None = None
@property
def fs(self) -> FileSystemComponent:
if self._fs is None:
raise RuntimeError("BwrapBooter filesystem is unavailable before boot")
return self._fs
@property
def python(self) -> PythonComponent:
if self._python is None:
raise RuntimeError("BwrapBooter python is unavailable before boot")
return self._python
@property
def shell(self) -> ShellComponent:
if self._shell is None:
raise RuntimeError("BwrapBooter shell is unavailable before boot")
return self._shell
@property
def capabilities(self) -> tuple[str, ...]:
return ("python", "shell", "filesystem")
async def boot(self, session_id: str) -> None:
workspace_dir = os.path.join(
get_astrbot_temp_path(),
f"sandbox_workspace_{session_id}",
)
await asyncio.to_thread(os.makedirs, workspace_dir, exist_ok=True)
self.config = BwrapConfig(
workspace_dir=await asyncio.to_thread(os.path.abspath, workspace_dir),
rw_binds=self._rw_binds,
ro_binds=self._ro_binds,
)
self._fs = HostBackedFileSystemComponent(self.config.workspace_dir)
self._python = BwrapPythonComponent(self.config)
self._shell = BwrapShellComponent(self.config)
if not await self.available():
raise RuntimeError(
"BubbleWrap sandbox unavailable on current machine for no bwrap executable.",
)
test_shl = await self._shell.exec(command="ls > /dev/null")
if test_shl["exit_code"] != 0:
raise RuntimeError(
"""BubbleWrap sandbox fails to exec test shell command "ls > /dev/null" with stderr:
{}""".format(test_shl["stderr"]),
)
test_py = await self._python.exec(code="print('Yes')")
if test_py["exit_code"] != 0:
raise RuntimeError(
"""BubbleWrap sandbox fails to exec test python code "print('Yes')" with stderr:
{}""".format(test_py["stderr"]),
)
async def shutdown(self, **kwargs) -> None:
config = self.config
if config is None:
return
if await asyncio.to_thread(os.path.exists, config.workspace_dir):
await asyncio.to_thread(
shutil.rmtree,
config.workspace_dir,
ignore_errors=True,
)
async def upload_file(self, path: str, file_name: str) -> dict:
if not self._fs or not self.config:
return {"success": False, "error": "Not booted"}
target = os.path.join(self.config.workspace_dir, file_name)
try:
shutil.copy2(path, target)
return {"success": True, "file_path": target}
except Exception as e:
return {"success": False, "error": str(e)}
async def download_file(self, remote_path: str, local_path: str) -> None:
if not self._fs or not self.config:
return
if not remote_path.startswith("/"):
remote_path = os.path.join(self.config.workspace_dir, remote_path)
shutil.copy2(remote_path, local_path)
async def available(self) -> bool:
if sys.platform == "win32":
return False
if shutil.which("bwrap") is None:
return False
return True

View File

@@ -1,3 +0,0 @@
BOOTER_SHIPYARD = "shipyard"
BOOTER_SHIPYARD_NEO = "shipyard_neo"
BOOTER_BOXLITE = "boxlite"

View File

@@ -2,31 +2,21 @@ from __future__ import annotations
import asyncio
import base64
import importlib
import inspect
import shlex
from collections.abc import Callable
from dataclasses import asdict, dataclass, is_dataclass
from pathlib import Path
from typing import Any
from astrbot.api import logger
from astrbot.core.computer.booters.base import ComputerBooter
from astrbot.core.computer.booters.cua_defaults import (
CUA_CONFIG_KEYS,
CUA_DEFAULT_CONFIG,
)
from astrbot.core.computer.booters.shipyard_search_file_util import (
search_files_via_shell,
)
from astrbot.core.computer.olayer import (
FileSystemComponent,
GUIComponent,
PythonComponent,
ShellComponent,
)
from ..olayer import FileSystemComponent, GUIComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .cua_defaults import CUA_CONFIG_KEYS, CUA_DEFAULT_CONFIG
from .shipyard_search_file_util import search_files_via_shell
_POSIX_OS_TYPES = {"linux", "darwin", "macos"}
_CUA_SANDBOX_HEALTH_PROBE = "_astrbot_cua_ok_"
_CUA_BACKGROUND_LAUNCHER = """
import subprocess, sys, time
@@ -67,10 +57,18 @@ async def _write_base64_via_shell(
encoded = base64.b64encode(data).decode("ascii")
decoder = (
"import base64,pathlib,sys; "
"pathlib.Path(sys.argv[1]).write_bytes(base64.b64decode(sys.stdin.read()))"
"path=pathlib.Path(sys.argv[1]); "
"path.parent.mkdir(parents=True, exist_ok=True); "
"path.write_bytes(base64.b64decode(sys.stdin.read()))"
)
chunk_size = 60_000
encoded_lines = "\n".join(
encoded[index : index + chunk_size]
for index in range(0, len(encoded), chunk_size)
)
return await shell.exec(
f"python3 -c {shlex.quote(decoder)} {shlex.quote(path)} <<'EOF'\n{encoded}\nEOF",
f"python3 -c {shlex.quote(decoder)} {shlex.quote(path)} <<'EOF'\n"
f"{encoded_lines}\nEOF"
)
@@ -92,9 +90,9 @@ def _maybe_model_dump(value: Any) -> dict[str, Any]:
dumped = model_dump()
if isinstance(dumped, dict):
return dumped
dict_method = getattr(value, "dict", None)
if callable(dict_method):
dumped = dict_method()
dict_attr = getattr(value, "dict", None)
if callable(dict_attr):
dumped = dict_attr()
if isinstance(dumped, dict):
return dumped
attr_payload = {
@@ -143,15 +141,18 @@ def _normalize_process_result(raw: Any) -> ProcessResult:
stdout = first_text("stdout", "output")
stderr = first_text("stderr", "error")
raw_exit_code = payload.get("exit_code")
if raw_exit_code is None:
raw_exit_code = payload.get("returncode")
if raw_exit_code is None:
raw_exit_code = payload.get("return_code")
if raw_exit_code is None:
exit_code = payload.get("exit_code")
if exit_code is None:
exit_code = payload.get("returncode")
if exit_code is None:
exit_code = payload.get("return_code")
if exit_code is not None:
try:
exit_code = int(exit_code)
except Exception:
exit_code = None
if exit_code is None:
exit_code = 0 if not stderr else 1
else:
exit_code = int(raw_exit_code)
success = bool(payload.get("success", not stderr and exit_code in (0, None)))
return ProcessResult(
stdout=stdout,
@@ -239,7 +240,7 @@ def _missing_component_method_error(
candidates = ", ".join(f"{component_name}.{name}" for name in names)
return RuntimeError(
f"CUA sandbox does not provide any of: {candidates}. "
"Please check the installed CUA SDK version and sandbox backend.",
"Please check the installed CUA SDK version and sandbox backend."
)
@@ -290,10 +291,9 @@ class CuaShellComponent(ShellComponent):
self._sandbox = sandbox
self._os_type = os_type.lower()
shell = sandbox.shell
exec_raw = getattr(shell, "exec", None) or getattr(shell, "run", None)
if exec_raw is None:
self._exec_raw = getattr(shell, "exec", None) or getattr(shell, "run", None)
if self._exec_raw is None:
raise RuntimeError("CUA sandbox shell must provide `.exec` or `.run`.")
self._exec_raw: Callable[..., Any] = exec_raw
async def exec(
self,
@@ -303,9 +303,7 @@ class CuaShellComponent(ShellComponent):
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
_ = session_id
if not shell:
return {
"stdout": "",
@@ -363,9 +361,7 @@ class CuaPythonComponent(PythonComponent):
self._python_exec = None
if python is not None:
self._python_exec = getattr(python, "exec", None) or getattr(
python,
"run",
None,
python, "run", None
)
async def exec(
@@ -413,19 +409,9 @@ def _write_result(path: str, result: dict[str, Any]) -> dict[str, Any]:
return {"success": True, "path": path, **result}
CUA_DEFAULT_IMAGE = str(CUA_DEFAULT_CONFIG["image"])
CUA_DEFAULT_OS_TYPE = str(CUA_DEFAULT_CONFIG["os_type"])
CUA_DEFAULT_TTL = int(CUA_DEFAULT_CONFIG["ttl"])
CUA_DEFAULT_TELEMETRY_ENABLED = bool(CUA_DEFAULT_CONFIG["telemetry_enabled"])
CUA_DEFAULT_LOCAL = bool(CUA_DEFAULT_CONFIG["local"])
CUA_DEFAULT_API_KEY = str(CUA_DEFAULT_CONFIG["api_key"])
class CuaFileSystemComponent(FileSystemComponent):
def __init__(
self,
sandbox: Any,
os_type: str = CUA_DEFAULT_OS_TYPE,
self, sandbox: Any, os_type: str = CUA_DEFAULT_CONFIG["os_type"]
) -> None:
self._shell = CuaShellComponent(sandbox, os_type=os_type)
self._fs_components = _resolve_files_components(sandbox)
@@ -451,21 +437,19 @@ class CuaFileSystemComponent(FileSystemComponent):
limit: int | None = None,
) -> dict[str, Any]:
read_file = _resolve_files_method(
self._fs_components,
("read_file", "read_text"),
self._fs_components, ("read_file", "read_text")
)
if read_file is None:
return await self._fallback.read_file(path, encoding, offset, limit)
content = await _maybe_await(read_file(path))
else:
content = await _maybe_await(read_file(path))
if isinstance(content, bytes):
content = content.decode(encoding, errors="replace")
return {
"success": True,
"path": path,
"content": _slice_content_by_lines(
str(content),
offset=offset,
limit=limit,
str(content), offset=offset, limit=limit
),
}
@@ -478,22 +462,22 @@ class CuaFileSystemComponent(FileSystemComponent):
) -> dict[str, Any]:
_ = mode
write_file = _resolve_files_method(
self._fs_components,
("write_file", "write_text"),
self._fs_components, ("write_file", "write_text")
)
if write_file is None:
return await self._fallback.write_file(path, content, mode, encoding)
await _maybe_await(write_file(path, content))
else:
await _maybe_await(write_file(path, content))
return {"success": True, "path": path}
async def delete_file(self, path: str) -> dict[str, Any]:
delete = _resolve_files_method(
self._fs_components,
("delete", "delete_file", "remove"),
self._fs_components, ("delete", "delete_file", "remove")
)
if delete is None:
return await self._fallback.delete_file(path)
await _maybe_await(delete(path))
else:
await _maybe_await(delete(path))
return {"success": True, "path": path}
async def list_dir(
@@ -563,26 +547,6 @@ class _PosixShellFileSystem(FileSystemComponent):
return None
return _non_posix_filesystem_result(path, self._os_type)
async def create_file(
self,
path: str,
content: str = "",
mode: int = 0o644,
) -> dict[str, Any]:
write_result = await self.write_file(path, content)
if not write_result.get("success"):
return {**write_result, "mode": mode, "mode_applied": False}
chmod_result = await self._shell.exec(f"chmod {mode:o} {shlex.quote(path)}")
if chmod_result.get("stderr"):
return {
"success": True,
"path": path,
"mode": mode,
"mode_applied": False,
"mode_error": chmod_result["stderr"],
}
return {"success": True, "path": path, "mode": mode, "mode_applied": True}
async def read_file(
self,
path: str,
@@ -600,9 +564,7 @@ class _PosixShellFileSystem(FileSystemComponent):
"success": True,
"path": path,
"content": _slice_content_by_lines(
str(result.get("stdout", "")),
offset=offset,
limit=limit,
str(result.get("stdout", "")), offset=offset, limit=limit
),
}
@@ -617,9 +579,7 @@ class _PosixShellFileSystem(FileSystemComponent):
if error := self._ensure_posix(path):
return error
result = await _write_base64_via_shell(
self._shell,
path,
content.encode(encoding),
self._shell, path, content.encode(encoding)
)
return _write_result(path, result)
@@ -660,36 +620,6 @@ class _PosixShellFileSystem(FileSystemComponent):
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
read_result = await self.read_file(path, encoding=encoding)
if not read_result.get("success"):
return read_result
content = str(read_result.get("content", ""))
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"path": path,
"error": "old string not found in file",
"replacements": 0,
}
updated = content.replace(old_string, new_string, -1 if replace_all else 1)
write_result = await self.write_file(path, updated, encoding=encoding)
if not write_result.get("success"):
return write_result
return {
"success": True,
"path": path,
"replacements": occurrences if replace_all else 1,
}
async def _list_dir_via_shell(
shell: CuaShellComponent,
@@ -715,17 +645,15 @@ class CuaGUIComponent(GUIComponent):
self._click = _resolve_component_method(mouse, "click")
self._type_text = _resolve_component_method(keyboard, "type")
self._press_key = _resolve_component_method(
keyboard,
("press", "key_press", "press_key"),
keyboard, ("press", "key_press", "press_key")
)
async def screenshot(self, path: str | None = None) -> dict[str, Any]:
raw = await self._sandbox.screenshot()
data = _screenshot_to_bytes(raw)
if path:
_p = Path(path)
await asyncio.to_thread(_p.parent.mkdir, parents=True, exist_ok=True)
await asyncio.to_thread(_p.write_bytes, data)
Path(path).parent.mkdir(parents=True, exist_ok=True)
Path(path).write_bytes(data)
return {
"success": True,
"path": path,
@@ -750,8 +678,7 @@ class CuaGUIComponent(GUIComponent):
async def press_key(self, key: str) -> dict[str, Any]:
if self._press_key is None:
raise _missing_component_method_error(
"keyboard",
("press", "key_press", "press_key"),
"keyboard", ("press", "key_press", "press_key")
)
result = await _maybe_await(self._press_key(key))
payload = _maybe_model_dump(result)
@@ -801,12 +728,12 @@ class _CuaRuntime:
class CuaBooter(ComputerBooter):
def __init__(
self,
image: str = CUA_DEFAULT_IMAGE,
os_type: str = CUA_DEFAULT_OS_TYPE,
ttl: int = CUA_DEFAULT_TTL,
telemetry_enabled: bool = CUA_DEFAULT_TELEMETRY_ENABLED,
local: bool = CUA_DEFAULT_LOCAL,
api_key: str = CUA_DEFAULT_API_KEY,
image: str = CUA_DEFAULT_CONFIG["image"],
os_type: str = CUA_DEFAULT_CONFIG["os_type"],
ttl: int = CUA_DEFAULT_CONFIG["ttl"],
telemetry_enabled: bool = CUA_DEFAULT_CONFIG["telemetry_enabled"],
local: bool = CUA_DEFAULT_CONFIG["local"],
api_key: str = CUA_DEFAULT_CONFIG["api_key"],
) -> None:
self.image = image
self.os_type = os_type
@@ -819,16 +746,13 @@ class CuaBooter(ComputerBooter):
async def boot(self, session_id: str) -> None:
_ = session_id
try:
cua_module = importlib.import_module("cua")
from cua import Image, Sandbox
except ImportError as exc:
raise RuntimeError(
"CUA sandbox support requires the optional `cua` package. "
"Install it with `pip install cua` in the AstrBot environment.",
"Install it with `pip install cua` in the AstrBot environment."
) from exc
Image = vars(cua_module)["Image"]
Sandbox = vars(cua_module)["Sandbox"]
image_obj = self._build_image(Image)
ephemeral_kwargs = self._build_ephemeral_kwargs(Sandbox.ephemeral)
sandbox_cm = Sandbox.ephemeral(image_obj, **ephemeral_kwargs)
@@ -878,7 +802,7 @@ class CuaBooter(ComputerBooter):
kwargs["api_key"] = self.api_key
return kwargs
async def shutdown(self, **kwargs) -> None:
async def shutdown(self) -> None:
if self._runtime is not None:
await self._runtime.sandbox_cm.__aexit__(None, None, None)
self._runtime = None
@@ -927,12 +851,12 @@ class CuaBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
local_path = Path(path)
if not await asyncio.to_thread(local_path.is_file):
if not local_path.is_file():
return {"success": False, "error": f"File not found: {path}"}
sandbox = None if self._runtime is None else self._runtime.sandbox
if sandbox is not None and hasattr(sandbox, "upload_file"):
return _maybe_model_dump(
await sandbox.upload_file(str(local_path), file_name),
await sandbox.upload_file(str(local_path), file_name)
)
files_components = () if sandbox is None else _resolve_files_components(sandbox)
upload = _resolve_files_method(files_components, "upload")
@@ -941,16 +865,12 @@ class CuaBooter(ComputerBooter):
return _normalize_native_upload_result(result, file_name)
write_bytes = _resolve_files_method(files_components, "write_bytes")
if write_bytes is not None:
data = await asyncio.to_thread(local_path.read_bytes)
result = await _maybe_await(write_bytes(file_name, data))
result = await _maybe_await(write_bytes(file_name, local_path.read_bytes()))
return _normalize_native_upload_result(result, file_name)
if not _is_posix_os_type(self.os_type):
return _non_posix_filesystem_result(file_name, self.os_type)
data = await asyncio.to_thread(local_path.read_bytes)
result = await _write_base64_via_shell(
self.shell,
file_name,
data,
self.shell, file_name, local_path.read_bytes()
)
return {
"success": not bool(result.get("stderr")),
@@ -968,12 +888,21 @@ class CuaBooter(ComputerBooter):
result = await self.shell.exec(f"base64 {shlex.quote(remote_path)}")
if result.get("stderr"):
raise RuntimeError(result["stderr"])
_p = Path(local_path)
await asyncio.to_thread(_p.parent.mkdir, parents=True, exist_ok=True)
await asyncio.to_thread(
_p.write_bytes,
base64.b64decode(result.get("stdout", "")),
)
Path(local_path).parent.mkdir(parents=True, exist_ok=True)
Path(local_path).write_bytes(base64.b64decode(result.get("stdout", "")))
async def available(self) -> bool:
return self._runtime is not None
if self._runtime is None:
return False
try:
result = await self._runtime.shell.exec(
f"echo {_CUA_SANDBOX_HEALTH_PROBE}", timeout=10
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.debug("[Computer] CUA sandbox health check failed: %s", exc)
return False
if result.get("exit_code") != 0:
return False
return _CUA_SANDBOX_HEALTH_PROBE in str(result.get("stdout", ""))

View File

@@ -9,22 +9,18 @@ import sys
from dataclasses import dataclass
from typing import Any
from astrbot.api import logger
from astrbot.core.computer.olayer import (
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from astrbot.core.computer.shell_session import PersistentShellSession
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_root,
get_astrbot_temp_path,
get_astrbot_workspaces_path,
)
from python_ripgrep import search
from astrbot.api import logger
from astrbot.core.computer.file_read_utils import (
detect_text_encoding,
read_local_text_range_sync,
)
from astrbot.core.utils.astrbot_path import get_astrbot_root
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .bwrap import _decode_shell_output
from .shipyard_search_file_util import _truncate_long_lines
_BLOCKED_COMMAND_PATTERNS = [
" rm -rf ",
@@ -48,19 +44,6 @@ def _is_safe_command(command: str) -> bool:
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)
def _ensure_safe_path(path: str) -> str:
abs_path = os.path.abspath(path)
allowed_roots = [
os.path.abspath(get_astrbot_root()),
os.path.abspath(get_astrbot_data_path()),
os.path.abspath(get_astrbot_temp_path()),
os.path.abspath(get_astrbot_workspaces_path()),
]
if not any(abs_path.startswith(root) for root in allowed_roots):
raise PermissionError("Path is outside the allowed computer roots.")
return abs_path
def _decode_bytes_with_fallback(
output: bytes | None,
*,
@@ -96,6 +79,10 @@ def _decode_bytes_with_fallback(
return output.decode("utf-8", errors="replace")
def _decode_shell_output(output: bytes | None) -> str:
return _decode_bytes_with_fallback(output, preferred_encoding="utf-8")
@dataclass
class LocalShellComponent(ShellComponent):
async def exec(
@@ -106,24 +93,71 @@ class LocalShellComponent(ShellComponent):
timeout: int | None = 300,
shell: bool = True,
background: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
if not _is_safe_command(command):
raise PermissionError("Blocked unsafe shell command.")
key = session_id or "default"
session = PersistentShellSession.get_or_create(key)
return await session.exec(
command,
cwd=cwd,
env=env,
timeout=timeout,
background=background,
)
def _run() -> dict[str, Any]:
run_env = os.environ.copy()
if env:
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = os.path.abspath(cwd) if cwd else get_astrbot_root()
if background:
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
command,
shell=shell,
cwd=working_dir,
env=run_env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
command,
shell=shell,
cwd=working_dir,
env=run_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
stdout, stderr = proc.communicate(timeout=timeout or 300)
except subprocess.TimeoutExpired:
should_kill_parent = sys.platform != "win32"
if sys.platform == "win32":
try:
taskkill_result = subprocess.run(
["taskkill", "/F", "/T", "/PID", str(proc.pid)],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
timeout=5,
)
should_kill_parent = taskkill_result.returncode != 0
except Exception:
should_kill_parent = True
if should_kill_parent:
try:
proc.kill()
except Exception:
pass
try:
proc.wait(timeout=5)
except Exception:
pass
raise
return {
"stdout": _decode_shell_output(stdout),
"stderr": _decode_shell_output(stderr),
"exit_code": proc.returncode,
}
@staticmethod
async def shutdown_all() -> None:
await PersistentShellSession.cleanup_all()
return await asyncio.to_thread(_run)
@dataclass
@@ -141,7 +175,6 @@ class LocalPythonComponent(PythonComponent):
working_dir = os.path.abspath(cwd) if cwd else get_astrbot_root()
result = subprocess.run(
[os.environ.get("PYTHON", sys.executable), "-c", code],
check=False,
timeout=timeout,
capture_output=True,
cwd=working_dir,
@@ -156,14 +189,14 @@ class LocalPythonComponent(PythonComponent):
"data": {
"output": {"text": stdout, "images": []},
"error": stderr,
},
}
}
except subprocess.TimeoutExpired:
return {
"data": {
"output": {"text": "", "images": []},
"error": "Execution timed out.",
},
}
}
return await asyncio.to_thread(_run)
@@ -172,13 +205,10 @@ class LocalPythonComponent(PythonComponent):
@dataclass
class LocalFileSystemComponent(FileSystemComponent):
async def create_file(
self,
path: str,
content: str = "",
mode: int = 0o644,
self, path: str, content: str = "", mode: int = 0o644
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
@@ -195,25 +225,21 @@ class LocalFileSystemComponent(FileSystemComponent):
limit: int | None = None,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
with open(abs_path, "rb") as f:
raw_content = f.read()
content = _decode_bytes_with_fallback(
raw_content,
preferred_encoding=encoding,
)
if offset is not None:
lines = content.splitlines(keepends=True)
start = offset
if limit is not None:
lines = lines[start : start + limit]
else:
lines = lines[start:]
content = "".join(lines)
elif limit is not None:
lines = content.splitlines(keepends=True)[:limit]
content = "".join(lines)
return {"success": True, "content": content}
abs_path = os.path.abspath(path)
detected_encoding = encoding
if encoding == "utf-8":
with open(abs_path, "rb") as f:
raw_sample = f.read(8192)
detected_encoding = detect_text_encoding(raw_sample) or encoding
return {
"success": True,
"content": read_local_text_range_sync(
abs_path,
encoding=detected_encoding,
offset=offset,
limit=limit,
),
}
return await asyncio.to_thread(_run)
@@ -225,36 +251,16 @@ class LocalFileSystemComponent(FileSystemComponent):
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
"""Search file contents using grep-like pattern matching."""
def _run() -> dict[str, Any]:
search_path = _ensure_safe_path(path) if path else "."
cmd = ["grep", "-rn", pattern, search_path]
if after_context is not None:
cmd.extend(["-A", str(after_context)])
if before_context is not None:
cmd.extend(["-B", str(before_context)])
if glob:
cmd.extend(["--include", glob])
try:
result = subprocess.run(
cmd,
check=False,
capture_output=True,
text=True,
timeout=30,
)
return {
"success": True,
"content": result.stdout,
"error": result.stderr if result.returncode != 0 else "",
}
except subprocess.TimeoutExpired:
return {
"success": False,
"output": "",
"error": "Search timed out.",
}
results = search(
patterns=[pattern],
paths=[path] if path else None,
globs=[glob] if glob else None,
after_context=after_context,
before_context=before_context,
line_number=True,
)
return {"success": True, "content": _truncate_long_lines("".join(results))}
return await asyncio.to_thread(_run)
@@ -267,33 +273,37 @@ class LocalFileSystemComponent(FileSystemComponent):
encoding: str = "utf-8",
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
with open(abs_path, encoding=encoding) as f:
content = f.read()
if replace_all:
new_content = content.replace(old_string, new_string)
else:
new_content = content.replace(old_string, new_string, 1)
if new_content == content:
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"error": f"String '{old_string}' not found in file.",
"error": "old string not found in file",
"replacements": 0,
}
if replace_all:
updated = content.replace(old_string, new_string)
replacements = occurrences
else:
updated = content.replace(old_string, new_string, 1)
replacements = 1
with open(abs_path, "w", encoding=encoding) as f:
f.write(new_content)
return {"success": True, "path": abs_path}
f.write(updated)
return {
"success": True,
"path": abs_path,
"replacements": replacements,
}
return await asyncio.to_thread(_run)
async def write_file(
self,
path: str,
content: str,
mode: str = "w",
encoding: str = "utf-8",
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, mode, encoding=encoding) as f:
f.write(content)
@@ -303,7 +313,7 @@ class LocalFileSystemComponent(FileSystemComponent):
async def delete_file(self, path: str) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
if os.path.isdir(abs_path):
shutil.rmtree(abs_path)
else:
@@ -313,12 +323,10 @@ class LocalFileSystemComponent(FileSystemComponent):
return await asyncio.to_thread(_run)
async def list_dir(
self,
path: str = ".",
show_hidden: bool = False,
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = _ensure_safe_path(path)
abs_path = os.path.abspath(path)
entries = os.listdir(abs_path)
if not show_hidden:
entries = [e for e in entries if not e.startswith(".")]
@@ -336,8 +344,7 @@ class LocalBooter(ComputerBooter):
async def boot(self, session_id: str) -> None:
logger.info(f"Local computer booter initialized for session: {session_id}")
async def shutdown(self, **kwargs) -> None:
await LocalShellComponent.shutdown_all()
async def shutdown(self) -> None:
logger.info("Local computer booter shutdown complete.")
@property
@@ -354,12 +361,12 @@ class LocalBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
raise NotImplementedError(
"LocalBooter does not support upload_file operation. Use shell instead.",
"LocalBooter does not support upload_file operation. Use shell instead."
)
async def download_file(self, remote_path: str, local_path: str) -> None:
raise NotImplementedError(
"LocalBooter does not support download_file operation. Use shell instead.",
"LocalBooter does not support download_file operation. Use shell instead."
)
async def available(self) -> bool:

View File

@@ -1,22 +1,14 @@
from __future__ import annotations
import functools
import shlex
from typing import TYPE_CHECKING, Any
from typing import Any
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
from shipyard import ShipyardClient, Spec
from astrbot.api import logger
if TYPE_CHECKING:
from astrbot.core.agent.tool import ToolSchema
from astrbot.core.computer.olayer import (
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shell_background import build_detached_shell_command
from .shipyard_search_file_util import search_files_via_shell
@@ -44,9 +36,7 @@ class ShipyardShellWrapper:
timeout: int | None = 300,
shell: bool = True,
background: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
_ = session_id
if not shell:
return {
"stdout": "",
@@ -109,18 +99,13 @@ class ShipyardShellWrapper:
class ShipyardFileSystemWrapper:
def __init__(
self,
_shipyard_fs: FileSystemComponent,
_shipyard_shell: ShellComponent,
self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent
):
self._fs = _shipyard_fs
self._shell = _shipyard_shell
async def create_file(
self,
path: str,
content: str = "",
mode: int = 420,
self, path: str, content: str = "", mode: int = 420
) -> dict[str, Any]:
return await self._fs.create_file(path=path, content=content, mode=mode)
@@ -132,30 +117,18 @@ class ShipyardFileSystemWrapper:
limit: int | None = None,
) -> dict[str, Any]:
return await self._fs.read_file(
path=path,
encoding=encoding,
offset=offset,
limit=limit,
path=path, encoding=encoding, offset=offset, limit=limit
)
async def write_file(
self,
path: str,
content: str,
mode: str = "w",
encoding: str = "utf-8",
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
return await self._fs.write_file(
path=path,
content=content,
mode=mode,
encoding=encoding,
path=path, content=content, mode=mode, encoding=encoding
)
async def list_dir(
self,
path: str = ".",
show_hidden: bool = False,
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
return await self._fs.list_dir(path=path, show_hidden=show_hidden)
@@ -197,27 +170,6 @@ class ShipyardFileSystemWrapper:
class ShipyardBooter(ComputerBooter):
@classmethod
@functools.cache
def _default_tools(cls) -> tuple[ToolSchema, ...]:
from astrbot.core.computer.tools import (
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
PythonTool,
)
return (
ExecuteShellTool(),
PythonTool(),
FileUploadTool(),
FileDownloadTool(),
)
@classmethod
def get_default_tools(cls) -> list[ToolSchema]:
return list(cls._default_tools())
def __init__(
self,
endpoint_url: str,
@@ -226,8 +178,7 @@ class ShipyardBooter(ComputerBooter):
session_num: int = 10,
) -> None:
self._sandbox_client = ShipyardClient(
endpoint_url=endpoint_url,
access_token=access_token,
endpoint_url=endpoint_url, access_token=access_token
)
self._ttl = ttl
self._session_num = session_num
@@ -239,21 +190,17 @@ class ShipyardBooter(ComputerBooter):
max_session_num=self._session_num,
session_id=session_id,
)
logger.info(
"[Computer] sandbox_created booter=shipyard ship_id=%s session=%s",
ship.id,
session_id,
)
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
self._ship = ship
self._shell = ShipyardShellWrapper(self._ship.shell) # type: ignore[arg-type]
self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._shell) # type: ignore[arg-type]
self._shell = ShipyardShellWrapper(self._ship.shell)
self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._shell)
async def shutdown(self, **kwargs) -> None:
logger.info("[Computer] booter_shutdown booter=shipyard status=done")
async def shutdown(self) -> None:
logger.info("[Computer] Shipyard booter shutdown.")
@property
def fs(self) -> FileSystemComponent:
return self._ship.fs # type: ignore[return-value]
return self._fs
@property
def python(self) -> PythonComponent:
@@ -266,17 +213,14 @@ class ShipyardBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
"""Upload file to sandbox"""
result = await self._ship.upload_file(path, file_name)
logger.info(
"[Computer] file_upload booter=shipyard remote_path=%s",
file_name,
)
logger.info("[Computer] File uploaded to Shipyard sandbox: %s", file_name)
return result
async def download_file(self, remote_path: str, local_path: str):
"""Download file from sandbox."""
result = await self._ship.download_file(remote_path, local_path)
logger.info(
"[Computer] file_download booter=shipyard remote_path=%s local_path=%s",
"[Computer] File downloaded from Shipyard sandbox: %s -> %s",
remote_path,
local_path,
)
@@ -288,21 +232,18 @@ class ShipyardBooter(ComputerBooter):
ship_id = self._ship.id
data = await self._sandbox_client.get_ship(ship_id)
if not data:
logger.debug(
"[Computer] health_check booter=shipyard ship_id=%s healthy=false reason=no_data",
logger.info(
"[Computer] Shipyard sandbox health check: id=%s, healthy=False (no data)",
ship_id,
)
return False
health = bool(data.get("status", 0) == 1)
logger.debug(
"[Computer] health_check booter=shipyard ship_id=%s healthy=%s",
logger.info(
"[Computer] Shipyard sandbox health check: id=%s, healthy=%s",
ship_id,
health,
)
return health
except Exception:
logger.exception(
"[Computer] health_check_failed booter=shipyard ship_id=%s",
getattr(getattr(self, "_ship", None), "id", "unknown"),
)
except Exception as e:
logger.error(f"Error checking Shipyard sandbox availability: {e}")
return False

View File

@@ -1,42 +1,34 @@
from __future__ import annotations
import asyncio
import functools
import os
import shlex
from typing import TYPE_CHECKING, Any
import anyio
from typing import Any, cast
from astrbot.api import logger
if TYPE_CHECKING:
from astrbot.core.agent.tool import ToolSchema
from astrbot.core.computer.booters.base import ComputerBooter
from astrbot.core.computer.olayer import (
from ..olayer import (
BrowserComponent,
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from .base import ComputerBooter
from .shell_background import build_detached_shell_command
from .shipyard_search_file_util import search_files_via_shell
try:
from shipyard_neo import BayClient # noqa: F401
from shipyard_neo import BayClient
from shipyard_neo.sandbox import Sandbox
except ImportError:
logger.warning(
"shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it.",
"shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it."
)
def _maybe_model_dump(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
out: dict[str, Any] = {}
out.update(value)
return out
return value
if hasattr(value, "model_dump"):
dumped = value.model_dump()
if isinstance(dumped, dict):
@@ -44,34 +36,51 @@ def _maybe_model_dump(value: Any) -> dict[str, Any]:
return {}
def _slice_content_by_lines(
content: str,
*,
offset: int | None = None,
limit: int | None = None,
) -> str:
lines = content.splitlines(keepends=True)
start = 0 if offset is None else offset
selected = lines[start:] if limit is None else lines[start : start + limit]
return "".join(selected)
class NeoPythonComponent(PythonComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
async def exec(
self,
code: str,
kernel_id: str | None = None,
timeout: int = 30, # noqa: ASYNC109
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
_ = kernel_id
with anyio.fail_after(timeout):
result = await self._sandbox.python.exec(code)
_ = kernel_id # Bay runtime does not expose kernel_id in current SDK.
result = await self._sandbox.python.exec(code, timeout=timeout)
payload = _maybe_model_dump(result)
output_text = payload.get("output", "") or ""
error_text = payload.get("error", "") or ""
data = payload.get("data") if isinstance(payload.get("data"), dict) else {}
rich_output = data.get("output") or {} if isinstance(data, dict) else {}
rich_output = data.get("output") if isinstance(data.get("output"), dict) else {}
if not isinstance(rich_output.get("images"), list):
rich_output["images"] = []
if "text" not in rich_output:
rich_output["text"] = output_text
if silent:
rich_output["text"] = ""
return {
"success": bool(payload.get("success", error_text == "")),
"data": {"output": rich_output, "error": error_text},
"data": {
"output": rich_output,
"error": error_text,
},
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"code": payload.get("code"),
@@ -81,7 +90,7 @@ class NeoPythonComponent(PythonComponent):
class NeoShellComponent(ShellComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
async def exec(
@@ -89,12 +98,10 @@ class NeoShellComponent(ShellComponent):
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout: int | None = 300, # noqa: ASYNC109
timeout: int | None = 300,
shell: bool = True,
background: bool = False,
session_id: str | None = None,
) -> dict[str, Any]:
_ = session_id
if not shell:
return {
"stdout": "",
@@ -102,12 +109,14 @@ class NeoShellComponent(ShellComponent):
"exit_code": 2,
"success": False,
}
run_command = command
if env:
env_prefix = " ".join(
(f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items())),
f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items())
)
run_command = f"{env_prefix} {run_command}"
if background:
run_command = build_detached_shell_command(run_command)
@@ -117,6 +126,7 @@ class NeoShellComponent(ShellComponent):
cwd=cwd,
)
payload = _maybe_model_dump(result)
stdout = payload.get("output", "") or ""
stderr = payload.get("error", "") or ""
exit_code = payload.get("exit_code")
@@ -140,6 +150,7 @@ class NeoShellComponent(ShellComponent):
"execution_time_ms": payload.get("execution_time_ms"),
"command": payload.get("command"),
}
return {
"stdout": stdout,
"stderr": stderr,
@@ -152,7 +163,7 @@ class NeoShellComponent(ShellComponent):
class NeoFileSystemComponent(FileSystemComponent):
def __init__(self, sandbox: Any, shell: Any | None = None) -> None:
def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None:
self._sandbox = sandbox
self._shell = shell
@@ -160,7 +171,7 @@ class NeoFileSystemComponent(FileSystemComponent):
self,
path: str,
content: str = "",
mode: int = 420,
mode: int = 0o644,
) -> dict[str, Any]:
_ = mode
await self._sandbox.filesystem.write_file(path, content)
@@ -175,13 +186,62 @@ class NeoFileSystemComponent(FileSystemComponent):
) -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
text = str(content)
if offset is not None or limit is not None:
lines = text.splitlines(keepends=True)
start = 0 if offset is None else offset
selected = lines[start:] if limit is None else lines[start : start + limit]
text = "".join(selected)
return {"success": True, "path": path, "content": text}
return {
"success": True,
"path": path,
"content": _slice_content_by_lines(
content,
offset=offset,
limit=limit,
),
}
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
return await search_files_via_shell(
self._shell,
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"error": "old string not found in file",
"replacements": 0,
}
if replace_all:
updated = content.replace(old_string, new_string)
replacements = occurrences
else:
updated = content.replace(old_string, new_string, 1)
replacements = 1
await self._sandbox.filesystem.write_file(path, updated)
return {
"success": True,
"path": path,
"replacements": replacements,
}
async def write_file(
self,
@@ -213,66 +273,15 @@ class NeoFileSystemComponent(FileSystemComponent):
data.append(item)
return {"success": True, "path": path, "entries": data}
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
if self._shell is None:
raise RuntimeError(
"NeoFileSystemComponent requires a shell for search_files.",
)
return await search_files_via_shell(
self._shell,
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
read_result = await self.read_file(path, encoding=encoding)
if not read_result.get("success"):
return read_result
content = str(read_result.get("content", ""))
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"path": path,
"error": "old string not found in file",
"replacements": 0,
}
updated = content.replace(old_string, new_string, -1 if replace_all else 1)
write_result = await self.write_file(path, updated, encoding=encoding)
if not write_result.get("success"):
return write_result
return {
"success": True,
"path": path,
"replacements": occurrences if replace_all else 1,
}
class NeoBrowserComponent(BrowserComponent):
def __init__(self, sandbox: Any) -> None:
def __init__(self, sandbox: Sandbox) -> None:
self._sandbox = sandbox
async def exec(
self,
cmd: str,
timeout_seconds: int = 30,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
@@ -280,7 +289,7 @@ class NeoBrowserComponent(BrowserComponent):
) -> dict[str, Any]:
result = await self._sandbox.browser.exec(
cmd,
timeout_seconds=timeout_seconds,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
@@ -291,7 +300,7 @@ class NeoBrowserComponent(BrowserComponent):
async def exec_batch(
self,
commands: list[str],
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
@@ -300,7 +309,7 @@ class NeoBrowserComponent(BrowserComponent):
) -> dict[str, Any]:
result = await self._sandbox.browser.exec_batch(
commands,
timeout_seconds=timeout_seconds,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
@@ -312,7 +321,7 @@ class NeoBrowserComponent(BrowserComponent):
async def run_skill(
self,
skill_key: str,
timeout_seconds: int = 60,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
@@ -320,7 +329,7 @@ class NeoBrowserComponent(BrowserComponent):
) -> dict[str, Any]:
result = await self._sandbox.browser.run_skill(
skill_key=skill_key,
timeout_seconds=timeout_seconds,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
@@ -351,9 +360,9 @@ class ShipyardNeoBooter(ComputerBooter):
self._access_token = access_token
self._profile = profile.strip() if profile else ""
self._ttl = ttl
self._client: Any = None
self._sandbox: Any = None
self._bay_manager: Any = None
self._client: BayClient | None = None
self._sandbox: Sandbox | None = None
self._bay_manager: Any = None # BayContainerManager when auto-started
self._fs: FileSystemComponent | None = None
self._python: PythonComponent | None = None
self._shell: ShellComponent | None = None
@@ -386,30 +395,35 @@ class ShipyardNeoBooter(ComputerBooter):
async def boot(self, session_id: str) -> None:
_ = session_id
# --- Auto-start Bay if needed ---
if self.is_auto_mode:
from .bay_manager import BayContainerManager
# Clean up previous manager if re-booting
if self._bay_manager is not None:
await self._bay_manager.close_client()
logger.info("[Computer] bay_autostart status=starting")
logger.info("[Computer] Neo auto-start mode: launching Bay container")
self._bay_manager = BayContainerManager()
self._endpoint_url = await self._bay_manager.ensure_running()
await self._bay_manager.wait_healthy()
# Read auto-provisioned credentials
if not self._access_token:
self._access_token = await self._bay_manager.read_credentials()
logger.info(
"[Computer] bay_autostart status=ready endpoint=%s",
self._endpoint_url,
)
logger.info("[Computer] Bay auto-started at %s", self._endpoint_url)
if not self._endpoint_url or not self._access_token:
if self._bay_manager is not None:
raise ValueError(
"Bay container started but credentials could not be read. Ensure Bay generated credentials.json, or set access_token manually.",
"Bay container started but credentials could not be read. "
"Ensure Bay generated credentials.json, or set access_token manually."
)
raise ValueError(
"Shipyard Neo sandbox configuration is incomplete. Set endpoint (default http://127.0.0.1:8114) and access token, or ensure Bay's credentials.json is accessible for auto-discovery.",
"Shipyard Neo sandbox configuration is incomplete. "
"Set endpoint (default http://127.0.0.1:8114) and access token, "
"or ensure Bay's credentials.json is accessible for auto-discovery."
)
from shipyard_neo import BayClient
self._client = BayClient(
endpoint_url=self._endpoint_url,
@@ -421,23 +435,26 @@ class ShipyardNeoBooter(ComputerBooter):
# An empty profile means auto-select; any non-empty profile must be
# honoured as an explicit choice, including "python-default".
resolved_profile = await self._resolve_profile(self._client)
self._sandbox = await self._client.create_sandbox(
profile=resolved_profile,
ttl=self._ttl,
)
# --- Readiness gate: wait until sandbox session is READY ---
await self._wait_until_ready(self._sandbox)
self._shell = NeoShellComponent(self._sandbox)
self._fs = NeoFileSystemComponent(self._sandbox, self._shell)
self._python = NeoPythonComponent(self._sandbox)
self._shell = NeoShellComponent(self._sandbox)
caps = self.capabilities or ()
self._browser = (
NeoBrowserComponent(self._sandbox) if "browser" in caps else None
)
logger.info(
"[Computer] sandbox_created booter=shipyard_neo sandbox_id=%s profile=%s capabilities=%s auto=%s",
"Got Shipyard Neo sandbox: %s (profile=%s, capabilities=%s, auto=%s)",
self._sandbox.id,
resolved_profile,
list(caps),
@@ -484,7 +501,7 @@ class ShipyardNeoBooter(ComputerBooter):
del_err,
)
raise RuntimeError(
f"Sandbox {sandbox_id} is in terminal state: {status}",
f"Sandbox {sandbox_id} is in terminal state: {status}"
)
remaining = deadline - asyncio.get_running_loop().time()
@@ -506,7 +523,7 @@ class ShipyardNeoBooter(ComputerBooter):
)
raise TimeoutError(
f"Sandbox {sandbox_id} did not become ready within "
f"{READINESS_TIMEOUT}s (last status: {status})",
f"{READINESS_TIMEOUT}s (last status: {status})"
)
logger.debug(
@@ -533,20 +550,23 @@ class ShipyardNeoBooter(ComputerBooter):
if self._profile:
logger.info("[Computer] Using user-specified profile: %s", self._profile)
return self._profile
# Query Bay for available profiles
from shipyard_neo.errors import ForbiddenError, UnauthorizedError
try:
profile_list = await client.list_profiles()
profiles = profile_list.items
except (UnauthorizedError, ForbiddenError):
raise
raise # auth errors must not be silenced
except Exception as exc:
logger.warning(
"[Computer] profile_selection_fallback reason=query_failed fallback=%s error=%s",
"[Computer] Failed to query Bay profiles, falling back to %s: %s",
self.DEFAULT_PROFILE,
exc,
)
return self.DEFAULT_PROFILE
if not profiles:
return self.DEFAULT_PROFILE
@@ -557,17 +577,18 @@ class ShipyardNeoBooter(ComputerBooter):
best = max(profiles, key=_score)
chosen = getattr(best, "id", self.DEFAULT_PROFILE)
if chosen != self.DEFAULT_PROFILE:
caps = getattr(best, "capabilities", [])
logger.info(
"[Computer] profile_selected mode=auto profile=%s capabilities=%s",
"[Computer] Auto-selected profile %s (capabilities=%s)",
chosen,
caps,
)
return chosen
async def shutdown(self, **kwargs) -> None:
delete_sandbox = bool(kwargs.get("delete_sandbox", False))
async def shutdown(self, *, delete_sandbox: bool = False) -> None:
if self._client is not None:
sandbox_id = getattr(self._sandbox, "id", "unknown")
@@ -578,13 +599,11 @@ class ShipyardNeoBooter(ComputerBooter):
if delete_sandbox and self._sandbox is not None:
try:
logger.info(
"[Computer] Deleting Shipyard Neo sandbox: id=%s",
sandbox_id,
"[Computer] Deleting Shipyard Neo sandbox: id=%s", sandbox_id
)
await self._sandbox.delete()
logger.info(
"[Computer] Shipyard Neo sandbox deleted: id=%s",
sandbox_id,
"[Computer] Shipyard Neo sandbox deleted: id=%s", sandbox_id
)
except Exception as e:
logger.warning(
@@ -595,15 +614,14 @@ class ShipyardNeoBooter(ComputerBooter):
)
logger.info(
"[Computer] booter_shutdown booter=shipyard_neo sandbox_id=%s status=starting",
"[Computer] Shutting down Shipyard Neo sandbox client: id=%s",
sandbox_id,
)
await self._client.__aexit__(None, None, None)
self._client = None
self._sandbox = None
logger.info(
"[Computer] booter_shutdown booter=shipyard_neo sandbox_id=%s status=done",
sandbox_id,
"[Computer] Shipyard Neo sandbox client shut down: id=%s", sandbox_id
)
# NOTE: We intentionally do NOT stop the Bay container here.
@@ -631,20 +649,19 @@ class ShipyardNeoBooter(ComputerBooter):
return self._shell
@property
def browser(self) -> BrowserComponent | None:
def browser(self) -> BrowserComponent:
if self._browser is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._browser
async def upload_file(self, path: str, file_name: str) -> dict:
if self._sandbox is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
async with await anyio.open_file(path, "rb") as f:
content = await f.read()
with open(path, "rb") as f:
content = f.read()
remote_path = file_name.lstrip("/")
await self._sandbox.filesystem.upload(remote_path, content)
logger.info(
"[Computer] file_upload booter=shipyard_neo remote_path=%s",
remote_path,
)
logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path)
return {
"success": True,
"message": "File uploaded successfully",
@@ -657,11 +674,11 @@ class ShipyardNeoBooter(ComputerBooter):
content = await self._sandbox.filesystem.download(remote_path.lstrip("/"))
local_dir = os.path.dirname(local_path)
if local_dir:
await anyio.Path(local_dir).mkdir(parents=True, exist_ok=True)
async with await anyio.open_file(local_path, "wb") as f:
await f.write(content)
os.makedirs(local_dir, exist_ok=True)
with open(local_path, "wb") as f:
f.write(cast(bytes, content))
logger.info(
"[Computer] file_download booter=shipyard_neo remote_path=%s local_path=%s",
"[Computer] File downloaded from Neo sandbox: %s -> %s",
remote_path,
local_path,
)
@@ -673,91 +690,13 @@ class ShipyardNeoBooter(ComputerBooter):
await self._sandbox.refresh()
status = getattr(self._sandbox.status, "value", str(self._sandbox.status))
healthy = status not in {"failed", "expired"}
logger.debug(
"[Computer] health_check booter=shipyard_neo sandbox_id=%s status=%s healthy=%s",
logger.info(
"[Computer] Neo sandbox health check: id=%s, status=%s, healthy=%s",
getattr(self._sandbox, "id", "unknown"),
status,
healthy,
)
return healthy
except Exception:
logger.exception(
"[Computer] health_check_failed booter=shipyard_neo sandbox_id=%s",
getattr(self._sandbox, "id", "unknown"),
)
except Exception as e:
logger.error(f"Error checking Shipyard Neo sandbox availability: {e}")
return False
@classmethod
@functools.cache
def _base_tools(cls) -> tuple[ToolSchema, ...]:
"""4 base + 11 Neo lifecycle = 15 tools (all Neo profiles)."""
from astrbot.core.computer.tools import (
AnnotateExecutionTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
SyncSkillReleaseTool,
)
return (
ExecuteShellTool(),
PythonTool(),
FileUploadTool(),
FileDownloadTool(),
GetExecutionHistoryTool(),
AnnotateExecutionTool(),
CreateSkillPayloadTool(),
GetSkillPayloadTool(),
CreateSkillCandidateTool(),
ListSkillCandidatesTool(),
EvaluateSkillCandidateTool(),
PromoteSkillCandidateTool(),
ListSkillReleasesTool(),
RollbackSkillReleaseTool(),
SyncSkillReleaseTool(),
)
@classmethod
@functools.cache
def _browser_tools(cls) -> tuple[ToolSchema, ...]:
from astrbot.core.computer.tools import (
BrowserBatchExecTool,
BrowserExecTool,
RunBrowserSkillTool,
)
return (BrowserExecTool(), BrowserBatchExecTool(), RunBrowserSkillTool())
@classmethod
def get_default_tools(cls) -> list[ToolSchema]:
"""Pre-boot: conservative full list (including browser)."""
return list(cls._base_tools()) + list(cls._browser_tools())
def get_tools(self) -> list[ToolSchema]:
"""Post-boot: capability-filtered list."""
caps = self.capabilities
if caps is None:
return self.__class__.get_default_tools()
tools: list[ToolSchema] = list(self._base_tools())
if "browser" in caps:
tools.extend(self._browser_tools())
return tools
@classmethod
def get_system_prompt_parts(cls) -> list[str]:
from astrbot.core.computer.prompts import (
NEO_FILE_PATH_PROMPT,
NEO_SKILL_LIFECYCLE_PROMPT,
)
return [NEO_FILE_PATH_PROMPT, NEO_SKILL_LIFECYCLE_PROMPT]

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import shlex
from typing import Any
from astrbot.core.computer.olayer import ShellComponent
from ..olayer import ShellComponent
_MAX_SEARCH_LINE_COLUMNS = 1000
@@ -92,7 +92,7 @@ def build_search_command(
glob=glob,
after_context=after_context,
before_context=before_context,
),
)
)
grep_command = _quote_command(
_build_grep_command(
@@ -101,7 +101,7 @@ def build_search_command(
glob=glob,
after_context=after_context,
before_context=before_context,
),
)
)
return (
"if command -v rg >/dev/null 2>&1; then "

Some files were not shown because too many files have changed in this diff Show More