Skip to content

Commit 51bc478

Browse files
author
Copybara Bot
committed
Move internal changes
GitOrigin-RevId: 8dd523bab228ccfb83aeba17793cb75895d31bc6
1 parent 6c3162a commit 51bc478

File tree

139 files changed

+5074
-1757
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+5074
-1757
lines changed

mlir-tensorrt/compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_mlir_tensorrt_compiler_dependency(MLIRNVVMTarget)
3131
add_mlir_tensorrt_compiler_dependency(MLIRPtrDialect)
3232
add_mlir_tensorrt_compiler_dependency(MLIRTargetLLVM)
3333
add_mlir_tensorrt_compiler_dependency(MLIRTensorTransformOps)
34+
add_mlir_tensorrt_compiler_dependency(MLIREmitCExtDataLayoutImpl)
3435

3536
add_subdirectory(include)
3637
add_subdirectory(lib)

mlir-tensorrt/compiler/include/mlir-tensorrt/Backends/Host/HostBackend.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include "mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td"
55
include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td"
66

77
def Plan_HostClusterKindAttr : Plan_Attr<"HostClusterKind", "host_cluster",
8-
[DeclareAttrInterfaceMethods<ClusterKindAttrInterface>]> {
8+
[DeclareAttrInterfaceMethods<ClusterKindAttrInterface, ["getDefaultMemorySpace"]>]> {
99
let parameters = (ins "int64_t":$benefit);
1010
let assemblyFormat = "`<` struct(params) `>`";
1111
}

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Extension.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,6 @@ class TaskExtensionBase {
7373
template <typename T, typename... Mods>
7474
using ListOption = mlir::detail::PassOptions::ListOption<T, Mods...>;
7575

76-
protected:
77-
/// Whether this extension is disabled. Should default to false and be
78-
/// associated with a flag `--disable-[name]-extension`.
79-
bool disabled{false};
80-
8176
private:
8277
mlir::TypeID typeID;
8378

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
/// Data structures and functions for manipulating compiler options.
2222
///
2323
//===----------------------------------------------------------------------===//
24-
#ifndef MLIR_TENSORRT_COMPILER_OPTIONS
25-
#define MLIR_TENSORRT_COMPILER_OPTIONS
24+
#ifndef MLIR_TENSORRT_COMPILER_OPTIONSPROVIDERS
25+
#define MLIR_TENSORRT_COMPILER_OPTIONSPROVIDERS
2626

2727
#include "mlir-executor/Support/DeviceInfo.h"
2828
#include "mlir/Pass/PassManager.h"
@@ -210,7 +210,7 @@ struct DeviceOptions : public OptionsProvider {
210210

211211
private:
212212
/// Stores host device info. This is populated by the callback of
213-
/// `shouldInfoFromHost`. If present, then it will also override the other
213+
/// `shouldInferFromHost`. If present, then it will also override the other
214214
/// options in their callbacks.
215215
std::optional<DeviceInfo> hostDeviceInfo{};
216216
};
@@ -316,7 +316,6 @@ class CompilationTaskOptionsBase
316316
llvm::cl::desc("entrypoint function name")};
317317

318318
protected:
319-
std::vector<std::unique_ptr<CompilationTaskOptionsBase>> extensions;
320319
std::unique_ptr<DebugOptions> debugOptions{nullptr};
321320
};
322321

@@ -363,4 +362,4 @@ class CompilationTaskOptions : public CompilationTaskOptionsBase {
363362

364363
} // namespace mlirtrt::compiler
365364

366-
#endif // MLIR_TENSORRT_COMPILER_OPTIONS
365+
#endif // MLIR_TENSORRT_COMPILER_OPTIONSPROVIDERS

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,6 @@ class StablehloToExecutableTask
138138

139139
static void populatePassManager(mlir::OpPassManager &pm,
140140
const StablehloToExecutableOptions &options);
141-
142-
/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
143-
/// This is the "functional" entrypoint that will allocate a new PassManager
144-
/// for a single run.
145-
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
146-
compileStableHLOToExecutable(CompilerClient &client, mlir::ModuleOp module,
147-
const StablehloToExecutableOptions &options);
148141
};
149142

150143
/// Register the task/options with the client's registry.

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ class StablehloToExecutableTensorRTExtension
6060
this->workspaceMemoryPoolLimit = options.workspaceMemoryPoolLimit;
6161
}
6262

63-
Option<bool> disable{this->ctx, "disable-tensorrt-extension",
64-
llvm::cl::init(false)};
63+
Option<bool> disabled{this->ctx, "disable-tensorrt-extension",
64+
llvm::cl::init(false)};
6565

