Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions src/trawl_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
"""Stdio/HTTP MCP server exposing trawl's fetch_page and profile_page tools.

The pipeline uses sync_playwright internally, which can't run inside an
asyncio event loop on its own. We run each pipeline invocation in a worker
thread via `asyncio.to_thread`, keeping the MCP server responsive.
asyncio event loop on its own. We run every pipeline invocation on a
single dedicated worker thread so the process-wide sync_playwright
greenlet dispatcher — which is pinned to the thread that first called
sync_playwright() — always sees the same thread. Using
`asyncio.to_thread` (default executor) instead causes intermittent
"Cannot switch to a different thread" greenlet errors whenever a call
is dispatched to a different worker thread.
"""

from __future__ import annotations

import asyncio
import functools
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor

from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool

from trawl import fetch_relevant, to_dict

_pipeline_executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="trawl-pipeline"
)

logger = logging.getLogger("trawl_mcp")

server: Server = Server("trawl")
Expand Down Expand Up @@ -166,13 +177,17 @@ async def _call_fetch_page(arguments: dict) -> list[TextContent]:
use_hyde,
use_rerank,
)
result = await asyncio.to_thread(
fetch_relevant,
url,
query,
k=k,
use_hyde=use_hyde,
use_rerank=use_rerank,
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
_pipeline_executor,
functools.partial(
fetch_relevant,
url,
query,
k=k,
use_hyde=use_hyde,
use_rerank=use_rerank,
),
)
payload = to_dict(result)
payload["ok"] = not bool(payload.get("error"))
Expand All @@ -195,10 +210,14 @@ async def _call_profile_page(arguments: dict) -> list[TextContent]:
# when the tool is actually called.
from trawl.profiles import generate_profile

payload = await asyncio.to_thread(
generate_profile,
url,
force_refresh=force_refresh,
loop = asyncio.get_running_loop()
payload = await loop.run_in_executor(
_pipeline_executor,
functools.partial(
generate_profile,
url,
force_refresh=force_refresh,
),
)
return [TextContent(type="text", text=json.dumps(payload, ensure_ascii=False))]

Expand Down
Loading