[mlir][spirv] Initial support for 64 bit index type and builtins
Differential Revision: https://reviews.llvm.org/D108516
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index a94435c..c20f4ed 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -26,11 +26,11 @@
/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
/// location invocation ID. This function will create necessary operations with
/// `builder` at the proper region containing `op`.
-static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
- OpBuilder *builder) {
+static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
+ Location loc, OpBuilder *builder) {
assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
Value invocation = spirv::getBuiltinVariableValue(
- op, spirv::BuiltIn::LocalInvocationId, *builder);
+ op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
Type xType = invocation.getType().cast<ShapedType>().getElementType();
return builder->create<spirv::CompositeExtractOp>(
loc, xType, invocation, builder->getI32ArrayAttr({dim}));
@@ -137,12 +137,15 @@
Value convertedInput = operands[0], convertedOutput = operands[1];
Location loc = genericOp.getLoc();
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ auto indexType = typeConverter->getIndexType();
+
// Get the invocation ID.
- Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
+ Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
+ &rewriter);
// TODO: Load to Workgroup storage class first.
- auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
// Get the input element accessed by this invocation.
Value inputElementPtr = spirv::getElementPtr(
@@ -164,8 +167,7 @@
#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
// Get the output element accessed by this reduction.
- Value zero = spirv::ConstantOp::getZero(
- typeConverter->getIndexType(rewriter.getContext()), loc, rewriter);
+ Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
Value outputElementPtr =
spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,