Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tripy/tests/backend/mlir/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,10 @@ def test_reason_context(self):
match=".*This is the first level of context\n This is the second level of context.\n.*",
) as exc:
map_error_to_user_code_and_raise(flat_ir, exc, err_str)

def test_layer_metadata_callback(self):
# TODO: Finish this:
inp = tp.ones((2, 2))
out = tp.gelu(inp)

out.eval()
34 changes: 31 additions & 3 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
make_ir_context,
map_error_to_user_code_and_raise,
redirect_stderr,
parse_tensor_names_from_location,
UNKNOWN_LOC,
)
from tripy.logging import logger
from tripy.common.exception import str_from_source_info


G_COMPILER_CLIENT = None
G_TIMING_CACHE_FILE = cfg.timing_cache_file_path
Expand All @@ -53,7 +57,7 @@ def __init__(self, trt_builder_opt_level=0) -> None:
self.mlir_context, self.compiler_client = _get_compiler_objects()
self.trt_builder_opt_level = trt_builder_opt_level

def _make_mlir_opts(self, trt_builder_opt_level):
def _make_mlir_opts(self, trt_builder_opt_level, layer_metadata_callback=None):
opts = [
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
Expand All @@ -67,7 +71,13 @@ def _make_mlir_opts(self, trt_builder_opt_level):
if config.enable_tensorrt_debug:
opts.append(f"--tensorrt-layer-info-dir={config.tensorrt_debug_path}")
opts.append(f"--tensorrt-engines-dir={config.tensorrt_debug_path}")
return compiler.StableHLOToExecutableOptions(self.compiler_client, opts)

opts = compiler.StableHLOToExecutableOptions(self.compiler_client, opts)

if layer_metadata_callback is not None:
opts.set_tensorrt_translation_metadata_callback(layer_metadata_callback)

return opts

def compile_stabehlo_program(self, code: str) -> compiler.Executable:
with self.mlir_context:
Expand Down Expand Up @@ -95,7 +105,25 @@ def infer_shapes(self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = Non
@utils.log_time
def compile(self, mlir_module: ir.Module, flat_ir: Optional["FlatIR"] = None) -> compiler.Executable:
logger.mlir(lambda: f"{mlir_module.operation.get_asm(large_elements_limit=32)}\n")
opts = self._make_mlir_opts(self.trt_builder_opt_level)

def layer_metadata_callback(op):
if UNKNOWN_LOC in str(op.location):
return str(op.name)

# _, _, _, trace_outputs, _ = parse_tensor_names_from_location(str(op.location))

# for name in trace_outputs:
# if name in flat_ir.tensor_map:
# tensor = flat_ir.tensor_map[name]
# user_frame_index = tensor.stack_info.get_first_user_frame_index()
# if user_frame_index is None:
# continue
# user_frame = tensor.stack_info[user_frame_index]
# return str_from_source_info(user_frame, enable_color=False)

return str(op.name)

opts = self._make_mlir_opts(self.trt_builder_opt_level, layer_metadata_callback)

try:
with redirect_stderr() as outfile:
Expand Down
26 changes: 12 additions & 14 deletions tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import mlir_tensorrt.runtime.api as runtime

from tripy.backend.api.stream import default_stream
from tripy.backend.mlir.memref import create_empty_memref
from tripy.backend.mlir.memref import create_memref
from tripy.backend.mlir.utils import MLIRRuntimeClient, convert_runtime_dtype_to_tripy_dtype
from tripy.backend.utils import TensorInfo
from tripy.common import datatype, device
Expand All @@ -31,6 +31,7 @@

class Executor:
def __init__(self, executable: runtime.Executable) -> None:

self.runtime_client = MLIRRuntimeClient()
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
self.session = runtime.RuntimeSession(session_options, executable)
Expand All @@ -47,15 +48,16 @@ def __init__(self, executable: runtime.Executable) -> None:
def _create_shape_memref(self, shape):
shape = make_tuple(shape)
if len(shape) == 0:
# create an empty memref
return self.runtime_client.create_memref(
shape=(0,), dtype=runtime.runtime.ScalarTypeCode.i64, stream=self.stream._active_cuda_stream
return create_memref(
shape=(0,),
dtype=datatype.int64,
device=device("cpu"),
)
return self.runtime_client.create_memref(
convert_list_to_array(shape, datatype.int64),
return create_memref(
array=convert_list_to_array(shape, datatype.int64),
shape=(len(shape),),
dtype=runtime.ScalarTypeCode.i64,
stream=self.stream._active_cuda_stream,
dtype=datatype.int64,
device=device("cpu"),
)

def _get_outputs_shape(self):
Expand Down Expand Up @@ -134,12 +136,8 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->

# Allocate output memory and store buffer pointers.
outputs = [
create_empty_memref(
shape=info.shape,
dtype=info.dtype,
device=info.device,
stream=self.stream._active_cuda_stream,
use_cache=False,
create_memref(
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
)
for info in out_tensor_info
]
Expand Down
69 changes: 32 additions & 37 deletions tripy/tripy/backend/mlir/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,49 @@
# limitations under the License.
#

import math
import re

from functools import lru_cache
from typing import Sequence
import mlir_tensorrt.runtime.api as runtime

from tripy.utils import raise_error
from tripy.backend.mlir import utils as mlir_utils
from tripy.common import device as tp_device
from tripy.common import utils as common_utils
from tripy.utils import raise_error

import mlir_tensorrt.runtime.api as runtime
EMPTY_MEMREF_CACHE = {}


@lru_cache(maxsize=None)
def _cached_create_empty_memref(shape: Sequence[int], dtype: str, device_kind: str, stream):
mlirtrt_device = mlir_utils.MLIRRuntimeClient().get_devices()[0] if device_kind == "gpu" else None
mlirtrt_stream = stream if device_kind == "gpu" else None
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)
return mlir_utils.MLIRRuntimeClient().create_memref(
shape=list(shape),
dtype=mlir_dtype,
device=mlirtrt_device,
stream=mlirtrt_stream,
)


def create_empty_memref(
shape: Sequence[int],
dtype: str,
device: tp_device = tp_device(("gpu", 0)),
stream=None,
use_cache: bool = True,
):
def create_memref(shape, dtype, device=tp_device("gpu"), stream=None, array=None):
"""
Creates a memref. If array is provided, it will be populated by the values
from the array. Otherwise, an empty memref is created.
"""
Creates an empty memref, used for allocating memory.
Caches the result for subsequent calls with the same parameters.
is_empty_shape = math.prod(shape) == 0
cache_key = (shape, dtype, device.kind, device.index)
if is_empty_shape and cache_key in EMPTY_MEMREF_CACHE:
return EMPTY_MEMREF_CACHE[cache_key]

Args:
use_cache (bool, optional): Whether to use cached results for repeated calls with the same parameters.
If True, returns cached results if available. If False, always creates a new memref.
Defaults to True. This ensures we reuse empty memref across functions.
mlir_dtype = mlir_utils.convert_tripy_dtype_to_runtime_dtype(dtype)

"""
if use_cache:
assert common_utils.is_shape_empty(shape)
return _cached_create_empty_memref(tuple(shape), dtype, device.kind, stream)
else:
return _cached_create_empty_memref.__wrapped__(tuple(shape), dtype, device.kind, stream)
args = []

# "array" is marked as a positional-only argument
if array is not None:
args.append(array)

kwargs = {"shape": shape, "dtype": mlir_dtype}

if device.kind == "gpu":
kwargs["device"] = mlir_utils.MLIRRuntimeClient().get_devices()[device.index]
# Streams are only allowed for GPU allocations.
kwargs["stream"] = stream

memref = mlir_utils.MLIRRuntimeClient().create_memref(*args, **kwargs)

if is_empty_shape:
EMPTY_MEMREF_CACHE[cache_key] = memref

return memref


def create_memref_view(data):
Expand Down
23 changes: 11 additions & 12 deletions tripy/tripy/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,27 @@ def _make_stack_info_message(stack_info: "utils.StackInfo", enable_color: bool =
exclude_file_lines = {} # Maps filenames to ranges of lines that should be ignored.
for func in EXCLUDE_FUNCTIONS:
filename, start_line, end_line = _get_function_file_and_lines(func)

exclude_file_lines[filename] = (start_line, end_line)

def should_exclude(frame):
if frame.file not in exclude_file_lines:
def should_exclude(source_info):
if source_info.code is None:
return True

# Exclude frames from some modules that are not very useful to users:
if source_info.module in utils.get_module_names_to_exclude_from_stack_info():
return True

if source_info.file not in exclude_file_lines:
return False

start_line, end_line = exclude_file_lines[frame.file]
return frame.line >= start_line and frame.line <= end_line
start_line, end_line = exclude_file_lines[source_info.file]
return source_info.line >= start_line and source_info.line <= end_line

frame_strs = []
num_frames_printed = 0

stack_info.fetch_source_code()
for index, source_info in enumerate(stack_info):
if source_info.code is None:
continue

# Exclude frames from some modules that are not very useful to users:
if source_info.module in utils.get_module_names_to_exclude_from_stack_info():
continue

if should_exclude(source_info):
continue

Expand Down
16 changes: 1 addition & 15 deletions tripy/tripy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
#

import array
import struct
from typing import Any, List, Sequence

from tripy.common.exception import raise_error
import tripy.common.datatype
from tripy.common.exception import raise_error


def is_int32(data):
Expand Down Expand Up @@ -72,16 +71,3 @@ def convert_list_to_array(values: List[Any], dtype: str) -> bytes:

def is_empty(data: Sequence) -> bool:
return isinstance(data, Sequence) and all(map(is_empty, data))


def is_shape_empty(shape: Sequence[int]) -> bool:
"""
A shape is considered empty if any of its dimensions is zero.

Args:
shape (Tuple[int, ...]): A tuple representing the shape of a tensor.

Returns:
bool: True if the shape represents an empty tensor, False otherwise.
"""
return any(dim == 0 for dim in shape)
15 changes: 8 additions & 7 deletions tripy/tripy/flat_ir/ops/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
# limitations under the License.
#

import numbers
from dataclasses import dataclass
from typing import Sequence, Set, Union

import mlir_tensorrt.runtime.api as runtime
from mlir_tensorrt.compiler import ir
from mlir_tensorrt.compiler.dialects import stablehlo

from tripy import utils
from tripy.backend.mlir.memref import create_memref
from tripy.common import device
from tripy.flat_ir.ops.base import BaseFlatIROp

import mlir_tensorrt.runtime.api as runtime


@dataclass(repr=False)
class ConstantOp(BaseFlatIROp):
Expand All @@ -40,6 +40,7 @@ def str_skip_fields(self) -> Set[str]:

def to_mlir(self, operands):
import array

import tripy.common.datatype as datatype
from tripy.backend.mlir import utils as mlir_utils

Expand All @@ -58,11 +59,11 @@ def to_mlir(self, operands):
# so we have to represent them as ints and then cast the result
if self.outputs[0].dtype == datatype.bool:
# need to use memoryview.cast to ensure that the view will be flattened
int_memref = runtime_client.create_memref(
array.array("i", memoryview(data_memref).cast("b").tolist()),
int_memref = create_memref(
array=array.array("i", memoryview(data_memref).cast("b").tolist()),
shape=self.data.shape,
dtype=mlir_utils.convert_tripy_dtype_to_runtime_dtype(datatype.int32),
device=None,
dtype=datatype.int32,
device=device("cpu"),
)
attr = ir.DenseElementsAttr.get(
array=int_memref, type=mlir_utils.get_mlir_dtype(datatype.int32), shape=data_memref.shape
Expand Down
10 changes: 5 additions & 5 deletions tripy/tripy/frontend/trace/ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from dataclasses import dataclass
from typing import List, Sequence, Set, Union

import mlir_tensorrt.runtime.api as runtime

from tripy import utils
from tripy.backend.mlir import memref
from tripy.backend.mlir import utils as mlir_utils
from tripy.common import datatype
from tripy.frontend import utils as frontend_utils
from tripy.common import utils as common_utils
from tripy.common import device as tp_device
from tripy.common import utils as common_utils
from tripy.frontend import utils as frontend_utils
from tripy.frontend.trace.ops import utils as op_utils
from tripy.frontend.trace.ops.base import BaseTraceOp

import mlir_tensorrt.runtime.api as runtime


@dataclass(repr=False)
class Storage(BaseTraceOp):
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
# special case: empty tensor
self.dtype = utils.default(dtype, datatype.float32)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_empty_memref(shape=self.shape, dtype=self.dtype)
self.data = memref.create_memref(shape=self.shape, dtype=self.dtype)
self.device = utils.default(device, tp_device(("gpu", 0)))
self.has_memref = True
else:
Expand Down
Loading