Skip to content

ROCm: default GPT-OSS to BF16 and disable AITER#4021

Closed
danielhanchen wants to merge 3380 commits intomainfrom
rocm-gpt-oss-bf16-aiter
Closed

ROCm: default GPT-OSS to BF16 and disable AITER#4021
danielhanchen wants to merge 3380 commits intomainfrom
rocm-gpt-oss-bf16-aiter

Conversation

@danielhanchen
Copy link
Contributor

@danielhanchen danielhanchen commented Feb 10, 2026

Summary

  • Default GPT-OSS model selection to BF16 on HIP to avoid MXFP4 and prequantized blocksize issues
  • Disable AITER and ROCm RoPE backend by default on HIP to avoid build locks and runtime faults

Testing

  • gpt-oss-(20B)-GRPO.ipynb (30 steps)
  • gpt-oss-(20B)-Fine-tuning.ipynb (30 steps)
  • Gemma3_(4B)-Vision.ipynb (30 steps)
  • Llama3.2_(1B_and_3B)-Conversational.ipynb (60 steps)

danielhanchen and others added 30 commits December 16, 2025 15:46
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Update _utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [FIX] [Transformers] VLM input embeds fix for gradients (#3715)

* Fix get_input_embeds call for VLMs

* patch input_require_grads instead

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup old patch

* cleanup old patch

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

* use logger instead of prints

* Move unsloth present set

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <[email protected]>

* Update rope_embedding.py

* Fixes

* Update _utils.py

* Update import_fixes.py

* Update rl_replacements.py

* fix_openenv_no_vllm

* Fix

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* logger

* Update __init__.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update __init__.py

* Update import_fixes.py

* Update __init__.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update import_fixes.py

* Update unsloth/import_fixes.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Datta Nimmaturi <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Silence fbgemm TMA print

Also safer .push_to_hub

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Update _utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [FIX] [Transformers] VLM input embeds fix for gradients (#3715)

* Fix get_input_embeds call for VLMs

* patch input_require_grads instead

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup old patch

* cleanup old patch

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

* use logger instead of prints

* Move unsloth present set

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <[email protected]>

* Update rope_embedding.py

* Fixes

* Update _utils.py

* Update import_fixes.py

* Update rl_replacements.py

* fix_openenv_no_vllm

* Fix

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* logger

* Update __init__.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update __init__.py

* Update import_fixes.py

* Update __init__.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update import_fixes.py

* Update unsloth/import_fixes.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update save.py

* [fbgemm] Silence tma fbgemm (#3735)

* Silence fbgemm TMA print

Also safer .push_to_hub

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update loader.py

* Update save.py

* Update save.py

* Update _utils.py

* Update _utils.py

* Diffusers warnings

* Update pyproject.toml

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Datta Nimmaturi <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@danielhanchen
Copy link
Contributor Author

Validation update for ROCm:

  • Pushed commit 3daaff6b to guard patch_gradient_accumulation_fix when generated _unsloth___init__ is missing.
  • Revalidated with temp/run_229_verify_pr4021_llama3_2_conv on ROCm (Llama3.2_1B_and_3B_Conversational.py, 60 steps).
  • Training completed with sane loss trend (train_loss=0.8569) and inference outputs remained coherent.
  • Confirmed no leftover run process and temporary outputs/*lora* artifacts were cleaned automatically.

NVIDIA compatibility notes:

  • Existing non-ROCm behavior in make_fast_generate_wrapper remains unchanged via DEVICE_TYPE == "hip" guards.
  • The new init patch guard is platform-agnostic and only changes failure handling from hard crash to safe skip when source patch generation fails.

@danielhanchen
Copy link
Contributor Author

Recheck update for PR #4021 against current ROCm notebook reruns.

Re-evaluated items

  • AITER defaults in device_type (AITER_DISABLE=1 and USE_ROCM_AITER_ROPE_BACKEND=0 via setdefault).
  • Trainer init patch guard in models/_utils.py (local_scope + generated function fallback).
  • TorchAO Int4WeightOnlyConfig(version=2) with TypeError fallback.

Findings

  • AITER override path works correctly: explicit env override still works and does not get clobbered.
    • Run: temp/run_236_aiter_on_llama3_2_conv
    • Result: SUCCESS with sane inference output.
  • Trainer init guard remains defensive only; no regression observed in reruns.
  • TorchAO fallback remains compatible with current torchao and protects older versions.

Action

@danielhanchen
Copy link
Contributor Author

Follow-up adjustment: HIP GPT-OSS routing is now capability-based, not blanket BF16.

What changed

  • Updated unsloth/models/loader.py in both loader paths (FastLanguageModel.from_pretrained and FastModel.from_pretrained).
  • Previous behavior forced BF16 for HIP GPT-OSS unconditionally.
  • New behavior checks device capability flags from device_type.py:
    • ALLOW_BITSANDBYTES
    • ALLOW_PREQUANTIZED_MODELS
  • BF16 fallback now applies only when prequantized GPT-OSS is not usable.

Logic

  • wants_prequantized = load_in_4bit or model_name has bnb-4bit suffix
  • can_use_prequantized = ALLOW_BITSANDBYTES and ALLOW_PREQUANTIZED_MODELS
  • Fallback to BF16 only if not (wants_prequantized and can_use_prequantized).

Why

  • Radeon setups can support prequantized GPT-OSS 4bit.
  • Instinct/MI-class setups can require BF16 fallback.
  • This keeps behavior aligned with device_type.py capability detection.

Validation (ROCm)

  • temp/run_242_postloadercap_gpt_oss_grpo: SUCCESS
    • losses [0.0, 0.0]
    • grad norms [0.264757..., 3.079521...]
    • reward columns populated and varying.
  • temp/run_243_postloadercap_llama3_2_conv: SUCCESS
    • losses [0.6471, 0.6471, 0.6443, 0.6324, 0.5953] (decreasing)
    • grad norms logged and stable.

Note

  • On this test machine, ALLOW_PREQUANTIZED_MODELS=False, so BF16 fallback remains active here as expected.
  • The new gating enables prequantized path automatically on HIP devices where capability flags permit it.

Commit

  • 24c7f2ee

@danielhanchen
Copy link
Contributor Author

Recheck complete. I cleaned up the PR to remove unintended fallback behavior that made the diff look botched.\n\nWhat I removed\n- HIP-specific fallback in that silently downgraded when is missing.\n- HIP-specific behavior in that auto-dropped / and auto-tokenized string inputs.\n\nWhat remains intentionally in this PR\n- : ROCm AITER defaults via so explicit user env overrides still win.\n- : capability-gated GPT-OSS BF16 fallback using + (Radeon can still use prequantized checkpoints where supported).\n- : defensive Trainer init patch guard + TorchAO with TypeError fallback for older TorchAO.\n\nCommit pushed\n- \n\nI will continue the AMD notebook reruns and keep logging outcomes in and .

@danielhanchen
Copy link
Contributor Author

Recheck complete. I cleaned up the PR to remove unintended fallback behavior that made the diff look botched.

What I removed

  • HIP-specific fallback in loader.py that silently downgraded fast_inference=True when vllm is missing.
  • HIP-specific behavior in make_fast_generate_wrapper that auto-dropped sampling_params / lora_request and auto-tokenized string inputs.

What remains intentionally in this PR

  • device_type.py: ROCm AITER defaults via setdefault so explicit user env overrides still win.
  • models/loader.py: capability-gated GPT-OSS BF16 fallback using ALLOW_BITSANDBYTES + ALLOW_PREQUANTIZED_MODELS (Radeon can still use prequantized checkpoints where supported).
  • models/_utils.py: defensive Trainer init patch guard + TorchAO Int4WeightOnlyConfig(version=2) with TypeError fallback for older TorchAO.

Commit pushed

  • a211e8c7

I will continue the AMD notebook reruns and keep logging outcomes in RUN_DETAILS.csv and TRANSCRIPT.md.

@danielhanchen
Copy link
Contributor Author

Applied a fresh cleanup/refactor for the duplicated HIP GPT-OSS routing block in loader.py.

What changed

  • Introduced shared helper: _route_hip_gpt_oss_model(...).
  • Replaced both inline copies with helper calls in:
    • FastLanguageModel.from_pretrained
    • FastModel.from_pretrained

Behavior

  • No routing logic changes intended.
  • Same capability-gated behavior (ALLOW_BITSANDBYTES + ALLOW_PREQUANTIZED_MODELS) is preserved.
  • Kept both call sites because both loader entrypoints are valid and used.

Validation

  • python -m py_compile unsloth/models/loader.py passed in ROCm container.

Commit

  • b56ab6ae

@danielhanchen
Copy link
Contributor Author

Follow-up cleanup applied:

  • Moved _route_hip_gpt_oss_model(...) to the footer of unsloth/models/loader.py for readability.
  • No behavior changes; call sites remain unchanged.

Safety note

  • This does not break runtime resolution in Python for this case because the helper is looked up when from_pretrained executes, after module import completes.

Validation

  • python -m py_compile unsloth/models/loader.py passed in ROCm container.
  • Import check passed with helper present (helper_end_exists=True).

Commit

  • 4f138acb

@danielhanchen
Copy link
Contributor Author

Quick status update from the latest ROCm notebook cycle.

  • No new unsloth code changes were required in this pass.
  • The newly identified blocker was in GRPO VLM log-prob handling and was fixed in unsloth-zoo PR RuntimeError: Unsloth: Quantization failed for ./model-unsloth.F16.gguf #494 (c312e59).
  • Verification runs this cycle:
    • temp/run_261_Qwen3_VL_8B_Vision_nocompile_rerun: SUCCESS
    • temp/run_267_Gemma3_4B_Vision_GRPO_nocompile_no_flex_logitsfix: SUCCESS (30/30 GRPO steps)

I will continue the remaining notebook sweep and report any unsloth-side regressions if they appear.

@danielhanchen
Copy link
Contributor Author

Added follow-up AMD notebook stability fixes:\n\n- Commit: \n- Files:\n - \n - \n - \n\nWhat changed:\n- Added definitions for HIP/XPU branches so downstream checks have a stable callable.\n- Hooked Deepseek OCR patch invocation in with a guarded import/call path.\n- Added an offline GGUF guard path in () so Ollama notebook flows do not fail hard when GGUF conversion dependencies are unavailable.\n\nValidation evidence (ROCm runs):\n- Deepseek OCR fix path: (fail) -> (success).\n- Ollama flow: (success).\n- Full notebook latest status remains green in after this cycle (all 25 tracked notebooks currently latest=SUCCESS).

@danielhanchen
Copy link
Contributor Author

Follow-up AMD notebook stability fixes are now pushed.

  • Commit: 9956d1d4
  • Files:
    • unsloth/__init__.py
    • unsloth/models/vision.py
    • unsloth/save.py

What changed:

  • Added is_bf16_supported() definitions for HIP/XPU branches so downstream checks have a stable callable.
  • Hooked Deepseek OCR patch invocation in FastVisionModel.from_pretrained() with a guarded import/call path.
  • Added an offline GGUF guard path in save.py (UNSLOTH_GGUF_OFFLINE=1) so Ollama notebook flows do not fail hard when GGUF conversion dependencies are unavailable.

Validation evidence (ROCm runs):

  • Deepseek OCR fix path: temp/run_247_Deepseek_OCR_3B (fail) -> temp/run_249_Deepseek_OCR_3B_rerun (success)
  • Ollama flow: temp/run_271_Llama3_8B_Ollama (success)
  • Full tracked notebook status is now green in RUN_DETAILS.csv (latest row per notebook = SUCCESS for all 25 notebooks)

@danielhanchen
Copy link
Contributor Author

Additional follow-up pushed:

  • Commit: f0da8260
  • File: unsloth/kernels/utils.py

Change:

  • Reinitialize dequant global weight buffer when dtype changes (WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype).
  • This avoids mixed bf16/fp16 reuse mismatch in global dequant buffers.

Why this was added:

Note on PR #4029:

  • I used --disable-compile in the Gemma3N notebook harness for targeted triage and validation.
  • I did not add a blanket HIP compile-disable in library code in this branch to avoid over-broad behavior changes.

@danielhanchen
Copy link
Contributor Author

Superseding my prior malformed CLI comment with corrected content.

Re-review update with fresh A/B checks on ROCm:

  1. Removed redundant Deepseek OCR patch invocation in unsloth/models/vision.py (commit 734649e4).
  • The Deepseek fix is already applied via temporary patches initialization in unsloth-zoo, so this loader-level call was duplicate.
  1. Validation evidence
  • temp/run_316_Deepseek_OCR_prunecheck (Deepseek masked_scatter patch removed): FAILED
    • RuntimeError: size mismatch in masked_scatter path.
  • temp/run_317_Deepseek_OCR_prunecheck_retry (Deepseek masked_scatter patch restored, loader duplicate call removed): SUCCESS.
  • temp/run_320_Llama3_2_conv_recheck: SUCCESS.
  1. Current assessment
  • Keep the Deepseek masked_scatter patch itself (required for Deepseek OCR inference path).
  • Remove duplicate invocation in vision.py (now done).
  • This cleanup commit does not introduce NVIDIA-specific behavior.

Additional note from sanity run:

  • temp/run_321_gpt_oss_20B_GRPO_recheck failed with CPU/CUDA embedding index device mismatch in a notebook generate path (embed_tokens offload interaction). This is separate from the duplicate-call cleanup and needs a dedicated fix path.

GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Feb 14, 2026
MI355X (gfx950) has the same 1024-thread workgroup limit as MI300X (gfx942),
but was missing from is_cdna(), causing all Triton kernels to use num_warps=32
(2048 threads) instead of 16 (1024 threads), resulting in OutOfResources crash.

Also includes ROCm GPT-OSS BF16 routing and dequant buffer dtype fix from PR unslothai#4021
by @danielhanchen, cherry-picked for MI355X validation.

Tested on: 8x AMD Instinct MI355X (gfx950), ROCm 7.1
- Vision RL GRPO (Qwen2.5-VL-7B): 5/5 steps
- Code RL GRPO (gpt-oss-20b BF16): 20/20 steps
- gpt-oss-120b GRPO: 5/5 steps (B200 OOM'd on this)
- MoE expert LoRA + save_pretrained_merged: success
@GoldenGrapeGentleman
Copy link
Contributor

W7900 (gfx1100, RDNA3) Validation

Environment: ROCm 7.1 | PyTorch 2.8.0+rocm7.1 | Triton 3.4.0

@danielhanchen Tested the AITER disable logic on W7900 — confirmed AITER is not available on RDNA3 and the setdefault approach correctly prevents import failures.

Re: the GPT-OSS BF16 default — on RDNA3, MXFP4 is not supported (no HW matrix engine), so defaulting to BF16 is the correct choice.

Also note: we have submitted #4109 (compile disable for Gemma3 on HIP) and #4110 (RMS LayerNorm eps float32 fix) which complement this PR for broader RDNA stability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.