Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 108 additions & 26 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class GenerationResult:
prompt_tokens: int = 0
generation_tokens: int = 0
total_tokens: int = 0
cached_tokens: int = 0
prompt_tps: float = 0.0
generation_tps: float = 0.0
peak_memory: float = 0.0
Expand All @@ -354,6 +355,13 @@ class PromptCacheState:
def __init__(self):
self.cache: Optional[List[Any]] = None
self.token_ids: Optional[List[int]] = None
self.last_used: float = time.time()
self.created_at: float = time.time()

@property
def token_count(self) -> int:
"""Number of tokens stored in the cache."""
return len(self.token_ids) if self.token_ids else 0

def find_prefix_length(self, new_ids: list) -> int:
"""Return the number of leading tokens that match the cached ids."""
Expand All @@ -365,10 +373,20 @@ def find_prefix_length(self, new_ids: list) -> int:
return i
return max_len

def touch(self):
"""Update last_used timestamp."""
self.last_used = time.time()

def update(self, token_ids: list, kv_cache: list):
"""Store the full token sequence and corresponding KV cache."""
self.token_ids = list(token_ids)
self.cache = kv_cache
self.last_used = time.time()

def invalidate(self):
"""Discard cached state, forcing a full prefill on next turn."""
self.cache = None
self.token_ids = None


def generate_step(
Expand Down Expand Up @@ -668,33 +686,63 @@ def stream_generate(
reused_prefix_len = 0
full_input_ids_list = input_ids.flatten().tolist()

# Save original input_ids for fallback if cache reuse fails
_original_input_ids = input_ids
_original_pixel_values = pixel_values
_original_kwargs = {k: v for k, v in kwargs.items() if k == "cached_image_features"}

if prompt_cache_state is not None and prompt_cache_state.cache is not None:
prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
if prefix_len > 0 and prefix_len < input_ids.shape[1]:
reused_prefix_len = prefix_len
# Trim to only new tokens
input_ids = input_ids[:, prefix_len:]
# Only skip vision if no image tokens in the new (trimmed) tokens
image_token_id = getattr(model.config, "image_token_id", None) or getattr(
model.config, "image_token_index", None
)
new_ids = input_ids.flatten().tolist()
has_image_in_new = image_token_id is not None and image_token_id in new_ids
if not has_image_in_new:
pixel_values = None
kwargs.pop("cached_image_features", None)
# Reuse the saved KV cache (trimmed to prefix length)
kv_cache = prompt_cache_state.cache
# Trim cache to prefix_len in case it includes generated tokens
for c in kv_cache:
if hasattr(c, "keys") and c.keys is not None:
cached_len = c.keys.shape[2]
if cached_len > prefix_len:
c.keys = c.keys[:, :, :prefix_len, :]
c.values = c.values[:, :, :prefix_len, :]
if hasattr(c, "offset"):
try:
prefix_len = prompt_cache_state.find_prefix_length(full_input_ids_list)
cached_total = len(prompt_cache_state.token_ids) if prompt_cache_state.token_ids else 0
# Only reuse if a substantial prefix matches (>= 50% of cached tokens).
# Short matches on quantized KV caches (TurboQuant) can produce
# corrupted output because trim() only adjusts the offset without
# clearing stale quantized data.
min_reuse = max(512, cached_total // 2)
if prefix_len >= min_reuse and prefix_len < input_ids.shape[1]:
reused_prefix_len = prefix_len
# Trim to only new tokens
input_ids = input_ids[:, prefix_len:]
# Only skip vision if no image tokens in the new (trimmed) tokens
image_token_id = getattr(model.config, "image_token_id", None) or getattr(
model.config, "image_token_index", None
)
new_ids = input_ids.flatten().tolist()
has_image_in_new = image_token_id is not None and image_token_id in new_ids
if not has_image_in_new:
pixel_values = None
kwargs.pop("cached_image_features", None)
# Reuse the saved KV cache (trimmed to prefix length).
# Works with both standard KVCache (mx.array keys) and
# quantized caches (TurboQuant) via their trim() method.
kv_cache = prompt_cache_state.cache
for c in kv_cache:
if hasattr(c, "offset") and c.offset > prefix_len:
trim_amount = c.offset - prefix_len
if hasattr(c, "trim") and callable(c.trim):
c.trim(trim_amount)
elif hasattr(c, "keys") and c.keys is not None:
keys = c.keys
if hasattr(keys, "shape") and len(keys.shape) >= 3:
c.keys = keys[:, :, :prefix_len, :]
c.values = c.values[:, :, :prefix_len, :]
c.offset = prefix_len
elif hasattr(c, "offset") and c.offset > prefix_len:
# Quantized cache: just update offset if possible
c.offset = prefix_len
kwargs["prompt_cache"] = kv_cache
kwargs["prompt_cache"] = kv_cache
except Exception as e:
# Cache reuse failed (e.g., shape mismatch, stale KV state).
# Invalidate the cache and fall back to a fresh generation.
print(f"[prompt_cache] Cache reuse failed, invalidating: {e}")
prompt_cache_state.invalidate()
reused_prefix_len = 0
input_ids = _original_input_ids
pixel_values = _original_pixel_values
if "cached_image_features" in _original_kwargs:
kwargs["cached_image_features"] = _original_kwargs["cached_image_features"]
kwargs.pop("prompt_cache", None)

if thinking_budget is not None:
thinking_start_token_id = tokenizer.encode(
Expand All @@ -720,6 +768,34 @@ def stream_generate(
model.language_model,
max_kv_size=kwargs.get("max_kv_size", None),
)

# Validate cache shapes before generation. If the cached KV state has
# inconsistent shapes (e.g., stale after model reload), discard it and
# build a fresh cache to avoid broadcast_shapes errors during generation.
if reused_prefix_len > 0:
try:
for c in kwargs["prompt_cache"]:
if hasattr(c, "keys") and c.keys is not None:
expected_seq = reused_prefix_len
actual_seq = c.keys.shape[2] if len(c.keys.shape) >= 3 else c.offset
if actual_seq != expected_seq:
raise ValueError(
f"Cache shape mismatch: expected seq={expected_seq}, got {actual_seq}"
)
except (ValueError, IndexError, AttributeError) as e:
print(f"[prompt_cache] Cache validation failed, rebuilding: {e}")
if prompt_cache_state is not None:
prompt_cache_state.invalidate()
reused_prefix_len = 0
input_ids = _original_input_ids
pixel_values = _original_pixel_values
if "cached_image_features" in _original_kwargs:
kwargs["cached_image_features"] = _original_kwargs["cached_image_features"]
kwargs["prompt_cache"] = cache.make_prompt_cache(
model.language_model,
max_kv_size=kwargs.get("max_kv_size", None),
)

tracked_cache = kwargs["prompt_cache"]

total_prompt_tokens = reused_prefix_len + input_ids.size
Expand Down Expand Up @@ -758,6 +834,7 @@ def stream_generate(
prompt_tokens=total_prompt_tokens,
generation_tokens=n + 1,
total_tokens=total_prompt_tokens + n + 1,
cached_tokens=reused_prefix_len,
prompt_tps=prompt_tps,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
Expand All @@ -777,7 +854,12 @@ def stream_generate(
)

# Save cache state for potential reuse on next turn
if prompt_cache_state is not None:
# Save cache state for potential reuse on next turn.
# Only save if the prompt was substantial (>= 1024 tokens) to avoid
# polluting the cache with short probe/capability-check requests that
# some agent frameworks send before the real request.
_MIN_CACHE_TOKENS = 1024
if prompt_cache_state is not None and len(full_input_ids_list) >= _MIN_CACHE_TOKENS:
all_ids = full_input_ids_list + [
t.item() if hasattr(t, "item") else t for t in generated_tokens
]
Expand Down
Loading
Loading