Skip to content

Commit c96fcdd

Browse files
authored
Radiance: support variant with nonzero txt_ids (Comfy-Org#14206)
1 parent e88a81d commit c96fcdd

3 files changed

Lines changed: 22 additions & 0 deletions

File tree

comfy/ldm/chroma_radiance/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
3838
# None means use the same dtype as the model.
3939
nerf_embedder_dtype: Optional[torch.dtype]
4040
use_x0: bool
41+
# Use sequential txt_ids instead of zeros
42+
use_sequential_txt_ids: bool
4143

4244
class ChromaRadiance(Chroma):
4345
"""
@@ -162,6 +164,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
162164
if params.use_x0:
163165
self.register_buffer("__x0__", torch.tensor([]))
164166

167+
if params.use_sequential_txt_ids:
168+
self.register_buffer("__sequential__", torch.tensor([]))
169+
165170
@property
166171
def _nerf_final_layer(self) -> nn.Module:
167172
if self.params.nerf_final_head_type == "linear":
@@ -313,6 +318,9 @@ def _forward(
313318
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
314319
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
315320
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
321+
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
322+
if params.use_sequential_txt_ids:
323+
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
316324

317325
img_out = self.forward_orig(
318326
img,

comfy/model_detection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
313313
dit_config["use_x0"] = True
314314
else:
315315
dit_config["use_x0"] = False
316+
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
317+
dit_config["use_sequential_txt_ids"] = True
318+
else:
319+
dit_config["use_sequential_txt_ids"] = False
316320
else:
317321
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
318322
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

comfy_extras/nodes_chroma_radiance.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def define_schema(cls) -> io.Schema:
6565
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
6666
advanced=True,
6767
),
68+
io.Boolean.Input(
69+
id="force_sequential_txt_ids",
70+
default=False,
71+
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
72+
advanced=True,
73+
),
6874
],
6975
outputs=[io.Model.Output()],
7076
)
@@ -78,11 +84,15 @@ def execute(
7884
start_sigma: float,
7985
end_sigma: float,
8086
nerf_tile_size: int,
87+
force_sequential_txt_ids: bool,
8188
) -> io.NodeOutput:
8289
radiance_options = {}
8390
if nerf_tile_size >= 0:
8491
radiance_options["nerf_tile_size"] = nerf_tile_size
8592

93+
if force_sequential_txt_ids:
94+
radiance_options["use_sequential_txt_ids"] = True
95+
8696
if not radiance_options:
8797
return io.NodeOutput(model)
8898

0 commit comments

Comments
 (0)