Skip to content

Commit 4f99ce0

Browse files
authored
[Partner Nodes] fix SaveWEBM node to save alpha channel; add BriaTransparentVideoBackground Partner node (Comfy-Org#14257)
1 parent 7758b9b commit 4f99ce0

2 files changed

Lines changed: 105 additions & 4 deletions

File tree

comfy_api_nodes/nodes_bria.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
import av
2+
import torch
3+
from av.codec import CodecContext
14
from typing_extensions import override
25

36
from comfy_api.latest import IO, ComfyExtension, Input
47
from comfy_api_nodes.apis.bria import (
58
BriaEditImageRequest,
9+
BriaImageEditResponse,
610
BriaRemoveBackgroundRequest,
711
BriaRemoveBackgroundResponse,
812
BriaRemoveVideoBackgroundRequest,
913
BriaRemoveVideoBackgroundResponse,
10-
BriaImageEditResponse,
1114
BriaStatusResponse,
1215
InputModerationSettings,
1316
)
@@ -316,13 +319,104 @@ async def execute(
316319
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
317320

318321

322+
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
323+
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
324+
325+
VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames
326+
are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W]
327+
following the Load Image convention (1 = transparent) for compositing or Save WEBM.
328+
"""
329+
rgb_frames: list[torch.Tensor] = []
330+
alpha_frames: list[torch.Tensor] = []
331+
with av.open(video.get_stream_source(), mode="r") as container:
332+
stream = container.streams.video[0]
333+
decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None
334+
for packet in container.demux(stream):
335+
for frame in (decoder.decode(packet) if decoder is not None else packet.decode()):
336+
rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0
337+
rgb_frames.append(rgba[..., :3])
338+
alpha_frames.append(rgba[..., 3])
339+
images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3)
340+
mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64))
341+
return images, mask
342+
343+
344+
class BriaTransparentVideoBackground(IO.ComfyNode):
345+
346+
@classmethod
347+
def define_schema(cls):
348+
return IO.Schema(
349+
node_id="BriaTransparentVideoBackground",
350+
display_name="Bria Remove Video Background (Transparent)",
351+
category="partner/video/Bria",
352+
description="Remove the background from a video using Bria and return the cut-out frames "
353+
"plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to "
354+
"write a transparent video.",
355+
inputs=[
356+
IO.Video.Input("video"),
357+
IO.Int.Input(
358+
"seed",
359+
default=0,
360+
min=0,
361+
max=2147483647,
362+
display_mode=IO.NumberDisplay.number,
363+
control_after_generate=True,
364+
tooltip="Seed controls whether the node should re-run; "
365+
"results are non-deterministic regardless of seed.",
366+
),
367+
],
368+
outputs=[
369+
IO.Image.Output(display_name="images"),
370+
IO.Mask.Output(display_name="mask"),
371+
],
372+
hidden=[
373+
IO.Hidden.auth_token_comfy_org,
374+
IO.Hidden.api_key_comfy_org,
375+
IO.Hidden.unique_id,
376+
],
377+
is_api_node=True,
378+
price_badge=IO.PriceBadge(
379+
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
380+
),
381+
)
382+
383+
@classmethod
384+
async def execute(
385+
cls,
386+
video: Input.Video,
387+
seed: int,
388+
) -> IO.NodeOutput:
389+
validate_video_duration(video, max_duration=60.0)
390+
response = await sync_op(
391+
cls,
392+
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
393+
data=BriaRemoveVideoBackgroundRequest(
394+
video=await upload_video_to_comfyapi(cls, video),
395+
background_color="Transparent",
396+
output_container_and_codec="webm_vp9",
397+
seed=seed,
398+
),
399+
response_model=BriaStatusResponse,
400+
)
401+
response = await poll_op(
402+
cls,
403+
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
404+
status_extractor=lambda r: r.status,
405+
response_model=BriaRemoveVideoBackgroundResponse,
406+
)
407+
video_out = await download_url_to_video_output(response.result.video_url)
408+
images, mask = _video_to_images_and_mask(video_out)
409+
return IO.NodeOutput(images, mask)
410+
411+
319412
class BriaExtension(ComfyExtension):
320413
@override
321414
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
322415
return [
323416
BriaImageEditNode,
324417
BriaRemoveImageBackground,
325418
BriaRemoveVideoBackground,
419+
BriaTransparentVideoBackground,
326420
]
327421

328422

comfy_extras/nodes_video.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def define_schema(cls):
1919
category="video",
2020
is_experimental=True,
2121
inputs=[
22-
io.Image.Input("images"),
22+
io.Image.Input("images", tooltip="RGBA images are saved with their alpha channel as transparency (vp9 codec only)."),
2323
io.String.Input("filename_prefix", default="ComfyUI"),
2424
io.Combo.Input("codec", options=["vp9", "av1"]),
2525
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
@@ -45,18 +45,25 @@ def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
4545
for x in cls.hidden.extra_pnginfo:
4646
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
4747

48+
# Save transparency when the images carry an alpha channel (RGBA) and the codec supports it.
49+
# vp9 -> yuva420p; other codecs have no usable alpha path, so the alpha is ignored.
50+
save_alpha = images.shape[-1] == 4 and codec == "vp9"
51+
4852
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
4953
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
5054
stream.width = images.shape[-2]
5155
stream.height = images.shape[-3]
52-
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
56+
stream.pix_fmt = "yuva420p" if save_alpha else ("yuv420p10le" if codec == "av1" else "yuv420p")
5357
stream.bit_rate = 0
5458
stream.options = {'crf': str(crf)}
5559
if codec == "av1":
5660
stream.options["preset"] = "6"
5761

5862
for frame in images:
59-
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
63+
if save_alpha:
64+
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :4] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgba")
65+
else:
66+
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
6067
for packet in stream.encode(frame):
6168
container.mux(packet)
6269
container.mux(stream.encode())

0 commit comments

Comments
 (0)