Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 1 | //===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" |
| 10 | #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| 11 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 12 | #include "mlir/Dialect/SPIRV/SPIRVDialect.h" |
| 13 | #include "mlir/Dialect/SPIRV/SPIRVLowering.h" |
| 14 | #include "mlir/Dialect/SPIRV/SPIRVOps.h" |
Rob Suderman | 69d757c | 2020-02-21 19:54:49 | [diff] [blame] | 15 | #include "mlir/Dialect/StandardOps/IR/Ops.h" |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 16 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 17 | #include "mlir/IR/AffineExpr.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | // Utilities |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | |
| 25 | /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V |
| 26 | /// location invocation ID. This function will create necessary operations with |
| 27 | /// `builder` at the proper region containing `op`. |
| 28 | static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc, |
| 29 | OpBuilder *builder) { |
| 30 | assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); |
| 31 | Value invocation = spirv::getBuiltinVariableValue( |
| 32 | op, spirv::BuiltIn::LocalInvocationId, *builder); |
| 33 | Type xType = invocation.getType().cast<ShapedType>().getElementType(); |
| 34 | return builder->create<spirv::CompositeExtractOp>( |
| 35 | loc, xType, invocation, builder->getI32ArrayAttr({dim})); |
| 36 | } |
| 37 | |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 38 | //===----------------------------------------------------------------------===// |
| 39 | // Reduction (single workgroup) |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | |
| 42 | namespace { |
| 43 | |
| 44 | /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition |
| 45 | /// that the linalg.generic op is performing reduction with a workload size that |
| 46 | /// can fit in one workgroup. |
| 47 | class SingleWorkgroupReduction final |
| 48 | : public SPIRVOpLowering<linalg::GenericOp> { |
| 49 | public: |
| 50 | using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering; |
| 51 | |
| 52 | /// Matches the given linalg.generic op as performing reduction and returns |
| 53 | /// the binary op kind if successful. |
| 54 | static Optional<linalg::RegionMatcher::BinaryOpKind> |
| 55 | matchAsPerformingReduction(linalg::GenericOp genericOp); |
| 56 | |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 57 | LogicalResult |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 58 | matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands, |
| 59 | ConversionPatternRewriter &rewriter) const override; |
| 60 | }; |
| 61 | |
| 62 | } // namespace |
| 63 | |
| 64 | Optional<linalg::RegionMatcher::BinaryOpKind> |
| 65 | SingleWorkgroupReduction::matchAsPerformingReduction( |
| 66 | linalg::GenericOp genericOp) { |
| 67 | Operation *op = genericOp.getOperation(); |
| 68 | |
| 69 | // Make sure the linalg.generic is working on memrefs. |
| 70 | if (!genericOp.hasBufferSemantics()) |
| 71 | return llvm::None; |
| 72 | |
Kazuaki Ishizaki | e5a8512 | 2020-03-26 18:51:37 | [diff] [blame] | 73 | // Make sure this is reduction with one input and one output. |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 74 | if (genericOp.args_in().getZExtValue() != 1 || |
| 75 | genericOp.args_out().getZExtValue() != 1) |
| 76 | return llvm::None; |
| 77 | |
| 78 | auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); |
| 79 | auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); |
| 80 | |
| 81 | // Make sure the original input has one dimension. |
| 82 | if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) |
| 83 | return llvm::None; |
| 84 | // Make sure the original output has one element. |
| 85 | if (!originalOutputType.hasStaticShape() || |
| 86 | originalOutputType.getNumElements() != 1) |
| 87 | return llvm::None; |
| 88 | |
| 89 | if (!genericOp.hasSingleReductionLoop()) |
| 90 | return llvm::None; |
| 91 | |
| 92 | if (genericOp.indexing_maps().getValue().size() != 2) |
| 93 | return llvm::None; |
| 94 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame^] | 95 | // TODO: create utility functions for these checks in Linalg |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 96 | // and use them. |
| 97 | auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>(); |
| 98 | auto outputMap = |
| 99 | genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>(); |
| 100 | // The indexing map for the input should be `(i) -> (i)`. |
| 101 | if (inputMap.getValue() != |
Jeremy Bruestle | 9f3ab92 | 2020-04-15 18:12:47 | [diff] [blame] | 102 | AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 103 | return llvm::None; |
| 104 | // The indexing map for the input should be `(i) -> (0)`. |
| 105 | if (outputMap.getValue() != |
Jeremy Bruestle | 9f3ab92 | 2020-04-15 18:12:47 | [diff] [blame] | 106 | AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 107 | return llvm::None; |
| 108 | |
| 109 | return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); |
| 110 | } |
| 111 | |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 112 | LogicalResult SingleWorkgroupReduction::matchAndRewrite( |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 113 | linalg::GenericOp genericOp, ArrayRef<Value> operands, |
| 114 | ConversionPatternRewriter &rewriter) const { |
| 115 | Operation *op = genericOp.getOperation(); |
| 116 | auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); |
| 117 | auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); |
| 118 | |
| 119 | auto binaryOpKind = matchAsPerformingReduction(genericOp); |
| 120 | if (!binaryOpKind) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 121 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 122 | |
| 123 | // Query the shader interface for local workgroup size to make sure the |
| 124 | // invocation configuration fits with the input memref's shape. |
| 125 | DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); |
| 126 | if (!localSize) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 127 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 128 | |
| 129 | if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 130 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 131 | if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1), |
| 132 | [](const APInt &size) { return !size.isOneValue(); })) |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 133 | return failure(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 134 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame^] | 135 | // TODO: Query the target environment to make sure the current |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 136 | // workload fits in a local workgroup. |
| 137 | |
| 138 | Value convertedInput = operands[0], convertedOutput = operands[1]; |
| 139 | Location loc = genericOp.getLoc(); |
| 140 | |
| 141 | // Get the invocation ID. |
| 142 | Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter); |
| 143 | |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame^] | 144 | // TODO: Load to Workgroup storage class first. |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 145 | |
| 146 | // Get the input element accessed by this invocation. |
| 147 | Value inputElementPtr = spirv::getElementPtr( |
| 148 | typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); |
| 149 | Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr); |
| 150 | |
| 151 | // Perform the group reduction operation. |
| 152 | Value groupOperation; |
| 153 | #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ |
| 154 | case linalg::RegionMatcher::BinaryOpKind::opKind: { \ |
| 155 | groupOperation = rewriter.create<spirv::spvOp>( \ |
| 156 | loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ |
| 157 | spirv::GroupOperation::Reduce, inputElement, \ |
Lei Zhang | a9cb529 | 2020-04-13 19:33:35 | [diff] [blame] | 158 | /*cluster_size=*/nullptr); \ |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 159 | } break |
| 160 | switch (*binaryOpKind) { |
| 161 | CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); |
| 162 | } |
| 163 | #undef CREATE_GROUP_NON_UNIFORM_BIN_OP |
| 164 | |
| 165 | // Get the output element accessed by this reduction. |
| 166 | Value zero = spirv::ConstantOp::getZero( |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 167 | typeConverter.getIndexType(rewriter.getContext()), loc, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 168 | SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero); |
| 169 | Value outputElementPtr = |
| 170 | spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput, |
| 171 | zeroIndices, loc, rewriter); |
| 172 | |
| 173 | // Write out the final reduction result. This should be only conducted by one |
| 174 | // invocation. We use spv.GroupNonUniformElect to find the invocation with the |
| 175 | // lowest ID. |
| 176 | // |
| 177 | // ``` |
| 178 | // if (spv.GroupNonUniformElect) { output = ... } |
| 179 | // ``` |
| 180 | |
| 181 | Value condition = rewriter.create<spirv::GroupNonUniformElectOp>( |
| 182 | loc, spirv::Scope::Subgroup); |
| 183 | |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 184 | auto createAtomicOp = [&](OpBuilder &builder) { |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 185 | #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ |
| 186 | case linalg::RegionMatcher::BinaryOpKind::opKind: { \ |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 187 | builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \ |
| 188 | spirv::MemorySemantics::AcquireRelease, \ |
| 189 | groupOperation); \ |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 190 | } break |
| 191 | switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); } |
| 192 | #undef CREATE_ATOMIC_BIN_OP |
| 193 | }; |
| 194 | |
Alex Zinenko | bb1d976 | 2020-04-23 14:02:46 | [diff] [blame] | 195 | spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 196 | |
| 197 | rewriter.eraseOp(genericOp); |
River Riddle | 3145427 | 2020-03-18 03:07:55 | [diff] [blame] | 198 | return success(); |
Lei Zhang | df71000 | 2020-01-26 16:10:29 | [diff] [blame] | 199 | } |
| 200 | |
| 201 | //===----------------------------------------------------------------------===// |
| 202 | // Pattern population |
| 203 | //===----------------------------------------------------------------------===// |
| 204 | |
| 205 | void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context, |
| 206 | SPIRVTypeConverter &typeConverter, |
| 207 | OwningRewritePatternList &patterns) { |
| 208 | patterns.insert<SingleWorkgroupReduction>(context, typeConverter); |
| 209 | } |