Skip to content

A modular and easy-to-use framework for Test-Time Training (TTT) and Test-Time Adaptation (TTA) in Pytorch, making your networks more generalizable with minimal effort ✨

License

Notifications You must be signed in to change notification settings

nikitadurasov/torch-ttt

Repository files navigation

TorchTTT


PyPI GitHub stars GitHub forks Documentation Testing Downloads Monthly Downloads Issues License

torch-ttt

torch-ttt is a comprehensive PyTorch library for Test-Time Training (TTT) and Test-Time Adaptation techniques. It helps make your neural networks more robust and generalizable to distribution shifts, corruptions, and out-of-distribution data—without requiring access to training data or labels at test time.

The library is designed to be modular, easy to integrate into existing PyTorch pipelines, and collaborative—we aim to include as many TTT methods as possible. If you've developed a TTT method, reach out to add yours!

>> You can find our webpage and documentation here: torch-ttt.github.io

torch-ttt is under active development. The API may change as we add new features and methods. Contributions are highly welcome! If you encounter any bugs or have feature requests, please submit an issue.

What is Test-Time Training?

Test-Time Training (TTT) is a paradigm where models adapt to test data during inference by optimizing self-supervised auxiliary objectives—without accessing training data or test labels. This helps models handle distribution shifts, corruptions, and out-of-distribution data.

Test-Time Training Schema

torch-ttt implements TTT methods through a unified Engine abstraction. Each Engine encapsulates the complete adaptation logic of a specific TTT method, allowing you to:

  • Wrap any PyTorch model with a single line of code
  • Switch between different TTT methods seamlessly
  • Adapt models at inference time without modifying your existing pipeline

This modular design makes it easy to experiment with different adaptation strategies and find the best approach for your specific use case.

Key Features

torch-ttt provides a streamlined API through Engines—lightweight wrappers around your PyTorch models. All Engines follow the same interface, making them easy to use and highly modular.

Test-Time Training Schema

You can add test-time adaptation with just a few lines of code, and switch between methods seamlessly. The library includes comprehensive tutorials and examples for every method, with efficient implementations suitable for both research and production deployment.

Check out the Quick Start guide or the API reference for more details.

Supported Methods

torch-ttt includes implementations of the following test-time training and adaptation methods:

Method Class Paper Description
TTT TTTEngine Sun et al. 2020 Original test-time training with self-supervised rotation prediction
TTT++ TTTPPEngine Liu et al. 2021 Improved TTT with contrastive learning
Masked TTT MaskedTTTEngine Gandelsman et al. 2022 Self-supervised masked reconstruction for adaptation
TENT TentEngine Wang et al. 2021 Entropy minimization for test-time adaptation
EATA EataEngine Niu et al. 2022 Efficient anti-catastrophic adaptation
MEMO MemoEngine Zhang et al. 2022 Marginal entropy minimization with one test point
ActMAD ActMADEngine Mirza et al. 2022 Activation matching for domain adaptation
DeYO DeYOEngine Mummadi et al. 2021 Test-time training with deep Y-shaped networks
IT3 IT3Engine Eastwood et al. 2024 Iterative test-time training

Want to see your method here? We welcome contributions!

Installation

Requirements: Python 3.10+, PyTorch 1.12+

Install from PyPI:

pip install torch-ttt

Install from source:

pip install git+https://github.com/nikitadurasov/torch-ttt.git

Quick Start

Here's a minimal example showing how to use torch-ttt to adapt a model at test time:

import torch
import torchvision.models as models
from torch_ttt.engine.tent_engine import TentEngine

# Load your pre-trained model
model = models.resnet50(pretrained=True)

# Wrap it with a TTT Engine (e.g., TENT for entropy minimization)
engine = TentEngine(
    model=model,
    optimization_parameters={
        "lr": 2e-3,
        "num_steps": 1
    }
)

# Switch to eval mode
engine.eval()

# At test time, adapt to new data
test_images = torch.randn(8, 3, 224, 224)

# The engine automatically adapts the model during forward
adapted_output = engine(test_images)

Documentation

Comprehensive documentation is available at torch-ttt.github.io, including:

Contributing

We welcome contributions! To add a new TTT method, report bugs, or improve documentation:

  1. Fork the repository
  2. Create a new engine inheriting from BaseEngine
  3. Add tests and documentation
  4. Submit a pull request

See GitHub Issues for bug reports and feature requests.

Citation

If you use torch-ttt in your research, please cite:

@software{durasov2024torchttt,
  author    = {Durasov, Nikita},
  title     = {torch-ttt: A Unified PyTorch Library for Test-Time Training},
  year      = {2024},
  doi       = {10.5281/zenodo.17620711},
  url       = {https://github.com/nikitadurasov/torch-ttt},
}

Also cite the original papers of the methods you use. See our Papers page.

License

MIT License - see the LICENSE file for details.

About

A modular and easy-to-use framework for Test-Time Training (TTT) and Test-Time Adaptation (TTA) in Pytorch, making your networks more generalizable with minimal effort ✨

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •  

Languages