diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py index f73f5b2cdcdd..16bda11e8ae1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py @@ -292,8 +292,14 @@ def add_new_req( remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], - remote_handshake_port=kv_transfer_params["remote_handshake_port"], - remote_notify_port=kv_transfer_params["remote_notify_port"], + remote_handshake_port=kv_transfer_params.get( + "remote_handshake_port", + int(MoRIIOConstants.DEFAULT_HANDSHAKE_PORT), + ), + remote_notify_port=kv_transfer_params.get( + "remote_notify_port", + int(MoRIIOConstants.DEFAULT_NOTIFY_PORT), + ), tp_size=kv_transfer_params.get("tp_size", 1), remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 2494857c6c69..15cb7a9fd8b2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -266,6 +266,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} + # Snapshot of kv_transfer_params captured at allocation time. + # The Request object's kv_transfer_params may be mutated or + # cleared between scheduler steps, so we cache a copy here + # to ensure build_connector_meta has access to the original + # values (remote_block_ids, remote_engine_id, etc.). + self._req_kv_params: dict[ReqId, dict] = {} + # For chunked prefill, we perform layer-wise access within the final chunk. # TODO: Perform transfer at end chunk. self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {} @@ -341,6 +348,7 @@ def update_state_after_alloc( if params.get("do_remote_decode"): local_block_ids = blocks.get_block_ids()[0] self._reqs_need_save[request.request_id] = (request, local_block_ids) + self._req_kv_params[request.request_id] = dict(params) if params is not None and params.get("do_remote_prefill"): if self.mode == MoRIIOMode.READ: @@ -365,6 +373,9 @@ def update_state_after_alloc( request, local_block_ids, ) + self._req_kv_params[request.request_id] = dict( + params + ) else: logger.warning( "Got invalid KVTransferParams: %s. This " @@ -458,15 +469,19 @@ def build_connector_meta( # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): - assert req.kv_transfer_params is not None + kv_params = self._req_kv_params.get( + req_id, req.kv_transfer_params or {} + ) meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, + kv_transfer_params=kv_params, ) for req_id, (req, block_ids) in self._reqs_need_save.items(): - assert req.kv_transfer_params is not None + kv_params = self._req_kv_params.get( + req_id, req.kv_transfer_params or {} + ) if req.num_prompt_tokens > len(block_ids) * self.block_size: # not last chunk prefill self._reqs_need_pending_save[req_id] = (req, block_ids) @@ -474,13 +489,17 @@ def build_connector_meta( meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, + kv_transfer_params=kv_params, write_mode=True, ) # Clear the list once workers start the transfers meta.reqs_to_send = self._reqs_need_send + for req_id in self._reqs_need_recv: + self._req_kv_params.pop(req_id, None) + for req_id in self._reqs_need_save: + self._req_kv_params.pop(req_id, None) self._reqs_need_recv.clear() self._reqs_need_save.clear() self._reqs_need_send = {} @@ -1086,7 +1105,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.slot_size_bytes = ( kv_elem_size * n_kv_heads * head_dim ) # 1 token 1 layer size , slot size - assert block_size == self.block_size + # The attention backend may override the configured block_size + # (e.g. FlashMLA forces block_size=64 for MLA models regardless + # of the --block-size CLI flag). Trust the actual tensor shape. + if block_size != self.block_size: + logger.info( + "KV cache block_size=%d differs from config block_size=%d; " + "using actual tensor shape (attention backend override).", + block_size, + self.block_size, + ) + self.block_size = block_size # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc # block size in bytes