Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
4e6a932f31 fix: restore webui 401 login redirect 2026-06-19 22:39:03 +08:00
17 changed files with 122 additions and 600 deletions

View File

@@ -16,11 +16,8 @@ venv*/
ENV/
.conda/
dashboard/
!astrbot/dashboard/
!astrbot/dashboard/dist/
!astrbot/dashboard/dist/**
data/
tests/
.ruff_cache/
.astrbot
astrbot.lock
astrbot.lock

View File

@@ -46,21 +46,14 @@ jobs:
- name: Build Dashboard
run: |
dashboard_version=$(python3 - <<'PY'
import tomllib
with open("pyproject.toml", "rb") as f:
print("v" + tomllib.load(f)["project"]["version"])
PY
)
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo "$dashboard_version" > dist/assets/version
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p astrbot/dashboard
rm -rf astrbot/dashboard/dist
cp -r dashboard/dist astrbot/dashboard/dist
mkdir -p data
cp -r dashboard/dist data/
- name: Determine test image tags
id: test-meta
@@ -164,11 +157,10 @@ jobs:
npm install
npm run build
mkdir -p dist/assets
echo "${{ steps.release-meta.outputs.version }}" > dist/assets/version
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p astrbot/dashboard
rm -rf astrbot/dashboard/dist
cp -r dashboard/dist astrbot/dashboard/dist
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v4.1.0

View File

@@ -1,3 +1,3 @@
import logging
from .core.log import LogManager
logger = logging.getLogger("astrbot")
logger = LogManager.GetLogger(log_name="astrbot")

View File

@@ -1,32 +1,3 @@
import re
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as package_version
from pathlib import Path
from astrbot.core.config.default import VERSION
try:
import tomllib
except ModuleNotFoundError:
tomllib = None
try:
__version__ = package_version("astrbot")
except PackageNotFoundError:
pyproject_path = Path(__file__).resolve().parents[2] / "pyproject.toml"
try:
if tomllib is None:
match = re.search(
r"(?m)^version\s*=\s*[\"']([^\"']+)[\"']",
pyproject_path.read_text(encoding="utf-8"),
)
__version__ = match.group(1) if match else "0.0.0"
else:
with pyproject_path.open("rb") as f:
__version__ = tomllib.load(f)["project"]["version"]
except (FileNotFoundError, IndexError, KeyError, TypeError, ValueError):
__version__ = "0.0.0"
match = re.match(r"^(\d+(?:\.\d+)*)(a|b|rc)(\d+)$", __version__)
if match:
release, prerelease, number = match.groups()
prerelease = {"a": "alpha", "b": "beta", "rc": "rc"}[prerelease]
__version__ = f"{release}-{prerelease}.{number}"
__version__ = VERSION

View File

@@ -1,11 +1,16 @@
import json
import os
import zoneinfo
from collections.abc import Callable
from typing import Any
import click
from astrbot.core.utils.auth_password import (
hash_dashboard_password,
hash_md5_dashboard_password,
validate_dashboard_password,
)
from ..utils import check_astrbot_root, get_astrbot_root
@@ -39,8 +44,6 @@ def _validate_dashboard_username(value: str) -> str:
def _validate_dashboard_password(value: str) -> str:
"""Validate Dashboard password"""
from astrbot.core.utils.auth_password import validate_dashboard_password
try:
validate_dashboard_password(value)
except ValueError as e:
@@ -86,7 +89,6 @@ def _load_config() -> dict[str, Any]:
raise click.ClickException(
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
os.environ["ASTRBOT_ROOT"] = str(root)
config_path = root / "data" / "cmd_config.json"
if not config_path.exists():
@@ -105,8 +107,7 @@ def _load_config() -> dict[str, Any]:
def _save_config(config: dict[str, Any]) -> None:
"""Save config file"""
root = get_astrbot_root()
config_path = root / "data" / "cmd_config.json"
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2),
@@ -138,11 +139,6 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
def _set_dashboard_password(config: dict[str, Any], raw_password: str) -> None:
"""Set dashboard password hashes and clear password migration flags."""
from astrbot.core.utils.auth_password import (
hash_dashboard_password,
hash_md5_dashboard_password,
)
_set_nested_item(
config,
"dashboard.pbkdf2_password",

View File

@@ -21,16 +21,17 @@ def _initialize_config_from_env(astrbot_root: Path) -> None:
async def initialize_astrbot(astrbot_root: Path) -> None:
"""Execute AstrBot initialization logic.
Args:
astrbot_root: Runtime root directory to initialize.
"""
"""Execute AstrBot initialization logic"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
if click.confirm(
f"Install AstrBot to this directory? {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
paths = {
"data": astrbot_root / "data",
@@ -40,9 +41,8 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
}
for name, path in paths.items():
path_exists = path.exists()
path.mkdir(parents=True, exist_ok=True)
click.echo(f"{'Directory exists' if path_exists else 'Created'}: {path}")
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
_initialize_config_from_env(astrbot_root)
@@ -53,25 +53,7 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
def init() -> None:
"""Initialize AstrBot"""
click.echo("Initializing AstrBot...")
if os.environ.get("ASTRBOT_ROOT"):
astrbot_root = get_astrbot_root()
click.echo(f"Using ASTRBOT_ROOT: {astrbot_root}")
else:
user_root = (Path.home() / ".astrbot").resolve()
current_root = Path.cwd().resolve()
click.echo("Choose AstrBot runtime directory:")
click.echo(f"1. {user_root} (recommended)")
click.echo(f"2. Current directory: {current_root}")
choice = click.prompt(
"Select",
type=click.Choice(["1", "2"]),
default="1",
show_choices=False,
)
astrbot_root = user_root if choice == "1" else current_root
astrbot_root.mkdir(parents=True, exist_ok=True)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
@@ -83,8 +65,6 @@ def init() -> None:
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running"
)
except click.Abort:
raise
except Exception as e:
raise click.ClickException(f"Initialization failed: {e!s}")

View File

@@ -1,4 +1,3 @@
import os
from pathlib import Path
import click
@@ -8,14 +7,7 @@ _BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
def check_astrbot_root(path: str | Path) -> bool:
"""Check whether a path is an AstrBot root directory.
Args:
path: Directory path to inspect.
Returns:
Whether the directory contains the AstrBot root marker.
"""
"""Check if the path is an AstrBot root directory"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
@@ -26,24 +18,8 @@ def check_astrbot_root(path: str | Path) -> bool:
def get_astrbot_root() -> Path:
"""Get the AstrBot root directory path.
Returns:
The explicit root, current local root, default user root, or current
directory when no initialized root exists.
"""
if root := os.environ.get("ASTRBOT_ROOT"):
return Path(root).expanduser().resolve()
current_root = Path.cwd().resolve()
if check_astrbot_root(current_root):
return current_root
user_root = (Path.home() / ".astrbot").resolve()
if check_astrbot_root(user_root):
return user_root
return current_root
"""Get the AstrBot root directory path"""
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None:

View File

@@ -278,11 +278,10 @@ async def _apply_kb(
)
if not kb_result:
return
req.extra_user_content_parts.append(
TextPart(
text=f"[Related Knowledge Base Results]:\n{kb_result}",
).mark_as_temp()
)
if req.system_prompt is not None:
req.system_prompt += (
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
)
except Exception as exc: # noqa: BLE001
logger.error("Error occurred while retrieving knowledge base: %s", exc)
else:

View File

@@ -183,22 +183,8 @@ async def download_file(
path: str,
show_progress: bool = False,
progress_callback=None,
allow_insecure_ssl_fallback: bool = True,
) -> None:
"""Download a remote file to a local path.
Args:
url: Remote URL to download.
path: Local destination path.
show_progress: Whether to print progress to stdout.
progress_callback: Optional callback for progress payloads.
allow_insecure_ssl_fallback: Whether certificate failures may retry with
TLS certificate verification disabled.
Returns:
None.
"""
"""从指定 url 下载文件到指定路径 path"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
@@ -273,8 +259,6 @@ async def download_file(
},
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
if not allow_insecure_ssl_fallback:
raise
# 关闭SSL验证仅在证书验证失败时作为fallback
logger.warning(
f"SSL certificate verification failed for {_safe_url_for_log(url)}. "
@@ -371,22 +355,10 @@ def get_local_ip_addresses():
return network_ips
def get_dashboard_dist_version(dist_dir: str | Path) -> str | None:
"""Read the WebUI version from a dashboard dist directory.
Args:
dist_dir: Dashboard dist directory path.
Returns:
The version string from assets/version, or None when unavailable.
"""
def _read_dashboard_dist_version(dist_dir: str | Path) -> str | None:
version_file = Path(dist_dir) / "assets" / "version"
try:
if version_file.exists():
return version_file.read_text(encoding="utf-8").strip()
except (OSError, UnicodeDecodeError) as exc:
logger.warning("Failed to read WebUI version from %s: %s", version_file, exc)
if version_file.exists():
return version_file.read_text(encoding="utf-8").strip()
return None
@@ -408,106 +380,42 @@ def _normalize_dashboard_version(version: str) -> str:
return version
def is_dashboard_version_compatible(
dashboard_version: str | None, current_version: str
def should_use_bundled_dashboard_dist(
user_dist: str | Path, current_version: str
) -> bool:
"""Check whether a WebUI version matches the current core version.
Args:
dashboard_version: Version read from the WebUI assets/version file.
current_version: Current AstrBot core version.
Returns:
True when both versions are valid SemVer values and compare equal.
"""
if dashboard_version is None:
user_version = _read_dashboard_dist_version(user_dist)
bundled_dist = get_bundled_dashboard_dist_path()
if user_version is None or not bundled_dist.exists():
return False
try:
return (
VersionComparator.compare_version(
_normalize_dashboard_version(dashboard_version),
_normalize_dashboard_version(current_version),
_normalize_dashboard_version(user_version),
)
== 0
> 0
)
except (TypeError, ValueError):
return False
def is_dashboard_dist_compatible(dist_dir: str | Path, current_version: str) -> bool:
"""Check whether a WebUI dist is complete and matches the core version.
Args:
dist_dir: Dashboard dist directory path.
current_version: Current AstrBot core version.
Returns:
True when the dist has an index file and a compatible assets/version.
"""
dist_path = Path(dist_dir)
return (dist_path / "index.html").is_file() and is_dashboard_version_compatible(
get_dashboard_dist_version(dist_path),
current_version,
)
def should_use_bundled_dashboard_dist(
user_dist: str | Path, current_version: str
) -> bool:
"""Decide whether bundled WebUI should replace a user data dist.
Args:
user_dist: Runtime dashboard dist directory under data/.
current_version: Current AstrBot core version.
Returns:
True when user_dist exists but is missing or mismatched against the
current core version, and bundled WebUI matches the current core version.
"""
user_dist = Path(user_dist)
user_version = get_dashboard_dist_version(user_dist)
bundled_dist = get_bundled_dashboard_dist_path()
if not user_dist.exists() or not is_dashboard_dist_compatible(
bundled_dist,
current_version,
):
return False
if user_version is None or not (user_dist / "index.html").is_file():
return True
try:
return not is_dashboard_version_compatible(user_version, current_version)
except (TypeError, ValueError):
return False
async def get_dashboard_version():
"""Return the effective WebUI version for the current runtime.
Returns:
The matching data/dist version, matching bundled version, or the raw
data/dist version when no compatible bundled WebUI is available.
"""
from astrbot.core.config.default import VERSION
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(dist_dir):
user_version = get_dashboard_dist_version(dist_dir)
if is_dashboard_dist_compatible(dist_dir, VERSION):
return user_version
from astrbot.core.config.default import VERSION
bundled = get_bundled_dashboard_dist_path()
if is_dashboard_dist_compatible(bundled, VERSION):
return get_dashboard_dist_version(bundled)
return user_version
if should_use_bundled_dashboard_dist(dist_dir, VERSION):
bundled_version = _read_dashboard_dist_version(
get_bundled_dashboard_dist_path()
)
if bundled_version is not None:
return bundled_version
return _read_dashboard_dist_version(dist_dir)
bundled = get_bundled_dashboard_dist_path()
if is_dashboard_dist_compatible(bundled, VERSION):
return get_dashboard_dist_version(bundled)
if bundled.exists():
return _read_dashboard_dist_version(bundled)
return None
@@ -519,7 +427,6 @@ async def download_dashboard(
proxy: str | None = None,
progress_callback=None,
extract: bool = True,
allow_insecure_ssl_fallback: bool = True,
) -> None:
"""Download dashboard assets and optionally extract them.
@@ -531,8 +438,6 @@ async def download_dashboard(
proxy: Optional download proxy prefix.
progress_callback: Optional callback for download progress payloads.
extract: Whether to extract the archive after download.
allow_insecure_ssl_fallback: Whether certificate failures may retry with
TLS certificate verification disabled.
Returns:
None.
@@ -555,7 +460,6 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError(
@@ -587,7 +491,6 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError(
@@ -603,7 +506,6 @@ async def download_dashboard(
str(zip_path),
show_progress=True,
progress_callback=progress_callback,
allow_insecure_ssl_fallback=allow_insecure_ssl_fallback,
)
if not zipfile.is_zipfile(zip_path):
raise RuntimeError("Downloaded dashboard package is not a valid ZIP file")

View File

@@ -22,9 +22,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import (
get_bundled_dashboard_dist_path,
get_dashboard_dist_version,
get_local_ip_addresses,
is_dashboard_dist_compatible,
should_use_bundled_dashboard_dist,
)
from astrbot.dashboard.asgi_runtime import (
@@ -184,32 +182,21 @@ class AstrBotDashboard:
# Path priority:
# 1. Explicit webui_dir argument
# 2. data/dist/ when it matches the core version
# 3. astrbot/dashboard/dist/ when it matches the core version
# 2. data/dist/ (user-installed / manually updated dashboard)
# 3. astrbot/dashboard/dist/ (bundled with the wheel)
if webui_dir and os.path.exists(webui_dir):
self.data_path = os.path.abspath(webui_dir)
else:
user_dist = os.path.join(get_astrbot_data_path(), "dist")
bundled_dist = get_bundled_dashboard_dist_path()
user_version = get_dashboard_dist_version(user_dist)
if os.path.exists(user_dist) and is_dashboard_dist_compatible(
if os.path.exists(user_dist) and not should_use_bundled_dashboard_dist(
user_dist,
VERSION,
):
self.data_path = os.path.abspath(user_dist)
elif should_use_bundled_dashboard_dist(
user_dist,
VERSION,
) or is_dashboard_dist_compatible(bundled_dist, VERSION):
elif bundled_dist.exists():
self.data_path = str(bundled_dist)
logger.info("Using bundled dashboard dist: %s", self.data_path)
elif os.path.exists(user_dist):
logger.warning(
"Ignoring data/dist because WebUI version mismatches core: %s, expected v%s.",
user_version,
VERSION,
)
self.data_path = None
else:
# Fall back to expected user path (will fail gracefully later)
self.data_path = os.path.abspath(user_dist)
@@ -558,7 +545,7 @@ class AstrBotDashboard:
raise Exception(f"端口 {port} 已被占用")
if self.data_path and (Path(self.data_path) / "index.html").is_file():
if (Path(self.data_path) / "index.html").is_file():
webui_status = "WebUI is ready"
else:
webui_status = (

91
main.py
View File

@@ -2,7 +2,6 @@ import argparse
import asyncio
import mimetypes
import os
import shutil
import sys
from pathlib import Path
@@ -47,10 +46,7 @@ from astrbot.core.utils.astrbot_path import ( # noqa: E402
from astrbot.core.utils.io import ( # noqa: E402
download_dashboard,
get_bundled_dashboard_dist_path,
get_dashboard_dist_version,
is_dashboard_dist_compatible,
is_dashboard_version_compatible,
remove_dir,
get_dashboard_version,
should_use_bundled_dashboard_dist,
)
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime # noqa: E402
@@ -95,15 +91,7 @@ def check_env() -> None:
async def check_dashboard_files(webui_dir: str | None = None):
"""Resolve and repair dashboard static files for startup.
Args:
webui_dir: Optional explicit WebUI directory path from CLI.
Returns:
The directory path to serve, or None when no usable WebUI can be prepared.
"""
"""下载管理面板文件"""
# 指定webui目录
if webui_dir:
if os.path.exists(webui_dir):
@@ -111,81 +99,40 @@ async def check_dashboard_files(webui_dir: str | None = None):
return webui_dir
logger.warning("WebUI directory not found: %s. Using default.", webui_dir)
data_dist_path = Path(get_astrbot_data_path()) / "dist"
bundled_dist = get_bundled_dashboard_dist_path()
if data_dist_path.exists():
v = get_dashboard_dist_version(data_dist_path)
if is_dashboard_dist_compatible(data_dist_path, VERSION):
logger.info("WebUI is up to date.")
return str(data_dist_path)
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(data_dist_path):
v = await get_dashboard_version()
if should_use_bundled_dashboard_dist(data_dist_path, VERSION):
bundled_dist = get_bundled_dashboard_dist_path()
logger.info(
"Replacing data/dist with bundled WebUI because its version does not match core version v%s.",
"Using bundled WebUI because data/dist is older than core version v%s.",
VERSION,
)
try:
remove_dir(str(data_dist_path))
shutil.copytree(bundled_dist, data_dist_path)
return str(data_dist_path)
except Exception as e:
return str(bundled_dist)
if v is not None:
# 存在文件
if v == f"v{VERSION}":
logger.info("WebUI is up to date.")
else:
logger.warning(
"Failed to replace data/dist with bundled WebUI: %s. Using bundled WebUI directly.",
e,
"WebUI version mismatch: %s, expected v%s.",
v,
VERSION,
)
return str(bundled_dist)
if is_dashboard_version_compatible(v, VERSION):
logger.warning(
"WebUI files are incomplete for v%s. Re-downloading WebUI.",
VERSION,
)
elif v is not None:
logger.warning(
"WebUI version mismatch: %s, expected v%s. Re-downloading WebUI.",
v,
VERSION,
)
else:
logger.warning(
"WebUI version file is missing. Re-downloading WebUI v%s.",
VERSION,
)
try:
await download_dashboard(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return str(data_dist_path)
if is_dashboard_dist_compatible(bundled_dist, VERSION):
logger.info(
"Using bundled WebUI v%s.", get_dashboard_dist_version(bundled_dist)
)
return str(bundled_dist)
return data_dist_path
logger.info(
"Downloading WebUI. If it fails, download dist.zip from https://github.com/AstrBotDevs/AstrBot/releases/latest and extract dist to data/.",
)
try:
await download_dashboard(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
await download_dashboard(version=f"v{VERSION}", latest=False)
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return None
logger.info("管理面板下载完成。")
return str(data_dist_path)
return data_dist_path
async def main_async(webui_dir_arg: str | None) -> None:

View File

@@ -1,11 +1,6 @@
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
from click.testing import CliRunner
from astrbot.cli.commands import cmd_init
from astrbot.core.utils.auth_password import verify_dashboard_password
@@ -19,7 +14,6 @@ async def test_init_without_initial_password_env_does_not_create_config(
async def fake_check_dashboard(_data_path):
return None
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.delenv(cmd_init.DASHBOARD_INITIAL_PASSWORD_ENV, raising=False)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
(tmp_path / ".astrbot").touch()
@@ -38,7 +32,6 @@ async def test_init_uses_initial_password_env_to_create_config(
return None
initial_password = "AstrBotInitialPassword123"
monkeypatch.setenv("ASTRBOT_ROOT", str(tmp_path))
monkeypatch.setenv(cmd_init.DASHBOARD_INITIAL_PASSWORD_ENV, initial_password)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
(tmp_path / ".astrbot").touch()
@@ -59,71 +52,3 @@ async def test_init_uses_initial_password_env_to_create_config(
)
assert dashboard_config["password_change_required"] is True
assert dashboard_config["password_storage_upgraded"] is True
def test_cli_main_import_does_not_create_cwd_data(tmp_path):
repo_root = Path(__file__).resolve().parents[1]
env = os.environ.copy()
env.pop("ASTRBOT_ROOT", None)
env["HOME"] = str(tmp_path / "home")
env["PYTHONPATH"] = (
str(repo_root)
if not env.get("PYTHONPATH")
else f"{repo_root}{os.pathsep}{env['PYTHONPATH']}"
)
result = subprocess.run(
[sys.executable, "-c", "import astrbot.cli.__main__"],
cwd=tmp_path,
env=env,
capture_output=True,
text=True,
check=False,
)
assert result.returncode == 0, result.stderr
assert not (tmp_path / "data").exists()
def test_init_defaults_to_user_runtime(monkeypatch, tmp_path):
async def fake_check_dashboard(_data_path):
return None
home = tmp_path / "home"
workdir = tmp_path / "workdir"
home.mkdir()
workdir.mkdir()
monkeypatch.setenv("HOME", str(home))
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(workdir)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
result = CliRunner().invoke(cmd_init.init, input="\n", env={"ASTRBOT_ROOT": ""})
assert result.exit_code == 0, result.output
assert (home / ".astrbot" / ".astrbot").exists()
assert (home / ".astrbot" / "data" / "config").is_dir()
assert not (workdir / "data").exists()
def test_init_can_install_to_current_directory(monkeypatch, tmp_path):
async def fake_check_dashboard(_data_path):
return None
home = tmp_path / "home"
workdir = tmp_path / "workdir"
home.mkdir()
workdir.mkdir()
monkeypatch.setenv("HOME", str(home))
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(workdir)
monkeypatch.setattr(cmd_init, "check_dashboard", fake_check_dashboard)
result = CliRunner().invoke(cmd_init.init, input="2\n", env={"ASTRBOT_ROOT": ""})
assert result.exit_code == 0, result.output
assert (workdir / ".astrbot").exists()
assert (workdir / "data" / "config").is_dir()
assert not (home / ".astrbot").exists()

View File

@@ -30,7 +30,6 @@ def _read_config(config_path):
def test_password_command_changes_dashboard_password(monkeypatch, tmp_path):
config_path = _write_config(tmp_path)
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
runner = CliRunner()
@@ -56,7 +55,6 @@ def test_password_command_changes_dashboard_password(monkeypatch, tmp_path):
def test_password_command_can_update_dashboard_username(monkeypatch, tmp_path):
config_path = _write_config(tmp_path)
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
runner = CliRunner()
@@ -73,7 +71,6 @@ def test_password_command_can_update_dashboard_username(monkeypatch, tmp_path):
def test_conf_set_dashboard_password_updates_password_state(monkeypatch, tmp_path):
config_path = _write_config(tmp_path)
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
runner = CliRunner()

View File

@@ -273,7 +273,6 @@ def test_dashboard_uses_bundled_dist_when_data_dist_is_stale(
bundled_dist = tmp_path / "bundled-dist"
user_dist.mkdir(parents=True)
bundled_dist.mkdir()
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
@@ -294,32 +293,6 @@ def test_dashboard_uses_bundled_dist_when_data_dist_is_stale(
assert server.data_path == str(bundled_dist)
def test_dashboard_ignores_mismatched_data_dist_without_bundled(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path,
):
data_dir = tmp_path / "data"
user_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
monkeypatch.setattr(
"astrbot.dashboard.server.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.dashboard.server.get_bundled_dashboard_dist_path",
lambda: bundled_dist,
)
shutdown_event = asyncio.Event()
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
assert server.data_path is None
async def _set_dashboard_password_change_required(
core_lifecycle_td: AstrBotCoreLifecycle,
required: bool,

View File

@@ -9,7 +9,7 @@ from unittest import mock
import pytest
from astrbot.core.utils.io import get_dashboard_version, should_use_bundled_dashboard_dist
from astrbot.core.utils.io import should_use_bundled_dashboard_dist
from main import (
DASHBOARD_RESET_PASSWORD_ENV,
_apply_startup_env_flags,
@@ -173,108 +173,49 @@ def test_version_info_comparisons():
@pytest.mark.asyncio
async def test_check_dashboard_files_not_exists(tmp_path):
async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist."""
data_dir = tmp_path / "data"
bundled_dist = tmp_path / "bundled-dist"
monkeypatch.setattr(os.path, "exists", lambda x: False)
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
from main import VERSION
assert result == str(data_dir / "dist")
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
mock_download.assert_called_once()
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(tmp_path):
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
"""Tests that dashboard is not downloaded when it exists and version matches."""
from main import VERSION
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(data_dist / "index.html").write_text("user", encoding="utf-8")
# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION
mock_get_version.return_value = f"v{VERSION}"
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
await check_dashboard_files()
# Assert that download_dashboard was NOT called
mock_download.assert_not_called()
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch_downloads(tmp_path):
"""Tests that a mismatched dashboard is downloaded on startup."""
from main import VERSION
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.logger.warning") as mock_logger_warning:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
with mock.patch(
"main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1")
):
with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "WebUI version mismatch" in call_args[0]
@pytest.mark.asyncio
async def test_check_dashboard_files_downloads_when_matching_dist_is_incomplete(
tmp_path,
):
"""Tests that a version match alone is not enough to serve WebUI."""
from main import VERSION
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
mock_download.assert_called_once_with(
version=f"v{VERSION}",
latest=False,
allow_insecure_ssl_fallback=False,
)
def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
@@ -282,7 +223,6 @@ def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
(bundled_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("v4.24.2", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
@@ -291,94 +231,46 @@ def test_should_use_bundled_dashboard_dist_when_data_dist_is_stale(tmp_path):
assert should_use_bundled_dashboard_dist(user_dist, "v4.24.4") is True
def test_should_use_bundled_dashboard_dist_when_version_file_is_malformed(tmp_path):
def test_should_keep_data_dist_when_version_file_is_malformed(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(user_dist / "assets" / "version").write_text("not-a-version", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is True
def test_should_use_bundled_dashboard_dist_when_data_version_file_is_missing(tmp_path):
user_dist = tmp_path / "user-dist"
bundled_dist = tmp_path / "bundled-dist"
(user_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets" / "version").write_text("v4.24.4", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is True
assert should_use_bundled_dashboard_dist(user_dist, "4.24.4") is False
@pytest.mark.asyncio
async def test_get_dashboard_version_uses_bundled_dist_when_data_dist_is_missing(
async def test_check_dashboard_files_uses_bundled_dist_when_data_dist_is_stale(
tmp_path,
):
"""Tests bundled WebUI version lookup when data/dist is absent."""
from main import VERSION
data_dir = tmp_path / "data"
bundled_dist = tmp_path / "bundled-dist"
(bundled_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
with mock.patch(
"astrbot.core.utils.io.get_astrbot_data_path",
return_value=str(data_dir),
):
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=bundled_dist,
):
assert await get_dashboard_version() == f"v{VERSION}"
@pytest.mark.asyncio
async def test_check_dashboard_files_replaces_stale_data_dist_with_bundled_dist(
tmp_path,
):
"""Tests that a stale data/dist is repaired from bundled dashboard assets."""
from main import VERSION
"""Tests that a stale data/dist does not override bundled dashboard assets."""
data_dir = tmp_path / "data"
data_dist = data_dir / "dist"
bundled_dist = tmp_path / "bundled-dist"
(data_dist / "assets").mkdir(parents=True)
(bundled_dist / "assets").mkdir(parents=True)
(data_dist / "assets" / "version").write_text("v0.0.1", encoding="utf-8")
(data_dist / "old.txt").write_text("old", encoding="utf-8")
(bundled_dist / "assets" / "version").write_text(f"v{VERSION}", encoding="utf-8")
(bundled_dist / "index.html").write_text("bundled", encoding="utf-8")
data_dist.mkdir(parents=True)
bundled_dist.mkdir()
with mock.patch("main.get_astrbot_data_path", return_value=str(data_dir)):
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
"main.get_dashboard_version", mock.AsyncMock(return_value="v0.0.1")
):
with mock.patch(
"astrbot.core.utils.io.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
"main.should_use_bundled_dashboard_dist", return_value=True
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
with mock.patch(
"main.get_bundled_dashboard_dist_path",
return_value=Path(bundled_dist),
):
with mock.patch("main.download_dashboard") as mock_download:
result = await check_dashboard_files()
assert result == str(data_dist)
assert (data_dist / "assets" / "version").read_text(encoding="utf-8") == f"v{VERSION}"
assert (data_dist / "index.html").read_text(encoding="utf-8") == "bundled"
assert not (data_dist / "old.txt").exists()
assert result == str(bundled_dist)
mock_download.assert_not_called()
@@ -389,7 +281,7 @@ async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_dist_version") as mock_get_version:
with mock.patch("main.get_dashboard_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()

View File

@@ -440,7 +440,6 @@ async def test_download_dashboard_falls_back_when_hosted_package_is_not_zip(
path: str,
show_progress: bool = False, # noqa: ARG001
progress_callback=None, # noqa: ARG001
allow_insecure_ssl_fallback: bool = True, # noqa: ARG001
) -> None:
calls.append(url)
parsed = urlparse(url)

View File

@@ -8,7 +8,6 @@ import pytest
from astrbot.core import astr_main_agent as ama
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.message import Message, dump_messages_with_checkpoints
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Plain, Reply, Video
@@ -378,18 +377,8 @@ class TestApplyKb:
):
await module._apply_kb(mock_event, req, mock_context, config)
assert req.system_prompt == "System prompt"
assert len(req.extra_user_content_parts) == 1
kb_part = req.extra_user_content_parts[0]
assert kb_part.text == "[Related Knowledge Base Results]:\nKB result"
message = Message.model_validate(await req.assemble_context())
assert isinstance(message.content, list)
assert message.content[0].text == "test question"
assert message.content[1].text == "[Related Knowledge Base Results]:\nKB result"
assert dump_messages_with_checkpoints([message]) == [
{"role": "user", "content": [{"type": "text", "text": "test question"}]}
]
assert "[Related Knowledge Base Results]:" in req.system_prompt
assert "KB result" in req.system_prompt
@pytest.mark.asyncio
async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context):