Skip to content

[model] methods to learn kernel banks #7

@ryanhammonds

Description

@ryanhammonds

This package currently solves for autoregressive kernels for each input / image. Instead, we could explore sets of pre-defined, data-driven kernels. This would allow better scaling at inference time (we don't have to learn kernels for each test image; instead measure how well each pre-defined kernel does). This requires extra functionality to generate sets of candidate kernels from a training set, and features based on how well each of these kernels reconstructs an input.

Options for kernels:

  1. Learn kernels for each training images. Then, for each image class, reduce kernels to a set the explains the most variance.
  2. Use a set of pre-defined kernels, e.g. gabor filters

Downstream features:

  1. Convolve the kernel set with inputs images
  2. Extract features based on how well each kernel reconstruct the image

Examples

Learning a set of kernels:

Below, set torch.randn below with something more informative/fixed. Learning could also be on a per-class basis.

import torch
from torch import nn
import torch.nn.functional as F

class Model(nn.Module):
    N_IMAGES = 100
    W = 100
    H = 100
    N_FILTERS = 20
    W_FILTER = 3
    H_FILTER = 3
    PAD = 1
    def __init__(self):
        super().__init__()
        self.filters = nn.Parameter(torch.randn(self.N_FILTERS, 1, self.W_FILTER, self.H_FILTER))
        self.mask = torch.ones_like(self.filters, dtype=torch.bool)
        self.mask[:, :, 1, 1] = False  # zero center pixel
    def forward(self, X):
        return F.conv2d(X, self.filters * self.mask, padding=self.PAD)

# model, loss, optimizer
model = Model()
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

# random images
images = torch.randn(model.N_IMAGES, 1, model.W, model.H)

# train
for iepoch in range(10):
    images_pred = model(images)
    loss = ((images - images_pred)**2).mean()
    loss.backward()
    opt.step()
    opt.zero_grad()
    print(float(loss))

End-to-End

One of the drawbacks of learning like above is that it requires optimization of two models, one to extract features and one to classify. These could be combined into one model and optimized:

  1. Define a set of candidate kernels
  2. Compute residuals between true and predicted
  3. Compute stats/features based on the residuals
  4. Use features to classify
import torch
from torch import nn
import torch.nn.functional as F

N_IMAGES  = 100
W         = 100
H         = 100
N_FILTERS = 64
W_FILTER  = 3
H_FILTER  = 3
PAD       = 1
N_CLASSES = 10

class ARModel(nn.Module):
    
    def __init__(self):
        super().__init__()

        # AR filters
        self.filters = nn.Parameter(
            torch.randn(N_FILTERS, 1, W_FILTER, H_FILTER) # replace
        )

        # Mask
        self.mask = torch.ones_like(self.filters, dtype=torch.bool)
        self.mask[:, :, 1, 1] = False # todo based this on kernel size

        # Residual metrics
        self.metrics = lambda resid : torch.hstack((
            # global features
            resid.abs().mean(dim=(2, 3)), # MAE
            (resid**2).mean(dim=(2, 3)),  # MSE
            resid.var(dim=(2, 3)),        # variance
            # local features
            resid.reshape(*resid.shape[:2], 100, 10, 10).mean(dim=(2, 3)).reshape(100, -1) # todo: don't hardcode reshape
            # todo add more
        ))

        # Classifier
        self.clf = nn.Sequential(
            nn.Linear(832, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, N_CLASSES)
        )

    def forward(self, X):
        residuals = X - F.conv2d(X, self.filters * self.mask, padding=PAD)
        features = self.metrics(residuals)
        return self.clf(features)

# model, loss, optimizer
model = ARModel()
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

# simulate random images and labels
images = torch.randn(N_IMAGES, 1, W, H)
y = torch.randint(0, N_CLASSES,  size=(N_IMAGES,))

# forward pass
for iepoch in range(10):
    logits = model(images)
    loss = loss_fn(logits, y)
    loss.backward()
    opt.step()
    opt.zero_grad()
    print(loss)

Related

Filter banks:

Generative AR models:

Gabor Basis

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions