Skip to content

Commit 5c412aa

Browse files
Mgluhovskoipranavm-nvidiaparthchadha
authored
Spec Verification (NVIDIA#25)
1) Create a standard for doc-strings dtypes 2) Automatically verify doc-strings' dtype - negative test any dtypes that are not supported 3) Integrate verification into test pipeline (L1 for now) 4) Add readme file to explain how to use verifier/decorator Side task: Add support for several dtypes within cast. --------- Signed-off-by: Mgluhovskoi <[email protected]> Co-authored-by: pranavm-nvidia <[email protected]> Co-authored-by: Parth Chadha <[email protected]>
1 parent 35473f3 commit 5c412aa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+907
-344
lines changed

tripy/docs/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ To view the documentation, you can open `build/docs/index.html` in a browser.
1818
The `export.public_api()` decorator allows you to specify metadata for documentation
1919
generation, such as where in the documentation hierarchy the API should be documented.
2020

21+
The `constraints.dtype_info()` decorator verifies the data types a function claims to support and generates
22+
corresponding documentation. For more information, see [this guide](../tests/spec_verification/README.md).
23+
2124
The `generate_rsts.py` script uses this information to automatically generate a directory
2225
structure and populate it with `.rst` files.
2326

tripy/docs/conf.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tests import helper
2828

2929
import tripy as tp
30-
from tripy.dtype_info import TYPE_VERIFICATION
30+
from tripy.constraints import TYPE_VERIFICATION, FUNC_W_DOC_VERIF
3131

3232
PARAM_PAT = re.compile(":param .*?:")
3333

@@ -161,12 +161,10 @@ def process_docstring(app, what, name, obj, options, lines):
161161
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
162162
pname = "*" + pname
163163

164-
if pname == "self":
165-
# Don't want a type annotation for the self parameter.
164+
if pname != "self" or obj.__qualname__ in FUNC_W_DOC_VERIF:
166165
assert (
167-
param.annotation == signature.empty
168-
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation!"
169-
else:
166+
pname in documented_args
167+
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args} {doc}"
170168
assert (
171169
pname in documented_args
172170
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args}"
@@ -179,6 +177,10 @@ def process_docstring(app, what, name, obj, options, lines):
179177
assert not inspect.ismodule(
180178
param.annotation
181179
), f"Type annotation cannot be a module, but got: '{param.annotation}' for parameter: '{pname}' in: '{obj}'. Please specify a type!"
180+
else:
181+
assert (
182+
param.annotation == signature.empty
183+
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation! Note: Documented parameters were: {documented_args} {doc}"
182184

183185
assert signature.return_annotation != signature.empty, (
184186
f"Missing return type annotation for: '{obj}'. "
@@ -190,39 +192,50 @@ def process_docstring(app, what, name, obj, options, lines):
190192
":returns:" in doc
191193
), f"For: {obj}, return value is not documented. Please ensure you've included a `Returns:` section"
192194

193-
# new docstring logic:
194-
# first figure out if we should it is the new docstring
195-
if name.split(".")[-1] in TYPE_VERIFICATION.keys():
196-
cleaned_name = name.split(".")[-1]
195+
# New docstring logic:
196+
# First figure out if object is using the @constraints.dtype_info decorator.
197+
unqual_name = name.split(".")[-1]
198+
if unqual_name in TYPE_VERIFICATION.keys():
197199
add_text_index = -1
198200
for index, block in enumerate(blocks):
199201
if re.search(r".. code-block::", block):
200-
type_dict = TYPE_VERIFICATION[cleaned_name][1]["types"]
202+
type_dict = TYPE_VERIFICATION[unqual_name].dtypes
201203
blocks.insert(index, "Type Constraints:")
202204
index += 1
205+
# Add the dtype constraint name and the dtypes that correlate.
203206
for type_name, dt in type_dict.items():
204-
blocks.insert(index, f" - {type_name}: " + ", ".join(dt))
207+
blocks.insert(
208+
index,
209+
f" - **{type_name}**: :class:`" + "`, :class:`".join(set(dt)) + "`",
210+
)
205211
index += 1
206212
blocks.insert(index, "\n")
213+
if TYPE_VERIFICATION[unqual_name].dtype_exceptions != []:
214+
# Add the dtype exceptions.
215+
index += 1
216+
blocks.insert(index, "**Unsupported Type Combinations**:")
217+
dtype_exception_text = []
218+
for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions:
219+
dtype_exception_text.append(
220+
", ".join([f"{key}: :class:`{val}`" for key, val in exception_dict.items()])
221+
)
222+
dtype_exception_text = "; ".join(dtype_exception_text) + "\n"
223+
index += 1
224+
blocks.insert(index, dtype_exception_text)
207225
break
208226
if re.search(r":param \w+: ", block):
209-
add_text_index = re.search(r":param \w+: ", block).span()[1]
210227
param_name = re.match(r":param (\w+): ", block).group(1)
211-
blocks[index] = (
212-
block[0:add_text_index]
213-
+ "[dtype="
214-
+ TYPE_VERIFICATION[cleaned_name][2][param_name]
215-
+ "] "
216-
+ block[add_text_index:]
217-
)
228+
# Add dtype constraint to start of each parameter description.
229+
if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None):
230+
add_text_index = re.search(r":param \w+: ", block).span()[1]
231+
blocks[index] = (
232+
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
233+
)
218234
if re.search(r":returns:", block):
219235
add_text_index = re.search(r":returns:", block).span()[1] + 1
236+
# Add dtype constraint to start of returns description.
220237
blocks[index] = (
221-
block[0:add_text_index]
222-
+ "[dtype="
223-
+ list(TYPE_VERIFICATION[cleaned_name][1]["returns"].values())[0]["dtype"]
224-
+ "] "
225-
+ block[add_text_index:]
238+
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}"
226239
)
227240

228241
seen_classes.add(name)

tripy/tests/common/test_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import tripy.common.datatype
2727

2828
from tests import helper
29-
from tripy.common.datatype import DATA_TYPES
3029
from tripy.common.exception import TripyException
3130
from tripy.common.utils import (
3231
convert_list_to_array,

tripy/tests/frontend/trace/ops/test_where.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
import pytest
1918
import re
2019
import tripy as tp
2120
from tests import helper

tripy/tests/integration/test_conv_transpose.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
from collections.abc import Sequence
1919
from dataclasses import dataclass
2020

21-
import cupy as cp
2221
import pytest
2322
import torch
2423

2524
import tripy as tp
26-
from tests import helper
2725

2826
DTYPES = [
2927
(torch.float16, tp.float16),

tripy/tests/integration/test_quantize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
import cupy as cp
19-
import numpy as np
18+
2019
import pytest
2120
import re
2221
import torch
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Introduction
2+
3+
Spec verification is designed to ensure that the datatypes documented in the operator documentation are accurate.
4+
5+
# How to Verify an Operation
6+
7+
To run the verification program on an operation, add the decorator `@constraints.dtype_info` to the operation. The inputs to the decorator will help the verifier determine the constraints on the inputs and the datatypes to verify.
8+
9+
To learn more about how to use `@constraints.dtype_info` check out `tripy/constraints.py`, `tests/spec_verification/test_dtype_constraints.py`, and `tests/spec_verification/object_builders.py`.
10+
11+
After the decorator is set up, it will automatically run verification test cases alongside the regular test cases. If you only want to run the verifier, execute `pytest -s -v` within the tests/spec_verification folder.

tripy/tests/spec_verification/object_builders.py

Lines changed: 111 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,128 @@
1616
#
1717

1818
import tripy as tp
19-
import math
2019

20+
from typing import Union, Optional, get_origin, get_args, ForwardRef, List
21+
from tripy.common import datatype
22+
import inspect
2123

22-
def tensor_builder(func_obj, input_values, namespace):
23-
shape = input_values.get("shape", None)
24-
if not shape:
25-
shape = (3, 2)
26-
return tp.ones(dtype=namespace[input_values["dtype"]], shape=shape)
2724

25+
def tensor_builder(init, dtype, namespace):
26+
if init is None:
27+
return tp.ones(dtype=namespace[dtype], shape=(3, 2))
28+
elif not isinstance(init, tp.Tensor):
29+
assert dtype == None
30+
return init
31+
return tp.cast(init, dtype=namespace[dtype])
2832

29-
def shape_tensor_builder(func_obj, input_values, namespace):
30-
follow_tensor = input_values.get("follow_tensor", None)
31-
return (math.prod((namespace[follow_tensor]).shape.tolist()),)
3233

33-
34-
def dtype_builder(func_obj, input_values, namespace):
35-
dtype = input_values.get("dtype", None)
34+
def dtype_builder(init, dtype, namespace):
3635
return namespace[dtype]
3736

3837

39-
def int_builder(func_obj, input_values, namespace):
40-
return input_values.get("value", None)
38+
def tensor_list_builder(init, dtype, namespace):
39+
if init is None:
40+
return [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
41+
else:
42+
return [tp.cast(tens, dtype=namespace[dtype]) for tens in init]
43+
44+
45+
def device_builder(init, dtype, namespace):
46+
if init is None:
47+
return tp.device("gpu")
48+
return init
49+
50+
51+
def default_builder(init, dtype, namespace):
52+
return init
4153

4254

4355
find_func = {
44-
"Tensor": tensor_builder,
45-
"shape_tensor": shape_tensor_builder,
46-
"dtype": dtype_builder,
47-
"int": int_builder,
56+
"tripy.Tensor": tensor_builder,
57+
"tripy.Shape": tensor_builder,
58+
"tripy.dtype": dtype_builder,
59+
datatype.dtype: dtype_builder,
60+
List[Union["tripy.Tensor"]]: tensor_list_builder,
61+
"tripy.device": device_builder,
62+
}
63+
64+
"""
65+
default_constraints_all: This dictionary helps set specific constraints and values for parameters. These constraints correspond to the type hint of each parameter.
66+
Some type have default values, so you might not need to pass other_constraints for every operation.
67+
If there is no default, you must specify an initialization value, or the testcase may fail.
68+
The dictionary's keys must be the name of the function that they are constraining and the value must be what the parameter should be initialized to.
69+
Here is the list of parameter types that have defaults or work differently from other types:
70+
- tensor - default: tp.ones(shape=(3,2)). If init is passed then value must be in the form of a list. Example: "scale": tp.Tensor([1,1,1]) or "scale": tp.ones((3,3))
71+
- dtype - default: no default. Dtype parameters will be set using dtype_constraints input so using default_constraints_all will not change anything.
72+
- list/sequence of tensors - default: [tp.ones((3,2)),tp.ones((3,2))]. Example: "dim": [tp.ones((2,4)),tp.ones((1,2))].
73+
This will create a list/sequence of tensors of size count and each tensor will follow the init and shape value similar to tensor parameters.
74+
- device - default: tp.device("gpu"). Example: {"device": tp.device("cpu")}.
75+
All other types do not have defaults and must be passed to the verifier using default_constraints_all.
76+
"""
77+
default_constraints_all = {
78+
"__rtruediv__": {"self": 1},
79+
"__rsub__": {"self": 1},
80+
"__radd__": {"self": 1},
81+
"__rpow__": {"self": 1},
82+
"__rmul__": {"self": 1},
83+
"softmax": {"dim": 1},
84+
"concatenate": {"dim": 0},
85+
"expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))},
86+
"full": {"shape": tp.Tensor([3]), "value": 1},
87+
"full_like": {"value": 1},
88+
"flip": {"dim": 1},
89+
"gather": {"dim": 0, "index": tp.Tensor([1])},
90+
"iota": {"shape": tp.Tensor([3])},
91+
"__matmul__": {"self": tp.ones((2, 3))},
92+
"transpose": {"dim0": 0, "dim1": 1},
93+
"permute": {"perm": [1, 0]},
94+
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
95+
"sum": {"dim": 0},
96+
"all": {"dim": 0},
97+
"any": {"dim": 0},
98+
"max": {"dim": 0},
99+
"prod": {"dim": 0},
100+
"mean": {"dim": 0},
101+
"var": {"dim": 0},
102+
"argmax": {"dim": 0},
103+
"argmin": {"dim": 0},
104+
"reshape": {"shape": tp.Tensor([6])},
105+
"squeeze": {"input": tp.ones((3, 1)), "dims": (1)},
106+
"__getitem__": {"index": 2},
107+
"split": {"indices_or_sections": 2},
108+
"unsqueeze": {"dim": 1},
109+
"masked_fill": {"value": 1},
110+
"ones": {"shape": tp.Tensor([3, 2])},
111+
"zeros": {"shape": tp.Tensor([3, 2])},
112+
"arange": {"start": 0, "stop": 5},
48113
}
49114

50115

51-
def create_obj(func_obj, param_name, input_desc, namespace):
52-
param_type = list(input_desc.keys())[0]
53-
create_obj_func = find_func[param_type]
54-
namespace[param_name] = create_obj_func(func_obj, input_desc[param_type], namespace)
55-
return namespace[param_name]
116+
def create_obj(func_obj, func_name, param_name, param_dtype, namespace):
117+
# If type is an optional or union get the first type.
118+
# Get names and type hints for each param.
119+
func_sig = inspect.signature(func_obj)
120+
param_dict = func_sig.parameters
121+
param_type_annot = param_dict[param_name]
122+
init = None
123+
# Check if there is a value in default_constraints_all for func_name and param_name and use it.
124+
default_constraints = default_constraints_all.get(func_name, None)
125+
if default_constraints != None:
126+
other_constraint = default_constraints.get(param_name, None)
127+
if other_constraint is not None:
128+
init = other_constraint
129+
# If parameter had a default then use it otherwise skip.
130+
if init is None and param_type_annot.default is not param_type_annot.empty:
131+
# Checking if not equal to None since default can be 0 or similar.
132+
if param_type_annot.default != None:
133+
init = param_type_annot.default
134+
param_type = param_type_annot.annotation
135+
while get_origin(param_type) in [Union, Optional]:
136+
param_type = get_args(param_type)[0]
137+
# ForwardRef refers to any case where type hint is a string.
138+
if isinstance(param_type, ForwardRef):
139+
param_type = param_type.__forward_arg__
140+
create_obj_func = find_func.get(param_type, default_builder)
141+
if create_obj_func:
142+
namespace[param_name] = create_obj_func(init, param_dtype, namespace)
143+
return namespace[param_name]

0 commit comments

Comments
 (0)