Skip to content

Conversation

@Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Oct 14, 2025

Remove autogptq clutter and autogptq related configs that are not worth adding backward compat.

GPTQModel has a slight project name change (pypi package and import name stays the same) to GPT-QModel with - as we now have added awq/AutoAWQ into our repo and will be making pr soon to address awq loading using GPT-QModel.

GPTQConfig has the most important changes in this PR:

# New GPTQConfig Property. Applicable for sister Peft/Optimum PRs
act_group_aware (`bool`, *optional*, defaults to `True`):
    Use GAR (group aware activation order) during quantization. Has measurable positive impact on quantization
    quality. Only applicable when `desc_act = False`. Will forced to be `False` when `desc_act = True`.
    
    
# Removed GPTQConfig Properties:
use_cuda_fp16
use_exllama
exllama_config

The 3 removed properties are all related kernel selection. These 3 are a hot potatoe mess and legacy from autogptq. GPT-QModel uses unified backend (existing) property to select kernels. There were compat codes written to convert these 3 properties to backend behind the scenes in 2024 but no longer relevant for 2025.

Note:

  • Transformers/Optimum/Peft CI tests should never check for kernel.QUANT_TYPE (str). GPTQ-QModel will return best performing kernel for the relevant module and it may be different per module due to in/out features and other gptq/module properties in relation to device type + dtype + many factors.
  • CI tests should only assert check for kernel.QUANT_TYPE if the test specifies a specific kernel via backend selection.

@Rocketknight1
Copy link
Member

cc @MekkCyber for quantization

@Qubitium Qubitium changed the title [WIP] Fully deprecate AutoGPTQ for GPT-QModel [WIP] Fully deprecate AutoGPTQ and AutoAWQ for GPT-QModel Nov 20, 2025
@Qubitium
Copy link
Contributor Author

We have begun AutoAWQ deprecation as well.

  • Fused module codes have all been removed. AutoAWQ used to do quant linear level fusing but I do not believe that this is maintainable or good since if SGLang/vLLM adopts Transformers v5 for model loading, they will do their own auto fusing and the quant module should not interfere with this.

  • IPEX is deprecated by Intel and we have a new AwqTorchFused kernel (based on same Intel TorchFused kernel for GPTQ) so any code/unit tests for IPEX now will point to AwqTorchFused kernel.

@MekkCyber
Copy link
Contributor

Hi @Qubitium ! Thanks a lot for working on this! Quick question, what do you mean by AutoAWQ being part of GPT-QModel now? Did you integrate the entire library (including the transformers dependency, like AutoAWQ does), or did you just port over the linear layers, kernels, and related components?

@Qubitium
Copy link
Contributor Author

Qubitium commented Nov 20, 2025

Hi @Qubitium ! Thanks a lot for working on this! Quick question, what do you mean by AutoAWQ being part of GPT-QModel now? Did you integrate the entire library (including the transformers dependency, like AutoAWQ does), or did you just port over the linear layers, kernels, and related components?

Long story short. We folded AutoAWQ into GPT-QModel in multiple stage over the past few weeks. Stage 1. Simple/Directly port/copy the AutoAWQ code over. Stage 2. Refractor. Stage 3. Fixed bugs, add new kernels. Major refractor to align with new internal life cycle in GPT-QModel v5.0+. So we are current post Stage 3 where GPT-QModel base retains minimal original AutoAWQ code. Most AutoAWQ code have been refractored away.

Major Changes vs AutoAWQ:

  1. New kernels. We have 2 new kernels for AWQ (AwqTorch pure torch based, and AwqTorchFused which is cpu optimized based on work by Intel @jiqing-feng.
  2. Plan to add 3rd new AWQ kernel based on Bitblas as most gptq kernels are compatbile with AWQ with some small changes. Marlin kernel also sycned with gptq Marlin kernel for updated Marlin fixes/otpimizations via vllm port.
  3. QuantLinear code have been rewritten/refractored.
  4. Quant logic is new due to GPT-QModel 5.0+ life cycle which is not compatible with AutoAWQ.

HF eco system compat:

Work on Peft integration is happening in a parallel pr by @LRL2-ModelCloud huggingface/peft#2917 in coordination from @BenjaminBossan huggingface/peft#2342 (comment)

The Peft pr will need to co-exist concurrently with this PR due to interdependency.

We will hold off Optimum change last if we can help it, or may have to parallel a 3rd Pr to Optimum as well if inter-dependency causes trouble there as well.

Final goal of the 2 Prs is to remove dead AutoGPTQ code (no one uses it or should use it frankly) and almost dead AutoAWQ (repo in readonly and no longer accepting bug fixes or new model support). Compat of model loading of old models that use these two packages will be maintained.

@SunMarc
Copy link
Member

SunMarc commented Nov 20, 2025

Thanks for working on this @Qubitium . We are still debating if this is something that should offload to GPTQ-Model or we should start upstreaming some of the inference code directly into transformers + kernels. Here is a proposal from a contributor #42256.
The goal would be to only upstream the GEMM path but we can potentially leave the other kernels to GPTQ-Model WDYT ?

About GPTQ-Model, will it be possible to create awq quants for newer models that are compatible with other frameworks (e.g vllm) just like autoawq did ?

@Qubitium
Copy link
Contributor Author

Qubitium commented Nov 20, 2025

WDYT

I think badly of this proposal.

Thanks for working on this @Qubitium . We are still debating if this is something that should offload to GPTQ-Model or we should start upstreaming some of the inference code directly into transformers + kernels. Here is a proposal from a contributor #42256. The goal would be to only upstream the GEMM path but we can potentially leave the other kernels to GPTQ-Model WDYT ?

About GPTQ-Model, will it be possible to create awq quants for newer models that are compatible with other frameworks (e.g vllm) just like autoawq did ?

I just checked the PR which has no code. I am not going to waste time arguing over vaporware vs what I have done with AWQ in GPT-QModel over the past 2 months that is awq inference and quantization full stack complete with new kernels, new model support with full ci kernel and modeling validation. GPT-Qmodel can be viewed as not an AutoAWQ port and a full point release upgrade in every regard.

Edit: I have outlined in prior post on why Fusing is a bad idea. It is not AWQ's job to fuse in 2025. Leave it to model makers and higher level engines such as SGLang, vlLLM which HF v5.0 is targeting from my understading.

@Qubitium
Copy link
Contributor Author

Qubitium commented Nov 20, 2025

About GPTQ-Model, will it be possible to create awq quants for newer models that are compatible with other frameworks (e.g vllm) just like autoawq did ?

Our quantized models are more compatible with vllm/SGLang than ones quantized with Optimum or AutoAWQ

SGLang/vLLM compat is a number 1 target/design from day one so 100% yes.

@jiqing-feng
Copy link
Contributor

It's a good chance to deprecate autoawq as it's archived. I suppose the best way is to go upstream to the transformer's main codes, just like we did in the Autogptq replacement. For example, the IPEX linear in AutoAWQ is out-of-date, we need a new implementation for it. The new linear implementation is TorchFusedLinear in gptqmodel.

@Qubitium
Copy link
Contributor Author

It's a good chance to deprecate autoawq as it's archived. I suppose the best way is to go upstream to the transformer's main codes, just like we did in the Autogptq replacement. For example, the IPEX linear in AutoAWQ is out-of-date, we need a new implementation for it. The new linear implementation is TorchFusedLinear in gptqmodel.

@jiqing-feng The AWQ version of GPTQ TorchFusedKernel has been added to gpt-qmodel as AwqTorchFusedKernel. Same underlying code but memory layout tweaks to get it to work. AWQ Kernel output tests passing.

@SunMarc For the most part the kernels for AWQ and GPTQ are shared. For example, we do not compile an extra awq only Marlin kernel for AWQ. The previous gptq only Marlin kernel is synced from vLLM to run AWQ weights as well.

@Qubitium
Copy link
Contributor Author

Qubitium commented Nov 21, 2025

CI Passing status using GPT-QModel main branch:

transformers/tests/quantization/autoawq/test_awq.py:
test_awq.py::AwqTest::test_quantized_model PASSED
test_awq.py::AwqTest::test_quantized_model_bf16 PASSED
test_awq.py::AwqTest::test_quantized_model_conversion PASSED
test_awq.py::AwqTest::test_quantized_model_exllama FAILED <-- Needs fixing. 
test_awq.py::AwqTest::test_quantized_model_multi_accelerator SKIPPED
test_awq.py::AwqTest::test_quantized_model_no_device_map PASSED
test_awq.py::AwqTest::test_save_pretrained PASSED
test_awq.py::AwqTest::test_raise_if_non_quantized PASSED
test_awq.py::AwqTest::test_quantized_model_no_k_proj_quantized PASSED
test_awq.py::AwqScaleTest::test_load_quantized_model PASSED
test_awq.py::AwqIPEXTest::test_quantized_model_ipex PASSED <-- test needs to be renamed to AwqTorchFused (ipex removed)
 
peft/tests/test_gpu_examples.py:
PeftAwqGPUTests PASSED
PeftGPTQGPUTests PASSED

@SunMarc PR is in working state and ready for prelim review. Look at the code diffs, we are eliminating 5x more crud for every line of code we add for the new awq integration.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a couple of comments. As I said, I'm happy to see that you are willing to fill the hole left by AutoAWQ and eager to see this PR merged. However, note that maybe in the future, we will add a default working path for GEMM if gptq-model is not installed. As those libraries depends on kernels that requires to deal with building + distribution for each new version of torch, we never know when this will suddenly stop.
Left a couple of comments. Also maybe it will be better to split this PR into 2: one for gptq and one for awq ?

Comment on lines -132 to -137
if self.quantization_config.do_fuse:
from ..integrations import fuse_awq_modules

model = fuse_awq_modules(model, self.quantization_config)
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fused module codes have all been removed. AutoAWQ used to do quant linear level fusing but I do not believe that this is maintainable or good since if SGLang/vLLM adopts Transformers v5 for model loading, they will do their own auto fusing and the quant module should not interfere with this.

Fusing only happens when the user specify do_fuse in the config when loading the awq model using from_pretained, so it shouldn't impact at all sglang or vllm at all. Also you can't serialize the model if we fuse the modules. So I think we should still try to maintain that path if possible

Copy link
Contributor Author

@Qubitium Qubitium Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fusing has no place in AWQ. QKV, MLP fusing should be done at an upper level just like what vLLM and SGLang is doing. This is wrong code/logic for 2025. AutoAWQ did it to squeeze some inference perpformance but if you look at the code, it is a hot mess of static model class mapping and unmtainable. The number of new models coming out and emergence of SGLang, vLLM make this obsolete. It is my understanding that v5 HF is going to be used for SGLang/vLLM loading model foudnation so that makes it even more so.

For those users that depend on AutoAWQ fusing, they need to choose. We are not gonna spend vaulable energy supporting dead code.

Comment on lines +127 to +129
if not is_gptqmodel_available():
raise ValueError(
"AWQ (either `autoawq` or `llmawq`) is not available. Please install it with `pip install autoawq` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
"AWQ (either `llmawq`) is not available. Please install it with `pip install gptqmodel` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note but if this doesn't make into v5, we will have to slowly deprecate autoawq

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to make this as clean as possible to reach that v5 goal. We are removig 75% of the code, not adding. Other than fusing, the features are not deprecated and only improved with zero compt issues. New kernels, bettter hw compat, faster kernels, and even bug fixes. Current awq kernels are fp16 only and failed all our bf16 kernel output quality tests. We will make sure users do not execute in bf16 or at least warn when this happens (loading model and executing in bf16 when awq kernels are validated for fp16 only).

from ..modeling_utils import PreTrainedModel

from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging
from ..utils import is_gptqmodel_available, is_optimum_available, is_torch_available, logging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do two seperate pr for gptq and awq ? For gptq one, I will be able to quickly merge it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huggingface/peft#2917 (comment)

GPTQ changes are minimal and mostly cosmetic. But this PR is required for huggingface/peft#2917 (comment) due to interdependency.

Comment on lines +134 to +136
from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear

target_cls = WQLinear_GEMM
target_cls = AwqGEMMQuantLinear
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said, we might have replace this path by one handled by kernels at some point

Copy link
Contributor Author

@Qubitium Qubitium Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. At that point in the future, HF staff can override the auto kernel selection code from gpt-qmodel and return specific AwqGEMM from kernels. It will be a clean override and requires no changes from gpt-qmodel. My point is that it is unreasonable to burden our task further by imposing a future optional requirement that that does not resolve the issue now.

Comment on lines +138 to 146
from gptqmodel.nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear

target_cls = WQLinear_GEMV
target_cls = AwqGEMVQuantLinear
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import WQLinear_Exllama
from gptqmodel.nn_modules.qlinear.awq_exllama import AwqExllamaQuantLinear

target_cls = WQLinear_Exllama
target_cls = AwqExllamaQuantLinear
elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlike gptq, the version selection is not automatic, is that right ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checked this code and it is incomplete. Need to changed to do auto-kernel selection just like gptq kernel selection. The reason is same format (version) has multiple compatible kernels (GEMM can be mapped to [AwqTorch, AwqTorchFued, AwqGEMM, AwqMarlin]) and in the same reason unreasonable to expect users to manually pass backend to select kernels.

After udpate, this entire block of manual kernel seection will replaced by one line of hf_select_awq_kernel or something similar.

Note that we will be removing old awq flawed terminalogy of version (actually format), and backend (no need for this unique for llm-awq as we will auto-compat during config loading for llm-awq where there is no quant_method and only version attribute). Backward compat will be maintained via config load/save mapping.

@SunMarc SunMarc requested a review from MekkCyber November 21, 2025 17:12
@Qubitium
Copy link
Contributor Author

Qubitium commented Nov 22, 2025

@SunMarc @MekkCyber Hold off on review. I will ping once ready. I need to remove more code related to fuse and kernel selection.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: autoawq, gptq

@Qubitium
Copy link
Contributor Author

Qubitium commented Nov 24, 2025

@SunMarc @MekkCyber Update. This PR will updated once we finish small refractor and add sync auto kernel selection just like what we did with gptq in ModelCloud/GPTQModel#2214. Both gptq and awq kernel selection will be folded into single hf_select_quant_linear_v2 interface for stability and single entry point.

in addition, the original AwqGEMM kernel will be split into effectively 3 distinct kernels, TorchGEMM, CudaGEMM, TritonGEMM. The autoawq gemm kernel was actually 3 kernels in one monolithic one. Sound nice but terrible for ci/kernel output regression/comparison tests with zero performance benefit. GPT-QModel will auto select the kernels based on system env, device_map, kernel qualifications (method, format, etc). This will knock off another layer of complexity in existing HF code.

# public/stable api exposed to transformer/optimum
def hf_select_quant_linear_v2(
        bits: int,
        group_size: int,
        desc_act: bool,
        sym: bool,
        format: Union[str, FORMAT], # awq `version` should be pre-mapped to format
        quant_method: Union[str, METHOD], # awq llm-awq `version` should be pre-mapped to method
        zero_point: Optional[bool] = True, # awq only
        dtype: Optional[Union[str, torch.dtype]] = None,
        meta: Optional[Dict[str, any]] = None,
        pack: Optional[bool] = True,
        device_map: Optional[Union[str, dict]] = None,
        backend: Optional[Union[str, BACKEND]] = None,
) -> Type[BaseQuantLinear]:

@jiqing-feng
Copy link
Contributor

Thanks, left a couple of comments. As I said, I'm happy to see that you are willing to fill the hole left by AutoAWQ and eager to see this PR merged. However, note that maybe in the future, we will add a default working path for GEMM if gptq-model is not installed. As those libraries depends on kernels that requires to deal with building + distribution for each new version of torch, we never know when this will suddenly stop. Left a couple of comments. Also maybe it will be better to split this PR into 2: one for gptq and one for awq ?

Does it mean we can upstream some specific ops for awq or gptq in the kernel-community? In that case, gptqmodel can pull kernels from the community at runtime?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants