mirror of
https://github.com/AstrBotDevs/AstrBot
synced 2026-07-01 01:10:21 +08:00
fix: make project update flow atomic (#8805)
* fix: make project update flow atomic * fix: address atomic update review feedback * fix: show update success after restart * fix: prevent update progress reset during restart * fix: align update success feedback styling
This commit is contained in:
@@ -2570,31 +2570,56 @@ async def test_do_update(
|
||||
release_path = temp_release_dir / "astrbot"
|
||||
calls = []
|
||||
|
||||
async def mock_update(*args, **kwargs):
|
||||
"""Mocks the update process by creating a directory in the temp path."""
|
||||
calls.append("core")
|
||||
async def mock_download_core(*args, **kwargs):
|
||||
calls.append("download-core")
|
||||
callback = kwargs.get("progress_callback")
|
||||
if callback:
|
||||
callback({"downloaded": 10, "total": 10, "percent": 1, "speed": 1})
|
||||
zip_path = kwargs["path"]
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("AstrBot-main/README.md", "core")
|
||||
return zip_path
|
||||
|
||||
def mock_apply_core(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("apply-core")
|
||||
os.makedirs(release_path, exist_ok=True)
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
"""Mocks the dashboard download to prevent network access."""
|
||||
calls.append("dashboard")
|
||||
calls.append("download-dashboard")
|
||||
callback = kwargs.get("progress_callback")
|
||||
if callback:
|
||||
callback({"downloaded": 10, "total": 10, "percent": 1, "speed": 1})
|
||||
with zipfile.ZipFile(kwargs["path"], "w") as zf:
|
||||
zf.writestr("dist/index.html", "dashboard")
|
||||
return
|
||||
|
||||
def mock_extract_dashboard(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("apply-dashboard")
|
||||
|
||||
async def mock_pip_install(*args, **kwargs):
|
||||
"""Mocks pip install to prevent actual installation."""
|
||||
return
|
||||
|
||||
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"download_update_package",
|
||||
mock_download_core,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"apply_update_package",
|
||||
mock_apply_core,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.download_dashboard",
|
||||
mock_download_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.extract_dashboard",
|
||||
mock_extract_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.pip_installer.install",
|
||||
mock_pip_install,
|
||||
@@ -2609,7 +2634,12 @@ async def test_do_update(
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists(release_path)
|
||||
assert calls[:2] == ["dashboard", "core"]
|
||||
assert calls[:4] == [
|
||||
"download-dashboard",
|
||||
"download-core",
|
||||
"apply-core",
|
||||
"apply-dashboard",
|
||||
]
|
||||
|
||||
progress_response = await test_client.get(
|
||||
"/api/update/progress?id=test-progress",
|
||||
@@ -2621,6 +2651,142 @@ async def test_do_update(
|
||||
assert progress_data["data"]["overall_percent"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update_does_not_apply_files_when_core_download_fails(
|
||||
app: FastAPIAppAdapter,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
calls = []
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
calls.append("download-dashboard")
|
||||
callback = kwargs.get("progress_callback")
|
||||
if callback:
|
||||
callback({"downloaded": 10, "total": 10, "percent": 1, "speed": 1})
|
||||
|
||||
async def mock_download_core(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("download-core")
|
||||
raise RuntimeError("core download failed")
|
||||
|
||||
def mock_apply_core(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("apply-core")
|
||||
|
||||
def mock_extract_dashboard(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("apply-dashboard")
|
||||
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"download_update_package",
|
||||
mock_download_core,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"apply_update_package",
|
||||
mock_apply_core,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.download_dashboard",
|
||||
mock_download_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.extract_dashboard",
|
||||
mock_extract_dashboard,
|
||||
)
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/do",
|
||||
headers=authenticated_header,
|
||||
json={"version": "v3.4.0", "reboot": False, "progress_id": "atomic-fail"},
|
||||
)
|
||||
data = await response.get_json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["status"] == "error"
|
||||
assert calls == ["download-dashboard", "download-core"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update_does_not_apply_files_when_package_verification_fails(
|
||||
app: FastAPIAppAdapter,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
):
|
||||
test_client = app.test_client()
|
||||
calls = []
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
del args
|
||||
calls.append("download-dashboard")
|
||||
Path(kwargs["path"]).write_bytes(b"not a zip")
|
||||
|
||||
async def mock_download_core(*args, **kwargs):
|
||||
del args
|
||||
calls.append("download-core")
|
||||
zip_path = kwargs["path"]
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("AstrBot-main/README.md", "core")
|
||||
return zip_path
|
||||
|
||||
def mock_apply_core(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("apply-core")
|
||||
|
||||
def mock_extract_dashboard(*args, **kwargs):
|
||||
del args, kwargs
|
||||
calls.append("apply-dashboard")
|
||||
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"download_update_package",
|
||||
mock_download_core,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
core_lifecycle_td.astrbot_updator,
|
||||
"apply_update_package",
|
||||
mock_apply_core,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.download_dashboard",
|
||||
mock_download_dashboard,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.services.update_service.extract_dashboard",
|
||||
mock_extract_dashboard,
|
||||
)
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/do",
|
||||
headers=authenticated_header,
|
||||
json={"version": "v3.4.0", "reboot": False, "progress_id": "invalid-zip"},
|
||||
)
|
||||
data = await response.get_json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["status"] == "error"
|
||||
assert calls == ["download-dashboard", "download-core"]
|
||||
|
||||
|
||||
def test_extract_dashboard_rejects_zip_path_traversal(tmp_path: Path):
|
||||
from astrbot.core.utils.io import extract_dashboard
|
||||
|
||||
archive_path = tmp_path / "dashboard.zip"
|
||||
extract_path = tmp_path / "data"
|
||||
with zipfile.ZipFile(archive_path, "w") as zf:
|
||||
zf.writestr("../evil.txt", "unsafe")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsafe dashboard archive path"):
|
||||
extract_dashboard(archive_path, extract_path)
|
||||
|
||||
assert not (tmp_path / "evil.txt").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update_hides_internal_error_message_in_response_and_progress(
|
||||
app: FastAPIAppAdapter,
|
||||
|
||||
@@ -540,6 +540,12 @@ class FakeAstrBotUpdator:
|
||||
async def update(self, *_args, **_kwargs) -> None:
|
||||
return None
|
||||
|
||||
async def download_update_package(self, *_args, **kwargs):
|
||||
return kwargs.get("path", "temp.zip")
|
||||
|
||||
def apply_update_package(self, *_args, **_kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class FakeAstrBotConfig(dict):
|
||||
def save_config(self, post_config: dict) -> None:
|
||||
|
||||
Reference in New Issue
Block a user