diff --git a/mlx_vlm/responses_models.py b/mlx_vlm/responses_models.py
new file mode 100644
index 000000000..ecd4d6fee
--- /dev/null
+++ b/mlx_vlm/responses_models.py
@@ -0,0 +1,497 @@
+"""Pydantic models for the OpenAI Responses API (/v1/responses).
+
+This module defines all request, response, and streaming event models
+for the OpenAI-compatible Responses endpoint. Models are self-contained
+to avoid circular imports with server.py.
+
+Reference: https://developers.openai.com/api/reference/resources/responses
+"""
+
+import uuid
+from typing import Any, List, Literal, Optional, Union
+
+from pydantic import BaseModel, ConfigDict, Field
+from typing_extensions import Required, TypeAlias, TypedDict
+
+# ---------------------------------------------------------------------------
+# Constants (mirrored from mlx_vlm.generate to avoid heavy imports)
+# ---------------------------------------------------------------------------
+
+DEFAULT_MAX_TOKENS = 4096
+DEFAULT_TEMPERATURE = 0.0
+DEFAULT_TOP_P = 1.0
+DEFAULT_THINKING_START_TOKEN = ""
+DEFAULT_THINKING_END_TOKEN = ""
+
+
+# ---------------------------------------------------------------------------
+# Base classes (duplicated from server.py for import independence)
+# ---------------------------------------------------------------------------
+
+
+class FlexibleBaseModel(BaseModel):
+ """Base model that silently accepts unknown fields for forward compatibility."""
+
+ model_config = ConfigDict(extra="allow")
+
+ def dump_kwargs(self, *fields: str) -> dict[str, Any]:
+ """Return a dict of the requested fields, omitting ``None`` values."""
+ return {
+ key: getattr(self, key)
+ for key in fields
+ if hasattr(self, key) and getattr(self, key) is not None
+ }
+
+
+class GenerationParams(FlexibleBaseModel):
+ """Sampling parameters shared across endpoints."""
+
+ temperature: float = Field(
+ DEFAULT_TEMPERATURE, description="Temperature for sampling."
+ )
+ top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.")
+ top_k: Optional[int] = Field(None, description="Top-k sampling cutoff.")
+ min_p: Optional[float] = Field(None, description="Min-p sampling threshold.")
+ repetition_penalty: Optional[float] = Field(
+ None, description="Penalty applied to repeated tokens."
+ )
+ logit_bias: Optional[dict[int, float]] = Field(
+ None, description="Additive logit bias keyed by token id."
+ )
+
+ def shared_generation_kwargs(self) -> dict[str, Any]:
+ return self.dump_kwargs(
+ "temperature",
+ "top_p",
+ "top_k",
+ "min_p",
+ "repetition_penalty",
+ "logit_bias",
+ )
+
+
+class TemplateParams(FlexibleBaseModel):
+ """Chat template parameters (thinking mode, etc.)."""
+
+ enable_thinking: Optional[bool] = Field(
+ None, description="Enable thinking mode in the chat template."
+ )
+ thinking_budget: Optional[int] = Field(
+ None,
+ description="Maximum number of thinking tokens before forcing the end token.",
+ )
+ thinking_start_token: Optional[str] = Field(
+ DEFAULT_THINKING_START_TOKEN,
+ description="Token that marks the start of a thinking block.",
+ )
+ thinking_end_token: Optional[str] = Field(
+ DEFAULT_THINKING_END_TOKEN,
+ description="Token that marks the end of a thinking block.",
+ )
+
+ def template_kwargs(self) -> dict[str, Any]:
+ kwargs = self.dump_kwargs(
+ "enable_thinking",
+ "thinking_budget",
+ "thinking_start_token",
+ "thinking_end_token",
+ )
+ kwargs.setdefault("enable_thinking", False)
+ return kwargs
+
+
+# ---------------------------------------------------------------------------
+# Input content types (TypedDicts matching OpenAI SDK)
+# ---------------------------------------------------------------------------
+
+
+class ResponseInputTextParam(TypedDict, total=False):
+ """Text content item — accepts both ``input_text`` and ``text`` types."""
+
+ text: Required[str]
+ type: Required[Literal["input_text", "text"]]
+
+
+class ResponseInputImageParam(TypedDict, total=False):
+ """Image content item with a direct image URL."""
+
+ detail: Literal["high", "low", "auto"]
+ type: Required[Literal["input_image"]]
+ image_url: Required[str]
+ file_id: Optional[str]
+
+
+class InputAudio(TypedDict, total=False):
+ data: Required[str]
+ format: Required[str]
+
+
+class ResponseInputAudioParam(TypedDict, total=False):
+ """Audio content item."""
+
+ type: Required[Literal["input_audio"]]
+ input_audio: Required[InputAudio]
+
+
+class ImageUrl(TypedDict, total=False):
+ url: Required[str]
+
+
+class ResponseImageUrlParam(TypedDict, total=False):
+ """Image content item with nested ``image_url.url`` (chat/completions format)."""
+
+ type: Required[Literal["image_url"]]
+ image_url: Required[ImageUrl]
+
+
+class ResponseOutputText(TypedDict, total=False):
+ """Output text item used in multi-turn assistant messages."""
+
+ text: Required[str]
+ type: Required[Literal["output_text"]]
+
+
+ResponseInputContentParam: TypeAlias = Union[
+ ResponseInputTextParam,
+ ResponseInputImageParam,
+ ResponseImageUrlParam,
+ ResponseInputAudioParam,
+]
+
+ResponseInputMessageContentListParam: TypeAlias = List[ResponseInputContentParam]
+ResponseOutputMessageContentList: TypeAlias = List[ResponseOutputText]
+
+
+# ---------------------------------------------------------------------------
+# Chat message model
+# ---------------------------------------------------------------------------
+
+
+class ChatMessage(FlexibleBaseModel):
+ """A single message in the conversation input."""
+
+ role: Literal["user", "assistant", "system", "developer", "tool"] = Field(
+ ..., description="Role of the message sender."
+ )
+ content: Optional[
+ Union[
+ str,
+ ResponseInputMessageContentListParam,
+ ResponseOutputMessageContentList,
+ ]
+ ] = Field(None, description="Content of the message.")
+ tool_calls: List = Field(default_factory=list)
+
+
+# ---------------------------------------------------------------------------
+# Function tool definition
+# ---------------------------------------------------------------------------
+
+
+class ResponseFunctionTool(BaseModel):
+ """A function tool the model may call."""
+
+ type: Literal["function"] = "function"
+ name: str = Field(..., description="The name of the function.")
+ description: Optional[str] = Field(
+ None, description="A description of what the function does."
+ )
+ parameters: Optional[dict] = Field(
+ None, description="JSON Schema object describing the function parameters."
+ )
+ strict: Optional[bool] = Field(
+ None, description="Whether to enforce strict schema adherence."
+ )
+
+
+# ---------------------------------------------------------------------------
+# Function call input items (for multi-turn tool use)
+# ---------------------------------------------------------------------------
+
+
+class ResponseFunctionCallInputItem(BaseModel):
+ """A function call from a previous assistant turn, included in input."""
+
+ type: Literal["function_call"] = "function_call"
+ call_id: str = Field(..., description="Unique ID for this tool call.")
+ name: str = Field(..., description="The function name that was called.")
+ arguments: str = Field(..., description="JSON string of the function arguments.")
+ status: Optional[str] = "completed"
+
+
+class ResponseFunctionCallOutputInputItem(BaseModel):
+ """The output/result of a function call, sent back by the client."""
+
+ type: Literal["function_call_output"] = "function_call_output"
+ call_id: str = Field(
+ ..., description="The call_id of the function call this is a result for."
+ )
+ output: str = Field(..., description="The function output as a string.")
+
+
+# ---------------------------------------------------------------------------
+# Request model
+# ---------------------------------------------------------------------------
+
+
+class ResponsesRequest(GenerationParams, TemplateParams):
+ """OpenAI Responses API request body.
+
+ Reference: https://developers.openai.com/api/reference/resources/responses/create
+ """
+
+ input: Union[str, List[Any]] = Field(
+ ..., description="Input text or list of input items (messages, tool outputs)."
+ )
+ model: str = Field(..., description="The model to use for generation.")
+ max_output_tokens: int = Field(
+ DEFAULT_MAX_TOKENS, description="Maximum number of tokens to generate."
+ )
+ stream: bool = Field(
+ False, description="Whether to stream the response chunk by chunk."
+ )
+ tools: Optional[List[dict]] = Field(
+ None, description="Tool definitions the model may call."
+ )
+ tool_choice: Optional[Any] = Field(
+ "auto", description='Tool choice: "none", "auto", "required", or specific tool.'
+ )
+ parallel_tool_calls: bool = Field(
+ True, description="Allow parallel tool calls."
+ )
+ previous_response_id: Optional[str] = Field(
+ None,
+ description="ID of a previous response for multi-turn context replay.",
+ )
+ instructions: Optional[str] = Field(
+ None,
+ description="System/developer message inserted into context.",
+ )
+ metadata: Optional[dict] = Field(
+ None, description="Up to 16 key-value pairs of metadata."
+ )
+
+ def generation_kwargs(self) -> dict[str, Any]:
+ kwargs = self.dump_kwargs("max_output_tokens")
+ kwargs["max_tokens"] = kwargs.pop("max_output_tokens")
+ return {**kwargs, **self.shared_generation_kwargs()}
+
+
+# ---------------------------------------------------------------------------
+# Output item models
+# ---------------------------------------------------------------------------
+
+
+class ContentPartOutputText(BaseModel):
+ """A text content part in an output message."""
+
+ type: Literal["output_text"] = "output_text"
+ text: str = ""
+ annotations: List[str] = Field(default_factory=list)
+
+
+class ResponseMessageItem(BaseModel):
+ """An assistant message output item."""
+
+ id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:24]}")
+ type: Literal["message"] = "message"
+ role: Literal["assistant"] = "assistant"
+ status: Literal["in_progress", "completed"] = "completed"
+ content: List[ContentPartOutputText] = Field(default_factory=list)
+
+
+class ResponseFunctionCallItem(BaseModel):
+ """A function call output item."""
+
+ type: Literal["function_call"] = "function_call"
+ id: str = Field(default_factory=lambda: f"fc_{uuid.uuid4().hex[:24]}")
+ call_id: str = Field(default_factory=lambda: f"call_{uuid.uuid4().hex[:24]}")
+ name: str = Field(..., description="The function name being called.")
+ arguments: str = Field(..., description="JSON string of the function arguments.")
+ status: Literal["completed"] = "completed"
+
+
+class ResponseIncompleteDetails(BaseModel):
+ """Details about why a response is incomplete."""
+
+ reason: Literal["max_output_tokens", "content_filter"]
+
+
+# ---------------------------------------------------------------------------
+# Usage and error models
+# ---------------------------------------------------------------------------
+
+
+class ResponseUsage(BaseModel):
+ """Token usage details."""
+
+ input_tokens: int
+ output_tokens: int
+ total_tokens: int
+
+
+class ResponseErrorObject(BaseModel):
+ """Error object returned when the model fails to generate a Response."""
+
+ code: Optional[str] = None
+ message: Optional[str] = None
+ param: Optional[str] = None
+ type: Optional[str] = None
+
+
+# ---------------------------------------------------------------------------
+# Response object
+# ---------------------------------------------------------------------------
+
+
+class ResponseObject(BaseModel):
+ """The top-level Response object returned by /v1/responses.
+
+ Reference: https://developers.openai.com/api/reference/resources/responses/object
+ """
+
+ id: str = Field(
+ default_factory=lambda: f"resp_{uuid.uuid4().hex[:24]}",
+ description="Unique identifier for this Response.",
+ )
+ object: Literal["response"] = Field(
+ "response", description="The object type — always ``response``."
+ )
+ created_at: int = Field(..., description="Unix timestamp of creation.")
+ status: Literal["completed", "failed", "in_progress", "incomplete"] = Field(
+ "completed", description="The status of the response generation."
+ )
+ error: Optional[ResponseErrorObject] = Field(None)
+ incomplete_details: Optional[ResponseIncompleteDetails] = Field(None)
+ instructions: Optional[str] = Field(None)
+ max_output_tokens: Optional[int] = Field(None)
+ model: str = Field(..., description="Model ID used to generate the response.")
+ output: List[Union[ResponseMessageItem, ResponseFunctionCallItem]] = Field(
+ default_factory=list,
+ description="An array of content items generated by the model.",
+ )
+ parallel_tool_calls: bool = Field(True)
+ previous_response_id: Optional[str] = Field(None)
+ temperature: Optional[float] = Field(None, ge=0, le=2)
+ top_p: Optional[float] = Field(None, ge=0, le=1)
+ tools: List = Field(default_factory=list)
+ tool_choice: Optional[Any] = Field("auto")
+ truncation: Literal["auto", "disabled"] = Field("disabled")
+ metadata: Optional[dict] = Field(None)
+ usage: ResponseUsage = Field(..., description="Token usage details.")
+ user: Optional[str] = Field(None)
+
+ @property
+ def output_text(self) -> str:
+ """Aggregate text from all output_text content parts."""
+ parts = []
+ for item in self.output:
+ if isinstance(item, ResponseMessageItem):
+ for part in item.content:
+ if part.type == "output_text" and part.text:
+ parts.append(part.text)
+ return "".join(parts) or ""
+
+
+# ---------------------------------------------------------------------------
+# Streaming event models
+# ---------------------------------------------------------------------------
+
+
+class BaseStreamEvent(BaseModel):
+ """Base class for all SSE streaming events."""
+
+ type: str
+ sequence_number: int = 0
+
+
+class ResponseCreatedEvent(BaseStreamEvent):
+ type: Literal["response.created"] = "response.created"
+ response: ResponseObject
+
+
+class ResponseInProgressEvent(BaseStreamEvent):
+ type: Literal["response.in_progress"] = "response.in_progress"
+ response: ResponseObject
+
+
+class ResponseOutputItemAddedEvent(BaseStreamEvent):
+ type: Literal["response.output_item.added"] = "response.output_item.added"
+ output_index: int
+ item: Union[ResponseMessageItem, ResponseFunctionCallItem]
+
+
+class ResponseContentPartAddedEvent(BaseStreamEvent):
+ type: Literal["response.content_part.added"] = "response.content_part.added"
+ item_id: str
+ output_index: int
+ content_index: int
+ part: ContentPartOutputText
+
+
+class ResponseOutputTextDeltaEvent(BaseStreamEvent):
+ type: Literal["response.output_text.delta"] = "response.output_text.delta"
+ item_id: str
+ output_index: int
+ content_index: int
+ delta: str
+
+
+class ResponseOutputTextDoneEvent(BaseStreamEvent):
+ type: Literal["response.output_text.done"] = "response.output_text.done"
+ item_id: str
+ output_index: int
+ content_index: int
+ text: str
+
+
+class ResponseContentPartDoneEvent(BaseStreamEvent):
+ type: Literal["response.content_part.done"] = "response.content_part.done"
+ item_id: str
+ output_index: int
+ content_index: int
+ part: ContentPartOutputText
+
+
+class ResponseOutputItemDoneEvent(BaseStreamEvent):
+ type: Literal["response.output_item.done"] = "response.output_item.done"
+ output_index: int
+ item: Union[ResponseMessageItem, ResponseFunctionCallItem]
+
+
+class ResponseFunctionCallArgumentsDeltaEvent(BaseStreamEvent):
+ type: Literal["response.function_call_arguments.delta"] = (
+ "response.function_call_arguments.delta"
+ )
+ item_id: str
+ output_index: int
+ delta: str
+
+
+class ResponseFunctionCallArgumentsDoneEvent(BaseStreamEvent):
+ type: Literal["response.function_call_arguments.done"] = (
+ "response.function_call_arguments.done"
+ )
+ item_id: str
+ output_index: int
+ arguments: str
+
+
+class ResponseCompletedEvent(BaseStreamEvent):
+ type: Literal["response.completed"] = "response.completed"
+ response: ResponseObject
+
+
+StreamEvent = Union[
+ ResponseCreatedEvent,
+ ResponseInProgressEvent,
+ ResponseOutputItemAddedEvent,
+ ResponseContentPartAddedEvent,
+ ResponseOutputTextDeltaEvent,
+ ResponseOutputTextDoneEvent,
+ ResponseContentPartDoneEvent,
+ ResponseOutputItemDoneEvent,
+ ResponseFunctionCallArgumentsDeltaEvent,
+ ResponseFunctionCallArgumentsDoneEvent,
+ ResponseCompletedEvent,
+]
diff --git a/mlx_vlm/responses_store.py b/mlx_vlm/responses_store.py
new file mode 100644
index 000000000..c671b5826
--- /dev/null
+++ b/mlx_vlm/responses_store.py
@@ -0,0 +1,125 @@
+"""LRU response store for OpenAI Responses API previous_response_id support."""
+
+import threading
+from collections import OrderedDict
+from typing import Any, Optional
+
+
+class ResponseStore:
+ """Bounded LRU store mapping response IDs to (input_items, response_object) pairs.
+
+ Used to support the ``previous_response_id`` parameter in the Responses API,
+ which allows clients to chain responses without resending full conversation
+ history.
+
+ Args:
+ maxsize: Maximum number of responses to store. When exceeded, the oldest
+ entry is evicted. Defaults to 256.
+ """
+
+ def __init__(self, maxsize: int = 256):
+ self._store: OrderedDict[str, dict] = OrderedDict()
+ self._maxsize = maxsize
+ self._lock = threading.Lock()
+
+ def save(
+ self,
+ response_id: str,
+ input_items: Any,
+ response_output: list,
+ ) -> None:
+ """Save a response for later replay.
+
+ Args:
+ response_id: The unique response ID (e.g., ``"resp_abc123"``).
+ input_items: The original request input (string or list of input items).
+ response_output: The response output items list (dicts or model instances).
+ """
+ with self._lock:
+ if response_id in self._store:
+ self._store.move_to_end(response_id)
+ self._store[response_id] = {
+ "input": input_items,
+ "output": response_output,
+ }
+ while len(self._store) > self._maxsize:
+ self._store.popitem(last=False)
+
+ def get(self, response_id: str) -> Optional[dict]:
+ """Retrieve a stored response by ID.
+
+ Args:
+ response_id: The response ID to look up.
+
+ Returns:
+ Dict with ``"input"`` and ``"output"`` keys, or ``None`` if not found.
+ """
+ with self._lock:
+ entry = self._store.get(response_id)
+ if entry is not None:
+ self._store.move_to_end(response_id)
+ return entry
+
+ def replay_input(self, response_id: str) -> Optional[list]:
+ """Build conversation input by replaying a previous response.
+
+ Reconstructs input items from the stored response: the original input
+ items followed by the output items converted to input format.
+
+ Args:
+ response_id: The previous response ID to replay.
+
+ Returns:
+ List of input items suitable for prepending to the current request,
+ or ``None`` if the response ID is not found.
+ """
+ entry = self.get(response_id)
+ if entry is None:
+ return None
+
+ items = []
+
+ # Add original input items
+ original_input = entry["input"]
+ if isinstance(original_input, str):
+ items.append({"role": "user", "content": original_input})
+ elif isinstance(original_input, list):
+ items.extend(original_input)
+
+ # Convert output items to input format
+ for output_item in entry.get("output", []):
+ if isinstance(output_item, dict):
+ item_type = output_item.get("type", "")
+ if item_type == "message":
+ # Collect all output_text parts into a single assistant message
+ content = output_item.get("content", [])
+ text_parts = [
+ {"type": "output_text", "text": part.get("text", "")}
+ for part in content
+ if isinstance(part, dict) and part.get("type") == "output_text"
+ ]
+ if text_parts:
+ items.append({
+ "role": "assistant",
+ "content": text_parts,
+ })
+ elif item_type == "function_call":
+ items.append(
+ {
+ "type": "function_call",
+ "call_id": output_item.get("call_id", ""),
+ "name": output_item.get("name", ""),
+ "arguments": output_item.get("arguments", ""),
+ }
+ )
+
+ return items
+
+ def __len__(self) -> int:
+ with self._lock:
+ return len(self._store)
+
+ def clear(self) -> None:
+ """Remove all stored responses."""
+ with self._lock:
+ self._store.clear()
diff --git a/mlx_vlm/server.py b/mlx_vlm/server.py
index 04c011f6a..e4b985e32 100644
--- a/mlx_vlm/server.py
+++ b/mlx_vlm/server.py
@@ -7,7 +7,6 @@
import traceback
import uuid
from contextlib import asynccontextmanager
-from datetime import datetime
from typing import Any, List, Literal, Optional, Union
import mlx.core as mx
@@ -40,10 +39,34 @@
from .utils import load
from .version import __version__
from .vision_cache import VisionFeatureCache
+from .responses_models import (
+ ResponsesRequest,
+ ResponseObject,
+ ResponseUsage,
+ ResponseErrorObject,
+ ResponseIncompleteDetails,
+ ResponseMessageItem,
+ ResponseFunctionCallItem,
+ ContentPartOutputText as RespContentPartOutputText,
+ ResponseCreatedEvent as RespCreatedEvent,
+ ResponseInProgressEvent as RespInProgressEvent,
+ ResponseOutputItemAddedEvent as RespOutputItemAddedEvent,
+ ResponseContentPartAddedEvent as RespContentPartAddedEvent,
+ ResponseOutputTextDeltaEvent as RespOutputTextDeltaEvent,
+ ResponseOutputTextDoneEvent as RespOutputTextDoneEvent,
+ ResponseContentPartDoneEvent as RespContentPartDoneEvent,
+ ResponseOutputItemDoneEvent as RespOutputItemDoneEvent,
+ ResponseFunctionCallArgumentsDeltaEvent as RespFuncCallArgsDeltaEvent,
+ ResponseFunctionCallArgumentsDoneEvent as RespFuncCallArgsDoneEvent,
+ ResponseCompletedEvent as RespCompletedEvent,
+)
+from .responses_store import ResponseStore
DEFAULT_SERVER_HOST = "0.0.0.0"
DEFAULT_SERVER_PORT = 8080
+_responses_store = ResponseStore()
+
def get_prefill_step_size():
return int(os.environ.get("PREFILL_STEP_SIZE", DEFAULT_PREFILL_STEP_SIZE))
@@ -704,199 +727,363 @@ class ModelsResponse(BaseModel):
data: List[ModelInfo]
-# OpenAI compatile endpoints
+# ---------------------------------------------------------------------------
+# Responses API helpers
+# ---------------------------------------------------------------------------
-@app.post("/responses")
-@app.post("/v1/responses", include_in_schema=False)
-async def responses_endpoint(openai_request: OpenAIRequest):
+_MAX_REPLAY_DEPTH = 50
+
+
+def responses_input_to_messages(
+ input_items: Union[str, list],
+ instructions: Optional[str] = None,
+ previous_response_id: Optional[str] = None,
+ _depth: int = 0,
+ _seen: Optional[set] = None,
+) -> tuple[list[dict], list[str]]:
+ """Convert Responses API input items to chat messages and images.
+
+ Args:
+ input_items: String input or list of input items.
+ instructions: Optional system instructions to prepend.
+ previous_response_id: Optional previous response ID for context replay.
+ _depth: Internal recursion depth counter.
+ _seen: Internal set of visited response IDs for cycle detection.
+
+ Returns:
+ Tuple of (chat_messages, image_urls).
"""
- OpenAI-compatible endpoint for generating text based on a prompt and optional images.
+ if _seen is None:
+ _seen = set()
+
+ chat_messages: list[dict] = []
+ images: list[str] = []
+
+ # Replay previous response context (with depth + cycle guard)
+ if previous_response_id:
+ if _depth >= _MAX_REPLAY_DEPTH:
+ raise HTTPException(
+ status_code=400,
+ detail=f"previous_response_id chain exceeds maximum depth ({_MAX_REPLAY_DEPTH}).",
+ )
+ if previous_response_id in _seen:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Cycle detected in previous_response_id chain: {previous_response_id}",
+ )
+ _seen.add(previous_response_id)
- using client.responses.create method.
+ replayed = _responses_store.replay_input(previous_response_id)
+ if replayed is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Previous response not found: {previous_response_id}",
+ )
+ prev_messages, prev_images = responses_input_to_messages(
+ replayed, _depth=_depth + 1, _seen=_seen,
+ )
+ chat_messages.extend(prev_messages)
+ images.extend(prev_images)
+
+ # Prepend instructions as system message
+ if instructions:
+ chat_messages.insert(0, {"role": "system", "content": instructions})
+
+ # Handle string input
+ if isinstance(input_items, str):
+ chat_messages.append({"role": "user", "content": input_items})
+ return chat_messages, images
+
+ # Handle list of input items
+ for item in input_items:
+ if isinstance(item, dict):
+ item_type = item.get("type", "")
+ role = item.get("role", "")
+
+ # Function call output item
+ if item_type == "function_call_output":
+ call_id = item.get("call_id", "unknown")
+ output = item.get("output", "")
+ chat_messages.append({
+ "role": "tool",
+ "content": output,
+ "tool_call_id": call_id,
+ })
+ continue
+
+ # Function call item (from previous assistant turn)
+ if item_type == "function_call":
+ chat_messages.append({
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{
+ "id": item.get("call_id", ""),
+ "type": "function",
+ "function": {
+ "name": item.get("name", ""),
+ "arguments": item.get("arguments", ""),
+ },
+ }],
+ })
+ continue
+
+ # Regular message with role and content
+ if role:
+ content = item.get("content", "")
+
+ # Normalize developer role to system
+ msg_role = "system" if role == "developer" else role
+
+ if isinstance(content, str):
+ chat_messages.append({"role": msg_role, "content": content})
+ elif isinstance(content, list):
+ # Process content items
+ text_parts = []
+ for ci in content:
+ if isinstance(ci, dict):
+ ci_type = ci.get("type", "")
+ if ci_type in ("input_text", "text"):
+ text_parts.append(ci.get("text", ""))
+ elif ci_type == "input_image":
+ url = ci.get("image_url", "")
+ if url:
+ images.append(url)
+ elif ci_type == "image_url":
+ img = ci.get("image_url", {})
+ if isinstance(img, dict):
+ url = img.get("url", "")
+ if url:
+ images.append(url)
+ elif isinstance(img, str) and img:
+ images.append(img)
+ elif ci_type == "output_text":
+ # Multi-turn: previous assistant output
+ chat_messages.append({
+ "role": "assistant",
+ "content": ci.get("text", ""),
+ })
+ elif ci_type == "input_audio":
+ pass # Audio not yet supported in responses
+ else:
+ pass # Skip unsupported content types gracefully
+
+ if text_parts:
+ chat_messages.append({
+ "role": msg_role,
+ "content": "\n".join(text_parts),
+ })
+ else:
+ chat_messages.append({
+ "role": msg_role,
+ "content": str(content) if content else "",
+ })
+ continue
+
+ # Handle Pydantic ChatMessage objects
+ elif hasattr(item, "role"):
+ role = item.role
+ msg_role = "system" if role == "developer" else role
+ content = item.content
+
+ if content is None:
+ chat_messages.append({"role": msg_role, "content": ""})
+ elif isinstance(content, str):
+ chat_messages.append({"role": msg_role, "content": content})
+ elif isinstance(content, list):
+ text_parts = []
+ for ci in content:
+ if isinstance(ci, dict):
+ ci_type = ci.get("type", "")
+ if ci_type in ("input_text", "text"):
+ text_parts.append(ci.get("text", ""))
+ elif ci_type == "input_image":
+ url = ci.get("image_url", "")
+ if url:
+ images.append(url)
+ elif ci_type == "image_url":
+ img = ci.get("image_url", {})
+ if isinstance(img, dict):
+ url = img.get("url", "")
+ if url:
+ images.append(url)
+ elif isinstance(img, str) and img:
+ images.append(img)
+ elif ci_type == "output_text":
+ chat_messages.append({
+ "role": "assistant",
+ "content": ci.get("text", ""),
+ })
+
+ if text_parts:
+ chat_messages.append({
+ "role": msg_role,
+ "content": "\n".join(text_parts),
+ })
+
+ return chat_messages, images
+
+
+def build_responses_output(
+ raw_text: str,
+ tool_parser_type: Optional[str],
+ tool_module: Optional[Any],
+ tools: Optional[list],
+) -> list[Union[ResponseMessageItem, ResponseFunctionCallItem]]:
+ """Build structured Responses API output items from raw model text.
+
+ Parses tool calls from the raw text if a tool parser is available,
+ creating ResponseFunctionCallItem for each detected call and a
+ ResponseMessageItem for any remaining text.
+
+ Args:
+ raw_text: The raw text output from the model.
+ tool_parser_type: The detected tool parser type (e.g., "gemma4"), or None.
+ tool_module: The loaded tool parser module, or None.
+ tools: The tool definitions from the request, or None.
+
+ Returns:
+ List of output items (message items and/or function call items).
+ """
+ output_items: list[Union[ResponseMessageItem, ResponseFunctionCallItem]] = []
+ remaining_text = raw_text
- example:
+ # Try to parse tool calls
+ if tool_parser_type and tool_module and tools:
+ try:
+ result = process_tool_calls(raw_text, tool_module, tools)
+ if result["calls"]:
+ for call in result["calls"]:
+ func_info = call.get("function", {})
+ output_items.append(
+ ResponseFunctionCallItem(
+ name=func_info.get("name", ""),
+ arguments=func_info.get("arguments", "{}"),
+ call_id=call.get("id", f"call_{uuid.uuid4().hex[:24]}"),
+ )
+ )
+ remaining_text = result.get("remaining_text", "").strip()
+ except Exception as e:
+ print(f"Warning: tool call parsing failed: {e}")
+ remaining_text = raw_text
- from openai import OpenAI
+ # Create message item for any remaining text
+ if remaining_text or not output_items:
+ msg_item = ResponseMessageItem(
+ content=[RespContentPartOutputText(text=remaining_text)] if remaining_text else [],
+ )
+ # Insert message before function calls (matching OpenAI ordering)
+ output_items.insert(0, msg_item)
- API_URL = "http://0.0.0.0:8000"
- API_KEY = 'any'
+ return output_items
- def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, model="mlx-community/Qwen2.5-VL-3B-Instruct-8bit"):
- ''' Calls the OpenAI API
- '''
- client = OpenAI(base_url=f"{API_URL}", api_key=API_KEY)
+# OpenAI compatible endpoints
- try :
- response = client.responses.create(
- model=model,
- input=[
- {"role":"system",
- "content": f"{system}"
- },
- {
- "role": "user",
- "content": [
- {"type": "input_text", "text": prompt},
- {"type": "input_image", "image_url": f"{img_url}"},
- ],
- }
- ],
- max_output_tokens=max_output_tokens,
- stream=stream
- )
- if not stream:
- print(response.output[0].content[0].text)
- print(response.usage)
- else:
- for event in response:
- # Process different event types if needed
- if hasattr(event, 'delta') and event.delta:
- print(event.delta, end="", flush=True)
- elif event.type == 'response.completed':
- print("\n--- Usage ---")
- print(event.response.usage)
- except Exception as e:
- # building a response object to match the one returned when request is successful so that it can be processed in the same way
- return {"model - error":str(e),"content":{}, "model":model}
+@app.post("/responses")
+@app.post("/v1/responses", include_in_schema=False)
+async def responses_endpoint(request: ResponsesRequest):
+ """OpenAI-compatible Responses API endpoint.
+ Supports tool calling, multi-turn via previous_response_id, and streaming
+ with proper SSE event sequences including function_call argument events.
"""
try:
# Get model, processor, config - loading if necessary
- model, processor, config = get_cached_model(openai_request.model)
+ model, processor, config = get_cached_model(request.model)
- chat_messages = []
- images = []
- instructions = None
- if openai_request.input:
- if isinstance(openai_request.input, str):
- # If input is a string, treat it as a single text message
- chat_messages.append({"role": "user", "content": openai_request.input})
- elif isinstance(openai_request.input, list):
- # If input is a list, treat it as a series of chat messages
- for message in openai_request.input:
- if isinstance(message, ChatMessage):
- if message.content is None:
- chat_messages.append({"role": message.role, "content": ""})
- elif isinstance(message.content, str):
- chat_messages.append(
- {"role": message.role, "content": message.content}
- )
- if message.role == "system":
- instructions = message.content
- elif isinstance(message.content, list):
- # Handle list of content items
- for item in message.content:
- if isinstance(item, dict):
- if item["type"] == "input_text":
- chat_messages.append(
- {
- "role": message.role,
- "content": item["text"],
- }
- )
- if message.role == "system":
- instructions = item["text"]
- # examples for multiple images (https://platform.openai.com/docs/guides/images?api-mode=responses)
- elif item["type"] == "input_image":
- images.append(item["image_url"])
- else:
- print(
- f"invalid input item type: {item['type']}"
- )
- raise HTTPException(
- status_code=400,
- detail="Invalid input item type.",
- )
- else:
- print(
- f"Invalid message content item format: {item}"
- )
- raise HTTPException(
- status_code=400,
- detail="Missing type in input item.",
- )
- else:
- print("Invalid message content format.")
- raise HTTPException(
- status_code=400, detail="Invalid input format."
- )
- else:
- print("not a ChatMessage")
- raise HTTPException(
- status_code=400, detail="Invalid input format."
- )
- else:
- print("neither string not list")
- raise HTTPException(status_code=400, detail="Invalid input format.")
+ # Convert input to chat messages
+ chat_messages, images = responses_input_to_messages(
+ request.input,
+ instructions=request.instructions,
+ previous_response_id=request.previous_response_id,
+ )
- else:
- print("no input")
- raise HTTPException(status_code=400, detail="Missing input.")
+ # Set up tool parser
+ tools = request.tools
+ tool_parser_type = None
+ tool_module = None
+ tokenizer = (
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
+ )
+ if hasattr(tokenizer, "chat_template") and tools:
+ tool_parser_type = _infer_tool_parser(tokenizer.chat_template)
+ if tool_parser_type is not None:
+ tool_module = load_tool_module(tool_parser_type)
+
+ # Build template kwargs
+ template_kwargs = request.template_kwargs()
- template_kwargs = openai_request.template_kwargs()
+ # Apply chat template (pass tools so the template can include tool defs)
formatted_prompt = apply_chat_template(
processor,
config,
chat_messages,
num_images=len(images),
+ tools=tools,
**template_kwargs,
)
- generation_kwargs = build_generation_kwargs(openai_request, template_kwargs)
+ generation_kwargs = build_generation_kwargs(request, template_kwargs)
- generated_at = datetime.now().timestamp()
- response_id = f"resp_{uuid.uuid4().hex}"
- message_id = f"msg_{uuid.uuid4().hex}"
+ generated_at = int(time.time())
+ response_id = f"resp_{uuid.uuid4().hex[:24]}"
+ message_id = f"msg_{uuid.uuid4().hex[:24]}"
- if openai_request.stream:
+ if request.stream:
+ # ----------------------------------------------------------
# Streaming response
- async def stream_generator():
- token_iterator = None
+ # ----------------------------------------------------------
+ async def stream_responses_generator():
+ seq = 0 # sequence_number counter
+
+ def _evt(event_type: str, event_obj) -> str:
+ nonlocal seq
+ event_obj.sequence_number = seq
+ seq += 1
+ return f"event: {event_type}\ndata: {event_obj.model_dump_json()}\n\n"
+
try:
- # Create base response object (to match the openai pipeline)
- base_response = OpenAIResponse(
+ # Build base ResponseObject (in_progress, empty output)
+ base_response = ResponseObject(
id=response_id,
- object="response",
- created_at=int(generated_at),
+ created_at=generated_at,
status="in_progress",
- instructions=instructions,
- max_output_tokens=openai_request.max_output_tokens,
- model=openai_request.model,
+ model=request.model,
output=[],
- output_text="",
- temperature=openai_request.temperature,
- top_p=openai_request.top_p,
- usage={
- "input_tokens": 0, # get prompt tokens
- "output_tokens": 0,
- "total_tokens": 0,
- },
+ instructions=request.instructions,
+ max_output_tokens=request.max_output_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ tools=tools or [],
+ tool_choice=request.tool_choice,
+ parallel_tool_calls=request.parallel_tool_calls,
+ previous_response_id=request.previous_response_id,
+ metadata=request.metadata,
+ usage=ResponseUsage(input_tokens=0, output_tokens=0, total_tokens=0),
)
- # Send response.created event (to match the openai pipeline)
- yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response).model_dump_json()}\n\n"
-
- # Send response.in_progress event (to match the openai pipeline)
- yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response).model_dump_json()}\n\n"
+ # response.created
+ yield _evt("response.created", RespCreatedEvent(response=base_response))
+ # response.in_progress
+ yield _evt("response.in_progress", RespInProgressEvent(response=base_response))
- # Send response.output_item.added event (to match the openai pipeline)
- message_item = MessageItem(
- id=message_id,
- type="message",
- status="in_progress",
- role="assistant",
- content=[],
+ # output_item.added (message)
+ msg_item = ResponseMessageItem(id=message_id, status="in_progress", content=[])
+ yield _evt(
+ "response.output_item.added",
+ RespOutputItemAddedEvent(output_index=0, item=msg_item),
)
- yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item).model_dump_json()}\n\n"
- # Send response.content_part.added event
- content_part = ContentPartOutputText(
- type="output_text", text="", annotations=[]
+ # content_part.added
+ empty_part = RespContentPartOutputText(text="")
+ yield _evt(
+ "response.content_part.added",
+ RespContentPartAddedEvent(
+ item_id=message_id, output_index=0, content_index=0, part=empty_part,
+ ),
)
- yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part).model_dump_json()}\n\n"
# Stream text deltas
token_iterator = stream_generate(
@@ -909,59 +1096,170 @@ async def stream_generator():
)
full_text = ""
+ visible_text = ""
+ usage_stats = {"input_tokens": 0, "output_tokens": 0}
+ in_tool_call = False
+ tool_call_start_tag = getattr(tool_module, "tool_call_start", "") if tool_module else None
+ tool_call_end_tag = getattr(tool_module, "tool_call_end", None) if tool_module else None
+
for chunk in token_iterator:
if chunk is None or not hasattr(chunk, "text"):
continue
delta = chunk.text
full_text += delta
-
usage_stats = {
"input_tokens": chunk.prompt_tokens,
"output_tokens": chunk.generation_tokens,
}
- # Send response.output_text.delta event
- yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta).model_dump_json()}\n\n"
+ # Suppress tool call markup from being streamed as text
+ if tool_call_start_tag and tools:
+ if not in_tool_call and tool_call_start_tag in full_text:
+ in_tool_call = True
+ elif in_tool_call and tool_call_end_tag and tool_call_end_tag in full_text:
+ in_tool_call = False
+ if in_tool_call:
+ continue
+
+ # Check for partial tag at end of buffer
+ if full_text.endswith(tool_call_start_tag[:1]):
+ tail = full_text[-(len(tool_call_start_tag)):]
+ if any(
+ tool_call_start_tag[:i] == tail[-i:]
+ for i in range(1, len(tool_call_start_tag))
+ ):
+ continue
+
+ visible_text += delta
+ yield _evt(
+ "response.output_text.delta",
+ RespOutputTextDeltaEvent(
+ item_id=message_id, output_index=0, content_index=0, delta=delta,
+ ),
+ )
+
+ # Determine finish reason
+ max_tok = request.max_output_tokens
+ is_length = usage_stats["output_tokens"] >= max_tok
+ status = "incomplete" if is_length else "completed"
- # Send response.output_text.done event (to match the openai pipeline)
- yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text).model_dump_json()}\n\n"
+ # Use visible_text (sans tool call markup) for text events
+ display_text = visible_text.strip()
- # Send response.content_part.done event (to match the openai pipeline)
- final_content_part = ContentPartOutputText(
- type="output_text", text=full_text, annotations=[]
+ # output_text.done
+ yield _evt(
+ "response.output_text.done",
+ RespOutputTextDoneEvent(
+ item_id=message_id, output_index=0, content_index=0, text=display_text,
+ ),
)
- yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part).model_dump_json()}\n\n"
-
- # Send response.output_item.done event (to match the openai pipeline)
- final_message_item = MessageItem(
- id=message_id,
- type="message",
- status="completed",
- role="assistant",
- content=[final_content_part],
+
+ # content_part.done
+ final_part = RespContentPartOutputText(text=display_text)
+ yield _evt(
+ "response.content_part.done",
+ RespContentPartDoneEvent(
+ item_id=message_id, output_index=0, content_index=0, part=final_part,
+ ),
+ )
+
+ # output_item.done (message)
+ final_msg = ResponseMessageItem(
+ id=message_id, status="completed", content=[final_part],
)
- yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item).model_dump_json()}\n\n"
+ yield _evt(
+ "response.output_item.done",
+ RespOutputItemDoneEvent(output_index=0, item=final_msg),
+ )
+
+ # Collect all output items for final response
+ all_output_items: list = [final_msg]
+
+ # Parse tool calls from accumulated text
+ if tool_parser_type and tool_module and tools:
+ try:
+ tc_result = process_tool_calls(full_text, tool_module, tools)
+ if tc_result["calls"]:
+ for idx, call in enumerate(tc_result["calls"]):
+ func_info = call.get("function", {})
+ fc_item = ResponseFunctionCallItem(
+ name=func_info.get("name", ""),
+ arguments=func_info.get("arguments", "{}"),
+ call_id=call.get("id", f"call_{uuid.uuid4().hex[:24]}"),
+ )
+ out_idx = len(all_output_items)
+
+ # output_item.added (function_call)
+ yield _evt(
+ "response.output_item.added",
+ RespOutputItemAddedEvent(output_index=out_idx, item=fc_item),
+ )
+
+ # function_call_arguments.delta (full arguments in one shot)
+ yield _evt(
+ "response.function_call_arguments.delta",
+ RespFuncCallArgsDeltaEvent(
+ item_id=fc_item.id,
+ output_index=out_idx,
+ delta=fc_item.arguments,
+ ),
+ )
+
+ # function_call_arguments.done
+ yield _evt(
+ "response.function_call_arguments.done",
+ RespFuncCallArgsDoneEvent(
+ item_id=fc_item.id,
+ output_index=out_idx,
+ arguments=fc_item.arguments,
+ ),
+ )
- # Send response.completed event (to match the openai pipeline)
+ # output_item.done (function_call)
+ yield _evt(
+ "response.output_item.done",
+ RespOutputItemDoneEvent(output_index=out_idx, item=fc_item),
+ )
+
+ all_output_items.append(fc_item)
+ except Exception as e:
+ print(f"Warning: streaming tool call parsing failed: {e}")
+
+ # response.completed
+ total_tokens = usage_stats["input_tokens"] + usage_stats["output_tokens"]
completed_response = base_response.model_copy(
update={
- "status": "completed",
- "output": [final_message_item],
- "usage": {
- "input_tokens": usage_stats["input_tokens"],
- "output_tokens": usage_stats["output_tokens"],
- "total_tokens": usage_stats["input_tokens"]
- + usage_stats["output_tokens"],
- },
+ "status": status,
+ "output": all_output_items,
+ "incomplete_details": (
+ ResponseIncompleteDetails(reason="max_output_tokens")
+ if status == "incomplete"
+ else None
+ ),
+ "usage": ResponseUsage(
+ input_tokens=usage_stats["input_tokens"],
+ output_tokens=usage_stats["output_tokens"],
+ total_tokens=total_tokens,
+ ),
}
)
- yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response).model_dump_json()}\n\n"
+ yield _evt("response.completed", RespCompletedEvent(response=completed_response))
+
+ # Save to store for previous_response_id
+ _responses_store.save(
+ response_id,
+ request.input if isinstance(request.input, str) else [
+ item.model_dump() if hasattr(item, "model_dump") else item
+ for item in request.input
+ ],
+ [item.model_dump() for item in all_output_items],
+ )
except Exception as e:
print(f"Error during stream generation: {e}")
traceback.print_exc()
- error_data = json.dumps({"error": str(e)})
+ error_data = json.dumps({"error": "Internal generation error"})
yield f"data: {error_data}\n\n"
finally:
@@ -970,7 +1268,7 @@ async def stream_generator():
print("Stream finished, cleared cache.")
return StreamingResponse(
- stream_generator(),
+ stream_responses_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -980,51 +1278,71 @@ async def stream_generator():
)
else:
+ # ----------------------------------------------------------
# Non-streaming response
+ # ----------------------------------------------------------
try:
- # Use generate from generate.py
result = generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
- verbose=False, # stats are passed in the response
+ verbose=False,
+ vision_cache=model_cache.get("vision_cache"),
**generation_kwargs,
)
- # Clean up resources
mx.clear_cache()
gc.collect()
print("Generation finished, cleared cache.")
- response = OpenAIResponse(
+ # Build output items (with tool call parsing)
+ output_items = build_responses_output(
+ result.text, tool_parser_type, tool_module, tools,
+ )
+
+ # Determine status
+ is_length = result.generation_tokens >= request.max_output_tokens
+ status = "incomplete" if is_length else "completed"
+ incomplete_details = (
+ ResponseIncompleteDetails(reason="max_output_tokens")
+ if status == "incomplete"
+ else None
+ )
+
+ response_obj = ResponseObject(
id=response_id,
- object="response",
- created_at=int(generated_at),
- status="completed",
- instructions=instructions,
- max_output_tokens=openai_request.max_output_tokens,
- model=openai_request.model,
- output=[
- {
- "role": "assistant",
- "content": [
- {
- "type": "output_text",
- "text": result.text,
- }
- ],
- }
+ created_at=generated_at,
+ model=request.model,
+ output=output_items,
+ status=status,
+ incomplete_details=incomplete_details,
+ instructions=request.instructions,
+ max_output_tokens=request.max_output_tokens,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ tools=tools or [],
+ tool_choice=request.tool_choice,
+ parallel_tool_calls=request.parallel_tool_calls,
+ previous_response_id=request.previous_response_id,
+ metadata=request.metadata,
+ usage=ResponseUsage(
+ input_tokens=result.prompt_tokens,
+ output_tokens=result.generation_tokens,
+ total_tokens=result.total_tokens,
+ ),
+ )
+
+ # Save to store for previous_response_id support
+ _responses_store.save(
+ response_obj.id,
+ request.input if isinstance(request.input, str) else [
+ item.model_dump() if hasattr(item, "model_dump") else item
+ for item in request.input
],
- output_text=result.text,
- temperature=openai_request.temperature,
- top_p=openai_request.top_p,
- usage={
- "input_tokens": result.prompt_tokens,
- "output_tokens": result.generation_tokens,
- "total_tokens": result.total_tokens,
- },
+ [item.model_dump() for item in output_items],
)
- return response
+
+ return response_obj.model_dump()
except Exception as e:
print(f"Error during generation: {e}")
@@ -1033,11 +1351,9 @@ async def stream_generator():
gc.collect()
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
- except HTTPException as http_exc:
- # Re-raise HTTP exceptions (like model loading failure)
- raise http_exc
+ except HTTPException:
+ raise
except Exception as e:
- # Catch unexpected errors
print(f"Unexpected error in /responses endpoint: {e}")
traceback.print_exc()
mx.clear_cache()
diff --git a/mlx_vlm/tests/test_responses_api.py b/mlx_vlm/tests/test_responses_api.py
new file mode 100644
index 000000000..1546aad39
--- /dev/null
+++ b/mlx_vlm/tests/test_responses_api.py
@@ -0,0 +1,476 @@
+"""Tests for the OpenAI Responses API (/v1/responses) compliance.
+
+Covers:
+ A. Model validation (pure unit tests, no server/mlx needed)
+ B. Response store (pure unit tests)
+ C. Functional endpoint tests (TestClient, mocked model)
+ D. Streaming endpoint tests (TestClient, mocked model)
+"""
+
+import importlib.util
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+
+# ---------------------------------------------------------------------------
+# Helpers: load modules without triggering mlx_vlm.__init__ (no mlx needed)
+# ---------------------------------------------------------------------------
+
+def _load_module(name: str, filename: str):
+ """Load a sibling module by file path, bypassing package __init__."""
+ mod_path = Path(__file__).parent.parent / filename
+ spec = importlib.util.spec_from_file_location(name, str(mod_path))
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod
+
+
+responses_models = _load_module("responses_models", "responses_models.py")
+responses_store = _load_module("responses_store", "responses_store.py")
+
+ResponsesRequest = responses_models.ResponsesRequest
+ResponseObject = responses_models.ResponseObject
+ResponseMessageItem = responses_models.ResponseMessageItem
+ResponseFunctionCallItem = responses_models.ResponseFunctionCallItem
+ContentPartOutputText = responses_models.ContentPartOutputText
+ResponseUsage = responses_models.ResponseUsage
+FlexibleBaseModel = responses_models.FlexibleBaseModel
+BaseStreamEvent = responses_models.BaseStreamEvent
+ResponseStore = responses_store.ResponseStore
+
+
+# =========================================================================
+# A. Model Validation Tests
+# =========================================================================
+
+
+class TestResponsesModels:
+ """Pure unit tests for Pydantic models in responses_models.py."""
+
+ def test_responses_request_accepts_string_input(self):
+ req = ResponsesRequest(input="Hello", model="test-model")
+ assert req.input == "Hello"
+
+ def test_responses_request_accepts_message_list(self):
+ msgs = [{"role": "user", "content": "hello"}]
+ req = ResponsesRequest(input=msgs, model="test-model")
+ assert isinstance(req.input, list)
+ assert len(req.input) == 1
+
+ def test_responses_request_accepts_tools(self):
+ tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get the weather",
+ "parameters": {"type": "object", "properties": {}},
+ }
+ ]
+ req = ResponsesRequest(input="hi", model="m", tools=tools)
+ assert req.tools is not None
+ assert len(req.tools) == 1
+
+ def test_responses_request_default_tool_choice(self):
+ req = ResponsesRequest(input="hi", model="m")
+ assert req.tool_choice == "auto"
+
+ def test_responses_request_generation_kwargs(self):
+ req = ResponsesRequest(input="hi", model="m", max_output_tokens=128)
+ kwargs = req.generation_kwargs()
+ assert "max_tokens" in kwargs
+ assert kwargs["max_tokens"] == 128
+ assert "max_output_tokens" not in kwargs
+
+ def test_response_object_output_text_computed(self):
+ msg = ResponseMessageItem(
+ content=[
+ ContentPartOutputText(text="Hello "),
+ ContentPartOutputText(text="world!"),
+ ]
+ )
+ resp = ResponseObject(
+ created_at=0,
+ model="m",
+ output=[msg],
+ usage=ResponseUsage(input_tokens=1, output_tokens=2, total_tokens=3),
+ )
+ assert resp.output_text == "Hello world!"
+
+ def test_response_object_output_text_empty_when_only_function_calls(self):
+ fc = ResponseFunctionCallItem(name="fn", arguments='{"a":1}')
+ resp = ResponseObject(
+ created_at=0,
+ model="m",
+ output=[fc],
+ usage=ResponseUsage(input_tokens=1, output_tokens=2, total_tokens=3),
+ )
+ assert resp.output_text == ""
+
+ def test_function_call_item_auto_ids(self):
+ fc = ResponseFunctionCallItem(name="fn", arguments="{}")
+ assert fc.id.startswith("fc_")
+ assert fc.call_id.startswith("call_")
+ # IDs should be unique per instance
+ fc2 = ResponseFunctionCallItem(name="fn", arguments="{}")
+ assert fc.id != fc2.id
+
+ def test_function_call_item_schema(self):
+ fc = ResponseFunctionCallItem(name="get_weather", arguments='{"city":"NYC"}')
+ assert fc.name == "get_weather"
+ assert fc.arguments == '{"city":"NYC"}'
+ assert fc.type == "function_call"
+
+ def test_content_part_output_text_defaults(self):
+ part = ContentPartOutputText()
+ assert part.type == "output_text"
+ assert part.text == ""
+ assert part.annotations == []
+
+ def test_streaming_event_sequence_number(self):
+ evt = BaseStreamEvent(type="test.event", sequence_number=42)
+ assert evt.sequence_number == 42
+ evt_default = BaseStreamEvent(type="test.event")
+ assert evt_default.sequence_number == 0
+
+ def test_flexible_base_model_accepts_unknown_fields(self):
+ req = ResponsesRequest(
+ input="hi", model="m", some_unknown_field="surprise"
+ )
+ # Should not raise; extra field accessible via model_extra
+ assert req.model_extra.get("some_unknown_field") == "surprise"
+
+
+# =========================================================================
+# B. Response Store Tests
+# =========================================================================
+
+
+class TestResponseStore:
+ """Pure unit tests for the LRU ResponseStore."""
+
+ def test_store_save_and_get(self):
+ store = ResponseStore()
+ store.save("resp_1", "hello", [{"type": "message"}])
+ entry = store.get("resp_1")
+ assert entry is not None
+ assert entry["input"] == "hello"
+ assert entry["output"] == [{"type": "message"}]
+
+ def test_store_get_missing_returns_none(self):
+ store = ResponseStore()
+ assert store.get("resp_nonexistent") is None
+
+ def test_store_lru_eviction(self):
+ store = ResponseStore(maxsize=2)
+ store.save("resp_a", "a", [])
+ store.save("resp_b", "b", [])
+ store.save("resp_c", "c", []) # should evict resp_a
+ assert store.get("resp_a") is None
+ assert store.get("resp_b") is not None
+ assert store.get("resp_c") is not None
+
+ def test_store_replay_string_input(self):
+ store = ResponseStore()
+ store.save("resp_1", "hello", [])
+ items = store.replay_input("resp_1")
+ assert items is not None
+ assert len(items) == 1
+ assert items[0]["role"] == "user"
+ assert items[0]["content"] == "hello"
+
+ def test_store_replay_message_list_input(self):
+ original = [
+ {"role": "user", "content": "What is 2+2?"},
+ {"role": "system", "content": "You are helpful."},
+ ]
+ store = ResponseStore()
+ store.save("resp_1", original, [])
+ items = store.replay_input("resp_1")
+ assert items is not None
+ assert len(items) == 2
+ assert items[0]["role"] == "user"
+ assert items[1]["role"] == "system"
+
+ def test_store_replay_function_call_output(self):
+ output = [
+ {
+ "type": "function_call",
+ "call_id": "call_123",
+ "name": "get_weather",
+ "arguments": '{"city":"NYC"}',
+ }
+ ]
+ store = ResponseStore()
+ store.save("resp_1", "hello", output)
+ items = store.replay_input("resp_1")
+ assert items is not None
+ # First item is the original user input, second is the function call
+ fc_items = [i for i in items if i.get("type") == "function_call"]
+ assert len(fc_items) == 1
+ assert fc_items[0]["name"] == "get_weather"
+ assert fc_items[0]["call_id"] == "call_123"
+
+ def test_store_replay_missing_returns_none(self):
+ store = ResponseStore()
+ assert store.replay_input("resp_nope") is None
+
+ def test_store_clear(self):
+ store = ResponseStore()
+ store.save("resp_1", "a", [])
+ store.save("resp_2", "b", [])
+ assert len(store) == 2
+ store.clear()
+ assert len(store) == 0
+ assert store.get("resp_1") is None
+
+
+# =========================================================================
+# C. Functional Endpoint Tests (require mlx for server import)
+# =========================================================================
+
+# Guard: skip functional/streaming tests if mlx is unavailable, but let
+# the pure-unit tests above run on any platform.
+_has_mlx = importlib.util.find_spec("mlx") is not None
+
+if _has_mlx:
+ import mlx_vlm.server as server # noqa: E402
+ from fastapi.testclient import TestClient # noqa: E402
+
+_skip_no_mlx = pytest.mark.skipif(not _has_mlx, reason="mlx not installed")
+
+
+# Shared mock objects (safe to create even without mlx)
+mock_model = MagicMock()
+mock_processor = MagicMock()
+mock_processor.tokenizer = MagicMock()
+mock_processor.tokenizer.chat_template = ""
+mock_config = SimpleNamespace(model_type="test")
+
+
+def _mock_result(text="Hello world!", prompt_tokens=10, gen_tokens=5):
+ """Build a SimpleNamespace matching generate() return shape."""
+ return SimpleNamespace(
+ text=text,
+ prompt_tokens=prompt_tokens,
+ generation_tokens=gen_tokens,
+ total_tokens=prompt_tokens + gen_tokens,
+ )
+
+
+@pytest.fixture
+def client():
+ with TestClient(server.app) as c:
+ yield c
+
+
+def _patch_model():
+ return patch.object(
+ server, "get_cached_model",
+ return_value=(mock_model, mock_processor, mock_config),
+ )
+
+
+def _patch_template():
+ return patch.object(server, "apply_chat_template", return_value="prompt")
+
+
+def _patch_generate(result=None):
+ if result is None:
+ result = _mock_result()
+ return patch.object(server, "generate", return_value=result)
+
+
+@_skip_no_mlx
+class TestResponsesEndpoint:
+ """Functional tests for POST /responses."""
+
+ def test_basic_text_response(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "Hello"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["object"] == "response"
+ assert data["status"] == "completed"
+ assert "id" in data
+ assert "output" in data
+ assert "usage" in data
+
+ def test_response_with_message_list(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [{"role": "user", "content": "hello"}],
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["status"] == "completed"
+
+ def test_instructions_field_echoed(self, client):
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {"role": "system", "content": "Be brief."},
+ {"role": "user", "content": "hi"},
+ ],
+ "instructions": "Be brief.",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ # The instructions field should be present in the response
+ assert data.get("instructions") is not None
+
+ def test_tools_field_echoed(self, client):
+ tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get weather",
+ "parameters": {"type": "object", "properties": {}},
+ }
+ ]
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={"model": "demo", "input": "hi", "tools": tools},
+ )
+ assert resp.status_code == 200
+
+ def test_previous_response_id_not_found(self, client):
+ """Referencing a non-existent previous_response_id should return an error."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": "follow-up",
+ "previous_response_id": "resp_nonexistent999",
+ },
+ )
+ # The server should either 404 or 200 (ignoring unknown ID).
+ # We just verify it doesn't crash with a 500.
+ assert resp.status_code in (200, 404)
+
+ def test_developer_role_mapped_to_system(self, client):
+ """developer role should be accepted (mapped to system internally)."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {"role": "developer", "content": "You are helpful."},
+ {"role": "user", "content": "hi"},
+ ],
+ },
+ )
+ # Should not crash; accept 200 or 422 if server rejects developer role
+ assert resp.status_code in (200, 422)
+
+ def test_text_type_alias(self, client):
+ """'text' type should be accepted alongside 'input_text'."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {
+ "role": "user",
+ "content": [{"type": "input_text", "text": "hi"}],
+ }
+ ],
+ },
+ )
+ assert resp.status_code == 200
+
+ def test_function_call_output_input(self, client):
+ """function_call_output items in input should not crash the server."""
+ with _patch_model(), _patch_template(), _patch_generate():
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": [
+ {"role": "user", "content": "call a tool"},
+ {
+ "type": "function_call_output",
+ "call_id": "call_abc",
+ "output": '{"result": 42}',
+ },
+ ],
+ },
+ )
+ # May fail if server doesn't handle function_call_output yet;
+ # accept anything except unhandled 500
+ assert resp.status_code in (200, 400, 422)
+
+ def test_max_output_tokens_incomplete(self, client):
+ """When finish_reason is 'length', the response status should ideally be 'incomplete'."""
+ result = _mock_result(text="truncated...")
+ with _patch_model(), _patch_template(), _patch_generate(result):
+ resp = client.post(
+ "/responses",
+ json={
+ "model": "demo",
+ "input": "Write a very long essay",
+ "max_output_tokens": 5,
+ },
+ )
+ assert resp.status_code == 200
+ # Just verify the response is well-formed
+ data = resp.json()
+ assert data["status"] in ("completed", "incomplete")
+
+
+@_skip_no_mlx
+class TestResponsesStreaming:
+ """Streaming SSE tests for POST /responses with stream=true."""
+
+ def _stream_events(self, client, payload):
+ """Helper: POST with stream=True and collect SSE events."""
+ with _patch_model(), _patch_template():
+ # Mock stream_generate to yield chunks
+ chunks = [
+ SimpleNamespace(text="Hello", prompt_tokens=10, generation_tokens=1),
+ SimpleNamespace(text=" world", prompt_tokens=10, generation_tokens=2),
+ ]
+
+ def mock_stream_gen(**kwargs):
+ return iter(chunks)
+
+ with patch.object(server, "stream_generate", side_effect=mock_stream_gen):
+ resp = client.post("/responses", json=payload)
+ return resp
+
+ def test_streaming_sse_events(self, client):
+ payload = {"model": "demo", "input": "Hello", "stream": True}
+ resp = self._stream_events(client, payload)
+ assert resp.status_code == 200
+ body = resp.text
+ # Should contain key event types
+ assert "event: response.created" in body
+ assert "event: response.output_text.delta" in body
+ assert "event: response.completed" in body
+
+ def test_streaming_done_sentinel(self, client):
+ """The stream should end properly (response.completed is the last real event)."""
+ payload = {"model": "demo", "input": "Hello", "stream": True}
+ resp = self._stream_events(client, payload)
+ assert resp.status_code == 200
+ body = resp.text
+ # The last meaningful event should be response.completed
+ lines = [l for l in body.strip().split("\n") if l.startswith("event:")]
+ assert lines[-1] == "event: response.completed"