Skip to content

Commit c17ace7

Browse files
Move internal change: [compiler] Add compilation task registry (#465)
This change adds a CompilationTask (cached pass managers) registry which enables creating and looking up cached compilation tasks from the Python API by just passing a mnemonic task name and a list of string options. GitOrigin-RevId: f19e634e8ff8338809fe2c1b8efa730ef2e14f21
1 parent 7dfa0fb commit c17ace7

File tree

20 files changed

+568
-199
lines changed

20 files changed

+568
-199
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) {
5252
return !options.ptr;
5353
}
5454

55+
MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerClientGetCompilationTask(
56+
MTRT_CompilerClient client, MlirStringRef taskMnemonic,
57+
const MlirStringRef *argv, unsigned argc, MlirPassManager *result);
58+
5559
//===----------------------------------------------------------------------===//
5660
// MTRT_OptionsContext
5761
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 125 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include "mlir-executor/Support/Status.h"
3131
#include "mlir-tensorrt/Compiler/OptionsProviders.h"
32+
#include "mlir-tensorrt/Compiler/OptionsRegistry.h"
3233
#include "mlir/IR/MLIRContext.h"
3334
#include "mlir/Pass/PassManager.h"
3435
#include "mlir/Support/TypeID.h"
@@ -100,28 +101,69 @@ class CompilerClient {
100101

101102
~CompilerClient() = default;
102103

103-
/// Create or retrieve a cached PassManager of the given derived type using
104-
/// the provided options. PassManagers are cached by type and a hash of the
105-
/// string representation of the options.
106-
/// This function should only be called if the options have a valid hash.
107-
template <typename CompilationTaskType, typename OptionsType>
108-
mlir::PassManager &getOrCreatePassManager(const OptionsType &options) {
109-
std::optional<llvm::hash_code> hash = options.getHash();
110-
if (!hash)
111-
llvm::report_fatal_error("attempted to lookup a PassManager from a cache "
112-
"with an un-hashable options key");
113-
114-
auto key =
115-
std::make_pair(mlir::TypeID::get<CompilationTaskType>(), hash.value());
104+
/// Create or retrieve from the cache a compilation task of the specified
105+
/// type and options. If an existing compilation task is not in the cache,
106+
/// then it is constructed using the registered construction function and
107+
/// inserted into the cache.
108+
StatusOr<CompilationTaskBase *>
109+
getCompilationTask(mlir::TypeID taskID,
110+
llvm::ArrayRef<llvm::StringRef> options);
111+
112+
/// Create or retrieve from the cache a compilation task of the specified
113+
/// type ID and options. If an existing compilation task is not in the cache,
114+
/// then it is constructed using the registered construction function and
115+
/// inserted into the cache.
116+
StatusOr<CompilationTaskBase *>
117+
getCompilationTask(mlir::TypeID taskID, llvm::ArrayRef<std::string> options) {
118+
return getCompilationTask(
119+
taskID, llvm::map_to_vector(options, [](const std::string &x) {
120+
return llvm::StringRef(x);
121+
}));
122+
}
123+
124+
StatusOr<CompilationTaskBase *>
125+
getCompilationTask(llvm::StringRef mnemonic,
126+
llvm::ArrayRef<llvm::StringRef> options);
127+
128+
/// Create or retrieve from the cache a compilation task of the specified
129+
/// type and options. If an existing compilation task is not in the cache,
130+
/// then it is constructed using the registered construction function and
131+
/// inserted into the cache.
132+
template <typename T, typename... Args>
133+
StatusOr<CompilationTaskBase *> getCompilationTask(Args &&...args) {
134+
return getCompilationTask(mlir::TypeID::get<T>(),
135+
std::forward<Args>(args)...);
136+
}
137+
138+
/// Insert a compilation task of type T with options hash `hash` into the
139+
/// cache.
140+
template <typename T>
141+
void updateCachedCompilationTask(const llvm::hash_code &hash,
142+
std::unique_ptr<CompilationTaskBase> task) {
143+
cachedPassManagers[std::make_pair(mlir::TypeID::get<T>(), hash)] =
144+
std::move(task);
145+
}
146+
147+
/// Check whether a CompilationTask with the specified typeID and whose
148+
/// options have the given hash is in the cache. If so, return it; otherwise
149+
/// returns nullptr.
150+
CompilationTaskBase *
151+
lookupCachedCompilationTask(mlir::TypeID taskID,
152+
const llvm::hash_code &optionsHash) {
153+
auto key = std::make_pair(taskID, optionsHash);
116154
auto it = cachedPassManagers.find(key);
117-
if (it == cachedPassManagers.end()) {
118-
auto pm = std::make_unique<CompilationTaskType>(context, options);
119-
setupPassManagerLogging(*pm, options.template get<DebugOptions>());
120-
auto *ptr = pm.get();
121-
cachedPassManagers[key] = std::move(pm);
122-
return *ptr;
123-
}
124-
return *it->second;
155+
if (it == cachedPassManagers.end())
156+
return nullptr;
157+
return it->second.get();
158+
}
159+
160+
/// Check whether a CompilationTask with the specified type T and whose
161+
/// options have the given hash is in the cache. If so, return it; otherwise
162+
/// returns nullptr.
163+
template <typename T>
164+
CompilationTaskBase *
165+
lookupCachedCompilationTask(const llvm::hash_code &optionsHash) {
166+
return lookupCachedCompilationTask(mlir::TypeID::get<T>(), optionsHash);
125167
}
126168

127169
/// Return the MLIRContext associated with the client.
@@ -147,6 +189,68 @@ class CompilerClient {
147189
cachedPassManagers;
148190
};
149191

192+
/// A registry function that adds passes to the given pass manager. This should
193+
/// also parse options and return success() if parsing succeeded.
194+
/// `errorHandler` is a functor used to emit errors during parsing.
195+
/// parameter corresponds to the raw location within the pipeline string. This
196+
/// should always return failure.
197+
using TaskRegistryFunction = std::function<StatusOr<CompilationTaskBase *>(
198+
CompilerClient &client, llvm::ArrayRef<llvm::StringRef> options)>;
199+
200+
struct TaskRegistration {
201+
TaskRegistryFunction registryFunc;
202+
};
203+
204+
void registerCompilationTask(llvm::StringRef mnemonic, mlir::TypeID typeID,
205+
TaskRegistryFunction func);
206+
207+
template <typename T>
208+
void registerCompilationTask(llvm::StringRef mnemonic,
209+
TaskRegistryFunction func) {
210+
return registerCompilationTask(mnemonic, mlir::TypeID::get<T>(),
211+
std::move(func));
212+
}
213+
214+
template <typename T, typename OptionsType>
215+
void registerCompilationTaskWithNoExtensions(llvm::StringRef mnemonic) {
216+
registerCompilationTask<T>(
217+
mnemonic,
218+
[](CompilerClient &client, llvm::ArrayRef<llvm::StringRef> options)
219+
-> StatusOr<CompilationTaskBase *> {
220+
OptionsType result;
221+
std::string err;
222+
if (failed(result.parse(options, err)))
223+
return getInvalidArgStatus(
224+
"failed to parse options string \"{0:$[ ]}\" due to error {1}",
225+
llvm::iterator_range(options), err);
226+
227+
llvm::Error finalizeStatus = result.finalize();
228+
std::optional<std::string> errMsg{};
229+
llvm::handleAllErrors(std::move(finalizeStatus),
230+
[&errMsg](const llvm::StringError &err) {
231+
errMsg = err.getMessage();
232+
});
233+
234+
if (errMsg)
235+
return getInvalidArgStatus("failed to parse options due to error {0}",
236+
errMsg);
237+
238+
std::optional<llvm::hash_code> hashCode = result.getHash();
239+
if (!hashCode)
240+
return getInvalidArgStatus("failed to hash options");
241+
242+
CompilationTaskBase *cached =
243+
client.lookupCachedCompilationTask<T>(*hashCode);
244+
if (cached)
245+
return cached;
246+
247+
auto newPM = std::make_unique<T>(client.getContext(), result);
248+
auto ptr = newPM.get();
249+
client.updateCachedCompilationTask<T>(*hashCode, std::move(newPM));
250+
return ptr;
251+
});
252+
}
253+
150254
} // namespace mlirtrt::compiler
151255

152256
#endif // MLIR_TENSORRT_COMPILER_CLIENT

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ constexpr bool has_finalize_impl_v<
4747
// a default implementation otherwise.
4848
template <typename Derived>
4949
struct OptionsProvider {
50+
OptionsProvider(mlir::OptionsContext &ctx) : ctx(ctx) {}
51+
52+
// We don't allow move construction since the actual ptrs/locations of
53+
// individual member elements of an OptionsProvider are captured into the
54+
// OptionsContext. If the OptionsContext is populated upon construction,
55+
// moving can change the memory location of the owned values, which will cause
56+
// a crash later on. This is in particular can happen if you are constructing
57+
// a tuple of `OptionsProviders`. Since we are deleting the move constructor,
58+
// one must instead use a tuple of `unique_ptr<OptionsProviders...>`.
59+
OptionsProvider(OptionsProvider &&) = delete;
60+
61+
mlir::OptionsContext &ctx;
62+
63+
template <typename T, typename... Mods>
64+
using Option = mlir::OptionsContext::Option<T, Mods...>;
65+
template <typename T, typename... Mods>
66+
using ListOption = mlir::OptionsContext::ListOption<T, Mods...>;
67+
5068
/// Modifies options after parsing. This is required since we may need
5169
/// to make changes to options based on the values of other options.
5270
/// Do *not* override this method; instead, implement `finalizeImpl()`.
@@ -62,67 +80,63 @@ struct OptionsProvider {
6280
/// interfaces.
6381
struct DebugOptions : public OptionsProvider<DebugOptions> {
6482
public:
83+
using OptionsProvider::OptionsProvider;
6584
/// A directory path where the IR will be dumped during compilation
6685
/// using the `mlir-print-ir-tree-dir` mechanism.
67-
std::string dumpIRPath = "";
86+
Option<std::string> dumpIRPath{&this->ctx, "mlir-print-ir-tree-dir",
87+
llvm::cl::init("")};
6888

6989
/// Whether the LLVM 'debug' flag that enables execution of code guarded by
7090
/// the `LLVM_DEBUG` macro should be set to 'on'. This results in very verbose
7191
/// output from the compiler dumped to stderr.
72-
bool enableLLVMDebugFlag = false;
92+
Option<bool> enableLLVMDebugFlag{&this->ctx, "debug", llvm::cl::init(false)};
7393

7494
/// A set of names to be given to the LLVM 'debug types' option, akin to
7595
/// setting
7696
/// `-debug-types=...` from the command line.
77-
mlir::SmallVector<std::string> llvmDebugTypes = {};
78-
79-
public:
80-
void addToOptions(mlir::OptionsContext &context) {
81-
context.addOption("mlir-print-ir-tree-dir", dumpIRPath, llvm::cl::init(""));
82-
context.addOption("debug", enableLLVMDebugFlag);
83-
context.addList<std::string>("debug-only", llvmDebugTypes,
84-
llvm::cl::ZeroOrMore,
85-
llvm::cl::CommaSeparated);
86-
}
97+
ListOption<std::string> llvmDebugTypes{
98+
&this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated};
8799
};
88100

89101
struct ExecutorOptions : public OptionsProvider<ExecutorOptions> {
90102
public:
91-
/// The host index bit-width.
92-
int64_t indexBitwidth{64};
103+
using OptionsProvider::OptionsProvider;
93104

94-
/// Whether to pass memref's as struct/table in function calls.
95-
bool usePackedMemRefCConv{true};
105+
Option<int64_t> indexBitwidth{&this->ctx, "executor-index-bitwidth",
106+
llvm::cl::init(64),
107+
llvm::cl::desc("executor index bitwidth")};
96108

97-
public:
98-
void addToOptions(mlir::OptionsContext &context) {
99-
context.addOption("executor-index-bitwidth", indexBitwidth,
100-
llvm::cl::init(64));
101-
}
109+
Option<bool> usePackedMemRefCConv{
110+
&this->ctx, "executor-use-packed-memref-cconv", llvm::cl::init(true),
111+
llvm::cl::desc(
112+
"whether to use packed or unpacked memref calling convention")};
102113
};
103114

104115
struct DeviceOptions : public OptionsProvider<DeviceOptions> {
105116
public:
117+
using OptionsProvider::OptionsProvider;
118+
119+
/// Device information. Members are manually bound to options in the
120+
/// constructor.
106121
DeviceInfo info;
107122

108-
/// Whether to ignore `deviceX` options and instead infer them from the GPUs
109-
/// on the host system running the compilation.
110-
bool shouldInferFromHost = false;
123+
Option<bool> shouldInferFromHost{
124+
&this->ctx, "device-infer-from-host", llvm::cl::init(true),
125+
llvm::cl::desc("whether to ignore `deviceX` options and instead infer "
126+
"them from the host GPU")};
127+
111128
Status inferFromHost();
112129

113130
public:
114-
void addToOptions(mlir::OptionsContext &context) {
115-
context.addOption(
131+
DeviceOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {
132+
ctx.addOption(
116133
"device-compute-capability", info.computeCapability, llvm::cl::init(60),
117134
llvm::cl::desc("Sets the device compute capbility. Only relevant "
118135
"if '--device-infer-from-host=false'"));
119-
context.addOption("device-max-shared-memory-per-block-kb",
120-
info.maxSharedMemoryPerBlockKb, llvm::cl::init(48));
121-
context.addOption("device-max-registers-per-block",
122-
info.maxRegistersPerBlock, llvm::cl::init(65536));
123-
context.addOption("device-infer-from-host", shouldInferFromHost,
124-
llvm::cl::init(true),
125-
llvm::cl::desc("Infers device information from host"));
136+
ctx.addOption("device-max-shared-memory-per-block-kb",
137+
info.maxSharedMemoryPerBlockKb, llvm::cl::init(48));
138+
ctx.addOption("device-max-registers-per-block", info.maxRegistersPerBlock,
139+
llvm::cl::init(65536));
126140
}
127141

128142
llvm::Error finalizeImpl();

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
#define MLIR_TENSORRT_COMPILER_OPTIONS_REGISTRY
2929

3030
#include "mlir-tensorrt-dialect/Utils/Options.h"
31-
#include "mlir-tensorrt/Compiler/Client.h"
3231
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
32+
#include "mlir/IR/MLIRContext.h"
3333
#include "llvm/ADT/ArrayRef.h"
3434
#include "llvm/ADT/StringRef.h"
3535
#include "llvm/Support/Error.h"
@@ -39,25 +39,23 @@ namespace mlirtrt::compiler {
3939

4040
using OptionsConstructorFuncT =
4141
std::function<StatusOr<std::unique_ptr<mlir::OptionsContext>>(
42-
const CompilerClient &client, const llvm::ArrayRef<llvm::StringRef>)>;
42+
mlir::MLIRContext *, llvm::ArrayRef<llvm::StringRef>)>;
4343

4444
/// Registers an options creation function for a specific options type.
45-
void registerOption(const llvm::StringRef optionsType,
46-
OptionsConstructorFuncT func);
45+
void registerOption(llvm::StringRef optionsType, OptionsConstructorFuncT func);
4746

4847
/// Creates an options instance for the specified options type using a creation
4948
/// function that was previously registered.
5049
StatusOr<std::unique_ptr<mlir::OptionsContext>>
51-
createOptions(const CompilerClient &client, const llvm::StringRef optionsType,
52-
const llvm::ArrayRef<llvm::StringRef> args);
50+
createOptions(mlir::MLIRContext *client, llvm::StringRef optionsType,
51+
llvm::ArrayRef<llvm::StringRef> args);
5352

5453
/// Helper to build callbacks that can create options.
5554
template <typename OptionsT, typename TaskT>
56-
StatusOr<std::unique_ptr<mlir::OptionsContext>>
57-
optionsCreateFromArgs(const CompilerClient &client,
58-
const llvm::ArrayRef<llvm::StringRef> args) {
55+
StatusOr<std::unique_ptr<OptionsT>>
56+
optionsCreateFromArgs(mlir::MLIRContext *context,
57+
llvm::ArrayRef<llvm::StringRef> args) {
5958
// Load available extensions.
60-
mlir::MLIRContext *context = client.getContext();
6159
mlir::plan::PlanDialect *planDialect =
6260
context->getLoadedDialect<mlir::plan::PlanDialect>();
6361
compiler::TaskExtensionRegistry extensions =
@@ -83,7 +81,7 @@ optionsCreateFromArgs(const CompilerClient &client,
8381
return getInternalErrorStatus("failed to initialize options: %s",
8482
errMsg->c_str());
8583

86-
return std::unique_ptr<mlir::OptionsContext>(result.release());
84+
return result;
8785
}
8886
} // namespace mlirtrt::compiler
8987

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,19 @@ struct StablehloToExecutableOptions
5858
StablehloToExecutableOptions(TaskExtensionRegistry extensions);
5959

6060
/// Whether to disallow host tensors in TensorRT clusters.
61-
bool disallowHostTensorsInTensorRTClusters = false;
61+
Option<bool> disallowHostTensorsInTensorRTClusters{
62+
this, "plan-clustering-disallow-host-tensors-in-tensorrt-clusters",
63+
llvm::cl::init(false),
64+
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
65+
"calculations (but they can still be inputs)")};
66+
67+
Option<std::string> entrypoint{this, "entrypoint", llvm::cl::init("main"),
68+
llvm::cl::desc("entrypoint function name")};
6269

6370
/// Use non-DPS style calling convention for entrypoint function
6471
/// and backend types that support allocating results.
6572
bool enableNonDPSReturns = false;
6673

67-
/// Entrypoint function name.
68-
std::string entrypoint = "main";
69-
7074
/// Base class for extensions associated with StableHloToExecutableTask.
7175
class ExtensionBase : public TaskExtensionBase {
7276
public:
@@ -134,13 +138,6 @@ class StablehloToExecutableTask
134138
static void populatePassManager(mlir::PassManager &pm,
135139
const StablehloToExecutableOptions &options);
136140

137-
/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
138-
/// This is the "functional" entrypoint that will allocate a new PassManager
139-
/// for a single run.
140-
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
141-
compileStableHLOToExecutable(mlir::ModuleOp module,
142-
const StablehloToExecutableOptions &options);
143-
144141
/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
145142
/// This is the "functional" entrypoint that will allocate a new PassManager
146143
/// for a single run.

0 commit comments

Comments
 (0)