Skip to content

Commit 3b82968

Browse files
Adds a mechanism to invalidate cache entries when a callback is provided
1 parent 1765e21 commit 3b82968

File tree

8 files changed

+73
-14
lines changed

8 files changed

+73
-14
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ class CompilerClient {
106106
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(),
107107
options.getHash());
108108
auto it = cachedPassManagers.find(key);
109-
if (it == cachedPassManagers.end()) {
109+
if (it == cachedPassManagers.end() || options.shouldInvalidateCache()) {
110110
auto pm = std::make_unique<CompilationTaskType>(context, options);
111111
setupPassManagerLogging(*pm, options.debugOptions);
112112
auto *ptr = pm.get();
113-
cachedPassManagers.insert(std::make_pair(key, std::move(pm)));
113+
cachedPassManagers[key] = std::move(pm);
114114
return *ptr;
115115
}
116116
return *it->second;

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
105105
/// Get the mutable DebugOptions.
106106
DebugOptions &getDebugOptions() { return debugOptions; }
107107

108+
llvm::hash_code getHash() const override;
109+
110+
bool shouldInvalidateCache() const override {
111+
// If a callback is provided, we have no way of verifying whether it is
112+
// equivalent to a callback from another set of options. Therefore, we are
113+
// forced to invalidate the cache entry if it is present at all.
114+
return static_cast<bool>(layerMetadataCallback);
115+
}
116+
108117
/// The host index bit-width.
109118
int64_t executorIndexBitwidth{64};
110119

@@ -129,8 +138,7 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
129138

130139
DebugOptions debugOptions;
131140

132-
std::function<std::string(mlir::Operation *)> layerMetadataCallback =
133-
[](mlir::Operation *) { return ""; };
141+
std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};
134142

135143
/// Base class for extensions associated with StableHloToExecutableTask.
136144
class ExtensionBase : public TaskExtensionBase {

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,13 @@ Status StableHLOToExecutableOptions::inferDeviceOptionsFromHost() {
271271
return Status::getOk();
272272
}
273273

274+
llvm::hash_code StableHLOToExecutableOptions::getHash() const {
275+
llvm::hash_code hash = OptionsContext::getHash();
276+
if (layerMetadataCallback)
277+
return llvm::hash_combine(hash, &layerMetadataCallback);
278+
return hash;
279+
}
280+
274281
//===----------------------------------------------------------------------===//
275282
// StableHloToExecutableTask
276283
//===----------------------------------------------------------------------===//

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class PyStableHLOToExecutableOptions
7171

7272
// We need this member so we can keep the Python callback alive long enough.
7373
std::function<std::string(MlirOperation)> callback;
74+
75+
~PyStableHLOToExecutableOptions() { callback = nullptr; }
7476
};
7577
} // namespace
7678

@@ -323,8 +325,7 @@ PYBIND11_MODULE(_api, m) {
323325
THROW_IF_MTRT_ERROR(status);
324326
return new PyExecutable(exe);
325327
},
326-
py::arg("client"), py::arg("module"), py::arg("options"),
327-
py::keep_alive<1, 3>());
328+
py::arg("client"), py::arg("module"), py::arg("options"));
328329

329330
m.def(
330331
"get_stablehlo_program_refined_signature",

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TranslateToTensorRT.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,7 @@ FailureOr<TensorRTEngineResult> buildFunction(
212212
TensorRTSerializedTimingCache &serializedTimingCache,
213213
const TensorRTTranslationOptions &options =
214214
TensorRTTranslationOptions::fromCLFlags(),
215-
std::function<std::string(Operation *)> layerMetadataCallback =
216-
[](Operation *op) { return ""; });
215+
std::function<std::string(Operation *)> layerMetadataCallback = nullptr);
217216

