Skip to content

Commit 96b804b

Browse files
authored
Cast fp16 to fp32 in tolist, remove FP16MemoryView (#254)
1 parent e73ba59 commit 96b804b

File tree

9 files changed

+62
-185
lines changed

9 files changed

+62
-185
lines changed

tripy/tests/common/test_utils.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from tripy.common.exception import TripyException
3030
from tripy.common.utils import (
3131
convert_list_to_array,
32-
Float16MemoryView,
3332
get_element_type,
3433
)
3534

@@ -50,43 +49,11 @@ def test_get_element_type():
5049
"values, dtype, expected",
5150
[
5251
([True, False, True], tripy.common.datatype.bool, b"\x01\x00\x01"),
53-
([1, 2, 3], tripy.common.datatype.int8, b"\x01\x02\x03"),
5452
([100000, 200000], tripy.common.datatype.int32, b"\xa0\x86\x01\x00@\x0d\x03\x00"),
5553
([1, 2], tripy.common.datatype.int64, b"\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00"),
56-
([1.0, 2.0], tripy.common.datatype.float16, b"\x00<\x00@"),
5754
([1.0, 2.0], tripy.common.datatype.float32, b"\x00\x00\x80?\x00\x00\x00@"),
5855
],
5956
)
60-
def convert_list_to_array(values, dtype, expected):
57+
def test_convert_list_to_array(values, dtype, expected):
6158
result = convert_list_to_array(values, dtype)
62-
assert result == expected
63-
64-
65-
def test_float16_memoryview():
66-
memview = Float16MemoryView(bytearray(struct.pack("5e", 1.5, 2.5, 3.5, 4.5, 5.5)))
67-
assert memview.itemsize == 2
68-
assert memview.format == "e"
69-
len(memview) == 5
70-
assert memview[0] == pytest.approx(1.5)
71-
assert memview[2] == pytest.approx(3.5)
72-
assert memview[1:4] == pytest.approx([2.5, 3.5, 4.5])
73-
expected = [1.5, 2.5, 3.5, 4.5, 5.5]
74-
assert memview.tolist() == pytest.approx(expected)
75-
76-
# Largest representable value in float16
77-
large_value = 65504.0
78-
buffer = struct.pack("e", large_value)
79-
mv = Float16MemoryView(buffer)
80-
assert mv[0] == pytest.approx(large_value)
81-
82-
# Smallest positive normal number for float16
83-
small_value = 6.1035e-5
84-
buffer = struct.pack("e", small_value)
85-
mv = Float16MemoryView(buffer)
86-
assert mv[0] == pytest.approx(small_value, rel=1e-3)
87-
88-
# Negative value
89-
negative_value = -42.5
90-
buffer = struct.pack("e", negative_value)
91-
mv = Float16MemoryView(buffer)
92-
assert mv[0] == pytest.approx(negative_value)
59+
assert result.tobytes() == expected

tripy/tests/frontend/trace/test_trace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def test_multiple_outputs(self):
133133
str(trace)
134134
== dedent(
135135
"""
136-
a = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
137-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
136+
a = storage(shape=(1,), dtype=float32, device=gpu:0)
137+
b = storage(shape=(1,), dtype=float32, device=gpu:0)
138138
c = a + b
139139
d = c + c
140140
outputs:
@@ -192,7 +192,7 @@ def test_const_and_input(self):
192192
"""
193193
inputs:
194194
a: [rank=(1), shape=((1,)), dtype=(float32), loc=(gpu:0)]
195-
b = storage(data=[1.0000], shape=(1,), dtype=float32, device=gpu:0)
195+
b = storage(shape=(1,), dtype=float32, device=gpu:0)
196196
c = a + b
197197
outputs:
198198
c: [rank=(1), dtype=(float32), loc=(gpu:0)]

tripy/tripy/backend/mlir/memref.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -77,50 +77,4 @@ def tolist(memref):
7777
memref_value = memref
7878
if memref.address_space == runtime.PointerType.device:
7979
memref_value = mlir_utils.MLIRRuntimeClient().copy_to_host(device_memref=memref)
80-
try:
81-
return memoryview(memref_value).tolist()
82-
except NotImplementedError as e:
83-
if "memoryview: format e not supported" in str(e):
84-
assert memref_value.dtype == runtime.ScalarTypeCode.f16
85-
return common_utils.Float16MemoryView(bytearray(memref_value)).tolist()
86-
raise
87-
88-
89-
def pretty_print_memref(memref, threshold=1000, linewidth=10, edgeitems=3):
90-
"""
91-
Returns a pretty-print string of memref values.
92-
"""
93-
memref_shape = memref.shape
94-
95-
def _data_str(data, summarize, linewidth, edgeitems, indent=0):
96-
if isinstance(data, (float, int)):
97-
return str(data)
98-
99-
if len(data) == 0 or isinstance(data[0], (float, int)):
100-
if summarize and len(data) > 2 * edgeitems:
101-
data_lines = [data[:edgeitems] + [" ..."] + data[-edgeitems:]]
102-
else:
103-
data_lines = [data[i : i + linewidth] for i in range(0, len(data), linewidth)]
104-
lines = [", ".join([f"{e:.4f}" if isinstance(e, float) else str(e) for e in line]) for line in data_lines]
105-
return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
106-
107-
if summarize and len(data) > 2 * edgeitems:
108-
slices = (
109-
[_data_str(data[i], summarize, linewidth, edgeitems, indent + 1) for i in range(0, edgeitems)]
110-
+ ["..."]
111-
+ [
112-
_data_str(data[i], summarize, linewidth, edgeitems, indent + 1)
113-
for i in range(len(data) - edgeitems, len(data))
114-
]
115-
)
116-
else:
117-
slices = [_data_str(data[i], summarize, linewidth, edgeitems, indent + 1) for i in range(0, len(data))]
118-
119-
tensor_str = ("," + "\n" * (max(len(memref_shape) - indent - 1, 1)) + " " * (indent + 1)).join(slices)
120-
return "[" + tensor_str + "]"
121-
122-
numel = 1
123-
for d in memref_shape:
124-
numel *= d
125-
summarize = numel > threshold
126-
return _data_str(tolist(memref), summarize, linewidth, edgeitems)
80+
return memoryview(memref_value).tolist()

tripy/tripy/common/utils.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -85,70 +85,3 @@ def is_shape_empty(shape: Sequence[int]) -> bool:
8585
bool: True if the shape represents an empty tensor, False otherwise.
8686
"""
8787
return any(dim == 0 for dim in shape)
88-
89-
90-
class Float16MemoryView:
91-
"""
92-
A custom memory view class for handling float16 data.
93-
"""
94-
95-
def __init__(self, buffer):
96-
"""
97-
Initialize the Float16MemoryView with a buffer.
98-
99-
Args:
100-
buffer (buffer): The buffer containing float16 data.
101-
"""
102-
self.buffer = buffer
103-
self.itemsize = 2 # size of float16 in bytes
104-
self.format = "e" # format character for float16
105-
106-
def __getitem__(self, index):
107-
"""
108-
Get an item or a slice from the buffer.
109-
110-
Args:
111-
index (int or slice): The index or slice to retrieve.
112-
113-
Returns:
114-
float or list of floats: The float16 value(s) at the specified index or slice.
115-
"""
116-
if isinstance(index, slice):
117-
return [
118-
self._unpack(self.buffer[i * self.itemsize : i * self.itemsize + self.itemsize])
119-
for i in range(*index.indices(len(self)))
120-
]
121-
else:
122-
start = index * self.itemsize
123-
end = start + self.itemsize
124-
return self._unpack(self.buffer[start:end])
125-
126-
def _unpack(self, data):
127-
"""
128-
Unpack a float16 value from bytes.
129-
130-
Args:
131-
data (bytes): The bytes to unpack.
132-
133-
Returns:
134-
float: The unpacked float16 value.
135-
"""
136-
return struct.unpack(self.format, data)[0]
137-
138-
def __len__(self):
139-
"""
140-
Get the number of float16 values in the buffer.
141-
142-
Returns:
143-
int: The number of float16 values.
144-
"""
145-
return len(self.buffer) // self.itemsize
146-
147-
def tolist(self):
148-
"""
149-
Convert the buffer to a list of float16 values.
150-
151-
Returns:
152-
list: The list of float16 values.
153-
"""
154-
return [self[i] for i in range(len(self))]

tripy/tripy/flat_ir/flat_ir.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -352,23 +352,16 @@ def register_tensor(self, tensor: "FlatIRTensor") -> "FlatIRTensor":
352352
return tensor
353353

354354
def _get_constant_key(self, op):
355-
from mlir_tensorrt.runtime._mlir_libs._api import MemRefValue
356-
from tripy.utils.utils import list_to_tuple, volume
355+
from mlir_tensorrt.runtime.api import MemRefValue
356+
from tripy.utils.utils import list_to_tuple
357357

358358
if isinstance(op.data, MemRefValue):
359-
from tripy.backend.mlir.memref import tolist
360-
361-
VOLUME_THRESHOLD_FOR_MEMREF = 50
362-
if volume(op.data.shape) < VOLUME_THRESHOLD_FOR_MEMREF:
363-
l = tolist(op.data)
364-
else:
365-
l = [op.data.ptr]
366-
data = list_to_tuple(l if isinstance(l, List) else [l])
367-
elif isinstance(op.data, int) or isinstance(op.data, float) or isinstance(op.data, bool):
368-
data = list_to_tuple(
369-
op.data,
370-
)
359+
# use data pointer as key when data is a memref,
360+
# usually come from users, no need to deduplicate
361+
data = (op.data.ptr,)
371362
else:
363+
# small constants can be deduplicated
364+
# when data is a list
372365
data = list_to_tuple(op.data)
373366

374367
# Create a unique key for the constant based on its data and type

tripy/tripy/flat_ir/ops/constant.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ class ConstantOp(BaseFlatIROp):
3434
data: Union[runtime.MemRefValue, Sequence]
3535

3636
def str_skip_fields(self) -> Set[str]:
37-
data_shape = self.data.shape if isinstance(self.data, runtime.MemRefValue) else self.outputs[0].shape
38-
if utils.should_omit_constant_in_str(data_shape):
37+
if not isinstance(self.data, Sequence) or utils.should_omit_constant_in_str(self.outputs[0].shape):
3938
return {"data"}
4039
return set()
4140

tripy/tripy/frontend/tensor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def eval(self) -> runtime.MemRefValue:
201201
def tolist(self):
202202
data_memref = self.eval()
203203
if self.dtype not in (
204-
datatype.float16,
205204
datatype.float32,
206205
datatype.int8,
207206
datatype.int32,
@@ -217,18 +216,20 @@ def __iter__(self):
217216
raise TypeError("Iterating over tensors is not supported")
218217

219218
def __repr__(self) -> str:
220-
# The Evaluation required before accessing self.trace_tensor.producer attributes.
221-
arr = self.eval()
222-
arr_str = memref.pretty_print_memref(arr)
219+
from tripy.frontend.utils import pretty_print
220+
221+
data_list = self.tolist()
222+
data_shape = self.trace_tensor.producer.shape
223+
arr_str = pretty_print(data_list, data_shape)
223224
indentation = ""
224225
sep = ""
225-
if len(arr.shape) > 1 and any(dim > 1 for dim in arr.shape):
226+
if len(data_shape) > 1 and any(dim > 1 for dim in data_shape):
226227
indentation = " " * 4
227228
sep = "\n"
228229
return (
229230
f"tensor({sep}"
230231
f"{indent(arr_str, prefix=indentation)}, {sep}"
231-
f"{indent(f'dtype={self.dtype}, loc={self.device}, shape={arr.shape}', prefix=indentation)}"
232+
f"{indent(f'dtype={self.dtype}, loc={self.device}, shape={data_shape}', prefix=indentation)}"
232233
f")"
233234
)
234235

@@ -240,7 +241,7 @@ def __dlpack_device__(self):
240241
return self.eval().__dlpack_device__()
241242

242243
def __bool__(self):
243-
data = memref.tolist(self.eval())
244+
data = self.tolist()
244245
if any(dim != 1 for dim in self.trace_tensor.producer.shape):
245246
raise_error(
246247
"Boolean value of a Tensor with more than one value is ambiguous",

tripy/tripy/frontend/trace/ops/storage.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,11 @@ def __init__(
7373
infer_shape_output_idxs = op_utils.ShapeOutputIdxPolicies.never_return_shape
7474

7575
def str_skip_fields(self) -> Set[str]:
76-
if utils.should_omit_constant_in_str(self.shape):
76+
# skip data if i) it is a MemRefValue or ii) its volume exceeds threshold
77+
if not isinstance(self.data, Sequence) or utils.should_omit_constant_in_str(self.shape):
7778
return {"data"}
7879
return set()
7980

80-
def __str__(self) -> str:
81-
skip_fields = self.str_skip_fields()
82-
args = []
83-
if "data" not in skip_fields:
84-
data_str = memref.pretty_print_memref(self.data) if self.has_memref else str(self.data)
85-
args.append(f"data={data_str}")
86-
args.extend([f"{field}={getattr(self, field)}" for field in ("shape", "dtype", "device")])
87-
return f"{self.outputs[0].name} = storage({', '.join([inp.name for inp in self.inputs] + args)})"
88-
8981
def __eq__(self, other) -> bool:
9082
return self.data == other.data if isinstance(other, Storage) else False
9183

tripy/tripy/frontend/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,3 +633,41 @@ def process_dim(dim: int) -> int:
633633
return func(*args, **kwargs)
634634

635635
return wrapper
636+
637+
638+
def pretty_print(data_list, shape, threshold=1000, linewidth=10, edgeitems=3):
639+
"""
640+
Returns a pretty-print string of list format data.
641+
"""
642+
def _data_str(data, summarize, linewidth, edgeitems, indent=0):
643+
if isinstance(data, (float, int)):
644+
return str(data)
645+
646+
if len(data) == 0 or isinstance(data[0], (float, int)):
647+
if summarize and len(data) > 2 * edgeitems:
648+
data_lines = [data[:edgeitems] + [" ..."] + data[-edgeitems:]]
649+
else:
650+
data_lines = [data[i : i + linewidth] for i in range(0, len(data), linewidth)]
651+
lines = [", ".join([f"{e:.4f}" if isinstance(e, float) else str(e) for e in line]) for line in data_lines]
652+
return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
653+
654+
if summarize and len(data) > 2 * edgeitems:
655+
slices = (
656+
[_data_str(data[i], summarize, linewidth, edgeitems, indent + 1) for i in range(0, edgeitems)]
657+
+ ["..."]
658+
+ [
659+
_data_str(data[i], summarize, linewidth, edgeitems, indent + 1)
660+
for i in range(len(data) - edgeitems, len(data))
661+
]
662+
)
663+
else:
664+
slices = [_data_str(data[i], summarize, linewidth, edgeitems, indent + 1) for i in range(0, len(data))]
665+
666+
tensor_str = ("," + "\n" * (max(len(shape) - indent - 1, 1)) + " " * (indent + 1)).join(slices)
667+
return "[" + tensor_str + "]"
668+
669+
numel = 1
670+
for d in shape:
671+
numel *= d
672+
summarize = numel > threshold
673+
return _data_str(data_list, summarize, linewidth, edgeitems)

0 commit comments

Comments
 (0)