blob: f8d81c57b7747db0dc0a836afc21e7bf130eb7db [file] [log] [blame]
Lei Zhang930c74f2020-12-23 19:32:311//===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===//
Lei Zhangdf710002020-01-26 16:10:292//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
River Riddle23aa5a72022-02-26 22:49:5410#include "mlir/Dialect/Func/IR/FuncOps.h"
gysitb7f2c102021-12-15 12:14:3511#include "mlir/Dialect/Linalg/IR/Linalg.h"
Lei Zhangdf710002020-01-26 16:10:2912#include "mlir/Dialect/Linalg/Utils/Utils.h"
Lei Zhang01178652020-12-17 15:55:4513#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
Lei Zhangdf710002020-01-26 16:10:2916#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17#include "mlir/IR/AffineExpr.h"
Lei Zhang7c3ae482021-01-09 13:04:4918#include "mlir/Transforms/DialectConversion.h"
Lei Zhangdf710002020-01-26 16:10:2919
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// Utilities
24//===----------------------------------------------------------------------===//
25
26/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
27/// location invocation ID. This function will create necessary operations with
28/// `builder` at the proper region containing `op`.
Butygin1e35a762021-08-14 08:57:0229static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
30 Location loc, OpBuilder *builder) {
Lei Zhangdf710002020-01-26 16:10:2931 assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
32 Value invocation = spirv::getBuiltinVariableValue(
Butygin1e35a762021-08-14 08:57:0233 op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
Lei Zhangdf710002020-01-26 16:10:2934 Type xType = invocation.getType().cast<ShapedType>().getElementType();
35 return builder->create<spirv::CompositeExtractOp>(
36 loc, xType, invocation, builder->getI32ArrayAttr({dim}));
37}
38
Lei Zhangdf710002020-01-26 16:10:2939//===----------------------------------------------------------------------===//
40// Reduction (single workgroup)
41//===----------------------------------------------------------------------===//
42
43namespace {
44
45/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
46/// that the linalg.generic op is performing reduction with a workload size that
47/// can fit in one workgroup.
Lei Zhang7c3ae482021-01-09 13:04:4948struct SingleWorkgroupReduction final
49 : public OpConversionPattern<linalg::GenericOp> {
50 using OpConversionPattern::OpConversionPattern;
Lei Zhangdf710002020-01-26 16:10:2951
52 /// Matches the given linalg.generic op as performing reduction and returns
53 /// the binary op kind if successful.
54 static Optional<linalg::RegionMatcher::BinaryOpKind>
55 matchAsPerformingReduction(linalg::GenericOp genericOp);
56
River Riddle31454272020-03-18 03:07:5557 LogicalResult
River Riddleb54c7242021-09-24 17:50:5858 matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
Lei Zhangdf710002020-01-26 16:10:2959 ConversionPatternRewriter &rewriter) const override;
60};
61
62} // namespace
63
64Optional<linalg::RegionMatcher::BinaryOpKind>
65SingleWorkgroupReduction::matchAsPerformingReduction(
66 linalg::GenericOp genericOp) {
67 Operation *op = genericOp.getOperation();
68
69 // Make sure the linalg.generic is working on memrefs.
70 if (!genericOp.hasBufferSemantics())
71 return llvm::None;
72
Kazuaki Ishizakie5a85122020-03-26 18:51:3773 // Make sure this is reduction with one input and one output.
Nicolas Vasilacheed229132020-09-21 19:30:4274 if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
Lei Zhangdf710002020-01-26 16:10:2975 return llvm::None;
76
77 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
78 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
79
80 // Make sure the original input has one dimension.
81 if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
82 return llvm::None;
83 // Make sure the original output has one element.
84 if (!originalOutputType.hasStaticShape() ||
85 originalOutputType.getNumElements() != 1)
86 return llvm::None;
87
88 if (!genericOp.hasSingleReductionLoop())
89 return llvm::None;
90
91 if (genericOp.indexing_maps().getValue().size() != 2)
92 return llvm::None;
93
River Riddle9db53a12020-07-07 08:35:2394 // TODO: create utility functions for these checks in Linalg
Lei Zhangdf710002020-01-26 16:10:2995 // and use them.
96 auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
97 auto outputMap =
98 genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
99 // The indexing map for the input should be `(i) -> (i)`.
100 if (inputMap.getValue() !=
Jeremy Bruestle9f3ab922020-04-15 18:12:47101 AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
Lei Zhangdf710002020-01-26 16:10:29102 return llvm::None;
103 // The indexing map for the input should be `(i) -> (0)`.
104 if (outputMap.getValue() !=
Jeremy Bruestle9f3ab922020-04-15 18:12:47105 AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
Lei Zhangdf710002020-01-26 16:10:29106 return llvm::None;
107
108 return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
109}
110
River Riddle31454272020-03-18 03:07:55111LogicalResult SingleWorkgroupReduction::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58112 linalg::GenericOp genericOp, OpAdaptor adaptor,
Lei Zhangdf710002020-01-26 16:10:29113 ConversionPatternRewriter &rewriter) const {
114 Operation *op = genericOp.getOperation();
115 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
116 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
117
118 auto binaryOpKind = matchAsPerformingReduction(genericOp);
119 if (!binaryOpKind)
River Riddle31454272020-03-18 03:07:55120 return failure();
Lei Zhangdf710002020-01-26 16:10:29121
122 // Query the shader interface for local workgroup size to make sure the
123 // invocation configuration fits with the input memref's shape.
124 DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
125 if (!localSize)
River Riddle31454272020-03-18 03:07:55126 return failure();
Lei Zhangdf710002020-01-26 16:10:29127
128 if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
River Riddle31454272020-03-18 03:07:55129 return failure();
River Riddle0cb5d7f2021-09-21 01:40:22130 if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
Lei Zhangdf710002020-01-26 16:10:29131 [](const APInt &size) { return !size.isOneValue(); }))
River Riddle31454272020-03-18 03:07:55132 return failure();
Lei Zhangdf710002020-01-26 16:10:29133
River Riddle9db53a12020-07-07 08:35:23134 // TODO: Query the target environment to make sure the current
Lei Zhangdf710002020-01-26 16:10:29135 // workload fits in a local workgroup.
136
River Riddleb54c7242021-09-24 17:50:58137 Value convertedInput = adaptor.getOperands()[0];
138 Value convertedOutput = adaptor.getOperands()[1];
Lei Zhangdf710002020-01-26 16:10:29139 Location loc = genericOp.getLoc();
140
Butygin1e35a762021-08-14 08:57:02141 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
142 auto indexType = typeConverter->getIndexType();
143
Lei Zhangdf710002020-01-26 16:10:29144 // Get the invocation ID.
Butygin1e35a762021-08-14 08:57:02145 Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
146 &rewriter);
Lei Zhangdf710002020-01-26 16:10:29147
River Riddle9db53a12020-07-07 08:35:23148 // TODO: Load to Workgroup storage class first.
Lei Zhangdf710002020-01-26 16:10:29149
Lei Zhang7c3ae482021-01-09 13:04:49150
Lei Zhangdf710002020-01-26 16:10:29151 // Get the input element accessed by this invocation.
152 Value inputElementPtr = spirv::getElementPtr(
Lei Zhang7c3ae482021-01-09 13:04:49153 *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
Lei Zhangdf710002020-01-26 16:10:29154 Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
155
156 // Perform the group reduction operation.
157 Value groupOperation;
158#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \
159 case linalg::RegionMatcher::BinaryOpKind::opKind: { \
160 groupOperation = rewriter.create<spirv::spvOp>( \
161 loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
162 spirv::GroupOperation::Reduce, inputElement, \
Lei Zhanga9cb5292020-04-13 19:33:35163 /*cluster_size=*/nullptr); \
Lei Zhangdf710002020-01-26 16:10:29164 } break
165 switch (*binaryOpKind) {
166 CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
167 }
168#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
169
170 // Get the output element accessed by this reduction.
Butygin1e35a762021-08-14 08:57:02171 Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
Lei Zhangdf710002020-01-26 16:10:29172 SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
173 Value outputElementPtr =
Lei Zhang7c3ae482021-01-09 13:04:49174 spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
Lei Zhangdf710002020-01-26 16:10:29175 zeroIndices, loc, rewriter);
176
177 // Write out the final reduction result. This should be only conducted by one
178 // invocation. We use spv.GroupNonUniformElect to find the invocation with the
179 // lowest ID.
180 //
181 // ```
182 // if (spv.GroupNonUniformElect) { output = ... }
183 // ```
184
185 Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
186 loc, spirv::Scope::Subgroup);
187
Alex Zinenkobb1d9762020-04-23 14:02:46188 auto createAtomicOp = [&](OpBuilder &builder) {
Lei Zhangdf710002020-01-26 16:10:29189#define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \
190 case linalg::RegionMatcher::BinaryOpKind::opKind: { \
Alex Zinenkobb1d9762020-04-23 14:02:46191 builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
192 spirv::MemorySemantics::AcquireRelease, \
193 groupOperation); \
Lei Zhangdf710002020-01-26 16:10:29194 } break
195 switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
196#undef CREATE_ATOMIC_BIN_OP
197 };
198
Alex Zinenkobb1d9762020-04-23 14:02:46199 spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
Lei Zhangdf710002020-01-26 16:10:29200
201 rewriter.eraseOp(genericOp);
River Riddle31454272020-03-18 03:07:55202 return success();
Lei Zhangdf710002020-01-26 16:10:29203}
204
205//===----------------------------------------------------------------------===//
206// Pattern population
207//===----------------------------------------------------------------------===//
208
Chris Lattner3a506b32021-03-20 23:29:41209void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
Chris Lattnerdc4e9132021-03-22 23:58:34210 RewritePatternSet &patterns) {
211 patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext());
Lei Zhangdf710002020-01-26 16:10:29212}