Skip to content
Open
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
aa34aa0
created class BlobLoader and moved all related function to sep file
Feb 28, 2023
f745dfc
added type hints and deleted chore pyre-ignore
Feb 28, 2023
c3c5110
linter
Feb 28, 2023
9b431bd
linter
Feb 28, 2023
c74261d
Merge branch 'main' into main
salaxieb Feb 28, 2023
627e60f
deleted chore pyre-ignore
Feb 28, 2023
d0a2d4d
Merge branch 'main' of github.com:salaxieb/pytorch3d
Feb 28, 2023
0aa27a6
renamed load_blob to blob_loader
Mar 1, 2023
53823cf
sending to BlobLoader whore seq_annotation
Mar 1, 2023
d6f13eb
made blob_loader dataclass to avoid boilerplate
Mar 1, 2023
86e64f7
documented, that FrameData modification done inplace
Mar 1, 2023
2f17049
spliited JsonIndexDataset args to 2 gorups: Matadata-related and Blo…
Mar 1, 2023
527ec09
code refactoring to delete chore pyre-ignore
Mar 1, 2023
24b731b
deleted chore function
Mar 6, 2023
f484a12
BloabLoader tests boilerplate
Mar 6, 2023
b8674ea
tests WIP (not tested)
Mar 7, 2023
faeffcf
tests typos and errors WIP
Mar 9, 2023
bc24e29
tests typos and errors WIP
Mar 9, 2023
e9c5969
solved error and typos for test_bbox
Mar 9, 2023
44cfcfb
updating test_blob_loader WIP
Mar 9, 2023
11def0a
blob loader tests ready for review
Mar 9, 2023
bc52382
typo
Mar 9, 2023
0149377
typo
Mar 9, 2023
3bcbd01
linter
Mar 9, 2023
269cffa
all entry tests run thru all frames
Mar 9, 2023
f930d71
assert .. == .. to self.assertEqual(.., ..)
Mar 10, 2023
dc7a702
testing only on 1 frame
Mar 10, 2023
fcd8d8b
instead of loading whole dataset, loading only single frame annots
Mar 10, 2023
c3bd722
added default values to BlobLoader to ease initialisation
Mar 10, 2023
cb34c01
mackink tests on single loaded frame
Mar 10, 2023
04b7d15
made _resize_image separate function (will ease use in pixar replay)
Mar 10, 2023
76f45aa
type in function arguments
Mar 10, 2023
e5d3a2b
moved tests for _resize_image to test_bbox
Mar 10, 2023
1ba1a3a
np array instead of tensor to resize_image
Mar 10, 2023
cd9aa5c
setting up default scale value to correct one
Mar 13, 2023
ce9fd40
renamed funciton to load_ to make more obvious inplace modification
Mar 14, 2023
f217eb1
moved crop_by_bbox to FrameData as method
Mar 14, 2023
664d35d
tests fix, typos, linter
Mar 14, 2023
5c249db
renamed crop to crop_ to show inplace modification
Mar 14, 2023
530b9a4
shifting camera according to bbox
Mar 14, 2023
e5500f3
delegated reize_image to FrameData, made bbox_xywh optinal external p…
Mar 15, 2023
0fc3253
using safe_as_tensor for fg_probability
Mar 15, 2023
7c8d89d
made resizing only for loaded objects
Mar 15, 2023
3027cd7
fixing scale
Mar 15, 2023
7d570c1
fixing scale again..
Mar 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 349 additions & 0 deletions pytorch3d/implicitron/dataset/blob_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,349 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Tuple, Union

import numpy as np
import torch
from PIL import Image

from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.utils import _get_bbox_from_mask
from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds


@dataclass
class BlobLoader:
"""
A loader for correctly (according to setup) loading blobs for FrameData.
Beware that modification done in place

Args:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
depth values used for evaluation (the points consistent across views).
load_masks: Enable loading frame foreground masks.
load_point_clouds: Enable loading sequence-level point clouds.
max_points: Cap on the number of loaded points in the point cloud;
if reached, they are randomly sampled without replacement.
mask_images: Whether to mask the images with the loaded foreground masks;
0 value is used for background.
mask_depths: Whether to mask the depth maps with the loaded foreground
masks; 0 value is used for background.
image_height: The height of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
image_width: The width of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
box_crop: Enable cropping of the image around the bounding box inferred
from the foreground region of the loaded segmentation mask; masks
and depth maps are cropped accordingly; cameras are corrected.
box_crop_mask_thr: The threshold used to separate pixels into foreground
and background based on the foreground_probability mask; if no value
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
"""

dataset_root: str = ""
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
load_masks: bool = True
load_point_clouds: bool = False
max_points: int = 0
mask_images: bool = False
mask_depths: bool = False
image_height: Optional[int] = 800
image_width: Optional[int] = 800
box_crop: bool = True
box_crop_mask_thr: float = 0.4
box_crop_context: float = 0.3
path_manager: Any = None

def load_(
self,
frame_data: FrameData,
entry: types.FrameAnnotation,
seq_annotation: types.SequenceAnnotation,
bbox_xywh: Optional[torch.Tensor] = None,
) -> FrameData:
"""Main method for loader.
FrameData modification done inplace
if bbox_xywh not provided bbox will be calculated from mask
"""
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
) = self._load_fg_probability(entry, bbox_xywh)

if self.load_images and entry.image is not None:
# original image size
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)

(
frame_data.image_rgb,
frame_data.image_path,
) = self._load_images(entry, frame_data.fg_probability)

if self.load_depths and entry.depth is not None:
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(entry, frame_data.fg_probability)

if self.load_point_clouds and seq_annotation.point_cloud is not None:
pcl_path = self._fix_point_cloud_path(seq_annotation.point_cloud.path)
frame_data.sequence_point_cloud = _load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
frame_data.sequence_point_cloud_path = pcl_path

clamp_bbox_xyxy = None
if self.box_crop:
clamp_bbox_xyxy = frame_data.crop_by_bbox_(self.box_crop_context)

scale = (
min(
self.image_height / entry.image.size[0],
# pyre-ignore
self.image_width / entry.image.size[1],
)
if self.image_height is not None and self.image_width is not None
else 1.0
)

if self.image_height is not None and self.image_width is not None:
optional_scale = frame_data.resize_frame_(
self.image_height, self.image_width
)
scale = optional_scale or scale

# creating camera taking to account bbox and resize scale
if entry.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(
entry, scale, clamp_bbox_xyxy
)
return frame_data

def _load_fg_probability(
self,
entry: types.FrameAnnotation,
bbox_xywh: Optional[torch.Tensor],
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
fg_probability = None
full_path = None

if (self.load_masks) and entry.mask is not None:
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = _load_mask(self._local_path(full_path))
# we can use provided bbox_xywh or calculate it based on mask
if bbox_xywh is None:
bbox_xywh = _get_bbox_from_mask(fg_probability, self.box_crop_mask_thr)
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
)

return (
_safe_as_tensor(fg_probability, torch.float),
full_path,
_safe_as_tensor(bbox_xywh, torch.long),
)

def _load_images(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = _load_image(self._local_path(path))

if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)

if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability

return image_rgb, path

def _load_mask_depth(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)

if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability

if self.load_depth_masks:
assert entry_depth.mask_path is not None
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = _load_depth_mask(self._local_path(mask_path))
else:
depth_mask = torch.ones_like(depth_map)

return torch.tensor(depth_map), path, torch.tensor(depth_mask)

def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
scale: float,
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)

half_image_size_wh_orig = (
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
)

# first, we convert from the dataset's NDC convention to pixels
format = entry_viewpoint.intrinsics_format
if format.lower() == "ndc_norm_image_bounds":
# this is e.g. currently used in CO3D for storing intrinsics
rescale = half_image_size_wh_orig
elif format.lower() == "ndc_isotropic":
rescale = half_image_size_wh_orig.min()
else:
raise ValueError(f"Unknown intrinsics format: {format}")

# principal point and focal length in pixels
principal_point_px = half_image_size_wh_orig - principal_point * rescale
focal_length_px = focal_length * rescale

# changing principal_point according to bbox_crop
if clamp_bbox_xyxy is not None:
principal_point_px -= clamp_bbox_xyxy[:2]

# now, convert from pixels to PyTorch3D v0.5+ NDC convention
if self.image_height is None or self.image_width is None:
out_size = list(reversed(entry.image.size))
else:
out_size = [self.image_width, self.image_height]

half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
half_min_image_size_output = half_image_size_output.min()

# rescaled principal point and focal length in ndc
principal_point = (
half_image_size_output - principal_point_px * scale
) / half_min_image_size_output
focal_length = focal_length_px * scale / half_min_image_size_output

return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)

def _fix_point_cloud_path(self, path: str) -> str:
"""
Fix up a point cloud path from the dataset.
Some files in Co3Dv2 have an accidental absolute path stored.
"""
unwanted_prefix = (
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
return os.path.join(self.dataset_root, path)

def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)


def _load_image(path) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im


def _load_mask(path) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel


def _load_depth(path, scale_adjustment) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)

d = _load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel


def _load_16big_png_depth(depth_png) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth


def _load_1bit_png_mask(file: str) -> np.ndarray:
with Image.open(file) as pil_im:
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
return mask


def _load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path)
return m[None] # fake feature channel


def _safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None


# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@functools.lru_cache(maxsize=256)
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)

return pcl
114 changes: 114 additions & 0 deletions pytorch3d/implicitron/dataset/dataset_base.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import (
@@ -23,6 +24,14 @@

import numpy as np
import torch
from pytorch3d.implicitron.dataset.utils import (
_bbox_xyxy_to_xywh,
_clamp_box_to_image_bounds_and_round,
_crop_around_box,
_get_clamp_bbox,
_rescale_bbox,
_resize_image,
)
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
@@ -90,6 +99,7 @@ class FrameData(Mapping[str, Any]):
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
cropped: Bool to avoid cropping FrameData twice
"""

frame_number: Optional[torch.LongTensor]
@@ -116,6 +126,7 @@ class FrameData(Mapping[str, Any]):
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})
cropped: bool = False

def to(self, *args, **kwargs):
new_params = {}
@@ -144,6 +155,109 @@ def __getitem__(self, key):
def __len__(self):
return len(fields(self))

def crop_by_bbox_(self, box_crop_context) -> Optional[torch.Tensor]:
if self.cropped:
warnings.warn(
f"You called cropping on same frame twice "
f"sequence_name: {self.sequence_name}, skipping cropping"
)
return None

if (
self.bbox_xywh is None
or self.fg_probability is None
or self.mask_path is None
or self.image_path is None
):
warnings.warn(
"You called cropping without loading frame data"
"please call blob_loader.load_ first, skipping cropping"
)
return None

bbox_xyxy = _get_clamp_bbox(
self.bbox_xywh,
# pyre-ignore
image_path=self.image_path,
box_crop_context=box_crop_context,
)
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
bbox_xyxy,
# pyre-ignore
image_size_hw=tuple(self.image_size_hw),
)
self.crop_bbox_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)

self.fg_probability = _crop_around_box(
self.fg_probability,
clamp_bbox_xyxy,
# pyre-ignore
self.mask_path,
)
self.image_rgb = _crop_around_box(
self.image_rgb,
clamp_bbox_xyxy,
# pyre-ignore
self.image_path,
)

if self.depth_map is not None:
self.depth_map = _crop_around_box(
self.depth_map,
clamp_bbox_xyxy,
# pyre-ignore
self.depth_path,
)
if self.depth_mask is not None:
self.depth_mask = _crop_around_box(
self.depth_mask,
clamp_bbox_xyxy,
# pyre-ignore
self.mask_path,
)
self.cropped = True
return clamp_bbox_xyxy

def resize_frame_(self, image_height, image_width) -> Optional[float]:
if self.bbox_xywh is not None:
self.bbox_xywh = _rescale_bbox(
self.bbox_xywh,
np.array(self.image_size_hw),
# pyre-ignore
self.image_rgb.shape[-2:],
)

scale = None
if self.image_rgb is not None:
self.image_rgb, scale, self.mask_crop = _resize_image(
self.image_rgb, image_height=image_height, image_width=image_width
)

if self.fg_probability is not None:
self.fg_probability, _, _ = _resize_image(
self.fg_probability,
image_height=image_height,
image_width=image_width,
mode="nearest",
)

if self.depth_map is not None:
self.depth_map, _, _ = _resize_image(
self.depth_map,
image_height=image_height,
image_width=image_width,
mode="nearest",
)

if self.depth_mask is not None:
self.depth_mask, _, _ = _resize_image(
self.depth_mask,
image_height=image_height,
image_width=image_width,
mode="nearest",
)
return scale

@classmethod
def collate(cls, batch):
"""
562 changes: 81 additions & 481 deletions pytorch3d/implicitron/dataset/json_index_dataset.py

Large diffs are not rendered by default.

135 changes: 134 additions & 1 deletion pytorch3d/implicitron/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.


from typing import List, Optional
import warnings
from typing import List, Optional, Tuple

import numpy as np

import torch

@@ -52,3 +55,133 @@ def is_train_frame(
dtype=torch.bool,
device=device,
)


def _get_bbox_from_mask(
mask, thr, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(
f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1
)

x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))

return x0, y0, x1 - x0, y1 - y0


def _crop_around_box(tensor, bbox, impath: str = ""):
# bbox is xyxy, where the upper bound is corrected with +1
bbox = _clamp_box_to_image_bounds_and_round(
bbox,
image_size_hw=tensor.shape[-2:],
)
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor


def _clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]


def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh])
return xywh


def _get_clamp_bbox(
bbox: torch.Tensor,
box_crop_context: float = 0.0,
image_path: str = "",
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float

bbox = bbox.clone() # do not edit bbox in place

# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c

if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image {image_path}!! The bounding box contains no pixels."
)

bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)

return bbox_xyxy


def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
assert bbox is not None
assert np.prod(orig_res) > 1e-8
# average ratio of dimensions
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
return bbox * rel_size


def _bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy


def _get_1d_bounds(arr) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1] + 1


def _resize_image(
image, image_height, image_width, mode="bilinear"
) -> Tuple[torch.Tensor, float, torch.Tensor]:

if type(image) == np.ndarray:
image = torch.from_numpy(image)

if image_height is None or image_width is None:
# skip the resizing
return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
image[None],
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width)
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
return imre_, minscale, mask
66 changes: 65 additions & 1 deletion tests/implicitron/test_bbox.py
Original file line number Diff line number Diff line change
@@ -9,11 +9,19 @@
import numpy as np

import torch
from pytorch3d.implicitron.dataset.json_index_dataset import (

from pytorch3d.implicitron.dataset.utils import (
_bbox_xywh_to_xyxy,
_bbox_xyxy_to_xywh,
_clamp_box_to_image_bounds_and_round,
_crop_around_box,
_get_1d_bounds,
_get_bbox_from_mask,
_get_clamp_bbox,
_rescale_bbox,
_resize_image,
)

from tests.common_testing import TestCaseMixin


@@ -76,3 +84,59 @@ def test_mask_to_bbox(self):
expected_bbox_xywh = [2, 1, 2, 1]
bbox_xywh = _get_bbox_from_mask(mask, 0.5)
self.assertClose(bbox_xywh, expected_bbox_xywh)

def test_crop_around_box(self):
bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max)
image = torch.LongTensor(
[
[0, 0, 10, 20],
[10, 20, 5, 1],
[10, 20, 1, 1],
[5, 4, 0, 1],
]
)
cropped = _crop_around_box(image, bbox)
self.assertClose(cropped, image[1:3, 0:2])

def test_clamp_box_to_image_bounds_and_round(self):
bbox = torch.LongTensor([0, 1, 10, 12])
image_size = (5, 6)
expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]])
clamped_bbox = _clamp_box_to_image_bounds_and_round(bbox, image_size)
self.assertClose(clamped_bbox, expected_clamped_bbox)

def test_get_clamp_bbox(self):
bbox_xywh = torch.LongTensor([1, 1, 4, 5])
clamped_bbox_xyxy = _get_clamp_bbox(bbox_xywh, box_crop_context=2)
# size multiplied by 2 and added coordinates
self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11]))

def test_rescale_bbox(self):
bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0])
original_resolution = (4, 4)
new_resolution = (8, 8) # twice bigger
rescaled_bbox = _rescale_bbox(bbox, original_resolution, new_resolution)
self.assertClose(bbox * 2, rescaled_bbox)

def test_get_1d_bounds(self):
array = [0, 1, 2]
bounds = _get_1d_bounds(array)
# make nonzero 1d bounds of image
self.assertClose(bounds, [1, 3])

def test_resize_image(self):
image = np.random.rand(3, 300, 500) # rgb image 300x500
expected_shape = (150, 250)

resized_image, scale, mask_crop = _resize_image(
image, image_height=expected_shape[0], image_width=expected_shape[1]
)

original_shape = image.shape[-2:]
expected_scale = min(
expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1]
)

self.assertEqual(scale, expected_scale)
self.assertEqual(resized_image.shape[-2:], expected_shape)
self.assertEqual(mask_crop.shape[-2:], expected_shape)
225 changes: 225 additions & 0 deletions tests/implicitron/test_blob_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import gzip
import os
import unittest
from typing import List

import numpy as np
import torch

from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.blob_loader import (
_load_16big_png_depth,
_load_1bit_png_mask,
_load_depth,
_load_depth_mask,
_load_image,
_load_mask,
_safe_as_tensor,
BlobLoader,
)
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer.cameras import PerspectiveCameras

from tests.common_testing import TestCaseMixin
from tests.implicitron.common_resources import get_skateboard_data


class TestBlobLoader(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)

category = "skateboard"
stack = contextlib.ExitStack()
self.dataset_root, self.path_manager = stack.enter_context(
get_skateboard_data()
)
self.addCleanup(stack.close)
self.image_height = 768
self.image_width = 512

