diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6540273b216e3..654aff71f25be 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3444,6 +3444,70 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> { let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// NVVM dot.accumulate.4way Op +//===----------------------------------------------------------------------===// + +def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">; +def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">; + +def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType", + "NVVM DotAccumulate4WayType", + [DotAccumulate4WayS8, DotAccumulate4WayU8]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def DotAccumulate4WayTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> { + let summary = "Four-way byte dot product-accumulate instruction."; + let description = [{ + Performs a four-way byte dot-product which is accumulated in a 32-bit + result. + Operand `a` and `b` are vectors of 4 bytes between which the dot product is + computed. + The `a_type` and `b_type` attributes specify the type of the elements in `a` + and `b` respectively. + If `a_type` or `b_type` is `s8`, then the elements in the corresponding + vector are sign-extended to 32-bit before the dot product is computed. + If `a_type` or `b_type` is `u8`, then the elements in the corresponding + vector are zero-extended to 32-bit instead. + Operand `c` is a 32-bit integer to which the result is accumulated. It is + treated as holding a signed integer if any of `a_type` or `b_type` is `s8`. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a) + }]; + + let arguments = (ins + VectorOfLengthAndType<[4], [I8]>:$a, + DotAccumulate4WayTypeAttr:$a_type, + VectorOfLengthAndType<[4], [I8]>:$b, + DotAccumulate4WayTypeAttr:$b_type, + I32:$c + ); + + let results = (outs I32:$res); + + let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)"; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID + getIntrinsicID(NVVM::DotAccumulate4WayType a_type, + NVVM::DotAccumulate4WayType b_type); + llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type); + llvm::Value* argA = op.getPackedArg($a, builder); + llvm::Value* argB = op.getPackedArg($b, builder); + $res = createIntrinsicCall(builder, id, {argA, argB, $c}); + }]; +} + //===----------------------------------------------------------------------===// // NVVM target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 3c3731a63e268..1ea3f96fa75f5 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" @@ -1203,6 +1204,13 @@ LogicalResult NVVM::VoteSyncOp::verify() { return success(); } +llvm::Value * +NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg, + llvm::IRBuilderBase &builder) { + return builder.CreateBitCast(arg, + llvm::Type::getInt32Ty(builder.getContext())); +} + //===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// @@ -1590,6 +1598,26 @@ static void nvvmInferResultRanges(Operation *op, Value result, } } +llvm::Intrinsic::ID +DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type, + NVVM::DotAccumulate4WayType b_type) { + bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8; + bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8; + unsigned type = (is_a_siext << 1) | is_b_siext; + switch (type) { + case 0: + return llvm::Intrinsic::nvvm_idp4a_u_u; + case 1: + return llvm::Intrinsic::nvvm_idp4a_u_s; + case 2: + return llvm::Intrinsic::nvvm_idp4a_s_u; + case 3: + return llvm::Intrinsic::nvvm_idp4a_s_s; + default: + llvm_unreachable("Invalid DP4a type"); + } +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index d3915492c38a0..e8425638cc9be 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -578,6 +578,15 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64) return } +// CHECK-LABEL: @dot_accumulate_4way +func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) { + // CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8> + %1 = nvvm.dot.accumulate.4way %a_vec , %b_vec , %c: vector<4xi8>, vector<4xi8> + // CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8> + %3 = nvvm.dot.accumulate.4way %a_vec , %b_vec , %c: vector<4xi8>, vector<4xi8> + return +} + // ----- // Just check these don't emit errors. diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 3a0713f2feee8..894b72733a46a 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -844,3 +844,25 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3> llvm.return } + +// ----- +// CHECK-LABEL: @nvvm_dot_accumulate_4way +llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) { + // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}}) + %0 = nvvm.dot.accumulate.4way %a , %b , %c: vector<4xi8>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}}) + %1 = nvvm.dot.accumulate.4way %a , %b , %c: vector<4xi8>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}}) + %2 = nvvm.dot.accumulate.4way %a , %b , %c: vector<4xi8>, vector<4xi8> + // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 + // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}}) + %3 = nvvm.dot.accumulate.4way %a , %b , %c: vector<4xi8>, vector<4xi8> + llvm.return +}