Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions astrbot/core/platform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,17 @@ async def load_platform(self, platform_config: dict) -> None:
)
return

platform_id = platform_config["id"]
if platform_id in self._inst_map:
logger.warning(
"平台适配器 %s(%s) 已存在,正在先终止旧实例再重新加载。",
platform_config["type"],
platform_id,
)
await self.terminate_platform(platform_id)

logger.info(
f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...",
f"载入 {platform_config['type']}({platform_id}) 平台适配器 ...",
)
match platform_config["type"]:
case "aiocqhttp":
Expand Down Expand Up @@ -255,24 +264,29 @@ async def reload(self, platform_config: dict) -> None:
await self.terminate_platform(key)

async def terminate_platform(self, platform_id: str) -> None:
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
tracked_inst: Platform | None = None
info = self._inst_map.pop(platform_id, None)
if info:
tracked_inst = info["inst"]

# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id)
client_id = info["client_id"]
inst: Platform = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
),
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
insts_to_terminate: list[Platform] = []
if tracked_inst is not None:
insts_to_terminate.append(tracked_inst)

for inst in list(self.platform_insts):
if inst in insts_to_terminate:
continue
if getattr(inst, "config", {}).get("id") == platform_id:
insts_to_terminate.append(inst)

if not insts_to_terminate:
return

logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")

for inst in insts_to_terminate:
while inst in self.platform_insts:
self.platform_insts.remove(inst)
await self._terminate_inst_and_tasks(inst)
Comment on lines +267 to 290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The terminate_platform method can be refactored for better performance and readability.

  1. Efficient Instance Collection: Using a set comprehension to collect all relevant instances is more efficient than the current loop with list membership checks.
  2. Efficient Instance Removal: A list comprehension is the idiomatic and performant way to filter self.platform_insts, rather than using remove() in a loop.
  3. Concurrent Termination: Terminating instances concurrently with asyncio.gather will improve performance, as these are independent I/O-bound operations.

Applying these changes will make the function more robust and faster.

        info = self._inst_map.pop(platform_id, None)
        tracked_inst = info["inst"] if info else None

        insts_to_terminate = {
            inst
            for inst in self.platform_insts
            if getattr(inst, "config", {}).get("id") == platform_id
        }
        if tracked_inst:
            insts_to_terminate.add(tracked_inst)

        if not insts_to_terminate:
            return

        logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")

        self.platform_insts = [
            inst for inst in self.platform_insts if inst not in insts_to_terminate
        ]
        await asyncio.gather(
            *(self._terminate_inst_and_tasks(inst) for inst in insts_to_terminate)
        )


async def terminate(self) -> None:
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_platform_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio

import pytest

from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform.platform import Platform
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.register import platform_cls_map


class DummyAstrBotConfig(dict):
def save_config(self, replace_config: dict | None = None) -> None:
if replace_config is not None:
self.clear()
self.update(replace_config)


class DummyPlatform(Platform):
instances: list["DummyPlatform"] = []

def __init__(self, platform_config: dict, platform_settings: dict, event_queue):
super().__init__(platform_config, event_queue)
self.platform_settings = platform_settings
self.terminated = False
self._stop_event = asyncio.Event()
self.__class__.instances.append(self)

async def _run(self) -> None:
await self._stop_event.wait()

def run(self):
return self._run()

async def terminate(self) -> None:
self.terminated = True
self._stop_event.set()

def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="dummy",
description="dummy platform",
id=self.config["id"],
support_proactive_message=False,
)


@pytest.fixture
def manager(monkeypatch: pytest.MonkeyPatch) -> PlatformManager:
DummyPlatform.instances.clear()
monkeypatch.setitem(platform_cls_map, "dummy", DummyPlatform)
config = DummyAstrBotConfig({"platform": [], "platform_settings": {}})
return PlatformManager(config, asyncio.Queue())


@pytest.mark.asyncio
async def test_load_platform_replaces_existing_same_id(manager: PlatformManager):
config = {"id": "default", "type": "dummy", "enable": True}

await manager.load_platform(config.copy())
first_inst = DummyPlatform.instances[-1]

await manager.load_platform(config.copy())
second_inst = DummyPlatform.instances[-1]

assert first_inst is not second_inst
assert first_inst.terminated is True
assert second_inst.terminated is False
assert manager._inst_map["default"]["inst"] is second_inst
assert manager.platform_insts == [second_inst]


@pytest.mark.asyncio
async def test_terminate_platform_cleans_orphaned_instances(manager: PlatformManager):
config = {"id": "default", "type": "dummy", "enable": True}

await manager.load_platform(config.copy())
tracked_inst = DummyPlatform.instances[-1]

orphan_inst = DummyPlatform(config.copy(), {}, asyncio.Queue())
manager.platform_insts.append(orphan_inst)
manager._start_platform_task("orphan_default", orphan_inst)

await manager.terminate_platform("default")

assert tracked_inst.terminated is True
assert orphan_inst.terminated is True
assert manager.platform_insts == []
assert "default" not in manager._inst_map