Skip to content

Commit 47c11d5

Browse files
Adds ExecutorOptions and DeviceOptions options providers
- Moves `executorIndexBitwidth` and `executorUsePackedMemRefCConv` into its own options bundle. - Adds a `DeviceOptions` provider and updates the OptionsContext and OptionsProviders to use `llvm::Error`s instead of `mlirtrt::Status` since the latter is not accessible to the OptionsContext.
1 parent 9163039 commit 47c11d5

File tree

14 files changed

+253
-176
lines changed

14 files changed

+253
-176
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#define MLIR_TENSORRT_COMPILER_CLIENT
2929

3030
#include "mlir-executor/Support/Status.h"
31-
#include "mlir-tensorrt/Compiler/Options.h"
31+
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
3232
#include "mlir/IR/MLIRContext.h"
3333
#include "mlir/Pass/PassManager.h"
3434
#include "mlir/Support/TypeID.h"

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Options.h

Lines changed: 0 additions & 62 deletions
This file was deleted.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//===- OptionsProviders.h ---------------------------------------*- C++ -*-===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// Data structures and functions for manipulating compiler options.
22+
///
23+
//===----------------------------------------------------------------------===//
24+
#ifndef MLIR_TENSORRT_COMPILER_OPTIONS
25+
#define MLIR_TENSORRT_COMPILER_OPTIONS
26+
27+
#include "mlir-executor/Support/DeviceInfo.h"
28+
#include "mlir-tensorrt-dialect/Utils/Options.h"
29+
#include "mlir/Support/LLVM.h"
30+
#include "llvm/Support/CommandLine.h"
31+
#include "llvm/Support/Error.h"
32+
#include <string>
33+
34+
namespace mlirtrt::compiler {
35+
36+
// Use SFINAE to check whether the `finalizeImpl()` method is defined on a type.
37+
// If it is, the specialization (where the value is true) will be the better
38+
// match. Otherwise, we'll get the default value of false.
39+
template <typename, typename = void>
40+
constexpr bool has_finalize_impl_v = false;
41+
42+
template <typename T>
43+
constexpr bool has_finalize_impl_v<
44+
T, std::void_t<decltype(std::declval<T>().finalizeImpl())>> = true;
45+
46+
// We use CRTP here so we can call `finalizeImpl()` if it's defined or provide
47+
// a default implementation otherwise.
48+
template <typename Derived>
49+
struct OptionsProvider {
50+
/// Modifies options after parsing. This is required since we may need
51+
/// to make changes to options based on the values of other options.
52+
/// Do *not* override this method; instead, implement `finalizeImpl()`.
53+
llvm::Error finalize() {
54+
if constexpr (has_finalize_impl_v<Derived>)
55+
return static_cast<Derived *>(this)->finalizeImpl();
56+
else
57+
return llvm::Error::success();
58+
}
59+
};
60+
61+
/// DebugOptions are options that are common to different compiler API
62+
/// interfaces.
63+
struct DebugOptions : public OptionsProvider<DebugOptions> {
64+
public:
65+
/// A directory path where the IR will be dumped during compilation
66+
/// using the `mlir-print-ir-tree-dir` mechanism.
67+
std::string dumpIRPath = "";
68+
69+
/// Whether the LLVM 'debug' flag that enables execution of code guarded by
70+
/// the `LLVM_DEBUG` macro should be set to 'on'. This results in very verbose
71+
/// output from the compiler dumped to stderr.
72+
bool enableLLVMDebugFlag = false;
73+
74+
/// A set of names to be given to the LLVM 'debug types' option, akin to
75+
/// setting
76+
/// `-debug-types=...` from the command line.
77+
mlir::SmallVector<std::string> llvmDebugTypes = {};
78+
79+
public:
80+
void addToOptions(mlir::OptionsContext &context) {
81+
context.addOption("mlir-print-ir-tree-dir", dumpIRPath, llvm::cl::init(""));
82+
context.addOption("debug", enableLLVMDebugFlag);
83+
context.addList<std::string>("debug-only", llvmDebugTypes,
84+
llvm::cl::ZeroOrMore,
85+
llvm::cl::CommaSeparated);
86+
}
87+
};
88+
89+
struct ExecutorOptions : public OptionsProvider<ExecutorOptions> {
90+
public:
91+
/// The host index bit-width.
92+
int64_t indexBitwidth{64};
93+
94+
/// Whether to pass memref's as struct/table in function calls.
95+
bool usePackedMemRefCConv{true};
96+
97+
public:
98+
void addToOptions(mlir::OptionsContext &context) {
99+
context.addOption("executor-index-bitwidth", indexBitwidth,
100+
llvm::cl::init(64));
101+
}
102+
};
103+
104+
struct DeviceOptions : public OptionsProvider<DeviceOptions> {
105+
public:
106+
DeviceInfo info;
107+
108+
/// Whether to ignore `deviceX` options and instead infer them from the GPUs
109+
/// on the host system running the compilation.
110+
bool shouldInferFromHost = false;
111+
112+
public:
113+
void addToOptions(mlir::OptionsContext &context) {
114+
context.addOption(
115+
"device-compute-capability", info.computeCapability, llvm::cl::init(60),
116+
llvm::cl::desc("Sets the device compute capbility. Only relevant "
117+
"if '--device-infer-from-host=false'"));
118+
context.addOption("device-max-shared-memory-per-block-kb",
119+
info.maxSharedMemoryPerBlockKb, llvm::cl::init(48));
120+
context.addOption("device-max-registers-per-block",
121+
info.maxRegistersPerBlock, llvm::cl::init(65536));
122+
context.addOption("device-infer-from-host", shouldInferFromHost,
123+
llvm::cl::init(true),
124+
llvm::cl::desc("Infers device information from host"));
125+
}
126+
127+
llvm::Error finalizeImpl();
128+
};
129+
130+
} // namespace mlirtrt::compiler
131+
132+
#endif // MLIR_TENSORRT_COMPILER_OPTIONS

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
3333
#include "llvm/ADT/ArrayRef.h"
3434
#include "llvm/ADT/StringRef.h"
35+
#include "llvm/Support/Error.h"
3536
#include <functional>
3637

