diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index a8ff0fdb90..10c14e66e3 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,7 +2,7 @@ import logging from contextlib import AsyncExitStack from datetime import timedelta -from typing import Generic +from typing import Any, Generic from tenacity import ( before_sleep_log, @@ -13,9 +13,16 @@ ) from astrbot import logger +from astrbot.core.agent.mcp_prompt_bridge import build_mcp_prompt_tool_names from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe +from .mcp_resource_bridge import build_mcp_resource_tool_names +from .mcp_stdio_client import tolerant_stdio_client +from .mcp_subcapability_bridge import ( + MCPClientSubCapabilityBridge, + normalize_mcp_server_config, +) from .run_context import TContext from .tool import FunctionTool @@ -41,7 +48,10 @@ def _prepare_config(config: dict) -> dict: if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] + config = normalize_mcp_server_config(config) config.pop("active", None) + config.pop("client_capabilities", None) + config.pop("provider", None) return config @@ -117,14 +127,22 @@ def __init__(self) -> None: self.name: str | None = None self.active: bool = True self.tools: list[mcp.Tool] = [] + self.prompts: list[mcp.types.Prompt] = [] + self.prompt_bridge_tool_names: list[str] = [] + self.resources: list[mcp.types.Resource] = [] + self.resource_templates: list[mcp.types.ResourceTemplate] = [] + self.resource_templates_supported: bool = False + self.resource_bridge_tool_names: list[str] = [] self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() # Store connection config for reconnection self._mcp_server_config: dict | None = None self._server_name: str | None = None + self._server_capabilities: mcp.types.ServerCapabilities | None = None self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging + self.subcapability_bridge = MCPClientSubCapabilityBridge[Any]() async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server @@ -141,6 +159,8 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: # Store config for reconnection self._mcp_server_config = mcp_server_config self._server_name = name + self.subcapability_bridge.set_server_name(name) + self.subcapability_bridge.configure_from_server_config(mcp_server_config) cfg = _prepare_config(mcp_server_config.copy()) @@ -184,6 +204,22 @@ def logging_callback( *streams, read_timeout_seconds=read_timeout, logging_callback=logging_callback, # type: ignore + sampling_callback=( + self.subcapability_bridge.handle_sampling + if self.subcapability_bridge.sampling_enabled + else None + ), + elicitation_callback=( + self.subcapability_bridge.handle_elicitation + if self.subcapability_bridge.elicitation_enabled + else None + ), + list_roots_callback=( + self.subcapability_bridge.handle_list_roots + if self.subcapability_bridge.roots_enabled + else None + ), + sampling_capabilities=self.subcapability_bridge.get_sampling_capabilities(), ), ) else: @@ -210,6 +246,22 @@ def logging_callback( write_stream=write_s, read_timeout_seconds=read_timeout, logging_callback=logging_callback, # type: ignore + sampling_callback=( + self.subcapability_bridge.handle_sampling + if self.subcapability_bridge.sampling_enabled + else None + ), + elicitation_callback=( + self.subcapability_bridge.handle_elicitation + if self.subcapability_bridge.elicitation_enabled + else None + ), + list_roots_callback=( + self.subcapability_bridge.handle_list_roots + if self.subcapability_bridge.roots_enabled + else None + ), + sampling_capabilities=self.subcapability_bridge.get_sampling_capabilities(), ), ) @@ -232,7 +284,7 @@ def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None: self.server_errlogs.append(log_msg) stdio_transport = await self.exit_stack.enter_async_context( - mcp.stdio_client( + tolerant_stdio_client( server_params, errlog=LogPipe( level=logging.INFO, @@ -245,9 +297,41 @@ def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None: # Create a new client session self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(*stdio_transport), + mcp.ClientSession( + *stdio_transport, + sampling_callback=( + self.subcapability_bridge.handle_sampling + if self.subcapability_bridge.sampling_enabled + else None + ), + elicitation_callback=( + self.subcapability_bridge.handle_elicitation + if self.subcapability_bridge.elicitation_enabled + else None + ), + list_roots_callback=( + self.subcapability_bridge.handle_list_roots + if self.subcapability_bridge.roots_enabled + else None + ), + sampling_capabilities=self.subcapability_bridge.get_sampling_capabilities(), + ), ) await self.session.initialize() + get_server_capabilities = getattr( + self.session, + "get_server_capabilities", + None, + ) + self._server_capabilities = ( + get_server_capabilities() if callable(get_server_capabilities) else None + ) + self.resources = [] + self.resource_templates = [] + self.resource_templates_supported = False + self.prompts = [] + self.prompt_bridge_tool_names = [] + self.resource_bridge_tool_names = [] async def list_tools_and_save(self) -> mcp.ListToolsResult: """List all tools from the server and save them to self.tools""" @@ -257,6 +341,120 @@ async def list_tools_and_save(self) -> mcp.ListToolsResult: self.tools = response.tools return response + @property + def supports_resources(self) -> bool: + return bool(self._server_capabilities and self._server_capabilities.resources) + + @property + def supports_prompts(self) -> bool: + return bool(self._server_capabilities and self._server_capabilities.prompts) + + async def load_resource_capabilities(self) -> None: + self.resources = [] + self.resource_templates = [] + self.resource_templates_supported = False + self.resource_bridge_tool_names = [] + + if not self._server_name or not self.supports_resources: + return + + try: + await self.list_resources_and_save() + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to preload MCP resources for server %s: %s", + self._server_name, + exc, + ) + + try: + await self.list_resource_templates_and_save() + except Exception as exc: # noqa: BLE001 + logger.debug( + "Skipping MCP resource templates for server %s: %s", + self._server_name, + exc, + ) + + self.resource_bridge_tool_names = build_mcp_resource_tool_names( + self._server_name, + include_templates=self.resource_templates_supported, + ) + + async def load_prompt_capabilities(self) -> None: + self.prompts = [] + self.prompt_bridge_tool_names = [] + + if not self._server_name or not self.supports_prompts: + return + + try: + await self.list_prompts_and_save() + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to preload MCP prompts for server %s: %s", + self._server_name, + exc, + ) + + self.prompt_bridge_tool_names = build_mcp_prompt_tool_names( + self._server_name, + ) + + async def list_prompts_and_save( + self, + cursor: str | None = None, + ) -> mcp.types.ListPromptsResult: + if not self.session: + raise ValueError("MCP session is not available for prompt listing.") + + params = ( + mcp.types.PaginatedRequestParams(cursor=cursor) + if cursor is not None + else None + ) + response = await self.session.list_prompts(params=params) + if cursor is None: + self.prompts = response.prompts + return response + + async def list_resources_and_save( + self, + cursor: str | None = None, + ) -> mcp.types.ListResourcesResult: + if not self.session: + raise ValueError("MCP session is not available for resource listing.") + + params = ( + mcp.types.PaginatedRequestParams(cursor=cursor) + if cursor is not None + else None + ) + response = await self.session.list_resources(params=params) + if cursor is None: + self.resources = response.resources + return response + + async def list_resource_templates_and_save( + self, + cursor: str | None = None, + ) -> mcp.types.ListResourceTemplatesResult: + if not self.session: + raise ValueError( + "MCP session is not available for resource template listing." + ) + + params = ( + mcp.types.PaginatedRequestParams(cursor=cursor) + if cursor is not None + else None + ) + response = await self.session.list_resource_templates(params=params) + self.resource_templates_supported = True + if cursor is None: + self.resource_templates = response.resourceTemplates + return response + async def _reconnect(self) -> None: """Reconnect to the MCP server using the stored configuration. @@ -281,6 +479,7 @@ async def _reconnect(self) -> None: logger.info( f"Attempting to reconnect to MCP server {self._server_name}..." ) + self.subcapability_bridge.clear_runtime_state() # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues) if self.exit_stack: @@ -295,6 +494,8 @@ async def _reconnect(self) -> None: # Reconnect using stored config await self.connect_to_server(self._mcp_server_config, self._server_name) await self.list_tools_and_save() + await self.load_resource_capabilities() + await self.load_prompt_capabilities() logger.info( f"Successfully reconnected to MCP server {self._server_name}" @@ -312,6 +513,7 @@ async def call_tool_with_reconnect( tool_name: str, arguments: dict, read_timeout_seconds: timedelta, + run_context: ContextWrapper[Any] | None = None, ) -> mcp.types.CallToolResult: """Call MCP tool with automatic reconnection on failure, max 2 retries. @@ -336,28 +538,102 @@ async def call_tool_with_reconnect( reraise=True, ) async def _call_with_retry(): + async with self.subcapability_bridge.interactive_call(run_context): + if not self.session: + raise ValueError( + "MCP session is not available for MCP function tools." + ) + + try: + return await self.session.call_tool( + name=tool_name, + arguments=arguments, + read_timeout_seconds=read_timeout_seconds, + ) + except anyio.ClosedResourceError: + logger.warning( + f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + ) + # Attempt to reconnect + await self._reconnect() + # Reraise the exception to trigger tenacity retry + raise + + return await _call_with_retry() + + async def read_resource_with_reconnect( + self, + uri: str, + read_timeout_seconds: timedelta, + ) -> mcp.types.ReadResourceResult: + _ = read_timeout_seconds + + @retry( + retry=retry_if_exception_type(anyio.ClosedResourceError), + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=1, min=1, max=3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _read_with_retry(): if not self.session: - raise ValueError("MCP session is not available for MCP function tools.") + raise ValueError("MCP session is not available for MCP resources.") try: - return await self.session.call_tool( - name=tool_name, + return await self.session.read_resource(uri=uri) + except anyio.ClosedResourceError: + logger.warning( + "MCP resource read for %s failed (ClosedResourceError), attempting to reconnect...", + uri, + ) + await self._reconnect() + raise + + return await _read_with_retry() + + async def get_prompt_with_reconnect( + self, + name: str, + arguments: dict[str, str] | None, + read_timeout_seconds: timedelta, + ) -> mcp.types.GetPromptResult: + @retry( + retry=retry_if_exception_type(anyio.ClosedResourceError), + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=1, min=1, max=3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _get_with_retry(): + if not self.session: + raise ValueError("MCP session is not available for MCP prompts.") + + try: + return await self.session.get_prompt( + name=name, arguments=arguments, - read_timeout_seconds=read_timeout_seconds, ) except anyio.ClosedResourceError: logger.warning( - f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + "MCP prompt read for %s failed (ClosedResourceError), attempting to reconnect...", + name, ) - # Attempt to reconnect await self._reconnect() - # Reraise the exception to trigger tenacity retry raise - return await _call_with_retry() + _ = read_timeout_seconds + return await _get_with_retry() async def cleanup(self) -> None: """Clean up resources including old exit stacks from reconnections""" + self.subcapability_bridge.clear_runtime_state() + self._server_capabilities = None + self.prompts = [] + self.prompt_bridge_tool_names = [] + self.resources = [] + self.resource_templates = [] + self.resource_templates_supported = False + self.resource_bridge_tool_names = [] # Close current exit stack try: await self.exit_stack.aclose() @@ -395,4 +671,5 @@ async def call( tool_name=self.mcp_tool.name, arguments=kwargs, read_timeout_seconds=timedelta(seconds=context.tool_call_timeout), + run_context=context, ) diff --git a/astrbot/core/agent/mcp_elicitation_registry.py b/astrbot/core/agent/mcp_elicitation_registry.py new file mode 100644 index 0000000000..767c77bf73 --- /dev/null +++ b/astrbot/core/agent/mcp_elicitation_registry.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +@dataclass(slots=True) +class MCPElicitationReply: + message_text: str + message_outline: str + + +@dataclass(slots=True) +class PendingMCPElicitation: + umo: str + sender_id: str + future: asyncio.Future[MCPElicitationReply] + + +_PENDING_MCP_ELICITATIONS: dict[str, PendingMCPElicitation] = {} + + +@asynccontextmanager +async def pending_mcp_elicitation( + umo: str, + sender_id: str, +) -> AsyncIterator[asyncio.Future[MCPElicitationReply]]: + loop = asyncio.get_running_loop() + future: asyncio.Future[MCPElicitationReply] = loop.create_future() + + current = _PENDING_MCP_ELICITATIONS.get(umo) + if current is not None and not current.future.done(): + raise RuntimeError( + f"Another MCP elicitation is already pending for session {umo}." + ) + + pending = PendingMCPElicitation( + umo=umo, + sender_id=sender_id, + future=future, + ) + _PENDING_MCP_ELICITATIONS[umo] = pending + + try: + yield future + finally: + current = _PENDING_MCP_ELICITATIONS.get(umo) + if current is pending: + _PENDING_MCP_ELICITATIONS.pop(umo, None) + if not future.done(): + future.cancel() + + +def try_capture_pending_mcp_elicitation(event: AstrMessageEvent) -> bool: + pending = _PENDING_MCP_ELICITATIONS.get(event.unified_msg_origin) + if pending is None: + return False + + sender_id = event.get_sender_id() + if not sender_id or sender_id != pending.sender_id: + return False + + if pending.future.done(): + _PENDING_MCP_ELICITATIONS.pop(event.unified_msg_origin, None) + return False + + pending.future.set_result( + MCPElicitationReply( + message_text=event.get_message_str() or "", + message_outline=event.get_message_outline(), + ) + ) + return True + + +def submit_pending_mcp_elicitation_reply( + umo: str, + sender_id: str, + reply_text: str, + *, + reply_outline: str | None = None, +) -> bool: + pending = _PENDING_MCP_ELICITATIONS.get(umo) + if pending is None or pending.sender_id != sender_id: + return False + + if pending.future.done(): + _PENDING_MCP_ELICITATIONS.pop(umo, None) + return False + + pending.future.set_result( + MCPElicitationReply( + message_text=reply_text, + message_outline=reply_outline or reply_text, + ) + ) + return True diff --git a/astrbot/core/agent/mcp_prompt_bridge.py b/astrbot/core/agent/mcp_prompt_bridge.py new file mode 100644 index 0000000000..b086accd5f --- /dev/null +++ b/astrbot/core/agent/mcp_prompt_bridge.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +import re +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Generic + +import mcp + +from astrbot.core.agent.run_context import ContextWrapper, TContext +from astrbot.core.agent.tool import FunctionTool + +if TYPE_CHECKING: + from .mcp_client import MCPClient + + +def build_mcp_prompt_tool_names(server_name: str) -> list[str]: + safe_server_name = _sanitize_tool_name_fragment(server_name) + return [ + f"mcp_{safe_server_name}_list_prompts", + f"mcp_{safe_server_name}_get_prompt", + ] + + +def build_mcp_prompt_tools( + mcp_client: MCPClient, + server_name: str, +) -> list[MCPPromptTool[TContext]]: + if not getattr(mcp_client, "supports_prompts", False): + return [] + + return [ + MCPListPromptsTool( + mcp_client=mcp_client, + mcp_server_name=server_name, + ), + MCPGetPromptTool( + mcp_client=mcp_client, + mcp_server_name=server_name, + ), + ] + + +class MCPPromptTool(FunctionTool, Generic[TContext]): + """Server-scoped synthetic tool for MCP prompts.""" + + def __init__(self, *, name: str, description: str, parameters: dict) -> None: + super().__init__( + name=name, + description=description, + parameters=parameters, + ) + self.mcp_client: MCPClient + self.mcp_server_name: str + + +class MCPListPromptsTool(MCPPromptTool[TContext]): + def __init__(self, *, mcp_client: MCPClient, mcp_server_name: str) -> None: + super().__init__( + name=build_mcp_prompt_tool_names(mcp_server_name)[0], + description=( + f"List MCP prompts exposed by server '{mcp_server_name}'. " + "Use this before getting a specific prompt template." + ), + parameters={ + "type": "object", + "properties": { + "cursor": { + "type": "string", + "description": ( + "Optional pagination cursor returned by a previous " + "prompt listing call." + ), + } + }, + }, + ) + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, + context: ContextWrapper[TContext], + **kwargs, + ) -> mcp.types.CallToolResult: + _ = context + response = await self.mcp_client.list_prompts_and_save( + cursor=kwargs.get("cursor"), + ) + return _text_result( + _format_prompts_listing( + server_name=self.mcp_server_name, + response=response, + ) + ) + + +class MCPGetPromptTool(MCPPromptTool[TContext]): + def __init__(self, *, mcp_client: MCPClient, mcp_server_name: str) -> None: + super().__init__( + name=build_mcp_prompt_tool_names(mcp_server_name)[1], + description=( + f"Get a specific MCP prompt from server '{mcp_server_name}' by " + "name, optionally providing prompt arguments." + ), + parameters={ + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The MCP prompt name to resolve.", + }, + "arguments": { + "type": "object", + "description": ( + "Optional prompt arguments. Keys and values are sent to " + "the MCP server as strings." + ), + "additionalProperties": { + "type": "string", + }, + }, + }, + "required": ["name"], + }, + ) + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, + context: ContextWrapper[TContext], + **kwargs, + ) -> mcp.types.CallToolResult: + read_timeout = timedelta(seconds=context.tool_call_timeout) + name = str(kwargs["name"]) + response = await self.mcp_client.get_prompt_with_reconnect( + name=name, + arguments=_normalize_prompt_arguments(kwargs.get("arguments")), + read_timeout_seconds=read_timeout, + ) + return _text_result( + shape_get_prompt_result( + server_name=self.mcp_server_name, + prompt_name=name, + response=response, + ) + ) + + +def shape_get_prompt_result( + *, + server_name: str, + prompt_name: str, + response: mcp.types.GetPromptResult, +) -> str: + lines = [ + f"MCP prompt from server '{server_name}':", + f"Prompt: {prompt_name}", + ] + if response.description: + lines.append(f"Description: {response.description}") + + if not response.messages: + lines.append("No prompt messages were returned.") + return "\n".join(lines) + + lines.append(f"Returned messages: {len(response.messages)}") + for idx, message in enumerate(response.messages, start=1): + lines.append("") + lines.append(f"Message {idx} ({message.role}):") + lines.extend(_format_prompt_message_content(message.content)) + return "\n".join(lines) + + +def _text_result(text: str) -> mcp.types.CallToolResult: + return mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text=text)] + ) + + +def _format_prompts_listing( + *, + server_name: str, + response: mcp.types.ListPromptsResult, +) -> str: + if not response.prompts: + text = f"No MCP prompts are currently exposed by server '{server_name}'." + if response.nextCursor: + text += f"\nNext cursor: {response.nextCursor}" + return text + + lines = [f"MCP prompts from server '{server_name}':"] + for idx, prompt in enumerate(response.prompts, start=1): + lines.extend(_format_prompt_metadata(idx, prompt)) + if response.nextCursor: + lines.append(f"Next cursor: {response.nextCursor}") + return "\n".join(lines) + + +def _format_prompt_metadata(index: int, prompt: mcp.types.Prompt) -> list[str]: + lines = [f"{index}. {prompt.name}"] + if prompt.title: + lines.append(f" Title: {prompt.title}") + if prompt.description: + lines.append(f" Description: {prompt.description}") + if prompt.arguments: + lines.append(" Arguments:") + for argument in prompt.arguments: + lines.append(_format_prompt_argument(argument)) + return lines + + +def _format_prompt_argument(argument: mcp.types.PromptArgument) -> str: + required_suffix = "required" if argument.required else "optional" + if argument.description: + return f" - {argument.name} ({required_suffix}): {argument.description}" + return f" - {argument.name} ({required_suffix})" + + +def _format_prompt_message_content( + content: mcp.types.ContentBlock, +) -> list[str]: + if isinstance(content, mcp.types.TextContent): + return content.text.splitlines() or [content.text] + if isinstance(content, mcp.types.ImageContent): + return [ + "Image block returned.", + f"MIME type: {content.mimeType}", + f"Base64 length: {len(content.data)}", + ] + if isinstance(content, mcp.types.AudioContent): + return [ + "Audio block returned.", + f"MIME type: {content.mimeType}", + f"Base64 length: {len(content.data)}", + ] + if isinstance(content, mcp.types.EmbeddedResource): + resource = content.resource + if isinstance(resource, mcp.types.TextResourceContents): + lines = [ + "Embedded text resource returned.", + f"URI: {resource.uri}", + ] + if resource.mimeType: + lines.append(f"MIME type: {resource.mimeType}") + lines.append("Text:") + lines.extend(resource.text.splitlines() or [resource.text]) + return lines + if isinstance(resource, mcp.types.BlobResourceContents): + lines = [ + "Embedded binary resource returned.", + f"URI: {resource.uri}", + ] + if resource.mimeType: + lines.append(f"MIME type: {resource.mimeType}") + lines.append(f"Base64 length: {len(resource.blob)}") + return lines + return [f"Unsupported prompt content block: {type(content).__name__}"] + + +def _normalize_prompt_arguments( + raw_arguments: Any, +) -> dict[str, str] | None: + if raw_arguments is None: + return None + if isinstance(raw_arguments, str): + stripped = raw_arguments.strip() + if not stripped: + return None + try: + parsed = json.loads(stripped) + except json.JSONDecodeError: + return None + raw_arguments = parsed + if not isinstance(raw_arguments, dict): + return None + normalized: dict[str, str] = {} + for key, value in raw_arguments.items(): + key_text = str(key).strip() + if not key_text: + continue + normalized[key_text] = "" if value is None else str(value) + return normalized or None + + +def _sanitize_tool_name_fragment(name: str) -> str: + sanitized = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_").lower() + return sanitized or "server" diff --git a/astrbot/core/agent/mcp_resource_bridge.py b/astrbot/core/agent/mcp_resource_bridge.py new file mode 100644 index 0000000000..3663e44a60 --- /dev/null +++ b/astrbot/core/agent/mcp_resource_bridge.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +import re +from datetime import timedelta +from typing import TYPE_CHECKING, Generic + +import mcp + +from astrbot.core.agent.run_context import ContextWrapper, TContext +from astrbot.core.agent.tool import FunctionTool + +if TYPE_CHECKING: + from .mcp_client import MCPClient + + +def build_mcp_resource_tool_names( + server_name: str, + *, + include_templates: bool, +) -> list[str]: + safe_server_name = _sanitize_tool_name_fragment(server_name) + names = [ + f"mcp_{safe_server_name}_list_resources", + f"mcp_{safe_server_name}_read_resource", + ] + if include_templates: + names.append(f"mcp_{safe_server_name}_list_resource_templates") + return names + + +def build_mcp_resource_tools( + mcp_client: MCPClient, + server_name: str, +) -> list[MCPResourceTool[TContext]]: + if not getattr(mcp_client, "supports_resources", False): + return [] + + tools: list[MCPResourceTool[TContext]] = [ + MCPListResourcesTool( + mcp_client=mcp_client, + mcp_server_name=server_name, + ), + MCPReadResourceTool( + mcp_client=mcp_client, + mcp_server_name=server_name, + ), + ] + if mcp_client.resource_templates_supported: + tools.append( + MCPListResourceTemplatesTool( + mcp_client=mcp_client, + mcp_server_name=server_name, + ) + ) + return tools + + +class MCPResourceTool(FunctionTool, Generic[TContext]): + """Server-scoped synthetic tool for MCP resources.""" + + def __init__(self, *, name: str, description: str, parameters: dict) -> None: + super().__init__( + name=name, + description=description, + parameters=parameters, + ) + self.mcp_client: MCPClient + self.mcp_server_name: str + + +class MCPListResourcesTool(MCPResourceTool[TContext]): + def __init__(self, *, mcp_client: MCPClient, mcp_server_name: str) -> None: + super().__init__( + name=build_mcp_resource_tool_names( + mcp_server_name, + include_templates=False, + )[0], + description=( + f"List readable MCP resources exposed by server '{mcp_server_name}'. " + "Use this before reading a specific resource URI." + ), + parameters={ + "type": "object", + "properties": { + "cursor": { + "type": "string", + "description": ( + "Optional pagination cursor returned by a previous " + "resource listing call." + ), + } + }, + }, + ) + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, + context: ContextWrapper[TContext], + **kwargs, + ) -> mcp.types.CallToolResult: + _ = context + response = await self.mcp_client.list_resources_and_save( + cursor=kwargs.get("cursor"), + ) + return _text_result( + _format_resources_listing( + server_name=self.mcp_server_name, + response=response, + ) + ) + + +class MCPListResourceTemplatesTool(MCPResourceTool[TContext]): + def __init__(self, *, mcp_client: MCPClient, mcp_server_name: str) -> None: + super().__init__( + name=build_mcp_resource_tool_names( + mcp_server_name, + include_templates=True, + )[2], + description=( + f"List MCP resource URI templates exposed by server " + f"'{mcp_server_name}'. Use the returned URI patterns to construct " + "resource URIs for read_resource." + ), + parameters={ + "type": "object", + "properties": { + "cursor": { + "type": "string", + "description": ( + "Optional pagination cursor returned by a previous " + "resource template listing call." + ), + } + }, + }, + ) + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, + context: ContextWrapper[TContext], + **kwargs, + ) -> mcp.types.CallToolResult: + _ = context + response = await self.mcp_client.list_resource_templates_and_save( + cursor=kwargs.get("cursor"), + ) + return _text_result( + _format_resource_templates_listing( + server_name=self.mcp_server_name, + response=response, + ) + ) + + +class MCPReadResourceTool(MCPResourceTool[TContext]): + def __init__(self, *, mcp_client: MCPClient, mcp_server_name: str) -> None: + super().__init__( + name=build_mcp_resource_tool_names( + mcp_server_name, + include_templates=False, + )[1], + description=( + f"Read a specific MCP resource from server '{mcp_server_name}' by " + "its URI." + ), + parameters={ + "type": "object", + "properties": { + "uri": { + "type": "string", + "description": "The MCP resource URI to read.", + } + }, + "required": ["uri"], + }, + ) + self.mcp_client = mcp_client + self.mcp_server_name = mcp_server_name + + async def call( + self, + context: ContextWrapper[TContext], + **kwargs, + ) -> mcp.types.CallToolResult: + read_timeout = timedelta(seconds=context.tool_call_timeout) + uri = str(kwargs["uri"]) + response = await self.mcp_client.read_resource_with_reconnect( + uri=uri, + read_timeout_seconds=read_timeout, + ) + return shape_read_resource_result( + server_name=self.mcp_server_name, + requested_uri=uri, + response=response, + ) + + +def shape_read_resource_result( + *, + server_name: str, + requested_uri: str, + response: mcp.types.ReadResourceResult, +) -> mcp.types.CallToolResult: + contents = response.contents + if not contents: + return _text_result( + f"MCP server '{server_name}' returned no contents for resource " + f"'{requested_uri}'." + ) + + if len(contents) == 1: + content = contents[0] + if isinstance(content, mcp.types.TextResourceContents): + return _text_result(_format_single_text_resource(server_name, content)) + if ( + isinstance(content, mcp.types.BlobResourceContents) + and content.mimeType + and content.mimeType.startswith("image/") + ): + return mcp.types.CallToolResult( + content=[ + mcp.types.EmbeddedResource( + type="resource", + resource=content, + ) + ] + ) + + return _text_result( + _format_multi_part_resource( + server_name=server_name, + requested_uri=requested_uri, + contents=contents, + ) + ) + + +def _text_result(text: str) -> mcp.types.CallToolResult: + return mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text=text)] + ) + + +def _format_resources_listing( + *, + server_name: str, + response: mcp.types.ListResourcesResult, +) -> str: + if not response.resources: + text = f"No MCP resources are currently exposed by server '{server_name}'." + if response.nextCursor: + text += f"\nNext cursor: {response.nextCursor}" + return text + + lines = [f"MCP resources from server '{server_name}':"] + for idx, resource in enumerate(response.resources, start=1): + lines.extend(_format_resource_metadata(idx, resource)) + if response.nextCursor: + lines.append(f"Next cursor: {response.nextCursor}") + return "\n".join(lines) + + +def _format_resource_templates_listing( + *, + server_name: str, + response: mcp.types.ListResourceTemplatesResult, +) -> str: + if not response.resourceTemplates: + text = ( + f"No MCP resource templates are currently exposed by server " + f"'{server_name}'." + ) + if response.nextCursor: + text += f"\nNext cursor: {response.nextCursor}" + return text + + lines = [f"MCP resource templates from server '{server_name}':"] + for idx, template in enumerate(response.resourceTemplates, start=1): + lines.extend(_format_resource_template_metadata(idx, template)) + if response.nextCursor: + lines.append(f"Next cursor: {response.nextCursor}") + return "\n".join(lines) + + +def _format_single_text_resource( + server_name: str, + content: mcp.types.TextResourceContents, +) -> str: + lines = [ + f"MCP text resource from server '{server_name}':", + f"URI: {content.uri}", + ] + if content.mimeType: + lines.append(f"MIME type: {content.mimeType}") + lines.extend(["", content.text]) + return "\n".join(lines) + + +def _format_multi_part_resource( + *, + server_name: str, + requested_uri: str, + contents: list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents], +) -> str: + lines = [ + f"MCP resource read result from server '{server_name}':", + f"Requested URI: {requested_uri}", + f"Returned parts: {len(contents)}", + ] + for idx, content in enumerate(contents, start=1): + lines.append("") + lines.append(f"Part {idx}:") + lines.append(f"URI: {content.uri}") + if content.mimeType: + lines.append(f"MIME type: {content.mimeType}") + if isinstance(content, mcp.types.TextResourceContents): + lines.append("Text:") + lines.append(content.text) + else: + lines.append(f"Binary blob returned (base64 length: {len(content.blob)}).") + return "\n".join(lines) + + +def _format_resource_metadata( + index: int, + resource: mcp.types.Resource, +) -> list[str]: + lines = [f"{index}. {resource.name}", f" URI: {resource.uri}"] + if resource.title: + lines.append(f" Title: {resource.title}") + if resource.description: + lines.append(f" Description: {resource.description}") + if resource.mimeType: + lines.append(f" MIME type: {resource.mimeType}") + if resource.size is not None: + lines.append(f" Size: {resource.size} bytes") + return lines + + +def _format_resource_template_metadata( + index: int, + template: mcp.types.ResourceTemplate, +) -> list[str]: + lines = [f"{index}. {template.name}", f" URI template: {template.uriTemplate}"] + if template.title: + lines.append(f" Title: {template.title}") + if template.description: + lines.append(f" Description: {template.description}") + if template.mimeType: + lines.append(f" MIME type: {template.mimeType}") + return lines + + +def _sanitize_tool_name_fragment(name: str) -> str: + sanitized = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_").lower() + return sanitized or "server" diff --git a/astrbot/core/agent/mcp_stdio_client.py b/astrbot/core/agent/mcp_stdio_client.py new file mode 100644 index 0000000000..e0433106d8 --- /dev/null +++ b/astrbot/core/agent/mcp_stdio_client.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import logging +import sys +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, TextIO + +import anyio +import anyio.lowlevel +import mcp.types as types +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.text import TextReceiveStream +from mcp.client.stdio import ( + PROCESS_TERMINATION_TIMEOUT, + _create_platform_compatible_process, + _get_executable_command, + _terminate_process_tree, + get_default_environment, +) +from mcp.shared.message import SessionMessage + +from astrbot import logger + +if TYPE_CHECKING: + import mcp + + +def _normalize_stdout_line(line: str) -> str: + return line.rstrip("\r") + + +def _should_ignore_stdout_line(line: str) -> bool: + stripped = _normalize_stdout_line(line).strip() + if not stripped: + return True + + # JSON-RPC messages are serialized as JSON objects. Wrapper banners from + # tools such as npm/pnpm/yarn should not abort the session. + return not stripped.startswith("{") + + +@asynccontextmanager +async def tolerant_stdio_client( + server: mcp.StdioServerParameters, + errlog: TextIO = sys.stderr, +): + """A stdio MCP transport that ignores obvious non-protocol stdout noise.""" + + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + try: + command = _get_executable_command(server.command) + process = await _create_platform_compatible_process( + command=command, + args=server.args, + env=( + {**get_default_environment(), **server.env} + if server.env is not None + else get_default_environment() + ), + errlog=errlog, + cwd=server.cwd, + ) + except OSError: + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + raise + + async def stdout_reader(): + assert process.stdout, "Opened process is missing stdout" + + try: + async with read_stream_writer: + buffer = "" + async for chunk in TextReceiveStream( + process.stdout, + encoding=server.encoding, + errors=server.encoding_error_handler, + ): + lines = (buffer + chunk).split("\n") + buffer = lines.pop() + + for raw_line in lines: + line = _normalize_stdout_line(raw_line) + if _should_ignore_stdout_line(line): + if line.strip(): + logger.debug( + "Ignoring non-JSON stdout line from MCP stdio server: %s", + line.strip(), + ) + continue + + try: + message = types.JSONRPCMessage.model_validate_json( + line.strip() + ) + except Exception as exc: # pragma: no cover + logging.getLogger("mcp.client.stdio").exception( + "Failed to parse JSONRPC message from server" + ) + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(SessionMessage(message)) + except anyio.ClosedResourceError: # pragma: no cover + await anyio.lowlevel.checkpoint() + + async def stdin_writer(): + assert process.stdin, "Opened process is missing stdin" + + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + payload = session_message.message.model_dump_json( + by_alias=True, + exclude_none=True, + ) + await process.stdin.send( + (payload + "\n").encode( + encoding=server.encoding, + errors=server.encoding_error_handler, + ) + ) + except anyio.ClosedResourceError: # pragma: no cover + await anyio.lowlevel.checkpoint() + + async with ( + anyio.create_task_group() as tg, + process, + ): + tg.start_soon(stdout_reader) + tg.start_soon(stdin_writer) + try: + yield read_stream, write_stream + finally: + if process.stdin: # pragma: no branch + try: + await process.stdin.aclose() + except Exception: # pragma: no cover + pass + + try: + with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): + await process.wait() + except TimeoutError: + await _terminate_process_tree(process) + except ProcessLookupError: # pragma: no cover + pass + + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() diff --git a/astrbot/core/agent/mcp_subcapability_bridge.py b/astrbot/core/agent/mcp_subcapability_bridge.py new file mode 100644 index 0000000000..9b92b584fa --- /dev/null +++ b/astrbot/core/agent/mcp_subcapability_bridge.py @@ -0,0 +1,1302 @@ +from __future__ import annotations + +import asyncio +import copy +import json +import re +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generic + +from astrbot import logger +from astrbot.core.agent.mcp_elicitation_registry import pending_mcp_elicitation +from astrbot.core.agent.run_context import ContextWrapper, TContext +from astrbot.core.message.components import Json +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.utils.astrbot_path import ( + get_astrbot_backups_path, + get_astrbot_config_path, + get_astrbot_data_path, + get_astrbot_knowledge_base_path, + get_astrbot_plugin_data_path, + get_astrbot_plugin_path, + get_astrbot_root, + get_astrbot_skills_path, + get_astrbot_temp_path, +) + +if TYPE_CHECKING: + import mcp + + +DEFAULT_MCP_CLIENT_CAPABILITIES = { + "elicitation": { + "enabled": False, + "timeout_seconds": 300, + }, + "sampling": { + "enabled": False, + }, + "roots": { + "enabled": False, + "paths": [], + }, +} + +DEFAULT_MCP_ROOT_PATHS = ("data", "temp") +DEFAULT_MCP_ELICITATION_TIMEOUT_SECONDS = 300 +MCP_ELICITATION_ACCEPT_KEYWORDS = { + "accept", + "done", + "ok", + "okay", + "yes", + "完成", + "已完成", + "同意", +} +MCP_ELICITATION_DECLINE_KEYWORDS = { + "decline", + "reject", + "refuse", + "no", + "拒绝", + "不同意", +} +MCP_ELICITATION_CANCEL_KEYWORDS = { + "cancel", + "stop", + "退出", + "取消", +} + + +def get_root_path_alias_resolvers(): + return { + "root": get_astrbot_root, + "data": get_astrbot_data_path, + "config": get_astrbot_config_path, + "plugins": get_astrbot_plugin_path, + "plugin_data": get_astrbot_plugin_data_path, + "temp": get_astrbot_temp_path, + "skills": get_astrbot_skills_path, + "knowledge_base": get_astrbot_knowledge_base_path, + "backups": get_astrbot_backups_path, + } + + +class UnsupportedSamplingRequestError(ValueError): + """Raised when a sampling request cannot be safely mapped.""" + + +class UnsupportedElicitationRequestError(ValueError): + """Raised when an elicitation request cannot be safely mapped.""" + + +@dataclass(slots=True) +class MCPElicitationCapabilityConfig: + enabled: bool = False + timeout_seconds: int = DEFAULT_MCP_ELICITATION_TIMEOUT_SECONDS + + +@dataclass(slots=True) +class MCPSamplingCapabilityConfig: + enabled: bool = False + + +@dataclass(slots=True) +class MCPRootsCapabilityConfig: + enabled: bool = False + paths: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class MCPClientCapabilitiesConfig: + elicitation: MCPElicitationCapabilityConfig + sampling: MCPSamplingCapabilityConfig + roots: MCPRootsCapabilityConfig + + @classmethod + def from_server_config( + cls, server_config: dict[str, Any] | None + ) -> MCPClientCapabilitiesConfig: + normalized = normalize_mcp_server_config(server_config or {}) + elicitation_cfg = normalized["client_capabilities"]["elicitation"] + sampling_cfg = normalized["client_capabilities"]["sampling"] + roots_cfg = normalized["client_capabilities"]["roots"] + return cls( + elicitation=MCPElicitationCapabilityConfig( + enabled=bool(elicitation_cfg.get("enabled", False)), + timeout_seconds=int( + elicitation_cfg.get( + "timeout_seconds", + DEFAULT_MCP_ELICITATION_TIMEOUT_SECONDS, + ) + ), + ), + sampling=MCPSamplingCapabilityConfig( + enabled=bool(sampling_cfg.get("enabled", False)), + ), + roots=MCPRootsCapabilityConfig( + enabled=bool(roots_cfg.get("enabled", False)), + paths=list(roots_cfg.get("paths", [])), + ), + ) + + +def normalize_mcp_server_config(server_config: dict[str, Any]) -> dict[str, Any]: + """Normalize persisted MCP server config fields for backward compatibility.""" + normalized = copy.deepcopy(server_config) + + client_capabilities = normalized.get("client_capabilities") + if not isinstance(client_capabilities, dict): + client_capabilities = {} + + elicitation_cfg = client_capabilities.get("elicitation") + if isinstance(elicitation_cfg, bool): + elicitation_cfg = {"enabled": elicitation_cfg} + elif not isinstance(elicitation_cfg, dict): + elicitation_cfg = {} + + sampling_cfg = client_capabilities.get("sampling") + if isinstance(sampling_cfg, bool): + sampling_cfg = {"enabled": sampling_cfg} + elif not isinstance(sampling_cfg, dict): + sampling_cfg = {} + + roots_cfg = client_capabilities.get("roots") + if isinstance(roots_cfg, bool): + roots_cfg = {"enabled": roots_cfg} + elif not isinstance(roots_cfg, dict): + roots_cfg = {} + + raw_root_paths = roots_cfg.get("paths", []) + if not isinstance(raw_root_paths, list): + raw_root_paths = [] + normalized_root_paths = [ + str(path).strip() + for path in raw_root_paths + if isinstance(path, str) and path.strip() + ] + + client_capabilities["elicitation"] = { + "enabled": bool(elicitation_cfg.get("enabled", False)), + "timeout_seconds": _normalize_positive_int( + elicitation_cfg.get( + "timeout_seconds", + DEFAULT_MCP_ELICITATION_TIMEOUT_SECONDS, + ), + DEFAULT_MCP_ELICITATION_TIMEOUT_SECONDS, + ), + } + client_capabilities["sampling"] = { + "enabled": bool(sampling_cfg.get("enabled", False)), + } + client_capabilities["roots"] = { + "enabled": bool(roots_cfg.get("enabled", False)), + "paths": normalized_root_paths, + } + normalized["client_capabilities"] = client_capabilities + return normalized + + +def _normalize_positive_int(value: Any, default: int) -> int: + if isinstance(value, bool): + return default + try: + normalized = int(value) + except (TypeError, ValueError): + return default + if normalized <= 0: + return default + return normalized + + +def normalize_mcp_config(config: dict[str, Any] | None) -> dict[str, Any]: + """Normalize the full MCP configuration file structure.""" + normalized = {"mcpServers": {}} + if not isinstance(config, dict): + return normalized + + raw_servers = config.get("mcpServers", {}) + if not isinstance(raw_servers, dict): + return normalized + + for name, server_config in raw_servers.items(): + if not isinstance(server_config, dict): + continue + normalized["mcpServers"][name] = normalize_mcp_server_config(server_config) + return normalized + + +class MCPClientSubCapabilityBridge(Generic[TContext]): + """Bridge MCP client sub-capability requests into AstrBot runtime calls.""" + + def __init__(self, server_name: str | None = None) -> None: + self._server_name = server_name or "" + self._capabilities = MCPClientCapabilitiesConfig.from_server_config({}) + self._interaction_lock = asyncio.Lock() + self._active_run_context: ContextWrapper[TContext] | None = None + + def configure_from_server_config(self, server_config: dict[str, Any]) -> None: + self._capabilities = MCPClientCapabilitiesConfig.from_server_config( + server_config + ) + + def set_server_name(self, server_name: str | None) -> None: + if server_name: + self._server_name = server_name + + @property + def sampling_enabled(self) -> bool: + return self._capabilities.sampling.enabled + + @property + def elicitation_enabled(self) -> bool: + return self._capabilities.elicitation.enabled + + @property + def elicitation_timeout_seconds(self) -> int: + return self._capabilities.elicitation.timeout_seconds + + @property + def roots_enabled(self) -> bool: + return self._capabilities.roots.enabled + + def get_sampling_capabilities(self) -> mcp.types.SamplingCapability | None: + if not self.sampling_enabled: + return None + + import mcp + + return mcp.types.SamplingCapability() + + async def handle_list_roots( + self, + _request_context: Any, + ) -> mcp.types.ListRootsResult | mcp.types.ErrorData: + import mcp + + if not self.roots_enabled: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message="Roots are not enabled for this MCP server.", + ) + + try: + return mcp.types.ListRootsResult(roots=self._build_root_entries()) + except Exception as exc: # noqa: BLE001 + logger.error( + "Roots request failed for MCP server %s: %s", + self._server_name, + exc, + exc_info=True, + ) + return mcp.types.ErrorData( + code=mcp.types.INTERNAL_ERROR, + message="Roots request failed inside AstrBot.", + data=str(exc), + ) + + def clear_runtime_state(self) -> None: + self._active_run_context = None + + @asynccontextmanager + async def interactive_call( + self, + run_context: ContextWrapper[TContext] | None, + ): + if not (self.sampling_enabled or self.elicitation_enabled): + yield + return + + async with self._interaction_lock: + self._active_run_context = run_context + try: + yield + finally: + self._active_run_context = None + + async def handle_sampling( + self, + _request_context: Any, + params: mcp.types.CreateMessageRequestParams, + ) -> ( + mcp.types.CreateMessageResult + | mcp.types.CreateMessageResultWithTools + | mcp.types.ErrorData + ): + import mcp + + if not self.sampling_enabled: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message="Sampling is not enabled for this MCP server.", + ) + + run_context = self._active_run_context + if run_context is None: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message=( + "Sampling requests are only supported during an active AstrBot " + "MCP interaction." + ), + ) + + try: + return await self._execute_sampling(run_context, params) + except UnsupportedSamplingRequestError as exc: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message=str(exc), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Sampling request failed for MCP server %s: %s", + self._server_name, + exc, + exc_info=True, + ) + return mcp.types.ErrorData( + code=mcp.types.INTERNAL_ERROR, + message="Sampling request failed inside AstrBot.", + data=str(exc), + ) + + async def handle_elicitation( + self, + _request_context: Any, + params: mcp.types.ElicitRequestParams, + ) -> mcp.types.ElicitResult | mcp.types.ErrorData: + import mcp + + if not self.elicitation_enabled: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message="Elicitation is not enabled for this MCP server.", + ) + + run_context = self._active_run_context + if run_context is None: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message=( + "Elicitation requests are only supported during an active AstrBot " + "MCP interaction." + ), + ) + + try: + return await self._execute_elicitation(run_context, params) + except UnsupportedElicitationRequestError as exc: + return mcp.types.ErrorData( + code=mcp.types.INVALID_REQUEST, + message=str(exc), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Elicitation request failed for MCP server %s: %s", + self._server_name, + exc, + exc_info=True, + ) + return mcp.types.ErrorData( + code=mcp.types.INTERNAL_ERROR, + message="Elicitation request failed inside AstrBot.", + data=str(exc), + ) + + async def _execute_sampling( + self, + run_context: ContextWrapper[TContext], + params: mcp.types.CreateMessageRequestParams, + ) -> mcp.types.CreateMessageResult: + import mcp + + plugin_context, event = self._extract_bound_runtime(run_context) + if plugin_context is None or event is None: + raise UnsupportedSamplingRequestError( + "Sampling requires an AstrBot agent context bound to the MCP tool call." + ) + + if params.includeContext not in (None, "none"): + raise UnsupportedSamplingRequestError( + "Sampling includeContext is not supported in the initial AstrBot integration." + ) + + if params.tools or params.toolChoice: + raise UnsupportedSamplingRequestError( + "Tool-assisted sampling is not supported in the initial AstrBot integration." + ) + + contexts = self._translate_sampling_messages(params.messages) + umo = getattr(event, "unified_msg_origin", None) + if not isinstance(umo, str) or not umo: + raise UnsupportedSamplingRequestError( + "Sampling requires a valid unified message origin." + ) + + provider_id = await plugin_context.get_current_chat_provider_id(umo) + provider = plugin_context.get_using_provider(umo) + if provider is None: + raise UnsupportedSamplingRequestError( + "Sampling requires an active chat provider." + ) + + provider_kwargs: dict[str, Any] = {"max_tokens": params.maxTokens} + if params.temperature is not None: + provider_kwargs["temperature"] = params.temperature + if params.stopSequences: + provider_kwargs["stop"] = params.stopSequences + provider_kwargs["stopSequences"] = params.stopSequences + if params.metadata: + provider_kwargs["metadata"] = params.metadata + + llm_resp = await plugin_context.llm_generate( + chat_provider_id=provider_id, + contexts=contexts, + system_prompt=params.systemPrompt or "", + **provider_kwargs, + ) + + if llm_resp.role == "err": + raise RuntimeError(llm_resp.completion_text or "Provider returned error") + if llm_resp.tools_call_args: + raise UnsupportedSamplingRequestError( + "Tool-assisted sampling responses are not supported in the initial AstrBot integration." + ) + + text = llm_resp.completion_text + if text is None: + raise RuntimeError("Provider returned no textual sampling result") + + model_name = provider.get_model() or provider.meta().model or provider.meta().id + return mcp.types.CreateMessageResult( + role="assistant", + content=mcp.types.TextContent(type="text", text=text), + model=model_name, + stopReason="endTurn", + ) + + @staticmethod + def _extract_bound_runtime( + run_context: ContextWrapper[TContext], + ) -> tuple[Any | None, Any | None]: + agent_context = getattr(run_context, "context", None) + plugin_context = getattr(agent_context, "context", None) + event = getattr(agent_context, "event", None) + return plugin_context, event + + async def _execute_elicitation( + self, + run_context: ContextWrapper[TContext], + params: mcp.types.ElicitRequestParams, + ) -> mcp.types.ElicitResult: + import mcp + + plugin_context, event = self._extract_bound_runtime(run_context) + if event is None: + raise UnsupportedElicitationRequestError( + "Elicitation requires an AstrBot event bound to the MCP tool call." + ) + + sender_id = event.get_sender_id() + if not sender_id: + raise UnsupportedElicitationRequestError( + "Elicitation requires a stable sender ID." + ) + + if isinstance(params, mcp.types.ElicitRequestFormParams): + return await self._execute_form_elicitation( + plugin_context, + event, + sender_id, + params, + ) + if isinstance(params, mcp.types.ElicitRequestURLParams): + return await self._execute_url_elicitation( + event, + sender_id, + params, + ) + raise UnsupportedElicitationRequestError( + f"Unsupported elicitation params type: {type(params).__name__}" + ) + + @staticmethod + def _translate_sampling_messages( + messages: list[mcp.types.SamplingMessage], + ) -> list[dict[str, str]]: + translated: list[dict[str, str]] = [] + for message in messages: + text = MCPClientSubCapabilityBridge._sampling_message_to_text(message) + translated.append( + { + "role": message.role, + "content": text, + } + ) + return translated + + @staticmethod + def _sampling_message_to_text(message: mcp.types.SamplingMessage) -> str: + import mcp + + text_parts: list[str] = [] + for block in message.content_as_list: + if isinstance(block, mcp.types.TextContent): + text_parts.append(block.text) + continue + + if isinstance(block, mcp.types.ImageContent): + raise UnsupportedSamplingRequestError( + "Image sampling inputs are not supported in the initial AstrBot integration." + ) + if isinstance(block, mcp.types.AudioContent): + raise UnsupportedSamplingRequestError( + "Audio sampling inputs are not supported in the initial AstrBot integration." + ) + + raise UnsupportedSamplingRequestError( + f"Sampling content block '{type(block).__name__}' is not supported in the initial AstrBot integration." + ) + + return "\n".join(text_parts) + + async def _execute_form_elicitation( + self, + plugin_context: Any, + event: Any, + sender_id: str, + params: mcp.types.ElicitRequestFormParams, + ) -> mcp.types.ElicitResult: + import mcp + + properties = self._get_elicitation_properties(params.requestedSchema) + deadline = asyncio.get_running_loop().time() + self.elicitation_timeout_seconds + await self._send_elicitation_message( + event, + self._build_form_elicitation_prompt(params, properties), + payload=self._build_form_elicitation_payload(params, properties), + ) + + while True: + reply_text = await self._wait_for_elicitation_reply( + event=event, + sender_id=sender_id, + deadline=deadline, + ) + if reply_text is None: + return mcp.types.ElicitResult(action="cancel") + + action = self._parse_cancel_or_decline_action(reply_text) + if action is not None: + return mcp.types.ElicitResult(action=action) + + try: + content = self._parse_form_elicitation_reply( + requested_schema=params.requestedSchema, + reply_text=reply_text, + ) + except UnsupportedElicitationRequestError as exc: + content = await self._try_llm_form_reply_fallback( + plugin_context=plugin_context, + event=event, + params=params, + reply_text=reply_text, + direct_parse_error=exc, + ) + if content is not None: + return mcp.types.ElicitResult( + action="accept", + content=content, + ) + await self._send_elicitation_message( + event, + self._build_form_retry_prompt(exc), + ) + continue + + return mcp.types.ElicitResult( + action="accept", + content=content, + ) + + async def _try_llm_form_reply_fallback( + self, + *, + plugin_context: Any, + event: Any, + params: mcp.types.ElicitRequestFormParams, + reply_text: str, + direct_parse_error: UnsupportedElicitationRequestError, + ) -> dict[str, str | int | float | bool | list[str] | None] | None: + if plugin_context is None: + return None + + umo = getattr(event, "unified_msg_origin", None) + if not isinstance(umo, str) or not umo: + return None + + try: + provider_id = await plugin_context.get_current_chat_provider_id(umo) + except Exception as exc: # noqa: BLE001 + logger.debug( + "Unable to resolve provider for MCP elicitation fallback on %s: %s", + self._server_name, + exc, + ) + return None + + prompt = self._build_elicitation_llm_fallback_prompt( + params=params, + reply_text=reply_text, + direct_parse_error=direct_parse_error, + ) + try: + llm_resp = await plugin_context.llm_generate( + chat_provider_id=provider_id, + prompt=prompt, + system_prompt=self._build_elicitation_llm_fallback_system_prompt(), + max_tokens=256, + ) + except Exception as exc: # noqa: BLE001 + logger.debug( + "LLM fallback failed during MCP elicitation for %s: %s", + self._server_name, + exc, + ) + return None + + if getattr(llm_resp, "role", None) == "err": + logger.debug( + "Provider returned error during MCP elicitation fallback for %s: %s", + self._server_name, + getattr(llm_resp, "completion_text", "") or "", + ) + return None + + raw_text = getattr(llm_resp, "completion_text", "") or "" + normalized = self._strip_code_fence(raw_text).strip() + if not normalized: + return None + + try: + payload = json.loads(normalized) + except json.JSONDecodeError: + logger.debug( + "LLM fallback returned non-JSON content during MCP elicitation for %s: %s", + self._server_name, + normalized, + ) + return None + + if not isinstance(payload, dict): + return None + + try: + return self._coerce_form_payload(payload, params.requestedSchema) + except UnsupportedElicitationRequestError as exc: + logger.debug( + "LLM fallback returned invalid MCP elicitation payload for %s: %s", + self._server_name, + exc, + ) + return None + + async def _execute_url_elicitation( + self, + event: Any, + sender_id: str, + params: mcp.types.ElicitRequestURLParams, + ) -> mcp.types.ElicitResult: + import mcp + + deadline = asyncio.get_running_loop().time() + self.elicitation_timeout_seconds + await self._send_elicitation_message( + event, + self._build_url_elicitation_prompt(params), + payload=self._build_url_elicitation_payload(params), + ) + + while True: + reply_text = await self._wait_for_elicitation_reply( + event=event, + sender_id=sender_id, + deadline=deadline, + ) + if reply_text is None: + return mcp.types.ElicitResult(action="cancel") + + action = self._parse_url_action(reply_text) + if action is not None: + return mcp.types.ElicitResult(action=action) + + await self._send_elicitation_message( + event, + "Please reply `done`, `decline`, or `cancel` to continue this MCP request.", + ) + + async def _send_elicitation_message( + self, + event: Any, + message: str, + *, + payload: dict[str, Any] | None = None, + ) -> None: + if payload and self._is_webchat_event(event): + try: + await event.send( + MessageChain( + chain=[Json(data=payload)], + type="elicitation", + ) + ) + return + except Exception as exc: # noqa: BLE001 + logger.debug( + "Falling back to plain-text MCP elicitation message for %s: %s", + self._server_name, + exc, + ) + + await event.send(MessageChain().message(message)) + + async def _wait_for_elicitation_reply( + self, + *, + event: Any, + sender_id: str, + deadline: float, + ) -> str | None: + remaining = deadline - asyncio.get_running_loop().time() + if remaining <= 0: + return None + + try: + async with pending_mcp_elicitation( + event.unified_msg_origin, + sender_id, + ) as future: + reply = await asyncio.wait_for(future, timeout=remaining) + except asyncio.TimeoutError: + return None + + reply_text = reply.message_text.strip() + if reply_text: + return self._strip_code_fence(reply_text) + return reply.message_outline.strip() + + def _build_form_elicitation_prompt( + self, + params: mcp.types.ElicitRequestFormParams, + properties: dict[str, dict[str, Any]], + ) -> str: + required_fields = set( + self._get_required_elicitation_fields(params.requestedSchema) + ) + lines = [f"MCP server `{self._server_name}` needs more information."] + if params.message.strip(): + lines.append(params.message.strip()) + if properties: + lines.append("Requested fields:") + for field_name, schema in properties.items(): + field_type = self._get_elicitation_field_type(schema) + desc = str(schema.get("description", "")).strip() + suffix = " required" if field_name in required_fields else " optional" + if desc: + lines.append(f"- {field_name} ({field_type},{suffix}): {desc}") + else: + lines.append(f"- {field_name} ({field_type},{suffix})") + if len(properties) == 1: + lines.append("Reply with plain text or JSON.") + elif len(properties) > 1: + lines.append("Reply with JSON or `field: value` lines.") + else: + lines.append("Reply `accept` to continue.") + lines.append("Reply `decline` to refuse or `cancel` to stop.") + return "\n".join(lines) + + def _build_form_elicitation_payload( + self, + params: mcp.types.ElicitRequestFormParams, + properties: dict[str, dict[str, Any]], + ) -> dict[str, Any]: + required_fields = set( + self._get_required_elicitation_fields(params.requestedSchema) + ) + fields: list[dict[str, Any]] = [] + for field_name, schema in properties.items(): + enum_values = schema.get("enum") + fields.append( + { + "name": field_name, + "label": str(schema.get("title") or field_name), + "description": str(schema.get("description", "")).strip(), + "required": field_name in required_fields, + "type": self._get_elicitation_field_type(schema), + "enum": ( + [str(value) for value in enum_values] + if isinstance(enum_values, list) + else [] + ), + } + ) + return { + "kind": "form", + "server_name": self._server_name, + "message": params.message.strip(), + "prompt": self._build_form_elicitation_prompt(params, properties), + "fields": fields, + } + + @staticmethod + def _build_form_retry_prompt(exc: UnsupportedElicitationRequestError) -> str: + return ( + "I could not use that reply for the MCP elicitation.\n" + f"Reason: {exc}\n" + "Please try again, or reply `decline` / `cancel`." + ) + + def _build_url_elicitation_prompt( + self, + params: mcp.types.ElicitRequestURLParams, + ) -> str: + lines = [ + f"MCP server `{self._server_name}` needs an external confirmation step." + ] + if params.message.strip(): + lines.append(params.message.strip()) + lines.append(f"URL: {params.url}") + lines.append( + "Reply `done` after you finish, `decline` to refuse, or `cancel` to stop." + ) + return "\n".join(lines) + + def _build_url_elicitation_payload( + self, + params: mcp.types.ElicitRequestURLParams, + ) -> dict[str, Any]: + return { + "kind": "url", + "server_name": self._server_name, + "message": params.message.strip(), + "prompt": self._build_url_elicitation_prompt(params), + "url": params.url, + } + + @staticmethod + def _build_elicitation_llm_fallback_system_prompt() -> str: + return ( + "You extract structured MCP elicitation data from a user's natural-language reply.\n" + "Return only a JSON object.\n" + "Use only keys from the provided schema.\n" + "Do not invent facts. Omit fields that are not clearly supported.\n" + "Use proper JSON types for booleans, integers, numbers, and arrays.\n" + "Do not wrap the JSON in markdown fences." + ) + + def _build_elicitation_llm_fallback_prompt( + self, + *, + params: mcp.types.ElicitRequestFormParams, + reply_text: str, + direct_parse_error: UnsupportedElicitationRequestError, + ) -> str: + return ( + f"MCP server: {self._server_name}\n" + f"Original elicitation message:\n{params.message.strip() or ''}\n\n" + f"Requested JSON schema:\n" + f"{json.dumps(params.requestedSchema, ensure_ascii=False, indent=2)}\n\n" + f"User reply:\n{reply_text}\n\n" + f"Direct parser error:\n{direct_parse_error}\n\n" + "Produce the best possible JSON object that matches the schema." + ) + + def _parse_form_elicitation_reply( + self, + *, + requested_schema: dict[str, Any], + reply_text: str, + ) -> dict[str, str | int | float | bool | list[str] | None]: + properties = self._get_elicitation_properties(requested_schema) + if not properties: + return {} + + normalized_reply = reply_text.strip() + if not normalized_reply: + raise UnsupportedElicitationRequestError("The reply is empty.") + + if normalized_reply.startswith("{"): + try: + payload = json.loads(normalized_reply) + except json.JSONDecodeError as exc: + raise UnsupportedElicitationRequestError( + "The JSON reply could not be parsed." + ) from exc + if not isinstance(payload, dict): + raise UnsupportedElicitationRequestError( + "The JSON reply must be an object." + ) + elif len(properties) == 1: + field_name = next(iter(properties)) + payload = {field_name: normalized_reply} + else: + payload = self._parse_key_value_lines(normalized_reply, properties) + if not payload: + payload = self._parse_natural_language_form_reply( + reply_text=normalized_reply, + requested_schema=requested_schema, + ) + if not payload: + raise UnsupportedElicitationRequestError( + "Please reply with JSON, natural language, or `field: value` lines." + ) + + return self._coerce_form_payload(payload, requested_schema) + + def _parse_natural_language_form_reply( + self, + *, + reply_text: str, + requested_schema: dict[str, Any], + ) -> dict[str, Any]: + properties = self._get_elicitation_properties(requested_schema) + if not properties: + return {} + + parsed = self._parse_field_patterns(reply_text, properties) + parsed.update(self._match_enum_values(reply_text, properties, parsed.keys())) + if parsed: + return parsed + + target_fields = self._get_required_elicitation_fields(requested_schema) + if not target_fields: + target_fields = list(properties.keys()) + if len(target_fields) == 1: + return {target_fields[0]: reply_text} + + return {} + + def _coerce_form_payload( + self, + payload: dict[str, Any], + requested_schema: dict[str, Any], + ) -> dict[str, str | int | float | bool | list[str] | None]: + properties = self._get_elicitation_properties(requested_schema) + required_fields = self._get_required_elicitation_fields(requested_schema) + normalized_keys = { + field_name.casefold(): field_name for field_name in properties.keys() + } + + coerced: dict[str, str | int | float | bool | list[str] | None] = {} + for raw_key, raw_value in payload.items(): + normalized_key = str(raw_key).strip().casefold() + field_name = normalized_keys.get(normalized_key) + if field_name is None: + continue + coerced[field_name] = self._coerce_form_value( + field_name=field_name, + raw_value=raw_value, + schema=properties[field_name], + ) + + missing_required = [ + field_name for field_name in required_fields if field_name not in coerced + ] + if missing_required: + raise UnsupportedElicitationRequestError( + "Missing required field(s): " + ", ".join(missing_required) + ) + return coerced + + def _coerce_form_value( + self, + *, + field_name: str, + raw_value: Any, + schema: dict[str, Any], + ) -> str | int | float | bool | list[str] | None: + field_type = self._get_elicitation_field_type(schema) + if raw_value is None: + return None + + if field_type == "string": + value = str(raw_value).strip() + elif field_type == "integer": + if isinstance(raw_value, bool): + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be an integer." + ) + try: + value = int(str(raw_value).strip()) + except (TypeError, ValueError) as exc: + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be an integer." + ) from exc + elif field_type == "number": + if isinstance(raw_value, bool): + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be a number." + ) + try: + value = float(str(raw_value).strip()) + except (TypeError, ValueError) as exc: + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be a number." + ) from exc + elif field_type == "boolean": + value = self._coerce_boolean_value(field_name, raw_value) + elif field_type == "array": + value = self._coerce_string_array_value(field_name, raw_value) + else: + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` uses unsupported type `{field_type}`." + ) + + enum_values = schema.get("enum") + if isinstance(enum_values, list) and value not in enum_values: + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be one of: {', '.join(map(str, enum_values))}." + ) + return value + + @staticmethod + def _coerce_boolean_value(field_name: str, raw_value: Any) -> bool: + if isinstance(raw_value, bool): + return raw_value + + normalized = str(raw_value).strip().casefold() + truthy = {"true", "1", "yes", "y", "on", "是", "好的"} + falsy = {"false", "0", "no", "n", "off", "否", "不是"} + if normalized in truthy: + return True + if normalized in falsy: + return False + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be a boolean." + ) + + @staticmethod + def _coerce_string_array_value(field_name: str, raw_value: Any) -> list[str]: + if isinstance(raw_value, list): + return [str(item).strip() for item in raw_value if str(item).strip()] + + normalized = str(raw_value).strip() + if not normalized: + return [] + parts = [ + part.strip() + for chunk in normalized.splitlines() + for part in chunk.split(",") + if part.strip() + ] + if not parts: + raise UnsupportedElicitationRequestError( + f"Field `{field_name}` must be a string array." + ) + return parts + + @staticmethod + def _parse_key_value_lines( + reply_text: str, + properties: dict[str, dict[str, Any]], + ) -> dict[str, str]: + normalized_keys = { + field_name.casefold(): field_name for field_name in properties.keys() + } + parsed: dict[str, str] = {} + for line in reply_text.splitlines(): + stripped = line.strip() + if not stripped: + continue + delimiter = ":" if ":" in stripped else (":" if ":" in stripped else None) + if delimiter is None: + continue + raw_key, raw_value = stripped.split(delimiter, 1) + field_name = normalized_keys.get(raw_key.strip().casefold()) + if field_name is None: + continue + parsed[field_name] = raw_value.strip() + return parsed + + @staticmethod + def _parse_field_patterns( + reply_text: str, + properties: dict[str, dict[str, Any]], + ) -> dict[str, str]: + parsed: dict[str, str] = {} + separators = r"[::=]|是|为" + boundaries = r"(?:[,,;;。]|$)" + for field_name in properties: + pattern = re.compile( + rf"{re.escape(field_name)}\s*(?:{separators})\s*(.+?)(?={boundaries})", + re.IGNORECASE, + ) + match = pattern.search(reply_text) + if match: + value = match.group(1).strip().strip("`'\"") + if value: + parsed[field_name] = value + return parsed + + @staticmethod + def _match_enum_values( + reply_text: str, + properties: dict[str, dict[str, Any]], + ignore_fields: set[str] | Any, + ) -> dict[str, str]: + normalized_reply = reply_text.casefold() + parsed: dict[str, str] = {} + ignored = set(ignore_fields) + for field_name, schema in properties.items(): + if field_name in ignored: + continue + enum_values = schema.get("enum") + if not isinstance(enum_values, list) or not enum_values: + continue + + matches = [ + str(enum_value) + for enum_value in enum_values + if str(enum_value).casefold() in normalized_reply + ] + if len(matches) == 1: + parsed[field_name] = matches[0] + return parsed + + @staticmethod + def _get_elicitation_properties( + requested_schema: dict[str, Any], + ) -> dict[str, dict[str, Any]]: + properties = requested_schema.get("properties", {}) + if not isinstance(properties, dict): + raise UnsupportedElicitationRequestError( + "Form-mode elicitation requires a top-level properties object." + ) + normalized_properties: dict[str, dict[str, Any]] = {} + for field_name, field_schema in properties.items(): + if isinstance(field_name, str) and isinstance(field_schema, dict): + normalized_properties[field_name] = field_schema + return normalized_properties + + @staticmethod + def _get_required_elicitation_fields( + requested_schema: dict[str, Any], + ) -> list[str]: + required_fields = requested_schema.get("required", []) + if not isinstance(required_fields, list): + return [] + return [field for field in required_fields if isinstance(field, str)] + + @staticmethod + def _get_elicitation_field_type(field_schema: dict[str, Any]) -> str: + field_type = field_schema.get("type", "string") + if isinstance(field_type, list): + for candidate in field_type: + if candidate in {"string", "integer", "number", "boolean", "array"}: + return candidate + raise UnsupportedElicitationRequestError( + "Unsupported multi-type elicitation field." + ) + if not isinstance(field_type, str): + return "string" + return field_type + + @staticmethod + def _parse_cancel_or_decline_action(reply_text: str) -> str | None: + normalized = reply_text.strip().casefold() + if normalized in MCP_ELICITATION_CANCEL_KEYWORDS: + return "cancel" + if normalized in MCP_ELICITATION_DECLINE_KEYWORDS: + return "decline" + return None + + @staticmethod + def _parse_url_action(reply_text: str) -> str | None: + normalized = reply_text.strip().casefold() + if normalized in MCP_ELICITATION_ACCEPT_KEYWORDS: + return "accept" + if normalized in MCP_ELICITATION_DECLINE_KEYWORDS: + return "decline" + if normalized in MCP_ELICITATION_CANCEL_KEYWORDS: + return "cancel" + return None + + @staticmethod + def _strip_code_fence(text: str) -> str: + stripped = text.strip() + if not stripped.startswith("```") or not stripped.endswith("```"): + return stripped + lines = stripped.splitlines() + if len(lines) <= 2: + return stripped.removeprefix("```").removesuffix("```").strip() + return "\n".join(lines[1:-1]).strip() + + @staticmethod + def _is_webchat_event(event: Any) -> bool: + platform_name = getattr(event, "get_platform_name", None) + if callable(platform_name): + try: + return platform_name() == "webchat" + except Exception: # noqa: BLE001 + return False + return False + + def _build_root_entries(self) -> list[mcp.types.Root]: + import mcp + + roots: list[mcp.types.Root] = [] + seen_paths: set[str] = set() + for name, path in self._iter_resolved_root_paths(): + normalized_path = str(path) + if normalized_path in seen_paths: + continue + seen_paths.add(normalized_path) + roots.append( + mcp.types.Root( + uri=path.as_uri(), + name=name, + ) + ) + return roots + + def _iter_resolved_root_paths(self) -> list[tuple[str, Path]]: + configured_paths = self._capabilities.roots.paths or list( + DEFAULT_MCP_ROOT_PATHS + ) + resolved_entries: list[tuple[str, Path]] = [] + for entry in configured_paths: + resolved = self._resolve_root_path_entry(entry) + if resolved is not None: + resolved_entries.append(resolved) + return resolved_entries + + def _resolve_root_path_entry(self, entry: str) -> tuple[str, Path] | None: + normalized_entry = entry.strip() + if not normalized_entry: + return None + + alias_key = normalized_entry.lower() + alias_resolvers = get_root_path_alias_resolvers() + if alias_key in alias_resolvers: + path = Path(alias_resolvers[alias_key]()).resolve() + display_name = alias_key + else: + candidate_path = Path(normalized_entry).expanduser() + if not candidate_path.is_absolute(): + candidate_path = Path(get_astrbot_root()) / candidate_path + path = candidate_path.resolve() + display_name = path.name or normalized_entry + + if not path.exists(): + logger.warning( + "Skipping missing MCP root path for server %s: %s", + self._server_name, + path, + ) + return None + + return display_name, path diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..de3dafdd9a 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -6,6 +6,9 @@ from dataclasses import replace from astrbot.core import logger +from astrbot.core.agent.mcp_elicitation_registry import ( + try_capture_pending_mcp_elicitation, +) from astrbot.core.agent.message import Message from astrbot.core.agent.response import AgentStats from astrbot.core.astr_main_agent import ( @@ -164,6 +167,12 @@ async def process( return logger.debug("ready to request llm provider") + if try_capture_pending_mcp_elicitation(event): + logger.info( + "Captured MCP elicitation reply for active agent run, umo=%s", + event.unified_msg_origin, + ) + return follow_up_capture = try_capture_follow_up(event) if follow_up_capture: ( diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index b7da864aae..5ceededf10 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -66,6 +66,17 @@ async def _send( }, ) elif isinstance(comp, Json): + if message.type == "elicitation" and isinstance(comp.data, dict): + await web_chat_back_queue.put( + { + "type": "elicitation", + "data": comp.data, + "streaming": streaming, + "chain_type": message.type, + "message_id": message_id, + }, + ) + continue await web_chat_back_queue.put( { "type": "plain", diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index f950b00250..e2a9635fb7 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -8,6 +8,7 @@ import urllib.parse from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass +from pathlib import Path from types import MappingProxyType from typing import Any @@ -16,6 +17,12 @@ from astrbot import logger from astrbot.core import sp from astrbot.core.agent.mcp_client import MCPClient, MCPTool +from astrbot.core.agent.mcp_prompt_bridge import build_mcp_prompt_tools +from astrbot.core.agent.mcp_resource_bridge import build_mcp_resource_tools +from astrbot.core.agent.mcp_subcapability_bridge import ( + normalize_mcp_config, + normalize_mcp_server_config, +) from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -145,7 +152,10 @@ def _prepare_config(config: dict) -> dict: if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] + config = normalize_mcp_server_config(config) config.pop("active", None) + config.pop("client_capabilities", None) + config.pop("provider", None) return config @@ -320,6 +330,13 @@ def get_full_tool_set(self) -> ToolSet: tool_set = ToolSet(self.func_list.copy()) return tool_set + def _remove_mcp_bound_tools(self, name: str) -> None: + self.func_list = [ + tool + for tool in self.func_list + if getattr(tool, "mcp_server_name", None) != name + ] + @staticmethod def _log_safe_mcp_debug_config(cfg: dict) -> None: # 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息 @@ -370,18 +387,15 @@ async def init_mcp_clients( - 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。 - 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。 """ - data_dir = get_astrbot_data_path() - - mcp_json_file = os.path.join(data_dir, "mcp_server.json") - if not os.path.exists(mcp_json_file): + mcp_json_file = Path(get_astrbot_data_path()) / "mcp_server.json" + if not mcp_json_file.exists(): # 配置文件不存在错误处理 - with open(mcp_json_file, "w", encoding="utf-8") as f: + with mcp_json_file.open("w", encoding="utf-8") as f: json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return MCPInitSummary(total=0, success=0, failed=[]) - with open(mcp_json_file, encoding="utf-8") as f: - mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"] + mcp_server_json_obj = self.load_mcp_config()["mcpServers"] init_timeout = self._init_timeout_default timeout_display = f"{init_timeout:g}" @@ -583,6 +597,8 @@ async def _init_mcp_client(self, name: str, config: dict) -> MCPClient: try: await mcp_client.connect_to_server(config, name) tools_res = await mcp_client.list_tools_and_save() + await mcp_client.load_resource_capabilities() + await mcp_client.load_prompt_capabilities() except asyncio.CancelledError: await self._cleanup_mcp_client_safely(mcp_client, name) raise @@ -591,13 +607,11 @@ async def _init_mcp_client(self, name: str, config: dict) -> MCPClient: raise logger.debug(f"MCP server {name} list tools response: {tools_res}") tool_names = [tool.name for tool in tools_res.tools] + tool_names.extend(getattr(mcp_client, "resource_bridge_tool_names", [])) + tool_names.extend(getattr(mcp_client, "prompt_bridge_tool_names", [])) # 移除该MCP服务之前的工具(如有) - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] + self._remove_mcp_bound_tools(name) # 将 MCP 工具转换为 FuncTool 并添加到 func_list for tool in mcp_client.tools: @@ -608,6 +622,11 @@ async def _init_mcp_client(self, name: str, config: dict) -> MCPClient: ) self.func_list.append(func_tool) + for resource_tool in build_mcp_resource_tools(mcp_client, name): + self.func_list.append(resource_tool) + for prompt_tool in build_mcp_prompt_tools(mcp_client, name): + self.func_list.append(prompt_tool) + logger.info(f"Connected to MCP server {name}, Tools: {tool_names}") return mcp_client @@ -620,11 +639,7 @@ async def _terminate_mcp_client(self, name: str) -> None: # 关闭MCP连接 await self._cleanup_mcp_client_safely(client, name) # 移除关联的FuncTool - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] + self._remove_mcp_bound_tools(name) async with self._runtime_lock: self._mcp_server_runtime.pop(name, None) self._mcp_starting.discard(name) @@ -632,11 +647,7 @@ async def _terminate_mcp_client(self, name: str) -> None: return # Runtime missing but stale tools may still exist after failed flows. - self.func_list = [ - f - for f in self.func_list - if not (isinstance(f, MCPTool) and f.mcp_server_name == name) - ] + self._remove_mcp_bound_tools(name) async with self._runtime_lock: self._mcp_starting.discard(name) @@ -652,7 +663,11 @@ async def test_mcp_server_connection(config: dict) -> list[str]: logger.debug(f"testing MCP server connection with config: {config}") await mcp_client.connect_to_server(config, "test") tools_res = await mcp_client.list_tools_and_save() + await mcp_client.load_resource_capabilities() + await mcp_client.load_prompt_capabilities() tool_names = [tool.name for tool in tools_res.tools] + tool_names.extend(getattr(mcp_client, "resource_bridge_tool_names", [])) + tool_names.extend(getattr(mcp_client, "prompt_bridge_tool_names", [])) finally: logger.debug("Cleaning up MCP client after testing connection.") await mcp_client.cleanup() @@ -820,28 +835,29 @@ def activate_llm_tool(self, name: str, star_map: dict) -> bool: @property def mcp_config_path(self): - data_dir = get_astrbot_data_path() - return os.path.join(data_dir, "mcp_server.json") + return Path(get_astrbot_data_path()) / "mcp_server.json" def load_mcp_config(self): - if not os.path.exists(self.mcp_config_path): + if not self.mcp_config_path.exists(): # 配置文件不存在,创建默认配置 - os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) - with open(self.mcp_config_path, "w", encoding="utf-8") as f: + self.mcp_config_path.parent.mkdir(parents=True, exist_ok=True) + with self.mcp_config_path.open("w", encoding="utf-8") as f: json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) - return DEFAULT_MCP_CONFIG + return normalize_mcp_config(DEFAULT_MCP_CONFIG) try: - with open(self.mcp_config_path, encoding="utf-8") as f: - return json.load(f) + with self.mcp_config_path.open(encoding="utf-8") as f: + return normalize_mcp_config(json.load(f)) except Exception as e: logger.error(f"加载 MCP 配置失败: {e}") - return DEFAULT_MCP_CONFIG + return normalize_mcp_config(DEFAULT_MCP_CONFIG) def save_mcp_config(self, config: dict) -> bool: try: - with open(self.mcp_config_path, "w", encoding="utf-8") as f: - json.dump(config, f, ensure_ascii=False, indent=4) + normalized = normalize_mcp_config(config) + self.mcp_config_path.parent.mkdir(parents=True, exist_ok=True) + with self.mcp_config_path.open("w", encoding="utf-8") as f: + json.dump(normalized, f, ensure_ascii=False, indent=4) return True except Exception as e: logger.error(f"保存 MCP 配置失败: {e}") diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a914f3cbf0..84c6ccb515 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -10,6 +10,9 @@ from quart import g, make_response, request, send_file from astrbot.core import logger, sp +from astrbot.core.agent.mcp_elicitation_registry import ( + submit_pending_mcp_elicitation_reply, +) from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.core.platform.message_type import MessageType @@ -50,6 +53,7 @@ def __init__( "/chat/sessions": ("GET", self.get_sessions), "/chat/get_session": ("GET", self.get_session), "/chat/stop": ("POST", self.stop_session), + "/chat/respond_elicitation": ("POST", self.respond_elicitation), "/chat/delete_session": ("GET", self.delete_webchat_session), "/chat/update_session_display_name": ( "POST", @@ -73,6 +77,18 @@ def __init__( self.running_convs: dict[str, bool] = {} + @staticmethod + def _build_webchat_session_umo(session) -> str: + message_type = ( + MessageType.GROUP_MESSAGE.value + if session.is_group + else MessageType.FRIEND_MESSAGE.value + ) + return ( + f"{session.platform_id}:{message_type}:" + f"{session.platform_id}!{session.creator}!{session.session_id}" + ) + async def get_file(self): filename = request.args.get("filename") if not filename: @@ -447,6 +463,19 @@ async def stream(): ) if part: accumulated_parts.append(part) + elif msg_type == "elicitation": + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + if isinstance(result_text, dict): + accumulated_parts.append( + { + "type": "elicitation", + "payload": result_text, + } + ) # 消息结束处理 if msg_type == "end": @@ -565,19 +594,72 @@ async def stop_session(self): if session.creator != username: return Response().error("Permission denied").__dict__ - message_type = ( - MessageType.GROUP_MESSAGE.value - if session.is_group - else MessageType.FRIEND_MESSAGE.value - ) - umo = ( - f"{session.platform_id}:{message_type}:" - f"{session.platform_id}!{username}!{session_id}" - ) + umo = self._build_webchat_session_umo(session) stopped_count = active_event_registry.request_agent_stop_all(umo) return Response().ok(data={"stopped_count": stopped_count}).__dict__ + async def respond_elicitation(self): + post_data = await request.json + if post_data is None: + return Response().error("Missing JSON body").__dict__ + + session_id = str(post_data.get("session_id", "")).strip() + reply_text = str(post_data.get("reply_text", "")).strip() + display_text = str(post_data.get("display_text", reply_text)).strip() + if not session_id: + return Response().error("Missing key: session_id").__dict__ + if not reply_text: + return Response().error("Missing key: reply_text").__dict__ + + username = g.get("username", "guest") + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + umo = self._build_webchat_session_umo(session) + if not submit_pending_mcp_elicitation_reply( + umo, + username, + reply_text, + reply_outline=display_text, + ): + return ( + Response().error("No pending MCP elicitation for this session").__dict__ + ) + + saved_record = await self.platform_history_mgr.insert( + platform_id=session.platform_id, + user_id=session_id, + content={ + "type": "user", + "message": [{"type": "plain", "text": display_text or reply_text}], + }, + sender_id=username, + sender_name=username, + ) + + return ( + Response() + .ok( + data={ + "saved_message": { + "id": saved_record.id, + "created_at": to_utc_isoformat(saved_record.created_at), + "content": { + "type": "user", + "message": [ + {"type": "plain", "text": display_text or reply_text} + ], + }, + } + } + ) + .__dict__ + ) + async def delete_webchat_session(self): """Delete a Platform session and all its related data.""" session_id = request.args.get("session_id") diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 8d0af938d0..275c9afa85 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -576,6 +576,19 @@ async def _handle_chat_message( part = await self._create_attachment_from_file(filename, "video") if part: accumulated_parts.append(part) + elif msg_type == "elicitation": + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + if isinstance(result_text, dict): + accumulated_parts.append( + { + "type": "elicitation", + "payload": result_text, + } + ) should_save = False if msg_type == "end": diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 9a736b1763..2f96c61b82 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -451,6 +451,19 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: ) if part: accumulated_parts.append(part) + elif msg_type == "elicitation": + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + if isinstance(result_text, dict): + accumulated_parts.append( + { + "type": "elicitation", + "payload": result_text, + } + ) if msg_type == "end": break diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 84f8dcc6d7..19f6bf896c 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -104,7 +104,13 @@ async def get_mcp_servers(self): for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items(): if name_key == name: mcp_client = runtime.client - server_info["tools"] = [tool.name for tool in mcp_client.tools] + server_info["tools"] = ( + [tool.name for tool in mcp_client.tools] + + list( + getattr(mcp_client, "resource_bridge_tool_names", []) + ) + + list(getattr(mcp_client, "prompt_bridge_tool_names", [])) + ) server_info["errlogs"] = mcp_client.server_errlogs break else: @@ -431,7 +437,7 @@ async def get_tool_list(self): tools = self.tool_mgr.func_list tools_dict = [] for tool in tools: - if isinstance(tool, MCPTool): + if isinstance(tool, MCPTool) or getattr(tool, "mcp_server_name", None): origin = "mcp" origin_name = tool.mcp_server_name elif tool.handler_module_path and star_map.get( diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 7c25e1bc3a..3d9c60ba69 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -51,6 +51,7 @@ { + messageList.value?.scrollToBottom(); + }); +} + // 路由变化监听 watch( () => route.path, diff --git a/dashboard/src/components/chat/MessageList.vue b/dashboard/src/components/chat/MessageList.vue index 42da95aa58..80041efd3a 100644 --- a/dashboard/src/components/chat/MessageList.vue +++ b/dashboard/src/components/chat/MessageList.vue @@ -97,6 +97,8 @@ @@ -221,6 +223,10 @@ export default { isLoadingMessages: { type: Boolean, default: false + }, + submitElicitation: { + type: Function, + default: null } }, emits: ['openImagePreview', 'replyMessage', 'replyWithText', 'openRefs'], @@ -417,6 +423,14 @@ export default { return messageParts.some(part => part.type === 'record' && part.embedded_url); }, + isActiveElicitationMessage(index, msg) { + if (!this.isStreaming || index !== this.messages.length - 1) { + return false; + } + return Array.isArray(msg?.content?.message) + && msg.content.message.some(part => part.type === 'elicitation' && part.payload); + }, + // 获取被引用消息的内容 getReplyContent(messageId) { const replyMsg = this.messages.find(m => m.id === messageId); diff --git a/dashboard/src/components/chat/StandaloneChat.vue b/dashboard/src/components/chat/StandaloneChat.vue index 69fac13f9b..5c96e56b42 100644 --- a/dashboard/src/components/chat/StandaloneChat.vue +++ b/dashboard/src/components/chat/StandaloneChat.vue @@ -5,7 +5,8 @@
@@ -158,6 +159,7 @@ const { enableStreaming, getSessionMessages: getSessionMsg, sendMessage: sendMsg, + submitElicitationResponse, stopMessage: stopMsg, toggleStreaming } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); @@ -243,6 +245,17 @@ async function handleStopMessage() { await stopMsg(); } +async function handleSubmitElicitation(replyText: string, displayText: string) { + if (!currSessionId.value) { + return; + } + + await submitElicitationResponse(currSessionId.value, replyText, displayText); + nextTick(() => { + messageList.value?.scrollToBottom(); + }); +} + onMounted(async () => { // 独立模式在挂载时创建新会话 try { diff --git a/dashboard/src/components/chat/message_list_comps/ElicitationCard.vue b/dashboard/src/components/chat/message_list_comps/ElicitationCard.vue new file mode 100644 index 0000000000..a7e6f679af --- /dev/null +++ b/dashboard/src/components/chat/message_list_comps/ElicitationCard.vue @@ -0,0 +1,352 @@ + + + + + diff --git a/dashboard/src/components/chat/message_list_comps/MessagePartsRenderer.vue b/dashboard/src/components/chat/message_list_comps/MessagePartsRenderer.vue index 5fd7f59bed..4c9a953b28 100644 --- a/dashboard/src/components/chat/message_list_comps/MessagePartsRenderer.vue +++ b/dashboard/src/components/chat/message_list_comps/MessagePartsRenderer.vue @@ -60,6 +60,14 @@ + + import { useI18n, useModuleI18n } from '@/i18n/composables'; import { MarkdownRender } from 'markstream-vue'; +import ElicitationCard from './ElicitationCard.vue'; import IPythonToolBlock from './IPythonToolBlock.vue'; import ToolCallItem from './ToolCallItem.vue'; @@ -136,6 +145,14 @@ const props = defineProps({ downloadingFiles: { type: Object, default: () => new Set() + }, + interactiveElicitation: { + type: Boolean, + default: false + }, + submitElicitation: { + type: Function, + default: null } }); diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts index c593fb283b..9c9d5ca6ba 100644 --- a/dashboard/src/composables/useMessages.ts +++ b/dashboard/src/composables/useMessages.ts @@ -34,14 +34,33 @@ export interface FileInfo { attachment_id?: string; // 用于按需下载 } +export interface ElicitationField { + name: string; + label: string; + description?: string; + required: boolean; + type: string; + enum?: string[]; +} + +export interface ElicitationPayload { + kind: 'form' | 'url'; + server_name: string; + message: string; + prompt: string; + url?: string; + fields?: ElicitationField[]; +} + // 消息部分的类型定义 export interface MessagePart { - type: 'plain' | 'image' | 'record' | 'file' | 'video' | 'reply' | 'tool_call'; + type: 'plain' | 'image' | 'record' | 'file' | 'video' | 'reply' | 'tool_call' | 'elicitation'; text?: string; // for plain attachment_id?: string; // for image, record, file, video filename?: string; // for file (filename from backend) message_id?: number; // for reply (PlatformSessionHistoryMessage.id) tool_calls?: ToolCall[]; // for tool_call + payload?: ElicitationPayload; // for elicitation // embedded fields - 加载后填充 embedded_url?: string; // blob URL for image, record embedded_file?: FileInfo; // for file (保留 attachment_id 用于按需下载) @@ -84,6 +103,35 @@ type StreamChunk = { [key: string]: any; }; +function normalizeElicitationPayload(payload: any): ElicitationPayload | null { + if (!payload || typeof payload !== 'object') { + return null; + } + + const kind = payload.kind === 'url' ? 'url' : 'form'; + const fields = Array.isArray(payload.fields) + ? payload.fields + .filter((field: any) => field && typeof field === 'object' && typeof field.name === 'string') + .map((field: any) => ({ + name: String(field.name), + label: String(field.label || field.name), + description: String(field.description || ''), + required: Boolean(field.required), + type: String(field.type || 'string'), + enum: Array.isArray(field.enum) ? field.enum.map((value: any) => String(value)) : [] + })) + : []; + + return { + kind, + server_name: String(payload.server_name || ''), + message: String(payload.message || ''), + prompt: String(payload.prompt || ''), + url: typeof payload.url === 'string' ? payload.url : undefined, + fields + }; +} + type WsStreamContext = { handleChunk: (payload: StreamChunk) => Promise; finish: (err?: unknown) => void; @@ -207,6 +255,24 @@ export function useMessages( return; } + if (payload.type === 'elicitation') { + const normalizedPayload = normalizeElicitationPayload(payload.data); + if (!normalizedPayload) { + return; + } + + messages.value.push({ + content: { + type: 'bot', + message: [{ + type: 'elicitation', + payload: normalizedPayload + }] + } + }); + return; + } + if (payload.type === 'plain') { const chainType = payload.chain_type || 'normal'; if (chainType === 'reasoning') { @@ -440,6 +506,24 @@ export function useMessages( return; } + if (chunkJson.type === 'elicitation') { + const normalizedPayload = normalizeElicitationPayload(chunkJson.data); + if (!normalizedPayload) { + return; + } + + messages.value.push({ + content: { + type: 'bot', + message: [{ + type: 'elicitation', + payload: normalizedPayload + }] + } + }); + return; + } + if (chunkJson.type === 'image') { const img = String(chunkJson.data || '').replace('[IMAGE]', ''); const imageUrl = await getMediaFile(img); @@ -671,6 +755,9 @@ export function useMessages( }; } // plain, reply, tool_call, video 保持原样 + if (part.type === 'elicitation') { + part.payload = normalizeElicitationPayload(part.payload) || undefined; + } } } @@ -1003,6 +1090,49 @@ export function useMessages( } } + async function submitElicitationResponse( + sessionId: string, + replyText: string, + displayText: string + ) { + const normalizedSessionId = sessionId.trim(); + const normalizedReplyText = replyText.trim(); + const normalizedDisplayText = (displayText || replyText).trim(); + if (!normalizedSessionId || !normalizedReplyText) { + throw new Error('Missing elicitation reply payload'); + } + + const response = await axios.post('/api/chat/respond_elicitation', { + session_id: normalizedSessionId, + reply_text: normalizedReplyText, + display_text: normalizedDisplayText + }); + if (response.data?.status !== 'ok') { + throw new Error(response.data?.message || 'Failed to submit elicitation reply'); + } + + const savedMessage = response.data?.data?.saved_message; + if (savedMessage?.content) { + await parseMessageContent(savedMessage.content); + messages.value.push({ + id: savedMessage.id, + created_at: savedMessage.created_at, + content: savedMessage.content + }); + return savedMessage; + } + + const fallbackMessage: MessageContent = { + type: 'user', + message: [{ + type: 'plain', + text: normalizedDisplayText + }] + }; + messages.value.push({ content: fallbackMessage }); + return null; + } + async function stopMessage() { const sessionId = currentRunningSessionId.value || currSessionId.value; if (!sessionId) { @@ -1057,6 +1187,7 @@ export function useMessages( currentSessionProject, getSessionMessages, sendMessage, + submitElicitationResponse, stopMessage, toggleStreaming, setTransportMode, diff --git a/dashboard/src/i18n/locales/en-US/features/chat.json b/dashboard/src/i18n/locales/en-US/features/chat.json index 53e98c991a..5d1e79eb6a 100644 --- a/dashboard/src/i18n/locales/en-US/features/chat.json +++ b/dashboard/src/i18n/locales/en-US/features/chat.json @@ -97,6 +97,21 @@ "replyTo": "Reply to", "notFound": "Message not found" }, + "elicitation": { + "title": "{server} needs more information", + "submit": "Submit", + "cancel": "Cancel", + "decline": "Decline", + "done": "Done", + "otherInput": "If none fit, enter a custom reply", + "arrayPlaceholder": "Use commas or new lines for multiple values", + "requiredField": "Please fill in {field}", + "emptyReply": "Please provide at least one field", + "submitted": "Submitted. Waiting for AstrBot to continue", + "submitFailed": "Failed to submit the elicitation reply", + "accepted": "Accepted", + "booleanLabel": "Turn on for true, off for false" + }, "project": { "title": "Projects", "create": "Create Project", diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index 9d60db1767..74b91e75d8 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -97,6 +97,21 @@ "replyTo": "引用", "notFound": "无法定位消息" }, + "elicitation": { + "title": "{server} 需要更多信息", + "submit": "提交", + "cancel": "取消", + "decline": "拒绝", + "done": "已完成", + "otherInput": "如果都不合适,可手动输入", + "arrayPlaceholder": "可使用逗号或换行分隔多个值", + "requiredField": "请填写 {field}", + "emptyReply": "请至少填写一个字段", + "submitted": "已提交,等待 AstrBot 继续处理", + "submitFailed": "提交 elicitation 回复失败", + "accepted": "已确认", + "booleanLabel": "打开以表示 true,关闭表示 false" + }, "project": { "title": "项目", "create": "创建项目", diff --git a/docs/en/use/mcp.md b/docs/en/use/mcp.md index 1681d31325..cebb5cde13 100644 --- a/docs/en/use/mcp.md +++ b/docs/en/use/mcp.md @@ -96,6 +96,100 @@ Configure it in the AstrBot WebUI: That's it. +## MCP Client Sub-Capabilities + +AstrBot now supports enabling MCP client sub-capabilities per server. The first integrated sub-capability is `sampling`, which allows an MCP server to send `sampling/createMessage` requests during a tool call and reuse the current bot's chat provider, persona context, and session origin. + +AstrBot also supports per-server `elicitation`. When enabled, an MCP server can ask the current user for missing input or require an external confirmation step while the current MCP tool call is still in progress. + +AstrBot also supports per-server `roots`. When enabled, an MCP server can call `roots/list` to learn which local file roots AstrBot is explicitly exposing to it. + +You can add the following to an MCP server configuration: + +```json +{ + "url": "https://example.com/mcp", + "transport": "sse", + "client_capabilities": { + "elicitation": { + "enabled": true, + "timeout_seconds": 300 + }, + "sampling": { + "enabled": true + }, + "roots": { + "enabled": true, + "paths": ["data", "temp"] + } + } +} +``` + +Current limitations: + +- `elicitation` is disabled by default and is only advertised for servers that explicitly set `enabled: true`. +- `elicitation.timeout_seconds` bounds the total time AstrBot waits for the user's reply; timeouts are returned as `cancel`. +- `elicitation` is only available while the MCP server is actively serving the current bot tool call; requests outside the active interaction context are rejected. +- `elicitation` currently uses plain chat messages instead of a dedicated form UI. +- Form-mode elicitation currently supports flat top-level fields with simple types: `string`, `integer`, `number`, `boolean`, and `array[string]`. +- For single-field form elicitation, the user can reply with plain text. For multi-field form elicitation, AstrBot accepts JSON or `field: value` lines. +- URL-mode elicitation currently sends the URL and instructions to the user, then waits for a chat reply such as `done`, `decline`, or `cancel`. +- Enabling `elicitation` only affects the configured MCP server and does not change other MCP servers or standard chat flows. +- `sampling` is disabled by default and is only advertised for servers that explicitly set `enabled: true`. +- `roots` is disabled by default and is only advertised for servers that explicitly set `enabled: true`. +- `roots.paths` can contain built-in aliases such as `data`, `temp`, `config`, `skills`, `plugins`, `plugin_data`, `knowledge_base`, `backups`, and `root`, as well as absolute paths or paths relative to AstrBot root. +- If `roots.enabled` is `true` and `paths` is omitted, AstrBot currently exposes `data` and `temp` as the default safe roots. +- `sampling` is only available while the MCP server is actively serving the current bot tool call; requests outside the active interaction context are rejected. +- The initial implementation only returns text sampling results. +- Tool-assisted sampling and multimodal sampling inputs such as image or audio are not supported yet. +- Enabling `sampling` only affects the configured MCP server and does not change other MCP servers or standard chat flows. +- Enabling `roots` only affects the configured MCP server and does not change other MCP servers or standard chat flows. + +## Notes for stdio servers + +When using an MCP server over stdio, the server should reserve stdout for JSON-RPC protocol messages and write logs to stderr. + +AstrBot now tolerates blank lines and common launcher banners such as `npm run` output to reduce noise during local testing, but that behavior is only a compatibility fallback. The robust setup is still to keep stdout protocol-only. + +## MCP Resources Bridge + +If an MCP server advertises the `resources` capability during initialization, AstrBot now registers a small set of bridge tools so the bot can interact with those resources through the existing tool loop. + +The first iteration exposes these server-scoped tools: + +- `mcp__list_resources` +- `mcp__read_resource` +- `mcp__list_resource_templates` (only when the server supports listing resource templates) + +This lets the bot discover available resources and read a specific resource URI without changing the normal chat flow or provider integration. + +Current limitations: + +- The bridge is read-only and does not support `resources/subscribe` push updates yet. +- AstrBot does not auto-inject MCP resources into prompt context; the bot still needs to read them explicitly through tools. +- A single text resource is returned as text. +- A single image blob resource is returned as an image-style tool result. +- Multi-part resources, mixed results, and non-image binary blobs are summarized into text in the first iteration. + +## MCP Prompts Bridge + +If an MCP server advertises the `prompts` capability during initialization, AstrBot also registers a small set of bridge tools so the bot can discover and fetch MCP prompts through the existing tool loop. + +The first iteration exposes these server-scoped tools: + +- `mcp__list_prompts` +- `mcp__get_prompt` + +This lets the bot inspect available prompt templates and resolve a specific prompt by name with optional arguments, without changing the normal chat flow or provider integration. + +Current limitations: + +- AstrBot does not auto-inject MCP prompts into the active chat context; the bot still needs to fetch them explicitly through tools. +- `get_prompt` results are currently summarized into text, preserving descriptions, message roles, and text blocks. +- Non-text prompt blocks such as images, audio, and embedded resources are summarized into text in the first iteration instead of being converted into multimodal context. +- The bridge does not support `prompts/list_changed` push updates yet, and it does not use MCP completions to auto-complete prompt arguments. + Reference links: 1. Learn how to use MCP here: [Model Context Protocol](https://modelcontextprotocol.io/introduction) diff --git a/docs/zh/use/mcp.md b/docs/zh/use/mcp.md index 79e3757fda..929c8c8176 100644 --- a/docs/zh/use/mcp.md +++ b/docs/zh/use/mcp.md @@ -95,6 +95,100 @@ npx -v 即可。 +## MCP Client 子能力 + +AstrBot 现在支持按服务开启 MCP Client 子能力。当前首个接入的子能力是 `sampling`,它允许 MCP Server 在工具调用过程中向 AstrBot 发起 `sampling/createMessage` 请求,并复用当前 Bot 的聊天模型、人格上下文与会话来源。 + +目前也支持按服务开启 `elicitation`。启用后,MCP Server 可以在工具调用过程中向当前用户追问缺失信息,或要求用户完成一个外部确认步骤,再继续当前 MCP 交互。 + +目前也支持按服务开启 `roots`。启用后,MCP Server 可以通过 `roots/list` 请求 AstrBot 暴露给它的文件根目录列表。 + +您可以在对应 MCP Server 配置中加入: + +```json +{ + "url": "https://example.com/mcp", + "transport": "sse", + "client_capabilities": { + "elicitation": { + "enabled": true, + "timeout_seconds": 300 + }, + "sampling": { + "enabled": true + }, + "roots": { + "enabled": true, + "paths": ["data", "temp"] + } + } +} +``` + +当前版本的限制: + +- `elicitation` 默认关闭,只有显式配置 `enabled: true` 的服务器才会声明该能力。 +- `elicitation.timeout_seconds` 用于限制等待用户回复的总时长;超时后当前 elicitation 会以 `cancel` 结束。 +- `elicitation` 当前只在对应 MCP Server 正在为当前 Bot 执行工具调用时可用;脱离当前交互上下文的请求会被拒绝。 +- `elicitation` 当前通过聊天消息交互,不提供独立表单 UI。 +- `elicitation` 的 form 模式当前支持顶层简单字段:`string`、`integer`、`number`、`boolean`、`array[string]`。 +- `elicitation` 的 form 模式在单字段时可直接回复纯文本,多字段时可回复 JSON 或 `field: value` 形式。 +- `elicitation` 的 url 模式当前会把 URL 和提示文本发送给用户,并等待用户在聊天中回复 `done` / `decline` / `cancel` 之类的确认。 +- 开启 `elicitation` 的影响范围仅限对应的 MCP Server,不会改变其他 MCP Server 或普通聊天流程。 +- `sampling` 默认关闭,只有显式配置 `enabled: true` 的服务器才会声明该能力。 +- `roots` 默认关闭,只有显式配置 `enabled: true` 的服务器才会声明该能力。 +- `roots.paths` 可填写内置别名(如 `data`、`temp`、`config`、`skills`、`plugins`、`plugin_data`、`knowledge_base`、`backups`、`root`)、绝对路径,或相对 AstrBot 根目录的路径。 +- 当 `roots.enabled` 为 `true` 且未显式填写 `paths` 时,AstrBot 当前默认暴露 `data` 和 `temp`。 +- `sampling` 仅在该 MCP Server 正在为当前 Bot 执行工具调用时可用;脱离当前交互上下文的请求会被拒绝。 +- 当前版本仅支持文本采样结果。 +- 当前版本不支持带工具的 sampling,也不支持图片、音频等多模态 sampling 输入。 +- 开启 `sampling` 的影响范围仅限对应的 MCP Server,不会改变其他 MCP Server 或普通聊天流程。 +- 开启 `roots` 的影响范围仅限对应的 MCP Server,不会改变其他 MCP Server 或普通聊天流程。 + +## stdio 服务输出说明 + +使用 stdio 方式接入 MCP Server 时,服务器应当只通过标准输出(stdout)发送 JSON-RPC 协议消息,并将日志写入标准错误(stderr)。 + +AstrBot 当前会尽量忽略空行以及 `npm run` 一类启动器输出的非协议横幅,减少测试时的噪声;但这只是兼容处理,不建议依赖。更稳妥的做法仍然是让 MCP Server 或启动脚本保持 stdout 干净。 + +## MCP Resources 桥接 + +如果某个 MCP Server 在初始化时声明了 `resources` 能力,AstrBot 会自动为它注册一组桥接工具,让 Bot 通过现有的工具调用流程与资源交互。 + +当前首版会按服务注册这些工具: + +- `mcp__list_resources` +- `mcp__read_resource` +- `mcp__list_resource_templates`(仅在服务支持模板列表时出现) + +这意味着 Bot 可以先列出资源,再读取某个具体 URI 的内容,而不需要修改普通聊天流程或额外配置 Provider。 + +当前版本的限制: + +- 当前只做只读桥接,不支持 `resources/subscribe` 订阅与推送刷新。 +- 当前不会自动把 MCP 资源注入提示词上下文,仍然需要 Bot 通过工具主动读取。 +- 单个文本资源会直接作为文本结果返回。 +- 单个图片 Blob 资源会按图片结果处理。 +- 多段资源、混合类型资源以及非图片二进制资源,当前会被整理为文本摘要返回。 + +## MCP Prompts 桥接 + +如果某个 MCP Server 在初始化时声明了 `prompts` 能力,AstrBot 也会自动为它注册一组桥接工具,让 Bot 通过现有的工具调用流程查看和获取 MCP prompt。 + +当前首版会按服务注册这些工具: + +- `mcp__list_prompts` +- `mcp__get_prompt` + +这意味着 Bot 可以先列出可用 prompt,再按名称和参数获取某个 prompt 的展开结果,而不需要修改普通聊天流程或额外配置 Provider。 + +当前版本的限制: + +- 当前不会自动把 MCP prompt 注入对话上下文,仍然需要 Bot 通过工具主动获取。 +- `get_prompt` 的结果当前会整理为文本摘要,保留描述、消息角色和文本内容。 +- 图片、音频、嵌入资源等非纯文本 prompt block,当前会被总结为文本说明,而不是直接转成多模态上下文。 +- 当前不支持 `prompts/list_changed` 推送刷新,也不支持通过 MCP completion 自动补全 prompt 参数。 + 参考链接: 1. 在这里了解如何使用 MCP: [Model Context Protocol](https://modelcontextprotocol.io/introduction) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 6c575910a0..83e27cc231 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -5,6 +5,7 @@ import zipfile from datetime import datetime from types import SimpleNamespace +from unittest.mock import AsyncMock import pytest import pytest_asyncio @@ -275,6 +276,298 @@ async def test_commands_api(app: Quart, authenticated_header: dict): assert isinstance(data["data"], list) +@pytest.mark.asyncio +async def test_mcp_servers_api_exposes_client_capabilities( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + tool_mgr = core_lifecycle_td.provider_manager.llm_tools + monkeypatch.setattr( + tool_mgr, + "load_mcp_config", + lambda: { + "mcpServers": { + "demo": { + "url": "https://example.com/mcp", + "transport": "sse", + "active": True, + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 180, + }, + "sampling": { + "enabled": True, + } + }, + } + } + }, + ) + + test_client = app.test_client() + response = await test_client.get( + "/api/tools/mcp/servers", + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + assert data["data"][0]["client_capabilities"]["elicitation"]["enabled"] is True + assert data["data"][0]["client_capabilities"]["elicitation"]["timeout_seconds"] == 180 + assert data["data"][0]["client_capabilities"]["sampling"]["enabled"] is True + + +@pytest.mark.asyncio +async def test_mcp_servers_api_includes_resource_bridge_tools( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + tool_mgr = core_lifecycle_td.provider_manager.llm_tools + monkeypatch.setattr( + tool_mgr, + "load_mcp_config", + lambda: { + "mcpServers": { + "demo": { + "command": "node", + "args": ["stdio.js"], + "active": True, + } + } + }, + ) + tool_mgr._mcp_server_runtime["demo"] = SimpleNamespace( + client=SimpleNamespace( + tools=[SimpleNamespace(name="demo_tool")], + resource_bridge_tool_names=["mcp_demo_list_resources"], + server_errlogs=[], + ) + ) + + test_client = app.test_client() + try: + response = await test_client.get( + "/api/tools/mcp/servers", + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + assert data["data"][0]["tools"] == [ + "demo_tool", + "mcp_demo_list_resources", + ] + finally: + tool_mgr._mcp_server_runtime.pop("demo", None) + + +@pytest.mark.asyncio +async def test_mcp_servers_api_includes_prompt_bridge_tools( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + tool_mgr = core_lifecycle_td.provider_manager.llm_tools + monkeypatch.setattr( + tool_mgr, + "load_mcp_config", + lambda: { + "mcpServers": { + "demo": { + "command": "node", + "args": ["stdio.js"], + "active": True, + } + } + }, + ) + tool_mgr._mcp_server_runtime["demo"] = SimpleNamespace( + client=SimpleNamespace( + tools=[SimpleNamespace(name="demo_tool")], + resource_bridge_tool_names=[], + prompt_bridge_tool_names=["mcp_demo_list_prompts", "mcp_demo_get_prompt"], + server_errlogs=[], + ) + ) + + test_client = app.test_client() + try: + response = await test_client.get( + "/api/tools/mcp/servers", + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + assert data["data"][0]["tools"] == [ + "demo_tool", + "mcp_demo_list_prompts", + "mcp_demo_get_prompt", + ] + finally: + tool_mgr._mcp_server_runtime.pop("demo", None) + + +@pytest.mark.asyncio +async def test_add_mcp_server_persists_client_capabilities( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + tmp_path, + monkeypatch, +): + monkeypatch.setattr( + "astrbot.core.provider.func_tool_manager.get_astrbot_data_path", + lambda: str(tmp_path), + ) + tool_mgr = core_lifecycle_td.provider_manager.llm_tools + monkeypatch.setattr( + tool_mgr, + "test_mcp_server_connection", + AsyncMock(return_value=["demo_tool"]), + ) + monkeypatch.setattr( + tool_mgr, + "enable_mcp_server", + AsyncMock(return_value=None), + ) + + test_client = app.test_client() + response = await test_client.post( + "/api/tools/mcp/add", + json={ + "name": "demo", + "url": "https://example.com/mcp", + "transport": "sse", + "active": True, + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 240, + }, + "sampling": { + "enabled": True, + } + }, + }, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + persisted = tool_mgr.load_mcp_config() + assert ( + persisted["mcpServers"]["demo"]["client_capabilities"]["elicitation"][ + "enabled" + ] + is True + ) + assert ( + persisted["mcpServers"]["demo"]["client_capabilities"]["elicitation"][ + "timeout_seconds" + ] + == 240 + ) + assert ( + persisted["mcpServers"]["demo"]["client_capabilities"]["sampling"]["enabled"] + is True + ) + + +@pytest.mark.asyncio +async def test_chat_respond_elicitation_resolves_pending_reply_and_persists_message( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + captured: dict[str, str] = {} + + def _fake_submit(umo: str, sender_id: str, reply_text: str, *, reply_outline=None): + captured["umo"] = umo + captured["sender_id"] = sender_id + captured["reply_text"] = reply_text + captured["reply_outline"] = reply_outline or "" + return True + + monkeypatch.setattr( + "astrbot.dashboard.routes.chat.submit_pending_mcp_elicitation_reply", + _fake_submit, + ) + + test_client = app.test_client() + new_session_response = await test_client.get( + "/api/chat/new_session", + headers=authenticated_header, + ) + session_data = await new_session_response.get_json() + session_id = session_data["data"]["session_id"] + + response = await test_client.post( + "/api/chat/respond_elicitation", + json={ + "session_id": session_id, + "reply_text": '{"topic":"MCP 最小实现"}', + "display_text": "topic: MCP 最小实现", + }, + headers=authenticated_header, + ) + + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + expected_username = core_lifecycle_td.astrbot_config["dashboard"]["username"] + assert captured["sender_id"] == expected_username + assert captured["reply_text"] == '{"topic":"MCP 最小实现"}' + assert captured["reply_outline"] == "topic: MCP 最小实现" + assert captured["umo"].endswith(f"webchat!{expected_username}!{session_id}") + assert data["data"]["saved_message"]["content"]["message"] == [ + {"type": "plain", "text": "topic: MCP 最小实现"} + ] + + +@pytest.mark.asyncio +async def test_chat_respond_elicitation_rejects_when_no_pending_request( + app: Quart, + authenticated_header: dict, + monkeypatch, +): + monkeypatch.setattr( + "astrbot.dashboard.routes.chat.submit_pending_mcp_elicitation_reply", + lambda *_args, **_kwargs: False, + ) + + test_client = app.test_client() + new_session_response = await test_client.get( + "/api/chat/new_session", + headers=authenticated_header, + ) + session_data = await new_session_response.get_json() + session_id = session_data["data"]["session_id"] + + response = await test_client.post( + "/api/chat/respond_elicitation", + json={ + "session_id": session_id, + "reply_text": "cancel", + "display_text": "cancel", + }, + headers=authenticated_header, + ) + + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "error" + assert "No pending MCP elicitation" in data["message"] + + @pytest.mark.asyncio async def test_check_update( app: Quart, diff --git a/tests/unit/test_mcp_prompt_bridge.py b/tests/unit/test_mcp_prompt_bridge.py new file mode 100644 index 0000000000..14a6c57894 --- /dev/null +++ b/tests/unit/test_mcp_prompt_bridge.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import mcp +import pytest + +from astrbot.core.agent.mcp_prompt_bridge import ( + MCPGetPromptTool, + MCPListPromptsTool, + build_mcp_prompt_tool_names, + shape_get_prompt_result, +) +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.provider.func_tool_manager import ( + FunctionToolManager, + _MCPServerRuntime, +) + + +class _FakePromptCapableMCPClient: + def __init__(self) -> None: + self.name: str | None = None + self.tools = [ + mcp.types.Tool( + name="draft_brief", + description="Draft a short brief.", + inputSchema={"type": "object", "properties": {}}, + ) + ] + self.resource_bridge_tool_names: list[str] = [] + self.prompt_bridge_tool_names: list[str] = [] + self.prompts: list[mcp.types.Prompt] = [] + self.server_errlogs: list[str] = [] + self.received_arguments: list[dict[str, str] | None] = [] + + @property + def supports_prompts(self) -> bool: + return True + + @property + def supports_resources(self) -> bool: + return False + + async def connect_to_server(self, config: dict, name: str) -> None: + _ = config + self.name = name + + async def list_tools_and_save(self) -> mcp.types.ListToolsResult: + return mcp.types.ListToolsResult(tools=self.tools) + + async def load_resource_capabilities(self) -> None: + self.resource_bridge_tool_names = [] + + async def load_prompt_capabilities(self) -> None: + self.prompt_bridge_tool_names = build_mcp_prompt_tool_names( + self.name or "server" + ) + + async def list_prompts_and_save( + self, + cursor: str | None = None, + ) -> mcp.types.ListPromptsResult: + _ = cursor + self.prompts = [ + mcp.types.Prompt( + name="draft_brief", + description="Draft a short brief from a topic.", + arguments=[ + mcp.types.PromptArgument( + name="topic", + description="Topic to summarize", + required=True, + ) + ], + ) + ] + return mcp.types.ListPromptsResult(prompts=self.prompts) + + async def get_prompt_with_reconnect( + self, + name: str, + arguments: dict[str, str] | None, + read_timeout_seconds, + ) -> mcp.types.GetPromptResult: + _ = read_timeout_seconds + self.received_arguments.append(arguments) + return mcp.types.GetPromptResult( + description=f"Prompt '{name}' resolved.", + messages=[ + mcp.types.PromptMessage( + role="user", + content=mcp.types.TextContent( + type="text", + text=f"Write a concise brief about {arguments.get('topic', 'the topic') if arguments else 'the topic'}.", + ), + ) + ], + ) + + async def cleanup(self) -> None: + return None + + +@pytest.mark.asyncio +async def test_prompt_bridge_tools_are_registered_and_removed( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr( + "astrbot.core.provider.func_tool_manager.MCPClient", + _FakePromptCapableMCPClient, + ) + + tool_mgr = FunctionToolManager() + client = await tool_mgr._init_mcp_client( + "demo-server", + {"command": "node", "args": ["stdio.js"]}, + ) + + tool_names = {tool.name for tool in tool_mgr.func_list} + assert "draft_brief" in tool_names + assert "mcp_demo_server_list_prompts" in tool_names + assert "mcp_demo_server_get_prompt" in tool_names + + completed_task = asyncio.create_task(asyncio.sleep(0)) + await completed_task + tool_mgr._mcp_server_runtime["demo-server"] = _MCPServerRuntime( + name="demo-server", + client=client, + shutdown_event=asyncio.Event(), + lifecycle_task=completed_task, + ) + + await tool_mgr._terminate_mcp_client("demo-server") + + assert not any( + getattr(tool, "mcp_server_name", None) == "demo-server" + for tool in tool_mgr.func_list + ) + + +@pytest.mark.asyncio +async def test_prompt_listing_tool_returns_text_summary(): + client = _FakePromptCapableMCPClient() + tool = MCPListPromptsTool( + mcp_client=client, + mcp_server_name="demo-server", + ) + context = ContextWrapper(context=SimpleNamespace()) + + result = await tool.call(context) + + assert isinstance(result.content[0], mcp.types.TextContent) + assert "draft_brief" in result.content[0].text + assert "topic" in result.content[0].text + + +@pytest.mark.asyncio +async def test_get_prompt_tool_passes_arguments_and_returns_text_summary(): + client = _FakePromptCapableMCPClient() + tool = MCPGetPromptTool( + mcp_client=client, + mcp_server_name="demo-server", + ) + context = ContextWrapper(context=SimpleNamespace(), tool_call_timeout=30) + + result = await tool.call( + context, + name="draft_brief", + arguments={"topic": "MCP 最小实现"}, + ) + + assert client.received_arguments == [{"topic": "MCP 最小实现"}] + assert isinstance(result.content[0], mcp.types.TextContent) + assert "Prompt: draft_brief" in result.content[0].text + assert "MCP 最小实现" in result.content[0].text + + +def test_shape_get_prompt_result_summarizes_non_text_blocks(): + result = shape_get_prompt_result( + server_name="demo-server", + prompt_name="rich_prompt", + response=mcp.types.GetPromptResult( + description="Rich prompt response.", + messages=[ + mcp.types.PromptMessage( + role="assistant", + content=mcp.types.ImageContent( + type="image", + data="ZmFrZQ==", + mimeType="image/png", + ), + ), + mcp.types.PromptMessage( + role="user", + content=mcp.types.EmbeddedResource( + type="resource", + resource=mcp.types.TextResourceContents( + uri="memo://prompt/context", + mimeType="text/plain", + text="embedded context", + ), + ), + ), + ], + ), + ) + + assert "Image block returned." in result + assert "Embedded text resource returned." in result + assert "embedded context" in result diff --git a/tests/unit/test_mcp_resource_bridge.py b/tests/unit/test_mcp_resource_bridge.py new file mode 100644 index 0000000000..e605ffcf15 --- /dev/null +++ b/tests/unit/test_mcp_resource_bridge.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest + +import mcp +from astrbot.core.agent.mcp_resource_bridge import ( + MCPListResourceTemplatesTool, + MCPListResourcesTool, + MCPReadResourceTool, + build_mcp_resource_tool_names, + shape_read_resource_result, +) +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.provider.func_tool_manager import ( + FunctionToolManager, + _MCPServerRuntime, +) + + +class _FakeResourceCapableMCPClient: + def __init__(self) -> None: + self.name: str | None = None + self.tools = [ + mcp.types.Tool( + name="draft_brief", + description="Draft a short brief.", + inputSchema={"type": "object", "properties": {}}, + ) + ] + self.resource_templates_supported = True + self.resource_bridge_tool_names: list[str] = [] + self.prompt_bridge_tool_names: list[str] = [] + self.server_errlogs: list[str] = [] + + @property + def supports_resources(self) -> bool: + return True + + @property + def supports_prompts(self) -> bool: + return False + + async def connect_to_server(self, config: dict, name: str) -> None: + _ = config + self.name = name + + async def list_tools_and_save(self) -> mcp.types.ListToolsResult: + return mcp.types.ListToolsResult(tools=self.tools) + + async def load_resource_capabilities(self) -> None: + self.resource_bridge_tool_names = build_mcp_resource_tool_names( + self.name or "server", + include_templates=True, + ) + + async def load_prompt_capabilities(self) -> None: + self.prompt_bridge_tool_names = [] + + async def list_resources_and_save( + self, + cursor: str | None = None, + ) -> mcp.types.ListResourcesResult: + _ = cursor + return mcp.types.ListResourcesResult( + resources=[ + mcp.types.Resource( + name="team_notes", + uri="memo://team/notes", + description="Shared team notes", + mimeType="text/plain", + ) + ] + ) + + async def list_resource_templates_and_save( + self, + cursor: str | None = None, + ) -> mcp.types.ListResourceTemplatesResult: + _ = cursor + self.resource_templates_supported = True + return mcp.types.ListResourceTemplatesResult( + resourceTemplates=[ + mcp.types.ResourceTemplate( + name="note_by_id", + uriTemplate="memo://notes/{id}", + description="Read a note by id", + mimeType="text/plain", + ) + ] + ) + + async def read_resource_with_reconnect( + self, + uri: str, + read_timeout_seconds, + ) -> mcp.types.ReadResourceResult: + _ = read_timeout_seconds + return mcp.types.ReadResourceResult( + contents=[ + mcp.types.TextResourceContents( + uri=uri, + mimeType="text/plain", + text="hello from resource", + ) + ] + ) + + async def cleanup(self) -> None: + return None + + +@pytest.mark.asyncio +async def test_resource_bridge_tools_are_registered_and_removed( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr( + "astrbot.core.provider.func_tool_manager.MCPClient", + _FakeResourceCapableMCPClient, + ) + + tool_mgr = FunctionToolManager() + client = await tool_mgr._init_mcp_client( + "demo-server", + {"command": "node", "args": ["stdio.js"]}, + ) + + tool_names = {tool.name for tool in tool_mgr.func_list} + assert "draft_brief" in tool_names + assert "mcp_demo_server_list_resources" in tool_names + assert "mcp_demo_server_read_resource" in tool_names + assert "mcp_demo_server_list_resource_templates" in tool_names + + completed_task = asyncio.create_task(asyncio.sleep(0)) + await completed_task + tool_mgr._mcp_server_runtime["demo-server"] = _MCPServerRuntime( + name="demo-server", + client=client, + shutdown_event=asyncio.Event(), + lifecycle_task=completed_task, + ) + + await tool_mgr._terminate_mcp_client("demo-server") + + assert not any( + getattr(tool, "mcp_server_name", None) == "demo-server" + for tool in tool_mgr.func_list + ) + + +@pytest.mark.asyncio +async def test_resource_listing_tools_return_text_summaries(): + client = _FakeResourceCapableMCPClient() + list_tool = MCPListResourcesTool( + mcp_client=client, + mcp_server_name="demo-server", + ) + template_tool = MCPListResourceTemplatesTool( + mcp_client=client, + mcp_server_name="demo-server", + ) + context = ContextWrapper(context=SimpleNamespace()) + + list_result = await list_tool.call(context) + template_result = await template_tool.call(context) + + assert isinstance(list_result.content[0], mcp.types.TextContent) + assert "memo://team/notes" in list_result.content[0].text + assert isinstance(template_result.content[0], mcp.types.TextContent) + assert "memo://notes/{id}" in template_result.content[0].text + + +@pytest.mark.asyncio +async def test_read_resource_tool_returns_text_resource_summary(): + client = _FakeResourceCapableMCPClient() + tool = MCPReadResourceTool( + mcp_client=client, + mcp_server_name="demo-server", + ) + context = ContextWrapper(context=SimpleNamespace(), tool_call_timeout=30) + + result = await tool.call(context, uri="memo://team/notes") + + assert isinstance(result.content[0], mcp.types.TextContent) + assert "hello from resource" in result.content[0].text + assert "memo://team/notes" in result.content[0].text + + +def test_shape_read_resource_result_returns_embedded_image_for_single_image_blob(): + response = mcp.types.ReadResourceResult( + contents=[ + mcp.types.BlobResourceContents( + uri="memo://images/cover", + mimeType="image/png", + blob="ZmFrZQ==", + ) + ] + ) + + result = shape_read_resource_result( + server_name="demo-server", + requested_uri="memo://images/cover", + response=response, + ) + + assert isinstance(result.content[0], mcp.types.EmbeddedResource) + assert result.content[0].resource.mimeType == "image/png" + + +def test_shape_read_resource_result_summarizes_multi_part_content(): + response = mcp.types.ReadResourceResult( + contents=[ + mcp.types.TextResourceContents( + uri="memo://notes/1", + mimeType="text/plain", + text="first part", + ), + mcp.types.BlobResourceContents( + uri="memo://notes/1.bin", + mimeType="application/octet-stream", + blob="Zm9v", + ), + ] + ) + + result = shape_read_resource_result( + server_name="demo-server", + requested_uri="memo://notes/1", + response=response, + ) + + assert isinstance(result.content[0], mcp.types.TextContent) + assert "Returned parts: 2" in result.content[0].text + assert "Binary blob returned" in result.content[0].text diff --git a/tests/unit/test_mcp_stdio_client.py b/tests/unit/test_mcp_stdio_client.py new file mode 100644 index 0000000000..796d4ef659 --- /dev/null +++ b/tests/unit/test_mcp_stdio_client.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import pytest + +from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.agent.mcp_stdio_client import _should_ignore_stdout_line + + +class _DummyAsyncContext: + def __init__(self, value): + self._value = value + + async def __aenter__(self): + return self._value + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _RecordingClientSession: + constructor_calls: list[dict] = [] + + def __init__(self, *args, **kwargs): + self.__class__.constructor_calls.append( + { + "args": args, + "kwargs": kwargs, + } + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def initialize(self): + return None + + +@pytest.mark.parametrize( + "line", + [ + "", + " ", + "> mcp-minimal-server@1.0.0 start:stdio", + "> node stdio.js", + "npm notice", + ], +) +def test_should_ignore_non_protocol_stdio_stdout_lines(line: str): + assert _should_ignore_stdout_line(line) is True + + +@pytest.mark.parametrize( + "line", + [ + '{"jsonrpc":"2.0","id":1,"result":{}}', + '{"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info"}}', + "{not-valid-json-yet}", + ], +) +def test_should_keep_json_like_stdio_stdout_lines_for_protocol_parsing(line: str): + assert _should_ignore_stdout_line(line) is False + + +@pytest.mark.asyncio +async def test_mcp_client_stdio_path_uses_tolerant_transport( + monkeypatch: pytest.MonkeyPatch, +): + _RecordingClientSession.constructor_calls.clear() + transport_calls: list[dict] = [] + + def _fake_tolerant_stdio_client(server_params, errlog): + transport_calls.append( + { + "server_params": server_params, + "errlog": errlog, + } + ) + return _DummyAsyncContext(("read", "write")) + + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.tolerant_stdio_client", + _fake_tolerant_stdio_client, + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.mcp.ClientSession", + _RecordingClientSession, + ) + + client = MCPClient() + await client.connect_to_server( + { + "command": "node", + "args": ["stdio.js"], + }, + "demo", + ) + + assert len(transport_calls) == 1 + server_params = transport_calls[0]["server_params"] + assert server_params.command == "node" + assert server_params.args == ["stdio.js"] + assert _RecordingClientSession.constructor_calls + + await client.cleanup() diff --git a/tests/unit/test_mcp_subcapability_bridge.py b/tests/unit/test_mcp_subcapability_bridge.py new file mode 100644 index 0000000000..5ae573ed47 --- /dev/null +++ b/tests/unit/test_mcp_subcapability_bridge.py @@ -0,0 +1,1295 @@ +# ruff: noqa: ASYNC110 + +from __future__ import annotations + +import asyncio +import json +from datetime import timedelta +from pathlib import Path +from types import SimpleNamespace + +import mcp +import pytest + +from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.agent.mcp_elicitation_registry import ( + pending_mcp_elicitation, + submit_pending_mcp_elicitation_reply, + try_capture_pending_mcp_elicitation, +) +from astrbot.core.agent.mcp_subcapability_bridge import MCPClientSubCapabilityBridge +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.func_tool_manager import FunctionToolManager + + +class _DummyAsyncContext: + def __init__(self, value): + self._value = value + + async def __aenter__(self): + return self._value + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _RecordingClientSession: + constructor_calls: list[dict] = [] + + def __init__(self, *args, **kwargs): + self.__class__.constructor_calls.append( + { + "args": args, + "kwargs": kwargs, + } + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def initialize(self): + return None + + def get_server_capabilities(self): + return None + + +class _SamplingAwareSession: + def __init__(self, bridge, params): + self.bridge = bridge + self.params = params + self.calls: list[dict] = [] + + async def call_tool(self, *, name, arguments, read_timeout_seconds): + self.calls.append( + { + "name": name, + "arguments": arguments, + "read_timeout_seconds": read_timeout_seconds, + } + ) + sampling_result = await self.bridge.handle_sampling(None, self.params) + assert isinstance(sampling_result, mcp.types.CreateMessageResult) + return mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=sampling_result.content.text, + ) + ] + ) + + +class _DummyProvider: + def __init__(self, model: str = "gpt-4o-mini"): + self._model = model + + def get_model(self) -> str: + return self._model + + def meta(self): + return SimpleNamespace(model=self._model, id="provider-1") + + +class _DummyPluginContext: + def __init__( + self, + *, + completion_text: str, + release_event: asyncio.Event | None = None, + entered_event: asyncio.Event | None = None, + ): + self.provider = _DummyProvider() + self.completion_text = completion_text + self.release_event = release_event + self.entered_event = entered_event + self.requests: list[dict] = [] + + async def get_current_chat_provider_id(self, umo: str) -> str: + assert umo + return "provider-1" + + def get_using_provider(self, umo: str | None = None): + return self.provider + + async def llm_generate(self, **kwargs): + self.requests.append(kwargs) + if self.entered_event is not None: + self.entered_event.set() + if self.release_event is not None: + await self.release_event.wait() + return LLMResponse(role="assistant", completion_text=self.completion_text) + + +class _DummyEvent: + def __init__( + self, + *, + umo: str = "test:umo", + sender_id: str = "user-1", + message_text: str = "", + outline: str = "", + platform_name: str = "test", + ) -> None: + self.unified_msg_origin = umo + self._sender_id = sender_id + self._message_text = message_text + self._outline = outline or message_text + self._platform_name = platform_name + self.sent_messages: list[str] = [] + self.sent_payloads: list[dict] = [] + + def get_sender_id(self) -> str: + return self._sender_id + + def get_message_str(self) -> str: + return self._message_text + + def get_message_outline(self) -> str: + return self._outline + + def get_platform_name(self) -> str: + return self._platform_name + + async def send(self, message_chain) -> None: + self.sent_messages.append( + message_chain.get_plain_text(with_other_comps_mark=True) + ) + if ( + getattr(message_chain, "type", None) == "elicitation" + and message_chain.chain + ): + first = message_chain.chain[0] + payload = getattr(first, "data", None) + if isinstance(payload, dict): + self.sent_payloads.append(payload) + + +def _build_run_context( + plugin_context: _DummyPluginContext, + *, + umo: str = "test:umo", + event: _DummyEvent | None = None, +): + event = event or _DummyEvent(umo=umo) + agent_context = SimpleNamespace( + context=plugin_context, + event=event, + ) + return ContextWrapper(context=agent_context) + + +def _build_sampling_params(*, tools=None, content=None): + if content is None: + content = mcp.types.TextContent(type="text", text="hello from server") + return mcp.types.CreateMessageRequestParams( + messages=[ + mcp.types.SamplingMessage( + role="user", + content=content, + ) + ], + maxTokens=64, + tools=tools, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("enabled", "expected_sampling_enabled"), + [ + pytest.param(False, False, id="sampling-disabled"), + pytest.param(True, True, id="sampling-enabled"), + ], +) +async def test_mcp_client_capability_advertisement_depends_on_config( + monkeypatch: pytest.MonkeyPatch, + enabled: bool, + expected_sampling_enabled: bool, +): + _RecordingClientSession.constructor_calls.clear() + + async def _fake_quick_test(_config): + return True, "" + + monkeypatch.setattr( + "astrbot.core.agent.mcp_client._quick_test_mcp_connection", + _fake_quick_test, + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.sse_client", + lambda **_kwargs: _DummyAsyncContext(("read", "write")), + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.mcp.ClientSession", + _RecordingClientSession, + ) + + client = MCPClient() + await client.connect_to_server( + { + "url": "https://example.com/mcp", + "transport": "sse", + "client_capabilities": { + "sampling": { + "enabled": enabled, + } + }, + }, + "demo", + ) + + kwargs = _RecordingClientSession.constructor_calls[0]["kwargs"] + sampling_callback = kwargs.get("sampling_callback") + sampling_capabilities = kwargs.get("sampling_capabilities") + + if expected_sampling_enabled: + assert callable(sampling_callback) + assert sampling_capabilities is not None + else: + assert sampling_callback is None + assert sampling_capabilities is None + + await client.cleanup() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("enabled", "expected_roots_enabled"), + [ + pytest.param(False, False, id="roots-disabled"), + pytest.param(True, True, id="roots-enabled"), + ], +) +async def test_mcp_client_roots_capability_advertisement_depends_on_config( + monkeypatch: pytest.MonkeyPatch, + enabled: bool, + expected_roots_enabled: bool, +): + _RecordingClientSession.constructor_calls.clear() + + async def _fake_quick_test(_config): + return True, "" + + monkeypatch.setattr( + "astrbot.core.agent.mcp_client._quick_test_mcp_connection", + _fake_quick_test, + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.sse_client", + lambda **_kwargs: _DummyAsyncContext(("read", "write")), + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.mcp.ClientSession", + _RecordingClientSession, + ) + + client = MCPClient() + await client.connect_to_server( + { + "url": "https://example.com/mcp", + "transport": "sse", + "client_capabilities": { + "roots": { + "enabled": enabled, + } + }, + }, + "demo", + ) + + kwargs = _RecordingClientSession.constructor_calls[0]["kwargs"] + roots_callback = kwargs.get("list_roots_callback") + + if expected_roots_enabled: + assert callable(roots_callback) + else: + assert roots_callback is None + + await client.cleanup() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("enabled", "expected_elicitation_enabled"), + [ + pytest.param(False, False, id="elicitation-disabled"), + pytest.param(True, True, id="elicitation-enabled"), + ], +) +async def test_mcp_client_elicitation_capability_advertisement_depends_on_config( + monkeypatch: pytest.MonkeyPatch, + enabled: bool, + expected_elicitation_enabled: bool, +): + _RecordingClientSession.constructor_calls.clear() + + async def _fake_quick_test(_config): + return True, "" + + monkeypatch.setattr( + "astrbot.core.agent.mcp_client._quick_test_mcp_connection", + _fake_quick_test, + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.sse_client", + lambda **_kwargs: _DummyAsyncContext(("read", "write")), + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_client.mcp.ClientSession", + _RecordingClientSession, + ) + + client = MCPClient() + await client.connect_to_server( + { + "url": "https://example.com/mcp", + "transport": "sse", + "client_capabilities": { + "elicitation": { + "enabled": enabled, + "timeout_seconds": 120, + } + }, + }, + "demo", + ) + + kwargs = _RecordingClientSession.constructor_calls[0]["kwargs"] + elicitation_callback = kwargs.get("elicitation_callback") + + if expected_elicitation_enabled: + assert callable(elicitation_callback) + else: + assert elicitation_callback is None + + await client.cleanup() + + +def test_load_and_save_mcp_config_normalizes_client_capabilities( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr( + "astrbot.core.provider.func_tool_manager.get_astrbot_data_path", + lambda: str(tmp_path), + ) + tool_mgr = FunctionToolManager() + + config = { + "mcpServers": { + "disabled-default": { + "command": "uv", + "args": ["run", "demo.py"], + }, + "enabled-sampling": { + "command": "uv", + "args": ["run", "demo.py"], + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 180, + }, + "sampling": { + "enabled": True, + }, + "roots": { + "enabled": True, + "paths": ["data", "temp"], + }, + }, + }, + } + } + + assert tool_mgr.save_mcp_config(config) is True + + saved_raw = json.loads((tmp_path / "mcp_server.json").read_text(encoding="utf-8")) + assert ( + saved_raw["mcpServers"]["disabled-default"]["client_capabilities"][ + "elicitation" + ]["enabled"] + is False + ) + assert ( + saved_raw["mcpServers"]["disabled-default"]["client_capabilities"][ + "elicitation" + ]["timeout_seconds"] + == 300 + ) + assert ( + saved_raw["mcpServers"]["disabled-default"]["client_capabilities"]["sampling"][ + "enabled" + ] + is False + ) + assert ( + saved_raw["mcpServers"]["disabled-default"]["client_capabilities"]["roots"][ + "enabled" + ] + is False + ) + assert ( + saved_raw["mcpServers"]["disabled-default"]["client_capabilities"]["roots"][ + "paths" + ] + == [] + ) + assert ( + saved_raw["mcpServers"]["enabled-sampling"]["client_capabilities"][ + "elicitation" + ]["enabled"] + is True + ) + assert ( + saved_raw["mcpServers"]["enabled-sampling"]["client_capabilities"][ + "elicitation" + ]["timeout_seconds"] + == 180 + ) + assert ( + saved_raw["mcpServers"]["enabled-sampling"]["client_capabilities"]["sampling"][ + "enabled" + ] + is True + ) + assert ( + saved_raw["mcpServers"]["enabled-sampling"]["client_capabilities"]["roots"][ + "enabled" + ] + is True + ) + assert saved_raw["mcpServers"]["enabled-sampling"]["client_capabilities"]["roots"][ + "paths" + ] == ["data", "temp"] + + loaded = tool_mgr.load_mcp_config() + assert ( + loaded["mcpServers"]["disabled-default"]["client_capabilities"]["elicitation"][ + "enabled" + ] + is False + ) + assert ( + loaded["mcpServers"]["disabled-default"]["client_capabilities"]["elicitation"][ + "timeout_seconds" + ] + == 300 + ) + assert ( + loaded["mcpServers"]["disabled-default"]["client_capabilities"]["sampling"][ + "enabled" + ] + is False + ) + assert ( + loaded["mcpServers"]["disabled-default"]["client_capabilities"]["roots"][ + "enabled" + ] + is False + ) + assert ( + loaded["mcpServers"]["enabled-sampling"]["client_capabilities"]["elicitation"][ + "enabled" + ] + is True + ) + assert ( + loaded["mcpServers"]["enabled-sampling"]["client_capabilities"]["elicitation"][ + "timeout_seconds" + ] + == 180 + ) + assert ( + loaded["mcpServers"]["enabled-sampling"]["client_capabilities"]["sampling"][ + "enabled" + ] + is True + ) + assert ( + loaded["mcpServers"]["enabled-sampling"]["client_capabilities"]["roots"][ + "enabled" + ] + is True + ) + assert loaded["mcpServers"]["enabled-sampling"]["client_capabilities"]["roots"][ + "paths" + ] == ["data", "temp"] + + +@pytest.mark.asyncio +async def test_roots_request_returns_default_roots_for_enabled_server( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + data_dir = tmp_path / "data" + temp_dir = data_dir / "temp" + data_dir.mkdir() + temp_dir.mkdir() + + monkeypatch.setattr( + "astrbot.core.agent.mcp_subcapability_bridge.get_astrbot_data_path", + lambda: str(data_dir), + ) + monkeypatch.setattr( + "astrbot.core.agent.mcp_subcapability_bridge.get_astrbot_temp_path", + lambda: str(temp_dir), + ) + + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "roots": { + "enabled": True, + } + } + } + ) + + result = await bridge.handle_list_roots(None) + + assert isinstance(result, mcp.types.ListRootsResult) + assert [root.name for root in result.roots] == ["data", "temp"] + assert str(result.roots[0].uri) == data_dir.resolve().as_uri() + assert str(result.roots[1].uri) == temp_dir.resolve().as_uri() + + +@pytest.mark.asyncio +async def test_roots_request_uses_explicit_paths_and_skips_missing_entries( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + root_dir = tmp_path / "astrbot-root" + root_dir.mkdir() + explicit_dir = tmp_path / "explicit" + explicit_dir.mkdir() + nested_dir = root_dir / "workspace" + nested_dir.mkdir() + + monkeypatch.setattr( + "astrbot.core.agent.mcp_subcapability_bridge.get_astrbot_root", + lambda: str(root_dir), + ) + + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "roots": { + "enabled": True, + "paths": [ + str(explicit_dir), + "workspace", + "missing-dir", + ], + } + } + } + ) + + result = await bridge.handle_list_roots(None) + + assert isinstance(result, mcp.types.ListRootsResult) + assert [ + Path(str(root.uri).removeprefix("file:///")).name for root in result.roots + ] == [ + "explicit", + "workspace", + ] + + +@pytest.mark.asyncio +async def test_roots_request_is_rejected_when_disabled(): + bridge = MCPClientSubCapabilityBridge("demo") + + result = await bridge.handle_list_roots(None) + + assert isinstance(result, mcp.types.ErrorData) + assert result.code == mcp.types.INVALID_REQUEST + assert "Roots are not enabled" in result.message + + +@pytest.mark.asyncio +async def test_sampling_request_uses_bound_astrbot_context(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "sampling": { + "enabled": True, + } + } + } + ) + plugin_context = _DummyPluginContext(completion_text="reply from astrbot") + run_context = _build_run_context(plugin_context) + params = _build_sampling_params() + + async with bridge.interactive_call(run_context): + result = await bridge.handle_sampling(None, params) + + assert isinstance(result, mcp.types.CreateMessageResult) + assert result.content.text == "reply from astrbot" + assert result.model == "gpt-4o-mini" + assert plugin_context.requests[0]["contexts"] == [ + { + "role": "user", + "content": "hello from server", + } + ] + assert plugin_context.requests[0]["max_tokens"] == 64 + + +@pytest.mark.asyncio +async def test_sampling_request_without_bound_context_is_rejected(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "sampling": { + "enabled": True, + } + } + } + ) + + result = await bridge.handle_sampling(None, _build_sampling_params()) + + assert isinstance(result, mcp.types.ErrorData) + assert result.code == mcp.types.INVALID_REQUEST + assert "active AstrBot MCP interaction" in result.message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "params", + [ + pytest.param( + _build_sampling_params( + tools=[ + mcp.types.Tool( + name="demo_tool", + description="demo", + inputSchema={"type": "object", "properties": {}}, + ) + ] + ), + id="tool-assisted-sampling", + ), + pytest.param( + _build_sampling_params( + content=mcp.types.ImageContent( + type="image", + data="ZmFrZQ==", + mimeType="image/png", + ) + ), + id="image-input", + ), + ], +) +async def test_sampling_request_rejects_unsupported_inputs(params): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "sampling": { + "enabled": True, + } + } + } + ) + plugin_context = _DummyPluginContext(completion_text="reply") + run_context = _build_run_context(plugin_context) + + async with bridge.interactive_call(run_context): + result = await bridge.handle_sampling(None, params) + + assert isinstance(result, mcp.types.ErrorData) + assert result.code == mcp.types.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_sampling_enabled_interactive_calls_are_serialized(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "sampling": { + "enabled": True, + } + } + } + ) + + first_entered = asyncio.Event() + release_first = asyncio.Event() + first_plugin_context = _DummyPluginContext( + completion_text="first reply", + release_event=release_first, + entered_event=first_entered, + ) + second_plugin_context = _DummyPluginContext(completion_text="second reply") + order: list[str] = [] + params = _build_sampling_params() + + async def _first_call(): + async with bridge.interactive_call(_build_run_context(first_plugin_context)): + order.append("enter-1") + result = await bridge.handle_sampling(None, params) + order.append("exit-1") + return result + + async def _second_call(): + await first_entered.wait() + async with bridge.interactive_call(_build_run_context(second_plugin_context)): + order.append("enter-2") + result = await bridge.handle_sampling(None, params) + order.append("exit-2") + return result + + first_task = asyncio.create_task(_first_call()) + await first_entered.wait() + second_task = asyncio.create_task(_second_call()) + + await asyncio.sleep(0) + assert order == ["enter-1"] + + release_first.set() + first_result = await first_task + second_result = await second_task + + assert isinstance(first_result, mcp.types.CreateMessageResult) + assert isinstance(second_result, mcp.types.CreateMessageResult) + assert first_result.content.text == "first reply" + assert second_result.content.text == "second reply" + assert order == ["enter-1", "exit-1", "enter-2", "exit-2"] + + +@pytest.mark.asyncio +async def test_mcp_client_call_tool_with_reconnect_preserves_sampling_runtime_context(): + client = MCPClient() + client.subcapability_bridge.set_server_name("demo") + client.subcapability_bridge.configure_from_server_config( + { + "client_capabilities": { + "sampling": { + "enabled": True, + } + } + } + ) + client.session = _SamplingAwareSession( + client.subcapability_bridge, + _build_sampling_params(), + ) + + plugin_context = _DummyPluginContext(completion_text="reply from astrbot") + run_context = _build_run_context(plugin_context) + + result = await client.call_tool_with_reconnect( + tool_name="draft-brief", + arguments={"topic": "MCP 最小实现"}, + read_timeout_seconds=timedelta(seconds=60), + run_context=run_context, + ) + + assert isinstance(result, mcp.types.CallToolResult) + assert result.content[0].text == "reply from astrbot" + assert len(client.session.calls) == 1 + assert client.session.calls[0]["name"] == "draft-brief" + assert client.session.calls[0]["arguments"] == {"topic": "MCP 最小实现"} + assert plugin_context.requests[0]["contexts"] == [ + { + "role": "user", + "content": "hello from server", + } + ] + + +@pytest.mark.asyncio +async def test_pending_mcp_elicitation_captures_only_matching_sender(): + async with pending_mcp_elicitation("umo:1", "user-1") as future: + wrong_sender_event = _DummyEvent( + umo="umo:1", + sender_id="user-2", + message_text="ignored", + ) + assert try_capture_pending_mcp_elicitation(wrong_sender_event) is False + assert future.done() is False + + matching_event = _DummyEvent( + umo="umo:1", + sender_id="user-1", + message_text="accepted", + ) + assert try_capture_pending_mcp_elicitation(matching_event) is True + resolved_event = await future + assert resolved_event.message_text == "accepted" + assert resolved_event.message_outline == "accepted" + + +@pytest.mark.asyncio +async def test_pending_mcp_elicitation_accepts_direct_submission(): + async with pending_mcp_elicitation("umo:1", "user-1") as future: + assert ( + submit_pending_mcp_elicitation_reply( + "umo:1", + "user-1", + '{"topic":"MCP 最小实现"}', + reply_outline="topic: MCP 最小实现", + ) + is True + ) + resolved_reply = await future + + assert resolved_reply.message_text == '{"topic":"MCP 最小实现"}' + assert resolved_reply.message_outline == "topic: MCP 最小实现" + + +@pytest.mark.asyncio +async def test_elicitation_form_request_uses_next_matching_reply(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + _DummyPluginContext(completion_text="unused"), + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="Please provide the topic.", + requestedSchema={ + "type": "object", + "properties": { + "topic": { + "type": "string", + "description": "Brief topic name", + } + }, + "required": ["topic"], + }, + ) + + async def _resolve_reply(): + while not event.sent_messages: + await asyncio.sleep(0) + assert "topic" in event.sent_messages[0] + reply_event = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="MCP 最小实现", + ) + while not try_capture_pending_mcp_elicitation(reply_event): + await asyncio.sleep(0) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_reply() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content == {"topic": "MCP 最小实现"} + + +@pytest.mark.asyncio +async def test_elicitation_form_request_reprompts_after_invalid_reply(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + _DummyPluginContext(completion_text="unused"), + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="How many sections do you need?", + requestedSchema={ + "type": "object", + "properties": { + "count": { + "type": "integer", + } + }, + "required": ["count"], + }, + ) + + async def _resolve_replies(): + while len(event.sent_messages) < 1: + await asyncio.sleep(0) + first_reply = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="not-a-number", + ) + while not try_capture_pending_mcp_elicitation(first_reply): + await asyncio.sleep(0) + + while len(event.sent_messages) < 2: + await asyncio.sleep(0) + assert "could not use that reply" in event.sent_messages[1].lower() + second_reply = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="2", + ) + while not try_capture_pending_mcp_elicitation(second_reply): + await asyncio.sleep(0) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_replies() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content == {"count": 2} + + +@pytest.mark.asyncio +async def test_elicitation_form_request_accepts_natural_language_key_value_patterns(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + _DummyPluginContext(completion_text="unused"), + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="Collect plan details.", + requestedSchema={ + "type": "object", + "properties": { + "topic": {"type": "string"}, + "audience": {"type": "string"}, + }, + "required": ["topic", "audience"], + }, + ) + + async def _resolve_reply(): + while not event.sent_messages: + await asyncio.sleep(0) + reply_event = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="topic 是 MCP 最小实现,audience 为 新手", + ) + while not try_capture_pending_mcp_elicitation(reply_event): + await asyncio.sleep(0) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_reply() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content == { + "topic": "MCP 最小实现", + "audience": "新手", + } + + +@pytest.mark.asyncio +async def test_elicitation_form_request_uses_llm_fallback_for_bot_reply(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + plugin_context = _DummyPluginContext( + completion_text='```json\n{"topic":"MCP 最小实现","audience":"新手"}\n```' + ) + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + plugin_context, + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="Collect plan details.", + requestedSchema={ + "type": "object", + "properties": { + "topic": {"type": "string"}, + "audience": {"type": "string"}, + }, + "required": ["topic", "audience"], + }, + ) + + async def _resolve_reply(): + while not event.sent_messages: + await asyncio.sleep(0) + reply_event = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="面向新手,写一个关于 MCP 最小实现 的简要说明。", + ) + while not try_capture_pending_mcp_elicitation(reply_event): + await asyncio.sleep(0) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_reply() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content == { + "topic": "MCP 最小实现", + "audience": "新手", + } + assert len(plugin_context.requests) == 1 + assert plugin_context.requests[0]["chat_provider_id"] == "provider-1" + assert "Return only a JSON object." in plugin_context.requests[0]["system_prompt"] + + +@pytest.mark.asyncio +async def test_elicitation_form_request_retries_when_llm_fallback_returns_invalid_json(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + plugin_context = _DummyPluginContext(completion_text="not-json") + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + plugin_context, + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="Collect plan details.", + requestedSchema={ + "type": "object", + "properties": { + "topic": {"type": "string"}, + "audience": {"type": "string"}, + }, + "required": ["topic", "audience"], + }, + ) + + async def _resolve_replies(): + while not event.sent_messages: + await asyncio.sleep(0) + first_reply = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="帮我给新手准备一个关于 MCP 最小实现 的说明。", + ) + while not try_capture_pending_mcp_elicitation(first_reply): + await asyncio.sleep(0) + + while len(event.sent_messages) < 2: + await asyncio.sleep(0) + assert "could not use that reply" in event.sent_messages[1].lower() + + second_reply = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text='{"topic":"MCP 最小实现","audience":"新手"}', + ) + while not try_capture_pending_mcp_elicitation(second_reply): + await asyncio.sleep(0) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_replies() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content == { + "topic": "MCP 最小实现", + "audience": "新手", + } + assert len(plugin_context.requests) == 1 + + +@pytest.mark.asyncio +async def test_webchat_elicitation_message_uses_structured_payload(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + event = _DummyEvent( + umo="test:umo", + sender_id="user-1", + platform_name="webchat", + ) + run_context = _build_run_context( + _DummyPluginContext(completion_text="unused"), + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="Choose a tone.", + requestedSchema={ + "type": "object", + "properties": { + "tone": { + "type": "string", + "enum": ["formal", "casual"], + } + }, + "required": ["tone"], + }, + ) + + async def _resolve_reply(): + while not event.sent_payloads: + await asyncio.sleep(0) + assert event.sent_payloads[0]["fields"][0]["enum"] == ["formal", "casual"] + submit_pending_mcp_elicitation_reply( + "test:umo", + "user-1", + "formal", + reply_outline="tone: formal", + ) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_reply() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content == {"tone": "formal"} + + +@pytest.mark.asyncio +async def test_elicitation_url_request_waits_for_confirmation(): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + _DummyPluginContext(completion_text="unused"), + event=event, + ) + params = mcp.types.ElicitRequestURLParams( + message="Authorize the test server.", + url="https://example.com/auth", + elicitationId="elic-1", + ) + + async def _resolve_reply(): + while not event.sent_messages: + await asyncio.sleep(0) + assert "https://example.com/auth" in event.sent_messages[0] + reply_event = _DummyEvent( + umo="test:umo", + sender_id="user-1", + message_text="done", + ) + while not try_capture_pending_mcp_elicitation(reply_event): + await asyncio.sleep(0) + + async with bridge.interactive_call(run_context): + result_task = asyncio.create_task(bridge.handle_elicitation(None, params)) + await _resolve_reply() + result = await result_task + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "accept" + assert result.content is None + + +@pytest.mark.asyncio +async def test_elicitation_request_times_out_to_cancel(monkeypatch: pytest.MonkeyPatch): + bridge = MCPClientSubCapabilityBridge("demo") + bridge.configure_from_server_config( + { + "client_capabilities": { + "elicitation": { + "enabled": True, + "timeout_seconds": 30, + } + } + } + ) + event = _DummyEvent(umo="test:umo", sender_id="user-1") + run_context = _build_run_context( + _DummyPluginContext(completion_text="unused"), + event=event, + ) + params = mcp.types.ElicitRequestFormParams( + message="Please provide topic.", + requestedSchema={ + "type": "object", + "properties": { + "topic": {"type": "string"}, + }, + }, + ) + + async def _fake_wait_for_reply(*, event, sender_id, deadline): + del event, sender_id, deadline + return None + + monkeypatch.setattr(bridge, "_wait_for_elicitation_reply", _fake_wait_for_reply) + + async with bridge.interactive_call(run_context): + result = await bridge.handle_elicitation(None, params) + + assert isinstance(result, mcp.types.ElicitResult) + assert result.action == "cancel"