Skip to content

Commit 9bd957d

Browse files
committed
[MLIR][NVVM] Add support for dp4a instructions
This change adds the `dp4a` Op to the NVVM dialect to perform four-way byte dot product-accumulate operation. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a
1 parent 20d6375 commit 9bd957d

File tree

4 files changed

+105
-0
lines changed

4 files changed

+105
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

+48
Original file line numberDiff line numberDiff line change
@@ -3444,6 +3444,54 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
34443444
let hasVerifier = 1;
34453445
}
34463446

3447+
//===----------------------------------------------------------------------===//
3448+
// NVVM dp4a Op
3449+
//===----------------------------------------------------------------------===//
3450+
3451+
def NVVM_Dp4aOp : NVVM_Op<"dp4a"> {
3452+
let summary = "Four-way byte dot product-accumulate instruction.";
3453+
let description = [{
3454+
Performs a four-way byte dot-product which is accumulated in a 32-bit
3455+
result.
3456+
Operand `a` and `b` can be passed either as packed 32-bit inputs holding
3457+
4 byte-inputs for the dot product, or as vectors of 4 i8 elements.
3458+
The `a_signed` and `b_signed` unit attributes specify whether the
3459+
individual byte inputs in operands `a` and `b` are signed or unsigned
3460+
respectively.
3461+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3462+
treated as holding a signed integer if any of `a` or `b` are signed.
3463+
3464+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
3465+
}];
3466+
3467+
let arguments = (ins
3468+
AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$a,
3469+
AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$b,
3470+
I32:$c,
3471+
DefaultValuedAttr<UnitAttr, "false">:$a_signed,
3472+
DefaultValuedAttr<UnitAttr, "false">:$b_signed
3473+
);
3474+
3475+
let results = (outs I32:$res);
3476+
3477+
let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a) `,` type($b)";
3478+
3479+
let extraClassDeclaration = [{
3480+
static llvm::Intrinsic::ID getIntrinsicID(bool a_signed, bool b_signed);
3481+
}];
3482+
3483+
string llvmBuilder = [{
3484+
auto id = NVVM::Dp4aOp::getIntrinsicID($a_signed, $b_signed);
3485+
llvm::Value* argA = $a;
3486+
llvm::Value* argB = $b;
3487+
if (!op.getA().getType().isInteger(32))
3488+
argA = builder.CreateBitCast(argA, llvm::Type::getInt32Ty(builder.getContext()));
3489+
if (!op.getB().getType().isInteger(32))
3490+
argB = builder.CreateBitCast(argB, llvm::Type::getInt32Ty(builder.getContext()));
3491+
$res = createIntrinsicCall(builder, id, {argA, argB, $c});
3492+
}];
3493+
}
3494+
34473495
//===----------------------------------------------------------------------===//
34483496
// NVVM target attribute.
34493497
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,14 @@ static void nvvmInferResultRanges(Operation *op, Value result,
15901590
}
15911591
}
15921592

1593+
#define GET_DP4A_ID(a_sign, is_b_signed) \
1594+
is_b_signed ? llvm::Intrinsic::nvvm_idp4a_##a_sign##_s \
1595+
: llvm::Intrinsic::nvvm_idp4a_##a_sign##_u
1596+
1597+
llvm::Intrinsic::ID Dp4aOp::getIntrinsicID(bool a_signed, bool b_signed) {
1598+
return a_signed ? GET_DP4A_ID(s, b_signed) : GET_DP4A_ID(u, b_signed);
1599+
}
1600+
15931601
//===----------------------------------------------------------------------===//
15941602
// NVVMDialect initialization, type parsing, and registration.
15951603
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

+13
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,19 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
578578
return
579579
}
580580

581+
// CHECK-LABEL: @dp4a
582+
func.func @dp4a_packed(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
583+
// CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : i32, i32
584+
%0 = nvvm.dp4a %a, %b, %c: i32, i32
585+
// CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
586+
%1 = nvvm.dp4a %a_vec, %b_vec, %c: vector<4xi8>, vector<4xi8>
587+
// CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : i32, i32
588+
%2 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
589+
// CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : vector<4xi8>, vector<4xi8>
590+
%3 = nvvm.dp4a %a_vec, %b_vec, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
591+
return
592+
}
593+
581594
// -----
582595

583596
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir.mlir

+36
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,39 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
844844
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
845845
llvm.return
846846
}
847+
848+
// -----
849+
// CHECK-LABEL: @nvvm_dp4a_packed
850+
llvm.func @nvvm_dp4a_packed(%a: i32, %b: i32, %c: i32) {
851+
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
852+
%0 = nvvm.dp4a %a, %b, %c: i32, i32
853+
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
854+
%1 = nvvm.dp4a %a, %b, %c {a_signed}: i32, i32
855+
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
856+
%2 = nvvm.dp4a %a, %b, %c {b_signed}: i32, i32
857+
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
858+
%3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
859+
llvm.return
860+
}
861+
862+
// -----
863+
// CHECK-LABEL: @nvvm_dp4a_vec
864+
llvm.func @nvvm_dp4a_vec(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
865+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
866+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
867+
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
868+
%0 = nvvm.dp4a %a, %b, %c: vector<4xi8>, vector<4xi8>
869+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
870+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
871+
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
872+
%1 = nvvm.dp4a %a, %b, %c {a_signed}: vector<4xi8>, vector<4xi8>
873+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
874+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
875+
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
876+
%2 = nvvm.dp4a %a, %b, %c {b_signed}: vector<4xi8>, vector<4xi8>
877+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
878+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
879+
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
880+
%3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
881+
llvm.return
882+
}

0 commit comments

Comments
 (0)