Open
Description
Thanks for your great library 👍
🐛 Bugs / Unexpected behaviors
I tried to write my custom PointsRenderer class to render a colored pointcloud with an fov orthographic camera. But it turns out the tensor contained in the dists attribute computed by the rasterizer only contains zero or -1 values. On the other hand the zbuf attribute is alright.
The z_near and z_far also don't seem to have any effect.
Instructions To Reproduce the Issue:
My specs:
pytorch3d==0.6.1
torch==1.10.0+cu111
torchaudio==0.10.0
torchvision==0.11.1+cu111
In the example below, the camera is located at (0,2,0), looks towards -Y and there are a red plane at Y=-1 and a green plane at Y=1.
Even though I explicitly ask for z_near=2.0 I still see the green plane in the rendered image.
# /usr/bin/python3
"""Image Renderer."""
import matplotlib.pyplot as plt
import numpy as np
import pytorch3d.structures as torch3d
import torch
from pytorch3d.renderer import NormWeightedCompositor
from pytorch3d.renderer import PointsRasterizationSettings
from pytorch3d.renderer import PointsRasterizer
from pytorch3d.renderer.cameras import FoVOrthographicCameras
from torch import nn
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def generate_colored_planes(n_pts: int) -> torch3d.Pointclouds:
"""Generate red plane at y=-1, green plane at y=1"""
verts = torch.rand(n_pts, 3) * 2.0 - 1.0
colors = torch.zeros(n_pts, 3)
half_n = int(n_pts/2)
verts[:half_n, 1] = -1 # Red floor at Y=-1
colors[:half_n] = torch.Tensor([1., 0., 0.])
verts[half_n:, 1] = 1 # Green ceiling at Y=1
colors[half_n:] = torch.Tensor([0., 1., 0.])
return torch3d.Pointclouds(points=[verts], features=[colors]).to(device=DEVICE)
class ImageRenderer(nn.Module):
def __init__(self,
rasterizer,
compositor):
super().__init__()
self.rasterizer = rasterizer
self.compositor = compositor
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = self.rasterizer.raster_settings.radius
print(f"{torch.min(fragments.dists[fragments.dists >=0])=}")
print(f"{torch.max(fragments.dists[fragments.dists >=0])=}")
print(f"{torch.min(fragments.zbuf[fragments.zbuf >=0])=}")
print(f"{torch.max(fragments.zbuf[fragments.zbuf >=0])=}")
dists2 = fragments.dists.permute(0, 3, 1, 2)
weights = 1 - dists2 / (r * r)
images = self.compositor(
fragments.idx.long().permute(0, 3, 1, 2),
weights,
point_clouds.features_packed().permute(1, 0),
**kwargs,
)
# permute so image comes at the end
images = images.permute(0, 2, 3, 1)
return images
def render_pcd(pcd: torch3d.Pointclouds,
rotations: torch.Tensor,
translations: torch.Tensor,
z_near: float,
z_far: float,
focal_length: int = 100,
width: int = 256):
cameras = FoVOrthographicCameras(
znear=z_near,
zfar=z_far,
R=rotations,
T=translations,
device=DEVICE,
scale_xyz=((2*focal_length/width,
2*focal_length/width,
1.0),)
)
raster_settings = PointsRasterizationSettings(
image_size=(width, width),
radius=0.01,
points_per_pixel=3
)
rasterizer = PointsRasterizer(
cameras=cameras,
raster_settings=raster_settings
)
renderer = ImageRenderer(
rasterizer=rasterizer,
compositor=NormWeightedCompositor()
)
return torch.squeeze(renderer(pcd))
cube_pcd = generate_colored_planes(100000)
# Topview: looking towards -Y from (0,2,0)
rotations = torch.Tensor([[[1., 0., 0.],
[0., 0., -1.],
[0., 1., 0.]]]).to(device=DEVICE)
translations = torch.Tensor([[0., 0., 2.]]).to(device=DEVICE)
torch_img = render_pcd(cube_pcd, rotations, translations,
z_near=2.0, # we shouldn't see the red plane
z_far=10)
img = (torch_img.cpu().numpy()*255).astype(np.uint8)
plt.ioff()
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.show()
Here's the output
torch.min(fragments.dists[fragments.dists >=0])=tensor(9.8953e-10, device='cuda:0')
torch.max(fragments.dists[fragments.dists >=0])=tensor(9.9999e-05, device='cuda:0')
torch.min(fragments.zbuf[fragments.zbuf >=0])=tensor(1., device='cuda:0')
torch.max(fragments.zbuf[fragments.zbuf >=0])=tensor(3., device='cuda:0')