Skip to content

Commit 87ecff6

Browse files
committed
improve tests, review fixes
1 parent 48b5d7d commit 87ecff6

File tree

3 files changed

+105
-86
lines changed

3 files changed

+105
-86
lines changed

tripy/nvtripy/backend/api/compile.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def add(a, b):
156156
trace_input_map = {}
157157
input_names = set()
158158
input_infos = {}
159+
trace_inputs = [] # flattened list of trace input tensors in argument order
159160

160161
# Set up names for the weights in the module to make the trace easier to read.
161162
if isinstance(func, Module):
@@ -184,6 +185,7 @@ def process_arg_input_info(name, arg):
184185

185186
trace_input_map[name] = tensor
186187
input_names.add(name)
188+
trace_inputs.append(tensor.trace_tensor)
187189

188190
return tensor
189191

@@ -199,35 +201,44 @@ def process_arg_input_info(name, arg):
199201

200202
trace_input_map[name] = tensor
201203
input_names.add(name)
204+
trace_inputs.append(tensor.trace_tensor)
202205

203206
return tensor
204207

205208
return arg
206209

207-
def process_arg(name, arg):
210+
def process_arg_and_flag(name, arg):
208211
# Handle individual InputInfo or DimensionInputInfo objects
209212
if isinstance(arg, (InputInfo, DimensionInputInfo)):
210-
return process_arg_input_info(name, arg)
213+
return process_arg_input_info(name, arg), True
211214

212215
# Handle containers of InputInfo objects
213216
if isinstance(arg, dict):
214-
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()):
215-
input_names.add(name)
216-
result = {}
217-
for key, value in arg.items():
218-
nested_name = f"{name}.{key}"
219-
result[key] = process_arg(nested_name, value)
220-
return result
217+
result = {}
218+
has_input = False
219+
for key, value in arg.items():
220+
nested_name = f"{name}.{key}"
221+
processed_child, child_has_input = process_arg_and_flag(nested_name, value)
222+
result[key] = processed_child
223+
has_input = has_input or child_has_input
224+
return result, has_input
221225
elif isinstance(arg, (list, tuple)):
222-
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg):
223-
input_names.add(name)
224-
result = []
225-
for idx, value in enumerate(arg):
226-
nested_name = f"{name}[{idx}]"
227-
result.append(process_arg(nested_name, value))
228-
return type(arg)(result)
226+
result_list = []
227+
has_input = False
228+
for idx, value in enumerate(arg):
229+
nested_name = f"{name}[{idx}]"
230+
processed_child, child_has_input = process_arg_and_flag(nested_name, value)
231+
result_list.append(processed_child)
232+
has_input = has_input or child_has_input
233+
return type(arg)(result_list), has_input # preserve sequence type
229234

230-
return arg
235+
return arg, False
236+
237+
def process_arg(name, arg):
238+
processed, has_input = process_arg_and_flag(name, arg)
239+
if has_input:
240+
input_names.add(name)
241+
return processed
231242

232243
compiled_arg_names = []
233244

@@ -283,25 +294,7 @@ def process_arg(name, arg):
283294
[f"Return value {index} was not a tensor: {repr(trace_out)}"],
284295
)
285296

286-
# Order of trace inputs also needs to match that of the compiled_arg_names
287-
# For containers, we need to collect all individual trace tensors
288-
def collect_trace_tensors(name):
289-
"""Collect trace tensors for a name, flattening containers."""
290-
if name in trace_input_map:
291-
# Regular InputInfo or DimensionInputInfo
292-
return [trace_input_map[name].trace_tensor]
293-
else:
294-
# Collect all nested trace tensors inside the container
295-
nested_tensors = []
296-
for nested_name in sorted(trace_input_map.keys()):
297-
if nested_name.startswith(f"{name}.") or nested_name.startswith(f"{name}["):
298-
nested_tensors.append(trace_input_map[nested_name].trace_tensor)
299-
return nested_tensors
300-
301-
# Flatten all trace tensors from containers and individual inputs
302-
trace_inputs = []
303-
for name in compiled_arg_names:
304-
trace_inputs.extend(collect_trace_tensors(name))
297+
# We collected flattened trace inputs during traversal
305298
trace = Trace(
306299
[tensor.trace_tensor for tensor in trace_outputs],
307300
trace_inputs,

tripy/nvtripy/backend/api/executable.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -195,40 +195,44 @@ def add(a, b):
195195
],
196196
)
197197

198-
# Recursively extract inputs from containers to get individual tensors for validation and execution
198+
# Recursively build a name->tensor map
199199
def extract_inputs(tensors, input_info_names):
200-
def extract_recursive(tensor, name_prefix):
201-
if isinstance(tensor, dict):
202-
result = []
203-
for key in sorted(tensor.keys()):
200+
name_to_tensor = {}
201+
202+
def extract_recursive(value, name_prefix):
203+
if name_prefix in input_info_names:
204+
name_to_tensor[name_prefix] = value
205+
return
206+
if isinstance(value, dict):
207+
for key, item in value.items():
204208
nested_name = f"{name_prefix}.{key}"
205-
if nested_name in input_info_names:
206-
result.append(tensor[key])
207-
else:
208-
result.extend(extract_recursive(tensor[key], nested_name))
209-
return result
210-
elif isinstance(tensor, (list, tuple)):
211-
result = []
212-
for idx, value in enumerate(tensor):
209+
extract_recursive(item, nested_name)
210+
elif isinstance(value, (list, tuple)):
211+
for idx, item in enumerate(value):
213212
nested_name = f"{name_prefix}[{idx}]"
214-
if nested_name in input_info_names:
215-
result.append(value)
216-
else:
217-
result.extend(extract_recursive(value, nested_name))
218-
return result
219-
else: # Regular tensor
220-
if name_prefix in input_info_names:
221-
return [tensor]
222-
else:
223-
return []
224-
225-
flattened = []
213+
extract_recursive(item, nested_name)
214+
else:
215+
print(f"Leaf tensor: {name_prefix}: {value}")
216+
return
217+
226218
for name_idx, tensor in enumerate(tensors):
227219
arg_name = self._arg_names[name_idx]
228-
flattened.extend(extract_recursive(tensor, arg_name))
229-
return flattened
220+
extract_recursive(tensor, arg_name)
221+
222+
return name_to_tensor
230223

231-
flattened_tensors = extract_inputs(input_tensors, set(self.input_infos.keys()))
224+
input_info_names = list(self.input_infos.keys())
225+
name_to_tensor = extract_inputs(input_tensors, set(input_info_names))
226+
try:
227+
flattened_tensors = [name_to_tensor[name] for name in input_info_names]
228+
except KeyError as missing:
229+
raise_error(
230+
f"Missing runtime tensor for input `{missing.args[0]}`.",
231+
[
232+
"Ensure your provided containers include tensors for all compiled inputs.",
233+
f"Expected inputs: {input_info_names}",
234+
],
235+
)
232236
expected_devices = ["gpu" if isinstance(info, InputInfo) else "cpu" for info in self.input_infos.values()]
233237

234238
# Validate flattened tensors against input_infos

tripy/tests/backend/api/test_compile.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -242,44 +242,55 @@ def test_dimension_input(self):
242242
expected = (inp_cp + inp_cp).reshape((-1, reshape_dim))
243243
assert cp.array_equal(cp.from_dlpack(out), expected)
244244

245-
def test_compile_dict_input_info(self):
246-
"""Test compilation with dictionary of InputInfo objects."""
247-
245+
def test_compile_nested_dict_input_info(self):
248246
def func(data_dict):
249-
return data_dict["a"] + data_dict["b"]
247+
return data_dict["a"]["inner"] + data_dict["b"]["list"][0] + data_dict["b"]["list"][1]
250248

251249
dict_input = {
252-
"a": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
253-
"b": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
250+
"a": {
251+
"inner": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
252+
},
253+
"b": {
254+
"list": [
255+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
256+
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
257+
],
258+
},
254259
}
255260
compiled_func = tp.compile(func, args=[dict_input])
256261

257-
test_dict = {"a": tp.ones((2, 3), dtype=tp.float32).eval(), "b": (tp.ones((2, 3), dtype=tp.float32) * 2).eval()}
262+
test_dict = {
263+
"a": {"inner": tp.ones((2, 3), dtype=tp.float32).eval()},
264+
"b": {
265+
"list": [
266+
(tp.ones((2, 3), dtype=tp.float32) * 2).eval(),
267+
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
268+
]
269+
},
270+
}
258271
result = compiled_func(test_dict)
259-
expected = test_dict["a"] + test_dict["b"]
272+
expected = test_dict["a"]["inner"] + test_dict["b"]["list"][0] + test_dict["b"]["list"][1]
260273
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))
261274

262-
def test_compile_nested_list_input_info(self):
263-
"""Test compilation with nested list containers."""
264-
275+
def test_compile_nested_sequence_input_info(self):
265276
def func(data_list):
266277
return data_list[0] + data_list[1][0] + data_list[1][1]
267278

268279
list_input = [
269280
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
270-
[ # Nested list
281+
[
271282
tp.InputInfo(shape=(2, 3), dtype=tp.float32),
272-
tp.ones((2, 3), dtype=tp.float32) * 2, # Constant in nested list
283+
tp.ones((2, 3), dtype=tp.float32) * 2,
273284
],
274285
]
275286
compiled_func = tp.compile(func, args=[list_input])
276287

277288
test_list = [
278289
tp.ones((2, 3), dtype=tp.float32).eval(),
279-
[ # Nested list in test data
290+
(
280291
(tp.ones((2, 3), dtype=tp.float32) * 3).eval(),
281-
tp.ones((2, 3), dtype=tp.float32) * 2, # Should match baked constant
282-
],
292+
tp.ones((2, 3), dtype=tp.float32) * 2,
293+
),
283294
]
284295
result = compiled_func(test_list)
285296
expected = test_list[0] + test_list[1][0] + test_list[1][1]
@@ -288,24 +299,35 @@ def func(data_list):
288299
def test_compile_mixed_containers_and_constants(self):
289300
"""Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant."""
290301

291-
def func(regular_input, data_dict, data_list, constant_value):
292-
return regular_input + data_dict["x"] + data_dict["y"] + data_list[0] + data_list[1] + constant_value
302+
def func(regular_input, data_dict, data_list, const_in_dict, const):
303+
return (
304+
regular_input
305+
+ data_dict["x"]
306+
+ data_dict["y"]
307+
+ data_list[0]
308+
+ data_list[1]
309+
+ const_in_dict["z"]
310+
+ const
311+
)
293312

294313
regular_input = tp.InputInfo(shape=(2, 3), dtype=tp.float32)
295314
dict_input = {
296315
"x": tp.InputInfo(shape=(2, 3), dtype=tp.float32),
297-
"y": tp.zeros((2, 3), dtype=tp.float32), # Constant in dict
316+
"y": tp.zeros((2, 3), dtype=tp.float32),
298317
}
299318
list_input = [tp.InputInfo(shape=(2, 3), dtype=tp.float32), tp.ones((2, 3), dtype=tp.float32) * 3]
300-
constant_value = tp.ones((2, 3), dtype=tp.float32) * 5
319+
const_in_dict = {"z": tp.ones((2, 3), dtype=tp.float32) * 5}
320+
const = tp.ones((2, 3), dtype=tp.float32) * 6
301321

302-
compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, constant_value])
322+
compiled_func = tp.compile(func, args=[regular_input, dict_input, list_input, const_in_dict, const])
303323

304324
# Only InputInfo arguments should be in function signature
305325
test_regular = tp.ones((2, 3), dtype=tp.float32).eval()
306326
test_dict = {"x": (tp.ones((2, 3), dtype=tp.float32) * 2).eval(), "y": tp.zeros((2, 3), dtype=tp.float32)}
307327
test_list = [(tp.ones((2, 3), dtype=tp.float32) * 4).eval(), tp.ones((2, 3), dtype=tp.float32) * 3]
308328

309329
result = compiled_func(test_regular, test_dict, test_list)
310-
expected = test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + constant_value
330+
expected = (
331+
test_regular + test_dict["x"] + test_dict["y"] + test_list[0] + test_list[1] + const_in_dict["z"] + const
332+
)
311333
assert cp.array_equal(cp.from_dlpack(result), cp.from_dlpack(expected))

0 commit comments

Comments
 (0)