blob: 7a707e5aa54bc0d845473e12e8f83b62b060a823 [file] [log] [blame]
Alexander Belyaevd2ce4352019-10-24 08:41:251//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2//
Mehdi Amini30857102020-01-26 03:58:303// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Mehdi Amini56222a02019-12-23 17:35:364// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Alexander Belyaevd2ce4352019-10-24 08:41:256//
Mehdi Amini56222a02019-12-23 17:35:367//===----------------------------------------------------------------------===//
Alexander Belyaev780a1082019-10-26 15:20:598#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
Alexander Belyaevd2ce4352019-10-24 08:41:2510
Alex Zinenko75e5f0a2021-07-08 16:35:1811#include "mlir/Conversion/LLVMCommon/Pattern.h"
River Riddle23aa5a72022-02-26 22:49:5412#include "mlir/Dialect/Func/IR/FuncOps.h"
Alexander Belyaevd2ce4352019-10-24 08:41:2513#include "mlir/Dialect/GPU/GPUDialect.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Alexander Belyaevd2ce4352019-10-24 08:41:2515#include "mlir/IR/Builders.h"
Alexander Belyaevd2ce4352019-10-24 08:41:2516
17namespace mlir {
18
Alexander Belyaev780a1082019-10-26 15:20:5919/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
20/// depending on the element type that Op operates upon. The function
21/// declaration is added in case it was not added before.
22///
Stephan Herhut2c8afe12020-06-09 15:20:5323/// If the input values are of f16 type, the value is first casted to f32, the
24/// function called and then the result casted back.
25///
Alexander Belyaev780a1082019-10-26 15:20:5926/// Example with NVVM:
River Riddle23aa5a72022-02-26 22:49:5427/// %exp_f32 = math.exp %arg_f32 : f32
Alexander Belyaev780a1082019-10-26 15:20:5928///
29/// will be transformed into
Alex Zinenkodd5165a2021-01-06 15:21:0830/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
Alexander Belyaevd2ce4352019-10-24 08:41:2531template <typename SourceOp>
Rahul Joshi563879b2020-12-10 02:18:3532struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Alexander Belyaevd2ce4352019-10-24 08:41:2533public:
Mehdi Aminicac7aab2022-01-14 01:36:0434 explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
Alexander Belyaevd2ce4352019-10-24 08:41:2535 StringRef f64Func)
Mehdi Aminicac7aab2022-01-14 01:36:0436 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
Rahul Joshi563879b2020-12-10 02:18:3537 f64Func(f64Func) {}
Alexander Belyaevd2ce4352019-10-24 08:41:2538
River Riddle31454272020-03-18 03:07:5539 LogicalResult
River Riddleef976332021-09-24 17:51:2040 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
Alexander Belyaevd2ce4352019-10-24 08:41:2541 ConversionPatternRewriter &rewriter) const override {
42 using LLVM::LLVMFuncOp;
Alexander Belyaevd2ce4352019-10-24 08:41:2543
44 static_assert(
45 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
46 "expected single result op");
47
Stephan Herhut2c8afe12020-06-09 15:20:5348 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
49 SourceOp>::value,
50 "expected op with same operand and result types");
51
52 SmallVector<Value, 1> castedOperands;
River Riddleef976332021-09-24 17:51:2053 for (Value operand : adaptor.getOperands())
Stephan Herhut2c8afe12020-06-09 15:20:5354 castedOperands.push_back(maybeCast(operand, rewriter));
55
Alex Zinenkoc69c9e02021-01-05 15:22:5356 Type resultType = castedOperands.front().getType();
57 Type funcType = getFunctionType(resultType, castedOperands);
Alex Zinenko8de43b92020-12-22 10:22:2158 StringRef funcName = getFunctionName(
59 funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
Alexander Belyaev780a1082019-10-26 15:20:5960 if (funcName.empty())
River Riddle31454272020-03-18 03:07:5561 return failure();
Alexander Belyaevd2ce4352019-10-24 08:41:2562
63 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
64 auto callOp = rewriter.create<LLVM::CallOp>(
Chris Lattnerfaf1c222021-08-30 16:31:4865 op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
Stephan Herhut2c8afe12020-06-09 15:20:5366
River Riddleef976332021-09-24 17:51:2067 if (resultType == adaptor.getOperands().front().getType()) {
Stephan Herhut2c8afe12020-06-09 15:20:5368 rewriter.replaceOp(op, {callOp.getResult(0)});
69 return success();
70 }
71
72 Value truncated = rewriter.create<LLVM::FPTruncOp>(
River Riddleef976332021-09-24 17:51:2073 op->getLoc(), adaptor.getOperands().front().getType(),
74 callOp.getResult(0));
Stephan Herhut2c8afe12020-06-09 15:20:5375 rewriter.replaceOp(op, {truncated});
River Riddle31454272020-03-18 03:07:5576 return success();
Alexander Belyaevd2ce4352019-10-24 08:41:2577 }
78
79private:
Stephan Herhut2c8afe12020-06-09 15:20:5380 Value maybeCast(Value operand, PatternRewriter &rewriter) const {
Alex Zinenkoc69c9e02021-01-05 15:22:5381 Type type = operand.getType();
Alex Zinenkodd5165a2021-01-06 15:21:0882 if (!type.isa<Float16Type>())
Stephan Herhut2c8afe12020-06-09 15:20:5383 return operand;
84
85 return rewriter.create<LLVM::FPExtOp>(
Alex Zinenkodd5165a2021-01-06 15:21:0886 operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
Stephan Herhut2c8afe12020-06-09 15:20:5387 }
88
River Riddleef976332021-09-24 17:51:2089 Type getFunctionType(Type resultType, ValueRange operands) const {
90 SmallVector<Type> operandTypes(operands.getTypes());
Alex Zinenko7ed9cfc2020-12-22 10:22:5691 return LLVM::LLVMFunctionType::get(resultType, operandTypes);
Alexander Belyaevd2ce4352019-10-24 08:41:2592 }
93
Alex Zinenkoc69c9e02021-01-05 15:22:5394 StringRef getFunctionName(Type type) const {
Alex Zinenkodd5165a2021-01-06 15:21:0895 if (type.isa<Float32Type>())
Alexander Belyaevd2ce4352019-10-24 08:41:2596 return f32Func;
Alex Zinenkodd5165a2021-01-06 15:21:0897 if (type.isa<Float64Type>())
Alexander Belyaevd2ce4352019-10-24 08:41:2598 return f64Func;
99 return "";
100 }
101
Alex Zinenkoc69c9e02021-01-05 15:22:53102 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
Alexander Belyaevd2ce4352019-10-24 08:41:25103 Operation *op) const {
104 using LLVM::LLVMFuncOp;
105
Chris Lattner41d4aa72021-08-29 21:22:24106 auto funcAttr = StringAttr::get(op->getContext(), funcName);
107 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
Alexander Belyaevd2ce4352019-10-24 08:41:25108 if (funcOp)
River Riddle4562e382019-12-18 17:28:48109 return cast<LLVMFuncOp>(*funcOp);
Alexander Belyaevd2ce4352019-10-24 08:41:25110
111 mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
Alex Zinenkofdbb99c2019-12-03 08:26:13112 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
Alexander Belyaevd2ce4352019-10-24 08:41:25113 }
114
115 const std::string f32Func;
116 const std::string f64Func;
117};
118
119} // namespace mlir
120
Alexander Belyaev780a1082019-10-26 15:20:59121#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_