Skip to content

Commit 0cc2e29

Browse files
authored
Merge branch 'allenai:main' into main
2 parents bd4250a + 1986faa commit 0cc2e29

24 files changed

+1289
-609
lines changed

.github/workflows/quality.yml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,8 @@ jobs:
1717
- name: Install uv
1818
uses: astral-sh/setup-uv@v4
1919
with:
20-
version: "0.5.11"
21-
- name: Set up Python
22-
run: uv python install 3.10
23-
- name: Install dependencies
24-
run: uv sync --frozen --only-group dev
20+
version: "0.8.8"
2521
- name: Code quality
2622
run: |
27-
source .venv/bin/activate
28-
ruff format --check --diff open_instruct
29-
ruff check --exit-non-zero-on-fix open_instruct
30-
23+
uv run ruff format --check --diff open_instruct mason.py
24+
uv run ruff check --exit-non-zero-on-fix open_instruct mason.py

mason.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from rich.console import Console
1414
from rich.text import Text
1515

16+
from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS
17+
1618
console = Console()
1719

1820

@@ -87,11 +89,6 @@ def parse_env_var(env_var_str: str) -> dict[str, str]:
8789
return {"name": name, "value": value}
8890

8991

90-
WEKA_CLUSTERS = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"]
91-
GCP_CLUSTERS = ["ai2/augusta"]
92-
93-
INTERCONNECT_CLUSTERS = ["ai2/jupiter", "ai2/ceres", "ai2/titan", "ai2/augusta"]
94-
9592
# by default, we turn off vllm compile cache
9693
# torch compile caching seems consistently broken, but the actual compiling isn't.
9794
# Not sure why, for now we have disabled the caching (VLLM_DISABLE_COMPILE_CACHE=1).
@@ -589,24 +586,29 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami:
589586
model_revision = command[idx + 1]
590587
break
591588

592-
commit_hash = get_commit_hash(model_name_or_path, model_revision, "config.json", "model")
593-
if os.path.exists(model_name_or_path):
594-
path = model_name_or_path
595-
assert args.gs_model_name is not None, "for local models to upload to gs, you must set --gs_model_name"
596-
model_name_or_path = args.gs_model_name
597-
commit_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8]
598-
console.log(
599-
f"Local model is already downloaded, using gs_model_name {model_name_or_path}, with hash of model path {commit_hash}"
600-
)
589+
if model_name_or_path.startswith("gs://"):
590+
gs_saved_path = model_name_or_path
601591
else:
602-
download_from_hf(model_name_or_path, model_revision) # first download the model
603-
path = download_from_hf(model_name_or_path, model_revision) # then get the path
604-
gs_saved_path = f"gs://ai2-llm/post-training/deletable_cache_models/{model_name_or_path}/{commit_hash}"
605-
gs_folder = gs_folder_exists(
606-
gs_saved_path
607-
) # race condition exists, but it's fine since we are launching mason sequentially
608-
if not gs_folder:
609-
upload_to_gs_bucket(path, gs_saved_path)
592+
commit_hash = get_commit_hash(model_name_or_path, model_revision, "config.json", "model")
593+
if os.path.exists(model_name_or_path):
594+
path = model_name_or_path
595+
assert args.gs_model_name is not None, (
596+
"for local models to upload to gs, you must set --gs_model_name"
597+
)
598+
model_name_or_path = args.gs_model_name
599+
commit_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8]
600+
console.log(
601+
f"Local model is already downloaded, using gs_model_name {model_name_or_path}, with hash of model path {commit_hash}"
602+
)
603+
else:
604+
download_from_hf(model_name_or_path, model_revision) # first download the model
605+
path = download_from_hf(model_name_or_path, model_revision) # then get the path
606+
gs_saved_path = f"gs://ai2-llm/post-training/deletable_cache_models/{model_name_or_path}/{commit_hash}"
607+
gs_folder = gs_folder_exists(
608+
gs_saved_path
609+
) # race condition exists, but it's fine since we are launching mason sequentially
610+
if not gs_folder:
611+
upload_to_gs_bucket(path, gs_saved_path)
610612

611613
download_path = gs_saved_path.replace("gs://", "/gs/")
612614
download_path_without_last_folder = download_path.rsplit("/", 1)[0]

open_instruct/actor_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ async def api_status():
110110
"queues": queues_data,
111111
"token_stats": self.get_token_stats(),
112112
"timing_stats": self.get_timing_stats(),
113-
"kv_cache_max_concurrency": self._kv_cache_max_concurrency,
114-
# This is less confusing to users.
115-
"inference_batch_size": self._args.inference_batch_size * self._args.num_samples_per_prompt_rollout,
113+
"concurrency_per_engine": self._kv_cache_max_concurrency,
114+
"total_concurrency": self._kv_cache_max_concurrency * self._args.vllm_num_engines,
115+
"batch_size": self._args.num_unique_prompts_rollout * self._args.num_samples_per_prompt_rollout,
116116
}
117117

118118
def run_server():

open_instruct/benchmark_generators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,6 @@ def setup_vllm_engines(
263263
prompt_queue=param_prompt_Q,
264264
results_queue=inference_results_Q,
265265
actor_manager=actor_manager,
266-
inference_batch_size=args.inference_batch_size,
267-
use_fp8_kv_cache=args.use_fp8_kv_cache,
268266
inflight_updates=args.inflight_updates,
269267
)
270268

open_instruct/dataset_transformation.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,51 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai
438438
"{% endif %}"
439439
"{% endfor %}"
440440
),
441+
"olmo_thinker_remove_intermediate_thinking": (
442+
"{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}"
443+
"{% if not has_system %}"
444+
"{{ '<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n' }}"
445+
"{% endif %}"
446+
"{% for message in messages %}"
447+
"{% if message['role'] == 'system' %}"
448+
"{{ '<|im_start|>system\n' + message['content'] }}"
449+
"{% if message.get('functions', none) is not none %}"
450+
"{{ ' <functions>' + message['functions'] + '</functions><|im_end|>\n' }}"
451+
"{% else %}"
452+
"{{ ' You do not currently have access to any functions. <functions></functions><|im_end|>\n' }}"
453+
"{% endif %}"
454+
"{% elif message['role'] == 'user' %}"
455+
"{% if message.get('functions', none) is not none %}"
456+
"{{ '<|im_start|>user\n' + message['content'] + '\n' + '<functions>' + message['functions'] + '</functions><|im_end|>\n' }}"
457+
"{% else %}"
458+
"{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}"
459+
"{% endif %}"
460+
"{% elif message['role'] == 'assistant' %}"
461+
"{{ '<|im_start|>assistant\n' }}"
462+
"{% set content = message.get('content', none) %}"
463+
"{% if content is not none %}"
464+
"{% set content = content | string %}"
465+
"{% if not loop.last and '</think>' in content and '<think>' in content %}"
466+
"{% set content = content.split('</think>')[-1].lstrip('\\n') %}"
467+
"{% endif %}"
468+
"{{ content }}"
469+
"{% endif %}"
470+
"{% if message.get('function_calls', none) is not none %}"
471+
"{{ '<function_calls>' + message['function_calls'] + '</function_calls>' }}"
472+
"{% endif %}"
473+
"{% if not loop.last %}"
474+
"{{ '<|im_end|>' + '\n' }}"
475+
"{% else %}"
476+
"{{ eos_token }}"
477+
"{% endif %}"
478+
"{% elif message['role'] == 'environment' %}"
479+
"{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}"
480+
"{% endif %}"
481+
"{% if loop.last and add_generation_prompt %}"
482+
"{{ '<|im_start|>assistant\n<think>' }}"
483+
"{% endif %}"
484+
"{% endfor %}"
485+
),
441486
"olmo_thinker_no_think_sft_tokenization": (
442487
"{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}"
443488
"{% if not has_system %}"

0 commit comments

Comments
 (0)