|
| 1 | +# TensorRT conversion code comes from the tutorial: |
| 2 | +# https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_compile_resnet_example.html |
| 3 | + |
| 4 | + |
| 5 | +import sys |
| 6 | +from argparse import ArgumentParser |
| 7 | +from pathlib import Path |
| 8 | +import time |
| 9 | + |
| 10 | +import cv2 as cv |
| 11 | +import numpy as np |
| 12 | +import torch |
| 13 | +import torch_tensorrt |
| 14 | + |
| 15 | +this_dir = Path(__file__).parent.resolve() |
| 16 | +sys.path.insert(0, str(this_dir.parent.parent.parent)) |
| 17 | + |
| 18 | +from ptlflow import get_model, load_checkpoint |
| 19 | +from ptlflow.models.rapidflow.rapidflow import RAPIDFlow |
| 20 | +from ptlflow.utils import flow_utils |
| 21 | + |
| 22 | + |
| 23 | +def _init_parser() -> ArgumentParser: |
| 24 | + parser = ArgumentParser() |
| 25 | + parser.add_argument( |
| 26 | + "model", |
| 27 | + type=str, |
| 28 | + choices=( |
| 29 | + "rapidflow", |
| 30 | + "rapidflow_it1", |
| 31 | + "rapidflow_it2", |
| 32 | + "rapidflow_it3", |
| 33 | + "rapidflow_it6", |
| 34 | + "rapidflow_it12", |
| 35 | + ), |
| 36 | + help="Name of the model to use.", |
| 37 | + ) |
| 38 | + parser.add_argument( |
| 39 | + "--checkpoint", |
| 40 | + type=str, |
| 41 | + default=None, |
| 42 | + help="Path to the checkpoint to be loaded. It can also be one of the following names: \{chairs, things, sintel, kitti\}, in which case the respective pretrained checkpoint will be downloaded.", |
| 43 | + ) |
| 44 | + parser.add_argument( |
| 45 | + "--image_paths", |
| 46 | + type=str, |
| 47 | + nargs=2, |
| 48 | + default=( |
| 49 | + str(this_dir / "image_samples" / "000000_10.png"), |
| 50 | + str(this_dir / "image_samples" / "000000_11.png"), |
| 51 | + ), |
| 52 | + help="Path to two images to estimate the optical flow with the TensorRT model.", |
| 53 | + ) |
| 54 | + parser.add_argument( |
| 55 | + "--output_path", |
| 56 | + type=str, |
| 57 | + default=".", |
| 58 | + help="Path to the directory where the predictions will be saved.", |
| 59 | + ) |
| 60 | + parser.add_argument( |
| 61 | + "--input_size", |
| 62 | + type=int, |
| 63 | + nargs=2, |
| 64 | + default=(384, 1280), |
| 65 | + help="Size of the input image.", |
| 66 | + ) |
| 67 | + return parser |
| 68 | + |
| 69 | + |
| 70 | +def compile_engine_and_infer(args): |
| 71 | + # Initialize model with half precision and sample inputs |
| 72 | + model = load_model(args).half().eval().to("cuda") |
| 73 | + images = [torch.from_numpy(load_images(args.image_paths)).half().to("cuda")] |
| 74 | + |
| 75 | + num_tries = 11 |
| 76 | + total_time_orig = 0.0 |
| 77 | + for i in range(num_tries): |
| 78 | + torch.cuda.synchronize() |
| 79 | + start = time.perf_counter() |
| 80 | + model(images[0]) |
| 81 | + torch.cuda.synchronize() |
| 82 | + end = time.perf_counter() |
| 83 | + if i > 0: |
| 84 | + total_time_orig += end - start |
| 85 | + |
| 86 | + # Enabled precision for TensorRT optimization |
| 87 | + enabled_precisions = {torch.half} |
| 88 | + |
| 89 | + # Whether to print verbose logs |
| 90 | + debug = True |
| 91 | + |
| 92 | + # Workspace size for TensorRT |
| 93 | + workspace_size = 20 << 30 |
| 94 | + |
| 95 | + # Maximum number of TRT Engines |
| 96 | + # (Lower value allows more graph segmentation) |
| 97 | + min_block_size = 7 |
| 98 | + |
| 99 | + # Operations to Run in Torch, regardless of converter support |
| 100 | + torch_executed_ops = {} |
| 101 | + |
| 102 | + # Build and compile the model with torch.compile, using Torch-TensorRT backend |
| 103 | + compiled_model = torch_tensorrt.compile( |
| 104 | + model, |
| 105 | + ir="torch_compile", |
| 106 | + inputs=images, |
| 107 | + enabled_precisions=enabled_precisions, |
| 108 | + debug=debug, |
| 109 | + workspace_size=workspace_size, |
| 110 | + min_block_size=min_block_size, |
| 111 | + torch_executed_ops=torch_executed_ops, |
| 112 | + ) |
| 113 | + |
| 114 | + total_time_optimized = 0.0 |
| 115 | + for i in range(num_tries): |
| 116 | + torch.cuda.synchronize() |
| 117 | + start = time.perf_counter() |
| 118 | + flow_pred = compiled_model(*images) |
| 119 | + torch.cuda.synchronize() |
| 120 | + end = time.perf_counter() |
| 121 | + if i > 0: |
| 122 | + total_time_optimized += end - start |
| 123 | + |
| 124 | + try: |
| 125 | + torch_tensorrt.save(compiled_model, f"{args.model}.tc", inputs=images) |
| 126 | + print(f"Saving compiled model to {args.model}.tc") |
| 127 | + compiled_model = torch_tensorrt.load(f"{args.model}.tc") |
| 128 | + print(f"Loading compiled model from {args.model}.tc") |
| 129 | + except Exception as e: |
| 130 | + print("WARNING: The compiled model was not saved due to the error:") |
| 131 | + print(e) |
| 132 | + |
| 133 | + print(f"Model: {args.model}. Average time of {num_tries - 1} runs:") |
| 134 | + print(f"Time (original): {(1000 * total_time_orig / (num_tries - 1)):.2f} ms.") |
| 135 | + print(f"Time (compiled): {(1000 * total_time_optimized / (num_tries - 1)):.2f} ms.") |
| 136 | + |
| 137 | + flow_pred_npy = flow_pred[0].permute(1, 2, 0).detach().cpu().numpy() |
| 138 | + |
| 139 | + output_dir = Path(args.output_path) |
| 140 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 141 | + |
| 142 | + flo_output_path = output_dir / f"flow_pred.flo" |
| 143 | + flow_utils.flow_write(flo_output_path, flow_pred_npy) |
| 144 | + print(f"Saved flow prediction to: {flo_output_path}") |
| 145 | + |
| 146 | + viz_output_path = output_dir / f"flow_pred_viz.png" |
| 147 | + flow_viz = flow_utils.flow_to_rgb(flow_pred_npy) |
| 148 | + cv.imwrite(str(viz_output_path), cv.cvtColor(flow_viz, cv.COLOR_RGB2BGR)) |
| 149 | + print(f"Saved flow prediction visualization to: {viz_output_path}") |
| 150 | + |
| 151 | + # Finally, we use Torch utilities to clean up the workspace |
| 152 | + torch._dynamo.reset() |
| 153 | + |
| 154 | + |
| 155 | +def load_images(image_paths): |
| 156 | + images = [cv.imread(p) for p in image_paths] |
| 157 | + images = [cv.resize(im, args.input_size[::-1]) for im in images] |
| 158 | + images = np.stack(images) |
| 159 | + images = images.transpose(0, 3, 1, 2)[None] |
| 160 | + images = images.astype(np.float32) / 255.0 |
| 161 | + return images |
| 162 | + |
| 163 | + |
| 164 | +def load_model(args): |
| 165 | + model = get_model(args.model, args=args) |
| 166 | + ckpt = load_checkpoint(args.checkpoint, RAPIDFlow, "rapidflow") |
| 167 | + state_dict = fuse_checkpoint_next1d_layers(ckpt["state_dict"]) |
| 168 | + model.load_state_dict(state_dict, strict=True) |
| 169 | + return model |
| 170 | + |
| 171 | + |
| 172 | +def fuse_checkpoint_next1d_layers(state_dict): |
| 173 | + fused_sd = {} |
| 174 | + hv_pairs = {} |
| 175 | + for name, param in state_dict.items(): |
| 176 | + if name.endswith("weight_h") or name.endswith("weight_v"): |
| 177 | + name_prefix = name[: -(len("weight_h") + 1)] |
| 178 | + orientation = name[-1] |
| 179 | + if name_prefix not in hv_pairs: |
| 180 | + hv_pairs[name_prefix] = {} |
| 181 | + hv_pairs[name_prefix][orientation] = param |
| 182 | + else: |
| 183 | + fused_sd[name] = param |
| 184 | + |
| 185 | + for name_prefix, param_pairs in hv_pairs.items(): |
| 186 | + weight = torch.einsum("cijk,cimj->cimk", param_pairs["h"], param_pairs["v"]) |
| 187 | + fused_sd[f"{name_prefix}.weight"] = weight |
| 188 | + return fused_sd |
| 189 | + |
| 190 | + |
| 191 | +if __name__ == "__main__": |
| 192 | + parser = _init_parser() |
| 193 | + parser = RAPIDFlow.add_model_specific_args(parser) |
| 194 | + args = parser.parse_args() |
| 195 | + args.corr_mode = "allpairs" |
| 196 | + args.fuse_next1d_weights = True |
| 197 | + args.simple_io = True |
| 198 | + |
| 199 | + compile_engine_and_infer(args) |
0 commit comments