218217
/// Create an instance of a translate-to-tensorrt pass using an existing
219218
/// TensorRTBuilderContext.

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Options.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ namespace mlir {
8383
/// ```
8484
class OptionsContext : public llvm::cl::SubCommand {
8585
public:
86+
OptionsContext() = default;
87+
OptionsContext(const OptionsContext &) = delete;
88+
OptionsContext(OptionsContext &&) = default;
89+
virtual ~OptionsContext() = default;
90+
8691
/// Add an option to this context. The storage `value` must outlive the
8792
/// OptionsContext.
8893
template <typename DataType, typename... Mods>
@@ -124,7 +129,9 @@ class OptionsContext : public llvm::cl::SubCommand {
124129
void print(llvm::raw_ostream &os) const;
125130

126131
/// Get a hash derived from the string representation of the options.
127-
llvm::hash_code getHash() const;
132+
virtual llvm::hash_code getHash() const;
133+
134+
virtual bool shouldInvalidateCache() const { return false; }
128135

129136
private:
130137
struct OptionInfo {

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer,
264264
Operation *sourceOp) {
265265
std::string name = createName(namesSet, sourceOp);
266266
layer->setName(name.c_str());
267-
layer->setMetadata(layerMetadataCallback(sourceOp).c_str());
267+
if (layerMetadataCallback)
268+
layer->setMetadata(layerMetadataCallback(sourceOp).c_str());
268269
}
269270

270271
nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {

mlir-tensorrt/test/python/mlir_tensorrt_compiler/compiler_api/test_layer_metadata_callback.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import glob
1010
import os
1111
import json
12+
import gc
1213

1314
STATIC_ASM = """
1415
func.func @main(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
@@ -19,17 +20,14 @@
1920

2021

2122
def layer_metadata_callback(op) -> str:
23+
print("layer_metadata_callback CALLED")
2224
return "TEST_CUSTOM_METADATA"
2325

2426

2527
def compile_asm():
2628
with Context() as context:
2729
m = Module.parse(STATIC_ASM)
2830
client = api.CompilerClient(context)
29-
opts = api.StableHLOToExecutableOptions(
30-
client,
31-
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
32-
)
3331

3432
with tempfile.TemporaryDirectory() as tmp:
3533
opts = api.StableHLOToExecutableOptions(
@@ -58,7 +56,45 @@ def compile_asm():
5856
# CHECK-LABEL: Compiling ASM
5957
# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder
6058
# CHECK: [translate-to-tensorrt] timing cache path was not specified, creating a fresh timing cache
59+
# CHECK: layer_metadata_callback CALLED
6160
# CHECK: [translate-to-tensorrt] deserializing TensorRT builder timing cache (0 bytes)
6261
# CHECK: [translate-to-tensorrt] Setting builder optimization level to 3
6362
# CHECK: [translate-to-tensorrt] replacing cache with updated data (0 -> 2057 bytes)
6463
# CHECK: TEST_CUSTOM_METADATA
64+
65+
66+
def layer_metadata_callback2(op) -> str:
67+
print("layer_metadata_callback2 CALLED")
68+
return "TEST_CUSTOM_METADATA2"
69+
70+
71+
def compile_multiple():
72+
# Compile multiple times with different callbacks to ensure pass manager caching doesn't
73+
# cause issues.
74+
with Context() as context:
75+
m = Module.parse(STATIC_ASM)
76+
client = api.CompilerClient(context)
77+
opts0 = api.StableHLOToExecutableOptions(
78+
client,
79+
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
80+
)
81+
opts0.set_tensorrt_translation_metadata_callback(layer_metadata_callback)
82+
api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts0)
83+
84+
del opts0
85+
gc.collect()
86+
87+
opts1 = api.StableHLOToExecutableOptions(
88+
client,
89+
["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"],
90+
)
91+
opts1.set_tensorrt_translation_metadata_callback(layer_metadata_callback2)
92+
api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts1)
93+
94+
95+
print("Checking multiple compile calls")
96+
compile_multiple()
97+
98+
# CHECK-LABEL: Checking multiple compile calls
99+
# CHECK: layer_metadata_callback CALLED
100+
# CHECK: layer_metadata_callback2 CALLED

0 commit comments

Comments
 (0)