Skip to content

[Draft] Qualcomm AI Engine Direct - Unexpected graph for mutable buffer after export during Quantization #11309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

shewu-quic
Copy link
Collaborator

@shewu-quic shewu-quic commented Jun 3, 2025

Issue

We observed that a copy node was not inserted for the mutable buffer after export during quantization.
The following results can be reproduced to generate the graph using this PR:

python3 backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_index_put -b build-android/ -s <serial> -H <host> -m SM8650

export.svg

image

Background

Given that the GA model is currently being enabled, some models use index_put/index_copy to update the key-value cache, similar to the Llama in Executorch.
In previous PR, we observed that a copy node would be inserted after the mutable buffer (b_k_cache), even if the input of the index_put node was frozen as b__frozen_param0. Therefore, we added a workaround pass to replace the input of index_put.

In the past

image

Seems workaround

I am curious why the Llama model is not affected. I found that you applied a patch for Llama export, which seems to reproduce the previous result.
Is this the expected solution for the mutable buffer issue?

export_with_patch.svg

image

cc @haowhsu-quic @winskuo-quic @DannyYuyang-quic

Copy link

pytorch-bot bot commented Jun 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/11309

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 5985058 with merge base 879eee0 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 3, 2025
Copy link

github-actions bot commented Jun 3, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,

While enabling GA mode, we encountered an issue where the copy node was not inserted for the mutable buffer node after export during quantization. This results in the mutable buffer being treated as a constant buffer after compiling the graph. I noticed that you applied a patch to fix this issue in Llama.

I have two questions regarding this issue:

  1. Do you know what might be causing this problem?
  2. Is applying the patch the expected solution for the mutable buffer issue?

@JacobSzwejbka
Copy link
Contributor

JacobSzwejbka commented Jun 3, 2025

I think the general flow today is

export -> high level non functional ir

<quant happens here>

edge -> decomps to functional ir + smaller operator lib

to_executorch -> where we will try and reinject some mutation.

Export used to spit out functional ir but it hasnt for some time. It might be possible for a quantizer to functionalize the graph on its own. Right now I believe the logic is coupled with the decompose api not totally sure. Asking around.

@cccclai
Copy link
Contributor

cccclai commented Jun 3, 2025

Asking compiler team (PoC for torch.export) and quantization team (PoC for quantization) to help...

@shewu-quic
Copy link
Collaborator Author

I think the general flow today is

export -> high level non functional ir

edge -> decomps to functional ir + smaller operator lib

to_executorch -> where we will try and reinject some mutation.

Export used to spit out functional ir but it hasnt for some time. It might be possible for a quantizer to functionalize the graph on its own. Right now I believe the logic is coupled with the decompose api not totally sure. Asking around.

In my view, the main challenge is integrating the mutable buffer feature with the quantization flow.

export -> high level non functional ir (Issue: Missing mutable buffer information in graph signature)

prepare_pt2e
calibration
convert_pt2e (Issue: Frozen the mutable buffer -> our workaround: Replace frozen param with mutable buffer)

edge -> decomps to functional ir + smaller operator lib

to_executorch -> where we will try and reinject some mutation.

@cccclai
Copy link
Contributor

cccclai commented Jun 4, 2025

I think the general flow today is
export -> high level non functional ir

edge -> decomps to functional ir + smaller operator lib
to_executorch -> where we will try and reinject some mutation.
Export used to spit out functional ir but it hasnt for some time. It might be possible for a quantizer to functionalize the graph on its own. Right now I believe the logic is coupled with the decompose api not totally sure. Asking around.

In my view, the main challenge is integrating the mutable buffer feature with the quantization flow.

export -> high level non functional ir (Issue: Missing mutable buffer information in graph signature)

prepare_pt2e calibration convert_pt2e (Issue: Frozen the mutable buffer -> our workaround: Replace frozen param with mutable buffer)

edge -> decomps to functional ir + smaller operator lib

to_executorch -> where we will try and reinject some mutation.

If we have a way to get a functional IR during quantization, does it solve the issue?

@shewu-quic
Copy link
Collaborator Author

I'm not very clear about the relationship between functional IR and the mutable buffer issue. As far as I know, functional IR refers to operators that do not mutate or alias inputs. However, there seem to be two issues here:

  1. After quantization export, the mutable buffer information is missing.
  2. After convert_pt2e, the input for index_put is replaced by frozen.
    If it becomes functional IR, which of the above issues can it solve?

Also, I'm curious if other backends have encountered this problem?

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

