Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions mlx_vlm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,14 @@ class GenerationRequest(VLMRequest):
)


class UsageStats(OpenAIUsage):
class UsageStats(BaseModel):
"""
Inherits from OpenAIUsage and adds additional fields for usage statistics.
OpenAI chat-completion usage with extra MLX stats.
"""

prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tps: float = Field(..., description="Tokens per second for the prompt.")
generation_tps: float = Field(
..., description="Tokens per second for the generation."
Expand Down Expand Up @@ -1142,8 +1145,8 @@ async def stream_generator():

# Yield chunks in Server-Sent Events (SSE) format
usage_stats = {
"input_tokens": chunk.prompt_tokens,
"output_tokens": chunk.generation_tokens,
"prompt_tokens": chunk.prompt_tokens,
"completion_tokens": chunk.generation_tokens,
"total_tokens": chunk.prompt_tokens
+ chunk.generation_tokens,
"prompt_tps": chunk.prompt_tps,
Expand Down Expand Up @@ -1240,8 +1243,8 @@ async def stream_generator():
print("Generation finished, cleared cache.")

usage_stats = UsageStats(
input_tokens=gen_result.prompt_tokens,
output_tokens=gen_result.generation_tokens,
prompt_tokens=gen_result.prompt_tokens,
completion_tokens=gen_result.generation_tokens,
total_tokens=gen_result.total_tokens,
prompt_tps=gen_result.prompt_tps,
generation_tps=gen_result.generation_tps,
Expand Down
43 changes: 43 additions & 0 deletions mlx_vlm/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,46 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client):
assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15
assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5}
assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512)


def test_chat_completions_response_uses_openai_usage_field_names(client):
"""Regression: /chat/completions must return `prompt_tokens` /
`completion_tokens` / `total_tokens` per the OpenAI Chat Completions
spec, not the `/responses` API's `input_tokens` / `output_tokens`.
"""
model = SimpleNamespace()
processor = SimpleNamespace()
config = SimpleNamespace(model_type="qwen2_vl")
result = SimpleNamespace(
text="done",
prompt_tokens=8,
generation_tokens=4,
total_tokens=12,
prompt_tps=10.0,
generation_tps=5.0,
peak_memory=0.1,
)

with (
patch.object(
server, "get_cached_model", return_value=(model, processor, config)
),
patch.object(server, "apply_chat_template", return_value="prompt"),
patch.object(server, "generate", return_value=result),
):
response = client.post(
"/chat/completions",
json={
"model": "demo",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 12,
},
)

assert response.status_code == 200
usage = response.json()["usage"]
assert usage["prompt_tokens"] == 8
assert usage["completion_tokens"] == 4
assert usage["total_tokens"] == 12
assert "input_tokens" not in usage
assert "output_tokens" not in usage