support continous batching for eagle3#3321
support continous batching for eagle3#3321xufang-lisa wants to merge 4 commits intoopenvinotoolkit:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds support for continuous batching with multiple prompts to the EAGLE3 speculative decoding pipeline. The implementation ensures that the draft model and main model process the same chunks by coordinating prefill completion across all requests.
Changes:
- Added synchronization logic to pause generation until all requests complete prefill in EAGLE3 mode
- Modified hidden state handling to support partial tensor copying for mismatched sequence lengths
- Extended test coverage to validate multiple prompt batching scenarios
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| tests/python_tests/test_continuous_batching.py | Updated tests to support parameterized multi-prompt batching scenarios for both speculative decoding and EAGLE3 |
| src/cpp/src/speculative_decoding/continuous_batching/pipeline_impl.cpp | Added prefill synchronization logic and modified draft model execution control flow |
| src/cpp/src/sequence_group.hpp | Added has_finished_prefill() method to check if any sequence has begun generation |
| src/cpp/src/continuous_batching/model_runner.hpp | Enhanced hidden state handling with partial tensor copying for mismatched sequence lengths |
src/cpp/src/speculative_decoding/continuous_batching/pipeline_impl.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
src/cpp/src/speculative_decoding/continuous_batching/pipeline_impl.cpp:1
- Switching from
to_generate |= ...toto_generate &= ...changes semantics from “any request can generate” to “all requests can generate”. With this change,to_generatemust be initialized totrueimmediately before the loop; otherwise&=can preserve a stale value from previous iterations (or remainfalse), preventing the draft model from running when it should.
// Copyright (C) 2023-2026 Intel Corporation
| // To ensure that `draft model` and `main model` process the same chunks. | ||
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | ||
| // prefill. | ||
| if (generated_requests.size() < m_requests.size() && |
There was a problem hiding this comment.
In the non-validation branch, pause_gen_status is no longer reset to false before applying the generated_len conditions. This means it can carry over the earlier prefill-synchronization value and incorrectly pause generation even when the generated_len condition is not met. Fix by using a separate variable for validation-mode prefill pausing, or explicitly setting pause_gen_status = false; at the start of the !m_is_validation_mode_enabled block.
| if (generated_requests.size() < m_requests.size() && | |
| if (m_is_validation_mode_enabled && | |
| generated_requests.size() < m_requests.size() && |
| generated_len -= result.removed_tokens_cnt; | ||
| generated_len += result.inserted_tokens_cnt; | ||
| if (generated_len >= max_new_tokens - 1 || generated_len != 0 && result.inserted_tokens_cnt == 0) { | ||
| pause_gen_status = true; | ||
| } | ||
| request->pause_generation(pause_gen_status); |
There was a problem hiding this comment.
In the non-validation branch, pause_gen_status is no longer reset to false before applying the generated_len conditions. This means it can carry over the earlier prefill-synchronization value and incorrectly pause generation even when the generated_len condition is not met. Fix by using a separate variable for validation-mode prefill pausing, or explicitly setting pause_gen_status = false; at the start of the !m_is_validation_mode_enabled block.
| bool pause_gen_status = false; | ||
| std::vector<uint64_t> generated_requests; | ||
| for (auto& request : m_requests) { | ||
| if (request->has_finished_prefill()) { | ||
| generated_requests.push_back(request->get_request_id()); | ||
| } | ||
| } | ||
|
|
||
| // To ensure that `draft model` and `main model` process the same chunks. | ||
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | ||
| // prefill. | ||
| if (generated_requests.size() < m_requests.size() && | ||
| std::find(generated_requests.begin(), generated_requests.end(), request_id) != generated_requests.end()) { | ||
| pause_gen_status = true; | ||
| } |
There was a problem hiding this comment.
This allocates and fills generated_requests on every call, then does a linear std::find, which is avoidable overhead in a hot path. Consider computing the pause condition without building a vector (e.g., first determine whether all requests finished prefill, and separately whether the current request_id has finished prefill), or track a finished-prefill counter/flag updated when requests transition state.
| bool pause_gen_status = false; | |
| std::vector<uint64_t> generated_requests; | |
| for (auto& request : m_requests) { | |
| if (request->has_finished_prefill()) { | |
| generated_requests.push_back(request->get_request_id()); | |
| } | |
| } | |
| // To ensure that `draft model` and `main model` process the same chunks. | |
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | |
| // prefill. | |
| if (generated_requests.size() < m_requests.size() && | |
| std::find(generated_requests.begin(), generated_requests.end(), request_id) != generated_requests.end()) { | |
| pause_gen_status = true; | |
| } | |
| bool all_prefill_finished = true; | |
| bool current_finished_prefill = false; | |
| for (auto& request : m_requests) { | |
| const bool finished = request->has_finished_prefill(); | |
| if (!finished) { | |
| all_prefill_finished = false; | |
| } | |
| if (request->get_request_id() == request_id && finished) { | |
| current_finished_prefill = true; | |
| } | |
| } | |
| // To ensure that `draft model` and `main model` process the same chunks. | |
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | |
| // prefill. | |
| bool pause_gen_status = !all_prefill_finished && current_finished_prefill; |
| std::vector<uint64_t> generated_requests; | ||
| for (auto& request : m_requests) { | ||
| if (request->has_finished_prefill()) { | ||
| generated_requests.push_back(request->get_request_id()); | ||
| } | ||
| } | ||
|
|
||
| // To ensure that `draft model` and `main model` process the same chunks. | ||
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | ||
| // prefill. | ||
| if (generated_requests.size() < m_requests.size() && | ||
| std::find(generated_requests.begin(), generated_requests.end(), request_id) != generated_requests.end()) { |
There was a problem hiding this comment.
This allocates and fills generated_requests on every call, then does a linear std::find, which is avoidable overhead in a hot path. Consider computing the pause condition without building a vector (e.g., first determine whether all requests finished prefill, and separately whether the current request_id has finished prefill), or track a finished-prefill counter/flag updated when requests transition state.
| std::vector<uint64_t> generated_requests; | |
| for (auto& request : m_requests) { | |
| if (request->has_finished_prefill()) { | |
| generated_requests.push_back(request->get_request_id()); | |
| } | |
| } | |
| // To ensure that `draft model` and `main model` process the same chunks. | |
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | |
| // prefill. | |
| if (generated_requests.size() < m_requests.size() && | |
| std::find(generated_requests.begin(), generated_requests.end(), request_id) != generated_requests.end()) { | |
| std::size_t finished_prefill_count = 0; | |
| bool current_request_finished_prefill = false; | |
| for (const auto& request : m_requests) { | |
| if (request->has_finished_prefill()) { | |
| ++finished_prefill_count; | |
| if (request->get_request_id() == request_id) { | |
| current_request_finished_prefill = true; | |
| } | |
| } | |
| } | |
| // To ensure that `draft model` and `main model` process the same chunks. | |
| // The request that completes the prefill first needs to pause and wait for other requests to complete their | |
| // prefill. | |
| if (finished_prefill_count < m_requests.size() && current_request_finished_prefill) { |
| bool has_finished_prefill() const { | ||
| for (auto& sequence : get_running_sequences()) { | ||
| if (sequence->get_generated_len() > 0) { | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| return false; | ||
| } |
There was a problem hiding this comment.
has_finished_prefill() currently returns true based on get_generated_len() > 0, which reads as “has started generating tokens” rather than “finished prefill”. This mismatch makes the new synchronization logic hard to reason about. Either (a) rename this helper to reflect what it actually checks, or (b) change the condition to a true prefill-completion signal (whatever the project uses to indicate prompt KV/cache is fully built) so the name matches behavior (C++ Core Guidelines: make interfaces precisely and strongly typed / self-descriptive; also aligns with guideline on descriptive function names).
Description
support continuous batching with multiple prompts for eagle3 pipeline.
CVS-179148
Checklist: