blob: cf67b96fce1ee5deb1a51af881b10f111be54e74 [file] [log] [blame]
Lei Zhangdf710002020-01-26 16:10:291//===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
10#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11#include "mlir/Dialect/Linalg/Utils/Utils.h"
12#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
13#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
14#include "mlir/Dialect/SPIRV/SPIRVOps.h"
Rob Suderman69d757c2020-02-21 19:54:4915#include "mlir/Dialect/StandardOps/IR/Ops.h"
Lei Zhangdf710002020-01-26 16:10:2916#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17#include "mlir/IR/AffineExpr.h"
18
19using namespace mlir;
20
21//===----------------------------------------------------------------------===//
22// Utilities
23//===----------------------------------------------------------------------===//
24
25/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
26/// location invocation ID. This function will create necessary operations with
27/// `builder` at the proper region containing `op`.
28static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
29 OpBuilder *builder) {
30 assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
31 Value invocation = spirv::getBuiltinVariableValue(
32 op, spirv::BuiltIn::LocalInvocationId, *builder);
33 Type xType = invocation.getType().cast<ShapedType>().getElementType();
34 return builder->create<spirv::CompositeExtractOp>(
35 loc, xType, invocation, builder->getI32ArrayAttr({dim}));
36}
37
Lei Zhangdf710002020-01-26 16:10:2938//===----------------------------------------------------------------------===//
39// Reduction (single workgroup)
40//===----------------------------------------------------------------------===//
41
42namespace {
43
44/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
45/// that the linalg.generic op is performing reduction with a workload size that
46/// can fit in one workgroup.
47class SingleWorkgroupReduction final
48 : public SPIRVOpLowering<linalg::GenericOp> {
49public:
50 using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;
51
52 /// Matches the given linalg.generic op as performing reduction and returns
53 /// the binary op kind if successful.
54 static Optional<linalg::RegionMatcher::BinaryOpKind>
55 matchAsPerformingReduction(linalg::GenericOp genericOp);
56
River Riddle31454272020-03-18 03:07:5557 LogicalResult
Lei Zhangdf710002020-01-26 16:10:2958 matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
59 ConversionPatternRewriter &rewriter) const override;
60};
61
62} // namespace
63
64Optional<linalg::RegionMatcher::BinaryOpKind>
65SingleWorkgroupReduction::matchAsPerformingReduction(
66 linalg::GenericOp genericOp) {
67 Operation *op = genericOp.getOperation();
68
69 // Make sure the linalg.generic is working on memrefs.
70 if (!genericOp.hasBufferSemantics())
71 return llvm::None;
72
Kazuaki Ishizakie5a85122020-03-26 18:51:3773 // Make sure this is reduction with one input and one output.
Lei Zhangdf710002020-01-26 16:10:2974 if (genericOp.args_in().getZExtValue() != 1 ||
75 genericOp.args_out().getZExtValue() != 1)
76 return llvm::None;
77
78 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
79 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
80
81 // Make sure the original input has one dimension.
82 if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
83 return llvm::None;
84 // Make sure the original output has one element.
85 if (!originalOutputType.hasStaticShape() ||
86 originalOutputType.getNumElements() != 1)
87 return llvm::None;
88
89 if (!genericOp.hasSingleReductionLoop())
90 return llvm::None;
91
92 if (genericOp.indexing_maps().getValue().size() != 2)
93 return llvm::None;
94
95 // TODO(nicolasvasilache): create utility functions for these checks in Linalg
96 // and use them.
97 auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
98 auto outputMap =
99 genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
100 // The indexing map for the input should be `(i) -> (i)`.
101 if (inputMap.getValue() !=
Jeremy Bruestle9f3ab922020-04-15 18:12:47102 AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
Lei Zhangdf710002020-01-26 16:10:29103 return llvm::None;
104 // The indexing map for the input should be `(i) -> (0)`.
105 if (outputMap.getValue() !=
Jeremy Bruestle9f3ab922020-04-15 18:12:47106 AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
Lei Zhangdf710002020-01-26 16:10:29107 return llvm::None;
108
109 return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
110}
111
River Riddle31454272020-03-18 03:07:55112LogicalResult SingleWorkgroupReduction::matchAndRewrite(
Lei Zhangdf710002020-01-26 16:10:29113 linalg::GenericOp genericOp, ArrayRef<Value> operands,
114 ConversionPatternRewriter &rewriter) const {
115 Operation *op = genericOp.getOperation();
116 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
117 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
118
119 auto binaryOpKind = matchAsPerformingReduction(genericOp);
120 if (!binaryOpKind)
River Riddle31454272020-03-18 03:07:55121 return failure();
Lei Zhangdf710002020-01-26 16:10:29122
123 // Query the shader interface for local workgroup size to make sure the
124 // invocation configuration fits with the input memref's shape.
125 DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
126 if (!localSize)
River Riddle31454272020-03-18 03:07:55127 return failure();
Lei Zhangdf710002020-01-26 16:10:29128
129 if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
River Riddle31454272020-03-18 03:07:55130 return failure();
Lei Zhangdf710002020-01-26 16:10:29131 if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
132 [](const APInt &size) { return !size.isOneValue(); }))
River Riddle31454272020-03-18 03:07:55133 return failure();
Lei Zhangdf710002020-01-26 16:10:29134
135 // TODO(antiagainst): Query the target environment to make sure the current
136 // workload fits in a local workgroup.
137
138 Value convertedInput = operands[0], convertedOutput = operands[1];
139 Location loc = genericOp.getLoc();
140
141 // Get the invocation ID.
142 Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
143
144 // TODO(antiagainst): Load to Workgroup storage class first.
145
146 // Get the input element accessed by this invocation.
147 Value inputElementPtr = spirv::getElementPtr(
148 typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
149 Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
150
151 // Perform the group reduction operation.
152 Value groupOperation;
153#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \
154 case linalg::RegionMatcher::BinaryOpKind::opKind: { \
155 groupOperation = rewriter.create<spirv::spvOp>( \
156 loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
157 spirv::GroupOperation::Reduce, inputElement, \
Lei Zhanga9cb5292020-04-13 19:33:35158 /*cluster_size=*/nullptr); \
Lei Zhangdf710002020-01-26 16:10:29159 } break
160 switch (*binaryOpKind) {
161 CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
162 }
163#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
164
165 // Get the output element accessed by this reduction.
166 Value zero = spirv::ConstantOp::getZero(
Alex Zinenkobb1d9762020-04-23 14:02:46167 typeConverter.getIndexType(rewriter.getContext()), loc, rewriter);
Lei Zhangdf710002020-01-26 16:10:29168 SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
169 Value outputElementPtr =
170 spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
171 zeroIndices, loc, rewriter);
172
173 // Write out the final reduction result. This should be only conducted by one
174 // invocation. We use spv.GroupNonUniformElect to find the invocation with the
175 // lowest ID.
176 //
177 // ```
178 // if (spv.GroupNonUniformElect) { output = ... }
179 // ```
180
181 Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
182 loc, spirv::Scope::Subgroup);
183
Alex Zinenkobb1d9762020-04-23 14:02:46184 auto createAtomicOp = [&](OpBuilder &builder) {
Lei Zhangdf710002020-01-26 16:10:29185#define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \
186 case linalg::RegionMatcher::BinaryOpKind::opKind: { \
Alex Zinenkobb1d9762020-04-23 14:02:46187 builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
188 spirv::MemorySemantics::AcquireRelease, \
189 groupOperation); \
Lei Zhangdf710002020-01-26 16:10:29190 } break
191 switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
192#undef CREATE_ATOMIC_BIN_OP
193 };
194
Alex Zinenkobb1d9762020-04-23 14:02:46195 spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
Lei Zhangdf710002020-01-26 16:10:29196
197 rewriter.eraseOp(genericOp);
River Riddle31454272020-03-18 03:07:55198 return success();
Lei Zhangdf710002020-01-26 16:10:29199}
200
201//===----------------------------------------------------------------------===//
202// Pattern population
203//===----------------------------------------------------------------------===//
204
205void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
206 SPIRVTypeConverter &typeConverter,
207 OwningRewritePatternList &patterns) {
208 patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
209}