Lei Zhang | 930c74f | 2020-12-23 19:32:31 | [diff] [blame] | 1 | //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===// |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" |
| 10 | #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| 11 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
Lei Zhang | 0117865 | 2020-12-17 15:55:45 | [diff] [blame] | 12 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 14 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
Rob Suderman | 69d757c | 2020-02-21 19:54:49 | [diff] [blame] | 15 | #include "mlir/Dialect/StandardOps/IR/Ops.h" |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 16 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 17 | #include "mlir/IR/AffineExpr.h" |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 18 | #include "mlir/Transforms/DialectConversion.h" |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | // Utilities |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | |
| 26 | /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V |
| 27 | /// location invocation ID. This function will create necessary operations with |
| 28 | /// `builder` at the proper region containing `op`. |
| 29 | static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc, |
| 30 | OpBuilder *builder) { |
| 31 | assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); |
| 32 | Value invocation = spirv::getBuiltinVariableValue( |
| 33 | op, spirv::BuiltIn::LocalInvocationId, *builder); |
| 34 | Type xType = invocation.getType().cast<ShapedType>().getElementType(); |
| 35 | return builder->create<spirv::CompositeExtractOp>( |
| 36 | loc, xType, invocation, builder->getI32ArrayAttr({dim})); |
| 37 | } |
| 38 | |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 39 | //===----------------------------------------------------------------------===// |
| 40 | // Reduction (single workgroup) |
| 41 | //===----------------------------------------------------------------------===// |
| 42 | |
| 43 | namespace { |
| 44 | |
| 45 | /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition |
| 46 | /// that the linalg.generic op is performing reduction with a workload size that |
| 47 | /// can fit in one workgroup. |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 48 | struct SingleWorkgroupReduction final |
| 49 | : public OpConversionPattern<linalg::GenericOp> { |
| 50 | using OpConversionPattern::OpConversionPattern; |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 51 | |
| 52 | /// Matches the given linalg.generic op as performing reduction and returns |
| 53 | /// the binary op kind if successful. |
| 54 | static Optional<linalg::RegionMatcher::BinaryOpKind> |
| 55 | matchAsPerformingReduction(linalg::GenericOp genericOp); |
| 56 | |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 57 | LogicalResult |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 58 | matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands, |
| 59 | ConversionPatternRewriter &rewriter) const override; |
| 60 | }; |
| 61 | |
| 62 | } // namespace |
| 63 | |
| 64 | Optional<linalg::RegionMatcher::BinaryOpKind> |
| 65 | SingleWorkgroupReduction::matchAsPerformingReduction( |
| 66 | linalg::GenericOp genericOp) { |
| 67 | Operation *op = genericOp.getOperation(); |
| 68 | |
| 69 | // Make sure the linalg.generic is working on memrefs. |
| 70 | if (!genericOp.hasBufferSemantics()) |
| 71 | return llvm::None; |
| 72 | |
Kazuaki Ishizaki | e5a8512 | 2020-03-26 18:51:37 | [diff] [blame] | 73 | // Make sure this is reduction with one input and one output. |
Nicolas Vasilache | ed22913 | 2020-09-21 19:30:42 | [diff] [blame] | 74 | if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 75 | return llvm::None; |
| 76 | |
| 77 | auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); |
| 78 | auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); |
| 79 | |
| 80 | // Make sure the original input has one dimension. |
| 81 | if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) |
| 82 | return llvm::None; |
| 83 | // Make sure the original output has one element. |
| 84 | if (!originalOutputType.hasStaticShape() || |
| 85 | originalOutputType.getNumElements() != 1) |
| 86 | return llvm::None; |
| 87 | |
| 88 | if (!genericOp.hasSingleReductionLoop()) |
| 89 | return llvm::None; |
| 90 | |
| 91 | if (genericOp.indexing_maps().getValue().size() != 2) |
| 92 | return llvm::None; |
| 93 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 94 | // TODO: create utility functions for these checks in Linalg |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 95 | // and use them. |
| 96 | auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>(); |
| 97 | auto outputMap = |
| 98 | genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>(); |
| 99 | // The indexing map for the input should be `(i) -> (i)`. |
| 100 | if (inputMap.getValue() != |
Jeremy Bruestle | 9f3ab92 | 2020-04-15 18:12:47 | [diff] [blame] | 101 | AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 102 | return llvm::None; |
| 103 | // The indexing map for the input should be `(i) -> (0)`. |
| 104 | if (outputMap.getValue() != |
Jeremy Bruestle | 9f3ab92 | 2020-04-15 18:12:47 | [diff] [blame] | 105 | AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 106 | return llvm::None; |
| 107 | |
| 108 | return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); |
| 109 | } |
| 110 | |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 111 | LogicalResult SingleWorkgroupReduction::matchAndRewrite( |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 112 | linalg::GenericOp genericOp, ArrayRef<Value> operands, |
| 113 | ConversionPatternRewriter &rewriter) const { |
| 114 | Operation *op = genericOp.getOperation(); |
| 115 | auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); |
| 116 | auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); |
| 117 | |
| 118 | auto binaryOpKind = matchAsPerformingReduction(genericOp); |
| 119 | if (!binaryOpKind) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 120 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 121 | |
| 122 | // Query the shader interface for local workgroup size to make sure the |
| 123 | // invocation configuration fits with the input memref's shape. |
| 124 | DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); |
| 125 | if (!localSize) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 126 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 127 | |
| 128 | if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 129 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 130 | if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1), |
| 131 | [](const APInt &size) { return !size.isOneValue(); })) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 132 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 133 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 134 | // TODO: Query the target environment to make sure the current |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 135 | // workload fits in a local workgroup. |
| 136 | |
| 137 | Value convertedInput = operands[0], convertedOutput = operands[1]; |
| 138 | Location loc = genericOp.getLoc(); |
| 139 | |
| 140 | // Get the invocation ID. |
| 141 | Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter); |
| 142 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 143 | // TODO: Load to Workgroup storage class first. |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 144 | |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 145 | auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); |
| 146 | |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 147 | // Get the input element accessed by this invocation. |
| 148 | Value inputElementPtr = spirv::getElementPtr( |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 149 | *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 150 | Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr); |
| 151 | |
| 152 | // Perform the group reduction operation. |
| 153 | Value groupOperation; |
| 154 | #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ |
| 155 | case linalg::RegionMatcher::BinaryOpKind::opKind: { \ |
| 156 | groupOperation = rewriter.create<spirv::spvOp>( \ |
| 157 | loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ |
| 158 | spirv::GroupOperation::Reduce, inputElement, \ |
Lei Zhang | a9cb529 | 2020-04-13 19:33:35 | [diff] [blame] | 159 | /*cluster_size=*/nullptr); \ |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 160 | } break |
| 161 | switch (*binaryOpKind) { |
| 162 | CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); |
| 163 | } |
| 164 | #undef CREATE_GROUP_NON_UNIFORM_BIN_OP |
| 165 | |
| 166 | // Get the output element accessed by this reduction. |
| 167 | Value zero = spirv::ConstantOp::getZero( |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 168 | typeConverter->getIndexType(rewriter.getContext()), loc, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 169 | SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero); |
| 170 | Value outputElementPtr = |
Lei Zhang | 7c3ae48 | 2021-01-09 13:04:49 | [diff] [blame] | 171 | spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 172 | zeroIndices, loc, rewriter); |
| 173 | |
| 174 | // Write out the final reduction result. This should be only conducted by one |
| 175 | // invocation. We use spv.GroupNonUniformElect to find the invocation with the |
| 176 | // lowest ID. |
| 177 | // |
| 178 | // ``` |
| 179 | // if (spv.GroupNonUniformElect) { output = ... } |
| 180 | // ``` |
| 181 | |
| 182 | Value condition = rewriter.create<spirv::GroupNonUniformElectOp>( |
| 183 | loc, spirv::Scope::Subgroup); |
| 184 | |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 185 | auto createAtomicOp = [&](OpBuilder &builder) { |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 186 | #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ |
| 187 | case linalg::RegionMatcher::BinaryOpKind::opKind: { \ |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 188 | builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \ |
| 189 | spirv::MemorySemantics::AcquireRelease, \ |
| 190 | groupOperation); \ |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 191 | } break |
| 192 | switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); } |
| 193 | #undef CREATE_ATOMIC_BIN_OP |
| 194 | }; |
| 195 | |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 196 | spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 197 | |
| 198 | rewriter.eraseOp(genericOp); |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 199 | return success(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 200 | } |
| 201 | |
| 202 | //===----------------------------------------------------------------------===// |
| 203 | // Pattern population |
| 204 | //===----------------------------------------------------------------------===// |
| 205 | |
Chris Lattner | 3a506b3 | 2021-03-20 23:29:41 | [diff] [blame] | 206 | void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame^] | 207 | RewritePatternSet &patterns) { |
| 208 | patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext()); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 209 | } |