66import json
77import os
88import re
9+ import math
910import torch
1011import comfy .utils
1112
@@ -682,6 +683,172 @@ def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
682683 upscale = execute # TODO: remove
683684
684685
686+ class SplitImageToTileList (IO .ComfyNode ):
687+ @classmethod
688+ def define_schema (cls ):
689+ return IO .Schema (
690+ node_id = "SplitImageToTileList" ,
691+ category = "image/batch" ,
692+ search_aliases = ["split image" , "tile image" , "slice image" ],
693+ display_name = "Split Image into List of Tiles" ,
694+ description = "Splits an image into a batched list of tiles with a specified overlap." ,
695+ inputs = [
696+ IO .Image .Input ("image" ),
697+ IO .Int .Input ("tile_width" , default = 1024 , min = 64 , max = MAX_RESOLUTION ),
698+ IO .Int .Input ("tile_height" , default = 1024 , min = 64 , max = MAX_RESOLUTION ),
699+ IO .Int .Input ("overlap" , default = 128 , min = 0 , max = 4096 ),
700+ ],
701+ outputs = [
702+ IO .Image .Output (is_output_list = True ),
703+ ],
704+ )
705+
706+ @staticmethod
707+ def get_grid_coords (width , height , tile_width , tile_height , overlap ):
708+ coords = []
709+ stride_x = max (1 , tile_width - overlap )
710+ stride_y = max (1 , tile_height - overlap )
711+
712+ y = 0
713+ while y < height :
714+ x = 0
715+ y_end = min (y + tile_height , height )
716+ y_start = max (0 , y_end - tile_height )
717+
718+ while x < width :
719+ x_end = min (x + tile_width , width )
720+ x_start = max (0 , x_end - tile_width )
721+
722+ coords .append ((x_start , y_start , x_end , y_end ))
723+
724+ if x_end >= width :
725+ break
726+ x += stride_x
727+
728+ if y_end >= height :
729+ break
730+ y += stride_y
731+
732+ return coords
733+
734+ @classmethod
735+ def execute (cls , image , tile_width , tile_height , overlap ):
736+ b , h , w , c = image .shape
737+ coords = cls .get_grid_coords (w , h , tile_width , tile_height , overlap )
738+
739+ output_list = []
740+ for (x_start , y_start , x_end , y_end ) in coords :
741+ tile = image [:, y_start :y_end , x_start :x_end , :]
742+ output_list .append (tile )
743+
744+ return IO .NodeOutput (output_list )
745+
746+
747+ class ImageMergeTileList (IO .ComfyNode ):
748+ @classmethod
749+ def define_schema (cls ):
750+ return IO .Schema (
751+ node_id = "ImageMergeTileList" ,
752+ display_name = "Merge List of Tiles to Image" ,
753+ category = "image/batch" ,
754+ search_aliases = ["split image" , "tile image" , "slice image" ],
755+ is_input_list = True ,
756+ inputs = [
757+ IO .Image .Input ("image_list" ),
758+ IO .Int .Input ("final_width" , default = 1024 , min = 64 , max = 32768 ),
759+ IO .Int .Input ("final_height" , default = 1024 , min = 64 , max = 32768 ),
760+ IO .Int .Input ("overlap" , default = 128 , min = 0 , max = 4096 ),
761+ ],
762+ outputs = [
763+ IO .Image .Output (is_output_list = False ),
764+ ],
765+ )
766+
767+ @staticmethod
768+ def get_grid_coords (width , height , tile_width , tile_height , overlap ):
769+ coords = []
770+ stride_x = max (1 , tile_width - overlap )
771+ stride_y = max (1 , tile_height - overlap )
772+
773+ y = 0
774+ while y < height :
775+ x = 0
776+ y_end = min (y + tile_height , height )
777+ y_start = max (0 , y_end - tile_height )
778+
779+ while x < width :
780+ x_end = min (x + tile_width , width )
781+ x_start = max (0 , x_end - tile_width )
782+
783+ coords .append ((x_start , y_start , x_end , y_end ))
784+
785+ if x_end >= width :
786+ break
787+ x += stride_x
788+
789+ if y_end >= height :
790+ break
791+ y += stride_y
792+
793+ return coords
794+
795+ @classmethod
796+ def execute (cls , image_list , final_width , final_height , overlap ):
797+ w = final_width [0 ]
798+ h = final_height [0 ]
799+ ovlp = overlap [0 ]
800+ feather_str = 1.0
801+
802+ first_tile = image_list [0 ]
803+ b , t_h , t_w , c = first_tile .shape
804+ device = first_tile .device
805+ dtype = first_tile .dtype
806+
807+ coords = cls .get_grid_coords (w , h , t_w , t_h , ovlp )
808+
809+ canvas = torch .zeros ((b , h , w , c ), device = device , dtype = dtype )
810+ weights = torch .zeros ((b , h , w , 1 ), device = device , dtype = dtype )
811+
812+ if ovlp > 0 :
813+ y_w = torch .sin (math .pi * torch .linspace (0 , 1 , t_h , device = device , dtype = dtype ))
814+ x_w = torch .sin (math .pi * torch .linspace (0 , 1 , t_w , device = device , dtype = dtype ))
815+ y_w = torch .clamp (y_w , min = 1e-5 )
816+ x_w = torch .clamp (x_w , min = 1e-5 )
817+
818+ sine_mask = (y_w .unsqueeze (1 ) * x_w .unsqueeze (0 )).unsqueeze (0 ).unsqueeze (- 1 )
819+ flat_mask = torch .ones_like (sine_mask )
820+
821+ weight_mask = torch .lerp (flat_mask , sine_mask , feather_str )
822+ else :
823+ weight_mask = torch .ones ((1 , t_h , t_w , 1 ), device = device , dtype = dtype )
824+
825+ for i , (x_start , y_start , x_end , y_end ) in enumerate (coords ):
826+ if i >= len (image_list ):
827+ break
828+
829+ tile = image_list [i ]
830+
831+ region_h = y_end - y_start
832+ region_w = x_end - x_start
833+
834+ real_h = min (region_h , tile .shape [1 ])
835+ real_w = min (region_w , tile .shape [2 ])
836+
837+ y_end_actual = y_start + real_h
838+ x_end_actual = x_start + real_w
839+
840+ tile_crop = tile [:, :real_h , :real_w , :]
841+ mask_crop = weight_mask [:, :real_h , :real_w , :]
842+
843+ canvas [:, y_start :y_end_actual , x_start :x_end_actual , :] += tile_crop * mask_crop
844+ weights [:, y_start :y_end_actual , x_start :x_end_actual , :] += mask_crop
845+
846+ weights [weights == 0 ] = 1.0
847+ merged_image = canvas / weights
848+
849+ return IO .NodeOutput (merged_image )
850+
851+
685852class ImagesExtension (ComfyExtension ):
686853 @override
687854 async def get_node_list (self ) -> list [type [IO .ComfyNode ]]:
@@ -701,6 +868,8 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
701868 ImageRotate ,
702869 ImageFlip ,
703870 ImageScaleToMaxDimension ,
871+ SplitImageToTileList ,
872+ ImageMergeTileList ,
704873 ]
705874
706875
0 commit comments