|
16 | 16 | # |
17 | 17 |
|
18 | 18 | import tripy as tp |
19 | | -import math |
20 | 19 |
|
| 20 | +from typing import Union, Optional, get_origin, get_args, ForwardRef, List |
| 21 | +from tripy.common import datatype |
| 22 | +import inspect |
21 | 23 |
|
22 | | -def tensor_builder(func_obj, input_values, namespace): |
23 | | - shape = input_values.get("shape", None) |
24 | | - if not shape: |
25 | | - shape = (3, 2) |
26 | | - return tp.ones(dtype=namespace[input_values["dtype"]], shape=shape) |
27 | 24 |
|
| 25 | +def tensor_builder(init, dtype, namespace): |
| 26 | + if init is None: |
| 27 | + return tp.ones(dtype=namespace[dtype], shape=(3, 2)) |
| 28 | + elif not isinstance(init, tp.Tensor): |
| 29 | + assert dtype == None |
| 30 | + return init |
| 31 | + return tp.cast(init, dtype=namespace[dtype]) |
28 | 32 |
|
29 | | -def shape_tensor_builder(func_obj, input_values, namespace): |
30 | | - follow_tensor = input_values.get("follow_tensor", None) |
31 | | - return (math.prod((namespace[follow_tensor]).shape.tolist()),) |
32 | 33 |
|
33 | | - |
34 | | -def dtype_builder(func_obj, input_values, namespace): |
35 | | - dtype = input_values.get("dtype", None) |
| 34 | +def dtype_builder(init, dtype, namespace): |
36 | 35 | return namespace[dtype] |
37 | 36 |
|
38 | 37 |
|
39 | | -def int_builder(func_obj, input_values, namespace): |
40 | | - return input_values.get("value", None) |
| 38 | +def tensor_list_builder(init, dtype, namespace): |
| 39 | + if init is None: |
| 40 | + return [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)] |
| 41 | + else: |
| 42 | + return [tp.cast(tens, dtype=namespace[dtype]) for tens in init] |
| 43 | + |
| 44 | + |
| 45 | +def device_builder(init, dtype, namespace): |
| 46 | + if init is None: |
| 47 | + return tp.device("gpu") |
| 48 | + return init |
| 49 | + |
| 50 | + |
| 51 | +def default_builder(init, dtype, namespace): |
| 52 | + return init |
41 | 53 |
|
42 | 54 |
|
43 | 55 | find_func = { |
44 | | - "Tensor": tensor_builder, |
45 | | - "shape_tensor": shape_tensor_builder, |
46 | | - "dtype": dtype_builder, |
47 | | - "int": int_builder, |
| 56 | + "tripy.Tensor": tensor_builder, |
| 57 | + "tripy.Shape": tensor_builder, |
| 58 | + "tripy.dtype": dtype_builder, |
| 59 | + datatype.dtype: dtype_builder, |
| 60 | + List[Union["tripy.Tensor"]]: tensor_list_builder, |
| 61 | + "tripy.device": device_builder, |
| 62 | +} |
| 63 | + |
| 64 | +""" |
| 65 | +default_constraints_all: This dictionary helps set specific constraints and values for parameters. These constraints correspond to the type hint of each parameter. |
| 66 | +Some type have default values, so you might not need to pass other_constraints for every operation. |
| 67 | +If there is no default, you must specify an initialization value, or the testcase may fail. |
| 68 | +The dictionary's keys must be the name of the function that they are constraining and the value must be what the parameter should be initialized to. |
| 69 | +Here is the list of parameter types that have defaults or work differently from other types: |
| 70 | + - tensor - default: tp.ones(shape=(3,2)). If init is passed then value must be in the form of a list. Example: "scale": tp.Tensor([1,1,1]) or "scale": tp.ones((3,3)) |
| 71 | + - dtype - default: no default. Dtype parameters will be set using dtype_constraints input so using default_constraints_all will not change anything. |
| 72 | + - list/sequence of tensors - default: [tp.ones((3,2)),tp.ones((3,2))]. Example: "dim": [tp.ones((2,4)),tp.ones((1,2))]. |
| 73 | + This will create a list/sequence of tensors of size count and each tensor will follow the init and shape value similar to tensor parameters. |
| 74 | + - device - default: tp.device("gpu"). Example: {"device": tp.device("cpu")}. |
| 75 | +All other types do not have defaults and must be passed to the verifier using default_constraints_all. |
| 76 | +""" |
| 77 | +default_constraints_all = { |
| 78 | + "__rtruediv__": {"self": 1}, |
| 79 | + "__rsub__": {"self": 1}, |
| 80 | + "__radd__": {"self": 1}, |
| 81 | + "__rpow__": {"self": 1}, |
| 82 | + "__rmul__": {"self": 1}, |
| 83 | + "softmax": {"dim": 1}, |
| 84 | + "concatenate": {"dim": 0}, |
| 85 | + "expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))}, |
| 86 | + "full": {"shape": tp.Tensor([3]), "value": 1}, |
| 87 | + "full_like": {"value": 1}, |
| 88 | + "flip": {"dim": 1}, |
| 89 | + "gather": {"dim": 0, "index": tp.Tensor([1])}, |
| 90 | + "iota": {"shape": tp.Tensor([3])}, |
| 91 | + "__matmul__": {"self": tp.ones((2, 3))}, |
| 92 | + "transpose": {"dim0": 0, "dim1": 1}, |
| 93 | + "permute": {"perm": [1, 0]}, |
| 94 | + "quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, |
| 95 | + "sum": {"dim": 0}, |
| 96 | + "all": {"dim": 0}, |
| 97 | + "any": {"dim": 0}, |
| 98 | + "max": {"dim": 0}, |
| 99 | + "prod": {"dim": 0}, |
| 100 | + "mean": {"dim": 0}, |
| 101 | + "var": {"dim": 0}, |
| 102 | + "argmax": {"dim": 0}, |
| 103 | + "argmin": {"dim": 0}, |
| 104 | + "reshape": {"shape": tp.Tensor([6])}, |
| 105 | + "squeeze": {"input": tp.ones((3, 1)), "dims": (1)}, |
| 106 | + "__getitem__": {"index": 2}, |
| 107 | + "split": {"indices_or_sections": 2}, |
| 108 | + "unsqueeze": {"dim": 1}, |
| 109 | + "masked_fill": {"value": 1}, |
| 110 | + "ones": {"shape": tp.Tensor([3, 2])}, |
| 111 | + "zeros": {"shape": tp.Tensor([3, 2])}, |
| 112 | + "arange": {"start": 0, "stop": 5}, |
48 | 113 | } |
49 | 114 |
|
50 | 115 |
|
51 | | -def create_obj(func_obj, param_name, input_desc, namespace): |
52 | | - param_type = list(input_desc.keys())[0] |
53 | | - create_obj_func = find_func[param_type] |
54 | | - namespace[param_name] = create_obj_func(func_obj, input_desc[param_type], namespace) |
55 | | - return namespace[param_name] |
| 116 | +def create_obj(func_obj, func_name, param_name, param_dtype, namespace): |
| 117 | + # If type is an optional or union get the first type. |
| 118 | + # Get names and type hints for each param. |
| 119 | + func_sig = inspect.signature(func_obj) |
| 120 | + param_dict = func_sig.parameters |
| 121 | + param_type_annot = param_dict[param_name] |
| 122 | + init = None |
| 123 | + # Check if there is a value in default_constraints_all for func_name and param_name and use it. |
| 124 | + default_constraints = default_constraints_all.get(func_name, None) |
| 125 | + if default_constraints != None: |
| 126 | + other_constraint = default_constraints.get(param_name, None) |
| 127 | + if other_constraint is not None: |
| 128 | + init = other_constraint |
| 129 | + # If parameter had a default then use it otherwise skip. |
| 130 | + if init is None and param_type_annot.default is not param_type_annot.empty: |
| 131 | + # Checking if not equal to None since default can be 0 or similar. |
| 132 | + if param_type_annot.default != None: |
| 133 | + init = param_type_annot.default |
| 134 | + param_type = param_type_annot.annotation |
| 135 | + while get_origin(param_type) in [Union, Optional]: |
| 136 | + param_type = get_args(param_type)[0] |
| 137 | + # ForwardRef refers to any case where type hint is a string. |
| 138 | + if isinstance(param_type, ForwardRef): |
| 139 | + param_type = param_type.__forward_arg__ |
| 140 | + create_obj_func = find_func.get(param_type, default_builder) |
| 141 | + if create_obj_func: |
| 142 | + namespace[param_name] = create_obj_func(init, param_dtype, namespace) |
| 143 | + return namespace[param_name] |
0 commit comments