Skip to content

Commit 448ac1f

Browse files
AMDGPU/GlobalISel: Fix broken exp10 lowering for f16 (#170708)
1 parent c347b26 commit 448ac1f

File tree

3 files changed

+413
-337
lines changed

3 files changed

+413
-337
lines changed

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3728,24 +3728,39 @@ bool AMDGPULegalizerInfo::legalizeFExp2(MachineInstr &MI,
37283728
return true;
37293729
}
37303730

3731+
static MachineInstrBuilder buildExp(MachineIRBuilder &B, const DstOp &Dst,
3732+
const SrcOp &Src, unsigned Flags) {
3733+
LLT Ty = Dst.getLLTTy(*B.getMRI());
3734+
3735+
if (Ty == LLT::scalar(32)) {
3736+
return B.buildIntrinsic(Intrinsic::amdgcn_exp2, {Dst})
3737+
.addUse(Src.getReg())
3738+
.setMIFlags(Flags);
3739+
}
3740+
return B.buildFExp2(Dst, Src, Flags);
3741+
}
3742+
3743+
bool AMDGPULegalizerInfo::legalizeFExpUnsafeImpl(MachineIRBuilder &B,
3744+
Register Dst, Register X,
3745+
unsigned Flags,
3746+
bool IsExp10) const {
3747+
LLT Ty = B.getMRI()->getType(X);
3748+
3749+
// exp(x) -> exp2(M_LOG2E_F * x);
3750+
// exp10(x) -> exp2(log2(10) * x);
3751+
auto Const = B.buildFConstant(Ty, IsExp10 ? 0x1.a934f0p+1f : numbers::log2e);
3752+
auto Mul = B.buildFMul(Ty, X, Const, Flags);
3753+
buildExp(B, Dst, Mul, Flags);
3754+
return true;
3755+
}
3756+
37313757
bool AMDGPULegalizerInfo::legalizeFExpUnsafe(MachineIRBuilder &B, Register Dst,
37323758
Register X, unsigned Flags) const {
37333759
LLT Ty = B.getMRI()->getType(Dst);
37343760
LLT F32 = LLT::scalar(32);
37353761

37363762
if (Ty != F32 || !needsDenormHandlingF32(B.getMF(), X, Flags)) {
3737-
auto Log2E = B.buildFConstant(Ty, numbers::log2e);
3738-
auto Mul = B.buildFMul(Ty, X, Log2E, Flags);
3739-
3740-
if (Ty == F32) {
3741-
B.buildIntrinsic(Intrinsic::amdgcn_exp2, ArrayRef<Register>{Dst})
3742-
.addUse(Mul.getReg(0))
3743-
.setMIFlags(Flags);
3744-
} else {
3745-
B.buildFExp2(Dst, Mul.getReg(0), Flags);
3746-
}
3747-
3748-
return true;
3763+
return legalizeFExpUnsafeImpl(B, Dst, X, Flags, /*IsExp10=*/false);
37493764
}
37503765

37513766
auto Threshold = B.buildFConstant(Ty, -0x1.5d58a0p+6f);
@@ -3768,6 +3783,55 @@ bool AMDGPULegalizerInfo::legalizeFExpUnsafe(MachineIRBuilder &B, Register Dst,
37683783
return true;
37693784
}
37703785

3786+
bool AMDGPULegalizerInfo::legalizeFExp10Unsafe(MachineIRBuilder &B,
3787+
Register Dst, Register X,
3788+
unsigned Flags) const {
3789+
LLT Ty = B.getMRI()->getType(Dst);
3790+
LLT F32 = LLT::scalar(32);
3791+
3792+
if (Ty != F32 || !needsDenormHandlingF32(B.getMF(), X, Flags)) {
3793+
// exp2(x * 0x1.a92000p+1f) * exp2(x * 0x1.4f0978p-11f);
3794+
auto K0 = B.buildFConstant(Ty, 0x1.a92000p+1f);
3795+
auto K1 = B.buildFConstant(Ty, 0x1.4f0978p-11f);
3796+
3797+
auto Mul1 = B.buildFMul(Ty, X, K1, Flags);
3798+
auto Exp2_1 = buildExp(B, Ty, Mul1, Flags);
3799+
auto Mul0 = B.buildFMul(Ty, X, K0, Flags);
3800+
auto Exp2_0 = buildExp(B, Ty, Mul0, Flags);
3801+
B.buildFMul(Dst, Exp2_0, Exp2_1, Flags);
3802+
return true;
3803+
}
3804+
3805+
// bool s = x < -0x1.2f7030p+5f;
3806+
// x += s ? 0x1.0p+5f : 0.0f;
3807+
// exp10 = exp2(x * 0x1.a92000p+1f) *
3808+
// exp2(x * 0x1.4f0978p-11f) *
3809+
// (s ? 0x1.9f623ep-107f : 1.0f);
3810+
3811+
auto Threshold = B.buildFConstant(Ty, -0x1.2f7030p+5f);
3812+
auto NeedsScaling =
3813+
B.buildFCmp(CmpInst::FCMP_OLT, LLT::scalar(1), X, Threshold);
3814+
3815+
auto ScaleOffset = B.buildFConstant(Ty, 0x1.0p+5f);
3816+
auto ScaledX = B.buildFAdd(Ty, X, ScaleOffset, Flags);
3817+
auto AdjustedX = B.buildSelect(Ty, NeedsScaling, ScaledX, X);
3818+
3819+
auto K0 = B.buildFConstant(Ty, 0x1.a92000p+1f);
3820+
auto K1 = B.buildFConstant(Ty, 0x1.4f0978p-11f);
3821+
3822+
auto Mul1 = B.buildFMul(Ty, AdjustedX, K1, Flags);
3823+
auto Exp2_1 = buildExp(B, Ty, Mul1, Flags);
3824+
auto Mul0 = B.buildFMul(Ty, AdjustedX, K0, Flags);
3825+
auto Exp2_0 = buildExp(B, Ty, Mul0, Flags);
3826+
3827+
auto MulExps = B.buildFMul(Ty, Exp2_0, Exp2_1, Flags);
3828+
auto ResultScaleFactor = B.buildFConstant(Ty, 0x1.9f623ep-107f);
3829+
auto AdjustedResult = B.buildFMul(Ty, MulExps, ResultScaleFactor, Flags);
3830+
3831+
B.buildSelect(Dst, NeedsScaling, AdjustedResult, MulExps);
3832+
return true;
3833+
}
3834+
37713835
bool AMDGPULegalizerInfo::legalizeFExp(MachineInstr &MI,
37723836
MachineIRBuilder &B) const {
37733837
Register Dst = MI.getOperand(0).getReg();
@@ -3784,18 +3848,22 @@ bool AMDGPULegalizerInfo::legalizeFExp(MachineInstr &MI,
37843848
// v_exp_f16 (fmul x, log2e)
37853849
if (allowApproxFunc(MF, Flags)) {
37863850
// TODO: Does this really require fast?
3787-
legalizeFExpUnsafe(B, Dst, X, Flags);
3851+
IsExp10 ? legalizeFExp10Unsafe(B, Dst, X, Flags)
3852+
: legalizeFExpUnsafe(B, Dst, X, Flags);
37883853
MI.eraseFromParent();
37893854
return true;
37903855
}
37913856

3857+
// Nothing in half is a denormal when promoted to f32.
3858+
//
37923859
// exp(f16 x) ->
37933860
// fptrunc (v_exp_f32 (fmul (fpext x), log2e))
3794-
3795-
// Nothing in half is a denormal when promoted to f32.
3861+
//
3862+
// exp10(f16 x) ->
3863+
// fptrunc (v_exp_f32 (fmul (fpext x), log2(10)))
37963864
auto Ext = B.buildFPExt(F32, X, Flags);
37973865
Register Lowered = MRI.createGenericVirtualRegister(F32);
3798-
legalizeFExpUnsafe(B, Lowered, Ext.getReg(0), Flags);
3866+
legalizeFExpUnsafeImpl(B, Lowered, Ext.getReg(0), Flags, IsExp10);
37993867
B.buildFPTrunc(Dst, Lowered, Flags);
38003868
MI.eraseFromParent();
38013869
return true;
@@ -3806,7 +3874,8 @@ bool AMDGPULegalizerInfo::legalizeFExp(MachineInstr &MI,
38063874
// TODO: Interpret allowApproxFunc as ignoring DAZ. This is currently copying
38073875
// library behavior. Also, is known-not-daz source sufficient?
38083876
if (allowApproxFunc(MF, Flags)) {
3809-
legalizeFExpUnsafe(B, Dst, X, Flags);
3877+
IsExp10 ? legalizeFExp10Unsafe(B, Dst, X, Flags)
3878+
: legalizeFExpUnsafe(B, Dst, X, Flags);
38103879
MI.eraseFromParent();
38113880
return true;
38123881
}

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,12 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
9191
bool legalizeFlogUnsafe(MachineIRBuilder &B, Register Dst, Register Src,
9292
bool IsLog10, unsigned Flags) const;
9393
bool legalizeFExp2(MachineInstr &MI, MachineIRBuilder &B) const;
94+
bool legalizeFExpUnsafeImpl(MachineIRBuilder &B, Register Dst, Register Src,
95+
unsigned Flags, bool IsExp10) const;
9496
bool legalizeFExpUnsafe(MachineIRBuilder &B, Register Dst, Register Src,
9597
unsigned Flags) const;
98+
bool legalizeFExp10Unsafe(MachineIRBuilder &B, Register Dst, Register Src,
99+
unsigned Flags) const;
96100
bool legalizeFExp(MachineInstr &MI, MachineIRBuilder &B) const;
97101
bool legalizeFPow(MachineInstr &MI, MachineIRBuilder &B) const;
98102
bool legalizeFFloor(MachineInstr &MI, MachineRegisterInfo &MRI,

0 commit comments

Comments
 (0)