Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ Start with the `Quick Demo` below to run the primary training command. The recom
</table>

```bash
# 0. If uv is not installed
# 0. Install uv if needed
# Linux / macOS:
curl -LsSf https://astral.sh/uv/install.sh | sh
#
# Windows:
# powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
# choco install make -y

# 1. Clone the repository
git clone https://github.com/unilabsim/UniLab.git
Expand All @@ -90,7 +95,7 @@ cd UniLab
# 2. Install dependencies
# Pick the setup command for your platform.

# Linux CUDA or macOS
# Linux CUDA, macOS, or Windows
make setup

# Linux AMD / ROCm
Expand All @@ -101,7 +106,7 @@ make setup

# Without shell completion setup:
# uv sync --extra mujoco --extra motrix
# If `make` is not installed:
# If `make` is not installed or unavailable:
# uv sync --extra mujoco --extra motrix && uv run --no-sync unilab-complete install

# 3. Pre-trained checkpoint playback (downloads from Hugging Face on first run)
Expand Down
9 changes: 7 additions & 2 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@

```bash
# 0. 如果还没有安装 uv
# Linux / macOS:
curl -LsSf https://astral.sh/uv/install.sh | sh
#
# Windows:
# powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
# choco install make -y

# 1. 克隆仓库
git clone https://github.com/unilabsim/UniLab.git
Expand All @@ -90,7 +95,7 @@ cd UniLab
# 2. 安装依赖
# 请按你的平台选择对应的安装命令。

# Linux CUDA 或 macOS
# Linux CUDA、macOSWindows
make setup

# Linux AMD / ROCm
Expand All @@ -101,7 +106,7 @@ make setup

# 不使用 shell completion 设置时:
# uv sync --extra mujoco --extra motrix
# 如果没有安装 `make`:
# 如果没有安装或无法使用 `make`:
# uv sync --extra mujoco --extra motrix && uv run --no-sync unilab-complete install

# 3. 预训练 checkpoint 回放(首次运行会从 Hugging Face 下载)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"onnxruntime>=1.20 ; python_version >= '3.11'",
"huggingface_hub>=0.25",
"ninja ; sys_platform == 'linux'",
"imageio-ffmpeg>=0.6.0",
]

[project.scripts]
Expand Down Expand Up @@ -62,7 +63,7 @@ explicit = true

[tool.uv.sources]
torch = [
{ index = "pytorch-cu128", marker = "sys_platform=='linux'" },
{ index = "pytorch-cu128", marker = "sys_platform=='linux' or sys_platform=='win32'" },
]

[tool.uv]
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def forward(self, obs: torch.Tensor) -> torch.Tensor:
on_plan=log_playback_plan,
)
if play_video_path is not None:
print(f"Saving video to {play_video_path} with mediapy...")
print(f"Saving video to {play_video_path} ...")
print("Done.")
return play_video_path

Expand Down
39 changes: 17 additions & 22 deletions scripts/train_mlx_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,28 +325,23 @@ def _play_step(current_obs):
return mx.nan_to_num(raw_obs, nan=0.0, posinf=0.0, neginf=0.0)

output_dir = run_dir if run_dir is not None else task_log_root
try:
play_video_path = env.run_playback_mode(
play_render_mode=getattr(cfg.training, "play_render_mode", "auto"),
play_steps=getattr(cfg.training, "play_steps", None),
output_video=output_dir / "play_video.mp4",
initialize=lambda: obs,
step=_play_step,
camera_kwargs={
"cam_distance": getattr(cfg.training, "cam_distance", 2.0),
"cam_elevation": getattr(cfg.training, "cam_elevation", -20.0),
"cam_azimuth": getattr(cfg.training, "cam_azimuth", 90.0),
"cam_lookat": getattr(cfg.training, "cam_lookat", None),
"cam_tracking": getattr(cfg.training, "cam_tracking", False),
"cam_tracking_env_idx": getattr(cfg.training, "cam_tracking_env_idx", 0),
"cam_tracking_extra_envs": getattr(cfg.training, "cam_tracking_extra_envs", 2),
},
on_plan=lambda plan: log_playback_plan(plan, prefix="[MLX PPO] "),
)
except ImportError:
print("mediapy is required for play video export. Install with `pip install mediapy`.")
env.close()
return None
play_video_path = env.run_playback_mode(
play_render_mode=getattr(cfg.training, "play_render_mode", "auto"),
play_steps=getattr(cfg.training, "play_steps", None),
output_video=output_dir / "play_video.mp4",
initialize=lambda: obs,
step=_play_step,
camera_kwargs={
"cam_distance": getattr(cfg.training, "cam_distance", 2.0),
"cam_elevation": getattr(cfg.training, "cam_elevation", -20.0),
"cam_azimuth": getattr(cfg.training, "cam_azimuth", 90.0),
"cam_lookat": getattr(cfg.training, "cam_lookat", None),
"cam_tracking": getattr(cfg.training, "cam_tracking", False),
"cam_tracking_env_idx": getattr(cfg.training, "cam_tracking_env_idx", 0),
"cam_tracking_extra_envs": getattr(cfg.training, "cam_tracking_extra_envs", 2),
},
on_plan=lambda plan: log_playback_plan(plan, prefix="[MLX PPO] "),
)
if play_video_path is not None:
print(f"[MLX PPO] Play video saved: {play_video_path}")
else:
Expand Down
8 changes: 5 additions & 3 deletions src/unilab/algos/torch/appo/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from rsl_rl.utils import resolve_optimizer
from tensordict import TensorDict

