diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 0857e38bd..c331d711c 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -4,6 +4,16 @@ from .context import Context from .server import MCPServer +from .utilities.roots import assert_within_roots, get_roots, within_roots_check from .utilities.types import Audio, Image -__all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] +__all__ = [ + "MCPServer", + "Context", + "Image", + "Audio", + "Icon", + "assert_within_roots", + "get_roots", + "within_roots_check", +] diff --git a/src/mcp/server/mcpserver/utilities/roots.py b/src/mcp/server/mcpserver/utilities/roots.py new file mode 100644 index 000000000..328aeaefb --- /dev/null +++ b/src/mcp/server/mcpserver/utilities/roots.py @@ -0,0 +1,127 @@ +"""Reusable roots enforcement utilities for MCPServer. + +Roots define filesystem boundaries that the MCP client declares to the server. +The MCP spec does not auto-enforce these — servers must do it themselves. +This module provides a simple reusable way to do that without rewriting +the logic in every server. + +Usage: + from mcp.server.mcpserver import Context, MCPServer + from mcp.server.mcpserver.utilities.roots import ( + get_roots, + assert_within_roots, + within_roots_check, + ) + + mcp = MCPServer("my-server") + + @mcp.tool() + async def read_file(path: str, ctx: Context) -> str: + await assert_within_roots(path, ctx) + return open(path).read() +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import TYPE_CHECKING, ParamSpec, TypeVar + +if TYPE_CHECKING: + from mcp.server.mcpserver import Context + +P = ParamSpec("P") +R = TypeVar("R") + + +async def get_roots(ctx: Context) -> list[str]: + """Fetch the list of root URIs declared by the connected client. + + Returns a list of URI strings e.g. ["file:///home/user/project"]. + Returns an empty list if the client declared no roots or does not + support the roots capability. + + Args: + ctx: The MCPServer Context object available inside any tool. + + Example: + @mcp.tool() + async def my_tool(ctx: Context) -> str: + roots = await get_roots(ctx) + return str(roots) + """ + try: + result = await ctx.session.list_roots() + return [str(root.uri) for root in result.roots] + except Exception: + return [] + + +async def assert_within_roots(path: str | Path, ctx: Context) -> None: + """Raise PermissionError if path falls outside all client-declared roots. + + If the client declared no roots this is a no-op — no restriction applied. + Only file:// URIs are checked. Non-file roots are skipped. + + Args: + path: The filesystem path your tool wants to access. + ctx: The MCPServer Context object available inside any tool. + + Raises: + PermissionError: If the resolved path is outside all declared roots. + + Example: + @mcp.tool() + async def read_file(path: str, ctx: Context) -> str: + await assert_within_roots(path, ctx) + return open(path).read() + """ + roots = await get_roots(ctx) + + if not roots: + return + + file_roots = [str(Path(r.removeprefix("file://")).resolve()) for r in roots if r.startswith("file://")] + + if not file_roots: + return + + resolved = str(Path(path).resolve()) + + if not any(resolved.startswith(root) for root in file_roots): + raise PermissionError(f"Access denied: '{resolved}' is outside the allowed roots.\nAllowed roots: {file_roots}") + + +def within_roots_check(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + """Auto-enforce roots on any tool parameter named 'path' or ending with '_path'. + + Requires the tool to also accept a `ctx: Context` parameter. + + Example: + @mcp.tool() + @within_roots_check + async def read_file(path: str, ctx: Context) -> str: + return open(path).read() + """ + + @functools.wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + sig = inspect.signature(fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + arguments = bound.arguments + + ctx = arguments.get("ctx") + if ctx is None: + raise ValueError("@within_roots_check requires the tool to have a `ctx: Context` parameter.") + + for param_name, value in arguments.items(): + if value and isinstance(value, str | Path): + if param_name == "path" or param_name.endswith("_path"): + await assert_within_roots(value, ctx) + + return await fn(*args, **kwargs) + + return wrapper diff --git a/tests/server/mcpserver/test_roots.py b/tests/server/mcpserver/test_roots.py new file mode 100644 index 000000000..4bfe119d1 --- /dev/null +++ b/tests/server/mcpserver/test_roots.py @@ -0,0 +1,176 @@ +"""Tests for mcp.server.mcpserver.utilities.roots.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from mcp.server.mcpserver.utilities.roots import ( + assert_within_roots, + get_roots, + within_roots_check, +) + +pytestmark = pytest.mark.anyio + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_ctx(root_uris: list[str]) -> MagicMock: + root_objects = [MagicMock(uri=uri) for uri in root_uris] + list_roots_result = MagicMock() + list_roots_result.roots = root_objects + session = MagicMock() + session.list_roots = AsyncMock(return_value=list_roots_result) + ctx = MagicMock() + ctx.session = session + return ctx + + +def make_failing_ctx() -> MagicMock: + session = MagicMock() + session.list_roots = AsyncMock(side_effect=Exception("not supported")) + ctx = MagicMock() + ctx.session = session + return ctx + + +# --------------------------------------------------------------------------- +# get_roots +# --------------------------------------------------------------------------- + + +async def test_get_roots_returns_uris(): + ctx = make_ctx(["file:///home/user/project", "file:///tmp/work"]) + result = await get_roots(ctx) + assert result == ["file:///home/user/project", "file:///tmp/work"] + + +async def test_get_roots_returns_empty_when_no_roots(): + ctx = make_ctx([]) + result = await get_roots(ctx) + assert result == [] + + +async def test_get_roots_returns_empty_on_exception(): + ctx = make_failing_ctx() + result = await get_roots(ctx) + assert result == [] + + +# --------------------------------------------------------------------------- +# assert_within_roots +# --------------------------------------------------------------------------- + + +async def test_assert_passes_when_no_roots(): + ctx = make_ctx([]) + await assert_within_roots("/any/path/at/all", ctx) + + +async def test_assert_passes_when_path_inside_root(): + ctx = make_ctx(["file:///home/user/project"]) + await assert_within_roots("/home/user/project/src/main.py", ctx) + + +async def test_assert_raises_when_path_outside_root(): + ctx = make_ctx(["file:///home/user/project"]) + with pytest.raises(PermissionError, match="Access denied"): + await assert_within_roots("/etc/passwd", ctx) + + +async def test_assert_passes_with_multiple_roots_matching_second(): + ctx = make_ctx(["file:///home/user/project", "file:///tmp/work"]) + await assert_within_roots("/tmp/work/file.txt", ctx) + + +async def test_assert_raises_outside_all_roots(): + ctx = make_ctx(["file:///home/user/project", "file:///tmp/work"]) + with pytest.raises(PermissionError): + await assert_within_roots("/var/log/syslog", ctx) + + +async def test_assert_accepts_pathlib_path(): + ctx = make_ctx(["file:///home/user/project"]) + await assert_within_roots(Path("/home/user/project/file.txt"), ctx) + + +async def test_assert_skips_non_file_roots(): + ctx = make_ctx(["https://api.example.com/v1"]) + await assert_within_roots("/any/local/path", ctx) + + +async def test_assert_no_raise_when_client_doesnt_support_roots(): + ctx = make_failing_ctx() + await assert_within_roots("/any/path", ctx) + + +# --------------------------------------------------------------------------- +# within_roots_check decorator +# --------------------------------------------------------------------------- + + +async def test_decorator_passes_inside_root(): + ctx = make_ctx(["file:///home/user/project"]) + + @within_roots_check + async def read_file(path: str, ctx: MagicMock) -> str: + return "file contents" + + result = await read_file(path="/home/user/project/notes.txt", ctx=ctx) + assert result == "file contents" + + +async def test_decorator_raises_outside_root(): + ctx = make_ctx(["file:///home/user/project"]) + + @within_roots_check + async def read_file(path: str, ctx: MagicMock) -> str: + raise AssertionError("tool body must not run when decorator denies access") # pragma: no cover + + with pytest.raises(PermissionError): + await read_file(path="/etc/passwd", ctx=ctx) + + +async def test_decorator_checks_star_path_params(): + ctx = make_ctx(["file:///home/user/project"]) + + @within_roots_check + async def copy_file(source_path: str, dest_path: str, ctx: MagicMock) -> str: + raise AssertionError("tool body must not run when decorator denies access") # pragma: no cover + + with pytest.raises(PermissionError): + await copy_file( + source_path="/home/user/project/file.txt", + dest_path="/etc/shadow", + ctx=ctx, + ) + + +async def test_decorator_ignores_non_path_string_params(): + ctx = make_ctx(["file:///home/user/project"]) + + @within_roots_check + async def tool(name: str, path: str, ctx: MagicMock) -> str: + return f"{name}:{path}" + + result = await tool( + name="greeting", + path="/home/user/project/file.txt", + ctx=ctx, + ) + assert result == "greeting:/home/user/project/file.txt" + + +async def test_decorator_raises_without_ctx(): + @within_roots_check + async def bad_tool(path: str) -> str: + raise AssertionError("tool body must not run when ctx is missing") # pragma: no cover + + with pytest.raises(ValueError, match="ctx"): + await bad_tool(path="/some/path")