From bd5faee1916d9e3afc4b8afe601a35bfe9074032 Mon Sep 17 00:00:00 2001 From: Aashish Ghimire Date: Sat, 25 Apr 2026 22:06:42 -0700 Subject: [PATCH 1/2] Fix stdio_server closing process stdio --- src/mcp/server/stdio.py | 38 ++++++++++++++++++++++++++------- tests/server/test_stdio.py | 43 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff..aafe3cd95 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -17,9 +17,10 @@ async def run_server(): ``` """ +import os import sys from contextlib import asynccontextmanager -from io import TextIOWrapper +from io import TextIOWrapper, UnsupportedOperation import anyio import anyio.lowlevel @@ -29,6 +30,20 @@ async def run_server(): from mcp.shared.message import SessionMessage +def _wrap_stdio_text_stream(stream: TextIOWrapper, mode: str, errors: str = "strict") -> anyio.AsyncFile[str]: + """Wrap a stdio text stream without closing the original handle on teardown.""" + try: + wrapped_stream = TextIOWrapper( + os.fdopen(os.dup(stream.fileno()), mode, closefd=True), + encoding="utf-8", + errors=errors, + ) + except (AttributeError, UnsupportedOperation): + wrapped_stream = TextIOWrapper(stream.buffer, encoding="utf-8", errors=errors) + + return anyio.wrap_file(wrapped_stream) + + @asynccontextmanager async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None): """Server transport for stdio: this communicates with an MCP client by reading @@ -38,10 +53,13 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # standard process handles. Encoding of stdin/stdout as text streams on # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. + close_stdin = stdin is None + close_stdout = stdout is None + if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) + stdin = _wrap_stdio_text_stream(sys.stdin, "rb", errors="replace") if not stdout: - stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) + stdout = _wrap_stdio_text_stream(sys.stdout, "wb") read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) write_stream, write_stream_reader = create_context_streams[SessionMessage](0) @@ -71,7 +89,13 @@ async def stdout_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream + try: + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream + finally: + if close_stdin: + await stdin.aclose() + if close_stdout: + await stdout.aclose() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a99356..b8075ef2e 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,5 +1,6 @@ import io import sys +import tempfile from io import TextIOWrapper import anyio @@ -73,12 +74,15 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): """ # \xff\xfe are invalid UTF-8 start bytes. valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin = tempfile.TemporaryFile("w+b") + raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin.seek(0) + raw_stdout = tempfile.TemporaryFile("w+b") # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that # stdio_server()'s default path wraps it with errors='replace'. monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) - monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(raw_stdout, encoding="utf-8")) with anyio.fail_after(5): async with stdio_server() as (read_stream, write_stream): @@ -92,3 +96,38 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): second = await read_stream.receive() assert isinstance(second, SessionMessage) assert second.message == valid + + sys.stdin.close() + sys.stdout.close() + + +@pytest.mark.anyio +async def test_stdio_server_does_not_close_process_stdio(monkeypatch: pytest.MonkeyPatch): + """Default stdio_server() teardown must not close the caller's stdio handles.""" + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + raw_stdin = tempfile.TemporaryFile("w+b") + raw_stdin.write(valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin.seek(0) + raw_stdout = tempfile.TemporaryFile("w+b") + + monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(raw_stdout, encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert received.message == valid + + assert not sys.stdin.closed + assert not sys.stdout.closed + + sys.stdout.write("still-open") + sys.stdout.flush() + raw_stdout.seek(0) + assert raw_stdout.read() == b"still-open" + + sys.stdin.close() + sys.stdout.close() From f9fa9c728c54919953af17949a1e887adcd6e608 Mon Sep 17 00:00:00 2001 From: Aashish Ghimire Date: Sat, 25 Apr 2026 22:52:49 -0700 Subject: [PATCH 2/2] fix(stdio): satisfy pyright and cover dup-fd fallback - Widen `_wrap_stdio_text_stream`'s parameter type from `TextIOWrapper` to `typing.TextIO` so `sys.stdin`/`sys.stdout` (typed `TextIO` in typeshed) pass type checking. - Add a regression test using `io.BytesIO`-backed streams to cover the `AttributeError`/`UnsupportedOperation` fallback path (needed to keep coverage at 100%). --- src/mcp/server/stdio.py | 3 ++- tests/server/test_stdio.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index aafe3cd95..f72ea8dc1 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -21,6 +21,7 @@ async def run_server(): import sys from contextlib import asynccontextmanager from io import TextIOWrapper, UnsupportedOperation +from typing import TextIO import anyio import anyio.lowlevel @@ -30,7 +31,7 @@ async def run_server(): from mcp.shared.message import SessionMessage -def _wrap_stdio_text_stream(stream: TextIOWrapper, mode: str, errors: str = "strict") -> anyio.AsyncFile[str]: +def _wrap_stdio_text_stream(stream: TextIO, mode: str, errors: str = "strict") -> anyio.AsyncFile[str]: """Wrap a stdio text stream without closing the original handle on teardown.""" try: wrapped_stream = TextIOWrapper( diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index b8075ef2e..fddf80654 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -131,3 +131,25 @@ async def test_stdio_server_does_not_close_process_stdio(monkeypatch: pytest.Mon sys.stdin.close() sys.stdout.close() + + +@pytest.mark.anyio +async def test_stdio_server_falls_back_when_stream_has_no_fileno(monkeypatch: pytest.MonkeyPatch): + """Streams without a real fd (e.g. pytest capture, in-memory buffers) must + fall back to wrapping the underlying ``.buffer`` instead of crashing.""" + valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + stdin_buf = io.BytesIO(valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + stdout_buf = io.BytesIO() + + # io.BytesIO raises UnsupportedOperation from .fileno(), forcing the + # buffer-wrapping fallback in _wrap_stdio_text_stream. + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_buf, encoding="utf-8")) + monkeypatch.setattr(sys, "stdout", TextIOWrapper(stdout_buf, encoding="utf-8")) + + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert received.message == valid