diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 0c319d53c..b4810c100 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -2,10 +2,12 @@ from __future__ import annotations +import functools import inspect from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal +import anyio.to_thread import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call @@ -155,10 +157,10 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**call_args) - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**call_args) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) # Validate messages if not isinstance(result, list | tuple): diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 2d612657c..542b5e6f8 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -2,12 +2,14 @@ from __future__ import annotations +import functools import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any from urllib.parse import unquote +import anyio.to_thread from pydantic import BaseModel, Field, validate_call from mcp.server.mcpserver.resources.types import FunctionResource, Resource @@ -110,10 +112,10 @@ async def create_resource( # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - # Call function and check if result is a coroutine - result = self.fn(**params) - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**params) + else: + result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) return FunctionResource( uri=uri, # type: ignore diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 42aecd6e3..04763be8b 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -55,11 +55,10 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it - if inspect.iscoroutine(result): - result = await result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn() + else: + result = await anyio.to_thread.run_sync(self.fn) if isinstance(result, Resource): # pragma: no cover return await result.read() diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index fe18e91bd..d4e4e6b5a 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -1,3 +1,4 @@ +import threading from typing import Any import pytest @@ -190,3 +191,21 @@ async def fn() -> dict[str, Any]: ) ) ] + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync prompt functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "hello" + + prompt = Prompt.from_function(blocking_fn) + messages = await prompt.render(None, Context()) + + assert messages == [UserMessage(content=TextContent(type="text", text="hello"))] + assert fn_thread[0] != main_thread diff --git a/tests/server/mcpserver/resources/test_function_resources.py b/tests/server/mcpserver/resources/test_function_resources.py index 5f5c216ed..c1ff96061 100644 --- a/tests/server/mcpserver/resources/test_function_resources.py +++ b/tests/server/mcpserver/resources/test_function_resources.py @@ -1,3 +1,7 @@ +import threading + +import anyio +import anyio.from_thread import pytest from pydantic import BaseModel @@ -190,3 +194,51 @@ def get_data() -> str: # pragma: no cover ) assert resource.meta is None + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync resource functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn() -> str: + fn_thread.append(threading.get_ident()) + return "data" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result = await resource.read() + + assert result == "data" + assert fn_thread[0] != main_thread + + +@pytest.mark.anyio +async def test_sync_fn_does_not_block_event_loop(): + """A blocking sync resource function must not stall the event loop. + + On regression (sync runs inline), anyio.from_thread.run_sync raises + RuntimeError because there is no worker-thread context, failing fast. + """ + handler_entered = anyio.Event() + release = threading.Event() + + def blocking_fn() -> str: + anyio.from_thread.run_sync(handler_entered.set) + release.wait() + return "done" + + resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn) + result: list[str | bytes] = [] + + async def run() -> None: + result.append(await resource.read()) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(run) + await handler_entered.wait() + release.set() + + assert result == ["done"] diff --git a/tests/server/mcpserver/resources/test_resource_template.py b/tests/server/mcpserver/resources/test_resource_template.py index 640cfe803..2a7ba8d50 100644 --- a/tests/server/mcpserver/resources/test_resource_template.py +++ b/tests/server/mcpserver/resources/test_resource_template.py @@ -1,4 +1,5 @@ import json +import threading from typing import Any import pytest @@ -310,3 +311,22 @@ def get_item(item_id: str) -> str: assert resource.meta == metadata assert resource.meta["category"] == "inventory" assert resource.meta["cacheable"] is True + + +@pytest.mark.anyio +async def test_sync_fn_runs_in_worker_thread(): + """Sync template functions must run in a worker thread, not the event loop.""" + + main_thread = threading.get_ident() + fn_thread: list[int] = [] + + def blocking_fn(name: str) -> str: + fn_thread.append(threading.get_ident()) + return f"hello {name}" + + template = ResourceTemplate.from_function(fn=blocking_fn, uri_template="test://{name}") + resource = await template.create_resource("test://world", {"name": "world"}, Context()) + + assert isinstance(resource, FunctionResource) + assert await resource.read() == "hello world" + assert fn_thread[0] != main_thread