diff --git a/.gitignore b/.gitignore index 004481c61..4a02b8bb3 100644 --- a/.gitignore +++ b/.gitignore @@ -62,4 +62,4 @@ GenieData/ .opencode/ .kilocode/ .worktrees/ -docs/plans/ + diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index cf000c5a4..57be1e9a9 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -1,12 +1,14 @@ """插件的重载、启停、安装、卸载等操作。""" import asyncio +import contextlib import functools import inspect import json import logging import os import sys +import tempfile import traceback from types import ModuleType @@ -29,12 +31,12 @@ get_astrbot_config_path, get_astrbot_path, get_astrbot_plugin_path, + get_astrbot_temp_path, ) from astrbot.core.utils.io import remove_dir from astrbot.core.utils.metrics import Metric from astrbot.core.utils.requirements_utils import ( - RequirementsPrecheckFailed, - find_missing_requirements_or_raise, + plan_missing_requirements_install, ) from . import StarMetadata @@ -74,30 +76,78 @@ def __init__( self.error = error +@contextlib.contextmanager +def _temporary_filtered_requirements_file( + *, + install_lines: tuple[str, ...], +): + filtered_requirements_path: str | None = None + temp_dir = get_astrbot_temp_path() + + try: + os.makedirs(temp_dir, exist_ok=True) + with tempfile.NamedTemporaryFile( + mode="w", + suffix="_plugin_requirements.txt", + delete=False, + dir=temp_dir, + encoding="utf-8", + ) as filtered_requirements_file: + filtered_requirements_file.write("\n".join(install_lines) + "\n") + filtered_requirements_path = filtered_requirements_file.name + + yield filtered_requirements_path + finally: + if filtered_requirements_path and os.path.exists(filtered_requirements_path): + try: + os.remove(filtered_requirements_path) + except OSError as exc: + logger.warning( + "删除临时插件依赖文件失败:%s(路径:%s)", + exc, + filtered_requirements_path, + ) + + async def _install_requirements_with_precheck( *, plugin_label: str, requirements_path: str, ) -> None: - try: - missing = find_missing_requirements_or_raise(requirements_path) - except RequirementsPrecheckFailed: + install_plan = plan_missing_requirements_install(requirements_path) + + if install_plan is None: logger.info( - f"正在安装插件 {plugin_label} 的依赖库(预检查失败,回退到完整安装): " + f"正在安装插件 {plugin_label} 的依赖库(缺失依赖预检查不可裁剪,回退到完整安装): " f"{requirements_path}" ) await pip_installer.install(requirements_path=requirements_path) return - if not missing: + if not install_plan.missing_names: logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。") return + if not install_plan.install_lines: + fallback_reason = install_plan.fallback_reason or "unknown reason" + logger.info( + "检测到插件 %s 缺失依赖,但无法安全裁剪 requirements,回退到完整安装: %s (%s)", + plugin_label, + requirements_path, + fallback_reason, + ) + await pip_installer.install(requirements_path=requirements_path) + return + logger.info( f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: " - f"{requirements_path} -> {sorted(missing)}" + f"{requirements_path} -> {sorted(install_plan.missing_names)}" ) - await pip_installer.install(requirements_path=requirements_path) + + with _temporary_filtered_requirements_file( + install_lines=install_plan.install_lines, + ) as filtered_requirements_path: + await pip_installer.install(requirements_path=filtered_requirements_path) class PluginManager: diff --git a/astrbot/core/utils/requirements_utils.py b/astrbot/core/utils/requirements_utils.py index 7f3827256..e031de846 100644 --- a/astrbot/core/utils/requirements_utils.py +++ b/astrbot/core/utils/requirements_utils.py @@ -4,7 +4,7 @@ import re import shlex import sys -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass from packaging.requirements import InvalidRequirement, Requirement @@ -29,6 +29,13 @@ class ParsedPackageInput: requirement_names: frozenset[str] +@dataclass(frozen=True) +class MissingRequirementsPlan: + missing_names: frozenset[str] + install_lines: tuple[str, ...] + fallback_reason: str | None = None + + def canonicalize_distribution_name(name: str) -> str: return re.sub(r"[-_.]+", "-", name).strip("-").lower() @@ -364,8 +371,8 @@ def _load_requirement_lines_for_precheck( None, ) if fallback_line is not None: - logger.warning( - "预检查缺失依赖失败,将回退到完整安装: unresolved direct reference in %s: %s", + logger.info( + "缺失依赖预检查发现无法安全裁剪的 option/direct-reference 行,将回退到完整安装: %s (%s)", requirements_path, fallback_line, ) @@ -381,6 +388,13 @@ def find_missing_requirements(requirements_path: str) -> set[str] | None: if not can_precheck or requirement_lines is None: return None + return find_missing_requirements_from_lines(requirement_lines) + + +def find_missing_requirements_from_lines( + requirement_lines: Sequence[str], +) -> set[str] | None: + required = list(iter_requirements(lines=requirement_lines)) if not required: return set() @@ -401,6 +415,70 @@ def find_missing_requirements(requirements_path: str) -> set[str] | None: return missing +def build_missing_requirements_install_lines( + requirements_path: str, + requirement_lines: Sequence[str], + missing_names: set[str] | frozenset[str], +) -> tuple[str, ...] | None: + wanted_names = set(missing_names) + install_lines: list[str] = [] + for line in requirement_lines: + parsed = _parse_requirement_line(line) + if parsed is None: + if looks_like_direct_reference(line) or line.startswith(("-", "--")): + logger.debug( + "缺失依赖行筛选回退到完整安装:requirements 中包含无法安全裁剪的 option/direct-reference 行: %s (%s)", + requirements_path, + line, + ) + return None + continue + + name, _specifier = parsed + if name in wanted_names: + install_lines.append(line) + + return tuple(install_lines) + + +def plan_missing_requirements_install( + requirements_path: str, +) -> MissingRequirementsPlan | None: + can_precheck, requirement_lines = _load_requirement_lines_for_precheck( + requirements_path + ) + if not can_precheck or requirement_lines is None: + return None + + missing = find_missing_requirements_from_lines(requirement_lines) + if missing is None: + return None + + install_lines = build_missing_requirements_install_lines( + requirements_path, + requirement_lines, + missing, + ) + if install_lines is None: + return None + if missing and not install_lines: + logger.warning( + "预检查缺失依赖成功,但无法映射到可安装 requirement 行,将回退到完整安装: %s -> %s", + requirements_path, + sorted(missing), + ) + return MissingRequirementsPlan( + missing_names=frozenset(missing), + install_lines=(), + fallback_reason="unmapped missing requirement names", + ) + + return MissingRequirementsPlan( + missing_names=frozenset(missing), + install_lines=install_lines, + ) + + def find_missing_requirements_or_raise(requirements_path: str) -> set[str]: missing = find_missing_requirements(requirements_path) if missing is None: diff --git a/tests/test_pip_helper_modules.py b/tests/test_pip_helper_modules.py index dcb5cdb21..506dd0945 100644 --- a/tests/test_pip_helper_modules.py +++ b/tests/test_pip_helper_modules.py @@ -145,24 +145,182 @@ def test_find_missing_requirements_or_raise_uses_requirements_exception(tmp_path requirements_utils.find_missing_requirements_or_raise(str(requirements_path)) +def test_build_missing_requirements_install_lines_keeps_only_missing_lines(tmp_path): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + 'aiohttp>=3.0\nboto3==1.2; python_version >= "3.0"\nbotocore\n', + encoding="utf-8", + ) + + install_lines = requirements_utils.build_missing_requirements_install_lines( + str(requirements_path), + [ + "aiohttp>=3.0", + 'boto3==1.2; python_version >= "3.0"', + "botocore", + ], + {"boto3", "botocore"}, + ) + + assert install_lines == ( + 'boto3==1.2; python_version >= "3.0"', + "botocore", + ) + + +def test_build_missing_requirements_install_lines_returns_empty_tuple_when_all_satisfied( + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("aiohttp>=3.0\nboto3\n", encoding="utf-8") + + install_lines = requirements_utils.build_missing_requirements_install_lines( + str(requirements_path), ["aiohttp>=3.0", "boto3"], set() + ) + + assert install_lines == () + + +def test_build_missing_requirements_install_lines_returns_none_for_option_lines( + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + "--extra-index-url https://example.com/simple\nboto3\n", + encoding="utf-8", + ) + + install_lines = requirements_utils.build_missing_requirements_install_lines( + str(requirements_path), + ["--extra-index-url https://example.com/simple", "boto3"], + {"boto3"}, + ) + + assert install_lines is None + + +def test_build_missing_requirements_install_lines_skips_inactive_marker_lines( + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + 'boto3\nother-package; sys_platform == "win32"\n', + encoding="utf-8", + ) + + install_lines = requirements_utils.build_missing_requirements_install_lines( + str(requirements_path), + ["boto3", 'other-package; sys_platform == "win32"'], + {"boto3"}, + ) + + assert install_lines == ("boto3",) + + +def test_plan_missing_requirements_install_returns_none_when_missing_names_cannot_map_to_lines( + monkeypatch, + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("boto3\n", encoding="utf-8") + + monkeypatch.setattr( + requirements_utils, + "find_missing_requirements_from_lines", + lambda lines: {"botocore"}, + ) + + plan = requirements_utils.plan_missing_requirements_install(str(requirements_path)) + + assert plan is not None + assert plan.missing_names == frozenset({"botocore"}) + assert plan.install_lines == () + assert plan.fallback_reason == "unmapped missing requirement names" + + +def test_plan_missing_requirements_install_loads_requirement_lines_once( + monkeypatch, + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("boto3\nbotocore\n", encoding="utf-8") + calls = [] + + def mock_load(path): + calls.append(path) + return True, ["boto3", "botocore"] + + monkeypatch.setattr( + requirements_utils, + "_load_requirement_lines_for_precheck", + mock_load, + ) + monkeypatch.setattr( + requirements_utils, + "collect_installed_distribution_versions", + lambda paths: {}, + ) + monkeypatch.setattr( + requirements_utils, + "get_requirement_check_paths", + lambda: ["/tmp/site-packages"], + ) + + plan = requirements_utils.plan_missing_requirements_install(str(requirements_path)) + + assert plan is not None + assert plan.missing_names == frozenset({"boto3", "botocore"}) + assert plan.install_lines == ("boto3", "botocore") + assert calls == [str(requirements_path)] + + +def test_build_missing_requirements_install_lines_logs_why_option_lines_fall_back( + monkeypatch, + tmp_path, +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text( + "--extra-index-url https://example.com/simple\nboto3\n", + encoding="utf-8", + ) + + debug_logs = [] + + monkeypatch.setattr( + "astrbot.core.utils.requirements_utils.logger.debug", + lambda line, *args: debug_logs.append(line % args if args else line), + ) + + install_lines = requirements_utils.build_missing_requirements_install_lines( + str(requirements_path), + ["--extra-index-url https://example.com/simple", "boto3"], + {"boto3"}, + ) + + assert install_lines is None + assert any(str(requirements_path) in log for log in debug_logs) + assert any("option/direct-reference" in log for log in debug_logs) + + def test_find_missing_requirements_logs_path_and_reason_on_precheck_fallback( monkeypatch, tmp_path, ): requirements_path = tmp_path / "requirements.txt" requirements_path.write_text("git+https://example.com/demo.git\n", encoding="utf-8") - warning_logs = [] + + info_logs = [] monkeypatch.setattr( - "astrbot.core.utils.requirements_utils.logger.warning", - lambda line, *args: warning_logs.append(line % args if args else line), + "astrbot.core.utils.requirements_utils.logger.info", + lambda line, *args: info_logs.append(line % args if args else line), ) missing = requirements_utils.find_missing_requirements(str(requirements_path)) assert missing is None - assert any(str(requirements_path) in log for log in warning_logs) - assert any("direct reference" in log for log in warning_logs) + assert any(str(requirements_path) in log for log in info_logs) + assert any("option/direct-reference" in log for log in info_logs) def test_load_requirement_lines_for_precheck_uses_parse_requirement_line_result( diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1b52990a5..b1dafc87e 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,4 +1,5 @@ import asyncio +import os from pathlib import Path import pytest @@ -6,6 +7,7 @@ from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager from astrbot.core.utils.pip_installer import PipInstallError +from astrbot.core.utils.requirements_utils import MissingRequirementsPlan # --- Test Data & Helpers --- @@ -74,13 +76,25 @@ async def mock_reload(specified_dir_name=None): return mock_reload -def _build_dependency_install_mock(events, fail: bool): +def _build_dependency_install_mock( + events, + fail: bool, + *, + capture_content: bool = False, +): async def mock_install_requirements( - *, requirements_path: str = None, package_name: str = None, **kwargs + *, + requirements_path: str | None = None, + package_name: str | None = None, + **kwargs, ): del kwargs if requirements_path: - events.append(("deps", str(requirements_path))) + path = Path(requirements_path) + event = ("deps", str(path)) + if capture_content: + event = (*event, path.read_text(encoding="utf-8")) + events.append(event) if package_name: events.append(("deps_pkg", package_name)) if fail: @@ -90,24 +104,56 @@ async def mock_install_requirements( def _mock_missing_requirements(monkeypatch, missing: set[str]): + _mock_missing_requirements_plan(monkeypatch, missing, sorted(missing)) + + +def _mock_missing_requirements_plan( + monkeypatch, + missing_names, + install_lines, + *, + fallback_reason: str | None = None, +): monkeypatch.setattr( - "astrbot.core.star.star_manager.find_missing_requirements_or_raise", - lambda requirements_path: missing, + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: MissingRequirementsPlan( + missing_names=frozenset(missing_names), + install_lines=tuple(install_lines), + fallback_reason=fallback_reason, + ), ) def _mock_precheck_fails(monkeypatch): - from astrbot.core import RequirementsPrecheckFailed - - def mock_fail(requirements_path): - raise RequirementsPrecheckFailed("mock precheck failure") - monkeypatch.setattr( - "astrbot.core.star.star_manager.find_missing_requirements_or_raise", - mock_fail, + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: None, ) +def _assert_dependency_install_event_matches( + event, + *, + expected_original_path: Path, + expected_content: str | None = None, + expect_filtered_tempfile: bool | None = None, +): + assert event[0] == "deps" + used_path = Path(event[1]) + should_be_filtered = expected_content is not None + if expect_filtered_tempfile is not None: + should_be_filtered = expect_filtered_tempfile + + if not should_be_filtered: + assert used_path == expected_original_path + else: + assert used_path != expected_original_path + assert used_path.name.endswith("_plugin_requirements.txt") + if expected_content is not None: + if len(event) >= 3: + assert event[2] == expected_content + + # --- Fixtures --- @@ -188,13 +234,21 @@ def mock_load_and_register(*args, **kwargs): if dependency_install_fails: with pytest.raises(PluginDependencyInstallError, match="pip failed"): await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) - assert events == [("deps", str(plugin_path / "requirements.txt"))] + assert len(events) == 1 + _assert_dependency_install_event_matches( + events[0], + expected_original_path=plugin_path / "requirements.txt", + expected_content="networkx\n", + ) else: await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) - assert events == [ - ("deps", str(plugin_path / "requirements.txt")), - ("load", TEST_PLUGIN_DIR), - ] + assert len(events) == 2 + _assert_dependency_install_event_matches( + events[0], + expected_original_path=plugin_path / "requirements.txt", + expected_content="networkx\n", + ) + assert events[1] == ("load", TEST_PLUGIN_DIR) @pytest.mark.asyncio @@ -265,13 +319,21 @@ def mock_load_and_register(*args, **kwargs): if dependency_install_fails: with pytest.raises(PluginDependencyInstallError, match="pip failed"): await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR) - assert events == [("deps", str(local_updator / "requirements.txt"))] + assert len(events) == 1 + _assert_dependency_install_event_matches( + events[0], + expected_original_path=local_updator / "requirements.txt", + expected_content="networkx\n", + ) else: await plugin_manager_pm.reload_failed_plugin(TEST_PLUGIN_DIR) - assert events == [ - ("deps", str(local_updator / "requirements.txt")), - ("load", TEST_PLUGIN_DIR), - ] + assert len(events) == 2 + _assert_dependency_install_event_matches( + events[0], + expected_original_path=local_updator / "requirements.txt", + expected_content="networkx\n", + ) + assert events[1] == ("load", TEST_PLUGIN_DIR) @pytest.mark.asyncio @@ -337,7 +399,9 @@ async def mock_install_requirements(*args, **kwargs): mock_install_requirements, ) - with pytest.raises(PluginDependencyInstallError, match="install failed") as exc_info: + with pytest.raises( + PluginDependencyInstallError, match="install failed" + ) as exc_info: await plugin_manager_pm._ensure_plugin_requirements( str(local_updator), TEST_PLUGIN_DIR, @@ -403,10 +467,20 @@ async def mock_update(plugin, proxy=""): if dependency_install_fails: with pytest.raises(PluginDependencyInstallError, match="pip failed"): await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME) - assert ("deps", str(local_updator / "requirements.txt")) in events + dep_event = next(event for event in events if event[0] == "deps") + _assert_dependency_install_event_matches( + dep_event, + expected_original_path=local_updator / "requirements.txt", + expected_content="networkx\n", + ) else: await plugin_manager_pm.update_plugin(TEST_PLUGIN_NAME) - assert ("deps", str(local_updator / "requirements.txt")) in events + dep_event = next(event for event in events if event[0] == "deps") + _assert_dependency_install_event_matches( + dep_event, + expected_original_path=local_updator / "requirements.txt", + expected_content="networkx\n", + ) assert ("reload", TEST_PLUGIN_DIR) in events @@ -468,5 +542,144 @@ def mock_load_and_register(*args, **kwargs): await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) - assert ("deps", str(plugin_path / "requirements.txt")) in events + dep_event = next(event for event in events if event[0] == "deps") + _assert_dependency_install_event_matches( + dep_event, + expected_original_path=plugin_path / "requirements.txt", + ) assert ("load", TEST_PLUGIN_DIR) in events + + +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_installs_only_missing_requirement_lines( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text( + "aiohttp>=3.0\nboto3==1.2\nbotocore\n", + encoding="utf-8", + ) + events = [] + _mock_missing_requirements_plan( + monkeypatch, {"boto3", "botocore"}, ["boto3==1.2", "botocore"] + ) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, False, capture_content=True), + ) + + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert len(events) == 1 + kind, used_path, content = events[0] + assert kind == "deps" + assert used_path != str(requirements_path) + assert content == "boto3==1.2\nbotocore\n" + assert not Path(used_path).exists() + + +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_creates_temp_dir_before_filtered_install( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch, tmp_path +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("boto3\n", encoding="utf-8") + temp_dir = tmp_path / "missing-temp-dir" + events = [] + _mock_missing_requirements_plan(monkeypatch, {"boto3"}, ["boto3"]) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.get_astrbot_temp_path", + lambda: str(temp_dir), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, False, capture_content=True), + ) + + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert temp_dir.is_dir() + assert len(events) == 1 + + +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_falls_back_when_missing_names_have_no_install_lines( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("boto3\n", encoding="utf-8") + events = [] + + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda path: MissingRequirementsPlan( + missing_names=frozenset({"botocore"}), + install_lines=(), + fallback_reason="unmapped missing requirement names", + ), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + _build_dependency_install_mock(events, False), + ) + + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert events == [("deps", str(requirements_path))] + + +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_does_not_mask_install_error_when_cleanup_fails( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch, tmp_path +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("boto3\n", encoding="utf-8") + temp_dir = tmp_path / "cleanup-fails" + _mock_missing_requirements_plan(monkeypatch, {"boto3"}, ["boto3"]) + warning_logs = [] + + async def mock_install_requirements( + *, requirements_path: str | None = None, **kwargs + ): + del kwargs, requirements_path + raise RuntimeError("pip failed") + + original_remove = os.remove + + def flaky_remove(path): + if str(path).endswith("_plugin_requirements.txt"): + raise OSError("cleanup failed") + return original_remove(path) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.get_astrbot_temp_path", + lambda: str(temp_dir), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, + ) + monkeypatch.setattr("astrbot.core.star.star_manager.os.remove", flaky_remove) + monkeypatch.setattr( + "astrbot.core.star.star_manager.logger.warning", + lambda line, *args: warning_logs.append(line % args if args else line), + ) + + with pytest.raises(PluginDependencyInstallError, match="pip failed"): + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert any("删除临时插件依赖文件失败" in log for log in warning_logs)