Skip to content

Commit 75bd2b1

Browse files
init commit
1 parent 251102e commit 75bd2b1

File tree

8 files changed

+74
-22
lines changed

8 files changed

+74
-22
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
7777
const char **debugTypes, size_t debugTypeSizes,
7878
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);
7979

80+
MLIR_CAPI_EXPORTED MTRT_Status
81+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
82+
MTRT_StableHLOToExecutableOptions options,
83+
const char *(*callback)(MlirOperation));
84+
8085
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
8186
MTRT_StableHLOToExecutableOptions options);
8287

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir-executor/Runtime/API/API.h"
3535
#include "mlir-executor/Support/Status.h"
3636
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
37+
#include "mlir-tensorrt-dialect/Utils/Types.h"
3738
#include "mlir-tensorrt/Compiler/Client.h"
3839
#include "mlir-tensorrt/Compiler/Extension.h"
3940
#include "mlir-tensorrt/Compiler/Options.h"
@@ -130,6 +131,9 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
130131

131132
DebugOptions debugOptions;
132133

134+
// TODO: Add a sane default here:
135+
MetadataCallbackT layerMetadataCallback = [](MlirOperation op) { return ""; };
136+
133137
/// Base class for extensions associated with StableHloToExecutableTask.
134138
class ExtensionBase : public TaskExtensionBase {
135139
public:

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,15 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
199199
return mtrtStatusGetOk();
200200
}
201201

202+
MTRT_Status
203+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
204+
MTRT_StableHLOToExecutableOptions options,
205+
const char *(*callback)(MlirOperation)) {
206+
StableHLOToExecutableOptions *cppOpts = unwrap(options);
207+
cppOpts->layerMetadataCallback = callback;
208+
return mtrtStatusGetOk();
209+
}
210+
202211
MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
203212
MTRT_StableHLOToExecutableOptions options) {
204213
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);

mlir-tensorrt/compiler/lib/Compiler/TensorRTExtension/TensorRTExtension.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void StableHLOToExecutableTensorRTExtension::populatePasses(
6464
auto &trtPM = pm.nest<tensorrt::TensorRTModuleOp>();
6565
tensorrt::buildTensorRTModuleTransformationPipeline(
6666
trtPM, translationOptions.enableStronglyTyped);
67-
trtPM.addPass(
68-
tensorrt::createTranslateTensorRTPass(nullptr, translationOptions));
67+
trtPM.addPass(tensorrt::createTranslateTensorRTPass(
68+
nullptr, translationOptions, options.layerMetadataCallback));
6969
return;
7070
}
7171

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
#ifndef MLIR_TENSORRT_TARGET_TENSORRT_TENSORRTENCODINGOPINTERFACE_NETWORKENCODER
2525
#define MLIR_TENSORRT_TARGET_TENSORRT_TENSORRTENCODINGOPINTERFACE_NETWORKENCODER
2626

27+
#include "mlir-c/IR.h"
2728
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2829
#include "mlir-tensorrt-dialect/Utils/NvInferAdaptor.h"
2930
#include "mlir-tensorrt-dialect/Utils/NvInferPluginUtils.h"
31+
#include "mlir-tensorrt-dialect/Utils/Types.h"
3032
#include "llvm/ADT/ScopedHashTable.h"
3133
#include "llvm/ADT/StringSet.h"
3234

@@ -76,9 +78,11 @@ class NvInferNetworkEncoder {
7678
public:
7779
NvInferNetworkEncoder(nvinfer1::INetworkDefinition *network,
7880
nvinfer1::IOptimizationProfile *profile,
79-
TensorRTVersion version, bool usesStronglyTyped)
81+
TensorRTVersion version, bool usesStronglyTyped,
82+
mlirtrt::MetadataCallbackT metadataCallback)
8083
: network(network), profile(profile), version(std::move(version)),
81-
usesStronglyTyped(usesStronglyTyped) {}
84+
usesStronglyTyped(usesStronglyTyped),
85+
layerMetadataCallback(std::move(metadataCallback)) {}
8286

8387
/// Lookup the TRT ITensor* equivalent of a Value.
8488
nvinfer1::ITensor *lookup(Value v) const;
@@ -238,6 +242,10 @@ class NvInferNetworkEncoder {
238242
bool hasQDQOps{false};
239243

240244
PluginManager pluginMgr;
245+
246+
// TODO: Where to use this? encodeOp doesn't have a way for us to access the
247+
// layers.
248+
std::function<std::string(MlirOperation)> layerMetadataCallback;
241249
};
242250

243251
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TranslateToTensorRT.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626
#include "mlir-tensorrt-dialect/Utils/Options.h"
27+
#include "mlir-tensorrt-dialect/Utils/Types.h"
2728
#include "mlir/Support/LogicalResult.h"
2829
#include "llvm/Support/raw_ostream.h"
2930

