blob: 4c8e6708f89993887323dac9a457f3981aea536f [file] [log] [blame]
Alex Zinenkoc59ce1f2021-06-07 16:33:421//===- DataLayoutAnalysis.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataLayoutAnalysis.h"
10#include "mlir/IR/BuiltinOps.h"
11#include "mlir/IR/Operation.h"
12#include "mlir/Interfaces/DataLayoutInterfaces.h"
Mehdi Amini285a2292023-10-20 08:12:3913#include "mlir/Support/LLVM.h"
14#include <memory>
Alex Zinenkoc59ce1f2021-06-07 16:33:4215
16using namespace mlir;
17
18DataLayoutAnalysis::DataLayoutAnalysis(Operation *root)
19 : defaultLayout(std::make_unique<DataLayout>(DataLayoutOpInterface())) {
20 // Construct a DataLayout if possible from the op.
21 auto computeLayout = [this](Operation *op) {
22 if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
23 layouts[op] = std::make_unique<DataLayout>(iface);
24 if (auto module = dyn_cast<ModuleOp>(op))
25 layouts[op] = std::make_unique<DataLayout>(module);
26 };
27
28 // Compute layouts for both ancestors and descendants.
29 root->walk(computeLayout);
30 for (Operation *ancestor = root->getParentOp(); ancestor != nullptr;
31 ancestor = ancestor->getParentOp()) {
32 computeLayout(ancestor);
33 }
34}
35
36const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const {
37 for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr;
38 ancestor = ancestor->getParentOp()) {
39 auto it = layouts.find(ancestor);
40 if (it != layouts.end())
41 return *it->getSecond();
42 }
43
44 // Fallback to the default layout.
45 return *defaultLayout;
46}
47
48const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const {
49 auto it = layouts.find(operation);
50 if (it != layouts.end())
51 return *it->getSecond();
52 return getAbove(operation);
53}