I'm not very clear about the relationship between functional IR and the mutable buffer issue. As far as I know, functional IR refers to operators that do not mutate or alias inputs. However, there seem to be two issues here:

  1. After quantization export, the mutable buffer information is missing.
  2. After convert_pt2e, the input for index_put is replaced by frozen.
    If it becomes functional IR, which of the above issues can it solve?

Also, I'm curious if other backends have encountered this problem?

If we have a functional IR during quantization, then instead of index_put_, we will see index_put with copy op at the end. For other backends, they are not quantizing the whole model, and just linear layer, so haven't hit the issue yet.

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Jun 5, 2025

If we have a functional IR during quantization, then instead of index_put_, we will see index_put with copy op at the end. For other backends, they are not quantizing the whole model, and just linear layer, so haven't hit the issue yet.

Got it. Is it similar to use legacy export which mean use the following patch? Actually, I currently use this patch to workaround mutable buffer issue. And It seems work. But it still doesn't resolve another issue which is the input for index_put is replaced by frozen after convert_pt2e. It will result in that we cannot compare with the CPU results after convert_pt2e since the input of index_put is fixed. Maybe I can try replace the frozen param with original mutable buffer after convert_pt2e to check is the workable.

with patch.object(
                torch._utils_internal,
                "export_training_ir_rollout_check",
                return_value=False,
            ):
            self.exported_whisper_encoder = torch.export.export(
                    self.whisper_encoder, self.whisper_encoder.get_example_inputs(), strict=True
                ).module()

the graph which is export with patch after convert_pt2e
image

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

it still doesn't resolve another issue which is the input for index_put is replaced by frozen after convert_pt2e

oh I remember that issue, and thought you sent a patch..

I would like to check if it's for support llama and similar decoder-only model in optimum. If so, maybe we can see the proposed option to use static llama instead of the optimum model definition, given that it will be much slower...

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Jun 5, 2025

oh I remember that issue, and thought you sent a patch..

Yes, I have a workaround PR before.

I would like to check if it's for support llama and similar decoder-only model in optimum. If so, maybe we can see the proposed option to use static llama instead of the optimum model definition, given that it will be much slower...

Yes, it is for support LLM and similar decoder model in GA model list.
In the beginning, I just want to leverage your runner with optimum model definition. 😊

Actually, we are also trying on using static llama instead of huggingingface model definition now.

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

I see, I will bring it up to the team. Using static llama likely is the easier solution

@shewu-quic
Copy link
Collaborator Author

I see, I will bring it up to the team.

Thanks a lot.

Using static llama likely is the easier solution

Yes, or modify the model definition to move kv cache as the inputs of the model such as mimi, whisper and T5.. etc.

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

The answer I got from compiler team is that:

torch.export.export(...).run_decompositions({}) is pretty much the same IR as the old functional IR minus autograd ops. If you are doing PTQ, it shouldn't matter

If it's resolved, then only the second issue needs to be solved.

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

The answer I got from compiler team is that:

torch.export.export(...).run_decompositions({}) is pretty much the same IR as the old functional IR minus autograd ops. If you are doing PTQ, it shouldn't matter

If it's resolved, then only the second issue need to be solved.

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

convert_pt2e

I also remember there is a way to selected fuse operator, @jerryzh168 any chance you know?

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

Just discuss with the team, still unclear if it's okay, but if the effort to enable them in static llama is minimum, it will be great to have them working there first, while we're trying to figure out this issue.

@cccclai
Copy link
Contributor

cccclai commented Jun 5, 2025

Regarding the second issue, @jerryzh168 was proposing

don't quantize inplace ops or change inplace ops to non-inpace ops in transform_for_quantization

I feel like it's reasonable, does it work for you.

@cccclai
Copy link
Contributor

cccclai commented Jun 9, 2025

@tugsbayasgalan made another suggestion

I think the best way is label ep.run_decompositions({}) as a "pass" to convert in-place ops to functional ops. I don't think we can do local conversion safely in general

I feel like it's also reasonable, and less work. I will try to make a PR

@cccclai
Copy link
Contributor

cccclai commented Jun 10, 2025

Here is the example PR pytorch/ao#2345, I don't have a better idea for re-tracing, and also it looks like we need to set

m = convert_pt2e(m, fold_quantize=False)

such that they won't be frozen. @jerryzh168 may have a better idea on this issue.

@shewu-quic
Copy link
Collaborator Author

Regarding the second issue, @jerryzh168 was proposing

don't quantize inplace ops or change inplace ops to non-inpace ops in transform_for_quantization

I feel like it's reasonable, does it work for you.

