Skip to content

Commit 9015703

Browse files
authored
Merge pull request #29 from hmorimitsu/docfixes
Docfixes
2 parents 08a3c9e + c72126b commit 9015703

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
python-version: [3.6, 3.9]
20-
torch: [1.7.1, 1.9.0]
21-
pytorch-lightning: [1.1.8, 1.4.0]
20+
torch: [1.7.1, 1.10.0]
21+
pytorch-lightning: [1.1.8, 1.5.2]
2222

2323
steps:
2424
- uses: actions/checkout@v2

docs/source/starting/inference.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ The code below shows a way to do this:
7171
predictions = model(inputs)
7272
7373
# Remove extra padding that may have been added to the inputs
74-
predictions = io_adapter.unpad(predictions)
74+
predictions = io_adapter.unpad_and_unscale(predictions)
7575
7676
# The output is a dict with possibly several keys,
7777
# but it should always store the optical flow prediction in a key called 'flows'.
@@ -93,4 +93,12 @@ The code below shows a way to do this:
9393
cv.imshow('image1', images[0])
9494
cv.imshow('image2', images[1])
9595
cv.imshow('flow', flow_bgr_npy)
96-
cv.waitKey()
96+
cv.waitKey()
97+
98+
Inference on batches of images
99+
==============================
100+
101+
For simplicity, the base PTLFlow scripts do not provide a direct way to do inference on batches.
102+
However, it should be easy to extend the base scripts to your use case.
103+
One example of a workaround to work with batches can be found in
104+
`[this GitHub issue] <https://github.com/hmorimitsu/ptlflow/issues/28>`__.

ptlflow/models/base_model/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def configure_optimizers(self) -> Dict[str, Any]:
338338

339339
self.train_dataloader() # Just to initialize dataloader variables
340340

341-
if self.args.max_steps is None:
341+
if self.args.max_steps is None or self.args.max_steps <= 0:
342342
if self.args.max_epochs is None:
343343
self.args.max_epochs = 10
344344
logging.warning('--max_epochs is not set. It will be set to %d.', self.args.max_epochs)

0 commit comments

Comments
 (0)