Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,39 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
return repeated_batch


def extract_initial_prompt_messages(
message_logs: list,
original_prompt_lengths: torch.Tensor,
) -> list:
"""Extract the original prompt messages from message logs using token length.

This function correctly identifies original prompt messages even when the prompt
contains assistant messages (e.g., multi-turn conversation history).

Args:
message_logs: List of message logs, where each log is a list of messages.
original_prompt_lengths: Tensor of original prompt token lengths per sample.

Returns:
List of message logs containing only the original prompt messages.
"""
initial_prompt_message_logs = []
for i, message_log in enumerate(message_logs):
initial_prompt_log = []
cumulative_length = 0
target_length = original_prompt_lengths[i].item()

for message in message_log:
if cumulative_length >= target_length:
break
initial_prompt_log.append(message)
cumulative_length += len(message["token_ids"])

initial_prompt_message_logs.append(initial_prompt_log)

return initial_prompt_message_logs


def _should_use_async_rollouts(master_config: MasterConfig) -> bool:
"""Determine if async rollouts should be used based on the configuration.

Expand Down Expand Up @@ -1072,28 +1105,6 @@ def _create_advantage_estimator(master_config: MasterConfig):
return adv_estimator


def _extract_prompt_only_messages(message_logs: list) -> list:
"""Extract only prompt messages (user/system) from message logs.

This is used to get prompt IDs for advantage estimation, excluding
any assistant responses.

Args:
message_logs: List of message logs, where each log is a list of messages.

Returns:
List of message logs containing only user and system messages.
"""
prompt_only_message_logs = []
for message_log in message_logs:
prompt_only_log = []
for message in message_log:
if message["role"] == "user" or message["role"] == "system":
prompt_only_log.append(message)
prompt_only_message_logs.append(prompt_only_log)
return prompt_only_message_logs


def refit_policy_generation(
policy: ColocatablePolicyInterface,
policy_generation: GenerationInterface,
Expand Down Expand Up @@ -1655,16 +1666,20 @@ def grpo_train(
# Save baseline for logging (before deletion)
baseline_for_log = baseline.clone()

# Extract prompt-only messages for advantage estimation
prompt_only_message_logs = _extract_prompt_only_messages(
repeated_batch["message_log"]
# Extract original prompt messages using the length field
# This correctly handles multi-turn prompts that contain assistant messages
initial_prompt_message_logs = extract_initial_prompt_messages(
repeated_batch["message_log"],
repeated_batch["length"],
)
prompt_batched_flat, _ = batched_message_log_to_flat_message(
prompt_only_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
prompt_batched_flat, prompt_input_lengths = (
batched_message_log_to_flat_message(
initial_prompt_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)
prompt_ids_for_adv = prompt_batched_flat["token_ids"]
del prompt_only_message_logs
del initial_prompt_message_logs
del prompt_batched_flat
del input_ids
del baseline
Expand Down Expand Up @@ -2720,16 +2735,21 @@ def async_grpo_train(

print("▶ Processing rewards...")
with timer.time("reward_calculation"):
# Extract prompt-only messages for advantage estimation
prompt_only_message_logs = _extract_prompt_only_messages(
repeated_batch["message_log"]
# Extract original prompt messages using the length field
# This correctly handles multi-turn prompts that contain assistant messages
initial_prompt_message_logs = extract_initial_prompt_messages(
repeated_batch["message_log"],
repeated_batch["length"],
)
prompt_batched_flat, _ = batched_message_log_to_flat_message(
prompt_only_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},

prompt_batched_flat, prompt_input_lengths = (
batched_message_log_to_flat_message(
initial_prompt_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)
prompt_ids_for_adv = prompt_batched_flat["token_ids"]
del prompt_only_message_logs
del initial_prompt_message_logs
del prompt_batched_flat

rewards = repeated_batch["total_reward"]
Expand Down
142 changes: 141 additions & 1 deletion tests/unit/algorithms/test_async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
os.environ["TMPDIR"] = _temp_dir # System temp dir

from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer
from nemo_rl.algorithms.grpo import MasterConfig
from nemo_rl.algorithms.grpo import MasterConfig, extract_initial_prompt_messages
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.environments.interfaces import (
Expand Down Expand Up @@ -698,3 +698,143 @@ def test_error_handling(self):
assert sample_result is None

ray.kill(buffer)


class TestPromptExtraction:
"""Test cases for prompt extraction logic used in async GRPO advantage calculation.

These tests verify that the length-based prompt extraction correctly handles
multi-turn conversation prompts where the original prompt itself contains
assistant messages (conversation history).
"""

def test_prompt_extraction_with_multi_turn_history(self):
"""Test that prompt extraction correctly handles prompts containing assistant messages.

This tests the fix for multi-turn conversation prompts where the original prompt
from the dataset itself contains assistant messages (conversation history).
The extraction should use the length field to identify original prompt messages,
not break at the first assistant message.
"""
# Create a multi-turn prompt with assistant messages in the history
# Original prompt: user -> assistant -> user (3 messages, 15 tokens total)
original_prompt_messages = [
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])},
{"role": "assistant", "content": "4", "token_ids": torch.tensor([6, 7, 8, 9, 10])},
{"role": "user", "content": "Now what is 3+3?", "token_ids": torch.tensor([11, 12, 13, 14, 15])},
]

# Generated response (added after original prompt)
generated_message = {
"role": "assistant",
"content": "6",
"token_ids": torch.tensor([16, 17, 18]),
}

# Full message_log after generation
full_message_log = original_prompt_messages + [generated_message]

# Original prompt length = sum of token_ids in original prompt
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) # 15

message_logs = [full_message_log]
original_prompt_lengths = torch.tensor([original_prompt_length])

result = extract_initial_prompt_messages(message_logs, original_prompt_lengths)
initial_prompt_log = result[0]

# Should extract all 3 original messages, NOT break at first assistant
assert len(initial_prompt_log) == 3, (
f"Expected 3 messages (user, assistant, user), got {len(initial_prompt_log)}. "
"The extraction should NOT break at the first assistant message when it's part of the original prompt."
)

assert initial_prompt_log[0]["role"] == "user"
assert initial_prompt_log[1]["role"] == "assistant"
assert initial_prompt_log[2]["role"] == "user"
assert generated_message not in initial_prompt_log

def test_prompt_extraction_with_single_turn(self):
"""Test that prompt extraction works correctly for single-turn prompts (regression test)."""
original_prompt_messages = [
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])},
]

generated_message = {
"role": "assistant",
"content": "4",
"token_ids": torch.tensor([6, 7, 8]),
}

full_message_log = original_prompt_messages + [generated_message]
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)

result = extract_initial_prompt_messages(
[full_message_log], torch.tensor([original_prompt_length])
)
initial_prompt_log = result[0]

assert len(initial_prompt_log) == 1
assert initial_prompt_log[0]["role"] == "user"
assert generated_message not in initial_prompt_log

def test_prompt_extraction_with_system_message(self):
"""Test prompt extraction with system message included."""
original_prompt_messages = [
{"role": "system", "content": "You are a math tutor.", "token_ids": torch.tensor([1, 2, 3])},
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([4, 5, 6, 7])},
]

generated_message = {
"role": "assistant",
"content": "4",
"token_ids": torch.tensor([8, 9]),
}

full_message_log = original_prompt_messages + [generated_message]
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)

result = extract_initial_prompt_messages(
[full_message_log], torch.tensor([original_prompt_length])
)
initial_prompt_log = result[0]

assert len(initial_prompt_log) == 2
assert initial_prompt_log[0]["role"] == "system"
assert initial_prompt_log[1]["role"] == "user"
assert generated_message not in initial_prompt_log

def test_prompt_extraction_complex_multi_turn(self):
"""Test prompt extraction with complex multi-turn history (multiple assistant turns)."""
original_prompt_messages = [
{"role": "system", "content": "Math tutor", "token_ids": torch.tensor([1, 2])},
{"role": "user", "content": "2+2?", "token_ids": torch.tensor([3, 4])},
{"role": "assistant", "content": "4", "token_ids": torch.tensor([5, 6])},
{"role": "user", "content": "3+3?", "token_ids": torch.tensor([7, 8])},
{"role": "assistant", "content": "6", "token_ids": torch.tensor([9, 10])},
{"role": "user", "content": "4+4?", "token_ids": torch.tensor([11, 12])},
]

generated_message = {
"role": "assistant",
"content": "8",
"token_ids": torch.tensor([13, 14]),
}

full_message_log = original_prompt_messages + [generated_message]
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)

result = extract_initial_prompt_messages(
[full_message_log], torch.tensor([original_prompt_length])
)
initial_prompt_log = result[0]

assert len(initial_prompt_log) == 6, (
f"Expected 6 messages, got {len(initial_prompt_log)}. "
"All history messages should be included in the prompt."
)

expected_roles = ["system", "user", "assistant", "user", "assistant", "user"]
actual_roles = [m["role"] for m in initial_prompt_log]
assert actual_roles == expected_roles
assert generated_message not in initial_prompt_log
Loading