from unilab.algos.torch.common.compile import get_torch_compile_for_cuda

_LOG_2_PI = math.log(2.0 * math.pi)
_NORMAL_ENTROPY_OFFSET = 0.5 * (1.0 + _LOG_2_PI)

Expand Down Expand Up @@ -209,7 +211,7 @@ def __init__(
self._update_counter = 0
self.last_update_metrics: dict[str, float] = {}
self.enable_compile = (
bool(enable_compile) and self._device_type == "cuda" and hasattr(torch, "compile")
bool(enable_compile) and get_torch_compile_for_cuda(device, warn=True) is not None
)

# Optimizer
Expand All @@ -221,8 +223,8 @@ def __init__(
self._compile_training_methods()

def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or self._device_type != "cuda":
compile_fn = get_torch_compile_for_cuda(self._device_type, warn=True)
if compile_fn is None:
return

self._minibatch_loss_fn = compile_fn(
Expand Down
48 changes: 48 additions & 0 deletions src/unilab/algos/torch/common/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import importlib.util
import sys
from collections.abc import Callable
from typing import Any, cast

import torch

_WARNED_REASONS: set[str] = set()


def _warn_once(reason: str) -> None:
if reason in _WARNED_REASONS:
return
_WARNED_REASONS.add(reason)
message = f"WARNING: torch.compile is unavailable for CUDA; using eager mode ({reason})."
try:
from rich.console import Console

Console(stderr=True).print(message, style="yellow")
except Exception: # pragma: no cover - best-effort diagnostic only
print(message, file=sys.stderr)


def get_torch_compile_for_cuda(
device: torch.device | str, *, warn: bool = False
) -> Callable[..., Any] | None:
"""Return ``torch.compile`` when CUDA Inductor dependencies are available."""
compile_fn = getattr(torch, "compile", None)
if torch.device(device).type != "cuda":
return None
if compile_fn is None:
if warn:
_warn_once("torch.compile is not present in this PyTorch build")
return None
if (
getattr(compile_fn, "__module__", "") == "torch"
and importlib.util.find_spec("triton") is None
):
if warn:
_warn_once(
"Triton is not installed; this environment cannot use CUDA Inductor. "
"PyTorch's Windows torch.compile documentation currently covers "
"CPU/XPU Inductor, not the CUDA/Triton path"
)
return None
return cast(Callable[..., Any], compile_fn)
7 changes: 4 additions & 3 deletions src/unilab/algos/torch/fast_sac/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn.functional as F
import torch.optim as optim

from unilab.algos.torch.common.compile import get_torch_compile_for_cuda
from unilab.base.augmentation import SymmetryAugmentation

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -406,7 +407,7 @@ def __init__(
self.use_autotune = use_autotune
self.use_amp = bool(use_amp) and self._device_type in ("cuda", "xpu")
self.use_compile = (
bool(use_compile) and self._device_type == "cuda" and hasattr(torch, "compile")
bool(use_compile) and get_torch_compile_for_cuda(self.device, warn=True) is not None
)
self.amp_dtype = amp_dtype
self._amp_dtype = self._resolve_amp_dtype(amp_dtype, self._device_type)
Expand Down Expand Up @@ -520,8 +521,8 @@ def _should_use_grad_scaler(
return bool(use_amp) and device_type == "cuda" and amp_dtype == torch.float16

def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or torch.device(self.device).type != "cuda":
compile_fn = get_torch_compile_for_cuda(self.device, warn=True)
if compile_fn is None:
return

compile_kwargs = {"options": {"triton.cudagraphs": False}}
Expand Down
7 changes: 4 additions & 3 deletions src/unilab/algos/torch/flash_sac/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
import torch.optim as optim

from unilab.algos.torch.common.compile import get_torch_compile_for_cuda
from unilab.algos.torch.common.normalization import EmpiricalNormalization
from unilab.algos.torch.flash_sac.network import (
FlashSACActor,
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(
self.amp_dtype = amp_dtype
self._amp_dtype = self._resolve_amp_dtype(amp_dtype, self.device.type)
self.use_compile = bool(
use_compile and hasattr(torch, "compile") and self.device.type == "cuda"
use_compile and get_torch_compile_for_cuda(self.device, warn=True) is not None
)

self.actor = FlashSACActor(
Expand Down Expand Up @@ -256,8 +257,8 @@ def __init__(
self._compile_training_methods()

def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or self.device.type != "cuda":
compile_fn = get_torch_compile_for_cuda(self.device, warn=True)
if compile_fn is None:
return

compile_kwargs = {"options": {"triton.cudagraphs": False}}
Expand Down
2 changes: 1 addition & 1 deletion src/unilab/algos/torch/hora/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def step_play_obs(obs_np: np.ndarray) -> np.ndarray:
),
)
if play_video_path is not None:
print(f"Saving video to {play_video_path} with mediapy...")
print(f"Saving video to {play_video_path} ...")
print("Done.")
return play_video_path

Expand Down
10 changes: 5 additions & 5 deletions src/unilab/algos/torch/rsl_rl_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from rsl_rl.algorithms import PPO
from tensordict import TensorDict

from unilab.algos.torch.common.compile import get_torch_compile_for_cuda

_LOG_2_PI = math.log(2.0 * math.pi)
_NORMAL_ENTROPY_OFFSET = 0.5 * (1.0 + _LOG_2_PI)

Expand All @@ -24,17 +26,15 @@ def __init__(
) -> None:
super().__init__(*args, **kwargs)
self.enable_compile = (
bool(enable_compile)
and torch.device(self.device).type == "cuda"
and hasattr(torch, "compile")
bool(enable_compile) and get_torch_compile_for_cuda(self.device, warn=True) is not None
)
self._minibatch_loss_fn = self._minibatch_loss_tensors
if self.enable_compile:
self._compile_training_methods()

def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or torch.device(self.device).type != "cuda":
compile_fn = get_torch_compile_for_cuda(self.device, warn=True)
if compile_fn is None:
return

self._minibatch_loss_fn = compile_fn(
Expand Down
21 changes: 15 additions & 6 deletions src/unilab/assets/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from __future__ import annotations

import logging
import ntpath
import os
import posixpath
from collections.abc import Sequence
from pathlib import Path

Expand Down Expand Up @@ -89,22 +91,23 @@ def resolve_checkpoint_file(
def _resolve_single(path_str: str, *, repo_id: str = _HF_MOTIONS_REPO_ID) -> str:
"""Resolve one asset file path, downloading if absent."""
path = Path(path_str)
is_absolute_input = path.is_absolute() or ntpath.isabs(path_str) or posixpath.isabs(path_str)

# Already exists locally — fast path.
if path.exists():
return str(path)

# Try interpreting as ASSETS_ROOT_PATH-relative.
if not path.is_absolute():
if not is_absolute_input:
local = ASSETS_ROOT_PATH / path
if local.exists():
return str(local)
relative = path_str
relative = _hf_relative_path(path_str)
else:
# Extract the portion relative to ASSETS_ROOT_PATH so we can
# request the matching file from the HF repo.
try:
relative = str(path.relative_to(ASSETS_ROOT_PATH))
relative = path.relative_to(ASSETS_ROOT_PATH).as_posix()
except ValueError:
raise FileNotFoundError(
f"Asset file not found and path is not under "
Expand All @@ -114,6 +117,11 @@ def _resolve_single(path_str: str, *, repo_id: str = _HF_MOTIONS_REPO_ID) -> str
return _download_from_hf(relative, repo_id=repo_id)


def _hf_relative_path(path_str: str) -> str:
"""Return a repo-relative HF path with POSIX separators."""
return path_str.replace("\\", "/")


def _hf_download(hf_hub_download, relative_path: str, *, repo_id: str) -> str: # type: ignore[no-untyped-def]
"""Call ``hf_hub_download`` with the standard arguments."""
return str(
Expand Down Expand Up @@ -206,6 +214,7 @@ def _resolve_snapshot_dir(directory: str, *, repo_id: str, marker: str) -> Path:
Returns:
Absolute ``Path`` to the resolved directory.
"""
hf_directory = _hf_relative_path(directory)
target = ASSETS_ROOT_PATH / directory
if (target / marker).is_file():
return target
Expand All @@ -221,10 +230,10 @@ def _resolve_snapshot_dir(directory: str, *, repo_id: str, marker: str) -> Path:
" uv pip install huggingface_hub"
) from None

logger.info("Downloading %s from HF repo %s ...", directory, repo_id)
logger.info("Downloading %s from HF repo %s ...", hf_directory, repo_id)

try:
_snapshot_download(snapshot_download, directory, repo_id=repo_id)
_snapshot_download(snapshot_download, hf_directory, repo_id=repo_id)
except Exception:
current_endpoint = os.environ.get("HF_ENDPOINT", "")
if current_endpoint and current_endpoint != _HF_OFFICIAL_ENDPOINT:
Expand All @@ -236,7 +245,7 @@ def _resolve_snapshot_dir(directory: str, *, repo_id: str, marker: str) -> Path:
original = os.environ["HF_ENDPOINT"]
os.environ["HF_ENDPOINT"] = _HF_OFFICIAL_ENDPOINT
try:
_snapshot_download(snapshot_download, directory, repo_id=repo_id)
_snapshot_download(snapshot_download, hf_directory, repo_id=repo_id)
finally:
os.environ["HF_ENDPOINT"] = original
else:
Expand Down
Loading
Loading