diff --git a/roborock/devices/rpc/b01_q7_channel.py b/roborock/devices/rpc/b01_q7_channel.py index add5bc97..16170446 100644 --- a/roborock/devices/rpc/b01_q7_channel.py +++ b/roborock/devices/rpc/b01_q7_channel.py @@ -14,7 +14,7 @@ decode_rpc_response, encode_mqtt_payload, ) -from roborock.roborock_message import RoborockMessage +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol _LOGGER = logging.getLogger(__name__) _TIMEOUT = 10.0 @@ -99,3 +99,30 @@ def find_response(response_message: RoborockMessage) -> None: raise finally: unsub() + + +async def send_map_command(mqtt_channel: MqttChannel, request_message: Q7RequestMessage) -> bytes: + """Send map upload command and wait for MAP_RESPONSE payload bytes.""" + + roborock_message = encode_mqtt_payload(request_message) + future: asyncio.Future[bytes] = asyncio.get_running_loop().create_future() + + def find_response(response_message: RoborockMessage) -> None: + if future.done(): + return + + if ( + response_message.protocol == RoborockMessageProtocol.MAP_RESPONSE + and response_message.payload + and response_message.version == roborock_message.version + ): + future.set_result(response_message.payload) + + unsub = await mqtt_channel.subscribe(find_response) + try: + await mqtt_channel.publish(roborock_message) + return await asyncio.wait_for(future, timeout=_TIMEOUT) + except TimeoutError as ex: + raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex + finally: + unsub() diff --git a/roborock/devices/traits/b01/q7/__init__.py b/roborock/devices/traits/b01/q7/__init__.py index 9c09c05c..02a285cd 100644 --- a/roborock/devices/traits/b01/q7/__init__.py +++ b/roborock/devices/traits/b01/q7/__init__.py @@ -21,12 +21,18 @@ from roborock.roborock_typing import RoborockB01Q7Methods from .clean_summary import CleanSummaryTrait +from .map import MapTrait, Q7MapList, Q7MapListEntry __all__ = [ "Q7PropertiesApi", "CleanSummaryTrait", + "MapTrait", + "Q7MapList", + "Q7MapListEntry", ] +_Q7_DPS = 10000 + class Q7PropertiesApi(Trait): """API for interacting with B01 devices.""" @@ -34,10 +40,14 @@ class Q7PropertiesApi(Trait): clean_summary: CleanSummaryTrait """Trait for clean records / clean summary (Q7 `service.get_record_list`).""" + map: MapTrait + """Trait for map list metadata + raw map payload retrieval.""" + def __init__(self, channel: MqttChannel) -> None: """Initialize the B01Props API.""" self._channel = channel self.clean_summary = CleanSummaryTrait(channel) + self.map = MapTrait(channel) async def query_values(self, props: list[RoborockB01Props]) -> B01Props | None: """Query the device for the values of the given Q7 properties.""" @@ -87,6 +97,17 @@ async def start_clean(self) -> None: }, ) + async def clean_segments(self, segment_ids: list[int]) -> None: + """Start segment cleaning for the given ids (Q7 uses room ids).""" + await self.send( + command=RoborockB01Q7Methods.SET_ROOM_CLEAN, + params={ + "clean_type": CleanTaskTypeMapping.ROOM.code, + "ctrl_value": SCDeviceCleanParam.START.code, + "room_ids": segment_ids, + }, + ) + async def pause_clean(self) -> None: """Pause cleaning.""" await self.send( @@ -127,7 +148,7 @@ async def send(self, command: CommandType, params: ParamsType) -> Any: """Send a command to the device.""" return await send_decoded_command( self._channel, - Q7RequestMessage(dps=10000, command=command, params=params), + Q7RequestMessage(dps=_Q7_DPS, command=command, params=params), ) diff --git a/roborock/devices/traits/b01/q7/map.py b/roborock/devices/traits/b01/q7/map.py new file mode 100644 index 00000000..b98524c8 --- /dev/null +++ b/roborock/devices/traits/b01/q7/map.py @@ -0,0 +1,101 @@ +"""Map trait for B01 Q7 devices.""" + +import asyncio +from dataclasses import dataclass, field + +from roborock.data import RoborockBase +from roborock.devices.rpc.b01_q7_channel import send_decoded_command, send_map_command +from roborock.devices.traits import Trait +from roborock.devices.transport.mqtt_channel import MqttChannel +from roborock.exceptions import RoborockException +from roborock.protocols.b01_q7_protocol import Q7RequestMessage +from roborock.roborock_typing import RoborockB01Q7Methods + +_Q7_DPS = 10000 + + +@dataclass +class Q7MapListEntry(RoborockBase): + """Single map list entry returned by `service.get_map_list`.""" + + id: int | None = None + cur: bool | None = None + + +@dataclass +class Q7MapList(RoborockBase): + """Map list response returned by `service.get_map_list`.""" + + map_list: list[Q7MapListEntry] = field(default_factory=list) + + +class MapTrait(Trait): + """Map retrieval + map metadata helpers for Q7 devices.""" + + def __init__(self, channel: MqttChannel) -> None: + self._channel = channel + # Map uploads are serialized per-device to avoid response cross-wiring. + self._map_command_lock = asyncio.Lock() + self._map_list: Q7MapList | None = None + + @property + def map_list(self) -> Q7MapList | None: + """Latest cached map list metadata, populated by ``refresh()``.""" + return self._map_list + + @property + def current_map_id(self) -> int | None: + """Current map id derived from cached map list metadata.""" + if self._map_list is None: + return None + return self._extract_current_map_id(self._map_list) + + async def refresh(self) -> None: + """Refresh cached map list metadata from the device.""" + response = await send_decoded_command( + self._channel, + Q7RequestMessage(dps=_Q7_DPS, command=RoborockB01Q7Methods.GET_MAP_LIST, params={}), + ) + if not isinstance(response, dict): + raise TypeError(f"Unexpected response type for GET_MAP_LIST: {type(response).__name__}: {response!r}") + + parsed = Q7MapList.from_dict(response) + if parsed is None: + raise TypeError(f"Failed to decode map list response: {response!r}") + + self._map_list = parsed + + async def _get_map_payload(self, *, map_id: int) -> bytes: + """Fetch raw map payload bytes for the given map id.""" + request = Q7RequestMessage( + dps=_Q7_DPS, + command=RoborockB01Q7Methods.UPLOAD_BY_MAPID, + params={"map_id": map_id}, + ) + async with self._map_command_lock: + return await send_map_command(self._channel, request) + + async def get_current_map_payload(self) -> bytes: + """Fetch raw map payload bytes for the currently selected map.""" + if self._map_list is None: + await self.refresh() + + map_id = self.current_map_id + if map_id is None: + raise RoborockException(f"Unable to determine map_id from map list response: {self._map_list!r}") + return await self._get_map_payload(map_id=map_id) + + @staticmethod + def _extract_current_map_id(map_list_response: Q7MapList) -> int | None: + map_list = map_list_response.map_list + if not map_list: + return None + + for entry in map_list: + if entry.cur and isinstance(entry.id, int): + return entry.id + + first = map_list[0] + if isinstance(first.id, int): + return first.id + return None diff --git a/tests/devices/traits/b01/q7/test_init.py b/tests/devices/traits/b01/q7/test_init.py index cb16299c..403cc04d 100644 --- a/tests/devices/traits/b01/q7/test_init.py +++ b/tests/devices/traits/b01/q7/test_init.py @@ -17,7 +17,7 @@ from roborock.devices.traits.b01.q7 import Q7PropertiesApi from roborock.exceptions import RoborockException from roborock.protocols.b01_q7_protocol import B01_VERSION, Q7RequestMessage -from roborock.roborock_message import RoborockB01Props, RoborockMessageProtocol +from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol from tests.fixtures.channel_fixtures import FakeChannel from . import B01MessageBuilder @@ -257,3 +257,108 @@ async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeChannel payload_data = json.loads(unpad(message.payload, AES.block_size)) assert payload_data["dps"]["10000"]["method"] == "service.find_device" assert payload_data["dps"]["10000"]["params"] == {} + + +async def test_q7_api_clean_segments( + q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder +): + """Test room/segment cleaning helper for Q7.""" + fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + await q7_api.clean_segments([10, 11]) + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + payload_data = json.loads(unpad(message.payload, AES.block_size)) + assert payload_data["dps"]["10000"]["method"] == "service.set_room_clean" + assert payload_data["dps"]["10000"]["params"] == { + "clean_type": CleanTaskTypeMapping.ROOM.code, + "ctrl_value": SCDeviceCleanParam.START.code, + "room_ids": [10, 11], + } + + +async def test_q7_api_get_current_map_payload( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +): + """Fetch current map by map-list lookup, then upload_by_mapid.""" + fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 1772093512, "cur": True}]})) + fake_channel.response_queue.append( + RoborockMessage( + protocol=RoborockMessageProtocol.MAP_RESPONSE, + payload=b"raw-map-payload", + version=b"B01", + seq=message_builder.seq + 1, + ) + ) + + raw_payload = await q7_api.map.get_current_map_payload() + assert raw_payload == b"raw-map-payload" + + assert len(fake_channel.published_messages) == 2 + + first = fake_channel.published_messages[0] + first_payload = json.loads(unpad(first.payload, AES.block_size)) + assert first_payload["dps"]["10000"]["method"] == "service.get_map_list" + assert first_payload["dps"]["10000"]["params"] == {} + + second = fake_channel.published_messages[1] + second_payload = json.loads(unpad(second.payload, AES.block_size)) + assert second_payload["dps"]["10000"]["method"] == "service.upload_by_mapid" + assert second_payload["dps"]["10000"]["params"] == {"map_id": 1772093512} + + +async def test_q7_api_map_trait_refresh_populates_cached_values( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +): + """Map trait follows refresh + cached-value access pattern.""" + fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 101, "cur": True}]})) + + assert q7_api.map.map_list is None + assert q7_api.map.current_map_id is None + + await q7_api.map.refresh() + + assert len(fake_channel.published_messages) == 1 + assert q7_api.map.map_list is not None + assert q7_api.map.map_list.map_list[0].id == 101 + assert q7_api.map.map_list.map_list[0].cur is True + assert q7_api.map.current_map_id == 101 + + +async def test_q7_api_get_current_map_payload_falls_back_to_first_map( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +): + """If no current map marker exists, first map in list is used.""" + fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 111}, {"id": 222, "cur": False}]})) + fake_channel.response_queue.append( + RoborockMessage( + protocol=RoborockMessageProtocol.MAP_RESPONSE, + payload=b"raw-map-payload", + version=b"B01", + seq=message_builder.seq + 1, + ) + ) + + await q7_api.map.get_current_map_payload() + + second = fake_channel.published_messages[1] + second_payload = json.loads(unpad(second.payload, AES.block_size)) + assert second_payload["dps"]["10000"]["params"] == {"map_id": 111} + + +async def test_q7_api_get_current_map_payload_errors_without_map_list( + q7_api: Q7PropertiesApi, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +): + """Current-map payload fetch should fail clearly when map list is unusable.""" + fake_channel.response_queue.append(message_builder.build({"map_list": []})) + + with pytest.raises(RoborockException, match="Unable to determine map_id"): + await q7_api.map.get_current_map_payload()