diff --git a/mlx_vlm/server.py b/mlx_vlm/server.py index 04c011f6a..6f1a593f8 100644 --- a/mlx_vlm/server.py +++ b/mlx_vlm/server.py @@ -586,11 +586,14 @@ class GenerationRequest(VLMRequest): ) -class UsageStats(OpenAIUsage): +class UsageStats(BaseModel): """ - Inherits from OpenAIUsage and adds additional fields for usage statistics. + OpenAI chat-completion usage with extra MLX stats. """ + prompt_tokens: int + completion_tokens: int + total_tokens: int prompt_tps: float = Field(..., description="Tokens per second for the prompt.") generation_tps: float = Field( ..., description="Tokens per second for the generation." @@ -1142,8 +1145,8 @@ async def stream_generator(): # Yield chunks in Server-Sent Events (SSE) format usage_stats = { - "input_tokens": chunk.prompt_tokens, - "output_tokens": chunk.generation_tokens, + "prompt_tokens": chunk.prompt_tokens, + "completion_tokens": chunk.generation_tokens, "total_tokens": chunk.prompt_tokens + chunk.generation_tokens, "prompt_tps": chunk.prompt_tps, @@ -1240,8 +1243,8 @@ async def stream_generator(): print("Generation finished, cleared cache.") usage_stats = UsageStats( - input_tokens=gen_result.prompt_tokens, - output_tokens=gen_result.generation_tokens, + prompt_tokens=gen_result.prompt_tokens, + completion_tokens=gen_result.generation_tokens, total_tokens=gen_result.total_tokens, prompt_tps=gen_result.prompt_tps, generation_tps=gen_result.generation_tps, diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 270a82d77..2fdad7be9 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -130,3 +130,46 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client): assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15 assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5} assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512) + + +def test_chat_completions_response_uses_openai_usage_field_names(client): + """Regression: /chat/completions must return `prompt_tokens` / + `completion_tokens` / `total_tokens` per the OpenAI Chat Completions + spec, not the `/responses` API's `input_tokens` / `output_tokens`. + """ + model = SimpleNamespace() + processor = SimpleNamespace() + config = SimpleNamespace(model_type="qwen2_vl") + result = SimpleNamespace( + text="done", + prompt_tokens=8, + generation_tokens=4, + total_tokens=12, + prompt_tps=10.0, + generation_tps=5.0, + peak_memory=0.1, + ) + + with ( + patch.object( + server, "get_cached_model", return_value=(model, processor, config) + ), + patch.object(server, "apply_chat_template", return_value="prompt"), + patch.object(server, "generate", return_value=result), + ): + response = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 12, + }, + ) + + assert response.status_code == 200 + usage = response.json()["usage"] + assert usage["prompt_tokens"] == 8 + assert usage["completion_tokens"] == 4 + assert usage["total_tokens"] == 12 + assert "input_tokens" not in usage + assert "output_tokens" not in usage