Skip to content

Commit 069b237

Browse files
authored
Merge pull request #79 from hmorimitsu/v04
V04
2 parents 9f20eec + be833e9 commit 069b237

File tree

270 files changed

+11551
-6064
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

270 files changed

+11551
-6064
lines changed

.github/workflows/build.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
python-version: ['3.10']
21+
python-version: ['3.12']
2222

2323
steps:
2424
- uses: actions/checkout@v4
@@ -29,10 +29,10 @@ jobs:
2929
- name: Install dependencies
3030
run: |
3131
python -m pip install --upgrade pip
32-
python -m pip install build==1.0.3
33-
python -m pip install --upgrade setuptools==68.0.0 wheel
34-
python -m pip install --upgrade pytest
35-
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
32+
python -m pip install build==1.2.2.post1
33+
python -m pip install --upgrade setuptools==75.6.0 wheel==0.45.1
34+
python -m pip install --upgrade pytest==8.3.3
35+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
3636
- name: Install package and remove local dir
3737
run: |
3838
python -m build

.github/workflows/lightning.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
lightning: [1.9.5]
19+
lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0"]
2020

2121
steps:
2222
- uses: actions/checkout@v4
2323
- name: Replace lightning
2424
uses: jacobtomlinson/gha-find-replace@v3
2525
with:
26-
find: "lightning<2"
27-
replace: "lightning==${{ matrix.lightning }}"
26+
find: "lightning[pytorch-extra]>=2,<2.5"
27+
replace: "lightning[pytorch-extra]==${{ matrix.lightning }}"
2828
regex: false
2929
include: "requirements.txt"
3030
- name: Install dependencies
3131
run: |
3232
python -m pip install --upgrade pip
33-
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
33+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
3434
pip install -r requirements.txt
3535
- name: Test with pytest
3636
run: |
37-
pip install pytest
37+
pip install pytest==8.3.3
3838
python -m pytest tests/

.github/workflows/publish_pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ jobs:
2626
- name: Install dependencies
2727
run: |
2828
python -m pip install --upgrade pip
29-
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
30-
pip install build==1.0.3
29+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
30+
pip install build==1.2.2.post1
3131
- name: Build package
3232
run: python -m build
3333
- name: Publish package

.github/workflows/python.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
python-version: ['3.8', '3.9', '3.10', '3.11']
19+
python-version: ['3.10', '3.11', '3.12']
2020

2121
steps:
2222
- uses: actions/checkout@v4
@@ -27,9 +27,9 @@ jobs:
2727
- name: Install dependencies
2828
run: |
2929
python -m pip install --upgrade pip
30-
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
30+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
3131
pip install -r requirements.txt
3232
- name: Test with pytest
3333
run: |
34-
pip install pytest
34+
pip install pytest==8.3.3
3535
python -m pytest tests/

.github/workflows/pytorch.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
pytorch: [
20-
'torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu',
21-
'torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cpu',
22-
'torch==1.13.1+cpu torchvision==0.14.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu',
20+
'torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu',
21+
'torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cpu',
22+
'torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu',
23+
'torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu',
2324
]
2425

2526
steps:
@@ -31,5 +32,5 @@ jobs:
3132
pip install -r requirements.txt
3233
- name: Test with pytest
3334
run: |
34-
pip install pytest
35+
pip install pytest==8.3.3
3536
python -m pytest tests/

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
lightning_logs/
12
ptlflow_logs/
23
ptlflow_scripts/
34
outputs/
5+
ckpts/
46

57
# Byte-compiled / optimized / DLL files
68
__pycache__/

README.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ This is still under development, so some things may not work as intended. I plan
2424

2525
## What's new
2626

