VLM: add support for mllama architecture (Llama-3.2-11B-Vision) #3315
VLM: add support for mllama architecture (Llama-3.2-11B-Vision) #3315RyanMetcalfeInt8 wants to merge 6 commits intoopenvinotoolkit:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds initial C++ pipeline support for the mllama VLM architecture (e.g., Llama-3.2-11B-Vision) by introducing mllama-specific vision encoding outputs (cross-KV states) and plumbing additional LM inputs (notably cross_attention_mask) through the decode loop.
Changes:
- Added
VLMModelType::MLLAMAand integrated mllama intoVisionEncoder/InputsEmbedderfactories. - Introduced a new
InputsEmbedder::get_language_model_inputs()API to return named LM inputs (embeds + extra tensors). - Implemented mllama preprocessing/tiling + cross-attention mask generation, and extended LM encoding to accept/update
cross_attention_mask.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| src/cpp/src/visual_language/vlm_config.hpp | Adds MLLAMA model type enum value. |
| src/cpp/src/visual_language/vlm_config.cpp | Adds string-to-enum mapping for "mllama". |
| src/cpp/src/visual_language/vision_encoder.hpp | Extends EncodedImage to carry mllama cross-KV states and tile count. |
| src/cpp/src/visual_language/vision_encoder.cpp | Wires up VisionEncoderMLlama factory creation. |
| src/cpp/src/visual_language/processor_config.hpp | Adds mllama processor parameter max_image_tiles. |
| src/cpp/src/visual_language/processor_config.cpp | Reads max_image_tiles from JSON. |
| src/cpp/src/visual_language/pipeline.cpp | Switches VLM pipeline to get_language_model_inputs() and passes cross_attention_mask to LM encoding. |
| src/cpp/src/visual_language/mllama/classes.hpp | Declares mllama-specific VisionEncoder and InputsEmbedder. |
| src/cpp/src/visual_language/mllama/classes.cpp | Implements mllama image tiling, vision encode (cross-KV outputs), and LM inputs assembly (incl. cross_attention_mask). |
| src/cpp/src/visual_language/inputs_embedder.hpp | Adds get_language_model_inputs() to public + interface APIs. |
| src/cpp/src/visual_language/inputs_embedder.cpp | Implements default get_language_model_inputs() and registers InputsEmbedderMLlama. |
| src/cpp/src/lm_encoding.hpp | Extends get_lm_encoded_results signature with cross_attention_mask. |
| src/cpp/src/lm_encoding.cpp | Sets/updates cross_attention_mask during generation iterations. |
| src/cpp/src/llm/pipeline_stateful.cpp | Updates call site for new get_lm_encoded_results signature. |
Comments suppressed due to low confidence (1)
src/cpp/src/visual_language/vision_encoder.hpp:1
num_tilesis not default-initialized, so it may contain an indeterminate value for non-mllama paths (or if any code forgets to set it). Initialize it (e.g.,size_t num_tiles = 0;) to avoid UB.
// Copyright (C) 2023-2026 Intel Corporation
| const ov::AnyMap device_config); | ||
|
|
||
| InputsEmbedderMLlama(const VLMConfig& vlm_config, | ||
| const ModelsMap& models_map, | ||
| const Tokenizer& tokenizer, | ||
| const std::filesystem::path& config_dir_path, | ||
| const std::string& device, | ||
| const ov::AnyMap device_config); |
There was a problem hiding this comment.
ov::AnyMap device_config is passed by value (and const), which forces a copy. Prefer const ov::AnyMap& device_config here (and in the corresponding definitions) to avoid unnecessary copying of potentially large maps.
| const ov::AnyMap device_config); | |
| InputsEmbedderMLlama(const VLMConfig& vlm_config, | |
| const ModelsMap& models_map, | |
| const Tokenizer& tokenizer, | |
| const std::filesystem::path& config_dir_path, | |
| const std::string& device, | |
| const ov::AnyMap device_config); | |
| const ov::AnyMap& device_config); | |
| InputsEmbedderMLlama(const VLMConfig& vlm_config, | |
| const ModelsMap& models_map, | |
| const Tokenizer& tokenizer, | |
| const std::filesystem::path& config_dir_path, | |
| const std::string& device, | |
| const ov::AnyMap& device_config); |
| ov::Coordinate in_coord_begin{b, 0}; | ||
| ov::Coordinate in_coord_end{b + 1, num_tokens}; | ||
| ov::Tensor input_ids_slice(input_ids, in_coord_begin, in_coord_end); | ||
| cross_attention_token_mask.emplace_back(get_cross_attention_token_mask(input_ids, image_token_id)); |
There was a problem hiding this comment.
This loop slices input_ids per batch, but then calls get_cross_attention_token_mask(input_ids, ...) using the full tensor instead of input_ids_slice. Additionally, get_cross_attention_token_mask explicitly requires shape [1, L] and will throw if batch > 1, making the loop incorrect as written. Use the slice (and ensure it has the expected shape) or assert batch==1 and remove the loop.
| cross_attention_token_mask.emplace_back(get_cross_attention_token_mask(input_ids, image_token_id)); | |
| cross_attention_token_mask.emplace_back(get_cross_attention_token_mask(input_ids_slice, image_token_id)); |
| //TODO: replace '4' here with max_tiles from vision preprocessor config. | ||
| auto cross_attention_mask = | ||
| convert_sparse_cross_attention_mask_to_dense(cross_attention_token_mask, num_tiles, 4, num_tokens); |
There was a problem hiding this comment.
The hardcoded 4 is a magic number that can go out of sync with ProcessorConfig::max_image_tiles (and you already have a TODO noting it). Thread max_image_tiles through to this call (e.g., from processor config / model config) so the mask shape matches the actual preprocessing configuration.
| //TODO: We can probably just adapt update_attention_mask_with_beams to support both | ||
| // attention_mask & cross_attention_mask, as I think it also needs to be a function of next_beams (?) |
There was a problem hiding this comment.
attention_mask is updated with next_beams, but cross_attention_mask is not beam-reordered—only extended by 1 token. For beam search (or any multi-sequence scenario), this will desynchronize cross_attention_mask from the selected beams and can produce incorrect cross-attention behavior. The fix should reorder/expand cross_attention_mask in the same way update_attention_mask_with_beams does (generalized to 4D).
| //TODO: We can probably just adapt update_attention_mask_with_beams to support both | |
| // attention_mask & cross_attention_mask, as I think it also needs to be a function of next_beams (?) | |
| // Reorder cross_attention_mask to follow the selected beams in the batch dimension. | |
| { | |
| const ov::Shape original_shape = cross_attention_mask->get_shape(); | |
| const size_t original_batch = original_shape.at(0); | |
| const size_t new_batch = next_beams.size(); | |
| OPENVINO_ASSERT(original_batch >= new_batch, | |
| "cross_attention_mask batch size is smaller than number of beams."); | |
| ov::Shape reordered_shape = original_shape; | |
| reordered_shape[0] = new_batch; | |
| ov::Tensor reordered_mask(cross_attention_mask->get_element_type(), reordered_shape); | |
| const size_t bytes_per_batch = cross_attention_mask->get_byte_size() / original_batch; | |
| const unsigned char* src = cross_attention_mask->data<const unsigned char>(); | |
| unsigned char* dst = reordered_mask.data<unsigned char>(); | |
| for (size_t i = 0; i < new_batch; ++i) { | |
| const size_t src_index = static_cast<size_t>(next_beams[i]); | |
| OPENVINO_ASSERT(src_index < original_batch, | |
| "Beam index is out of range for cross_attention_mask batch size."); | |
| const size_t src_offset = src_index * bytes_per_batch; | |
| const size_t dst_offset = i * bytes_per_batch; | |
| std::copy(src + src_offset, src + src_offset + bytes_per_batch, dst + dst_offset); | |
| } | |
| *cross_attention_mask = reordered_mask; | |
| } |
Description
This PR adds support for mllama VLM architecture -- e.g. https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct
As there is no support in optimum-intel yet for this model, and there is still quite a lot of work to do, I am marking as a draft PR for now.
Am testing with LLama3.2-11B-Vision-Instruct models that have been exported with this fork / branch: https://github.com/RyanMetcalfeInt8/optimum-intel/tree/llama3.2_11b_wip_trans_4.55.4_1.27.0
It's a branch based off of optimum-intel 1.27.0, but adds some initial support for mllama.
Unlike other VLM architectures that this project currently supports, mllama doesn't merge text & vision embeddings into a consolidated

inputs_embedstensor. Instead, the output of the vision encoder are several cross-kv states that need to be set to as inputs to the language model:It additionally requires
cross_attention_maskto get set to the language model, and managed during decode loop in a similar way as selfattention_masktensor.To account for the extra tensors needed to be created by the mllama implementation, and ultimately passed as inputs to the language model, a new
get_language_model_inputsAPI was added toInputsEmbedder/IInputsEmbedder.get_language_model_inputsreturns a vector of named tensors --std::vector<std::pair<std::string, ov::Tensor>>.The VLM pipeline makes this call now:
And then grabs 'special' tensors (
input_embeds,token_type_ids,cross_attention_mask) that specifically need to be passed on to / managed by lm_encoder. For other inputs returned, they are just set directly (as it's assumed they don't change during generation).Happy to receive any early feedback on this mechanism.
Other than that, there is quite a lot of work to do:
CVS-173891
Checklist: