Lei Zhang | 930c74f | 2020-12-23 19:32:31 | [diff] [blame] | 1 | //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===// |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" |
River Riddle | 23aa5a7 | 2022-02-26 22:49:54 | [diff] [blame^] | 10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
gysit | b7f2c10 | 2021-12-15 12:14:35 | [diff] [blame] | 11 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 12 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
Lei Zhang | 0117865 | 2020-12-17 15:55:45 | [diff] [blame] | 13 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 14 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 15 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 16 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 17 | #include "mlir/IR/AffineExpr.h" |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 18 | #include "mlir/Transforms/DialectConversion.h" |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | // Utilities |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | |
| 26 | /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V |
| 27 | /// location invocation ID. This function will create necessary operations with |
| 28 | /// `builder` at the proper region containing `op`. |
Butygin | 1e35a76 | 2021-08-14 08:57:02 | [diff] [blame] | 29 | static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType, |
| 30 | Location loc, OpBuilder *builder) { |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 31 | assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); |
| 32 | Value invocation = spirv::getBuiltinVariableValue( |
Butygin | 1e35a76 | 2021-08-14 08:57:02 | [diff] [blame] | 33 | op, spirv::BuiltIn::LocalInvocationId, integerType, *builder); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 34 | Type xType = invocation.getType().cast<ShapedType>().getElementType(); |
| 35 | return builder->create<spirv::CompositeExtractOp>( |
| 36 | loc, xType, invocation, builder->getI32ArrayAttr({dim})); |
| 37 | } |
| 38 | |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 39 | //===----------------------------------------------------------------------===// |
| 40 | // Reduction (single workgroup) |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | |
| 43 | namespace { |
| 44 | |
| 45 | /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition |
| 46 | /// that the linalg.generic op is performing reduction with a workload size that |
| 47 | /// can fit in one workgroup. |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 48 | struct SingleWorkgroupReduction final |
| 49 | : public OpConversionPattern<linalg::GenericOp> { |
| 50 | using OpConversionPattern::OpConversionPattern; |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 51 | |
| 52 | /// Matches the given linalg.generic op as performing reduction and returns |
| 53 | /// the binary op kind if successful. |
| 54 | static Optional<linalg::RegionMatcher::BinaryOpKind> |
| 55 | matchAsPerformingReduction(linalg::GenericOp genericOp); |
| 56 | |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 57 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 58 | matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor, |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 59 | ConversionPatternRewriter &rewriter) const override; |
| 60 | }; |
| 61 | |
| 62 | } // namespace |
| 63 | |
| 64 | Optional<linalg::RegionMatcher::BinaryOpKind> |
| 65 | SingleWorkgroupReduction::matchAsPerformingReduction( |
| 66 | linalg::GenericOp genericOp) { |
| 67 | Operation *op = genericOp.getOperation(); |
| 68 | |
| 69 | // Make sure the linalg.generic is working on memrefs. |
| 70 | if (!genericOp.hasBufferSemantics()) |
| 71 | return llvm::None; |
| 72 | |
Kazuaki Ishizaki | e5a8512 | 2020-03-26 18:51:37 | [diff] [blame] | 73 | // Make sure this is reduction with one input and one output. |
Nicolas Vasilache | ed22913 | 2020-09-21 19:30:42 | [diff] [blame] | 74 | if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 75 | return llvm::None; |
| 76 | |
| 77 | auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); |
| 78 | auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); |
| 79 | |
| 80 | // Make sure the original input has one dimension. |
| 81 | if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) |
| 82 | return llvm::None; |
| 83 | // Make sure the original output has one element. |
| 84 | if (!originalOutputType.hasStaticShape() || |
| 85 | originalOutputType.getNumElements() != 1) |
| 86 | return llvm::None; |
| 87 | |
| 88 | if (!genericOp.hasSingleReductionLoop()) |
| 89 | return llvm::None; |
| 90 | |
| 91 | if (genericOp.indexing_maps().getValue().size() != 2) |
| 92 | return llvm::None; |
| 93 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 94 | // TODO: create utility functions for these checks in Linalg |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 95 | // and use them. |
| 96 | auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>(); |
| 97 | auto outputMap = |
| 98 | genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>(); |
| 99 | // The indexing map for the input should be `(i) -> (i)`. |
| 100 | if (inputMap.getValue() != |
Jeremy Bruestle | 9f3ab92 | 2020-04-15 18:12:47 | [diff] [blame] | 101 | AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 102 | return llvm::None; |
| 103 | // The indexing map for the input should be `(i) -> (0)`. |
| 104 | if (outputMap.getValue() != |
Jeremy Bruestle | 9f3ab92 | 2020-04-15 18:12:47 | [diff] [blame] | 105 | AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 106 | return llvm::None; |
| 107 | |
| 108 | return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); |
| 109 | } |
| 110 | |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 111 | LogicalResult SingleWorkgroupReduction::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 112 | linalg::GenericOp genericOp, OpAdaptor adaptor, |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 113 | ConversionPatternRewriter &rewriter) const { |
| 114 | Operation *op = genericOp.getOperation(); |
| 115 | auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); |
| 116 | auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); |
| 117 | |
| 118 | auto binaryOpKind = matchAsPerformingReduction(genericOp); |
| 119 | if (!binaryOpKind) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 120 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 121 | |
| 122 | // Query the shader interface for local workgroup size to make sure the |
| 123 | // invocation configuration fits with the input memref's shape. |
| 124 | DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); |
| 125 | if (!localSize) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 126 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 127 | |
| 128 | if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 129 | return failure(); |
River Riddle | 0cb5d7f | 2021-09-21 01:40:22 | [diff] [blame] | 130 | if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1), |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 131 | [](const APInt &size) { return !size.isOneValue(); })) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 132 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 133 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 134 | // TODO: Query the target environment to make sure the current |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 135 | // workload fits in a local workgroup. |
| 136 | |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 137 | Value convertedInput = adaptor.getOperands()[0]; |
| 138 | Value convertedOutput = adaptor.getOperands()[1]; |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 139 | Location loc = genericOp.getLoc(); |
| 140 | |
Butygin | 1e35a76 | 2021-08-14 08:57:02 | [diff] [blame] | 141 | auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); |
| 142 | auto indexType = typeConverter->getIndexType(); |
| 143 | |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 144 | // Get the invocation ID. |
Butygin | 1e35a76 | 2021-08-14 08:57:02 | [diff] [blame] | 145 | Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc, |
| 146 | &rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 147 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 148 | // TODO: Load to Workgroup storage class first. |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 149 | |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 150 | |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 151 | // Get the input element accessed by this invocation. |
| 152 | Value inputElementPtr = spirv::getElementPtr( |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 153 | *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 154 | Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr); |
| 155 | |
| 156 | // Perform the group reduction operation. |
| 157 | Value groupOperation; |
| 158 | #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ |
| 159 | case linalg::RegionMatcher::BinaryOpKind::opKind: { \ |
| 160 | groupOperation = rewriter.create<spirv::spvOp>( \ |
| 161 | loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ |
| 162 | spirv::GroupOperation::Reduce, inputElement, \ |
Lei Zhang | a9cb529 | 2020-04-13 19:33:35 | [diff] [blame] | 163 | /*cluster_size=*/nullptr); \ |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 164 | } break |
| 165 | switch (*binaryOpKind) { |
| 166 | CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); |
| 167 | } |
| 168 | #undef CREATE_GROUP_NON_UNIFORM_BIN_OP |
| 169 | |
| 170 | // Get the output element accessed by this reduction. |
Butygin | 1e35a76 | 2021-08-14 08:57:02 | [diff] [blame] | 171 | Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 172 | SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero); |
| 173 | Value outputElementPtr = |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 174 | spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 175 | zeroIndices, loc, rewriter); |
| 176 | |
| 177 | // Write out the final reduction result. This should be only conducted by one |
| 178 | // invocation. We use spv.GroupNonUniformElect to find the invocation with the |
| 179 | // lowest ID. |
| 180 | // |
| 181 | // ``` |
| 182 | // if (spv.GroupNonUniformElect) { output = ... } |
| 183 | // ``` |
| 184 | |
| 185 | Value condition = rewriter.create<spirv::GroupNonUniformElectOp>( |
| 186 | loc, spirv::Scope::Subgroup); |
| 187 | |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 188 | auto createAtomicOp = [&](OpBuilder &builder) { |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 189 | #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ |
| 190 | case linalg::RegionMatcher::BinaryOpKind::opKind: { \ |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 191 | builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \ |
| 192 | spirv::MemorySemantics::AcquireRelease, \ |
| 193 | groupOperation); \ |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 194 | } break |
| 195 | switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); } |
| 196 | #undef CREATE_ATOMIC_BIN_OP |
| 197 | }; |
| 198 | |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 199 | spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 200 | |
| 201 | rewriter.eraseOp(genericOp); |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 202 | return success(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 203 | } |
| 204 | |
| 205 | //===----------------------------------------------------------------------===// |
| 206 | // Pattern population |
| 207 | //===----------------------------------------------------------------------===// |
| 208 | |
Chris Lattner | 3a506b3 | 2021-03-20 23:29:41 | [diff] [blame] | 209 | void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame] | 210 | RewritePatternSet &patterns) { |
| 211 | patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext()); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 212 | } |