Skip to content

[Bug] RuntimeError: shape mismatch in get_rope_index during GRPO training with Qwen3-VL(8B) #5176

@Saad1926Q

Description

@Saad1926Q
  1. Did you update? pip install --upgrade unsloth unsloth_zoo
  • Yes, latest version installed
  1. Colab or Kaggle or local / cloud
  • Cloud (VastAI)
  1. Number GPUs used, use nvidia-smi
  • 1 GPU, RTX 4080 with 16GB VRAM.
  1. Which notebook? Please link!
  1. Which Unsloth version, TRL version, transformers version, PyTorch version?
  • unsloth: 2026.4.8, TRL: 0.24.0, transformers: 5.5.0, PyTorch: 2.10.0+cu130
  1. Which trainer? SFTTrainer, GRPOTrainer etc
  • GRPOTrainer with Qwen3-VL-8B on a multimodal ChartQA dataset

Description

Getting a shape mismatch in get_rope_index during GRPO's compute_loss forward pass with Qwen3-VL-8B on a multimodal dataset. The mismatch size varies per run (e.g. [3, 447] vs [3, 445], or [3, 454] vs [3, 447]), so it's not a fixed offset - it depends on batch contents. Crashes at the very first training step.

Minimal Code to Reproduce Error

from unsloth import FastVisionModel
from trl import GRPOTrainer, GRPOConfig                                                                                                                                   
from datasets import load_dataset                 
import re                                                                                                                        
                                                                                                                                                                          
model, tokenizer = FastVisionModel.from_pretrained(                                                                                                                       
    model_name="unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit",
    max_seq_length=16384,                                                                                                                                                 
    load_in_4bit=True,                                                                                                                                                    
)                                                                                                                                                                         
                                                                                                                                                                          
def process(example):
    image = example["image"].resize((512, 512)).convert("RGB")
    prompt = [{                                                                                                                                                           
        "role": "user",
        "content": [                                                                                                                                                      
            {"type": "image"},
            {"type": "text", "text": example["query"].strip()},                                                                                                           
        ],                                                                                                                                                                
    }]
    return {"prompt": prompt, "image": image, "answer": example["label"][0]}                                                                                              
                                                                                                                                                                          
dataset = load_dataset("HuggingFaceM4/ChartQA", split="train[:500]").map(process)                                                                                         
dataset = dataset.select_columns(["prompt", "image", "answer"])                                                                                                           
                                                                                                                                                                          
def reward_func(completions, answer, **kwargs):                                                                                                                           
    scores = [] 
    for completion, ans in zip(completions, answer):                                                                                                                      
        if isinstance(completion, list):
            completion = completion[0]["content"] if completion else ""                                                                                                   
        match = re.search(r"<SOLUTION>(.*?)</SOLUTION>", completion, re.DOTALL)
        predicted = match.group(1).strip() if match else ""                                                                                                               
        scores.append(1.0 if predicted == ans.strip() else 0.0)
    return scores                                                                                                                               
                                                                                                                                                                          
training_args = GRPOConfig(
    per_device_train_batch_size=4,                                                                                                                                        
    num_generations=4,
    max_prompt_length=1024,                                                                                                                                               
    max_completion_length=1024,
    num_train_epochs=1,                                                                                                                                                   
    output_dir="outputs",                                                                                                                                                 
    bf16=True,
)                                                                                                                                                                         
training_args.unsloth_num_chunks = -1
training_args.unsloth_grpo_mini_batch = None                                                                                                                              
training_args.unsloth_logit_chunk_multiplier = None                                                                                                                       
training_args.vllm_sampling_params = None                                                                                                                                 
                                                                                                                                                                          
trainer = GRPOTrainer(                                                                                                                                                    
    model=model,
    args=training_args,
    processing_class=tokenizer,                                                                                                                                           
    reward_funcs=[reward_func],
    train_dataset=dataset,                                                                                                                                                
)                                                                                                                                                                         
 
trainer.train()  # crashes here     

Traceback

File "transformers/models/qwen3_vl/modeling_qwen3_vl.py", line 1117, in get_rope_index
    position_ids[:, batch_idx, attention_mask[batch_idx].bool()] = llm_positions.to(position_ids.device)                                                                  
RuntimeError: shape mismatch: value tensor of shape [3, 454] cannot be broadcast to indexing result of shape [3, 447]    

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions