diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6a33a9ee..1ee0e93f 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -465,7 +465,11 @@ def forward(self, *inputs): for i, output_name in enumerate(self.output_names): idx = self.engine.get_binding_index(output_name) dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) - shape = (batch_size,) + tuple(self.engine.get_binding_shape(idx)) + oshape = tuple(self.engine.get_binding_shape(idx)) + if oshape[0] == -1: + shape = (batch_size, ) + oshape[1:] + else: + shape = (batch_size,) + oshape device = torch_device_from_trt(self.engine.get_location(idx)) output = torch.empty(size=shape, dtype=dtype, device=device) outputs[i] = output @@ -474,6 +478,12 @@ def forward(self, *inputs): for i, input_name in enumerate(self.input_names): idx = self.engine.get_binding_index(input_name) bindings[idx] = inputs[i].contiguous().data_ptr() + ishape = tuple(self.engine.get_binding_shape(idx)) + if ishape[0] == -1: + tmp_ishape = (batch_size,) + ishape[1:] + batch_size = 1 + self.context.set_binding_shape(idx, tmp_ishape) + self.context.execute_async( batch_size, bindings, torch.cuda.current_stream().cuda_stream