Compare commits

..

2 Commits

596 changed files with 18368 additions and 62712 deletions

View File

@@ -3,8 +3,8 @@
### Modifications / 改动点
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
- [x] This is NOT a breaking change. / 这不是一个破坏性变更。
<!-- If your changes is a breaking change, please uncheck the checkbox above -->
@@ -21,14 +21,7 @@
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
- [ ] 😊 If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
/ 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。
- [ ] 👀 My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
/ 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。
- [ ] 🤓 I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
/ 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 My changes do not introduce malicious code.
/ 我的更改没有引入恶意代码。
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.

View File

@@ -11,25 +11,19 @@ jobs:
runs-on: ubuntu-latest # 运行环境
steps:
- name: checkout
uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v5.0.0
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/checkout@master
- name: nodejs installation
uses: actions/setup-node@v6
with:
node-version: "24.13.0"
cache: "pnpm"
cache-dependency-path: docs/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
working-directory: './docs'
- name: Build docs
run: pnpm run docs:build
node-version: "18"
- name: npm install
run: npm add -D vitepress
working-directory: './docs' # working-directory 指定 shell 命令运行目录
- name: npm run build
run: npm run docs:build
working-directory: './docs'
- name: scp
uses: appleboy/scp-action@v1.0.0
uses: appleboy/scp-action@master
with:
host: ${{ secrets.HOST_NEKO }}
username: ${{ secrets.USERNAME }}
@@ -37,7 +31,7 @@ jobs:
source: 'docs/.vitepress/dist/*'
target: '/tmp/'
- name: script
uses: appleboy/ssh-action@v1.2.5
uses: appleboy/ssh-action@master
with:
host: ${{ secrets.HOST_NEKO }}
username: ${{ secrets.USERNAME }}

View File

@@ -40,7 +40,6 @@ jobs:
pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
if: github.repository == 'AstrBotDevs/AstrBot'
uses: codecov/codecov-action@v6
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -8,28 +8,22 @@ on:
jobs:
build:
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v5.0.0
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
- name: Install and Build
working-directory: dashboard
- name: npm install, build
run: |
pnpm install --frozen-lockfile
cd dashboard
npm install pnpm -g
pnpm install
pnpm i --save-dev @types/markdown-it
pnpm run build
- name: Inject Commit SHA
@@ -51,7 +45,7 @@ jobs:
- name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1.21.0
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}
owner: AstrBotDevs

View File

@@ -11,7 +11,7 @@ on:
jobs:
build-nightly-image:
if: github.repository == 'AstrBotDevs/AstrBot' && github.event_name == 'schedule'
if: github.event_name == 'schedule'
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
@@ -64,20 +64,20 @@ jobs:
echo "build_date=$build_date" >> $GITHUB_OUTPUT
- name: Set QEMU
uses: docker/setup-qemu-action@v4.0.0
uses: docker/setup-qemu-action@v4
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v4.0.0
uses: docker/setup-buildx-action@v4
- name: Log in to DockerHub
uses: docker/login-action@v4.1.0
uses: docker/login-action@v4
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v4.1.0
uses: docker/login-action@v4
with:
registry: ghcr.io
username: ${{ env.GHCR_OWNER }}
@@ -98,7 +98,7 @@ jobs:
echo "EOF" >> $GITHUB_OUTPUT
- name: Build and Push Nightly Image
uses: docker/build-push-action@v7.1.0
uses: docker/build-push-action@v7
with:
context: .
platforms: linux/amd64,linux/arm64
@@ -109,7 +109,7 @@ jobs:
run: echo "Test Docker image has been built and pushed successfully"
build-release-image:
if: github.repository == 'AstrBotDevs/AstrBot' && (github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')))
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v'))
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
@@ -163,27 +163,27 @@ jobs:
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v4.0.0
uses: docker/setup-qemu-action@v4
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v4.0.0
uses: docker/setup-buildx-action@v4
- name: Log in to DockerHub
uses: docker/login-action@v4.1.0
uses: docker/login-action@v4
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v4.1.0
uses: docker/login-action@v4
with:
registry: ghcr.io
username: ${{ env.GHCR_OWNER }}
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Release Image
uses: docker/build-push-action@v7.1.0
uses: docker/build-push-action@v7
with:
context: .
platforms: linux/amd64,linux/arm64

View File

@@ -1,54 +0,0 @@
name: PR Title Check
on:
pull_request_target:
types: [opened, edited, reopened, synchronize]
jobs:
title-format:
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-latest
permissions:
pull-requests: write
issues: write
steps:
- name: Validate PR title
uses: actions/github-script@v9
with:
script: |
const title = (context.payload.pull_request.title || "").trim();
// allow only:
// feat: xxx
// feat(scope): xxx
const pattern = /^(feat)(\([a-z0-9-]+\))?:\s.+$/i;
const isValid = pattern.test(title);
const isSameRepo =
context.payload.pull_request.head.repo.full_name === context.payload.repository.full_name;
if (!isValid) {
if (isSameRepo) {
try {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.payload.pull_request.number,
body: [
"⚠️ PR title format check failed.",
"Required formats:",
"- `feat: xxx`",
"- `feat(scope): xxx`",
"Please update your PR title and push again."
].join("\n")
});
} catch (e) {
core.warning(`Failed to post PR title comment: ${e.message}`);
}
} else {
core.warning("Fork PR: comment permission is restricted; skip posting review comment.");
}
}
if (!isValid) {
core.setFailed("Invalid PR title. Expected format: feat: xxx or feat(scope): xxx.");
}

View File

@@ -20,7 +20,6 @@ permissions:
jobs:
build-dashboard:
name: Build Dashboard
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-24.04
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
@@ -51,7 +50,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup pnpm
uses: pnpm/action-setup@v5.0.0
uses: pnpm/action-setup@v4
with:
version: 10.28.2
@@ -64,11 +63,11 @@ jobs:
- name: Build dashboard dist
shell: bash
working-directory: dashboard
run: |
pnpm install --frozen-lockfile
pnpm run build
echo "${{ steps.tag.outputs.tag }}" > dist/assets/version
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir dashboard run build
echo "${{ steps.tag.outputs.tag }}" > dashboard/dist/assets/version
cd dashboard
zip -r "AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" dist
- name: Upload dashboard artifact
@@ -105,7 +104,6 @@ jobs:
publish-release:
name: Publish GitHub Release
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-24.04
needs:
- build-dashboard
@@ -185,7 +183,6 @@ jobs:
publish-pypi:
name: Publish PyPI
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-24.04
needs:
- publish-release

View File

@@ -13,23 +13,10 @@ on:
jobs:
smoke-test:
name: Smoke test (${{ matrix.os }}, Python ${{ matrix.python-version }})
runs-on: ${{ matrix.os }}
name: Run smoke tests
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
fail-fast: false
matrix:
os:
- ubuntu-latest
- macos-latest
- windows-latest
python-version:
- '3.10'
- '3.11'
- '3.12'
- '3.13'
- '3.14'
steps:
- name: Checkout
uses: actions/checkout@v6
@@ -39,21 +26,33 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: requirements.txt
- name: Install uv
python-version: '3.12'
- name: Install UV package manager
run: |
python -m pip install --upgrade pip
python -m pip install uv
pip install uv
- name: Install dependencies
run: |
uv pip install --system -r requirements.txt
uv sync
timeout-minutes: 15
- name: Run smoke tests
run: |
python scripts/smoke_startup_check.py
uv run main.py &
APP_PID=$!
echo "Waiting for application to start..."
for i in {1..60}; do
if curl -f http://localhost:6185 > /dev/null 2>&1; then
echo "Application started successfully!"
kill $APP_PID
exit 0
fi
sleep 1
done
echo "Application failed to start within 30 seconds"
kill $APP_PID 2>/dev/null || true
exit 1
timeout-minutes: 2

View File

@@ -18,7 +18,6 @@ on:
jobs:
stale:
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-latest
permissions:
issues: write

View File

@@ -18,7 +18,6 @@ concurrency:
jobs:
sync:
if: github.repository == 'AstrBotDevs/AstrBot'
runs-on: ubuntu-latest
permissions:
contents: read

View File

@@ -1,37 +0,0 @@
name: Unit Tests
on:
push:
branches:
- master
paths-ignore:
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
unit-tests:
name: Run pytest suite
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install uv
run: |
python -m pip install --upgrade pip
python -m pip install uv
- name: Run tests
run: |
chmod +x scripts/run_pytests_ci.sh
bash ./scripts/run_pytests_ci.sh ./tests

3
.gitignore vendored
View File

@@ -62,5 +62,4 @@ GenieData/
.opencode/
.kilocode/
.worktrees/
dashboard/bun.lock
docs/plans/

View File

@@ -12,11 +12,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
bash \
ffmpeg \
libavcodec-extra \
curl \
gnupg \
git \
ripgrep \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& apt-get clean \

View File

@@ -1,5 +1,4 @@
![astrbot-github-banner-v2-light-0405_副本](https://github.com/user-attachments/assets/36fb04e4-cc75-4454-bd8b-049d11aa86f9)
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
@@ -33,7 +32,7 @@
<a href="https://astrbot.app/">Documentation</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
<a href="mailto:community@astrbot.app">Email Support</a>
</div>
@@ -77,31 +76,27 @@ AstrBot is an open-source all-in-one Agent chatbot platform that integrates with
For users who want to quickly experience AstrBot, are familiar with command-line usage, and can install a `uv` environment on their own, we recommend the `uv` one-click deployment method ⚡️:
```bash
uv tool install astrbot --python 3.12
uv tool install astrbot
astrbot init # Only execute this command for the first time to initialize the environment
astrbot run
astrbot
```
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
> AstrBot requires Python 3.12 or later. The `--python 3.12` option ensures that `uv` creates the tool environment with Python 3.12.
> [!NOTE]
> For macOS users: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
> For macOS user: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
Update `astrbot`:
```bash
uv tool upgrade astrbot --python 3.12
uv tool upgrade astrbot
```
> [!WARNING]
> AstrBot deployed via `uv` **does not support upgrading through the WebUI**. To update, please run the command above from the command line.
### Docker Deployment
For users familiar with containers and looking for a more stable, production-ready deployment method, we recommend deploying AstrBot with Docker / Docker Compose.
Please refer to the official documentation: [Deploy AstrBot with Docker](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Deploy on RainYun
@@ -139,7 +134,7 @@ yay -S astrbot-git
**More deployment methods**
If you need panel-based management or deeper customization, see [BT-Panel Deployment](https://docs.astrbot.app/deploy/astrbot/btpanel.html) for BT Panel app-store setup, [1Panel Deployment](https://docs.astrbot.app/deploy/astrbot/1panel.html) for 1Panel app-market deployment, [CasaOS Deployment](https://docs.astrbot.app/deploy/astrbot/casaos.html) for NAS/home-server visual deployment, and [Manual Deployment](https://docs.astrbot.app/deploy/astrbot/cli.html) for fully custom source-based installation with `uv`.
If you need panel-based management or deeper customization, see [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html) for BT Panel app-store setup, [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html) for 1Panel app-market deployment, [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html) for NAS/home-server visual deployment, and [Manual Deployment](https://astrbot.app/deploy/astrbot/cli.html) for fully custom source-based installation with `uv`.
## Supported Messaging Platforms
@@ -158,12 +153,10 @@ Connect AstrBot to your favorite chat platform.
| Discord | Official |
| LINE | Official |
| Satori | Official |
| KOOK | Official |
| Misskey | Official |
| Mattermost | Official |
| WhatsApp (Coming Soon) | Official |
| [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | Community |
| [Rocket.Chat](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | Community |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | Community |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | Community |
## Supported Model Services
@@ -191,7 +184,6 @@ Connect AstrBot to your favorite chat platform.
| Coze | LLMOps Platforms |
| OpenAI Whisper | Speech-to-Text Services |
| SenseVoice | Speech-to-Text Services |
| Xiaomi MiMo Omni | Speech-to-Text Services |
| OpenAI TTS | Text-to-Speech Services |
| Gemini TTS | Text-to-Speech Services |
| GPT-Sovits-Inference | Text-to-Speech Services |
@@ -201,7 +193,6 @@ Connect AstrBot to your favorite chat platform.
| Alibaba Cloud Bailian TTS | Text-to-Speech Services |
| Azure TTS | Text-to-Speech Services |
| Minimax TTS | Text-to-Speech Services |
| Xiaomi MiMo TTS | Text-to-Speech Services |
| Volcano Engine TTS | Text-to-Speech Services |
## ❤️ Sponsors
@@ -234,17 +225,14 @@ pre-commit install
### QQ Groups
- Group 12: 916228568 (New)
- Group 9: 1076659624 (Full)
- Group 10: 1078079676 (Full)
- Group 11: 704659519 (Full)
- Group 1: 322154837 (Full)
- Group 3: 630166526 (Full)
- Group 4: 1077826412 (Full)
- Group 5: 822130018 (Full)
- Group 6: 753075035 (Full)
- Group 7: 743746109 (Full)
- Group 8: 1030353265 (Full)
- Group 9: 1076659624 (New)
- Group 10: 1078079676 (New)
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Group 7: 743746109
- Group 8: 1030353265
- Developer Group(Chit-chat): 975206796
- Developer Group(Formal): 1039761811

View File

@@ -76,13 +76,12 @@ AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègr
Pour les utilisateurs qui veulent découvrir AstrBot rapidement, qui sont familiers avec la ligne de commande et peuvent installer eux-mêmes l'environnement `uv`, nous recommandons la méthode de déploiement en un clic avec `uv` ⚡️ :
```bash
uv tool install astrbot --python 3.12
uv tool install astrbot
astrbot init # Exécutez cette commande uniquement la première fois pour initialiser l'environnement
astrbot run
astrbot
```
> [uv](https://docs.astral.sh/uv/) doit être installé.
> AstrBot nécessite Python 3.12 ou une version plus récente. L'option `--python 3.12` garantit que `uv` crée l'environnement tool avec Python 3.12.
> [!NOTE]
> Pour les utilisateurs macOS : en raison des vérifications de sécurité de macOS, la première exécution de la commande `astrbot` peut prendre plus de temps (environ 10-20s).
@@ -90,17 +89,14 @@ astrbot run
Mettre à jour `astrbot` :
```bash
uv tool upgrade astrbot --python 3.12
uv tool upgrade astrbot
```
> [!WARNING]
> AstrBot déployé via `uv` **ne prend pas en charge la mise à jour via le WebUI**. Pour mettre à jour, exécutez la commande ci-dessus depuis le terminal.
### Déploiement Docker
Pour les utilisateurs familiers avec les conteneurs et qui souhaitent une méthode plus stable et adaptée à la production, nous recommandons de déployer AstrBot avec Docker / Docker Compose.
Veuillez consulter la documentation officielle [Déployer AstrBot avec Docker](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
Veuillez consulter la documentation officielle [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Déployer sur RainYun
@@ -138,7 +134,7 @@ yay -S astrbot-git
**Autres méthodes de déploiement**
Si vous avez besoin d'une gestion par panneau ou d'une personnalisation plus poussée, consultez [Déploiement BT-Panel](https://docs.astrbot.app/deploy/astrbot/btpanel.html) pour une installation via BT Panel, [Déploiement 1Panel](https://docs.astrbot.app/deploy/astrbot/1panel.html) pour le marketplace 1Panel, [Déploiement CasaOS](https://docs.astrbot.app/deploy/astrbot/casaos.html) pour un déploiement visuel sur NAS/serveur domestique, et [Déploiement manuel](https://docs.astrbot.app/deploy/astrbot/cli.html) pour une installation complète depuis les sources avec `uv`.
Si vous avez besoin d'une gestion par panneau ou d'une personnalisation plus poussée, consultez [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) pour une installation via BT Panel, [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) pour le marketplace 1Panel, [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) pour un déploiement visuel sur NAS/serveur domestique, et [Déploiement manuel](https://astrbot.app/deploy/astrbot/cli.html) pour une installation complète depuis les sources avec `uv`.
## Plateformes de messagerie prises en charge
@@ -157,12 +153,10 @@ Connectez AstrBot à vos plateformes de chat préférées.
| Discord | Officielle |
| LINE | Officielle |
| Satori | Officielle |
| KOOK | Officielle |
| Misskey | Officielle |
| Mattermost | Officielle |
| WhatsApp (Bientôt disponible) | Officielle |
| [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | Communauté |
| [Rocket.Chat](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | Communauté |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | Communauté |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | Communauté |
## Services de modèles pris en charge
@@ -190,7 +184,6 @@ Connectez AstrBot à vos plateformes de chat préférées.
| Coze | Plateformes LLMOps |
| OpenAI Whisper | Services de reconnaissance vocale |
| SenseVoice | Services de reconnaissance vocale |
| Xiaomi MiMo Omni | Services de reconnaissance vocale |
| OpenAI TTS | Services de synthèse vocale |
| Gemini TTS | Services de synthèse vocale |
| GPT-Sovits-Inference | Services de synthèse vocale |
@@ -200,7 +193,6 @@ Connectez AstrBot à vos plateformes de chat préférées.
| Alibaba Cloud Bailian TTS | Services de synthèse vocale |
| Azure TTS | Services de synthèse vocale |
| Minimax TTS | Services de synthèse vocale |
| Xiaomi MiMo TTS | Services de synthèse vocale |
| Volcano Engine TTS | Services de synthèse vocale |
## ❤️ Contribuer
@@ -225,17 +217,10 @@ pre-commit install
### Groupes QQ
- Groupe 12 : 916228568 (nouveau)
- Groupe 9 : 1076659624 (complet)
- Groupe 10 : 1078079676 (complet)
- Groupe 11 : 704659519 (complet)
- Groupe 1 : 322154837 (complet)
- Groupe 3 : 630166526 (complet)
- Groupe 4 : 1077826412 (complet)
- Groupe 5 : 822130018 (complet)
- Groupe 6 : 753075035 (complet)
- Groupe 7 : 743746109 (complet)
- Groupe 8 : 1030353265 (complet)
- Groupe 1 : 322154837
- Groupe 3 : 630166526
- Groupe 5 : 822130018
- Groupe 6 : 753075035
- Groupe développeurs : 975206796
- Groupe développeurs (officiel) : 1039761811

View File

@@ -76,13 +76,12 @@ AstrBot は、主要なインスタントメッセージングアプリと統合
AstrBot を素早く試したいユーザーで、コマンドラインに慣れており `uv` 環境を自分でインストールできる場合は、`uv` のワンクリックデプロイをおすすめします ⚡️:
```bash
uv tool install astrbot --python 3.12
uv tool install astrbot
astrbot init # 初回のみ実行して環境を初期化します
astrbot run
astrbot
```
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
> AstrBot には Python 3.12 以降が必要です。`--python 3.12` を指定すると、`uv` は Python 3.12 で tool 環境を作成します。
> [!NOTE]
> macOS ユーザーの場合macOS のセキュリティチェックにより、`astrbot` コマンドの初回実行に時間がかかる場合があります(約 10〜20 秒)。
@@ -90,17 +89,14 @@ astrbot run
`astrbot` の更新:
```bash
uv tool upgrade astrbot --python 3.12
uv tool upgrade astrbot
```
> [!WARNING]
> `uv` 経由でデプロイした AstrBot は、**WebUI からのバージョンアップグレードに対応していません**。更新するには、上記のコマンドをコマンドラインで実行してください。
### Docker デプロイ
コンテナ運用に慣れており、より安定した本番向けのデプロイ方法を求めるユーザーには、Docker / Docker Compose での AstrBot デプロイをおすすめします。
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
### 雨云でのデプロイ
@@ -138,7 +134,7 @@ yay -S astrbot-git
**その他のデプロイ方法**
パネル操作での導入やより高度なカスタマイズが必要な場合は、[宝塔パネルデプロイ](https://docs.astrbot.app/deploy/astrbot/btpanel.html)BT Panel 経由の導入)、[1Panel デプロイ](https://docs.astrbot.app/deploy/astrbot/1panel.html)1Panel アプリマーケット経由)、[CasaOS デプロイ](https://docs.astrbot.app/deploy/astrbot/casaos.html)NAS / ホームサーバー向け可視化導入)、[手動デプロイ](https://docs.astrbot.app/deploy/astrbot/cli.html)`uv` とソースベースのフルカスタム導入)を参照してください。
パネル操作での導入やより高度なカスタマイズが必要な場合は、[宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html)BT Panel 経由の導入)、[1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html)1Panel アプリマーケット経由)、[CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html)NAS / ホームサーバー向け可視化導入)、[手動デプロイ](https://astrbot.app/deploy/astrbot/cli.html)`uv` とソースベースのフルカスタム導入)を参照してください。
## サポートされているメッセージプラットフォーム
@@ -157,12 +153,10 @@ AstrBot をよく使うチャットプラットフォームに接続できます
| Discord | 公式 |
| LINE | 公式 |
| Satori | 公式 |
| KOOK | 公式 |
| Misskey | 公式 |
| Mattermost | 公式 |
| WhatsApp (近日対応予定) | 公式 |
| [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | コミュニティ |
| [Rocket.Chat](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | コミュニティ |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | コミュニティ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | コミュニティ |
@@ -191,7 +185,6 @@ AstrBot をよく使うチャットプラットフォームに接続できます
| Coze | LLMOps プラットフォーム |
| OpenAI Whisper | 音声認識サービス |
| SenseVoice | 音声認識サービス |
| Xiaomi MiMo Omni | 音声認識サービス |
| OpenAI TTS | 音声合成サービス |
| Gemini TTS | 音声合成サービス |
| GPT-Sovits-Inference | 音声合成サービス |
@@ -201,7 +194,6 @@ AstrBot をよく使うチャットプラットフォームに接続できます
| Alibaba Cloud 百炼 TTS | 音声合成サービス |
| Azure TTS | 音声合成サービス |
| Minimax TTS | 音声合成サービス |
| Xiaomi MiMo TTS | 音声合成サービス |
| Volcano Engine TTS | 音声合成サービス |
## ❤️ コントリビューション
@@ -226,17 +218,10 @@ pre-commit install
### QQ グループ
- 12群: 916228568 (新)
- 9群: 1076659624 (満員)
- 10群: 1078079676 (満員)
- 11群: 704659519 (満員)
- 1群: 322154837 (満員)
- 3群: 630166526 (満員)
- 4群: 1077826412 (満員)
- 5群: 822130018 (満員)
- 6群: 753075035 (満員)
- 7群: 743746109 (満員)
- 8群: 1030353265 (満員)
- 1群: 322154837
- 3群: 630166526
- 5群: 822130018
- 6群: 753075035
- 開発者群: 975206796
- 開発者群(正式): 1039761811

View File

@@ -76,13 +76,12 @@ AstrBot — это универсальная платформа Agent-чатб
Для пользователей, которые хотят быстро попробовать AstrBot, знакомы с командной строкой и могут самостоятельно установить окружение `uv`, мы рекомендуем использовать развёртывание в один клик через `uv` ⚡️:
```bash
uv tool install astrbot --python 3.12
uv tool install astrbot
astrbot init # Выполните эту команду только при первом запуске для инициализации окружения
astrbot run
astrbot
```
> Требуется установленный [uv](https://docs.astral.sh/uv/).
> Для AstrBot требуется Python 3.12 или новее. Параметр `--python 3.12` гарантирует, что `uv` создаст tool-окружение с Python 3.12.
> [!NOTE]
> Для пользователей macOS: из-за проверок безопасности macOS первый запуск команды `astrbot` может занять больше времени (около 10-20 секунд).
@@ -90,17 +89,14 @@ astrbot run
Обновить `astrbot`:
```bash
uv tool upgrade astrbot --python 3.12
uv tool upgrade astrbot
```
> [!WARNING]
> AstrBot, развёрнутый через `uv`, **не поддерживает обновление через WebUI**. Для обновления выполните указанную выше команду из командной строки.
### Развёртывание Docker
Для пользователей, знакомых с контейнерами и которым нужен более стабильный и подходящий для production способ, мы рекомендуем разворачивать AstrBot через Docker / Docker Compose.
См. официальную документацию [Развёртывание AstrBot с Docker](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
См. официальную документацию [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Развёртывание на RainYun
@@ -138,7 +134,7 @@ yay -S astrbot-git
**Другие способы развёртывания**
Если вам нужна панельная установка или более глубокая кастомизация, смотрите [Развёртывание BT-Panel](https://docs.astrbot.app/deploy/astrbot/btpanel.html) (установка через BT Panel), [Развёртывание 1Panel](https://docs.astrbot.app/deploy/astrbot/1panel.html) (развёртывание через маркетплейс 1Panel), [Развёртывание CasaOS](https://docs.astrbot.app/deploy/astrbot/casaos.html) (визуальный вариант для NAS и домашних серверов) и [Ручное развёртывание](https://docs.astrbot.app/deploy/astrbot/cli.html) (полностью настраиваемая установка из исходников через `uv`).
Если вам нужна панельная установка или более глубокая кастомизация, смотрите [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) (установка через BT Panel), [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) (развёртывание через маркетплейс 1Panel), [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) (визуальный вариант для NAS и домашних серверов) и [Ручное развёртывание](https://astrbot.app/deploy/astrbot/cli.html) (полностью настраиваемая установка из исходников через `uv`).
## Поддерживаемые платформы обмена сообщениями
@@ -157,12 +153,10 @@ yay -S astrbot-git
| Discord | Официальная |
| LINE | Официальная |
| Satori | Официальная |
| KOOK | Официальная |
| Misskey | Официальная |
| Mattermost | Официальная |
| WhatsApp (Скоро) | Официальная |
| [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | Сообщество |
| [Rocket.Chat](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | Сообщество |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | Сообщество |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | Сообщество |
## Поддерживаемые сервисы моделей
@@ -190,7 +184,6 @@ yay -S astrbot-git
| Coze | Платформы LLMOps |
| OpenAI Whisper | Сервисы распознавания речи |
| SenseVoice | Сервисы распознавания речи |
| Xiaomi MiMo Omni | Сервисы распознавания речи |
| OpenAI TTS | Сервисы синтеза речи |
| Gemini TTS | Сервисы синтеза речи |
| GPT-Sovits-Inference | Сервисы синтеза речи |
@@ -200,7 +193,6 @@ yay -S astrbot-git
| Alibaba Cloud Bailian TTS | Сервисы синтеза речи |
| Azure TTS | Сервисы синтеза речи |
| Minimax TTS | Сервисы синтеза речи |
| Xiaomi MiMo TTS | Сервисы синтеза речи |
| Volcano Engine TTS | Сервисы синтеза речи |
## ❤️ Вклад в проект
@@ -225,17 +217,10 @@ pre-commit install
### Группы QQ
- Группа 12: 916228568 (новая)
- Группа 9: 1076659624 (полная)
- Группа 10: 1078079676 (полная)
- Группа 11: 704659519 (полная)
- Группа 1: 322154837 (полная)
- Группа 3: 630166526 (полная)
- Группа 4: 1077826412 (полная)
- Группа 5: 822130018 (полная)
- Группа 6: 753075035 (полная)
- Группа 7: 743746109 (полная)
- Группа 8: 1030353265 (полная)
- Группа 1: 322154837
- Группа 3: 630166526
- Группа 5: 822130018
- Группа 6: 753075035
- Группа разработчиков: 975206796
- Группа разработчиков (официальная): 1039761811

View File

@@ -32,7 +32,7 @@
<a href="https://astrbot.app/">文件</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路線圖</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題回報</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題回報</a>
<a href="mailto:community@astrbot.app">Email</a>
</div>
@@ -76,13 +76,12 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
對於想快速體驗 AstrBot、且熟悉命令列並能自行安裝 `uv` 環境的使用者,我們推薦使用 `uv` 一鍵部署方式 ⚡️。
```bash
uv tool install astrbot --python 3.12
uv tool install astrbot
astrbot init # 僅首次執行此命令以初始化環境
astrbot run
astrbot
```
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
> AstrBot 需要 Python 3.12 或更高版本。`--python 3.12` 會確保 `uv` 使用 Python 3.12 建立 tool 環境。
> [!NOTE]
> 對於 macOS 使用者:由於 macOS 安全性檢查,首次執行 `astrbot` 指令可能需要較長時間(約 10-20 秒)。
@@ -90,17 +89,14 @@ astrbot run
更新 `astrbot`
```bash
uv tool upgrade astrbot --python 3.12
uv tool upgrade astrbot
```
> [!WARNING]
> 透過 `uv` 部署的 AstrBot **不支援在 WebUI 中進行版本升級**。如需更新,請透過命令列執行上述命令。
### Docker 部署
對於熟悉容器、希望獲得更穩定且更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
請參考官方文件 [使用 Docker 部署 AstrBot](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
請參考官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
### 在雨雲上部署
@@ -138,7 +134,7 @@ yay -S astrbot-git
**更多部署方式**
若你需要面板化或更高自訂程度的部署,可參考 [寶塔面板](https://docs.astrbot.app/deploy/astrbot/btpanel.html)BT Panel 應用商店安裝)、[1Panel](https://docs.astrbot.app/deploy/astrbot/1panel.html)1Panel 應用商店安裝)、[CasaOS](https://docs.astrbot.app/deploy/astrbot/casaos.html)NAS / 家用伺服器可視化部署)與 [手動部署](https://docs.astrbot.app/deploy/astrbot/cli.html)(基於原始碼與 `uv` 的完整自訂安裝)。
若你需要面板化或更高自訂程度的部署,可參考 [寶塔面板](https://astrbot.app/deploy/astrbot/btpanel.html)BT Panel 應用商店安裝)、[1Panel](https://astrbot.app/deploy/astrbot/1panel.html)1Panel 應用商店安裝)、[CasaOS](https://astrbot.app/deploy/astrbot/casaos.html)NAS / 家用伺服器可視化部署)與 [手動部署](https://astrbot.app/deploy/astrbot/cli.html)(基於原始碼與 `uv` 的完整自訂安裝)。
## 支援的訊息平台
@@ -157,12 +153,10 @@ yay -S astrbot-git
| Discord | 官方維護 |
| LINE | 官方維護 |
| Satori | 官方維護 |
| KOOK | 官方維護 |
| Misskey | 官方維護 |
| Mattermost | 官方維護 |
| WhatsApp即將支援 | 官方維護 |
| Whatsapp即將支援 | 官方維護 |
| [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | 社群維護 |
| [Rocket.Chat](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | 社群維護 |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | 社群維護 |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | 社群維護 |
## 支援的模型服務
@@ -190,7 +184,6 @@ yay -S astrbot-git
| Coze | LLMOps 平台 |
| OpenAI Whisper | 語音轉文字服務 |
| SenseVoice | 語音轉文字服務 |
| Xiaomi MiMo Omni | 語音轉文字服務 |
| OpenAI TTS | 文字轉語音服務 |
| Gemini TTS | 文字轉語音服務 |
| GPT-Sovits-Inference | 文字轉語音服務 |
@@ -200,7 +193,6 @@ yay -S astrbot-git
| 阿里雲百煉 TTS | 文字轉語音服務 |
| Azure TTS | 文字轉語音服務 |
| Minimax TTS | 文字轉語音服務 |
| Xiaomi MiMo TTS | 文字轉語音服務 |
| 火山引擎 TTS | 文字轉語音服務 |
## ❤️ 貢獻
@@ -225,17 +217,14 @@ pre-commit install
### QQ 群組
- 12 群916228568 (新)
- 9 群1076659624 (人滿)
- 10 群:1078079676 (人滿)
- 11 群:704659519 (人滿)
- 1 群:322154837 (人滿)
- 3 群:630166526 (人滿)
- 4 群:1077826412 (人滿)
- 5 群:822130018 (人滿)
- 6 群753075035 (人滿)
- 7 群743746109 (人滿)
- 8 群1030353265 (人滿)
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 開發者群闲聊吹水975206796
- 開發者群正式1039761811

View File

@@ -31,12 +31,12 @@
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">博客</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
<a href="mailto:community@astrbot.app">Email</a>
</div>
AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、Telegram、企业微信、飞书、钉钉、Slack 等数十款主流即时通讯软件上部署,此外还内置类似 OpenWebUI 的轻量化 ChatUI为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手还是企业知识库AstrBot 都能在你的即时通讯软件平台的工作流中快速构建 AI 应用。
AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、Telegram、企业微信、飞书、钉钉、Slack等数十款主流即时通讯软件上部署,此外还内置类似 OpenWebUI 的轻量化 ChatUI为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手还是企业知识库AstrBot 都能在你的即时通讯软件平台的工作流中快速构建 AI 应用。
![landingpage](https://github.com/user-attachments/assets/45fc5699-cddf-4e21-af35-13040706f6c0)
@@ -76,13 +76,12 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
对于想快速体验 AstrBot、且熟悉命令行并能够自行安装 `uv` 环境的用户,我们推荐使用 `uv` 一键部署方式 ⚡️。
```bash
uv tool install astrbot --python 3.12
uv tool install astrbot
astrbot init # 仅首次执行此命令以初始化环境
astrbot run
astrbot
```
> 需要安装 [uv](https://docs.astral.sh/uv/)。
> AstrBot 需要 Python 3.12 或更高版本。`--python 3.12` 会确保 `uv` 使用 Python 3.12 创建 tool 环境。
> [!NOTE]
> 对于 macOS 用户:由于 macOS 安全检查,首次运行 `astrbot` 命令可能需要较长时间(约 10-20 秒)。
@@ -90,17 +89,14 @@ astrbot run
更新 `astrbot`
```bash
uv tool upgrade astrbot --python 3.12
uv tool upgrade astrbot
```
> [!WARNING]
> 通过 `uv` 部署的 AstrBot **不支持在 WebUI 中进行版本升级**。如需更新,请通过命令行执行上述命令。
### Docker 部署
对于熟悉容器、希望获得更稳定且更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
请参考官方文档 [使用 Docker 部署 AstrBot](https://docs.astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
请参考官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
### 在 雨云 上部署
@@ -138,7 +134,7 @@ yay -S astrbot-git
**更多部署方式**
若你需要面板化或更高自定义部署,可参考 [宝塔面板](https://docs.astrbot.app/deploy/astrbot/btpanel.html)BT Panel 应用商店安装)、[1Panel](https://docs.astrbot.app/deploy/astrbot/1panel.html)1Panel 应用商店安装)、[CasaOS](https://docs.astrbot.app/deploy/astrbot/casaos.html)NAS / 家庭服务器可视化部署)和 [手动部署](https://docs.astrbot.app/deploy/astrbot/cli.html)(基于源码与 `uv` 的完整自定义安装)。
若你需要面板化或更高自定义部署,可参考 [宝塔面板](https://astrbot.app/deploy/astrbot/btpanel.html)BT Panel 应用商店安装)、[1Panel](https://astrbot.app/deploy/astrbot/1panel.html)1Panel 应用商店安装)、[CasaOS](https://astrbot.app/deploy/astrbot/casaos.html)NAS / 家庭服务器可视化部署)和 [手动部署](https://astrbot.app/deploy/astrbot/cli.html)(基于源码与 `uv` 的完整自定义安装)。
## 支持的消息平台
@@ -157,12 +153,10 @@ yay -S astrbot-git
| **Discord** | 官方维护 |
| **LINE** | 官方维护 |
| **Satori** | 官方维护 |
| **KOOK** | 官方维护 |
| **Misskey** | 官方维护 |
| **Mattermost** | 官方维护 |
| **WhatsApp将支持** | 官方维护 |
| **Whatsapp (将支持)** | 官方维护 |
| [**Matrix**](https://github.com/stevessr/astrbot_plugin_matrix_adapter) | 社区维护 |
| [**Rocket.Chat**](https://github.com/NET-Homeless/astrbot_plugin_rocket_chat_adapter) | 社区维护 |
| [**KOOK**](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | 社区维护 |
| [**VoceChat**](https://github.com/HikariFroya/astrbot_plugin_vocechat) | 社区维护 |
## 支持的模型提供商
@@ -191,7 +185,6 @@ yay -S astrbot-git
| Coze | LLMOps 平台 |
| OpenAI Whisper | 语音转文本 |
| SenseVoice | 语音转文本 |
| Xiaomi MiMo Omni | 语音转文本 |
| OpenAI TTS | 文本转语音 |
| Gemini TTS | 文本转语音 |
| GPT-Sovits-Inference | 文本转语音 |
@@ -201,7 +194,6 @@ yay -S astrbot-git
| 阿里云百炼 TTS | 文本转语音 |
| Azure TTS | 文本转语音 |
| Minimax TTS | 文本转语音 |
| Xiaomi MiMo TTS | 文本转语音 |
| 火山引擎 TTS | 文本转语音 |
## ❤️ 贡献
@@ -226,17 +218,14 @@ pre-commit install
### QQ 群组
- 12 群916228568 (新)
- 9 群1076659624 (人满)
- 10 群:1078079676 (人满)
- 11 群:704659519 (人满)
- 1 群:322154837 (人满)
- 3 群:630166526 (人满)
- 4 群:1077826412 (人满)
- 5 群:822130018 (人满)
- 6 群753075035 (人满)
- 7 群743746109 (人满)
- 8 群1030353265 (人满)
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 开发者群偏闲聊吹水975206796
- 开发者群正式1039761811

View File

@@ -14,8 +14,6 @@ from astrbot.core.star.register import register_command_group as command_group
from astrbot.core.star.register import register_custom_filter as custom_filter
from astrbot.core.star.register import register_event_message_type as event_message_type
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_on_agent_begin as on_agent_begin
from astrbot.core.star.register import register_on_agent_done as on_agent_done
from astrbot.core.star.register import register_on_astrbot_loaded as on_astrbot_loaded
from astrbot.core.star.register import (
register_on_decorating_result as on_decorating_result,
@@ -53,8 +51,6 @@ __all__ = [
"custom_filter",
"event_message_type",
"llm_tool",
"on_agent_begin",
"on_agent_done",
"on_astrbot_loaded",
"on_decorating_result",
"on_llm_request",

View File

@@ -36,9 +36,9 @@ class Main(star.Star):
if self.ltm_enabled(event) and self.ltm and has_image_or_plain:
need_active = await self.ltm.need_active_reply(event)
group_icl_enable = self.context.get_config(umo=event.unified_msg_origin)[
"provider_ltm_settings"
]["group_icl_enable"]
group_icl_enable = self.context.get_config()["provider_ltm_settings"][
"group_icl_enable"
]
if group_icl_enable:
"""记录对话"""
try:

View File

@@ -1,17 +1,29 @@
# Commands module
from .admin import AdminCommands
from .alter_cmd import AlterCmdCommands
from .conversation import ConversationCommands
from .help import HelpCommand
from .llm import LLMCommands
from .persona import PersonaCommands
from .plugin import PluginCommands
from .provider import ProviderCommands
from .setunset import SetUnsetCommands
from .sid import SIDCommand
from .t2i import T2ICommand
from .tts import TTSCommand
__all__ = [
"AdminCommands",
"AlterCmdCommands",
"ConversationCommands",
"HelpCommand",
"LLMCommands",
"PersonaCommands",
"PluginCommands",
"ProviderCommands",
"SetUnsetCommands",
"SIDCommand",
"SetUnsetCommands",
"T2ICommand",
"TTSCommand",
]

View File

@@ -1,5 +1,5 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.event import AstrMessageEvent, MessageChain, MessageEventResult
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard
@@ -8,8 +8,70 @@ class AdminCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
if not admin_id:
event.set_result(
MessageEventResult().message(
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员。可通过 /sid 获取 ID。",
),
)
return
self.context.get_config()["admins_id"].append(str(admin_id))
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""取消授权管理员。deop <admin_id>"""
if not admin_id:
event.set_result(
MessageEventResult().message(
"使用方法: /deop <id> 取消管理员。可通过 /sid 获取 ID。",
),
)
return
try:
self.context.get_config()["admins_id"].remove(str(admin_id))
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("取消授权成功。"))
except ValueError:
event.set_result(
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
)
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
if not sid:
event.set_result(
MessageEventResult().message(
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单。可通过 /sid 获取 ID。",
),
)
return
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg["platform_settings"]["id_whitelist"].append(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""删除白名单。dwl <sid>"""
if not sid:
event.set_result(
MessageEventResult().message(
"使用方法: /dwl <id> 删除白名单。可通过 /sid 获取 ID。",
),
)
return
try:
cfg = self.context.get_config(umo=event.unified_msg_origin)
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
cfg.save_config()
event.set_result(MessageEventResult().message("删除白名单成功。"))
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await event.send(MessageChain().message("⏳ Updating dashboard..."))
await event.send(MessageChain().message("正在尝试更新管理面板..."))
await download_dashboard(version=f"v{VERSION}", latest=False)
await event.send(MessageChain().message("✅ Dashboard updated successfully."))
await event.send(MessageChain().message("管理面板更新完成。"))

View File

@@ -0,0 +1,173 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
from astrbot.core.utils.command_parser import CommandParserMixin
from .utils.rst_scene import RstScene
class AlterCmdCommands(CommandParserMixin):
def __init__(self, context: star.Context) -> None:
self.context = context
async def update_reset_permission(self, scene_key: str, perm_type: str) -> None:
"""更新reset命令在特定场景下的权限设置"""
from astrbot.api import sp
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_cfg.get("reset", {})
reset_cfg[scene_key] = perm_type
plugin_cfg["reset"] = reset_cfg
alter_cmd_cfg["astrbot"] = plugin_cfg
await sp.global_put("alter_cmd", alter_cmd_cfg)
async def alter_cmd(self, event: AstrMessageEvent) -> None:
token = self.parse_commands(event.message_str)
if token.len < 3:
await event.send(
MessageChain().message(
"该指令用于设置指令或指令组的权限。\n"
"格式: /alter_cmd <cmd_name> <admin/member>\n"
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
"/alter_cmd reset config 打开 reset 权限配置",
),
)
return
# 兼容 reset scene 的专门配置
cmd_name = token.get(1)
cmd_type = token.get(2)
if cmd_name == "reset" and cmd_type == "config":
from astrbot.api import sp
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get("astrbot", {})
reset_cfg = plugin_.get("reset", {})
group_unique_on = reset_cfg.get("group_unique_on", "admin")
group_unique_off = reset_cfg.get("group_unique_off", "admin")
private = reset_cfg.get("private", "member")
config_menu = f"""reset命令权限细粒度配置
当前配置:
1. 群聊+会话隔离开: {group_unique_on}
2. 群聊+会话隔离关: {group_unique_off}
3. 私聊: {private}
修改指令格式:
/alter_cmd reset scene <场景编号> <admin/member>
例如: /alter_cmd reset scene 2 member"""
await event.send(MessageChain().message(config_menu))
return
if cmd_name == "reset" and cmd_type == "scene" and token.len >= 4:
scene_num = token.get(3)
perm_type = token.get(4)
if scene_num is None or perm_type is None:
await event.send(MessageChain().message("场景编号和权限类型不能为空"))
return
if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3:
await event.send(
MessageChain().message("场景编号必须是 1-3 之间的数字"),
)
return
if perm_type not in ["admin", "member"]:
await event.send(
MessageChain().message("权限类型错误,只能是 admin 或 member"),
)
return
scene_num = int(scene_num)
scene = RstScene.from_index(scene_num)
scene_key = scene.key
await self.update_reset_permission(scene_key, perm_type)
await event.send(
MessageChain().message(
f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}",
),
)
return
if cmd_type not in ["admin", "member"]:
await event.send(
MessageChain().message("指令类型错误,可选类型有 admin, member"),
)
return
# 查找指令
cmd_name = " ".join(token.tokens[1:-1])
cmd_type = token.get(-1)
found_command = None
cmd_group = False
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.equals(cmd_name):
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if filter_.equals(cmd_name):
found_command = handler
cmd_group = True
break
if not found_command:
await event.send(MessageChain().message("未找到该指令"))
return
found_plugin = star_map[found_command.handler_module_path]
from astrbot.api import sp
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
cfg = plugin_.get(found_command.handler_name, {})
cfg["permission"] = cmd_type
plugin_[found_command.handler_name] = cfg
alter_cmd_cfg[found_plugin.name] = plugin_
await sp.global_put("alter_cmd", alter_cmd_cfg)
# 注入权限过滤器
found_permission_filter = False
for filter_ in found_command.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
from astrbot.api.event import filter
filter_.permission_type = filter.PermissionType.ADMIN
else:
from astrbot.api.event import filter
filter_.permission_type = filter.PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
from astrbot.api.event import filter
found_command.event_filters.insert(
0,
PermissionTypeFilter(
filter.PermissionType.ADMIN
if cmd_type == "admin"
else filter.PermissionType.MEMBER,
),
)
cmd_group_str = "指令组" if cmd_group else "指令"
await event.send(
MessageChain().message(
f"已将「{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}",
),
)

View File

@@ -1,16 +1,13 @@
from sqlalchemy import case, func, select
from sqlmodel import col
import datetime
from astrbot.api import sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core import logger
from astrbot.core.agent.runners.deerflow.constants import (
DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY,
DEERFLOW_PROVIDER_TYPE,
DEERFLOW_THREAD_ID_KEY,
)
from astrbot.core.agent.runners.deerflow.deerflow_api_client import DeerFlowAPIClient
from astrbot.core.db.po import ProviderStat
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.core.platform.message_type import MessageType
from astrbot.core.utils.active_event_registry import active_event_registry
from .utils.rst_scene import RstScene
@@ -24,85 +21,6 @@ THIRD_PARTY_AGENT_RUNNER_KEY = {
THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
async def _cleanup_deerflow_thread_if_present(
context: star.Context,
umo: str,
) -> None:
try:
thread_id = await sp.get_async(
scope="umo",
scope_id=umo,
key=DEERFLOW_THREAD_ID_KEY,
default="",
)
if not thread_id:
return
cfg = context.get_config(umo=umo)
provider_id = cfg["provider_settings"].get(
DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY,
"",
)
if not provider_id:
return
merged_provider_config = context.provider_manager.get_provider_config_by_id(
provider_id,
merged=True,
)
if not merged_provider_config:
logger.warning(
"Failed to resolve DeerFlow provider config for remote thread cleanup: provider_id=%s",
provider_id,
)
return
client = DeerFlowAPIClient(
api_base=merged_provider_config.get(
"deerflow_api_base",
"http://127.0.0.1:2026",
),
api_key=merged_provider_config.get("deerflow_api_key", ""),
auth_header=merged_provider_config.get("deerflow_auth_header", ""),
proxy=merged_provider_config.get("proxy", ""),
)
try:
await client.delete_thread(thread_id)
finally:
try:
await client.close()
except Exception as e:
logger.warning(
"Failed to close DeerFlow API client after thread cleanup: %s",
e,
)
except Exception as e:
logger.warning(
"Failed to clean up DeerFlow thread for session %s: %s",
umo,
e,
)
async def _clear_third_party_agent_runner_state(
context: star.Context,
umo: str,
agent_runner_type: str,
) -> None:
session_key = THIRD_PARTY_AGENT_RUNNER_KEY.get(agent_runner_type)
if not session_key:
return
if agent_runner_type == DEERFLOW_PROVIDER_TYPE:
await _cleanup_deerflow_thread_if_present(context, umo)
await sp.remove_async(
scope="umo",
scope_id=umo,
key=session_key,
)
class ConversationCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
@@ -142,8 +60,8 @@ class ConversationCommands:
if required_perm == "admin" and message.role != "admin":
message.set_result(
MessageEventResult().message(
f"Reset command requires admin permission in {scene.name} scenario, "
f"you (ID {message.get_sender_id()}) are not admin, cannot perform this action.",
f"{scene.name}场景下reset命令需要管理员权限"
f" (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。",
),
)
return
@@ -151,21 +69,17 @@ class ConversationCommands:
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
active_event_registry.stop_all(umo, exclude=message)
await _clear_third_party_agent_runner_state(
self.context,
umo,
agent_runner_type,
)
message.set_result(
MessageEventResult().message("✅ Conversation reset successfully.")
await sp.remove_async(
scope="umo",
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return
if not self.context.get_using_provider(umo):
message.set_result(
MessageEventResult().message(
"😕 Cannot find any LLM provider. Configure one first."
),
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
@@ -174,7 +88,7 @@ class ConversationCommands:
if not cid:
message.set_result(
MessageEventResult().message(
"😕 You are not in a conversation. Use /new to create one.",
"当前未处于对话状态,请 /switch 切换或者 /new 创建。",
),
)
return
@@ -187,7 +101,7 @@ class ConversationCommands:
[],
)
ret = "✅ Conversation reset successfully."
ret = "清除聊天历史成功!"
message.set_extra("_clean_ltm_session", True)
@@ -210,29 +124,160 @@ class ConversationCommands:
if stopped_count > 0:
message.set_result(
MessageEventResult().message(
f"✅ Requested to stop {stopped_count} running tasks."
f"已请求停止 {stopped_count} 个运行中的任务。"
)
)
return
message.set_result(
MessageEventResult().message("✅ No running tasks in the current session.")
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
size_per_page = 6
conv_mgr = self.context.conversation_manager
umo = message.unified_msg_origin
session_curr_cid = await conv_mgr.get_curr_conversation_id(umo)
if not session_curr_cid:
session_curr_cid = await conv_mgr.new_conversation(
umo,
message.get_platform_id(),
)
contexts, total_pages = await conv_mgr.get_human_readable_context(
umo,
session_curr_cid,
page,
size_per_page,
)
parts = []
for context in contexts:
if len(context) > 150:
context = context[:150] + "..."
parts.append(f"{context}\n")
history = "".join(parts)
ret = (
f"当前对话历史记录:"
f"{history or '无历史记录'}\n\n"
f"{page} 页 | 共 {total_pages}\n"
f"*输入 /history 2 跳转到第 2 页"
)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
message.set_result(
MessageEventResult().message(
f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。",
),
)
return
size_per_page = 6
"""获取所有对话列表"""
conversations_all = await self.context.conversation_manager.get_conversations(
message.unified_msg_origin,
)
"""计算总页数"""
total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page
"""确保页码有效"""
page = max(1, min(page, total_pages))
"""分页处理"""
start_idx = (page - 1) * size_per_page
end_idx = start_idx + size_per_page
conversations_paged = conversations_all[start_idx:end_idx]
parts = ["对话列表:\n---\n"]
"""全局序号从当前页的第一个开始"""
global_index = start_idx + 1
"""生成所有对话的标题字典"""
_titles = {}
for conv in conversations_all:
title = conv.title if conv.title else "新对话"
_titles[conv.cid] = title
"""遍历分页后的对话生成列表显示"""
provider_settings = cfg.get("provider_settings", {})
platform_name = message.get_platform_name()
for conv in conversations_paged:
(
persona_id,
_,
force_applied_persona_id,
_,
) = await self.context.persona_manager.resolve_selected_persona(
umo=message.unified_msg_origin,
conversation_persona_id=conv.persona_id,
platform_name=platform_name,
provider_settings=provider_settings,
)
if persona_id == "[%None]":
persona_name = ""
elif persona_id:
persona_name = persona_id
else:
persona_name = ""
if force_applied_persona_id:
persona_name = f"{persona_name} (自定义规则)"
title = _titles.get(conv.cid, "新对话")
parts.append(
f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_name}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
)
global_index += 1
parts.append("---\n")
ret = "".join(parts)
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
message.unified_msg_origin,
)
if curr_cid:
"""从所有对话的标题字典中获取标题"""
title = _titles.get(curr_cid, "新对话")
ret += f"\n当前对话: {title}({curr_cid[:4]})"
else:
ret += "\n当前对话: 无"
cfg = self.context.get_config(umo=message.unified_msg_origin)
unique_session = cfg["platform_settings"]["unique_session"]
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
ret += "\n会话隔离粒度: 群聊"
ret += f"\n{page} 页 | 共 {total_pages}"
ret += "\n*输入 /ls 2 跳转到第 2 页"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
return
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
await _clear_third_party_agent_runner_state(
self.context,
message.unified_msg_origin,
agent_runner_type,
)
message.set_result(
MessageEventResult().message("✅ New conversation created.")
await sp.remove_async(
scope="umo",
scope_id=message.unified_msg_origin,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("已创建新对话。"))
return
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
@@ -246,66 +291,130 @@ class ConversationCommands:
message.set_extra("_clean_ltm_session", True)
message.set_result(
MessageEventResult().message(
f"✅ Switched to new conversation: {cid[:4]}"
),
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
)
async def stats(self, message: AstrMessageEvent) -> None:
"""Show token usage statistics for the current conversation."""
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None:
"""创建新群聊对话"""
if sid:
session = str(
MessageSession(
platform_name=message.platform_meta.id,
message_type=MessageType("GroupMessage"),
session_id=sid,
),
)
cpersona = await self._get_current_persona_id(session)
cid = await self.context.conversation_manager.new_conversation(
session,
message.get_platform_id(),
persona_id=cpersona,
)
message.set_result(
MessageEventResult().message(
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。",
),
)
else:
message.set_result(
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。"),
)
async def switch_conv(
self,
message: AstrMessageEvent,
index: int | None = None,
) -> None:
"""通过 /ls 前面的序号切换对话"""
if not isinstance(index, int):
message.set_result(
MessageEventResult().message("类型错误,请输入数字对话序号。"),
)
return
if index is None:
message.set_result(
MessageEventResult().message(
"请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话",
),
)
return
conversations = await self.context.conversation_manager.get_conversations(
message.unified_msg_origin,
)
if index > len(conversations) or index < 1:
message.set_result(
MessageEventResult().message("对话序号错误,请使用 /ls 查看"),
)
else:
conversation = conversations[index - 1]
title = conversation.title if conversation.title else "新对话"
await self.context.conversation_manager.switch_conversation(
message.unified_msg_origin,
conversation.cid,
)
message.set_result(
MessageEventResult().message(
f"切换到对话: {title}({conversation.cid[:4]})。",
),
)
async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None:
"""重命名对话"""
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
return
await self.context.conversation_manager.update_conversation_title(
message.unified_msg_origin,
new_name,
)
message.set_result(MessageEventResult().message("重命名对话成功。"))
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
umo = message.unified_msg_origin
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
if not cid:
cfg = self.context.get_config(umo=umo)
is_unique_session = cfg["platform_settings"]["unique_session"]
if message.get_group_id() and not is_unique_session and message.role != "admin":
# 群聊,没开独立会话,发送人不是管理员
message.set_result(
MessageEventResult().message(
"❌ You are not in a conversation. Use /new to create one."
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。",
),
)
return
db = self.context.get_db()
async with db.get_db() as session:
result = await session.execute(
select(
func.count(case((col(ProviderStat.id).is_not(None), 1))).label(
"record_count",
),
func.coalesce(func.sum(ProviderStat.token_input_other), 0).label(
"total_input_other",
),
func.coalesce(func.sum(ProviderStat.token_input_cached), 0).label(
"total_input_cached",
),
func.coalesce(func.sum(ProviderStat.token_output), 0).label(
"total_output",
),
).where(
col(ProviderStat.agent_type) == "internal",
col(ProviderStat.conversation_id) == cid,
)
)
stats = result.one()
if stats.record_count == 0:
message.set_result(
MessageEventResult().message(
"📊 No stats available for this conversation yet."
),
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
active_event_registry.stop_all(umo, exclude=message)
await sp.remove_async(
scope="umo",
scope_id=umo,
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
)
message.set_result(MessageEventResult().message("重置对话成功。"))
return
total_input_other = stats.total_input_other
total_input_cached = stats.total_input_cached
total_output = stats.total_output
total_tokens = total_input_other + total_input_cached + total_output
ret = (
f"📊 Conversation Token usage (ID: {cid[:8]}...)\n"
f"Total: {total_tokens:,}\n"
f"Input (cached): {total_input_cached:,}\n"
f"Input (other): {total_input_other:,}\n"
f"Output: {total_output:,}\n"
session_curr_cid = (
await self.context.conversation_manager.get_curr_conversation_id(umo)
)
if not session_curr_cid:
message.set_result(
MessageEventResult().message(
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建。",
),
)
return
active_event_registry.stop_all(umo, exclude=message)
await self.context.conversation_manager.delete_conversation(
umo,
session_curr_cid,
)
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
message.set_extra("_clean_ltm_session", True)
message.set_result(MessageEventResult().message(ret))

View File

@@ -32,6 +32,7 @@ class HelpCommand:
return []
lines: list[str] = []
hidden_commands = {"set", "unset", "websearch"}
def walk(items: list[dict], indent: int = 0) -> None:
for item in items:
@@ -48,12 +49,9 @@ class HelpCommand:
or item.get("original_command")
or item.get("handler_name")
)
if not effective or effective in [
"set",
"unset",
"help",
"dashboard_update",
]:
if not effective:
continue
if effective in hidden_commands:
continue
description = item.get("description") or ""
@@ -75,13 +73,12 @@ class HelpCommand:
dashboard_version = await get_dashboard_version()
command_lines = await self._build_reserved_command_lines()
commands_section = (
"\n".join(command_lines)
if command_lines
else "No enabled built-in commands."
"\n".join(command_lines) if command_lines else "暂无启用的内置指令"
)
msg_parts = [
f"AstrBot v{VERSION}(WebUI: {dashboard_version})",
"内置指令:",
commands_section,
]
if notice:

View File

@@ -0,0 +1,20 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageChain
class LLMCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
cfg = self.context.get_config(umo=event.unified_msg_origin)
enable = cfg["provider_settings"].get("enable", True)
if enable:
cfg["provider_settings"]["enable"] = False
status = "关闭"
else:
cfg["provider_settings"]["enable"] = True
status = "开启"
cfg.save_config()
await event.send(MessageChain().message(f"{status} LLM 聊天功能。"))

View File

@@ -0,0 +1,216 @@
import builtins
from typing import TYPE_CHECKING
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
if TYPE_CHECKING:
from astrbot.core.db.po import Persona
class PersonaCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
def _build_tree_output(
self,
folder_tree: list[dict],
all_personas: list["Persona"],
depth: int = 0,
) -> list[str]:
"""递归构建树状输出,使用短线条表示层级"""
lines: list[str] = []
# 使用短线条作为缩进前缀,每层只用 "│" 加一个空格
prefix = "" * depth
for folder in folder_tree:
# 输出文件夹
lines.append(f"{prefix}├ 📁 {folder['name']}/")
# 获取该文件夹下的人格
folder_personas = [
p for p in all_personas if p.folder_id == folder["folder_id"]
]
child_prefix = "" * (depth + 1)
# 输出该文件夹下的人格
for persona in folder_personas:
lines.append(f"{child_prefix}├ 👤 {persona.persona_id}")
# 递归处理子文件夹
children = folder.get("children", [])
if children:
lines.extend(
self._build_tree_output(
children,
all_personas,
depth + 1,
)
)
return lines
async def persona(self, message: AstrMessageEvent) -> None:
l = message.message_str.split(" ") # noqa: E741
umo = message.unified_msg_origin
curr_persona_name = ""
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
default_persona = await self.context.persona_manager.get_default_persona_v3(
umo=umo,
)
force_applied_persona_id = None
curr_cid_title = ""
if cid:
conv = await self.context.conversation_manager.get_conversation(
unified_msg_origin=umo,
conversation_id=cid,
create_if_not_exists=True,
)
if conv is None:
message.set_result(
MessageEventResult().message(
"当前对话不存在,请先使用 /new 新建一个对话。",
),
)
return
provider_settings = self.context.get_config(umo=umo).get(
"provider_settings",
{},
)
(
persona_id,
_,
force_applied_persona_id,
_,
) = await self.context.persona_manager.resolve_selected_persona(
umo=umo,
conversation_persona_id=conv.persona_id,
platform_name=message.get_platform_name(),
provider_settings=provider_settings,
)
if persona_id == "[%None]":
curr_persona_name = ""
elif persona_id:
curr_persona_name = persona_id
if force_applied_persona_id:
curr_persona_name = f"{curr_persona_name} (自定义规则)"
curr_cid_title = conv.title if conv.title else "新对话"
curr_cid_title += f"({cid[:4]})"
if len(l) == 1:
message.set_result(
MessageEventResult()
.message(
f"""[Persona]
- 人格情景列表: `/persona list`
- 设置人格情景: `/persona 人格`
- 人格情景详细信息: `/persona view 人格`
- 取消人格: `/persona unset`
默认人格情景: {default_persona["name"]}
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
配置人格情景请前往管理面板-配置页
""",
)
.use_t2i(False),
)
elif l[1] == "list":
# 获取文件夹树和所有人格
folder_tree = await self.context.persona_manager.get_folder_tree()
all_personas = self.context.persona_manager.personas
lines = ["📂 人格列表:\n"]
# 构建树状输出
tree_lines = self._build_tree_output(folder_tree, all_personas)
lines.extend(tree_lines)
# 输出根目录下的人格(没有文件夹的)
root_personas = [p for p in all_personas if p.folder_id is None]
if root_personas:
if tree_lines: # 如果有文件夹内容,加个空行
lines.append("")
for persona in root_personas:
lines.append(f"👤 {persona.persona_id}")
# 统计信息
total_count = len(all_personas)
lines.append(f"\n{total_count} 个人格")
lines.append("\n*使用 `/persona <人格名>` 设置人格")
lines.append("*使用 `/persona view <人格名>` 查看详细信息")
msg = "\n".join(lines)
message.set_result(MessageEventResult().message(msg).use_t2i(False))
elif l[1] == "view":
if len(l) == 2:
message.set_result(MessageEventResult().message("请输入人格情景名"))
return
ps = l[2].strip()
if persona := next(
builtins.filter(
lambda persona: persona["name"] == ps,
self.context.provider_manager.personas,
),
None,
):
msg = f"人格{ps}的详细信息:\n"
msg += f"{persona['prompt']}\n"
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
elif l[1] == "unset":
if not cid:
message.set_result(
MessageEventResult().message("当前没有对话,无法取消人格。"),
)
return
await self.context.conversation_manager.update_conversation_persona_id(
message.unified_msg_origin,
"[%None]",
)
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
ps = "".join(l[1:]).strip()
if not cid:
message.set_result(
MessageEventResult().message(
"当前没有对话,请先开始对话或使用 /new 创建一个对话。",
),
)
return
if persona := next(
builtins.filter(
lambda persona: persona["name"] == ps,
self.context.provider_manager.personas,
),
None,
):
await self.context.conversation_manager.update_conversation_persona_id(
message.unified_msg_origin,
ps,
)
force_warn_msg = ""
if force_applied_persona_id:
force_warn_msg = (
"提醒:由于自定义规则,您现在切换的人格将不会生效。"
)
message.set_result(
MessageEventResult().message(
f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。{force_warn_msg}",
),
)
else:
message.set_result(
MessageEventResult().message(
"不存在该人格情景。使用 /persona list 查看所有。",
),
)

View File

@@ -0,0 +1,120 @@
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core import DEMO_MODE, logger
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
from astrbot.core.star.star_manager import PluginManager
class PluginCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
parts = ["已加载的插件:\n"]
for plugin in self.context.get_all_stars():
line = f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
if not plugin.activated:
line += " (未启用)"
parts.append(line + "\n")
if len(parts) == 1:
plugin_list_info = "没有加载任何插件。"
else:
plugin_list_info = "".join(parts)
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
)
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
return
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin off <插件名> 禁用插件。"),
)
return
await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
return
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin on <插件名> 启用插件。"),
)
return
await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
return
if not plugin_repo:
event.set_result(
MessageEventResult().message("/plugin get <插件仓库地址> 安装插件"),
)
return
logger.info(f"准备从 {plugin_repo} 安装插件。")
if self.context._star_manager:
star_mgr: PluginManager = self.context._star_manager
try:
await star_mgr.install_plugin(plugin_repo) # type: ignore
event.set_result(MessageEventResult().message("安装插件成功。"))
except Exception as e:
logger.error(f"安装插件失败: {e}")
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
return
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
if not plugin_name:
event.set_result(
MessageEventResult().message("/plugin help <插件名> 查看插件信息。"),
)
return
plugin = self.context.get_registered_star(plugin_name)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
return
help_msg = ""
help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}"
command_handlers = []
command_names = []
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
if handler.handler_module_path != plugin.module_path:
continue
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
command_handlers.append(handler)
command_names.append(filter_.command_name)
break
if isinstance(filter_, CommandGroupFilter):
command_handlers.append(handler)
command_names.append(filter_.group_name)
if len(command_handlers) > 0:
parts = ["\n\n🔧 指令列表:\n"]
for i in range(len(command_handlers)):
line = f"- {command_names[i]}"
if command_handlers[i].desc:
line += f": {command_handlers[i].desc}"
parts.append(line + "\n")
parts.append("\nTip: 指令的触发需要添加唤醒前缀,默认为 /。")
help_msg += "".join(parts)
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))

View File

@@ -1,6 +1,10 @@
from __future__ import annotations
import asyncio
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING
from astrbot import logger
from astrbot.api import star
@@ -8,10 +12,251 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
from astrbot.core.utils.error_redaction import safe_error
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT = 30.0
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT = 4
MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16
MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds"
MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency"
MODEL_CACHE_MAX_ENTRIES = 512
@dataclass(frozen=True)
class _ModelLookupConfig:
umo: str | None
cache_ttl_seconds: float
max_concurrency: int
class _ModelCache:
def __init__(self) -> None:
self._store: dict[tuple[str, str | None], tuple[float, list[str]]] = {}
def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None:
if ttl <= 0:
return None
entry = self._store.get((provider_id, umo))
if not entry:
return None
timestamp, models = entry
if time.monotonic() - timestamp > ttl:
self._store.pop((provider_id, umo), None)
return None
return models
def set(
self, provider_id: str, umo: str | None, models: list[str], ttl: float
) -> None:
if ttl <= 0:
return
self._store[(provider_id, umo)] = (time.monotonic(), list(models))
self._evict_if_needed()
def _evict_if_needed(self) -> None:
if len(self._store) <= MODEL_CACHE_MAX_ENTRIES:
return
# Drop oldest entries first when cache grows too large.
overflow = len(self._store) - MODEL_CACHE_MAX_ENTRIES
for key, _ in sorted(
self._store.items(),
key=lambda item: item[1][0],
)[:overflow]:
self._store.pop(key, None)
def invalidate(
self, provider_id: str | None = None, *, umo: str | None = None
) -> None:
if provider_id is None:
self._store.clear()
return
if umo is not None:
self._store.pop((provider_id, umo), None)
return
stale_keys = [
cache_key for cache_key in self._store if cache_key[0] == provider_id
]
for cache_key in stale_keys:
self._store.pop(cache_key, None)
class ProviderCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
self._model_cache = _ModelCache()
self._register_provider_change_hook()
def _register_provider_change_hook(self) -> None:
set_change_callback = getattr(
self.context.provider_manager,
"set_provider_change_callback",
None,
)
if callable(set_change_callback):
set_change_callback(self._on_provider_manager_changed)
return
register_change_hook = getattr(
self.context.provider_manager,
"register_provider_change_hook",
None,
)
if callable(register_change_hook):
register_change_hook(self._on_provider_manager_changed)
def invalidate_provider_models_cache(
self, provider_id: str | None = None, *, umo: str | None = None
) -> None:
"""Public hook for cache invalidation on external provider config changes."""
self._model_cache.invalidate(provider_id, umo=umo)
def _on_provider_manager_changed(
self,
provider_id: str,
provider_type: ProviderType,
umo: str | None,
) -> None:
if provider_type == ProviderType.CHAT_COMPLETION:
self.invalidate_provider_models_cache(provider_id, umo=umo)
def _get_provider_settings(self, umo: str | None) -> dict:
if not umo:
return {}
try:
return self.context.get_config(umo).get("provider_settings", {}) or {}
except Exception as e:
logger.debug(
"读取 provider_settings 失败,使用默认值: %s",
safe_error("", e),
)
return {}
def _get_model_cache_ttl(self, umo: str | None) -> float:
settings = self._get_provider_settings(umo)
raw = settings.get(
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
)
try:
return max(float(raw), 0.0)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
safe_error("", e),
)
return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
def _get_model_lookup_concurrency(self, umo: str | None) -> int:
settings = self._get_provider_settings(umo)
raw = settings.get(
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
)
try:
value = int(raw)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
safe_error("", e),
)
value = MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
return min(max(value, 1), MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND)
def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig:
return _ModelLookupConfig(
umo=umo,
cache_ttl_seconds=self._get_model_cache_ttl(umo),
max_concurrency=self._get_model_lookup_concurrency(umo),
)
def _resolve_model_name(
self,
model_name: str,
models: Sequence[str],
) -> str | None:
"""Resolve model name with precedence:
exact > case-insensitive > provider-qualified suffix.
"""
requested = model_name.strip()
if not requested:
return None
requested_norm = requested.casefold()
# exact / case-insensitive match
for candidate in models:
if candidate == requested or candidate.casefold() == requested_norm:
return candidate
# provider-qualified suffix match:
# e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`.
for candidate in models:
cand_norm = candidate.casefold()
if cand_norm.endswith(f"/{requested_norm}") or cand_norm.endswith(
f":{requested_norm}"
):
return candidate
return None
def _apply_model(
self, prov: Provider, model_name: str, *, umo: str | None = None
) -> str:
prov.set_model(model_name)
self.invalidate_provider_models_cache(prov.meta().id, umo=umo)
return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
async def _get_provider_models(
self,
provider: Provider,
*,
config: _ModelLookupConfig,
use_cache: bool = True,
) -> list[str]:
provider_id = provider.meta().id
ttl_seconds = config.cache_ttl_seconds
umo = config.umo
if use_cache:
cached = self._model_cache.get(provider_id, umo, ttl_seconds)
if cached is not None:
return cached
models = list(await provider.get_models())
if use_cache:
self._model_cache.set(provider_id, umo, models, ttl_seconds)
return models
async def _get_models_or_reply_error(
self,
message: AstrMessageEvent,
prov: Provider,
config: _ModelLookupConfig,
*,
error_prefix: str,
disable_t2i: bool = False,
warning_log: str | None = None,
) -> list[str] | None:
try:
return await self._get_provider_models(prov, config=config)
except asyncio.CancelledError:
raise
except Exception as e:
if warning_log is not None:
logger.warning(
warning_log,
prov.meta().id,
safe_error("", e),
)
result = MessageEventResult().message(safe_error(error_prefix, e))
if disable_t2i:
result = result.use_t2i(False)
message.set_result(result)
return None
def _log_reachability_failure(
self,
@@ -20,6 +265,7 @@ class ProviderCommands:
err_code: str,
err_reason: str,
) -> None:
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
"Provider reachability check failed: id=%s type=%s code=%s reason=%s",
@@ -30,6 +276,7 @@ class ProviderCommands:
)
async def _test_provider_capability(self, provider):
"""测试单个 provider 的可用性"""
meta = provider.meta()
provider_capability_type = meta.provider_type
@@ -44,69 +291,89 @@ class ProviderCommands:
)
return False, err_code, err_reason
async def _build_provider_display_data(
async def _find_provider_for_model(
self,
providers,
provider_type: str,
reachability_check_enabled: bool,
) -> list[dict]:
if not providers:
return []
model_name: str,
*,
exclude_provider_id: str | None = None,
config: _ModelLookupConfig,
use_cache: bool = True,
) -> tuple[Provider | None, str | None]:
all_providers = []
for provider in self.context.get_all_providers():
provider_meta = provider.meta()
if provider_meta.provider_type != ProviderType.CHAT_COMPLETION:
continue
if (
exclude_provider_id is not None
and provider_meta.id == exclude_provider_id
):
continue
all_providers.append(provider)
if not all_providers:
return None, None
if reachability_check_enabled:
check_results = await asyncio.gather(
*[self._test_provider_capability(provider) for provider in providers],
return_exceptions=True,
semaphore = asyncio.Semaphore(config.max_concurrency)
async def fetch_models(
provider: Provider,
) -> tuple[Provider, list[str] | None, str | None]:
async with semaphore:
try:
models = await self._get_provider_models(
provider,
config=config,
use_cache=use_cache,
)
return provider, models, None
except asyncio.CancelledError:
raise
except Exception as e:
err = safe_error("", e)
logger.debug(
"跨提供商查找模型 %s 获取 %s 模型列表失败: %s",
model_name,
provider.meta().id,
err,
)
return provider, None, err
results = await asyncio.gather(
*(fetch_models(provider) for provider in all_providers)
)
failed_provider_errors: list[tuple[str, str]] = []
for provider, models, err in results:
if err is not None:
failed_provider_errors.append((provider.meta().id, err))
continue
if models is None:
continue
matched_model_name = self._resolve_model_name(model_name, models)
if matched_model_name is not None:
return provider, matched_model_name
if failed_provider_errors and len(failed_provider_errors) == len(all_providers):
failed_ids = ",".join(
provider_id for provider_id, _ in failed_provider_errors
)
else:
check_results = [None for _ in providers]
display_data = []
for provider, reachable in zip(providers, check_results):
meta = provider.meta()
id_ = meta.id
error_code = None
if isinstance(reachable, asyncio.CancelledError):
raise reachable
if isinstance(reachable, Exception):
self._log_reachability_failure(
provider,
None,
reachable.__class__.__name__,
safe_error("", reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
elif isinstance(reachable, tuple):
reachable_flag, error_code, _ = reachable
else:
reachable_flag = reachable
if provider_type == "llm":
info = f"{id_} ({meta.model})"
else:
info = f"{id_}"
if reachable_flag is True:
mark = ""
elif reachable_flag is False:
if error_code:
mark = f" ❌(errcode: {error_code})"
else:
mark = ""
else:
mark = ""
display_data.append(
{
"info": info,
"mark": mark,
"provider": provider,
}
logger.error(
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络",
model_name,
len(all_providers),
failed_ids,
)
return display_data
elif failed_provider_errors:
logger.debug(
"跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s",
model_name,
len(failed_provider_errors),
",".join(
f"{provider_id}({error})"
for provider_id, error in failed_provider_errors
),
)
return None, None
async def provider(
self,
@@ -120,82 +387,137 @@ class ProviderCommands:
reachability_check_enabled = cfg.get("reachability_check", True)
if idx is None:
parts = ["## LLM Providers\n"]
parts = ["## 载入的 LLM 提供商\n"]
# 获取所有类型的提供商
llms = list(self.context.get_all_providers())
ttss = self.context.get_all_tts_providers()
stts = self.context.get_all_stt_providers()
if reachability_check_enabled and (llms or ttss or stts):
await event.send(
MessageEventResult().message("👀 Testing provider reachability...")
# 构造待检测列表: [(provider, type_label), ...]
all_providers = []
all_providers.extend([(p, "llm") for p in llms])
all_providers.extend([(p, "tts") for p in ttss])
all_providers.extend([(p, "stt") for p in stts])
# 并发测试连通性
if reachability_check_enabled:
if all_providers:
await event.send(
MessageEventResult().message(
"正在进行提供商可达性测试,请稍候..."
)
)
check_results = await asyncio.gather(
*[self._test_provider_capability(p) for p, _ in all_providers],
return_exceptions=True,
)
else:
# 用 None 表示未检测
check_results = [None for _ in all_providers]
# 整合结果
display_data = []
for (p, p_type), reachable in zip(all_providers, check_results):
meta = p.meta()
id_ = meta.id
error_code = None
if isinstance(reachable, asyncio.CancelledError):
raise reachable
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
safe_error("", reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
elif isinstance(reachable, tuple):
reachable_flag, error_code, _ = reachable
else:
reachable_flag = reachable
# 根据类型构建显示名称
if p_type == "llm":
info = f"{id_} ({meta.model})"
else:
info = f"{id_}"
# 确定状态标记
if reachable_flag is True:
mark = ""
elif reachable_flag is False:
if error_code:
mark = f" ❌(错误码: {error_code})"
else:
mark = ""
else:
mark = "" # 不支持检测时不显示标记
display_data.append(
{
"type": p_type,
"info": info,
"mark": mark,
"provider": p,
}
)
llm_data, tts_data, stt_data = await asyncio.gather(
self._build_provider_display_data(
llms,
"llm",
reachability_check_enabled,
),
self._build_provider_display_data(
ttss,
"tts",
reachability_check_enabled,
),
self._build_provider_display_data(
stts,
"stt",
reachability_check_enabled,
),
)
provider_using = self.context.get_using_provider(umo=umo)
# 分组输出
# 1. LLM
llm_data = [d for d in display_data if d["type"] == "llm"]
for i, d in enumerate(llm_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
provider_using = self.context.get_using_provider(umo=umo)
if (
provider_using
and provider_using.meta().id == d["provider"].meta().id
):
line += " 👈"
line += " (当前使用)"
parts.append(line + "\n")
# 2. TTS
tts_data = [d for d in display_data if d["type"] == "tts"]
if tts_data:
parts.append("\n## TTS Providers\n")
tts_using = self.context.get_using_tts_provider(umo=umo)
parts.append("\n## 载入的 TTS 提供商\n")
for i, d in enumerate(tts_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
tts_using = self.context.get_using_tts_provider(umo=umo)
if tts_using and tts_using.meta().id == d["provider"].meta().id:
line += " 👈"
line += " (当前使用)"
parts.append(line + "\n")
# 3. STT
stt_data = [d for d in display_data if d["type"] == "stt"]
if stt_data:
parts.append("\n## STT Providers\n")
stt_using = self.context.get_using_stt_provider(umo=umo)
parts.append("\n## 载入的 STT 提供商\n")
for i, d in enumerate(stt_data):
line = f"{i + 1}. {d['info']}{d['mark']}"
stt_using = self.context.get_using_stt_provider(umo=umo)
if stt_using and stt_using.meta().id == d["provider"].meta().id:
line += " 👈"
line += " (当前使用)"
parts.append(line + "\n")
parts.append("\nUse /provider <idx> to switch LLM providers.")
parts.append("\n使用 /provider <序号> 切换 LLM 提供商。")
ret = "".join(parts)
if ttss:
ret += "\nUse /provider tts <idx> to switch TTS providers."
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
if stts:
ret += "\nUse /provider stt <idx> to switch STT providers."
ret += "\n使用 /provider stt <序号> 切换 STT 提供商。"
if not reachability_check_enabled:
ret += "\n已跳过提供商可达性检测,如需检测请在配置文件中开启。"
event.set_result(MessageEventResult().message(ret))
elif idx == "tts":
if idx2 is None:
event.set_result(
MessageEventResult().message("Please enter the index.")
)
event.set_result(MessageEventResult().message("请输入序号。"))
return
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_tts_providers()[idx2 - 1]
id_ = provider.meta().id
@@ -204,19 +526,13 @@ class ProviderCommands:
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif idx == "stt":
if idx2 is None:
event.set_result(
MessageEventResult().message("Please enter the index.")
)
event.set_result(MessageEventResult().message("请输入序号。"))
return
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_stt_providers()[idx2 - 1]
id_ = provider.meta().id
@@ -225,14 +541,10 @@ class ProviderCommands:
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
@@ -241,8 +553,184 @@ class ProviderCommands:
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
else:
event.set_result(MessageEventResult().message("❌ Invalid parameter."))
event.set_result(MessageEventResult().message("无效的参数。"))
async def _switch_model_by_name(
self, message: AstrMessageEvent, model_name: str, prov: Provider
) -> None:
model_name = model_name.strip()
if not model_name:
message.set_result(MessageEventResult().message("模型名不能为空。"))
return
umo = message.unified_msg_origin
config = self._get_model_lookup_config(umo)
curr_provider_id = prov.meta().id
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取当前提供商模型列表失败: ",
warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
)
if models is None:
return
matched_model_name = self._resolve_model_name(model_name, models)
if matched_model_name is not None:
message.set_result(
MessageEventResult().message(
self._apply_model(prov, matched_model_name, umo=umo)
),
)
return
target_prov, matched_target_model_name = await self._find_provider_for_model(
model_name,
exclude_provider_id=curr_provider_id,
config=config,
)
if target_prov is None or matched_target_model_name is None:
message.set_result(
MessageEventResult().message(
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。",
),
)
return
target_id = target_prov.meta().id
try:
await self.context.provider_manager.set_provider(
provider_id=target_id,
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
self._apply_model(target_prov, matched_target_model_name, umo=umo)
message.set_result(
MessageEventResult().message(
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
),
)
except asyncio.CancelledError:
raise
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("跨提供商切换并设置模型失败: ", e)
),
)
async def model_ls(
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
) -> None:
"""查看或者切换模型"""
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
config = self._get_model_lookup_config(message.unified_msg_origin)
if idx_or_name is None:
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取模型列表失败: ",
disable_t2i=True,
)
if models is None:
return
parts = ["下面列出了此模型提供商可用模型:"]
for i, model in enumerate(models, 1):
parts.append(f"\n{i}. {model}")
curr_model = prov.get_model() or ""
parts.append(f"\n当前模型: [{curr_model}]")
parts.append(
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。"
)
ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
elif isinstance(idx_or_name, int):
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取模型列表失败: ",
)
if models is None:
return
if idx_or_name > len(models) or idx_or_name < 1:
message.set_result(MessageEventResult().message("模型序号错误。"))
else:
try:
new_model = models[idx_or_name - 1]
message.set_result(
MessageEventResult().message(
self._apply_model(
prov,
new_model,
umo=message.unified_msg_origin,
)
),
)
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("切换模型未知错误: ", e)
),
)
return
else:
await self._switch_model_by_name(message, idx_or_name, prov)
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
message.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
if index is None:
keys_data = prov.get_keys()
curr_key = prov.get_current_key()
parts = ["Key:"]
for i, k in enumerate(keys_data, 1):
parts.append(f"\n{i}. {k[:8]}")
parts.append(f"\n当前 Key: {curr_key[:8]}")
parts.append("\n当前模型: " + prov.get_model())
parts.append("\n使用 /key <idx> 切换 Key。")
ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
keys_data = prov.get_keys()
if index > len(keys_data) or index < 1:
message.set_result(MessageEventResult().message("Key 序号错误。"))
else:
try:
new_key = keys_data[index - 1]
prov.set_key(new_key)
self.invalidate_provider_models_cache(
prov.meta().id,
umo=message.unified_msg_origin,
)
message.set_result(MessageEventResult().message("切换 Key 成功。"))
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("切换 Key 未知错误: ", e)
),
)
return

View File

@@ -18,19 +18,19 @@ class SIDCommand:
umo_msg_type = event.session.message_type.value
umo_session_id = event.session.session_id
ret = (
f"UMO: 「{sid}\n"
f"UID: 「{user_id}\n"
"*Use UMO to set whitelist and configure routing, use UID to set admin list(UMO 可用于设置白名单和配置文件路由UID 可用于设置管理员列表)\n\n"
f"Your session information:\n"
f"Bot ID: 「{umo_platform}\n"
f"Message Type: 「{umo_msg_type}\n"
f"Session ID: 「{umo_session_id}\n\n"
f"UMO: 「{sid} 此值可用于设置白名单。\n"
f"UID: 「{user_id} 此值可用于设置管理员。\n"
f"消息会话来源信息:\n"
f" 机器人 ID: 「{umo_platform}\n"
f" 消息类型: 「{umo_msg_type}\n"
f" 会话 ID: 「{umo_session_id}\n"
f"消息来源可用于配置机器人的配置文件路由。"
)
if (
self.context.get_config()["platform_settings"]["unique_session"]
and event.get_group_id()
):
ret += f"\n\nThe group's ID: 「{event.get_group_id()}. Set this ID to whitelist to allow the entire group."
ret += f"\n\n当前处于独立会话模式, 此群 ID: 「{event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))

View File

@@ -0,0 +1,23 @@
"""文本转图片命令"""
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
class T2ICommand:
"""文本转图片命令类"""
def __init__(self, context: star.Context) -> None:
self.context = context
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
config = self.context.get_config(umo=event.unified_msg_origin)
if config["t2i"]:
config["t2i"] = False
config.save_config()
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
return
config["t2i"] = True
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))

View File

@@ -0,0 +1,36 @@
"""文本转语音命令"""
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.star.session_llm_manager import SessionServiceManager
class TTSCommand:
"""文本转语音命令类"""
def __init__(self, context: star.Context) -> None:
self.context = context
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
tts_enable = cfg["provider_tts_settings"]["enable"]
# 切换状态
new_status = not ses_tts
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
status_text = "已开启" if new_status else "已关闭"
if new_status and not tts_enable:
event.set_result(
MessageEventResult().message(
f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。",
),
)
else:
event.set_result(
MessageEventResult().message(f"{status_text}当前会话的文本转语音。"),
)

View File

@@ -3,11 +3,17 @@ from astrbot.api.event import AstrMessageEvent, filter
from .commands import (
AdminCommands,
AlterCmdCommands,
ConversationCommands,
HelpCommand,
LLMCommands,
PersonaCommands,
PluginCommands,
ProviderCommands,
SetUnsetCommands,
SIDCommand,
T2ICommand,
TTSCommand,
)
@@ -15,42 +21,100 @@ class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
self.help_c = HelpCommand(self.context)
self.llm_c = LLMCommands(self.context)
self.plugin_c = PluginCommands(self.context)
self.admin_c = AdminCommands(self.context)
self.conversation_c = ConversationCommands(self.context)
self.help_c = HelpCommand(self.context)
self.provider_c = ProviderCommands(self.context)
self.persona_c = PersonaCommands(self.context)
self.alter_cmd_c = AlterCmdCommands(self.context)
self.setunset_c = SetUnsetCommands(self.context)
self.t2i_c = T2ICommand(self.context)
self.tts_c = TTSCommand(self.context)
self.sid_c = SIDCommand(self.context)
@filter.command("help")
async def help(self, event: AstrMessageEvent) -> None:
"""Show help message"""
"""查看帮助"""
await self.help_c.help(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("llm")
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
await self.llm_c.llm(event)
@filter.command_group("plugin")
def plugin(self) -> None:
"""插件管理"""
@plugin.command("ls")
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
await self.plugin_c.plugin_ls(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("off")
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
await self.plugin_c.plugin_off(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("on")
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
await self.plugin_c.plugin_on(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("get")
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
await self.plugin_c.plugin_get(event, plugin_repo)
@plugin.command("help")
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
await self.plugin_c.plugin_help(event, plugin_name)
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
await self.t2i_c.t2i(event)
@filter.command("tts")
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
await self.tts_c.tts(event)
@filter.command("sid")
async def sid(self, event: AstrMessageEvent) -> None:
"""Get session ID and other related information"""
"""获取会话 ID 和 管理员 ID"""
await self.sid_c.sid(event)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent) -> None:
"""Reset conversation history"""
await self.conversation_c.reset(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
await self.admin_c.op(event, admin_id)
@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""Stop agent execution"""
await self.conversation_c.stop(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
"""取消授权管理员。deop <admin_id>"""
await self.admin_c.deop(event, admin_id)
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent) -> None:
"""Create new conversation"""
await self.conversation_c.new_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
await self.admin_c.wl(event, sid)
@filter.command("stats")
async def stats(self, message: AstrMessageEvent) -> None:
"""Show token usage statistics for the current conversation"""
await self.conversation_c.stats(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
"""删除白名单。dwl <sid>"""
await self.admin_c.dwl(event, sid)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("provider")
@@ -60,21 +124,95 @@ class Main(star.Star):
idx: str | int | None = None,
idx2: int | None = None,
) -> None:
"""View or switch LLM Provider"""
"""查看或者切换 LLM Provider"""
await self.provider_c.provider(event, idx, idx2)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
await self.conversation_c.reset(message)
@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话中正在运行的 Agent"""
await self.conversation_c.stop(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("model")
async def model_ls(
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
) -> None:
"""查看或者切换模型"""
await self.provider_c.model_ls(message, idx_or_name)
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
await self.conversation_c.his(message, page)
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
await self.conversation_c.convs(message, page)
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
await self.conversation_c.new_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("groupnew")
async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None:
"""创建新群聊对话"""
await self.conversation_c.groupnew_conv(message, sid)
@filter.command("switch")
async def switch_conv(
self, message: AstrMessageEvent, index: int | None = None
) -> None:
"""通过 /ls 前面的序号切换对话"""
await self.conversation_c.switch_conv(message, index)
@filter.command("rename")
async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None:
"""重命名对话"""
await self.conversation_c.rename_conv(message, new_name)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
await self.conversation_c.del_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
"""查看或者切换 Key"""
await self.provider_c.key(message, index)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("persona")
async def persona(self, message: AstrMessageEvent) -> None:
"""查看或者切换 Persona"""
await self.persona_c.persona(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""Update AstrBot WebUI"""
"""更新管理面板"""
await self.admin_c.update_dashboard(event)
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
"""Set session variable"""
await self.setunset_c.set_variable(event, key, value)
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
"""Unset session variable"""
await self.setunset_c.unset_variable(event, key)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent) -> None:
"""修改命令权限"""
await self.alter_cmd_c.alter_cmd(event)

View File

@@ -91,8 +91,6 @@ class Main(Star):
controller: SessionController,
event: AstrMessageEvent,
) -> None:
if not event.message_str or not event.message_str.strip():
return
event.message_obj.message.insert(
0,
Comp.At(qq=event.get_self_id(), name=event.get_self_id()),

View File

@@ -0,0 +1,112 @@
import random
import urllib.parse
from dataclasses import dataclass
from aiohttp import ClientSession
from bs4 import BeautifulSoup, Tag
HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
"Accept": "*/*",
"Connection": "keep-alive",
"Accept-Language": "en-GB,en;q=0.5",
}
USER_AGENT_BING = "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0"
USER_AGENTS = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0",
]
@dataclass
class SearchResult:
title: str
url: str
snippet: str
favicon: str | None = None
def __str__(self) -> str:
return f"{self.title} - {self.url}\n{self.snippet}"
class SearchEngine:
"""搜索引擎爬虫基类"""
def __init__(self) -> None:
self.TIMEOUT = 10
self.page = 1
self.headers = HEADERS
def _set_selector(self, selector: str) -> str:
raise NotImplementedError
async def _get_next_page(self, query: str) -> str:
raise NotImplementedError
async def _get_html(self, url: str, data: dict | None = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
if data:
async with (
ClientSession() as session,
session.post(
url,
headers=headers,
data=data,
timeout=self.TIMEOUT,
) as resp,
):
ret = await resp.text(encoding="utf-8")
return ret
else:
async with (
ClientSession() as session,
session.get(
url,
headers=headers,
timeout=self.TIMEOUT,
) as resp,
):
ret = await resp.text(encoding="utf-8")
return ret
def tidy_text(self, text: str) -> str:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def _get_url(self, tag: Tag) -> str:
return self.tidy_text(tag.get_text())
async def search(self, query: str, num_results: int) -> list[SearchResult]:
query = urllib.parse.quote(query)
try:
resp = await self._get_next_page(query)
soup = BeautifulSoup(resp, "html.parser")
links = soup.select(self._set_selector("links"))
results = []
for link in links:
# Safely get the title text (select_one may return None)
title_elem = link.select_one(self._set_selector("title"))
title = ""
if title_elem is not None:
title = self.tidy_text(title_elem.get_text())
url_tag = link.select_one(self._set_selector("url"))
snippet = ""
if title and url_tag:
url = self._get_url(url_tag)
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
raise e

View File

@@ -0,0 +1,30 @@
from . import USER_AGENT_BING, SearchEngine
class Bing(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_urls = ["https://cn.bing.com", "https://www.bing.com"]
self.headers.update({"User-Agent": USER_AGENT_BING})
def _set_selector(self, selector: str):
selectors = {
"url": "div.b_attribution cite",
"title": "h2",
"text": "p",
"links": "ol#b_results > li.b_algo",
"next": 'div#b_content nav[role="navigation"] a.sb_pagN',
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
# if self.page == 1:
# await self._get_html(self.base_url)
for base_url in self.base_urls:
try:
url = f"{base_url}/search?q={query}"
return await self._get_html(url, None)
except Exception as _:
self.base_url = base_url
continue
raise Exception("Bing search failed")

View File

@@ -0,0 +1,52 @@
import random
import re
from typing import cast
from bs4 import BeautifulSoup, Tag
from . import USER_AGENTS, SearchEngine, SearchResult
class Sogo(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.sogou.com"
self.headers["User-Agent"] = random.choice(USER_AGENTS)
def _set_selector(self, selector: str):
selectors = {
"url": "h3 > a",
"title": "h3",
"text": "",
"links": "div.results > div.vrwrap:not(.middle-better-hintBox)",
"next": "",
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
url = f"{self.base_url}/web?query={query}"
return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
return cast(str, tag.get("href"))
async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
return results
async def _parse_url(self, url) -> str:
html = await self._get_html(url)
soup = BeautifulSoup(html, "html.parser")
script = soup.find("script")
if script:
script_text = (
script.string if script.string is not None else script.get_text()
)
match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
if match:
url = match.group(1)
return url

View File

@@ -0,0 +1,611 @@
import asyncio
import json
import random
import uuid
import aiohttp
from bs4 import BeautifulSoup
from readability import Document
from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star
from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.provider import ProviderRequest
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from .engines import HEADERS, USER_AGENTS, SearchResult
from .engines.bing import Bing
from .engines.sogo import Sogo
class Main(star.Star):
TOOLS = [
"web_search",
"fetch_url",
"web_search_tavily",
"tavily_extract_web_page",
"web_search_bocha",
]
def __init__(self, context: star.Context) -> None:
self.context = context
self.tavily_key_index = 0
self.tavily_key_lock = asyncio.Lock()
self.bocha_key_index = 0
self.bocha_key_lock = asyncio.Lock()
# 将 str 类型的 key 迁移至 list[str],并保存
cfg = self.context.get_config()
provider_settings = cfg.get("provider_settings")
if provider_settings:
tavily_key = provider_settings.get("websearch_tavily_key")
if isinstance(tavily_key, str):
logger.info(
"检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。",
)
if tavily_key:
provider_settings["websearch_tavily_key"] = [tavily_key]
else:
provider_settings["websearch_tavily_key"] = []
cfg.save_config()
bocha_key = provider_settings.get("websearch_bocha_key")
if isinstance(bocha_key, str):
if bocha_key:
provider_settings["websearch_bocha_key"] = [bocha_key]
else:
provider_settings["websearch_bocha_key"] = []
cfg.save_config()
self.bing_search = Bing()
self.sogo_search = Sogo()
self.baidu_initialized = False
async def _tidy_text(self, text: str) -> str:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def _get_from_url(self, url: str) -> str:
"""获取网页内容"""
header = HEADERS
header.update({"User-Agent": random.choice(USER_AGENTS)})
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, headers=header) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, "html.parser")
ret = await self._tidy_text(soup.get_text())
return ret
async def _process_search_result(
self,
result: SearchResult,
idx: int,
websearch_link: bool,
) -> str:
"""处理单个搜索结果"""
logger.info(f"web_searcher - scraping web: {result.title} - {result.url}")
try:
site_result = await self._get_from_url(result.url)
except BaseException:
site_result = ""
site_result = (
f"{site_result[:700]}..." if len(site_result) > 700 else site_result
)
header = f"{idx}. {result.title} "
if websearch_link and result.url:
header += result.url
return f"{header}\n{result.snippet}\n{site_result}\n\n"
async def _web_search_default(
self,
query,
num_results: int = 5,
) -> list[SearchResult]:
results = []
try:
results = await self.bing_search.search(query, num_results)
except Exception as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await self.sogo_search.search(query, num_results)
except Exception as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
logger.debug("search sogo failed")
return []
return results
async def _get_tavily_key(self, cfg: AstrBotConfig) -> str:
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
tavily_keys = cfg.get("provider_settings", {}).get("websearch_tavily_key", [])
if not tavily_keys:
raise ValueError("错误Tavily API密钥未在AstrBot中配置。")
async with self.tavily_key_lock:
key = tavily_keys[self.tavily_key_index]
self.tavily_key_index = (self.tavily_key_index + 1) % len(tavily_keys)
return key
async def _web_search_tavily(
self,
cfg: AstrBotConfig,
payload: dict,
) -> list[SearchResult]:
"""使用 Tavily 搜索引擎进行搜索"""
tavily_key = await self._get_tavily_key(cfg)
url = "https://api.tavily.com/search"
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
url,
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Tavily web search failed: {reason}, status: {response.status}",
)
data = await response.json()
results = []
for item in data.get("results", []):
result = SearchResult(
title=item.get("title"),
url=item.get("url"),
snippet=item.get("content"),
favicon=item.get("favicon"),
)
results.append(result)
return results
async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict]:
"""使用 Tavily 提取网页内容"""
tavily_key = await self._get_tavily_key(cfg)
url = "https://api.tavily.com/extract"
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
url,
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Tavily web search failed: {reason}, status: {response.status}",
)
data = await response.json()
results: list[dict] = data.get("results", [])
if not results:
raise ValueError(
"Error: Tavily web searcher does not return any results.",
)
return results
@llm_tool(name="web_search")
async def search_from_search_engine(
self,
event: AstrMessageEvent,
query: str,
max_results: int = 5,
) -> str:
"""搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。
Args:
query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索。
max_results(number): 返回的最大搜索结果数量,默认为 5。
"""
logger.info(f"web_searcher - search_from_search_engine: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
websearch_link = cfg["provider_settings"].get("web_search_link", False)
results = await self._web_search_default(query, max_results)
if not results:
return "Error: web searcher does not return any results."
tasks = []
for idx, result in enumerate(results, 1):
task = self._process_search_result(result, idx, websearch_link)
tasks.append(task)
processed_results = await asyncio.gather(*tasks, return_exceptions=True)
ret = ""
for processed_result in processed_results:
if isinstance(processed_result, BaseException):
logger.error(f"Error processing search result: {processed_result}")
continue
ret += processed_result
if websearch_link:
ret += "\n\n针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
return ret
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None:
if self.baidu_initialized:
return
cfg = self.context.get_config(umo=umo)
key = cfg.get("provider_settings", {}).get(
"websearch_baidu_app_builder_key",
"",
)
if not key:
raise ValueError(
"Error: Baidu AI Search API key is not configured in AstrBot.",
)
func_tool_mgr = self.context.get_llm_tool_manager()
await func_tool_mgr.enable_mcp_server(
"baidu_ai_search",
config={
"transport": "sse",
"url": f"http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key={key}",
"headers": {},
"timeout": 600,
},
)
self.baidu_initialized = True
logger.info("Successfully initialized Baidu AI Search MCP server.")
@llm_tool(name="fetch_url")
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
"""Fetch the content of a website with the given web url
Args:
url(string): The url of the website to fetch content from
"""
resp = await self._get_from_url(url)
return resp
@llm_tool("web_search_tavily")
async def search_from_tavily(
self,
event: AstrMessageEvent,
query: str,
max_results: int = 7,
search_depth: str = "basic",
topic: str = "general",
days: int = 3,
time_range: str = "",
start_date: str = "",
end_date: str = "",
) -> str:
"""A web search tool that uses Tavily to search the web for relevant content.
Ideal for gathering current information, news, and detailed web content analysis.
Args:
query(string): Required. Search query.
max_results(number): Optional. The maximum number of results to return. Default is 7. Range is 5-20.
search_depth(string): Optional. The depth of the search, must be one of 'basic', 'advanced'. Default is "basic".
topic(string): Optional. The topic of the search, must be one of 'general', 'news'. Default is "general".
days(number): Optional. The number of days back from the current date to include in the search results. Please note that this feature is only available when using the 'news' search topic.
time_range(string): Optional. The time range back from the current date to include in the search results. This feature is available for both 'general' and 'news' search topics. Must be one of 'day', 'week', 'month', 'year'.
start_date(string): Optional. The start date for the search results in the format 'YYYY-MM-DD'.
end_date(string): Optional. The end date for the search results in the format 'YYYY-MM-DD'.
"""
logger.info(f"web_searcher - search_from_tavily: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []):
raise ValueError("Error: Tavily API key is not configured in AstrBot.")
# build payload
payload = {"query": query, "max_results": max_results, "include_favicon": True}
if search_depth not in ["basic", "advanced"]:
search_depth = "basic"
payload["search_depth"] = search_depth
if topic not in ["general", "news"]:
topic = "general"
payload["topic"] = topic
if topic == "news":
payload["days"] = days
if time_range in ["day", "week", "month", "year"]:
payload["time_range"] = time_range
if start_date:
payload["start_date"] = start_date
if end_date:
payload["end_date"] = end_date
results = await self._web_search_tavily(cfg, payload)
if not results:
return "Error: Tavily web searcher does not return any results."
ret_ls = []
ref_uuid = str(uuid.uuid4())[:4]
for idx, result in enumerate(results, 1):
index = f"{ref_uuid}.{idx}"
ret_ls.append(
{
"title": f"{result.title}",
"url": f"{result.url}",
"snippet": f"{result.snippet}",
# TODO: do not need ref for non-webchat platform adapter
"index": index,
}
)
if result.favicon:
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
# ret = "\n".join(ret_ls)
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
return ret
@llm_tool("tavily_extract_web_page")
async def tavily_extract_web_page(
self,
event: AstrMessageEvent,
url: str = "",
extract_depth: str = "basic",
) -> str:
"""Extract the content of a web page using Tavily.
Args:
url(string): Required. An URl to extract content from.
extract_depth(string): Optional. The depth of the extraction, must be one of 'basic', 'advanced'. Default is "basic".
"""
cfg = self.context.get_config(umo=event.unified_msg_origin)
if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []):
raise ValueError("Error: Tavily API key is not configured in AstrBot.")
if not url:
raise ValueError("Error: url must be a non-empty string.")
if extract_depth not in ["basic", "advanced"]:
extract_depth = "basic"
payload = {
"urls": [url],
"extract_depth": extract_depth,
}
results = await self._extract_tavily(cfg, payload)
ret_ls = []
for result in results:
ret_ls.append(f"URL: {result.get('url', 'No URL')}")
ret_ls.append(f"Content: {result.get('raw_content', 'No content')}")
ret = "\n".join(ret_ls)
if not ret:
return "Error: Tavily web searcher does not return any results."
return ret
async def _get_bocha_key(self, cfg: AstrBotConfig) -> str:
"""并发安全的从列表中获取并轮换BoCha API密钥。"""
bocha_keys = cfg.get("provider_settings", {}).get("websearch_bocha_key", [])
if not bocha_keys:
raise ValueError("错误BoCha API密钥未在AstrBot中配置。")
async with self.bocha_key_lock:
key = bocha_keys[self.bocha_key_index]
self.bocha_key_index = (self.bocha_key_index + 1) % len(bocha_keys)
return key
async def _web_search_bocha(
self,
cfg: AstrBotConfig,
payload: dict,
) -> list[SearchResult]:
"""使用 BoCha 搜索引擎进行搜索"""
bocha_key = await self._get_bocha_key(cfg)
url = "https://api.bochaai.com/v1/web-search"
header = {
"Authorization": f"Bearer {bocha_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
url,
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"BoCha web search failed: {reason}, status: {response.status}",
)
data = await response.json()
data = data["data"]["webPages"]["value"]
results = []
for item in data:
result = SearchResult(
title=item.get("name"),
url=item.get("url"),
snippet=item.get("snippet"),
favicon=item.get("siteIcon"),
)
results.append(result)
return results
@llm_tool("web_search_bocha")
async def search_from_bocha(
self,
event: AstrMessageEvent,
query: str,
freshness: str = "noLimit",
summary: bool = False,
include: str = "",
exclude: str = "",
count: int = 10,
) -> str:
"""
A web search tool based on Bocha Search API, used to retrieve web pages
related to the user's query.
Args:
query (string): Required. User's search query.
freshness (string): Optional. Specifies the time range of the search.
Supported values:
- "noLimit": No time limit (default, recommended).
- "oneDay": Within one day.
- "oneWeek": Within one week.
- "oneMonth": Within one month.
- "oneYear": Within one year.
- "YYYY-MM-DD..YYYY-MM-DD": Search within a specific date range.
Example: "2025-01-01..2025-04-06".
- "YYYY-MM-DD": Search on a specific date.
Example: "2025-04-06".
It is recommended to use "noLimit", as the search algorithm will
automatically optimize time relevance. Manually restricting the
time range may result in no search results.
summary (boolean): Optional. Whether to include a text summary
for each search result.
- True: Include summary.
- False: Do not include summary (default).
include (string): Optional. Specifies the domains to include in
the search. Multiple domains can be separated by "|" or ",".
A maximum of 100 domains is allowed.
Examples:
- "qq.com"
- "qq.com|m.163.com"
exclude (string): Optional. Specifies the domains to exclude from
the search. Multiple domains can be separated by "|" or ",".
A maximum of 100 domains is allowed.
Examples:
- "qq.com"
- "qq.com|m.163.com"
count (number): Optional. Number of search results to return.
- Range: 150
- Default: 10
The actual number of returned results may be less than the
specified count.
"""
logger.info(f"web_searcher - search_from_bocha: {query}")
cfg = self.context.get_config(umo=event.unified_msg_origin)
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
if not cfg.get("provider_settings", {}).get("websearch_bocha_key", []):
raise ValueError("Error: BoCha API key is not configured in AstrBot.")
# build payload
payload = {
"query": query,
"count": count,
}
# freshness时间范围
if freshness:
payload["freshness"] = freshness
# 是否返回摘要
payload["summary"] = summary
# include限制搜索域
if include:
payload["include"] = include
# exclude排除搜索域
if exclude:
payload["exclude"] = exclude
results = await self._web_search_bocha(cfg, payload)
if not results:
return "Error: BoCha web searcher does not return any results."
ret_ls = []
ref_uuid = str(uuid.uuid4())[:4]
for idx, result in enumerate(results, 1):
index = f"{ref_uuid}.{idx}"
ret_ls.append(
{
"title": f"{result.title}",
"url": f"{result.url}",
"snippet": f"{result.snippet}",
"index": index,
}
)
if result.favicon:
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
# ret = "\n".join(ret_ls)
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
return ret
@filter.on_llm_request(priority=-10000)
async def edit_web_search_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
) -> None:
"""Get the session conversation for the given event."""
cfg = self.context.get_config(umo=event.unified_msg_origin)
prov_settings = cfg.get("provider_settings", {})
websearch_enable = prov_settings.get("web_search", False)
provider = prov_settings.get("websearch_provider", "default")
tool_set = req.func_tool
if isinstance(tool_set, FunctionToolManager):
req.func_tool = tool_set.get_full_tool_set()
tool_set = req.func_tool
if not tool_set:
return
if not websearch_enable:
# pop tools
for tool_name in self.TOOLS:
tool_set.remove_tool(tool_name)
return
func_tool_mgr = self.context.get_llm_tool_manager()
if provider == "default":
web_search_t = func_tool_mgr.get_func("web_search")
fetch_url_t = func_tool_mgr.get_func("fetch_url")
if web_search_t:
tool_set.add_tool(web_search_t)
if fetch_url_t:
tool_set.add_tool(fetch_url_t)
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
tool_set.remove_tool("AIsearch")
tool_set.remove_tool("web_search_bocha")
elif provider == "tavily":
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
if web_search_tavily:
tool_set.add_tool(web_search_tavily)
if tavily_extract_web_page:
tool_set.add_tool(tavily_extract_web_page)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("AIsearch")
tool_set.remove_tool("web_search_bocha")
elif provider == "baidu_ai_search":
try:
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
aisearch_tool = func_tool_mgr.get_func("AIsearch")
if not aisearch_tool:
raise ValueError("Cannot get Baidu AI Search MCP tool.")
tool_set.add_tool(aisearch_tool)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")
tool_set.remove_tool("web_search_bocha")
except Exception as e:
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
elif provider == "bocha":
web_search_bocha = func_tool_mgr.get_func("web_search_bocha")
if web_search_bocha:
tool_set.add_tool(web_search_bocha)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
tool_set.remove_tool("AIsearch")
tool_set.remove_tool("web_search_tavily")
tool_set.remove_tool("tavily_extract_web_page")

View File

@@ -0,0 +1,4 @@
name: astrbot-web-searcher
desc: 让 LLM 具有网页检索能力
author: Soulter
version: 1.14.514

View File

@@ -1 +1 @@
__version__ = "4.23.5"
__version__ = "4.19.5"

View File

@@ -84,7 +84,7 @@ def new(name: str) -> None:
# Rewrite README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://docs.astrbot.app)\n"
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://astrbot.app)\n"
)
# Rewrite main.py

View File

@@ -1,7 +1,7 @@
import json
from typing import Protocol, runtime_checkable
from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart
from ..message import Message, TextPart
@runtime_checkable
@@ -28,19 +28,9 @@ class TokenCounter(Protocol):
...
# 图片/音频 token 开销估算值,参考 OpenAI vision pricing:
# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千。
# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错。
IMAGE_TOKEN_ESTIMATE = 765
AUDIO_TOKEN_ESTIMATE = 500
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
Supports multimodal content: images, audio, and thinking parts
are all counted so that the context compressor can trigger in time.
"""
def count_tokens(
@@ -55,16 +45,12 @@ class EstimateTokenCounter:
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
elif isinstance(part, ThinkPart):
total += self._estimate_tokens(part.think)
elif isinstance(part, ImageURLPart):
total += IMAGE_TOKEN_ESTIMATE
elif isinstance(part, AudioURLPart):
total += AUDIO_TOKEN_ESTIMATE
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())

View File

@@ -12,50 +12,14 @@ class ContextTruncator:
and len(message.tool_calls) > 0
)
@staticmethod
def _split_system_rest(
messages: list[Message],
) -> tuple[list[Message], list[Message]]:
"""Split messages into system messages and the rest.
Returns:
tuple: (system_messages, non_system_messages)
"""
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
return messages[:first_non_system], messages[first_non_system:]
@staticmethod
def _ensure_user_message(
system_messages: list[Message],
truncated: list[Message],
original_messages: list[Message],
) -> list[Message]:
"""Ensure the result always contains the first user message right after
system messages. This is required by many LLM APIs (e.g. Zhipu) that
mandate a ``user`` message immediately following the ``system`` message.
"""
if truncated and truncated[0].role == "user":
return system_messages + truncated
# Locate the first user message from the *original* list.
first_user = next((m for m in original_messages if m.role == "user"), None)
if first_user is None:
return system_messages + truncated
return system_messages + [first_user] + truncated
def fix_messages(self, messages: list[Message]) -> list[Message]:
"""Fix the message list to ensure the validity of tool call and tool response pairing.
"""修复消息列表,确保 tool call tool response 的配对关系有效。
This method ensures that:
1. Each `tool` message is preceded by an `assistant` message containing `tool_calls`.
2. Each `assistant` message containing `tool_calls` is followed by corresponding `
此方法确保:
1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息
2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应
This is a requirement of the OpenAI Chat Completions API specification (Gemini enforces this strictly).
这是 OpenAI Chat Completions API 规范的要求Gemini 对此执行严格检查)。
"""
if not messages:
return messages
@@ -74,25 +38,24 @@ class ContextTruncator:
for msg in messages:
if msg.role == "tool":
# Only record tool responses when there is a pending assistant(tool_calls)
# 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应
if pending_assistant is not None:
pending_tools.append(msg)
# Isolated tool messages without a preceding assistant(tool_calls) are ignored
# else: 孤立的 tool 消息,直接忽略
continue
if self._has_tool_calls(msg):
# When encountering a new assistant(tool_calls), first process the old pending chain
# 遇到新的 assistant(tool_calls) 前,先处理旧的 pending
flush_pending_if_valid()
pending_assistant = msg
continue
# Non-tool messages that do not contain tool_calls will break the pending chain.
# Flush any pending chain first, then append the current message normally.
# 非 tool且不含 tool_calls 的消息
# 先结束任何 pending 链,再正常追加
flush_pending_if_valid()
fixed_messages.append(msg)
# Flush the last pending chain at the end,
# ensuring that any remaining valid assistant(tool_calls) and its tools are included in the final list.
# 结束时处理最后一个 pending 链
flush_pending_if_valid()
return fixed_messages
@@ -103,23 +66,29 @@ class ContextTruncator:
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""
Turn-based truncation strategy, which drops the oldest turns while keeping the most recent N turns.
A turn consists of a user message and an assistant message.
This method ensures that the truncated context list conforms to OpenAI's context format.
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: The original list of messages in the context.
keep_most_recent_turns: The number of most recent turns to keep. If set to -1, it means keeping all turns (no truncation).
drop_turns: The number of turns to drop from the beginning.
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
The truncated list of messages.
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
@@ -130,7 +99,7 @@ class ContextTruncator:
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# Find the first user message
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
@@ -138,9 +107,8 @@ class ContextTruncator:
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = self._ensure_user_message(
system_messages, truncated_contexts, messages
)
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
@@ -148,39 +116,53 @@ class ContextTruncator:
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""Drop the oldest N turns, regardless of the number of turns to keep."""
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
# Find the first user message
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
result = self._ensure_user_message(
system_messages, truncated_non_system, messages
)
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""Halve the number of messages, keeping the most recent ones."""
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
system_messages, non_system_messages = self._split_system_rest(messages)
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
@@ -188,7 +170,6 @@ class ContextTruncator:
truncated_non_system = non_system_messages[messages_to_delete:]
# Find the first user message
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
@@ -196,7 +177,6 @@ class ContextTruncator:
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = self._ensure_user_message(
system_messages, truncated_non_system, messages
)
result = system_messages + truncated_non_system
return self.fix_messages(result)

View File

@@ -15,6 +15,7 @@ class HandoffTool(FunctionTool, Generic[TContext]):
tool_description: str | None = None,
**kwargs,
) -> None:
# Avoid passing duplicate `description` to the FunctionTool dataclass.
# Some call sites (e.g. SubAgentOrchestrator) pass `description` via kwargs
# to override what the main agent sees, while we also compute a default
@@ -61,4 +62,4 @@ class HandoffTool(FunctionTool, Generic[TContext]):
def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another"
return f"Delegate tasks to {agent_name} agent to handle the request."
return f"Delegate tasks to {self.name} agent to handle the request."

View File

@@ -1,13 +1,8 @@
import asyncio
import copy
import logging
import os
import re
import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from pathlib import Path, PureWindowsPath
from typing import Any, Generic
from typing import Generic
from tenacity import (
before_sleep_log,
@@ -24,75 +19,6 @@ from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
_DEFAULT_STDIO_COMMAND_ALLOWLIST = frozenset(
{
"python",
"python3",
"py",
"node",
"npx",
"npm",
"pnpm",
"yarn",
"bun",
"bunx",
"deno",
"uv",
"uvx",
}
)
_DENIED_STDIO_COMMANDS = frozenset(
{
"bash",
"sh",
"zsh",
"fish",
"cmd",
"cmd.exe",
"powershell",
"powershell.exe",
"pwsh",
"pwsh.exe",
"osascript",
"open",
"curl",
"wget",
"nc",
"netcat",
"telnet",
"ssh",
"scp",
"rm",
"mv",
"cp",
"dd",
"mkfs",
"sudo",
"su",
"chmod",
"chown",
"kill",
"killall",
"shutdown",
"reboot",
"poweroff",
"halt",
}
)
_SHELL_META_RE = re.compile(r"[\r\n\x00;&|<>`$]")
_PYTHON_INLINE_CODE_FLAGS = frozenset({"-c"})
_JS_INLINE_CODE_FLAGS = frozenset({"-e", "--eval", "-p", "--print"})
_DENIED_DOCKER_ARGS = frozenset(
{
"--privileged",
"--pid=host",
"--network=host",
"--net=host",
"--ipc=host",
}
)
_STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS"
try:
import anyio
import mcp
@@ -114,156 +40,11 @@ def _prepare_config(config: dict) -> dict:
"""Prepare configuration, handle nested format"""
if config.get("mcpServers"):
first_key = next(iter(config["mcpServers"]))
config = dict(config["mcpServers"][first_key])
else:
config = dict(config)
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
def _normalize_stdio_command_name(command: str) -> str:
command = command.strip()
if "\\" in command:
command_name = PureWindowsPath(command).name
else:
command_name = Path(command).name
command_name = command_name.lower()
for suffix in (".exe", ".cmd", ".bat"):
if command_name.endswith(suffix):
return command_name[: -len(suffix)]
return command_name
def _get_stdio_command_allowlist() -> set[str]:
allowed = set(_DEFAULT_STDIO_COMMAND_ALLOWLIST)
configured = os.environ.get(_STDIO_ALLOWLIST_ENV, "")
if configured.strip():
allowed = {
_normalize_stdio_command_name(item)
for item in configured.split(",")
if item.strip()
}
return allowed
def _is_stdio_config(config: dict) -> bool:
cfg = _prepare_config(config.copy())
return "url" not in cfg
def _validate_stdio_args(command_name: str, args: object) -> None:
if args is None:
return
if not isinstance(args, list) or not all(isinstance(arg, str) for arg in args):
raise ValueError("MCP stdio args must be a list of strings.")
for arg in args:
if "\x00" in arg or "\r" in arg or "\n" in arg:
raise ValueError("MCP stdio args cannot contain control characters.")
if command_name.startswith("python") or command_name == "py":
if any(
arg == "-c"
or (arg.startswith("-") and not arg.startswith("--") and "c" in arg)
for arg in args
):
raise ValueError(
"MCP stdio Python servers must be launched from a module or file; inline code flags such as -c are not allowed."
)
elif command_name in {"node", "deno", "bun"} or command_name.startswith("node"):
if any(
arg in _JS_INLINE_CODE_FLAGS
or arg == "eval"
or (
arg.startswith("-")
and not arg.startswith("--")
and any(c in arg for c in "ep")
)
for arg in args
):
raise ValueError(
"MCP stdio JavaScript servers must be launched from a package or file; inline eval flags are not allowed."
)
elif command_name == "docker":
denied = []
for i, arg in enumerate(args):
if arg in _DENIED_DOCKER_ARGS:
denied.append(arg)
elif (
arg in {"--network", "--net", "--pid", "--ipc"}
and i + 1 < len(args)
and args[i + 1] == "host"
):
denied.append(f"{arg} {args[i + 1]}")
if denied:
raise ValueError(
f"MCP stdio Docker args are unsafe and not allowed: {', '.join(denied)}."
)
def validate_mcp_stdio_config(config: dict) -> None:
"""Validate stdio MCP config before any subprocess can be spawned."""
cfg = _prepare_config(config.copy())
if "url" in cfg:
return
command = cfg.get("command")
if not isinstance(command, str) or not command.strip():
raise ValueError("MCP stdio server requires a non-empty command.")
if _SHELL_META_RE.search(command):
raise ValueError("MCP stdio command contains unsafe shell metacharacters.")
command_name = _normalize_stdio_command_name(command)
if command_name in _DENIED_STDIO_COMMANDS:
raise ValueError(f"MCP stdio command `{command_name}` is not allowed.")
allowed = _get_stdio_command_allowlist()
if command_name not in allowed:
allowed_display = ", ".join(sorted(allowed))
raise ValueError(
f"MCP stdio command `{command_name}` is not allowed. "
f"Allowed commands: {allowed_display}. "
f"Set {_STDIO_ALLOWLIST_ENV} to override this list if you trust another launcher."
)
_validate_stdio_args(command_name, cfg.get("args"))
env = cfg.get("env")
if env is not None and not isinstance(env, dict):
raise ValueError("MCP stdio env must be an object.")
if isinstance(env, dict) and not all(
isinstance(key, str) and isinstance(value, str) for key, value in env.items()
):
raise ValueError("MCP stdio env keys and values must be strings.")
def _prepare_stdio_env(config: dict) -> dict:
"""Preserve Windows executable resolution for stdio subprocesses."""
if sys.platform != "win32":
return config
prepared = config.copy()
env = dict(prepared.get("env") or {})
env = _merge_environment_variables(env)
prepared["env"] = env
return prepared
def _merge_environment_variables(env: dict) -> dict:
"""合并环境变量处理Windows不区分大小写的情况"""
merged = env.copy()
# 将用户环境变量转换为统一的大小写形式便于比较
user_keys_lower = {k.lower(): k for k in merged.keys()}
for sys_key, sys_value in os.environ.items():
sys_key_lower = sys_key.lower()
if sys_key_lower not in user_keys_lower:
# 使用系统环境变量中的原始大小写
merged[sys_key] = sys_value
return merged
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity"""
import aiohttp
@@ -326,61 +107,6 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"
def _normalize_mcp_input_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Normalize common non-standard MCP JSON Schema variants.
Some MCP servers incorrectly mark required properties with a boolean
`required: true` on the property schema itself. Draft 2020-12 requires the
parent object to declare `required` as an array of property names instead.
We lift those booleans to the parent object so the schema remains usable
without disabling validation entirely.
"""
def _normalize(node: Any) -> Any:
if isinstance(node, list):
return [_normalize(item) for item in node]
if not isinstance(node, dict):
return node
normalized = {key: _normalize(value) for key, value in node.items()}
properties = normalized.get("properties")
if isinstance(properties, dict):
original_properties = (
node.get("properties")
if isinstance(node.get("properties"), dict)
else {}
)
required = normalized.get("required")
required_list = required[:] if isinstance(required, list) else []
for prop_name, prop_schema in properties.items():
if not isinstance(prop_schema, dict):
continue
original_prop_schema = original_properties.get(prop_name, {})
prop_required = (
original_prop_schema.get("required")
if isinstance(original_prop_schema, dict)
else None
)
if isinstance(prop_required, bool):
if prop_schema.get("required") is prop_required:
prop_schema.pop("required", None)
if prop_required:
required_list.append(prop_name)
if required_list:
normalized["required"] = list(dict.fromkeys(required_list))
elif isinstance(required, list):
normalized.pop("required", None)
return normalized
return _normalize(copy.deepcopy(schema))
class MCPClient:
def __init__(self) -> None:
# Initialize session and client objects
@@ -488,8 +214,6 @@ class MCPClient:
)
else:
validate_mcp_stdio_config(cfg)
cfg = _prepare_stdio_env(cfg)
server_params = mcp.StdioServerParameters(
**cfg,
)
@@ -658,7 +382,7 @@ class MCPTool(FunctionTool, Generic[TContext]):
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=_normalize_mcp_input_schema(mcp_tool.inputSchema),
parameters=mcp_tool.inputSchema,
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client

View File

@@ -7,7 +7,6 @@ from pydantic import (
BaseModel,
GetCoreSchemaHandler,
PrivateAttr,
ValidationError,
model_serializer,
model_validator,
)
@@ -166,15 +165,6 @@ class ToolCallPart(BaseModel):
"""A part of the arguments of the tool call."""
class CheckpointData(BaseModel):
"""Internal checkpoint data for linking LLM turns to platform history."""
id: str
CHECKPOINT_ROLE = "_checkpoint"
class Message(BaseModel):
"""A message in a conversation."""
@@ -183,10 +173,9 @@ class Message(BaseModel):
"user",
"assistant",
"tool",
"_checkpoint",
]
content: str | list[ContentPart] | CheckpointData | None = None
content: str | list[ContentPart] | None = None
"""The content of the message."""
tool_calls: list[ToolCall] | list[dict] | None = None
@@ -196,18 +185,9 @@ class Message(BaseModel):
"""The ID of the tool call."""
_no_save: bool = PrivateAttr(default=False)
_checkpoint_after: CheckpointData | None = PrivateAttr(default=None)
@model_validator(mode="after")
def check_content_required(self):
if self.role == CHECKPOINT_ROLE:
if not isinstance(self.content, CheckpointData):
raise ValueError("checkpoint message content must be CheckpointData")
return self
if isinstance(self.content, CheckpointData):
raise ValueError("CheckpointData is only allowed for role='_checkpoint'")
# assistant + tool_calls is not None: allow content to be None
if self.role == "assistant" and self.tool_calls is not None:
return self
@@ -251,87 +231,3 @@ class SystemMessageSegment(Message):
"""A message segment from the system."""
role: Literal["system"] = "system"
class CheckpointMessageSegment(Message):
"""Internal checkpoint segment for persisted conversation history."""
role: Literal["_checkpoint"] = "_checkpoint"
content: CheckpointData | None = None
def is_checkpoint_message(message: Message | dict) -> bool:
"""Return whether a message is an internal checkpoint."""
if isinstance(message, Message):
return message.role == CHECKPOINT_ROLE
return isinstance(message, dict) and message.get("role") == CHECKPOINT_ROLE
def get_checkpoint_id(message: Message | dict) -> str | None:
"""Return the checkpoint id from an internal checkpoint message."""
if not is_checkpoint_message(message):
return None
content = (
message.content if isinstance(message, Message) else message.get("content")
)
if isinstance(content, CheckpointData):
return content.id
if isinstance(content, dict):
checkpoint_id = content.get("id")
return (
checkpoint_id if isinstance(checkpoint_id, str) and checkpoint_id else None
)
return None
def strip_checkpoint_messages(history: list[dict]) -> list[dict]:
"""Remove internal checkpoint messages from provider-facing history."""
return [message for message in history if not is_checkpoint_message(message)]
def _get_checkpoint_data(message: Message | dict) -> CheckpointData | None:
if not is_checkpoint_message(message):
return None
content = (
message.content if isinstance(message, Message) else message.get("content")
)
if isinstance(content, CheckpointData):
return content
if isinstance(content, dict):
try:
return CheckpointData.model_validate(content)
except ValidationError:
return None
return None
def bind_checkpoint_messages(history: list[dict]) -> list[Message]:
"""Load persisted history and bind checkpoint segments to prior messages."""
messages: list[Message] = []
for item in history:
if is_checkpoint_message(item):
checkpoint = _get_checkpoint_data(item)
if checkpoint is not None and messages:
messages[-1]._checkpoint_after = checkpoint
continue
message = Message.model_validate(item)
if item.get("_no_save"):
message._no_save = True
messages.append(message)
return messages
def dump_messages_with_checkpoints(messages: list[Message]) -> list[dict]:
"""Dump runtime messages and reinsert bound checkpoint segments."""
dumped: list[dict] = []
for message in messages:
dumped.append(message.model_dump())
if message._checkpoint_after is not None:
dumped.append(
CheckpointMessageSegment(content=message._checkpoint_after).model_dump()
)
return dumped

View File

@@ -16,7 +16,7 @@ class ContextWrapper(Generic[TContext]):
context: TContext
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 120 # Default tool call timeout in seconds
tool_call_timeout: int = 60 # Default tool call timeout in seconds
NoContext = ContextWrapper[None]

View File

@@ -13,7 +13,6 @@ from astrbot.core.provider.entities import (
)
from ...hooks import BaseAgentRunHooks
from ...message import is_checkpoint_message
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
@@ -149,8 +148,6 @@ class CozeAgentRunner(BaseAgentRunner[TContext]):
# 处理历史上下文
if not self.auto_save_history and contexts:
for ctx in contexts:
if is_checkpoint_message(ctx):
continue
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
# 处理上下文中的图片
content = ctx["content"]

View File

@@ -410,20 +410,18 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
)
return messages
def _build_runtime_configurable(self, thread_id: str) -> dict[str, T.Any]:
runtime_configurable: dict[str, T.Any] = {
def _build_runtime_context(self, thread_id: str) -> dict[str, T.Any]:
runtime_context: dict[str, T.Any] = {
"thread_id": thread_id,
"thinking_enabled": self.thinking_enabled,
"is_plan_mode": self.plan_mode,
"subagent_enabled": self.subagent_enabled,
}
if self.subagent_enabled:
runtime_configurable["max_concurrent_subagents"] = (
self.max_concurrent_subagents
)
runtime_context["max_concurrent_subagents"] = self.max_concurrent_subagents
if self.model_name:
runtime_configurable["model_name"] = self.model_name
return runtime_configurable
runtime_context["model_name"] = self.model_name
return runtime_context
def _build_payload(
self,
@@ -432,19 +430,16 @@ class DeerFlowAgentRunner(BaseAgentRunner[TContext]):
image_urls: list[str],
system_prompt: str | None,
) -> dict[str, T.Any]:
runtime_configurable = self._build_runtime_configurable(thread_id)
return {
"assistant_id": self.assistant_id,
"input": {
"messages": self._build_messages(prompt, image_urls, system_prompt),
},
"stream_mode": ["values", "messages-tuple", "custom"],
# DeerFlow 2.0 consumes runtime overrides from config.configurable.
# Keep the legacy context mirror for older compat paths.
"context": dict(runtime_configurable),
# LangGraph 0.6+ prefers context instead of configurable.
"context": self._build_runtime_context(thread_id),
"config": {
"recursion_limit": self.recursion_limit,
"configurable": runtime_configurable,
},
}

View File

@@ -10,33 +10,6 @@ from astrbot.core import logger
SSE_MAX_BUFFER_CHARS = 1_048_576
class DeerFlowAPIError(Exception):
def __init__(
self,
*,
operation: str,
status: int,
body: str,
url: str,
thread_id: str | None = None,
) -> None:
self.operation = operation
self.status = status
self.body = body
self.url = url
self.thread_id = thread_id
message = (
f"DeerFlow {operation} failed: status={status}, url={url}, body={body}"
)
if thread_id is not None:
message = (
f"DeerFlow {operation} failed: thread_id={thread_id}, "
f"status={status}, url={url}, body={body}"
)
super().__init__(message)
def _normalize_sse_newlines(text: str) -> str:
"""Normalize CRLF/CR to LF so SSE block splitting works reliably."""
return text.replace("\r\n", "\n").replace("\r", "\n")
@@ -179,33 +152,11 @@ class DeerFlowAPIClient:
) as resp:
if resp.status not in (200, 201):
text = await resp.text()
raise DeerFlowAPIError(
operation="create thread",
status=resp.status,
body=text,
url=url,
raise Exception(
f"DeerFlow create thread failed: {resp.status}. {text}",
)
return await resp.json()
async def delete_thread(self, thread_id: str, timeout: float = 20) -> None:
session = self._get_session()
url = f"{self.api_base}/api/threads/{thread_id}"
async with session.delete(
url,
headers=self.headers,
timeout=timeout,
proxy=self.proxy,
) as resp:
if resp.status not in (200, 202, 204, 404):
text = await resp.text()
raise DeerFlowAPIError(
operation="delete thread",
status=resp.status,
body=text,
url=url,
thread_id=thread_id,
)
async def stream_run(
self,
thread_id: str,
@@ -249,12 +200,8 @@ class DeerFlowAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise DeerFlowAPIError(
operation="runs/stream request",
status=resp.status,
body=text,
url=url,
thread_id=thread_id,
raise Exception(
f"DeerFlow runs/stream request failed: {resp.status}. {text}",
)
async for event in _stream_sse(resp):
yield event

File diff suppressed because it is too large Load Diff

View File

@@ -89,21 +89,11 @@ class ToolSet:
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set.
If a tool with the same name already exists:
- Prefer the one that is active (active=True)
- If both have the same active state, use the new one (overwrite)
"""
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
# Use getattr with a default of True for compatibility with tools
# that may not define an `active` attribute (e.g., mocks).
existing_active = bool(getattr(existing_tool, "active", True))
new_active = bool(getattr(tool, "active", True))
# Overwrite if new tool is active, or if existing tool is not active
if new_active or not existing_active:
self.tools[i] = tool
self.tools[i] = tool
return
self.tools.append(tool)
@@ -303,15 +293,8 @@ class ToolSet:
if properties:
result["properties"] = properties
if target_type == "array":
items_schema = schema.get("items")
if isinstance(items_schema, dict):
result["items"] = convert_schema(items_schema)
else:
# Gemini requires array schemas to include an `items` schema.
# JSON Schema allows omitting it, so fall back to a permissive
# string item schema instead of emitting an invalid declaration.
result["items"] = {"type": "string"}
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result

View File

@@ -12,15 +12,6 @@ from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_begin(
self, run_context: ContextWrapper[AstrAgentContext]
) -> None:
await call_event_hook(
run_context.context.event,
EventType.OnAgentBeginEvent,
run_context,
)
async def on_agent_done(self, run_context, llm_response) -> None:
# 执行事件钩子
if llm_response and llm_response.reasoning_content:
@@ -34,12 +25,6 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
EventType.OnLLMResponseEvent,
llm_response,
)
await call_event_hook(
run_context.context.event,
EventType.OnAgentDoneEvent,
run_context,
llm_response,
)
async def on_tool_start(
self,
@@ -74,13 +59,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
platform_name = run_context.context.event.get_platform_name()
if (
platform_name == "webchat"
and tool.name
in [
"web_search_baidu",
"web_search_tavily",
"web_search_bocha",
"web_search_brave",
]
and tool.name in ["web_search_tavily", "web_search_bocha"]
and len(run_context.messages) > 0
and tool_result
and len(tool_result.content)

View File

@@ -87,31 +87,6 @@ def _build_tool_result_status_message(
return status_msg
def _should_buffer_llm_result(
buffer_intermediate_messages: bool,
stream_to_general: bool,
agent_runner: AgentRunner,
) -> bool:
return (
buffer_intermediate_messages
and not stream_to_general
and not agent_runner.streaming
)
def _merge_buffered_llm_chains(
buffered_llm_chains: list[MessageChain],
) -> MessageChain | None:
if not buffered_llm_chains:
return None
merged_chain = MessageChain()
for chain in buffered_llm_chains:
merged_chain.chain.extend(chain.chain)
buffered_llm_chains.clear()
return merged_chain
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
@@ -119,17 +94,10 @@ async def run_agent(
show_tool_call_result: bool = False,
stream_to_general: bool = False,
show_reasoning: bool = False,
buffer_intermediate_messages: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
tool_name_by_call_id: dict[str, str] = {}
buffered_llm_chains: list[MessageChain] = []
can_buffer_llm_result = _should_buffer_llm_result(
buffer_intermediate_messages,
stream_to_general,
agent_runner,
)
while step_idx < max_step + 1:
step_idx += 1
@@ -158,17 +126,6 @@ async def run_agent(
agent_runner.request_stop()
if resp.type == "aborted":
if can_buffer_llm_result:
merged_chain = _merge_buffered_llm_chains(buffered_llm_chains)
if merged_chain:
astr_event.set_result(
MessageEventResult(
chain=merged_chain.chain,
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield merged_chain
astr_event.clear_result()
if not stop_watcher.done():
stop_watcher.cancel()
try:
@@ -208,13 +165,8 @@ async def run_agent(
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming and show_tool_use:
# 向下游平台发送 "break" 分段信号(空 MessageChain不携带数据
# 平台适配器收到后会关闭当前流式消息,并在后续文本到来时创建新消息。
# 仅在 show_tool_use 为 True 时才发送:此时紧接着会通过
# astr_event.send() 独立发送工具状态消息(如"🔨 调用工具: xxx"
# 需要分段才能保证消息顺序正确。
# 若 show_tool_use 为 False不会有独立消息插入无需分段。
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
tool_info = _extract_chain_json_data(resp.data["chain"])
@@ -235,21 +187,11 @@ async def run_agent(
)
await astr_event.send(chain)
continue
elif resp.type == "llm_result":
chain = resp.data["chain"]
if chain.type == "reasoning":
# For non-streaming mode, we handle reasoning in astrbot/core/astr_agent_hooks.py.
# For streaming mode, we yield content immediately when received a reasoning chunk but not in here, see below.
continue
if stream_to_general and resp.type == "streaming_delta":
continue
if stream_to_general or not agent_runner.streaming:
if can_buffer_llm_result and resp.type == "llm_result":
buffered_llm_chains.append(resp.data["chain"])
continue
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
@@ -261,7 +203,7 @@ async def run_agent(
result_content_type=content_typ,
),
)
yield resp.data["chain"]
yield
astr_event.clear_result()
elif resp.type == "streaming_delta":
chain = resp.data["chain"]
@@ -269,19 +211,6 @@ async def run_agent(
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
if can_buffer_llm_result and agent_runner.done():
merged_chain = _merge_buffered_llm_chains(buffered_llm_chains)
if merged_chain:
astr_event.set_result(
MessageEventResult(
chain=merged_chain.chain,
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield merged_chain
astr_event.clear_result()
if not stop_watcher.done():
stop_watcher.cancel()
try:
@@ -354,7 +283,6 @@ async def run_live_agent(
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
buffer_intermediate_messages: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS
@@ -378,7 +306,6 @@ async def run_live_agent(
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
buffer_intermediate_messages=buffer_intermediate_messages,
):
yield chain
return
@@ -411,7 +338,6 @@ async def run_live_agent(
show_tool_use,
show_tool_call_result,
show_reasoning,
buffer_intermediate_messages,
)
)
@@ -499,7 +425,6 @@ async def _run_agent_feeder(
show_tool_use: bool,
show_tool_call_result: bool,
show_reasoning: bool,
buffer_intermediate_messages: bool,
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
buffer = ""
@@ -511,7 +436,6 @@ async def _run_agent_feeder(
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
buffer_intermediate_messages=buffer_intermediate_messages,
):
if chain is None:
continue

View File

@@ -19,6 +19,13 @@ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.astr_main_agent_resources import (
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PYTHON_TOOL,
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
@@ -30,18 +37,6 @@ from astrbot.core.message.message_event_result import (
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.tools.computer_tools import (
ExecuteShellTool,
FileDownloadTool,
FileEditTool,
FileReadTool,
FileUploadTool,
FileWriteTool,
GrepTool,
LocalPythonTool,
PythonTool,
)
from astrbot.core.tools.message_tools import SendMessageToUserTool
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.image_ref_utils import is_supported_image_ref
@@ -182,44 +177,18 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
return
@classmethod
def _get_runtime_computer_tools(
cls,
runtime: str,
tool_mgr,
) -> dict[str, FunctionTool]:
def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]:
if runtime == "sandbox":
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
python_tool = tool_mgr.get_builtin_tool(PythonTool)
upload_tool = tool_mgr.get_builtin_tool(FileUploadTool)
download_tool = tool_mgr.get_builtin_tool(FileDownloadTool)
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
return {
shell_tool.name: shell_tool,
python_tool.name: python_tool,
upload_tool.name: upload_tool,
download_tool.name: download_tool,
read_tool.name: read_tool,
write_tool.name: write_tool,
edit_tool.name: edit_tool,
grep_tool.name: grep_tool,
EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL,
PYTHON_TOOL.name: PYTHON_TOOL,
FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL,
FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL,
}
if runtime == "local":
shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool)
python_tool = tool_mgr.get_builtin_tool(LocalPythonTool)
read_tool = tool_mgr.get_builtin_tool(FileReadTool)
write_tool = tool_mgr.get_builtin_tool(FileWriteTool)
edit_tool = tool_mgr.get_builtin_tool(FileEditTool)
grep_tool = tool_mgr.get_builtin_tool(GrepTool)
return {
shell_tool.name: shell_tool,
python_tool.name: python_tool,
read_tool.name: read_tool,
write_tool.name: write_tool,
edit_tool.name: edit_tool,
grep_tool.name: grep_tool,
LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL,
}
return {}
@@ -234,15 +203,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
cfg = ctx.get_config(umo=event.unified_msg_origin)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
tool_mgr = (
ctx.get_llm_tool_manager()
if hasattr(ctx, "get_llm_tool_manager")
else llm_tools
)
runtime_computer_tools = cls._get_runtime_computer_tools(
runtime,
tool_mgr,
)
runtime_computer_tools = cls._get_runtime_computer_tools(runtime)
# Keep persona semantics aligned with the main agent: tools=None means
# "all tools", including runtime computer-use tools.
@@ -342,7 +303,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
tools=toolset,
contexts=contexts,
max_steps=agent_max_step,
tool_call_timeout=run_context.tool_call_timeout,
stream=stream,
)
yield mcp.types.CallToolResult(
@@ -521,7 +481,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
cron_event.role = event.role
config = MainAgentBuildConfig(
tool_call_timeout=run_context.tool_call_timeout,
tool_call_timeout=3600,
streaming_response=ctx.get_config()
.get("provider_settings", {})
.get("stream", False),
@@ -554,9 +514,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
if not req.func_tool:
req.func_tool = ToolSet()
req.func_tool.add_tool(
ctx.get_llm_tool_manager().get_builtin_tool(SendMessageToUserTool)
)
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
result = await build_main_agent(
event=cron_event, plugin_context=ctx, config=config, req=req

View File

@@ -9,7 +9,6 @@ import platform
import zoneinfo
from collections.abc import Coroutine
from dataclasses import dataclass, field
from pathlib import Path
from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
@@ -21,15 +20,38 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.astr_agent_run_util import AgentRunner
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astr_main_agent_resources import (
ANNOTATE_EXECUTION_TOOL,
BROWSER_BATCH_EXEC_TOOL,
BROWSER_EXEC_TOOL,
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
CREATE_SKILL_CANDIDATE_TOOL,
CREATE_SKILL_PAYLOAD_TOOL,
EVALUATE_SKILL_CANDIDATE_TOOL,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
GET_EXECUTION_HISTORY_TOOL,
GET_SKILL_PAYLOAD_TOOL,
KNOWLEDGE_BASE_QUERY_TOOL,
LIST_SKILL_CANDIDATES_TOOL,
LIST_SKILL_RELEASES_TOOL,
LIVE_MODE_SYSTEM_PROMPT,
LLM_SAFETY_MODE_SYSTEM_PROMPT,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PROMOTE_SKILL_CANDIDATE_TOOL,
PYTHON_TOOL,
ROLLBACK_SKILL_RELEASE_TOOL,
RUN_BROWSER_SKILL_TOOL,
SANDBOX_MODE_PROMPT,
SEND_MESSAGE_TO_USER_TOOL,
SYNC_SKILL_RELEASE_TOOL,
TOOL_CALL_PROMPT,
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
retrieve_knowledge_base,
)
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Record, Reply, Video
from astrbot.core.message.components import File, Image, Reply
from astrbot.core.persona_error_reply import (
extract_persona_custom_error_message_from_persona,
set_persona_custom_error_message_on_event,
@@ -37,63 +59,16 @@ from astrbot.core.persona_error_reply import (
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import star_map
from astrbot.core.tools.computer_tools import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileEditTool,
FileReadTool,
FileUploadTool,
FileWriteTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
GrepTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
LocalPythonTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
normalize_umo_for_workspace,
)
from astrbot.core.tools.cron_tools import FutureTaskTool
from astrbot.core.tools.knowledge_base_tools import (
KnowledgeBaseQueryTool,
retrieve_knowledge_base,
)
from astrbot.core.tools.message_tools import SendMessageToUserTool
from astrbot.core.tools.web_search_tools import (
BaiduWebSearchTool,
BochaWebSearchTool,
BraveWebSearchTool,
FirecrawlExtractWebPageTool,
FirecrawlWebSearchTool,
TavilyExtractWebPageTool,
TavilyWebSearchTool,
normalize_legacy_web_search_config,
)
from astrbot.core.utils.astrbot_path import (
get_astrbot_system_tmp_path,
get_astrbot_workspaces_path,
from astrbot.core.tools.cron_tools import (
CREATE_CRON_JOB_TOOL,
DELETE_CRON_JOB_TOOL,
LIST_CRON_JOBS_TOOL,
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
IMAGE_COMPRESS_DEFAULT_QUALITY,
compress_image,
)
from astrbot.core.utils.quoted_message.settings import (
SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS,
)
@@ -238,11 +213,7 @@ async def _apply_kb(
else:
if req.func_tool is None:
req.func_tool = ToolSet()
req.func_tool.add_tool(
plugin_context.get_llm_tool_manager().get_builtin_tool(
KnowledgeBaseQueryTool
)
)
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
async def _apply_file_extract(
@@ -304,54 +275,11 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None:
req.prompt = f"{prefix}{req.prompt}"
def _get_workspace_path_for_umo(umo: str) -> Path:
normalized_umo = normalize_umo_for_workspace(umo)
return Path(get_astrbot_workspaces_path()) / normalized_umo
def _apply_workspace_extra_prompt(
event: AstrMessageEvent,
req: ProviderRequest,
) -> None:
extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / (
"EXTRA_PROMPT.md"
)
if not extra_prompt_path.is_file():
return
try:
extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip()
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to read workspace extra prompt for umo=%s from %s: %s",
event.unified_msg_origin,
extra_prompt_path,
exc,
)
return
if not extra_prompt:
return
req.system_prompt = (
f"{req.system_prompt or ''}\n"
"[Workspace Extra Prompt]\n"
"The following instructions are loaded from the current workspace "
"`EXTRA_PROMPT.md` file.\n"
f"{extra_prompt}\n"
)
def _apply_local_env_tools(req: ProviderRequest, plugin_context: Context) -> None:
def _apply_local_env_tools(req: ProviderRequest) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
tool_mgr = plugin_context.get_llm_tool_manager()
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(LocalPythonTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool))
req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(LOCAL_PYTHON_TOOL)
req.system_prompt = f"{req.system_prompt or ''}\n{_build_local_mode_prompt()}\n"
@@ -462,9 +390,14 @@ async def _ensure_persona_and_skills(
persona_tools = None
pid = a.get("persona_id")
if pid:
persona = plugin_context.persona_manager.get_persona_v3_by_id(pid)
if persona is not None:
persona_tools = persona.get("tools")
persona_tools = next(
(
p.get("tools")
for p in plugin_context.persona_manager.personas_v3
if p["name"] == pid
),
None,
)
tools = a.get("tools", [])
if persona_tools is not None:
tools = persona_tools
@@ -545,23 +478,16 @@ async def _request_img_caption(
async def _ensure_img_caption(
event: AstrMessageEvent,
req: ProviderRequest,
cfg: dict,
plugin_context: Context,
image_caption_provider: str,
) -> None:
try:
compressed_urls = []
for url in req.image_urls:
compressed_url = await _compress_image_for_provider(url, cfg)
compressed_urls.append(compressed_url)
if _is_generated_compressed_image_path(url, compressed_url):
event.track_temporary_local_file(compressed_url)
caption = await _request_img_caption(
image_caption_provider,
cfg,
compressed_urls,
req.image_urls,
plugin_context,
)
if caption:
@@ -571,9 +497,6 @@ async def _ensure_img_caption(
req.image_urls = []
except Exception as exc: # noqa: BLE001
logger.error("处理图片描述失败: %s", exc)
req.extra_user_content_parts.append(TextPart(text="[Image Captioning Failed]"))
finally:
req.image_urls = []
def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None:
@@ -582,45 +505,6 @@ def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> No
)
def _append_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
req.extra_user_content_parts.append(
TextPart(text=f"[Audio Attachment: path {audio_path}]")
)
def _append_quoted_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
req.extra_user_content_parts.append(
TextPart(text=f"[Audio Attachment in quoted message: path {audio_path}]")
)
async def _append_video_attachment(
req: ProviderRequest,
video: Video,
*,
quoted: bool = False,
) -> None:
try:
video_path = await video.convert_to_file_path()
except Exception as exc: # noqa: BLE001
if quoted:
logger.error("Error processing quoted video attachment: %s", exc)
else:
logger.error("Error processing video attachment: %s", exc)
return
video_name = os.path.basename(video_path)
if quoted:
text = (
f"[Video Attachment in quoted message: "
f"name {video_name}, path {video_path}]"
)
else:
text = f"[Video Attachment: name {video_name}, path {video_path}]"
req.extra_user_content_parts.append(TextPart(text=text))
def _get_quoted_message_parser_settings(
provider_settings: dict[str, object] | None,
) -> QuotedMessageParserSettings:
@@ -632,64 +516,12 @@ def _get_quoted_message_parser_settings(
return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides)
def _get_image_compress_args(
provider_settings: dict[str, object] | None,
) -> tuple[bool, int, int]:
if not isinstance(provider_settings, dict):
return True, IMAGE_COMPRESS_DEFAULT_MAX_SIZE, IMAGE_COMPRESS_DEFAULT_QUALITY
enabled = provider_settings.get("image_compress_enabled", True)
if not isinstance(enabled, bool):
enabled = True
raw_options = provider_settings.get("image_compress_options", {})
options = raw_options if isinstance(raw_options, dict) else {}
max_size = options.get("max_size", IMAGE_COMPRESS_DEFAULT_MAX_SIZE)
if not isinstance(max_size, int):
max_size = IMAGE_COMPRESS_DEFAULT_MAX_SIZE
max_size = max(max_size, 1)
quality = options.get("quality", IMAGE_COMPRESS_DEFAULT_QUALITY)
if not isinstance(quality, int):
quality = IMAGE_COMPRESS_DEFAULT_QUALITY
quality = min(max(quality, 1), 100)
return enabled, max_size, quality
async def _compress_image_for_provider(
url_or_path: str,
provider_settings: dict[str, object] | None,
) -> str:
try:
enabled, max_size, quality = _get_image_compress_args(provider_settings)
if not enabled:
return url_or_path
return await compress_image(url_or_path, max_size=max_size, quality=quality)
except Exception as exc: # noqa: BLE001
logger.error("Image compression failed: %s", exc)
return url_or_path
def _is_generated_compressed_image_path(
original_path: str,
compressed_path: str | None,
) -> bool:
if not compressed_path or compressed_path == original_path:
return False
if compressed_path.startswith("http") or compressed_path.startswith("data:image"):
return False
return os.path.exists(compressed_path)
async def _process_quote_message(
event: AstrMessageEvent,
req: ProviderRequest,
img_cap_prov_id: str,
plugin_context: Context,
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
config: MainAgentBuildConfig | None = None,
) -> None:
quote = None
for comp in event.message_obj.message:
@@ -722,24 +554,15 @@ async def _process_quote_message(
if image_seg:
try:
prov = None
path = None
compress_path = None
if img_cap_prov_id:
prov = plugin_context.get_provider_by_id(img_cap_prov_id)
if prov is None:
prov = plugin_context.get_using_provider(event.unified_msg_origin)
if prov and isinstance(prov, Provider):
path = await image_seg.convert_to_file_path()
compress_path = await _compress_image_for_provider(
path,
config.provider_settings if config else None,
)
if path and _is_generated_compressed_image_path(path, compress_path):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
image_urls=[compress_path],
image_urls=[await image_seg.convert_to_file_path()],
)
if llm_resp.completion_text:
content_parts.append(
@@ -749,16 +572,6 @@ async def _process_quote_message(
logger.warning("No provider found for image captioning in quote.")
except BaseException as exc:
logger.error("处理引用图片失败: %s", exc)
finally:
if (
compress_path
and compress_path != path
and os.path.exists(compress_path)
):
try:
os.remove(compress_path)
except Exception as exc: # noqa: BLE001
logger.warning("Fail to remove temporary compressed image: %s", exc)
quoted_content = "\n".join(content_parts)
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
@@ -827,7 +640,6 @@ async def _decorate_llm_request(
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
if img_cap_prov_id and req.image_urls:
await _ensure_img_caption(
event,
req,
cfg,
plugin_context,
@@ -842,14 +654,113 @@ async def _decorate_llm_request(
img_cap_prov_id,
plugin_context,
quoted_message_settings,
config,
)
tz = config.timezone
if tz is None:
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(
"Provider %s does not support image, using placeholder.", provider
)
image_count = len(req.image_urls)
placeholder = " ".join(["[图片]"] * image_count)
if req.prompt:
req.prompt = f"{placeholder} {req.prompt}"
else:
req.prompt = placeholder
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
if "tool_use" not in provider_cfg:
logger.debug(
"Provider %s does not support tool_use, clearing tools.", provider
)
req.func_tool = None
def _sanitize_context_by_modalities(
config: MainAgentBuildConfig,
provider: Provider,
req: ProviderRequest,
) -> None:
if not config.sanitize_context_by_modalities:
return
if not isinstance(req.contexts, list) or not req.contexts:
return
modalities = provider.provider_config.get("modalities", None)
if not modalities or not isinstance(modalities, list):
return
supports_image = bool("image" in modalities)
supports_tool_use = bool("tool_use" in modalities)
if supports_image and supports_tool_use:
return
sanitized_contexts: list[dict] = []
removed_image_blocks = 0
removed_tool_messages = 0
removed_tool_calls = 0
for msg in req.contexts:
if not isinstance(msg, dict):
continue
role = msg.get("role")
if not role:
continue
new_msg = msg
if not supports_tool_use:
if role == "tool":
removed_tool_messages += 1
continue
if role == "assistant" and "tool_calls" in new_msg:
if "tool_calls" in new_msg:
removed_tool_calls += 1
new_msg.pop("tool_calls", None)
new_msg.pop("tool_call_id", None)
if not supports_image:
content = new_msg.get("content")
if isinstance(content, list):
filtered_parts: list = []
removed_any_image = False
for part in content:
if isinstance(part, dict):
part_type = str(part.get("type", "")).lower()
if part_type in {"image_url", "image"}:
removed_any_image = True
removed_image_blocks += 1
continue
filtered_parts.append(part)
if removed_any_image:
new_msg["content"] = filtered_parts
if role == "assistant":
content = new_msg.get("content")
has_tool_calls = bool(new_msg.get("tool_calls"))
if not has_tool_calls:
if not content:
continue
if isinstance(content, str) and not content.strip():
continue
sanitized_contexts.append(new_msg)
if removed_image_blocks or removed_tool_messages or removed_tool_calls:
logger.debug(
"sanitize_context_by_modalities applied: "
"removed_image_blocks=%s, removed_tool_messages=%s, removed_tool_calls=%s",
removed_image_blocks,
removed_tool_messages,
removed_tool_calls,
)
req.contexts = sanitized_contexts
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
@@ -867,14 +778,9 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
continue
mp = tool.handler_module_path
if not mp:
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*
# 不应受到会话插件过滤影响。
new_tool_set.add_tool(tool)
continue
plugin = star_map.get(mp)
if not plugin:
# 无法解析插件归属时,保守保留工具,避免误过滤。
new_tool_set.add_tool(tool)
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
@@ -936,9 +842,7 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) -
def _apply_sandbox_tools(
config: MainAgentBuildConfig,
req: ProviderRequest,
session_id: str,
config: MainAgentBuildConfig, req: ProviderRequest, session_id: str
) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
@@ -954,15 +858,10 @@ def _apply_sandbox_tools(
os.environ["SHIPYARD_ENDPOINT"] = ep
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
tool_mgr = llm_tools
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(PythonTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileUploadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileDownloadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool))
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(PYTHON_TOOL)
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
if booter == "shipyard_neo":
# Neo-specific path rule: filesystem tools operate relative to sandbox
# workspace root. Do not prepend "/workspace".
@@ -998,62 +897,32 @@ def _apply_sandbox_tools(
# Browser tools: only register if profile supports browser
# (or if capabilities are unknown because sandbox hasn't booted yet)
if sandbox_capabilities is None or "browser" in sandbox_capabilities:
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserExecTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserBatchExecTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(RunBrowserSkillTool))
req.func_tool.add_tool(BROWSER_EXEC_TOOL)
req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL)
req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL)
# Neo-specific tools (always available for shipyard_neo)
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetExecutionHistoryTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(AnnotateExecutionTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillPayloadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetSkillPayloadTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillCandidateTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillCandidatesTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(EvaluateSkillCandidateTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(PromoteSkillCandidateTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillReleasesTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(RollbackSkillReleaseTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(SyncSkillReleaseTool))
req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL)
req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL)
req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL)
req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL)
req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL)
req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL)
req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n"
def _proactive_cron_job_tools(req: ProviderRequest, plugin_context: Context) -> None:
def _proactive_cron_job_tools(req: ProviderRequest) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
tool_mgr = plugin_context.get_llm_tool_manager()
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FutureTaskTool))
async def _apply_web_search_tools(
event: AstrMessageEvent,
req: ProviderRequest,
plugin_context: Context,
) -> None:
cfg = plugin_context.get_config(umo=event.unified_msg_origin)
normalize_legacy_web_search_config(cfg)
prov_settings = cfg.get("provider_settings", {})
if not prov_settings.get("web_search", False):
return
if req.func_tool is None:
req.func_tool = ToolSet()
tool_mgr = plugin_context.get_llm_tool_manager()
provider = prov_settings.get("websearch_provider", "tavily")
if provider == "tavily":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(TavilyWebSearchTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(TavilyExtractWebPageTool))
elif provider == "bocha":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BochaWebSearchTool))
elif provider == "brave":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BraveWebSearchTool))
elif provider == "firecrawl":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlWebSearchTool))
req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlExtractWebPageTool))
elif provider == "baidu_ai_search":
req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool))
req.func_tool.add_tool(CREATE_CRON_JOB_TOOL)
req.func_tool.add_tool(DELETE_CRON_JOB_TOOL)
req.func_tool.add_tool(LIST_CRON_JOBS_TOOL)
def _get_compress_provider(
@@ -1144,7 +1013,6 @@ async def build_main_agent(
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
req.audio_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if config.provider_wake_prefix and not event.message_str.startswith(
@@ -1157,21 +1025,11 @@ async def build_main_agent(
# media files attachments
for comp in event.message_obj.message:
if isinstance(comp, Image):
path = await comp.convert_to_file_path()
image_path = await _compress_image_for_provider(
path,
config.provider_settings,
)
if _is_generated_compressed_image_path(path, image_path):
event.track_temporary_local_file(image_path)
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
req.extra_user_content_parts.append(
TextPart(text=f"[Image Attachment: path {image_path}]")
)
elif isinstance(comp, Record):
audio_path = await comp.convert_to_file_path()
req.audio_urls.append(audio_path)
_append_audio_attachment(req, audio_path)
elif isinstance(comp, File):
file_path = await comp.get_file()
file_name = comp.name or os.path.basename(file_path)
@@ -1180,8 +1038,6 @@ async def build_main_agent(
text=f"[File Attachment: name {file_name}, path {file_path}]"
)
)
elif isinstance(comp, Video):
await _append_video_attachment(req, comp)
# quoted message attachments
reply_comps = [
comp for comp in event.message_obj.message if isinstance(comp, Reply)
@@ -1196,19 +1052,9 @@ async def build_main_agent(
for reply_comp in comp.chain:
if isinstance(reply_comp, Image):
has_embedded_image = True
path = await reply_comp.convert_to_file_path()
image_path = await _compress_image_for_provider(
path,
config.provider_settings,
)
if _is_generated_compressed_image_path(path, image_path):
event.track_temporary_local_file(image_path)
image_path = await reply_comp.convert_to_file_path()
req.image_urls.append(image_path)
_append_quoted_image_attachment(req, image_path)
elif isinstance(reply_comp, Record):
audio_path = await reply_comp.convert_to_file_path()
req.audio_urls.append(audio_path)
_append_quoted_audio_attachment(req, audio_path)
elif isinstance(reply_comp, File):
file_path = await reply_comp.get_file()
file_name = reply_comp.name or os.path.basename(file_path)
@@ -1220,8 +1066,6 @@ async def build_main_agent(
)
)
)
elif isinstance(reply_comp, Video):
await _append_video_attachment(req, reply_comp, quoted=True)
# Fallback quoted image extraction for reply-id-only payloads, or when
# embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]).
@@ -1277,19 +1121,7 @@ async def build_main_agent(
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
thread_selected_text = event.get_extra("thread_selected_text")
if isinstance(thread_selected_text, str) and thread_selected_text.strip():
req.extra_user_content_parts.append(
TextPart(
text=(
"The user is asking in a side thread about this selected "
"excerpt from the previous assistant answer:\n"
f"<selected_excerpt>{thread_selected_text.strip()}</selected_excerpt>"
)
)
)
req.image_urls = normalize_and_dedupe_strings(req.image_urls)
req.audio_urls = normalize_and_dedupe_strings(req.audio_urls)
if config.file_extract_enabled:
try:
@@ -1297,7 +1129,7 @@ async def build_main_agent(
except Exception as exc: # noqa: BLE001
logger.error("Error occurred while applying file extract: %s", exc)
if not req.prompt and not req.image_urls and not req.audio_urls:
if not req.prompt and not req.image_urls:
if not event.get_group_id() and req.extra_user_content_parts:
req.prompt = "<attachment>"
else:
@@ -1310,8 +1142,9 @@ async def build_main_agent(
if not req.session_id:
req.session_id = event.unified_msg_origin
_modalities_fix(provider, req)
_plugin_tool_fix(event, req)
await _apply_web_search_tools(event, req, plugin_context)
_sanitize_context_by_modalities(config, provider, req)
if config.llm_safety_mode:
_apply_llm_safety_mode(config, req)
@@ -1319,7 +1152,7 @@ async def build_main_agent(
if config.computer_use_runtime == "sandbox":
_apply_sandbox_tools(config, req, req.session_id)
elif config.computer_use_runtime == "local":
_apply_local_env_tools(req, plugin_context)
_apply_local_env_tools(req)
agent_runner = AgentRunner()
astr_agent_ctx = AstrAgentContext(
@@ -1328,16 +1161,12 @@ async def build_main_agent(
)
if config.add_cron_tools:
_proactive_cron_job_tools(req, plugin_context)
_proactive_cron_job_tools(req)
if event.platform_meta.support_proactive_message:
if req.func_tool is None:
req.func_tool = ToolSet()
req.func_tool.add_tool(
plugin_context.get_llm_tool_manager().get_builtin_tool(
SendMessageToUserTool
)
)
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
@@ -1355,15 +1184,6 @@ async def build_main_agent(
if config.tool_schema_mode == "full"
else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE
)
if config.computer_use_runtime == "local":
tool_prompt += (
f"\nCurrent workspace you can use: "
f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n"
"Unless the user explicitly specifies a different directory, "
"perform all file-related operations in this workspace.\n"
)
req.system_prompt += f"\n{tool_prompt}\n"
action_type = event.get_extra("action_type")
@@ -1389,14 +1209,6 @@ async def build_main_agent(
fallback_providers=_get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
),
tool_result_overflow_dir=(
get_astrbot_system_tmp_path()
if req.func_tool and req.func_tool.get_tool("astrbot_file_read_tool")
else None
),
read_tool=(
req.func_tool.get_tool("astrbot_file_read_tool") if req.func_tool else None
),
)
if apply_reset:

View File

@@ -1,4 +1,42 @@
import base64
import json
import os
import uuid
from pydantic import Field
from pydantic.dataclasses import dataclass
import astrbot.core.message.components as Comp
from astrbot.api import logger, sp
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.computer.tools import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
LocalPythonTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
)
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.star.context import Context
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
LLM_SAFETY_MODE_SYSTEM_PROMPT = """You are running in Safe Mode.
@@ -108,6 +146,351 @@ BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT = (
"{background_task_result}"
)
@dataclass
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
name: str = "astr_kb_search"
description: str = (
"Query the knowledge base for facts or relevant context. "
"Use this tool when the user's question requires factual information, "
"definitions, background knowledge, or previously indexed content. "
"Only send short keywords or a concise question as the query."
)
parameters: dict = Field(
default_factory=lambda: {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A concise keyword query for the knowledge base.",
},
},
"required": ["query"],
}
)
async def call(
self, context: ContextWrapper[AstrAgentContext], **kwargs
) -> ToolExecResult:
query = kwargs.get("query", "")
if not query:
return "error: Query parameter is empty."
result = await retrieve_knowledge_base(
query=kwargs.get("query", ""),
umo=context.context.event.unified_msg_origin,
context=context.context.context,
)
if not result:
return "No relevant knowledge found."
return result
@dataclass
class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
name: str = "send_message_to_user"
description: str = "Directly send message to the user. Only use this tool when you need to proactively message the user. Otherwise you can directly output the reply in the conversation."
parameters: dict = Field(
default_factory=lambda: {
"type": "object",
"properties": {
"messages": {
"type": "array",
"description": "An ordered list of message components to send. `mention_user` type can be used to mention the user.",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": (
"Component type. One of: "
"plain, image, record, video, file, mention_user. Record is voice message."
),
},
"text": {
"type": "string",
"description": "Text content for `plain` type.",
},
"path": {
"type": "string",
"description": "File path for `image`, `record`, or `file` types. Both local path and sandbox path are supported.",
},
"url": {
"type": "string",
"description": "URL for `image`, `record`, or `file` types.",
},
"mention_user_id": {
"type": "string",
"description": "User ID to mention for `mention_user` type.",
},
},
"required": ["type"],
},
},
},
"required": ["messages"],
}
)
async def _resolve_path_from_sandbox(
self, context: ContextWrapper[AstrAgentContext], path: str
) -> tuple[str, bool]:
"""
If the path exists locally, return it directly.
Otherwise, check if it exists in the sandbox and download it.
bool: indicates whether the file was downloaded from sandbox.
"""
if os.path.exists(path):
return path, False
# Try to check if the file exists in the sandbox
try:
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
# Use shell to check if the file exists in sandbox
result = await sb.shell.exec(f"test -f {path} && echo '_&exists_'")
if "_&exists_" in json.dumps(result):
# Download the file from sandbox
name = os.path.basename(path)
local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
await sb.download_file(path, local_path)
logger.info(f"Downloaded file from sandbox: {path} -> {local_path}")
return local_path, True
except Exception as e:
logger.warning(f"Failed to check/download file from sandbox: {e}")
# Return the original path (will likely fail later, but that's expected)
return path, False
async def call(
self, context: ContextWrapper[AstrAgentContext], **kwargs
) -> ToolExecResult:
session = kwargs.get("session") or context.context.event.unified_msg_origin
messages = kwargs.get("messages")
if not isinstance(messages, list) or not messages:
return "error: messages parameter is empty or invalid."
components: list[Comp.BaseMessageComponent] = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict):
return f"error: messages[{idx}] should be an object."
msg_type = str(msg.get("type", "")).lower()
if not msg_type:
return f"error: messages[{idx}].type is required."
file_from_sandbox = False
try:
if msg_type == "plain":
text = str(msg.get("text", "")).strip()
if not text:
return f"error: messages[{idx}].text is required for plain component."
components.append(Comp.Plain(text=text))
elif msg_type == "image":
path = msg.get("path")
url = msg.get("url")
if path:
(
local_path,
file_from_sandbox,
) = await self._resolve_path_from_sandbox(context, path)
components.append(Comp.Image.fromFileSystem(path=local_path))
elif url:
components.append(Comp.Image.fromURL(url=url))
else:
return f"error: messages[{idx}] must include path or url for image component."
elif msg_type == "record":
path = msg.get("path")
url = msg.get("url")
if path:
(
local_path,
file_from_sandbox,
) = await self._resolve_path_from_sandbox(context, path)
components.append(Comp.Record.fromFileSystem(path=local_path))
elif url:
components.append(Comp.Record.fromURL(url=url))
else:
return f"error: messages[{idx}] must include path or url for record component."
elif msg_type == "video":
path = msg.get("path")
url = msg.get("url")
if path:
(
local_path,
file_from_sandbox,
) = await self._resolve_path_from_sandbox(context, path)
components.append(Comp.Video.fromFileSystem(path=local_path))
elif url:
components.append(Comp.Video.fromURL(url=url))
else:
return f"error: messages[{idx}] must include path or url for video component."
elif msg_type == "file":
path = msg.get("path")
url = msg.get("url")
name = (
msg.get("text")
or (os.path.basename(path) if path else "")
or (os.path.basename(url) if url else "")
or "file"
)
if path:
(
local_path,
file_from_sandbox,
) = await self._resolve_path_from_sandbox(context, path)
components.append(Comp.File(name=name, file=local_path))
elif url:
components.append(Comp.File(name=name, url=url))
else:
return f"error: messages[{idx}] must include path or url for file component."
elif msg_type == "mention_user":
mention_user_id = msg.get("mention_user_id")
if not mention_user_id:
return f"error: messages[{idx}].mention_user_id is required for mention_user component."
components.append(
Comp.At(
qq=mention_user_id,
),
)
else:
return (
f"error: unsupported message type '{msg_type}' at index {idx}."
)
except Exception as exc: # 捕获组件构造异常,避免直接抛出
return f"error: failed to build messages[{idx}] component: {exc}"
try:
target_session = (
MessageSession.from_str(session)
if isinstance(session, str)
else session
)
except Exception as e:
return f"error: invalid session: {e}"
await context.context.context.send_message(
target_session,
MessageChain(chain=components),
)
# if file_from_sandbox:
# try:
# os.remove(local_path)
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
return f"Message sent to session {target_session}"
async def retrieve_knowledge_base(
query: str,
umo: str,
context: Context,
) -> str | None:
"""Inject knowledge base context into the provider request
Args:
umo: Unique message object (session ID)
p_ctx: Pipeline context
"""
kb_mgr = context.kb_manager
config = context.get_config(umo=umo)
# 1. 优先读取会话级配置
session_config = await sp.session_get(umo, "kb_config", default={})
if session_config and "kb_ids" in session_config:
# 会话级配置
kb_ids = session_config.get("kb_ids", [])
# 如果配置为空列表,明确表示不使用知识库
if not kb_ids:
logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
return
top_k = session_config.get("top_k", 5)
# 将 kb_ids 转换为 kb_names
kb_names = []
invalid_kb_ids = []
for kb_id in kb_ids:
kb_helper = await kb_mgr.get_kb(kb_id)
if kb_helper:
kb_names.append(kb_helper.kb.kb_name)
else:
logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
invalid_kb_ids.append(kb_id)
if invalid_kb_ids:
logger.warning(
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}",
)
if not kb_names:
return
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
else:
kb_names = config.get("kb_names", [])
top_k = config.get("kb_final_top_k", 5)
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
top_k_fusion = config.get("kb_fusion_top_k", 20)
if not kb_names:
return
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
kb_context = await kb_mgr.retrieve(
query=query,
kb_names=kb_names,
top_k_fusion=top_k_fusion,
top_m_final=top_k,
)
if not kb_context:
return
formatted = kb_context.get("context_text", "")
if formatted:
results = kb_context.get("results", [])
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
return formatted
KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool()
SEND_MESSAGE_TO_USER_TOOL = SendMessageToUserTool()
EXECUTE_SHELL_TOOL = ExecuteShellTool()
LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True)
PYTHON_TOOL = PythonTool()
LOCAL_PYTHON_TOOL = LocalPythonTool()
FILE_UPLOAD_TOOL = FileUploadTool()
FILE_DOWNLOAD_TOOL = FileDownloadTool()
BROWSER_EXEC_TOOL = BrowserExecTool()
BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool()
RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool()
GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool()
ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool()
CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool()
GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool()
CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool()
LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool()
EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool()
PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool()
LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool()
ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool()
SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool()
# we prevent astrbot from connecting to known malicious hosts
# these hosts are base64 encoded
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}

View File

@@ -7,7 +7,6 @@ from sqlmodel import SQLModel
from astrbot.core.db.po import (
Attachment,
ChatUIProject,
CommandConfig,
CommandConflict,
ConversationV2,
@@ -17,8 +16,6 @@ from astrbot.core.db.po import (
PlatformSession,
PlatformStat,
Preference,
SessionProjectRelation,
WebChatThread,
)
from astrbot.core.knowledge_base.models import (
KBDocument,
@@ -47,9 +44,6 @@ MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
"preferences": Preference,
"platform_message_history": PlatformMessageHistory,
"platform_sessions": PlatformSession,
"webchat_threads": WebChatThread,
"chatui_projects": ChatUIProject,
"session_project_relations": SessionProjectRelation,
"attachments": Attachment,
"command_configs": CommandConfig,
"command_conflicts": CommandConflict,

View File

@@ -59,20 +59,6 @@ def _get_major_version(version_str: str) -> str:
return "0.0"
def _validate_path_within(target_path: Path, base_dir: Path) -> bool:
"""Validate that target_path is within base_dir after resolving symlinks.
Prevents path traversal attacks (CWE-22) by ensuring the resolved
target path is relative to the resolved base directory.
"""
try:
resolved = target_path.resolve(strict=False)
base_resolved = base_dir.resolve(strict=False)
return resolved.is_relative_to(base_resolved)
except (OSError, ValueError):
return False
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
KB_PATH = get_astrbot_knowledge_base_path()
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
@@ -779,10 +765,6 @@ class AstrBotImporter:
try:
rel_path = name[len(media_prefix) :]
target_path = kb_dir / rel_path
# Validate path is within kb directory (CWE-22)
if not _validate_path_within(target_path, kb_dir):
logger.warning(f"媒体文件路径越界,已跳过: {target_path}")
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
@@ -845,11 +827,6 @@ class AstrBotImporter:
else:
target_path = attachments_dir / os.path.basename(name)
# Validate path is within attachments directory (CWE-22)
if not _validate_path_within(target_path, attachments_dir):
logger.warning(f"附件路径越界,已跳过: {target_path}")
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
@@ -927,10 +904,6 @@ class AstrBotImporter:
continue
target_path = target_dir / rel_path
# Validate path is within target directory (CWE-22)
if not _validate_path_within(target_path, target_dir):
result.add_warning(f"文件路径越界,已跳过: {name}")
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:

View File

@@ -4,7 +4,7 @@ from typing import Any
import aiohttp
import boxlite
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
from shipyard.filesystem import FileSystemComponent as ShipyardFileSystemComponent
from shipyard.python import PythonComponent as ShipyardPythonComponent
from shipyard.shell import ShellComponent as ShipyardShellComponent
@@ -12,7 +12,6 @@ from astrbot.api import logger
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard import ShipyardFileSystemWrapper
class MockShipyardSandboxClient:
@@ -151,6 +150,11 @@ class BoxliteBooter(ComputerBooter):
self.mocked = MockShipyardSandboxClient(
sb_url=f"http://127.0.0.1:{random_port}"
)
self._fs = ShipyardFileSystemComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
self._python = ShipyardPythonComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
@@ -161,14 +165,6 @@ class BoxliteBooter(ComputerBooter):
ship_id=self.box.id,
session_id=session_id,
)
self._ship_fs = ShipyardFileSystemComponent(
client=self.mocked, # type: ignore
ship_id=self.box.id,
session_id=session_id,
)
self._fs = ShipyardFileSystemWrapper(
_shipyard_fs=self._ship_fs, _shipyard_shell=self._shell
)
await self.mocked.wait_healthy(self.box.id, session_id)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import locale
import os
import shutil
import subprocess
@@ -9,18 +8,15 @@ import sys
from dataclasses import dataclass
from typing import Any
from python_ripgrep import search
from astrbot.api import logger
from astrbot.core.computer.file_read_utils import (
detect_text_encoding,
read_local_text_range_sync,
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_root,
get_astrbot_temp_path,
)
from astrbot.core.utils.astrbot_path import get_astrbot_root
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard_search_file_util import _truncate_long_lines
_BLOCKED_COMMAND_PATTERNS = [
" rm -rf ",
@@ -44,43 +40,16 @@ def _is_safe_command(command: str) -> bool:
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)
def _decode_bytes_with_fallback(
output: bytes | None,
*,
preferred_encoding: str | None = None,
) -> str:
if output is None:
return ""
preferred = locale.getpreferredencoding(False) or "utf-8"
attempted_encodings: list[str] = []
def _try_decode(encoding: str) -> str | None:
normalized = encoding.lower()
if normalized in attempted_encodings:
return None
attempted_encodings.append(normalized)
try:
return output.decode(encoding)
except (LookupError, UnicodeDecodeError):
return None
for encoding in filter(None, [preferred_encoding, "utf-8", "utf-8-sig"]):
if decoded := _try_decode(encoding):
return decoded
if os.name == "nt":
for encoding in ("mbcs", "cp936", "gbk", "gb18030", preferred):
if decoded := _try_decode(encoding):
return decoded
elif decoded := _try_decode(preferred):
return decoded
return output.decode("utf-8", errors="replace")
def _decode_shell_output(output: bytes | None) -> str:
return _decode_bytes_with_fallback(output, preferred_encoding="utf-8")
def _ensure_safe_path(path: str) -> str:
abs_path = os.path.abspath(path)
allowed_roots = [
os.path.abspath(get_astrbot_root()),
os.path.abspath(get_astrbot_data_path()),
os.path.abspath(get_astrbot_temp_path()),
]
if not any(abs_path.startswith(root) for root in allowed_roots):
raise PermissionError("Path is outside the allowed computer roots.")
return abs_path
@dataclass
@@ -101,34 +70,30 @@ class LocalShellComponent(ShellComponent):
run_env = os.environ.copy()
if env:
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = os.path.abspath(cwd) if cwd else get_astrbot_root()
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
if background:
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
proc = subprocess.Popen(
command,
shell=shell,
cwd=working_dir,
env=run_env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
result = subprocess.run( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
result = subprocess.run(
command,
shell=shell,
cwd=working_dir,
env=run_env,
timeout=timeout,
capture_output=True,
text=True,
)
return {
"stdout": _decode_shell_output(result.stdout),
"stderr": _decode_shell_output(result.stderr),
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.returncode,
}
@@ -177,7 +142,7 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str, content: str = "", mode: int = 0o644
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
abs_path = _ensure_safe_path(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
@@ -186,85 +151,12 @@ class LocalFileSystemComponent(FileSystemComponent):
return await asyncio.to_thread(_run)
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
detected_encoding = encoding
if encoding == "utf-8":
with open(abs_path, "rb") as f:
raw_sample = f.read(8192)
detected_encoding = detect_text_encoding(raw_sample) or encoding
return {
"success": True,
"content": read_local_text_range_sync(
abs_path,
encoding=detected_encoding,
offset=offset,
limit=limit,
),
}
return await asyncio.to_thread(_run)
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
results = search(
patterns=[pattern],
paths=[path] if path else None,
globs=[glob] if glob else None,
after_context=after_context,
before_context=before_context,
line_number=True,
)
return {"success": True, "content": _truncate_long_lines("".join(results))}
return await asyncio.to_thread(_run)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
abs_path = _ensure_safe_path(path)
with open(abs_path, encoding=encoding) as f:
content = f.read()
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"error": "old string not found in file",
"replacements": 0,
}
if replace_all:
updated = content.replace(old_string, new_string)
replacements = occurrences
else:
updated = content.replace(old_string, new_string, 1)
replacements = 1
with open(abs_path, "w", encoding=encoding) as f:
f.write(updated)
return {
"success": True,
"path": abs_path,
"replacements": replacements,
}
return {"success": True, "content": content}
return await asyncio.to_thread(_run)
@@ -272,7 +164,7 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
abs_path = _ensure_safe_path(path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
with open(abs_path, mode, encoding=encoding) as f:
f.write(content)
@@ -282,7 +174,7 @@ class LocalFileSystemComponent(FileSystemComponent):
async def delete_file(self, path: str) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
abs_path = _ensure_safe_path(path)
if os.path.isdir(abs_path):
shutil.rmtree(abs_path)
else:
@@ -295,7 +187,7 @@ class LocalFileSystemComponent(FileSystemComponent):
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
def _run() -> dict[str, Any]:
abs_path = os.path.abspath(path)
abs_path = _ensure_safe_path(path)
entries = os.listdir(abs_path)
if not show_hidden:
entries = [e for e in entries if not e.startswith(".")]

View File

@@ -1,87 +1,9 @@
from __future__ import annotations
from typing import Any
from shipyard import FileSystemComponent as ShipyardFileSystemComponent
from shipyard import ShipyardClient, Spec
from astrbot.api import logger
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from .base import ComputerBooter
from .shipyard_search_file_util import search_files_via_shell
class ShipyardFileSystemWrapper:
def __init__(
self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent
):
self._fs = _shipyard_fs
self._shell = _shipyard_shell
async def create_file(
self, path: str, content: str = "", mode: int = 420
) -> dict[str, Any]:
return await self._fs.create_file(path=path, content=content, mode=mode)
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
return await self._fs.read_file(
path=path, encoding=encoding, offset=offset, limit=limit
)
async def write_file(
self, path: str, content: str, mode: str = "w", encoding: str = "utf-8"
) -> dict[str, Any]:
return await self._fs.write_file(
path=path, content=content, mode=mode, encoding=encoding
)
async def list_dir(
self, path: str = ".", show_hidden: bool = False
) -> dict[str, Any]:
return await self._fs.list_dir(path=path, show_hidden=show_hidden)
async def delete_file(self, path: str) -> dict[str, Any]:
return await self._fs.delete_file(path=path)
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
return await search_files_via_shell(
self._shell,
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
return await self._fs.edit_file(
path=path,
old_string=old_string,
new_string=new_string,
replace_all=replace_all,
encoding=encoding,
)
class ShipyardBooter(ComputerBooter):
@@ -107,14 +29,13 @@ class ShipyardBooter(ComputerBooter):
)
logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}")
self._ship = ship
self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._ship.shell)
async def shutdown(self) -> None:
logger.info("[Computer] Shipyard booter shutdown.")
@property
def fs(self) -> FileSystemComponent:
return self._fs
return self._ship.fs
@property
def python(self) -> PythonComponent:

View File

@@ -13,15 +13,6 @@ from ..olayer import (
ShellComponent,
)
from .base import ComputerBooter
from .shipyard_search_file_util import search_files_via_shell
try:
from shipyard_neo import BayClient
from shipyard_neo.sandbox import Sandbox
except ImportError:
logger.warning(
"shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it."
)
def _maybe_model_dump(value: Any) -> dict[str, Any]:
@@ -34,20 +25,8 @@ def _maybe_model_dump(value: Any) -> dict[str, Any]:
return {}
def _slice_content_by_lines(
content: str,
*,
offset: int | None = None,
limit: int | None = None,
) -> str:
lines = content.splitlines(keepends=True)
start = 0 if offset is None else offset
selected = lines[start:] if limit is None else lines[start : start + limit]
return "".join(selected)
class NeoPythonComponent(PythonComponent):
def __init__(self, sandbox: Sandbox) -> None:
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
@@ -88,7 +67,7 @@ class NeoPythonComponent(PythonComponent):
class NeoShellComponent(ShellComponent):
def __init__(self, sandbox: Sandbox) -> None:
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
@@ -157,9 +136,8 @@ class NeoShellComponent(ShellComponent):
class NeoFileSystemComponent(FileSystemComponent):
def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None:
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
self._shell = shell
async def create_file(
self,
@@ -171,71 +149,10 @@ class NeoFileSystemComponent(FileSystemComponent):
await self._sandbox.filesystem.write_file(path, content)
return {"success": True, "path": path}
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
return {
"success": True,
"path": path,
"content": _slice_content_by_lines(
content,
offset=offset,
limit=limit,
),
}
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
return await search_files_via_shell(
self._shell,
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
occurrences = content.count(old_string)
if occurrences == 0:
return {
"success": False,
"error": "old string not found in file",
"replacements": 0,
}
if replace_all:
updated = content.replace(old_string, new_string)
replacements = occurrences
else:
updated = content.replace(old_string, new_string, 1)
replacements = 1
await self._sandbox.filesystem.write_file(path, updated)
return {
"success": True,
"path": path,
"replacements": replacements,
}
return {"success": True, "path": path, "content": content}
async def write_file(
self,
@@ -269,7 +186,7 @@ class NeoFileSystemComponent(FileSystemComponent):
class NeoBrowserComponent(BrowserComponent):
def __init__(self, sandbox: Sandbox) -> None:
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
@@ -354,8 +271,8 @@ class ShipyardNeoBooter(ComputerBooter):
self._access_token = access_token
self._profile = profile
self._ttl = ttl
self._client: BayClient | None = None
self._sandbox: Sandbox | None = None
self._client: Any = None
self._sandbox: Any = None
self._bay_manager: Any = None # BayContainerManager when auto-started
self._fs: FileSystemComponent | None = None
self._python: PythonComponent | None = None
@@ -419,6 +336,8 @@ class ShipyardNeoBooter(ComputerBooter):
"or ensure Bay's credentials.json is accessible for auto-discovery."
)
from shipyard_neo import BayClient
self._client = BayClient(
endpoint_url=self._endpoint_url,
access_token=self._access_token,
@@ -433,9 +352,9 @@ class ShipyardNeoBooter(ComputerBooter):
ttl=self._ttl,
)
self._shell = NeoShellComponent(self._sandbox)
self._fs = NeoFileSystemComponent(self._sandbox, self._shell)
self._fs = NeoFileSystemComponent(self._sandbox)
self._python = NeoPythonComponent(self._sandbox)
self._shell = NeoShellComponent(self._sandbox)
caps = self.capabilities or ()
self._browser = (

View File

@@ -1,148 +0,0 @@
from __future__ import annotations
import shlex
from typing import Any
from ..olayer import ShellComponent
_MAX_SEARCH_LINE_COLUMNS = 1000
def _truncate_long_lines(text: str) -> str:
output_lines: list[str] = []
for line in text.splitlines(keepends=True):
line_ending = ""
line_body = line
if line.endswith("\r\n"):
line_body = line[:-2]
line_ending = "\r\n"
elif line.endswith("\n") or line.endswith("\r"):
line_body = line[:-1]
line_ending = line[-1]
if len(line_body) > _MAX_SEARCH_LINE_COLUMNS:
line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS]
output_lines.append(f"{line_body}{line_ending}")
return "".join(output_lines)
def _build_rg_command(
*,
pattern: str,
path: str,
glob: str | None,
after_context: int | None,
before_context: int | None,
) -> list[str]:
command = [
"rg",
"--color=never",
"-n",
"--max-columns",
str(_MAX_SEARCH_LINE_COLUMNS),
"-e",
pattern,
]
if glob:
command.extend(["-g", glob])
if after_context is not None:
command.extend(["-A", str(after_context)])
if before_context is not None:
command.extend(["-B", str(before_context)])
command.extend(["--", path])
return command
def _build_grep_command(
*,
pattern: str,
path: str,
glob: str | None,
after_context: int | None,
before_context: int | None,
) -> list[str]:
command = ["grep", "-R", "-H", "-n", "-e", pattern]
if glob:
command.append(f"--include={glob}")
if after_context is not None:
command.extend(["-A", str(after_context)])
if before_context is not None:
command.extend(["-B", str(before_context)])
command.extend(["--", path])
return command
def _quote_command(command: list[str]) -> str:
return " ".join(shlex.quote(part) for part in command)
def build_search_command(
*,
pattern: str,
path: str,
glob: str | None,
after_context: int | None,
before_context: int | None,
) -> str:
rg_command = _quote_command(
_build_rg_command(
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
)
grep_command = _quote_command(
_build_grep_command(
pattern=pattern,
path=path,
glob=glob,
after_context=after_context,
before_context=before_context,
)
)
return (
"if command -v rg >/dev/null 2>&1; then "
f"{rg_command}; "
"elif command -v grep >/dev/null 2>&1; then "
f"{grep_command}; "
"else "
"echo 'Neither rg nor grep is available in the sandbox.' >&2; "
"exit 127; "
"fi"
)
async def search_files_via_shell(
shell: ShellComponent,
*,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
timeout: int = 30,
) -> dict[str, Any]:
command = build_search_command(
pattern=pattern,
path=path or ".",
glob=glob,
after_context=after_context,
before_context=before_context,
)
result = await shell.exec(command, timeout=timeout)
stdout = _truncate_long_lines(str(result.get("stdout", "") or ""))
stderr = str(result.get("stderr", "") or "")
exit_code = result.get("exit_code")
if exit_code in (0, None):
return {"success": True, "content": stdout}
if exit_code == 1:
return {"success": True, "content": ""}
return {
"success": False,
"content": "",
"error": stderr or f"command exited with code {exit_code}",
"exit_code": exit_code,
}

View File

@@ -213,24 +213,13 @@ def parse_description(text: str) -> str:
break
if end_idx is None:
return ""
frontmatter = "\\n".join(lines[1:end_idx])
try:
import yaml
except ImportError:
return ""
try:
payload = yaml.safe_load(frontmatter) or dict()
except yaml.YAMLError:
return ""
if not isinstance(payload, dict):
return ""
description = payload.get("description", "")
if not isinstance(description, str):
return ""
return description.strip()
for line in lines[1:end_idx]:
if ":" not in line:
continue
key, value = line.split(":", 1)
if key.strip().lower() == "description":
return value.strip().strip('"').strip("'")
return ""
def load_managed_skills() -> list[str]:

View File

@@ -1,744 +0,0 @@
from __future__ import annotations
import base64
import hashlib
import io
import json
import zipfile
from asyncio import to_thread
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import mcp
from astrbot.core.agent.context.token_counter import EstimateTokenCounter
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
IMAGE_COMPRESS_DEFAULT_OPTIMIZE,
IMAGE_COMPRESS_DEFAULT_QUALITY,
_compress_image_sync,
)
from .booters.base import ComputerBooter
_MAX_FILE_READ_BYTES = 128 * 1024
_MAX_FILE_READ_TOKENS = 25_000
_MAX_TEXT_FILE_FULL_READ_BYTES = 256 * 1024
_FILE_SNIFF_BYTES = 512
_TOKEN_COUNTER = EstimateTokenCounter()
_TEXT_ENCODINGS = (
"utf-8-sig",
"utf-8",
"gb18030",
"utf-16",
"utf-16-le",
"utf-16-be",
"utf-32",
"utf-32-le",
"utf-32-be",
)
_UTF_BOMS = (
b"\xef\xbb\xbf",
b"\xff\xfe",
b"\xfe\xff",
b"\xff\xfe\x00\x00",
b"\x00\x00\xfe\xff",
)
_ZIP_MAGIC_PREFIXES = (
b"PK\x03\x04",
b"PK\x05\x06",
b"PK\x07\x08",
)
_BINARY_MAGIC_PREFIXES = (
b"%PDF-",
b"\x1f\x8b",
b"7z\xbc\xaf\x27\x1c",
b"Rar!\x1a\x07",
b"\x7fELF",
b"MZ",
)
@dataclass(frozen=True)
class FileProbe:
kind: Literal["text", "image", "binary"]
encoding: str | None
mime_type: str | None
size_bytes: int
@dataclass(frozen=True)
class ParsedDocument:
kind: Literal["docx", "epub", "pdf"]
file_bytes: bytes
text: str
def _build_probe_script(path: str) -> str:
return f"""
import base64
import json
from pathlib import Path
path = Path({path!r})
with path.open("rb") as file_obj:
sample = file_obj.read({_FILE_SNIFF_BYTES})
print(
json.dumps(
{{
"size_bytes": path.stat().st_size,
"sample_b64": base64.b64encode(sample).decode("utf-8"),
}}
)
)
""".strip()
def _build_text_read_script(
path: str,
*,
encoding: str,
offset: int | None,
limit: int | None,
) -> str:
start_expr = "0" if offset is None else str(offset)
limit_expr = "None" if limit is None else str(limit)
return f"""
import json
from pathlib import Path
path = Path({path!r})
start = {start_expr}
limit = {limit_expr}
end = None if limit is None else start + limit
lines = []
with path.open("r", encoding={encoding!r}, newline="") as file_obj:
for index, line in enumerate(file_obj):
if index < start:
continue
if end is not None and index >= end:
break
lines.append(line)
content = "".join(lines)
print(json.dumps({{"content": content}}, ensure_ascii=False))
""".strip()
def _build_image_read_script(path: str) -> str:
return f"""
import base64
import json
from pathlib import Path
path = Path({path!r})
data = path.read_bytes()
print(
json.dumps(
{{
"size_bytes": len(data),
"base64": base64.b64encode(data).decode("utf-8"),
}}
)
)
""".strip()
def _looks_like_text(decoded: str) -> bool:
if not decoded:
return True
disallowed = 0
printable = 0
for char in decoded:
if char in "\n\r\t\f\b":
printable += 1
continue
if char.isprintable():
printable += 1
code = ord(char)
if (0 <= code < 32) or (127 <= code < 160):
disallowed += 1
total = max(len(decoded), 1)
return disallowed / total <= 0.02 and printable / total >= 0.85
def detect_text_encoding(sample: bytes) -> str | None:
if not sample:
return "utf-8"
if b"\x00" in sample and not sample.startswith(_UTF_BOMS):
odd_bytes = sample[1::2]
even_bytes = sample[0::2]
odd_zero_ratio = odd_bytes.count(0) / max(len(odd_bytes), 1)
even_zero_ratio = even_bytes.count(0) / max(len(even_bytes), 1)
if odd_zero_ratio < 0.8 and even_zero_ratio < 0.8:
return None
for encoding in _TEXT_ENCODINGS:
try:
decoded = sample.decode(encoding)
except UnicodeDecodeError as exc:
# Probe samples can end in the middle of a multibyte sequence.
# When the decode failure only happens at the sample tail, trim a few
# bytes and retry so UTF-8 text is not misclassified as binary.
if exc.start >= len(sample) - 4:
decoded = ""
for trim_bytes in range(1, min(4, len(sample)) + 1):
try:
decoded = sample[:-trim_bytes].decode(encoding)
break
except UnicodeDecodeError:
continue
if not decoded:
continue
else:
continue
if _looks_like_text(decoded):
return encoding
return None
def read_local_text_range_sync(
path: str,
*,
encoding: str,
offset: int | None,
limit: int | None,
) -> str:
lines: list[str] = []
start = 0 if offset is None else offset
end = None if limit is None else start + limit
with open(path, encoding=encoding, newline="") as file_obj:
for index, line in enumerate(file_obj):
if index < start:
continue
if end is not None and index >= end:
break
lines.append(line)
return "".join(lines)
async def read_local_text_range(
path: str,
*,
encoding: str,
offset: int | None,
limit: int | None,
) -> str:
return await to_thread(
read_local_text_range_sync,
path,
encoding=encoding,
offset=offset,
limit=limit,
)
async def _exec_python_json(
booter: ComputerBooter,
script: str,
*,
action: str,
) -> dict:
result = await booter.python.exec(script)
data = result.get("data") if isinstance(result.get("data"), dict) else {}
if not isinstance(data, dict):
raise RuntimeError(f"{action} failed: invalid result format")
output = data.get("output") if isinstance(data.get("output"), dict) else {}
if not isinstance(output, dict):
raise RuntimeError(f"{action} failed: invalid output format")
error_text = str(data.get("error", "") or result.get("error", "") or "").strip()
if error_text:
raise RuntimeError(f"{action} failed: {error_text}")
text = str(output.get("text", "") or "").strip()
if not text:
raise RuntimeError(f"{action} failed: empty output")
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise RuntimeError(f"{action} failed: invalid JSON output") from exc
if not isinstance(payload, dict):
raise RuntimeError(f"{action} failed: invalid JSON payload")
return payload
async def _probe_local_file(path: str) -> dict[str, str | int]:
def _run() -> dict[str, str | int]:
file_path = Path(path)
with file_path.open("rb") as file_obj:
sample = file_obj.read(_FILE_SNIFF_BYTES)
return {
"size_bytes": file_path.stat().st_size,
"sample_b64": base64.b64encode(sample).decode("utf-8"),
}
return await to_thread(_run)
async def _read_local_image_base64(path: str) -> dict[str, str | int]:
def _run() -> dict[str, str | int]:
data = Path(path).read_bytes()
return {
"size_bytes": len(data),
"base64": base64.b64encode(data).decode("utf-8"),
}
return await to_thread(_run)
async def _read_local_file_bytes(path: str) -> bytes:
return await to_thread(Path(path).read_bytes)
async def _compress_image_bytes_to_base64(data: bytes) -> dict[str, str | int]:
def _run() -> dict[str, str | int]:
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
compressed_path = Path(
_compress_image_sync(
data,
temp_dir,
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
IMAGE_COMPRESS_DEFAULT_QUALITY,
IMAGE_COMPRESS_DEFAULT_OPTIMIZE,
)
)
try:
compressed_bytes = compressed_path.read_bytes()
finally:
compressed_path.unlink(missing_ok=True)
return {
"size_bytes": len(compressed_bytes),
"base64": base64.b64encode(compressed_bytes).decode("utf-8"),
"mime_type": "image/jpeg",
}
return await to_thread(_run)
def _detect_image_mime(sample: bytes) -> str | None:
if sample.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
if sample.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
if sample.startswith((b"GIF87a", b"GIF89a")):
return "image/gif"
if sample.startswith(b"BM"):
return "image/bmp"
if sample.startswith((b"II*\x00", b"MM\x00*")):
return "image/tiff"
if sample.startswith(b"\x00\x00\x01\x00"):
return "image/x-icon"
if len(sample) >= 12 and sample[:4] == b"RIFF" and sample[8:12] == b"WEBP":
return "image/webp"
if len(sample) >= 12 and sample[4:12] in (b"ftypavif", b"ftypavis"):
return "image/avif"
return None
def _looks_like_known_binary(sample: bytes) -> bool:
return any(sample.startswith(prefix) for prefix in _BINARY_MAGIC_PREFIXES)
def _looks_like_pdf(path: str, sample: bytes) -> bool:
return Path(path).suffix.lower() == ".pdf" or sample.startswith(b"%PDF-")
def _looks_like_zip_container(sample: bytes) -> bool:
return any(sample.startswith(prefix) for prefix in _ZIP_MAGIC_PREFIXES)
def _is_docx_bytes(file_bytes: bytes) -> bool:
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as archive:
names = set(archive.namelist())
except (OSError, zipfile.BadZipFile):
return False
if "[Content_Types].xml" not in names:
return False
return any(name.startswith("word/") for name in names)
def _is_epub_bytes(file_bytes: bytes) -> bool:
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as archive:
names = set(archive.namelist())
with archive.open("mimetype") as mimetype_file:
mimetype = mimetype_file.read(64).decode("utf-8").strip()
except (KeyError, OSError, UnicodeDecodeError, zipfile.BadZipFile):
return False
return mimetype == "application/epub+zip" and "META-INF/container.xml" in names
async def _parse_local_docx_text(file_bytes: bytes, file_name: str) -> str:
from astrbot.core.knowledge_base.parsers.markitdown_parser import (
MarkitdownParser,
)
result = await MarkitdownParser().parse(file_bytes, file_name)
return result.text
async def _parse_local_pdf_text(file_bytes: bytes, file_name: str) -> str:
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
result = await PDFParser().parse(file_bytes, file_name)
return result.text
async def _parse_local_epub_text(file_bytes: bytes, file_name: str) -> str:
from astrbot.core.knowledge_base.parsers.epub_parser import EpubParser
result = await EpubParser().parse(file_bytes, file_name)
return result.text
async def _parse_local_supported_document(
path: str,
sample: bytes,
) -> ParsedDocument | None:
file_name = Path(path).name
suffix = Path(path).suffix.lower()
if _looks_like_pdf(path, sample):
file_bytes = await _read_local_file_bytes(path)
text = await _parse_local_pdf_text(file_bytes, file_name)
return ParsedDocument(kind="pdf", file_bytes=file_bytes, text=text)
if suffix == ".epub":
file_bytes = await _read_local_file_bytes(path)
if not _is_epub_bytes(file_bytes):
return None
text = await _parse_local_epub_text(file_bytes, file_name)
return ParsedDocument(kind="epub", file_bytes=file_bytes, text=text)
if suffix == ".docx":
file_bytes = await _read_local_file_bytes(path)
if not _is_docx_bytes(file_bytes):
return None
text = await _parse_local_docx_text(file_bytes, file_name)
return ParsedDocument(kind="docx", file_bytes=file_bytes, text=text)
if _looks_like_zip_container(sample):
file_bytes = await _read_local_file_bytes(path)
if _is_epub_bytes(file_bytes):
text = await _parse_local_epub_text(file_bytes, file_name)
return ParsedDocument(kind="epub", file_bytes=file_bytes, text=text)
if _is_docx_bytes(file_bytes):
text = await _parse_local_docx_text(file_bytes, file_name)
return ParsedDocument(kind="docx", file_bytes=file_bytes, text=text)
return None
return None
def _probe_file(sample: bytes, *, size_bytes: int) -> FileProbe:
if image_mime := _detect_image_mime(sample):
return FileProbe(
kind="image",
encoding=None,
mime_type=image_mime,
size_bytes=size_bytes,
)
if _looks_like_known_binary(sample):
return FileProbe(
kind="binary",
encoding=None,
mime_type=None,
size_bytes=size_bytes,
)
if encoding := detect_text_encoding(sample):
return FileProbe(
kind="text",
encoding=encoding,
mime_type="text/plain",
size_bytes=size_bytes,
)
return FileProbe(
kind="binary",
encoding=None,
mime_type=None,
size_bytes=size_bytes,
)
def _validate_text_output(content: str) -> str | None:
content_bytes = len(content.encode("utf-8"))
if content_bytes > _MAX_FILE_READ_BYTES:
return (
"Error reading file: "
f"output exceeds {_MAX_FILE_READ_BYTES} bytes "
f"({content_bytes} bytes). Use `offset`, `limit` to narrow the read window."
)
content_tokens = _TOKEN_COUNTER.count_tokens(
[Message(role="user", content=content)]
)
if content_tokens > _MAX_FILE_READ_TOKENS:
return (
"Error reading file: "
f"output exceeds {_MAX_FILE_READ_TOKENS} tokens "
f"({content_tokens} tokens). Use `offset`, `limit` to narrow the read window."
)
return None
def _text_exceeds_read_thresholds(content: str) -> bool:
return _validate_text_output(content) is not None
def _validate_full_text_read_request(probe: FileProbe) -> str | None:
if probe.size_bytes > _MAX_TEXT_FILE_FULL_READ_BYTES:
return (
"Error reading file: "
f"text file exceeds {_MAX_TEXT_FILE_FULL_READ_BYTES} bytes "
f"({probe.size_bytes} bytes). Use `offset` and `limit` to narrow the read window."
)
return None
def _slice_text_by_lines(
content: str,
*,
offset: int | None,
limit: int | None,
) -> str:
if offset is None and limit is None:
return content
lines = content.splitlines(keepends=True)
start = 0 if offset is None else offset
end = None if limit is None else start + limit
return "".join(lines[start:end])
async def _store_converted_text_for_workspace(
*,
workspace_dir: str,
original_path: str,
original_bytes: bytes,
content: str,
) -> str:
def _run() -> str:
original_name = Path(original_path).name
digest_suffix = hashlib.md5(original_bytes).hexdigest()[-6:]
target_dir = (
Path(workspace_dir) / "converted_files" / f"{original_name}_{digest_suffix}"
)
target_dir.mkdir(parents=True, exist_ok=True)
target_path = target_dir / "text.txt"
target_path.write_text(content, encoding="utf-8")
return str(target_path)
return await to_thread(_run)
def _build_converted_text_notice(
converted_text_path: str,
*,
selection_returned: bool,
selection_too_large: bool = False,
) -> str:
if selection_too_large:
return (
"Converted text was saved to "
f"`{converted_text_path}`. The requested output is still too large to "
"return directly. Read or grep that file with a narrower window."
)
if selection_returned:
return (
"Full converted text is also available at "
f"`{converted_text_path}`. Read or grep that file with a narrow "
"window for additional reads."
)
return (
"Converted text was saved to "
f"`{converted_text_path}` because the parsed document is too large to "
"return directly. Read or grep that file with a narrow window."
)
async def _read_local_supported_document_result(
*,
path: str,
parsed_document: ParsedDocument,
workspace_dir: str | None,
offset: int | None,
limit: int | None,
) -> ToolExecResult:
content = parsed_document.text
if not content:
return "No content found at the requested line offset."
if not _text_exceeds_read_thresholds(content):
selected_content = _slice_text_by_lines(content, offset=offset, limit=limit)
if not selected_content:
return "No content found at the requested line offset."
if validation_error := _validate_text_output(selected_content):
return validation_error
return selected_content
if not workspace_dir:
return (
"Error reading file: parsed document exceeds the read output limit and "
"no workspace is available for storing converted text."
)
converted_text_path = await _store_converted_text_for_workspace(
workspace_dir=workspace_dir,
original_path=path,
original_bytes=parsed_document.file_bytes,
content=content,
)
if offset is None and limit is None:
return _build_converted_text_notice(
converted_text_path,
selection_returned=False,
)
selected_content = _slice_text_by_lines(content, offset=offset, limit=limit)
if not selected_content:
return (
"No content found at the requested line offset. "
+ _build_converted_text_notice(
converted_text_path,
selection_returned=False,
)
)
notice = _build_converted_text_notice(
converted_text_path,
selection_returned=True,
)
combined_output = f"{selected_content}\n\n[{notice}]"
if _validate_text_output(combined_output):
if _validate_text_output(selected_content):
return _build_converted_text_notice(
converted_text_path,
selection_returned=False,
selection_too_large=True,
)
return selected_content
return combined_output
async def read_file_tool_result(
booter: ComputerBooter,
*,
local_mode: bool,
path: str,
offset: int | None,
limit: int | None,
workspace_dir: str | None = None,
) -> ToolExecResult:
if local_mode:
probe_payload = await _probe_local_file(path)
else:
probe_payload = await _exec_python_json(
booter,
_build_probe_script(path),
action="file probe",
)
sample_b64 = str(probe_payload.get("sample_b64", "") or "")
sample = base64.b64decode(sample_b64) if sample_b64 else b""
size_bytes = int(probe_payload.get("size_bytes", 0) or 0)
probe = _probe_file(sample, size_bytes=size_bytes)
if local_mode:
try:
parsed_document = await _parse_local_supported_document(path, sample)
except Exception as exc:
return f"Error reading file: failed to parse document: {exc}"
if parsed_document is not None:
return await _read_local_supported_document_result(
path=path,
parsed_document=parsed_document,
workspace_dir=workspace_dir,
offset=offset,
limit=limit,
)
if probe.kind == "binary":
return "Error reading file: binary files are not supported by this tool."
if probe.kind == "image":
if local_mode:
image_payload = await _read_local_image_base64(path)
else:
image_payload = await _exec_python_json(
booter,
_build_image_read_script(path),
action="image read",
)
raw_base64_data = str(image_payload.get("base64", "") or "")
if not raw_base64_data:
return "Error reading file: image payload is empty."
raw_bytes = base64.b64decode(raw_base64_data)
compressed_payload = await _compress_image_bytes_to_base64(raw_bytes)
compressed_base64_data = str(compressed_payload.get("base64", "") or "")
if not compressed_base64_data:
return "Error reading file: compressed image payload is empty."
return mcp.types.CallToolResult(
content=[
mcp.types.ImageContent(
type="image",
data=compressed_base64_data,
mimeType=str(
compressed_payload.get("mime_type", "") or "image/jpeg"
),
)
]
)
if offset is None and limit is None:
if validation_error := _validate_full_text_read_request(probe):
return validation_error
if local_mode:
content = await read_local_text_range(
path,
encoding=probe.encoding or "utf-8",
offset=offset,
limit=limit,
)
else:
text_payload = await _exec_python_json(
booter,
_build_text_read_script(
path,
encoding=probe.encoding or "utf-8",
offset=offset,
limit=limit,
),
action="text read",
)
content = str(text_payload.get("content", "") or "")
if not content:
return "No content found at the requested line offset."
if validation_error := _validate_text_output(content):
return validation_error
return content

View File

@@ -12,36 +12,8 @@ class FileSystemComponent(Protocol):
"""Create a file with the specified content"""
...
async def read_file(
self,
path: str,
encoding: str = "utf-8",
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Read file content by line window"""
...
async def search_files(
self,
pattern: str,
path: str | None = None,
glob: str | None = None,
after_context: int | None = None,
before_context: int | None = None,
) -> dict[str, Any]:
"""Search file contents"""
...
async def edit_file(
self,
path: str,
old_string: str,
new_string: str,
replace_all: bool = False,
encoding: str = "utf-8",
) -> dict[str, Any]:
"""Edit file content by string replacement"""
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
"""Read file content"""
...
async def write_file(

View File

@@ -1,4 +1,5 @@
from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool
from .fs import FileDownloadTool, FileUploadTool
from .neo_skills import (
AnnotateExecutionTool,
CreateSkillCandidateTool,
@@ -12,20 +13,27 @@ from .neo_skills import (
RollbackSkillReleaseTool,
SyncSkillReleaseTool,
)
from .python import LocalPythonTool, PythonTool
from .shell import ExecuteShellTool
__all__ = [
"AnnotateExecutionTool",
"BrowserBatchExecTool",
"BrowserExecTool",
"CreateSkillCandidateTool",
"CreateSkillPayloadTool",
"EvaluateSkillCandidateTool",
"GetExecutionHistoryTool",
"GetSkillPayloadTool",
"ListSkillCandidatesTool",
"ListSkillReleasesTool",
"PromoteSkillCandidateTool",
"RollbackSkillReleaseTool",
"BrowserBatchExecTool",
"RunBrowserSkillTool",
"GetExecutionHistoryTool",
"AnnotateExecutionTool",
"CreateSkillPayloadTool",
"GetSkillPayloadTool",
"CreateSkillCandidateTool",
"ListSkillCandidatesTool",
"EvaluateSkillCandidateTool",
"PromoteSkillCandidateTool",
"ListSkillReleasesTool",
"RollbackSkillReleaseTool",
"SyncSkillReleaseTool",
"FileUploadTool",
"PythonTool",
"LocalPythonTool",
"ExecuteShellTool",
"FileDownloadTool",
]

View File

@@ -6,20 +6,23 @@ from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.tools.computer_tools.util import check_admin_permission
from astrbot.core.tools.registry import builtin_tool
_SHIPYARD_NEO_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
}
from ..computer_client import get_booter
def _to_json(data: Any) -> str:
return json.dumps(data, ensure_ascii=False, default=str)
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
if context.context.event.role != "admin":
return (
"error: Permission denied. Browser and skill lifecycle tools are only allowed "
"for admin users."
)
return None
async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> Any:
booter = await get_booter(
context.context.context,
@@ -34,7 +37,6 @@ async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> A
return browser
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class BrowserExecTool(FunctionTool):
name: str = "astrbot_execute_browser"
@@ -75,7 +77,7 @@ class BrowserExecTool(FunctionTool):
learn: bool = False,
include_trace: bool = False,
) -> ToolExecResult:
if err := check_admin_permission(context, "Using browser tools"):
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
@@ -92,7 +94,6 @@ class BrowserExecTool(FunctionTool):
return f"Error executing browser command: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class BrowserBatchExecTool(FunctionTool):
name: str = "astrbot_execute_browser_batch"
@@ -139,7 +140,7 @@ class BrowserBatchExecTool(FunctionTool):
learn: bool = False,
include_trace: bool = False,
) -> ToolExecResult:
if err := check_admin_permission(context, "Using browser tools"):
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
@@ -157,7 +158,6 @@ class BrowserBatchExecTool(FunctionTool):
return f"Error executing browser batch command: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class RunBrowserSkillTool(FunctionTool):
name: str = "astrbot_run_browser_skill"
@@ -187,7 +187,7 @@ class RunBrowserSkillTool(FunctionTool):
description: str | None = None,
tags: str | None = None,
) -> ToolExecResult:
if err := check_admin_permission(context, "Using browser tools"):
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)

View File

@@ -0,0 +1,204 @@
import os
import uuid
from dataclasses import dataclass, field
from astrbot.api import FunctionTool, logger
from astrbot.api.event import MessageChain
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.components import File
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..computer_client import get_booter
from .permissions import check_admin_permission
# @dataclass
# class CreateFileTool(FunctionTool):
# name: str = "astrbot_create_file"
# description: str = "Create a new file in the sandbox."
# parameters: dict = field(
# default_factory=lambda: {
# "type": "object",
# "properties": {
# "path": {
# "path": "string",
# "description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
# },
# "content": {
# "type": "string",
# "description": "The content to write into the file.",
# },
# },
# "required": ["path", "content"],
# }
# )
# async def call(
# self, context: ContextWrapper[AstrAgentContext], path: str, content: str
# ) -> ToolExecResult:
# sb = await get_booter(
# context.context.context,
# context.context.event.unified_msg_origin,
# )
# try:
# result = await sb.fs.create_file(path, content)
# return json.dumps(result)
# except Exception as e:
# return f"Error creating file: {str(e)}"
# @dataclass
# class ReadFileTool(FunctionTool):
# name: str = "astrbot_read_file"
# description: str = "Read the content of a file in the sandbox."
# parameters: dict = field(
# default_factory=lambda: {
# "type": "object",
# "properties": {
# "path": {
# "type": "string",
# "description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.",
# },
# },
# "required": ["path"],
# }
# )
# async def call(self, context: ContextWrapper[AstrAgentContext], path: str):
# sb = await get_booter(
# context.context.context,
# context.context.event.unified_msg_origin,
# )
# try:
# result = await sb.fs.read_file(path)
# return result
# except Exception as e:
# return f"Error reading file: {str(e)}"
@dataclass
class FileUploadTool(FunctionTool):
name: str = "astrbot_upload_file"
description: str = "Upload a local file to the sandbox. The file must exist on the local filesystem."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"local_path": {
"type": "string",
"description": "The local file path to upload. This must be an absolute path to an existing file on the local filesystem.",
},
# "remote_path": {
# "type": "string",
# "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.",
# },
},
"required": ["local_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
local_path: str,
) -> str | None:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
# Check if file exists
if not os.path.exists(local_path):
return f"Error: File does not exist: {local_path}"
if not os.path.isfile(local_path):
return f"Error: Path is not a file: {local_path}"
# Use basename if sandbox_filename is not provided
remote_path = os.path.basename(local_path)
# Upload file to sandbox
result = await sb.upload_file(local_path, remote_path)
logger.debug(f"Upload result: {result}")
success = result.get("success", False)
if not success:
return f"Error uploading file: {result.get('message', 'Unknown error')}"
file_path = result.get("file_path", "")
logger.info(f"File {local_path} uploaded to sandbox at {file_path}")
return f"File uploaded successfully to {file_path}"
except Exception as e:
logger.error(f"Error uploading file {local_path}: {e}")
return f"Error uploading file: {str(e)}"
@dataclass
class FileDownloadTool(FunctionTool):
name: str = "astrbot_download_file"
description: str = "Download a file from the sandbox. Only call this when user explicitly need you to download a file."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"remote_path": {
"type": "string",
"description": "The path of the file in the sandbox to download.",
},
"also_send_to_user": {
"type": "boolean",
"description": "Whether to also send the downloaded file to the user via message. Defaults to true.",
},
},
"required": ["remote_path"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
remote_path: str,
also_send_to_user: bool = True,
) -> ToolExecResult:
if permission_error := check_admin_permission(context, "File upload/download"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
name = os.path.basename(remote_path)
local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
# Download file from sandbox
await sb.download_file(remote_path, local_path)
logger.info(f"File {remote_path} downloaded from sandbox to {local_path}")
if also_send_to_user:
try:
name = os.path.basename(local_path)
await context.context.event.send(
MessageChain(chain=[File(name=name, file=local_path)])
)
except Exception as e:
logger.error(f"Error sending file message: {e}")
# remove
# try:
# os.remove(local_path)
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path} and sent to user."
return f"File downloaded successfully to {local_path}"
except Exception as e:
logger.error(f"Error downloading file {remote_path}: {e}")
return f"Error downloading file: {str(e)}"

View File

@@ -7,15 +7,9 @@ from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from astrbot.core.tools.computer_tools.util import check_admin_permission
from astrbot.core.tools.registry import builtin_tool
_SHIPYARD_NEO_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
}
from ..computer_client import get_booter
def _to_jsonable(model_like: Any) -> Any:
@@ -32,6 +26,12 @@ def _to_json_text(data: Any) -> str:
return json.dumps(_to_jsonable(data), ensure_ascii=False, default=str)
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
if context.context.event.role != "admin":
return "error: Permission denied. Skill lifecycle tools are only allowed for admin users."
return None
async def _get_neo_context(
context: ContextWrapper[AstrAgentContext],
) -> tuple[Any, Any]:
@@ -59,7 +59,7 @@ class NeoSkillToolBase(FunctionTool):
neo_call: Callable[[Any, Any], Awaitable[Any]],
error_action: str,
) -> ToolExecResult:
if err := check_admin_permission(context, "Using skill lifecycle tools"):
if err := _ensure_admin(context):
return err
try:
client, sandbox = await _get_neo_context(context)
@@ -69,7 +69,6 @@ class NeoSkillToolBase(FunctionTool):
return f"{self.error_prefix} {error_action}: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class GetExecutionHistoryTool(NeoSkillToolBase):
name: str = "astrbot_get_execution_history"
@@ -116,7 +115,6 @@ class GetExecutionHistoryTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class AnnotateExecutionTool(NeoSkillToolBase):
name: str = "astrbot_annotate_execution"
@@ -154,7 +152,6 @@ class AnnotateExecutionTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class CreateSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_payload"
@@ -167,10 +164,7 @@ class CreateSkillPayloadTool(NeoSkillToolBase):
"type": "object",
"properties": {
"payload": {
"anyOf": [
{"type": "object"},
{"type": "array", "items": {"type": "object"}},
],
"anyOf": [{"type": "object"}, {"type": "array"}],
"description": (
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
"This only stores content and returns payload_ref; it does not create a candidate or release."
@@ -202,7 +196,6 @@ class CreateSkillPayloadTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class GetSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_get_skill_payload"
@@ -229,7 +222,6 @@ class GetSkillPayloadTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class CreateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_candidate"
@@ -283,7 +275,6 @@ class CreateSkillCandidateTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class ListSkillCandidatesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_candidates"
@@ -321,7 +312,6 @@ class ListSkillCandidatesTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class EvaluateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_evaluate_skill_candidate"
@@ -362,7 +352,6 @@ class EvaluateSkillCandidateTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class PromoteSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_promote_skill_candidate"
@@ -400,7 +389,7 @@ class PromoteSkillCandidateTool(NeoSkillToolBase):
stage: str = "canary",
sync_to_local: bool = True,
) -> ToolExecResult:
if err := check_admin_permission(context, "Using skill lifecycle tools"):
if err := _ensure_admin(context):
return err
if stage not in {"canary", "stable"}:
return "Error promoting skill candidate: stage must be canary or stable."
@@ -433,7 +422,6 @@ class PromoteSkillCandidateTool(NeoSkillToolBase):
return f"Error promoting skill candidate: {str(e)}"
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class ListSkillReleasesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_releases"
@@ -474,7 +462,6 @@ class ListSkillReleasesTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class RollbackSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_rollback_skill_release"
@@ -501,7 +488,6 @@ class RollbackSkillReleaseTool(NeoSkillToolBase):
)
@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG)
@dataclass
class SyncSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_sync_skill_release"

View File

@@ -1,29 +1,5 @@
import re
from pathlib import Path
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path
def normalize_umo_for_workspace(umo: str) -> str:
normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", umo.strip())
return normalized or "unknown"
def workspace_root(umo: str) -> Path:
"""Root directory for relative paths in local runtime"""
normalized_umo = normalize_umo_for_workspace(umo)
return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False)
def is_local_runtime(context: ContextWrapper[AstrAgentContext]) -> bool:
cfg = context.context.context.get_config(
umo=context.context.event.unified_msg_origin
)
provider_settings = cfg.get("provider_settings", {})
runtime = str(provider_settings.get("computer_use_runtime", "local"))
return runtime == "local"
def check_admin_permission(

View File

@@ -8,18 +8,10 @@ from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent
from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
from ..registry import builtin_tool
from .util import check_admin_permission
_OS_NAME = platform.system()
_SANDBOX_PYTHON_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "sandbox",
}
_LOCAL_PYTHON_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": "local",
}
param_schema = {
"type": "object",
@@ -69,7 +61,6 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult
return resp
@builtin_tool(config=_SANDBOX_PYTHON_TOOL_CONFIG)
@dataclass
class PythonTool(FunctionTool):
name: str = "astrbot_execute_ipython"
@@ -92,7 +83,6 @@ class PythonTool(FunctionTool):
return f"Error executing code: {str(e)}"
@builtin_tool(config=_LOCAL_PYTHON_TOOL_CONFIG)
@dataclass
class LocalPythonTool(FunctionTool):
name: str = "astrbot_execute_python"

View File

@@ -5,17 +5,11 @@ from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from ..registry import builtin_tool
from .util import check_admin_permission, is_local_runtime, workspace_root
_COMPUTER_RUNTIME_TOOL_CONFIG = {
"provider_settings.computer_use_runtime": ("local", "sandbox"),
}
from ..computer_client import get_booter, get_local_booter
from .permissions import check_admin_permission
@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG)
@dataclass
class ExecuteShellTool(FunctionTool):
name: str = "astrbot_execute_shell"
@@ -44,6 +38,8 @@ class ExecuteShellTool(FunctionTool):
}
)
is_local: bool = False
async def call(
self,
context: ContextWrapper[AstrAgentContext],
@@ -54,25 +50,15 @@ class ExecuteShellTool(FunctionTool):
if permission_error := check_admin_permission(context, "Shell execution"):
return permission_error
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
try:
cwd: str | None = None
if is_local_runtime(context):
current_workspace_root = workspace_root(
context.context.event.unified_msg_origin
)
current_workspace_root.mkdir(parents=True, exist_ok=True)
cwd = str(current_workspace_root)
result = await sb.shell.exec(
command,
cwd=cwd,
background=background,
env=env,
if self.is_local:
sb = get_local_booter()
else:
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
return json.dumps(result, ensure_ascii=False)
try:
result = await sb.shell.exec(command, background=background, env=env)
return json.dumps(result)
except Exception as e:
return f"Error executing command: {str(e)}"

View File

@@ -178,6 +178,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
if not self.config_path: # 加判空
return False
return os.path.exists(self.config_path)

View File

@@ -5,40 +5,8 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.23.5"
VERSION = "4.19.5"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
PERSONAL_WECHAT_CONFIG_METADATA = {
"weixin_oc_base_url": {
"description": "Base URL",
"type": "string",
"hint": "默认值: https://ilinkai.weixin.qq.com",
},
"weixin_oc_bot_type": {
"description": "扫码参数 bot_type",
"type": "string",
"hint": "默认值: 3",
},
"weixin_oc_qr_poll_interval": {
"description": "二维码状态轮询间隔(秒)",
"type": "int",
"hint": "每隔多少秒轮询一次二维码状态。",
},
"weixin_oc_long_poll_timeout_ms": {
"description": "getUpdates 长轮询超时时间(毫秒)",
"type": "int",
"hint": "会话消息拉取接口超时参数。",
},
"weixin_oc_api_timeout_ms": {
"description": "HTTP 请求超时(毫秒)",
"type": "int",
"hint": "通用 API 请求超时参数。",
},
"weixin_oc_token": {
"description": "登录后 token可留空",
"type": "string",
"hint": "扫码登录成功后会自动写入;高级场景可手动填写。",
},
}
WEBHOOK_SUPPORTED_PLATFORMS = [
"qq_official_webhook",
@@ -106,10 +74,9 @@ DEFAULT_CONFIG = {
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
"wake_prefix": "",
"web_search": False,
"websearch_provider": "tavily",
"websearch_provider": "default",
"websearch_tavily_key": [],
"websearch_bocha_key": [],
"websearch_brave_key": [],
"websearch_baidu_app_builder_key": "",
"web_search_link": False,
"display_reasoning_text": False,
@@ -134,7 +101,6 @@ DEFAULT_CONFIG = {
"streaming_response": False,
"show_tool_use_status": False,
"show_tool_call_result": False,
"buffer_intermediate_messages": False,
"sanitize_context_by_modalities": False,
"max_quoted_fallback_images": 20,
"quoted_message_parser": {
@@ -151,7 +117,7 @@ DEFAULT_CONFIG = {
"unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30,
"tool_call_timeout": 120,
"tool_call_timeout": 60,
"tool_schema_mode": "full",
"llm_safety_mode": True,
"safety_mode_strategy": "system_prompt", # TODO: llm judge
@@ -176,11 +142,6 @@ DEFAULT_CONFIG = {
"shipyard_neo_profile": "python-default",
"shipyard_neo_ttl": 3600,
},
"image_compress_enabled": True,
"image_compress_options": {
"max_size": 1280,
"quality": 95,
},
},
# SubAgent orchestrator mode:
# - main_enable = False: disabled; main LLM mounts tools normally (persona selection).
@@ -403,16 +364,6 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6198,
},
"个人微信": {
"id": "weixin_personal",
"type": "weixin_oc",
"enable": False,
"weixin_oc_base_url": "https://ilinkai.weixin.qq.com",
"weixin_oc_bot_type": "3",
"weixin_oc_qr_poll_interval": 1,
"weixin_oc_long_poll_timeout_ms": 35_000,
"weixin_oc_api_timeout_ms": 15_000,
},
"飞书(Lark)": {
"id": "lark",
"type": "lark",
@@ -445,7 +396,6 @@ CONFIG_METADATA_2 = {
"telegram_command_register": True,
"telegram_command_auto_refresh": True,
"telegram_command_register_interval": 300,
"telegram_polling_restart_delay": 5.0,
},
"Discord": {
"id": "discord",
@@ -455,7 +405,6 @@ CONFIG_METADATA_2 = {
"discord_proxy": "",
"discord_command_register": True,
"discord_activity_name": "",
"discord_allow_bot_messages": False,
},
"Misskey": {
"id": "misskey",
@@ -509,11 +458,12 @@ CONFIG_METADATA_2 = {
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
"KOOK": {
"kook": {
"id": "kook",
"type": "kook",
"enable": False,
"kook_bot_token": "",
"kook_bot_nickname": "",
"kook_reconnect_delay": 1,
"kook_max_reconnect_delay": 60,
"kook_max_retry_delay": 60,
@@ -522,14 +472,6 @@ CONFIG_METADATA_2 = {
"kook_max_heartbeat_failures": 3,
"kook_max_consecutive_failures": 5,
},
"Mattermost": {
"id": "mattermost",
"type": "mattermost",
"enable": False,
"mattermost_url": "https://chat.example.com",
"mattermost_bot_token": "",
"mattermost_reconnect_delay": 5.0,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
@@ -664,21 +606,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
},
"mattermost_url": {
"description": "Mattermost URL",
"type": "string",
"hint": "Mattermost 服务地址,例如 https://chat.example.com。",
},
"mattermost_bot_token": {
"description": "Mattermost Bot Token",
"type": "string",
"hint": "在 Mattermost 中创建 Bot 账户后生成的访问令牌。",
},
"mattermost_reconnect_delay": {
"description": "Mattermost 重连延迟",
"type": "float",
"hint": "WebSocket 断开后的重连等待时间,单位为秒。默认 5 秒。",
},
"misskey_instance_url": {
"description": "Misskey 实例 URL",
"type": "string",
@@ -760,11 +687,6 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "Telegram 命令自动刷新间隔,单位为秒。",
},
"telegram_polling_restart_delay": {
"description": "Telegram 轮询重启延迟",
"type": "float",
"hint": "当轮询意外结束尝试自动重启时的延迟时间,理论上越短恢复越快,但过短(<0.1s)可能导致死循环针对 API 服务器的请求阻断。单位为秒。默认为 5s。",
},
"id": {
"description": "机器人名称",
"type": "string",
@@ -783,7 +705,7 @@ CONFIG_METADATA_2 = {
"appid": {
"description": "appid",
"type": "string",
"hint": "必填项。当前消息平台的 AppID。如何获取请参考对应平台接入文档。",
"hint": "必填项。QQ 官方机器人平台的 appid。如何获取请参考文档。",
},
"secret": {
"description": "secret",
@@ -921,11 +843,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
},
"discord_allow_bot_messages": {
"description": "允许接收机器人消息",
"type": "bool",
"hint": "启用后AstrBot 将接收来自其他 Discord 机器人的消息。适用于机器人间通信场景(如消息转发)。默认关闭。",
},
"port": {
"description": "回调服务器端口",
"type": "int",
@@ -947,7 +864,6 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "Webhook 模式下使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
},
**PERSONAL_WECHAT_CONFIG_METADATA,
"webhook_uuid": {
"invisible": True,
"description": "Webhook UUID",
@@ -959,6 +875,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。",
},
"kook_bot_nickname": {
"description": "Bot Nickname",
"type": "string",
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。",
},
"kook_reconnect_delay": {
"description": "重连延迟",
"type": "int",
@@ -1153,7 +1074,7 @@ CONFIG_METADATA_2 = {
"type": "list",
# provider sources templates
"config_template": {
"OpenAI Compatible": {
"OpenAI": {
"id": "openai",
"provider": "openai",
"type": "openai_chat_completion",
@@ -1197,20 +1118,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"proxy": "",
"custom_headers": {},
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
},
"Kimi Coding Plan": {
"id": "kimi-code",
"provider": "kimi-code",
"type": "kimi_code_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.kimi.com/coding",
"timeout": 120,
"proxy": "",
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
},
"Moonshot": {
@@ -1225,31 +1132,6 @@ CONFIG_METADATA_2 = {
"proxy": "",
"custom_headers": {},
},
"MiniMax": {
"id": "minimax",
"provider": "minimax",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.minimaxi.com/v1",
"timeout": 120,
"proxy": "",
"custom_headers": {},
},
"MiniMax Token Plan": {
"id": "minimax-token-plan",
"provider": "minimax-token-plan",
"type": "minimax_token_plan",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.minimaxi.com/anthropic",
"timeout": 120,
"proxy": "",
"custom_headers": {"User-Agent": "claude-code/0.1.0"},
"anth_thinking_config": {"type": "", "budget": 0, "effort": ""},
},
"xAI": {
"id": "xai",
"provider": "xai",
@@ -1287,18 +1169,6 @@ CONFIG_METADATA_2 = {
"proxy": "",
"custom_headers": {},
},
"LongCat": {
"id": "longcat",
"provider": "longcat",
"type": "longcat_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.longcat.chat/openai",
"timeout": 120,
"proxy": "",
"custom_headers": {},
},
"AIHubMix": {
"id": "aihubmix",
"provider": "aihubmix",
@@ -1358,7 +1228,6 @@ CONFIG_METADATA_2 = {
"api_base": "http://127.0.0.1:11434/v1",
"proxy": "",
"custom_headers": {},
"ollama_disable_thinking": False,
},
"LM Studio": {
"id": "lm_studio",
@@ -1556,20 +1425,6 @@ CONFIG_METADATA_2 = {
"model": "whisper-1",
"proxy": "",
},
"MiMo STT(API)": {
"id": "mimo_stt",
"provider": "mimo",
"type": "mimo_stt_api",
"provider_type": "speech_to_text",
"enable": False,
"api_key": "",
"api_base": "https://api.xiaomimimo.com/v1",
"model": "mimo-v2-omni",
"mimo-stt-system-prompt": "You are a speech transcription assistant. Transcribe the spoken content from the audio exactly and return only the transcription text.",
"mimo-stt-user-prompt": "Please transcribe the content of the audio and return only the transcription text.",
"timeout": "20",
"proxy": "",
},
"Whisper(Local)": {
"provider": "openai",
"type": "openai_whisper_selfhost",
@@ -1577,7 +1432,6 @@ CONFIG_METADATA_2 = {
"enable": False,
"id": "whisper_selfhost",
"model": "tiny",
"whisper_device": "cpu",
},
"SenseVoice(Local)": {
"type": "sensevoice_stt_selfhost",
@@ -1601,23 +1455,6 @@ CONFIG_METADATA_2 = {
"timeout": "20",
"proxy": "",
},
"MiMo TTS(API)": {
"id": "mimo_tts",
"type": "mimo_tts_api",
"provider": "mimo",
"provider_type": "text_to_speech",
"enable": False,
"api_key": "",
"api_base": "https://api.xiaomimimo.com/v1",
"model": "mimo-v2-tts",
"mimo-tts-voice": "mimo_default",
"mimo-tts-format": "wav",
"mimo-tts-style-prompt": "",
"mimo-tts-dialect": "",
"mimo-tts-seed-text": "Hello, MiMo, have you had lunch?",
"timeout": "20",
"proxy": "",
},
"Genie TTS": {
"id": "genie_tts",
"provider": "genie_tts",
@@ -1680,14 +1517,10 @@ CONFIG_METADATA_2 = {
"type": "gsvi_tts_api",
"provider": "gpt_sovits_inference",
"provider_type": "text_to_speech",
"enable": False,
"api_key": "",
"api_base": "http://127.0.0.1:8000",
"version": "v4",
"api_base": "http://127.0.0.1:5000",
"character": "",
"prompt_text_lang": "中文",
"emotion": "默认",
"text_lang": "中文",
"emotion": "default",
"enable": False,
"timeout": 20,
},
"FishAudio TTS(API)": {
@@ -1818,7 +1651,6 @@ CONFIG_METADATA_2 = {
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "http://127.0.0.1:8000",
"rerank_api_suffix": "/v1/rerank",
"rerank_model": "BAAI/bge-reranker-base",
"timeout": 20,
},
@@ -1847,19 +1679,6 @@ CONFIG_METADATA_2 = {
"return_documents": False,
"instruct": "",
},
"NVIDIA Rerank": {
"id": "nvidia_rerank",
"type": "nvidia_rerank",
"provider": "nvidia",
"provider_type": "rerank",
"enable": True,
"nvidia_rerank_api_key": "",
"nvidia_rerank_api_base": "https://ai.api.nvidia.com/v1/retrieval",
"nvidia_rerank_model": "nv-rerank-qa-mistral-4b:1",
"nvidia_rerank_model_endpoint": "/reranking",
"timeout": 20,
"nvidia_rerank_truncate": "",
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
@@ -1897,12 +1716,7 @@ CONFIG_METADATA_2 = {
"rerank_api_base": {
"description": "重排序模型 API Base URL",
"type": "string",
"hint": "最终请求路径由 Base URL 和路径后缀拼接而成(默认为 /v1/rerank",
},
"rerank_api_suffix": {
"description": "API URL 路径后缀",
"type": "string",
"hint": "追加到 base_url 后的路径,如 /v1/rerank。留空则不追加。",
"hint": "AstrBot 会在请求时在末尾加上 /v1/rerank。",
},
"rerank_api_key": {
"description": "API Key",
@@ -1928,40 +1742,12 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
},
"nvidia_rerank_api_base": {
"description": "API Base URL",
"type": "string",
},
"nvidia_rerank_api_key": {
"description": "API Key",
"type": "string",
},
"nvidia_rerank_model": {
"description": "重排序模型名称",
"type": "string",
"hint": "请参照NVIDIA Docs中模型名称填写。",
},
"nvidia_rerank_model_endpoint": {
"description": "自定义模型端点",
"type": "string",
"hint": "自定义URL末尾端点默认为 /reranking",
},
"nvidia_rerank_truncate": {
"description": "文本截断策略",
"type": "string",
"hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。",
"options": [
"",
"NONE",
"END",
],
},
"modalities": {
"description": "模型能力",
"type": "list",
"items": {"type": "string"},
"options": ["text", "image", "audio", "tool_use"],
"labels": ["文本", "图像", "音频", "工具使用"],
"options": ["text", "image", "tool_use"],
"labels": ["文本", "图像", "工具使用"],
"render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
},
@@ -1971,11 +1757,6 @@ CONFIG_METADATA_2 = {
"items": {},
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
},
"ollama_disable_thinking": {
"description": "关闭思考模式",
"type": "bool",
"hint": "关闭 Ollama 思考模式。",
},
"custom_extra_body": {
"description": "自定义请求体参数",
"type": "dict",
@@ -2522,46 +2303,11 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超时时间,单位为秒。",
},
"mimo-stt-system-prompt": {
"description": "系统提示词",
"type": "string",
"hint": "用于指导 MiMo STT 转录行为的 system prompt。",
},
"mimo-stt-user-prompt": {
"description": "用户提示词",
"type": "string",
"hint": "附加给 MiMo STT 的用户提示词,用于约束返回结果格式。",
},
"openai-tts-voice": {
"description": "voice",
"type": "string",
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
},
"mimo-tts-voice": {
"description": "音色",
"type": "string",
"hint": "MiMo TTS 的音色名称。可选值包括 'mimo_default''default_en''default_zh'",
},
"mimo-tts-format": {
"description": "输出格式",
"type": "string",
"hint": "MiMo TTS 生成音频的格式。支持 'wav''mp3''pcm'",
},
"mimo-tts-style-prompt": {
"description": "风格提示词",
"type": "string",
"hint": "会以 <style>...</style> 标签形式添加到待合成文本开头,用于控制语速、情绪、角色或风格,例如 开心、变快、孙悟空、悄悄话。可留空。",
},
"mimo-tts-dialect": {
"description": "方言",
"type": "string",
"hint": "会与风格提示词一起写入开头的 <style>...</style> 标签中,例如 东北话、四川话、河南话、粤语。可留空。",
},
"mimo-tts-seed-text": {
"description": "种子文本",
"type": "string",
"hint": "作为可选的 user 消息发送,用于辅助调节语气和风格,不会拼接到待合成文本中。",
},
"fishaudio-tts-character": {
"description": "character",
"type": "string",
@@ -2577,12 +2323,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库N卡用户大约下载 2GB主要是 torch 和 cudaCPU 用户大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
},
"whisper_device": {
"description": "推理设备",
"type": "string",
"hint": "Whisper 推理设备。Apple Silicon 可选 mps其他环境建议使用 cpu。若指定 mps 但当前环境不可用,将自动回退到 cpu。",
"options": ["cpu", "mps"],
},
"id": {
"description": "ID",
"type": "string",
@@ -2685,12 +2425,12 @@ CONFIG_METADATA_2 = {
"deerflow_assistant_id": {
"description": "Assistant ID",
"type": "string",
"hint": "DeerFlow 2.0 LangGraph assistant_id默认为 lead_agent。",
"hint": "LangGraph assistant_id默认为 lead_agent。",
},
"deerflow_model_name": {
"description": "模型名称覆盖",
"type": "string",
"hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name",
"hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name",
},
"deerflow_thinking_enabled": {
"description": "启用思考模式",
@@ -2699,17 +2439,17 @@ CONFIG_METADATA_2 = {
"deerflow_plan_mode": {
"description": "启用计划模式",
"type": "bool",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。",
"hint": "对应 DeerFlow 的 is_plan_mode。",
},
"deerflow_subagent_enabled": {
"description": "启用子智能体",
"type": "bool",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。",
"hint": "对应 DeerFlow 的 subagent_enabled。",
},
"deerflow_max_concurrent_subagents": {
"description": "子智能体最大并发数",
"type": "int",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效默认 3。",
"hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效默认 3。",
},
"deerflow_recursion_limit": {
"description": "递归深度上限",
@@ -2778,9 +2518,6 @@ CONFIG_METADATA_2 = {
"show_tool_call_result": {
"type": "bool",
},
"buffer_intermediate_messages": {
"type": "bool",
},
"unsupported_streaming_strategy": {
"type": "string",
},
@@ -3197,13 +2934,7 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": {
"description": "网页搜索提供商",
"type": "string",
"options": [
"tavily",
"baidu_ai_search",
"bocha",
"brave",
"firecrawl",
],
"options": ["default", "tavily", "baidu_ai_search", "bocha"],
"condition": {
"provider_settings.web_search": True,
},
@@ -3228,26 +2959,6 @@ CONFIG_METADATA_3 = {
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_brave_key": {
"description": "Brave Search API Key",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": {
"provider_settings.websearch_provider": "brave",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_firecrawl_key": {
"description": "Firecrawl API Key",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": {
"provider_settings.websearch_provider": "firecrawl",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"type": "string",
@@ -3558,15 +3269,6 @@ CONFIG_METADATA_3 = {
"provider_settings.show_tool_use_status": True,
},
},
"provider_settings.buffer_intermediate_messages": {
"description": "合并 Agent 中间消息",
"type": "bool",
"hint": "开启后,非流式模式下多步工具调用过程中产生的中间文本将缓冲,待 Agent 完成后合并为一条回复发送。",
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.streaming_response": False,
},
},
"provider_settings.sanitize_context_by_modalities": {
"description": "按模型能力清理历史上下文",
"type": "bool",
@@ -3609,39 +3311,14 @@ CONFIG_METADATA_3 = {
"type": "string",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
},
"provider_settings.image_compress_enabled": {
"description": "启用图片压缩",
"type": "bool",
"hint": "启用后,发送给多模态模型前会先压缩本地大图片。",
},
"provider_settings.image_compress_options.max_size": {
"description": "最大边长",
"type": "int",
"hint": "压缩后图片的最长边,单位为像素。超过该尺寸时会按比例缩放。",
"condition": {
"provider_settings.image_compress_enabled": True,
},
"slider": {"min": 256, "max": 4096, "step": 64},
},
"provider_settings.image_compress_options.quality": {
"description": "压缩质量",
"type": "int",
"hint": "JPEG 输出质量,范围为 1-100。值越高画质越好文件也越大。",
"condition": {
"provider_settings.image_compress_enabled": True,
},
"slider": {"min": 1, "max": 100, "step": 1},
},
"provider_tts_settings.dual_output": {
"description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool",
"collapsed": True,
},
"provider_settings.reachability_check": {
"description": "提供商可达性检测",
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
"collapsed": True,
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
@@ -3650,7 +3327,6 @@ CONFIG_METADATA_3 = {
"condition": {
"provider_settings.agent_runner_type": "local",
},
"collapsed": True,
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
@@ -3659,7 +3335,6 @@ CONFIG_METADATA_3 = {
"condition": {
"provider_settings.agent_runner_type": "local",
},
"collapsed": True,
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
@@ -3668,7 +3343,6 @@ CONFIG_METADATA_3 = {
"condition": {
"provider_settings.agent_runner_type": "local",
},
"collapsed": True,
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
@@ -3677,7 +3351,6 @@ CONFIG_METADATA_3 = {
"condition": {
"provider_settings.agent_runner_type": "local",
},
"collapsed": True,
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
@@ -3686,7 +3359,6 @@ CONFIG_METADATA_3 = {
"condition": {
"provider_settings.agent_runner_type": "local",
},
"collapsed": True,
},
},
"condition": {
@@ -3761,7 +3433,7 @@ CONFIG_METADATA_3 = {
"description": "白名单 ID 列表",
"type": "list",
"items": {"type": "string"},
"hint": "使用 /sid 获取 ID。当白名单列表为空时,代表不启用白名单(即所有 ID 都在白名单内)。",
"hint": "使用 /sid 获取 ID。",
},
"platform_settings.id_whitelist_log": {
"description": "输出日志",
@@ -4188,9 +3860,9 @@ CONFIG_METADATA_3_SYSTEM = {
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
},
"http_proxy": {
"description": "代理",
"description": "HTTP 代理",
"type": "string",
"hint": "启用后,会以添加环境变量的方式设置代理。支持 http://、https://、socks5:// 格式例如http://127.0.0.1:7890 或 socks5://127.0.0.1:7891",
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
},
"no_proxy": {
"description": "直连地址列表",

View File

@@ -22,12 +22,6 @@ if TYPE_CHECKING:
from astrbot.core.star.context import Context
class CronJobSchedulingError(Exception):
"""Raised when a cron job fails to be scheduled."""
pass
class CronJobManager:
"""Central scheduler for BasicCronJob and ActiveAgentCronJob."""
@@ -65,10 +59,7 @@ class CronJobManager:
job.job_id,
)
continue
try:
self._schedule_job(job)
except CronJobSchedulingError:
continue # Error already logged in _schedule_job
self._schedule_job(job)
async def add_basic_job(
self,
@@ -190,15 +181,12 @@ class CronJobManager:
job.job_id, next_run_time=self._get_next_run_time(job.job_id)
)
)
except (ValueError, TypeError) as e:
logger.exception("Failed to schedule cron job %s", job.job_id)
raise CronJobSchedulingError(str(e)) from e
except Exception as e:
logger.error(f"Failed to schedule cron job {job.job_id}: {e!s}")
def _get_next_run_time(self, job_id: str):
aps_job = self.scheduler.get_job(job_id)
if not aps_job or aps_job.next_run_time is None:
return None
return aps_job.next_run_time.astimezone(timezone.utc)
return aps_job.next_run_time if aps_job else None
async def _run_job(self, job_id: str) -> None:
job = await self.db.get_cron_job(job_id)
@@ -287,8 +275,8 @@ class CronJobManager:
)
from astrbot.core.astr_main_agent_resources import (
PROACTIVE_AGENT_CRON_WOKE_SYSTEM_PROMPT,
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.tools.message_tools import SendMessageToUserTool
try:
session = (
@@ -319,11 +307,8 @@ class CronJobManager:
if cron_payload.get("origin", "tool") == "api":
cron_event.role = "admin"
tool_call_timeout = cfg.get("provider_settings", {}).get(
"tool_call_timeout", 120
)
config = MainAgentBuildConfig(
tool_call_timeout=tool_call_timeout,
tool_call_timeout=3600,
llm_safety_mode=False,
streaming_response=False,
)
@@ -347,16 +332,14 @@ class CronJobManager:
cron_job=cron_job_str
)
req.prompt = (
"You are now responding to a scheduled task. "
"You are now responding to a scheduled task"
"Proceed according to your system instructions. "
"Output using same language as previous conversation. "
"Output using same language as previous conversation."
"After completing your task, summarize and output your actions and results."
)
if not req.func_tool:
req.func_tool = ToolSet()
req.func_tool.add_tool(
self.ctx.get_llm_tool_manager().get_builtin_tool(SendMessageToUserTool)
)
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
result = await build_main_agent(
event=cron_event, plugin_context=self.ctx, config=config, req=req

View File

@@ -21,10 +21,8 @@ from astrbot.core.db.po import (
PlatformSession,
PlatformStat,
Preference,
ProviderStat,
SessionProjectRelation,
Stats,
WebChatThread,
)
@@ -35,18 +33,10 @@ class BaseDatabase(abc.ABC):
DATABASE_URL = ""
def __init__(self) -> None:
# SQLite only supports a single writer at a time. Without a busy
# timeout the driver raises "database is locked" instantly when a
# second write is attempted. Setting timeout=30 tells SQLite to
# wait up to 30 s for the lock, which is enough to ride out brief
# write bursts from concurrent agent/metrics/session operations.
is_sqlite = "sqlite" in self.DATABASE_URL
connect_args = {"timeout": 30} if is_sqlite else {}
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
connect_args=connect_args,
)
self.AsyncSessionLocal = async_sessionmaker(
self.engine,
@@ -107,21 +97,6 @@ class BaseDatabase(abc.ABC):
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...
@abc.abstractmethod
async def insert_provider_stat(
self,
*,
umo: str,
provider_id: str,
provider_model: str | None = None,
conversation_id: str | None = None,
status: str = "completed",
stats: dict | None = None,
agent_type: str = "internal",
) -> ProviderStat:
"""Insert a per-response provider stat record."""
...
@abc.abstractmethod
async def get_conversations(
self,
@@ -205,26 +180,10 @@ class BaseDatabase(abc.ABC):
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
llm_checkpoint_id: str | None = None,
) -> PlatformMessageHistory:
"""Insert a new platform message history record."""
...
@abc.abstractmethod
async def update_platform_message_history(
self,
message_id: int,
content: dict | None = None,
llm_checkpoint_id: str | None = None,
) -> None:
"""Update a platform message history record."""
...
@abc.abstractmethod
async def delete_platform_message_history_by_id(self, message_id: int) -> None:
"""Delete a platform message history record by its ID."""
...
@abc.abstractmethod
async def delete_platform_message_offset(
self,
@@ -254,68 +213,6 @@ class BaseDatabase(abc.ABC):
"""Get a platform message history record by its ID."""
...
@abc.abstractmethod
async def create_webchat_thread(
self,
creator: str,
parent_session_id: str,
parent_message_id: int,
base_checkpoint_id: str,
selected_text: str,
) -> WebChatThread:
"""Create a WebChat side thread."""
...
@abc.abstractmethod
async def get_webchat_thread_by_id(
self,
thread_id: str,
) -> WebChatThread | None:
"""Get a WebChat side thread by thread_id."""
...
@abc.abstractmethod
async def get_webchat_threads_by_parent_session(
self,
parent_session_id: str,
creator: str | None = None,
) -> list[WebChatThread]:
"""Get side threads for a parent WebChat session."""
...
@abc.abstractmethod
async def get_webchat_thread_by_parent_message_and_text(
self,
parent_session_id: str,
parent_message_id: int,
selected_text: str,
creator: str | None = None,
) -> WebChatThread | None:
"""Get an existing side thread for the same selected text."""
...
@abc.abstractmethod
async def delete_webchat_thread(self, thread_id: str) -> None:
"""Delete a WebChat side thread."""
...
@abc.abstractmethod
async def delete_webchat_threads_by_parent_session(
self,
parent_session_id: str,
) -> list[str]:
"""Delete side threads for a parent WebChat session."""
...
@abc.abstractmethod
async def delete_webchat_threads_by_parent_message_ids(
self,
parent_session_id: str,
parent_message_ids: list[int],
) -> list[str]:
"""Delete side threads linked to parent message IDs."""
...
@abc.abstractmethod
async def insert_attachment(
self,
@@ -750,13 +647,6 @@ class BaseDatabase(abc.ABC):
"""Get a Platform session by its ID."""
...
@abc.abstractmethod
async def get_platform_sessions_by_ids(
self, session_ids: list[str]
) -> list[PlatformSession]:
"""Get platform sessions by IDs."""
...
@abc.abstractmethod
async def get_platform_sessions_by_creator(
self,

View File

@@ -38,30 +38,6 @@ class PlatformStat(SQLModel, table=True):
)
class ProviderStat(TimestampMixin, SQLModel, table=True):
"""Per-response provider stats for internal agent runs."""
__tablename__: str = "provider_stats"
id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
agent_type: str = Field(default="internal", nullable=False, index=True)
status: str = Field(default="completed", nullable=False, index=True)
umo: str = Field(nullable=False, index=True)
conversation_id: str | None = Field(default=None, index=True)
provider_id: str = Field(nullable=False, index=True)
provider_model: str | None = Field(default=None, index=True)
token_input_other: int = Field(default=0, nullable=False)
token_input_cached: int = Field(default=0, nullable=False)
token_output: int = Field(default=0, nullable=False)
start_time: float = Field(default=0.0, nullable=False)
end_time: float = Field(default=0.0, nullable=False)
time_to_first_token: float = Field(default=0.0, nullable=False)
class ConversationV2(TimestampMixin, SQLModel, table=True):
__tablename__: str = "conversations"
@@ -244,37 +220,6 @@ class PlatformMessageHistory(TimestampMixin, SQLModel, table=True):
default=None,
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
llm_checkpoint_id: str | None = Field(default=None, index=True)
class WebChatThread(TimestampMixin, SQLModel, table=True):
"""A side thread created from a selected WebChat assistant response."""
__tablename__: str = "webchat_threads"
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
thread_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
creator: str = Field(nullable=False, index=True)
parent_session_id: str = Field(nullable=False, index=True)
parent_message_id: int = Field(nullable=False, index=True)
base_checkpoint_id: str = Field(nullable=False, index=True)
selected_text: str = Field(sa_type=Text, nullable=False)
__table_args__ = (
UniqueConstraint(
"thread_id",
name="uix_webchat_thread_id",
),
)
class PlatformSession(TimestampMixin, SQLModel, table=True):

View File

@@ -23,10 +23,8 @@ from astrbot.core.db.po import (
PlatformSession,
PlatformStat,
Preference,
ProviderStat,
SessionProjectRelation,
SQLModel,
WebChatThread,
)
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat,
@@ -61,7 +59,6 @@ class SQLiteDatabase(BaseDatabase):
await self._ensure_persona_folder_columns(conn)
await self._ensure_persona_skills_column(conn)
await self._ensure_persona_custom_error_message_column(conn)
await self._ensure_platform_message_history_checkpoint_column(conn)
await conn.commit()
async def _ensure_persona_folder_columns(self, conn) -> None:
@@ -106,26 +103,6 @@ class SQLiteDatabase(BaseDatabase):
text("ALTER TABLE personas ADD COLUMN custom_error_message TEXT")
)
async def _ensure_platform_message_history_checkpoint_column(self, conn) -> None:
"""Ensure platform_message_history has llm_checkpoint_id."""
result = await conn.execute(text("PRAGMA table_info(platform_message_history)"))
columns = {row[1] for row in result.fetchall()}
if "llm_checkpoint_id" not in columns:
await conn.execute(
text(
"ALTER TABLE platform_message_history "
"ADD COLUMN llm_checkpoint_id VARCHAR DEFAULT NULL"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS "
"ix_platform_message_history_llm_checkpoint_id "
"ON platform_message_history (llm_checkpoint_id)"
)
)
# ====
# Platform Statistics
# ====
@@ -192,51 +169,6 @@ class SQLiteDatabase(BaseDatabase):
)
return list(result.scalars().all())
async def insert_provider_stat(
self,
*,
umo: str,
provider_id: str,
provider_model: str | None = None,
conversation_id: str | None = None,
status: str = "completed",
stats: dict | None = None,
agent_type: str = "internal",
) -> ProviderStat:
"""Insert a provider stat record for a single agent response."""
stats = stats or {}
token_usage = stats.get("token_usage", {})
token_input_other = int(token_usage.get("input_other", 0) or 0)
token_input_cached = int(token_usage.get("input_cached", 0) or 0)
token_output = int(token_usage.get("output", 0) or 0)
start_time = float(stats.get("start_time", 0.0) or 0.0)
end_time = float(stats.get("end_time", 0.0) or 0.0)
time_to_first_token = float(stats.get("time_to_first_token", 0.0) or 0.0)
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
record = ProviderStat(
agent_type=agent_type,
status=status,
umo=umo,
conversation_id=conversation_id,
provider_id=provider_id,
provider_model=provider_model,
token_input_other=token_input_other,
token_input_cached=token_input_cached,
token_output=token_output,
start_time=start_time,
end_time=end_time,
time_to_first_token=time_to_first_token,
)
session.add(record)
await session.flush()
await session.refresh(record)
return record
# ====
# Conversation Management
# ====
@@ -521,7 +453,6 @@ class SQLiteDatabase(BaseDatabase):
content,
sender_id=None,
sender_name=None,
llm_checkpoint_id=None,
):
"""Insert a new platform message history record."""
async with self.get_db() as session:
@@ -533,46 +464,10 @@ class SQLiteDatabase(BaseDatabase):
content=content,
sender_id=sender_id,
sender_name=sender_name,
llm_checkpoint_id=llm_checkpoint_id,
)
session.add(new_history)
return new_history
async def update_platform_message_history(
self,
message_id: int,
content: dict | None = None,
llm_checkpoint_id: str | None = None,
) -> None:
"""Update a platform message history record."""
values = {}
if content is not None:
values["content"] = content
if llm_checkpoint_id is not None:
values["llm_checkpoint_id"] = llm_checkpoint_id
if not values:
return
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
update(PlatformMessageHistory)
.where(PlatformMessageHistory.id == message_id)
.values(**values)
)
async def delete_platform_message_history_by_id(self, message_id: int) -> None:
"""Delete a platform message history record by ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(PlatformMessageHistory).where(
PlatformMessageHistory.id == message_id
)
)
async def delete_platform_message_offset(
self,
platform_id,
@@ -627,136 +522,6 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
async def create_webchat_thread(
self,
creator: str,
parent_session_id: str,
parent_message_id: int,
base_checkpoint_id: str,
selected_text: str,
) -> WebChatThread:
"""Create a WebChat side thread."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
thread = WebChatThread(
creator=creator,
parent_session_id=parent_session_id,
parent_message_id=parent_message_id,
base_checkpoint_id=base_checkpoint_id,
selected_text=selected_text,
)
session.add(thread)
await session.flush()
await session.refresh(thread)
return thread
async def get_webchat_thread_by_id(
self,
thread_id: str,
) -> WebChatThread | None:
"""Get a WebChat side thread by thread_id."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(WebChatThread).where(WebChatThread.thread_id == thread_id)
)
return result.scalar_one_or_none()
async def get_webchat_threads_by_parent_session(
self,
parent_session_id: str,
creator: str | None = None,
) -> list[WebChatThread]:
"""Get side threads for a parent WebChat session."""
async with self.get_db() as session:
session: AsyncSession
query = select(WebChatThread).where(
WebChatThread.parent_session_id == parent_session_id
)
if creator is not None:
query = query.where(WebChatThread.creator == creator)
query = query.order_by(WebChatThread.created_at)
result = await session.execute(query)
return list(result.scalars().all())
async def get_webchat_thread_by_parent_message_and_text(
self,
parent_session_id: str,
parent_message_id: int,
selected_text: str,
creator: str | None = None,
) -> WebChatThread | None:
"""Get an existing side thread for the same selected text."""
async with self.get_db() as session:
session: AsyncSession
query = select(WebChatThread).where(
WebChatThread.parent_session_id == parent_session_id,
WebChatThread.parent_message_id == parent_message_id,
WebChatThread.selected_text == selected_text,
)
if creator is not None:
query = query.where(WebChatThread.creator == creator)
result = await session.execute(query)
return result.scalar_one_or_none()
async def delete_webchat_thread(self, thread_id: str) -> None:
"""Delete a WebChat side thread."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(WebChatThread).where(WebChatThread.thread_id == thread_id)
)
async def delete_webchat_threads_by_parent_session(
self,
parent_session_id: str,
) -> list[str]:
"""Delete side threads for a parent WebChat session."""
threads = await self.get_webchat_threads_by_parent_session(parent_session_id)
thread_ids = [thread.thread_id for thread in threads]
if not thread_ids:
return []
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(WebChatThread).where(
col(WebChatThread.thread_id).in_(thread_ids)
)
)
return thread_ids
async def delete_webchat_threads_by_parent_message_ids(
self,
parent_session_id: str,
parent_message_ids: list[int],
) -> list[str]:
"""Delete side threads linked to parent message IDs."""
if not parent_message_ids:
return []
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(WebChatThread.thread_id).where(
WebChatThread.parent_session_id == parent_session_id,
col(WebChatThread.parent_message_id).in_(parent_message_ids),
)
)
thread_ids = list(result.scalars().all())
if not thread_ids:
return []
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(WebChatThread).where(
col(WebChatThread.thread_id).in_(thread_ids)
)
)
return thread_ids
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
@@ -1652,21 +1417,6 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_platform_sessions_by_ids(
self, session_ids: list[str]
) -> list[PlatformSession]:
"""Get platform sessions by IDs."""
if not session_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformSession).where(
col(PlatformSession.session_id).in_(session_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_platform_sessions_by_creator(
self,
creator: str,

View File

@@ -2,22 +2,13 @@ import json
import os
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from sqlalchemy import Column, Text, bindparam
from sqlalchemy import Column, Text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
from astrbot.core import logger
from astrbot.core.knowledge_base.retrieval.tokenizer import (
build_fts5_or_query,
load_stopwords,
to_fts5_search_text,
)
FTS_TABLE_NAME = "documents_fts"
FTS_REBUILD_BATCH_SIZE = 1000
class BaseDocModel(SQLModel, table=False):
@@ -51,10 +42,6 @@ class DocumentStorage:
os.path.dirname(__file__),
"sqlite_init.sql",
)
self.fts5_available = False
self._fts_contentless_delete = False
self._fts_index_ready = False
self._stopwords: set[str] | None = None
async def initialize(self) -> None:
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
@@ -91,105 +78,8 @@ class DocumentStorage:
except BaseException:
pass
await self._initialize_fts5(conn)
await conn.commit()
async def _initialize_fts5(self, executor) -> None:
try:
await self._create_fts5_table(executor, if_not_exists=True)
is_valid_fts5, has_contentless_delete = await self._inspect_fts5_table(
executor,
)
if not is_valid_fts5:
logger.warning(
f"Detected incompatible legacy table `{FTS_TABLE_NAME}` in "
f"{self.db_path}; recreating FTS5 table.",
)
await executor.execute(text(f"DROP TABLE IF EXISTS {FTS_TABLE_NAME}"))
await self._create_fts5_table(executor, if_not_exists=False)
is_valid_fts5, has_contentless_delete = await self._inspect_fts5_table(
executor,
)
if not is_valid_fts5:
raise RuntimeError(
f"Failed to create a valid FTS5 table `{FTS_TABLE_NAME}`",
)
self.fts5_available = True
self._fts_contentless_delete = has_contentless_delete
except Exception as e:
self.fts5_available = False
self._fts_contentless_delete = False
logger.warning(
f"SQLite FTS5 is unavailable for document storage {self.db_path}; "
f"falling back to in-memory BM25 sparse retrieval: {e}",
)
async def _create_fts5_table(self, executor, if_not_exists: bool) -> None:
create_clause = (
"CREATE VIRTUAL TABLE IF NOT EXISTS"
if if_not_exists
else "CREATE VIRTUAL TABLE"
)
try:
await executor.execute(
text(
f"""
{create_clause} {FTS_TABLE_NAME}
USING fts5(
search_text,
content='',
contentless_delete=1,
tokenize='unicode61'
)
""",
),
)
except Exception:
await executor.execute(
text(
f"""
{create_clause} {FTS_TABLE_NAME}
USING fts5(
search_text,
content='',
tokenize='unicode61'
)
""",
),
)
async def _inspect_fts5_table(self, executor) -> tuple[bool, bool]:
schema_result = await executor.execute(
text(
"""
SELECT sql
FROM sqlite_master
WHERE type='table' AND name=:table_name
""",
),
{"table_name": FTS_TABLE_NAME},
)
create_sql = schema_result.scalar_one_or_none()
if not create_sql:
return False, False
normalized_sql = create_sql.lower()
if "virtual table" not in normalized_sql or "using fts5" not in normalized_sql:
return False, False
pragma_result = await executor.execute(
text(f"PRAGMA table_info({FTS_TABLE_NAME})"),
)
columns = {row[1] for row in pragma_result.fetchall()}
if "search_text" not in columns:
return False, False
normalized_sql_no_whitespace = "".join(normalized_sql.split())
return True, "contentless_delete=1" in normalized_sql_no_whitespace
async def connect(self) -> None:
"""Connect to the SQLite database."""
if self.engine is None:
@@ -210,18 +100,6 @@ class DocumentStorage:
async with self.async_session_maker() as session: # type: ignore
yield session
@property
def stopwords(self) -> set[str]:
if self._stopwords is None:
stopwords_path = (
Path(__file__).parents[3]
/ "knowledge_base"
/ "retrieval"
/ "hit_stopwords.txt"
)
self._stopwords = load_stopwords(stopwords_path)
return self._stopwords
async def get_documents(
self,
metadata_filters: dict,
@@ -294,8 +172,6 @@ class DocumentStorage:
)
session.add(document)
await session.flush() # Flush to get the ID
if document.id is not None:
await self._insert_fts_row(session, int(document.id), text)
return document.id # type: ignore
async def insert_documents_batch(
@@ -333,7 +209,6 @@ class DocumentStorage:
session.add(document)
await session.flush() # Flush to get all IDs
await self._insert_fts_rows_batch(session, documents, texts)
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str) -> None:
@@ -351,8 +226,6 @@ class DocumentStorage:
document = result.scalar_one_or_none()
if document:
if document.id is not None:
await self._delete_fts_row(session, int(document.id), document.text)
await session.delete(document)
async def get_document_by_doc_id(self, doc_id: str):
@@ -392,13 +265,9 @@ class DocumentStorage:
document = result.scalar_one_or_none()
if document:
if document.id is not None:
await self._delete_fts_row(session, int(document.id), document.text)
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
if document.id is not None:
await self._insert_fts_row(session, int(document.id), new_text)
async def delete_documents(self, metadata_filters: dict) -> None:
"""Delete documents by their metadata filters.
@@ -424,7 +293,6 @@ class DocumentStorage:
result = await session.execute(query)
documents = result.scalars().all()
await self._delete_fts_rows_batch(session, documents)
for doc in documents:
await session.delete(doc)
@@ -455,286 +323,6 @@ class DocumentStorage:
count = result.scalar_one_or_none()
return count if count is not None else 0
async def ensure_fts_index(self) -> bool:
"""Ensure the FTS5 sparse index exists and matches the documents table."""
if not self.fts5_available:
return False
if self._fts_index_ready:
return True
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
doc_count = await self._count_documents_in_session(session)
fts_count = await self._count_fts_rows(session)
if doc_count == fts_count:
self._fts_index_ready = True
return True
logger.info(
f"Rebuilding FTS5 sparse index for {self.db_path}: "
f"documents={doc_count}, fts_rows={fts_count}",
)
await self.rebuild_fts_index()
return self.fts5_available
async def rebuild_fts_index(self) -> None:
"""Rebuild the contentless FTS5 sparse index from documents."""
if not self.fts5_available:
return
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
await session.execute(text(f"DROP TABLE IF EXISTS {FTS_TABLE_NAME}"))
await self._initialize_fts5(session)
if not self.fts5_available:
return
last_id = 0
while True:
query = (
select(Document)
.where(col(Document.id) > last_id)
.order_by(col(Document.id))
.limit(FTS_REBUILD_BATCH_SIZE)
)
result = await session.execute(query)
documents = result.scalars().all()
if not documents:
break
await self._insert_fts_rows_batch(
session,
documents,
[doc.text for doc in documents],
)
last_id = int(documents[-1].id or last_id)
self._fts_index_ready = True
async def search_sparse(
self,
query_tokens: list[str],
limit: int,
) -> list[dict] | None:
"""Search chunks using the FTS5 sparse index.
Returns None when FTS5 is unavailable so callers can fall back to another
sparse retrieval implementation.
"""
if limit <= 0:
return []
if not await self.ensure_fts_index():
return None
match_query = build_fts5_or_query(query_tokens)
if not match_query:
return []
async with self.get_session() as session:
try:
result = await session.execute(
text(
f"""
SELECT
d.id AS id,
d.doc_id AS doc_id,
d.text AS text,
d.metadata AS metadata,
d.created_at AS created_at,
d.updated_at AS updated_at,
bm25({FTS_TABLE_NAME}) AS score
FROM {FTS_TABLE_NAME}
JOIN documents d ON d.id = {FTS_TABLE_NAME}.rowid
WHERE {FTS_TABLE_NAME} MATCH :query
ORDER BY score ASC, d.id ASC
LIMIT :limit
""",
),
{"query": match_query, "limit": int(limit)},
)
except Exception as e:
logger.warning(
f"FTS5 sparse search failed for {self.db_path}; "
f"falling back to in-memory BM25: {e}",
)
self.fts5_available = False
return None
rows = result.mappings().all()
return [
{
"id": row["id"],
"doc_id": row["doc_id"],
"text": row["text"],
"metadata": row["metadata"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"score": float(row["score"]),
}
for row in rows
]
async def _count_documents_in_session(self, session: AsyncSession) -> int:
result = await session.execute(select(func.count(col(Document.id))))
count = result.scalar_one_or_none()
return int(count or 0)
async def _count_fts_rows(self, session: AsyncSession) -> int:
result = await session.execute(
text(f"SELECT count(*) FROM {FTS_TABLE_NAME}"),
)
count = result.scalar_one_or_none()
return int(count or 0)
async def _insert_fts_row(
self,
session: AsyncSession,
rowid: int,
content: str,
) -> None:
if not self.fts5_available:
return
search_text = to_fts5_search_text(content, self.stopwords)
await session.execute(
text(
f"""
INSERT INTO {FTS_TABLE_NAME}(rowid, search_text)
VALUES (:rowid, :search_text)
""",
),
{"rowid": rowid, "search_text": search_text},
)
async def _insert_fts_rows_batch(
self,
session: AsyncSession,
documents: list[Document],
contents: list[str],
) -> None:
if not self.fts5_available:
return
fts_params = [
{
"rowid": int(doc.id),
"search_text": to_fts5_search_text(content, self.stopwords),
}
for doc, content in zip(documents, contents)
if doc.id is not None
]
if not fts_params:
return
await session.execute(
text(
f"""
INSERT INTO {FTS_TABLE_NAME}(rowid, search_text)
VALUES (:rowid, :search_text)
""",
),
fts_params,
)
async def _delete_fts_row(
self,
session: AsyncSession,
rowid: int,
content: str,
) -> None:
if not self.fts5_available:
return
if self._fts_contentless_delete:
await session.execute(
text(f"DELETE FROM {FTS_TABLE_NAME} WHERE rowid = :rowid"),
{"rowid": rowid},
)
return
if not await self._fts_row_exists(session, rowid):
return
search_text = to_fts5_search_text(content, self.stopwords)
await session.execute(
text(
f"""
INSERT INTO {FTS_TABLE_NAME}({FTS_TABLE_NAME}, rowid, search_text)
VALUES ('delete', :rowid, :search_text)
""",
),
{"rowid": rowid, "search_text": search_text},
)
async def _delete_fts_rows_batch(
self,
session: AsyncSession,
documents: list[Document],
) -> None:
if not self.fts5_available:
return
docs_with_ids = [doc for doc in documents if doc.id is not None]
if not docs_with_ids:
return
if self._fts_contentless_delete:
await session.execute(
text(f"DELETE FROM {FTS_TABLE_NAME} WHERE rowid = :rowid"),
[{"rowid": int(doc.id)} for doc in docs_with_ids if doc.id is not None],
)
return
existing_rowids = await self._existing_fts_rowids(
session,
[int(doc.id) for doc in docs_with_ids if doc.id is not None],
)
fts_params = [
{
"rowid": int(doc.id),
"search_text": to_fts5_search_text(doc.text, self.stopwords),
}
for doc in docs_with_ids
if doc.id is not None and int(doc.id) in existing_rowids
]
if not fts_params:
return
await session.execute(
text(
f"""
INSERT INTO {FTS_TABLE_NAME}({FTS_TABLE_NAME}, rowid, search_text)
VALUES ('delete', :rowid, :search_text)
""",
),
fts_params,
)
async def _fts_row_exists(self, session: AsyncSession, rowid: int) -> bool:
result = await session.execute(
text(f"SELECT 1 FROM {FTS_TABLE_NAME} WHERE rowid = :rowid LIMIT 1"),
{"rowid": rowid},
)
return result.scalar_one_or_none() is not None
async def _existing_fts_rowids(
self,
session: AsyncSession,
rowids: list[int],
) -> set[int]:
if not rowids:
return set()
result = await session.execute(
text(
f"SELECT rowid FROM {FTS_TABLE_NAME} WHERE rowid IN :rowids"
).bindparams(bindparam("rowids", expanding=True)),
{"rowids": rowids},
)
return {int(row[0]) for row in result.fetchall()}
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.

View File

@@ -4,7 +4,6 @@ import uuid
import numpy as np
from astrbot import logger
from astrbot.core.exceptions import KnowledgeBaseUploadError
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from ..base import BaseVecDB, Result
@@ -75,38 +74,6 @@ class FaissVecDB(BaseVecDB):
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
if not contents:
logger.debug(
"No contents provided for batch insert; skipping embedding generation."
)
return []
content_count = len(contents)
if len(metadatas) != content_count:
raise KnowledgeBaseUploadError(
stage="storage",
user_message=(
f"存储失败:文本分块数量与元数据数量不一致(期望 {content_count}"
f"实际 {len(metadatas)})。"
),
details={
"expected_contents": content_count,
"actual_metadatas": len(metadatas),
},
)
if len(ids) != content_count:
raise KnowledgeBaseUploadError(
stage="storage",
user_message=(
f"存储失败:文本分块数量与文档 ID 数量不一致(期望 {content_count}"
f"实际 {len(ids)})。"
),
details={
"expected_contents": content_count,
"actual_ids": len(ids),
},
)
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
@@ -120,20 +87,6 @@ class FaissVecDB(BaseVecDB):
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.",
)
if len(vectors) != content_count:
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
"向量化失败:嵌入模型返回的向量数量与文本分块数量不一致"
f"(期望 {content_count},实际 {len(vectors)})。"
"这通常说明当前 Embedding 接口未完整返回批量结果,"
"或该服务不兼容当前批量请求格式。"
),
details={
"expected_contents": content_count,
"actual_vectors": len(vectors),
},
)
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
@@ -141,52 +94,9 @@ class FaissVecDB(BaseVecDB):
contents,
metadatas,
)
if len(int_ids) != content_count:
raise KnowledgeBaseUploadError(
stage="storage",
user_message=(
f"存储失败:写入文档索引后返回的内部 ID 数量与文本分块数量不一致"
f"(期望 {content_count},实际 {len(int_ids)})。"
),
details={
"expected_contents": content_count,
"actual_int_ids": len(int_ids),
},
)
# 批量插入向量到 FAISS
try:
vectors_array = np.asarray(vectors, dtype=np.float32)
except (TypeError, ValueError) as exc:
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
"向量化失败:嵌入模型返回的向量格式不正确,"
"无法转换为统一的浮点向量矩阵。"
),
details={"vector_count": len(vectors)},
) from exc
if vectors_array.ndim != 2:
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
"向量化失败:嵌入模型返回的向量格式不正确,无法构造成二维向量矩阵。"
),
details={"actual_ndim": int(vectors_array.ndim)},
)
if vectors_array.shape[1] != self.embedding_storage.dimension:
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
"向量化失败:返回向量维度与当前知识库索引维度不一致"
f"(期望 {self.embedding_storage.dimension}"
f"实际 {vectors_array.shape[1]})。"
),
details={
"expected_dimension": self.embedding_storage.dimension,
"actual_dimension": int(vectors_array.shape[1]),
},
)
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids

View File

@@ -7,26 +7,3 @@ class AstrBotError(Exception):
class ProviderNotFoundError(AstrBotError):
"""Raised when a specified provider is not found."""
class EmptyModelOutputError(AstrBotError):
"""Raised when the model response contains no usable assistant output."""
class KnowledgeBaseUploadError(AstrBotError):
"""Raised when knowledge base upload fails with a user-facing message."""
def __init__(
self,
*,
stage: str,
user_message: str,
details: dict | None = None,
) -> None:
super().__init__(user_message)
self.stage = stage
self.user_message = user_message
self.details = details or {}
def __str__(self) -> str:
return self.user_message

View File

@@ -1,12 +1,12 @@
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import col, desc
from astrbot.core import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
@@ -15,9 +15,6 @@ from astrbot.core.knowledge_base.models import (
)
from astrbot.core.utils.astrbot_path import get_astrbot_knowledge_base_path
if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
class KBSQLiteDatabase:
def __init__(self, db_path: str | None = None) -> None:
@@ -299,7 +296,7 @@ class KBSQLiteDatabase:
return metadata_map
async def delete_document_by_id(self, doc_id: str, vec_db: "FaissVecDB") -> None:
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session, session.begin():
@@ -327,7 +324,7 @@ class KBSQLiteDatabase:
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_kb_stats(self, kb_id: str, vec_db: "FaissVecDB") -> None:
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()

View File

@@ -4,13 +4,12 @@ import re
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import aiofiles
from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.exceptions import KnowledgeBaseUploadError
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
EmbeddingProvider,
@@ -28,9 +27,6 @@ from .parsers.url_parser import extract_text_from_url
from .parsers.util import select_parser
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
class RateLimiter:
"""一个简单的速率限制器"""
@@ -109,14 +105,9 @@ Text chunk to process:
return [chunk]
def _compact_chunks(chunks: list[str]) -> list[str]:
return [chunk.strip() for chunk in chunks if chunk and chunk.strip()]
class KBHelper:
vec_db: BaseVecDB
kb: KnowledgeBase
init_error: str | None
def __init__(
self,
@@ -131,7 +122,6 @@ class KBHelper:
self.prov_mgr = provider_manager
self.kb_root_dir = kb_root_dir
self.chunker = chunker
self.init_error = None
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
@@ -158,30 +148,21 @@ class KBHelper:
async def get_rp(self) -> RerankProvider | None:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider | None = await self.prov_mgr.get_provider_by_id(
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id,
) # type: ignore
if not rp:
logger.warning(
f"知识库 {self.kb.kb_name}({self.kb.kb_id}) 的 Rerank Provider({self.kb.rerank_provider_id}) 不可用,将跳过重排序。",
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider",
)
return None
return rp
async def _ensure_vec_db(self) -> "FaissVecDB":
async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep = await self.get_ep()
rp: RerankProvider | None = None
try:
rp = await self.get_rp()
except Exception as e:
logger.warning(
f"知识库 {self.kb.kb_name}({self.kb.kb_id}) 初始化重排序能力失败,将跳过重排序: {e}",
)
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
rp = await self.get_rp()
vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
@@ -191,8 +172,6 @@ class KBHelper:
)
await vec_db.initialize()
self.vec_db = vec_db
# Clear stale init_error once initialization succeeds.
self.init_error = None
return vec_db
async def delete_vec_db(self) -> None:
@@ -204,7 +183,7 @@ class KBHelper:
shutil.rmtree(self.kb_dir)
async def terminate(self) -> None:
if hasattr(self, "vec_db") and self.vec_db:
if self.vec_db:
await self.vec_db.close()
async def upload_document(
@@ -253,7 +232,7 @@ class KBHelper:
if pre_chunked_text is not None:
# 如果提供了预分块文本,直接使用
chunks_text = _compact_chunks(pre_chunked_text)
chunks_text = pre_chunked_text
file_size = sum(len(chunk) for chunk in chunks_text)
logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
else:
@@ -269,31 +248,10 @@ class KBHelper:
if progress_callback:
await progress_callback("parsing", 0, 100)
try:
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="parsing",
user_message=(
"文档解析失败:无法读取或解析上传文件。"
"请确认文件格式受支持且文件内容未损坏。"
),
details={"file_name": file_name},
) from exc
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if not text_content or not text_content.strip():
raise KnowledgeBaseUploadError(
stage="parsing",
user_message=(
"文档解析失败:未能从文件中提取可索引文本。"
"该文件可能是扫描件、纯图片 PDF或格式暂不受支持。"
),
details={"file_name": file_name},
)
if progress_callback:
await progress_callback("parsing", 100, 100)
@@ -314,41 +272,11 @@ class KBHelper:
if progress_callback:
await progress_callback("chunking", 0, 100)
try:
chunks_text = await self.chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunks_text = _compact_chunks(chunks_text)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="chunking",
user_message=(
"分块失败:文档内容在切分文本块时发生错误。"
"请稍后重试,或调整分块参数后再次上传。"
),
details={"file_name": file_name},
) from exc
if not chunks_text or not any(chunk.strip() for chunk in chunks_text):
if pre_chunked_text is not None:
raise KnowledgeBaseUploadError(
stage="validation",
user_message=("预分块文本为空,未提供任何可索引文本块。"),
details={"file_name": file_name},
)
else:
raise KnowledgeBaseUploadError(
stage="chunking",
user_message=(
"分块失败:文档内容为空,未生成任何可索引文本块。"
),
details={"file_name": file_name},
)
chunks_text = await self.chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
@@ -369,23 +297,14 @@ class KBHelper:
if progress_callback:
await progress_callback("embedding", current, total)
try:
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="storage",
user_message=("存储失败:文本块已生成,但写入知识库索引时出错。"),
details={"file_name": file_name},
) from exc
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
)
# 保存文档的元数据
doc = KBDocument(
@@ -399,47 +318,22 @@ class KBHelper:
chunk_count=len(chunks_text),
media_count=0,
)
try:
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for media in saved_media:
session.add(media)
await session.commit()
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="metadata",
user_message=(
"元数据保存失败:文本块已写入知识库,但文档记录保存失败。"
),
details={"file_name": file_name, "doc_id": doc_id},
) from exc
await session.refresh(doc)
vec_db: FaissVecDB = self.vec_db # type: ignore
try:
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="metadata",
user_message=(
"元数据更新失败:文档已上传,但知识库统计信息刷新失败。"
),
details={"file_name": file_name, "doc_id": doc_id},
) from exc
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
return doc
except Exception as e:
if isinstance(e, KnowledgeBaseUploadError):
logger.warning(f"上传文档失败: {e}", extra={"details": e.details})
else:
logger.error(f"上传文档失败: {e}", exc_info=True)
logger.error(f"上传文档失败: {e}")
# if file_path.exists():
# file_path.unlink()
@@ -450,7 +344,7 @@ class KBHelper:
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
raise
raise e
async def list_documents(
self,
@@ -733,8 +627,6 @@ class KBHelper:
elif isinstance(result, list):
final_chunks.extend(result)
final_chunks = _compact_chunks(final_chunks)
logger.info(
f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
)

View File

@@ -1,3 +1,4 @@
import traceback
from pathlib import Path
from astrbot.core import logger
@@ -55,7 +56,8 @@ class KnowledgeBaseManager:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}", exc_info=True)
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self) -> None:
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
@@ -74,14 +76,7 @@ class KnowledgeBaseManager:
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
try:
await kb_helper.initialize()
except Exception as e:
kb_helper.init_error = str(e)
logger.error(
f"知识库 {record.kb_name}({record.kb_id}) 初始化失败: {e}",
exc_info=True,
)
await kb_helper.initialize()
self.kb_insts[record.kb_id] = kb_helper
async def create_kb(
@@ -184,20 +179,6 @@ class KnowledgeBaseManager:
return None
kb = kb_helper.kb
previous_state = {
"kb_name": kb.kb_name,
"description": kb.description,
"emoji": kb.emoji,
"embedding_provider_id": kb.embedding_provider_id,
"rerank_provider_id": kb.rerank_provider_id,
"chunk_size": kb.chunk_size,
"chunk_overlap": kb.chunk_overlap,
"top_k_dense": kb.top_k_dense,
"top_k_sparse": kb.top_k_sparse,
"top_m_final": kb.top_m_final,
}
previous_init_error = kb_helper.init_error
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
@@ -217,47 +198,12 @@ class KnowledgeBaseManager:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
# Build a new helper first. Keep current vec_db alive until new init succeeds.
new_helper = KBHelper(
kb_db=self.kb_db,
kb=kb,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
try:
await new_helper.initialize()
except Exception as e:
# Roll back in-memory settings and keep current helper available.
kb.kb_name = previous_state["kb_name"]
kb.description = previous_state["description"]
kb.emoji = previous_state["emoji"]
kb.embedding_provider_id = previous_state["embedding_provider_id"]
kb.rerank_provider_id = previous_state["rerank_provider_id"]
kb.chunk_size = previous_state["chunk_size"]
kb.chunk_overlap = previous_state["chunk_overlap"]
kb.top_k_dense = previous_state["top_k_dense"]
kb.top_k_sparse = previous_state["top_k_sparse"]
kb.top_m_final = previous_state["top_m_final"]
kb_helper.init_error = previous_init_error
logger.error(
f"知识库 {kb.kb_name}({kb.kb_id}) 重新初始化失败,继续使用旧实例: {e}",
exc_info=True,
)
return kb_helper
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
old_helper = kb_helper
self.kb_insts[kb_id] = new_helper
await old_helper.terminate()
new_helper.init_error = None
return new_helper
return kb_helper
async def retrieve(
self,
@@ -269,21 +215,11 @@ class KnowledgeBaseManager:
"""从指定知识库中检索相关内容"""
kb_ids = []
kb_id_helper_map = {}
unavailable_kbs = []
for kb_name in kb_names:
if kb_helper := await self.get_kb_by_name(kb_name):
if kb_helper.init_error:
unavailable_kbs.append((kb_name, kb_helper.init_error))
logger.warning(f"知识库 {kb_name} 不可用: {kb_helper.init_error}")
continue
kb_ids.append(kb_helper.kb.kb_id)
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
# all requested KBs are unavailable
if not kb_ids and unavailable_kbs:
errors = "; ".join(f"{n}: {e}" for n, e in unavailable_kbs)
raise ValueError(f"所有请求的知识库均不可用: {errors}")
if not kb_ids:
return {}

View File

@@ -1,13 +1,11 @@
"""文档解析器模块"""
from .base import BaseParser, MediaItem, ParseResult
from .epub_parser import EpubParser
from .pdf_parser import PDFParser
from .text_parser import TextParser
__all__ = [
"BaseParser",
"EpubParser",
"MediaItem",
"PDFParser",
"ParseResult",

View File

@@ -1,162 +0,0 @@
"""EPUB document parser."""
import html
import re
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
_KEYS = (
"Title|Author|Creator|Language|Publisher|Date|Modified|Identifier|ISBN|Description|"
"Subject|Rights|Source|Series|标题|书名|作者|语言|出版社|日期|出版日期|标识符|简介|描述|"
"主题|版权|来源|系列|タイトル|書名|著者|言語|出版社|日付|識別子|説明|件名|権利|ソース|シリーズ"
)
_META_RE = re.compile(rf"^\s*(?:[-*]\s*)?\*\*(?:{_KEYS})\s*[:]\*\*\s+\S")
_TOC_HEAD_RE = re.compile(
r"^\s{0,3}(?:#{1,6}\s*)?(?:table of contents|contents|toc|目录|目次|もくじ)\s*$",
re.I,
)
_LINK_RE = re.compile(r"(?<!!)\[([^\]]+)\]\(([^)]+)\)")
_IMG_RE = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
_EMPTY_IMG_LINK_RE = re.compile(
r"\[\s*\]\([^)]+\.(?:png|jpe?g|gif|webp|svg)(?:#[^)]+)?\)", re.I
)
_FOOTNOTE_LABEL_RE = re.compile(
r"^(?:\d{1,3}|[ivxlcdm]{1,8}|[*†‡§¶]|↩|↑|back|return|返回|回到正文)$", re.I
)
_FOOTNOTE_HREF_RE = re.compile(
r"(?:^#|[#/_-](?:fn|footnote|note|noteref|backlink|return|filepos)\b)", re.I
)
_DOTTED_TOC_RE = re.compile(r"^\s*.+?\.{2,}\s*(?:\d+|[ivxlcdm]+)\s*$", re.I)
_SEP_RE = re.compile(r"^\s*(?:[-=*_]){3,}\s*$")
_NOISE_RE = re.compile(
r"^\s*(?:\[\s*)?(?:\d{1,3}|[ivxlcdm]{1,8}|[*†‡§¶]|↩|↑)(?:\s*\])?\s*$", re.I
)
_GENERIC_ALT_RE = re.compile(
r"^(?:image|img|picture|photo|illustration|figure|fig|cover|插图|图片|图像|封面)\s*[\d._-]*$",
re.I,
)
_FILENAME_ALT_RE = re.compile(r"^[\w.\- ]+\.(?:png|jpe?g|gif|webp|svg)$", re.I)
def _n(s: str) -> str:
return (
html.unescape(s)
.replace("\r\n", "\n")
.replace("\r", "\n")
.replace("\ufeff", "")
.replace("\u00a0", " ")
.replace("\u200b", "")
)
def _is_internal(href: str) -> bool:
href = html.unescape(href).strip().lower()
return (
href.startswith("#")
or href.endswith(".html")
or href.endswith(".xhtml")
or ".html#" in href
or ".xhtml#" in href
)
def _is_toc_line(s: str) -> bool:
s = s.strip()
if not s:
return False
s = re.sub(r"^\s*(?:[-*+]|\d+\.)\s+", "", s)
m = re.fullmatch(r"\[([^\]]+)\]\(([^)]+)\)", s)
return bool((m and _is_internal(m.group(2))) or _DOTTED_TOC_RE.match(s))
def _strip_head(text: str) -> str:
lines = _n(text).split("\n")
i = 0
while i < len(lines) and not lines[i].strip():
i += 1
start = i
while i < len(lines) and _META_RE.match(lines[i].strip()):
i += 1
if i - start >= 2:
while i < len(lines) and not lines[i].strip():
i += 1
else:
i = start
toc0, had_head = i, False
if i < len(lines) and _TOC_HEAD_RE.match(lines[i].strip()):
had_head = True
i += 1
while i < len(lines) and not lines[i].strip():
i += 1
toc = 0
while i < len(lines) and i - toc0 < 120:
s = lines[i].strip()
if not s:
if toc and i + 1 < len(lines) and _is_toc_line(lines[i + 1]):
i += 1
continue
break
if not _is_toc_line(s):
break
toc += 1
i += 1
if toc >= 2 and (had_head or toc >= 3):
while i < len(lines) and not lines[i].strip():
i += 1
return "\n".join(lines[i:]).strip()
return "\n".join(lines[toc0:]).strip()
def _strip_links(text: str) -> str:
def repl(m: re.Match[str]) -> str:
label = html.unescape(m.group(1)).strip()
href = html.unescape(m.group(2)).strip().lower()
if not _is_internal(href):
return m.group(0)
if _FOOTNOTE_HREF_RE.search(href) or (
href.startswith("#") and _FOOTNOTE_LABEL_RE.fullmatch(label)
):
return ""
return label
return _LINK_RE.sub(repl, _n(text))
def _img_alt(m: re.Match[str]) -> str:
alt = re.sub(r"\s+", " ", html.unescape(m.group(1)).strip())
if not alt or _GENERIC_ALT_RE.fullmatch(alt) or _FILENAME_ALT_RE.fullmatch(alt):
return ""
return alt
def _sanitize(text: str) -> str:
out, prev_blank, prev = [], True, ""
for raw in _n(text).split("\n"):
line = _IMG_RE.sub(_img_alt, raw)
line = _EMPTY_IMG_LINK_RE.sub("", line).rstrip()
s = line.strip()
if not s:
if not prev_blank:
out.append("")
prev_blank = True
continue
if _SEP_RE.match(s) or _NOISE_RE.match(s):
continue
norm = re.sub(r"^\s{0,3}#{1,6}\s*", "", s).strip("*_ ").casefold()
if norm and norm == prev and len(norm) <= 120:
continue
out.append(line)
prev_blank = False
prev = norm
return "\n".join(out).strip()
class EpubParser(BaseParser):
"""Parse EPUB files via MarkItDown."""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
from .markitdown_parser import MarkitdownParser
result = await MarkitdownParser().parse(file_content, file_name)
text = _sanitize(_strip_links(_strip_head(result.text)))
return ParseResult(text=text, media=result.media)

View File

@@ -6,10 +6,6 @@ async def select_parser(ext: str) -> BaseParser:
from .markitdown_parser import MarkitdownParser
return MarkitdownParser()
if ext == ".epub":
from .epub_parser import EpubParser
return EpubParser()
if ext == ".pdf":
from .pdf_parser import PDFParser

View File

@@ -1,11 +1,8 @@
"""检索模块"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .manager import RetrievalManager, RetrievalResult
from .rank_fusion import FusedResult, RankFusion
from .sparse_retriever import SparseResult, SparseRetriever
from .manager import RetrievalManager, RetrievalResult
from .rank_fusion import FusedResult, RankFusion
from .sparse_retriever import SparseResult, SparseRetriever
__all__ = [
"FusedResult",
@@ -15,31 +12,3 @@ __all__ = [
"SparseResult",
"SparseRetriever",
]
def __getattr__(name: str):
if name in {"RetrievalManager", "RetrievalResult"}:
from .manager import RetrievalManager, RetrievalResult
return {
"RetrievalManager": RetrievalManager,
"RetrievalResult": RetrievalResult,
}[name]
if name in {"FusedResult", "RankFusion"}:
from .rank_fusion import FusedResult, RankFusion
return {
"FusedResult": FusedResult,
"RankFusion": RankFusion,
}[name]
if name in {"SparseResult", "SparseRetriever"}:
from .sparse_retriever import SparseResult, SparseRetriever
return {
"SparseResult": SparseResult,
"SparseRetriever": SparseRetriever,
}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -5,10 +5,10 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING
from astrbot import logger
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
@@ -16,9 +16,6 @@ from astrbot.core.provider.provider import RerankProvider
from ..kb_helper import KBHelper
if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@dataclass
class RetrievalResult:
@@ -173,31 +170,26 @@ class RetrievalManager:
first_rerank = None
for kb_id in kb_ids:
vec_db = kb_options[kb_id]["vec_db"]
rerank_provider = (
getattr(vec_db, "rerank_provider", None) if vec_db else None
)
if rerank_provider is None:
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and rerank_provider
and vec_db.rerank_provider
and rerank_pi
and rerank_pi == rerank_provider.meta().id
and rerank_pi == vec_db.rerank_provider.meta().id
):
first_rerank = rerank_provider
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
try:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=first_rerank,
)
except Exception as e:
logger.warning(f"Rerank 执行失败,已跳过重排序并使用融合结果: {e}")
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=first_rerank,
)
return retrieval_results[:top_m_final]
@@ -237,10 +229,10 @@ class RetrievalManager:
all_results.extend(vec_results)
except Exception as e:
logger.error(f"知识库 {kb_id} 稠密检索失败: {e}", exc_info=True)
if len(kb_ids) == 1:
raise RuntimeError(f"知识库 {kb_id} 稠密检索失败: {e}") from e
# multi-KB: skip the faulty KB and continue
from astrbot.core import logger
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)

View File

@@ -6,18 +6,12 @@
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import jieba
from rank_bm25 import BM25Okapi
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.tokenizer import (
load_stopwords,
tokenize_text,
)
if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@dataclass
@@ -50,9 +44,13 @@ class SparseRetriever:
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
self.hit_stopwords = load_stopwords(
with open(
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
)
encoding="utf-8",
) as f:
self.hit_stopwords = {
word.strip() for word in set(f.read().splitlines()) if word.strip()
}
async def retrieve(
self,
@@ -71,56 +69,11 @@ class SparseRetriever:
List[SparseResult]: 检索结果列表
"""
fts_results = []
fallback_kb_ids = []
query_tokens = tokenize_text(query, self.hit_stopwords)
for kb_id in kb_ids:
vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db")
if not vec_db:
continue
top_k_sparse = kb_options.get(kb_id, {}).get("top_k_sparse", 50)
result = await vec_db.document_storage.search_sparse(
query_tokens=query_tokens,
limit=top_k_sparse,
)
if result is None:
fallback_kb_ids.append(kb_id)
continue
for doc in result:
chunk_md = json.loads(doc["metadata"])
fts_results.append(
SparseResult(
chunk_id=doc["doc_id"],
chunk_index=chunk_md["chunk_index"],
doc_id=chunk_md["kb_doc_id"],
kb_id=kb_id,
content=doc["text"],
score=-float(doc["score"]),
),
)
fallback_results = []
if fallback_kb_ids:
fallback_results = await self._retrieve_with_bm25(
query=query,
kb_ids=fallback_kb_ids,
kb_options=kb_options,
)
results = fts_results + fallback_results
results.sort(key=lambda x: x.score, reverse=True)
return results
async def _retrieve_with_bm25(
self,
query: str,
kb_ids: list[str],
kb_options: dict,
) -> list[SparseResult]:
# 1. 获取所有相关块
top_k_sparse = 0
chunks = []
for kb_id in kb_ids:
vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db")
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
@@ -147,13 +100,20 @@ class SparseRetriever:
# 2. 准备文档和索引
corpus = [chunk["text"] for chunk in chunks]
tokenized_corpus = [tokenize_text(doc, self.hit_stopwords) for doc in corpus]
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
tokenized_corpus = [
[word for word in doc if word not in self.hit_stopwords]
for doc in tokenized_corpus
]
# 3. 构建 BM25 索引
bm25 = BM25Okapi(tokenized_corpus)
# 4. 执行检索
tokenized_query = tokenize_text(query, self.hit_stopwords)
tokenized_query = list(jieba.cut(query))
tokenized_query = [
word for word in tokenized_query if word not in self.hit_stopwords
]
scores = bm25.get_scores(tokenized_query)
# 5. 排序并返回 Top-K

View File

@@ -1,39 +0,0 @@
"""Tokenization helpers shared by sparse retrieval indexes."""
import re
from pathlib import Path
from re import Pattern
import jieba
_TERM_PATTERN: Pattern[str] = re.compile(r"\w", re.UNICODE)
def load_stopwords(path: Path | str) -> set[str]:
with Path(path).open(encoding="utf-8") as f:
return {word.strip() for word in set(f.read().splitlines()) if word.strip()}
def tokenize_text(text: str, stopwords: set[str]) -> list[str]:
tokens = []
for token in jieba.cut(text or ""):
token = token.strip()
if not token or token in stopwords:
continue
if not _TERM_PATTERN.search(token):
continue
tokens.append(token)
return tokens
def to_fts5_search_text(text: str, stopwords: set[str]) -> str:
return " ".join(tokenize_text(text, stopwords))
def quote_fts5_token(token: str) -> str:
return '"' + token.replace('"', '""') + '"'
def build_fts5_or_query(tokens: list[str]) -> str:
quoted_tokens = [quote_fts5_token(token) for token in tokens if token]
return " OR ".join(quoted_tokens)

View File

@@ -64,6 +64,7 @@ class ComponentType(str, Enum):
Music = "Music"
Json = "Json"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel):
@@ -90,14 +91,15 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent):
type: ComponentType = ComponentType.Plain
text: str
convert: bool | None = True
def __init__(self, text: str, convert: bool = True, **_) -> None:
super().__init__(text=text, convert=convert, **_)
def toDict(self) -> dict:
return {"type": "text", "data": {"text": self.text}}
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self) -> dict:
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
@@ -112,11 +114,15 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
type: ComponentType = ComponentType.Record
file: str | None = ""
magic: bool | None = False
url: str | None = ""
cache: bool | None = True
proxy: bool | None = True
timeout: int | None = 0
# Original text content (e.g. TTS source text), used as caption in fallback scenarios
text: str | None = None
# 额外
path: str | None = None
path: str | None
def __init__(self, file: str | None, **_) -> None:
for k in _:
@@ -218,6 +224,7 @@ class Video(BaseMessageComponent):
type: ComponentType = ComponentType.Video
file: str
cover: str | None = ""
c: int | None = 2
# 额外
path: str | None = ""
@@ -394,9 +401,14 @@ class Image(BaseMessageComponent):
type: ComponentType = ComponentType.Image
file: str | None = ""
_type: str | None = ""
subType: int | None = 0
url: str | None = ""
cache: bool | None = True
id: int | None = 40000
c: int | None = 2
# 额外
path: str | None = ""
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: str | None, **_) -> None:
super().__init__(file=file, **_)
@@ -827,6 +839,16 @@ class File(BaseMessageComponent):
}
class WechatEmoji(BaseMessageComponent):
type: ComponentType = ComponentType.WechatEmoji
md5: str | None = ""
md5_len: int | None = 0
cdnurl: str | None = ""
def __init__(self, **_) -> None:
super().__init__(**_)
ComponentTypes = {
# Basic Message Segments
"plain": Plain,
@@ -852,4 +874,5 @@ ComponentTypes = {
"nodes": Nodes,
"json": Json,
"unknown": Unknown,
"WechatEmoji": WechatEmoji,
}

View File

@@ -44,22 +44,6 @@ class PersonaManager:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona
def get_persona_v3_by_id(self, persona_id: str | None) -> Personality | None:
"""Resolve a v3 persona object by id.
- None/empty id returns None.
- "default" maps to in-memory DEFAULT_PERSONALITY.
- Otherwise search in personas_v3 by persona name.
"""
if not persona_id:
return None
if persona_id == "default":
return DEFAULT_PERSONALITY
return next(
(persona for persona in self.personas_v3 if persona["name"] == persona_id),
None,
)
async def get_default_persona_v3(
self,
umo: str | MessageSession | None = None,
@@ -70,7 +54,12 @@ class PersonaManager:
"default_personality",
"default",
)
return self.get_persona_v3_by_id(default_persona_id) or DEFAULT_PERSONALITY
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
async def resolve_selected_persona(
self,

View File

@@ -6,7 +6,6 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.message.components import Image, Plain, Record
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.media_utils import ensure_wav
from ..context import PipelineContext
from ..stage import Stage, register_stage
@@ -65,21 +64,6 @@ class PreProcessStage(Stage):
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component
# In here, we convert all Record components to wav format and update the file path.
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record):
try:
original_path = await component.convert_to_file_path()
record_path = await ensure_wav(original_path)
if record_path != original_path:
event.track_temporary_local_file(record_path)
component.file = record_path
component.path = record_path
message_chain[idx] = component
except Exception as e:
logger.warning(f"Voice processing failed: {e}")
# STT
if self.stt_settings.get("enable", False):
# TODO: 独立
@@ -92,8 +76,8 @@ class PreProcessStage(Stage):
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record):
path = await component.convert_to_file_path()
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:

View File

@@ -172,9 +172,6 @@ def try_capture_follow_up(event: AstrMessageEvent) -> FollowUpCapture | None:
if not active_sender_id or active_sender_id != sender_id:
return None
if runner_event.get_extra("agent_stop_requested"):
return None
ticket = runner.follow_up(message_text=_event_follow_up_text(event))
if not ticket:
return None

View File

@@ -5,20 +5,15 @@ import base64
from collections.abc import AsyncGenerator
from dataclasses import replace
from astrbot.core import db_helper, logger
from astrbot.core.agent.message import (
CheckpointData,
CheckpointMessageSegment,
Message,
dump_messages_with_checkpoints,
)
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.astr_main_agent import (
MainAgentBuildConfig,
MainAgentBuildResult,
build_main_agent,
)
from astrbot.core.message.components import File, Image, Record, Video
from astrbot.core.message.components import File, Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
@@ -71,10 +66,6 @@ class InternalAgentSubStage(Stage):
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
self.buffer_intermediate_messages: bool = settings.get(
"buffer_intermediate_messages",
False,
)
self.show_reasoning = settings.get("display_reasoning_text", False)
self.sanitize_context_by_modalities: bool = settings.get(
"sanitize_context_by_modalities",
@@ -153,7 +144,6 @@ class InternalAgentSubStage(Stage):
follow_up_capture: FollowUpCapture | None = None
follow_up_consumed_marked = False
follow_up_activated = False
typing_requested = False
try:
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
@@ -162,8 +152,7 @@ class InternalAgentSubStage(Stage):
has_provider_request = event.get_extra("provider_request") is not None
has_valid_message = bool(event.message_str and event.message_str.strip())
has_media_content = any(
isinstance(comp, (Image, File, Record, Video))
for comp in event.message_obj.message
isinstance(comp, Image | File) for comp in event.message_obj.message
)
if (
@@ -189,11 +178,7 @@ class InternalAgentSubStage(Stage):
)
return
try:
typing_requested = True
await event.send_typing()
except Exception:
logger.warning("send_typing failed", exc_info=True)
await event.send_typing()
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
@@ -289,7 +274,6 @@ class InternalAgentSubStage(Stage):
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
buffer_intermediate_messages=self.buffer_intermediate_messages,
),
),
)
@@ -320,7 +304,6 @@ class InternalAgentSubStage(Stage):
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
buffer_intermediate_messages=self.buffer_intermediate_messages,
),
),
)
@@ -351,7 +334,6 @@ class InternalAgentSubStage(Stage):
self.show_tool_call_result,
stream_to_general,
show_reasoning=self.show_reasoning,
buffer_intermediate_messages=self.buffer_intermediate_messages,
):
yield
@@ -363,15 +345,6 @@ class InternalAgentSubStage(Stage):
resp=final_resp.completion_text if final_resp else None,
)
asyncio.create_task(
_record_internal_agent_stats(
event,
req,
agent_runner,
final_resp,
)
)
# 检查事件是否被停止,如果被停止则不保存历史记录
if not event.is_stopped() or agent_runner.was_aborted():
await self._save_to_history(
@@ -404,11 +377,6 @@ class InternalAgentSubStage(Stage):
)
await event.send(MessageChain().message(error_text))
finally:
if typing_requested:
try:
await event.stop_typing()
except Exception:
logger.warning("stop_typing failed", exc_info=True)
if follow_up_capture:
await finalize_follow_up_capture(
follow_up_capture,
@@ -449,7 +417,7 @@ class InternalAgentSubStage(Stage):
logger.debug("LLM 响应为空,不保存记录。")
return
messages_to_save: list[Message] = []
message_to_save = []
skipped_initial_system = False
for message in all_messages:
if message.role == "system" and not skipped_initial_system:
@@ -457,16 +425,7 @@ class InternalAgentSubStage(Stage):
continue
if message.role in ["assistant", "user"] and message._no_save:
continue
messages_to_save.append(message)
checkpoint_id = event.get_extra("llm_checkpoint_id")
message_to_save = dump_messages_with_checkpoints(messages_to_save)
if isinstance(checkpoint_id, str) and checkpoint_id:
message_to_save.append(
CheckpointMessageSegment(
content=CheckpointData(id=checkpoint_id),
).model_dump()
)
message_to_save.append(message.model_dump())
# if user_aborted:
# message_to_save.append(
@@ -493,46 +452,3 @@ class InternalAgentSubStage(Stage):
# these hosts are base64 encoded
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}
decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED]
async def _record_internal_agent_stats(
event: AstrMessageEvent,
req: ProviderRequest | None,
agent_runner: AgentRunner | None,
final_resp: LLMResponse | None,
) -> None:
"""Persist internal agent stats without affecting the user response flow."""
if agent_runner is None:
return
provider = agent_runner.provider
stats = agent_runner.stats
if provider is None or stats is None:
return
try:
provider_config = getattr(provider, "provider_config", {}) or {}
conversation_id = (
req.conversation.cid
if req is not None and req.conversation is not None
else None
)
if agent_runner.was_aborted():
status = "aborted"
elif final_resp is not None and final_resp.role == "err":
status = "error"
else:
status = "completed"
await db_helper.insert_provider_stat(
umo=event.unified_msg_origin,
conversation_id=conversation_id,
provider_id=provider_config.get("id", "") or provider.meta().id,
provider_model=provider.get_model(),
status=status,
stats=stats.to_dict(),
agent_type="internal",
)
except Exception as e:
logger.warning("Persist provider stats failed: %s", e, exc_info=True)

View File

@@ -17,7 +17,7 @@ from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import (
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image, Record
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
@@ -317,11 +317,8 @@ class ThirdPartyAgentSubStage(Stage):
if isinstance(comp, Image):
image_path = await comp.convert_to_base64()
req.image_urls.append(image_path)
elif isinstance(comp, Record):
audio_path = await comp.convert_to_file_path()
req.audio_urls.append(audio_path)
if not req.prompt and not req.image_urls and not req.audio_urls:
if not req.prompt and not req.image_urls:
return
custom_error_message = await self._resolve_persona_custom_error_message(event)
@@ -381,7 +378,7 @@ class ThirdPartyAgentSubStage(Stage):
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=120,
tool_call_timeout=60,
),
agent_hooks=MAIN_AGENT_HOOKS,
provider_config=self.prov_cfg,

Some files were not shown because too many files have changed in this diff Show More