Skip to content

Commit cc91644

Browse files
authored
Merge pull request #84 from hmorimitsu/dpflow
Dpflow
2 parents 753922d + ace9ff7 commit cc91644

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+5366
-65
lines changed

.github/workflows/lightning.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0"]
19+
lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0", "2.5.0"]
2020

2121
steps:
2222
- uses: actions/checkout@v4

.github/workflows/pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
pytorch: [
20+
'torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cpu',
2021
'torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu',
2122
'torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cpu',
2223
'torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu',
23-
'torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu',
2424
]
2525

2626
steps:

model_benchmark.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@
3636
from ptlflow.utils.lightning.ptlflow_cli import PTLFlowCLI
3737
from ptlflow.utils.registry import RegisteredModel
3838
from ptlflow.utils.timer import Timer
39-
from ptlflow.utils.utils import (
40-
count_parameters,
41-
make_divisible,
42-
)
39+
from ptlflow.utils.utils import count_parameters
4340

4441
NUM_COMMON_COLUMNS = 6
4542
TABLE_KEYS_LEGENDS = {
@@ -302,8 +299,8 @@ def benchmark(args: Namespace, device_handle) -> pd.DataFrame:
302299
1,
303300
2,
304301
3,
305-
make_divisible(input_size[0], model.output_stride),
306-
make_divisible(input_size[1], model.output_stride),
302+
input_size[0],
303+
input_size[1],
307304
)
308305
}
309306

@@ -372,7 +369,7 @@ def benchmark(args: Namespace, device_handle) -> pd.DataFrame:
372369
)
373370
except Exception as e: # noqa: B902
374371
logger.warning(
375-
"Skipping model %s with datatype %s due to exception %s",
372+
"Skipping model {} with datatype {} due to exception {}",
376373
mname,
377374
dtype_str,
378375
e,
@@ -440,8 +437,8 @@ def estimate_inference_time(
440437
args.batch_size,
441438
2,
442439
3,
443-
make_divisible(input_size[0], model.output_stride),
444-
make_divisible(input_size[1], model.output_stride),
440+
input_size[0],
441+
input_size[1],
445442
)
446443
}
447444
if torch.cuda.is_available():

plot_results.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
import argparse
2020
import logging
2121
from pathlib import Path
22-
from typing import Optional, Tuple, Union
22+
from typing import Union
2323

24-
import numpy as np
2524
import pandas as pd
2625
import plotly.express as px
2726

ptlflow/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,17 @@ def load_checkpoint(ckpt_path: str, model_ref: BaseModel) -> Dict[str, Any]:
236236
device = "cuda" if torch.cuda.is_available() else "cpu"
237237

238238
if Path(ckpt_path).exists():
239-
ckpt = torch.load(ckpt_path, map_location=torch.device(device))
239+
ckpt = torch.load(
240+
ckpt_path, map_location=torch.device(device), weights_only=True
241+
)
240242
else:
241243
model_dir = Path(hub.get_dir()) / "checkpoints"
242244
ckpt = hub.load_state_dict_from_url(
243245
ckpt_path,
244246
model_dir=model_dir,
245247
map_location=torch.device(device),
246248
check_hash=True,
249+
weights_only=True,
247250
)
248251
return ckpt
249252

ptlflow/data/datasets.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from loguru import logger
2626
import numpy as np
2727
import torch
28+
import torch.nn.functional as F
2829
from torch.utils.data import Dataset
2930
from ptlflow.utils import flow_utils
3031

@@ -1662,7 +1663,6 @@ def __init__( # noqa: C901
16621663
reverse_only: bool = False,
16631664
subsample: bool = False,
16641665
is_image_4k: bool = False,
1665-
image_4k_split_dir_suffix: str = "_4k",
16661666
) -> None:
16671667
"""Initialize SintelDataset.
16681668
@@ -1705,11 +1705,6 @@ def __init__( # noqa: C901
17051705
If False, and is_image_4k is True, then the groundtruth is returned in its original 4D-shaped 4K resolution, but the flow values are doubled.
17061706
is_image_4k : bool, default False
17071707
If True, assumes the input images will be provided in 4K resolution, instead of the original 2K.
1708-
image_4k_split_dir_suffix : str, default "_4k"
1709-
Only used when is_image_4k == True. It indicates the suffix to add to the split folder name where the 4k images are located.
1710-
For example, by default, the 4K images need to be located inside folders called "train_4k" and/or "test/4k".
1711-
The structure of these folders should be the same as the original "train" and "test".
1712-
The "*_4k" folders only need to contain the image directories, the groundtruth will still be loaded from the original locations.
17131708
"""
17141709
if isinstance(side_names, str):
17151710
side_names = [side_names]
@@ -1731,7 +1726,6 @@ def __init__( # noqa: C901
17311726
self.sequence_position = sequence_position
17321727
self.subsample = subsample
17331728
self.is_image_4k = is_image_4k
1734-
self.image_4k_split_dir_suffix = image_4k_split_dir_suffix
17351729

17361730
if self.is_image_4k:
17371731
assert not self.subsample
@@ -1758,17 +1752,9 @@ def __init__( # noqa: C901
17581752
for side in side_names:
17591753
for direcs in directions:
17601754
rev = direcs[0] == "BW"
1761-
img_split_dir_name = (
1762-
f"{split_dir}{self.image_4k_split_dir_suffix}"
1763-
if self.is_image_4k
1764-
else split_dir
1765-
)
17661755
image_paths = sorted(
17671756
(
1768-
Path(self.root_dir)
1769-
/ img_split_dir_name
1770-
/ seq_name
1771-
/ f"frame_{side}"
1757+
Path(self.root_dir) / split_dir / seq_name / f"frame_{side}"
17721758
).glob("*.png"),
17731759
reverse=rev,
17741760
)
@@ -1883,12 +1869,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901
18831869
if self.transform is not None:
18841870
inputs = self.transform(inputs)
18851871
elif self.is_image_4k:
1872+
inputs["images"] = [
1873+
cv.resize(img, None, fx=2, fy=2, interpolation=cv.INTER_CUBIC)
1874+
for img in inputs["images"]
1875+
]
18861876
if self.transform is not None:
18871877
inputs = self.transform(inputs)
18881878
if "flows" in inputs:
18891879
inputs["flows"] = 2 * inputs["flows"]
18901880
if self.get_backward:
18911881
inputs["flows_b"] = 2 * inputs["flows_b"]
1882+
1883+
process_keys = [("flows", "valids")]
1884+
if self.get_backward:
1885+
process_keys.append(("flows_b", "valids_b"))
1886+
1887+
for flow_key, valid_key in process_keys:
1888+
flow = inputs[flow_key]
1889+
flow_stack = rearrange(
1890+
flow, "b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2
1891+
)
1892+
flow_stack4 = flow_stack.repeat(1, 4, 1, 1, 1)
1893+
flow_stack4 = rearrange(
1894+
flow_stack4, "b (m n) c h w -> b m n c h w", m=4
1895+
)
1896+
diff = flow_stack[:, :, None] - flow_stack4
1897+
diff = rearrange(diff, "b m n c h w -> b (m n) c h w")
1898+
diff = torch.sqrt(torch.pow(diff, 2).sum(2))
1899+
max_diff, _ = diff.max(1)
1900+
max_diff = F.interpolate(
1901+
max_diff[:, None], scale_factor=2, mode="nearest"
1902+
)
1903+
inputs[valid_key] = (max_diff < 1.0).float()
18921904
else:
18931905
if self.transform is not None:
18941906
inputs = self.transform(inputs)

ptlflow/data/flow_datamodule.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
tartanair_root_dir: Optional[str] = None,
6565
spring_root_dir: Optional[str] = None,
6666
kubric_root_dir: Optional[str] = None,
67+
middlebury_st_root_dir: Optional[str] = None,
68+
viper_root_dir: Optional[str] = None,
6769
dataset_config_path: str = "./datasets.yaml",
6870
):
6971
super().__init__()
@@ -89,6 +91,8 @@ def __init__(
8991
self.tartanair_root_dir = tartanair_root_dir
9092
self.spring_root_dir = spring_root_dir
9193
self.kubric_root_dir = kubric_root_dir
94+
self.middlebury_st_root_dir = middlebury_st_root_dir
95+
self.viper_root_dir = viper_root_dir
9296
self.dataset_config_path = dataset_config_path
9397

9498
self.predict_dataset_parsed = None
@@ -935,6 +939,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset:
935939
sequence_position = "first"
936940
reverse_only = False
937941
subsample = False
942+
is_image_4k = False
938943
side_names = []
939944
fbocc_transform = False
940945
for v in args:
@@ -952,6 +957,8 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset:
952957
sequence_position = v.split("_")[1]
953958
elif v == "sub":
954959
subsample = True
960+
elif v == "4k":
961+
is_image_4k = True
955962
elif v == "left":
956963
side_names.append("left")
957964
elif v == "right":
@@ -1012,6 +1019,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset:
10121019
sequence_position=sequence_position,
10131020
reverse_only=reverse_only,
10141021
subsample=subsample,
1022+
is_image_4k=is_image_4k,
10151023
)
10161024
return dataset
10171025

ptlflow/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .csflow import *
44
from .dicl import *
55
from .dip import *
6+
from .dpflow import *
67
from .fastflownet import *
78
from .flow1d import *
89
from .flowformer import *

ptlflow/models/base_model/base_model.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,25 @@ def __init__(
6969
lr: Optional[float] = None,
7070
wdecay: Optional[float] = None,
7171
warm_start: bool = False,
72+
metric_interpolate_pred_to_target_size: bool = False,
7273
) -> None:
7374
"""Initialize BaseModel.
7475
7576
Parameters
7677
----------
77-
args : Namespace
78-
A namespace with the required arguments. Typically, this can be gotten from add_model_specific_args().
79-
loss_fn : Callable
80-
A function to be used to compute the loss for the training. The input of this function must match the output of the
81-
forward() method. The output of this function must be a tensor with a single value.
8278
output_stride : int
8379
How many times the output of the network is smaller than the input.
80+
loss_fn : Optional[Callable]
81+
A function to be used to compute the loss for the training. The input of this function must match the output of the
82+
forward() method. The output of this function must be a tensor with a single value.
83+
lr : Optional[float]
84+
The learning rate to be used for training the model. If not provided, it will be set as 1e-4.
85+
wdecay : Optional[float]
86+
The weight decay to be used for training the model. If not provided, it will be set as 1e-4.
87+
warm_start : bool, default False
88+
If True, use warm start to initialize the flow prediction. The warm_start strategy was presented by the RAFT method and forward interpolates the prediction from the last frame.
89+
metric_interpolate_pred_to_target_size : bool, default False
90+
If True, the prediction is bilinearly interpolated to match the target size during metric calculation, if their sizes are different.
8491
"""
8592
super(BaseModel, self).__init__()
8693

@@ -89,13 +96,19 @@ def __init__(
8996
self.lr = lr
9097
self.wdecay = wdecay
9198
self.warm_start = warm_start
99+
self.metric_interpolate_pred_to_target_size = (
100+
metric_interpolate_pred_to_target_size
101+
)
92102

93103
self.train_size = None
94104
self.train_avg_length = None
95105

96106
self.extra_params = None
97107

98-
self.train_metrics = FlowMetrics(prefix="train/")
108+
self.train_metrics = FlowMetrics(
109+
prefix="train/",
110+
interpolate_pred_to_target_size=self.metric_interpolate_pred_to_target_size,
111+
)
99112
self.val_metrics = nn.ModuleList()
100113
self.val_dataset_names = []
101114

@@ -132,6 +145,7 @@ def add_extra_param(self, name, value):
132145
def preprocess_images(
133146
self,
134147
images: torch.Tensor,
148+
stride: Optional[int] = None,
135149
bgr_add: Union[float, Tuple[float, float, float], np.ndarray, torch.Tensor] = 0,
136150
bgr_mult: Union[
137151
float, Tuple[float, float, float], np.ndarray, torch.Tensor
@@ -201,7 +215,7 @@ def preprocess_images(
201215
if bgr_to_rgb:
202216
images = torch.flip(images, [-3])
203217

204-
stride = self.output_stride
218+
stride = self.output_stride if stride is None else stride
205219
if target_size is not None:
206220
stride = None
207221

@@ -371,7 +385,10 @@ def validation_step(
371385
"""
372386
if len(self.val_metrics) <= dataloader_idx:
373387
self.val_metrics.append(
374-
FlowMetrics(prefix="val/").to(device=batch["flows"].device)
388+
FlowMetrics(
389+
prefix="val/",
390+
interpolate_pred_to_target_size=self.metric_interpolate_pred_to_target_size,
391+
).to(device=batch["flows"].device)
375392
)
376393
self.val_dataset_names.append(None)
377394

0 commit comments

Comments
 (0)