Skip to content

Commit da87d77

Browse files
mnoukhovrootfinbarrtimbers
authored
Active sampling simplified (#1143)
* tmp * initial conversion to reward in accumulate_inference_batches * nearly working * first test fixes * running, just need to test reduced logging * test scripts, tmp commit for integration test * update tests * intermediate commit * fix accumulate_inference_batches inputs * change model to actually solve * refill filtered prompts move weight sync directly after update episode now refers to "training episode", not "generation episode" as previously * fix test reward fn * cleanup and move episode to later * allow for not having time/reward metric * always calculate advantage becomes the same as reward when num_responses_per_prompt is 1 just because cursor keeps complaining * try to fix test * fix ground truths and datasets makes grpo and ppo reward functions the same * fix test we now return k repeats of a prompts, not just 1 in the batch * active sampling in large tests * Update open_instruct/grpo_fast.py Co-authored-by: Finbarr Timbers <[email protected]> * Update open_instruct/grpo_fast.py Co-authored-by: Finbarr Timbers <[email protected]> * cursor was right * address comments * nit * 32b without active sampling * repeat each fix --------- Co-authored-by: root <[email protected]> Co-authored-by: Finbarr Timbers <[email protected]>
1 parent e3f4341 commit da87d77

File tree

8 files changed

+347
-245
lines changed

8 files changed

+347
-245
lines changed

open_instruct/grpo_fast.py

Lines changed: 251 additions & 232 deletions
Large diffs are not rendered by default.

open_instruct/model_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class Batch:
5555
ground_truths: list[list[int]]
5656
datasets: list[str]
5757
raw_queries: list[str] | None
58+
decoded_responses: list[str] | None
5859
indices: list[int] | None
60+
scores: list[float] | None
5961

6062
def __getitem__(self, key: slice | int | list[int]) -> "Batch":
6163
"""Enable indexing and slicing: batch[5], batch[start:end], or batch[[1,3,5]]."""
@@ -262,7 +264,8 @@ async def apply_verifiable_reward(
262264
reward_fn_mapping: dict[str, VerifierFunction],
263265
responses: list[torch.Tensor],
264266
decoded_responses: list[str],
265-
batch: Batch,
267+
ground_truths: list[float],
268+
datasets: list[str],
266269
reward_mult: int = 10,
267270
queries: list[str] | None = None,
268271
):
@@ -274,7 +277,7 @@ async def apply_verifiable_reward(
274277
task_metadata = []
275278

276279
for i, (tok_prediction, prediction, ground_truth, dataset, query) in enumerate(
277-
zip(responses, decoded_responses, batch.ground_truths, batch.datasets, queries)
280+
zip(responses, decoded_responses, ground_truths, datasets, queries)
278281
):
279282
# allow multiple ground truths and datasets for a single response
280283

@@ -308,7 +311,7 @@ async def apply_verifiable_reward(
308311
# Execute all tasks in parallel
309312
if async_tasks:
310313
reward_results = await asyncio.gather(*async_tasks)
311-
logger.info(f"Applied {len(reward_results)} ground truth rewards in parallel 🤗")
314+
logger.debug(f"Applied {len(reward_results)} ground truth rewards in parallel 🤗")
312315
else:
313316
reward_results = []
314317

open_instruct/test_grpo_fast.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import threading
55
import time
66
import unittest
7+
from typing import Any
78
from unittest.mock import MagicMock, Mock
89

910
import numpy as np
@@ -198,6 +199,26 @@ def create_mock_result(self, dataset_index, epoch_number, num_samples_per_prompt
198199
logprobs=[[0.0, 0.0, 0.0] for _ in range(total_responses)],
199200
)
200201

202+
def create_mock_tokenizer_and_reward_fn(self):
203+
# Set up dummy tokenizer
204+
tokenizer_name = "EleutherAI/pythia-14m" # Using a small model for testing
205+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
206+
207+
# Set up dummy reward fn that will guarantee nonzero std
208+
async def reward_fn(
209+
responses: list[torch.Tensor],
210+
decoded_responses: list[str],
211+
ground_truths: list[Any],
212+
datasets: list[str],
213+
finish_reasons: list[str],
214+
infos: list[list[int]],
215+
queries: list[str] | None = None,
216+
) -> (list[float], dict[str, Any]):
217+
num_responses = len(responses)
218+
return [i / num_responses for i in range(num_responses)], {"time/reward": 0.0}
219+
220+
return tokenizer, reward_fn
221+
201222
def setup_and_split_batch(
202223
self, queries, ground_truths, datasets, raw_queries, indices, num_engines, training_step=1
203224
):
@@ -212,7 +233,13 @@ def setup_and_split_batch(
212233
self._ray_queues.extend([param_prompt_Q, inference_results_Q])
213234

214235
batch = model_utils.Batch(
215-
queries=queries, ground_truths=ground_truths, datasets=datasets, raw_queries=raw_queries, indices=indices
236+
queries=queries,
237+
ground_truths=ground_truths,
238+
datasets=datasets,
239+
raw_queries=raw_queries,
240+
indices=indices,
241+
decoded_responses=None,
242+
scores=None,
216243
)
217244

218245
mock_generation_config = MagicMock()
@@ -558,6 +585,9 @@ def test_out_of_order_processing(self):
558585
# Create test data
559586
queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts)
560587

588+
# Create mock tokenizer and reward
589+
tokenizer, reward_fn = self.create_mock_tokenizer_and_reward_fn()
590+
561591
# Setup and split batch
562592
param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_split_batch(
563593
queries, ground_truths, datasets, raw_queries, indices, num_engines
@@ -580,17 +610,19 @@ def test_out_of_order_processing(self):
580610
mock_generation_config.n = num_samples_per_prompt
581611

582612
mock_model_dims = self.create_mock_model_dims()
583-
combined_result, batch, prompt_lengths, response_lengths = grpo_fast.accumulate_inference_batches(
613+
combined_result, batch, reward_metrics, batch_stats = grpo_fast.accumulate_inference_batches(
584614
inference_results_Q,
585615
pending_queries_map,
586616
mock_args,
587617
generation_config=mock_generation_config,
588618
num_prompts=num_prompts,
589619
model_dims=mock_model_dims,
620+
tokenizer=tokenizer,
621+
reward_fn=reward_fn,
590622
)
591623

592624
# Verify results work correctly even with out-of-order processing
593-
self.assertEqual(len(batch.queries), num_prompts)
625+
self.assertEqual(len(batch.queries), num_prompts * num_samples_per_prompt)
594626
self.assertEqual(len(combined_result.responses), num_prompts * num_samples_per_prompt)
595627
self.assertEqual(len(pending_queries_map), 0)
596628

@@ -643,6 +675,9 @@ def test_accumulate_waits_for_all_engines(self):
643675
num_engines = 4
644676
num_prompts = 16
645677

678+
# Create mock tokenizer and reward
679+
tokenizer, reward_fn = self.create_mock_tokenizer_and_reward_fn()
680+
646681
# Setup with results from only 3 engines
647682
# Queue size must be large enough for all results being put before accumulation starts
648683
expected_results = 3 * (num_prompts // num_engines) # 3 engines * 4 results each = 12
@@ -682,6 +717,8 @@ def run_accumulate():
682717
generation_config=mock_generation_config,
683718
num_prompts=num_prompts,
684719
model_dims=mock_model_dims,
720+
tokenizer=tokenizer,
721+
reward_fn=reward_fn,
685722
)
686723
completed.set()
687724
except Exception:
@@ -717,7 +754,13 @@ def test_more_engines_than_queries(self):
717754
self._ray_queues.append(param_prompt_Q)
718755

719756
batch = model_utils.Batch(
720-
queries=queries, ground_truths=ground_truths, datasets=datasets, raw_queries=raw_queries, indices=indices
757+
queries=queries,
758+
ground_truths=ground_truths,
759+
datasets=datasets,
760+
raw_queries=raw_queries,
761+
indices=indices,
762+
decoded_responses=None,
763+
scores=None,
721764
)
722765

723766
mock_generation_config = MagicMock()
@@ -768,7 +811,13 @@ def test_uneven_distribution_no_empty_batches(self):
768811
self._ray_queues.append(param_prompt_Q)
769812

770813
batch = model_utils.Batch(
771-
queries=queries, ground_truths=ground_truths, datasets=datasets, raw_queries=raw_queries, indices=indices
814+
queries=queries,
815+
ground_truths=ground_truths,
816+
datasets=datasets,
817+
raw_queries=raw_queries,
818+
indices=indices,
819+
decoded_responses=None,
820+
scores=None,
772821
)
773822

774823
mock_generation_config = MagicMock()

open_instruct/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import sys
4444
import threading
4545
import time
46+
from collections import defaultdict
4647
from collections.abc import Iterable
4748
from concurrent import futures
4849
from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p
@@ -2366,3 +2367,29 @@ def check_calculation(
23662367
)
23672368

23682369
logger.warning(warning_message)
2370+
2371+
2372+
def combine_reward_metrics(reward_metrics: list[dict[str, Any]]) -> dict[str, Any]:
2373+
"""Assumes same number of metric_records in each dict in the list"""
2374+
buckets = defaultdict(list)
2375+
for metrics in reward_metrics:
2376+
for key, value in metrics.items():
2377+
buckets[key].append(value)
2378+
2379+
combined: dict[str, Any] = {}
2380+
for key, records in buckets.items():
2381+
sample_value = records[0]
2382+
if isinstance(sample_value, np.ndarray):
2383+
combined[key] = [x for value in records for x in value]
2384+
elif isinstance(sample_value, (list | tuple)):
2385+
concatenated: list[Any] = []
2386+
for value in records:
2387+
concatenated.extend(list(value))
2388+
combined[key] = concatenated
2389+
elif isinstance(sample_value, (int | float | bool | np.integer | np.floating)):
2390+
# combine and get average value
2391+
combined[key] = sum(value for value in records) / len(records) if len(records) > 0 else sample_value
2392+
else:
2393+
# Fallback: keep the latest value if aggregation strategy is unclear.
2394+
combined[key] = records[-1]
2395+
return combined

scripts/train/debug/grpo_fast.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ uv run python open_instruct/grpo_fast.py \
1414
--num_samples_per_prompt_rollout 4 \
1515
--model_name_or_path Qwen/Qwen3-0.6B \
1616
--stop_strings "</answer>" \
17-
--apply_r1_style_format_reward \
1817
--apply_verifiable_reward true \
1918
--temperature 0.7 \
2019
--ground_truths_key ground_truth \
@@ -36,4 +35,5 @@ uv run python open_instruct/grpo_fast.py \
3635
--single_gpu_mode \
3736
--push_to_hub false \
3837
--system_prompt_override_file scripts/train/debug/cute_debug_system_prompt.txt \
38+
--active_sampling --async_steps 8
3939
# --with_tracking

scripts/train/debug/large_test_script.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,6 @@ uv run python mason.py \
6161
--oe_eval_max_length 32768 \
6262
--oe_eval_tasks "codex_humanevalplus:0-shot-chat-v1::tulu-thinker,mbppplus:0-shot-chat::tulu-thinker,livecodebench_codegeneration::tulu-thinker" \
6363
--dataset_skip_cache True \
64+
--active_sampling \
65+
--async_steps 4 \
6466
--push_to_hub False

scripts/train/debug/single_gpu_integration_test.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ uv run python mason.py \
2828
--dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 16 \
2929
--dataset_mixer_eval_list_splits train \
3030
--max_prompt_token_length 512 \
31-
--response_length 512 \
32-
--pack_length 1024 \
31+
--response_length 1024 \
32+
--pack_length 2048 \
3333
--per_device_train_batch_size 1 \
3434
--num_unique_prompts_rollout 8 \
3535
--num_samples_per_prompt_rollout 4 \
36-
--model_name_or_path Qwen/Qwen3-1.7B \
36+
--model_name_or_path Qwen/Qwen2.5-0.5B \
3737
--stop_strings "</answer>" \
3838
--apply_r1_style_format_reward \
3939
--apply_verifiable_reward true \
@@ -55,4 +55,6 @@ uv run python mason.py \
5555
--vllm_enforce_eager \
5656
--gradient_checkpointing \
5757
--push_to_hub false \
58+
--active_sampling \
59+
--async_steps 8 \
5860
--single_gpu_mode

scripts/train/olmo3/32b_rl_smoke_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ uv run python mason.py \
4545
--verbose False \
4646
--ground_truths_key ground_truth \
4747
--sft_messages_key messages \
48-
--total_episodes 200_000 \
48+
--total_episodes 10240 \
4949
--gather_whole_model False \
5050
--deepspeed_stage 3 \
5151
--num_learners_per_node 8 8 8 \

0 commit comments

Comments
 (0)