-
Notifications
You must be signed in to change notification settings - Fork 50
Unit Tests for On Device Sampling #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: quic-sanising <[email protected]>
@quic-sanising can you add a small feature description under /docs/source/quick_start.md supported feature section? also provide the example script link in the description |
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Done |
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix lint error.
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
@quic-amitraj The lint failures were happening because the linter is installing To fix the errors, we need to either install |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything else LGTM, just see if we can use single layer model in our tests
"sampling support is not available. Please check the QPC and try again." | ||
) | ||
|
||
if include_sampler and not self.include_sampler: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we include this check in line 489 itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify this?
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
next_token_id = logits.argmax(2) | ||
next_token_id = self._fetch_next_token_id(outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is inside the decode loop would it create any performance drop? Have you done any tests for checking any performance deviation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not follow. How will this cause performance drop? Instead of performing argmax on the host CPU, we are simply reading the next_token
provided by the QAIC device. In my opinion, this would lead to performance improvement instead. Please let me know if you are talking about something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the additional reshape just wanted to check it would create any perfromance deviation. its good to have a performance number for a smaller model just for reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still do not follow. Are you talking about this reshape?
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1])
If yes, this converts 3D tensor of shape (batch_size, 1, 1)
to a 2D tensor of shape (batch_size, 1)
. This operation doesn't cause a drop in performance.
Otherwise, if you are curious about the overall performance gains, please reach out to me for a complete performance report.
) | ||
|
||
# Compare generated texts | ||
golden_texts = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this test more robust? The goal is to verify whether the random sampler is working as expected right? One idea is to check if the sampler produces the same output when given the same random seed or random values, and different outputs when the seed or values change. Secondly if we want to compare outputs with and without the sampler cam we use cosine similarity with a threshold instead of exact matching? im open for suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! The current test checks for exact output match, which allows us to reproduce it with a fixed seed value.
As for using cosine similarity or other metrics like perplexity or rogue score, they are useful for semantic comparison. But can be vague as there are no clear-defined or universally accepted thresholds.
So, for this test, exact matching might help keep things deterministic and easier to maintain in my opinion. Happy to explore alternatives if you want to...
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
|
||
# Load QPC | ||
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) | ||
|
||
# Validate sampler inputs for On-Device Sampling | ||
self.include_sampler = validate_sampler_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have something like self.include_sampler = validate_sampler_inputs(set(self._session.input_names), include_sampler if include_sampler is not False else False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to do this. include_sampler
is a mandatory boolean variable with a default value of False
. The only other possible value is True
. Both of these scenarios are handled in validate_sampler_inputs()
function. Additionally, the function handles the case of include_sampler=None
so that we can re-use the function in other places if we want to.
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
next_token_id = logits.argmax(2) | ||
next_token_id = self._fetch_next_token_id(outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the additional reshape just wanted to check it would create any perfromance deviation. its good to have a performance number for a smaller model just for reference
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
This PR adds the following Unit Tests for On Device Sampling:
test_sampler_transform
: Test ifSamplerTransform
adds nodes at the output of aQEffForCausalLM model
to enable the sampling of next tokens at the device (instead of the host) and returns the next tokens and/or probability distributions.test_greedy_sampling
: Test greedy sampling with QPC compiled with and without On Device Sampling.test_random_sampling
: Test random sampling with QPC compiled with and without On Device Sampling.