27+
### - v0.4.0
28+
29+
Major update to support Lightning 2 (finally!). However, it also introduces breaking changes from the previous v0.3 code. See the details below.
30+
31+
- Transitioning from v0.3 to v0.4: check the [v0.4 upgrade guide](https://ptlflow.readthedocs.io/en/latest/starting/v04_upgrade_guide.html)
32+
- Added features:
33+
- Support for YAML config files. See the [config file documentation](https://ptlflow.readthedocs.io/en/latest/starting/config_files.html)
34+
- Table [comparing PTLFlow results with the original papers](https://ptlflow.readthedocs.io/en/latest/results/paper_ptlflow.html) to check the stability of the included models.
35+
- Added new models:
36+
- NeuFlow v2 [https://arxiv.org/abs/2408.10161](https://arxiv.org/abs/2408.10161)
37+
- Add support for more datasets:
38+
- Middlebury-ST [https://vision.middlebury.edu/stereo/data/scenes2014/]{https://vision.middlebury.edu/stereo/data/scenes2014/}
39+
- VIPER [https://playing-for-benchmarks.org/](https://playing-for-benchmarks.org/)
40+
2741
### - v0.3.2
2842

2943
- Added new models:
@@ -129,8 +143,13 @@ Please take a look at the [documentation](https://ptlflow.readthedocs.io/) to le
129143

130144
You can also check the notebooks below running on Google Colab for some practical examples:
131145

132-
- [Inference with a pretrained model](https://colab.research.google.com/drive/1YARBRUGplqTRnRuY9sKIs6LY_2kWAWZJ?usp=sharing).
133-
- [Training and using the learned weights for inference](https://colab.research.google.com/drive/1mbuAEF728_jZpFEsQHXDxjIGAcB1-nVs?usp=sharing).
146+
- [Inference with a pretrained model](https://colab.research.google.com/drive/1_WXvIRweQJgex0X-HS0LFXBb0IWZIvR4?usp=sharing).
147+
- [Training and using the learned weights for inference](https://colab.research.google.com/drive/1b_SMGSXh9F9TkinqZt0c64EH-GE87HVi?usp=sharing).
148+
149+
If you are using the previous v0.3.X code, then check the [v0.3.2 documentation](https://ptlflow.readthedocs.io/en/v0.3.2/) and the following example notebooks:
150+
151+
- [Inference with a pretrained model (PTLFlow v0.3)](https://colab.research.google.com/drive/1YARBRUGplqTRnRuY9sKIs6LY_2kWAWZJ?usp=sharing).
152+
- [Training and using the learned weights for inference (PTLFlow v0.3)](https://colab.research.google.com/drive/1mbuAEF728_jZpFEsQHXDxjIGAcB1-nVs?usp=sharing).
134153

135154
## Licenses
136155

compare_paper_results.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Create a side-by-side table comparing the results of PTLFlow with those reported in the original papers.
2+
3+
This script only evaluates results of models that provide the "things" pretrained models.
4+
5+
Tha parsing of this script is tightly connected to how the results are output by validate.py.
6+
"""
7+
8+
# =============================================================================
9+
# Copyright 2024 Henrique Morimitsu
10+
#
11+
# Licensed under the Apache License, Version 2.0 (the "License");
12+
# you may not use this file except in compliance with the License.
13+
# You may obtain a copy of the License at
14+
#
15+
# http://www.apache.org/licenses/LICENSE-2.0
16+
#
17+
# Unless required by applicable law or agreed to in writing, software
18+
# distributed under the License is distributed on an "AS IS" BASIS,
19+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20+
# See the License for the specific language governing permissions and
21+
# limitations under the License.
22+
# =============================================================================
23+
24+
import argparse
25+
import math
26+
from pathlib import Path
27+
28+
from loguru import logger
29+
import pandas as pd
30+
31+
PAPER_VAL_COLS = {
32+
"model": ("Model", "model"),
33+
"sclean": ("S.clean", "sintel-clean-val/epe"),
34+
"sfinal": ("S.final", "sintel-final-val/epe"),
35+
"k15epe": ("K15-epe", "kitti-2015-val/epe"),
36+
"k15fl": ("K15-fl", "kitti-2015-val/flall"),
37+
}
38+
39+
40+
def _init_parser() -> argparse.ArgumentParser:
41+
parser = argparse.ArgumentParser()
42+
parser.add_argument(
43+
"--paper_results_path",
44+
type=str,
45+
default=str(Path("docs/source/results/paper_results_things.csv")),
46+
help=("Path to the csv file containing the results from the papers."),
47+
)
48+
parser.add_argument(
49+
"--validate_results_path",
50+
type=str,
51+
default=str(Path("docs/source/results/metrics_all_things.csv")),
52+
help=(
53+
"Path to the csv file containing the results obtained by the validate script."
54+
),
55+
)
56+
parser.add_argument(
57+
"--output_dir",
58+
type=str,
59+
default=str(Path("outputs/metrics")),
60+
help=("Path to the directory where the outputs will be saved."),
61+
)
62+
parser.add_argument(
63+
"--add_delta",
64+
action="store_true",
65+
help=(
66+
"If set, adds one more column showing the difference between paper and validation results."
67+
),
68+
)
69+
70+
return parser
71+
72+
73+
def save_results(args: argparse.Namespace) -> None:
74+
paper_df = pd.read_csv(args.paper_results_path)
75+
val_df = pd.read_csv(args.validate_results_path)
76+
paper_df["model"] = paper_df[PAPER_VAL_COLS["model"][0]]
77+
val_df["model"] = val_df[PAPER_VAL_COLS["model"][1]]
78+
df = pd.merge(val_df, paper_df, "left", "model")
79+
80+
compare_cols = ["ptlflow", "paper"]
81+
if args.add_delta:
82+
compare_cols.append("delta")
83+
84+
out_dict = {"model": ["", ""]}
85+
for name in list(PAPER_VAL_COLS.keys())[1:]:
86+
for ic, col in enumerate(compare_cols):
87+
out_dict[f"{name}-{col}"] = [name if ic == 0 else "", col]
88+
89+
for _, row in df.iterrows():
90+
out_dict["model"].append(row["model"])
91+
for key in list(PAPER_VAL_COLS.keys())[1:]:
92+
paper_col_name = PAPER_VAL_COLS[key][0]
93+
paper_res = float(row[paper_col_name])
94+
val_col_name = PAPER_VAL_COLS[key][1]
95+
val_res = float(row[val_col_name])
96+
res_list = [val_res, paper_res]
97+
98+
if args.add_delta:
99+
delta = val_res - paper_res
100+
res_list.append(delta)
101+
102+
for name, res in zip(compare_cols, res_list):
103+
out_dict[f"{key}-{name}"].append(
104+
"" if (math.isinf(res) or math.isnan(res)) else f"{res:.3f}"
105+
)
106+
107+
out_df = pd.DataFrame(out_dict)
108+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
109+
output_path = Path(args.output_dir) / "paper_ptlflow_metrics.csv"
110+
out_df.to_csv(output_path, index=False, header=False)
111+
logger.info("Results saved to: {}", output_path)
112+
113+
114+
if __name__ == "__main__":
115+
parser = _init_parser()
116+
args = parser.parse_args()
117+
save_results(args)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# lightning.pytorch==2.4.0
2+
# Use this config to benchmark all the models.
3+
# python validate.py --config configs/results/model_benchmark_all.yaml
4+
all: true
5+
select: null
6+
ckpt_path: null
7+
exclude: null
8+
csv_path: null
9+
num_trials: 1
10+
num_samples: 10
11+
sleep_interval: 0.0
12+
input_size:
13+
- 500
14+
- 1000
15+
output_path: outputs/benchmark
16+
final_speed_mode: median
17+
final_memory_mode: first
18+
plot_axes: null
19+
plot_log_x: false
20+
plot_log_y: false
21+
datatypes:
22+
- fp32
23+
batch_size: 1
24+
seed_everything: true

configs/results/validate_all.yaml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# lightning.pytorch==2.4.0
2+
# Use this config to generate validation results for all models using all their pretrained ckpts.
3+
# python validate.py --config configs/results/validate_all.yaml
4+
all: true
5+
select: null
6+
exclude: null
7+
ckpt_path: things
8+
output_path: outputs/validate
9+
write_outputs: false
10+
show: false
11+
flow_format: original
12+
max_forward_side: null
13+
scale_factor: null
14+
max_show_side: 1000
15+
max_samples: null
16+
reversed: false
17+
fp16: false
18+
seq_val_mode: all
19+
write_individual_metrics: false
20+
epe_clip: 5.0
21+
seed_everything: true
22+
data:
23+
predict_dataset: null
24+
test_dataset: null
25+
train_dataset: null
26+
val_dataset: sintel-clean+sintel-final+kitti-2015
27+
train_batch_size: null
28+
train_num_workers: 4
29+
train_crop_size: null
30+
train_transform_cuda: false
31+
train_transform_fp16: false
32+
autoflow_root_dir: null
33+
flying_chairs_root_dir: null
34+
flying_chairs2_root_dir: null
35+
flying_things3d_root_dir: null
36+
flying_things3d_subset_root_dir: null
37+
mpi_sintel_root_dir: null
38+
kitti_2012_root_dir: null
39+
kitti_2015_root_dir: null
40+
hd1k_root_dir: null
41+
tartanair_root_dir: null
42+
spring_root_dir: null
43+
kubric_root_dir: null
44+
dataset_config_path: ./datasets.yaml

0 commit comments

Comments
 (0)