6666
Option<TensorRTTargetFormat> format{
6767
this->ctx, "tensorrt-target",

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> {
4242
Option<"convertConditionals", "convert-conditionals", "bool", "true",
4343
"convert conditionals to TensorRT's conditional layer">,
4444
Option<"trtMajorVersion", "trt-major-version", "int64_t", "10",
45-
"target TensorRT version for conversion">
45+
"target TensorRT version for conversion">,
46+
Option<"preferEinsum", "prefer-einsum", "bool", "false",
47+
"prefer converting to 'tensorrt.einsum' over 'tensorrt.matrix_multiply'">
4648
];
4749
}
4850
#endif // MLIR_TENSORRT_ENABLE_HLO
@@ -321,7 +323,10 @@ def ConvertStablehloToScfPass : Pass<"convert-stablehlo-to-scf"> {
321323
}];
322324
let dependentDialects = [
323325
"::mlir::tensor::TensorDialect",
324-
"::mlir::scf::SCFDialect"
326+
"::mlir::scf::SCFDialect",
327+
"::mlir::tensor::TensorDialect",
328+
"::mlir::arith::ArithDialect",
329+
"::mlir::math::MathDialect"
325330
];
326331
}
327332

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,20 @@
3131
namespace mlir {
3232
class ConversionTarget;
3333

34-
// Collection of rewrite patterns for lowering of Stable HLO to TensorRT
35-
// dialect.
34+
/// Populate patterns for converting Stablehlo reduction and contraction ops to
35+
/// TensorRT.
36+
void populateStablehloReductionAndContractionToTensorRtConversionPattern(
37+
TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns,
38+
PatternBenefit benefit = 1, PatternBenefit dotToEinsumBenefit = 0);
39+
40+
/// Collection of rewrite patterns for lowering of Stable HLO to TensorRT
41+
/// dialect.
42+
/// The `preferEinsum` parameter controls whether `tensorrt.einsum` is used
43+
/// as the primary method for converting `stablehlo.dot_general` or only for
44+
/// fallback when conversion to `tensorrt.matrix_multiply` is not possible.
3645
void populateStablehloToTensorRtConversionPattern(
3746
TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns,
38-
ShapeInfoCallbacks shapeInfoCallbacks = {});
47+
ShapeInfoCallbacks shapeInfoCallbacks = {}, bool preferEinsum = false);
3948

4049
/// Populate patterns for convert Chlo ops to TensorRT ops.
4150
void populateChloToTensorRtLegalityAndPatterns(

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -208,26 +208,6 @@ class ConvertOpToTensorRTPattern : public ConvertToTensorRTPattern {
208208
: ConvertToTensorRTPattern(typeConverter, SourceOp::getOperationName(),
209209
benefit, context) {}
210210

211-
/// Wrappers around the ConversionPattern methods that pass the derived op
212-
/// type.
213-
LogicalResult match(Operation *op) const final {
214-
return match(cast<SourceOp>(op));
215-
}
216-
void rewrite(Operation *op, ArrayRef<Value> operands,
217-
ConversionPatternRewriter &rewriter) const final {
218-
if constexpr (SourceOp::hasProperties())
219-
return rewrite(cast<SourceOp>(op),
220-
OpAdaptor(operands, op->getAttrDictionary(),
221-
cast<SourceOp>(op).getProperties()),
222-
rewriter);
223-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
224-
rewriter);
225-
}
226-
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
227-
ConversionPatternRewriter &rewriter) const final {
228-
auto sourceOp = cast<SourceOp>(op);
229-
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
230-
}
231211
LogicalResult
232212
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
233213
ConversionPatternRewriter &rewriter) const override {
@@ -248,32 +228,10 @@ class ConvertOpToTensorRTPattern : public ConvertToTensorRTPattern {
248228
rewriter);
249229
}
250230

251-
/// Rewrite and Match methods that operate on the SourceOp type. These must be
252-
/// overridden by the derived pattern class.
253-
virtual LogicalResult match(SourceOp op) const {
254-
(void)op;
255-
llvm_unreachable("must override match or matchAndRewrite");
256-
}
257-
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
258-
ConversionPatternRewriter &rewriter) const {
259-
(void)op;
260-
(void)adaptor;
261-
(void)rewriter;
262-
llvm_unreachable("must override matchAndRewrite or a rewrite method");
263-
}
264-
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
265-
ConversionPatternRewriter &rewriter) const {
266-
SmallVector<Value> oneToOneOperands =
267-
getOneToOneAdaptorOperands(adaptor.getOperands());
268-
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
269-
}
270231
virtual LogicalResult
271232
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
272233
ConversionPatternRewriter &rewriter) const {
273-
if (failed(match(op)))
274-
return failure();
275-
rewrite(op, adaptor, rewriter);
276-
return success();
234+
llvm_unreachable("must override matchAndRewrite");
277235
}
278236
virtual LogicalResult
279237
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.h"
2828
#include "mlir-tensorrt/Compiler/Extension.h"
29+
#include "mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h"
2930
#include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h"
3031
#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h"
3132
#include "mlir/Bytecode/BytecodeOpInterface.h"
@@ -134,11 +135,6 @@ class PlanDialectExtension
134135
};
135136
} // namespace mlir::plan
136137

137-
//===----------------------------------------------------------------------===//
138-
// Plan Enums
139-
//===----------------------------------------------------------------------===//
140-
#include "mlir-tensorrt/Dialect/Plan/IR/PlanEnums.h.inc"
141-
142138
//===----------------------------------------------------------------------===//
143139
// Plan Attributes
144140
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)