Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,14 @@ def add_new_req(
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
remote_notify_port=kv_transfer_params["remote_notify_port"],
remote_handshake_port=kv_transfer_params.get(
"remote_handshake_port",
int(MoRIIOConstants.DEFAULT_HANDSHAKE_PORT),
),
remote_notify_port=kv_transfer_params.get(
"remote_notify_port",
int(MoRIIOConstants.DEFAULT_NOTIFY_PORT),
),
tp_size=kv_transfer_params.get("tp_size", 1),
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}

# Snapshot of kv_transfer_params captured at allocation time.
# The Request object's kv_transfer_params may be mutated or
# cleared between scheduler steps, so we cache a copy here
# to ensure build_connector_meta has access to the original
# values (remote_block_ids, remote_engine_id, etc.).
self._req_kv_params: dict[ReqId, dict] = {}

# For chunked prefill, we perform layer-wise access within the final chunk.
# TODO: Perform transfer at end chunk.
self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {}
Expand Down Expand Up @@ -341,6 +348,7 @@ def update_state_after_alloc(
if params.get("do_remote_decode"):
local_block_ids = blocks.get_block_ids()[0]
self._reqs_need_save[request.request_id] = (request, local_block_ids)
self._req_kv_params[request.request_id] = dict(params)

if params is not None and params.get("do_remote_prefill"):
if self.mode == MoRIIOMode.READ:
Expand All @@ -365,6 +373,9 @@ def update_state_after_alloc(
request,
local_block_ids,
)
self._req_kv_params[request.request_id] = dict(
params
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
Expand Down Expand Up @@ -458,29 +469,37 @@ def build_connector_meta(

# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
kv_params = self._req_kv_params.get(
req_id, req.kv_transfer_params or {}
)
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
kv_transfer_params=kv_params,
)

for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
kv_params = self._req_kv_params.get(
req_id, req.kv_transfer_params or {}
)
if req.num_prompt_tokens > len(block_ids) * self.block_size:
# not last chunk prefill
self._reqs_need_pending_save[req_id] = (req, block_ids)
continue
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
kv_transfer_params=kv_params,
write_mode=True,
)
# Clear the list once workers start the transfers

meta.reqs_to_send = self._reqs_need_send

for req_id in self._reqs_need_recv:
self._req_kv_params.pop(req_id, None)
for req_id in self._reqs_need_save:
self._req_kv_params.pop(req_id, None)
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_need_send = {}
Expand Down Expand Up @@ -1086,7 +1105,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self.slot_size_bytes = (
kv_elem_size * n_kv_heads * head_dim
) # 1 token 1 layer size , slot size
assert block_size == self.block_size
# The attention backend may override the configured block_size
# (e.g. FlashMLA forces block_size=64 for MLA models regardless
# of the --block-size CLI flag). Trust the actual tensor shape.
if block_size != self.block_size:
logger.info(
"KV cache block_size=%d differs from config block_size=%d; "
"using actual tensor shape (attention backend override).",
block_size,
self.block_size,
)
self.block_size = block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
Expand Down
Loading