|
| 1 | +import av |
| 2 | +import torch |
| 3 | +from av.codec import CodecContext |
1 | 4 | from typing_extensions import override |
2 | 5 |
|
3 | 6 | from comfy_api.latest import IO, ComfyExtension, Input |
4 | 7 | from comfy_api_nodes.apis.bria import ( |
5 | 8 | BriaEditImageRequest, |
| 9 | + BriaImageEditResponse, |
6 | 10 | BriaRemoveBackgroundRequest, |
7 | 11 | BriaRemoveBackgroundResponse, |
8 | 12 | BriaRemoveVideoBackgroundRequest, |
9 | 13 | BriaRemoveVideoBackgroundResponse, |
10 | | - BriaImageEditResponse, |
11 | 14 | BriaStatusResponse, |
12 | 15 | InputModerationSettings, |
13 | 16 | ) |
@@ -316,13 +319,104 @@ async def execute( |
316 | 319 | return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) |
317 | 320 |
|
318 | 321 |
|
| 322 | +def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]: |
| 323 | + """Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask. |
| 324 | +
|
| 325 | + VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames |
| 326 | + are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W] |
| 327 | + following the Load Image convention (1 = transparent) for compositing or Save WEBM. |
| 328 | + """ |
| 329 | + rgb_frames: list[torch.Tensor] = [] |
| 330 | + alpha_frames: list[torch.Tensor] = [] |
| 331 | + with av.open(video.get_stream_source(), mode="r") as container: |
| 332 | + stream = container.streams.video[0] |
| 333 | + decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None |
| 334 | + for packet in container.demux(stream): |
| 335 | + for frame in (decoder.decode(packet) if decoder is not None else packet.decode()): |
| 336 | + rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0 |
| 337 | + rgb_frames.append(rgba[..., :3]) |
| 338 | + alpha_frames.append(rgba[..., 3]) |
| 339 | + images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3) |
| 340 | + mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64)) |
| 341 | + return images, mask |
| 342 | + |
| 343 | + |
| 344 | +class BriaTransparentVideoBackground(IO.ComfyNode): |
| 345 | + |
| 346 | + @classmethod |
| 347 | + def define_schema(cls): |
| 348 | + return IO.Schema( |
| 349 | + node_id="BriaTransparentVideoBackground", |
| 350 | + display_name="Bria Remove Video Background (Transparent)", |
| 351 | + category="partner/video/Bria", |
| 352 | + description="Remove the background from a video using Bria and return the cut-out frames " |
| 353 | + "plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to " |
| 354 | + "write a transparent video.", |
| 355 | + inputs=[ |
| 356 | + IO.Video.Input("video"), |
| 357 | + IO.Int.Input( |
| 358 | + "seed", |
| 359 | + default=0, |
| 360 | + min=0, |
| 361 | + max=2147483647, |
| 362 | + display_mode=IO.NumberDisplay.number, |
| 363 | + control_after_generate=True, |
| 364 | + tooltip="Seed controls whether the node should re-run; " |
| 365 | + "results are non-deterministic regardless of seed.", |
| 366 | + ), |
| 367 | + ], |
| 368 | + outputs=[ |
| 369 | + IO.Image.Output(display_name="images"), |
| 370 | + IO.Mask.Output(display_name="mask"), |
| 371 | + ], |
| 372 | + hidden=[ |
| 373 | + IO.Hidden.auth_token_comfy_org, |
| 374 | + IO.Hidden.api_key_comfy_org, |
| 375 | + IO.Hidden.unique_id, |
| 376 | + ], |
| 377 | + is_api_node=True, |
| 378 | + price_badge=IO.PriceBadge( |
| 379 | + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", |
| 380 | + ), |
| 381 | + ) |
| 382 | + |
| 383 | + @classmethod |
| 384 | + async def execute( |
| 385 | + cls, |
| 386 | + video: Input.Video, |
| 387 | + seed: int, |
| 388 | + ) -> IO.NodeOutput: |
| 389 | + validate_video_duration(video, max_duration=60.0) |
| 390 | + response = await sync_op( |
| 391 | + cls, |
| 392 | + ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"), |
| 393 | + data=BriaRemoveVideoBackgroundRequest( |
| 394 | + video=await upload_video_to_comfyapi(cls, video), |
| 395 | + background_color="Transparent", |
| 396 | + output_container_and_codec="webm_vp9", |
| 397 | + seed=seed, |
| 398 | + ), |
| 399 | + response_model=BriaStatusResponse, |
| 400 | + ) |
| 401 | + response = await poll_op( |
| 402 | + cls, |
| 403 | + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), |
| 404 | + status_extractor=lambda r: r.status, |
| 405 | + response_model=BriaRemoveVideoBackgroundResponse, |
| 406 | + ) |
| 407 | + video_out = await download_url_to_video_output(response.result.video_url) |
| 408 | + images, mask = _video_to_images_and_mask(video_out) |
| 409 | + return IO.NodeOutput(images, mask) |
| 410 | + |
| 411 | + |
319 | 412 | class BriaExtension(ComfyExtension): |
320 | 413 | @override |
321 | 414 | async def get_node_list(self) -> list[type[IO.ComfyNode]]: |
322 | 415 | return [ |
323 | 416 | BriaImageEditNode, |
324 | 417 | BriaRemoveImageBackground, |
325 | 418 | BriaRemoveVideoBackground, |
| 419 | + BriaTransparentVideoBackground, |
326 | 420 | ] |
327 | 421 |
|
328 | 422 |
|
|
0 commit comments