diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff..f72ea8dc1 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -17,9 +17,11 @@ async def run_server(): ``` """ +import os import sys from contextlib import asynccontextmanager -from io import TextIOWrapper +from io import TextIOWrapper, UnsupportedOperation +from typing import TextIO import anyio import anyio.lowlevel @@ -29,6 +31,20 @@ async def run_server(): from mcp.shared.message import SessionMessage +def _wrap_stdio_text_stream(stream: TextIO, mode: str, errors: str = "strict") -> anyio.AsyncFile[str]: + """Wrap a stdio text stream without closing the original handle on teardown.""" + try: + wrapped_stream = TextIOWrapper( + os.fdopen(os.dup(stream.fileno()), mode, closefd=True), + encoding="utf-8", + errors=errors, + ) + except (AttributeError, UnsupportedOperation): + wrapped_stream = TextIOWrapper(stream.buffer, encoding="utf-8", errors=errors) + + return anyio.wrap_file(wrapped_stream) + + @asynccontextmanager async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): """Server transport for stdio: this communicates with an MCP client by reading @@ -38,10 +54,13 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # standard process handles. Encoding of stdin/stdout as text streams on # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. + close_stdin = stdin is None + close_stdout = stdout is None + if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) + stdin = _wrap_stdio_text_stream(sys.stdin, "rb", errors="replace") if not stdout: - stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) + stdout = _wrap_stdio_text_stream(sys.stdout, "wb") read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) write_stream, write_stream_reader = create_context_streams[SessionMessage](0) @@ -71,7 +90,13 @@ async def stdout_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream + try: + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream + finally: + if close_stdin: + await stdin.aclose() + if close_stdout: + await stdout.aclose() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a99356..fddf80654 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,5 +1,6 @@ import io import sys +import tempfile from io import TextIOWrapper import anyio @@ -73,12 +74,15 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): """ # \xff\xfe are invalid UTF-8 start bytes. valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin = tempfile.TemporaryFile("w+b") + raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin.seek(0) + raw_stdout = tempfile.TemporaryFile("w+b") # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that # stdio_server()'s default path wraps it with errors='replace'. monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) - monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(raw_stdout, encoding="utf-8")) with anyio.fail_after(5): async with stdio_server() as (read_stream, write_stream): @@ -92,3 +96,60 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): second = await read_stream.receive() assert isinstance(second, SessionMessage) assert second.message == valid + + sys.stdin.close() + sys.stdout.close() + + +@pytest.mark.anyio +async def test_stdio_server_does_not_close_process_stdio(monkeypatch: pytest.MonkeyPatch): + """Default stdio_server() teardown must not close the caller's stdio handles.""" + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = tempfile.TemporaryFile("w+b") + raw_stdin.write(valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin.seek(0) + raw_stdout = tempfile.TemporaryFile("w+b") + + monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(raw_stdout, encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert received.message == valid + + assert not sys.stdin.closed + assert not sys.stdout.closed + + sys.stdout.write("still-open") + sys.stdout.flush() + raw_stdout.seek(0) + assert raw_stdout.read() == b"still-open" + + sys.stdin.close() + sys.stdout.close() + + +@pytest.mark.anyio +async def test_stdio_server_falls_back_when_stream_has_no_fileno(monkeypatch: pytest.MonkeyPatch): + """Streams without a real fd (e.g. pytest capture, in-memory buffers) must + fall back to wrapping the underlying ``.buffer`` instead of crashing.""" + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + stdin_buf = io.BytesIO(valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + stdout_buf = io.BytesIO() + + # io.BytesIO raises UnsupportedOperation from .fileno(), forcing the + # buffer-wrapping fallback in _wrap_stdio_text_stream. + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_buf, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(stdout_buf, encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert received.message == valid