I have tried don't quantize the mutable buffer of inplace ops. It seems work well. Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. During the comple, I can fill in quant_attr of mutable buffer from the quant_attr of node becuase index_copy shold not affect quant_attr.
But I am figuring out two quesion.

  1. Is there a way to identify the input is mutable buffer?
  2. Is there problem for inplace binary operator such as mul_? How to handle the quant_attr of mutable buffer?
@register_annotator(
    [torch.ops.aten.index_copy.default, torch.ops.aten.index_copy_.default]
)
def annotate_index_copy(node: Node, quantization_config: QuantizationConfig) -> None:
    # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. 
    value = node.args[3]

    input_qspec_map = {}
    input_qspec_map[value] = quantization_config.input_activation

    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
        input_qspec_map=input_qspec_map,
        output_qspec=SharedQuantizationSpec((value, node)),
        _annotated=True,
    )

@shewu-quic
Copy link
Collaborator Author

@tugsbayasgalan made another suggestion

I think the best way is label ep.run_decompositions({}) as a "pass" to convert in-place ops to functional ops. I don't think we can do local conversion safely in general

I feel like it's also reasonable, and less work. I will try to make a PR

I've also tried using run_decompositions({}), and it appears to yield the same results as with patch. However, I'm uncertain if this affects other test cases.
image

@cccclai
Copy link
Contributor

cccclai commented Jun 10, 2025

I have tried don't quantize the mutable buffer of inplace ops. It seems work well. Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. During the comple, I can fill in quant_attr of mutable buffer from the quant_attr of node becuase index_copy shold not affect quant_attr.
But I am figuring out two quesion.

Do you mean that you avoid quantizing the index_put_ op? And you rely on the other ops next to the index_put_ to get the quant attr? I think in theory that's a proper way for index_put, because index_put is not a computation op, but data manipulation op and it shouldn't affect accuracy. In the internal boltnn stack, we don't quantize these ops either.

For mul_, we can't do this, because mul is a computation op, and we need to run functionalize this op. And during convert, this op shouldn't be folded either. When we run convert_pt2e, if we choose fold_quantize=False, then they won't be folded. I'm trying to figuring out how to avoid folding constant tensors with @jerryzh168

Is there a way to identify the input is mutable buffer?

I'm not quite following this question, if the buffer is not folded, like this

opcode         name                             target                                              args                                                              kwargs
-------------  -------------------------------  --------------------------------------------------  ----------------------------------------------------------------  --------
get_attr       buf                              buf                                                 ()                                                                {}
call_function  quantize_per_tensor_default      quantized_decomposed.quantize_per_tensor.default    (buf, 1.0, 0, 0, 255, torch.uint8)                                {}
call_function  dequantize_per_tensor_default    quantized_decomposed.dequantize_per_tensor.default  (quantize_per_tensor_default, 1.0, 0, 0, 255, torch.uint8)        {}
placeholder    x                                x                                                   ()                                                                {}
call_function  quantize_per_tensor_default_1    quantized_decomposed.quantize_per_tensor.default    (x, 1.0, 0, 0, 255, torch.uint8)                                  {}
call_function  dequantize_per_tensor_default_1  quantized_decomposed.dequantize_per_tensor.default  (quantize_per_tensor_default_1, 1.0, 0, 0, 255, torch.uint8)      {}
call_function  add                              aten.add.Tensor                                     (dequantize_per_tensor_default_1, dequantize_per_tensor_default)  {}
call_function  quantize_per_tensor_default_2    quantized_decomposed.quantize_per_tensor.default    (add, 1.0, 0, 0, 255, torch.uint8)                                {}
call_function  dequantize_per_tensor_default_2  quantized_decomposed.dequantize_per_tensor.default  (quantize_per_tensor_default_2, 1.0, 0, 0, 255, torch.uint8)      {}
call_function  copy__default                    aten.copy_.default                                  (x, dequantize_per_tensor_default_2)                              {}
output         output                           output                                              ((copy__default,),)                                               {}

THen we can see the get_attr is the mutable buffer input?

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Jun 10, 2025

In fact, I am trying to not quantize the mutable buffer input (args[0]) of the index_put operation. Because I have annotated the output of the index_put, I assign the quant_attr of the mutable buffer with the quant_attr of the index_put in the operation builder. Therefore, I'm considering if there's a way to identify whether the input is a mutable buffer to avoid annotation during annotation.

As you say, I think this approach cannot be used for computation op.

