Skip to content

[XPU] Implemented 8bit optimizers in triton #1692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

Egor-Krivov
Copy link
Contributor

@Egor-Krivov Egor-Krivov commented Jul 1, 2025

Implemented 8bit optimizers in triton to use of XPU devices.

Depends on interface from #1706

Tested with BNB_TEST_DEVICE="xpu" pytest --show-capture=no -q tests/test_optim.py::test_optimizer8bit

Benchmarked essentially on the same test, getting better performance than torch optimizer.

This PR contains 3 implementations:

  1. Pure torch implementation that materializes during quantization, hence large memory usage. Can be used for testing purposes.
  2. Combination of torch and triton kernels for quantization+dequantization
  3. Pure triton implementation - fastest, used currenlty.

@Egor-Krivov Egor-Krivov changed the title [Draft] Implemented 8bit optimizers on XPU [Draft][triton] Implemented 8bit optimizers on XPU Jul 1, 2025
@matthewdouglas matthewdouglas added Intel Optimizers Issues or feature requests relating to optimizers labels Jul 1, 2025
@matthewdouglas matthewdouglas added this to the v0.47.0 milestone Jul 1, 2025
@matthewdouglas matthewdouglas self-requested a review July 1, 2025 18:35
@Egor-Krivov Egor-Krivov changed the title [Draft][triton] Implemented 8bit optimizers on XPU [XPU] Implemented 8bit optimizers in triton Jul 11, 2025
@Egor-Krivov
Copy link
Contributor Author

@matthewdouglas This PR is ready for review. Interface with 8bit interface was merged #1706

@Egor-Krivov
Copy link
Contributor Author

Egor-Krivov commented Jul 15, 2025

Benchmarking on 4096*4096 shapes I get about 2x performance gain from 8bit optimizers on GPU Max 1100.

Torch is 32bit optimizer from torch, BNB is 8bit optimizer:

Torch step (eager): 2.972ms
BNB step: 1.325ms
Torch step (eager): 2.954ms
BNB step: 1.308ms
Torch step (eager): 2.985ms
BNB step: 1.290ms
Torch step (eager): 2.957ms
BNB step: 1.283ms
Torch step (eager): 2.951ms
BNB step: 1.320ms
Torch step (eager): 2.943ms
BNB step: 1.250ms
Torch step (eager): 2.959ms
BNB step: 1.318ms

For small shapes difference is smaller (1024*9):

Torch step (eager): 0.257ms
BNB step: 0.256ms
Torch step (eager): 0.253ms
BNB step: 0.260ms

benchmark is based on optimizer tests:

import os
from os.path import join
import shutil
import time
import uuid

from lion_pytorch import Lion
import torch

import bitsandbytes as bnb
import bitsandbytes.functional as F
from bitsandbytes.utils import sync_gpu

# optim_name = "momentum8bit_blockwise"
# optim_name = "rmsprop8bit_blockwise"
# optim_name = "adagrad8bit_blockwise"
# optim_name = "adam8bit_blockwise"
optim_name = "ademamix8bit_blockwise"
# optim_name = "lion8bit_blockwise"

str2optimizers = {}

k = 20
## TODO: maybe remove these three.
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
    None,
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    bnb.optim.Adam,
)

str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam8bit_blockwise"] = (
    torch.optim.Adam,
    lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["paged_adamw8bit_blockwise"] = (
    torch.optim.AdamW,
    lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)

str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix)
str2optimizers["ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx),
)
str2optimizers["paged_ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.PagedAdEMAMix)
str2optimizers["paged_ademamix8bit_blockwise"] = (
    bnb.optim.ademamix._ReferenceAdEMAMix,
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx),
)
str2optimizers["ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["paged_ademamix_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
    lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
    lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
    lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)

str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))

str2optimizers["momentum"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["momentum8bit_blockwise"] = (
    lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)

str2optimizers["rmsprop"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = (
    lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
    lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)


str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]

str2statenames["adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adam8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_adamw8bit_blockwise"] = [
    ("exp_avg", "state1", "qmap1", "absmax1"),
    ("exp_avg_sq", "state2", "qmap2", "absmax2"),
]

str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]

str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_ademamix8bit_blockwise"] = [
    ("m1_m2", "state1", "qmap1", "absmax1"),
    ("nu", "state2", "qmap2", "absmax2"),
]


gtype = [torch.float32, torch.float16, torch.bfloat16][2]

dim2 = 4096
dim1 = 4096

# dim2 = 1024
# dim1 = 1024

device = "xpu"

check_precision = True

gradient_scale = 0.01
# gradient_scale = 1


def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
    idx = torch.isclose(a, b, rtol=rtol, atol=atol)
    error_count = (idx == 0).sum().item()
    if error_count > max_error_count:
        print(f"Too many values not close: assert {error_count} < {max_error_count}")
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


def get_temp_dir():
    path = f"/tmp/autoswap/{uuid.uuid4()}"
    os.makedirs(path, exist_ok=True)
    return path


def rm_path(path):
    shutil.rmtree(path)


