@@ -459,27 +459,23 @@ def output_patch(h, hsp, transformer_options):
459459 total_images = image .shape [0 ]
460460 captured_feat = None
461461
462- model_h = int (head .heatmap_size [0 ]) * 4 # e.g. 192 * 4 = 768
463- model_w = int (head .heatmap_size [1 ]) * 4 # e.g. 256 * 4 = 1024
462+ model_w = int (head .heatmap_size [0 ]) * 4 # 192 * 4 = 768
463+ model_h = int (head .heatmap_size [1 ]) * 4 # 256 * 4 = 1024
464464
465465 def _resize_to_model (imgs ):
466- """Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left) ."""
466+ """Stretch BHWC images to (model_h, model_w), model expects no aspect preservation ."""
467467 h , w = imgs .shape [- 3 ], imgs .shape [- 2 ]
468- scale = min (model_h / h , model_w / w )
469- sh , sw = int (round (h * scale )), int (round (w * scale ))
470- pt , pl = (model_h - sh ) // 2 , (model_w - sw ) // 2
468+ method = "area" if (model_h <= h and model_w <= w ) else "bilinear"
471469 chw = imgs .permute (0 , 3 , 1 , 2 ).float ()
472- scaled = comfy .utils .common_upscale (chw , sw , sh , upscale_method = "bilinear" , crop = "disabled" )
473- padded = torch .zeros (scaled .shape [0 ], scaled .shape [1 ], model_h , model_w , dtype = scaled .dtype , device = scaled .device )
474- padded [:, :, pt :pt + sh , pl :pl + sw ] = scaled
475- return padded .permute (0 , 2 , 3 , 1 ), scale , pt , pl
470+ scaled = comfy .utils .common_upscale (chw , model_w , model_h , upscale_method = method , crop = "disabled" )
471+ return scaled .permute (0 , 2 , 3 , 1 ), model_w / w , model_h / h
476472
477- def _remap_keypoints (kp , scale , pad_top , pad_left , offset_x = 0 , offset_y = 0 ):
473+ def _remap_keypoints (kp , scale_x , scale_y , offset_x = 0 , offset_y = 0 ):
478474 """Remap keypoints from model space back to original image space."""
479475 kp = kp .copy () if isinstance (kp , np .ndarray ) else np .array (kp , dtype = np .float32 )
480476 invalid = kp [..., 0 ] < 0
481- kp [..., 0 ] = ( kp [..., 0 ] - pad_left ) / scale + offset_x
482- kp [..., 1 ] = ( kp [..., 1 ] - pad_top ) / scale + offset_y
477+ kp [..., 0 ] = kp [..., 0 ] / scale_x + offset_x
478+ kp [..., 1 ] = kp [..., 1 ] / scale_y + offset_y
483479 kp [invalid ] = - 1
484480 return kp
485481
@@ -529,18 +525,18 @@ def _run_on_latent(latent_batch):
529525 continue
530526
531527 crop = img [:, y1 :y2 , x1 :x2 , :] # (1, crop_h, crop_w, C)
532- crop_resized , scale , pad_top , pad_left = _resize_to_model (crop )
528+ crop_resized , sx , sy = _resize_to_model (crop )
533529
534530 latent_crop = vae .encode (crop_resized )
535531 kp_batch , sc_batch = _run_on_latent (latent_crop )
536- kp = _remap_keypoints (kp_batch [0 ], scale , pad_top , pad_left , x1 , y1 )
532+ kp = _remap_keypoints (kp_batch [0 ], sx , sy , x1 , y1 )
537533 img_keypoints .append (kp )
538534 img_scores .append (sc_batch [0 ])
539535 else :
540- img_resized , scale , pad_top , pad_left = _resize_to_model (img )
536+ img_resized , sx , sy = _resize_to_model (img )
541537 latent_img = vae .encode (img_resized )
542538 kp_batch , sc_batch = _run_on_latent (latent_img )
543- img_keypoints .append (_remap_keypoints (kp_batch [0 ], scale , pad_top , pad_left ))
539+ img_keypoints .append (_remap_keypoints (kp_batch [0 ], sx , sy ))
544540 img_scores .append (sc_batch [0 ])
545541
546542 all_keypoints .append (img_keypoints )
@@ -549,12 +545,12 @@ def _run_on_latent(latent_batch):
549545
550546 else : # full-image mode, batched
551547 for batch_start in tqdm (range (0 , total_images , batch_size ), desc = "Extracting keypoints" ):
552- batch_resized , scale , pad_top , pad_left = _resize_to_model (image [batch_start :batch_start + batch_size ])
548+ batch_resized , sx , sy = _resize_to_model (image [batch_start :batch_start + batch_size ])
553549 latent_batch = vae .encode (batch_resized )
554550 kp_batch , sc_batch = _run_on_latent (latent_batch )
555551
556552 for kp , sc in zip (kp_batch , sc_batch ):
557- all_keypoints .append ([_remap_keypoints (kp , scale , pad_top , pad_left )])
553+ all_keypoints .append ([_remap_keypoints (kp , sx , sy )])
558554 all_scores .append ([sc ])
559555
560556 pbar .update (len (kp_batch ))
@@ -727,13 +723,13 @@ def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspec
727723 scale = min (output_width / crop_w , output_height / crop_h )
728724 scaled_w = int (round (crop_w * scale ))
729725 scaled_h = int (round (crop_h * scale ))
730- scaled = comfy .utils .common_upscale (crop_chw , scaled_w , scaled_h , upscale_method = "bilinear " , crop = "disabled" )
726+ scaled = comfy .utils .common_upscale (crop_chw , scaled_w , scaled_h , upscale_method = "area " , crop = "disabled" )
731727 pad_left = (output_width - scaled_w ) // 2
732728 pad_top = (output_height - scaled_h ) // 2
733729 resized = torch .zeros (1 , num_ch , output_height , output_width , dtype = image .dtype , device = image .device )
734730 resized [:, :, pad_top :pad_top + scaled_h , pad_left :pad_left + scaled_w ] = scaled
735731 else : # "stretch"
736- resized = comfy .utils .common_upscale (crop_chw , output_width , output_height , upscale_method = "bilinear " , crop = "disabled" )
732+ resized = comfy .utils .common_upscale (crop_chw , output_width , output_height , upscale_method = "area " , crop = "disabled" )
737733 crops .append (resized )
738734
739735 if not crops :
0 commit comments