Skip to content

Commit 821f9ec

Browse files
Addresses review comments
1 parent f0a824f commit 821f9ec

File tree

6 files changed

+36
-21
lines changed

6 files changed

+36
-21
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ typedef struct MTRT_StableHLOToExecutableOptions {
6060
void *ptr;
6161
} MTRT_StableHLOToExecutableOptions;
6262

63+
/// A callback that allows the user to customize the metadata set for layers
64+
/// corresponding to each MLIR operation. The callback should invoke the
65+
/// provided append function in order to manipulate the result string.
6366
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
6467
MlirStringCallback append,
6568
void *appendCtx, void *userData);
@@ -81,6 +84,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
8184
const char **debugTypes, size_t debugTypeSizes,
8285
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);
8386

87+
/// Sets the layer metadata callback. The `userData` argument is passed along
88+
/// to the callback when it is invoked.
8489
MLIR_CAPI_EXPORTED MTRT_Status
8590
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
8691
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Pass/PassManager.h"
3434
#include "mlir/Support/TypeID.h"
3535
#include "llvm/ADT/DenseMap.h"
36+
#include "llvm/ADT/Hashing.h"
3637
#include <memory>
3738

3839
namespace mlirtrt::compiler {
@@ -101,12 +102,18 @@ class CompilerClient {
101102
/// Create or retrieve a cached PassManager of the given derived type using
102103
/// the provided options. PassManagers are cached by type and a hash of the
103104
/// string representation of the options.
105+
/// This function should only be called if the options have a valid hash.
104106
template <typename CompilationTaskType, typename OptionsType>
105107
mlir::PassManager &getOrCreatePassManager(const OptionsType &options) {
106-
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(),
107-
options.getHash());
108+
auto hash = options.getHash();
109+
110+
assert(hash);
111+
112+
auto key =
113+
std::make_pair(mlir::TypeID::get<CompilationTaskType>(), hash.value());
108114
auto it = cachedPassManagers.find(key);
109-
if (it == cachedPassManagers.end() || options.shouldInvalidateCache()) {
115+
116+
if (it == cachedPassManagers.end()) {
110117
auto pm = std::make_unique<CompilationTaskType>(context, options);
111118
setupPassManagerLogging(*pm, options.debugOptions);
112119
auto *ptr = pm.get();
@@ -119,7 +126,6 @@ class CompilerClient {
119126
/// Return the MLIRContext associated with the client.
120127
mlir::MLIRContext *getContext() const { return context; }
121128

122-
private:
123129
/// Helper for setting the correct logging options on cached PassManagers.
124130
static void setupPassManagerLogging(mlir::PassManager &pm,
125131
const DebugOptions &options);

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,7 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
105105
/// Get the mutable DebugOptions.
106106
DebugOptions &getDebugOptions() { return debugOptions; }
107107

108-
llvm::hash_code getHash() const override;
109-
110-
bool shouldInvalidateCache() const override {
111-
// If a callback is provided, we have no way of verifying whether it is
112-
// equivalent to a callback from another set of options. Therefore, we are
113-
// forced to invalidate the cache entry if it is present at all.
114-
return static_cast<bool>(layerMetadataCallback);
115-
}
108+
std::optional<llvm::hash_code> getHash() const override;
116109

117110
/// The host index bit-width.
118111
int64_t executorIndexBitwidth{64};

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@
4040
#include "mlir/Dialect/Arith/IR/Arith.h"
4141
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
4242
#include "mlir/Dialect/Func/IR/FuncOps.h"
43+
#include "mlir/Pass/Pass.h"
4344
#include "mlir/Pass/PassManager.h"
4445
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4546
#include "mlir/Transforms/Passes.h"
4647
#include "stablehlo/dialect/StablehloOps.h"
4748
#include "llvm/Support/CommandLine.h"
4849
#include "llvm/Support/Debug.h"
4950
#include "llvm/Support/raw_ostream.h"
51+
#include <functional>
5052
#include <memory>
5153

5254
#define DEBUG_TYPE "compiler-api"
@@ -271,11 +273,12 @@ Status StableHLOToExecutableOptions::inferDeviceOptionsFromHost() {
271273
return Status::getOk();
272274
}
273275

274-
llvm::hash_code StableHLOToExecutableOptions::getHash() const {
275-
llvm::hash_code hash = OptionsContext::getHash();
276+
std::optional<llvm::hash_code> StableHLOToExecutableOptions::getHash() const {
277+
// If a callback is provided, we have no way of reliably hashing it.
276278
if (layerMetadataCallback)
277-
return llvm::hash_combine(hash, &layerMetadataCallback);
278-
return hash;
279+
return std::nullopt;
280+
281+
return OptionsContext::getHash();
279282
}
280283

281284
//===----------------------------------------------------------------------===//
@@ -473,11 +476,19 @@ StableHloToExecutableTask::compileStableHLOToExecutable(
473476
}
474477
#endif
475478

476-
mlir::PassManager &runner =
477-
client.getOrCreatePassManager<StableHloToExecutableTask>(options);
479+
mlir::PassManager *runner;
480+
481+
if (options.getHash())
482+
runner = &client.getOrCreatePassManager<StableHloToExecutableTask>(options);
483+
else {
484+
auto pm = std::make_unique<StableHloToExecutableTask>(client.getContext(),
485+
options);
486+
CompilerClient::setupPassManagerLogging(*pm, options.debugOptions);
487+
runner = pm.get();
488+
}
478489

479490
// Setup pass manager
480-
if (failed(runner.run(module)))
491+
if (failed(runner->run(module)))
481492
return getInternalErrorStatus(
482493
"failed to run compilation on module with symbol name: {0}",
483494
module.getName() ? *module.getName() : "no-symbol-name");

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class OptionsContext : public llvm::cl::SubCommand {
149149
void print(llvm::raw_ostream &os) const;
150150

151151
/// Get a hash derived from the string representation of the options.
152-
virtual llvm::hash_code getHash() const;
152+
virtual std::optional<llvm::hash_code> getHash() const;
153153

154154
virtual bool shouldInvalidateCache() const { return false; }
155155

mlir-tensorrt/tensorrt/lib/Utils/Options.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void OptionsContext::print(llvm::raw_ostream &os) const {
6363
" ");
6464
}
6565

66-
llvm::hash_code OptionsContext::getHash() const {
66+
std::optional<llvm::hash_code> OptionsContext::getHash() const {
6767
// We hash by just hashing the string representation.
6868
llvm::SmallString<128> str;
6969
{

0 commit comments

Comments
 (0)