Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions chronoedit/_ext/imaginaire/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,26 @@
import math
import os

import pynvml
try:
import pynvml
PYNVML_AVAILABLE = True
except ImportError:
PYNVML_AVAILABLE = False

from loguru import logger as logging
import torch


def get_gpu_architecture():
"""
Retrieves the GPU architecture of the available GPUs.

Returns:
str: The GPU architecture, which can be "H100", "A100", or "Other".
str: The GPU architecture, which can be "H100", "A100", "L40S", "B200", "Other", or None (CPU mode).
"""
if not PYNVML_AVAILABLE or not torch.cuda.is_available():
return None

try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
Expand All @@ -47,10 +56,13 @@ def get_gpu_architecture():
return "L40S"
elif "B200" in model_name:
return "B200"
except pynvml.NVMLError as error:
except Exception as error:
print(f"Failed to get GPU info: {error}")
finally:
pynvml.nvmlShutdown()
try:
pynvml.nvmlShutdown()
except:
pass

# return "Other" incase of non hopper/ampere or error
return "Other"
Expand All @@ -65,13 +77,18 @@ class GPUArchitectureNotSupported(Exception):


def print_gpu_mem(str=None):
if not PYNVML_AVAILABLE or not torch.cuda.is_available():
if str:
logging.info(f"{str}: Running on CPU (no GPU memory info)")
return

try:
pynvml.nvmlInit()
meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0))
logging.info(
f"{str}: {meminfo.used / 1024 / 1024}/{meminfo.total / 1024 / 1024}MiB used ({meminfo.free / 1024 / 1024}MiB free)"
)
except pynvml.NVMLError as error:
except Exception as error:
print(f"Failed to get GPU memory info: {error}")


Expand All @@ -86,19 +103,25 @@ def force_gc():


def gpu0_has_80gb_or_less():
if not PYNVML_AVAILABLE or not torch.cuda.is_available():
return True # Conservative default for CPU mode

try:
pynvml.nvmlInit()
meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0))
return meminfo.total / 1024 / 1024 / 1024 <= 80
except pynvml.NVMLError as error:
except Exception as error:
print(f"Failed to get GPU memory info: {error}")
return True # Conservative default on error


class Device:
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64) if os.cpu_count() else 1 # type: ignore

def __init__(self, device_idx: int):
super().__init__()
if not PYNVML_AVAILABLE or not torch.cuda.is_available():
raise RuntimeError("Device class requires CUDA and pynvml to be available")
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)

def get_name(self) -> str:
Expand Down
19 changes: 19 additions & 0 deletions chronoedit/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .device_utils import get_device, get_device_type, is_cuda_available

__all__ = ["get_device", "get_device_type", "is_cuda_available"]

168 changes: 168 additions & 0 deletions chronoedit/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Device utility module for ChronoEdit.

Provides centralized device detection and management with automatic fallback
from CUDA to CPU when CUDA is not available.
"""

import warnings
from typing import Optional, Union

import torch


def is_cuda_available() -> bool:
"""
Check if CUDA is available.

Returns:
bool: True if CUDA is available, False otherwise.
"""
return torch.cuda.is_available()


def get_device_type(device: Optional[Union[str, torch.device]] = None) -> str:
"""
Get the device type string ('cuda' or 'cpu').

Args:
device: Device specification. Can be:
- None: Auto-detect (CUDA if available, else CPU)
- str: "cuda", "cpu", "cuda:0", etc.
- torch.device: Device object

Returns:
str: Device type ('cuda' or 'cpu')
"""
if device is None:
# Auto-detect
if is_cuda_available():
return "cuda"
else:
warnings.warn(
"CUDA is not available. Falling back to CPU. "
"Note: CPU inference will be significantly slower.",
UserWarning
)
return "cpu"

if isinstance(device, torch.device):
return device.type

if isinstance(device, str):
# Handle "cuda:0" -> "cuda", "cpu" -> "cpu"
device_lower = device.lower()
if device_lower.startswith("cuda"):
if not is_cuda_available():
warnings.warn(
f"CUDA device '{device}' requested but CUDA is not available. "
"Falling back to CPU.",
UserWarning
)
return "cpu"
return "cuda"
elif device_lower == "cpu":
return "cpu"
else:
raise ValueError(f"Unknown device type: {device}")

raise TypeError(f"Device must be None, str, or torch.device, got {type(device)}")


def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
"""
Get a torch.device object with automatic CUDA to CPU fallback.

Args:
device: Device specification. Can be:
- None: Auto-detect (CUDA if available, else CPU)
- str: "cuda", "cpu", "cuda:0", etc.
- torch.device: Device object (returned as-is after validation)

Returns:
torch.device: Device object ready for use

Examples:
>>> device = get_device() # Auto-detect
>>> device = get_device("cuda") # Use CUDA (falls back to CPU if unavailable)
>>> device = get_device("cpu") # Force CPU
>>> device = get_device("cuda:0") # Specific CUDA device
"""
if device is None:
# Auto-detect
device_type = get_device_type(None)
return torch.device(device_type)

if isinstance(device, torch.device):
# Validate and potentially fall back
if device.type == "cuda" and not is_cuda_available():
warnings.warn(
f"CUDA device '{device}' requested but CUDA is not available. "
"Falling back to CPU.",
UserWarning
)
return torch.device("cpu")
return device

if isinstance(device, str):
device_lower = device.lower()

if device_lower.startswith("cuda"):
if not is_cuda_available():
warnings.warn(
f"CUDA device '{device}' requested but CUDA is not available. "
"Falling back to CPU.",
UserWarning
)
return torch.device("cpu")
# Return the specific CUDA device (e.g., cuda:0, cuda:1)
return torch.device(device_lower)

elif device_lower == "cpu":
return torch.device("cpu")

else:
raise ValueError(f"Unknown device type: {device}")

raise TypeError(f"Device must be None, str, or torch.device, got {type(device)}")


def get_device_map(device: Optional[Union[str, torch.device]] = None) -> str:
"""
Get device_map string for HuggingFace models with automatic fallback.

Args:
device: Device specification (None for auto-detect)

Returns:
str: Device map string compatible with HuggingFace from_pretrained

Examples:
>>> device_map = get_device_map() # "cuda:0" or "cpu"
>>> device_map = get_device_map("cuda") # "cuda:0" or "cpu" (with fallback)
"""
device_obj = get_device(device)

if device_obj.type == "cuda":
# Use specific device index if available, otherwise default to 0
if device_obj.index is not None:
return f"cuda:{device_obj.index}"
return "cuda:0"
else:
return "cpu"

9 changes: 6 additions & 3 deletions chronoedit_diffusers/pipeline_chronoedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ def __call__(
self._num_timesteps = len(timesteps)

if offload_model:
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand Down Expand Up @@ -692,7 +693,8 @@ def __call__(
)[0]

if offload_model:
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
Expand Down Expand Up @@ -727,7 +729,8 @@ def __call__(

if offload_model:
self.transformer.cpu()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

self._current_timestep = None

Expand Down
41 changes: 35 additions & 6 deletions scripts/prompt_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
)
from qwen_vl_utils import process_vision_info

# Import device utilities
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from chronoedit.utils.device_utils import get_device_map


def parse_args():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -59,12 +65,20 @@ def parse_args():
return parser.parse_args()


def pick_attn_implementation(prefer_flash: bool = True) -> str:
def pick_attn_implementation(prefer_flash: bool = True, device: str = "cuda") -> str:
"""
Decide the best attn_implementation based on environment.

Args:
prefer_flash: Whether to prefer flash attention if available
device: Target device ("cuda" or "cpu")

Returns one of: "flash_attention_2", "sdpa", "eager".
"""
# CPU only supports eager
if device == "cpu" or not torch.cuda.is_available():
return "eager"

# Try FlashAttention v2 first (needs SM80+ and the wheel to import)
if prefer_flash:
try:
Expand All @@ -84,18 +98,33 @@ def pick_attn_implementation(prefer_flash: bool = True) -> str:

# Fallback: eager (always works, slower)
return "eager"
def load_model(model_name):
"""Load the vision-language model and processor."""
def load_model(model_name, device=None):
"""
Load the vision-language model and processor.

Args:
model_name: Name/path of the model to load
device: Target device (None for auto-detect, "cuda", "cpu", etc.)

Returns:
Tuple of (model, processor)
"""
print(f"Loading model: {model_name}")

# Get device map
device_map = get_device_map(device)
device_type = "cpu" if device_map == "cpu" else "cuda"
print(f"Using device: {device_map}")

attn_impl = pick_attn_implementation(prefer_flash=True)
attn_impl = pick_attn_implementation(prefer_flash=True, device=device_type)
print(f"Using attention implementation: {attn_impl}")

if model_name == "Qwen/Qwen2.5-VL-7B-Instruct":
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
dtype=torch.bfloat16,
attn_implementation=attn_impl,
device_map="cuda:0"
device_map=device_map
)
processor = AutoProcessor.from_pretrained(model_name)

Expand All @@ -104,7 +133,7 @@ def load_model(model_name):
model_name,
dtype=torch.bfloat16,
attn_implementation=attn_impl,
device_map="cuda:0"
device_map=device_map
)
processor = AutoProcessor.from_pretrained(model_name)

Expand Down
Loading