self.blob_loader = BlobLoader(
image_height=self.image_height,
image_width=self.image_width,
dataset_root=self.dataset_root,
path_manager=self.path_manager,
)

# loading single frame annotation of dataset (see JsonIndexDataset._load_frames())
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
local_file = self.path_manager.get_local_path(frame_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass(
zipfile, List[types.FrameAnnotation]
)
self.frame_annotation = frame_annots_list[0]

sequence_annotations_file = os.path.join(
self.dataset_root, category, "sequence_annotations.jgz"
)
local_file = self.path_manager.get_local_path(sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots_list = types.load_dataclass(
zipfile, List[types.SequenceAnnotation]
)
seq_annots = {entry.sequence_name: entry for entry in seq_annots_list}
self.seq_annotation = seq_annots[self.frame_annotation.sequence_name]

point_cloud = self.seq_annotation.point_cloud
self.frame_data = FrameData(
frame_number=_safe_as_tensor(
self.frame_annotation.frame_number, torch.long
),
frame_timestamp=_safe_as_tensor(
self.frame_annotation.frame_timestamp, torch.float
),
sequence_name=self.frame_annotation.sequence_name,
sequence_category=self.seq_annotation.category,
camera_quality_score=_safe_as_tensor(
self.seq_annotation.viewpoint_quality_score, torch.float
),
point_cloud_quality_score=_safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)

def test_BlobLoader_args(self):
# test that BlobLoader works with get_default_args
get_default_args(BlobLoader)

def test_fix_point_cloud_path(self):
"""Some files in Co3Dv2 have an accidental absolute path stored."""
original_path = "some_file_path"
modified_path = self.blob_loader._fix_point_cloud_path(original_path)
assert original_path in modified_path
assert self.blob_loader.dataset_root in modified_path

def test_load_(self):
bbox_xywh = None
self.frame_data.image_size_hw = _safe_as_tensor(
self.frame_annotation.image.size, torch.long
)
(
self.frame_data.fg_probability,
self.frame_data.mask_path,
self.frame_data.bbox_xywh,
) = self.blob_loader._load_fg_probability(self.frame_annotation, bbox_xywh)

assert self.frame_data.mask_path
assert torch.is_tensor(self.frame_data.fg_probability)
assert torch.is_tensor(self.frame_data.bbox_xywh)
# assert bboxes shape
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
(
self.frame_data.image_rgb,
self.frame_data.image_path,
) = self.blob_loader._load_images(
self.frame_annotation, self.frame_data.fg_probability
)
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
assert self.frame_data.image_path

(
self.frame_data.depth_map,
depth_path,
self.frame_data.depth_mask,
) = self.blob_loader._load_mask_depth(
self.frame_annotation,
self.frame_data.fg_probability,
)
assert torch.is_tensor(self.frame_data.depth_map)
assert depth_path
assert torch.is_tensor(self.frame_data.depth_mask)

clamp_bbox_xyxy = None
if self.blob_loader.box_crop:
clamp_bbox_xyxy = self.frame_data.crop_by_bbox_(
self.blob_loader.box_crop_context
)

# assert image and mask shapes after resize
scale = self.frame_data.resize_frame_(self.image_height, self.image_width)
assert scale
self.assertEqual(
self.frame_data.mask_crop.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.image_rgb.shape,
torch.Size([3, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.mask_crop.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.fg_probability.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.depth_map.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.depth_mask.shape,
torch.Size([1, self.image_height, self.image_width]),
)

self.frame_data.camera = self.blob_loader._get_pytorch3d_camera(
self.frame_annotation,
scale,
clamp_bbox_xyxy,
)
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)

def test_load_image(self):
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
local_path = self.path_manager.get_local_path(path)
image = _load_image(local_path)
self.assertEqual(image.dtype, np.float32)
assert np.max(image) <= 1.0
assert np.min(image) >= 0.0

def test_load_mask(self):
path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
mask = _load_mask(path)
self.assertEqual(mask.dtype, np.float32)
assert np.max(mask) <= 1.0
assert np.min(mask) >= 0.0

def test_load_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
depth_map = _load_depth(path, self.frame_annotation.depth.scale_adjustment)
self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 3)

def test_load_16big_png_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
depth_map = _load_16big_png_depth(path)
self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 2)

def test_load_1bit_png_mask(self):
mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path
)
mask = _load_1bit_png_mask(mask_path)
self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 2)

def test_load_depth_mask(self):
mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path
)
mask = _load_depth_mask(mask_path)
self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 3)