Skip to content

Spatially varying PSF #269

@ConnorStoneAstro

Description

@ConnorStoneAstro

Is your feature request related to a problem? Please describe.

PSFs are known to vary across an image. It would be very useful to have a PSF model that naturally depended on position.

Describe the solution you'd like

I have heard that one can use a PSF expressed as a basis set, convolve with each of the basis elements and then combine using certain weights to effectively have a spatially varying PSF convolution.

Describe alternatives you've considered

Using caskade functional parameter relations it is possible to make a PSF depend on position, but a new PSF model object is needed for each position.

Additional context

The Following code written by Nicolas Payot achieves this, though I haven't examined it in detail.

class PSFex(nn.Module):
    """
    Convolve an image with a spatially-varying PSFEx model whose basis
    images are the monomials' coefficients.

    Instantiate once, then call many times:

        out = psf_layer(img_crop, x0_pix, y0_pix)

    Parameters
    ----------
    psf_path  : str
    device    : 'cuda' | 'cpu'
    dtype     : torch.dtype   (float32 or float64)
    """

    def __init__(self, psf_path, device='cuda', dtype=torch.float32):
        super().__init__()
        dev = torch.device(device)

        self.des_psfex = galsim.des.DES_PSFEx(psf_path)

        # 1. polynomial metadata
        self.fit_order = int(self.des_psfex.fit_order)
        self.x0, self.y0 = float(self.des_psfex.x_zero), float(self.des_psfex.y_zero)
        self.xs, self.ys = float(self.des_psfex.x_scale), float(self.des_psfex.y_scale)

        # 2. exponent list  (K == #basis images)
        pairs = _poly_pairs(self.fit_order)
        if len(pairs) != len(self.des_psfex.basis):
            raise ValueError(
                f"PSFEx file has {len(self.des_psfex.basis)} basis images but "
                f"{len(pairs)} polynomial terms for fit_order={self.fit_order}."
            )
        self.register_buffer('poly_pairs',
                             _np_to_tensor(pairs, dtype=torch.int16, device=dev))  # (K,2)

        # 3. basis cube  (flip for true convolution)
        basis_np = self.des_psfex.basis  # (K,pH,pW)  big-endian FITS
        basis = _np_to_tensor(basis_np, dtype=dtype, device=dev)
        self.register_buffer('basis',
                             basis.flip(-1, -2).unsqueeze(1))  # (K,1,pH,pW)
        pH, pW = basis.shape[-2:]
        self.pad = (pW // 2, pH // 2)  # same padding

    # ------------------------------------------------------------------
    # forward
    # ------------------------------------------------------------------
    def forward(self, image, x0_pix: int, y0_pix: int):
        """
        Parameters
        ----------
        image   : (H,W)  or  (B,1,H,W) tensor
        x0_pix  : column index (in full frame) of image[0,0]
        y0_pix  : row    index (in full frame) of image[0,0]
        """
        original_shape = image.shape
        if image.ndim == 2:
            image = image.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
        elif not (image.ndim == 4 and image.shape[1] == 1):
            raise ValueError("image must be (H,W) or (B,1,H,W)")

        B, _, H, W = image.shape
        dtype, dev = image.dtype, image.device

        # ---- 1. scaled coordinate grids for this crop
        yy, xx = torch.meshgrid(
            torch.arange(H, device=dev, dtype=dtype) + y0_pix,
            torch.arange(W, device=dev, dtype=dtype) + x0_pix,
            indexing='ij'
        )
        xt = (xx - self.x0) / self.xs
        yt = (yy - self.y0) / self.ys

        # ---- 2. pre-compute powers  xt^k , yt^k  up to fit_order
        max_d = self.fit_order
        xt_p = [torch.ones_like(xt)]
        yt_p = [torch.ones_like(yt)]
        for _ in range(max_d):
            xt_p.append(xt_p[-1] * xt)
            yt_p.append(yt_p[-1] * yt)

        xt_stack = torch.stack(xt_p)  # (d+1, H, W)
        yt_stack = torch.stack(yt_p)  # (d+1, H, W)

        # ---- 3. coefficient maps  c_k(x,y)  =  xt^nx * yt^ny
        nx = self.poly_pairs[:, 0].long()  # (K,)
        ny = self.poly_pairs[:, 1].long()  # (K,)
        coef_maps = xt_stack[nx] * yt_stack[ny]  # (K, H, W)

        # ---- 4. stationary convolutions and weighted sum
        filtered = F.conv2d(image, self.basis, padding=self.pad)  # (B,K,H,W)
        out = torch.sum(filtered * coef_maps.unsqueeze(0), dim=1, keepdim=True)

        return out.squeeze(0).squeeze(0) if original_shape == (H, W) else out

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions