diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..290c9f304 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -94,6 +94,10 @@ def elicitations(self) -> list[ElicitRequestURLParams]: """The list of URL elicitations required before the request can proceed.""" return self._elicitations + def __reduce__(self) -> tuple[type, tuple[list[ElicitRequestURLParams], str]]: + """Support pickling by reconstructing with the original __init__ signature.""" + return (self.__class__, (self._elicitations, self.error.message)) + @classmethod def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError: """Reconstruct from an ErrorData received over the wire.""" diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 9a7466264..92ea6d8a0 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -1,5 +1,7 @@ """Tests for MCP exception classes.""" +import pickle + import pytest from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError @@ -162,3 +164,61 @@ def test_url_elicitation_required_error_exception_message() -> None: # The exception's string representation should match the message assert str(error) == "URL elicitation required" + + +def test_mcp_error_pickle_roundtrip() -> None: + """Test that MCPError survives a pickle round-trip.""" + original = MCPError(code=-32600, message="Invalid request", data={"detail": "bad"}) + + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is MCPError + assert restored.error.code == original.error.code + assert restored.error.message == original.error.message + assert restored.error.data == original.error.data + + +def test_url_elicitation_required_error_pickle_roundtrip() -> None: + """Test that UrlElicitationRequiredError survives a pickle round-trip.""" + elicitations = [ + ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ), + ] + original = UrlElicitationRequiredError(elicitations, message="Please authenticate") + + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is UrlElicitationRequiredError + assert restored.error.code == URL_ELICITATION_REQUIRED + assert restored.error.message == "Please authenticate" + assert len(restored.elicitations) == 1 + assert restored.elicitations[0].elicitation_id == "test-123" + assert restored.elicitations[0].url == "https://example.com/auth" + + +def test_url_elicitation_required_error_pickle_default_message() -> None: + """Test pickle round-trip preserves the auto-generated default message.""" + elicitations = [ + ElicitRequestURLParams( + mode="url", + message="Auth", + url="https://example.com/auth", + elicitation_id="e1", + ), + ElicitRequestURLParams( + mode="url", + message="Auth2", + url="https://example.com/auth2", + elicitation_id="e2", + ), + ] + original = UrlElicitationRequiredError(elicitations) + + restored = pickle.loads(pickle.dumps(original)) + + assert restored.error.message == "URL elicitations required" + assert len(restored.elicitations) == 2