Skip to content

Commit a22524e

Browse files
stabilizes perf test
1 parent 64153e8 commit a22524e

File tree

4 files changed

+22
-36
lines changed

4 files changed

+22
-36
lines changed

tripy/tests/frontend/test_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
import numpy as np
2323
import pytest
2424
import torch
25+
from tests.conftest import DATA_TYPE_TEST_CASES
26+
from tests.helper import NUMPY_TO_TRIPY
2527

2628
import tripy as tp
27-
from tests.conftest import DATA_TYPE_TEST_CASES
28-
from tests.helper import NUMPY_TYPES, np_to_tripy_dtype
29-
from tripy.utils.stack_info import SourceInfo
3029
from tripy.common.utils import get_element_type
30+
from tripy.utils.stack_info import SourceInfo
3131

3232

3333
class TestTensor:
@@ -52,12 +52,12 @@ def test_tensor_device(self, kind):
5252
assert isinstance(a.trace_tensor.producer, tp.frontend.trace.ops.Storage)
5353
assert a.trace_tensor.producer.device.kind == kind
5454

55-
@pytest.mark.parametrize("dtype", NUMPY_TYPES)
55+
@pytest.mark.parametrize("dtype", NUMPY_TO_TRIPY.keys())
5656
def test_dtype_from_numpy(self, dtype):
5757

5858
np_array = np.array([1, 2, 3], dtype=dtype)
5959
tensor = tp.Tensor(np_array)
60-
tp_dtype = np_to_tripy_dtype(dtype)
60+
tp_dtype = NUMPY_TO_TRIPY[dtype]
6161
assert tensor.dtype == tp_dtype
6262

6363
def test_bool_tensor(self):

tripy/tests/helper.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,35 +102,22 @@ def check_mlir(mlir, expected):
102102

103103

104104
# Supported NumPy data types
105-
NUMPY_TYPES = [
106-
np.int8,
105+
NUMPY_TO_TRIPY = {
106+
bool: tp.bool,
107+
np.int8: tp.int8,
108+
np.int32: tp.int32,
109+
np.int64: tp.int64,
110+
np.float16: tp.float16,
111+
np.float32: tp.float32,
107112
# np.int16, # TODO(#247): Add support for int16
108-
np.int32,
109-
np.int64,
110113
# np.uint8, # TODO(#247): Add support for uint8
111114
# np.uint16, # TODO(#190): Add support for unsupported MLIR-TensorRT types.
112115
# np.uint32, # TODO(#190): Add support for unsupported MLIR-TensorRT types.
113116
# np.uint64, # TODO(#190): Add support for unsupported MLIR-TensorRT types.
114-
np.float16,
115-
np.float32,
116117
# np.float64, # TODO(#247): Add support for float64
117-
]
118-
119-
120-
def np_to_tripy_dtype(dtype):
121-
return {
122-
bool: tp.bool,
123-
np.int8: tp.int8,
124-
np.int32: tp.int32,
125-
np.int64: tp.int64,
126-
np.float16: tp.float16,
127-
np.float32: tp.float32,
128-
}[dtype]
129-
118+
}
130119

131-
def torch_type_supported(data: np.ndarray):
132-
unsupported_dtypes = [np.int16, np.uint16, np.uint32, np.uint64]
133-
return data.dtype not in unsupported_dtypes
120+
TRIPY_TO_NUMPY = {v: k for k, v in NUMPY_TO_TRIPY.items()}
134121

135122

136123
TORCH_DTYPES = {

tripy/tests/integration/test_cast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import tripy as tp
2323
from tests.conftest import skip_if_older_than_sm89
24-
from tests.helper import np_to_tripy_dtype
24+
from tests.helper import NUMPY_TO_TRIPY
2525

2626

2727
class TestCast:
@@ -50,8 +50,8 @@ class TestCast:
5050
],
5151
)
5252
def test_cast(self, input_dtype, target_dtype):
53-
tp_input_dtype = np_to_tripy_dtype(input_dtype)
54-
tp_target_dtype = np_to_tripy_dtype(target_dtype)
53+
tp_input_dtype = NUMPY_TO_TRIPY[input_dtype]
54+
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]
5555

5656
# TODO(#222): Integer casts with negative numbers fail in many cases
5757
input_tensor = tp.Tensor([0, 1, 2], dtype=tp_input_dtype)
@@ -71,7 +71,7 @@ def test_cast_quantized_dtypes_into_bool(self, source_dtype):
7171

7272
@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8])
7373
def test_cast_from_bool(self, target_dtype):
74-
tp_target_dtype = np_to_tripy_dtype(target_dtype)
74+
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]
7575

7676
# in principle, it is not important what *specific* values we convert to,
7777
# so long as false is mapped to 0 and true to nonzero

tripy/tests/performance/cases/linear_block.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import pytest
15+
import numpy as np
1616
import torch
17+
from tests.helper import TRIPY_TO_NUMPY
1718
from tests.performance.conftest import PerfParam, perf_fixture
1819

1920
import tripy as tp
@@ -33,10 +34,8 @@ class LinearBlock(tp.Module):
3334
def __init__(self):
3435
self.layers = [tp.Linear(256, 256, bias=False, dtype=tripy_dtype) for _ in range(NUM_LAYERS)]
3536
for layer in self.layers:
36-
# Adjust the weights to prevent FP16 overflows.
37-
weight = torch.tile(
38-
torch.tensor([[-1, 1], [1, -1]], dtype=torch_dtype, device=torch.device("cuda")), (128, 128)
39-
)
37+
# Adjust the weights to prevent FP16 overflows:
38+
weight = np.tile(np.array([[-1, 1], [1, -1]], dtype=TRIPY_TO_NUMPY[tripy_dtype]), (128, 128))
4039
layer.weight = tp.Parameter(weight)
4140

4241
def __call__(self, input):

0 commit comments

Comments
 (0)