-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
ROCm: Add gfx950 (MI355X/CDNA4) to is_cdna() and include PR #4021 fixes #4050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4e1e24c
3daaff6
cff534d
24c7f2e
458af41
a211e8c
b56ab6a
4f138ac
9956d1d
28aa6c2
f0da826
734649e
8b12e72
7369727
5ac8f45
897a004
8729bb5
41c5a96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -275,6 +275,24 @@ def from_pretrained( | |
| ) | ||
| load_in_4bit = False | ||
|
|
||
| ( | ||
| model_name, | ||
| load_in_4bit, | ||
| load_in_8bit, | ||
| load_in_fp8, | ||
| load_in_16bit, | ||
| quantization_config, | ||
| ) = _route_hip_gpt_oss_model( | ||
| model_name = model_name, | ||
| use_exact_model_name = use_exact_model_name, | ||
| load_in_4bit = load_in_4bit, | ||
| load_in_8bit = load_in_8bit, | ||
| load_in_fp8 = load_in_fp8, | ||
| load_in_16bit = load_in_16bit, | ||
| quantization_config = quantization_config, | ||
| kwargs = kwargs, | ||
| ) | ||
|
Comment on lines
+278
to
+294
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of code for routing HIP GPT-OSS models is duplicated in |
||
|
|
||
| # Find FP8, BnB 4bit, other mapped names | ||
| old_model_name = model_name | ||
| fp8_mode = None | ||
|
|
@@ -859,6 +877,24 @@ def from_pretrained( | |
| ) | ||
| load_in_4bit = False | ||
|
|
||
| ( | ||
| model_name, | ||
| load_in_4bit, | ||
| load_in_8bit, | ||
| load_in_fp8, | ||
| load_in_16bit, | ||
| quantization_config, | ||
| ) = _route_hip_gpt_oss_model( | ||
| model_name = model_name, | ||
| use_exact_model_name = use_exact_model_name, | ||
| load_in_4bit = load_in_4bit, | ||
| load_in_8bit = load_in_8bit, | ||
| load_in_fp8 = load_in_fp8, | ||
| load_in_16bit = load_in_16bit, | ||
| quantization_config = quantization_config, | ||
| kwargs = kwargs, | ||
| ) | ||
|
|
||
| if fast_inference: | ||
| if importlib.util.find_spec("vllm") is None: | ||
| raise ImportError( | ||
|
|
@@ -1407,3 +1443,50 @@ class FastVisionModel(FastModel): | |
|
|
||
| class FastTextModel(FastModel): | ||
| pass | ||
|
|
||
|
|
||
| def _route_hip_gpt_oss_model( | ||
| model_name, | ||
| use_exact_model_name, | ||
| load_in_4bit, | ||
| load_in_8bit, | ||
| load_in_fp8, | ||
| load_in_16bit, | ||
| quantization_config, | ||
| kwargs, | ||
| ): | ||
| # AMD GPT-OSS routing: | ||
| # - Radeon can often use prequantized bnb-4bit checkpoints. | ||
| # - Instinct/MI (warp=64) often cannot, so fallback to BF16. | ||
| lower_model_name = model_name.lower() | ||
| if ( | ||
| is_hip() | ||
| and ("gpt-oss" in lower_model_name or "gpt_oss" in lower_model_name) | ||
| and not use_exact_model_name | ||
| ): | ||
| gpt_oss_prequant_suffix = lower_model_name.endswith( | ||
| ("-unsloth-bnb-4bit", "-bnb-4bit") | ||
| ) | ||
| wants_prequantized = load_in_4bit or gpt_oss_prequant_suffix | ||
| can_use_prequantized = ALLOW_BITSANDBYTES and ALLOW_PREQUANTIZED_MODELS | ||
| if not (wants_prequantized and can_use_prequantized): | ||
| if not lower_model_name.endswith("-bf16"): | ||
| if "120b" in lower_model_name: | ||
| model_name = "unsloth/gpt-oss-120b-BF16" | ||
| else: | ||
| model_name = "unsloth/gpt-oss-20b-BF16" | ||
|
Comment on lines
+1473
to
+1477
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| load_in_4bit = False | ||
| load_in_8bit = False | ||
| load_in_fp8 = False | ||
| load_in_16bit = True | ||
| quantization_config = None | ||
| kwargs.pop("quantization_config", None) | ||
|
|
||
| return ( | ||
| model_name, | ||
| load_in_4bit, | ||
| load_in_8bit, | ||
| load_in_fp8, | ||
| load_in_16bit, | ||
| quantization_config, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and reduce code duplication, you can refactor this logic. The
is_bf16_supportedfunction is defined identically for bothhipandxpudevice types. Combining theelifblocks forhipandxpuand defining the function only once would make the code cleaner.