@@ -208,19 +209,26 @@ class TensorRTSerializedTimingCache {
208209
/// `tensorrt.shape_profile` arguments have been populated for each argument
209210
/// that has unknown dimensions.
210211
/// TODO(cbate): add additional options here for builder configuration.
211-
FailureOr<TensorRTEngineResult>
212-
buildFunction(mlir::FunctionOpInterface op,
213-
TensorRTBuilderContext &builderContext,
214-
TensorRTSerializedTimingCache &serializedTimingCache,
215-
const TensorRTTranslationOptions &options =
216-
TensorRTTranslationOptions::fromCLFlags());
212+
FailureOr<TensorRTEngineResult> buildFunction(
213+
mlir::FunctionOpInterface op, TensorRTBuilderContext &builderContext,
214+
TensorRTSerializedTimingCache &serializedTimingCache,
215+
const TensorRTTranslationOptions &options =
216+
TensorRTTranslationOptions::fromCLFlags(),
217+
// TODO: Add a sane default here:
218+
mlirtrt::MetadataCallbackT layerMetadataCallback = [](MlirOperation op) {
219+
return "";
220+
});
217221

218222
/// Create an instance of a translate-to-tensorrt pass using an existing
219223
/// TensorRTBuilderContext.
220224
std::unique_ptr<mlir::Pass> createTranslateTensorRTPass(
221225
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
222226
TensorRTTranslationOptions options =
223-
TensorRTTranslationOptions::fromCLFlags());
227+
TensorRTTranslationOptions::fromCLFlags(),
228+
// TODO: Add a sane default here:
229+
mlirtrt::MetadataCallbackT layerMetadataCallback = [](MlirOperation op) {
230+
return "";
231+
});
224232

225233
/// Register llvm::cl opts related to TensorRT translation. This should be
226234
/// called before having LLVM parse CL options.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef MLIR_TENSORRT_UTILS_TYPES_H
2+
#define MLIR_TENSORRT_UTILS_TYPES_H
3+
4+
#include "mlir/IR/Operation.h"
5+
#include <functional>
6+
7+
namespace mlirtrt {
8+
using MetadataCallbackT = std::function<const char *(MlirOperation)>;
9+
} // namespace mlirtrt
10+
11+
#endif // MLIR_TENSORRT_UTILS_TYPES_H

mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ FailureOr<TensorRTEngineResult>
336336
tensorrt::buildFunction(mlir::FunctionOpInterface op,
337337
TensorRTBuilderContext &builderContext,
338338
TensorRTSerializedTimingCache &serializedTimingCache,
339-
const TensorRTTranslationOptions &opts) {
339+
const TensorRTTranslationOptions &opts,
340+
mlirtrt::MetadataCallbackT layerMetadataCallback) {
340341
assert(builderContext.getBuilder() && "expected valid builder context");
341342
std::unique_ptr<nvinfer1::IBuilder> &builder = builderContext.getBuilder();
342343

@@ -357,9 +358,9 @@ tensorrt::buildFunction(mlir::FunctionOpInterface op,
357358
nvinfer1::IOptimizationProfile *optimProfile =
358359
builder->createOptimizationProfile();
359360

360-
NvInferNetworkEncoder encoder(network.get(), optimProfile,
361-
builderContext.getTensorRTVersion(),
362-
opts.enableStronglyTyped);
361+
NvInferNetworkEncoder encoder(
362+
network.get(), optimProfile, builderContext.getTensorRTVersion(),
363+
opts.enableStronglyTyped, layerMetadataCallback);
363364

364365
// Currently we only support single-block functions with unique return
365366
// terminator ops.
@@ -673,9 +674,10 @@ class TranslateToTensorRTEnginePass
673674

674675
explicit TranslateToTensorRTEnginePass(
675676
std::shared_ptr<TensorRTBuilderContext> builderContext,
676-
TensorRTTranslationOptions options)
677-
: builderContext(builderContext), translationOptions(std::move(options)) {
678-
}
677+
TensorRTTranslationOptions options,
678+
mlirtrt::MetadataCallbackT metadataCallback)
679+
: builderContext(builderContext), translationOptions(std::move(options)),
680+
layerMetadataCallback(std::move(metadataCallback)) {}
679681

680682
LogicalResult initialize(MLIRContext *context) final {
681683
if (!this->builderContext) {
@@ -742,8 +744,9 @@ class TranslateToTensorRTEnginePass
742744
continue;
743745
}
744746

745-
FailureOr<TensorRTEngineResult> engineResult = buildFunction(
746-
func, *builderContext, *timingCache, translationOptions);
747+
FailureOr<TensorRTEngineResult> engineResult =
748+
buildFunction(func, *builderContext, *timingCache, translationOptions,
749+
layerMetadataCallback);
747750
if (failed(engineResult) || !engineResult->serializedEngine) {
748751
func.emitError() << "failed to translate function '" << func.getName()
749752
<< "' to a TensorRT engine";
@@ -820,11 +823,15 @@ class TranslateToTensorRTEnginePass
820823

821824
/// Options affecting TensorRT translation.
822825
TensorRTTranslationOptions translationOptions;
826+
827+
mlirtrt::MetadataCallbackT layerMetadataCallback;
823828
};
824829
} // namespace
825830

826831
std::unique_ptr<mlir::Pass> tensorrt::createTranslateTensorRTPass(
827832
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
828-
TensorRTTranslationOptions options) {
829-
return std::make_unique<TranslateToTensorRTEnginePass>(context, options);
833+
TensorRTTranslationOptions options,
834+
mlirtrt::MetadataCallbackT layerMetadataCallback) {
835+
return std::make_unique<TranslateToTensorRTEnginePass>(context, options,
836+
layerMetadataCallback);
830837
}

0 commit comments

Comments
 (0)