Skip to content

Commit 280ecaa

Browse files
Addresses review comments
1 parent f0a824f commit 280ecaa

File tree

6 files changed

+43
-18
lines changed

6 files changed

+43
-18
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ typedef struct MTRT_StableHLOToExecutableOptions {
6060
void *ptr;
6161
} MTRT_StableHLOToExecutableOptions;
6262

63+
/// A callback that allows the user to customize the metadata set for layers
64+
/// corresponding to each MLIR operation. The callback should invoke the
65+
/// provided append function in order to manipulate the result string.
6366
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
6467
MlirStringCallback append,
6568
void *appendCtx, void *userData);
@@ -81,6 +84,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
8184
const char **debugTypes, size_t debugTypeSizes,
8285
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);
8386

87+
/// Sets the layer metadata callback. The `userData` argument is passed along
88+
/// to the callback when it is invoked.
8489
MLIR_CAPI_EXPORTED MTRT_Status
8590
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
8691
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,32 @@
3333
#include "mlir/Pass/PassManager.h"
3434
#include "mlir/Support/TypeID.h"
3535
#include "llvm/ADT/DenseMap.h"
36+
#include "llvm/ADT/DenseMapInfo.h"
37+
#include "llvm/ADT/Hashing.h"
3638
#include <memory>
3739

40+
namespace llvm {
41+
42+
template <>
43+
struct DenseMapInfo<std::optional<llvm::hash_code>> {
44+
static inline std::optional<llvm::hash_code> getEmptyKey() { return {}; }
45+
static inline std::optional<llvm::hash_code> getTombstoneKey() {
46+
// TODO: how to provide a tombstone key?
47+
return {};
48+
}
49+
static unsigned getHashValue(const std::optional<llvm::hash_code> &val) {
50+
if (!val)
51+
return 0U;
52+
return hash_value(*val);
53+
}
54+
static bool isEqual(const std::optional<llvm::hash_code> &LHS,
55+
const std::optional<llvm::hash_code> &RHS) {
56+
return LHS == RHS;
57+
}
58+
};
59+
60+
} // namespace llvm
61+
3862
namespace mlirtrt::compiler {
3963

4064
//===----------------------------------------------------------------------===//
@@ -103,10 +127,11 @@ class CompilerClient {
103127
/// string representation of the options.
104128
template <typename CompilationTaskType, typename OptionsType>
105129
mlir::PassManager &getOrCreatePassManager(const OptionsType &options) {
106-
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(),
107-
options.getHash());
130+
auto hash = options.getHash();
131+
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(), hash);
108132
auto it = cachedPassManagers.find(key);
109-
if (it == cachedPassManagers.end() || options.shouldInvalidateCache()) {
133+
134+
if (!hash || it == cachedPassManagers.end()) {
110135
auto pm = std::make_unique<CompilationTaskType>(context, options);
111136
setupPassManagerLogging(*pm, options.debugOptions);
112137
auto *ptr = pm.get();
@@ -131,7 +156,8 @@ class CompilerClient {
131156
mlir::MLIRContext *context;
132157

133158
/// Key pair of [PassManager Kind, Options Hash].
134-
using PassManagerKey = std::pair<mlir::TypeID, llvm::hash_code>;
159+
using PassManagerKey =
160+
std::pair<mlir::TypeID, std::optional<llvm::hash_code>>;
135161

136162
/// A registry of pass managers for specific kinds of tasks. The map is
137163
/// indexed by the TypeID of the PassManager kind and the hash of the options

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,7 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
105105
/// Get the mutable DebugOptions.
106106
DebugOptions &getDebugOptions() { return debugOptions; }
107107

108-
llvm::hash_code getHash() const override;
109-
110-
bool shouldInvalidateCache() const override {
111-
// If a callback is provided, we have no way of verifying whether it is
112-
// equivalent to a callback from another set of options. Therefore, we are
113-
// forced to invalidate the cache entry if it is present at all.
114-
return static_cast<bool>(layerMetadataCallback);
115-
}
108+
std::optional<llvm::hash_code> getHash() const override;
116109

117110
/// The host index bit-width.
118111
int64_t executorIndexBitwidth{64};

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,12 @@ Status StableHLOToExecutableOptions::inferDeviceOptionsFromHost() {
271271
return Status::getOk();
272272
}
273273

274-
llvm::hash_code StableHLOToExecutableOptions::getHash() const {
275-
llvm::hash_code hash = OptionsContext::getHash();
274+
std::optional<llvm::hash_code> StableHLOToExecutableOptions::getHash() const {
275+
// If a callback is provided, we have no way of reliably hashing it.
276276
if (layerMetadataCallback)
277-
return llvm::hash_combine(hash, &layerMetadataCallback);
278-
return hash;
277+
return std::nullopt;
278+
279+
return OptionsContext::getHash();
279280
}
280281

281282
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class OptionsContext : public llvm::cl::SubCommand {
149149
void print(llvm::raw_ostream &os) const;
150150

151151
/// Get a hash derived from the string representation of the options.
152-
virtual llvm::hash_code getHash() const;
152+
virtual std::optional<llvm::hash_code> getHash() const;
153153

154154
virtual bool shouldInvalidateCache() const { return false; }
155155

mlir-tensorrt/tensorrt/lib/Utils/Options.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void OptionsContext::print(llvm::raw_ostream &os) const {
6363
" ");
6464
}
6565

66-
llvm::hash_code OptionsContext::getHash() const {
66+
std::optional<llvm::hash_code> OptionsContext::getHash() const {
6767
// We hash by just hashing the string representation.
6868
llvm::SmallString<128> str;
6969
{

0 commit comments

Comments
 (0)