diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e3507e9..2bd1665 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,6 @@ jobs: - name: Install prerequisites (for OpenCV) run: apt-get update && apt-get install ffmpeg libsm6 libxext6 -y - name: Install trajdata base version - run: python -m pip install . + run: python -m pip install ".[dev]" - name: Run tests - run: python -m unittest tests/test_state.py + run: python -m pytest tests/ -v --tb=short diff --git a/.gitignore b/.gitignore index 1e55dd9..177053b 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.idea # PyInstaller # Usually these files are written by a python script from a template diff --git a/examples/eupeds_example.py b/examples/eupeds_example.py new file mode 100644 index 0000000..bfa036c --- /dev/null +++ b/examples/eupeds_example.py @@ -0,0 +1,43 @@ +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset + + +def main(): + dataset = UnifiedDataset( + desired_data=["eupeds_eth-train"], + centric="agent", + desired_dt=0.4, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.PEDESTRIAN], + num_workers=0, + verbose=True, + data_dirs={ + "eupeds_eth": "~/datasets/eth_ucy", + }, + ) + + print(f"\n# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=8, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + batch: AgentBatch + for i, batch in enumerate(tqdm(dataloader, desc="Loading batches")): + print(f"\nBatch {i}: agent_hist shape={batch.agent_hist.shape}, future shape={batch.agent_fut.shape}") + if i >= 2: + print("... (showing first 3 batches only)") + break + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/examples/new_features_demo.py b/examples/new_features_demo.py new file mode 100644 index 0000000..5ba228d --- /dev/null +++ b/examples/new_features_demo.py @@ -0,0 +1,243 @@ +""" +Demo: 4 new trajdata features +============================= +1. Fast I/O – export dataset → zarr, reload with PrecomputedDataset +2. CSV Dataset – load any CSV directory as a dataset +3. Data Enrichment – MirrorAugmentation, SpeedScaleAugmentation, MotionTypeLabeler +4. Advanced Simulation – CollisionMetric, OffRoadRate, SimRunner + ConstantVelocityPolicy +""" +import os +import tempfile +from pathlib import Path +from collections import defaultdict + +import numpy as np +from torch.utils.data import DataLoader + +from trajdata import AgentType, UnifiedDataset +from trajdata.augmentation import MirrorAugmentation, MotionTypeLabeler, SpeedScaleAugmentation + +# ────────────────────────────────────────────────────────────────────────────── +# Shared base dataset (ETH/UCY, already downloaded) +# ────────────────────────────────────────────────────────────────────────────── +BASE_DATA_DIRS = {"eupeds_eth": "~/datasets/eth_ucy"} + + +def make_base_dataset(**kwargs): + return UnifiedDataset( + desired_data=["eupeds_eth-train"], + centric="agent", + desired_dt=0.4, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.PEDESTRIAN], + num_workers=0, + verbose=False, + data_dirs=BASE_DATA_DIRS, + **kwargs, + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# FEATURE 1 – Fast I/O (zarr + numpy) +# ══════════════════════════════════════════════════════════════════════════════ + +def demo_fast_io(): + print("\n" + "="*60) + print("FEATURE 1 – Fast I/O Formats") + print("="*60) + + from trajdata.io import DataExporter, PrecomputedDataset + + dataset = make_base_dataset() + print(f" Original dataset: {len(dataset):,} samples") + + with tempfile.TemporaryDirectory() as tmpdir: + # ── zarr export ── + zarr_path = Path(tmpdir) / "cache.zarr" + print(" Exporting to zarr …") + DataExporter.export(dataset, str(zarr_path), format="zarr", + batch_size=32, num_workers=0, verbose=False) + + fast_ds = PrecomputedDataset(str(zarr_path), format="zarr") + print(f" PrecomputedDataset (zarr): {len(fast_ds):,} samples") + sample = fast_ds[0] + print(f" Fields available: {fast_ds.fields}") + if "agent_hist" in sample: + print(f" agent_hist shape: {sample['agent_hist'].shape}") + + # ── numpy export ── + np_path = Path(tmpdir) / "cache_np" + print(" Exporting to numpy …") + DataExporter.export(dataset, str(np_path), format="numpy", + batch_size=32, num_workers=0, verbose=False) + + fast_np = PrecomputedDataset(str(np_path), format="numpy") + print(f" PrecomputedDataset (numpy): {len(fast_np):,} samples") + + print(" ✓ Feature 1 complete") + + +# ══════════════════════════════════════════════════════════════════════════════ +# FEATURE 2 – CSV Dataset Adapter +# ══════════════════════════════════════════════════════════════════════════════ + +def demo_csv_dataset(): + print("\n" + "="*60) + print("FEATURE 2 – CSV Dataset Support") + print("="*60) + + import json + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + # ── generate synthetic CSV scene ── + rng = np.random.default_rng(42) + for scene_idx in range(3): + rows = [] + for agent_id in range(5): + x, y = rng.uniform(0, 50, 2) + vx, vy = rng.uniform(-1, 1, 2) + for t in range(30): + rows.append({ + "frame_id": t, + "agent_id": agent_id, + "x": x + vx * t * 0.1, + "y": y + vy * t * 0.1, + }) + import pandas as pd + pd.DataFrame(rows).to_csv(tmpdir / f"scene_{scene_idx:03d}.csv", index=False) + + # ── config.json with splits ── + config = { + "dt": 0.1, + "splits": { + "train": ["scene_000", "scene_001"], + "val": ["scene_002"], + } + } + (tmpdir / "config.json").write_text(json.dumps(config)) + + # ── load via UnifiedDataset ── + dataset = UnifiedDataset( + desired_data=["csv_mydata-train"], + centric="agent", + desired_dt=0.1, + history_sec=(1.0, 1.0), + future_sec=(1.0, 1.0), + only_predict=[AgentType.PEDESTRIAN], + num_workers=0, + verbose=False, + data_dirs={"csv_mydata": str(tmpdir)}, + ) + print(f" CSV dataset (train split): {len(dataset):,} samples") + + if len(dataset) > 0: + loader2 = DataLoader(dataset, batch_size=4, shuffle=False, + collate_fn=dataset.get_collate_fn(), num_workers=0) + batch2 = next(iter(loader2)) + print(f" Sample agent_hist shape: {batch2.agent_hist.shape}") + + print(" ✓ Feature 2 complete") + + +# ══════════════════════════════════════════════════════════════════════════════ +# FEATURE 3 – Data Enrichment Augmentations +# ══════════════════════════════════════════════════════════════════════════════ + +def demo_enrichment(): + print("\n" + "="*60) + print("FEATURE 3 – Data Enrichment & Auto-Labeling") + print("="*60) + + # Combine all three new augmentations + augmentations = [ + MirrorAugmentation(axis="x", prob=0.5), + SpeedScaleAugmentation(scale_min=0.8, scale_max=1.2), + MotionTypeLabeler(stationary_thresh=0.3, walking_thresh=2.0), + ] + + dataset = make_base_dataset(augmentations=augmentations) + loader = DataLoader( + dataset, + batch_size=16, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + batch = next(iter(loader)) + + print(f" agent_hist shape : {batch.agent_hist.shape}") + print(f" agent_fut shape : {batch.agent_fut.shape}") + + if "motion_type" in batch.extras: + mt = batch.extras["motion_type"] + labels = {0: "STATIONARY", 1: "WALKING", 2: "RUNNING", 3: "FAST"} + for lbl_id, lbl_name in labels.items(): + count = (mt == lbl_id).sum().item() + print(f" {lbl_name:12s}: {count} agents") + else: + print(" (motion_type not in extras – state format may lack velocity channels)") + + print(" ✓ Feature 3 complete") + + +# ══════════════════════════════════════════════════════════════════════════════ +# FEATURE 4 – Advanced Simulation +# ══════════════════════════════════════════════════════════════════════════════ + +def demo_simulation(): + print("\n" + "="*60) + print("FEATURE 4 – Advanced Simulation Features") + print("="*60) + + from trajdata.simulation import ( + ADE, FDE, CollisionMetric, OffRoadRate, + SimulationScene, ConstantVelocityPolicy, SimRunner, + ) + + dataset = make_base_dataset() + + # Pick the first available scene + loaded_scene = dataset.get_scene(0) + + print(f" Scene: {loaded_scene.name} ({loaded_scene.length_timesteps} timesteps)") + + sim_scene = SimulationScene( + env_name="sim_demo", + scene_name="sim_scene_001", + scene=loaded_scene, + dataset=dataset, + init_timestep=0, + freeze_agents=True, + ) + + policy = ConstantVelocityPolicy() + runner = SimRunner(sim_scene, policy, max_steps=10) + + metrics = [ADE(), FDE(), CollisionMetric(distance_thresh=1.0), OffRoadRate()] + results = runner.run(metrics=metrics, verbose=False) + + print(f" Simulation ran for {results['steps']} steps") + for metric_name, per_agent in results["metrics"].items(): + avg = np.mean(list(per_agent.values())) + print(f" {metric_name:20s}: mean={avg:.4f}") + + print(" ✓ Feature 4 complete") + + +# ══════════════════════════════════════════════════════════════════════════════ +# Main +# ══════════════════════════════════════════════════════════════════════════════ + +if __name__ == "__main__": + demo_fast_io() + demo_csv_dataset() + demo_enrichment() + demo_simulation() + + print("\n" + "="*60) + print("All 4 features demonstrated successfully!") + print("="*60) diff --git a/img/icon/icon_augment.svg b/img/icon/icon_augment.svg new file mode 100644 index 0000000..74d5ad1 --- /dev/null +++ b/img/icon/icon_augment.svg @@ -0,0 +1,10 @@ + + + + + \ No newline at end of file diff --git a/img/icon/icon_dashboard.svg b/img/icon/icon_dashboard.svg new file mode 100644 index 0000000..fecc302 --- /dev/null +++ b/img/icon/icon_dashboard.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/img/icon/icon_dataset.svg b/img/icon/icon_dataset.svg new file mode 100644 index 0000000..312ee72 --- /dev/null +++ b/img/icon/icon_dataset.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/img/icon/icon_export.svg b/img/icon/icon_export.svg new file mode 100644 index 0000000..0d61182 --- /dev/null +++ b/img/icon/icon_export.svg @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/img/icon/icon_run_demo.svg b/img/icon/icon_run_demo.svg new file mode 100644 index 0000000..509751a --- /dev/null +++ b/img/icon/icon_run_demo.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/img/icon/icon_simulate.svg b/img/icon/icon_simulate.svg new file mode 100644 index 0000000..3832fa8 --- /dev/null +++ b/img/icon/icon_simulate.svg @@ -0,0 +1,6 @@ + + + +revert + + \ No newline at end of file diff --git a/img/icon/icon_visualize.svg b/img/icon/icon_visualize.svg new file mode 100644 index 0000000..ce3aee7 --- /dev/null +++ b/img/icon/icon_visualize.svg @@ -0,0 +1,12 @@ + + + + + \ No newline at end of file diff --git a/run_webui.sh b/run_webui.sh new file mode 100755 index 0000000..c82be09 --- /dev/null +++ b/run_webui.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# Launch, stop or restart the trajdata Web UI +set -e +cd "$(dirname "$0")" + +PORT=5006 +COMMAND="run" +EXTRA_ARGS="" + +while [[ $# -gt 0 ]]; do + case $1 in + stop) COMMAND="stop"; shift ;; + restart) COMMAND="restart"; shift ;; + --port) PORT="$2"; shift 2 ;; + --port=*) PORT="${1#*=}"; shift ;; + --no-browser) EXTRA_ARGS="$EXTRA_ARGS --no-browser"; shift ;; + *) shift ;; + esac +done + +function stop_server() { + PID=$(lsof -ti :$PORT || true) + if [ -n "$PID" ]; then + echo "Stopping server on port $PORT (PID: $PID)..." + kill -9 $PID 2>/dev/null || true + sleep 1 + else + echo "No server running on port $PORT." + fi +} + +function run_server() { + # Auto-kill if port is busy before running + PID=$(lsof -ti :$PORT || true) + if [ -n "$PID" ]; then + echo "Force-closing existing process on port $PORT..." + kill -9 $PID 2>/dev/null || true + sleep 0.5 + fi + echo "Launching trajdata Web UI at http://localhost:$PORT/..." + exec .venv/bin/python trajdata_webui/main.py --port "$PORT" $EXTRA_ARGS +} + +case $COMMAND in + stop) + stop_server + ;; + restart) + stop_server + run_server + ;; + run) + run_server + ;; +esac diff --git a/src/trajdata/augmentation/__init__.py b/src/trajdata/augmentation/__init__.py index 9a55625..685ed25 100644 --- a/src/trajdata/augmentation/__init__.py +++ b/src/trajdata/augmentation/__init__.py @@ -1,3 +1,6 @@ from .augmentation import Augmentation, BatchAugmentation, DatasetAugmentation from .low_vel_yaw_correction import LowSpeedYawCorrection +from .mirror import MirrorAugmentation +from .motion_type import MotionTypeLabeler from .noise_histories import NoiseHistories +from .speed_scale import SpeedScaleAugmentation diff --git a/src/trajdata/augmentation/mirror.py b/src/trajdata/augmentation/mirror.py new file mode 100644 index 0000000..8724ab9 --- /dev/null +++ b/src/trajdata/augmentation/mirror.py @@ -0,0 +1,84 @@ +""" +MirrorAugmentation: randomly flip trajectories horizontally (x-axis) or +vertically (y-axis) at batch time. + +Applies to: agent_hist, agent_fut, neigh_hist, neigh_fut. +Heading and velocity components are adjusted consistently. + +Usage:: + + from trajdata.augmentation import MirrorAugmentation + dataset = UnifiedDataset(..., augmentations=[MirrorAugmentation(axis="x", prob=0.5)]) +""" +import torch + +from trajdata.augmentation.augmentation import BatchAugmentation +from trajdata.data_structures.batch import AgentBatch, SceneBatch + + +class MirrorAugmentation(BatchAugmentation): + """Randomly mirror trajectories along the x or y axis. + + Args: + axis: Which axis to mirror – ``"x"`` flips the x-coordinate and + ``"y"`` flips the y-coordinate (default ``"x"``). + prob: Probability of applying the flip to each sample (default 0.5). + """ + + def __init__(self, axis: str = "x", prob: float = 0.5) -> None: + if axis not in ("x", "y"): + raise ValueError("axis must be 'x' or 'y'") + self.axis = axis + self.prob = prob + + # ----------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------- + + def _flip_traj(self, traj: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Flip a [B, T, D] or [B, N, T, D] trajectory tensor. + + Internal state format: x(0) y(1) z(2) xd(3) yd(4) xdd(5) ydd(6) h(7). + mask: bool tensor of shape [B] – True = apply flip. + """ + result = traj.clone() + D = traj.shape[-1] + B = traj.shape[0] + + # Channels to negate + if self.axis == "x": + flip_chs = [c for c in (0, 3, 5, 7) if c < D] + else: + flip_chs = [c for c in (1, 4, 6, 7) if c < D] + + for ch in flip_chs: + sliced = result[..., ch] # shape [B, T] or [B, N, T] + # Build mask with same ndim as sliced: [B, 1, ...1] + m = mask.view((B,) + (1,) * (sliced.dim() - 1)) + result[..., ch] = torch.where(m.expand_as(sliced), -sliced, sliced) + + return result + + def _sample_mask(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.rand(batch_size, device=device) < self.prob + + # ----------------------------------------------------------------- + + def apply_agent(self, batch: AgentBatch) -> None: + B = batch.agent_hist.shape[0] + mask = self._sample_mask(B, batch.agent_hist.device) + + batch.agent_hist = self._flip_traj(batch.agent_hist, mask) + batch.agent_fut = self._flip_traj(batch.agent_fut, mask) + + if batch.neigh_hist is not None and batch.neigh_hist.numel() > 0: + # neigh tensors are [B, N, T, D]; expand mask to [B] + batch.neigh_hist = self._flip_traj(batch.neigh_hist, mask) + if batch.neigh_fut is not None and batch.neigh_fut.numel() > 0: + batch.neigh_fut = self._flip_traj(batch.neigh_fut, mask) + + def apply_scene(self, batch: SceneBatch) -> None: + B = batch.agent_hist.shape[0] + mask = self._sample_mask(B, batch.agent_hist.device) + batch.agent_hist = self._flip_traj(batch.agent_hist, mask) + batch.agent_fut = self._flip_traj(batch.agent_fut, mask) diff --git a/src/trajdata/augmentation/motion_type.py b/src/trajdata/augmentation/motion_type.py new file mode 100644 index 0000000..61cfcba --- /dev/null +++ b/src/trajdata/augmentation/motion_type.py @@ -0,0 +1,96 @@ +""" +MotionTypeLabeler: classify agent motion into discrete categories based on +instantaneous speed, and store the label as a batch extra. + +Categories (stored as integer in ``batch.extras["motion_type"]``): + ++----+--------------+---------------------------+ +| ID | Name | Speed range (m/s) | ++====+==============+===========================+ +| 0 | STATIONARY | v < ``stationary_thresh`` | +| 1 | WALKING | stationary – walking | +| 2 | RUNNING | walking – running | +| 3 | FAST | > ``running_thresh`` | ++----+--------------+---------------------------+ + +Usage:: + + from trajdata.augmentation import MotionTypeLabeler + dataset = UnifiedDataset(..., augmentations=[MotionTypeLabeler()]) + # batch.extras["motion_type"] → LongTensor [B] +""" +import torch + +from trajdata.augmentation.augmentation import BatchAugmentation +from trajdata.data_structures.batch import AgentBatch, SceneBatch + +STATIONARY = 0 +WALKING = 1 +RUNNING = 2 +FAST = 3 + + +class MotionTypeLabeler(BatchAugmentation): + """Add a ``motion_type`` integer label to ``batch.extras``. + + The label is derived from the mean speed over the observed history + (velocity channels xd, yd at indices 3 and 4 of the state). + + Args: + stationary_thresh: Max speed to be considered stationary (default 0.5 m/s). + walking_thresh: Max speed for walking category (default 2.5 m/s). + running_thresh: Max speed for running category (default 6.0 m/s). + Anything above is labelled FAST. + """ + + def __init__( + self, + stationary_thresh: float = 0.5, + walking_thresh: float = 2.5, + running_thresh: float = 6.0, + ) -> None: + self.stationary_thresh = stationary_thresh + self.walking_thresh = walking_thresh + self.running_thresh = running_thresh + + # ----------------------------------------------------------------- + + def _compute_labels(self, hist: torch.Tensor, hist_len: torch.Tensor) -> torch.Tensor: + """Compute per-sample motion type label from history tensor [B, T, D].""" + D = hist.shape[-1] + if D < 5: + # Not enough channels to extract velocity + return torch.zeros(hist.shape[0], dtype=torch.long, device=hist.device) + + vx = hist[..., 3] # [B, T] + vy = hist[..., 4] # [B, T] + speed = torch.sqrt(vx ** 2 + vy ** 2) # [B, T] + + # Mean speed over valid timesteps only + # hist_len: [B] + T = hist.shape[1] + valid_mask = ( + torch.arange(T, device=hist.device).unsqueeze(0) < hist_len.unsqueeze(1) + ).float() # [B, T] + mean_speed = (speed * valid_mask).sum(dim=1) / hist_len.float().clamp(min=1) + + labels = torch.zeros(hist.shape[0], dtype=torch.long, device=hist.device) + labels[mean_speed >= self.stationary_thresh] = WALKING + labels[mean_speed >= self.walking_thresh] = RUNNING + labels[mean_speed >= self.running_thresh] = FAST + + return labels + + # ----------------------------------------------------------------- + + def apply_agent(self, batch: AgentBatch) -> None: + labels = self._compute_labels(batch.agent_hist, batch.agent_hist_len) + batch.extras["motion_type"] = labels + + def apply_scene(self, batch: SceneBatch) -> None: + # For scene-centric batches, compute per-agent labels [B, A] + B, A, T, D = batch.agent_hist.shape + flat_hist = batch.agent_hist.view(B * A, T, D) + flat_len = batch.agent_hist_len.view(B * A) + labels = self._compute_labels(flat_hist, flat_len).view(B, A) + batch.extras["motion_type"] = labels diff --git a/src/trajdata/augmentation/speed_scale.py b/src/trajdata/augmentation/speed_scale.py new file mode 100644 index 0000000..e5635f2 --- /dev/null +++ b/src/trajdata/augmentation/speed_scale.py @@ -0,0 +1,85 @@ +""" +SpeedScaleAugmentation: randomly scale agent speeds (and derived quantities) +at batch time to simulate faster or slower motion. + +Scales: velocity (xd, yd), acceleration (xdd, ydd) and the temporal extent of +trajectories are left untouched – only the magnitude of movement is altered. + +Usage:: + + from trajdata.augmentation import SpeedScaleAugmentation + dataset = UnifiedDataset(..., augmentations=[SpeedScaleAugmentation(0.7, 1.3)]) +""" +import torch + +from trajdata.augmentation.augmentation import BatchAugmentation +from trajdata.data_structures.batch import AgentBatch, SceneBatch + + +class SpeedScaleAugmentation(BatchAugmentation): + """Randomly scale agent velocities and accelerations. + + A random scalar drawn uniformly from ``[scale_min, scale_max]`` is + applied per sample to velocity channels (xd, yd → channels 3,4) and + acceleration channels (xdd, ydd → channels 5,6). Position channels are + left unchanged so that the history/future trajectories remain geometrically + consistent while the speed distribution is augmented. + + Args: + scale_min: Lower bound of the uniform scale factor (default 0.8). + scale_max: Upper bound of the uniform scale factor (default 1.2). + """ + + # Internal state format: x(0), y(1), z(2), xd(3), yd(4), xdd(5), ydd(6), h(7) + _VEL_CHANNELS = (3, 4) + _ACC_CHANNELS = (5, 6) + + def __init__(self, scale_min: float = 0.8, scale_max: float = 1.2) -> None: + if scale_min <= 0 or scale_max <= 0: + raise ValueError("Scale bounds must be positive.") + if scale_min > scale_max: + raise ValueError("scale_min must be ≤ scale_max.") + self.scale_min = scale_min + self.scale_max = scale_max + + # ----------------------------------------------------------------- + + def _scale_traj(self, traj: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Scale velocity/acc channels; traj shape [B, T, D] or [B, N, T, D].""" + result = traj.clone() + D = traj.shape[-1] + B = traj.shape[0] + + for ch in self._VEL_CHANNELS + self._ACC_CHANNELS: + if ch < D: + sliced = result[..., ch] # [B, T] or [B, N, T] + s = scale.view((B,) + (1,) * (sliced.dim() - 1)) + result[..., ch] = sliced * s.expand_as(sliced) + + return result + + def _sample_scale(self, batch_size: int, device: torch.device) -> torch.Tensor: + return ( + torch.rand(batch_size, device=device) * (self.scale_max - self.scale_min) + + self.scale_min + ) + + # ----------------------------------------------------------------- + + def apply_agent(self, batch: AgentBatch) -> None: + B = batch.agent_hist.shape[0] + scale = self._sample_scale(B, batch.agent_hist.device) + + batch.agent_hist = self._scale_traj(batch.agent_hist, scale) + batch.agent_fut = self._scale_traj(batch.agent_fut, scale) + + if batch.neigh_hist is not None and batch.neigh_hist.numel() > 0: + batch.neigh_hist = self._scale_traj(batch.neigh_hist, scale) + if batch.neigh_fut is not None and batch.neigh_fut.numel() > 0: + batch.neigh_fut = self._scale_traj(batch.neigh_fut, scale) + + def apply_scene(self, batch: SceneBatch) -> None: + B = batch.agent_hist.shape[0] + scale = self._sample_scale(B, batch.agent_hist.device) + batch.agent_hist = self._scale_traj(batch.agent_hist, scale) + batch.agent_fut = self._scale_traj(batch.agent_fut, scale) diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py index 329bb3d..1e0e165 100644 --- a/src/trajdata/caching/df_cache.py +++ b/src/trajdata/caching/df_cache.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import warnings from decimal import Decimal from typing import TYPE_CHECKING @@ -89,6 +90,11 @@ def _agent_data_file(scene_dt: float) -> str: @staticmethod def _agent_data_index_file(scene_dt: float) -> str: + return f"scene_index_dt{scene_dt:.2f}.json" + + @staticmethod + def _agent_data_index_file_legacy(scene_dt: float) -> str: + """Legacy pickle format kept for backwards compatibility with existing caches.""" return f"scene_index_dt{scene_dt:.2f}.pkl" # AGENT STATE DATA @@ -148,18 +154,34 @@ def _load_agent_data(self, scene_dt: float) -> None: use_threads=False, ).set_index(["agent_id", "scene_ts"]) - with open( - self.scene_dir / DataFrameCache._agent_data_index_file(scene_dt), "rb" - ) as f: - self.index_dict: Dict[Tuple[str, int], int] = pickle.load(f) + json_index_path = self.scene_dir / DataFrameCache._agent_data_index_file(scene_dt) + pkl_index_path = self.scene_dir / DataFrameCache._agent_data_index_file_legacy(scene_dt) + if json_index_path.exists(): + with open(json_index_path, "r") as f: + raw: Dict[str, int] = json.load(f) + self.index_dict: Dict[Tuple[str, int], int] = { + (parts[0], int(parts[1])): v + for k, v in raw.items() + for parts in [k.split("::", 1)] + } + elif pkl_index_path.exists(): + # Legacy format: load from pickle for backwards compatibility. + with open(pkl_index_path, "rb") as f: + self.index_dict = pickle.load(f) + else: + raise FileNotFoundError( + f"No agent data index found at {json_index_path} or {pkl_index_path}." + ) self._get_and_reorder_col_idxs() def write_cache_to_disk(self) -> None: with open( - self.scene_dir / DataFrameCache._agent_data_index_file(self.dt), "wb" + self.scene_dir / DataFrameCache._agent_data_index_file(self.dt), "w" ) as f: - pickle.dump(self.index_dict, f) + json.dump( + {f"{k[0]}::{k[1]}": v for k, v in self.index_dict.items()}, f + ) self.scene_data_df.reset_index().to_feather( self.scene_dir / DataFrameCache._agent_data_file(self.dt) @@ -180,9 +202,11 @@ def save_agent_data( val: idx for idx, val in enumerate(agent_data.index) } with open( - scene_cache_dir / DataFrameCache._agent_data_index_file(scene.dt), "wb" + scene_cache_dir / DataFrameCache._agent_data_index_file(scene.dt), "w" ) as f: - pickle.dump(index_dict, f) + json.dump( + {f"{k[0]}::{k[1]}": v for k, v in index_dict.items()}, f + ) agent_data.reset_index().to_feather( scene_cache_dir / DataFrameCache._agent_data_file(scene.dt) diff --git a/src/trajdata/data_structures/data_index.py b/src/trajdata/data_structures/data_index.py index 54f70ca..8fe64d7 100644 --- a/src/trajdata/data_structures/data_index.py +++ b/src/trajdata/data_structures/data_index.py @@ -26,7 +26,7 @@ def __init__( ) self._len: int = self._cumulative_lengths[-1].item() - self._scene_paths: np.ndarray = np.array(scene_paths).astype(np.string_) + self._scene_paths: np.ndarray = np.array(scene_paths).astype(np.bytes_) def __len__(self) -> int: return self._len @@ -61,7 +61,7 @@ def __init__( ): agent_ids, agent_times = zip(*scene_data_index) - self._agent_ids.append(np.array(agent_ids).astype(np.string_)) + self._agent_ids.append(np.array(agent_ids).astype(np.bytes_)) agent_ts: np.ndarray = np.stack(agent_times) self._agent_times.append(agent_ts) diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py index c3667b5..713d604 100644 --- a/src/trajdata/dataset.py +++ b/src/trajdata/dataset.py @@ -1,5 +1,6 @@ import gc import json +import logging import random import re import time @@ -51,6 +52,8 @@ ) from trajdata.utils.parallel_utils import parallel_iapply +logger = logging.getLogger(__name__) + class UnifiedDataset(Dataset): # @profile @@ -250,10 +253,9 @@ def __init__( matching_datasets: List[SceneTag] = self._get_matching_scene_tags(desired_data) if self.verbose: - print( - "Loading data for matched scene tags:", + logger.info( + "Loading data for matched scene tags: %s", string_utils.pretty_string_tags(matching_datasets), - flush=True, ) self.check_args_combinations(matching_datasets) @@ -348,7 +350,7 @@ def __init__( all_scenes_list, num_workers ) if self.verbose: - print(len(scene_paths), "scenes in the scene index.") + logger.info("%d scenes in the scene index.", len(scene_paths)) # Done with this list. Cutting memory usage because # of multiprocessing later on. @@ -381,9 +383,8 @@ def __init__( # Use only rank 0 process for caching when using multi-GPU torch training. if save_index and rank == 0: if self._index_cache_path().exists(): - print( - "WARNING: Overwriting already-cached data index (since save_index is True).", - flush=True, + logger.warning( + "Overwriting already-cached data index (since save_index is True)." ) self._cache_data_index(data_index) @@ -498,10 +499,7 @@ def _cache_data_index( with open(args_file, "w") as f: json.dump(index_args, f, indent=4) - print( - f"Cached data index to {str(index_cache_file)}", - flush=True, - ) + logger.info("Cached data index to %s", str(index_cache_file)) def _load_data_index( self, @@ -514,10 +512,7 @@ def _load_data_index( data_index = dill.load(f) if self.verbose: - print( - f"Loaded data index from {str(index_cache_file)}", - flush=True, - ) + logger.info("Loaded data index from %s", str(index_cache_file)) return data_index @@ -525,11 +520,11 @@ def load_or_create_cache( self, cache_path: str, num_workers=0, filter_fn=None ) -> None: if isfile(cache_path): - print(f"Loading cache from {cache_path} ...", end="") + logger.info("Loading cache from %s ...", cache_path) t = time.time() with open(cache_path, "rb") as f: self._cached_batch_elements, keep_ids = dill.load(f, encoding="latin1") - print(f" done in {time.time() - t:.1f}s.") + logger.info("Cache loaded in %.1fs.", time.time() - t) else: # Build cache @@ -563,11 +558,11 @@ def load_or_create_cache( # not self (in case it is set to that)! del cache_data_iterator - print(f"Saving cache to {cache_path} ....", end="") + logger.info("Saving cache to %s ...", cache_path) t = time.time() with open(cache_path, "wb") as f: dill.dump((cached_batch_elements, keep_ids), f) - print(f" done in {time.time() - t:.1f}s.") + logger.info("Cache saved in %.1fs.", time.time() - t) self._cached_batch_elements = cached_batch_elements @@ -624,8 +619,9 @@ def apply_filter( keep_count += 1 if max_count is not None and keep_count >= max_count: # Add False for remaining samples and break loop - print( - f"Reached maximum number of {max_count} elements, terminating early." + logger.info( + "Reached maximum number of %d elements, terminating early.", + max_count, ) break @@ -642,15 +638,18 @@ def apply_filter( # All proceses use the indices from rank 0 self._data_index = gathered_values[0] self._data_len = len(self._data_index) - print(f"Rank {self.rank} has {self._data_len} elements.") + logger.info("Rank %d has %d elements.", self.rank, self._data_len) def remove_elements(self, keep_ids: Union[np.ndarray, List[int]]): old_len = self._data_len self._data_index = [self._data_index[i] for i in keep_ids] self._data_len = len(self._data_index) - print( - f"Kept {self._data_len}/{old_len} elements, {self._data_len/old_len*100.0:.2f}%." + logger.info( + "Kept %d/%d elements (%.2f%%).", + self._data_len, + old_len, + self._data_len / old_len * 100.0, ) def _get_data_index( diff --git a/src/trajdata/dataset_specific/argoverse2/av2_dataset.py b/src/trajdata/dataset_specific/argoverse2/av2_dataset.py index fe36f64..9d4def8 100644 --- a/src/trajdata/dataset_specific/argoverse2/av2_dataset.py +++ b/src/trajdata/dataset_specific/argoverse2/av2_dataset.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Any, Dict, List, Tuple, Type, Union @@ -25,6 +26,8 @@ from trajdata.dataset_specific.scene_records import Argoverse2Record from trajdata.utils import arr_utils +logger = logging.getLogger(__name__) + AV2_MOTION_FORECASTING = "av2_motion_forecasting" AV2_DT = 1 / AV2_SCENARIO_STEP_HZ @@ -47,7 +50,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) self.dataset_obj = Av2Object(self.metadata.data_dir) def _get_matching_scenes_from_obj( diff --git a/src/trajdata/dataset_specific/csv_dataset/__init__.py b/src/trajdata/dataset_specific/csv_dataset/__init__.py new file mode 100644 index 0000000..10755b2 --- /dev/null +++ b/src/trajdata/dataset_specific/csv_dataset/__init__.py @@ -0,0 +1 @@ +from .csv_dataset import CSVDataset diff --git a/src/trajdata/dataset_specific/csv_dataset/csv_dataset.py b/src/trajdata/dataset_specific/csv_dataset/csv_dataset.py new file mode 100644 index 0000000..341de4d --- /dev/null +++ b/src/trajdata/dataset_specific/csv_dataset/csv_dataset.py @@ -0,0 +1,327 @@ +""" +Generic CSV Dataset adapter for trajdata. + +Each CSV file represents one scene and must contain at minimum: + frame_id, agent_id, x, y + +Optional columns (computed from x/y if missing): + vx, vy, heading, agent_type + +CSV files are placed in a single directory; an optional ``config.json`` +specifies the time-step and train/val/test splits. + +Directory layout:: + + /path/to/csv_data/ + ├── config.json (optional) + ├── scene_001.csv + ├── scene_002.csv + └── ... + +``config.json`` example:: + + { + "dt": 0.1, + "splits": { + "train": ["scene_001", "scene_002"], + "val": ["scene_003"] + } + } + +Register in UnifiedDataset as ``csv_``, e.g.:: + + dataset = UnifiedDataset( + desired_data=["csv_mydata-train"], + data_dirs={"csv_mydata": "/path/to/csv_data"}, + ) +""" +import json +import logging +from pathlib import Path +from typing import Any, Dict, Final, List, Optional, Tuple, Type + +import numpy as np +import pandas as pd + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import AgentMetadata, AgentType, FixedExtent +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import EUPedsRecord +from trajdata.utils import arr_utils + +logger = logging.getLogger(__name__) + +DEFAULT_DT: Final[float] = 0.1 +DEFAULT_EXTENT: Final[Tuple[float, float, float]] = (0.5, 0.5, 1.7) + +_AGENT_TYPE_MAP: Final[Dict[str, AgentType]] = { + "pedestrian": AgentType.PEDESTRIAN, + "ped": AgentType.PEDESTRIAN, + "vehicle": AgentType.VEHICLE, + "car": AgentType.VEHICLE, + "bicycle": AgentType.BICYCLE, + "bike": AgentType.BICYCLE, + "motorcycle": AgentType.MOTORCYCLE, +} + + +def _parse_agent_type(value) -> AgentType: + if isinstance(value, str): + return _AGENT_TYPE_MAP.get(value.lower(), AgentType.UNKNOWN) + return AgentType.UNKNOWN + + +class CSVDataset(RawDataset): + """Adapter that loads any directory of per-scene CSV files.""" + + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + data_dir = Path(data_dir).expanduser() + config = self._load_config(data_dir) + + self._dt: float = config.get("dt", DEFAULT_DT) + self._config_splits: Dict[str, List[str]] = config.get("splits", {}) + + dataset_parts: List[Tuple[str, ...]] = [("train", "val", "test")] + scene_split_map: Dict[str, str] = {} + + # Assign each CSV scene to a split using config; default to "train" + for csv_path in sorted(data_dir.glob("*.csv")): + scene_name = csv_path.stem + assigned = "train" + for split, names in self._config_splits.items(): + if scene_name in names: + assigned = split + break + scene_split_map[scene_name] = assigned + + self._scene_split_map = scene_split_map + + return EnvMetadata( + name=env_name, + data_dir=str(data_dir), + dt=self._dt, + parts=dataset_parts, + scene_split_map=scene_split_map, + ) + + # ------------------------------------------------------------------ + # Dataset object loading + # ------------------------------------------------------------------ + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + logger.info("Loading CSV dataset from %s ...", self.metadata.data_dir) + + data_dir = Path(self.metadata.data_dir) + self.dataset_obj: Dict[str, pd.DataFrame] = {} + + for csv_path in sorted(data_dir.glob("*.csv")): + scene_name = csv_path.stem + df = pd.read_csv(csv_path) + df = self._normalise_columns(df) + df["frame_id"] = pd.to_numeric( + df["frame_id"] - df["frame_id"].min(), downcast="integer" + ) + self.dataset_obj[scene_name] = df + + # ------------------------------------------------------------------ + # Scene discovery helpers + # ------------------------------------------------------------------ + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_records: List[EUPedsRecord] = [] + scenes: List[SceneMetadata] = [] + + for idx, (scene_name, df) in enumerate(self.dataset_obj.items()): + split = self.metadata.scene_split_map.get(scene_name, "train") + length = int(df["frame_id"].max()) + 1 + + all_records.append( + EUPedsRecord(scene_name, "csv", length, split, idx) + ) + + if split in scene_tag and ( + scene_desc_contains is None + or any(k in scene_name for k in scene_desc_contains) + ): + scenes.append( + SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + ) + + self.cache_all_scenes_list(env_cache, all_records) + return scenes + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_records: List[EUPedsRecord] = env_cache.load_env_scenes_list(self.name) + scenes: List[Scene] = [] + + for record in all_records: + scene_name, _loc, length, split, data_idx = record + if split in scene_tag and ( + scene_desc_contains is None + or any(k in scene_name for k in scene_desc_contains) + ): + scenes.append( + Scene( + self.metadata, + scene_name, + "csv", + split, + length, + data_idx, + None, + ) + ) + return scenes + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, scene_name, _, data_idx = scene_info + df = self.dataset_obj[scene_name] + split = self.metadata.scene_split_map.get(scene_name, "train") + length = int(df["frame_id"].max()) + 1 + return Scene(self.metadata, scene_name, "csv", split, length, data_idx, None) + + # ------------------------------------------------------------------ + # Agent info extraction + # ------------------------------------------------------------------ + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + df = self.dataset_obj[scene.name].copy() + df.rename( + columns={"frame_id": "scene_ts", "agent_id": "agent_id"}, + inplace=True, + ) + df["agent_id"] = df["agent_id"].astype(str) + df.set_index(["agent_id", "scene_ts"], inplace=True) + df.sort_index(inplace=True) + df.reset_index(level=1, inplace=True) + + agent_ids = df.index.get_level_values(0).to_numpy() + + # z column + if "z" not in df.columns: + df["z"] = 0.0 + + # velocities + if "vx" not in df.columns or "vy" not in df.columns: + vel = ( + arr_utils.agent_aware_diff(df[["x", "y"]].to_numpy(), agent_ids) + / self.metadata.dt + ) + df["vx"], df["vy"] = vel[:, 0], vel[:, 1] + + # accelerations + if "ax" not in df.columns or "ay" not in df.columns: + acc = ( + arr_utils.agent_aware_diff(df[["vx", "vy"]].to_numpy(), agent_ids) + / self.metadata.dt + ) + df["ax"], df["ay"] = acc[:, 0], acc[:, 1] + + # heading + if "heading" not in df.columns: + df["heading"] = np.arctan2(df["vy"], df["vx"]) + + # agent_type per-row + if "agent_type" in df.columns: + df["_atype"] = df["agent_type"].apply(_parse_agent_type) + else: + df["_atype"] = AgentType.PEDESTRIAN + + # Build metadata lists + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(scene.length_timesteps) + ] + + for agent_id, frames_df in df.groupby(level=0): + frames = frames_df["scene_ts"] + if len(frames) <= 1: + continue + + # pandas may store enum as int; cast back to AgentType + raw_atype = frames_df["_atype"].iat[0] + atype = AgentType(int(raw_atype)) if not isinstance(raw_atype, AgentType) else raw_atype + t0, t1 = int(frames.iat[0]), int(frames.iat[-1]) + meta = AgentMetadata( + name=str(agent_id), + agent_type=atype, + first_timestep=t0, + last_timestep=t1, + extent=FixedExtent(*DEFAULT_EXTENT), + ) + agent_list.append(meta) + for ts in frames: + if 0 <= ts < scene.length_timesteps: + agent_presence[ts].append(meta) + + # Drop helper columns before caching + df.drop(columns=["_atype"], inplace=True, errors="ignore") + df.drop(columns=["agent_type"], inplace=True, errors="ignore") + + # Restore (agent_id, scene_ts) MultiIndex expected by save_agent_data + df.reset_index(inplace=True) # agent_id becomes column again + df["agent_id"] = df["agent_id"].astype(str) + df.set_index(["agent_id", "scene_ts"], inplace=True) + + cache_class.save_agent_data(df, cache_path, scene) + return agent_list, agent_presence + + # ------------------------------------------------------------------ + # Maps (none for CSV) + # ------------------------------------------------------------------ + + def cache_map(self, map_name, layer_names, cache_path, map_cache_class, resolution): + pass + + def cache_maps(self, cache_path, map_cache_class, map_params): + pass + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _load_config(data_dir: Path) -> dict: + cfg_path = data_dir / "config.json" + if cfg_path.exists(): + with open(cfg_path) as f: + return json.load(f) + return {} + + @staticmethod + def _normalise_columns(df: pd.DataFrame) -> pd.DataFrame: + """Rename common column aliases to canonical names.""" + rename = {} + lc = {c.lower(): c for c in df.columns} + for canonical, aliases in { + "frame_id": ["frame_id", "frame", "timestep", "t"], + "agent_id": ["agent_id", "track_id", "id", "object_id"], + "x": ["x", "pos_x", "position_x"], + "y": ["y", "pos_y", "position_y"], + }.items(): + for alias in aliases: + if alias in lc and canonical not in df.columns: + rename[lc[alias]] = canonical + break + return df.rename(columns=rename) diff --git a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py index 4234df4..53beef5 100644 --- a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py +++ b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Any, Dict, Final, List, Optional, Tuple, Type @@ -13,6 +14,8 @@ from trajdata.dataset_specific.scene_records import EUPedsRecord from trajdata.utils import arr_utils +logger = logging.getLogger(__name__) + TRAIN_SCENES: Final[List[str]] = [ "biwi_eth", "biwi_hotel", @@ -103,7 +106,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) self.dataset_obj: Dict[str, pd.DataFrame] = dict() for scene_name in TRAIN_SCENES: diff --git a/src/trajdata/dataset_specific/interaction/interaction_dataset.py b/src/trajdata/dataset_specific/interaction/interaction_dataset.py index 89cc29a..6f685ee 100644 --- a/src/trajdata/dataset_specific/interaction/interaction_dataset.py +++ b/src/trajdata/dataset_specific/interaction/interaction_dataset.py @@ -1,4 +1,6 @@ +import logging import os +import pickle import time from collections import defaultdict from pathlib import Path @@ -20,6 +22,8 @@ from trajdata.maps.vec_map_elements import Polyline, RoadLane from trajdata.utils import arr_utils +logger = logging.getLogger(__name__) + # SDD was captured at 10 frames per second. INTERACTION_DT: Final[float] = 0.1 @@ -137,7 +141,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) data_dir_path = Path(self.metadata.data_dir) @@ -167,9 +171,8 @@ def load_dataset_obj(self, verbose: bool = False) -> None: ) if verbose: - print( - f"The first ~60 iterations might be slow, don't worry the following ones will be fast.", - flush=True, + logger.info( + "The first ~60 iterations might be slow, don't worry the following ones will be fast." ) def _get_matching_scenes_from_obj( @@ -281,12 +284,23 @@ def get_agent_info( if scene_metadata_path.exists(): # Try repeatedly to open the file because it might still be # being created in another process. - while True: + max_retries: int = 30 + for attempt in range(max_retries): try: already_done_scene = EnvCache.load(scene_metadata_path) break - except: + except (OSError, EOFError, pickle.UnpicklingError) as e: + if attempt == max_retries - 1: + raise RuntimeError( + f"Failed to load cached scene metadata from " + f"{scene_metadata_path} after {max_retries} attempts." + ) from e time.sleep(1) + else: + raise RuntimeError( + f"Failed to load cached scene metadata from " + f"{scene_metadata_path} after {max_retries} attempts." + ) # Already processed, so we can immediately return our cached results. return ( diff --git a/src/trajdata/dataset_specific/lyft/lyft_dataset.py b/src/trajdata/dataset_specific/lyft/lyft_dataset.py index 61ad7d7..0059d8c 100644 --- a/src/trajdata/dataset_specific/lyft/lyft_dataset.py +++ b/src/trajdata/dataset_specific/lyft/lyft_dataset.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from functools import partial from pathlib import Path @@ -25,6 +26,8 @@ from trajdata.maps import VectorMap from trajdata.utils import arr_utils +logger = logging.getLogger(__name__) + def const_lambda(const_val: Any) -> Any: return const_val @@ -86,7 +89,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) self.dataset_obj = ChunkedDataset(str(self.metadata.data_dir)).open() @@ -336,7 +339,7 @@ def cache_maps( ) -> None: resolution: float = map_params["px_per_m"] map_name: str = "palo_alto" - print(f"Caching {map_name} Map at {resolution:.2f} px/m...", flush=True) + logger.info("Caching %s Map at %.2f px/m...", map_name, resolution) # We have to do this .parent.parent stuff because the data_dir for lyft is scenes/*.zarr dm = LocalDataManager((self.metadata.data_dir.parent.parent).resolve()) diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py index 1c4df3f..70867b9 100644 --- a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py +++ b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type @@ -24,6 +25,8 @@ from trajdata.maps.vec_map import VectorMap from trajdata.utils import arr_utils +logger = logging.getLogger(__name__) + class NuplanDataset(RawDataset): def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: @@ -71,7 +74,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) if self.name == "nuplan_mini": subfolder = "mini" diff --git a/src/trajdata/dataset_specific/nusc/nusc_dataset.py b/src/trajdata/dataset_specific/nusc/nusc_dataset.py index 209824e..a6806ed 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_dataset.py +++ b/src/trajdata/dataset_specific/nusc/nusc_dataset.py @@ -1,3 +1,4 @@ +import logging import warnings from copy import deepcopy from pathlib import Path @@ -26,6 +27,8 @@ from trajdata.dataset_specific.scene_records import NuscSceneRecord from trajdata.maps import VectorMap +logger = logging.getLogger(__name__) + class NuscDataset(RawDataset): def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: @@ -91,7 +94,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) if self.name == "nusc_mini": version_str = "v1.0-mini" diff --git a/src/trajdata/dataset_specific/nusc/nusc_utils.py b/src/trajdata/dataset_specific/nusc/nusc_utils.py index 356debd..7413fad 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_utils.py +++ b/src/trajdata/dataset_specific/nusc/nusc_utils.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, Final, List, Tuple, Union import numpy as np @@ -21,6 +22,8 @@ ) from trajdata.utils import arr_utils, map_utils +logger = logging.getLogger(__name__) + NUSC_DT: Final[float] = 0.5 @@ -60,7 +63,7 @@ def agg_agent_data( ) -> Agent: """Loops through all annotations of a specific agent in a scene and aggregates their data into an Agent object.""" if agent_data["prev"]: - print("WARN: This is not the first frame of this agent!") + logger.warning("This is not the first frame of this agent!") translation_list = [np.array(agent_data["translation"][:3])[np.newaxis]] agent_size = agent_data["size"] diff --git a/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py b/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py index a94bf88..39f6b51 100644 --- a/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py +++ b/src/trajdata/dataset_specific/sdd_peds/sddpeds_dataset.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from random import Random from typing import Any, Dict, Final, List, Optional, Tuple, Type @@ -16,6 +17,8 @@ from .estimated_homography import SDD_HOMOGRAPHY_SCALES +logger = logging.getLogger(__name__) + # SDD was captured at 30 frames per second. SDDPEDS_DT: Final[float] = 1.0 / 30.0 @@ -77,7 +80,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) # Just storing the filepath and scene length (number of frames). # One could load the entire dataset here, but there's no need diff --git a/src/trajdata/dataset_specific/vod/vod_dataset.py b/src/trajdata/dataset_specific/vod/vod_dataset.py index be0dc9b..cc238d8 100644 --- a/src/trajdata/dataset_specific/vod/vod_dataset.py +++ b/src/trajdata/dataset_specific/vod/vod_dataset.py @@ -1,3 +1,4 @@ +import logging import warnings from copy import deepcopy from pathlib import Path @@ -25,6 +26,8 @@ from trajdata.dataset_specific.vod import vod_utils from trajdata.maps import VectorMap +logger = logging.getLogger(__name__) + class VODDataset(RawDataset): def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: @@ -86,7 +89,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) if self.name == "vod_trainval": version_str = "v1.0-trainval" diff --git a/src/trajdata/dataset_specific/vod/vod_utils.py b/src/trajdata/dataset_specific/vod/vod_utils.py index a4050e8..69e1db2 100644 --- a/src/trajdata/dataset_specific/vod/vod_utils.py +++ b/src/trajdata/dataset_specific/vod/vod_utils.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, Final, List, Tuple, Union import numpy as np @@ -21,6 +22,8 @@ ) from trajdata.utils import arr_utils, map_utils +logger = logging.getLogger(__name__) + VOD_DT: Final[float] = 0.1 @@ -60,7 +63,7 @@ def agg_agent_data( ) -> Agent: """Loops through all annotations of a specific agent in a scene and aggregates their data into an Agent object.""" if agent_data["prev"]: - print("WARN: This is not the first frame of this agent!") + logger.warning("This is not the first frame of this agent!") translation_list = [np.array(agent_data["translation"][:3])[np.newaxis]] agent_size = agent_data["size"] diff --git a/src/trajdata/dataset_specific/waymo/waymo_dataset.py b/src/trajdata/dataset_specific/waymo/waymo_dataset.py index 4497a79..5fc8d73 100644 --- a/src/trajdata/dataset_specific/waymo/waymo_dataset.py +++ b/src/trajdata/dataset_specific/waymo/waymo_dataset.py @@ -1,3 +1,4 @@ +import logging import os from collections import defaultdict from functools import partial @@ -44,6 +45,8 @@ from trajdata.utils import arr_utils from trajdata.utils.parallel_utils import parallel_apply +logger = logging.getLogger(__name__) + def const_lambda(const_val: Any) -> Any: return const_val @@ -76,7 +79,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: - print(f"Loading {self.name} dataset...", flush=True) + logger.info("Loading %s dataset...", self.name) dataset_name: str = "" if self.name == "waymo_train": dataset_name = "training" diff --git a/src/trajdata/dataset_specific/waymo/waymo_utils.py b/src/trajdata/dataset_specific/waymo/waymo_utils.py index 83cc210..5c29883 100644 --- a/src/trajdata/dataset_specific/waymo/waymo_utils.py +++ b/src/trajdata/dataset_specific/waymo/waymo_utils.py @@ -1,3 +1,4 @@ +import logging import os from pathlib import Path from subprocess import check_call, check_output @@ -14,6 +15,8 @@ from trajdata.maps import TrafficLightStatus, VectorMap from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane +logger = logging.getLogger(__name__) + WAYMO_DT: Final[float] = 0.1 WAYMO_DATASET_NAMES = [ "testing", @@ -87,14 +90,20 @@ def __init__( def download_dataset(self) -> None: # check_call("snap install google-cloud-sdk --classic".split()) - gsutil = check_output(["which", "gsutil"]) - download_cmd = ( - str(gsutil.decode("utf-8")) - + "-m cp -r gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario/" - + str(self.name) - + " " - + str(self.source_dir) - ).split() + gsutil_path: str = check_output(["which", "gsutil"]).decode("utf-8").strip() + if not gsutil_path or not os.path.isfile(gsutil_path): + raise RuntimeError( + "gsutil not found. Install the Google Cloud SDK before downloading Waymo data. " + "See: https://cloud.google.com/sdk/docs/install" + ) + download_cmd: List[str] = [ + gsutil_path, + "-m", + "cp", + "-r", + f"gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario/{self.name}", + str(self.source_dir), + ] check_call(download_cmd) def split_scenarios( @@ -103,13 +112,13 @@ def split_scenarios( source_it: Path = (self.source_dir / self.name).glob("*") file_names: List[str] = [str(file_name) for file_name in source_it] if verbose: - print("Loading tfrecord files...") + logger.info("Loading tfrecord files...") dataset = tf.data.TFRecordDataset( file_names, compression_type="", num_parallel_reads=num_parallel_reads ) if verbose: - print("Splitting tfrecords...") + logger.info("Splitting tfrecords...") splitted_dir: Path = self.source_dir / f"{self.name}_splitted" if not splitted_dir.exists(): @@ -127,13 +136,11 @@ def split_scenarios( self.num_scenarios = scenario_num if verbose: - print( - str(self.num_scenarios) - + " scenarios from " - + str(len(file_names)) - + " file(s) have been split into " - + str(self.num_scenarios) - + " files." + logger.info( + "%d scenarios from %d file(s) have been split into %d files.", + self.num_scenarios, + len(file_names), + self.num_scenarios, ) def get_filename(self, data_idx): diff --git a/src/trajdata/io/__init__.py b/src/trajdata/io/__init__.py new file mode 100644 index 0000000..cdb2508 --- /dev/null +++ b/src/trajdata/io/__init__.py @@ -0,0 +1,2 @@ +from .exporter import DataExporter +from .precomputed_dataset import PrecomputedDataset diff --git a/src/trajdata/io/exporter.py b/src/trajdata/io/exporter.py new file mode 100644 index 0000000..4880cfc --- /dev/null +++ b/src/trajdata/io/exporter.py @@ -0,0 +1,181 @@ +""" +DataExporter: Precompute and save UnifiedDataset batches to disk for fast reloading. + +Supported formats: + - "zarr": Zarr compressed arrays (default, already a trajdata dependency) + - "numpy": Uncompressed .npy files (portable, no extra deps) + +Usage: + from trajdata.io import DataExporter + DataExporter.export(dataset, "my_cache.zarr", format="zarr", batch_size=64) +""" +import json +import logging +from pathlib import Path +from typing import List, Optional + +import numpy as np +import zarr +from torch.utils.data import DataLoader +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +# Fields from AgentBatch that we store (tensor → numpy) +_FLOAT_FIELDS = [ + "curr_agent_state", + "agent_hist", + "agent_fut", + "neigh_hist", + "neigh_fut", +] +_INT_FIELDS = [ + "agent_hist_len", + "agent_fut_len", + "neigh_hist_len", + "neigh_fut_len", + "neigh_types", + "agent_type", + "scene_ts", +] +_FLOAT_SCALAR = ["dt"] +_ALL_FIELDS = _FLOAT_FIELDS + _INT_FIELDS + _FLOAT_SCALAR + + +class DataExporter: + """Export a UnifiedDataset to a precomputed on-disk cache for fast reloading.""" + + @staticmethod + def export( + dataset, + output_path: str, + format: str = "zarr", + batch_size: int = 64, + num_workers: int = 0, + verbose: bool = True, + compressor: Optional[object] = None, + ) -> None: + """Export all samples in *dataset* to *output_path*. + + Args: + dataset: A :class:`~trajdata.UnifiedDataset` instance (agent-centric). + output_path: Directory/file path for the exported data. + format: ``"zarr"`` or ``"numpy"``. + batch_size: Batch size to use during export iteration. + num_workers: Worker count for the temporary DataLoader. + verbose: Show progress bar. + compressor: Custom Zarr compressor (``None`` → Zarr default Blosc). + """ + if format not in ("zarr", "numpy"): + raise ValueError(f"Unsupported format '{format}'. Choose 'zarr' or 'numpy'.") + + output_path = Path(output_path) + + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=dataset.get_collate_fn(), + num_workers=num_workers, + ) + + if format == "zarr": + DataExporter._export_zarr(loader, output_path, len(dataset), verbose, compressor) + else: + DataExporter._export_numpy(loader, output_path, len(dataset), verbose) + + if verbose: + logger.info("Export complete → %s", output_path) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _batch_to_numpy(batch) -> dict: + """Convert an AgentBatch to a dict of numpy arrays.""" + result = {} + for field in _ALL_FIELDS: + tensor = getattr(batch, field, None) + if tensor is None: + continue + try: + result[field] = tensor.numpy() + except Exception: + pass + # agent_name is a list of strings + if hasattr(batch, "agent_name") and batch.agent_name is not None: + result["agent_name"] = np.array(batch.agent_name, dtype=object) + return result + + @staticmethod + def _export_zarr(loader, output_path: Path, total: int, verbose: bool, compressor) -> None: + store = zarr.open(str(output_path), mode="w") + arrays = {} # field → zarr Array (created on first batch) + idx = 0 + + for batch in tqdm(loader, desc="Exporting (zarr)", disable=not verbose): + np_data = DataExporter._batch_to_numpy(batch) + n = next(iter(np_data.values())).shape[0] + + for field, arr in np_data.items(): + if field not in arrays: + shape = (total,) + arr.shape[1:] + dtype = arr.dtype if arr.dtype != object else str + arrays[field] = store.zeros( + field, + shape=shape, + dtype=dtype, + chunks=(min(64, total),) + arr.shape[1:], + compressor=compressor, + object_codec=zarr.codecs.VLenUTF8() if dtype == str else None, + ) + try: + arrays[field][idx : idx + n] = arr + except Exception: + pass + + idx += n + + store.attrs["total_samples"] = idx + store.attrs["fields"] = list(arrays.keys()) + + @staticmethod + def _pad_ragged(chunks: list) -> list: + """Pad ragged arrays along axis=1 (neighbour dimension) to the same size.""" + max_n = max(c.shape[1] for c in chunks) + padded = [] + for c in chunks: + pad_width = [(0, 0)] * c.ndim + pad_width[1] = (0, max_n - c.shape[1]) + padded.append(np.pad(c, pad_width)) + return padded + + @staticmethod + def _export_numpy(loader, output_path: Path, total: int, verbose: bool) -> None: + output_path.mkdir(parents=True, exist_ok=True) + buffers = {} # field → list of batches + + for batch in tqdm(loader, desc="Exporting (numpy)", disable=not verbose): + np_data = DataExporter._batch_to_numpy(batch) + for field, arr in np_data.items(): + buffers.setdefault(field, []).append(arr) + + # Neighbour fields are ragged (max_neigh varies per batch) → pad to max + _NEIGH_FIELDS = {"neigh_hist", "neigh_fut", "neigh_hist_len", "neigh_fut_len", "neigh_types"} + for field in list(buffers.keys()): + if field in _NEIGH_FIELDS: + buffers[field] = DataExporter._pad_ragged(buffers[field]) + + meta = {"total_samples": 0, "fields": []} + for field, chunks in buffers.items(): + try: + combined = np.concatenate(chunks, axis=0) + np.save(output_path / f"{field}.npy", combined) + meta["fields"].append(field) + meta["total_samples"] = combined.shape[0] + except Exception as e: + logger.warning("Skipping field '%s': %s", field, e) + + with open(output_path / "metadata.json", "w") as f: + json.dump(meta, f) diff --git a/src/trajdata/io/precomputed_dataset.py b/src/trajdata/io/precomputed_dataset.py new file mode 100644 index 0000000..55ade43 --- /dev/null +++ b/src/trajdata/io/precomputed_dataset.py @@ -0,0 +1,109 @@ +""" +PrecomputedDataset: Fast PyTorch Dataset that reads from a precomputed cache. + +Usage: + from trajdata.io import DataExporter, PrecomputedDataset + + # 1. Export once + DataExporter.export(original_dataset, "cache.zarr", format="zarr") + + # 2. Load fast (no on-the-fly preprocessing) + fast_ds = PrecomputedDataset("cache.zarr", format="zarr") + loader = DataLoader(fast_ds, batch_size=64, shuffle=True) +""" +import json +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import torch +import zarr +from torch.utils.data import Dataset + + +class PrecomputedDataset(Dataset): + """PyTorch Dataset backed by a precomputed cache (zarr or numpy). + + Each ``__getitem__`` returns a dict of tensors corresponding to the + fields stored by :class:`~trajdata.io.DataExporter`. + """ + + def __init__(self, path: str, format: str = "zarr", fields: Optional[List[str]] = None): + """ + Args: + path: Path to the exported cache (zarr store or numpy directory). + format: ``"zarr"`` or ``"numpy"``. + fields: Optional subset of field names to load. ``None`` loads all. + """ + self.path = Path(path) + self.format = format + self._data: Dict[str, np.ndarray] = {} + + if format == "zarr": + self._load_zarr(fields) + elif format == "numpy": + self._load_numpy(fields) + else: + raise ValueError(f"Unsupported format '{format}'") + + self._len = next(iter(self._data.values())).shape[0] if self._data else 0 + + # ------------------------------------------------------------------ + + def _load_zarr(self, fields: Optional[List[str]]) -> None: + store = zarr.open(str(self.path), mode="r") + available = list(store.attrs.get("fields", store.array_keys())) + to_load = fields if fields is not None else available + for field in to_load: + if field in store: + self._data[field] = store[field] # lazy-loaded zarr array + + def _load_numpy(self, fields: Optional[List[str]]) -> None: + meta_path = self.path / "metadata.json" + if meta_path.exists(): + with open(meta_path) as f: + meta = json.load(f) + available = meta.get("fields", []) + else: + available = [p.stem for p in self.path.glob("*.npy")] + + to_load = fields if fields is not None else available + for field in to_load: + npy_path = self.path / f"{field}.npy" + if npy_path.exists(): + self._data[field] = np.load(npy_path, allow_pickle=True) + + # ------------------------------------------------------------------ + + def __len__(self) -> int: + return self._len + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + sample = {} + for field, arr in self._data.items(): + val = arr[idx] + if isinstance(val, np.ndarray) and val.dtype.kind in ("f", "i", "u"): + sample[field] = torch.from_numpy(np.array(val)) + else: + sample[field] = val # strings, objects etc. pass through + return sample + + @property + def fields(self) -> List[str]: + """List of available field names.""" + return list(self._data.keys()) + + @staticmethod + def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: + """Default collate: stacks tensor fields, lists others.""" + result = {} + for key in batch[0]: + vals = [item[key] for item in batch] + if isinstance(vals[0], torch.Tensor): + try: + result[key] = torch.stack(vals, dim=0) + except RuntimeError: + result[key] = vals + else: + result[key] = vals + return result diff --git a/src/trajdata/simulation/__init__.py b/src/trajdata/simulation/__init__.py index 1cda25c..6d026b0 100644 --- a/src/trajdata/simulation/__init__.py +++ b/src/trajdata/simulation/__init__.py @@ -1 +1,3 @@ +from .sim_metrics import ADE, FDE, CollisionMetric, OffRoadRate +from .sim_scenarios import ConstantVelocityPolicy, RandomWalkPolicy, SimRunner from .sim_scene import SimulationScene diff --git a/src/trajdata/simulation/sim_metrics.py b/src/trajdata/simulation/sim_metrics.py index 9f1603d..787140f 100644 --- a/src/trajdata/simulation/sim_metrics.py +++ b/src/trajdata/simulation/sim_metrics.py @@ -30,3 +30,87 @@ def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame) -> Dict[str, float err_df = pd.DataFrame(index=gt_df.index, columns=["error"]) err_df["error"] = np.linalg.norm(gt_df[["x", "y"]] - sim_df[["x", "y"]], axis=1) return err_df.groupby("agent_id")["error"].last().to_dict() + + +class CollisionMetric(SimMetric): + """Detect pairwise agent collisions based on Euclidean distance threshold. + + Returns per-agent collision rate (fraction of timesteps in which the + agent is within ``distance_thresh`` of at least one other agent). + + Args: + distance_thresh: Distance (metres) below which two agents are considered + to have collided (default 0.5 m). + """ + + def __init__(self, distance_thresh: float = 0.5) -> None: + super().__init__("collision_rate") + self.distance_thresh = distance_thresh + + def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame) -> Dict[str, float]: + # sim_df index: (agent_id, scene_ts) + results: Dict[str, float] = {} + + sim_reset = sim_df.reset_index() + agent_ids = sim_reset["agent_id"].unique() + + for agent_id in agent_ids: + agent_ts = sim_reset[sim_reset["agent_id"] == agent_id][ + ["scene_ts", "x", "y"] + ].set_index("scene_ts") + others = sim_reset[sim_reset["agent_id"] != agent_id] + + collision_ts = set() + for _, row in others.iterrows(): + ts = row["scene_ts"] + if ts not in agent_ts.index: + continue + dx = agent_ts.loc[ts, "x"] - row["x"] + dy = agent_ts.loc[ts, "y"] - row["y"] + if np.sqrt(dx**2 + dy**2) < self.distance_thresh: + collision_ts.add(ts) + + n_ts = len(agent_ts) + results[str(agent_id)] = len(collision_ts) / n_ts if n_ts > 0 else 0.0 + + return results + + +class OffRoadRate(SimMetric): + """Fraction of timesteps an agent spends outside a bounding box. + + Useful as a proxy for off-road / out-of-bounds detection when no map is + available. Provide ``scene_bounds`` as ``(x_min, x_max, y_min, y_max)``. + + Args: + scene_bounds: Tuple ``(x_min, x_max, y_min, y_max)`` defining the + valid region. Derived automatically from the ground-truth + trajectory if not provided. + """ + + def __init__(self, scene_bounds=None) -> None: + super().__init__("off_road_rate") + self.scene_bounds = scene_bounds + + def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame) -> Dict[str, float]: + if self.scene_bounds is not None: + x_min, x_max, y_min, y_max = self.scene_bounds + else: + x_min = gt_df["x"].min() - 5.0 + x_max = gt_df["x"].max() + 5.0 + y_min = gt_df["y"].min() - 5.0 + y_max = gt_df["y"].max() + 5.0 + + results: Dict[str, float] = {} + sim_reset = sim_df.reset_index() + + for agent_id, grp in sim_reset.groupby("agent_id"): + out_of_bounds = ( + (grp["x"] < x_min) + | (grp["x"] > x_max) + | (grp["y"] < y_min) + | (grp["y"] > y_max) + ) + results[str(agent_id)] = out_of_bounds.mean() + + return results diff --git a/src/trajdata/simulation/sim_scenarios.py b/src/trajdata/simulation/sim_scenarios.py new file mode 100644 index 0000000..832b5b3 --- /dev/null +++ b/src/trajdata/simulation/sim_scenarios.py @@ -0,0 +1,174 @@ +""" +Simulation scenario runners and baseline policies. + +A *policy* is any callable with signature:: + + policy(obs: AgentBatch) -> Dict[str, StateArray] + +where the returned dict maps ``agent_name → new_xyzh_state`` for the next +timestep. + +Built-in policies +----------------- +* :class:`ConstantVelocityPolicy` – each agent continues at its last observed + velocity (dead-reckoning). +* :class:`RandomWalkPolicy` – adds Gaussian noise to the last position. + +High-level runner +----------------- +:class:`SimRunner` wraps a :class:`~trajdata.simulation.sim_scene.SimulationScene` +and runs it for a fixed number of steps, collecting observations and metrics. + +Usage:: + + from trajdata.simulation import SimulationScene + from trajdata.simulation.sim_scenarios import SimRunner, ConstantVelocityPolicy + from trajdata.simulation.sim_metrics import ADE, FDE + + policy = ConstantVelocityPolicy() + runner = SimRunner(sim_scene, policy, max_steps=20) + results = runner.run(metrics=[ADE(), FDE()]) + print(results) # {"ade": {"agent_0": 1.2, ...}, "fde": {...}} +""" +from typing import Callable, Dict, List, Optional + +import numpy as np + +from trajdata.data_structures.state import StateArray +from trajdata.simulation.sim_metrics import SimMetric +from trajdata.simulation.sim_stats import SimStatistic + + +# --------------------------------------------------------------------------- +# Built-in policies +# --------------------------------------------------------------------------- + +class ConstantVelocityPolicy: + """Propagate each agent at its last observed velocity (linear extrapolation). + + The policy reads the current-agent observation from the batch and steps + each agent forward by ``dt * velocity``. + """ + + def __call__(self, obs, dt: float) -> Dict[str, "StateArray"]: + """ + Args: + obs: :class:`~trajdata.AgentBatch` from ``SimulationScene.get_obs()``. + dt: Simulation time-step in seconds. + + Returns: + Dict mapping ``agent_name → next StateArray`` in ``"x,y,z,h"`` format. + """ + result = {} + names = obs.agent_name if not isinstance(obs, dict) else obs["agent_name"] + states = obs.curr_agent_state if not isinstance(obs, dict) else obs["curr_agent_state"] + + for i, name in enumerate(names): + s = states[i].numpy() # shape [D]; internal format x,y,z,xd,yd,xdd,ydd,h + x = float(s[0]) if s.shape[0] > 0 else 0.0 + y = float(s[1]) if s.shape[0] > 1 else 0.0 + z = float(s[2]) if s.shape[0] > 2 else 0.0 + vx = float(s[3]) if s.shape[0] > 3 else 0.0 + vy = float(s[4]) if s.shape[0] > 4 else 0.0 + h = float(s[7]) if s.shape[0] > 7 else float(np.arctan2(vy, vx)) + + next_state = np.array([x + vx * dt, y + vy * dt, z, h], dtype=np.float64) + result[name] = StateArray.from_array(next_state, "x,y,z,h") + return result + + +class RandomWalkPolicy: + """Each agent takes a random step drawn from a Gaussian. + + Args: + stddev: Standard deviation of position perturbation in metres (default 0.1). + seed: Optional random seed for reproducibility. + """ + + def __init__(self, stddev: float = 0.1, seed: Optional[int] = None) -> None: + self.stddev = stddev + self.rng = np.random.default_rng(seed) + + def __call__(self, obs, dt: float) -> Dict[str, "StateArray"]: + result = {} + names = obs.agent_name if not isinstance(obs, dict) else obs["agent_name"] + states = obs.curr_agent_state if not isinstance(obs, dict) else obs["curr_agent_state"] + + for i, name in enumerate(names): + s = states[i].numpy() + x = float(s[0]) if s.shape[0] > 0 else 0.0 + y = float(s[1]) if s.shape[0] > 1 else 0.0 + z = float(s[2]) if s.shape[0] > 2 else 0.0 + h = float(s[7]) if s.shape[0] > 7 else 0.0 + noise = self.rng.normal(0.0, self.stddev, size=2) + next_state = np.array([x + noise[0], y + noise[1], z, h], dtype=np.float64) + result[name] = StateArray.from_array(next_state, "x,y,z,h") + return result + + +# --------------------------------------------------------------------------- +# High-level runner +# --------------------------------------------------------------------------- + +PolicyFn = Callable # (obs, dt) → Dict[str, StateArray] + + +class SimRunner: + """Run a :class:`~trajdata.simulation.sim_scene.SimulationScene` with a policy. + + Args: + sim_scene: An initialised ``SimulationScene`` instance. + policy: A callable ``(obs, dt) → Dict[agent_name, StateArray]``. + max_steps: Maximum number of simulation steps to run. + dt: Override the scene time-step; if ``None`` the scene's own dt is used. + """ + + def __init__( + self, + sim_scene, + policy: PolicyFn, + max_steps: int = 50, + dt: Optional[float] = None, + ) -> None: + self.sim_scene = sim_scene + self.policy = policy + self.max_steps = max_steps + self._dt = dt + + @property + def dt(self) -> float: + if self._dt is not None: + return self._dt + return self.sim_scene.scene.dt + + def run( + self, + metrics: Optional[List[SimMetric]] = None, + stats: Optional[List[SimStatistic]] = None, + verbose: bool = False, + ) -> dict: + """Execute the simulation loop. + + Returns: + A dict with keys ``"metrics"`` and/or ``"stats"`` containing the + computed values, plus ``"steps"`` (number of steps executed). + """ + obs = self.sim_scene.reset() + steps = 0 + + for step in range(self.max_steps): + action = self.policy(obs, self.dt) + obs = self.sim_scene.step(action, return_obs=True) + steps += 1 + if verbose: + print(f"Step {step + 1}/{self.max_steps}") + + self.sim_scene.finalize() + + results: dict = {"steps": steps} + if metrics: + results["metrics"] = self.sim_scene.get_metrics(metrics) + if stats: + results["stats"] = self.sim_scene.get_stats(stats) + + return results diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 49c22ec..dc5af7b 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -55,6 +55,11 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: return Av2Dataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + if dataset_name.startswith("csv_"): + from trajdata.dataset_specific.csv_dataset import CSVDataset + + return CSVDataset(dataset_name, data_dir, parallelizable=True, has_maps=False) + raise ValueError(f"Dataset with name '{dataset_name}' is not supported") diff --git a/trajdata_webui/__init__.py b/trajdata_webui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trajdata_webui/app_state.py b/trajdata_webui/app_state.py new file mode 100644 index 0000000..68da624 --- /dev/null +++ b/trajdata_webui/app_state.py @@ -0,0 +1,49 @@ +"""Shared per-session state passed between all tabs.""" +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class AppState: + # Loaded dataset + dataset: Optional[Any] = None + dataset_split: str = "eupeds_eth-train" + + # Visualization + current_sample_idx: int = 0 + + # Augmentation config (mirrors widget values) + aug_config: Dict[str, Any] = field(default_factory=lambda: { + "mirror": False, + "mirror_axis": "x", + "mirror_prob": 0.5, + "speed_scale": False, + "speed_min": 0.8, + "speed_max": 1.2, + "motion_labeler": False, + "stationary_thresh": 0.5, + "walking_thresh": 2.5, + "running_thresh": 6.0, + }) + + # Simulation config + sim_config: Dict[str, Any] = field(default_factory=lambda: { + "scene_idx": 0, + "max_steps": 30, + "metrics": ["ADE", "FDE", "Collision"], + }) + + # Export config + export_config: Dict[str, Any] = field(default_factory=lambda: { + "output_path": "~/trajdata_export", + "format": "zarr", + "batch_size": 64, + }) + + # UI State + active_panel: str = "dashboard" + lang: str = "en" + theme: str = "dark" + + # Status messages (set by background threads, read by UI) + status: str = "Ready. Load a dataset to begin." diff --git a/trajdata_webui/backend/__init__.py b/trajdata_webui/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trajdata_webui/backend/aug_preview.py b/trajdata_webui/backend/aug_preview.py new file mode 100644 index 0000000..edbfe78 --- /dev/null +++ b/trajdata_webui/backend/aug_preview.py @@ -0,0 +1,30 @@ +"""Compute before/after augmentation trajectory sources.""" +from typing import Any, Dict, Tuple + +from trajdata import UnifiedDataset +from trajdata.data_structures.collation import agent_collate_fn + +from .dataset_loader import build_augmentations +from .traj_renderer import batch_to_sources + + +def compute_preview( + dataset: UnifiedDataset, + sample_idx: int, + aug_config: Dict[str, Any], +) -> Tuple[Dict, Dict, Dict, Dict]: + """Return (base_hist, base_fut, aug_hist, aug_fut) ColumnDataSource dicts.""" + elem = dataset[sample_idx] + base_batch = agent_collate_fn([elem], return_dict=False, pad_format="outside") + base_hist, base_fut = batch_to_sources(base_batch) + + # Apply augmentations to a fresh copy of the batch element + aug_batch = agent_collate_fn([elem], return_dict=False, pad_format="outside") + for aug in build_augmentations(aug_config): + try: + aug.apply_agent(aug_batch) + except Exception: + pass + aug_hist, aug_fut = batch_to_sources(aug_batch) + + return base_hist, base_fut, aug_hist, aug_fut diff --git a/trajdata_webui/backend/dataset_loader.py b/trajdata_webui/backend/dataset_loader.py new file mode 100644 index 0000000..e52fe42 --- /dev/null +++ b/trajdata_webui/backend/dataset_loader.py @@ -0,0 +1,63 @@ +"""Build a UnifiedDataset from a split name + aug config.""" +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from trajdata import AgentType, UnifiedDataset +from trajdata.augmentation import ( + MirrorAugmentation, + MotionTypeLabeler, + SpeedScaleAugmentation, +) + +# Datasets available without extra downloads +AVAILABLE_SPLITS = [ + "eupeds_eth-train", + "eupeds_eth-val", + "eupeds_hotel-train", + "eupeds_hotel-val", + "eupeds_univ-train", + "eupeds_univ-val", + "eupeds_zara1-train", + "eupeds_zara1-val", + "eupeds_zara2-train", + "eupeds_zara2-val", +] + +DATA_DIRS = {"eupeds_eth": "~/datasets/eth_ucy", + "eupeds_hotel": "~/datasets/eth_ucy", + "eupeds_univ": "~/datasets/eth_ucy", + "eupeds_zara1": "~/datasets/eth_ucy", + "eupeds_zara2": "~/datasets/eth_ucy"} + + +def build_augmentations(cfg: Dict[str, Any]) -> list: + augs = [] + if cfg.get("mirror"): + augs.append(MirrorAugmentation(axis=cfg["mirror_axis"], prob=cfg["mirror_prob"])) + if cfg.get("speed_scale"): + augs.append(SpeedScaleAugmentation(cfg["speed_min"], cfg["speed_max"])) + if cfg.get("motion_labeler"): + augs.append(MotionTypeLabeler( + stationary_thresh=cfg["stationary_thresh"], + walking_thresh=cfg["walking_thresh"], + running_thresh=cfg["running_thresh"], + )) + return augs + + +def load_dataset(split: str, aug_config: Dict[str, Any]) -> UnifiedDataset: + """Build and return a UnifiedDataset for *split*.""" + augs = build_augmentations(aug_config) + dataset = UnifiedDataset( + desired_data=[split], + centric="agent", + desired_dt=0.4, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.PEDESTRIAN], + augmentations=augs, + num_workers=0, + verbose=False, + data_dirs=DATA_DIRS, + ) + return dataset diff --git a/trajdata_webui/backend/sim_runner.py b/trajdata_webui/backend/sim_runner.py new file mode 100644 index 0000000..88db102 --- /dev/null +++ b/trajdata_webui/backend/sim_runner.py @@ -0,0 +1,63 @@ +"""Thin wrapper around SimRunner for the web UI.""" +from typing import Any, Dict, List + +import numpy as np + +from trajdata import UnifiedDataset +from trajdata.simulation import ( + ADE, FDE, CollisionMetric, OffRoadRate, + SimulationScene, ConstantVelocityPolicy, SimRunner, +) + + +def run_simulation( + dataset: UnifiedDataset, + scene_idx: int, + max_steps: int, + metric_names: List[str], +) -> Dict[str, Any]: + """Run a constant-velocity simulation and return formatted results.""" + scene = dataset.get_scene(scene_idx) + + sim_scene = SimulationScene( + env_name="webui_sim", + scene_name=f"sim_scene_{scene_idx:04d}", + scene=scene, + dataset=dataset, + init_timestep=0, + freeze_agents=True, + ) + + metric_map = { + "ADE": ADE(), + "FDE": FDE(), + "Collision": CollisionMetric(distance_thresh=1.0), + "OffRoad": OffRoadRate(), + } + metrics = [metric_map[m] for m in metric_names if m in metric_map] + + policy = ConstantVelocityPolicy() + runner = SimRunner(sim_scene, policy, max_steps=max_steps) + raw = runner.run(metrics=metrics) + + # Format: list of rows {agent, metric_name: value} + rows = [] + if "metrics" in raw: + # Collect all agents + all_agents = set() + for per_agent in raw["metrics"].values(): + all_agents.update(per_agent.keys()) + for agent in sorted(all_agents): + row: Dict[str, Any] = {"agent": agent} + for mname, per_agent in raw["metrics"].items(): + row[mname] = round(float(per_agent.get(agent, float("nan"))), 4) + rows.append(row) + + # Aggregate means + means: Dict[str, float] = {} + if "metrics" in raw: + for mname, per_agent in raw["metrics"].items(): + vals = [v for v in per_agent.values() if not np.isnan(v)] + means[mname] = round(float(np.mean(vals)), 4) if vals else float("nan") + + return {"rows": rows, "means": means, "steps": raw.get("steps", max_steps)} diff --git a/trajdata_webui/backend/stats_computer.py b/trajdata_webui/backend/stats_computer.py new file mode 100644 index 0000000..dde5428 --- /dev/null +++ b/trajdata_webui/backend/stats_computer.py @@ -0,0 +1,56 @@ +"""Compute summary statistics for a loaded UnifiedDataset.""" +from collections import Counter +from typing import Any, Dict + +import numpy as np +from torch.utils.data import DataLoader + +from trajdata import AgentType, UnifiedDataset + + +def compute_stats(dataset: UnifiedDataset, max_batches: int = 30) -> Dict[str, Any]: + """Return a dict of human-readable statistics for *dataset*.""" + stats: Dict[str, Any] = {} + stats["total_samples"] = len(dataset) + stats["num_scenes"] = dataset.num_scenes() + + # Scene info from first scene + try: + scene0 = dataset.get_scene(0) + stats["dt_s"] = round(scene0.dt, 3) + stats["scene0_name"] = scene0.name + stats["scene0_timesteps"] = scene0.length_timesteps + stats["scene0_agents"] = len(scene0.agents) + except Exception: + stats["dt_s"] = "?" + stats["scene0_name"] = "?" + stats["scene0_timesteps"] = "?" + stats["scene0_agents"] = "?" + + # Agent type distribution (from a small loader sample) + type_counter: Counter = Counter() + hist_lens = [] + fut_lens = [] + + loader = DataLoader( + dataset, + batch_size=32, + shuffle=False, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + for i, batch in enumerate(loader): + if i >= max_batches: + break + for t in batch.agent_type.tolist(): + try: + type_counter[AgentType(t).name] += 1 + except ValueError: + type_counter["UNKNOWN"] += 1 + hist_lens.extend(batch.agent_hist_len.tolist()) + fut_lens.extend(batch.agent_fut_len.tolist()) + + stats["agent_type_counts"] = dict(type_counter) + stats["mean_hist_len"] = round(float(np.mean(hist_lens)), 2) if hist_lens else 0 + stats["mean_fut_len"] = round(float(np.mean(fut_lens)), 2) if fut_lens else 0 + return stats diff --git a/trajdata_webui/backend/traj_renderer.py b/trajdata_webui/backend/traj_renderer.py new file mode 100644 index 0000000..9c8897e --- /dev/null +++ b/trajdata_webui/backend/traj_renderer.py @@ -0,0 +1,97 @@ +"""Convert AgentBatch → Bokeh ColumnDataSource dicts for trajectory plotting.""" +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from torch.utils.data import DataLoader + +from trajdata import AgentBatch, UnifiedDataset +from trajdata.data_structures.collation import agent_collate_fn +from trajdata.utils import vis_utils + +# Colour used when type lookup fails +_FALLBACK_COLOR = "#888888" + + +def _agent_color(agent_type_val: int) -> str: + try: + return vis_utils.get_agent_type_color(agent_type_val) + except Exception: + return _FALLBACK_COLOR + + +def sample_to_batch(dataset: UnifiedDataset, idx: int) -> AgentBatch: + """Load a single sample and collate it into a 1-element AgentBatch.""" + elem = dataset[idx] + return agent_collate_fn([elem], return_dict=False, pad_format="outside") + + +def batch_to_sources( + batch: AgentBatch, + batch_idx: int = 0, + show_hist: bool = True, + show_fut: bool = True, + show_neighbors: bool = True, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Return two ColumnDataSource data dicts: (hist_data, fut_data). + + Each dict has keys: xs, ys, line_color, line_dash, legend_label. + """ + agent_type_val: int = batch.agent_type[batch_idx].item() + num_neigh: int = batch.num_neigh[batch_idx].item() + + hist_xs, hist_ys, hist_colors, hist_dashes, hist_labels = [], [], [], [], [] + fut_xs, fut_ys, fut_colors, fut_dashes, fut_labels = [], [], [], [], [] + + # ── ego agent ────────────────────────────────────────────────────── + if show_hist: + h = batch.agent_hist[batch_idx].cpu().numpy() + hist_xs.append(_safe_xy(h, "x")) + hist_ys.append(_safe_xy(h, "y")) + hist_colors.append(_agent_color(agent_type_val)) + hist_dashes.append("dashed") + hist_labels.append("Ego history") + + if show_fut: + f = batch.agent_fut[batch_idx].cpu().numpy() + fut_xs.append(_safe_xy(f, "x")) + fut_ys.append(_safe_xy(f, "y")) + fut_colors.append(_agent_color(agent_type_val)) + fut_dashes.append("solid") + fut_labels.append("Ego future") + + # ── neighbors ────────────────────────────────────────────────────── + if show_neighbors and num_neigh > 0: + neigh_types = batch.neigh_types[batch_idx].cpu().numpy() + for n in range(num_neigh): + c = _agent_color(int(neigh_types[n])) + if show_hist: + nh = batch.neigh_hist[batch_idx, n].cpu().numpy() + hist_xs.append(_safe_xy(nh, "x")) + hist_ys.append(_safe_xy(nh, "y")) + hist_colors.append(c) + hist_dashes.append("dashed") + hist_labels.append(f"Neigh {n} hist") + if show_fut: + nf = batch.neigh_fut[batch_idx, n].cpu().numpy() + fut_xs.append(_safe_xy(nf, "x")) + fut_ys.append(_safe_xy(nf, "y")) + fut_colors.append(c) + fut_dashes.append("solid") + fut_labels.append(f"Neigh {n} fut") + + hist_data = dict(xs=hist_xs, ys=hist_ys, + line_color=hist_colors, line_dash=hist_dashes, + legend_label=hist_labels) + fut_data = dict(xs=fut_xs, ys=fut_ys, + line_color=fut_colors, line_dash=fut_dashes, + legend_label=fut_labels) + return hist_data, fut_data + + +def _safe_xy(state_np, attr: str) -> List[float]: + """Return a list of floats for *attr* from a numpy StateArray; mask NaNs.""" + try: + vals = state_np.get_attr(attr) + return [float(v) for v in vals if not np.isnan(v)] + except Exception: + return [] diff --git a/trajdata_webui/main.py b/trajdata_webui/main.py new file mode 100644 index 0000000..8c6ac44 --- /dev/null +++ b/trajdata_webui/main.py @@ -0,0 +1,51 @@ +""" +trajdata Web UI – Tornado-based Bokeh Server at root path /. + +Run: + python trajdata_webui/main.py [--port 5006] [--no-browser] +""" +import argparse +import sys +from pathlib import Path + +# Make src/ and project root importable +_webui_dir = Path(__file__).parent +_project_root = _webui_dir.parent +for _p in [str(_project_root), str(_project_root / "src")]: + if _p not in sys.path: + sys.path.insert(0, _p) + +from bokeh.application import Application +from bokeh.application.handlers.function import FunctionHandler +from bokeh.server.server import Server +from tornado.ioloop import IOLoop + +from trajdata_webui.ui import build_ui + + +def modify_doc(doc): + build_ui(doc) + + +def main(): + parser = argparse.ArgumentParser(description="trajdata Web UI") + parser.add_argument("--port", type=int, default=5006) + parser.add_argument("--no-browser", action="store_true") + args = parser.parse_args() + + app = Application(FunctionHandler(modify_doc)) + server = Server({"/": app}, port=args.port, num_procs=1) + server.start() + + if not args.no_browser: + server.io_loop.add_callback(server.show, "/") + + print(f"\n trajdata Web UI → http://localhost:{args.port}/\n") + try: + server.io_loop.start() + except KeyboardInterrupt: + print("\nShutting down.") + + +if __name__ == "__main__": + main() diff --git a/trajdata_webui/ui.py b/trajdata_webui/ui.py new file mode 100644 index 0000000..3a7a632 --- /dev/null +++ b/trajdata_webui/ui.py @@ -0,0 +1,1398 @@ +""" +trajdata Web UI – full layout: header · sidebar · main content · footer. + +Sections +-------- + Dashboard – overview cards + Dataset – load & stats + Visualize – interactive trajectory viewer + Augment – before/after augmentation preview + Simulate – simulation runner + metrics table + Export – precomputed cache export +""" +from __future__ import annotations + +import sys +from functools import partial +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any, Dict, List + +# ── Bokeh ────────────────────────────────────────────────────────────────── +from bokeh.layouts import column, row +from bokeh.models import ( + Button, CheckboxGroup, ColumnDataSource, CustomJS, + DataTable, Div, MultiLine, RangeSlider, Select, + Slider, TableColumn, TextInput, Toggle, SVGIcon +) +from bokeh.plotting import figure + +# ── trajdata ─────────────────────────────────────────────────────────────── +from trajdata_webui.app_state import AppState +from trajdata_webui.backend.dataset_loader import ( + AVAILABLE_SPLITS, build_augmentations, load_dataset, +) +from trajdata_webui.backend.stats_computer import compute_stats +from trajdata_webui.backend.traj_renderer import batch_to_sources, sample_to_batch +from trajdata_webui.backend.aug_preview import compute_preview +from trajdata_webui.backend.sim_runner import run_simulation + +_pool = ThreadPoolExecutor(max_workers=2) + +# ── Color Palettes & UI Strings ───────────────────────────────────────────── +THEMES = { + "dark": dict( + bg = "#030303", + surface = "#0c0c0e", + surface2 = "#111114", + surface3 = "#1a1a1f", + border = "#1f1f23", + accent = "#2563eb", # Premium Blue + accent2 = "#3b82f6", + success = "#10b981", + warning = "#f59e0b", + danger = "#ef4444", + text = "#f8fafc", + muted = "#94a3b8", + plot_bg = "#0c0c0e", + grid = "#1a1a1f", + ), + "light": dict( + bg = "#f8fafc", + surface = "#ffffff", + surface2 = "#f1f5f9", + surface3 = "#e2e8f0", + border = "#e2e8f0", + accent = "#2563eb", + accent2 = "#3b82f6", + success = "#059669", + warning = "#d97706", + danger = "#dc2626", + text = "#0f172a", + muted = "#64748b", + plot_bg = "#ffffff", + grid = "#f1f5f9", + ) +} + +STRINGS = { + "en": { + "dashboard": "Dashboard", + "dataset": "Dataset", + "visualize": "Visualize", + "augment": "Augment", + "simulate": "Simulate", + "export": "Export", + "run_demo": "Run Demo", + "theme_light": "Light Mode", + "theme_dark": "Dark Mode", + "lang_switch": "Türkçe", + "developed": "Developed by", + "rights": "All rights reserved", + # Hero + "hero_title": "Project Overview", + "hero_desc": "Welcome to the trajdata workspace. Monitor global dataset metrics below or use the Quick Start guide to begin your analysis.", + # Quick Start + "qs_title": "Quick Start & Workflow", + "qs_1": "Ingest Dataset: Navigate to the Dataset tab to load raw trajectory splits. The system will pre-process agent distributions, sampling rates, and scenario counts for immediate analysis.", + "qs_2": "Trajectory Forensics: Use the Visualize tab to deep-dive into agent behaviors. Scrub through temporal sequences to observe past history (dashed) vs. future ground-truth (solid) paths.", + "qs_3": "Data Augmentation: In the Augment panel, apply complex kinematic transforms. Flip axes to simulate different driving cultures or scale velocities to expand the training manifold.", + "qs_4": "Validation & Simulation: Verify model logic or data integrity using Simulate. Run closed-loop experiments to detect collisions and measure prediction errors (ADE/FDE) across scenes.", + "qs_5": "High-Performance Cache: Once satisfied, use Export to precompute and save the workspace as a Zarr Cache. This ensures peak I/O performance during model training.", + "qs_6": "Interaction Demo: If you are new to the platform, click Run Demo in the sidebar. This automated walkthrough will guide you through each functional unit of the trajdata system.", + # Explainer: Visualize + "viz_exp_title": "Understanding Trajectories", + "viz_exp_hist_t": "History Paths (Dashed)", + "viz_exp_hist_d": "Observations leading up to current T=0. Shows past speed, heading, and curvature.", + "viz_exp_fut_t": "Future Ground-Truth (Solid)", + "viz_exp_fut_d": "Actual path taken by the agent. Used to validate prediction accuracy.", + "viz_exp_neigh_t": "Neighbor Agents (Muted)", + "viz_exp_neigh_d": "Other actors in the scene. Critical for multi-agent interaction modeling.", + "viz_exp_ctrl_t": "Visualization Control", + "viz_exp_ctrl_d": "Use the slider above to scrub through temporal sequences in the loaded split.", + # Explainer: Augment + "aug_exp_title": "Augmentation Strategy", + "aug_exp_spatial": "Spatial Transforms: Flipping across X/Y axes doubles the training manifold coverage (e.g., LHD to RHD scenarios).", + "aug_exp_velocity": "Velocity Scaling: Modifies temporal relationships to simulate agents moving at different speeds.", + "aug_exp_motion": "Motion Labeling: Automated heuristic labeling (Stationary, Walking, etc.) based on kinematic profiles.", + # Explainer: Export + "exp_note": "Pre-compute all batch elements to disk to skip on-the-fly cache building during model training. This ensures peak I/O throughput for large-scale experiments.", + "exp_btn_label": " Export Workspace", + }, + "tr": { + "dashboard": "Panel", + "dataset": "Veri Seti", + "visualize": "Görselleştir", + "augment": "Zenginleştir", + "simulate": "Simüle Et", + "export": "Dışa Aktar", + "run_demo": "Demoyu Çalıştır", + "theme_light": "Işık Modu", + "theme_dark": "Karanlık Mod", + "lang_switch": "English", + "developed": "Geliştiren", + "rights": "Tüm hakları saklıdır", + # Hero + "hero_title": "Proje Genel Bakışı", + "hero_desc": "trajdata çalışma alanına hoş geldiniz. Küresel veri seti ölçümlerini aşağıdan izleyebilir veya analizinize başlamak için Hızlı Başlangıç kılavuzunu kullanabilirsiniz.", + # Quick Start + "qs_title": "Hızlı Başlangıç ve İş Akışı", + "qs_1": "Veri Seti Yükleme: Ham yörünge verilerini içe aktarmak için Veri Seti sekmesini kullanın. Sistem; ajan dağılımını, örnekleme hızlarını ve senaryo sayılarını anında ön işleme tabi tutar.", + "qs_2": "Yörünge Analizi: Ajan davranışlarını derinlemesine incelemek için Görselleştir sekmesini kullanın. Geçmiş gözlemler (kesikli) ve gelecek gerçekliği (düz) arasındaki ilişkiyi zaman içinde izleyin.", + "qs_3": "Veri Zenginleştirme: Zenginleştir panelinde karmaşık kinematik dönüşümler uygulayın. Farklı sürüş kültürlerini simüle etmek için eksenleri aynalayın veya eğitim kapsamını genişletmek için hızları ölçeklendirin.", + "qs_4": "Doğrulama ve Simülasyon: Model mantığını veya veri bütünlüğünü Simüle Et sekmesinde doğrulayın. Çarpışmaları tespit etmek ve tahmin hatalarını (ADE/FDE) ölçmek için kapalı döngü deneyler yapın.", + "qs_5": "Yüksek Performanslı Önbellek: Ayarlarınızdan memnun kaldığınızda, çalışma alanını Zarr Önbelleği olarak kaydetmek için Dışa Aktar sekmesini kullanın. Bu, eğitim sırasında maksimum I/O hızı sağlar.", + "qs_6": "Etkileşimli Demo: Platformu ilk kez kullanıyorsanız yan menüdeki Demoyu Çalıştır butonuna tıklayın. Bu otomatik tur, trajdata sisteminin her bir biriminde size rehberlik edecektir.", + # Explainer: Export + "exp_note": "Eğitim sırasında anlık önbellek oluşturma adımını atlamak için tüm toplu iş öğelerini diske önceden hesaplayarak kaydedin. Bu, büyük ölçekli deneyler için en yüksek I/O veri akışını sağlar.", + "exp_btn_label": " Çalışma Alanını Dışa Aktar", + # Explainer: Visualize + "viz_exp_title": "Yörüngeleri Anlama", + "viz_exp_hist_t": "Geçmiş Yollar (Kesikli)", + "viz_exp_hist_d": "T=0 anına kadar olan gözlemler. Geçmiş hızı, rotayı ve eğriliği gösterir.", + "viz_exp_fut_t": "Gelecek Gerçekliği (Düz)", + "viz_exp_fut_d": "Ajanın izlediği gerçek yol. Tahmin doğruluğunu doğrulamak için kullanılır.", + "viz_exp_neigh_t": "Komşu Ajanlar (Sönük)", + "viz_exp_neigh_d": "Sahnedeki diğer aktörler. Çoklu ajan etkileşim modellemesi için kritiktir.", + "viz_exp_ctrl_t": "Görselleştirme Kontrolü", + "viz_exp_ctrl_d": "Yüklü veri seti içindeki zaman dizileri arasında geçiş yapmak için yukarıdaki sürgüyü kullanın.", + # Explainer: Augment + "aug_exp_title": "Zenginleştirme Stratejisi", + "aug_exp_spatial": "Uzamsal Dönüşümler: X/Y eksenleri boyunca aynalama, eğitim kapsamını iki katına çıkarır (örn. LHD'den RHD senaryolarına).", + "aug_exp_velocity": "Hız Ölçeklendirme: Farklı hızlarda hareket eden ajanları simüle etmek için zamansal ilişkileri değiştirir.", + "aug_exp_motion": "Hareket Etiketleme: Kinematik profillere dayalı otomatik etiketleme (Sabit, Yürüyen vb.).", + } +} + +# Default Active Theme +C = THEMES["dark"] +S = STRINGS["en"] + +_SIDEBAR_W = 248 +_TRAJ_PH = dict(xs=[[]], ys=[[]], line_color=["#252540"], + line_dash=["solid"], legend_label=[""]) + +_PANEL_TITLES = { + "dashboard": ("Dashboard", "Overview & quick start"), + "dataset": ("Dataset", "Load & explore datasets"), + "visualize": ("Visualize", "Interactive trajectory viewer"), + "augment": ("Augment", "Data augmentation preview"), + "simulate": ("Simulate", "Run & evaluate simulations"), + "export": ("Export", "Precompute & save caches"), +} + +# ═══════════════════════════════════════════════════════════════════════════ +# Helpers +# ═══════════════════════════════════════════════════════════════════════════ + +def _div(html: str, **kw) -> Div: + return Div(text=html, **kw) + + +def _card(title: str, value: str, color: str = C["accent"]) -> Div: + return _div( + f"""
+
{title}
+
{value}
+
""", + ) + + +def _section_title(text: str) -> Div: + return _div( + f"
{text}
", + width=800, + ) + + +def _title_html(title: str, subtitle: str) -> str: + return ( + f"
" + f"
{title}
" + f"
{subtitle}
" + f"
" + ) + + +def _get_icon(name: str, color: str = "#ffffff") -> SVGIcon: + path = f"/Users/hidirektor/PycharmProjects/trajdata/img/icon/icon_{name}.svg" + try: + import re + with open(path, "r") as f: + svg = f.read() + + # Color: Map named colors to proper hex for better SVG support + c = "#ffffff" if color == "white" else ("#111114" if color == "black" else color) + + # Power-RE: Replace hex colors + svg = re.sub(r'#([0-9a-fA-F]{3,6})', c, svg) + # Named black replacements + svg = svg.replace('stroke="black"', f'stroke="{c}"').replace('fill="black"', f'fill="{c}"') + + # Force-scale size (38px for maximum impact in 44px buttons) + svg = re.sub(r']*?)width="[^"]+"', rf']*?)height="[^"]+"', rf' str: + opp = "1" if visible else "0" + tra = "translateY(0)" if visible else "translateY(-20px)" + return f""" +
+
i
+
+
Walkthrough Update
+
{msg}
+
+
+""" + + +def _traj_figure(title: str, w: int = 520, h: int = 400) -> figure: + p = figure( + title=title, width=w, height=h, + tools="pan,wheel_zoom,box_zoom,reset,save", + background_fill_color=C["plot_bg"], + border_fill_color=C["surface"], + outline_line_color=C["border"], + ) + p.title.text_color = C["text"] + p.title.text_font_size = "13px" + p.xaxis.axis_label = "x (m)" + p.yaxis.axis_label = "y (m)" + p.axis.axis_label_text_color = C["muted"] + p.axis.major_label_text_color = C["muted"] + p.axis.axis_line_color = C["border"] + p.axis.major_tick_line_color = C["border"] + p.grid.grid_line_color = C["grid"] + p.xaxis.axis_label_text_font_size = "11px" + p.yaxis.axis_label_text_font_size = "11px" + return p + + +def _add_traj_glyphs(fig: figure, hist_src: ColumnDataSource, + fut_src: ColumnDataSource) -> None: + fig.add_glyph(hist_src, MultiLine( + xs="xs", ys="ys", line_color="line_color", + line_dash="line_dash", line_width=2, line_alpha=0.9, + )) + fig.add_glyph(fut_src, MultiLine( + xs="xs", ys="ys", line_color="line_color", + line_dash="line_dash", line_width=2, line_alpha=0.55, + )) + + +def _nav_btn(label: str, icon_name: str, theme: str, on_click: callable, active: bool = False) -> Button: + # Color logic: + # Dark Mode -> Always white + # Light Mode -> Active? White : Black + if theme == "dark": + color = "white" + else: + color = "white" if active else "black" + + btn = Button( + label=label, + icon=_get_icon(icon_name, color), + width=_SIDEBAR_W - 24, + height=44, + stylesheets=[f""" + :host .bk-btn {{ + background: transparent; + border: none !important; + border-radius: 12px; + color: {C['muted']}; + font-weight: 500; + font-size: 15px; + font-family: 'Inter', -apple-system, sans-serif; + text-align: left; + padding-left: 12px; + display: flex; + align-items: center; + gap: 12px; + transition: all .2s; + cursor: pointer; + }} + :host .bk-btn:hover {{ + background: {C['surface3']}; + color: {C['text']}; + }} + :host .bk-btn-primary {{ + background: {C['accent']} !important; + color: white !important; + font-weight: 600; + box-shadow: 0 4px 12px {C['accent']}30 !important; + }} + """], + ) + btn.on_click(on_click) + if active: + btn.button_type = "primary" + return btn + + +# ═══════════════════════════════════════════════════════════════════════════ +# Panels +# ═══════════════════════════════════════════════════════════════════════════ + +def _build_dashboard(state: AppState, cards_refs: dict) -> column: + def _t(k): return STRINGS[state.lang].get(k, STRINGS["en"].get(k, k)) + + quick_div = _div( + f"""
+
+ {_t('qs_title')}
+
+
    +
  1. {_t('qs_1')}
  2. +
  3. {_t('qs_2')}
  4. +
  5. {_t('qs_3')}
  6. +
+
    +
  1. {_t('qs_4')}
  2. +
  3. {_t('qs_5')}
  4. +
  5. {_t('qs_6')}
  6. +
+
""", + sizing_mode="stretch_width", + ) + + return column(quick_div, sizing_mode="stretch_width") + + +def _build_stats_row(cards_refs: dict) -> row: + # We rebuild cards ONLY if they don't exist in refs to avoid duplicate state + if "samples" not in cards_refs: + cards_refs["samples"] = _card("Samples", "—", C["accent"]) + cards_refs["scenes"] = _card("Scenes", "—", C["success"]) + cards_refs["dt"] = _card("dt (s)", "—", C["warning"]) + cards_refs["agents"] = _card("Agent types","—", C["danger"]) + + return row( + cards_refs["samples"], cards_refs["scenes"], + cards_refs["dt"], cards_refs["agents"], + spacing=14, sizing_mode="stretch_width", + styles={"justify-content": "space-between"} + ) + +def _build_dataset_panel(doc, state: AppState, cards_refs: dict, + nav_fn) -> tuple: + """Returns (panel_column, status_div, stats_source).""" + split_sel = Select(title="Split", value=state.dataset_split, + options=AVAILABLE_SPLITS, width=320) + load_btn = Button(label="Load Dataset", button_type="primary", + width=160, height=36, margin=(24, 0, 0, 12)) + status_div = _div(f"Select a split and load.", + width=500) + + stats_src = ColumnDataSource({"stat": [], "value": []}) + table_style = f""" + :host {{ + background: {C['surface']} !important; + border: 1px solid {C['border']} !important; + border-radius: 12px; + overflow: hidden; + }} + .bk-data-table {{ + background: {C['surface']} !important; + color: {C['text']} !important; + font-family: 'Inter', sans-serif !important; + }} + .bk-cell-index {{ background: {C['surface2']} !important; color: {C['muted']} !important; }} + .bk-header-column {{ + background: {C['surface3']} !important; + color: {C['text']} !important; + font-weight: 600 !important; + border-bottom: 1px solid {C['border']} !important; + }} + .bk-header-column:hover {{ + background: {C['accent']} !important; + color: white !important; + }} + .slick-cell {{ border-right: 1px solid {C['border']} !important; border-bottom: 1px solid {C['border']} !important; }} + .slick-row {{ background: {C['surface']} !important; }} + .slick-row:hover {{ background: {C['accent']}15 !important; }} + .slick-row.even {{ background: {C['surface2']} !important; }} + """ + + stats_tbl = DataTable( + source=stats_src, + columns=[TableColumn(field="stat", title="Statistic", width=240), + TableColumn(field="value", title="Value", width=220)], + sizing_mode="stretch_width", height=400, index_position=None, + stylesheets=[table_style], + ) + + def _do_load(): + split = split_sel.value + try: + ds = load_dataset(split, state.aug_config) + stats = compute_stats(ds) + def _update(): + state.dataset = ds + state.dataset_split = split + state.current_sample_idx = 0 + rows = [ + ("Split", split), + ("Total samples", f"{stats['total_samples']:,}"), + ("Scenes", str(stats["num_scenes"])), + ("dt (s)", str(stats["dt_s"])), + ("First scene", str(stats["scene0_name"])), + ("Timesteps (s0)", str(stats["scene0_timesteps"])), + ("Agents (s0)", str(stats["scene0_agents"])), + ("Mean hist len", str(stats["mean_hist_len"])), + ("Mean fut len", str(stats["mean_fut_len"])), + ] + [(f"Type: {k}", str(v)) for k, v in stats["agent_type_counts"].items()] + stats_src.data = {"stat": [r[0] for r in rows], + "value": [r[1] for r in rows]} + # Update dashboard cards + if "samples" in cards_refs: + cards_refs["samples"].text = _card("Samples", + f"{stats['total_samples']:,}", C["accent"]).text + cards_refs["scenes"].text = _card("Scenes", + str(stats["num_scenes"]), C["success"]).text + cards_refs["dt"].text = _card("dt (s)", + str(stats["dt_s"]), C["warning"]).text + types_str = "/".join(stats["agent_type_counts"].keys()) + cards_refs["agents"].text = _card("Agent types", + types_str, C["danger"]).text + status_div.text = ( + f"" + f"✓ Loaded {split} — {stats['total_samples']:,} samples" + ) + doc.add_next_tick_callback(_update) + except Exception as e: + doc.add_next_tick_callback(lambda: setattr( + status_div, "text", + f"✗ {e}" + )) + + def on_load(_): + status_div.text = f"Loading…" + _pool.submit(_do_load) + + load_btn.on_click(on_load) + + panel = column(row(split_sel, load_btn, spacing=16, margin=(16,0,0,0)), + status_div, stats_tbl, sizing_mode="stretch_width") + return panel, status_div, stats_src, split_sel, load_btn + + +def _build_viz_panel(doc, state: AppState, cards_refs: dict) -> tuple: + hist_src = ColumnDataSource(_TRAJ_PH.copy()) + fut_src = ColumnDataSource(_TRAJ_PH.copy()) + + p = _traj_figure("Trajectory", w=600, h=450) + _add_traj_glyphs(p, hist_src, fut_src) + + legend_div = _div( + f"
" + f"- - History   " + f"--- Future   " + f"- - Neighbor hist   " + f"--- Neighbor fut
", + ) + + slider = Slider(title="Sample", start=0, end=1, step=1, value=0, width=380) + prev_btn = Button(label=" < ", width=50, height=34, button_type="default") + next_btn = Button(label=" > ", width=50, height=34, button_type="default") + opts = CheckboxGroup(labels=["History", "Future", "Neighbors"], + active=[0, 1, 2], inline=True) + info_div = _div(f"Load a dataset first.", + width=580) + + def _refresh(idx: int): + if state.dataset is None: + return + n = len(state.dataset) + idx = max(0, min(idx, n - 1)) + state.current_sample_idx = idx + slider.value = idx + sh, sf, sn = 0 in opts.active, 1 in opts.active, 2 in opts.active + try: + b = sample_to_batch(state.dataset, idx) + hd, fd = batch_to_sources(b, 0, sh, sf, sn) + hist_src.data, fut_src.data = hd, fd + info_div.text = ( + f"" + f"Sample {idx}/{n-1}  ·  " + f"Agent {b.agent_name[0]}  ·  " + f"Hist {b.agent_hist_len[0].item()}ts  ·  " + f"Fut {b.agent_fut_len[0].item()}ts  ·  " + f"Neighbors {b.num_neigh[0].item()}" + ) + except Exception as e: + info_div.text = f"{e}" + + slider.on_change("value", lambda a, o, n: _refresh(n)) + prev_btn.on_click(lambda _: _refresh(state.current_sample_idx - 1)) + next_btn.on_click(lambda _: _refresh(state.current_sample_idx + 1)) + opts.on_change("active", lambda a, o, n: _refresh(state.current_sample_idx)) + + def _periodic(): + if state.dataset is not None and slider.end == 1: + slider.end = max(1, len(state.dataset) - 1) + _refresh(0) + doc.add_periodic_callback(_periodic, 1200) + + def _t(k): return STRINGS[state.lang].get(k, STRINGS["en"].get(k, k)) + # Detailed Explainers (Collapsible details tag) + explain_div = _div(f""" +
+ + {_t('viz_exp_title')} + +
+
+ {_t('viz_exp_hist_t')}
+ {_t('viz_exp_hist_d')} +
+
+ {_t('viz_exp_fut_t')}
+ {_t('viz_exp_fut_d')} +
+
+ {_t('viz_exp_neigh_t')}
+ {_t('viz_exp_neigh_d')} +
+
+ {_t('viz_exp_ctrl_t')}
+ {_t('viz_exp_ctrl_d')} +
+
+ Collision Detection / Çarpışma Tespiti
+ Overlap of paths indicates potential collisions. In Simulations, collisions are automatically highlighted and categorized by agent type. +
+
+
+ """, sizing_mode="stretch_width") + + controls = column(slider, row(prev_btn, next_btn, sizing_mode="stretch_width"), opts, info_div, legend_div, + sizing_mode="stretch_width", max_width=420) + + panel_layout = row(controls, p, spacing=24, sizing_mode="stretch_width") + panel = column(explain_div, panel_layout, sizing_mode="stretch_width") + return panel, _refresh, slider + +def _build_aug_panel(doc, state: AppState, cards_refs: dict) -> tuple: + bh_src = ColumnDataSource(_TRAJ_PH.copy()) + bf_src = ColumnDataSource(_TRAJ_PH.copy()) + ah_src = ColumnDataSource(_TRAJ_PH.copy()) + af_src = ColumnDataSource(_TRAJ_PH.copy()) + + fig_o = _traj_figure("Original", w=420, h=380) + fig_a = _traj_figure("Augmented", w=420, h=380) + _add_traj_glyphs(fig_o, bh_src, bf_src) + _add_traj_glyphs(fig_a, ah_src, af_src) + + slider = Slider(title="Sample", start=0, end=1, step=1, value=0, width=360) + mirror_tog = Toggle(label="Mirror", active=False, button_type="default", + width=130, height=34) + mirror_ax = Select(title="Axis", value="x", options=["x","y"], width=90) + mirror_prob = Slider(title="Prob", start=0.0, end=1.0, step=0.05, value=0.5, width=200) + speed_tog = Toggle(label="Speed Scale", active=False, button_type="default", + width=150, height=34) + speed_range = RangeSlider(title="Scale", start=0.3, end=2.5, + step=0.05, value=(0.8, 1.2), width=260) + motion_tog = Toggle(label="Motion Labels", active=False, button_type="default", + width=160, height=34) + motion_div = _div("", width=380) + apply_btn = Button(label="Apply to Dataset & Reload", + button_type="success", width=240, height=36) + status_div = _div("", width=500) + + def _read(): + state.aug_config.update({ + "mirror": mirror_tog.active, "mirror_axis": mirror_ax.value, + "mirror_prob": mirror_prob.value, "speed_scale": speed_tog.active, + "speed_min": speed_range.value[0], "speed_max": speed_range.value[1], + "motion_labeler": motion_tog.active, + }) + + def _preview(idx: int): + if state.dataset is None: + return + _read() + idx = max(0, min(idx, len(state.dataset) - 1)) + try: + bh, bf, ah, af = compute_preview(state.dataset, idx, state.aug_config) + bh_src.data, bf_src.data = bh, bf + ah_src.data, af_src.data = ah, af + if state.aug_config["motion_labeler"]: + from trajdata.data_structures.collation import agent_collate_fn + elem = state.dataset[idx] + b = agent_collate_fn([elem], True, pad_format="outside") + for aug in build_augmentations(state.aug_config): + try: aug.apply_agent(b) + except Exception: pass + if "motion_type" in b.extras: + lm = {0:"STATIONARY",1:"WALKING",2:"RUNNING",3:"FAST"} + lbl = lm.get(b.extras["motion_type"][0].item(), "?") + motion_div.text = ( + f"" + f"Motion type: {lbl}" + ) + except Exception as e: + status_div.text = f"{e}" + + slider.on_change("value", lambda a, o, n: _preview(n)) + for w in (mirror_tog, mirror_ax, mirror_prob, speed_tog, speed_range, motion_tog): + prop = "active" if isinstance(w, Toggle) else "value" + w.on_change(prop, lambda a, o, n: _preview(slider.value)) + + def on_apply(_): + _read() + status_div.text = f"Reloading…" + def _do(): + try: + ds = load_dataset(state.dataset_split, state.aug_config) + def _done(): + state.dataset = ds + slider.end = max(1, len(ds) - 1) + status_div.text = ( + f"✓ Reloaded with augmentations" + ) + _preview(0) + doc.add_next_tick_callback(_done) + except Exception as e: + doc.add_next_tick_callback(lambda: setattr( + status_div, "text", + f"{e}" + )) + _pool.submit(_do) + + apply_btn.on_click(on_apply) + + def _periodic(): + if state.dataset is not None and slider.end == 1: + slider.end = max(1, len(state.dataset) - 1) + _preview(0) + doc.add_periodic_callback(_periodic, 1200) + + def _sep(t): return _div( + f"
{t}
", width=380) + + def _t(k): return STRINGS[state.lang].get(k, STRINGS["en"].get(k, k)) + # Detailed Augmentation Explanation + aug_explain = _div(f""" +
+

