Skip to content

Commit 7c49ab0

Browse files
authored
[MLIR][NVVM] Add dot.accumulate.4way OP (#139043)
This change adds the `dot.accumulate.4way` Op to the NVVM dialect to perform four-way byte dot product-accumulate operation. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a
1 parent 48a814c commit 7c49ab0

File tree

4 files changed

+123
-0
lines changed

4 files changed

+123
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

+64
Original file line numberDiff line numberDiff line change
@@ -3444,6 +3444,70 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
34443444
let hasVerifier = 1;
34453445
}
34463446

3447+
//===----------------------------------------------------------------------===//
3448+
// NVVM dot.accumulate.4way Op
3449+
//===----------------------------------------------------------------------===//
3450+
3451+
def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
3452+
def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
3453+
3454+
def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
3455+
"NVVM DotAccumulate4WayType",
3456+
[DotAccumulate4WayS8, DotAccumulate4WayU8]> {
3457+
let cppNamespace = "::mlir::NVVM";
3458+
let genSpecializedAttr = 0;
3459+
}
3460+
3461+
def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
3462+
let assemblyFormat = "`<` $value `>`";
3463+
}
3464+
3465+
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3466+
let summary = "Four-way byte dot product-accumulate instruction.";
3467+
let description = [{
3468+
Performs a four-way byte dot-product which is accumulated in a 32-bit
3469+
result.
3470+
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
3471+
computed.
3472+
The `a_type` and `b_type` attributes specify the type of the elements in `a`
3473+
and `b` respectively.
3474+
If `a_type` or `b_type` is `s8`, then the elements in the corresponding
3475+
vector are sign-extended to 32-bit before the dot product is computed.
3476+
If `a_type` or `b_type` is `u8`, then the elements in the corresponding
3477+
vector are zero-extended to 32-bit instead.
3478+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
3479+
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
3480+
3481+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
3482+
}];
3483+
3484+
let arguments = (ins
3485+
VectorOfLengthAndType<[4], [I8]>:$a,
3486+
DotAccumulate4WayTypeAttr:$a_type,
3487+
VectorOfLengthAndType<[4], [I8]>:$b,
3488+
DotAccumulate4WayTypeAttr:$b_type,
3489+
I32:$c
3490+
);
3491+
3492+
let results = (outs I32:$res);
3493+
3494+
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
3495+
3496+
let extraClassDeclaration = [{
3497+
static llvm::Intrinsic::ID
3498+
getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
3499+
NVVM::DotAccumulate4WayType b_type);
3500+
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
3501+
}];
3502+
3503+
string llvmBuilder = [{
3504+
llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
3505+
llvm::Value* argA = op.getPackedArg($a, builder);
3506+
llvm::Value* argB = op.getPackedArg($b, builder);
3507+
$res = createIntrinsicCall(builder, id, {argA, argB, $c});
3508+
}];
3509+
}
3510+
34473511
//===----------------------------------------------------------------------===//
34483512
// NVVM target attribute.
34493513
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/AsmParser/Parser.h"
3434
#include "llvm/IR/Attributes.h"
3535
#include "llvm/IR/Function.h"
36+
#include "llvm/IR/IRBuilder.h"
3637
#include "llvm/IR/IntrinsicsNVPTX.h"
3738
#include "llvm/IR/Type.h"
3839
#include "llvm/Support/Casting.h"
@@ -1203,6 +1204,13 @@ LogicalResult NVVM::VoteSyncOp::verify() {
12031204
return success();
12041205
}
12051206

1207+
llvm::Value *
1208+
NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
1209+
llvm::IRBuilderBase &builder) {
1210+
return builder.CreateBitCast(arg,
1211+
llvm::Type::getInt32Ty(builder.getContext()));
1212+
}
1213+
12061214
//===----------------------------------------------------------------------===//
12071215
// getIntrinsicID/getIntrinsicIDAndArgs methods
12081216
//===----------------------------------------------------------------------===//
@@ -1590,6 +1598,26 @@ static void nvvmInferResultRanges(Operation *op, Value result,
15901598
}
15911599
}
15921600

1601+
llvm::Intrinsic::ID
1602+
DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
1603+
NVVM::DotAccumulate4WayType b_type) {
1604+
bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
1605+
bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
1606+
unsigned type = (is_a_siext << 1) | is_b_siext;
1607+
switch (type) {
1608+
case 0:
1609+
return llvm::Intrinsic::nvvm_idp4a_u_u;
1610+
case 1:
1611+
return llvm::Intrinsic::nvvm_idp4a_u_s;
1612+
case 2:
1613+
return llvm::Intrinsic::nvvm_idp4a_s_u;
1614+
case 3:
1615+
return llvm::Intrinsic::nvvm_idp4a_s_s;
1616+
default:
1617+
llvm_unreachable("Invalid DP4a type");
1618+
}
1619+
}
1620+
15931621
//===----------------------------------------------------------------------===//
15941622
// NVVMDialect initialization, type parsing, and registration.
15951623
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/nvvm.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
578578
return
579579
}
580580

581+
// CHECK-LABEL: @dot_accumulate_4way
582+
func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
583+
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
584+
%1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
585+
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
586+
%3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
587+
return
588+
}
589+
581590
// -----
582591

583592
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir.mlir

+22
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,25 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
844844
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
845845
llvm.return
846846
}
847+
848+
// -----
849+
// CHECK-LABEL: @nvvm_dot_accumulate_4way
850+
llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
851+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
852+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
853+
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
854+
%0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
855+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
856+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
857+
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
858+
%1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
859+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
860+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
861+
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
862+
%2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
863+
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
864+
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
865+
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
866+
%3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
867+
llvm.return
868+
}

0 commit comments

Comments
 (0)