Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4e1e24c
ROCm: default GPT-OSS to BF16 and disable AITER
danielhanchen Feb 10, 2026
3daaff6
ROCm: guard Trainer init patch against missing generated function
danielhanchen Feb 11, 2026
cff534d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
24c7f2e
ROCm GPT-OSS: gate BF16 fallback by prequant capability
danielhanchen Feb 11, 2026
458af41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
a211e8c
ROCm: trim unintended fast-inference fallback behaviors
danielhanchen Feb 11, 2026
b56ab6a
Refactor HIP GPT-OSS routing into shared loader helper
danielhanchen Feb 11, 2026
4f138ac
Move HIP GPT-OSS routing helper to loader footer
danielhanchen Feb 11, 2026
9956d1d
ROCm notebook stability: deepseek OCR hook + offline GGUF guard
danielhanchen Feb 11, 2026
28aa6c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
f0da826
Fix dequant global buffer dtype reuse across mixed precision
danielhanchen Feb 11, 2026
734649e
Remove redundant Deepseek OCR patch call from vision loader
danielhanchen Feb 11, 2026
8b12e72
ROCm notebook stability: deepseek OCR hook + offline GGUF guard
danielhanchen Feb 11, 2026
7369727
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
5ac8f45
Fix dequant global buffer dtype reuse across mixed precision
danielhanchen Feb 11, 2026
897a004
Remove redundant Deepseek OCR patch call from vision loader
danielhanchen Feb 11, 2026
8729bb5
Merge remote-tracking branch 'pr4021/rocm-gpt-oss-bf16-aiter' into fi…
billishyahao Feb 14, 2026
41c5a96
Add gfx950 (MI355X/CDNA4) to is_cdna() for correct Triton num_warps
GoldenGrapeGentleman Feb 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,18 @@ def is_bf16_supported():
del major_version, minor_version
elif DEVICE_TYPE == "hip":
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()

def is_bf16_supported():
return SUPPORTS_BFLOAT16
elif DEVICE_TYPE == "xpu":
# torch.xpu.is_bf16_supported() does not have including_emulation
# set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()

def is_bf16_supported():
return SUPPORTS_BFLOAT16
Comment on lines 207 to +218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, you can refactor this logic. The is_bf16_supported function is defined identically for both hip and xpu device types. Combining the elif blocks for hip and xpu and defining the function only once would make the code cleaner.

Suggested change
elif DEVICE_TYPE == "hip":
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
def is_bf16_supported():
return SUPPORTS_BFLOAT16
elif DEVICE_TYPE == "xpu":
# torch.xpu.is_bf16_supported() does not have including_emulation
# set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()
def is_bf16_supported():
return SUPPORTS_BFLOAT16
elif DEVICE_TYPE in ("hip", "xpu"):
if DEVICE_TYPE == "hip":
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
else: # xpu
# torch.xpu.is_bf16_supported() does not have including_emulation
# set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()
def is_bf16_supported():
return SUPPORTS_BFLOAT16



# For Gradio HF Spaces?
# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
import triton
Expand Down
5 changes: 5 additions & 0 deletions unsloth/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
]

import torch
import os
import functools
import inspect
from unsloth_zoo.utils import Version
Expand Down Expand Up @@ -94,6 +95,10 @@ def get_device_count():
# HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB
ALLOW_BITSANDBYTES: bool = True
if DEVICE_TYPE == "hip":
# Disable AITER by default on ROCm to avoid JIT build locks and runtime faults.
# Users can override by explicitly setting env vars.
os.environ.setdefault("AITER_DISABLE", "1")
os.environ.setdefault("USE_ROCM_AITER_ROPE_BACKEND", "0")
try:
import bitsandbytes
except:
Expand Down
5 changes: 3 additions & 2 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def is_cdna():
"gfx940",
"gfx941",
"gfx942",
"gfx950", # CDNA4 (MI350/MI355X)
)


Expand Down Expand Up @@ -388,7 +389,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
global ABSMAX_BUFFERS
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
size, dtype = dtype, device = device, requires_grad = False
)
Expand Down Expand Up @@ -498,7 +499,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
global ABSMAX_BUFFERS
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
size, dtype = dtype, device = device, requires_grad = False
)
Expand Down
28 changes: 25 additions & 3 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,8 +2019,23 @@ def patch_gradient_accumulation_fix(Trainer):
'if hasattr(unwrapped_model, "accepts_loss_kwargs") and False:',
init_function,
)
exec(init_function, globals())
Trainer.__init__ = _unsloth___init__
local_scope = {}
try:
exec(init_function, globals(), local_scope)
except Exception as _patch_error:
print(
"Unsloth: gradient accumulation init patch skipped due to "
f"source patch error: {_patch_error}"
)
local_scope = {}
_patched_init = local_scope.get("_unsloth___init__")
if _patched_init is None:
print(
"Unsloth: gradient accumulation init patch skipped because "
"_unsloth___init__ was not generated."
)
else:
Trainer.__init__ = _patched_init


