blob: 213e46075a47d207218acde87e1f684ada7f241e [file] [log] [blame]
George Mitenkov89808ce2020-10-23 14:46:181//===- mlir-spirv-cpu-runner.cpp - MLIR SPIR-V Execution on CPU -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Main entry point to a command line utility that executes an MLIR file on the
10// CPU by translating MLIR GPU module and host part to LLVM IR before
11// JIT-compiling and executing.
12//
13//===----------------------------------------------------------------------===//
14
Lei Zhang930c74f2020-12-23 19:32:3115#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
16#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
George Mitenkov89808ce2020-10-23 14:46:1817#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
Mogballa54f4ea2021-10-12 23:14:5718#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
River Riddle23aa5a72022-02-26 22:49:5419#include "mlir/Dialect/Func/IR/FuncOps.h"
Alex Zinenko9a08f762021-02-10 18:21:1620#include "mlir/Dialect/GPU/GPUDialect.h"
George Mitenkov89808ce2020-10-23 14:46:1821#include "mlir/Dialect/GPU/Passes.h"
Alex Zinenko9a08f762021-02-10 18:21:1622#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Alex Zinenkob868a3e2021-03-15 17:35:4023#include "mlir/Dialect/MemRef/IR/MemRef.h"
Alex Zinenko9a08f762021-02-10 18:21:1624#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
Lei Zhang01178652020-12-17 15:55:4525#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
26#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
George Mitenkov89808ce2020-10-23 14:46:1827#include "mlir/ExecutionEngine/JitRunner.h"
28#include "mlir/ExecutionEngine/OptUtils.h"
George Mitenkov89808ce2020-10-23 14:46:1829#include "mlir/Pass/Pass.h"
30#include "mlir/Pass/PassManager.h"
Alex Zinenko19db8022021-03-03 13:08:3031#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
Alex Zinenkoce8f10d2021-02-16 16:36:4532#include "mlir/Target/LLVMIR/Export.h"
George Mitenkov89808ce2020-10-23 14:46:1833
34#include "llvm/IR/LLVMContext.h"
35#include "llvm/IR/Module.h"
36#include "llvm/Linker/Linker.h"
37#include "llvm/Support/InitLLVM.h"
38#include "llvm/Support/TargetSelect.h"
39
40using namespace mlir;
41
42/// A utility function that builds llvm::Module from two nested MLIR modules.
43///
44/// module @main {
45/// module @kernel {
46/// // Some ops
47/// }
48/// // Some other ops
49/// }
50///
51/// Each of these two modules is translated to LLVM IR module, then they are
52/// linked together and returned.
53static std::unique_ptr<llvm::Module>
54convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) {
55 // Verify that there is only one nested module.
56 auto modules = module.getOps<ModuleOp>();
57 if (!llvm::hasSingleElement(modules)) {
58 module.emitError("The module must contain exactly one nested module");
59 return nullptr;
60 }
61
62 // Translate nested module and erase it.
63 ModuleOp nested = *modules.begin();
64 std::unique_ptr<llvm::Module> kernelModule =
65 translateModuleToLLVMIR(nested, context);
66 nested.erase();
67
68 std::unique_ptr<llvm::Module> mainModule =
69 translateModuleToLLVMIR(module, context);
70 llvm::Linker::linkModules(*mainModule, std::move(kernelModule));
71 return mainModule;
72}
73
74static LogicalResult runMLIRPasses(ModuleOp module) {
75 PassManager passManager(module.getContext());
76 applyPassManagerCLOptions(passManager);
77 passManager.addPass(createGpuKernelOutliningPass());
78 passManager.addPass(createConvertGPUToSPIRVPass());
79
80 OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
81 nestedPM.addPass(spirv::createLowerABIAttributesPass());
82 nestedPM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
83 passManager.addPass(createLowerHostCodeToLLVMPass());
84 passManager.addPass(createConvertSPIRVToLLVMPass());
85 return passManager.run(module);
86}
87
88int main(int argc, char **argv) {
89 llvm::InitLLVM y(argc, argv);
90
91 llvm::InitializeNativeTarget();
92 llvm::InitializeNativeTargetAsmPrinter();
93 mlir::initializeLLVMPasses();
94
Eugene Zhulenevf6c9f6e2020-10-27 21:12:4795 mlir::JitRunnerConfig jitRunnerConfig;
Eugene Zhuleneva2973402020-10-28 00:02:1096 jitRunnerConfig.mlirTransformer = runMLIRPasses;
97 jitRunnerConfig.llvmModuleBuilder = convertMLIRModule;
Eugene Zhulenevf6c9f6e2020-10-27 21:12:4798
Alex Zinenko9a08f762021-02-10 18:21:1699 mlir::DialectRegistry registry;
Mogballa54f4ea2021-10-12 23:14:57100 registry.insert<mlir::arith::ArithmeticDialect, mlir::LLVM::LLVMDialect,
101 mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
River Riddle23aa5a72022-02-26 22:49:54102 mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
Alex Zinenkob77bac02021-02-11 14:01:33103 mlir::registerLLVMDialectTranslation(registry);
Alex Zinenko9a08f762021-02-10 18:21:16104
105 return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
George Mitenkov89808ce2020-10-23 14:46:18106}