{_t('aug_exp_title')}

+

+ {_t('viz_exp_hist_t')}: {_t('aug_exp_spatial')}

+ {_t('viz_exp_fut_t')}: {_t('aug_exp_velocity')}

+ {_t('viz_exp_neigh_t')}: {_t('aug_exp_motion')} +

+
+ """, width=320) + + controls = column( + slider, + _sep("Mirror Augmentation"), row(mirror_tog, mirror_ax, sizing_mode="stretch_width"), mirror_prob, + _sep("Speed Scale"), row(speed_tog, speed_range, sizing_mode="stretch_width"), + _sep("Motion Labeler"), motion_tog, motion_div, + apply_btn, status_div, + sizing_mode="stretch_width", max_width=380, + ) + figs = column(fig_o, fig_a, sizing_mode="stretch_width") + panel = column(row(controls, figs, aug_explain, spacing=24, sizing_mode="stretch_width"), + sizing_mode="stretch_width") + return panel, _preview, slider + + +def _build_sim_panel(doc, state: AppState, cards_refs: dict) -> tuple: + scene_sl = Slider(title="Scene", start=0, end=1, step=1, value=0, width=280) + steps_sl = Slider(title="Steps", start=5, end=200, step=5, value=30, width=280) + met_check = CheckboxGroup(labels=["ADE","FDE","Collision","OffRoad"], + active=[0,1,2], inline=True, + stylesheets=[f":host {{ color: {C['text']}; font-size: 13px; }}"]) + run_btn = Button(label=" > Run Simulation", button_type="success", + width=190, height=36) + status_div = _div(f"Load a dataset first.", + width=520) + means_div = _div("", width=560) + + _MCOLS = ["agent","ade","fde","collision","offroad"] + tbl_src = ColumnDataSource({c: [] for c in _MCOLS}) + table_style = f""" + :host {{ + background: {C['surface']} !important; + border: 1px solid {C['border']} !important; + border-radius: 12px; + overflow: hidden; + }} + .bk-data-table {{ + background: {C['surface']} !important; + color: {C['text']} !important; + font-family: 'Inter', sans-serif !important; + }} + .bk-cell-index {{ background: {C['surface2']} !important; color: {C['muted']} !important; }} + .bk-header-column {{ + background: {C['surface3']} !important; + color: {C['text']} !important; + font-weight: 600 !important; + border-bottom: 1px solid {C['border']} !important; + }} + .bk-header-column:hover {{ + background: {C['accent']} !important; + color: white !important; + }} + .slick-cell {{ border-right: 1px solid {C['border']} !important; border-bottom: 1px solid {C['border']} !important; }} + .slick-row {{ background: {C['surface']} !important; }} + .slick-row:hover {{ background: {C['accent']}15 !important; }} + .slick-row.even {{ background: {C['surface2']} !important; }} + """ + + tbl = DataTable( + source=tbl_src, + columns=[TableColumn(field=c, title=c.upper(), width=105) for c in _MCOLS], + sizing_mode="stretch_width", height=320, index_position=None, + stylesheets=[table_style], + ) + + def _run(): + names = ["ADE","FDE","Collision","OffRoad"] + sel = [names[i] for i in met_check.active] + try: + res = run_simulation(state.dataset, scene_sl.value, steps_sl.value, sel) + def _update(): + rows = res["rows"] + new = {c: [] for c in _MCOLS} + new["agent"] = [r["agent"] for r in rows] + for m in ["ade","fde","collision","offroad"]: + new[m] = [round(r.get(m, float("nan")), 4) for r in rows] + tbl_src.data = new + parts = [f"{k.upper()}" + f": {v:.4f}" + for k, v in res.get("means", {}).items()] + means_div.text = ( + f"
" + "  ·  ".join(parts) + "
" + ) + status_div.text = ( + f"✓ Done — " + f"{res['steps']} steps, {len(rows)} agents" + ) + doc.add_next_tick_callback(_update) + except Exception as e: + doc.add_next_tick_callback(lambda: setattr( + status_div, "text", + f"{e}" + )) + + def on_run(_): + if state.dataset is None: + status_div.text = f"Load a dataset first." + return + status_div.text = f"Running…" + _pool.submit(_run) + + run_btn.on_click(on_run) + + def _periodic(): + if state.dataset is not None and scene_sl.end == 1: + scene_sl.end = max(0, state.dataset.num_scenes() - 1) + doc.add_periodic_callback(_periodic, 1200) + + controls = column(scene_sl, steps_sl, + _div(f"
Metrics
", width=300), + met_check, run_btn, status_div, width=320) + results = column( + _div(f"
" + f"Per-agent results
", width=580), + tbl, means_div, sizing_mode="stretch_width", + ) + panel = column(row(controls, results, spacing=24, sizing_mode="stretch_width"), sizing_mode="stretch_width") + return panel, run_btn, status_div, scene_sl + + +def _build_export_panel(doc, state: AppState, cards_refs: dict) -> column: + def _t(k): return STRINGS[state.lang].get(k, STRINGS["en"].get(k, k)) + + note = _div( + f"

" + f"{_t('exp_note')}

", + sizing_mode="stretch_width" + ) + path_in = TextInput(title="Output path", value="~/trajdata_export", sizing_mode="stretch_width", max_width=500) + fmt_sel = Select(title="Format", value="zarr", + options=["zarr","numpy"], width=160) + bs_sl = Slider(title="Batch size", start=8, end=256, step=8, + value=64, sizing_mode="stretch_width", max_width=340) + exp_btn = Button(label=_t('exp_btn_label'), button_type="primary", + width=220, height=36) + status_div = _div(f"Load a dataset first.", + sizing_mode="stretch_width") + result_div = _div("", sizing_mode="stretch_width") + + def _do(): + from trajdata.io import DataExporter + out = str(Path(path_in.value).expanduser()) + try: + DataExporter.export(state.dataset, out, format=fmt_sel.value, + batch_size=bs_sl.value, num_workers=0, verbose=True) + def _done(): + _s = C["surface2"]; _ok = C["success"]; _t = C["text"] + _ac = C["accent"]; _mu = C["muted"]; _fmt = fmt_sel.value + result_div.text = ( + f"
" + f"Export complete
" + f"Path: {out}

" + f"Load back with:
" + f"" + f"PrecomputedDataset('{out}', format='{_fmt}')" + f"
" + ) + status_div.text = f"Done." + doc.add_next_tick_callback(_done) + except Exception as e: + doc.add_next_tick_callback(lambda: setattr( + status_div, "text", + f"{e}" + )) + + def on_export(_): + if state.dataset is None: + status_div.text = f"Load a dataset first." + return + status_div.text = f"Exporting…" + result_div.text = "" + _pool.submit(_do) + + exp_btn.on_click(on_export) + + return column(note, path_in, row(fmt_sel, bs_sl, align="end"), + exp_btn, status_div, result_div, + sizing_mode="stretch_width") + + +# ═══════════════════════════════════════════════════════════════════════════ +# Run Demo +# ═══════════════════════════════════════════════════════════════════════════ + +def _run_demo(doc, state: AppState, refs: dict): + """Automated demo: load → visualize → augment → simulate.""" + toast = refs["toast"] + demo_btn = refs["demo_btn"] + nav_fn = refs["nav_fn"] + viz_refresh= refs["viz_refresh"] + aug_preview= refs["aug_preview"] + sim_run_btn= refs["sim_run_btn"] + sim_status = refs["sim_status"] + + demo_btn.disabled = True + demo_btn.label = "Running Demo..." + + def _log(msg: str): + toast.text = _toast_html(msg, visible=True) + # Auto-hide after 4s (or next log) + doc.add_timeout_callback(lambda: _hide_if_same(msg), 4000) + + def _hide_if_same(msg: str): + if msg in toast.text: + toast.text = _toast_html(msg, visible=False) + + def _step1(): + nav_fn("dataset") + _log(f"Step 1/4" + f" — Loading eupeds_eth-train…") + + def _do(): + try: + ds = load_dataset("eupeds_eth-train", state.aug_config) + stats = compute_stats(ds) + def _done(): + state.dataset = ds + state.dataset_split = "eupeds_eth-train" + refs["stats_src"].data = { + "stat": ["Split","Samples","Scenes","dt (s)"], + "value": ["eupeds_eth-train", + f"{stats['total_samples']:,}", + str(stats["num_scenes"]), + str(stats["dt_s"])], + } + refs["ds_status"].text = ( + f"" + f"✓ {stats['total_samples']:,} samples loaded" + ) + _log(f"✓ Dataset loaded — " + f"{stats['total_samples']:,} samples, " + f"{stats['num_scenes']} scenes") + doc.add_timeout_callback(_step2, 1500) + doc.add_next_tick_callback(_done) + except Exception as e: + doc.add_next_tick_callback(lambda: _log( + f"✗ {e}" + )) + _pool.submit(_do) + + def _step2(): + nav_fn("visualize") + _log(f"Step 2/4" + f" — Visualizing sample 10…") + if state.dataset is not None: + refs["viz_slider"].end = max(1, len(state.dataset) - 1) + viz_refresh(10) + doc.add_timeout_callback(_step3, 2000) + + def _step3(): + nav_fn("augment") + _log(f"Step 3/4" + f" — Previewing Mirror augmentation…") + state.aug_config["mirror"] = True + state.aug_config["mirror_axis"] = "x" + state.aug_config["mirror_prob"] = 1.0 + if state.dataset is not None: + refs["aug_slider"].end = max(1, len(state.dataset) - 1) + aug_preview(10) + doc.add_timeout_callback(_step4, 2500) + + def _step4(): + nav_fn("simulate") + _log(f"Step 4/4" + f" — Running simulation…") + if state.dataset is None: + doc.add_timeout_callback(_finish, 1000) + return + sim_status.text = f"Running…" + def _do(): + try: + res = run_simulation(state.dataset, 0, 20, ["ADE","FDE","Collision"]) + def _done(): + parts = [f"{k.upper()}: {v:.4f}" + for k, v in res.get("means", {}).items()] + sim_status.text = ( + f"✓ {res['steps']} steps" + ) + _log( + f"✓ Simulation done — " + + " · ".join(parts) + "" + ) + doc.add_timeout_callback(_finish, 1500) + doc.add_next_tick_callback(_done) + except Exception as e: + doc.add_next_tick_callback(lambda: doc.add_timeout_callback(_finish, 500)) + _pool.submit(_do) + + def _finish(): + nav_fn("dashboard") + _log(f"" + f"Demo complete! All features working.") + demo_btn.disabled = False + demo_btn.label = "Run Demo" + + doc.add_timeout_callback(_step1, 300) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Root builder +# ═══════════════════════════════════════════════════════════════════════════ + +def build_ui(doc, state=None): + import bokeh + if state is None: + state = AppState() + + # ── Current Config ────────────────────────────────────────────────── + global C, S + C = THEMES.get(state.theme, THEMES["dark"]) + S = STRINGS.get(state.lang, STRINGS["en"]) + cards = {} + refs = {} + + # ── Panels ───────────────────────────────────────────────────────── + dash_panel = _build_dashboard(state, cards) + + ds_panel, ds_status, stats_src, _, _ = _build_dataset_panel( + doc, state, cards, lambda s: None + ) + refs["ds_status"] = ds_status + refs["stats_src"] = stats_src + + viz_panel, viz_refresh, viz_slider = _build_viz_panel(doc, state, cards) + refs["viz_refresh"] = viz_refresh + refs["viz_slider"] = viz_slider + + aug_panel, aug_preview, aug_slider = _build_aug_panel(doc, state, cards) + refs["aug_preview"] = aug_preview + refs["aug_slider"] = aug_slider + + sim_panel, sim_run_btn, sim_status, _ = _build_sim_panel(doc, state, cards) + refs["sim_run_btn"] = sim_run_btn + refs["sim_status"] = sim_status + + exp_panel = _build_export_panel(doc, state, cards) + + # Wrap panels in containers + panels = { + "dashboard": dash_panel, + "dataset": ds_panel, + "visualize": viz_panel, + "augment": aug_panel, + "simulate": sim_panel, + "export": exp_panel, + } + for p in panels.values(): + p.visible = False + + # ── Sidebar nav buttons ────────────────────────────────────────────── + # Local helper for localized string + def _t(k): return S.get(k, STRINGS["en"].get(k, k)) + + NAV = [ + ("dashboard", _t("dashboard")), + ("dataset", _t("dataset")), + ("visualize", _t("visualize")), + ("augment", _t("augment")), + ("simulate", _t("simulate")), + ("export", _t("export")), + ] + + # Header title div (updated on nav change) + header_title_div = _div( + _title_html("Dashboard", "Overview & quick start"), + width=380, + styles={"display": "flex", "align-items": "center"}, + ) + + def nav_fn(key: str): + old = state.active_panel + if old == key: + return + panels[old].visible = False + panels[key].visible = True + nav_btns[old].button_type = "default" + nav_btns[key].button_type = "primary" + state.active_panel = key + t, s = _PANEL_TITLES[key] + header_title_div.text = _title_html(t, s) + + panels["dashboard"].visible = True + + refs["nav_fn"] = nav_fn + + # ── Demo button ────────────────────────────────────────────────────── + demo_btn = Button( + label=_t("run_demo"), + icon=_get_icon("run_demo", "white"), + button_type="default", + sizing_mode="stretch_width", + height=52, + stylesheets=[f""" + :host .bk-btn {{ + background: {C['accent']} !important; + border: none !important; + color: white !important; + font-weight: 700 !important; + font-size: 16px !important; + border-radius: 14px !important; + height: 52px !important; + width: 100% !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + gap: 12px !important; + box-shadow: 0 6px 16px {C['accent']}40 !important; + transition: all 0.2s cubic-bezier(0.4, 0, 0.2, 1) !important; + }} + :host .bk-btn:hover {{ + transform: scale(1.02) translateY(-1px) !important; + box-shadow: 0 8px 20px {C['accent']}60 !important; + filter: brightness(1.1) !important; + }} + :host .bk-btn:active {{ + transform: scale(0.98) !important; + }} + :host .bk-btn-group {{ display: block; width: 100%; }} + """], + ) + refs["demo_btn"] = demo_btn + + # Toast notification div (fixed position) + toast_div = _div(_toast_html("Demo started", visible=False), width=0, height=0) + refs["toast"] = toast_div + + def on_demo(_): + _run_demo(doc, state, refs) + + demo_btn.on_click(on_demo) + + # ── Sidebar HTML sections ──────────────────────────────────────────── + # Local helper for localized string + def _t(k): return S.get(k, STRINGS["en"].get(k, k)) + + sidebar_logo = _div(f""" +
+ trajdata +
+""", width=_SIDEBAR_W) + + sidebar_run_btn_wrapper = column( + demo_btn, + width=_SIDEBAR_W, + styles={ + "padding": "12px", + "border-bottom": f"1px solid {C['border']}", + "background": C['surface'], + } + ) + + # Navigation items (secondary) + nav_btns = {} + for key, label in [ + ("dashboard", _t("dashboard")), + ("dataset", _t("dataset")), + ("visualize", _t("visualize")), + ("augment", _t("augment")), + ("simulate", _t("simulate")), + ("export", _t("export")), + ]: + is_active = (key == state.active_panel) + btn = _nav_btn(label, key, state.theme, partial(nav_fn, key), is_active) + nav_btns[key] = btn + + sidebar_nav_wrapper = column( + *[nav_btns[k] for k in nav_btns], + spacing=2, + width=_SIDEBAR_W, + styles={"padding": "8px 12px", "flex": "1", "overflow-y": "auto"} + ) + + + # ── Sidebar bottom (Theme & Language) ────────────────────────────── + def theme_fn(): + state.theme = "light" if state.theme == "dark" else "dark" + doc.clear() + build_ui(doc, state) + + def lang_fn(): + state.lang = "tr" if state.lang == "en" else "en" + doc.clear() + build_ui(doc, state) + + theme_btn = Button( + label=_t("theme_light") if state.theme == "dark" else _t("theme_dark"), + width=_SIDEBAR_W - 24, + height=36, + stylesheets=[f""" + :host .bk-btn {{ background: transparent; border: none !important; + color: {C['muted']}; font-size: 13px; font-weight: 500; text-align: left; }} + :host .bk-btn:hover {{ background: {C['surface3']}; color: {C['text']}; }} + """], + ) + theme_btn.on_click(theme_fn) + + lang_btn = Button( + label=_t("lang_switch"), + width=_SIDEBAR_W - 24, + height=36, + stylesheets=[f""" + :host .bk-btn {{ background: transparent; border: none !important; + color: {C['muted']}; font-size: 13px; font-weight: 500; text-align: left; }} + :host .bk-btn:hover {{ background: {C['surface3']}; color: {C['text']}; }} + """], + ) + lang_btn.on_click(lang_fn) + + sidebar_bottom = column( + theme_btn, + lang_btn, + width=_SIDEBAR_W, + styles={ + "padding": "12px", + "border-top": f"1px solid {C['border']}", + "margin-top": "auto" + } + ) + + sidebar_content = column( + sidebar_logo, + sidebar_run_btn_wrapper, + sidebar_nav_wrapper, + sidebar_bottom, + spacing=0, + width=_SIDEBAR_W, + styles={ + "background": C["surface"], + "height": "100%", + "overflow": "hidden" + } + ) + + # ── Header ────────────────────────────────────────────────────────── + # Header Title Area + + header_right = _div(f""" +
+
+""", width=120, styles={"flex-shrink": "0"}) + + spacer_div = _div("", sizing_mode="stretch_width") + + header = row( + header_title_div, + spacer_div, + header_right, + toast_div, + sizing_mode="stretch_width", + spacing=0, + styles={ + "background": C["surface"], + "border-bottom": f"1px solid {C['border']}", + "height": "72px", + "position": "sticky", + "top": "0", + "z-index": "100", + "align-items": "center", + "margin": "0", + "padding": "0", + "width": f"calc(100vw - {_SIDEBAR_W}px)", + }, + ) + + # ── Footer ────────────────────────────────────────────────────────── + footer = _div(f""" +
+
+ trajdata + Web UI  ·  + + build #1.0.0 + +
+ + © 2024 trajdata contributors  ·  {_t("rights")} +  ·  + + GitHub ↗ + + +
+""", sizing_mode="stretch_width", min_width=200, styles={"margin": "0", "width": "100%", "min-width": "100%"}) + + # ── Body ───────────────────────────────────────────────────────────── + main_area = column( + *panels.values(), + sizing_mode="stretch_width", + styles={ + "padding": "32px 40px", + "flex": "1", + "overflow-y": "auto", + }, + ) + + sidebar_col = column( + sidebar_content, + width=_SIDEBAR_W, + styles={ + "background": C["surface"], + "border-right": f"1px solid {C['border']}", + "height": "100vh", + "flex-shrink": "0", + }, + ) + + content_col = column( + header, + main_area, + footer, + sizing_mode="stretch_both", + spacing=0, + styles={ + "flex": "1", + "height": "100vh", + "overflow": "hidden", + "display": "flex", + "flex-direction": "column", + } + ) + + root = row( + sidebar_col, + content_col, + spacing=0, + sizing_mode="stretch_both", + styles={ + "background": C["bg"], + "margin": "0", + "padding": "0", + "overflow": "hidden", + "width": "100vw", + "height": "100vh", + }, + ) + doc.theme = "dark_minimal" if state.theme == "dark" else None + doc.add_root(root) + doc.title = "trajdata Web UI" + doc.template_variables["extra_head"] = f""" + + + + +"""