Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def generate_step(
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prefill_step_size: Optional[int] = DEFAULT_PREFILL_STEP_SIZE,
verbose: bool = False,
**kwargs,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
Expand Down Expand Up @@ -529,7 +530,7 @@ def _step(y, inputs_embeds=None):
if prefill_step_size is not None and inputs_embeds.shape[1] > prefill_step_size:
# Chunked prefill with embeddings
total_tokens = inputs_embeds.shape[1]
with tqdm(total=total_tokens, desc="Prefill", unit="tok") as pbar:
with tqdm(total=total_tokens, desc="Prefill", unit="tok", disable=not verbose) as pbar:
while inputs_embeds.shape[1] > 1:
n_to_process = min(prefill_step_size, inputs_embeds.shape[1] - 1)
model.language_model(
Expand Down Expand Up @@ -600,6 +601,7 @@ def stream_generate(
containing the generated text, tokens, and statistics.
"""
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
verbose = kwargs.pop("verbose", False)

# Set up thinking budget criteria if requested
thinking_budget = kwargs.pop("thinking_budget", None)
Expand Down Expand Up @@ -728,7 +730,7 @@ def stream_generate(
detokenizer = processor.detokenizer
detokenizer.reset()
thinking_criteria = getattr(tokenizer, "thinking_budget_criteria", None)
gen = generate_step(input_ids, model, pixel_values, mask, **kwargs)
gen = generate_step(input_ids, model, pixel_values, mask, verbose=verbose, **kwargs)
tic = time.perf_counter()

generated_tokens = []
Expand Down Expand Up @@ -853,7 +855,7 @@ def generate(
else:
tokenizer.stopping_criteria.reset(model.config.eos_token_id)

for response in stream_generate(model, processor, prompt, image, audio, **kwargs):
for response in stream_generate(model, processor, prompt, image, audio, verbose=verbose, **kwargs):
if verbose:
print(response.text, end="", flush=True)
text += response.text
Expand Down