Skip to content

Commit f5c7d6b

Browse files
authored
Merge pull request #66 from hmorimitsu/rapidflow
Adapt RAPIDFlow code to TensorRT and add simple test script
2 parents f42066c + 2831fc4 commit f5c7d6b

File tree

7 files changed

+280
-16
lines changed

7 files changed

+280
-16
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ jobs:
4040
mv ptlflow ptlflow_tmp
4141
- name: Test with pytest
4242
run: |
43-
python -m pytest
43+
python -m pytest tests/

.github/workflows/lightning.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ jobs:
3535
- name: Test with pytest
3636
run: |
3737
pip install pytest
38-
python -m pytest
38+
python -m pytest tests/

.github/workflows/python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ jobs:
3232
- name: Test with pytest
3333
run: |
3434
pip install pytest
35-
python -m pytest
35+
python -m pytest tests/

.github/workflows/pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ jobs:
3232
- name: Test with pytest
3333
run: |
3434
pip install pytest
35-
python -m pytest
35+
python -m pytest tests/

ptlflow/models/rapidflow/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,19 @@ You can also provide your own images to test by providing an additional argument
106106
python onnx_infer.py rapidflow_it12.onnx --image_paths /path/to/first/image /path/to/second/image
107107
```
108108

109-
### ONNX example limitations
109+
## Compiling model to TensorRT
110110

111-
Directly converting the model to ONNX as shown in this example will work, but it is not optimal.
111+
The script [tensorrt_test.py](tensorrt_test.py) provides a simple example of how to compile RAPIDFlow models to TensorRT.
112+
Run it by typing:
113+
```bash
114+
python tensorrt_test.py rapidflow_it12 --checkpoint things
115+
```
116+
117+
### ONNX and TensorRT example limitations
118+
119+
Directly converting the model to ONNX and TensorRT as shown in this example will work, but it is not optimal.
112120
To obtain the best convertion, it would be necessary to rewrite some parts of the code to remove conditions and operations that may change according to the input size.
113-
Also, ONNX convertion only supports `--corr_mode allpairs`, which is not suitable for large images.
121+
Also, these convertions only supports `--corr_mode allpairs`, which is not suitable for large images.
114122

115123
## Code license
116124

ptlflow/models/rapidflow/rapidflow.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch.nn.functional as F
2525

2626
from ptlflow.utils.utils import forward_interpolate_batch
27-
from .pwc_modules import rescale_flow, upsample2d_as
27+
from .pwc_modules import rescale_flow
2828
from .update import UpdateBlock
2929
from .corr import get_corr_block
3030
from .local_timm.norm import LayerNorm2d
@@ -353,8 +353,11 @@ def forward(self, inputs):
353353
and "prev_flows" in inputs
354354
and inputs["prev_flows"] is not None
355355
):
356-
flow = upsample2d_as(
357-
inputs["prev_flows"][:, 0], pass_pyramid1[0], mode="bilinear"
356+
flow = F.interpolate(
357+
inputs["prev_flows"][:, 0],
358+
[pass_pyramid1[0].shape[-2], pass_pyramid1[0].shape[-1]],
359+
mode="bilinear",
360+
align_corners=True,
358361
)
359362
flow = rescale_flow(flow, width_im, height_im, to_local=True)
360363
flow = forward_interpolate_batch(flow)
@@ -385,7 +388,12 @@ def forward(self, inputs):
385388
if net is None:
386389
net = torch.tanh(net_tmp)
387390
else:
388-
net = upsample2d_as(net, x1, mode="bilinear")
391+
net = F.interpolate(
392+
net,
393+
[x1.shape[-2], x1.shape[-1]],
394+
mode="bilinear",
395+
align_corners=True,
396+
)
389397

390398
net_skip = torch.tanh(net_tmp)
391399
gate = torch.sigmoid(
@@ -395,7 +403,12 @@ def forward(self, inputs):
395403

396404
if l > 0:
397405
flow = rescale_flow(flow, x1.shape[-1], x1.shape[-2], to_local=False)
398-
flow = upsample2d_as(flow, x1, mode="bilinear")
406+
flow = F.interpolate(
407+
flow,
408+
[x1.shape[-2], x1.shape[-1]],
409+
mode="bilinear",
410+
align_corners=True,
411+
)
399412

400413
for k in range(iters_per_level[l]):
401414
flow = flow.detach()
@@ -414,16 +427,60 @@ def forward(self, inputs):
414427
out_flow = rescale_flow(flow, width_im, height_im, to_local=False)
415428
if self.training:
416429
if mask is not None and l == (output_level - start_level):
417-
out_flow = self.upsample_flow(out_flow, mask, pred_stride)
430+
if self.args.simple_io:
431+
# Just copied the code from self.upsample_flow to here.
432+
# For some reason, TensorRT backend does not compile when calling the function
433+
N, _, H, W = out_flow.shape
434+
mask = mask.view(N, 1, 9, pred_stride, pred_stride, H, W)
435+
mask = torch.softmax(mask, dim=2)
436+
437+
up_flow = F.unfold(flow, [3, 3], padding=1)
438+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
439+
440+
up_flow = torch.sum(mask * up_flow, dim=2)
441+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
442+
up_flow = up_flow.reshape(
443+
N, 2, pred_stride * H, pred_stride * W
444+
)
445+
out_flow = up_flow
446+
else:
447+
out_flow = self.upsample_flow(out_flow, mask, pred_stride)
418448
else:
419-
out_flow = upsample2d_as(out_flow, x1_raw, mode="bilinear")
449+
out_flow = F.interpolate(
450+
out_flow,
451+
[x1_raw.shape[-2], x1_raw.shape[-1]],
452+
mode="bilinear",
453+
align_corners=True,
454+
)
420455
elif l == (output_level - start_level) and k == (
421456
iters_per_level[l] - 1
422457
):
423458
if mask is not None:
424-
out_flow = self.upsample_flow(out_flow, mask, pred_stride)
459+
if self.args.simple_io:
460+
# Just copied the code from self.upsample_flow to here.
461+
# For some reason, TensorRT backend does not compile when calling the function
462+
N, _, H, W = out_flow.shape
463+
mask = mask.view(N, 1, 9, pred_stride, pred_stride, H, W)
464+
mask = torch.softmax(mask, dim=2)
465+
466+
up_flow = F.unfold(flow, [3, 3], padding=1)
467+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
468+
469+
up_flow = torch.sum(mask * up_flow, dim=2)
470+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
471+
up_flow = up_flow.reshape(
472+
N, 2, pred_stride * H, pred_stride * W
473+
)
474+
out_flow = up_flow
475+
else:
476+
out_flow = self.upsample_flow(out_flow, mask, pred_stride)
425477
else:
426-
out_flow = upsample2d_as(out_flow, x1_raw, mode="bilinear")
478+
out_flow = F.interpolate(
479+
out_flow,
480+
[x1_raw.shape[-2], x1_raw.shape[-1]],
481+
mode="bilinear",
482+
align_corners=True,
483+
)
427484
out_flow = self.postprocess_predictions(
428485
out_flow, image_resizer, is_flow=True
429486
)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# TensorRT conversion code comes from the tutorial:
2+
# https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_compile_resnet_example.html
3+
4+
5+
import sys
6+
from argparse import ArgumentParser
7+
from pathlib import Path
8+
import time
9+
10+
import cv2 as cv
11+
import numpy as np
12+
import torch
13+
import torch_tensorrt
14+
15+
this_dir = Path(__file__).parent.resolve()
16+
sys.path.insert(0, str(this_dir.parent.parent.parent))
17+
18+
from ptlflow import get_model, load_checkpoint
19+
from ptlflow.models.rapidflow.rapidflow import RAPIDFlow
20+
from ptlflow.utils import flow_utils
21+
22+
23+
def _init_parser() -> ArgumentParser:
24+
parser = ArgumentParser()
25+
parser.add_argument(
26+
"model",
27+
type=str,
28+
choices=(
29+
"rapidflow",
30+
"rapidflow_it1",
31+
"rapidflow_it2",
32+
"rapidflow_it3",
33+
"rapidflow_it6",
34+
"rapidflow_it12",
35+
),
36+
help="Name of the model to use.",
37+
)
38+
parser.add_argument(
39+
"--checkpoint",
40+
type=str,
41+
default=None,
42+
help="Path to the checkpoint to be loaded. It can also be one of the following names: \{chairs, things, sintel, kitti\}, in which case the respective pretrained checkpoint will be downloaded.",
43+
)
44+
parser.add_argument(
45+
"--image_paths",
46+
type=str,
47+
nargs=2,
48+
default=(
49+
str(this_dir / "image_samples" / "000000_10.png"),
50+
str(this_dir / "image_samples" / "000000_11.png"),
51+
),
52+
help="Path to two images to estimate the optical flow with the TensorRT model.",
53+
)
54+
parser.add_argument(
55+
"--output_path",
56+
type=str,
57+
default=".",
58+
help="Path to the directory where the predictions will be saved.",
59+
)
60+
parser.add_argument(
61+
"--input_size",
62+
type=int,
63+
nargs=2,
64+
default=(384, 1280),
65+
help="Size of the input image.",
66+
)
67+
return parser
68+
69+
70+
def compile_engine_and_infer(args):
71+
# Initialize model with half precision and sample inputs
72+
model = load_model(args).half().eval().to("cuda")
73+
images = [torch.from_numpy(load_images(args.image_paths)).half().to("cuda")]
74+
75+
num_tries = 11
76+
total_time_orig = 0.0
77+
for i in range(num_tries):
78+
torch.cuda.synchronize()
79+
start = time.perf_counter()
80+
model(images[0])
81+
torch.cuda.synchronize()
82+
end = time.perf_counter()
83+
if i > 0:
84+
total_time_orig += end - start
85+
86+
# Enabled precision for TensorRT optimization
87+
enabled_precisions = {torch.half}
88+
89+
# Whether to print verbose logs
90+
debug = True
91+
92+
# Workspace size for TensorRT
93+
workspace_size = 20 << 30
94+
95+
# Maximum number of TRT Engines
96+
# (Lower value allows more graph segmentation)
97+
min_block_size = 7
98+
99+
# Operations to Run in Torch, regardless of converter support
100+
torch_executed_ops = {}
101+
102+
# Build and compile the model with torch.compile, using Torch-TensorRT backend
103+
compiled_model = torch_tensorrt.compile(
104+
model,
105+
ir="torch_compile",
106+
inputs=images,
107+
enabled_precisions=enabled_precisions,
108+
debug=debug,
109+
workspace_size=workspace_size,
110+
min_block_size=min_block_size,
111+
torch_executed_ops=torch_executed_ops,
112+
)
113+
114+
total_time_optimized = 0.0
115+
for i in range(num_tries):
116+
torch.cuda.synchronize()
117+
start = time.perf_counter()
118+
flow_pred = compiled_model(*images)
119+
torch.cuda.synchronize()
120+
end = time.perf_counter()
121+
if i > 0:
122+
total_time_optimized += end - start
123+
124+
try:
125+
torch_tensorrt.save(compiled_model, f"{args.model}.tc", inputs=images)
126+
print(f"Saving compiled model to {args.model}.tc")
127+
compiled_model = torch_tensorrt.load(f"{args.model}.tc")
128+
print(f"Loading compiled model from {args.model}.tc")
129+
except Exception as e:
130+
print("WARNING: The compiled model was not saved due to the error:")
131+
print(e)
132+
133+
print(f"Model: {args.model}. Average time of {num_tries - 1} runs:")
134+
print(f"Time (original): {(1000 * total_time_orig / (num_tries - 1)):.2f} ms.")
135+
print(f"Time (compiled): {(1000 * total_time_optimized / (num_tries - 1)):.2f} ms.")
136+
137+
flow_pred_npy = flow_pred[0].permute(1, 2, 0).detach().cpu().numpy()
138+
139+
output_dir = Path(args.output_path)
140+
output_dir.mkdir(parents=True, exist_ok=True)
141+
142+
flo_output_path = output_dir / f"flow_pred.flo"
143+
flow_utils.flow_write(flo_output_path, flow_pred_npy)
144+
print(f"Saved flow prediction to: {flo_output_path}")
145+
146+
viz_output_path = output_dir / f"flow_pred_viz.png"
147+
flow_viz = flow_utils.flow_to_rgb(flow_pred_npy)
148+
cv.imwrite(str(viz_output_path), cv.cvtColor(flow_viz, cv.COLOR_RGB2BGR))
149+
print(f"Saved flow prediction visualization to: {viz_output_path}")
150+
151+
# Finally, we use Torch utilities to clean up the workspace
152+
torch._dynamo.reset()
153+
154+
155+
def load_images(image_paths):
156+
images = [cv.imread(p) for p in image_paths]
157+
images = [cv.resize(im, args.input_size[::-1]) for im in images]
158+
images = np.stack(images)
159+
images = images.transpose(0, 3, 1, 2)[None]
160+
images = images.astype(np.float32) / 255.0
161+
return images
162+
163+
164+
def load_model(args):
165+
model = get_model(args.model, args=args)
166+
ckpt = load_checkpoint(args.checkpoint, RAPIDFlow, "rapidflow")
167+
state_dict = fuse_checkpoint_next1d_layers(ckpt["state_dict"])
168+
model.load_state_dict(state_dict, strict=True)
169+
return model
170+
171+
172+
def fuse_checkpoint_next1d_layers(state_dict):
173+
fused_sd = {}
174+
hv_pairs = {}
175+
for name, param in state_dict.items():
176+
if name.endswith("weight_h") or name.endswith("weight_v"):
177+
name_prefix = name[: -(len("weight_h") + 1)]
178+
orientation = name[-1]
179+
if name_prefix not in hv_pairs:
180+
hv_pairs[name_prefix] = {}
181+
hv_pairs[name_prefix][orientation] = param
182+
else:
183+
fused_sd[name] = param
184+
185+
for name_prefix, param_pairs in hv_pairs.items():
186+
weight = torch.einsum("cijk,cimj->cimk", param_pairs["h"], param_pairs["v"])
187+
fused_sd[f"{name_prefix}.weight"] = weight
188+
return fused_sd
189+
190+
191+
if __name__ == "__main__":
192+
parser = _init_parser()
193+
parser = RAPIDFlow.add_model_specific_args(parser)
194+
args = parser.parse_args()
195+
args.corr_mode = "allpairs"
196+
args.fuse_next1d_weights = True
197+
args.simple_io = True
198+
199+
compile_engine_and_infer(args)

0 commit comments

Comments
 (0)