From 79f2105fedcda44b8867dc906de7c6a42b68dead Mon Sep 17 00:00:00 2001 From: maiguangcan Date: Wed, 29 Sep 2021 18:32:51 +0800 Subject: [PATCH 1/2] revise the TRTModule to support engine with explicit dynamic batch --- torch2trt/torch2trt.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6a33a9ee..d90fca36 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -465,7 +465,11 @@ def forward(self, *inputs): for i, output_name in enumerate(self.output_names): idx = self.engine.get_binding_index(output_name) dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) - shape = (batch_size,) + tuple(self.engine.get_binding_shape(idx)) + oshape = tuple(self.engine.get_binding_shape(idx)) + if oshape[0] == -1: + shape = (batch_size, ) + oshape[1:] + else: + shape = (batch_size,) + oshape device = torch_device_from_trt(self.engine.get_location(idx)) output = torch.empty(size=shape, dtype=dtype, device=device) outputs[i] = output @@ -474,6 +478,10 @@ def forward(self, *inputs): for i, input_name in enumerate(self.input_names): idx = self.engine.get_binding_index(input_name) bindings[idx] = inputs[i].contiguous().data_ptr() + ishape = tuple(self.engine.get_binding_shape(idx)) + if ishape[0] == -1: + batch_size = 1 + self.context.execute_async( batch_size, bindings, torch.cuda.current_stream().cuda_stream From 8a27555b5e6609856bf2b2d911464f1e2215dc0c Mon Sep 17 00:00:00 2001 From: maiguangcan Date: Thu, 30 Sep 2021 10:48:19 +0800 Subject: [PATCH 2/2] dynmaic batch setting --- torch2trt/torch2trt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index d90fca36..1ee0e93f 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -480,7 +480,9 @@ def forward(self, *inputs): bindings[idx] = inputs[i].contiguous().data_ptr() ishape = tuple(self.engine.get_binding_shape(idx)) if ishape[0] == -1: + tmp_ishape = (batch_size,) + ishape[1:] batch_size = 1 + self.context.set_binding_shape(idx, tmp_ishape) self.context.execute_async(