Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 147 additions & 30 deletions docker/dockerfile
Original file line number Diff line number Diff line change
@@ -1,47 +1,164 @@
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04

# To build (from repo root): docker build --tag foundationstereo -f docker/dockerfile .
# By default it includes the weights from the largest model (23-51-11): https://github.com/NVlabs/FoundationStereo?tab=readme-ov-file#model-weights
# To build with different weights: docker build --build-arg PRETRAINED_WEIGHTS="onnx" --tag foundationstereo -f ./docker/dockerfile .
# To run: docker run --gpus all -it foundationstereo /bin/bash
# Then you can run the demo script or whatever you want

ARG CUDA_VERSION=12.8.1
ARG PRETRAINED_WEIGHTS="23-51-11"

# https://github.com/anaconda/docker-images/blob/main/miniconda3/debian/Dockerfile
FROM continuumio/miniconda3 AS miniconda

FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu24.04 AS builder

ENV TZ=US/Pacific
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone

RUN apt-get update --fix-missing && \
apt-get install -y libgtk2.0-dev && \
apt-get install -y wget bzip2 ca-certificates curl git vim tmux g++ gcc build-essential cmake checkinstall gfortran libjpeg8-dev libtiff5-dev pkg-config yasm libavcodec-dev libavformat-dev libswscale-dev libdc1394-22-dev libxine2-dev libv4l-dev qt5-default libgtk2.0-dev libtbb-dev libatlas-base-dev libfaac-dev libmp3lame-dev libtheora-dev libvorbis-dev libxvidcore-dev libopencore-amrnb-dev libopencore-amrwb-dev x264 v4l-utils libprotobuf-dev protobuf-compiler libgoogle-glog-dev libgflags-dev libgphoto2-dev libhdf5-dev doxygen libflann-dev libboost-all-dev proj-data libproj-dev libyaml-cpp-dev cmake-curses-gui libzmq3-dev freeglut3-dev
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
apt-utils \
build-essential \
bzip2 \
ca-certificates \
checkinstall \
cmake \
cmake-curses-gui \
curl \
doxygen \
ffmpeg \
g++ \
gcc \
gdb \
gfortran \
git \
htop \
iputils-ping \
libassimp-dev \
libavcodec-dev \
libavformat-dev \
libblas-dev \
libboost-all-dev \
libccd-dev \
libcgal-dev \
libdc1394-dev \
libfaac-dev \
libflann-dev \
libgflags-dev \
libglew-dev \
libgoogle-glog-dev \
libgphoto2-dev \
libgtk2.0-dev \
libhdf5-dev \
libjpeg8-dev \
liblapack-dev \
libmp3lame-dev \
libnoise-dev \
libopencore-amrnb-dev \
libopencore-amrwb-dev \
libproj-dev \
libprotobuf-dev \
libswscale-dev \
libtbb-dev \
libtheora-dev \
libtiff5-dev \
libtinyxml2-dev \
libturbojpeg-dev \
libv4l-dev \
libvorbis-dev \
libxine2-dev \
libxvidcore-dev \
net-tools \
openexr \
p7zip-full \
p7zip-rar \
parallel \
pkg-config \
proj-data \
protobuf-compiler \
rclone \
rsync \
tmux \
v4l-utils \
vim \
wget \
x264 \
xvfb \
yasm \
zlib1g-dev \
&& apt-get clean && \
rm -rf /var/lib/apt/lists/*

COPY ./docker/environment.yml /tmp/environment.yml

RUN apt-get update --fix-missing && \
apt-get install -y --no-install-recommends apt-utils git gdb pkg-config libgtk2.0-dev libusb-1.0-0-dev wget software-properties-common &&\
apt-get install -y wget ca-certificates curl git vim tmux build-essential cmake checkinstall gfortran pkg-config yasm libavcodec-dev libavformat-dev libswscale-dev libxine2-dev libgtk2.0-dev libtbb-dev libatlas-base-dev libfaac-dev libtheora-dev libvorbis-dev libxvidcore-dev libopencore-amrnb-dev libopencore-amrwb-dev x264 libprotobuf-dev protobuf-compiler libgoogle-glog-dev libgflags-dev libgphoto2-dev libhdf5-dev doxygen libflann-dev libboost-all-dev libblas-dev liblapack-dev proj-data libproj-dev libccd-dev libglew-dev zlib1g-dev libtinyxml2-dev p7zip-full p7zip-rar xvfb rsync libnoise-dev libcgal-dev libassimp-dev iputils-ping parallel htop net-tools ffmpeg libturbojpeg-dev rclone openexr

# setup conda
COPY --from=miniconda /opt/conda /opt/conda
# Put conda in path so we can use conda activate
ENV PATH=/opt/conda/bin:$PATH

SHELL ["/bin/bash", "--login", "-c"]
# install python and all dependencies
# install flash attention separately because it requires a committed installed pytorch
# the conda environment name is "foundation_stereo" and is defined in the environment.yml file
RUN conda env create --file /tmp/environment.yml && \
conda run -n foundation_stereo pip install --no-cache-dir flash-attn && \
conda clean --all

RUN cd / && wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /miniconda.sh && \
/bin/bash /miniconda.sh -b -p /opt/conda &&\
ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh &&\
echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc &&\
/bin/bash -c "source ~/.bashrc" && \
/opt/conda/bin/conda update -n base -c defaults conda -y &&\
/opt/conda/bin/conda create -n my python=3.9
### Final image
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu24.04

ARG PRETRAINED_WEIGHTS

ENV PATH $PATH:/opt/conda/envs/my/bin
ENV RCLONE_CONFIG /rclone.conf
# otherwise apt-get will ask for timezone input
ENV TZ=US/Pacific
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone

RUN conda init bash &&\
echo "conda activate my" >> ~/.bashrc &&\
conda activate my &&\
pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
RUN apt-get update && apt-get install -y --no-install-recommends \
ffmpeg \
libavcodec-dev \
libavformat-dev \
libblas-dev \
libdc1394-dev \
libfaac-dev \
libflann-dev \
libgflags-dev \
libglew-dev \
libgoogle-glog-dev \
libgtk2.0-dev \
libhdf5-dev \
libjpeg8-dev \
liblapack-dev \
libmp3lame-dev \
libopencore-amrnb-dev \
libopencore-amrwb-dev \
libproj-dev \
libswscale-dev \
libtbb-dev \
libtheora-dev \
libtiff5-dev \
libtinyxml2-dev \
libturbojpeg-dev \
libv4l-dev \
libvorbis-dev \
libxvidcore-dev \
openexr \
proj-data \
v4l-utils \
wget \
x264 \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*

COPY --from=builder /opt/conda /opt/conda

ENV OPENCV_IO_ENABLE_OPENEXR=1
# Add Conda to PATH so we can use it
ENV PATH=/opt/conda/bin:$PATH

RUN conda activate my &&\
pip install scikit-image omegaconf opencv-contrib-python imgaug Ninja timm albumentations nodejs jupyterlab scipy joblib scikit-learn ruamel.yaml trimesh pyyaml imageio open3d transformations einops gdown &&\
pip install -U git+https://github.com/lilohuang/PyTurboJPEG.git &&\
pip install flash-attn --no-build-isolation &&\
pip install xformers==0.0.28.post1 --index-url https://download.pytorch.org/whl/cu124
ENV SHELL=/bin/bash
RUN ln -sf /bin/bash /bin/sh
COPY ./docker/download_weights.py /tmp/download_weights.py
RUN conda run -n foundation_stereo python /tmp/download_weights.py --weights ${PRETRAINED_WEIGHTS} && rm -rf /tmp/download_weights.py

COPY . /FoundationStereo
WORKDIR /FoundationStereo

SHELL ["/bin/bash", "-c", "source ~/.bashrc && conda activate my"]
# When commands are executed, they will be run in the conda environment
ENTRYPOINT ["sh", "-c", ". /opt/conda/etc/profile.d/conda.sh && conda activate foundation_stereo && exec \"$@\"", "--"]
70 changes: 70 additions & 0 deletions docker/download_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import urllib.request
import argparse
import torch
from timm.models import create_model

# This script downloads the pretrained weights for FoundationStereo and its dependencies

HF_BASE_URL: str = "https://huggingface.co/datasets/steve-redefine/FoundationStereoWeights/resolve/main"
ROOT_DIR: str = "/FoundationStereo/pretrained_models"

def download_file(url: str, dest_path: str) -> None:
"""Download a file from a URL to a given local path using standard Python."""
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
print(f"Downloading {url} → {dest_path}")
try:
urllib.request.urlretrieve(url, dest_path)
except Exception as e:
print(f"❌ Failed to download {url}: {e}")
raise

def download_pretrained_weights(model_name: str) -> None:
"""Download FoundationStereo pretrained weights from Hugging Face based on the model name."""
if model_name == "23-51-11":
model_dir = os.path.join(ROOT_DIR, "23-51-11")
download_file(f"{HF_BASE_URL}/23-51-11/model_best_bp2.pth", f"{model_dir}/model_best_bp2.pth")
download_file(f"{HF_BASE_URL}/23-51-11/cfg.yaml", f"{model_dir}/cfg.yaml")

elif model_name == "11-33-40":
model_dir = os.path.join(ROOT_DIR, "11-33-40")
download_file(f"{HF_BASE_URL}/11-33-40/model_best_bp2.pth", f"{model_dir}/model_best_bp2.pth")
download_file(f"{HF_BASE_URL}/11-33-40/cfg.yaml", f"{model_dir}/cfg.yaml")

elif model_name == "onnx":
model_dir = os.path.join(ROOT_DIR, "onnx")
download_file(f"{HF_BASE_URL}/onnx/foundation_stereo_23-51-11.onnx", f"{model_dir}/foundation_stereo_23-51-11.onnx")

else:
raise ValueError(f"❌ Unrecognized model name: {model_name}")


def download_torchhub_and_timm_models() -> None:
"""Preload Torch Hub (DINOv2) and timm (EdgeNeXt) model weights."""
print("⬇️ Downloading DINOv2 repo from Torch Hub...")
torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', source='github', trust_repo=True)

print("⬇️ Downloading timm model (edgenext_small.usi_in1k)...")
create_model('edgenext_small.usi_in1k', pretrained=True)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Download FoundationStereo weights and dependencies.")
parser.add_argument(
"--weights",
type=str,
choices=["23-51-11", "11-33-40", "onnx", ""],
help="Which pretrained weights to download. If empty, nothing, including dependency models, will be downloaded.",
)
return parser.parse_args()

if __name__ == "__main__":
args = parse_args()
weights = args.weights
if weights == "":
print("No pretrained weights selected. Skipping download. Also not downloading Torch Hub and timm models.")
exit(0)
print(f"Selected pretrained model: {weights}")
download_pretrained_weights(weights)
download_torchhub_and_timm_models()
print("✅ All model weights downloaded successfully.")
35 changes: 35 additions & 0 deletions docker/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: foundation_stereo
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9
- pip
- pip:
- torch==2.4.1
- torchvision==0.19.1
- torchaudio==2.4.1
- scikit-image
- omegaconf
- opencv-contrib-python
- imgaug
- ninja
- timm
- albumentations
- jupyterlab
- scipy
- joblib
- scikit-learn
- ruamel.yaml
- trimesh
- pyyaml
- imageio
- open3d
- transformations
- einops
- gdown
- nodejs
- xformers==0.0.28.post1
- huggingface-hub