diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index 11b9acf8267f..de289621c1f5 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -748,6 +748,8 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=Fals if x is None: return None if transpose: + if hasattr(x, "get"): + x = x.get() if isinstance(x, paddle.Tensor): x = paddle.transpose(x, [1, 0]) else: diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 62f3660a5bfe..4c8f216e46d5 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -21,7 +21,10 @@ from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: fused_rotary_position_embedding = None - +try: + from paddle.incubate.nn.functional import fused_rms_norm_ext +except ImportError: + fused_rms_norm_ext = None try: from paddle.incubate.nn.functional import swiglu except ImportError: @@ -132,6 +135,8 @@ def fusion_rope( def rms_norm_fused(x_in, w, eps, use_fast_ln=False): + if fused_rms_norm_ext is not None: + return fused_rms_norm_ext(x_in, w, eps)[0].astype(w.dtype) if use_fast_ln: fast_ln = try_import("fast_ln") return fast_ln.fast_rms_norm(x_in, w, eps)[0] diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 34c1e14e39e9..a75d6f367c23 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -92,6 +92,21 @@ "Qwen2ForTokenClassification", "Qwen2SentenceEmbedding", ] +import os + + +def str2bool(v): + if isinstance(v, bool): + return v + elif v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ValueError("Unsupported value encountered.") + + +FLAGS_ALIGN_PADDLEFORMERS = str2bool(os.getenv("FLAGS_ALIGN_PADDLEFORMERS", "True")) def get_triangle_upper_mask(x, mask=None): @@ -329,45 +344,6 @@ def forward(self, hidden_states): return hidden_states * self.weight -class Qwen2RotaryEmbedding(nn.Layer): - def __init__(self, dim, max_position_embeddings=2048, base=10000): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - # [dim / 2] - self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) - self._set_cos_sin_cache(seq_len=max_position_embeddings) - - def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - if self.inv_freq.dtype != paddle.float32: - self.inv_freq = 1.0 / ( - self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim) - ) - # [seq_len] - t = paddle.arange(seq_len, dtype="float32") - # [seq_len, dim/2] - freqs = paddle.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - # [seq_len, dim] - emb = paddle.concat([freqs, freqs], axis=-1) - # [1, seqlen, 1, dim] - self.cos_cached = emb.cos()[None, :, None, :] - self.sin_cached = emb.sin()[None, :, None, :] - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len) - cos = self.cos_cached[:, :seq_len, :, :] - sin = self.sin_cached[:, :seq_len, :, :] - return ( - cos.cast(x.dtype) if cos.dtype != x.dtype else cos, - sin.cast(x.dtype) if sin.dtype != x.dtype else sin, - ) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -375,19 +351,118 @@ def rotate_half(x): return paddle.concat([-x2, x1], axis=-1) # shape is the same as x -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - if position_ids is None: - # Note: Only for Qwen2MoEForCausalLMPipe model pretraining - cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - else: - cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] - sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +if FLAGS_ALIGN_PADDLEFORMERS: + + def _apply_rotary_emb( + x: paddle.Tensor, + cos: paddle.Tensor, + sin: paddle.Tensor, + ) -> paddle.Tensor: + x = x.transpose([0, 2, 1, 3]) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed.transpose([0, 2, 1, 3]) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed.astype(q.dtype), k_embed.astype(k.dtype) + + class Qwen2RotaryEmbedding(nn.Layer): + def __init__(self, config: Qwen2Config): + super().__init__() + self.config = config + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + inv_freq = 1.0 / ( + base ** (paddle.arange(0, dim, 2, dtype=paddle.int64).astype(dtype=paddle.float32) / dim) + ) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq, persistable=False) + self.original_inv_freq = self.inv_freq + + def forward(self, x, position_ids): + # NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast + # certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where + # numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss. + # Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended. + with paddle.amp.auto_cast(False): + inv_freq_expanded = ( + self.inv_freq.unsqueeze(0) + .unsqueeze(-1) + .cast(paddle.float32) + .expand([position_ids.shape[0], -1, 1]) + .to(x.place) + ) + position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32) + + freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1]) + emb = paddle.cat((freqs, freqs), axis=-1) + cos = paddle.cos(emb) * self.attention_scaling + sin = paddle.sin(emb) * self.attention_scaling + + return cos.cast(dtype="float32"), sin.cast(dtype="float32") + +else: + + class Qwen2RotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + self.inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim) + ) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + if self.inv_freq.dtype != paddle.float32: + self.inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim) + ) + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + cos = self.cos_cached[:, :seq_len, :, :] + sin = self.sin_cached[:, :seq_len, :, :] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + if position_ids is None: + # Note: Only for Qwen2MoEForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed class Qwen2MLP(nn.Layer): @@ -612,11 +687,14 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_r ) self.o_proj = Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias_attr=False) - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + if FLAGS_ALIGN_PADDLEFORMERS: + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + else: + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) self.attn_func = scaled_dot_product_attention @@ -692,7 +770,10 @@ def forward( use_neox_rotary_style=False, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if FLAGS_ALIGN_PADDLEFORMERS: + cos, sin = self.rotary_emb(hidden_states, position_ids) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] @@ -948,8 +1029,9 @@ def get_tensor_parallel_split_mappings(num_layers): "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), } - - if config.tie_word_embeddings: + if FLAGS_ALIGN_PADDLEFORMERS: + base_actions["lm_head.weight"] = partial(fn, is_column=False, transpose=True) + elif config.tie_word_embeddings: base_actions["lm_head.weight"] = partial(fn, is_column=False) else: base_actions["lm_head.weight"] = partial(fn, is_column=True) @@ -1518,7 +1600,7 @@ def __init__(self, config: Qwen2Config): self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True) self.tie_weights() else: - self.lm_head = Qwen2LMHead(config) + self.lm_head = Qwen2LMHead(config, transpose_y=FLAGS_ALIGN_PADDLEFORMERS) self.criterion = Qwen2PretrainingCriterion(config) self.vocab_size = config.vocab_size