def patch_tokenizer(model, tokenizer):
Expand Down Expand Up @@ -2499,7 +2514,14 @@ def _prepare_model_for_qat(
except ImportError:
raise ImportError(TORCHAO_MSG)
group_size = 128
base_config = Int4WeightOnlyConfig(group_size = group_size)
try:
base_config = Int4WeightOnlyConfig(
group_size = group_size,
version = 2,
)
except TypeError:
# Older TorchAO versions do not support the version argument.
base_config = Int4WeightOnlyConfig(group_size = group_size)
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
Expand Down
83 changes: 83 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,24 @@ def from_pretrained(
)
load_in_4bit = False

(
model_name,
load_in_4bit,
load_in_8bit,
load_in_fp8,
load_in_16bit,
quantization_config,
) = _route_hip_gpt_oss_model(
model_name = model_name,
use_exact_model_name = use_exact_model_name,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
load_in_fp8 = load_in_fp8,
load_in_16bit = load_in_16bit,
quantization_config = quantization_config,
kwargs = kwargs,
)
Comment on lines +278 to +294
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for routing HIP GPT-OSS models is duplicated in FastModel.from_pretrained at lines 880-896. To improve maintainability and reduce redundancy, consider refactoring this logic into a shared helper method that both FastLanguageModel.from_pretrained and FastModel.from_pretrained can call. This would centralize the model routing logic, making future updates easier.


# Find FP8, BnB 4bit, other mapped names
old_model_name = model_name
fp8_mode = None
Expand Down Expand Up @@ -859,6 +877,24 @@ def from_pretrained(
)
load_in_4bit = False

(
model_name,
load_in_4bit,
load_in_8bit,
load_in_fp8,
load_in_16bit,
quantization_config,
) = _route_hip_gpt_oss_model(
model_name = model_name,
use_exact_model_name = use_exact_model_name,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
load_in_fp8 = load_in_fp8,
load_in_16bit = load_in_16bit,
quantization_config = quantization_config,
kwargs = kwargs,
)

if fast_inference:
if importlib.util.find_spec("vllm") is None:
raise ImportError(
Expand Down Expand Up @@ -1407,3 +1443,50 @@ class FastVisionModel(FastModel):

class FastTextModel(FastModel):
pass


def _route_hip_gpt_oss_model(
model_name,
use_exact_model_name,
load_in_4bit,
load_in_8bit,
load_in_fp8,
load_in_16bit,
quantization_config,
kwargs,
):
# AMD GPT-OSS routing:
# - Radeon can often use prequantized bnb-4bit checkpoints.
# - Instinct/MI (warp=64) often cannot, so fallback to BF16.
lower_model_name = model_name.lower()
if (
is_hip()
and ("gpt-oss" in lower_model_name or "gpt_oss" in lower_model_name)
and not use_exact_model_name
):
gpt_oss_prequant_suffix = lower_model_name.endswith(
("-unsloth-bnb-4bit", "-bnb-4bit")
)
wants_prequantized = load_in_4bit or gpt_oss_prequant_suffix
can_use_prequantized = ALLOW_BITSANDBYTES and ALLOW_PREQUANTIZED_MODELS
if not (wants_prequantized and can_use_prequantized):
if not lower_model_name.endswith("-bf16"):
if "120b" in lower_model_name:
model_name = "unsloth/gpt-oss-120b-BF16"
else:
model_name = "unsloth/gpt-oss-20b-BF16"
Comment on lines +1473 to +1477

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restrict HIP GPT-OSS remap to canonical model IDs

_route_hip_gpt_oss_model rewrites matched names to unsloth/gpt-oss-20b-BF16/120b-BF16 based only on substring matching, so on HIP it can replace requested non-base models (e.g. unsloth/gpt-oss-safeguard-20b, which is a valid mapped ID in unsloth/models/mapper.py:1246-1252) and local checkpoint paths that include gpt-oss. In those cases the loader silently fetches different weights than the caller asked for, which can invalidate training/evaluation results.

Useful? React with 👍 / 👎.

load_in_4bit = False
load_in_8bit = False
load_in_fp8 = False
load_in_16bit = True
quantization_config = None
kwargs.pop("quantization_config", None)

return (
model_name,
load_in_4bit,
load_in_8bit,
load_in_fp8,
load_in_16bit,
quantization_config,
)
20 changes: 19 additions & 1 deletion unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,14 @@ def unsloth_save_pretrained_gguf(
is_gpt_oss = is_gpt_oss, # Pass gpt_oss Flag
)
except Exception as e:
if IS_KAGGLE_ENVIRONMENT:
if os.environ.get("UNSLOTH_GGUF_OFFLINE", "0") == "1":
print(
"Unsloth: GGUF conversion skipped due to offline mode. " f"Reason: {e}"
)
all_file_locations = []
want_full_precision = None
is_vlm_update = False
elif IS_KAGGLE_ENVIRONMENT:
raise RuntimeError(
f"Unsloth: GGUF conversion failed in Kaggle environment.\n"
f"This is likely due to the 20GB disk space limit.\n"
Expand All @@ -2010,6 +2017,17 @@ def unsloth_save_pretrained_gguf(
gguf_directory = f"{save_directory}_gguf"
modelfile_location = None
ollama_success = False
if not all_file_locations:
# Offline or failed GGUF conversion: return early to avoid index errors
return {
"save_directory": save_directory,
"gguf_directory": gguf_directory,
"gguf_files": all_file_locations,
"modelfile_location": modelfile_location,
"want_full_precision": want_full_precision,
"is_vlm": is_vlm_update,
"fix_bos_token": fix_bos_token,
}
if all_file_locations:
try:
if is_vlm_update:
Expand Down