def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
    torch.set_printoptions(precision=10)

    if dim1 == 1 and dim2 == 1:
        return

    p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
    p2 = p1.clone()
    p1 = p1.float()

    blocksize = 256

    torch_optimizer = str2optimizers[optim_name][0]([p1])
    bnb_optimizer = str2optimizers[optim_name][1]([p2])

    if gtype == torch.float32:
        atol, rtol = 3e-3, 1e-3
        # atol, rtol = 5e-3, 1e-3
        patol, prtol = 1e-5, 1e-3
    elif gtype == torch.bfloat16:
        atol, rtol = 3e-3, 1e-3
        # atol, rtol = 5e-3, 1e-3
        patol, prtol = 1e-4, 1e-2
    else:
        atol, rtol = 3e-3, 1e-3
        # atol, rtol = 5e-3, 1e-3
        patol, prtol = 1e-5, 1e-3

    errors = []
    relerrors = []

    for i in range(50):
        g = torch.randn(dim1, dim2, device=device, dtype=gtype) * gradient_scale
        p1.grad = g.clone().float()
        p2.grad = g.clone()

        p_copy = p1.clone()

        sync_gpu(p1)
        start = time.time()
        torch_optimizer.step()
        sync_gpu(p1)
        stop = time.time()
        print(f"Torch step (eager): {1000 * (stop - start):.3f}ms")
        sync_gpu(p1)
        start = time.time()
        bnb_optimizer.step()
        sync_gpu(p1)
        stop = time.time()
        print(f"BNB step: {1000 * (stop - start):.3f}ms")

        # since Lion can have pretty noisy updates where things lie at the boundary
        if check_precision:
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

        dequant_states = []
        for name1, name2, qmap, max_val in str2statenames[optim_name]:
            ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
            ## separately and then stack them. The qmap is shared, but absmax is also stacked.
            # if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
            #     m1 = F.dequantize_blockwise(
            #         code=bnb_optimizer.state[p2][qmap],
            #         absmax=bnb_optimizer.state[p2][max_val][0],
            #         A=bnb_optimizer.state[p2][name2][0],
            #         blocksize=blocksize,
            #     )
            #     m2 = F.dequantize_blockwise(
            #         code=bnb_optimizer.state[p2][qmap],
            #         absmax=bnb_optimizer.state[p2][max_val][1],
            #         A=bnb_optimizer.state[p2][name2][1],
            #         blocksize=blocksize,
            #     )

            #     s1 = torch.stack((m1, m2))
            if True:
                s2 = F.dequantize_blockwise(
                    code=bnb_optimizer.state[p2][qmap],
                    absmax=bnb_optimizer.state[p2][max_val],
                    A=bnb_optimizer.state[p2][name2],
                    blocksize=blocksize,
                )

            code = bnb_optimizer.state[p2][qmap]
            absmax = bnb_optimizer.state[p2][max_val]
            A = bnb_optimizer.state[p2][name2]

            s1 = torch_optimizer.state[p1][name1]
            diff = s1 - s2
            # import pdb; pdb.set_trace()

            num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s2, atol=atol, rtol=rtol) == 0
            if check_precision:
                assert num_not_close.sum().item() < 20
            dequant_states.append(s2.clone())

        err = torch.abs(p1 - p2)
        relerr = err / (torch.abs(p1) + 1e-9)
        if g.dtype == torch.bfloat16 and check_precision:
            assert err.mean() <= 0.00017
            assert relerr.mean() <= 0.0016
        elif check_precision:
            assert err.mean() < 0.00006
            assert relerr.mean() < 0.0006

        errors.append(err.mean().item())
        relerrors.append(relerr.mean().item())

        if i % 10 == 0 and i > 0:
            for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
                s1cpy = s.clone()
                raws1cpy = bnb_optimizer.state[p2][name2].clone()
                qmap1 = bnb_optimizer.state[p2][qmap].clone()

                path = get_temp_dir()
                torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
                del bnb_optimizer
                bnb_optimizer = None
                bnb_optimizer = str2optimizers[optim_name][1]([p2])
                bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
                rm_path(path)
                torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
                torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])

                ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
                ## separately and then stack them. The qmap is shared, but absmax is also stacked.
                if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
                    s2 = torch.stack(
                        (
                            F.dequantize_blockwise(
                                code=bnb_optimizer.state[p2][qmap],
                                absmax=bnb_optimizer.state[p2][max_val][0],
                                A=bnb_optimizer.state[p2][name2][0],
                                blocksize=blocksize,
                            ),
                            F.dequantize_blockwise(
                                code=bnb_optimizer.state[p2][qmap],
                                absmax=bnb_optimizer.state[p2][max_val][1],
                                A=bnb_optimizer.state[p2][name2][1],
                                blocksize=blocksize,
                            ),
                        )
                    )
                else:
                    s2 = F.dequantize_blockwise(
                        code=bnb_optimizer.state[p2][qmap],
                        absmax=bnb_optimizer.state[p2][max_val],
                        A=bnb_optimizer.state[p2][name2],
                        blocksize=blocksize,
                    )

                torch.testing.assert_close(s1cpy, s2)

                num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s2, atol=atol, rtol=rtol) == 0
                if check_precision:
                    assert num_not_close.sum().item() < 20

            # Lion can have pretty noisy updates where things lie at the boundary
            assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

        # the parameters diverge quickly. Here we keep them close
        # together so we can test against the Adam error
        p1.data = p1.data.to(gtype).float()
        p2.copy_(p1.data)
        torch.testing.assert_close(p1.to(gtype), p2)
        for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
            torch_optimizer.state[p1][name1].copy_(s.data)


test_optimizer8bit(dim1, dim2, gtype, optim_name, device)

``

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Intel Optimizers Issues or feature requests relating to optimizers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants