Skip to content

Commit 9b1c63e

Browse files
Add SplitImageToTileList and ImageMergeTileList nodes. (Comfy-Org#12599)
With these you can split an image into tiles, do operations and then combine it back to a single image.
1 parent 7a7debc commit 9b1c63e

1 file changed

Lines changed: 169 additions & 0 deletions

File tree

comfy_extras/nodes_images.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import os
88
import re
9+
import math
910
import torch
1011
import comfy.utils
1112

@@ -682,6 +683,172 @@ def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
682683
upscale = execute # TODO: remove
683684

684685

686+
class SplitImageToTileList(IO.ComfyNode):
687+
@classmethod
688+
def define_schema(cls):
689+
return IO.Schema(
690+
node_id="SplitImageToTileList",
691+
category="image/batch",
692+
search_aliases=["split image", "tile image", "slice image"],
693+
display_name="Split Image into List of Tiles",
694+
description="Splits an image into a batched list of tiles with a specified overlap.",
695+
inputs=[
696+
IO.Image.Input("image"),
697+
IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION),
698+
IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION),
699+
IO.Int.Input("overlap", default=128, min=0, max=4096),
700+
],
701+
outputs=[
702+
IO.Image.Output(is_output_list=True),
703+
],
704+
)
705+
706+
@staticmethod
707+
def get_grid_coords(width, height, tile_width, tile_height, overlap):
708+
coords = []
709+
stride_x = max(1, tile_width - overlap)
710+
stride_y = max(1, tile_height - overlap)
711+
712+
y = 0
713+
while y < height:
714+
x = 0
715+
y_end = min(y + tile_height, height)
716+
y_start = max(0, y_end - tile_height)
717+
718+
while x < width:
719+
x_end = min(x + tile_width, width)
720+
x_start = max(0, x_end - tile_width)
721+
722+
coords.append((x_start, y_start, x_end, y_end))
723+
724+
if x_end >= width:
725+
break
726+
x += stride_x
727+
728+
if y_end >= height:
729+
break
730+
y += stride_y
731+
732+
return coords
733+
734+
@classmethod
735+
def execute(cls, image, tile_width, tile_height, overlap):
736+
b, h, w, c = image.shape
737+
coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap)
738+
739+
output_list = []
740+
for (x_start, y_start, x_end, y_end) in coords:
741+
tile = image[:, y_start:y_end, x_start:x_end, :]
742+
output_list.append(tile)
743+
744+
return IO.NodeOutput(output_list)
745+
746+
747+
class ImageMergeTileList(IO.ComfyNode):
748+
@classmethod
749+
def define_schema(cls):
750+
return IO.Schema(
751+
node_id="ImageMergeTileList",
752+
display_name="Merge List of Tiles to Image",
753+
category="image/batch",
754+
search_aliases=["split image", "tile image", "slice image"],
755+
is_input_list=True,
756+
inputs=[
757+
IO.Image.Input("image_list"),
758+
IO.Int.Input("final_width", default=1024, min=64, max=32768),
759+
IO.Int.Input("final_height", default=1024, min=64, max=32768),
760+
IO.Int.Input("overlap", default=128, min=0, max=4096),
761+
],
762+
outputs=[
763+
IO.Image.Output(is_output_list=False),
764+
],
765+
)
766+
767+
@staticmethod
768+
def get_grid_coords(width, height, tile_width, tile_height, overlap):
769+
coords = []
770+
stride_x = max(1, tile_width - overlap)
771+
stride_y = max(1, tile_height - overlap)
772+
773+
y = 0
774+
while y < height:
775+
x = 0
776+
y_end = min(y + tile_height, height)
777+
y_start = max(0, y_end - tile_height)
778+
779+
while x < width:
780+
x_end = min(x + tile_width, width)
781+
x_start = max(0, x_end - tile_width)
782+
783+
coords.append((x_start, y_start, x_end, y_end))
784+
785+
if x_end >= width:
786+
break
787+
x += stride_x
788+
789+
if y_end >= height:
790+
break
791+
y += stride_y
792+
793+
return coords
794+
795+
@classmethod
796+
def execute(cls, image_list, final_width, final_height, overlap):
797+
w = final_width[0]
798+
h = final_height[0]
799+
ovlp = overlap[0]
800+
feather_str = 1.0
801+
802+
first_tile = image_list[0]
803+
b, t_h, t_w, c = first_tile.shape
804+
device = first_tile.device
805+
dtype = first_tile.dtype
806+
807+
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)
808+
809+
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
810+
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
811+
812+
if ovlp > 0:
813+
y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype))
814+
x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype))
815+
y_w = torch.clamp(y_w, min=1e-5)
816+
x_w = torch.clamp(x_w, min=1e-5)
817+
818+
sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1)
819+
flat_mask = torch.ones_like(sine_mask)
820+
821+
weight_mask = torch.lerp(flat_mask, sine_mask, feather_str)
822+
else:
823+
weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype)
824+
825+
for i, (x_start, y_start, x_end, y_end) in enumerate(coords):
826+
if i >= len(image_list):
827+
break
828+
829+
tile = image_list[i]
830+
831+
region_h = y_end - y_start
832+
region_w = x_end - x_start
833+
834+
real_h = min(region_h, tile.shape[1])
835+
real_w = min(region_w, tile.shape[2])
836+
837+
y_end_actual = y_start + real_h
838+
x_end_actual = x_start + real_w
839+
840+
tile_crop = tile[:, :real_h, :real_w, :]
841+
mask_crop = weight_mask[:, :real_h, :real_w, :]
842+
843+
canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop
844+
weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop
845+
846+
weights[weights == 0] = 1.0
847+
merged_image = canvas / weights
848+
849+
return IO.NodeOutput(merged_image)
850+
851+
685852
class ImagesExtension(ComfyExtension):
686853
@override
687854
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -701,6 +868,8 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
701868
ImageRotate,
702869
ImageFlip,
703870
ImageScaleToMaxDimension,
871+
SplitImageToTileList,
872+
ImageMergeTileList,
704873
]
705874

706875

0 commit comments

Comments
 (0)