fix: protect desktop plugin installs with core lock (#7872)

This commit is contained in:
エイカク
2026-04-28 21:10:19 +09:00
committed by GitHub
parent d8de0035a9
commit 2f33c34b5c
4 changed files with 220 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ from collections.abc import Iterator
from packaging.requirements import Requirement
from astrbot.core.utils.desktop_core_lock import get_desktop_core_lock_constraints
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name,
collect_installed_distribution_versions,
@@ -93,7 +94,14 @@ class CoreConstraintsProvider:
@contextlib.contextmanager
def constraints_file(self) -> Iterator[str | None]:
constraints = _get_core_constraints(self._core_dist_name)
constraints = tuple(
dict.fromkeys(
(
*_get_core_constraints(self._core_dist_name),
*get_desktop_core_lock_constraints(),
)
)
)
if not constraints:
yield None
return

View File

@@ -0,0 +1,108 @@
import json
import logging
import os
import re
from functools import lru_cache
from typing import Any
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot")
DESKTOP_CORE_LOCK_PATH_ENV = "ASTRBOT_DESKTOP_CORE_LOCK_PATH"
def _canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
def _safe_requirement_pin(name: str, version: str) -> str | None:
if not name or not version:
return None
if any(char.isspace() for char in name) or any(char.isspace() for char in version):
return None
return f"{name}=={version}"
def _fallback_module_name(name: str) -> str:
return _canonicalize_distribution_name(name).replace("-", "_")
def _iter_distribution_records(data: Any):
if not isinstance(data, dict):
return
distributions = data.get("distributions", [])
if not isinstance(distributions, list):
return
for record in distributions:
if isinstance(record, dict):
yield record
@lru_cache(maxsize=8)
def _load_lock_data(lock_path: str) -> dict[str, Any] | None:
try:
with open(lock_path, encoding="utf-8") as file:
data = json.load(file)
except FileNotFoundError:
logger.warning("桌面端核心依赖锁不存在: %s", lock_path)
return None
except Exception as exc:
logger.warning("读取桌面端核心依赖锁失败: %s", exc)
return None
if not isinstance(data, dict):
logger.warning("桌面端核心依赖锁格式无效: %s", lock_path)
return None
return data
def _resolve_lock_data() -> dict[str, Any] | None:
if not is_packaged_desktop_runtime():
return None
lock_path = os.environ.get(DESKTOP_CORE_LOCK_PATH_ENV, "").strip()
if not lock_path:
return None
return _load_lock_data(lock_path)
def get_desktop_core_lock_constraints() -> tuple[str, ...]:
data = _resolve_lock_data()
if not data:
return ()
constraints: dict[str, str] = {}
for record in _iter_distribution_records(data):
name = record.get("name")
version = record.get("version")
if not isinstance(name, str) or not isinstance(version, str):
continue
pin = _safe_requirement_pin(name, version)
if not pin:
continue
constraints.setdefault(_canonicalize_distribution_name(name), pin)
return tuple(constraints[key] for key in sorted(constraints))
def get_desktop_core_lock_modules() -> frozenset[str]:
data = _resolve_lock_data()
if not data:
return frozenset()
modules: set[str] = set()
for record in _iter_distribution_records(data):
name = record.get("name")
top_level_modules = record.get("top_level_modules", [])
if isinstance(top_level_modules, list):
for module_name in top_level_modules:
if isinstance(module_name, str) and module_name:
modules.add(module_name.split(".", 1)[0])
if isinstance(name, str):
fallback = _fallback_module_name(name)
if fallback:
modules.add(fallback)
return frozenset(modules)

View File

@@ -18,6 +18,7 @@ from urllib.parse import urlparse
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
from astrbot.core.utils.desktop_core_lock import get_desktop_core_lock_modules
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name as _canonicalize_distribution_name,
)
@@ -811,6 +812,12 @@ def _ensure_plugin_dependencies_preferred(
if not candidate_modules:
return
locked_modules = get_desktop_core_lock_modules()
if locked_modules:
candidate_modules = candidate_modules.difference(locked_modules)
if not candidate_modules:
return
_ensure_preferred_modules(candidate_modules, target_site_packages)

View File

@@ -1,6 +1,8 @@
import asyncio
import json
import ntpath
import threading
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
@@ -1061,6 +1063,100 @@ def test_core_constraints_file_propagates_inner_conflict_without_fake_warning(
assert warning_logs == []
@pytest.mark.asyncio
async def test_install_adds_desktop_core_lock_constraints_for_packaged_runtime(
monkeypatch, tmp_path
):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delattr("sys.frozen", raising=False)
lock_path = tmp_path / "runtime-core-lock.json"
lock_path.write_text(
json.dumps(
{
"version": 1,
"distributions": [
{
"name": "desktop-only-core",
"version": "9.9.9",
"top_level_modules": ["desktop_only_core"],
}
],
}
),
encoding="utf-8",
)
monkeypatch.setenv("ASTRBOT_DESKTOP_CORE_LOCK_PATH", str(lock_path))
site_packages_path = tmp_path / "site-packages"
captured_constraints = []
async def capture_pip_args(self, args):
del self
constraints_path = args[args.index("-c") + 1]
captured_constraints.append(Path(constraints_path).read_text(encoding="utf-8"))
return 0
monkeypatch.setattr(PipInstaller, "_run_pip_in_process", capture_pip_args)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer.get_astrbot_site_packages_path",
lambda: str(site_packages_path),
)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred",
lambda path, requirements: None,
)
installer = PipInstaller("")
await installer.install(package_name="Cua")
assert captured_constraints
assert "desktop-only-core==9.9.9" in captured_constraints[0]
def test_ensure_plugin_dependencies_preferred_skips_desktop_core_lock_modules(
monkeypatch, tmp_path
):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
lock_path = tmp_path / "runtime-core-lock.json"
lock_path.write_text(
json.dumps(
{
"version": 1,
"distributions": [
{
"name": "openai",
"version": "2.32.0",
"top_level_modules": ["openai"],
}
],
}
),
encoding="utf-8",
)
monkeypatch.setenv("ASTRBOT_DESKTOP_CORE_LOCK_PATH", str(lock_path))
preferred_calls = []
monkeypatch.setattr(
pip_installer_module,
"_collect_candidate_modules",
lambda requirements, site_packages_path: {"openai", "cua_agent"},
)
monkeypatch.setattr(
pip_installer_module,
"_ensure_preferred_modules",
lambda modules, site_packages_path: preferred_calls.append(modules),
)
pip_installer_module._ensure_plugin_dependencies_preferred(
str(tmp_path / "site-packages"),
{"Cua"},
)
assert preferred_calls == [{"cua_agent"}]
def test_iter_requirement_lines_expands_nested_requirement_files(tmp_path):
base_requirements = tmp_path / "base.txt"
base_requirements.write_text("demo-package==1.0\n", encoding="utf-8")