From ea1fef4990f226a8cbde5bbf014478be3429e1c5 Mon Sep 17 00:00:00 2001 From: stablegenius49 <185121704+stablegenius49@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:09:33 -0700 Subject: [PATCH] fix: clean up duplicate platform adapters on reload --- astrbot/core/platform/manager.py | 48 ++++++++++------ tests/unit/test_platform_manager.py | 88 +++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 17 deletions(-) create mode 100644 tests/unit/test_platform_manager.py diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 68737b2bcf..80090d71aa 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -122,8 +122,17 @@ async def load_platform(self, platform_config: dict) -> None: ) return + platform_id = platform_config["id"] + if platform_id in self._inst_map: + logger.warning( + "平台适配器 %s(%s) 已存在,正在先终止旧实例再重新加载。", + platform_config["type"], + platform_id, + ) + await self.terminate_platform(platform_id) + logger.info( - f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...", + f"载入 {platform_config['type']}({platform_id}) 平台适配器 ...", ) match platform_config["type"]: case "aiocqhttp": @@ -255,24 +264,29 @@ async def reload(self, platform_config: dict) -> None: await self.terminate_platform(key) async def terminate_platform(self, platform_id: str) -> None: - if platform_id in self._inst_map: - logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") + tracked_inst: Platform | None = None + info = self._inst_map.pop(platform_id, None) + if info: + tracked_inst = info["inst"] - # client_id = self._inst_map.pop(platform_id, None) - info = self._inst_map.pop(platform_id) - client_id = info["client_id"] - inst: Platform = info["inst"] - try: - self.platform_insts.remove( - next( - inst - for inst in self.platform_insts - if inst.client_self_id == client_id - ), - ) - except Exception: - logger.warning(f"可能未完全移除 {platform_id} 平台适配器") + insts_to_terminate: list[Platform] = [] + if tracked_inst is not None: + insts_to_terminate.append(tracked_inst) + + for inst in list(self.platform_insts): + if inst in insts_to_terminate: + continue + if getattr(inst, "config", {}).get("id") == platform_id: + insts_to_terminate.append(inst) + + if not insts_to_terminate: + return + + logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") + for inst in insts_to_terminate: + while inst in self.platform_insts: + self.platform_insts.remove(inst) await self._terminate_inst_and_tasks(inst) async def terminate(self) -> None: diff --git a/tests/unit/test_platform_manager.py b/tests/unit/test_platform_manager.py new file mode 100644 index 0000000000..2f25ebe0ac --- /dev/null +++ b/tests/unit/test_platform_manager.py @@ -0,0 +1,88 @@ +import asyncio + +import pytest + +from astrbot.core.platform.manager import PlatformManager +from astrbot.core.platform.platform import Platform +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.platform.register import platform_cls_map + + +class DummyAstrBotConfig(dict): + def save_config(self, replace_config: dict | None = None) -> None: + if replace_config is not None: + self.clear() + self.update(replace_config) + + +class DummyPlatform(Platform): + instances: list["DummyPlatform"] = [] + + def __init__(self, platform_config: dict, platform_settings: dict, event_queue): + super().__init__(platform_config, event_queue) + self.platform_settings = platform_settings + self.terminated = False + self._stop_event = asyncio.Event() + self.__class__.instances.append(self) + + async def _run(self) -> None: + await self._stop_event.wait() + + def run(self): + return self._run() + + async def terminate(self) -> None: + self.terminated = True + self._stop_event.set() + + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + name="dummy", + description="dummy platform", + id=self.config["id"], + support_proactive_message=False, + ) + + +@pytest.fixture +def manager(monkeypatch: pytest.MonkeyPatch) -> PlatformManager: + DummyPlatform.instances.clear() + monkeypatch.setitem(platform_cls_map, "dummy", DummyPlatform) + config = DummyAstrBotConfig({"platform": [], "platform_settings": {}}) + return PlatformManager(config, asyncio.Queue()) + + +@pytest.mark.asyncio +async def test_load_platform_replaces_existing_same_id(manager: PlatformManager): + config = {"id": "default", "type": "dummy", "enable": True} + + await manager.load_platform(config.copy()) + first_inst = DummyPlatform.instances[-1] + + await manager.load_platform(config.copy()) + second_inst = DummyPlatform.instances[-1] + + assert first_inst is not second_inst + assert first_inst.terminated is True + assert second_inst.terminated is False + assert manager._inst_map["default"]["inst"] is second_inst + assert manager.platform_insts == [second_inst] + + +@pytest.mark.asyncio +async def test_terminate_platform_cleans_orphaned_instances(manager: PlatformManager): + config = {"id": "default", "type": "dummy", "enable": True} + + await manager.load_platform(config.copy()) + tracked_inst = DummyPlatform.instances[-1] + + orphan_inst = DummyPlatform(config.copy(), {}, asyncio.Queue()) + manager.platform_insts.append(orphan_inst) + manager._start_platform_task("orphan_default", orphan_inst) + + await manager.terminate_platform("default") + + assert tracked_inst.terminated is True + assert orphan_inst.terminated is True + assert manager.platform_insts == [] + assert "default" not in manager._inst_map