diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index 5c2f7a37f3..c30f88a195 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -1,24 +1,35 @@ import asyncio import json +import re import time -from xml.etree import ElementTree as ET +from collections.abc import Sequence +from typing import cast import websockets from aiohttp import ClientSession, ClientTimeout +from satori import element +from satori.const import EventType +from satori.event import MessageEvent +from satori.model import Event, Identify, Login, Opcode, Ready +from satori.utils import decode, encode from websockets.asyncio.client import ClientConnection, connect from astrbot.api import logger from astrbot.api.event import MessageChain from astrbot.api.message_components import ( At, + AtAll, + Face, File, Image, Plain, Record, Reply, + Video, ) from astrbot.api.platform import ( AstrBotMessage, + Group, MessageMember, MessageType, Platform, @@ -27,6 +38,8 @@ ) from astrbot.core.platform.astr_message_event import MessageSession +b64_cap = re.compile(r"^data:([\w/.+-]+);base64,") + @register_platform_adapter( "satori", "Satori 协议适配器", support_streaming_message=False @@ -64,7 +77,7 @@ def __init__( self.ws: ClientConnection | None = None self.session: ClientSession | None = None self.sequence = 0 - self.logins = [] + self.logins: Sequence[Login] = [] self.running = False self.heartbeat_task: asyncio.Task | None = None self.ready_received = False @@ -73,6 +86,7 @@ async def send_by_session( self, session: MessageSession, message_chain: MessageChain, + referrer: dict | None = None, ) -> None: from .satori_event import SatoriPlatformEvent @@ -80,6 +94,7 @@ async def send_by_session( self, message_chain, session.session_id, + referrer=referrer, ) await super().send_by_session(session, message_chain) @@ -188,19 +203,15 @@ async def send_identify(self) -> None: if self._is_websocket_closed(self.ws): raise Exception("WebSocket连接已关闭") - identify_payload = { - "op": 3, # IDENTIFY - "body": { - "token": str(self.token) if self.token else "", # 字符串 - }, - } - + identify_payload = Identify(token=self.token) # 只有在有序列号时才添加sn字段 if self.sequence > 0: - identify_payload["body"]["sn"] = self.sequence + identify_payload.sn = self.sequence try: - message_str = json.dumps(identify_payload, ensure_ascii=False) + message_str = encode( + {"op": Opcode.IDENTIFY, "body": identify_payload.dump()} + ) await self.ws.send(message_str) except websockets.exceptions.ConnectionClosed as e: logger.error(f"发送 IDENTIFY 信令时连接关闭: {e}") @@ -209,6 +220,35 @@ async def send_identify(self) -> None: logger.error(f"发送 IDENTIFY 信令失败: {e}") raise + try: + response_str = await self.ws.recv() + except websockets.exceptions.ConnectionClosed as e: + logger.error(f"接收 READY 消息时连接关闭: {e}") + raise + except Exception as e: + logger.error(f"接收 READY 消息失败: {e}") + raise + payload = decode(response_str) + op = payload.get("op") + if op != Opcode.READY: + logger.error(f"预期收到 READY 消息,但收到的消息 op 是 {op}") + raise Exception(f"预期收到 READY 消息,但收到的消息 op 是 {op}") + body = payload.get("body", {}) + resp = Ready.parse(body) + self.logins = resp.logins + + # 输出连接成功的bot信息 + for i, login in enumerate(self.logins): + logger.info( + f"Satori 连接成功 - Bot {i + 1}: " + f"platform={login.platform}, " + f"user_id={login.user.id if login.user else ''}, " + f"user_name={login.user.name if login.user else ''}", + ) + if self.logins: + self.ready_received = True + logger.info("Satori 适配器已准备就绪") + async def heartbeat_loop(self) -> None: try: while self.running and self.ws: @@ -216,11 +256,8 @@ async def heartbeat_loop(self) -> None: if self.ws and not self._is_websocket_closed(self.ws): try: - ping_payload = { - "op": 1, # PING - "body": {}, - } - await self.ws.send(json.dumps(ping_payload, ensure_ascii=False)) + ping_payload = {"op": Opcode.PING} + await self.ws.send(encode(ping_payload)) except websockets.exceptions.ConnectionClosed as e: logger.error(f"Satori WebSocket 连接关闭: {e}") break @@ -236,39 +273,21 @@ async def heartbeat_loop(self) -> None: async def handle_message(self, message: str) -> None: try: - data = json.loads(message) + data = decode(message) + op = data.get("op") body = data.get("body", {}) - - if op == 4: # READY - self.logins = body.get("logins", []) - self.ready_received = True - - # 输出连接成功的bot信息 - if self.logins: - for i, login in enumerate(self.logins): - platform = login.get("platform", "") - user = login.get("user", {}) - user_id = user.get("id", "") - user_name = user.get("name", "") - logger.info( - f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}", - ) - - if "sn" in body: - self.sequence = body["sn"] - - elif op == 2: # PONG + if op == Opcode.PONG: pass - elif op == 0: # EVENT + elif op == Opcode.EVENT: # EVENT await self.handle_event(body) - if "sn" in body: - self.sequence = body["sn"] - elif op == 5: # META - if "sn" in body: - self.sequence = body["sn"] + elif op == Opcode.META: + # TODO: META 消息会携带 satori-server 支持的 proxy_urls, 用于资源链接的下载 + pass + else: + logger.warning(f"收到未知的 WebSocket 消息: {data}") except json.JSONDecodeError as e: logger.error(f"解析 WebSocket 消息失败: {e}, 消息内容: {message}") @@ -277,93 +296,80 @@ async def handle_message(self, message: str) -> None: async def handle_event(self, event_data: dict) -> None: try: - event_type = event_data.get("type") - sn = event_data.get("sn") - if sn: - self.sequence = sn - - if event_type == "message-created": - message = event_data.get("message", {}) - user = event_data.get("user", {}) - channel = event_data.get("channel", {}) - guild = event_data.get("guild") - login = event_data.get("login", {}) - timestamp = event_data.get("timestamp") - - if user.get("id") == login.get("user", {}).get("id"): - return - - abm = await self.convert_satori_message( - message, - user, - channel, - guild, - login, - timestamp, + event = Event.parse(event_data) + except Exception as e: + if ( + "self_id" in event_data + or ("login" in event_data and "self_id" in event_data["login"]) + or ( + "login" in event_data + and "user" in event_data["login"] + and "self_id" in event_data["login"]["user"] ) - if abm: + ): + logger.error(f"解析事件失败: {e}") + else: + logger.debug(f"解析事件失败: {e}") + else: + if event.sn is not None: + self.sequence = event.sn + if event.type == EventType.MESSAGE_CREATED: + if event.user and event.user.id == event.login.user.id: + return + if abm := await self.convert_satori_message(cast(MessageEvent, event)): await self.handle_msg(abm) - except Exception as e: - logger.error(f"处理事件失败: {e}") - async def convert_satori_message( - self, - message: dict, - user: dict, - channel: dict, - guild: dict | None, - login: dict, - timestamp: int | None = None, + self, event: MessageEvent ) -> AstrBotMessage | None: try: abm = AstrBotMessage() - abm.message_id = message.get("id", "") + abm.message_id = event.message.id + abm.timestamp = int(event.timestamp.timestamp()) abm.raw_message = { - "message": message, - "user": user, - "channel": channel, - "guild": guild, - "login": login, + "type": event._type, + "data": event._data, + "message": event.message.dump(), + "user": event.user.dump(), + "channel": event.channel.dump(), + "guild": event.guild.dump() if event.guild else None, + "login": event.login.dump(), + "referrer": event.referrer, } - - if guild and guild.get("id"): - abm.type = MessageType.GROUP_MESSAGE - abm.group_id = guild.get("id", "") - abm.session_id = channel.get("id", "") - else: + channel_id = event.channel.id + if channel_id.startswith("private:"): abm.type = MessageType.FRIEND_MESSAGE - abm.session_id = channel.get("id", "") + abm.session_id = channel_id + else: + abm.type = MessageType.GROUP_MESSAGE + abm.group = Group( + group_id=channel_id, + group_name=event.channel.name, + group_avatar=event.guild.avatar if event.guild else None, + ) + if event.guild and event.guild.id != channel_id: # 二级频道 + abm.session_id = f"{event.guild.id}:{channel_id}" + else: # 一级群组 + abm.session_id = channel_id abm.sender = MessageMember( - user_id=user.get("id", ""), - nickname=user.get("nick", user.get("name", "")), + user_id=event.user.id, + nickname=event.user.nick or event.user.name or "", ) - - abm.self_id = login.get("user", {}).get("id", "") - + abm.self_id = event.login.user.id # 消息链 abm.message = [] - content = message.get("content", "") - - quote = message.get("quote") - content_for_parsing = content # 副本 - - # 提取标签 - if "标签时发生错误: {e}, 错误内容: {content}") - + elements = event.message.message + if raw_quote := event.message._raw_data.get("quote"): + quote: element.Quote | None = element.transform([raw_quote])[0] # type: ignore + elif quotes := element.select(elements, element.Quote): + quote = quotes[0] + else: + quote = None if quote: - # 引用消息 - quote_abm = await self._convert_quote_message(quote) - if quote_abm: + elements = [e for e in elements if not isinstance(e, element.Quote)] + if quote_abm := self._convert_quote_message(quote, abm.self_id): sender_id = quote_abm.sender.user_id if isinstance(sender_id, str) and sender_id.isdigit(): sender_id = int(sender_id) @@ -383,202 +389,51 @@ async def convert_satori_message( abm.message.append(reply_component) # 解析消息内容 - content_elements = await self.parse_satori_elements(content_for_parsing) + content_elements = self.parse_satori_elements(elements) abm.message.extend(content_elements) abm.message_str = "" for comp in content_elements: if isinstance(comp, Plain): abm.message_str += comp.text - - # 优先使用Satori事件中的时间戳 - if timestamp is not None: - abm.timestamp = timestamp - else: - abm.timestamp = int(time.time()) - return abm except Exception as e: logger.error(f"转换 Satori 消息失败: {e}") return None - def _extract_namespace_prefixes(self, content: str) -> set: - """提取XML内容中的命名空间前缀""" - prefixes = set() - - # 查找所有标签 - i = 0 - while i < len(content): - # 查找开始标签 - if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/": - # 找到标签结束位置 - tag_end = content.find(">", i) - if tag_end != -1: - # 提取标签内容 - tag_content = content[i + 1 : tag_end] - # 检查是否有命名空间前缀 - if ":" in tag_content and "xmlns:" not in tag_content: - # 分割标签名 - parts = tag_content.split() - if parts: - tag_name = parts[0] - if ":" in tag_name: - prefix = tag_name.split(":")[0] - # 确保是有效的命名空间前缀 - if ( - prefix.isalnum() - or prefix.replace("_", "").isalnum() - ): - prefixes.add(prefix) - i = tag_end + 1 - else: - i += 1 - # 查找结束标签 - elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/": - # 找到标签结束位置 - tag_end = content.find(">", i) - if tag_end != -1: - # 提取标签内容 - tag_content = content[i + 2 : tag_end] - # 检查是否有命名空间前缀 - if ":" in tag_content: - prefix = tag_content.split(":")[0] - # 确保是有效的命名空间前缀 - if prefix.isalnum() or prefix.replace("_", "").isalnum(): - prefixes.add(prefix) - i = tag_end + 1 - else: - i += 1 - else: - i += 1 - - return prefixes - - async def _extract_quote_element(self, content: str) -> dict | None: - """提取标签信息""" - try: - # 处理命名空间前缀问题 - processed_content = content - if ":" in content and not content.startswith("{content}" - elif not content.startswith("{content}" - else: - processed_content = content - - root = ET.fromstring(processed_content) - - # 查找标签 - quote_element = None - for elem in root.iter(): - tag_name = elem.tag - if "}" in tag_name: - tag_name = tag_name.split("}")[1] - if tag_name.lower() == "quote": - quote_element = elem - break - - if quote_element is not None: - # 提取quote标签的属性 - quote_id = quote_element.get("id", "") - - # 提取标签内部的内容 - inner_content = "" - if quote_element.text: - inner_content += quote_element.text - for child in quote_element: - inner_content += ET.tostring( - child, - encoding="unicode", - method="xml", - ) - if child.tail: - inner_content += child.tail - - # 构造移除了标签的内容 - content_without_quote = content.replace( - ET.tostring(quote_element, encoding="unicode", method="xml"), - "", - ) - - return { - "quote": {"id": quote_id, "content": inner_content}, - "content_without_quote": content_without_quote, - } - - return None - except ET.ParseError as e: - logger.warning(f"XML解析失败,使用正则提取: {e}") - return await self._extract_quote_with_regex(content) - except Exception as e: - logger.error(f"提取标签时发生错误: {e}") - return None - - async def _extract_quote_with_regex(self, content: str) -> dict | None: - """使用正则表达式提取quote标签信息""" - import re - - quote_pattern = r"]*)>(.*?)" - match = re.search(quote_pattern, content, re.DOTALL) - - if not match: - return None - - attrs_str = match.group(1) - inner_content = match.group(2) - - id_match = re.search(r'id\s*=\s*["\']([^"\']*)["\']', attrs_str) - quote_id = id_match.group(1) if id_match else "" - content_without_quote = content.replace(match.group(0), "") - content_without_quote = content_without_quote.strip() - - return { - "quote": {"id": quote_id, "content": inner_content}, - "content_without_quote": content_without_quote, - } - - async def _convert_quote_message(self, quote: dict) -> AstrBotMessage | None: + def _convert_quote_message( + self, quote: element.Quote, self_id: str + ) -> AstrBotMessage | None: """转换引用消息""" try: quote_abm = AstrBotMessage() - quote_abm.message_id = quote.get("id", "") + quote_abm.message_id = quote.id or "" # 解析引用消息的发送者 - quote_author = quote.get("author", {}) - if quote_author: + quote_authors = element.select(quote, element.Author) + if quote_authors: + quote_author = quote_authors[0] quote_abm.sender = MessageMember( - user_id=quote_author.get("id", ""), - nickname=quote_author.get("nick", quote_author.get("name", "")), + user_id=quote_author.id, + nickname=quote_author.name or "", ) else: # 如果没有作者信息,使用默认值 quote_abm.sender = MessageMember( - user_id=quote.get("user_id", ""), + user_id=self_id, nickname="内容", ) # 解析引用消息内容 - quote_content = quote.get("content", "") - quote_abm.message = await self.parse_satori_elements(quote_content) + quote_abm.message = self.parse_satori_elements(quote.children) quote_abm.message_str = "" for comp in quote_abm.message: if isinstance(comp, Plain): quote_abm.message_str += comp.text - quote_abm.timestamp = int(quote.get("timestamp", time.time())) + quote_abm.timestamp = int(time.time()) # 如果没有任何内容,使用默认文本 if not quote_abm.message_str.strip(): @@ -589,136 +444,89 @@ async def _convert_quote_message(self, quote: dict) -> AstrBotMessage | None: logger.error(f"转换引用消息失败: {e}") return None - async def parse_satori_elements(self, content: str) -> list: + def parse_satori_elements(self, elements: list[element.Element]) -> list: """解析 Satori 消息元素""" - elements = [] - - if not content: - return elements - - try: - # 处理命名空间前缀问题 - processed_content = content - if ":" in content and not content.startswith("{content}" - elif not content.startswith("{content}" - else: - processed_content = content - - root = ET.fromstring(processed_content) - await self._parse_xml_node(root, elements) - except ET.ParseError as e: - logger.warning(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}") - # 如果解析失败,将整个内容当作纯文本 - if content.strip(): - elements.append(Plain(text=content)) - except Exception as e: - logger.error(f"解析 Satori 元素时发生未知错误: {e}") - raise e - - # 如果没有解析到任何元素,将整个内容当作纯文本 - if not elements and content.strip(): - elements.append(Plain(text=content)) - - return elements - - async def _parse_xml_node(self, node: ET.Element, elements: list) -> None: - """递归解析 XML 节点""" - if node.text and node.text.strip(): - elements.append(Plain(text=node.text)) - - for child in node: - # 获取标签名,去除命名空间前缀 - tag_name = child.tag - if "}" in tag_name: - tag_name = tag_name.split("}")[1] - tag_name = tag_name.lower() - - attrs = child.attrib - - if tag_name == "at": - user_id = attrs.get("id") or attrs.get("name", "") - elements.append(At(qq=user_id, name=user_id)) - - elif tag_name in ("img", "image"): - src = attrs.get("src", "") - if not src: - continue - elements.append(Image(file=src)) - - elif tag_name == "file": - src = attrs.get("src", "") - name = attrs.get("name", "文件") - if src: - elements.append(File(name=name, file=src)) - - elif tag_name in ("audio", "record"): - src = attrs.get("src", "") - if not src: - continue - elements.append(Record(file=src)) - - elif tag_name == "quote": - # quote标签已经被特殊处理 - pass - - elif tag_name == "face": - face_id = attrs.get("id", "") - face_name = attrs.get("name", "") - face_type = attrs.get("type", "") - - if face_name: - elements.append(Plain(text=f"[表情:{face_name}]")) - elif face_id and face_type: - elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]")) - elif face_id: - elements.append(Plain(text=f"[表情ID:{face_id}]")) + parsed_elements = [] + + for item in elements: + if isinstance(item, element.Text): + parsed_elements.append(Plain(text=item.text)) + elif isinstance(item, element.Sharp): + parsed_elements.append(Plain(text=f"#{item.id}")) + elif isinstance(item, element.Link): + parsed_elements.extend(self.parse_satori_elements(item.children)) + if item.href: + parsed_elements.append(Plain(text=f" ({item.href})")) + elif isinstance(item, element.Br): + parsed_elements.append(Plain(text="\n")) + elif isinstance(item, element.Paragraph): + prev = parsed_elements[-1] if parsed_elements else None + if prev and isinstance(prev, Plain): + if not prev.text.endswith("\n"): + prev.text += "\n" else: - elements.append(Plain(text="[表情]")) - - elif tag_name == "ark": - # 作为纯文本添加到消息链中 - data = attrs.get("data", "") - if data: - import html - - decoded_data = html.unescape(data) - elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]")) + parsed_elements.append(Plain(text="\n")) + parsed_elements.extend(self.parse_satori_elements(item.children)) + parsed_elements.append(Plain(text="\n")) + elif isinstance(item, element.At): + if item.type in ("all", "here", "everyone"): + parsed_elements.append(AtAll()) else: - elements.append(Plain(text="[ARK卡片]")) - - elif tag_name == "json": - # JSON标签 视为ARK卡片消息 - data = attrs.get("data", "") - if data: - import html - - decoded_data = html.unescape(data) - elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]")) + user_id = item.id or "" + parsed_elements.append(At(qq=user_id, name=item.name or user_id)) + elif isinstance(item, element.Image): + file = item.src + if mat := b64_cap.match(item.src): + file = f"base64://{item.src[len(mat[0]) :]}" + parsed_elements.append(Image(file=file)) + elif isinstance(item, element.File): + file = item.src + if mat := b64_cap.match(item.src): + file = f"base64://{item.src[len(mat[0]) :]}" + parsed_elements.append(File(name=item.title or "文件", file=file)) + elif isinstance(item, element.Audio): + file = item.src + if mat := b64_cap.match(item.src): + file = f"base64://{item.src[len(mat[0]) :]}" + parsed_elements.append(Record(file=file)) + elif isinstance(item, element.Video): + file = item.src + if mat := b64_cap.match(item.src): + file = f"base64://{item.src[len(mat[0]) :]}" + parsed_elements.append(Video(file=file)) + elif isinstance(item, element.Emoji): + if item.name: + parsed_elements.append(Plain(text=f"[表情:{item.name}]")) else: - elements.append(Plain(text="[JSON卡片]")) - + parsed_elements.append(Face(id=item.id)) + elif isinstance(item, element.Custom): + if item.tag == "ark": + data = item._attrs.get("data", "") + if data: + import html + + decoded_data = html.unescape(data) + parsed_elements.append( + Plain(text=f"[ARK卡片数据: {decoded_data}]") + ) + else: + parsed_elements.append(Plain(text="[ARK卡片]")) + elif item.tag == "json": + data = item._attrs.get("data", "") + if data: + import html + + decoded_data = html.unescape(data) + parsed_elements.append( + Plain(text=f"[JSON卡片数据: {decoded_data}]") + ) + else: + parsed_elements.append(Plain(text="[JSON卡片]")) + else: + parsed_elements.extend(self.parse_satori_elements(item.children)) else: - # 未知标签,递归处理其内容 - if child.text and child.text.strip(): - elements.append(Plain(text=child.text)) - await self._parse_xml_node(child, elements) - - # 处理标签后的文本 - if child.tail and child.tail.strip(): - elements.append(Plain(text=child.tail)) + parsed_elements.extend(self.parse_satori_elements(item.children)) + return parsed_elements async def handle_msg(self, message: AstrBotMessage) -> None: from .satori_event import SatoriPlatformEvent @@ -751,13 +559,14 @@ async def send_http_request( headers["Authorization"] = f"Bearer {self.token}" if platform and user_id: - headers["satori-platform"] = platform - headers["satori-user-id"] = user_id + headers["Satori-Platform"] = platform + headers["Satori-User-Id"] = user_id elif self.logins: current_login = self.logins[0] - headers["satori-platform"] = current_login.get("platform", "") - user = current_login.get("user", {}) - headers["satori-user-id"] = user.get("id", "") if user else "" + headers["Satori-Platform"] = current_login.platform + headers["Satori-User-Id"] = ( + current_login.user.id if current_login.user else "" + ) if not path.startswith("/"): path = "/" + path diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 0214222837..2231c88381 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -1,4 +1,10 @@ -from typing import TYPE_CHECKING +from base64 import b64decode +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING, TypeVar + +from satori.const import Api +from satori.element import E, Element, Resource from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -15,11 +21,44 @@ Video, ) from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.core.message.components import AtAll, Face +from astrbot.core.utils.io import download_image_by_url if TYPE_CHECKING: from .satori_adapter import SatoriPlatformAdapter +TR = TypeVar("TR", bound=Resource) + + +async def _components_to_element( + comp: Image | Record | Video | File, func: Callable[..., TR] +) -> TR: + if hasattr(comp, "url") and comp.url: + return func(url=comp.url) + if not hasattr(comp, "file") or not comp.file: + raise ValueError("No valid file or URL provided") + + if comp.file.startswith("file://"): + path = Path(comp.file[7:]) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + raw_data = path.read_bytes() + return func(raw=raw_data) + if comp.file.startswith("http"): + image_file_path = await download_image_by_url(comp.file) + raw_data = Path(image_file_path).read_bytes() + return func(raw=raw_data) + if comp.file.startswith("base64://"): + bs64_data = comp.file[9:] + return func(raw=b64decode(bs64_data)) + if Path(comp.file).exists(): + raw_data = Path(comp.file).read_bytes() + return func(raw=raw_data) + else: + raise Exception(f"not a valid file: {comp.file}") + + class SatoriPlatformEvent(AstrMessageEvent): def __init__( self, @@ -30,11 +69,10 @@ def __init__( adapter: "SatoriPlatformAdapter", ) -> None: # 更新平台元数据 - if adapter and hasattr(adapter, "logins") and adapter.logins: + if adapter.logins: current_login = adapter.logins[0] - platform_name = current_login.get("platform", "satori") - user = current_login.get("user", {}) - user_id = user.get("id", "") if user else "" + platform_name = current_login.platform or "satori" + user_id = current_login.user.id if current_login.user else None if not platform_meta.id and user_id: platform_meta.id = f"{platform_name}({user_id})" @@ -42,15 +80,13 @@ def __init__( self.adapter = adapter self.platform = None self.user_id = None - if ( - hasattr(message_obj, "raw_message") - and message_obj.raw_message - and isinstance(message_obj.raw_message, dict) - ): + self.referrer = None + if isinstance(message_obj.raw_message, dict): login = message_obj.raw_message.get("login", {}) self.platform = login.get("platform") user = login.get("user", {}) self.user_id = user.get("id") if user else None + self.referrer = message_obj.raw_message.get("referrer") @classmethod async def send_with_adapter( @@ -58,46 +94,43 @@ async def send_with_adapter( adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str, + referrer: dict | None = None, ): try: content_parts = [] for component in message.chain: - component_content = await cls._convert_component_to_satori_static( + component_content = await cls._convert_component_to_satori( component, ) - if component_content: - content_parts.append(component_content) + content_parts.append(component_content) # 特殊处理 Node 和 Nodes 组件 if isinstance(component, Node): # 单个转发节点 - node_content = await cls._convert_node_to_satori_static(component) - if node_content: - content_parts.append(node_content) + node_content = await cls._convert_node_to_satori(component) + content_parts.append(node_content) elif isinstance(component, Nodes): # 合并转发消息 - node_content = await cls._convert_nodes_to_satori_static(component) - if node_content: - content_parts.append(node_content) + node_content = await cls._convert_nodes_to_satori(component) + content_parts.append(node_content) - content = "".join(content_parts) + content = "".join(str(i) for i in content_parts) channel_id = session_id - data = {"channel_id": channel_id, "content": content} + data = {"channel_id": channel_id, "content": content, "referrer": referrer} platform = None user_id = None - if hasattr(adapter, "logins") and adapter.logins: + if adapter.logins: current_login = adapter.logins[0] - platform = current_login.get("platform", "") - user = current_login.get("user", {}) - user_id = user.get("id", "") if user else "" + platform = current_login.platform or "satori" + user_id = current_login.user.id if current_login.user else None result = await adapter.send_http_request( "POST", - "/message.create", + Api.MESSAGE_CREATE, data, platform, user_id, @@ -115,40 +148,40 @@ async def send(self, message: MessageChain) -> None: user_id = getattr(self, "user_id", None) if not platform or not user_id: - if hasattr(self.adapter, "logins") and self.adapter.logins: + if self.adapter.logins: current_login = self.adapter.logins[0] - platform = current_login.get("platform", "") - user = current_login.get("user", {}) - user_id = user.get("id", "") if user else "" + platform = current_login.platform or "satori" + user_id = current_login.user.id if current_login.user else None try: content_parts = [] for component in message.chain: component_content = await self._convert_component_to_satori(component) - if component_content: - content_parts.append(component_content) + content_parts.append(component_content) # 特殊处理 Node 和 Nodes 组件 if isinstance(component, Node): # 单个转发节点 node_content = await self._convert_node_to_satori(component) - if node_content: - content_parts.append(node_content) + content_parts.append(node_content) elif isinstance(component, Nodes): # 合并转发消息 node_content = await self._convert_nodes_to_satori(component) - if node_content: - content_parts.append(node_content) + content_parts.append(node_content) - content = "".join(content_parts) + content = "".join(str(i) for i in content_parts) channel_id = self.session_id - data = {"channel_id": channel_id, "content": content} + data = { + "channel_id": channel_id, + "content": content, + "referrer": self.referrer, + } result = await self.adapter.send_http_request( "POST", - "/message.create", + Api.MESSAGE_CREATE, data, platform, user_id, @@ -183,19 +216,7 @@ async def send_streaming(self, generator, use_fallback: bool = False): temp_chain = MessageChain([Plain(text=content)]) await self.send(temp_chain) content_parts = [] - try: - image_base64 = await component.convert_to_base64() - if image_base64: - img_chain = MessageChain( - [ - Plain( - text=f'', - ), - ], - ) - await self.send(img_chain) - except Exception as e: - logger.error(f"图片转换为base64失败: {e}") + await self.send(MessageChain([component])) else: content_parts.append(str(component)) @@ -209,224 +230,106 @@ async def send_streaming(self, generator, use_fallback: bool = False): return await super().send_streaming(generator, use_fallback) - async def _convert_component_to_satori(self, component) -> str: + @staticmethod + async def _convert_component_to_satori(component) -> Element: """将单个消息组件转换为 Satori 格式""" try: if isinstance(component, Plain): - text = ( - component.text.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - ) - return text + return E.text(component.text) if isinstance(component, At): - if component.qq: - return f'' - if component.name: - return f'' - - elif isinstance(component, Image): - try: - image_base64 = await component.convert_to_base64() - if image_base64: - return f'' - except Exception as e: - logger.error(f"图片转换为base64失败: {e}") - - elif isinstance(component, File): - return ( - f'' + qq = ( + component.qq + if isinstance(component.qq, str) + else str(component.qq) + if isinstance(component.qq, int) + else None ) + if qq: + return E.at(id=qq, name=component.name) + return E.at(name=component.name) - elif isinstance(component, Record): - try: - record_base64 = await component.convert_to_base64() - if record_base64: - return f'