-
Notifications
You must be signed in to change notification settings - Fork 284
opti-qwen2-vl-pre-process #1094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @SangChengC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the image preprocessing pipeline for Qwen2-VL models by refactoring the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request focuses on optimizing the vision pre-processing pipeline for Qwen2-VL models. The changes include refactoring the image processing logic to be more efficient by leveraging BaseImageProcessorFast and torch-based operations, which should yield significant performance improvements. Additionally, TCP_NODELAY is enabled for rpyc connections to reduce network latency. My review includes suggestions to improve code correctness, reusability, and maintainability, such as correcting a type hint, removing a hardcoded device to prevent redundant operations, and flagging the use of private library attributes which could be a future maintenance risk.
| min_pixels=self.min_pixels, | ||
| max_pixels=self.max_pixels, | ||
| image_arr = np.asarray(image, dtype=np.uint8) | ||
| image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The preprocess method hardcodes moving the image tensor to "cuda". This makes the function less flexible and tightly coupled to a specific device. Furthermore, the calling code in qwen2_5_visual.py and qwen2_visual.py also moves the tensors to "cuda", leading to redundant device transfers. It's better to keep this preprocessing function device-agnostic and let the caller manage device placement.
| image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True) | |
| image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous() |
| def resize_image( | ||
| image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS | ||
| ) -> tuple[Image.Image, int, int]: | ||
| ) -> tuple[Image.Image]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for resize_image is tuple[Image.Image], but the function returns a single Image.Image object. This mismatch can be confusing and might be flagged by static type checkers. It should be corrected to Image.Image to accurately reflect the return value.
| def resize_image( | |
| image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS | |
| ) -> tuple[Image.Image, int, int]: | |
| ) -> tuple[Image.Image]: | |
| def resize_image( | |
| image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS | |
| ) -> Image.Image: |
| def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| if self.do_convert_rgb: | ||
| image = convert_to_rgb(image) | ||
| image = to_numpy_array(image) | ||
| input_data_format = infer_channel_dimension_format(image) | ||
| height, width = get_image_size(image, channel_dim=input_data_format) | ||
|
|
||
| resized_height, resized_width = height, width | ||
| if self.do_resize: | ||
| resized_height, resized_width = smart_resize( | ||
| height, | ||
| width, | ||
| factor=self.patch_size * self.merge_size, | ||
| min_pixels=self.min_pixels, | ||
| max_pixels=self.max_pixels, | ||
| image_arr = np.asarray(image, dtype=np.uint8) | ||
| image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True) | ||
| grouped_images, grouped_images_index = group_images_by_shape( | ||
| [image_data], disable_grouping=self.disable_grouping | ||
| ) | ||
| resized_images_grouped = {} | ||
| for shape, stacked_images in grouped_images.items(): | ||
| height, width = stacked_images.shape[-2:] | ||
| if self.do_resize: | ||
| resized_height, resized_width = smart_resize( | ||
| height, | ||
| width, | ||
| factor=self.patch_size * self.merge_size, | ||
| min_pixels=self.min_pixels, | ||
| max_pixels=self.max_pixels, | ||
| ) | ||
| stacked_images = self.resize( | ||
| image=stacked_images, | ||
| size=SizeDict(height=resized_height, width=resized_width), | ||
| interpolation=self.interpolation, | ||
| ) | ||
| resized_images_grouped[shape] = stacked_images | ||
| resized_images = reorder_images(resized_images_grouped, grouped_images_index) | ||
|
|
||
| # Group images by size for further processing | ||
| # Needed in case do_resize is False, or resize returns images with different sizes | ||
| grouped_images, grouped_images_index = group_images_by_shape( | ||
| resized_images, disable_grouping=self.disable_grouping | ||
| ) | ||
| processed_images_grouped = {} | ||
| processed_grids = {} | ||
| for shape, stacked_images in grouped_images.items(): | ||
| resized_height, resized_width = stacked_images.shape[-2:] | ||
| # Fused rescale and normalize | ||
| patches = self.rescale_and_normalize( | ||
| stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std | ||
| ) | ||
| image = resize( | ||
| image, size=(resized_height, resized_width), resample=self.resample, input_data_format=input_data_format | ||
| if patches.ndim == 4: | ||
| # add a temporal dimension if we have images | ||
| patches = patches.unsqueeze(1) | ||
| if patches.shape[1] % self.temporal_patch_size != 0: | ||
| repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) | ||
| patches = torch.cat([patches, repeats], dim=1) | ||
| batch_size, grid_t, channel = patches.shape[:3] | ||
| grid_t = grid_t // self.temporal_patch_size | ||
| grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size | ||
|
|
||
| patches = ( | ||
| patches.view( | ||
| batch_size, | ||
| grid_t, | ||
| self.temporal_patch_size, | ||
| channel, | ||
| grid_h // self.merge_size, | ||
| self.merge_size, | ||
| self.patch_size, | ||
| grid_w // self.merge_size, | ||
| self.merge_size, | ||
| self.patch_size, | ||
| ) | ||
| .permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) | ||
| .contiguous() | ||
| ) | ||
|
|
||
| if self.do_rescale: | ||
| image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format) | ||
|
|
||
| if self.do_normalize: | ||
| image = self.normalize( | ||
| image=image, mean=self.image_mean, std=self.image_std, input_data_format=input_data_format | ||
| # Reorder dimensions to group grid and patch information for subsequent flattening. | ||
| # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w) | ||
| flatten_patches = patches.view( | ||
| batch_size, | ||
| grid_t * grid_h * grid_w, | ||
| channel * self.temporal_patch_size * self.patch_size * self.patch_size, | ||
| ) | ||
|
|
||
| image = to_channel_dimension_format(image, self.data_format, input_channel_dim=input_data_format) | ||
|
|
||
| patches = np.array([image]) | ||
|
|
||
| if patches.shape[0] == 1: | ||
| # why to copy image 2 times. use self.temporal_patch_size = 2. | ||
| patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) | ||
| channel = patches.shape[1] | ||
| grid_t = patches.shape[0] // self.temporal_patch_size | ||
| grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size | ||
| patches = patches.reshape( | ||
| grid_t, | ||
| self.temporal_patch_size, | ||
| channel, | ||
| grid_h // self.merge_size, | ||
| self.merge_size, | ||
| self.patch_size, | ||
| grid_w // self.merge_size, | ||
| self.merge_size, | ||
| self.patch_size, | ||
| ) | ||
| patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) | ||
| flatten_patches = patches.reshape( | ||
| grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size | ||
| ) | ||
| image_grid_thw = (grid_t, grid_h, grid_w) | ||
| pixel_values = torch.as_tensor(flatten_patches) | ||
| grid_thw = torch.as_tensor([image_grid_thw]) | ||
| processed_images_grouped[shape] = flatten_patches | ||
| processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size | ||
|
|
||
| processed_images = reorder_images(processed_images_grouped, grouped_images_index) | ||
| processed_grids = reorder_images(processed_grids, grouped_images_index) | ||
| pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps) | ||
| image_grid_thw = torch.as_tensor(processed_grids) | ||
|
|
||
| return pixel_values, grid_thw | ||
| return pixel_values, image_grid_thw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The preprocess method is implemented to handle a single image, but its implementation uses primitives like group_images_by_shape which are designed for batch processing. This is a missed opportunity for a significant performance improvement. By modifying the preprocess method to accept a list of images, you could leverage true batch processing and avoid calling it in a loop from the encode methods. While the calling encode methods are mostly outside the current diff, consider this for a future refactoring to further boost performance.
| self.enable_multimodal = args.enable_multimodal | ||
| if self.enable_multimodal: | ||
| self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) | ||
| self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing private attributes like _channel and stream.sock to set socket options is risky. This implementation detail of rpyc might change in future versions, which would break this code. While this may be necessary if no public API is available on the client side, it's a maintainability concern. Consider adding a comment explaining why this is done and that it might need updates if rpyc is upgraded.
| self.zmq_recv_socket = context.socket(zmq.PULL) | ||
| self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") | ||
| self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) | ||
| self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| weight_dir = kvargs["weight_dir"] | ||
| self.vit_rank_id = kvargs["vit_rank_id"] | ||
| self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) | ||
| self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| while repeat_count < 20: | ||
| try: | ||
| con = rpyc.connect("localhost", port, config={"allow_pickle": True}) | ||
| con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.