3738
namespace mlirtrt::compiler {
@@ -71,14 +72,16 @@ optionsCreateFromArgs(const CompilerClient &client,
7172
llvm::iterator_range(args), err);
7273
}
7374

74-
// TODO: Figure out whether to add a method in the base class like
75-
// "finalizeOptions" or a callback here, or something else if
76-
// `inferDeviceOptionsFromHost` is unique to StableHLO.
77-
//
78-
// Populate device options from host information.
79-
Status inferStatus = result->inferDeviceOptionsFromHost();
80-
if (!inferStatus.isOk())
81-
return inferStatus;
75+
llvm::Error finalizeStatus = result->finalize();
76+
77+
std::optional<std::string> errMsg{};
78+
llvm::handleAllErrors(
79+
std::move(finalizeStatus),
80+
[&errMsg](const llvm::StringError &err) { errMsg = err.getMessage(); });
81+
82+
if (errMsg)
83+
return getInternalErrorStatus("failed to initialize options: %s",
84+
errMsg->c_str());
8285

8386
return std::unique_ptr<mlir::OptionsContext>(result.release());
8487
}

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "mlir-tensorrt-dialect/Utils/OptionsBundle.h"
3737
#include "mlir-tensorrt/Compiler/Client.h"
3838
#include "mlir-tensorrt/Compiler/Extension.h"
39-
#include "mlir-tensorrt/Compiler/Options.h"
39+
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
4040
#include "mlir/IR/BuiltinOps.h"
4141
#include "mlir/Pass/PassManager.h"
4242
#include "mlir/Support/TypeID.h"
@@ -51,46 +51,16 @@ namespace mlirtrt::compiler {
5151

5252
class StableHloToExecutableTask;
5353

54-
struct StableHLOToExecutableOptions : public mlir::OptionsBundle<DebugOptions> {
54+
struct StableHLOToExecutableOptions
55+
: public mlir::OptionsBundle<DebugOptions, ExecutorOptions, DeviceOptions> {
5556
/// Initializes the options. The extensions in the provided registry
5657
/// must be extensions for the StableHloToExecutable task.
5758
StableHLOToExecutableOptions(TaskExtensionRegistry extensions);
5859

59-
/// Set the target device compute capability (SM version) and max shared
60-
/// memory per block (in kilobytes). The `maxSharedMemoryPerBlockKb` is the
61-
/// maximum shared memory per block allowed for kernels and is passed to the
62-
/// TensorRT builder.
63-
StableHLOToExecutableOptions &
64-
setDeviceOptions(int64_t computeCapability,
65-
int64_t maxSharedMemoryPerBlockKb);
66-
67-
/// Infer target device information from the first visible CUDA device on the
68-
/// host executing this code.
69-
Status inferDeviceOptionsFromHost();
70-
7160
/// Return the hash of the options. Returns `nullopt` when the TensorRT
7261
/// layer metadata callback is set since that can't be reliably hashed.
7362
std::optional<llvm::hash_code> getHash() const override;
7463

75-
/// The host index bit-width.
76-
int64_t executorIndexBitwidth{64};
77-
78-
/// Whether to pass memref's as struct/table in function calls.
79-
bool executorUsePackedMemRefCConv{true};
80-
81-
/// Target device compute capability (SM version)
82-
int64_t deviceComputeCapability;
83-
84-
/// Target device max shared memory per block (kilobytes)
85-
int64_t deviceMaxSharedMemoryPerBlockKb;
86-
87-
/// Target device maximum 4-byte register sper block.
88-
uint64_t deviceMaxRegistersPerBlock;
89-
90-
/// Whether to ignore `deviceX` options and instead infer them from the GPUs
91-
/// on the host system running the compilation.
92-
bool shouldInferDeviceOptionsFromHost = false;
93-
9464
/// Whether to disallow host tensors in TensorRT clusters.
9565
bool disallowHostTensorsInTensorRTClusters = false;
9666

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,15 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreate(
169169
auto result =
170170
std::make_unique<StableHLOToExecutableOptions>(std::move(extensions));
171171

172-
/// Populate device options from host information.
173-
Status inferStatus = result->inferDeviceOptionsFromHost();
174-
if (!inferStatus.isOk())
175-
return wrap(inferStatus);
172+
llvm::Error finalizeStatus = result->finalize();
173+
174+
std::optional<std::string> errMsg{};
175+
llvm::handleAllErrors(
176+
std::move(finalizeStatus),
177+
[&errMsg](const llvm::StringError &err) { errMsg = err.getMessage(); });
178+
179+
if (errMsg)
180+
return wrap(getInternalErrorStatus(errMsg->c_str()));
176181

177182
*options = wrap(result.release());
178183
return mtrtStatusGetOk();
@@ -209,10 +214,15 @@ MTRT_Status mtrtStableHloToExecutableOptionsCreateFromArgs(
209214
"failed to parse options string {0} due to error: {1}", line, err));
210215
}
211216

212-
/// Populate device options from host information.
213-
Status inferStatus = result->inferDeviceOptionsFromHost();
214-
if (!inferStatus.isOk())
215-
return wrap(inferStatus);
217+
llvm::Error finalizeStatus = result->finalize();
218+
219+
std::optional<std::string> errMsg{};
220+
llvm::handleAllErrors(
221+
std::move(finalizeStatus),
222+
[&errMsg](const llvm::StringError &err) { errMsg = err.getMessage(); });
223+
224+
if (errMsg)
225+
return wrap(getInternalErrorStatus(errMsg->c_str()));
216226

217227
*options = wrap(result.release());
218228
return mtrtStatusGetOk();

mlir-tensorrt/compiler/lib/Compiler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_tensorrt_library(MLIRTensorRTCompilerClient
22
Client.cpp
33
Extension.cpp
44
OptionsRegistry.cpp
5+
OptionsProviders.cpp
56
PARTIAL_SOURCES_INTENDED
67

78
LINK_LIBS PUBLIC
@@ -11,6 +12,7 @@ add_mlir_tensorrt_library(MLIRTensorRTCompilerClient
1112
MLIRTensorRTOptionUtils
1213
MLIRTensorRTTargetTensorRT
1314
StablehloLinalgTransforms
15+
MLIRTensorRTSupportDeviceInfo
1416
)
1517

1618
add_mlir_tensorrt_library(MLIRTensorRTCompilerStableHloToExecutable
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- OptionsProviders.cpp -------------------------------------*- C++ -*-===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// Data structures and functions for manipulating compiler options.
22+
///
23+
//===----------------------------------------------------------------------===//
24+
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
25+
#include "mlir-executor/Support/DeviceInfo.h"
26+
#include "llvm/Support/Error.h"
27+
28+
llvm::Error mlirtrt::compiler::DeviceOptions::finalizeImpl() {
29+
if (shouldInferFromHost) {
30+
StatusOr<DeviceInfo> deviceInfo = getDeviceInformationFromHost();
31+
32+
if (!deviceInfo.isOk())
33+
return llvm::createStringError(deviceInfo.getString());
34+
35+
info = *deviceInfo;
36+
}
37+
return llvm::Error::success();
38+
}

0 commit comments

Comments
 (0)