Skip to content

Commit 5863d9b

Browse files
committed
Add stride suppport
1 parent 54f9819 commit 5863d9b

File tree

16 files changed

+143
-18
lines changed

16 files changed

+143
-18
lines changed

tripy/tests/backend/test_compiler_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_function(self):
185185
inp = tp.ones((2, 2), dtype=tp.float32)
186186
out = compiled_gelu(inp)
187187

188-
# TODO (#225): Replace with tp.all
189-
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
188+
assert tp.equal(out, tp.relu(inp))
190189

191190
def test_module(self):
192191
layernorm = tp.LayerNorm(2)

tripy/tests/flat_ir/ops/test_constant.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def test_str(self):
3030

3131
const = flat_ir.ops[-1]
3232
assert isinstance(const, ConstantOp)
33-
assert str(const) == "out: [rank=(1), shape=((2,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=[2.0, 3.0])"
33+
assert (
34+
str(const)
35+
== "out: [rank=(1), shape=((2,)), stride=((1,)), dtype=(float32), loc=(gpu:0)] = ConstantOp(data=[2.0, 3.0])"
36+
)
3437

3538
def test_mlir(self):
3639
out = tp.Tensor([2, 3], dtype=tp.int32, name="out")

tripy/tests/frontend/trace/test_trace.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_str(self):
9595
str(trace)
9696
== dedent(
9797
"""
98-
a = storage(data=[0], shape=(1,), dtype=int32, device=gpu:0)
99-
b = storage(data=[1], shape=(1,), dtype=int32, device=gpu:0)
98+
a = storage(data=[0], shape=(1,), stride=(1,), dtype=int32, device=gpu:0)
99+
b = storage(data=[1], shape=(1,), stride=(1,), dtype=int32, device=gpu:0)
100100
c = a + b
101101
outputs:
102102
c: [rank=(1), dtype=(int32), loc=(gpu:0)]
@@ -133,8 +133,8 @@ def test_multiple_outputs(self):
133133
str(trace)
134134
== dedent(
135135
"""
136-
a = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
137-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
136+
a = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
137+
b = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
138138
c = a + b
139139
d = c + c
140140
outputs:
@@ -168,8 +168,8 @@ def test_all_inputs(self):
168168
== dedent(
169169
"""
170170
inputs:
171-
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
172-
b: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
171+
a: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
172+
b: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
173173
c = a + b
174174
outputs:
175175
c: [rank=(1), dtype=(float32), loc=(gpu:0)]
@@ -191,8 +191,8 @@ def test_const_and_input(self):
191191
== dedent(
192192
"""
193193
inputs:
194-
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
195-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
194+
a: [rank=(1), shape=((1,)), stride=((1,)), dtype=(float32), loc=(gpu:0)]
195+
b = storage(data=[1.0000], shape=(1,), stride=(1,), dtype=float32, device=gpu:0)
196196
c = a + b
197197
outputs:
198198
c: [rank=(1), dtype=(float32), loc=(gpu:0)]

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def test_non_constant_scale(self):
118118
scale = tp.ones((4,))
119119
quantized = tp.quantize(input, scale, tp.int8, dim=0)
120120

121-
assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
121+
assert tp.equal(quantized, tp.ones((4, 4), dtype=tp.int8))

tripy/tripy/backend/mlir/executor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
114114
is_static_shape = all(dim >= 0 for dim in memref.shape)
115115
if is_static_shape:
116116
outputs_tensor_info.append(
117-
TensorInfo(len(memref.shape), tuple(memref.shape), dtype, device(device_type))
117+
TensorInfo(
118+
len(memref.shape), tuple(memref.shape), tuple(memref.strides), dtype, device(device_type)
119+
)
118120
)
119121
else:
120122
runtime_shape = [
@@ -124,6 +126,7 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
124126
TensorInfo(
125127
len(runtime_shape),
126128
tuple(runtime_shape),
129+
tuple(memref.strides),
127130
dtype,
128131
device(device_type),
129132
)
@@ -174,7 +177,10 @@ def execute(self, output_devices=List[device], inputs: List["Tensor"] = []) -> L
174177
# Allocate output memory and store buffer pointers.
175178
outputs = [
176179
create_empty_memref(
177-
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
180+
shape=info.shape,
181+
dtype=info.dtype,
182+
device=info.device,
183+
stream=self.stream._active_cuda_stream,
178184
)
179185
for info in out_tensor_info
180186
]

tripy/tripy/backend/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class TensorInfo:
2828
rank: int
2929
shape: Sequence[int]
30+
stride: Sequence[int]
3031
dtype: "tripy.dtype"
3132
device: "tripy.device"
3233

@@ -36,11 +37,12 @@ def encode(tensor_info: TensorInfo) -> Dict[str, Any]:
3637
return {
3738
"rank": tensor_info.rank,
3839
"shape": tensor_info.shape,
40+
"stride": tensor_info.stride,
3941
"dtype": tensor_info.dtype,
4042
"device": tensor_info.device,
4143
}
4244

4345

4446
@Decoder.register(TensorInfo)
4547
def decode(dct: Dict[str, Any]) -> TensorInfo:
46-
return TensorInfo(dct["rank"], dct["shape"], dct["dtype"], dct["device"])
48+
return TensorInfo(dct["rank"], dct["shape"], dct["stride"], dct["dtype"], dct["device"])

tripy/tripy/flat_ir/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class FlatIRTensor:
3838
rank: int
3939
producer: "BaseFlatIROp" = None
4040
shape: Optional[List[int]] = None
41+
stride: Optional[List[int]] = None
4142
reason_details: Optional[List[Any]] = None
4243
"""
4344
Describes why this tensor was created.
@@ -69,6 +70,7 @@ def build(
6970
rank: int,
7071
reason_details: List[Any],
7172
shape: List[int] = None,
73+
stride: List[int] = None,
7274
) -> "FlatIRTensor":
7375
return FlatIRTensor(
7476
name=None,
@@ -80,6 +82,7 @@ def build(
8082
rank=rank,
8183
producer=None,
8284
shape=shape,
85+
stride=utils.get_stride(shape, stride),
8386
reason_details=reason_details,
8487
reason_context=copy.copy(_BUILD_CONTEXT),
8588
)
@@ -88,6 +91,7 @@ def __str__(self) -> str:
8891
return (
8992
f"{self.name}: [rank=({self.rank}), "
9093
+ (f"shape=({self.shape}), " if self.shape is not None else "")
94+
+ (f"stride=({self.stride}), " if self.stride is not None else "")
9195
+ (f"dtype=({self.dtype.name}), " if self.dtype is not None else "")
9296
+ f"loc=({self.device})]"
9397
)

tripy/tripy/frontend/ops/equal.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from tripy import export, constraints
19+
from tripy.common.exception import raise_error
20+
21+
22+
@export.public_api(document_under="operations/functions")
23+
@constraints.dtype_info(
24+
dtype_variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]},
25+
dtype_constraints={"a": "T1", "b": "T1"},
26+
)
27+
def equal(a: "tripy.Tensor", b: "tripy.Tensor") -> bool:
28+
r"""
29+
Returns true if all elements in ``a`` and ``b`` are exactly equal.
30+
31+
This function performs an element-wise equality comparison between tensors ``a`` and ``b``,
32+
and returns True only if all elements are exactly equal.
33+
34+
Args:
35+
a: First tensor to compare.
36+
b: Second tensor to compare.
37+
38+
Returns:
39+
``True`` if all elements in both tensors are exactly equal, ``False`` otherwise.
40+
41+
.. code-block:: python
42+
:linenos:
43+
:caption: Equal Tensors
44+
45+
# doc: print-locals out
46+
out = tp.equal(tp.Tensor([1, 2, 3]), tp.Tensor([1, 2, 3]))
47+
assert out
48+
49+
.. code-block:: python
50+
:linenos:
51+
:caption: Unequal Tensors
52+
53+
# doc: print-locals out
54+
out = tp.equal(tp.Tensor([1, 2, 3]), tp.Tensor([1, 2, 4]))
55+
assert not out
56+
"""
57+
from tripy.frontend.trace.ops.reduce import all
58+
59+
return bool(all(a == b))

tripy/tripy/frontend/tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
attr is not None
126126
for attr in [
127127
self.trace_tensor.shape,
128+
self.trace_tensor.stride,
128129
self.trace_tensor.dtype,
129130
self.trace_tensor.device,
130131
self.trace_tensor.producer,
@@ -171,6 +172,10 @@ def dtype(self):
171172
def rank(self):
172173
return self.trace_tensor.rank
173174

175+
@property
176+
def stride(self):
177+
return self.trace_tensor.stride
178+
174179
def eval(self) -> runtime.MemRefValue:
175180
from tripy.backend.mlir.compiler import Compiler
176181
from tripy.backend.mlir.executor import Executor
@@ -228,7 +233,7 @@ def __repr__(self) -> str:
228233
return (
229234
f"tensor({sep}"
230235
f"{indent(arr_str, prefix=indentation)}, {sep}"
231-
f"{indent(f'dtype={self.dtype}, loc={self.device}, shape={arr.shape}', prefix=indentation)}"
236+
f"{indent(f'dtype={self.dtype}, loc={self.device}, shape={arr.shape}, stride={arr.strides}', prefix=indentation)}"
232237
f")"
233238
)
234239

tripy/tripy/frontend/trace/ops/binary_elementwise.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def to_flat_ir(self, inputs, outputs):
140140
# First apply DivideOp
141141
divide_out = FlatIRTensor.build(
142142
shape=outputs[0].shape,
143+
stride=outputs[0].stride,
143144
rank=outputs[0].rank,
144145
dtype=outputs[0].dtype,
145146
device=outputs[0].device,
@@ -152,6 +153,7 @@ def to_flat_ir(self, inputs, outputs):
152153
# Step 1: Perform DivideOp
153154
divide_out = FlatIRTensor.build(
154155
shape=outputs[0].shape,
156+
stride=outputs[0].stride,
155157
rank=outputs[0].rank,
156158
dtype=outputs[0].dtype,
157159
device=outputs[0].device,
@@ -162,6 +164,7 @@ def to_flat_ir(self, inputs, outputs):
162164
# Step 2: Apply FloorOp
163165
floor_out = FlatIRTensor.build(
164166
shape=outputs[0].shape,
167+
stride=outputs[0].stride,
165168
rank=outputs[0].rank,
166169
dtype=outputs[0].dtype,
167170
device=outputs[0].device,
@@ -172,6 +175,7 @@ def to_flat_ir(self, inputs, outputs):
172175
# Step 3: Multiply divisor with floored division result (FloorOp output)
173176
multiply_out = FlatIRTensor.build(
174177
shape=outputs[0].shape,
178+
stride=outputs[0].stride,
175179
rank=outputs[0].rank,
176180
dtype=outputs[0].dtype,
177181
device=outputs[0].device,

0 commit comments

Comments
 (0)