-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Is your feature request related to a problem? Please describe.
PSFs are known to vary across an image. It would be very useful to have a PSF model that naturally depended on position.
Describe the solution you'd like
I have heard that one can use a PSF expressed as a basis set, convolve with each of the basis elements and then combine using certain weights to effectively have a spatially varying PSF convolution.
Describe alternatives you've considered
Using caskade functional parameter relations it is possible to make a PSF depend on position, but a new PSF model object is needed for each position.
Additional context
The Following code written by Nicolas Payot achieves this, though I haven't examined it in detail.
class PSFex(nn.Module):
"""
Convolve an image with a spatially-varying PSFEx model whose basis
images are the monomials' coefficients.
Instantiate once, then call many times:
out = psf_layer(img_crop, x0_pix, y0_pix)
Parameters
----------
psf_path : str
device : 'cuda' | 'cpu'
dtype : torch.dtype (float32 or float64)
"""
def __init__(self, psf_path, device='cuda', dtype=torch.float32):
super().__init__()
dev = torch.device(device)
self.des_psfex = galsim.des.DES_PSFEx(psf_path)
# 1. polynomial metadata
self.fit_order = int(self.des_psfex.fit_order)
self.x0, self.y0 = float(self.des_psfex.x_zero), float(self.des_psfex.y_zero)
self.xs, self.ys = float(self.des_psfex.x_scale), float(self.des_psfex.y_scale)
# 2. exponent list (K == #basis images)
pairs = _poly_pairs(self.fit_order)
if len(pairs) != len(self.des_psfex.basis):
raise ValueError(
f"PSFEx file has {len(self.des_psfex.basis)} basis images but "
f"{len(pairs)} polynomial terms for fit_order={self.fit_order}."
)
self.register_buffer('poly_pairs',
_np_to_tensor(pairs, dtype=torch.int16, device=dev)) # (K,2)
# 3. basis cube (flip for true convolution)
basis_np = self.des_psfex.basis # (K,pH,pW) big-endian FITS
basis = _np_to_tensor(basis_np, dtype=dtype, device=dev)
self.register_buffer('basis',
basis.flip(-1, -2).unsqueeze(1)) # (K,1,pH,pW)
pH, pW = basis.shape[-2:]
self.pad = (pW // 2, pH // 2) # same padding
# ------------------------------------------------------------------
# forward
# ------------------------------------------------------------------
def forward(self, image, x0_pix: int, y0_pix: int):
"""
Parameters
----------
image : (H,W) or (B,1,H,W) tensor
x0_pix : column index (in full frame) of image[0,0]
y0_pix : row index (in full frame) of image[0,0]
"""
original_shape = image.shape
if image.ndim == 2:
image = image.unsqueeze(0).unsqueeze(0) # (1,1,H,W)
elif not (image.ndim == 4 and image.shape[1] == 1):
raise ValueError("image must be (H,W) or (B,1,H,W)")
B, _, H, W = image.shape
dtype, dev = image.dtype, image.device
# ---- 1. scaled coordinate grids for this crop
yy, xx = torch.meshgrid(
torch.arange(H, device=dev, dtype=dtype) + y0_pix,
torch.arange(W, device=dev, dtype=dtype) + x0_pix,
indexing='ij'
)
xt = (xx - self.x0) / self.xs
yt = (yy - self.y0) / self.ys
# ---- 2. pre-compute powers xt^k , yt^k up to fit_order
max_d = self.fit_order
xt_p = [torch.ones_like(xt)]
yt_p = [torch.ones_like(yt)]
for _ in range(max_d):
xt_p.append(xt_p[-1] * xt)
yt_p.append(yt_p[-1] * yt)
xt_stack = torch.stack(xt_p) # (d+1, H, W)
yt_stack = torch.stack(yt_p) # (d+1, H, W)
# ---- 3. coefficient maps c_k(x,y) = xt^nx * yt^ny
nx = self.poly_pairs[:, 0].long() # (K,)
ny = self.poly_pairs[:, 1].long() # (K,)
coef_maps = xt_stack[nx] * yt_stack[ny] # (K, H, W)
# ---- 4. stationary convolutions and weighted sum
filtered = F.conv2d(image, self.basis, padding=self.pad) # (B,K,H,W)
out = torch.sum(filtered * coef_maps.unsqueeze(0), dim=1, keepdim=True)
return out.squeeze(0).squeeze(0) if original_shape == (H, W) else out
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels