File tree Expand file tree Collapse file tree 4 files changed +10
-5
lines changed Expand file tree Collapse file tree 4 files changed +10
-5
lines changed Original file line number Diff line number Diff line change 2626class Cast (TraceOp ):
2727 dtype : "nvtripy.common.dtype"
2828
29- def infer_rank (self ):
30- # Casting does not change the shape, so we can simply copy it.
31- self .outputs [0 ].shape = self .inputs [0 ].shape
29+ infer_rank = op_utils .InferRankPolicies .same_shape_as_input ()
3230
3331 def infer_dtypes (self ):
3432 self .outputs [0 ].dtype = self .dtype
Original file line number Diff line number Diff line change @@ -29,7 +29,7 @@ class Dequantize(TraceOp):
2929 dtype : datatype .dtype
3030 dim : int
3131
32- infer_rank = op_utils .InferRankPolicies .same_as_input ()
32+ infer_rank = op_utils .InferRankPolicies .same_shape_as_input ()
3333
3434 def infer_dtypes (self ):
3535 self .outputs [0 ].dtype = self .dtype
Original file line number Diff line number Diff line change @@ -29,7 +29,7 @@ class Quantize(TraceOp):
2929 dtype : datatype .dtype
3030 dim : int
3131
32- infer_rank = op_utils .InferRankPolicies .same_as_input ()
32+ infer_rank = op_utils .InferRankPolicies .same_shape_as_input ()
3333
3434 def infer_dtypes (self ):
3535 self .outputs [0 ].dtype = self .dtype
Original file line number Diff line number Diff line change 2323
2424
2525class InferRankPolicies :
26+ # Indicates that the output has not only the same rank but also the same shape as the input.
27+ def same_shape_as_input (idx = 0 ):
28+ def impl (self ):
29+ self .outputs [0 ].shape = self .inputs [idx ].shape
30+
31+ return impl
32+
2633 def same_as_input (idx = 0 ):
2734 def impl (self ):
2835 self .outputs [0 ].rank = self .inputs [idx ].rank
You can’t perform that action at this time.
0 commit comments