Skip to content

Commit 794551d

Browse files
authored
[RISCV][llvm] Support PSRA, PSRAI, PSRL, PSRLI codegen for P extension (#171460)
1 parent 6ad0c7c commit 794551d

File tree

4 files changed

+507
-21
lines changed

4 files changed

+507
-21
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
526526
setOperationAction({ISD::AVGFLOORS, ISD::AVGFLOORU}, VTs, Legal);
527527
setOperationAction({ISD::ABDS, ISD::ABDU}, VTs, Legal);
528528
setOperationAction(ISD::SPLAT_VECTOR, VTs, Legal);
529-
setOperationAction(ISD::SHL, VTs, Custom);
529+
setOperationAction({ISD::SHL, ISD::SRL, ISD::SRA}, VTs, Custom);
530530
setOperationAction(ISD::BITCAST, VTs, Custom);
531531
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VTs, Custom);
532532
}
@@ -8662,22 +8662,21 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
86628662
case ISD::VSELECT:
86638663
return lowerToScalableOp(Op, DAG);
86648664
case ISD::SHL:
8665-
if (Subtarget.enablePExtCodeGen() &&
8666-
Op.getSimpleValueType().isFixedLengthVector()) {
8667-
// We have patterns for scalar/immediate shift amount, so no lowering
8668-
// needed.
8669-
if (Op.getOperand(1)->getOpcode() == ISD::SPLAT_VECTOR)
8670-
return Op;
8671-
8672-
// There's no vector-vector version of shift instruction in P extension so
8673-
// we need to unroll to scalar computation and pack them back.
8674-
return DAG.UnrollVectorOp(Op.getNode());
8675-
}
8676-
[[fallthrough]];
8677-
case ISD::SRA:
86788665
case ISD::SRL:
8679-
if (Op.getSimpleValueType().isFixedLengthVector())
8666+
case ISD::SRA:
8667+
if (Op.getSimpleValueType().isFixedLengthVector()) {
8668+
if (Subtarget.enablePExtCodeGen()) {
8669+
// We have patterns for scalar/immediate shift amount, so no lowering
8670+
// needed.
8671+
if (Op.getOperand(1)->getOpcode() == ISD::SPLAT_VECTOR)
8672+
return Op;
8673+
8674+
// There's no vector-vector version of shift instruction in P extension
8675+
// so we need to unroll to scalar computation and pack them back.
8676+
return DAG.UnrollVectorOp(Op.getNode());
8677+
}
86808678
return lowerToScalableOp(Op, DAG);
8679+
}
86818680
// This can be called for an i32 shift amount that needs to be promoted.
86828681
assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
86838682
"Unexpected custom legalisation");

llvm/lib/Target/RISCV/RISCVInstrInfoP.td

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,26 +1513,55 @@ let Predicates = [HasStdExtP] in {
15131513
def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
15141514
def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
15151515

1516-
// 8-bit logical shift left patterns
1516+
// 8-bit logical shift left/right patterns
15171517
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
15181518
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
1519+
def: Pat<(XLenVecI8VT (srl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
1520+
(PSRLI_B GPR:$rs1, uimm3:$shamt)>;
15191521

1520-
// 16-bit logical shift left patterns
1522+
// 16-bit logical shift left/right patterns
15211523
def: Pat<(XLenVecI16VT (shl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
15221524
(PSLLI_H GPR:$rs1, uimm4:$shamt)>;
1525+
def: Pat<(XLenVecI16VT (srl GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
1526+
(PSRLI_H GPR:$rs1, uimm4:$shamt)>;
1527+
1528+
// 8-bit arithmetic shift right patterns
1529+
def: Pat<(XLenVecI8VT (sra GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
1530+
(PSRAI_B GPR:$rs1, uimm3:$shamt)>;
1531+
1532+
// 16-bit arithmetic shift right patterns
1533+
def: Pat<(XLenVecI16VT (sra GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
1534+
(PSRAI_H GPR:$rs1, uimm4:$shamt)>;
15231535

15241536
// 16-bit signed saturation shift left patterns
15251537
def: Pat<(XLenVecI16VT (sshlsat GPR:$rs1, (XLenVecI16VT (splat_vector uimm4:$shamt)))),
15261538
(PSSLAI_H GPR:$rs1, uimm4:$shamt)>;
15271539

1528-
// 8-bit logical shift left
1540+
// 8-bit logical shift left/right
15291541
def: Pat<(XLenVecI8VT (shl GPR:$rs1,
15301542
(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
15311543
(PSLL_BS GPR:$rs1, GPR:$rs2)>;
1532-
// 16-bit logical shift left
1544+
def: Pat<(XLenVecI8VT (srl GPR:$rs1,
1545+
(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
1546+
(PSRL_BS GPR:$rs1, GPR:$rs2)>;
1547+
1548+
// 8-bit arithmetic shift left/right
1549+
def: Pat<(XLenVecI8VT (sra GPR:$rs1,
1550+
(XLenVecI8VT (splat_vector (XLenVT GPR:$rs2))))),
1551+
(PSRA_BS GPR:$rs1, GPR:$rs2)>;
1552+
1553+
// 16-bit logical shift left/right
15331554
def: Pat<(XLenVecI16VT (shl GPR:$rs1,
15341555
(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
15351556
(PSLL_HS GPR:$rs1, GPR:$rs2)>;
1557+
def: Pat<(XLenVecI16VT (srl GPR:$rs1,
1558+
(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
1559+
(PSRL_HS GPR:$rs1, GPR:$rs2)>;
1560+
1561+
// 16-bit arithmetic shift left/right
1562+
def: Pat<(XLenVecI16VT (sra GPR:$rs1,
1563+
(XLenVecI16VT (splat_vector (XLenVT GPR:$rs2))))),
1564+
(PSRA_HS GPR:$rs1, GPR:$rs2)>;
15361565

15371566
// 8-bit PLI SD node pattern
15381567
def: Pat<(XLenVecI8VT (splat_vector simm8_unsigned:$imm8)), (PLI_B simm8_unsigned:$imm8)>;
@@ -1580,16 +1609,28 @@ let Predicates = [HasStdExtP, IsRV64] in {
15801609
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
15811610
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
15821611

1583-
// 32-bit logical shift left
1612+
// 32-bit logical shift left/right
15841613
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
15851614
(PSLL_WS GPR:$rs1, GPR:$rs2)>;
1615+
def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
1616+
(PSRL_WS GPR:$rs1, GPR:$rs2)>;
1617+
1618+
// 32-bit arithmetic shift left/right
1619+
def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
1620+
(PSRA_WS GPR:$rs1, GPR:$rs2)>;
15861621

15871622
// splat pattern
15881623
def: Pat<(v2i32 (splat_vector (XLenVT GPR:$rs2))), (PADD_WS (XLenVT X0), GPR:$rs2)>;
15891624

1590-
// 32-bit logical shift left patterns
1625+
// 32-bit logical shift left/right patterns
15911626
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
15921627
(PSLLI_W GPR:$rs1, uimm5:$shamt)>;
1628+
def: Pat<(v2i32 (srl GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
1629+
(PSRLI_W GPR:$rs1, uimm5:$shamt)>;
1630+
1631+
// 32-bit arithmetic shift left/right patterns
1632+
def: Pat<(v2i32 (sra GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),
1633+
(PSRAI_W GPR:$rs1, uimm5:$shamt)>;
15931634

15941635
// 32-bit signed saturation shift left patterns
15951636
def: Pat<(v2i32 (sshlsat GPR:$rs1, (v2i32 (splat_vector uimm5:$shamt)))),

0 commit comments

Comments
 (0)