[mlir][spirv] Initial support for 64 bit index type and builtins

Differential Revision: https://reviews.llvm.org/D108516
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index a94435c..c20f4ed 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -26,11 +26,11 @@
 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
 /// location invocation ID. This function will create necessary operations with
 /// `builder` at the proper region containing `op`.
-static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
-                                       OpBuilder *builder) {
+static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
+                                       Location loc, OpBuilder *builder) {
   assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
   Value invocation = spirv::getBuiltinVariableValue(
-      op, spirv::BuiltIn::LocalInvocationId, *builder);
+      op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
   Type xType = invocation.getType().cast<ShapedType>().getElementType();
   return builder->create<spirv::CompositeExtractOp>(
       loc, xType, invocation, builder->getI32ArrayAttr({dim}));
@@ -137,12 +137,15 @@
   Value convertedInput = operands[0], convertedOutput = operands[1];
   Location loc = genericOp.getLoc();
 
+  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+  auto indexType = typeConverter->getIndexType();
+
   // Get the invocation ID.
-  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
+  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
+                                      &rewriter);
 
   // TODO: Load to Workgroup storage class first.
 
-  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
 
   // Get the input element accessed by this invocation.
   Value inputElementPtr = spirv::getElementPtr(
@@ -164,8 +167,7 @@
 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP
 
   // Get the output element accessed by this reduction.
-  Value zero = spirv::ConstantOp::getZero(
-      typeConverter->getIndexType(rewriter.getContext()), loc, rewriter);
+  Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
   SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
   Value outputElementPtr =
       spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,