class IndexPutVisitor(NodeVisitor):
    target = ["aten.index_put.default"]

    def __init__(self, *args) -> None:
        super().__init__(*args)

    def define_node(
        input_node = self.get_node(node.args[0])
        if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
            quant_attrs = quant_attrs.copy()
            input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
        input_tensor = self.get_tensor(input_node, node)
        ...

@shewu-quic
Copy link
Collaborator Author

I think the best way is label ep.run_decompositions({}) as a "pass" to convert in-place ops to functional ops. I don't think we can do local conversion safely in general

Yes, I agree with ep.run_decomposition({}) as a pass in transform_for_quantization. It would be easy to use for user.

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Jun 10, 2025

By the way, does ExecuTorch initialize mutable buffer zero?
I try the below test but the results seems strange because of different initial value of k_cache.

# test model
class IndexCopy(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer(
            "k_cache",
            torch.ones((1, 1024, 12, 64), dtype=torch.float32),
            persistent=False
        )

    def forward(self, input_pos, k_val):
        k_out = self.k_cache
        k_out.index_copy_(1, input_pos, k_val)
        return k_out + 0

# test
    def test_qnn_backend_index_put(self):
        test_comb = [
            {
                QCOM_MODULE: IndexCopy(),  # noqa: F405
                QCOM_SAMPLE_INPUTS: [(
                    torch.tensor([2], dtype=torch.int64),
                    torch.randn([1, 1, 12, 64]),
                ),(
                    torch.tensor([1], dtype=torch.int64),
                    torch.randn([1, 1, 12, 64]),
                )],
            },
        ]
        for i, test in enumerate(test_comb):
            with self.subTest(i=i):
                module = self.get_qdq_module(
                    test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
                )
                # module.reset
                self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS])


# ref_output from nn.module:
tensor([[[[0.0218, 0.8506, 1.0905,  ..., 0.9814, 0.6325, 1.0905],
          [0.3271, 0.0000, 0.3708,  ..., 0.0000, 0.0000, 0.0000],
          [0.2835, 0.0000, 0.0436,  ..., 0.1963, 1.0905, 0.0000],
          ...,
          [1.0033, 1.0905, 0.5016,  ..., 0.2835, 0.0000, 1.0905],
          [0.0000, 1.0905, 1.0905,  ..., 0.0000, 0.0000, 1.0905],
          [0.2617, 0.6543, 0.6761,  ..., 0.0000, 1.0905, 0.0000]],

         [[0.2617, 0.0654, 0.3053,  ..., 0.0000, 0.0000, 0.7852],
          [0.7852, 0.3708, 0.0000,  ..., 0.2181, 0.4144, 1.0905],
          [0.0000, 1.0905, 0.0000,  ..., 0.0000, 1.0905, 1.0905],
          ...,
          [0.0000, 0.0000, 0.1090,  ..., 0.5234, 0.0000, 0.7197],
          [0.3490, 0.0000, 0.0000,  ..., 0.0218, 0.7197, 0.0000],
          [0.1745, 0.4798, 0.0000,  ..., 0.1309, 0.2399, 0.0000]],

         [[1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          ...,
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033]],

         [[1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          ...,
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033],
          [1.0033, 1.0033, 1.0033,  ..., 1.0033, 1.0033, 1.0033]]]])

# runner_output from qnn_executor_runner:
tensor([[[[0.0259, 0.8576, 1.0989,  ..., 0.9826, 0.6292, 1.0989],
          [0.3318, 0.0000, 0.3749,  ..., 0.0000, 0.0000, 0.0000],
          [0.2844, 0.0000, 0.0474,  ..., 0.2025, 1.0989, 0.0000],
          ...,
          [1.0041, 1.0989, 0.5042,  ..., 0.2758, 0.0000, 1.0989],
          [0.0000, 1.0989, 1.0989,  ..., 0.0000, 0.0000, 1.0946],
          [0.2543, 0.6594, 0.6809,  ..., 0.0000, 1.0989, 0.0000]],

         [[0.2715, 0.0646, 0.3146,  ..., 0.0000, 0.0000, 0.7800],
          [0.7887, 0.3663, 0.0000,  ..., 0.2284, 0.4094, 1.0989],
          [0.0000, 1.0989, 0.0000,  ..., 0.0000, 1.0989, 1.0989],
          ...,
          [0.0000, 0.0000, 0.1164,  ..., 0.5128, 0.0000, 0.7111],
          [0.3405, 0.0000, 0.0000,  ..., 0.0215, 0.7111, 0.0000],
          [0.1810, 0.4741, 0.0000,  ..., 0.1250, 0.2327, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]])

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Jun 10, 2025

Let me try to summary current results.
We have two approaches to fixed mutable buffer missing issue and fold quant with mutable buffer issue now.

Approach 1: Use run_decomposition as a pass before annotation and convert_pt2e(m, fold_quantized=False)

  • It seems works well.
  • In my view, this approach should be general more.

Follow up question:

  1. How to avoid re-tracing?
  2. Is it possible that mutable buffer won't be frozen without fold_quantized=False?
    Or it is expected solution with fold_quantized=False

Approach 2: Avoid quantize mutable buffer

  • It can work well for non computation op
  • Computation op such as mul_ should be failed because of the inability to process quant_attr.
  • I feel like it is less work.

Would this be in line with your thoughts?

@cccclai
Copy link
Contributor

cccclai commented Jun 10, 2025

Yes, it's aligned. Thank you for summarizing.

Yes, I agree with ep.run_decomposition({}) as a pass in transform_for_quantization. It would be easy to use for users.

The pass will run under the hood, so users won't notice the change.

Therefore, I'm considering if there's a way to identify whether the input is a mutable buffer to avoid annotation during annotation.

I think the graph input will be a buffer instead of a place holder in this case? We may need to trace up the graph

How to avoid re-tracing?

Yeah agree that re-tracing might not be the best idea, maybe we can add a pass to manually convert in place op to functional op, just so we're not tied to retracing. It seems like the general recommendation from compiler team is to trace to be safe, but I feel like with passes we have more control, but possibly more work (the effort to add passes). What do you think?

Is it possible that mutable buffer won't be frozen without fold_quantized=False? Or it is expected solution with fold_quantized=False

I'm still discussing with Jerry about it, I think we should have a way to choose what op to fold.

Approach 2: Avoid quantize mutable buffer

I think in general, we want to avoid annotate non-compute op, as it might slightly affect accuracy. So we should avoid annotating index_put_ anyway, and it will be an improvement. This will unblock us for decoder-only model.

For the general solution (option 1), we haven't hit a GA model like this, so maybe safe to go with approach 2 first, while we're working on a proper solution for in place compute op in option 1.

@shewu-quic
Copy link
Collaborator Author

I think the graph input will be a buffer instead of a place holder in this case? We may need to trace up the graph

Are you suggesting that a buffer should be a get_attr node? If so, shouldn't weights or other constant nodes also be get_attr nodes?

Yeah agree that re-tracing might not be the best idea, maybe we can add a pass to manually convert in place op to functional op, just so we're not tied to retracing. It seems like the general recommendation from compiler team is to trace to be safe, but I feel like with passes we have more control, but possibly more work (the effort to add passes). What do you think?

Yes, I also intend to use a pass to convert them. I think we need put some efforts to test this pass.

  1. How to identify the mutable buffer for in-place op?
    In transform_for_annotation, we only have access to graph_module. For dual input nodes like mul_, how can we determine whether the first input is a mutable buffer or the second? If we had exported_program, it would be easier to identify using graph_signature..

I'm still discussing with Jerry about it, I think we should have a way to choose what op to fold.

Thanks for your help. The main effort for this task is required because if fold_quantized=False, we need to make some changes in annotate_quant_attr.py and other related passes.

Approach 2: Avoid quantize mutable buffer

For the general solution (option 1), we haven't hit a GA model like this, so maybe safe to go with approach 2 first, while we're working on a proper solution for in place compute op in option 1.

It makes sense to me. I will put a PR to fix mutable buffer issue with approach 2 including this change to improve performance.

By the way, we have a good new. The Qwen model weights can be loaded into our static LLaMA structure. If other LLM-based models are compatible in the same way, it would be very helpful.
However, this approach is still useful for whisper or T5 model.

@cccclai
Copy link
Contributor

cccclai commented Jun 16, 2025

@shewu-quic @haowhsu-quic
I update the example in pytorch/ao#2345 with a patch to the constant fold, such that the mutable buffer won't be folded. The idea is basically to find the buffer by getting the input from the copy_ op, and the first argument is always the mutable buffer. Does it work for you?

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Jun 18, 2025

@shewu-quic @haowhsu-quic I update the example in pytorch/ao#2345 with a patch to the constant fold, such that the mutable buffer won't be folded. The idea is basically to find the buffer by getting the input from the copy_ op, and the first argument is always the mutable buffer. Does it work for you?

Sorry for my late reply. I have created a PR to go approach 2 and add a option to choose whether delegates mutable buffer or not. And based on this PR, I can successfully enable whisper model.

Thanks for your PR and I will take a shot.
I have a quick question regarding this PR: We still need run_decomposition, correct? Do you have any plans to make it a pass so that re-tracing is not required?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants