Skip to content

Commit 2eee4cf

Browse files
Replaces tp.Compiler with tp.compile
Removes the `tp.Compiler` class and instead replaces it with just a standalone function, `tp.compile`. Also splits up the `backend/api` tests to reflect the structure of the code.
1 parent 928b541 commit 2eee4cf

File tree

19 files changed

+724
-694
lines changed

19 files changed

+724
-694
lines changed

tripy/README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ Tripy can also compile functions to generate efficient machine code for faster e
4848
def add(a, b):
4949
return a + b
5050

51-
compiler = tp.Compiler(add)
52-
5351
# When compiling, we need to specify shape and data type constraints on the inputs:
5452

5553
# a is a 1D dynamic shape tensor of shape (d,), where `d` can range from 1 to 5.
@@ -59,7 +57,7 @@ a_info = tp.InputInfo(shape=([1, 2, 5],), dtype=tp.float32)
5957
# `b` is a 1D tensor of shape (1,).
6058
b_info = tp.InputInfo((1,), dtype=tp.float32)
6159

62-
compiled_add = compiler.compile(a_info, b_info)
60+
compiled_add = tp.compile(add, args=[a_info, b_info])
6361

6462
print(compiled_add(tp.Tensor([1., 2., 3.]), tp.Tensor([3.])))
6563
```

tripy/docs/pre0_user_guides/00-introduction-to-tripy.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,19 @@ All the code we've seen so far has been using Tripy's eager mode. It is also pos
105105
functions or modules ahead of time, which can result in significantly better performance.
106106

107107
*Note that the compiler imposes some requirements on the functions/modules it can compile.*
108-
*See {class}`tripy.Compiler` for details.*
108+
*See {func}`tripy.compile` for details.*
109109

110110
Let's compile the MLP module we defined above as an example:
111111

112112
```py
113113
# doc: no-print-locals
114-
compiler = tp.Compiler(mlp)
115-
116114
# When we compile, we need to indicate which parameters to the function should be runtime inputs.
117115
# In this case, MLP takes a single input tensor for which we can specify our desired shape and datatype.
118-
fast_mlp = compiler.compile(tp.InputInfo(shape=(1, 2), dtype=tp.float32))
116+
fast_mlp = tp.compile(mlp, args=[tp.InputInfo(shape=(1, 2), dtype=tp.float32)])
119117
```
120118

121119
It is also possible to compile for a range of possible input shapes.
122-
See {func}`tripy.Compiler.compile` for details.
120+
See {func}`tripy.compile` for details.
123121

124122
Now let's benchmark the compiled version against eager mode:
125123
```py

tripy/docs/pre0_user_guides/02-compiler.md

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,9 @@ inp = tp.ones((1, 2))
2828
out = layer(inp)
2929
```
3030

31-
Now, let's try to optimize this model for inference using Tripy's {class}`tripy.Compiler`.
31+
Now, let's try to optimize this model for inference using Tripy's {func}`tripy.compile`.
3232

33-
First, let's initialize the compiler with the module we want to compile, `layer`,
34-
which lets the compiler know its properties, like the function signature.
35-
36-
```py
37-
# doc: no-print-locals
38-
compiler = tp.Compiler(layer)
39-
```
40-
41-
Next, we need to provide information about each input using {class}`tripy.InputInfo`.
33+
When we compile our module, we need to provide information about each input using {class}`tripy.InputInfo`.
4234
The first argument for `InputInfo` is `shape`, where we specify either the static or
4335
dynamic shape information for each dimension. In the example below, we assume the
4436
shape of `inp` is static (`(1, 2)`). The second argument specifies the `dtype` for the input:
@@ -51,7 +43,7 @@ Now, we can call the `compile` function to obtain a compiled function and use it
5143

5244
```py
5345
# doc: no-print-locals
54-
fast_geglu = compiler.compile(inp_info)
46+
fast_geglu = tp.compile(layer, args=[inp_info])
5547
fast_geglu(inp).eval()
5648
```
5749

@@ -67,7 +59,7 @@ and it should optimize for a size of 8.
6759
```py
6860
# doc: print-locals out out_change_shape
6961
inp_info = tp.InputInfo(shape=((1, 8, 16), 2), dtype=tp.float32)
70-
fast_geglu = compiler.compile(inp_info)
62+
fast_geglu = tp.compile(layer, args=[inp_info])
7163
out = fast_geglu(inp)
7264

7365
# Let's change the shape of input to (2, 2)
@@ -94,20 +86,23 @@ Saving an executable to disk:
9486

9587
```py
9688
# doc: no-print-locals
97-
import tempfile, os
98-
temp_dir = tempfile.mkdtemp()
99-
executable_file_path = os.path.join(temp_dir, "executable.json")
89+
import tempfile # doc: omit
90+
import os
91+
92+
out_dir = tempfile.mkdtemp() # doc: omit
93+
executable_file_path = os.path.join(out_dir, "executable.json")
10094
fast_geglu.save(executable_file_path)
10195
```
10296

10397
Reading an executable and running inference:
10498

10599
```py
106100
# doc: no-print-locals
107-
inp = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
108101
loaded_fast_geglu = tp.Executable.load(executable_file_path)
102+
103+
inp = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
109104
out = loaded_fast_geglu(inp)
110-
os.remove(executable_file_path)
105+
os.remove(executable_file_path) # doc: omit
111106
```
112107

113108
### Querying Executable Properties
@@ -134,9 +129,8 @@ def add_times_two(a, b):
134129
print(f"c : {c}")
135130
return c + a + b
136131

137-
compiler = tp.Compiler(add_times_two)
138132
inp_info = tp.InputInfo(shape=(1, 2), dtype=tp.float32)
139-
fast_myadd = compiler.compile(inp_info, inp_info)
133+
fast_myadd = tp.compile(add_times_two, args=[inp_info, inp_info])
140134
a = tp.Tensor([[1.0, 2.0]], dtype=tp.float32)
141135
b = tp.Tensor([[2.0, 3.0]], dtype=tp.float32)
142136

tripy/examples/nanogpt/example.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
21
#
3-
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
43
# SPDX-License-Identifier: Apache-2.0
54
#
65
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -99,14 +98,13 @@ def main():
9998

10099
# Compile the model before running inference.
101100
compile_start_time = time.perf_counter()
102-
compiler = tp.Compiler(model)
103101
input_shape = (
104102
1,
105103
# We can specify dynamic dimensions by using a sequence indicating the min/opt/max values that
106104
# a dimension should support:
107105
[1, len(input_ids), padded_seq_len],
108106
)
109-
model = compiler.compile(tp.InputInfo(input_shape, dtype=tp.int32))
107+
model = tp.compile(model, args=[tp.InputInfo(input_shape, dtype=tp.int32)])
110108
compile_end_time = time.perf_counter()
111109
print(f"Compilation took {compile_end_time - compile_start_time} seconds.")
112110

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
def add(a, b):
16+
return a + b
17+
18+
19+
def sub(a, b):
20+
return a - b
21+
22+
23+
def returns_non_tensor(a):
24+
return "not a tensor"
25+
26+
27+
def returns_nothing(a):
28+
return
29+
30+
31+
def returns_multiple_tensors(a, b):
32+
return a + b, a - b
33+
34+
35+
def variadic_positional(*args):
36+
pass
37+
38+
39+
def variadic_keyword(**kwargs):
40+
pass
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import cupy as cp
18+
import pytest
19+
from tests import helper
20+
from tests.backend.api.conftest import *
21+
22+
import tripy as tp
23+
24+
25+
class TestCompile:
26+
# TODO (#246): Verify that it's actually compiling somehow here and below.
27+
# Need to return something programatically queriable from compile to do this.
28+
def test_function(self):
29+
compiled_gelu = tp.compile(tp.relu, args=[tp.InputInfo((2, 2), dtype=tp.float32)])
30+
31+
inp = tp.ones((2, 2), dtype=tp.float32)
32+
out = compiled_gelu(inp)
33+
34+
# TODO (#225): Replace with tp.all
35+
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(tp.relu(inp)))
36+
37+
def test_module(self):
38+
layernorm = tp.LayerNorm(2)
39+
compiled_layernorm = tp.compile(layernorm, args=[tp.InputInfo((2, 2), dtype=tp.float32)])
40+
41+
inp = tp.ones((2, 2), dtype=tp.float32)
42+
out = compiled_layernorm(inp)
43+
44+
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(layernorm(inp)))
45+
46+
def test_compile_arg_order_irrelevant(self):
47+
# The order of arguments we specify to `compile` should not affect the order
48+
# of the arguments in the compiled function, which should just follow the order
49+
# of the original function.
50+
compiled_sub = tp.compile(
51+
sub, kwargs=dict(b=tp.InputInfo((2, 2), dtype=tp.float32), a=tp.InputInfo((2, 2), dtype=tp.float32))
52+
)
53+
54+
a = tp.ones((2, 2), dtype=tp.float32) * 2
55+
b = tp.ones((2, 2), dtype=tp.float32)
56+
57+
# Compiled function should still take arguments in (a, b) order.
58+
out = compiled_sub(a, b)
59+
assert cp.array_equal(cp.from_dlpack(out), cp.ones((2, 2), dtype=cp.float32))
60+
61+
@pytest.mark.parametrize("b", [2, tp.ones((2, 2), dtype=tp.float32) * 2])
62+
def test_constants_baked(self, b):
63+
# Any non-InputInfo argument to compile is baked into the compiled function.
64+
compiled_add = tp.compile(add, args=[tp.InputInfo((2, 2), dtype=tp.float32), b])
65+
66+
a = tp.zeros((2, 2), dtype=tp.float32)
67+
68+
out = compiled_add(a)
69+
70+
assert cp.array_equal(cp.from_dlpack(out), cp.ones((2, 2), dtype=cp.float32) * 2)
71+
72+
@pytest.mark.parametrize("func", [variadic_positional, variadic_keyword])
73+
def test_variadic_arguments_rejected(self, func):
74+
with helper.raises(tp.TripyException, "Variadic positional/keyword arguments are not currently supported."):
75+
tp.compile(func)
76+
77+
@pytest.mark.parametrize("func", [returns_non_tensor, returns_nothing])
78+
def test_invalid_return_rejected(self, func):
79+
with helper.raises(tp.TripyException, "Function must return 1 or more Tensors"):
80+
tp.compile(func, args=[tp.InputInfo((2, 2), dtype=tp.float32)])
81+
82+
def test_multiple_return_values(self):
83+
compiled_func = tp.compile(
84+
returns_multiple_tensors,
85+
args=[tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)],
86+
)
87+
88+
a = tp.ones((2, 2), dtype=tp.float32) * 2
89+
b = tp.ones((2, 2), dtype=tp.float32)
90+
91+
plus, minus = compiled_func(a, b)
92+
93+
assert cp.array_equal(cp.from_dlpack(plus), cp.ones((2, 2), dtype=cp.float32) * 3)
94+
assert cp.array_equal(cp.from_dlpack(minus), cp.ones((2, 2), dtype=cp.float32))
95+
96+
def test_incorrect_dtype_rejected(self):
97+
a = tp.ones((2, 2), dtype=tp.int32)
98+
99+
with helper.raises(tp.TripyException, "Unexpected tensor data type.", has_stack_info_for=[a]):
100+
compiled_add = tp.compile(
101+
add, args=[tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)]
102+
)
103+
compiled_add(a, a)
104+
105+
def test_incorrect_shape_rejected(self):
106+
a = tp.ones((1, 2), dtype=tp.float32)
107+
108+
with helper.raises(tp.TripyException, "Unexpected tensor shape.", has_stack_info_for=[a]):
109+
compiled_add = tp.compile(
110+
add, args=[tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)]
111+
)
112+
compiled_add(a, a)
113+
114+
@pytest.mark.skip("TODO (#155): Re-enable once we no longer implicitly copy inputs to device")
115+
def test_incorrect_device_rejected(self):
116+
compiled_add = tp.compile(
117+
add, args=[tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)]
118+
)
119+
a = tp.copy(tp.ones((2, 2), dtype=tp.float32), device=tp.device("cpu"))
120+
121+
with helper.raises(tp.TripyException):
122+
compiled_add(a, a)
123+
124+
# TODO (#244): Add multi-profile test
125+
def test_dynamic_shapes(self):
126+
compiled_add = tp.compile(
127+
add, args=[tp.InputInfo(((1, 2, 3), 1), dtype=tp.float32), tp.InputInfo(((1, 2, 3), 1), dtype=tp.float32)]
128+
)
129+
130+
out = compiled_add(tp.ones((2, 1), dtype=tp.float32), tp.ones((2, 1), dtype=tp.float32))
131+
assert cp.array_equal(cp.from_dlpack(out), cp.ones((2, 1), dtype=cp.float32) * 2)
132+
133+
out = compiled_add(tp.ones((3, 1), dtype=tp.float32), tp.ones((3, 1), dtype=tp.float32))
134+
assert cp.array_equal(cp.from_dlpack(out), cp.ones((3, 1), dtype=cp.float32) * 2)
135+
136+
137+
# TODO (#256): Remove these tests and replace with exhaustive integration testing
138+
class TestCompiledOps:
139+
def test_cast(self):
140+
compiled_cast = tp.compile(tp.cast, args=[tp.InputInfo((2, 2), dtype=tp.float32)], kwargs=dict(dtype=tp.int32))
141+
142+
a = tp.ones((2, 2), dtype=tp.float32)
143+
out = compiled_cast(a)
144+
145+
assert cp.array_equal(cp.from_dlpack(out), cp.ones((2, 2), dtype=cp.int32))
146+
147+
def test_linear(self):
148+
linear = tp.Linear(2, 3)
149+
150+
compiled_linear = tp.compile(linear, args=[tp.InputInfo((2, 2), dtype=tp.float32)])
151+
152+
a = tp.ones((2, 2), dtype=tp.float32)
153+
154+
out = compiled_linear(a)
155+
156+
assert cp.array_equal(cp.from_dlpack(out), cp.from_dlpack(linear(a)))

0 commit comments

Comments
 (0)