[WiP] Support Linear State in SDPA Pipeline#3359
[WiP] Support Linear State in SDPA Pipeline#3359apaniukov wants to merge 4 commits intoopenvinotoolkit:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR generalizes the “KV cache state” tracking to support fixed-size linear (and hybrid) cache state in stateful/SDPA-based pipelines by introducing a unified cache state type and propagating it through LLM/VLM/speculative decoding codepaths.
Changes:
- Replaced
KVCacheStatewithCacheStateacross pipelines and embedders. - Added cache kind detection (
CacheTypes/get_cache_types) and updated cache-trimming behavior to reset for linear caches. - Wired cache-kind awareness into speculative decoding wrappers and stateful LLM pipeline initialization.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/cpp/src/visual_language/vision_token_pruning_processor.hpp | Updates pruning processor API to use CacheState. |
| src/cpp/src/visual_language/vision_token_pruning_processor.cpp | Updates pruning processor implementation signature to CacheState. |
| src/cpp/src/visual_language/pipeline.cpp | VLM pipeline now uses CacheState when managing chat history/cache trimming. |
| src/cpp/src/visual_language/phi4mm/classes.cpp | Switches embedder history/cache bookkeeping to m_cache_state. |
| src/cpp/src/visual_language/phi3_vision/classes.cpp | Switches embedder history/cache bookkeeping to m_cache_state. |
| src/cpp/src/visual_language/inputs_embedder.hpp | Replaces stored state from KVCacheState to CacheState. |
| src/cpp/src/visual_language/inputs_embedder.cpp | Updates chat/history alignment and rollback bookkeeping to CacheState. |
| src/cpp/src/utils.hpp | Introduces CacheTypes, CacheState, and get_cache_types() API. |
| src/cpp/src/utils.cpp | Implements cache kind detection and updates trim_kv_cache() behavior for linear caches. |
| src/cpp/src/speculative_decoding/stateful/fast_draft_strategy.hpp | Adds CacheTypes member to infer wrapper. |
| src/cpp/src/speculative_decoding/stateful/fast_draft_strategy.cpp | Initializes CacheTypes and uses it to build CacheState for trimming. |
| src/cpp/src/speculative_decoding/stateful/eagle3_strategy.hpp | Adds CacheTypes member to eagle3 infer wrapper base. |
| src/cpp/src/speculative_decoding/stateful/eagle3_strategy.cpp | Initializes CacheTypes and uses it to build CacheState for trimming. |
| src/cpp/src/lm_encoding.hpp | Updates encoding helpers to accept CacheState. |
| src/cpp/src/lm_encoding.cpp | Updates chat-history alignment logic and cache-state updates for CacheState. |
| src/cpp/src/llm/pipeline_stateful.hpp | Renames stored cache reflection to m_cache_state and renames reset helper. |
| src/cpp/src/llm/pipeline_stateful.cpp | Initializes CacheState from model and propagates it through chat/trim logic. |
Comments suppressed due to low confidence (1)
src/cpp/src/utils.cpp:525
- trim_kv_cache() resets the InferRequest when reset_mem_state is set (or when linear cache needs reset), but it returns without clearing cache_state.reset_mem_state / num_tokens_to_trim or updating the token reflection state. This can leave CacheState inconsistent (stale tokens / repeated resets) for subsequent steps. Consider resetting the CacheState fields when a reset happens (and clearing the token reflection if the underlying model state is cleared).
void trim_kv_cache(ov::InferRequest request, CacheState& cache_state, std::optional<AdapterController> adapter_controller) {
if (
cache_state.reset_mem_state
// linear cache stores only the last state, trimming is not possible, so we reset the whole cache in this case
|| (cache_state.num_tokens_to_trim > 0 && cache_state.has_linear())
) {
if (adapter_controller) {
for(auto& state: request.query_state()) {
if(!adapter_controller->has_state_name(state.get_name())) {
state.reset();
}
}
} else {
request.reset_state();
}
return;
| } | ||
| } | ||
|
|
||
| if (rank == 4 && dynamic_axis_count == 2) { |
There was a problem hiding this comment.
get_cache_types() classifies KV-cache as rank==4 with dynamic_axis_count==2, but the code comment/example shape is [-1,4,0,64] (only one dynamic dim). This heuristic likely fails to detect KV-cache (and hybrid) correctly. Consider aligning the detection with get_kv_axes_pos() logic (e.g., rank==4 plus presence of the 0 "seq_len" marker) or otherwise document/handle the expected shape patterns.
| if (rank == 4 && dynamic_axis_count == 2) { | |
| const bool is_rank4 = (rank == 4); | |
| // For KV-cache, expect 4D state with "seq_len" axis at index 2. | |
| // This axis can be marked by length 0 (e.g. [-1,4,0,64]) or be dynamic. | |
| const bool has_seq_len_marker = | |
| is_rank4 && | |
| !shape[2].is_dynamic() && | |
| shape[2].get_length() == 0; | |
| const bool has_dynamic_seq_len = | |
| is_rank4 && | |
| shape[2].is_dynamic(); | |
| if (is_rank4 && (has_seq_len_marker || has_dynamic_seq_len)) { |
| cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx; | ||
| state.resize(first_diverse_tokens_idx); | ||
| kv_cache_state.reset_mem_state = state.empty(); | ||
| cache_state.reset_mem_state = state.empty() || cache_state.has_linear(); |
There was a problem hiding this comment.
align_kv_cache_and_history(): setting reset_mem_state to true whenever has_linear() makes linear-cache models reset their state on every chat iteration, which breaks incremental generation (get_chat_encoded_input will still send only the delta). reset_mem_state for linear caches should typically be triggered only when trimming would be required (num_tokens_to_trim>0); and when a reset is needed, the token-reflection state should be cleared so the next prompt replays the full history.
| cache_state.reset_mem_state = state.empty() || cache_state.has_linear(); | |
| const bool need_trim = cache_state.num_tokens_to_trim > 0; | |
| if (cache_state.has_linear()) { | |
| // For linear caches, reset memory state only when trimming is required. | |
| cache_state.reset_mem_state = need_trim; | |
| } else { | |
| // For non-linear caches, preserve previous behavior: reset only when state becomes empty. | |
| cache_state.reset_mem_state = state.empty(); | |
| } |
| void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) { | ||
| m_kv_cache_state.num_tokens_to_trim = 0; | ||
| m_cache_state.num_tokens_to_trim = 0; | ||
| if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) { | ||
| // If chat generation process was cancelled by user, let's rollback to previous state of kv cache | ||
| std::vector<int64_t>& state = m_kv_cache_state.get_state(); | ||
| std::vector<int64_t>& state = m_cache_state.get_state(); | ||
|
|
||
| m_kv_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length; | ||
| m_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length; | ||
| state.resize(m_prev_hist_length); | ||
| m_kv_cache_state.reset_mem_state = state.empty(); | ||
| m_cache_state.reset_mem_state = state.empty() || m_cache_state.has_linear(); | ||
| } |
There was a problem hiding this comment.
update_chat_history() cancel path: for linear-cache models, resetting the model state is necessary (trimming isn't possible), but keeping a non-empty token reflection state will cause the next get_chat_encoded_input() to send only a delta even though the model cache was reset. When linear reset is required, clear the CacheState token history (and reset trim counters) so the next iteration replays the full templated history.
| // reflection of tokens contained in the kv cache | ||
| utils::KVCacheState m_kv_cache_state; | ||
| utils::CacheState m_cache_state; | ||
| // length of attention_mask/kv cache at the beginning of generation() |
There was a problem hiding this comment.
m_cache_state is default-constructed, so its cache type mask stays at 0 (has_linear()/has_kvcache() are false). For linear/hybrid attention models this prevents the new reset-vs-trim logic from working and can make trim_kv_cache slice tensors with the wrong rank. CacheState needs to be initialized with the language model (or have cache types injected) before it is used.
| m_use_full_chat_history = m_is_npu; | ||
|
|
||
| utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); | ||
| utils::CacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); |
There was a problem hiding this comment.
VLMPipeline initializes kv_cache_state.seq_length_axis but never initializes CacheState's cache types from the loaded language_model. That leaves has_linear()/has_kvcache() false and can break linear/hybrid-cache behavior (including potential invalid trimming in trim_kv_cache). Consider initializing the embedder's CacheState using the language_model (e.g., assign CacheState(language_model) or add a setter for cache types) during pipeline construction.
| utils::CacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); | |
| utils::CacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); | |
| kv_cache_state = utils::CacheState(language_model); |
| uint8_t mask = 0; | ||
| }; | ||
|
|
||
| CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model); |
There was a problem hiding this comment.
get_cache_types() takes the std::shared_ptr by value, which adds an unnecessary refcount inc/dec and makes it easier to accidentally pass nullptr without checks. Prefer taking it as a const reference (const std::shared_ptr&) here and in the definition for consistency with other helpers like get_kv_axes_pos().
| @@ -451,6 +486,10 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) { | |||
|
|
|||
| // Shape example: [-1,4,0,64] | |||
| auto shape = op->get_input_partial_shape(0); | |||
| if (shape.rank().get_length() != 4) { | |||
| // kv cache should have 4 dimensions | |||
| continue; | |||
| } | |||
|
|
|||
| for (size_t i = 0; i < shape.rank().get_length(); i++) { | |||
| // Find axis = 0. This would be sequence length axis. | |||
| @@ -467,8 +506,12 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) { | |||
| return kv_pos; | |||
| } | |||
|
|
|||
| void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller) { | |||
| if (kv_cache_state.reset_mem_state) { | |||
| void trim_kv_cache(ov::InferRequest request, CacheState& cache_state, std::optional<AdapterController> adapter_controller) { | |||
| if ( | |||
| cache_state.reset_mem_state | |||
| // linear cache stores only the last state, trimming is not possible, so we reset the whole cache in this case | |||
| || (cache_state.num_tokens_to_trim > 0 && cache_state.has_linear()) | |||
| ) { | |||
There was a problem hiding this comment.
This change introduces new behavior for linear/hybrid cache handling (cache type detection + reset vs trim semantics) but there are no accompanying regression tests. Please add tests that exercise chat continuation and cancellation/rollback for linear-only and hybrid models to ensure state resets trigger a full replay and don't cause repeated resets or invalid trimming.
|
I converted PR to draft as it labeled as WIP |
| CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model) { | ||
| // "ReadValue" node is cache representation in stateful model | ||
| const std::string state_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); | ||
| CacheTypes cache_types; | ||
|
|
||
| for (const auto op : model->get_ops()) { | ||
| // check input size, as in LoRA adapters case it could be 0 | ||
| if (op->get_type_name() != state_node_type_name || op->get_input_size() < 1) { | ||
| continue; | ||
| } | ||
|
|
||
| // Shape example: [-1,4,0,64] | ||
| auto shape = op->get_input_partial_shape(0); | ||
| const auto rank = shape.rank().get_length(); | ||
| size_t dynamic_axis_count = 0; | ||
| for (size_t i = 0; i < rank; i++) { | ||
| if (shape[i].is_dynamic()) { | ||
| dynamic_axis_count++; | ||
| } | ||
| } | ||
|
|
||
| if (rank == 4 && dynamic_axis_count == 2) { | ||
| cache_types.add_kvcache(); | ||
| } else if (rank == 3 && dynamic_axis_count == 1) { | ||
| cache_types.add_linear(); | ||
| } else { | ||
| continue; | ||
| } |
There was a problem hiding this comment.
get_cache_types() counts only Dimension::is_dynamic() axes, but elsewhere (get_kv_axes_pos) the sequence axis is detected via shape[i] == 0 (see the "Shape example: [-1,4,0,64]" comment). For models where seq_len is represented as a static 0 dimension, dynamic_axis_count will be 1 and KV-cache will never be detected, leaving CacheTypes empty/incorrect. Update the detection logic to also treat a zero-length seq dimension as the KV-cache pattern (or otherwise make the heuristic consistent with get_kv_axes_pos).
| // in the case of beam_search the longest answer is in the kv cache, but the best one is needed | ||
| // so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence | ||
| kv_cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx; | ||
| // so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated sequence |
There was a problem hiding this comment.
The comment still refers to KVCacheState even though the type has been renamed to CacheState. Please update the wording to match the new type name for clarity.
| // so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated sequence | |
| // so generated tokens were not added to CacheState and num_tokens_to_trim was set to the size of the generated sequence |
| struct CacheTypes { | ||
| CacheTypes() = default; | ||
| explicit CacheTypes(uint8_t m) : mask(m) {} | ||
| void add_kvcache() { mask |= (1u << 0); } | ||
| void add_linear() { mask |= (1u << 1); } | ||
| bool has_kvcache() const { return (mask & (1u << 0)) != 0; } | ||
| bool has_linear() const { return (mask & (1u << 1)) != 0; } | ||
| bool is_hybrid() const { return has_kvcache() && has_linear(); } | ||
| uint8_t value() const { return mask; } | ||
| private: | ||
| uint8_t mask = 0; | ||
| }; | ||
|
|
||
| CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model); | ||
|
|
There was a problem hiding this comment.
CacheTypes / get_cache_types() introduces new cache-kind detection logic that is central to linear/hybrid attention support, but there are no unit tests covering the detection heuristics (e.g., models with ReadValue shapes like [-1,4,0,64] for KV or [-1,0,64] for linear). Adding focused tests would help prevent regressions when model export patterns change.
Description
Support fixed-size cache state for linear/hybrid attention models.
CVS-181414
Checklist: