Skip to content

Commit 4168b81

Browse files
Adds shape propagation to Q/DQ ops
1 parent c31f48f commit 4168b81

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

tripy/nvtripy/trace/ops/cast.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
class Cast(TraceOp):
2727
dtype: "nvtripy.common.dtype"
2828

29-
def infer_rank(self):
30-
# Casting does not change the shape, so we can simply copy it.
31-
self.outputs[0].shape = self.inputs[0].shape
29+
infer_rank = op_utils.InferRankPolicies.same_shape_as_input()
3230

3331
def infer_dtypes(self):
3432
self.outputs[0].dtype = self.dtype

tripy/nvtripy/trace/ops/dequantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Dequantize(TraceOp):
2929
dtype: datatype.dtype
3030
dim: int
3131

32-
infer_rank = op_utils.InferRankPolicies.same_as_input()
32+
infer_rank = op_utils.InferRankPolicies.same_shape_as_input()
3333

3434
def infer_dtypes(self):
3535
self.outputs[0].dtype = self.dtype

tripy/nvtripy/trace/ops/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Quantize(TraceOp):
2929
dtype: datatype.dtype
3030
dim: int
3131

32-
infer_rank = op_utils.InferRankPolicies.same_as_input()
32+
infer_rank = op_utils.InferRankPolicies.same_shape_as_input()
3333

3434
def infer_dtypes(self):
3535
self.outputs[0].dtype = self.dtype

tripy/nvtripy/trace/ops/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323

2424

2525
class InferRankPolicies:
26+
# Indicates that the output has not only the same rank but also the same shape as the input.
27+
def same_shape_as_input(idx=0):
28+
def impl(self):
29+
self.outputs[0].shape = self.inputs[idx].shape
30+
31+
return impl
32+
2633
def same_as_input(idx=0):
2734
def impl(self):
2835
self.outputs[0].rank = self.inputs[idx].rank

0 commit comments

Comments
 (0)