Skip to content

Unit Tests for On Device Sampling #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

quic-sanising
Copy link
Contributor

@quic-sanising quic-sanising commented Jun 18, 2025

This PR adds the following Unit Tests for On Device Sampling:

  1. test_sampler_transform: Test if SamplerTransform adds nodes at the output of a QEffForCausalLM model to enable the sampling of next tokens at the device (instead of the host) and returns the next tokens and/or probability distributions.
  2. test_greedy_sampling: Test greedy sampling with QPC compiled with and without On Device Sampling.
  3. test_random_sampling: Test random sampling with QPC compiled with and without On Device Sampling.

Signed-off-by: quic-sanising <[email protected]>
@quic-rishinr
Copy link
Contributor

@quic-sanising can you add a small feature description under /docs/source/quick_start.md supported feature section? also provide the example script link in the description

@quic-sanising
Copy link
Contributor Author

@quic-sanising can you add a small feature description under /docs/source/quick_start.md supported feature section? also provide the example script link in the description

Done

Copy link
Contributor

@quic-amitraj quic-amitraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix lint error.

sanising added 3 commits July 3, 2025 13:44
@quic-sanising quic-sanising marked this pull request as ready for review July 3, 2025 19:08
Signed-off-by: sanising <[email protected]>
@quic-sanising
Copy link
Contributor Author

quic-sanising commented Jul 3, 2025

Please fix lint error.

@quic-amitraj The lint failures were happening because the linter is installing ruff v0.12.2 whereas the .pre-commit-config.yaml file has an older version of v0.5.2.

To fix the errors, we need to either install ruff v0.5.2 in the linter or update the .pre-commit-config.yaml file to version v0.12.2.

Copy link
Contributor

@ochougul ochougul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything else LGTM, just see if we can use single layer model in our tests

"sampling support is not available. Please check the QPC and try again."
)

if include_sampler and not self.include_sampler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include this check in line 489 itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify this?

if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)
next_token_id = self._fetch_next_token_id(outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is inside the decode loop would it create any performance drop? Have you done any tests for checking any performance deviation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not follow. How will this cause performance drop? Instead of performing argmax on the host CPU, we are simply reading the next_token provided by the QAIC device. In my opinion, this would lead to performance improvement instead. Please let me know if you are talking about something else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the additional reshape just wanted to check it would create any perfromance deviation. its good to have a performance number for a smaller model just for reference

Copy link
Contributor Author

@quic-sanising quic-sanising Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still do not follow. Are you talking about this reshape?

return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1])

If yes, this converts 3D tensor of shape (batch_size, 1, 1) to a 2D tensor of shape (batch_size, 1). This operation doesn't cause a drop in performance.

Otherwise, if you are curious about the overall performance gains, please reach out to me for a complete performance report.

)

# Compare generated texts
golden_texts = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this test more robust? The goal is to verify whether the random sampler is working as expected right? One idea is to check if the sampler produces the same output when given the same random seed or random values, and different outputs when the seed or values change. Secondly if we want to compare outputs with and without the sampler cam we use cosine similarity with a threshold instead of exact matching? im open for suggestions

Copy link
Contributor Author

@quic-sanising quic-sanising Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! The current test checks for exact output match, which allows us to reproduce it with a fixed seed value.

As for using cosine similarity or other metrics like perplexity or rogue score, they are useful for semantic comparison. But can be vague as there are no clear-defined or universally accepted thresholds.

So, for this test, exact matching might help keep things deterministic and easier to maintain in my opinion. Happy to explore alternatives if you want to...

Signed-off-by: sanising <[email protected]>

# Load QPC
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)

# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have something like self.include_sampler = validate_sampler_inputs(set(self._session.input_names), include_sampler if include_sampler is not False else False

Copy link
Contributor Author

@quic-sanising quic-sanising Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to do this. include_sampler is a mandatory boolean variable with a default value of False. The only other possible value is True. Both of these scenarios are handled in validate_sampler_inputs() function. Additionally, the function handles the case of include_sampler=None so that we can re-use the function in other places if we want to.

if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)
next_token_id = self._fetch_next_token_id(outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the additional reshape just wanted to check it would create any perfromance deviation. its good to have a performance number for a smaller model just for reference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants