Skip to content

Comments

[WiP] Support Linear State in SDPA Pipeline#3359

Draft
apaniukov wants to merge 4 commits intoopenvinotoolkit:masterfrom
apaniukov:lfm2-stateful-model
Draft

[WiP] Support Linear State in SDPA Pipeline#3359
apaniukov wants to merge 4 commits intoopenvinotoolkit:masterfrom
apaniukov:lfm2-stateful-model

Conversation

@apaniukov
Copy link
Contributor

Description

Support fixed-size cache state for linear/hybrid attention models.

CVS-181414

Checklist:

  • This PR follows GenAI Contributing guidelines.
  • Tests have been updated or added to cover the new code.
  • This PR fully addresses the ticket.
  • I have made corresponding changes to the documentation.

Copilot AI review requested due to automatic review settings February 19, 2026 12:36
@apaniukov apaniukov changed the title Support Linear State in SDPA Pipeline [WiP] Support Linear State in SDPA Pipeline Feb 19, 2026
@github-actions github-actions bot added category: visual language Visual language pipeline category: LLM LLM pipeline (stateful, static) category: speculative decoding Speculative decoding no-match-files labels Feb 19, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR generalizes the “KV cache state” tracking to support fixed-size linear (and hybrid) cache state in stateful/SDPA-based pipelines by introducing a unified cache state type and propagating it through LLM/VLM/speculative decoding codepaths.

Changes:

  • Replaced KVCacheState with CacheState across pipelines and embedders.
  • Added cache kind detection (CacheTypes / get_cache_types) and updated cache-trimming behavior to reset for linear caches.
  • Wired cache-kind awareness into speculative decoding wrappers and stateful LLM pipeline initialization.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/cpp/src/visual_language/vision_token_pruning_processor.hpp Updates pruning processor API to use CacheState.
src/cpp/src/visual_language/vision_token_pruning_processor.cpp Updates pruning processor implementation signature to CacheState.
src/cpp/src/visual_language/pipeline.cpp VLM pipeline now uses CacheState when managing chat history/cache trimming.
src/cpp/src/visual_language/phi4mm/classes.cpp Switches embedder history/cache bookkeeping to m_cache_state.
src/cpp/src/visual_language/phi3_vision/classes.cpp Switches embedder history/cache bookkeeping to m_cache_state.
src/cpp/src/visual_language/inputs_embedder.hpp Replaces stored state from KVCacheState to CacheState.
src/cpp/src/visual_language/inputs_embedder.cpp Updates chat/history alignment and rollback bookkeeping to CacheState.
src/cpp/src/utils.hpp Introduces CacheTypes, CacheState, and get_cache_types() API.
src/cpp/src/utils.cpp Implements cache kind detection and updates trim_kv_cache() behavior for linear caches.
src/cpp/src/speculative_decoding/stateful/fast_draft_strategy.hpp Adds CacheTypes member to infer wrapper.
src/cpp/src/speculative_decoding/stateful/fast_draft_strategy.cpp Initializes CacheTypes and uses it to build CacheState for trimming.
src/cpp/src/speculative_decoding/stateful/eagle3_strategy.hpp Adds CacheTypes member to eagle3 infer wrapper base.
src/cpp/src/speculative_decoding/stateful/eagle3_strategy.cpp Initializes CacheTypes and uses it to build CacheState for trimming.
src/cpp/src/lm_encoding.hpp Updates encoding helpers to accept CacheState.
src/cpp/src/lm_encoding.cpp Updates chat-history alignment logic and cache-state updates for CacheState.
src/cpp/src/llm/pipeline_stateful.hpp Renames stored cache reflection to m_cache_state and renames reset helper.
src/cpp/src/llm/pipeline_stateful.cpp Initializes CacheState from model and propagates it through chat/trim logic.
Comments suppressed due to low confidence (1)

src/cpp/src/utils.cpp:525

  • trim_kv_cache() resets the InferRequest when reset_mem_state is set (or when linear cache needs reset), but it returns without clearing cache_state.reset_mem_state / num_tokens_to_trim or updating the token reflection state. This can leave CacheState inconsistent (stale tokens / repeated resets) for subsequent steps. Consider resetting the CacheState fields when a reset happens (and clearing the token reflection if the underlying model state is cleared).
void trim_kv_cache(ov::InferRequest request, CacheState& cache_state, std::optional<AdapterController> adapter_controller) {
    if (
        cache_state.reset_mem_state
        // linear cache stores only the last state, trimming is not possible, so we reset the whole cache in this case
        || (cache_state.num_tokens_to_trim > 0 && cache_state.has_linear())
    ) {
        if (adapter_controller) {
            for(auto& state: request.query_state()) {
                if(!adapter_controller->has_state_name(state.get_name())) {
                    state.reset();
                }
            }
        } else {
            request.reset_state();
        }

        return;

}
}

if (rank == 4 && dynamic_axis_count == 2) {
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cache_types() classifies KV-cache as rank==4 with dynamic_axis_count==2, but the code comment/example shape is [-1,4,0,64] (only one dynamic dim). This heuristic likely fails to detect KV-cache (and hybrid) correctly. Consider aligning the detection with get_kv_axes_pos() logic (e.g., rank==4 plus presence of the 0 "seq_len" marker) or otherwise document/handle the expected shape patterns.

Suggested change
if (rank == 4 && dynamic_axis_count == 2) {
const bool is_rank4 = (rank == 4);
// For KV-cache, expect 4D state with "seq_len" axis at index 2.
// This axis can be marked by length 0 (e.g. [-1,4,0,64]) or be dynamic.
const bool has_seq_len_marker =
is_rank4 &&
!shape[2].is_dynamic() &&
shape[2].get_length() == 0;
const bool has_dynamic_seq_len =
is_rank4 &&
shape[2].is_dynamic();
if (is_rank4 && (has_seq_len_marker || has_dynamic_seq_len)) {

Copilot uses AI. Check for mistakes.
cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx;
state.resize(first_diverse_tokens_idx);
kv_cache_state.reset_mem_state = state.empty();
cache_state.reset_mem_state = state.empty() || cache_state.has_linear();
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

align_kv_cache_and_history(): setting reset_mem_state to true whenever has_linear() makes linear-cache models reset their state on every chat iteration, which breaks incremental generation (get_chat_encoded_input will still send only the delta). reset_mem_state for linear caches should typically be triggered only when trimming would be required (num_tokens_to_trim>0); and when a reset is needed, the token-reflection state should be cleared so the next prompt replays the full history.

Suggested change
cache_state.reset_mem_state = state.empty() || cache_state.has_linear();
const bool need_trim = cache_state.num_tokens_to_trim > 0;
if (cache_state.has_linear()) {
// For linear caches, reset memory state only when trimming is required.
cache_state.reset_mem_state = need_trim;
} else {
// For non-linear caches, preserve previous behavior: reset only when state becomes empty.
cache_state.reset_mem_state = state.empty();
}

Copilot uses AI. Check for mistakes.
Comment on lines 49 to 58
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
m_kv_cache_state.num_tokens_to_trim = 0;
m_cache_state.num_tokens_to_trim = 0;
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
// If chat generation process was cancelled by user, let's rollback to previous state of kv cache
std::vector<int64_t>& state = m_kv_cache_state.get_state();
std::vector<int64_t>& state = m_cache_state.get_state();

m_kv_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length;
m_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length;
state.resize(m_prev_hist_length);
m_kv_cache_state.reset_mem_state = state.empty();
m_cache_state.reset_mem_state = state.empty() || m_cache_state.has_linear();
}
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_chat_history() cancel path: for linear-cache models, resetting the model state is necessary (trimming isn't possible), but keeping a non-empty token reflection state will cause the next get_chat_encoded_input() to send only a delta even though the model cache was reset. When linear reset is required, clear the CacheState token history (and reset trim counters) so the next iteration replays the full templated history.

Copilot uses AI. Check for mistakes.
Comment on lines 145 to 147
// reflection of tokens contained in the kv cache
utils::KVCacheState m_kv_cache_state;
utils::CacheState m_cache_state;
// length of attention_mask/kv cache at the beginning of generation()
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_cache_state is default-constructed, so its cache type mask stays at 0 (has_linear()/has_kvcache() are false). For linear/hybrid attention models this prevents the new reset-vs-trim logic from working and can make trim_kv_cache slice tensors with the wrong rank. CacheState needs to be initialized with the language model (or have cache types injected) before it is used.

Copilot uses AI. Check for mistakes.
m_use_full_chat_history = m_is_npu;

utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
utils::CacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLMPipeline initializes kv_cache_state.seq_length_axis but never initializes CacheState's cache types from the loaded language_model. That leaves has_linear()/has_kvcache() false and can break linear/hybrid-cache behavior (including potential invalid trimming in trim_kv_cache). Consider initializing the embedder's CacheState using the language_model (e.g., assign CacheState(language_model) or add a setter for cache types) during pipeline construction.

Suggested change
utils::CacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
utils::CacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
kv_cache_state = utils::CacheState(language_model);

Copilot uses AI. Check for mistakes.
uint8_t mask = 0;
};

CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model);
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cache_types() takes the std::shared_ptr by value, which adds an unnecessary refcount inc/dec and makes it easier to accidentally pass nullptr without checks. Prefer taking it as a const reference (const std::shared_ptr&) here and in the definition for consistency with other helpers like get_kv_axes_pos().

Copilot uses AI. Check for mistakes.
Comment on lines 439 to 514
@@ -451,6 +486,10 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) {

// Shape example: [-1,4,0,64]
auto shape = op->get_input_partial_shape(0);
if (shape.rank().get_length() != 4) {
// kv cache should have 4 dimensions
continue;
}

for (size_t i = 0; i < shape.rank().get_length(); i++) {
// Find axis = 0. This would be sequence length axis.
@@ -467,8 +506,12 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) {
return kv_pos;
}

void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller) {
if (kv_cache_state.reset_mem_state) {
void trim_kv_cache(ov::InferRequest request, CacheState& cache_state, std::optional<AdapterController> adapter_controller) {
if (
cache_state.reset_mem_state
// linear cache stores only the last state, trimming is not possible, so we reset the whole cache in this case
|| (cache_state.num_tokens_to_trim > 0 && cache_state.has_linear())
) {
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces new behavior for linear/hybrid cache handling (cache type detection + reset vs trim semantics) but there are no accompanying regression tests. Please add tests that exercise chat continuation and cancellation/rollback for linear-only and hybrid models to ensure state resets trigger a full replay and don't cause repeated resets or invalid trimming.

Copilot generated this review using guidance from repository custom instructions.
@as-suvorov
Copy link
Collaborator

I converted PR to draft as it labeled as WIP

@as-suvorov as-suvorov marked this pull request as draft February 19, 2026 12:54
Copilot AI review requested due to automatic review settings February 20, 2026 13:31
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 3 comments.

Comment on lines +439 to +466
CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model) {
// "ReadValue" node is cache representation in stateful model
const std::string state_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name);
CacheTypes cache_types;

for (const auto op : model->get_ops()) {
// check input size, as in LoRA adapters case it could be 0
if (op->get_type_name() != state_node_type_name || op->get_input_size() < 1) {
continue;
}

// Shape example: [-1,4,0,64]
auto shape = op->get_input_partial_shape(0);
const auto rank = shape.rank().get_length();
size_t dynamic_axis_count = 0;
for (size_t i = 0; i < rank; i++) {
if (shape[i].is_dynamic()) {
dynamic_axis_count++;
}
}

if (rank == 4 && dynamic_axis_count == 2) {
cache_types.add_kvcache();
} else if (rank == 3 && dynamic_axis_count == 1) {
cache_types.add_linear();
} else {
continue;
}
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cache_types() counts only Dimension::is_dynamic() axes, but elsewhere (get_kv_axes_pos) the sequence axis is detected via shape[i] == 0 (see the "Shape example: [-1,4,0,64]" comment). For models where seq_len is represented as a static 0 dimension, dynamic_axis_count will be 1 and KV-cache will never be detected, leaving CacheTypes empty/incorrect. Update the detection logic to also treat a zero-length seq dimension as the KV-cache pattern (or otherwise make the heuristic consistent with get_kv_axes_pos).

Copilot uses AI. Check for mistakes.
// in the case of beam_search the longest answer is in the kv cache, but the best one is needed
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence
kv_cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx;
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated sequence
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment still refers to KVCacheState even though the type has been renamed to CacheState. Please update the wording to match the new type name for clarity.

Suggested change
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated sequence
// so generated tokens were not added to CacheState and num_tokens_to_trim was set to the size of the generated sequence

Copilot uses AI. Check for mistakes.
Comment on lines +148 to +162
struct CacheTypes {
CacheTypes() = default;
explicit CacheTypes(uint8_t m) : mask(m) {}
void add_kvcache() { mask |= (1u << 0); }
void add_linear() { mask |= (1u << 1); }
bool has_kvcache() const { return (mask & (1u << 0)) != 0; }
bool has_linear() const { return (mask & (1u << 1)) != 0; }
bool is_hybrid() const { return has_kvcache() && has_linear(); }
uint8_t value() const { return mask; }
private:
uint8_t mask = 0;
};

CacheTypes get_cache_types(std::shared_ptr<const ov::Model> model);

Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CacheTypes / get_cache_types() introduces new cache-kind detection logic that is central to linear/hybrid attention support, but there are no unit tests covering the detection heuristics (e.g., models with ReadValue shapes like [-1,4,0,64] for KV or [-1,0,64] for linear). Adding focused tests would help prevent regressions when model export patterns change.

Copilot generated this review using guidance from repository custom instructions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants