Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.experience.rollouts import (
run_multi_turn_rollout_async_generation,
run_async_multi_turn_rollout,
run_multi_turn_rollout,
)
Expand Down Expand Up @@ -984,7 +985,7 @@ def validate(
# Generate responses (updates the LLMMessageLogType in batch_with_msg_logs)
# Use async rollouts if vLLM async engine is enabled
if _should_use_async_rollouts(master_config):
val_batch, gen_metrics = run_async_multi_turn_rollout(
val_batch, gen_metrics = run_multi_turn_rollout_async_generation(
policy_generation,
val_batch,
tokenizer,
Expand Down
3 changes: 2 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.experience.rollouts import (
run_multi_turn_rollout_async_generation,
run_async_multi_turn_rollout,
run_async_nemo_gym_rollout,
run_multi_turn_rollout,
Expand Down Expand Up @@ -2263,7 +2264,7 @@ def validate(
gen_metrics = nemo_gym_rollout_result.rollout_metrics
additional_metrics_to_report = gen_metrics
elif _should_use_async_rollouts(master_config):
val_batch, gen_metrics = run_async_multi_turn_rollout(
val_batch, gen_metrics = run_multi_turn_rollout_async_generation(
policy_generation,
val_batch,
tokenizer,
Expand Down
210 changes: 210 additions & 0 deletions nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,213 @@ def run_multi_turn_rollout(
return current_batch, rollout_metrics


def run_multi_turn_rollout_async_generation(
policy_generation: GenerationInterface,
input_batch: BatchedDataDict[DatumSpec],
tokenizer: TokenizerType,
task_to_env: dict[str, EnvironmentInterface],
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Run a batched multi-turn rollout using async generation.

This mirrors `run_multi_turn_rollout()`'s batched environment interaction,
but swaps synchronous generation for `generate_responses_async()`. It is
intended for cases like validation where preserving batched environment
evaluation matters more than per-sample pipelining across turns.
"""

async def _async_rollout_implementation():
current_batch = input_batch.copy() # Work on a copy
batch_size = len(current_batch["message_log"])
active_indices = torch.arange(batch_size)
total_rewards = torch.zeros(batch_size, dtype=torch.float32)

# Multi_rewards: number of components inferred from first env_output (1 for single-reward envs)
number_of_rewards: int | None = None
multi_rewards: torch.Tensor | None = None

# Initialize stop_strings from the initial batch if present
current_stop_strings = current_batch.get("stop_strings", [None] * batch_size)

# Tracking metrics for each sample
sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_assistant_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_env_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_terminated = torch.zeros(batch_size, dtype=torch.bool)
sample_truncated = torch.zeros(batch_size, dtype=torch.bool)
sample_max_turns_reached = torch.zeros(batch_size, dtype=torch.bool)

# Tracking per-turn metrics
total_gen_tokens_per_turn = []
active_samples_per_turn = []

for turn in range(max_rollout_turns):
if len(active_indices) == 0:
break

active_samples_per_turn.append(len(active_indices))

active_batch = current_batch.select_indices(active_indices)
active_stop_strings = [
current_stop_strings[i] for i in active_indices.tolist()
]

active_flat_messages: BatchedDataDict[FlatMessagesType]
active_flat_messages, active_input_lengths = (
batched_message_log_to_flat_message(
active_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)

active_input_ids = active_flat_messages["token_ids"]
generation_input_data = BatchedDataDict[GenerationDatumSpec](
{
"input_ids": active_input_ids,
"input_lengths": active_input_lengths,
"stop_strings": active_stop_strings,
}
)
multimodal_data = active_flat_messages.get_multimodal_dict(as_tensors=False)
generation_input_data.update(multimodal_data)

if "vllm_content" in active_batch:
generation_input_data["vllm_content"] = active_batch["vllm_content"]
if "vllm_images" in active_batch:
generation_input_data["vllm_images"] = active_batch["vllm_images"]
if "vllm_videos" in active_batch:
generation_input_data["vllm_videos"] = active_batch["vllm_videos"]

active_batch, generated_ids, gen_metrics = await generate_responses_async(
policy_generation,
generation_input_data,
active_batch,
tokenizer,
input_lengths=active_input_lengths,
greedy=greedy,
)

response_truncated = gen_metrics.pop("_response_truncated", None)
if response_truncated is not None:
for i, global_idx in enumerate(active_indices.tolist()):
if response_truncated[i]:
sample_truncated[global_idx] = True

for i, global_idx in enumerate(active_indices.tolist()):
sample_assistant_token_counts[global_idx] += len(generated_ids[i])
sample_token_counts[global_idx] += len(generated_ids[i])

total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids))

env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env)

if number_of_rewards is None:
if env_output.rewards.ndim >= 2:
number_of_rewards = int(env_output.rewards.shape[1])
multi_rewards = torch.zeros(
batch_size, number_of_rewards, dtype=torch.float32
)
else:
number_of_rewards = 1

if number_of_rewards > 1:
assert multi_rewards is not None
multi_rewards[active_indices] += env_output.rewards
total_rewards[active_indices] += env_output.rewards.sum(dim=1)
else:
total_rewards[active_indices] += env_output.rewards

truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool)
for i, global_idx in enumerate(active_indices.tolist()):
env_obs_content = env_output.observations[i]["content"]
tokenized_obs = tokenizer(
env_obs_content, return_tensors="pt", add_special_tokens=False
).input_ids[0]
tokenized_obs = tokenized_obs.to(dtype=torch.int64)

if (
len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i]
>= max_seq_len
):
tokens_left_for_obs = max_seq_len - (
len(generated_ids[i]) + active_input_lengths[i]
)
assert tokens_left_for_obs >= 0, (
f"tokens_left_for_obs={tokens_left_for_obs} should not be negative. This should not happen if the inference engine respects the max sequence length."
)
tokenized_obs = tokenized_obs[:tokens_left_for_obs]
truncation_mask[i] = True
sample_truncated[active_indices[i]] = True

tokenized_env_obs_message = {
"role": env_output.observations[i]["role"],
"content": env_obs_content,
"token_ids": tokenized_obs,
}
current_batch["message_log"][global_idx].append(tokenized_env_obs_message)

sample_env_token_counts[global_idx] += len(tokenized_obs)
sample_token_counts[global_idx] += len(tokenized_obs)
sample_turn_counts[global_idx] += 1

terminateds = env_output.terminateds.bool()
done = truncation_mask | terminateds
sample_terminated[active_indices] |= done

active_indices_local_next = torch.where(~done)[0]
active_indices = active_indices[active_indices_local_next]
continuing_indices_global = active_indices
continuing_next_stops = [
env_output.next_stop_strings[i] for i in active_indices_local_next.tolist()
]
continuing_metadata = [
env_output.metadata[i] for i in active_indices_local_next.tolist()
]

for i, global_idx in enumerate(continuing_indices_global.tolist()):
current_stop_strings[global_idx] = continuing_next_stops[i]
if continuing_metadata[i] is not None:
current_batch["extra_env_info"][global_idx] = continuing_metadata[i]

sample_max_turns_reached[active_indices] = True

current_batch["total_reward"] = total_rewards
current_batch["truncated"] = sample_truncated
if multi_rewards is not None:
num_reward_components = multi_rewards.shape[1]
for i in range(num_reward_components):
current_batch[f"reward{i + 1}"] = multi_rewards[:, i].clone()

rollout_metrics = {
"total_turns": int(sample_turn_counts.sum().item()),
"avg_turns_per_sample": float(sample_turn_counts.float().mean().item()),
"max_turns_per_sample": int(sample_turn_counts.max().item()),
"natural_termination_rate": float(sample_terminated.float().mean().item()),
"truncation_rate": float(sample_truncated.float().mean().item()),
"max_turns_reached_rate": float(
sample_max_turns_reached.float().mean().item()
),
"mean_total_tokens_per_sample": float(
sample_token_counts.float().mean().item()
),
"mean_gen_tokens_per_sample": float(
sample_assistant_token_counts.float().mean().item()
),
"max_gen_tokens_per_sample": float(
sample_assistant_token_counts.float().max().item()
),
"mean_env_tokens_per_sample": float(
sample_env_token_counts.float().mean().item()
),
}
return current_batch, rollout_metrics

return asyncio.run(_async_rollout_implementation())


async def async_generate_response_for_sample_turn(
policy_generation: GenerationInterface,
sample_message_log: list[dict],
Expand Down Expand Up @@ -872,6 +1079,9 @@ def run_async_multi_turn_rollout(

Each sample in the batch proceeds through its interaction independently.
Async generation is used internally when available but the function is synchronous.
This keeps sample-level pipelining across turns, which is useful for some
training paths, but it also means environment evaluation happens from the
per-sample loop rather than from a batched rollout loop.

Args:
policy_generation: The generation interface (policy)
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/algorithms/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,40 @@ def test_validate_function(mock_components):
# Note: validate() function itself doesn't call logger.log_metrics - that's done by the caller


def test_validate_function_uses_batched_async_generation_helper(mock_components):
"""Async distillation validation should use the batched async-generation helper."""
mock_components["master_config"]["policy"]["generation"]["backend"] = "vllm"
mock_components["master_config"]["policy"]["generation"]["vllm_cfg"] = {
"async_engine": True
}
mock_components["student_generation"] = MagicMock()

mock_rollout_metrics = {"mean_gen_tokens_per_sample": 1.0}
with patch(
"nemo_rl.algorithms.distillation.run_multi_turn_rollout_async_generation"
) as mock_async_validation_rollout:
mock_async_validation_rollout.return_value = (
next(iter(mock_components["val_dataloader"])),
mock_rollout_metrics,
)
with patch(
"nemo_rl.algorithms.distillation.run_async_multi_turn_rollout",
side_effect=AssertionError(
"Validation should not use run_async_multi_turn_rollout"
),
):
validate(
mock_components["student_generation"],
mock_components["val_dataloader"],
mock_components["tokenizer"],
mock_components["val_task_to_env"],
step=0,
master_config=mock_components["master_config"],
)

mock_async_validation_rollout.assert_called_once()


def test_check_vocab_equality_pass(monkeypatch):
student_tokenizer = MagicMock()
student_tokenizer.get_vocab.return_value = {"a": 0, "b": 1}
Expand Down
94 changes: 94 additions & 0 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,6 +2127,100 @@ def test_validate_works_without_logger(self):
assert "accuracy" in val_metrics
assert "avg_length" in val_metrics

def test_validate_uses_batched_async_generation_helper(self):
"""Async validation should use the batched async-generation rollout helper."""
mock_policy_gen = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.pad_token_id = 0

mock_batch = BatchedDataDict[DatumSpec](
{
"message_log": [
[
{
"role": "user",
"content": "test1",
"token_ids": torch.tensor([1, 2, 3]),
},
{
"role": "assistant",
"content": "response1",
"token_ids": torch.tensor([4, 5, 6]),
},
]
],
"task_name": ["math"],
"extra_env_info": [{}],
"loss_multiplier": torch.tensor([1.0]),
"idx": torch.tensor([0]),
"length": torch.tensor([6]),
"total_reward": torch.tensor([1.0]),
}
)

mock_dataloader = MagicMock(spec=StatefulDataLoader)
mock_dataloader.__iter__ = MagicMock(return_value=iter([mock_batch]))

mock_env = MagicMock(spec=EnvironmentInterface)
mock_env.global_post_process_and_metrics.return_value = (mock_batch, {})

mock_config = {
"grpo": {
"max_val_samples": 10,
"val_batch_size": 1,
"max_rollout_turns": 1,
},
"policy": {
"max_total_sequence_length": 2048,
"generation": {
"temperature": 1.0,
"top_p": 1.0,
"top_k": None,
"backend": "vllm",
"colocated": {"enabled": True},
"vllm_cfg": {"async_engine": True},
},
},
"logger": {
"num_val_samples_to_print": 1,
},
}

mock_rollout_metrics = {"mean_gen_tokens_per_sample": 10.0}

with patch(
"nemo_rl.algorithms.grpo.run_multi_turn_rollout_async_generation"
) as mock_async_validation_rollout:
mock_async_validation_rollout.return_value = (
mock_batch,
mock_rollout_metrics,
)
with patch(
"nemo_rl.algorithms.grpo.run_async_multi_turn_rollout",
side_effect=AssertionError(
"Validation should not use run_async_multi_turn_rollout"
),
):
with patch(
"nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False
):
with patch(
"nemo_rl.algorithms.grpo._should_use_async_rollouts",
return_value=True,
):
with patch("nemo_rl.algorithms.grpo.print_message_log_samples"):
validate(
mock_policy_gen,
mock_dataloader,
mock_tokenizer,
{"math": mock_env},
step=5,
master_config=mock_config,
logger=None,
)

mock_async_validation_rollout.assert_called_once()

def test_validate_returns_empty_when_no_dataloader(self):
"""Test that validate returns empty dicts when no dataloader is provided."""
mock_policy_gen = MagicMock()
Expand Down
Loading
Loading