diff --git a/unsloth/__init__.py b/unsloth/__init__.py index a505e89ad4..2c5a0ffe9e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -206,11 +206,18 @@ def is_bf16_supported(): del major_version, minor_version elif DEVICE_TYPE == "hip": SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported() + + def is_bf16_supported(): + return SUPPORTS_BFLOAT16 elif DEVICE_TYPE == "xpu": # torch.xpu.is_bf16_supported() does not have including_emulation # set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported() SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported() + def is_bf16_supported(): + return SUPPORTS_BFLOAT16 + + # For Gradio HF Spaces? # if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: import triton diff --git a/unsloth/device_type.py b/unsloth/device_type.py index 0f924bfdfd..8142fc58cf 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -23,6 +23,7 @@ ] import torch +import os import functools import inspect from unsloth_zoo.utils import Version @@ -94,6 +95,10 @@ def get_device_count(): # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES: bool = True if DEVICE_TYPE == "hip": + # Disable AITER by default on ROCm to avoid JIT build locks and runtime faults. + # Users can override by explicitly setting env vars. + os.environ.setdefault("AITER_DISABLE", "1") + os.environ.setdefault("USE_ROCM_AITER_ROPE_BACKEND", "0") try: import bitsandbytes except: diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 5dcc7c232c..eb50a5f617 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -82,6 +82,7 @@ def is_cdna(): "gfx940", "gfx941", "gfx942", + "gfx950", # CDNA4 (MI350/MI355X) ) @@ -388,7 +389,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False global ABSMAX_BUFFERS WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index] ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index] - if WEIGHT_BUFFER is None: + if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype: WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty( size, dtype = dtype, device = device, requires_grad = False ) @@ -498,7 +499,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False global ABSMAX_BUFFERS WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index] ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index] - if WEIGHT_BUFFER is None: + if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype: WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty( size, dtype = dtype, device = device, requires_grad = False ) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e5b73db079..aa306cea6e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -2019,8 +2019,23 @@ def patch_gradient_accumulation_fix(Trainer): 'if hasattr(unwrapped_model, "accepts_loss_kwargs") and False:', init_function, ) - exec(init_function, globals()) - Trainer.__init__ = _unsloth___init__ + local_scope = {} + try: + exec(init_function, globals(), local_scope) + except Exception as _patch_error: + print( + "Unsloth: gradient accumulation init patch skipped due to " + f"source patch error: {_patch_error}" + ) + local_scope = {} + _patched_init = local_scope.get("_unsloth___init__") + if _patched_init is None: + print( + "Unsloth: gradient accumulation init patch skipped because " + "_unsloth___init__ was not generated." + ) + else: + Trainer.__init__ = _patched_init def patch_tokenizer(model, tokenizer): @@ -2499,7 +2514,14 @@ def _prepare_model_for_qat( except ImportError: raise ImportError(TORCHAO_MSG) group_size = 128 - base_config = Int4WeightOnlyConfig(group_size = group_size) + try: + base_config = Int4WeightOnlyConfig( + group_size = group_size, + version = 2, + ) + except TypeError: + # Older TorchAO versions do not support the version argument. + base_config = Int4WeightOnlyConfig(group_size = group_size) filter_fn = ( lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 4054f1b7f5..393c46dec2 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -275,6 +275,24 @@ def from_pretrained( ) load_in_4bit = False + ( + model_name, + load_in_4bit, + load_in_8bit, + load_in_fp8, + load_in_16bit, + quantization_config, + ) = _route_hip_gpt_oss_model( + model_name = model_name, + use_exact_model_name = use_exact_model_name, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + load_in_fp8 = load_in_fp8, + load_in_16bit = load_in_16bit, + quantization_config = quantization_config, + kwargs = kwargs, + ) + # Find FP8, BnB 4bit, other mapped names old_model_name = model_name fp8_mode = None @@ -859,6 +877,24 @@ def from_pretrained( ) load_in_4bit = False + ( + model_name, + load_in_4bit, + load_in_8bit, + load_in_fp8, + load_in_16bit, + quantization_config, + ) = _route_hip_gpt_oss_model( + model_name = model_name, + use_exact_model_name = use_exact_model_name, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + load_in_fp8 = load_in_fp8, + load_in_16bit = load_in_16bit, + quantization_config = quantization_config, + kwargs = kwargs, + ) + if fast_inference: if importlib.util.find_spec("vllm") is None: raise ImportError( @@ -1407,3 +1443,50 @@ class FastVisionModel(FastModel): class FastTextModel(FastModel): pass + + +def _route_hip_gpt_oss_model( + model_name, + use_exact_model_name, + load_in_4bit, + load_in_8bit, + load_in_fp8, + load_in_16bit, + quantization_config, + kwargs, +): + # AMD GPT-OSS routing: + # - Radeon can often use prequantized bnb-4bit checkpoints. + # - Instinct/MI (warp=64) often cannot, so fallback to BF16. + lower_model_name = model_name.lower() + if ( + is_hip() + and ("gpt-oss" in lower_model_name or "gpt_oss" in lower_model_name) + and not use_exact_model_name + ): + gpt_oss_prequant_suffix = lower_model_name.endswith( + ("-unsloth-bnb-4bit", "-bnb-4bit") + ) + wants_prequantized = load_in_4bit or gpt_oss_prequant_suffix + can_use_prequantized = ALLOW_BITSANDBYTES and ALLOW_PREQUANTIZED_MODELS + if not (wants_prequantized and can_use_prequantized): + if not lower_model_name.endswith("-bf16"): + if "120b" in lower_model_name: + model_name = "unsloth/gpt-oss-120b-BF16" + else: + model_name = "unsloth/gpt-oss-20b-BF16" + load_in_4bit = False + load_in_8bit = False + load_in_fp8 = False + load_in_16bit = True + quantization_config = None + kwargs.pop("quantization_config", None) + + return ( + model_name, + load_in_4bit, + load_in_8bit, + load_in_fp8, + load_in_16bit, + quantization_config, + ) diff --git a/unsloth/save.py b/unsloth/save.py index fc3b7b8771..32c8889eb6 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -1996,7 +1996,14 @@ def unsloth_save_pretrained_gguf( is_gpt_oss = is_gpt_oss, # Pass gpt_oss Flag ) except Exception as e: - if IS_KAGGLE_ENVIRONMENT: + if os.environ.get("UNSLOTH_GGUF_OFFLINE", "0") == "1": + print( + "Unsloth: GGUF conversion skipped due to offline mode. " f"Reason: {e}" + ) + all_file_locations = [] + want_full_precision = None + is_vlm_update = False + elif IS_KAGGLE_ENVIRONMENT: raise RuntimeError( f"Unsloth: GGUF conversion failed in Kaggle environment.\n" f"This is likely due to the 20GB disk space limit.\n" @@ -2010,6 +2017,17 @@ def unsloth_save_pretrained_gguf( gguf_directory = f"{save_directory}_gguf" modelfile_location = None ollama_success = False + if not all_file_locations: + # Offline or failed GGUF conversion: return early to avoid index errors + return { + "save_directory": save_directory, + "gguf_directory": gguf_directory, + "gguf_files": all_file_locations, + "modelfile_location": modelfile_location, + "want_full_precision": want_full_precision, + "is_vlm": is_vlm_update, + "fix_bos_token": fix_bos_token, + } if all_file_locations: try: if is_vlm_update: