Skip to content

[XPU] Fix precision for paddle.nn.functional.gelu backward with bfloat16#78845

Open
YqGe585 wants to merge 1 commit intoPaddlePaddle:developfrom
YqGe585:xpu-api-fixer/PAD-2-xpu-precision
Open

[XPU] Fix precision for paddle.nn.functional.gelu backward with bfloat16#78845
YqGe585 wants to merge 1 commit intoPaddlePaddle:developfrom
YqGe585:xpu-api-fixer/PAD-2-xpu-precision

Conversation

@YqGe585
Copy link
Copy Markdown
Member

@YqGe585 YqGe585 commented Apr 29, 2026

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

paddle.nn.functional.gelu

问题: paddle.nn.functional.gelu(Tensor([1, 8192, 6912], "bfloat16"), approximate=True) 的反向传播在 XPU 上存在精度问题。XPU kernel 调用 xpu::gelu_grad<bfloat16>,所有中间计算在原生 bfloat16 精度下执行,导致数值溢出(max_abs_diff=32768 = 2^15)。

根本原因: GPU kernel 使用 MPTypeTrait<T>::Type,将 bfloat16 提升为 float32 进行所有中间计算(x_sq、x_cube、tanh、tanh_derivative 等),结果再转回 bfloat16。XPU kernel 缺少这种精度提升机制。

修复方案:paddle/phi/kernels/xpu/gelu_grad_kernel.cc 中,对 bfloat16 和 float16 输入,使用 xpu::gelu_grad_highprecision<XPUType> 替代 xpu::gelu_grad<XPUType>。该高精度变体在内部使用 float32 执行中间计算,与 GPU kernel 行为保持一致。float32 输入继续使用原有路径。

验证结果: 修复后,反向传播 max_abs_diff=0.004571(atol=0.01 → PASS),max_rel_diff=0.004049(rtol=0.01 → PASS),梯度中无 NaN 或 Inf。

是否引起精度变化

是 — XPU bfloat16/float16 的 gelu 反向精度与 GPU 对齐,修复了溢出问题。

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 29, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant